1 #include <torch/csrc/TypeInfo.h>
2
3 #include <torch/csrc/Exceptions.h>
4 #include <torch/csrc/utils/object_ptr.h>
5 #include <torch/csrc/utils/pybind.h>
6 #include <torch/csrc/utils/python_arg_parser.h>
7 #include <torch/csrc/utils/python_numbers.h>
8 #include <torch/csrc/utils/python_strings.h>
9 #include <torch/csrc/utils/tensor_dtypes.h>
10
11 #include <ATen/Dispatch_v2.h>
12
13 #include <c10/util/Exception.h>
14
15 #include <structmember.h>
16 #include <cstring>
17 #include <limits>
18 #include <sstream>
19
THPFInfo_New(const at::ScalarType & type)20 PyObject* THPFInfo_New(const at::ScalarType& type) {
21 auto finfo = (PyTypeObject*)&THPFInfoType;
22 auto self = THPObjectPtr{finfo->tp_alloc(finfo, 0)};
23 if (!self)
24 throw python_error();
25 auto self_ = reinterpret_cast<THPDTypeInfo*>(self.get());
26 self_->type = c10::toRealValueType(type);
27 return self.release();
28 }
29
THPIInfo_New(const at::ScalarType & type)30 PyObject* THPIInfo_New(const at::ScalarType& type) {
31 auto iinfo = (PyTypeObject*)&THPIInfoType;
32 auto self = THPObjectPtr{iinfo->tp_alloc(iinfo, 0)};
33 if (!self)
34 throw python_error();
35 auto self_ = reinterpret_cast<THPDTypeInfo*>(self.get());
36 self_->type = type;
37 return self.release();
38 }
39
THPFInfo_pynew(PyTypeObject * type,PyObject * args,PyObject * kwargs)40 PyObject* THPFInfo_pynew(PyTypeObject* type, PyObject* args, PyObject* kwargs) {
41 HANDLE_TH_ERRORS
42 static torch::PythonArgParser parser({
43 "finfo(ScalarType type)",
44 "finfo()",
45 });
46
47 torch::ParsedArgs<1> parsed_args;
48 auto r = parser.parse(args, kwargs, parsed_args);
49 TORCH_CHECK(r.idx < 2, "Not a type");
50 at::ScalarType scalar_type = at::ScalarType::Undefined;
51 if (r.idx == 1) {
52 scalar_type = torch::tensors::get_default_scalar_type();
53 // The default tensor type can only be set to a floating point type/
54 AT_ASSERT(at::isFloatingType(scalar_type));
55 } else {
56 scalar_type = r.scalartype(0);
57 if (!at::isFloatingType(scalar_type) && !at::isComplexType(scalar_type)) {
58 return PyErr_Format(
59 PyExc_TypeError,
60 "torch.finfo() requires a floating point input type. Use torch.iinfo to handle '%s'",
61 type->tp_name);
62 }
63 }
64 return THPFInfo_New(scalar_type);
65 END_HANDLE_TH_ERRORS
66 }
67
THPIInfo_pynew(PyTypeObject * type,PyObject * args,PyObject * kwargs)68 PyObject* THPIInfo_pynew(PyTypeObject* type, PyObject* args, PyObject* kwargs) {
69 HANDLE_TH_ERRORS
70 static torch::PythonArgParser parser({
71 "iinfo(ScalarType type)",
72 });
73 torch::ParsedArgs<1> parsed_args;
74 auto r = parser.parse(args, kwargs, parsed_args);
75 TORCH_CHECK(r.idx == 0, "Not a type");
76
77 at::ScalarType scalar_type = r.scalartype(0);
78 if (scalar_type == at::ScalarType::Bool) {
79 return PyErr_Format(
80 PyExc_TypeError, "torch.bool is not supported by torch.iinfo");
81 }
82 if (!at::isIntegralType(scalar_type, /*includeBool=*/false) &&
83 !at::isQIntType(scalar_type)) {
84 return PyErr_Format(
85 PyExc_TypeError,
86 "torch.iinfo() requires an integer input type. Use torch.finfo to handle '%s'",
87 type->tp_name);
88 }
89 return THPIInfo_New(scalar_type);
90 END_HANDLE_TH_ERRORS
91 }
92
THPDTypeInfo_compare(THPDTypeInfo * a,THPDTypeInfo * b,int op)93 PyObject* THPDTypeInfo_compare(THPDTypeInfo* a, THPDTypeInfo* b, int op) {
94 switch (op) {
95 case Py_EQ:
96 if (a->type == b->type) {
97 Py_RETURN_TRUE;
98 } else {
99 Py_RETURN_FALSE;
100 }
101 case Py_NE:
102 if (a->type != b->type) {
103 Py_RETURN_TRUE;
104 } else {
105 Py_RETURN_FALSE;
106 }
107 }
108 return Py_INCREF(Py_NotImplemented), Py_NotImplemented;
109 }
110
THPDTypeInfo_bits(THPDTypeInfo * self,void *)111 static PyObject* THPDTypeInfo_bits(THPDTypeInfo* self, void*) {
112 uint64_t bits = elementSize(self->type) * CHAR_BIT;
113 return THPUtils_packUInt64(bits);
114 }
115
116 #define _AT_DISPATCH_FINFO_TYPES(TYPE, NAME, ...) \
117 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND6( \
118 at::kHalf, \
119 at::ScalarType::BFloat16, \
120 at::ScalarType::Float8_e5m2, \
121 at::ScalarType::Float8_e5m2fnuz, \
122 at::ScalarType::Float8_e4m3fn, \
123 at::ScalarType::Float8_e4m3fnuz, \
124 TYPE, \
125 NAME, \
126 __VA_ARGS__)
127
THPFInfo_eps(THPFInfo * self,void *)128 static PyObject* THPFInfo_eps(THPFInfo* self, void*) {
129 HANDLE_TH_ERRORS
130 return _AT_DISPATCH_FINFO_TYPES(self->type, "epsilon", [] {
131 return PyFloat_FromDouble(
132 std::numeric_limits<at::scalar_value_type<scalar_t>::type>::epsilon());
133 });
134 END_HANDLE_TH_ERRORS
135 }
136
THPFInfo_max(THPFInfo * self,void *)137 static PyObject* THPFInfo_max(THPFInfo* self, void*) {
138 HANDLE_TH_ERRORS
139 return _AT_DISPATCH_FINFO_TYPES(self->type, "max", [] {
140 return PyFloat_FromDouble(
141 std::numeric_limits<at::scalar_value_type<scalar_t>::type>::max());
142 });
143 END_HANDLE_TH_ERRORS
144 }
145
THPFInfo_min(THPFInfo * self,void *)146 static PyObject* THPFInfo_min(THPFInfo* self, void*) {
147 HANDLE_TH_ERRORS
148 return _AT_DISPATCH_FINFO_TYPES(self->type, "lowest", [] {
149 return PyFloat_FromDouble(
150 std::numeric_limits<at::scalar_value_type<scalar_t>::type>::lowest());
151 });
152 END_HANDLE_TH_ERRORS
153 }
154
155 #define AT_DISPATCH_IINFO_TYPES(TYPE, NAME, ...) \
156 AT_DISPATCH_V2( \
157 TYPE, NAME, AT_WRAP(__VA_ARGS__), AT_EXPAND(AT_INTEGRAL_TYPES_V2))
158
THPIInfo_max(THPIInfo * self,void *)159 static PyObject* THPIInfo_max(THPIInfo* self, void*) {
160 HANDLE_TH_ERRORS
161 if (at::isIntegralType(self->type, /*includeBool=*/false)) {
162 return AT_DISPATCH_IINFO_TYPES(self->type, "max", [] {
163 if (std::is_unsigned_v<scalar_t>) {
164 return THPUtils_packUInt64(std::numeric_limits<scalar_t>::max());
165 } else {
166 return THPUtils_packInt64(std::numeric_limits<scalar_t>::max());
167 }
168 });
169 }
170 // Quantized Type
171 return AT_DISPATCH_QINT_AND_SUB_BYTE_TYPES(self->type, "max", [] {
172 return THPUtils_packInt64(std::numeric_limits<underlying_t>::max());
173 });
174 END_HANDLE_TH_ERRORS
175 }
176
THPIInfo_min(THPIInfo * self,void *)177 static PyObject* THPIInfo_min(THPIInfo* self, void*) {
178 HANDLE_TH_ERRORS
179 if (at::isIntegralType(self->type, /*includeBool=*/false)) {
180 return AT_DISPATCH_IINFO_TYPES(self->type, "min", [] {
181 if (std::is_unsigned_v<scalar_t>) {
182 return THPUtils_packUInt64(std::numeric_limits<scalar_t>::lowest());
183 } else {
184 return THPUtils_packInt64(std::numeric_limits<scalar_t>::lowest());
185 }
186 });
187 }
188 // Quantized Type
189 return AT_DISPATCH_QINT_AND_SUB_BYTE_TYPES(self->type, "min", [] {
190 return THPUtils_packInt64(std::numeric_limits<underlying_t>::lowest());
191 });
192 END_HANDLE_TH_ERRORS
193 }
194
THPIInfo_dtype(THPIInfo * self,void *)195 static PyObject* THPIInfo_dtype(THPIInfo* self, void*) {
196 HANDLE_TH_ERRORS
197 auto primary_name = c10::getDtypeNames(self->type).first;
198 return AT_DISPATCH_IINFO_TYPES(self->type, "dtype", [&primary_name] {
199 return PyUnicode_FromString(primary_name.data());
200 });
201 END_HANDLE_TH_ERRORS
202 }
203
THPFInfo_smallest_normal(THPFInfo * self,void *)204 static PyObject* THPFInfo_smallest_normal(THPFInfo* self, void*) {
205 HANDLE_TH_ERRORS
206 return _AT_DISPATCH_FINFO_TYPES(self->type, "min", [] {
207 return PyFloat_FromDouble(
208 std::numeric_limits<at::scalar_value_type<scalar_t>::type>::min());
209 });
210 END_HANDLE_TH_ERRORS
211 }
212
THPFInfo_tiny(THPFInfo * self,void *)213 static PyObject* THPFInfo_tiny(THPFInfo* self, void*) {
214 // see gh-70909, essentially the array_api prefers smallest_normal over tiny
215 return THPFInfo_smallest_normal(self, nullptr);
216 }
217
THPFInfo_resolution(THPFInfo * self,void *)218 static PyObject* THPFInfo_resolution(THPFInfo* self, void*) {
219 HANDLE_TH_ERRORS
220 return _AT_DISPATCH_FINFO_TYPES(self->type, "digits10", [] {
221 return PyFloat_FromDouble(std::pow(
222 10,
223 -std::numeric_limits<at::scalar_value_type<scalar_t>::type>::digits10));
224 });
225 END_HANDLE_TH_ERRORS
226 }
227
THPFInfo_dtype(THPFInfo * self,void *)228 static PyObject* THPFInfo_dtype(THPFInfo* self, void*) {
229 HANDLE_TH_ERRORS
230 auto primary_name = c10::getDtypeNames(self->type).first;
231 return _AT_DISPATCH_FINFO_TYPES(self->type, "dtype", [&primary_name] {
232 return PyUnicode_FromString(primary_name.data());
233 });
234 END_HANDLE_TH_ERRORS
235 }
236
THPFInfo_str(THPFInfo * self)237 PyObject* THPFInfo_str(THPFInfo* self) {
238 std::ostringstream oss;
239 const auto dtypeStr = THPFInfo_dtype(self, nullptr);
240 oss << "finfo(resolution="
241 << PyFloat_AsDouble(THPFInfo_resolution(self, nullptr));
242 oss << ", min=" << PyFloat_AsDouble(THPFInfo_min(self, nullptr));
243 oss << ", max=" << PyFloat_AsDouble(THPFInfo_max(self, nullptr));
244 oss << ", eps=" << PyFloat_AsDouble(THPFInfo_eps(self, nullptr));
245 oss << ", smallest_normal="
246 << PyFloat_AsDouble(THPFInfo_smallest_normal(self, nullptr));
247 oss << ", tiny=" << PyFloat_AsDouble(THPFInfo_tiny(self, nullptr));
248 if (dtypeStr != nullptr) {
249 oss << ", dtype=" << PyUnicode_AsUTF8(dtypeStr) << ")";
250 }
251 return !PyErr_Occurred() ? THPUtils_packString(oss.str().c_str()) : nullptr;
252 }
253
THPIInfo_str(THPIInfo * self)254 PyObject* THPIInfo_str(THPIInfo* self) {
255 std::ostringstream oss;
256
257 const auto dtypeStr = THPIInfo_dtype(self, nullptr);
258 oss << "iinfo(min=" << PyLong_AsDouble(THPIInfo_min(self, nullptr));
259 oss << ", max=" << PyLong_AsDouble(THPIInfo_max(self, nullptr));
260 if (dtypeStr) {
261 oss << ", dtype=" << PyUnicode_AsUTF8(dtypeStr) << ")";
262 }
263
264 return !PyErr_Occurred() ? THPUtils_packString(oss.str().c_str()) : nullptr;
265 }
266
267 // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables,cppcoreguidelines-avoid-c-arrays)
268 static struct PyGetSetDef THPFInfo_properties[] = {
269 {"bits", (getter)THPDTypeInfo_bits, nullptr, nullptr, nullptr},
270 {"eps", (getter)THPFInfo_eps, nullptr, nullptr, nullptr},
271 {"max", (getter)THPFInfo_max, nullptr, nullptr, nullptr},
272 {"min", (getter)THPFInfo_min, nullptr, nullptr, nullptr},
273 {"smallest_normal",
274 (getter)THPFInfo_smallest_normal,
275 nullptr,
276 nullptr,
277 nullptr},
278 {"tiny", (getter)THPFInfo_tiny, nullptr, nullptr, nullptr},
279 {"resolution", (getter)THPFInfo_resolution, nullptr, nullptr, nullptr},
280 {"dtype", (getter)THPFInfo_dtype, nullptr, nullptr, nullptr},
281 {nullptr}};
282
283 // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables,cppcoreguidelines-avoid-c-arrays)
284 static PyMethodDef THPFInfo_methods[] = {
285 {nullptr} /* Sentinel */
286 };
287
288 PyTypeObject THPFInfoType = {
289 PyVarObject_HEAD_INIT(nullptr, 0) "torch.finfo", /* tp_name */
290 sizeof(THPFInfo), /* tp_basicsize */
291 0, /* tp_itemsize */
292 nullptr, /* tp_dealloc */
293 0, /* tp_vectorcall_offset */
294 nullptr, /* tp_getattr */
295 nullptr, /* tp_setattr */
296 nullptr, /* tp_reserved */
297 (reprfunc)THPFInfo_str, /* tp_repr */
298 nullptr, /* tp_as_number */
299 nullptr, /* tp_as_sequence */
300 nullptr, /* tp_as_mapping */
301 nullptr, /* tp_hash */
302 nullptr, /* tp_call */
303 (reprfunc)THPFInfo_str, /* tp_str */
304 nullptr, /* tp_getattro */
305 nullptr, /* tp_setattro */
306 nullptr, /* tp_as_buffer */
307 Py_TPFLAGS_DEFAULT, /* tp_flags */
308 nullptr, /* tp_doc */
309 nullptr, /* tp_traverse */
310 nullptr, /* tp_clear */
311 (richcmpfunc)THPDTypeInfo_compare, /* tp_richcompare */
312 0, /* tp_weaklistoffset */
313 nullptr, /* tp_iter */
314 nullptr, /* tp_iternext */
315 THPFInfo_methods, /* tp_methods */
316 nullptr, /* tp_members */
317 THPFInfo_properties, /* tp_getset */
318 nullptr, /* tp_base */
319 nullptr, /* tp_dict */
320 nullptr, /* tp_descr_get */
321 nullptr, /* tp_descr_set */
322 0, /* tp_dictoffset */
323 nullptr, /* tp_init */
324 nullptr, /* tp_alloc */
325 THPFInfo_pynew, /* tp_new */
326 };
327
328 // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables,cppcoreguidelines-avoid-c-arrays)
329 static struct PyGetSetDef THPIInfo_properties[] = {
330 {"bits", (getter)THPDTypeInfo_bits, nullptr, nullptr, nullptr},
331 {"max", (getter)THPIInfo_max, nullptr, nullptr, nullptr},
332 {"min", (getter)THPIInfo_min, nullptr, nullptr, nullptr},
333 {"dtype", (getter)THPIInfo_dtype, nullptr, nullptr, nullptr},
334 {nullptr}};
335
336 // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables,cppcoreguidelines-avoid-c-arrays)
337 static PyMethodDef THPIInfo_methods[] = {
338 {nullptr} /* Sentinel */
339 };
340
341 PyTypeObject THPIInfoType = {
342 PyVarObject_HEAD_INIT(nullptr, 0) "torch.iinfo", /* tp_name */
343 sizeof(THPIInfo), /* tp_basicsize */
344 0, /* tp_itemsize */
345 nullptr, /* tp_dealloc */
346 0, /* tp_vectorcall_offset */
347 nullptr, /* tp_getattr */
348 nullptr, /* tp_setattr */
349 nullptr, /* tp_reserved */
350 (reprfunc)THPIInfo_str, /* tp_repr */
351 nullptr, /* tp_as_number */
352 nullptr, /* tp_as_sequence */
353 nullptr, /* tp_as_mapping */
354 nullptr, /* tp_hash */
355 nullptr, /* tp_call */
356 (reprfunc)THPIInfo_str, /* tp_str */
357 nullptr, /* tp_getattro */
358 nullptr, /* tp_setattro */
359 nullptr, /* tp_as_buffer */
360 Py_TPFLAGS_DEFAULT, /* tp_flags */
361 nullptr, /* tp_doc */
362 nullptr, /* tp_traverse */
363 nullptr, /* tp_clear */
364 (richcmpfunc)THPDTypeInfo_compare, /* tp_richcompare */
365 0, /* tp_weaklistoffset */
366 nullptr, /* tp_iter */
367 nullptr, /* tp_iternext */
368 THPIInfo_methods, /* tp_methods */
369 nullptr, /* tp_members */
370 THPIInfo_properties, /* tp_getset */
371 nullptr, /* tp_base */
372 nullptr, /* tp_dict */
373 nullptr, /* tp_descr_get */
374 nullptr, /* tp_descr_set */
375 0, /* tp_dictoffset */
376 nullptr, /* tp_init */
377 nullptr, /* tp_alloc */
378 THPIInfo_pynew, /* tp_new */
379 };
380
THPDTypeInfo_init(PyObject * module)381 void THPDTypeInfo_init(PyObject* module) {
382 if (PyType_Ready(&THPFInfoType) < 0) {
383 throw python_error();
384 }
385 Py_INCREF(&THPFInfoType);
386 if (PyModule_AddObject(module, "finfo", (PyObject*)&THPFInfoType) != 0) {
387 throw python_error();
388 }
389 if (PyType_Ready(&THPIInfoType) < 0) {
390 throw python_error();
391 }
392 Py_INCREF(&THPIInfoType);
393 if (PyModule_AddObject(module, "iinfo", (PyObject*)&THPIInfoType) != 0) {
394 throw python_error();
395 }
396 }
397