1 #include <torch/csrc/Dtype.h> 2 #include <torch/csrc/DynamicTypes.h> 3 #include <torch/csrc/Exceptions.h> 4 #include <torch/csrc/python_headers.h> 5 #include <torch/csrc/utils/object_ptr.h> 6 #include <torch/csrc/utils/tensor_dtypes.h> 7 8 namespace torch::utils { 9 initializeDtypes()10void initializeDtypes() { 11 auto torch_module = THPObjectPtr(PyImport_ImportModule("torch")); 12 if (!torch_module) 13 throw python_error(); 14 15 #define DEFINE_SCALAR_TYPE(_1, n) at::ScalarType::n, 16 17 auto all_scalar_types = { 18 AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_SCALAR_TYPE)}; 19 20 #undef DEFINE_SCALAR_TYPE 21 22 for (at::ScalarType scalarType : all_scalar_types) { 23 auto [primary_name, legacy_name] = c10::getDtypeNames(scalarType); 24 PyObject* dtype = THPDtype_New(scalarType, primary_name); 25 torch::registerDtypeObject((THPDtype*)dtype, scalarType); 26 Py_INCREF(dtype); 27 if (PyModule_AddObject(torch_module.get(), primary_name.c_str(), dtype) != 28 0) { 29 throw python_error(); 30 } 31 if (!legacy_name.empty()) { 32 Py_INCREF(dtype); 33 if (PyModule_AddObject(torch_module.get(), legacy_name.c_str(), dtype) != 34 0) { 35 throw python_error(); 36 } 37 } 38 } 39 } 40 41 } // namespace torch::utils 42