1 #include <torch/csrc/utils/tensor_qschemes.h> 2 3 #include <c10/core/QScheme.h> 4 #include <c10/util/irange.h> 5 #include <torch/csrc/DynamicTypes.h> 6 #include <torch/csrc/Exceptions.h> 7 #include <torch/csrc/QScheme.h> 8 9 #include <torch/csrc/python_headers.h> 10 #include <torch/csrc/utils/object_ptr.h> 11 12 namespace torch::utils { 13 14 // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) 15 static std::array<PyObject*, at::COMPILE_TIME_NUM_QSCHEMES> thp_qscheme_array; 16 initializeQSchemes()17void initializeQSchemes() { 18 auto torch_module = THPObjectPtr(PyImport_ImportModule("torch")); 19 if (!torch_module) { 20 throw python_error(); 21 } 22 23 for (const auto i : c10::irange(at::COMPILE_TIME_NUM_QSCHEMES)) { 24 auto qscheme = static_cast<at::QScheme>(i); 25 PyObject* qscheme_obj = THPQScheme_New(qscheme, toString(qscheme)); 26 thp_qscheme_array[static_cast<int>(qscheme)] = qscheme_obj; 27 Py_INCREF(qscheme_obj); 28 if (PyModule_AddObject( 29 torch_module, toString(qscheme).c_str(), qscheme_obj) != 0) { 30 throw python_error(); 31 } 32 } 33 } 34 getTHPQScheme(at::QScheme qscheme)35PyObject* getTHPQScheme(at::QScheme qscheme) { 36 auto qscheme_ = thp_qscheme_array[static_cast<int>(qscheme)]; 37 if (!qscheme_) { 38 throw std::invalid_argument("unsupported QScheme"); 39 } 40 return qscheme_; 41 } 42 } // namespace torch::utils 43