xref: /aosp_15_r20/external/pytorch/torch/csrc/Dtype.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/Dtype.h>
2 
3 #include <c10/core/ScalarType.h>
4 #include <structmember.h>
5 #include <torch/csrc/DynamicTypes.h>
6 #include <torch/csrc/Exceptions.h>
7 #include <torch/csrc/utils/object_ptr.h>
8 #include <torch/csrc/utils/python_numbers.h>
9 #include <torch/csrc/utils/python_strings.h>
10 #include <torch/csrc/utils/pythoncapi_compat.h>
11 #include <torch/csrc/utils/tensor_dtypes.h>
12 #include <torch/csrc/utils/tensor_types.h>
13 #include <cstring>
14 
THPDtype_New(at::ScalarType scalar_type,const std::string & name)15 PyObject* THPDtype_New(at::ScalarType scalar_type, const std::string& name) {
16   HANDLE_TH_ERRORS
17   AT_ASSERT(name.length() < DTYPE_NAME_LEN);
18   auto type = (PyTypeObject*)&THPDtypeType;
19   auto self = THPObjectPtr{type->tp_alloc(type, 0)};
20   if (!self)
21     throw python_error();
22   auto self_ = reinterpret_cast<THPDtype*>(self.get());
23   self_->scalar_type = scalar_type;
24   std::strncpy(self_->name, name.c_str(), DTYPE_NAME_LEN);
25   return self.release();
26   END_HANDLE_TH_ERRORS
27 }
28 
THPDtype_is_floating_point(THPDtype * self,PyObject * noargs)29 PyObject* THPDtype_is_floating_point(THPDtype* self, PyObject* noargs) {
30   HANDLE_TH_ERRORS
31   if (at::isFloatingType(self->scalar_type)) {
32     Py_RETURN_TRUE;
33   } else {
34     Py_RETURN_FALSE;
35   }
36   END_HANDLE_TH_ERRORS
37 }
38 
THPDtype_itemsize(THPDtype * self,PyObject * noargs)39 PyObject* THPDtype_itemsize(THPDtype* self, PyObject* noargs) {
40   HANDLE_TH_ERRORS
41   return THPUtils_packUInt64(
42       scalarTypeToTypeMeta(self->scalar_type).itemsize());
43   END_HANDLE_TH_ERRORS
44 }
45 
THPDtype_is_complex(THPDtype * self,PyObject * noargs)46 PyObject* THPDtype_is_complex(THPDtype* self, PyObject* noargs) {
47   HANDLE_TH_ERRORS
48   if (at::isComplexType(self->scalar_type)) {
49     Py_RETURN_TRUE;
50   } else {
51     Py_RETURN_FALSE;
52   }
53   END_HANDLE_TH_ERRORS
54 }
55 
THPDtype_is_signed(THPDtype * self,PyObject * noargs)56 PyObject* THPDtype_is_signed(THPDtype* self, PyObject* noargs) {
57   HANDLE_TH_ERRORS
58   if (at::isSignedType(self->scalar_type)) {
59     Py_RETURN_TRUE;
60   } else {
61     Py_RETURN_FALSE;
62   }
63   END_HANDLE_TH_ERRORS
64 }
65 
THPDtype_reduce(PyObject * _self,PyObject * noargs)66 PyObject* THPDtype_reduce(PyObject* _self, PyObject* noargs) {
67   HANDLE_TH_ERRORS
68   /*
69    * For singletons, a string is returned. The string should be interpreted
70    * as the name of a global variable.
71    */
72   auto self = (THPDtype*)_self;
73   return THPUtils_packString(self->name);
74   END_HANDLE_TH_ERRORS
75 }
76 
THPDtype_to_real(PyObject * _self,PyObject * noargs)77 PyObject* THPDtype_to_real(PyObject* _self, PyObject* noargs) {
78   HANDLE_TH_ERRORS
79   auto* self = (THPDtype*)_self;
80   auto scalar_type = self->scalar_type;
81   if (!at::isFloatingType(self->scalar_type)) {
82     scalar_type = at::toRealValueType(self->scalar_type);
83   }
84   return Py_NewRef(torch::getTHPDtype(scalar_type));
85   END_HANDLE_TH_ERRORS
86 }
87 
THPDtype_to_complex(PyObject * _self,PyObject * noargs)88 PyObject* THPDtype_to_complex(PyObject* _self, PyObject* noargs) {
89   HANDLE_TH_ERRORS
90   auto* self = (THPDtype*)_self;
91   auto scalar_type = self->scalar_type;
92   if (!at::isComplexType(self->scalar_type)) {
93     scalar_type = at::toComplexType(self->scalar_type);
94   }
95   return Py_NewRef(torch::getTHPDtype(scalar_type));
96   END_HANDLE_TH_ERRORS
97 }
98 
99 typedef PyObject* (*getter)(PyObject*, void*);
100 
101 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays)
102 static struct PyGetSetDef THPDtype_properties[] = {
103     {"is_floating_point",
104      (getter)THPDtype_is_floating_point,
105      nullptr,
106      nullptr,
107      nullptr},
108     {"is_complex", (getter)THPDtype_is_complex, nullptr, nullptr, nullptr},
109     {"is_signed", (getter)THPDtype_is_signed, nullptr, nullptr, nullptr},
110     {"itemsize", (getter)THPDtype_itemsize, nullptr, nullptr, nullptr},
111     {nullptr}};
112 
113 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays)
114 static PyMethodDef THPDtype_methods[] = {
115     {"__reduce__", THPDtype_reduce, METH_NOARGS, nullptr},
116     {"to_real", THPDtype_to_real, METH_NOARGS, nullptr},
117     {"to_complex", THPDtype_to_complex, METH_NOARGS, nullptr},
118     {nullptr} /* Sentinel */
119 };
120 
THPDtype_repr(THPDtype * self)121 PyObject* THPDtype_repr(THPDtype* self) {
122   return THPUtils_packString(std::string("torch.") + self->name);
123 }
124 
125 PyTypeObject THPDtypeType = {
126     PyVarObject_HEAD_INIT(nullptr, 0) "torch.dtype", /* tp_name */
127     sizeof(THPDtype), /* tp_basicsize */
128     0, /* tp_itemsize */
129     nullptr, /* tp_dealloc */
130     0, /* tp_vectorcall_offset */
131     nullptr, /* tp_getattr */
132     nullptr, /* tp_setattr */
133     nullptr, /* tp_reserved */
134     (reprfunc)THPDtype_repr, /* tp_repr */
135     nullptr, /* tp_as_number */
136     nullptr, /* tp_as_sequence */
137     nullptr, /* tp_as_mapping */
138     nullptr, /* tp_hash  */
139     nullptr, /* tp_call */
140     nullptr, /* tp_str */
141     nullptr, /* tp_getattro */
142     nullptr, /* tp_setattro */
143     nullptr, /* tp_as_buffer */
144     Py_TPFLAGS_DEFAULT, /* tp_flags */
145     nullptr, /* tp_doc */
146     nullptr, /* tp_traverse */
147     nullptr, /* tp_clear */
148     nullptr, /* tp_richcompare */
149     0, /* tp_weaklistoffset */
150     nullptr, /* tp_iter */
151     nullptr, /* tp_iternext */
152     THPDtype_methods, /* tp_methods */
153     nullptr, /* tp_members */
154     THPDtype_properties, /* tp_getset */
155     nullptr, /* tp_base */
156     nullptr, /* tp_dict */
157     nullptr, /* tp_descr_get */
158     nullptr, /* tp_descr_set */
159     0, /* tp_dictoffset */
160     nullptr, /* tp_init */
161     nullptr, /* tp_alloc */
162     nullptr, /* tp_new */
163 };
164 
THPDtype_init(PyObject * module)165 void THPDtype_init(PyObject* module) {
166   // Set a __dict__ with `__module__` = `torch`. This means
167   // `__module__` value will be inherited by instances
168   // (i.e. `torch.float32.__module__ == "torch"`). This will prevent
169   // Pickle from having to search all of sys.modules in order to find
170   // the module when pickling a dtype instance.
171   //
172   // We have to do this in C++ because extension types are not mutable
173   // from Python code.
174   //
175   // See https://github.com/pytorch/pytorch/issues/65077
176   TORCH_INTERNAL_ASSERT(THPDtypeType.tp_dict == nullptr);
177   auto dict = THPObjectPtr(PyDict_New());
178   if (!dict)
179     throw python_error();
180   auto torch = THPUtils_packString("torch");
181   if (!torch)
182     throw python_error();
183   if (PyDict_SetItemString(dict, "__module__", torch) < 0) {
184     throw python_error();
185   }
186   THPDtypeType.tp_dict = dict.release();
187 
188   if (PyType_Ready(&THPDtypeType) < 0) {
189     throw python_error();
190   }
191   Py_INCREF(&THPDtypeType);
192   if (PyModule_AddObject(module, "dtype", (PyObject*)&THPDtypeType) != 0) {
193     throw python_error();
194   }
195 }
196