1 #include <torch/csrc/Dtype.h>
2
3 #include <c10/core/ScalarType.h>
4 #include <structmember.h>
5 #include <torch/csrc/DynamicTypes.h>
6 #include <torch/csrc/Exceptions.h>
7 #include <torch/csrc/utils/object_ptr.h>
8 #include <torch/csrc/utils/python_numbers.h>
9 #include <torch/csrc/utils/python_strings.h>
10 #include <torch/csrc/utils/pythoncapi_compat.h>
11 #include <torch/csrc/utils/tensor_dtypes.h>
12 #include <torch/csrc/utils/tensor_types.h>
13 #include <cstring>
14
THPDtype_New(at::ScalarType scalar_type,const std::string & name)15 PyObject* THPDtype_New(at::ScalarType scalar_type, const std::string& name) {
16 HANDLE_TH_ERRORS
17 AT_ASSERT(name.length() < DTYPE_NAME_LEN);
18 auto type = (PyTypeObject*)&THPDtypeType;
19 auto self = THPObjectPtr{type->tp_alloc(type, 0)};
20 if (!self)
21 throw python_error();
22 auto self_ = reinterpret_cast<THPDtype*>(self.get());
23 self_->scalar_type = scalar_type;
24 std::strncpy(self_->name, name.c_str(), DTYPE_NAME_LEN);
25 return self.release();
26 END_HANDLE_TH_ERRORS
27 }
28
THPDtype_is_floating_point(THPDtype * self,PyObject * noargs)29 PyObject* THPDtype_is_floating_point(THPDtype* self, PyObject* noargs) {
30 HANDLE_TH_ERRORS
31 if (at::isFloatingType(self->scalar_type)) {
32 Py_RETURN_TRUE;
33 } else {
34 Py_RETURN_FALSE;
35 }
36 END_HANDLE_TH_ERRORS
37 }
38
THPDtype_itemsize(THPDtype * self,PyObject * noargs)39 PyObject* THPDtype_itemsize(THPDtype* self, PyObject* noargs) {
40 HANDLE_TH_ERRORS
41 return THPUtils_packUInt64(
42 scalarTypeToTypeMeta(self->scalar_type).itemsize());
43 END_HANDLE_TH_ERRORS
44 }
45
THPDtype_is_complex(THPDtype * self,PyObject * noargs)46 PyObject* THPDtype_is_complex(THPDtype* self, PyObject* noargs) {
47 HANDLE_TH_ERRORS
48 if (at::isComplexType(self->scalar_type)) {
49 Py_RETURN_TRUE;
50 } else {
51 Py_RETURN_FALSE;
52 }
53 END_HANDLE_TH_ERRORS
54 }
55
THPDtype_is_signed(THPDtype * self,PyObject * noargs)56 PyObject* THPDtype_is_signed(THPDtype* self, PyObject* noargs) {
57 HANDLE_TH_ERRORS
58 if (at::isSignedType(self->scalar_type)) {
59 Py_RETURN_TRUE;
60 } else {
61 Py_RETURN_FALSE;
62 }
63 END_HANDLE_TH_ERRORS
64 }
65
THPDtype_reduce(PyObject * _self,PyObject * noargs)66 PyObject* THPDtype_reduce(PyObject* _self, PyObject* noargs) {
67 HANDLE_TH_ERRORS
68 /*
69 * For singletons, a string is returned. The string should be interpreted
70 * as the name of a global variable.
71 */
72 auto self = (THPDtype*)_self;
73 return THPUtils_packString(self->name);
74 END_HANDLE_TH_ERRORS
75 }
76
THPDtype_to_real(PyObject * _self,PyObject * noargs)77 PyObject* THPDtype_to_real(PyObject* _self, PyObject* noargs) {
78 HANDLE_TH_ERRORS
79 auto* self = (THPDtype*)_self;
80 auto scalar_type = self->scalar_type;
81 if (!at::isFloatingType(self->scalar_type)) {
82 scalar_type = at::toRealValueType(self->scalar_type);
83 }
84 return Py_NewRef(torch::getTHPDtype(scalar_type));
85 END_HANDLE_TH_ERRORS
86 }
87
THPDtype_to_complex(PyObject * _self,PyObject * noargs)88 PyObject* THPDtype_to_complex(PyObject* _self, PyObject* noargs) {
89 HANDLE_TH_ERRORS
90 auto* self = (THPDtype*)_self;
91 auto scalar_type = self->scalar_type;
92 if (!at::isComplexType(self->scalar_type)) {
93 scalar_type = at::toComplexType(self->scalar_type);
94 }
95 return Py_NewRef(torch::getTHPDtype(scalar_type));
96 END_HANDLE_TH_ERRORS
97 }
98
99 typedef PyObject* (*getter)(PyObject*, void*);
100
101 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays)
102 static struct PyGetSetDef THPDtype_properties[] = {
103 {"is_floating_point",
104 (getter)THPDtype_is_floating_point,
105 nullptr,
106 nullptr,
107 nullptr},
108 {"is_complex", (getter)THPDtype_is_complex, nullptr, nullptr, nullptr},
109 {"is_signed", (getter)THPDtype_is_signed, nullptr, nullptr, nullptr},
110 {"itemsize", (getter)THPDtype_itemsize, nullptr, nullptr, nullptr},
111 {nullptr}};
112
113 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays)
114 static PyMethodDef THPDtype_methods[] = {
115 {"__reduce__", THPDtype_reduce, METH_NOARGS, nullptr},
116 {"to_real", THPDtype_to_real, METH_NOARGS, nullptr},
117 {"to_complex", THPDtype_to_complex, METH_NOARGS, nullptr},
118 {nullptr} /* Sentinel */
119 };
120
THPDtype_repr(THPDtype * self)121 PyObject* THPDtype_repr(THPDtype* self) {
122 return THPUtils_packString(std::string("torch.") + self->name);
123 }
124
125 PyTypeObject THPDtypeType = {
126 PyVarObject_HEAD_INIT(nullptr, 0) "torch.dtype", /* tp_name */
127 sizeof(THPDtype), /* tp_basicsize */
128 0, /* tp_itemsize */
129 nullptr, /* tp_dealloc */
130 0, /* tp_vectorcall_offset */
131 nullptr, /* tp_getattr */
132 nullptr, /* tp_setattr */
133 nullptr, /* tp_reserved */
134 (reprfunc)THPDtype_repr, /* tp_repr */
135 nullptr, /* tp_as_number */
136 nullptr, /* tp_as_sequence */
137 nullptr, /* tp_as_mapping */
138 nullptr, /* tp_hash */
139 nullptr, /* tp_call */
140 nullptr, /* tp_str */
141 nullptr, /* tp_getattro */
142 nullptr, /* tp_setattro */
143 nullptr, /* tp_as_buffer */
144 Py_TPFLAGS_DEFAULT, /* tp_flags */
145 nullptr, /* tp_doc */
146 nullptr, /* tp_traverse */
147 nullptr, /* tp_clear */
148 nullptr, /* tp_richcompare */
149 0, /* tp_weaklistoffset */
150 nullptr, /* tp_iter */
151 nullptr, /* tp_iternext */
152 THPDtype_methods, /* tp_methods */
153 nullptr, /* tp_members */
154 THPDtype_properties, /* tp_getset */
155 nullptr, /* tp_base */
156 nullptr, /* tp_dict */
157 nullptr, /* tp_descr_get */
158 nullptr, /* tp_descr_set */
159 0, /* tp_dictoffset */
160 nullptr, /* tp_init */
161 nullptr, /* tp_alloc */
162 nullptr, /* tp_new */
163 };
164
THPDtype_init(PyObject * module)165 void THPDtype_init(PyObject* module) {
166 // Set a __dict__ with `__module__` = `torch`. This means
167 // `__module__` value will be inherited by instances
168 // (i.e. `torch.float32.__module__ == "torch"`). This will prevent
169 // Pickle from having to search all of sys.modules in order to find
170 // the module when pickling a dtype instance.
171 //
172 // We have to do this in C++ because extension types are not mutable
173 // from Python code.
174 //
175 // See https://github.com/pytorch/pytorch/issues/65077
176 TORCH_INTERNAL_ASSERT(THPDtypeType.tp_dict == nullptr);
177 auto dict = THPObjectPtr(PyDict_New());
178 if (!dict)
179 throw python_error();
180 auto torch = THPUtils_packString("torch");
181 if (!torch)
182 throw python_error();
183 if (PyDict_SetItemString(dict, "__module__", torch) < 0) {
184 throw python_error();
185 }
186 THPDtypeType.tp_dict = dict.release();
187
188 if (PyType_Ready(&THPDtypeType) < 0) {
189 throw python_error();
190 }
191 Py_INCREF(&THPDtypeType);
192 if (PyModule_AddObject(module, "dtype", (PyObject*)&THPDtypeType) != 0) {
193 throw python_error();
194 }
195 }
196