xref: /aosp_15_r20/external/pytorch/torch/csrc/utils/tensor_dtypes.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/Dtype.h>
2 #include <torch/csrc/DynamicTypes.h>
3 #include <torch/csrc/Exceptions.h>
4 #include <torch/csrc/python_headers.h>
5 #include <torch/csrc/utils/object_ptr.h>
6 #include <torch/csrc/utils/tensor_dtypes.h>
7 
8 namespace torch::utils {
9 
initializeDtypes()10 void initializeDtypes() {
11   auto torch_module = THPObjectPtr(PyImport_ImportModule("torch"));
12   if (!torch_module)
13     throw python_error();
14 
15 #define DEFINE_SCALAR_TYPE(_1, n) at::ScalarType::n,
16 
17   auto all_scalar_types = {
18       AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_SCALAR_TYPE)};
19 
20 #undef DEFINE_SCALAR_TYPE
21 
22   for (at::ScalarType scalarType : all_scalar_types) {
23     auto [primary_name, legacy_name] = c10::getDtypeNames(scalarType);
24     PyObject* dtype = THPDtype_New(scalarType, primary_name);
25     torch::registerDtypeObject((THPDtype*)dtype, scalarType);
26     Py_INCREF(dtype);
27     if (PyModule_AddObject(torch_module.get(), primary_name.c_str(), dtype) !=
28         0) {
29       throw python_error();
30     }
31     if (!legacy_name.empty()) {
32       Py_INCREF(dtype);
33       if (PyModule_AddObject(torch_module.get(), legacy_name.c_str(), dtype) !=
34           0) {
35         throw python_error();
36       }
37     }
38   }
39 }
40 
41 } // namespace torch::utils
42