xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/python/module_python.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <pybind11/pybind11.h>
3 #include <pybind11/stl.h>
4 #include <torch/csrc/jit/api/module.h>
5 #include <torch/csrc/utils/pybind.h>
6 
7 namespace py = pybind11;
8 
9 namespace torch::jit {
10 
as_module(py::handle obj)11 inline std::optional<Module> as_module(py::handle obj) {
12   static py::handle ScriptModule =
13       py::module::import("torch.jit").attr("ScriptModule");
14   if (py::isinstance(obj, ScriptModule)) {
15     return py::cast<Module>(obj.attr("_c"));
16   }
17   return std::nullopt;
18 }
19 
as_object(py::handle obj)20 inline std::optional<Object> as_object(py::handle obj) {
21   static py::handle ScriptObject =
22       py::module::import("torch").attr("ScriptObject");
23   if (py::isinstance(obj, ScriptObject)) {
24     return py::cast<Object>(obj);
25   }
26 
27   static py::handle RecursiveScriptClass =
28       py::module::import("torch.jit").attr("RecursiveScriptClass");
29   if (py::isinstance(obj, RecursiveScriptClass)) {
30     return py::cast<Object>(obj.attr("_c"));
31   }
32   return std::nullopt;
33 }
34 
35 } // namespace torch::jit
36