xref: /aosp_15_r20/external/pytorch/torch/csrc/inductor/aoti_runner/pybind.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/inductor/aoti_runner/model_container_runner_cpu.h>
2 #ifdef USE_CUDA
3 #include <torch/csrc/inductor/aoti_runner/model_container_runner_cuda.h>
4 #endif
5 #include <torch/csrc/inductor/aoti_torch/tensor_converter.h>
6 #include <torch/csrc/inductor/aoti_torch/utils.h>
7 
8 #include <torch/csrc/utils/pybind.h>
9 
10 namespace torch::inductor {
11 
initAOTIRunnerBindings(PyObject * module)12 void initAOTIRunnerBindings(PyObject* module) {
13   auto rootModule = py::handle(module).cast<py::module>();
14   auto m = rootModule.def_submodule("_aoti");
15 
16   py::class_<AOTIModelContainerRunnerCpu>(m, "AOTIModelContainerRunnerCpu")
17       .def(py::init<const std::string&, int>())
18       .def("run", &AOTIModelContainerRunnerCpu::run)
19       .def("get_call_spec", &AOTIModelContainerRunnerCpu::get_call_spec)
20       .def(
21           "get_constant_names_to_original_fqns",
22           &AOTIModelContainerRunnerCpu::getConstantNamesToOriginalFQNs)
23       .def(
24           "get_constant_names_to_dtypes",
25           &AOTIModelContainerRunnerCpu::getConstantNamesToDtypes);
26 
27 #ifdef USE_CUDA
28   py::class_<AOTIModelContainerRunnerCuda>(m, "AOTIModelContainerRunnerCuda")
29       .def(py::init<const std::string&, int>())
30       .def(py::init<const std::string&, int, const std::string&>())
31       .def(py::init<
32            const std::string&,
33            int,
34            const std::string&,
35            const std::string&>())
36       .def("run", &AOTIModelContainerRunnerCuda::run)
37       .def("get_call_spec", &AOTIModelContainerRunnerCuda::get_call_spec)
38       .def(
39           "get_constant_names_to_original_fqns",
40           &AOTIModelContainerRunnerCuda::getConstantNamesToOriginalFQNs)
41       .def(
42           "get_constant_names_to_dtypes",
43           &AOTIModelContainerRunnerCuda::getConstantNamesToDtypes);
44 #endif
45 
46   m.def(
47       "unsafe_alloc_void_ptrs_from_tensors",
48       [](std::vector<at::Tensor>& tensors) {
49         std::vector<AtenTensorHandle> handles =
50             torch::aot_inductor::unsafe_alloc_new_handles_from_tensors(tensors);
51         std::vector<void*> result(
52             reinterpret_cast<void**>(handles.data()),
53             reinterpret_cast<void**>(handles.data()) + handles.size());
54         return result;
55       });
56   m.def("unsafe_alloc_void_ptr_from_tensor", [](at::Tensor& tensor) {
57     return reinterpret_cast<void*>(
58         torch::aot_inductor::new_tensor_handle(std::move(tensor)));
59   });
60   m.def(
61       "alloc_tensors_by_stealing_from_void_ptrs",
62       [](std::vector<void*>& raw_handles) {
63         return torch::aot_inductor::alloc_tensors_by_stealing_from_handles(
64             reinterpret_cast<AtenTensorHandle*>(raw_handles.data()),
65             raw_handles.size());
66       });
67   m.def("alloc_tensor_by_stealing_from_void_ptr", [](void* raw_handle) {
68     return *torch::aot_inductor::tensor_handle_to_tensor_pointer(
69         reinterpret_cast<AtenTensorHandle>(raw_handle));
70   });
71 }
72 } // namespace torch::inductor
73