xref: /aosp_15_r20/external/pytorch/torch/csrc/dynamo/init.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/dynamo/init.h>
2 #include <torch/csrc/dynamo/utils.h>
3 
4 #include <pybind11/stl_bind.h>
5 #include <torch/csrc/Exceptions.h>
6 #include <torch/csrc/dynamo/cache_entry.h>
7 #include <torch/csrc/dynamo/cpython_defs.h>
8 #include <torch/csrc/dynamo/eval_frame.h>
9 #include <torch/csrc/dynamo/extra_state.h>
10 #include <torch/csrc/dynamo/guards.h>
11 #include <torch/csrc/dynamo/python_compiled_autograd.h>
12 #include <torch/csrc/utils/pybind.h>
13 #include <torch/csrc/utils/python_compat.h>
14 
15 static struct PyModuleDef _module =
16     {PyModuleDef_HEAD_INIT, "torch._C._dynamo", "", -1, nullptr};
17 
18 PYBIND11_MAKE_OPAQUE(std::vector<uint8_t>);
19 
20 namespace torch::dynamo {
21 
22 #if IS_PYTHON_3_11_PLUS
23 
24 std::vector<uint8_t> _PyOpcode_Caches_vec(
25     THP_PyOpcode_Caches,
26     THP_PyOpcode_Caches + THP_PyOpcode_Caches_size);
27 
28 #else
29 
30 std::vector<uint8_t> _PyOpcode_Caches_vec;
31 
32 #endif
33 
34 using torch::dynamo::autograd::torch_c_dynamo_compiled_autograd_init;
35 
initDynamoBindings(PyObject * torch)36 void initDynamoBindings(PyObject* torch) {
37   PyObject* dynamo = PyModule_Create(&_module);
38   if (dynamo == nullptr || PyModule_AddObject(torch, "_dynamo", dynamo) != 0) {
39     throw python_error();
40   }
41 
42   PyObject* eval_frame = torch_c_dynamo_eval_frame_init();
43   if (eval_frame == nullptr ||
44       PyModule_AddObject(dynamo, "eval_frame", eval_frame) != 0) {
45     throw python_error();
46   }
47 
48   PyObject* utils = torch_c_dynamo_utils_init();
49   if (utils == nullptr || PyModule_AddObject(dynamo, "utils", utils) != 0) {
50     throw python_error();
51   }
52 
53   PyObject* guards = torch_c_dynamo_guards_init();
54   if (guards == nullptr || PyModule_AddObject(dynamo, "guards", guards) != 0) {
55     throw python_error();
56   }
57 
58   PyObject* compiled_autograd = torch_c_dynamo_compiled_autograd_init();
59   if (compiled_autograd == nullptr ||
60       PyModule_AddObject(dynamo, "compiled_autograd", compiled_autograd) != 0) {
61     throw python_error();
62   }
63 
64   auto m = py::handle(eval_frame).cast<py::module>();
65 
66   py::class_<CacheEntry>(m, "_CacheEntry")
67       .def_readonly("check_fn", &CacheEntry::check_fn)
68       .def_readonly("code", &CacheEntry::code)
69       .def_readonly("compile_id", &CacheEntry::compile_id)
70       .def_property_readonly("next", &CacheEntry::next);
71 
72   py::class_<ExtraState>(m, "_ExtraState")
73       .def("invalidate", &ExtraState::invalidate);
74 
75   m.def("_debug_get_cache_entry_list", &_debug_get_cache_entry_list);
76   py::bind_vector<std::vector<uint8_t>>(m, "VectorUInt8");
77   m.attr("py_opcode_caches") = _PyOpcode_Caches_vec;
78 }
79 
80 } // namespace torch::dynamo
81