xref: /aosp_15_r20/external/pytorch/torch/csrc/TypeInfo.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/TypeInfo.h>
2 
3 #include <torch/csrc/Exceptions.h>
4 #include <torch/csrc/utils/object_ptr.h>
5 #include <torch/csrc/utils/pybind.h>
6 #include <torch/csrc/utils/python_arg_parser.h>
7 #include <torch/csrc/utils/python_numbers.h>
8 #include <torch/csrc/utils/python_strings.h>
9 #include <torch/csrc/utils/tensor_dtypes.h>
10 
11 #include <ATen/Dispatch_v2.h>
12 
13 #include <c10/util/Exception.h>
14 
15 #include <structmember.h>
16 #include <cstring>
17 #include <limits>
18 #include <sstream>
19 
THPFInfo_New(const at::ScalarType & type)20 PyObject* THPFInfo_New(const at::ScalarType& type) {
21   auto finfo = (PyTypeObject*)&THPFInfoType;
22   auto self = THPObjectPtr{finfo->tp_alloc(finfo, 0)};
23   if (!self)
24     throw python_error();
25   auto self_ = reinterpret_cast<THPDTypeInfo*>(self.get());
26   self_->type = c10::toRealValueType(type);
27   return self.release();
28 }
29 
THPIInfo_New(const at::ScalarType & type)30 PyObject* THPIInfo_New(const at::ScalarType& type) {
31   auto iinfo = (PyTypeObject*)&THPIInfoType;
32   auto self = THPObjectPtr{iinfo->tp_alloc(iinfo, 0)};
33   if (!self)
34     throw python_error();
35   auto self_ = reinterpret_cast<THPDTypeInfo*>(self.get());
36   self_->type = type;
37   return self.release();
38 }
39 
THPFInfo_pynew(PyTypeObject * type,PyObject * args,PyObject * kwargs)40 PyObject* THPFInfo_pynew(PyTypeObject* type, PyObject* args, PyObject* kwargs) {
41   HANDLE_TH_ERRORS
42   static torch::PythonArgParser parser({
43       "finfo(ScalarType type)",
44       "finfo()",
45   });
46 
47   torch::ParsedArgs<1> parsed_args;
48   auto r = parser.parse(args, kwargs, parsed_args);
49   TORCH_CHECK(r.idx < 2, "Not a type");
50   at::ScalarType scalar_type = at::ScalarType::Undefined;
51   if (r.idx == 1) {
52     scalar_type = torch::tensors::get_default_scalar_type();
53     // The default tensor type can only be set to a floating point type/
54     AT_ASSERT(at::isFloatingType(scalar_type));
55   } else {
56     scalar_type = r.scalartype(0);
57     if (!at::isFloatingType(scalar_type) && !at::isComplexType(scalar_type)) {
58       return PyErr_Format(
59           PyExc_TypeError,
60           "torch.finfo() requires a floating point input type. Use torch.iinfo to handle '%s'",
61           type->tp_name);
62     }
63   }
64   return THPFInfo_New(scalar_type);
65   END_HANDLE_TH_ERRORS
66 }
67 
THPIInfo_pynew(PyTypeObject * type,PyObject * args,PyObject * kwargs)68 PyObject* THPIInfo_pynew(PyTypeObject* type, PyObject* args, PyObject* kwargs) {
69   HANDLE_TH_ERRORS
70   static torch::PythonArgParser parser({
71       "iinfo(ScalarType type)",
72   });
73   torch::ParsedArgs<1> parsed_args;
74   auto r = parser.parse(args, kwargs, parsed_args);
75   TORCH_CHECK(r.idx == 0, "Not a type");
76 
77   at::ScalarType scalar_type = r.scalartype(0);
78   if (scalar_type == at::ScalarType::Bool) {
79     return PyErr_Format(
80         PyExc_TypeError, "torch.bool is not supported by torch.iinfo");
81   }
82   if (!at::isIntegralType(scalar_type, /*includeBool=*/false) &&
83       !at::isQIntType(scalar_type)) {
84     return PyErr_Format(
85         PyExc_TypeError,
86         "torch.iinfo() requires an integer input type. Use torch.finfo to handle '%s'",
87         type->tp_name);
88   }
89   return THPIInfo_New(scalar_type);
90   END_HANDLE_TH_ERRORS
91 }
92 
THPDTypeInfo_compare(THPDTypeInfo * a,THPDTypeInfo * b,int op)93 PyObject* THPDTypeInfo_compare(THPDTypeInfo* a, THPDTypeInfo* b, int op) {
94   switch (op) {
95     case Py_EQ:
96       if (a->type == b->type) {
97         Py_RETURN_TRUE;
98       } else {
99         Py_RETURN_FALSE;
100       }
101     case Py_NE:
102       if (a->type != b->type) {
103         Py_RETURN_TRUE;
104       } else {
105         Py_RETURN_FALSE;
106       }
107   }
108   return Py_INCREF(Py_NotImplemented), Py_NotImplemented;
109 }
110 
THPDTypeInfo_bits(THPDTypeInfo * self,void *)111 static PyObject* THPDTypeInfo_bits(THPDTypeInfo* self, void*) {
112   uint64_t bits = elementSize(self->type) * CHAR_BIT;
113   return THPUtils_packUInt64(bits);
114 }
115 
116 #define _AT_DISPATCH_FINFO_TYPES(TYPE, NAME, ...) \
117   AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND6(    \
118       at::kHalf,                                  \
119       at::ScalarType::BFloat16,                   \
120       at::ScalarType::Float8_e5m2,                \
121       at::ScalarType::Float8_e5m2fnuz,            \
122       at::ScalarType::Float8_e4m3fn,              \
123       at::ScalarType::Float8_e4m3fnuz,            \
124       TYPE,                                       \
125       NAME,                                       \
126       __VA_ARGS__)
127 
THPFInfo_eps(THPFInfo * self,void *)128 static PyObject* THPFInfo_eps(THPFInfo* self, void*) {
129   HANDLE_TH_ERRORS
130   return _AT_DISPATCH_FINFO_TYPES(self->type, "epsilon", [] {
131     return PyFloat_FromDouble(
132         std::numeric_limits<at::scalar_value_type<scalar_t>::type>::epsilon());
133   });
134   END_HANDLE_TH_ERRORS
135 }
136 
THPFInfo_max(THPFInfo * self,void *)137 static PyObject* THPFInfo_max(THPFInfo* self, void*) {
138   HANDLE_TH_ERRORS
139   return _AT_DISPATCH_FINFO_TYPES(self->type, "max", [] {
140     return PyFloat_FromDouble(
141         std::numeric_limits<at::scalar_value_type<scalar_t>::type>::max());
142   });
143   END_HANDLE_TH_ERRORS
144 }
145 
THPFInfo_min(THPFInfo * self,void *)146 static PyObject* THPFInfo_min(THPFInfo* self, void*) {
147   HANDLE_TH_ERRORS
148   return _AT_DISPATCH_FINFO_TYPES(self->type, "lowest", [] {
149     return PyFloat_FromDouble(
150         std::numeric_limits<at::scalar_value_type<scalar_t>::type>::lowest());
151   });
152   END_HANDLE_TH_ERRORS
153 }
154 
155 #define AT_DISPATCH_IINFO_TYPES(TYPE, NAME, ...) \
156   AT_DISPATCH_V2(                                \
157       TYPE, NAME, AT_WRAP(__VA_ARGS__), AT_EXPAND(AT_INTEGRAL_TYPES_V2))
158 
THPIInfo_max(THPIInfo * self,void *)159 static PyObject* THPIInfo_max(THPIInfo* self, void*) {
160   HANDLE_TH_ERRORS
161   if (at::isIntegralType(self->type, /*includeBool=*/false)) {
162     return AT_DISPATCH_IINFO_TYPES(self->type, "max", [] {
163       if (std::is_unsigned_v<scalar_t>) {
164         return THPUtils_packUInt64(std::numeric_limits<scalar_t>::max());
165       } else {
166         return THPUtils_packInt64(std::numeric_limits<scalar_t>::max());
167       }
168     });
169   }
170   // Quantized Type
171   return AT_DISPATCH_QINT_AND_SUB_BYTE_TYPES(self->type, "max", [] {
172     return THPUtils_packInt64(std::numeric_limits<underlying_t>::max());
173   });
174   END_HANDLE_TH_ERRORS
175 }
176 
THPIInfo_min(THPIInfo * self,void *)177 static PyObject* THPIInfo_min(THPIInfo* self, void*) {
178   HANDLE_TH_ERRORS
179   if (at::isIntegralType(self->type, /*includeBool=*/false)) {
180     return AT_DISPATCH_IINFO_TYPES(self->type, "min", [] {
181       if (std::is_unsigned_v<scalar_t>) {
182         return THPUtils_packUInt64(std::numeric_limits<scalar_t>::lowest());
183       } else {
184         return THPUtils_packInt64(std::numeric_limits<scalar_t>::lowest());
185       }
186     });
187   }
188   // Quantized Type
189   return AT_DISPATCH_QINT_AND_SUB_BYTE_TYPES(self->type, "min", [] {
190     return THPUtils_packInt64(std::numeric_limits<underlying_t>::lowest());
191   });
192   END_HANDLE_TH_ERRORS
193 }
194 
THPIInfo_dtype(THPIInfo * self,void *)195 static PyObject* THPIInfo_dtype(THPIInfo* self, void*) {
196   HANDLE_TH_ERRORS
197   auto primary_name = c10::getDtypeNames(self->type).first;
198   return AT_DISPATCH_IINFO_TYPES(self->type, "dtype", [&primary_name] {
199     return PyUnicode_FromString(primary_name.data());
200   });
201   END_HANDLE_TH_ERRORS
202 }
203 
THPFInfo_smallest_normal(THPFInfo * self,void *)204 static PyObject* THPFInfo_smallest_normal(THPFInfo* self, void*) {
205   HANDLE_TH_ERRORS
206   return _AT_DISPATCH_FINFO_TYPES(self->type, "min", [] {
207     return PyFloat_FromDouble(
208         std::numeric_limits<at::scalar_value_type<scalar_t>::type>::min());
209   });
210   END_HANDLE_TH_ERRORS
211 }
212 
THPFInfo_tiny(THPFInfo * self,void *)213 static PyObject* THPFInfo_tiny(THPFInfo* self, void*) {
214   // see gh-70909, essentially the array_api prefers smallest_normal over tiny
215   return THPFInfo_smallest_normal(self, nullptr);
216 }
217 
THPFInfo_resolution(THPFInfo * self,void *)218 static PyObject* THPFInfo_resolution(THPFInfo* self, void*) {
219   HANDLE_TH_ERRORS
220   return _AT_DISPATCH_FINFO_TYPES(self->type, "digits10", [] {
221     return PyFloat_FromDouble(std::pow(
222         10,
223         -std::numeric_limits<at::scalar_value_type<scalar_t>::type>::digits10));
224   });
225   END_HANDLE_TH_ERRORS
226 }
227 
THPFInfo_dtype(THPFInfo * self,void *)228 static PyObject* THPFInfo_dtype(THPFInfo* self, void*) {
229   HANDLE_TH_ERRORS
230   auto primary_name = c10::getDtypeNames(self->type).first;
231   return _AT_DISPATCH_FINFO_TYPES(self->type, "dtype", [&primary_name] {
232     return PyUnicode_FromString(primary_name.data());
233   });
234   END_HANDLE_TH_ERRORS
235 }
236 
THPFInfo_str(THPFInfo * self)237 PyObject* THPFInfo_str(THPFInfo* self) {
238   std::ostringstream oss;
239   const auto dtypeStr = THPFInfo_dtype(self, nullptr);
240   oss << "finfo(resolution="
241       << PyFloat_AsDouble(THPFInfo_resolution(self, nullptr));
242   oss << ", min=" << PyFloat_AsDouble(THPFInfo_min(self, nullptr));
243   oss << ", max=" << PyFloat_AsDouble(THPFInfo_max(self, nullptr));
244   oss << ", eps=" << PyFloat_AsDouble(THPFInfo_eps(self, nullptr));
245   oss << ", smallest_normal="
246       << PyFloat_AsDouble(THPFInfo_smallest_normal(self, nullptr));
247   oss << ", tiny=" << PyFloat_AsDouble(THPFInfo_tiny(self, nullptr));
248   if (dtypeStr != nullptr) {
249     oss << ", dtype=" << PyUnicode_AsUTF8(dtypeStr) << ")";
250   }
251   return !PyErr_Occurred() ? THPUtils_packString(oss.str().c_str()) : nullptr;
252 }
253 
THPIInfo_str(THPIInfo * self)254 PyObject* THPIInfo_str(THPIInfo* self) {
255   std::ostringstream oss;
256 
257   const auto dtypeStr = THPIInfo_dtype(self, nullptr);
258   oss << "iinfo(min=" << PyLong_AsDouble(THPIInfo_min(self, nullptr));
259   oss << ", max=" << PyLong_AsDouble(THPIInfo_max(self, nullptr));
260   if (dtypeStr) {
261     oss << ", dtype=" << PyUnicode_AsUTF8(dtypeStr) << ")";
262   }
263 
264   return !PyErr_Occurred() ? THPUtils_packString(oss.str().c_str()) : nullptr;
265 }
266 
267 // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables,cppcoreguidelines-avoid-c-arrays)
268 static struct PyGetSetDef THPFInfo_properties[] = {
269     {"bits", (getter)THPDTypeInfo_bits, nullptr, nullptr, nullptr},
270     {"eps", (getter)THPFInfo_eps, nullptr, nullptr, nullptr},
271     {"max", (getter)THPFInfo_max, nullptr, nullptr, nullptr},
272     {"min", (getter)THPFInfo_min, nullptr, nullptr, nullptr},
273     {"smallest_normal",
274      (getter)THPFInfo_smallest_normal,
275      nullptr,
276      nullptr,
277      nullptr},
278     {"tiny", (getter)THPFInfo_tiny, nullptr, nullptr, nullptr},
279     {"resolution", (getter)THPFInfo_resolution, nullptr, nullptr, nullptr},
280     {"dtype", (getter)THPFInfo_dtype, nullptr, nullptr, nullptr},
281     {nullptr}};
282 
283 // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables,cppcoreguidelines-avoid-c-arrays)
284 static PyMethodDef THPFInfo_methods[] = {
285     {nullptr} /* Sentinel */
286 };
287 
288 PyTypeObject THPFInfoType = {
289     PyVarObject_HEAD_INIT(nullptr, 0) "torch.finfo", /* tp_name */
290     sizeof(THPFInfo), /* tp_basicsize */
291     0, /* tp_itemsize */
292     nullptr, /* tp_dealloc */
293     0, /* tp_vectorcall_offset */
294     nullptr, /* tp_getattr */
295     nullptr, /* tp_setattr */
296     nullptr, /* tp_reserved */
297     (reprfunc)THPFInfo_str, /* tp_repr */
298     nullptr, /* tp_as_number */
299     nullptr, /* tp_as_sequence */
300     nullptr, /* tp_as_mapping */
301     nullptr, /* tp_hash  */
302     nullptr, /* tp_call */
303     (reprfunc)THPFInfo_str, /* tp_str */
304     nullptr, /* tp_getattro */
305     nullptr, /* tp_setattro */
306     nullptr, /* tp_as_buffer */
307     Py_TPFLAGS_DEFAULT, /* tp_flags */
308     nullptr, /* tp_doc */
309     nullptr, /* tp_traverse */
310     nullptr, /* tp_clear */
311     (richcmpfunc)THPDTypeInfo_compare, /* tp_richcompare */
312     0, /* tp_weaklistoffset */
313     nullptr, /* tp_iter */
314     nullptr, /* tp_iternext */
315     THPFInfo_methods, /* tp_methods */
316     nullptr, /* tp_members */
317     THPFInfo_properties, /* tp_getset */
318     nullptr, /* tp_base */
319     nullptr, /* tp_dict */
320     nullptr, /* tp_descr_get */
321     nullptr, /* tp_descr_set */
322     0, /* tp_dictoffset */
323     nullptr, /* tp_init */
324     nullptr, /* tp_alloc */
325     THPFInfo_pynew, /* tp_new */
326 };
327 
328 // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables,cppcoreguidelines-avoid-c-arrays)
329 static struct PyGetSetDef THPIInfo_properties[] = {
330     {"bits", (getter)THPDTypeInfo_bits, nullptr, nullptr, nullptr},
331     {"max", (getter)THPIInfo_max, nullptr, nullptr, nullptr},
332     {"min", (getter)THPIInfo_min, nullptr, nullptr, nullptr},
333     {"dtype", (getter)THPIInfo_dtype, nullptr, nullptr, nullptr},
334     {nullptr}};
335 
336 // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables,cppcoreguidelines-avoid-c-arrays)
337 static PyMethodDef THPIInfo_methods[] = {
338     {nullptr} /* Sentinel */
339 };
340 
341 PyTypeObject THPIInfoType = {
342     PyVarObject_HEAD_INIT(nullptr, 0) "torch.iinfo", /* tp_name */
343     sizeof(THPIInfo), /* tp_basicsize */
344     0, /* tp_itemsize */
345     nullptr, /* tp_dealloc */
346     0, /* tp_vectorcall_offset */
347     nullptr, /* tp_getattr */
348     nullptr, /* tp_setattr */
349     nullptr, /* tp_reserved */
350     (reprfunc)THPIInfo_str, /* tp_repr */
351     nullptr, /* tp_as_number */
352     nullptr, /* tp_as_sequence */
353     nullptr, /* tp_as_mapping */
354     nullptr, /* tp_hash  */
355     nullptr, /* tp_call */
356     (reprfunc)THPIInfo_str, /* tp_str */
357     nullptr, /* tp_getattro */
358     nullptr, /* tp_setattro */
359     nullptr, /* tp_as_buffer */
360     Py_TPFLAGS_DEFAULT, /* tp_flags */
361     nullptr, /* tp_doc */
362     nullptr, /* tp_traverse */
363     nullptr, /* tp_clear */
364     (richcmpfunc)THPDTypeInfo_compare, /* tp_richcompare */
365     0, /* tp_weaklistoffset */
366     nullptr, /* tp_iter */
367     nullptr, /* tp_iternext */
368     THPIInfo_methods, /* tp_methods */
369     nullptr, /* tp_members */
370     THPIInfo_properties, /* tp_getset */
371     nullptr, /* tp_base */
372     nullptr, /* tp_dict */
373     nullptr, /* tp_descr_get */
374     nullptr, /* tp_descr_set */
375     0, /* tp_dictoffset */
376     nullptr, /* tp_init */
377     nullptr, /* tp_alloc */
378     THPIInfo_pynew, /* tp_new */
379 };
380 
THPDTypeInfo_init(PyObject * module)381 void THPDTypeInfo_init(PyObject* module) {
382   if (PyType_Ready(&THPFInfoType) < 0) {
383     throw python_error();
384   }
385   Py_INCREF(&THPFInfoType);
386   if (PyModule_AddObject(module, "finfo", (PyObject*)&THPFInfoType) != 0) {
387     throw python_error();
388   }
389   if (PyType_Ready(&THPIInfoType) < 0) {
390     throw python_error();
391   }
392   Py_INCREF(&THPIInfoType);
393   if (PyModule_AddObject(module, "iinfo", (PyObject*)&THPIInfoType) != 0) {
394     throw python_error();
395   }
396 }
397