xref: /aosp_15_r20/external/pytorch/torch/csrc/api/include/torch/python.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/detail/static.h>
4 #include <torch/nn/module.h>
5 #include <torch/ordered_dict.h>
6 #include <torch/types.h>
7 
8 #include <torch/csrc/Device.h>
9 #include <torch/csrc/Dtype.h>
10 #include <torch/csrc/DynamicTypes.h>
11 #include <torch/csrc/Exceptions.h>
12 #include <torch/csrc/autograd/python_variable.h>
13 #include <torch/csrc/python_headers.h>
14 #include <torch/csrc/utils/pybind.h>
15 #include <torch/csrc/utils/python_numbers.h>
16 #include <torch/csrc/utils/python_tuples.h>
17 
18 #include <iterator>
19 #include <string>
20 #include <unordered_map>
21 #include <utility>
22 #include <vector>
23 
24 namespace torch {
25 namespace python {
26 namespace detail {
py_object_to_device(py::object object)27 inline Device py_object_to_device(py::object object) {
28   PyObject* obj = object.ptr();
29   if (THPDevice_Check(obj)) {
30     return reinterpret_cast<THPDevice*>(obj)->device;
31   }
32   throw TypeError("Expected device");
33 }
34 
py_object_to_dtype(py::object object)35 inline Dtype py_object_to_dtype(py::object object) {
36   PyObject* obj = object.ptr();
37   if (THPDtype_Check(obj)) {
38     return reinterpret_cast<THPDtype*>(obj)->scalar_type;
39   }
40   throw TypeError("Expected dtype");
41 }
42 
43 template <typename ModuleType>
44 using PyModuleClass =
45     py::class_<ModuleType, torch::nn::Module, std::shared_ptr<ModuleType>>;
46 
47 /// Dynamically creates a subclass of `torch.nn.cpp.ModuleWrapper` that is also
48 /// a subclass of `torch.nn.Module`, and passes it the user-provided C++ module
49 /// to which it delegates all calls.
50 template <typename ModuleType>
bind_cpp_module_wrapper(py::module module,PyModuleClass<ModuleType> cpp_class,const char * name)51 void bind_cpp_module_wrapper(
52     py::module module,
53     PyModuleClass<ModuleType> cpp_class,
54     const char* name) {
55   // Grab the `torch.nn.cpp.ModuleWrapper` class, which we'll subclass
56   // with a dynamically created class below.
57   py::object cpp_module =
58       py::module::import("torch.nn.cpp").attr("ModuleWrapper");
59 
60   // Grab the `type` class which we'll use as a metaclass to create a new class
61   // dynamically.
62   py::object type_metaclass =
63       py::reinterpret_borrow<py::object>((PyObject*)&PyType_Type);
64 
65   // The `ModuleWrapper` constructor copies all functions to its own `__dict__`
66   // in its constructor, but we do need to give our dynamic class a constructor.
67   // Inside, we construct an instance of the original C++ module we're binding
68   // (the `torch::nn::Module` subclass), and then forward it to the
69   // `ModuleWrapper` constructor.
70   py::dict attributes;
71 
72   // `type()` always needs a `str`, but pybind11's `str()` method always creates
73   // a `unicode` object.
74   py::object name_str = py::str(name);
75 
76   // Dynamically create the subclass of `ModuleWrapper`, which is a subclass of
77   // `torch.nn.Module`, and will delegate all calls to the C++ module we're
78   // binding.
79   py::object wrapper_class =
80       type_metaclass(name_str, py::make_tuple(cpp_module), attributes);
81 
82   // The constructor of the dynamic class calls `ModuleWrapper.__init__()`,
83   // which replaces its methods with those of the C++ module.
84   wrapper_class.attr("__init__") = py::cpp_function(
85       [cpp_module, cpp_class](
86           py::object self, py::args args, py::kwargs kwargs) {
87         cpp_module.attr("__init__")(self, cpp_class(*args, **kwargs));
88       },
89       py::is_method(wrapper_class));
90 
91   // Calling `my_module.my_class` now means that `my_class` is a subclass of
92   // `ModuleWrapper`, and whose methods call into the C++ module we're binding.
93   module.attr(name) = wrapper_class;
94 }
95 } // namespace detail
96 
97 /// Adds method bindings for a pybind11 `class_` that binds an `nn::Module`
98 /// subclass.
99 ///
100 /// Say you have a pybind11 class object created with `py::class_<Net>(m,
101 /// "Net")`. This function will add all the necessary `.def()` calls to bind the
102 /// `nn::Module` base class' methods, such as `train()`, `eval()` etc. into
103 /// Python.
104 ///
105 /// Users should prefer to use `bind_module` if possible.
106 template <typename ModuleType, typename... Extra>
add_module_bindings(py::class_<ModuleType,Extra...> module)107 py::class_<ModuleType, Extra...> add_module_bindings(
108     py::class_<ModuleType, Extra...> module) {
109   // clang-format off
110   return module
111       .def("train",
112           [](ModuleType& module, bool mode) { module.train(mode); },
113           py::arg("mode") = true)
114       .def("eval", [](ModuleType& module) { module.eval(); })
115       .def("clone", [](ModuleType& module) { return module.clone(); })
116       .def_property_readonly(
117           "training", [](ModuleType& module) { return module.is_training(); })
118       .def("zero_grad", [](ModuleType& module) { module.zero_grad(); })
119       .def_property_readonly( "_parameters", [](ModuleType& module) {
120             return module.named_parameters(/*recurse=*/false);
121           })
122       .def("parameters", [](ModuleType& module, bool recurse) {
123             return module.parameters(recurse);
124           },
125           py::arg("recurse") = true)
126       .def("named_parameters", [](ModuleType& module, bool recurse) {
127             return module.named_parameters(recurse);
128           },
129           py::arg("recurse") = true)
130       .def_property_readonly("_buffers", [](ModuleType& module) {
131             return module.named_buffers(/*recurse=*/false);
132           })
133       .def("buffers", [](ModuleType& module, bool recurse) {
134             return module.buffers(recurse); },
135           py::arg("recurse") = true)
136       .def("named_buffers", [](ModuleType& module, bool recurse) {
137             return module.named_buffers(recurse);
138           },
139           py::arg("recurse") = true)
140       .def_property_readonly(
141         "_modules", [](ModuleType& module) { return module.named_children(); })
142       .def("modules", [](ModuleType& module) { return module.modules(); })
143       .def("named_modules",
144            [](ModuleType& module, py::object /* unused */, std::string prefix, bool remove_duplicate /* unused */) {
145             return module.named_modules(std::move(prefix));
146           },
147           py::arg("memo") = py::none(),
148           py::arg("prefix") = std::string(),
149           py::arg("remove_duplicate") = true)
150       .def("children", [](ModuleType& module) { return module.children(); })
151       .def("named_children",
152           [](ModuleType& module) { return module.named_children(); })
153       .def("to", [](ModuleType& module, py::object object, bool non_blocking) {
154             if (THPDevice_Check(object.ptr())) {
155               module.to(
156                   reinterpret_cast<THPDevice*>(object.ptr())->device,
157                   non_blocking);
158             } else {
159               module.to(detail::py_object_to_dtype(object), non_blocking);
160             }
161           },
162           py::arg("dtype_or_device"),
163           py::arg("non_blocking") = false)
164       .def("to",
165           [](ModuleType& module,
166              py::object device,
167              py::object dtype,
168              bool non_blocking) {
169               if (device.is_none()) {
170                 module.to(detail::py_object_to_dtype(dtype), non_blocking);
171               } else if (dtype.is_none()) {
172                 module.to(detail::py_object_to_device(device), non_blocking);
173               } else {
174                 module.to(
175                     detail::py_object_to_device(device),
176                     detail::py_object_to_dtype(dtype),
177                     non_blocking);
178               }
179           },
180           py::arg("device"),
181           py::arg("dtype"),
182           py::arg("non_blocking") = false)
183       .def("cuda", [](ModuleType& module) { module.to(kCUDA); })
184       .def("cpu", [](ModuleType& module) { module.to(kCPU); })
185       .def("float", [](ModuleType& module) { module.to(kFloat32); })
186       .def("double", [](ModuleType& module) { module.to(kFloat64); })
187       .def("half", [](ModuleType& module) { module.to(kFloat16); })
188       .def("__str__", [](ModuleType& module) { return module.name(); })
189       .def("__repr__", [](ModuleType& module) { return module.name(); });
190   // clang-format on
191 }
192 
193 /// Creates a pybind11 class object for an `nn::Module` subclass type and adds
194 /// default bindings.
195 ///
196 /// After adding the default bindings, the class object is returned, such that
197 /// you can add more bindings.
198 ///
199 /// Example usage:
200 /// \rst
201 /// .. code-block:: cpp
202 ///
203 ///   struct Net : torch::nn::Module {
204 ///     Net(int in, int out) { }
205 ///     torch::Tensor forward(torch::Tensor x) { return x; }
206 ///   };
207 ///
208 ///   PYBIND11_MODULE(my_module, m) {
209 ///     torch::python::bind_module<Net>(m, "Net")
210 ///       .def(py::init<int, int>())
211 ///       .def("forward", &Net::forward);
212 ///  }
213 /// \endrst
214 template <typename ModuleType, bool force_enable = false>
215 std::enable_if_t<
216     !torch::detail::has_forward<ModuleType>::value || force_enable,
217     detail::PyModuleClass<ModuleType>>
bind_module(py::module module,const char * name)218 bind_module(py::module module, const char* name) {
219   py::module cpp = module.def_submodule("cpp");
220   auto cpp_class =
221       add_module_bindings(detail::PyModuleClass<ModuleType>(cpp, name));
222   detail::bind_cpp_module_wrapper(module, cpp_class, name);
223   return cpp_class;
224 }
225 
226 /// Creates a pybind11 class object for an `nn::Module` subclass type and adds
227 /// default bindings.
228 ///
229 /// After adding the default bindings, the class object is returned, such that
230 /// you can add more bindings.
231 ///
232 /// If the class has a `forward()` method, it is automatically exposed as
233 /// `forward()` and `__call__` in Python.
234 ///
235 /// Example usage:
236 /// \rst
237 /// .. code-block:: cpp
238 ///
239 ///   struct Net : torch::nn::Module {
240 ///     Net(int in, int out) { }
241 ///     torch::Tensor forward(torch::Tensor x) { return x; }
242 ///   };
243 ///
244 ///   PYBIND11_MODULE(my_module, m) {
245 ///     torch::python::bind_module<Net>(m, "Net")
246 ///       .def(py::init<int, int>())
247 ///       .def("forward", &Net::forward);
248 ///  }
249 /// \endrst
250 template <
251     typename ModuleType,
252     typename = std::enable_if_t<torch::detail::has_forward<ModuleType>::value>>
bind_module(py::module module,const char * name)253 detail::PyModuleClass<ModuleType> bind_module(
254     py::module module,
255     const char* name) {
256   return bind_module<ModuleType, /*force_enable=*/true>(module, name)
257       .def("forward", &ModuleType::forward)
258       .def("__call__", &ModuleType::forward);
259 }
260 } // namespace python
261 } // namespace torch
262