xref: /aosp_15_r20/external/pytorch/torch/csrc/utils/tensor_qschemes.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/utils/tensor_qschemes.h>
2 
3 #include <c10/core/QScheme.h>
4 #include <c10/util/irange.h>
5 #include <torch/csrc/DynamicTypes.h>
6 #include <torch/csrc/Exceptions.h>
7 #include <torch/csrc/QScheme.h>
8 
9 #include <torch/csrc/python_headers.h>
10 #include <torch/csrc/utils/object_ptr.h>
11 
12 namespace torch::utils {
13 
14 // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
15 static std::array<PyObject*, at::COMPILE_TIME_NUM_QSCHEMES> thp_qscheme_array;
16 
initializeQSchemes()17 void initializeQSchemes() {
18   auto torch_module = THPObjectPtr(PyImport_ImportModule("torch"));
19   if (!torch_module) {
20     throw python_error();
21   }
22 
23   for (const auto i : c10::irange(at::COMPILE_TIME_NUM_QSCHEMES)) {
24     auto qscheme = static_cast<at::QScheme>(i);
25     PyObject* qscheme_obj = THPQScheme_New(qscheme, toString(qscheme));
26     thp_qscheme_array[static_cast<int>(qscheme)] = qscheme_obj;
27     Py_INCREF(qscheme_obj);
28     if (PyModule_AddObject(
29             torch_module, toString(qscheme).c_str(), qscheme_obj) != 0) {
30       throw python_error();
31     }
32   }
33 }
34 
getTHPQScheme(at::QScheme qscheme)35 PyObject* getTHPQScheme(at::QScheme qscheme) {
36   auto qscheme_ = thp_qscheme_array[static_cast<int>(qscheme)];
37   if (!qscheme_) {
38     throw std::invalid_argument("unsupported QScheme");
39   }
40   return qscheme_;
41 }
42 } // namespace torch::utils
43