1 #pragma once 2 #include <pybind11/pybind11.h> 3 #include <pybind11/stl.h> 4 #include <torch/csrc/jit/api/module.h> 5 #include <torch/csrc/utils/pybind.h> 6 7 namespace py = pybind11; 8 9 namespace torch::jit { 10 as_module(py::handle obj)11inline std::optional<Module> as_module(py::handle obj) { 12 static py::handle ScriptModule = 13 py::module::import("torch.jit").attr("ScriptModule"); 14 if (py::isinstance(obj, ScriptModule)) { 15 return py::cast<Module>(obj.attr("_c")); 16 } 17 return std::nullopt; 18 } 19 as_object(py::handle obj)20inline std::optional<Object> as_object(py::handle obj) { 21 static py::handle ScriptObject = 22 py::module::import("torch").attr("ScriptObject"); 23 if (py::isinstance(obj, ScriptObject)) { 24 return py::cast<Object>(obj); 25 } 26 27 static py::handle RecursiveScriptClass = 28 py::module::import("torch.jit").attr("RecursiveScriptClass"); 29 if (py::isinstance(obj, RecursiveScriptClass)) { 30 return py::cast<Object>(obj.attr("_c")); 31 } 32 return std::nullopt; 33 } 34 35 } // namespace torch::jit 36