1 #include <c10/util/irange.h>
2 #include <torch/csrc/autograd/python_cpp_function.h>
3
4 #include <torch/csrc/python_headers.h>
5 #include <cstdio>
6 #include <memory>
7 #include <typeindex>
8 #include <unordered_map>
9
10 #include <pybind11/pybind11.h>
11 #include <torch/csrc/DynamicTypes.h>
12 #include <torch/csrc/Exceptions.h>
13 #include <torch/csrc/autograd/python_anomaly_mode.h>
14 #include <torch/csrc/autograd/python_function.h>
15 #include <torch/csrc/autograd/python_hook.h>
16 #include <torch/csrc/autograd/python_variable.h>
17 #include <torch/csrc/utils/pybind.h>
18 #include <torch/csrc/utils/python_numbers.h>
19 #include <torch/csrc/utils/python_strings.h>
20
21 using namespace torch::autograd;
22
23 namespace torch::autograd {
24
25 namespace {
26
THPCppFunction_call(PyObject * self,PyObject * args,PyObject * kwargs)27 PyObject* THPCppFunction_call(
28 PyObject* self,
29 PyObject* args,
30 PyObject* kwargs) {
31 if (kwargs && PyDict_Size(kwargs) != 0) {
32 return PyErr_Format(PyExc_TypeError, "keyword arguments are not supported");
33 }
34
35 auto num_inputs = PyTuple_GET_SIZE(args);
36 auto num_inputs_required = ((THPCppFunction*)self)->cdata->num_inputs();
37 if (num_inputs != num_inputs_required) {
38 return PyErr_Format(
39 PyExc_TypeError,
40 "expected %d arguments, got %d instead",
41 num_inputs_required,
42 num_inputs);
43 }
44 variable_list vars(num_inputs);
45 for (int i = 0; i != num_inputs; ++i) {
46 PyObject* arg = PyTuple_GET_ITEM(args, i);
47 if (arg == Py_None) {
48 continue;
49 }
50 if (!THPVariable_Check(arg)) {
51 return PyErr_Format(PyExc_TypeError, "argument %d is not a Variable", i);
52 }
53 vars[i] = THPVariable_Unpack(arg);
54 }
55
56 variable_list output;
57
58 HANDLE_TH_ERRORS {
59 pybind11::gil_scoped_release nogil;
60 output = (*((THPCppFunction*)self)->cdata)(std::move(vars));
61 }
62 END_HANDLE_TH_ERRORS
63
64 auto num_outputs = output.size();
65 if (num_outputs == 1) {
66 // assume we want to unpack one element tuples for now
67 return THPVariable_Wrap(output[0]);
68 }
69
70 THPObjectPtr tuple(PyTuple_New(static_cast<Py_ssize_t>(num_outputs)));
71 for (size_t i = 0; i != num_outputs; ++i) {
72 PyTuple_SET_ITEM(tuple.get(), i, THPVariable_Wrap(output[i]));
73 }
74 return tuple.release();
75 }
76
THPCppFunction_traverse(PyObject * self,visitproc visit,void * arg)77 int THPCppFunction_traverse(PyObject* self, visitproc visit, void* arg) {
78 if ((((THPCppFunction*)self)->cdata).use_count() == 1) {
79 // The fields traversed below are owned by the cpp grad_fn, which we own a
80 // reference to. We should only them traverse however if we are the only
81 // owner of the grad_fn, otherwise we risk prematurely gc'ing the grad_fn.
82 //
83 // See: https://github.com/pytorch/pytorch/issues/102174
84 auto& fn = *((THPCppFunction*)self)->cdata;
85 for (const auto& hook : fn.tensor_pre_hooks()) {
86 if (auto pyhook = dynamic_cast<PyFunctionTensorPreHook*>(hook.get())) {
87 Py_VISIT(pyhook->dict);
88 }
89 }
90 // NOTE [retains_grad_hook PyObject traversal]
91 // In theory this shouldn't be necessary, because retains_grad_hooks should
92 // not contain any PyFunctionTensorPreHooks. The alternative is to have a
93 // check that actually guarantees this.
94 for (const auto& pair : fn.retains_grad_hooks()) {
95 if (auto pyhook =
96 dynamic_cast<PyFunctionTensorPreHook*>(pair.second.get())) {
97 Py_VISIT(pyhook->dict);
98 }
99 }
100 for (const auto& hook : fn.pre_hooks()) {
101 if (auto pyhook = dynamic_cast<PyFunctionPreHook*>(hook.get())) {
102 Py_VISIT(pyhook->dict);
103 }
104 }
105 for (const auto& hook : fn.post_hooks()) {
106 if (auto pyhook = dynamic_cast<PyFunctionPostHook*>(hook.get())) {
107 Py_VISIT(pyhook->dict);
108 }
109 }
110 }
111 return 0;
112 }
113
THPCppFunction_clear(PyObject * self)114 int THPCppFunction_clear(PyObject* self) {
115 auto f = (THPCppFunction*)self;
116 // Remove the weak ref of the c++ object if it exist
117 if (f->cdata) {
118 f->cdata->set_pyobj(nullptr);
119 }
120 f->cdata.reset();
121 return 0;
122 }
123
THPCppFunction_dealloc(PyObject * self)124 void THPCppFunction_dealloc(PyObject* self) {
125 PyObject_GC_UnTrack(self);
126 THPCppFunction_clear(self);
127 ((THPCppFunction*)self)->cdata.~shared_ptr();
128 Py_TYPE(self)->tp_free(self);
129 }
130
131 } // namespace
132
THPCppFunction_next_functions(PyObject * self,void * _unused)133 PyObject* THPCppFunction_next_functions(PyObject* self, void* _unused) {
134 auto cdata = reinterpret_cast<const THPCppFunction*>(self)->cdata;
135 const auto num_next = cdata->num_outputs();
136 THPObjectPtr py_functions(PyTuple_New(num_next));
137 if (!py_functions)
138 return nullptr;
139 for (const auto i : c10::irange(num_next)) {
140 auto& c_tuple = cdata->next_edge(i);
141 THPObjectPtr tuple(PyTuple_New(2));
142 if (!tuple)
143 return nullptr;
144 PyObject* py_fn = functionToPyObject(c_tuple.function);
145 if (!py_fn)
146 return nullptr;
147 PyTuple_SET_ITEM(tuple.get(), 0, py_fn);
148 PyObject* py_idx = THPUtils_packUInt32(c_tuple.input_nr);
149 if (!py_idx)
150 return nullptr;
151 PyTuple_SET_ITEM(tuple.get(), 1, py_idx);
152 PyTuple_SET_ITEM(py_functions.get(), i, tuple.release());
153 }
154 return py_functions.release();
155 }
156
THPCppFunction_metadata(PyObject * self,void * _unused)157 PyObject* THPCppFunction_metadata(PyObject* self, void* _unused) {
158 auto* metadata =
159 static_cast<PyAnomalyMetadata*>(
160 reinterpret_cast<THPCppFunction*>(self)->cdata->metadata())
161 ->dict();
162
163 Py_XINCREF(metadata);
164 return metadata;
165 }
166
THPCppFunction_requires_grad(PyObject * self,void * unused)167 PyObject* THPCppFunction_requires_grad(PyObject* self, void* unused) {
168 Py_RETURN_TRUE;
169 }
170
THPCppFunction_register_hook_dict(PyObject * self,PyObject * _var)171 PyObject* THPCppFunction_register_hook_dict(PyObject* self, PyObject* _var) {
172 if (!THPVariable_Check(_var)) {
173 return PyErr_Format(
174 PyExc_TypeError, "_register_hook_dict expected a variable");
175 }
176 auto var = (THPVariable*)_var;
177 auto& fn = *((THPCppFunction*)self)->cdata;
178 fn.add_tensor_pre_hook(std::make_unique<PyFunctionTensorPreHook>(
179 var->backward_hooks, THPVariable_Unpack(var).output_nr()));
180 Py_RETURN_NONE;
181 }
182
THPCppFunction_register_hook(PyObject * self,PyObject * hook)183 PyObject* THPCppFunction_register_hook(PyObject* self, PyObject* hook) {
184 auto& fn = *((THPCppFunction*)self)->cdata;
185 return registerFunctionHook(fn, hook);
186 }
187
THPCppFunction_register_prehook(PyObject * self,PyObject * hook)188 PyObject* THPCppFunction_register_prehook(PyObject* self, PyObject* hook) {
189 auto& fn = *((THPCppFunction*)self)->cdata;
190 return registerFunctionPreHook(fn, hook);
191 }
192
THPCppFunction_name(PyObject * self,PyObject * noargs)193 PyObject* THPCppFunction_name(PyObject* self, PyObject* noargs) {
194 auto& fn = *((THPCppFunction*)self)->cdata;
195 return THPUtils_packString(fn.name());
196 }
197
THPCppFunction_sequence_nr(PyObject * self,PyObject * noargs)198 PyObject* THPCppFunction_sequence_nr(PyObject* self, PyObject* noargs) {
199 auto& fn = *((THPCppFunction*)self)->cdata;
200 return THPUtils_packUInt64(fn.sequence_nr());
201 }
202
THPCppFunction_set_sequence_nr(PyObject * self,PyObject * sequence_nr)203 PyObject* THPCppFunction_set_sequence_nr(
204 PyObject* self,
205 PyObject* sequence_nr) {
206 HANDLE_TH_ERRORS
207 auto& fn = *((THPCppFunction*)self)->cdata;
208 fn.set_sequence_nr(THPUtils_unpackUInt64(sequence_nr));
209 Py_RETURN_NONE;
210 END_HANDLE_TH_ERRORS
211 }
212
THPCppFunction_input_metadata(PyObject * self,void * closure)213 PyObject* THPCppFunction_input_metadata(PyObject* self, void* closure) {
214 HANDLE_TH_ERRORS;
215 auto& fn = *((THPCppFunction*)self)->cdata;
216 const auto num_inputs =
217 fn.num_inputs(); // Assuming there's a method to get the number of inputs
218 THPObjectPtr list(PyTuple_New(num_inputs));
219 if (!list) {
220 return nullptr;
221 }
222 for (size_t i = 0; i < num_inputs; ++i) {
223 const auto& metadata = fn.input_metadata(i);
224 THPObjectPtr item(py::cast(metadata).release().ptr());
225 if (!item) {
226 return nullptr;
227 }
228 PyTuple_SET_ITEM(list.get(), i, item.release());
229 }
230 return list.release();
231 END_HANDLE_TH_ERRORS
232 }
233
234 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays)
235 static struct PyMethodDef default_methods[] = {
236 THP_FUNCTION_DEFAULT_METHODS,
237 {nullptr}};
238
239 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays)
240 static struct PyGetSetDef default_properties[] = {
241 THP_FUNCTION_DEFAULT_PROPERTIES,
242 {nullptr}};
243
_initFunctionPyTypeObject(PyTypeObject & type,const char * name,PyGetSetDef * function_properties,PyMethodDef * function_methods)244 PyTypeObject* _initFunctionPyTypeObject(
245 PyTypeObject& type,
246 const char* name,
247 PyGetSetDef* function_properties,
248 PyMethodDef* function_methods) {
249 type.ob_base = {PyObject_HEAD_INIT(nullptr) 0};
250 // NOLINTNEXTLINE(misc-redundant-expression)
251 type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC;
252 type.tp_name = name;
253 type.tp_basicsize = sizeof(THPCppFunction);
254 type.tp_call = THPCppFunction_call;
255 type.tp_methods = function_methods ? function_methods : default_methods;
256 type.tp_getset =
257 function_properties ? function_properties : default_properties;
258 type.tp_dealloc = THPCppFunction_dealloc;
259 type.tp_traverse = THPCppFunction_traverse;
260 type.tp_clear = THPCppFunction_clear;
261 if (PyType_Ready(&type) < 0) {
262 auto msg = std::string("Unable to instantiate PyTypeObject for ") + name;
263 throw std::runtime_error(msg);
264 }
265 return &type;
266 }
267
268 static std::unordered_map<std::type_index, THPObjectPtr> cpp_function_types_map;
269 static std::unordered_set<PyTypeObject*> cpp_function_types_set;
270
271 struct DefaultFunctionType {
DefaultFunctionTypetorch::autograd::DefaultFunctionType272 DefaultFunctionType() : type() {
273 _initFunctionPyTypeObject(type, "CppFunction", nullptr, nullptr);
274 }
275
276 PyTypeObject type;
277 };
278
get_default_type()279 PyTypeObject* get_default_type() {
280 static DefaultFunctionType default_type;
281 return &(default_type.type);
282 }
283
functionToPyObject(const std::shared_ptr<Node> & cdata)284 PyObject* functionToPyObject(const std::shared_ptr<Node>& cdata) {
285 if (!cdata) {
286 Py_RETURN_NONE;
287 }
288
289 if (auto pfw = dynamic_cast<PyNode*>(cdata.get())) {
290 PyObject* obj = pfw->obj;
291 Py_INCREF(obj);
292 return obj;
293 }
294
295 if (cdata->pyobj()) {
296 Py_INCREF(cdata->pyobj());
297 } else {
298 auto& fn = *cdata;
299 auto it = cpp_function_types_map.find(std::type_index(typeid(fn)));
300 PyTypeObject* type = nullptr;
301 if (it == cpp_function_types_map.end()) {
302 type = get_default_type();
303 } else {
304 type = (PyTypeObject*)it->second.get();
305 }
306
307 THPObjectPtr obj(type->tp_alloc(type, 0));
308 if (!obj)
309 return nullptr;
310 THPCppFunction* f = (THPCppFunction*)obj.get();
311 new (&f->cdata) std::shared_ptr<Node>(cdata);
312
313 // No INCREF here as we only have a weak reference
314 cdata->set_pyobj(obj.release());
315 }
316
317 return cdata->pyobj();
318 }
319
registerCppFunction(const std::type_info & type,PyTypeObject * pytype)320 void registerCppFunction(const std::type_info& type, PyTypeObject* pytype) {
321 Py_INCREF((PyObject*)pytype);
322 cpp_function_types_map[std::type_index(type)] =
323 THPObjectPtr((PyObject*)pytype);
324 cpp_function_types_set.insert(pytype);
325 }
326
THPCppFunction_Check(PyObject * obj)327 bool THPCppFunction_Check(PyObject* obj) {
328 THPObjectPtr type = THPObjectPtr(PyObject_Type(obj));
329 if ((PyTypeObject*)type.get() == get_default_type()) {
330 return true;
331 }
332 if (cpp_function_types_set.find((PyTypeObject*)type.get()) ==
333 cpp_function_types_set.end()) {
334 return false;
335 } else {
336 return true;
337 }
338 }
339
callRegisterFn(PyObject * dict,PyObject * hook)340 PyObject* callRegisterFn(PyObject* dict, PyObject* hook) {
341 THPObjectPtr register_fn(
342 PyObject_GetAttrString(THPFunctionClass, "_register_hook"));
343 if (!register_fn) {
344 return nullptr;
345 }
346 THPObjectPtr res(
347 PyObject_CallFunctionObjArgs(register_fn.get(), dict, hook, nullptr));
348 if (!res) {
349 return nullptr;
350 }
351 return res.release();
352 }
353
registerFunctionHook(Node & fn,PyObject * hook)354 PyObject* registerFunctionHook(Node& fn, PyObject* hook) {
355 PyObject* dict = Py_None;
356 for (const auto& hook : fn.post_hooks()) {
357 if (auto pyhook = dynamic_cast<PyFunctionPostHook*>(hook.get())) {
358 dict = pyhook->dict;
359 break;
360 }
361 }
362 THPObjectPtr res{callRegisterFn(dict, hook)};
363 if (!res) {
364 return nullptr;
365 }
366 if (dict == Py_None) {
367 dict = PyTuple_GET_ITEM(res.get(), 0);
368 fn.add_post_hook(std::make_unique<PyFunctionPostHook>(dict));
369 }
370
371 PyObject* handle = PyTuple_GET_ITEM(res.get(), 1);
372 Py_INCREF(handle);
373 return handle;
374 }
375
376 // This is almost a copy of the function above except post -> pre
registerFunctionPreHook(Node & fn,PyObject * hook)377 PyObject* registerFunctionPreHook(Node& fn, PyObject* hook) {
378 PyObject* dict = Py_None;
379 for (const auto& hook : fn.pre_hooks()) {
380 if (auto pyhook = dynamic_cast<PyFunctionPreHook*>(hook.get())) {
381 dict = pyhook->dict;
382 break;
383 }
384 }
385 THPObjectPtr res{callRegisterFn(dict, hook)};
386 if (!res) {
387 return nullptr;
388 }
389 if (dict == Py_None) {
390 dict = PyTuple_GET_ITEM(res.get(), 0);
391 fn.add_pre_hook(std::make_unique<PyFunctionPreHook>(dict));
392 }
393
394 PyObject* handle = PyTuple_GET_ITEM(res.get(), 1);
395 Py_INCREF(handle);
396 return handle;
397 }
398
399 } // namespace torch::autograd
400