1 #pragma once
2
3 #include <torch/detail/static.h>
4 #include <torch/nn/module.h>
5 #include <torch/ordered_dict.h>
6 #include <torch/types.h>
7
8 #include <torch/csrc/Device.h>
9 #include <torch/csrc/Dtype.h>
10 #include <torch/csrc/DynamicTypes.h>
11 #include <torch/csrc/Exceptions.h>
12 #include <torch/csrc/autograd/python_variable.h>
13 #include <torch/csrc/python_headers.h>
14 #include <torch/csrc/utils/pybind.h>
15 #include <torch/csrc/utils/python_numbers.h>
16 #include <torch/csrc/utils/python_tuples.h>
17
18 #include <iterator>
19 #include <string>
20 #include <unordered_map>
21 #include <utility>
22 #include <vector>
23
24 namespace torch {
25 namespace python {
26 namespace detail {
py_object_to_device(py::object object)27 inline Device py_object_to_device(py::object object) {
28 PyObject* obj = object.ptr();
29 if (THPDevice_Check(obj)) {
30 return reinterpret_cast<THPDevice*>(obj)->device;
31 }
32 throw TypeError("Expected device");
33 }
34
py_object_to_dtype(py::object object)35 inline Dtype py_object_to_dtype(py::object object) {
36 PyObject* obj = object.ptr();
37 if (THPDtype_Check(obj)) {
38 return reinterpret_cast<THPDtype*>(obj)->scalar_type;
39 }
40 throw TypeError("Expected dtype");
41 }
42
43 template <typename ModuleType>
44 using PyModuleClass =
45 py::class_<ModuleType, torch::nn::Module, std::shared_ptr<ModuleType>>;
46
47 /// Dynamically creates a subclass of `torch.nn.cpp.ModuleWrapper` that is also
48 /// a subclass of `torch.nn.Module`, and passes it the user-provided C++ module
49 /// to which it delegates all calls.
50 template <typename ModuleType>
bind_cpp_module_wrapper(py::module module,PyModuleClass<ModuleType> cpp_class,const char * name)51 void bind_cpp_module_wrapper(
52 py::module module,
53 PyModuleClass<ModuleType> cpp_class,
54 const char* name) {
55 // Grab the `torch.nn.cpp.ModuleWrapper` class, which we'll subclass
56 // with a dynamically created class below.
57 py::object cpp_module =
58 py::module::import("torch.nn.cpp").attr("ModuleWrapper");
59
60 // Grab the `type` class which we'll use as a metaclass to create a new class
61 // dynamically.
62 py::object type_metaclass =
63 py::reinterpret_borrow<py::object>((PyObject*)&PyType_Type);
64
65 // The `ModuleWrapper` constructor copies all functions to its own `__dict__`
66 // in its constructor, but we do need to give our dynamic class a constructor.
67 // Inside, we construct an instance of the original C++ module we're binding
68 // (the `torch::nn::Module` subclass), and then forward it to the
69 // `ModuleWrapper` constructor.
70 py::dict attributes;
71
72 // `type()` always needs a `str`, but pybind11's `str()` method always creates
73 // a `unicode` object.
74 py::object name_str = py::str(name);
75
76 // Dynamically create the subclass of `ModuleWrapper`, which is a subclass of
77 // `torch.nn.Module`, and will delegate all calls to the C++ module we're
78 // binding.
79 py::object wrapper_class =
80 type_metaclass(name_str, py::make_tuple(cpp_module), attributes);
81
82 // The constructor of the dynamic class calls `ModuleWrapper.__init__()`,
83 // which replaces its methods with those of the C++ module.
84 wrapper_class.attr("__init__") = py::cpp_function(
85 [cpp_module, cpp_class](
86 py::object self, py::args args, py::kwargs kwargs) {
87 cpp_module.attr("__init__")(self, cpp_class(*args, **kwargs));
88 },
89 py::is_method(wrapper_class));
90
91 // Calling `my_module.my_class` now means that `my_class` is a subclass of
92 // `ModuleWrapper`, and whose methods call into the C++ module we're binding.
93 module.attr(name) = wrapper_class;
94 }
95 } // namespace detail
96
97 /// Adds method bindings for a pybind11 `class_` that binds an `nn::Module`
98 /// subclass.
99 ///
100 /// Say you have a pybind11 class object created with `py::class_<Net>(m,
101 /// "Net")`. This function will add all the necessary `.def()` calls to bind the
102 /// `nn::Module` base class' methods, such as `train()`, `eval()` etc. into
103 /// Python.
104 ///
105 /// Users should prefer to use `bind_module` if possible.
106 template <typename ModuleType, typename... Extra>
add_module_bindings(py::class_<ModuleType,Extra...> module)107 py::class_<ModuleType, Extra...> add_module_bindings(
108 py::class_<ModuleType, Extra...> module) {
109 // clang-format off
110 return module
111 .def("train",
112 [](ModuleType& module, bool mode) { module.train(mode); },
113 py::arg("mode") = true)
114 .def("eval", [](ModuleType& module) { module.eval(); })
115 .def("clone", [](ModuleType& module) { return module.clone(); })
116 .def_property_readonly(
117 "training", [](ModuleType& module) { return module.is_training(); })
118 .def("zero_grad", [](ModuleType& module) { module.zero_grad(); })
119 .def_property_readonly( "_parameters", [](ModuleType& module) {
120 return module.named_parameters(/*recurse=*/false);
121 })
122 .def("parameters", [](ModuleType& module, bool recurse) {
123 return module.parameters(recurse);
124 },
125 py::arg("recurse") = true)
126 .def("named_parameters", [](ModuleType& module, bool recurse) {
127 return module.named_parameters(recurse);
128 },
129 py::arg("recurse") = true)
130 .def_property_readonly("_buffers", [](ModuleType& module) {
131 return module.named_buffers(/*recurse=*/false);
132 })
133 .def("buffers", [](ModuleType& module, bool recurse) {
134 return module.buffers(recurse); },
135 py::arg("recurse") = true)
136 .def("named_buffers", [](ModuleType& module, bool recurse) {
137 return module.named_buffers(recurse);
138 },
139 py::arg("recurse") = true)
140 .def_property_readonly(
141 "_modules", [](ModuleType& module) { return module.named_children(); })
142 .def("modules", [](ModuleType& module) { return module.modules(); })
143 .def("named_modules",
144 [](ModuleType& module, py::object /* unused */, std::string prefix, bool remove_duplicate /* unused */) {
145 return module.named_modules(std::move(prefix));
146 },
147 py::arg("memo") = py::none(),
148 py::arg("prefix") = std::string(),
149 py::arg("remove_duplicate") = true)
150 .def("children", [](ModuleType& module) { return module.children(); })
151 .def("named_children",
152 [](ModuleType& module) { return module.named_children(); })
153 .def("to", [](ModuleType& module, py::object object, bool non_blocking) {
154 if (THPDevice_Check(object.ptr())) {
155 module.to(
156 reinterpret_cast<THPDevice*>(object.ptr())->device,
157 non_blocking);
158 } else {
159 module.to(detail::py_object_to_dtype(object), non_blocking);
160 }
161 },
162 py::arg("dtype_or_device"),
163 py::arg("non_blocking") = false)
164 .def("to",
165 [](ModuleType& module,
166 py::object device,
167 py::object dtype,
168 bool non_blocking) {
169 if (device.is_none()) {
170 module.to(detail::py_object_to_dtype(dtype), non_blocking);
171 } else if (dtype.is_none()) {
172 module.to(detail::py_object_to_device(device), non_blocking);
173 } else {
174 module.to(
175 detail::py_object_to_device(device),
176 detail::py_object_to_dtype(dtype),
177 non_blocking);
178 }
179 },
180 py::arg("device"),
181 py::arg("dtype"),
182 py::arg("non_blocking") = false)
183 .def("cuda", [](ModuleType& module) { module.to(kCUDA); })
184 .def("cpu", [](ModuleType& module) { module.to(kCPU); })
185 .def("float", [](ModuleType& module) { module.to(kFloat32); })
186 .def("double", [](ModuleType& module) { module.to(kFloat64); })
187 .def("half", [](ModuleType& module) { module.to(kFloat16); })
188 .def("__str__", [](ModuleType& module) { return module.name(); })
189 .def("__repr__", [](ModuleType& module) { return module.name(); });
190 // clang-format on
191 }
192
193 /// Creates a pybind11 class object for an `nn::Module` subclass type and adds
194 /// default bindings.
195 ///
196 /// After adding the default bindings, the class object is returned, such that
197 /// you can add more bindings.
198 ///
199 /// Example usage:
200 /// \rst
201 /// .. code-block:: cpp
202 ///
203 /// struct Net : torch::nn::Module {
204 /// Net(int in, int out) { }
205 /// torch::Tensor forward(torch::Tensor x) { return x; }
206 /// };
207 ///
208 /// PYBIND11_MODULE(my_module, m) {
209 /// torch::python::bind_module<Net>(m, "Net")
210 /// .def(py::init<int, int>())
211 /// .def("forward", &Net::forward);
212 /// }
213 /// \endrst
214 template <typename ModuleType, bool force_enable = false>
215 std::enable_if_t<
216 !torch::detail::has_forward<ModuleType>::value || force_enable,
217 detail::PyModuleClass<ModuleType>>
bind_module(py::module module,const char * name)218 bind_module(py::module module, const char* name) {
219 py::module cpp = module.def_submodule("cpp");
220 auto cpp_class =
221 add_module_bindings(detail::PyModuleClass<ModuleType>(cpp, name));
222 detail::bind_cpp_module_wrapper(module, cpp_class, name);
223 return cpp_class;
224 }
225
226 /// Creates a pybind11 class object for an `nn::Module` subclass type and adds
227 /// default bindings.
228 ///
229 /// After adding the default bindings, the class object is returned, such that
230 /// you can add more bindings.
231 ///
232 /// If the class has a `forward()` method, it is automatically exposed as
233 /// `forward()` and `__call__` in Python.
234 ///
235 /// Example usage:
236 /// \rst
237 /// .. code-block:: cpp
238 ///
239 /// struct Net : torch::nn::Module {
240 /// Net(int in, int out) { }
241 /// torch::Tensor forward(torch::Tensor x) { return x; }
242 /// };
243 ///
244 /// PYBIND11_MODULE(my_module, m) {
245 /// torch::python::bind_module<Net>(m, "Net")
246 /// .def(py::init<int, int>())
247 /// .def("forward", &Net::forward);
248 /// }
249 /// \endrst
250 template <
251 typename ModuleType,
252 typename = std::enable_if_t<torch::detail::has_forward<ModuleType>::value>>
bind_module(py::module module,const char * name)253 detail::PyModuleClass<ModuleType> bind_module(
254 py::module module,
255 const char* name) {
256 return bind_module<ModuleType, /*force_enable=*/true>(module, name)
257 .def("forward", &ModuleType::forward)
258 .def("__call__", &ModuleType::forward);
259 }
260 } // namespace python
261 } // namespace torch
262