Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/type_info.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ eps float The smallest representable number such that ``1.0 + eps != 1
max float The largest representable number.
min float The smallest representable number (typically ``-max``).
tiny float The smallest positive representable number.
resolution float The approximate decimal resolution of this type, i.e., ``10**-precision``.
========= ===== ========================================

.. note::
Expand Down
30 changes: 24 additions & 6 deletions test/test_type_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,29 +14,30 @@
class TestDTypeInfo(TestCase):

def test_invalid_input(self):
for dtype in [torch.float32, torch.float64]:
for dtype in [torch.float16, torch.float32, torch.float64, torch.bfloat16, torch.complex64, torch.complex128, torch.bool]:
with self.assertRaises(TypeError):
_ = torch.iinfo(dtype)

for dtype in [torch.int64, torch.int32, torch.int16, torch.uint8]:
for dtype in [torch.int64, torch.int32, torch.int16, torch.int8, torch.uint8, torch.bool]:
with self.assertRaises(TypeError):
_ = torch.finfo(dtype)

@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
def test_iinfo(self):
for dtype in [torch.int64, torch.int32, torch.int16, torch.uint8]:
for dtype in [torch.int64, torch.int32, torch.int16, torch.int8, torch.uint8]:
x = torch.zeros((2, 2), dtype=dtype)
xinfo = torch.iinfo(x.dtype)
xn = x.cpu().numpy()
xninfo = np.iinfo(xn.dtype)
self.assertEqual(xinfo.bits, xninfo.bits)
self.assertEqual(xinfo.max, xninfo.max)
self.assertEqual(xinfo.min, xninfo.min)
self.assertEqual(xinfo.dtype, xninfo.dtype)

@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
def test_finfo(self):
initial_default_type = torch.get_default_dtype()
for dtype in [torch.float32, torch.float64]:
for dtype in [torch.float16, torch.float32, torch.float64, torch.complex64, torch.complex128]:
x = torch.zeros((2, 2), dtype=dtype)
xinfo = torch.finfo(x.dtype)
xn = x.cpu().numpy()
Expand All @@ -46,8 +47,25 @@ def test_finfo(self):
self.assertEqual(xinfo.min, xninfo.min)
self.assertEqual(xinfo.eps, xninfo.eps)
self.assertEqual(xinfo.tiny, xninfo.tiny)
torch.set_default_dtype(dtype)
self.assertEqual(torch.finfo(dtype), torch.finfo())
self.assertEqual(xinfo.resolution, xninfo.resolution)
self.assertEqual(xinfo.dtype, xninfo.dtype)
if not dtype.is_complex:
torch.set_default_dtype(dtype)
self.assertEqual(torch.finfo(dtype), torch.finfo())

# Special test case for BFloat16 type
x = torch.zeros((2, 2), dtype=torch.bfloat16)
xinfo = torch.finfo(x.dtype)
self.assertEqual(xinfo.bits, 16)
self.assertEqual(xinfo.max, 3.38953e+38)
self.assertEqual(xinfo.min, -3.38953e+38)
self.assertEqual(xinfo.eps, 0.0078125)
self.assertEqual(xinfo.tiny, 1.17549e-38)
self.assertEqual(xinfo.resolution, 0.01)
self.assertEqual(xinfo.dtype, "bfloat16")
torch.set_default_dtype(x.dtype)
self.assertEqual(torch.finfo(x.dtype), torch.finfo())

# Restore the default type to ensure that the test has no side effect
torch.set_default_dtype(initial_default_type)

Expand Down
3 changes: 3 additions & 0 deletions torch/_C/__init__.pyi.in
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ class iinfo:
bits: _int
min: _int
max: _int
dtype: str

def __init__(self, dtype: _dtype) -> None: ...

Expand All @@ -68,6 +69,8 @@ class finfo:
max: _float
eps: _float
tiny: _float
resolution: _float
dtype: str

@overload
def __init__(self, dtype: _dtype) -> None: ...
Expand Down
80 changes: 60 additions & 20 deletions torch/csrc/TypeInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <torch/csrc/utils/python_arg_parser.h>
#include <torch/csrc/utils/python_numbers.h>
#include <torch/csrc/utils/python_strings.h>
#include <torch/csrc/utils/tensor_dtypes.h>

#include <c10/util/Exception.h>

Expand All @@ -20,7 +21,7 @@ PyObject* THPFInfo_New(const at::ScalarType& type) {
if (!self)
throw python_error();
auto self_ = reinterpret_cast<THPDTypeInfo*>(self.get());
self_->type = type;
self_->type = c10::toValueType(type);
return self.release();
}

Expand All @@ -34,18 +35,6 @@ PyObject* THPIInfo_New(const at::ScalarType& type) {
return self.release();
}

PyObject* THPFInfo_str(THPFInfo* self) {
std::ostringstream oss;
oss << "finfo(type=" << self->type << ")";
return THPUtils_packString(oss.str().c_str());
}

PyObject* THPIInfo_str(THPIInfo* self) {
std::ostringstream oss;
oss << "iinfo(type=" << self->type << ")";
return THPUtils_packString(oss.str().c_str());
}

PyObject* THPFInfo_pynew(PyTypeObject* type, PyObject* args, PyObject* kwargs) {
HANDLE_TH_ERRORS
static torch::PythonArgParser parser({
Expand All @@ -63,7 +52,7 @@ PyObject* THPFInfo_pynew(PyTypeObject* type, PyObject* args, PyObject* kwargs) {
AT_ASSERT(at::isFloatingType(scalar_type));
} else {
scalar_type = r.scalartype(0);
if (!at::isFloatingType(scalar_type)) {
if (!at::isFloatingType(scalar_type) && !at::isComplexType(scalar_type)) {
return PyErr_Format(
PyExc_TypeError,
"torch.finfo() requires a floating point input type. Use torch.iinfo to handle '%s'",
Expand Down Expand Up @@ -123,7 +112,7 @@ static PyObject* THPDTypeInfo_bits(THPDTypeInfo* self, void*) {
}

static PyObject* THPFInfo_eps(THPFInfo* self, void*) {
return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(at::kHalf,
return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(at::kHalf, at::ScalarType::BFloat16,
self->type, "epsilon", [] {
return PyFloat_FromDouble(
std::numeric_limits<
Expand All @@ -132,20 +121,20 @@ static PyObject* THPFInfo_eps(THPFInfo* self, void*) {
}

static PyObject* THPFInfo_max(THPFInfo* self, void*) {
return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(at::kHalf, self->type, "max", [] {
return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(at::kHalf, at::ScalarType::BFloat16, self->type, "max", [] {
return PyFloat_FromDouble(
std::numeric_limits<at::scalar_value_type<scalar_t>::type>::max());
});
}

static PyObject* THPFInfo_min(THPFInfo* self, void*) {
return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(at::kHalf, self->type, "min", [] {
return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(at::kHalf, at::ScalarType::BFloat16, self->type, "lowest", [] {
return PyFloat_FromDouble(
std::numeric_limits<at::scalar_value_type<scalar_t>::type>::lowest());
});
}

static PyObject* THPIInfo_max(THPFInfo* self, void*) {
static PyObject* THPIInfo_max(THPIInfo* self, void*) {
if (at::isIntegralType(self->type, /*includeBool=*/false)) {
return AT_DISPATCH_INTEGRAL_TYPES(self->type, "max", [] {
return THPUtils_packInt64(std::numeric_limits<scalar_t>::max());
Expand All @@ -157,7 +146,7 @@ static PyObject* THPIInfo_max(THPFInfo* self, void*) {
});
}

static PyObject* THPIInfo_min(THPFInfo* self, void*) {
static PyObject* THPIInfo_min(THPIInfo* self, void*) {
if (at::isIntegralType(self->type, /*includeBool=*/false)) {
return AT_DISPATCH_INTEGRAL_TYPES(self->type, "min", [] {
return THPUtils_packInt64(std::numeric_limits<scalar_t>::lowest());
Expand All @@ -169,19 +158,69 @@ static PyObject* THPIInfo_min(THPFInfo* self, void*) {
});
}

static PyObject* THPIInfo_dtype(THPIInfo* self, void*) {
std::string primary_name, legacy_name;
std::tie(primary_name, legacy_name) = torch::utils::getDtypeNames(self->type);
return AT_DISPATCH_INTEGRAL_TYPES(self->type, "dtype", [primary_name] {
return PyUnicode_FromString((char*)primary_name.data());
});
}

static PyObject* THPFInfo_tiny(THPFInfo* self, void*) {
return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(at::kHalf, self->type, "min", [] {
return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(at::kHalf, at::ScalarType::BFloat16, self->type, "min", [] {
return PyFloat_FromDouble(
std::numeric_limits<at::scalar_value_type<scalar_t>::type>::min());
});
}

static PyObject* THPFInfo_resolution(THPFInfo* self, void*) {
return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(at::kHalf, at::ScalarType::BFloat16, self->type, "digits10", [] {
return PyFloat_FromDouble(
std::pow(10, -std::numeric_limits<at::scalar_value_type<scalar_t>::type>::digits10));
});
}

static PyObject* THPFInfo_dtype(THPFInfo* self, void*) {
std::string primary_name, legacy_name;
std::tie(primary_name, legacy_name) = torch::utils::getDtypeNames(self->type);
return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(at::kHalf, at::ScalarType::BFloat16, self->type, "dtype", [primary_name] {
return PyUnicode_FromString((char*)primary_name.data());
});
}

PyObject* THPFInfo_str(THPFInfo* self) {
std::ostringstream oss;
oss << "finfo(resolution=" << PyFloat_AsDouble(THPFInfo_resolution(self, nullptr));
oss << ", min=" << PyFloat_AsDouble(THPFInfo_min(self, nullptr));
oss << ", max=" << PyFloat_AsDouble(THPFInfo_max(self, nullptr));
oss << ", eps=" << PyFloat_AsDouble(THPFInfo_eps(self, nullptr));
oss << ", tiny=" << PyFloat_AsDouble(THPFInfo_tiny(self, nullptr));
oss << ", dtype=" << PyUnicode_AsUTF8(THPFInfo_dtype(self, nullptr)) << ")";

return THPUtils_packString(oss.str().c_str());
}

PyObject* THPIInfo_str(THPIInfo* self) {
auto type = self->type;
std::string primary_name, legacy_name;
std::tie(primary_name, legacy_name) = torch::utils::getDtypeNames(type);
std::ostringstream oss;

oss << "iinfo(min=" << PyFloat_AsDouble(THPIInfo_min(self, nullptr));
oss << ", max=" << PyFloat_AsDouble(THPIInfo_max(self, nullptr));
oss << ", dtype=" << PyUnicode_AsUTF8(THPIInfo_dtype(self, nullptr)) << ")";

return THPUtils_packString(oss.str().c_str());
}

static struct PyGetSetDef THPFInfo_properties[] = {
{"bits", (getter)THPDTypeInfo_bits, nullptr, nullptr, nullptr},
{"eps", (getter)THPFInfo_eps, nullptr, nullptr, nullptr},
{"max", (getter)THPFInfo_max, nullptr, nullptr, nullptr},
{"min", (getter)THPFInfo_min, nullptr, nullptr, nullptr},
{"tiny", (getter)THPFInfo_tiny, nullptr, nullptr, nullptr},
{"resolution", (getter)THPFInfo_resolution, nullptr, nullptr, nullptr},
{"dtype", (getter)THPFInfo_dtype, nullptr, nullptr, nullptr},
{nullptr}};

static PyMethodDef THPFInfo_methods[] = {
Expand Down Expand Up @@ -232,6 +271,7 @@ static struct PyGetSetDef THPIInfo_properties[] = {
{"bits", (getter)THPDTypeInfo_bits, nullptr, nullptr, nullptr},
{"max", (getter)THPIInfo_max, nullptr, nullptr, nullptr},
{"min", (getter)THPIInfo_min, nullptr, nullptr, nullptr},
{"dtype", (getter)THPIInfo_dtype, nullptr, nullptr, nullptr},
{nullptr}};

static PyMethodDef THPIInfo_methods[] = {
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/utils/tensor_dtypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
namespace torch {
namespace utils {

static std::pair<std::string, std::string> getDtypeNames(
std::pair<std::string, std::string> getDtypeNames(
at::ScalarType scalarType) {
switch (scalarType) {
case at::ScalarType::Byte:
Expand Down
2 changes: 2 additions & 0 deletions torch/csrc/utils/tensor_dtypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

namespace torch { namespace utils {

std::pair<std::string, std::string> getDtypeNames(at::ScalarType scalarType);

void initializeDtypes();

}} // namespace torch::utils