xref: /aosp_15_r20/external/pytorch/torch/csrc/autograd/python_cpp_function.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <c10/util/irange.h>
2 #include <torch/csrc/autograd/python_cpp_function.h>
3 
4 #include <torch/csrc/python_headers.h>
5 #include <cstdio>
6 #include <memory>
7 #include <typeindex>
8 #include <unordered_map>
9 
10 #include <pybind11/pybind11.h>
11 #include <torch/csrc/DynamicTypes.h>
12 #include <torch/csrc/Exceptions.h>
13 #include <torch/csrc/autograd/python_anomaly_mode.h>
14 #include <torch/csrc/autograd/python_function.h>
15 #include <torch/csrc/autograd/python_hook.h>
16 #include <torch/csrc/autograd/python_variable.h>
17 #include <torch/csrc/utils/pybind.h>
18 #include <torch/csrc/utils/python_numbers.h>
19 #include <torch/csrc/utils/python_strings.h>
20 
21 using namespace torch::autograd;
22 
23 namespace torch::autograd {
24 
25 namespace {
26 
THPCppFunction_call(PyObject * self,PyObject * args,PyObject * kwargs)27 PyObject* THPCppFunction_call(
28     PyObject* self,
29     PyObject* args,
30     PyObject* kwargs) {
31   if (kwargs && PyDict_Size(kwargs) != 0) {
32     return PyErr_Format(PyExc_TypeError, "keyword arguments are not supported");
33   }
34 
35   auto num_inputs = PyTuple_GET_SIZE(args);
36   auto num_inputs_required = ((THPCppFunction*)self)->cdata->num_inputs();
37   if (num_inputs != num_inputs_required) {
38     return PyErr_Format(
39         PyExc_TypeError,
40         "expected %d arguments, got %d instead",
41         num_inputs_required,
42         num_inputs);
43   }
44   variable_list vars(num_inputs);
45   for (int i = 0; i != num_inputs; ++i) {
46     PyObject* arg = PyTuple_GET_ITEM(args, i);
47     if (arg == Py_None) {
48       continue;
49     }
50     if (!THPVariable_Check(arg)) {
51       return PyErr_Format(PyExc_TypeError, "argument %d is not a Variable", i);
52     }
53     vars[i] = THPVariable_Unpack(arg);
54   }
55 
56   variable_list output;
57 
58   HANDLE_TH_ERRORS {
59     pybind11::gil_scoped_release nogil;
60     output = (*((THPCppFunction*)self)->cdata)(std::move(vars));
61   }
62   END_HANDLE_TH_ERRORS
63 
64   auto num_outputs = output.size();
65   if (num_outputs == 1) {
66     // assume we want to unpack one element tuples for now
67     return THPVariable_Wrap(output[0]);
68   }
69 
70   THPObjectPtr tuple(PyTuple_New(static_cast<Py_ssize_t>(num_outputs)));
71   for (size_t i = 0; i != num_outputs; ++i) {
72     PyTuple_SET_ITEM(tuple.get(), i, THPVariable_Wrap(output[i]));
73   }
74   return tuple.release();
75 }
76 
THPCppFunction_traverse(PyObject * self,visitproc visit,void * arg)77 int THPCppFunction_traverse(PyObject* self, visitproc visit, void* arg) {
78   if ((((THPCppFunction*)self)->cdata).use_count() == 1) {
79     // The fields traversed below are owned by the cpp grad_fn, which we own a
80     // reference to. We should only them traverse however if we are the only
81     // owner of the grad_fn, otherwise we risk prematurely gc'ing the grad_fn.
82     //
83     // See: https://github.com/pytorch/pytorch/issues/102174
84     auto& fn = *((THPCppFunction*)self)->cdata;
85     for (const auto& hook : fn.tensor_pre_hooks()) {
86       if (auto pyhook = dynamic_cast<PyFunctionTensorPreHook*>(hook.get())) {
87         Py_VISIT(pyhook->dict);
88       }
89     }
90     // NOTE [retains_grad_hook PyObject traversal]
91     // In theory this shouldn't be necessary, because retains_grad_hooks should
92     // not contain any PyFunctionTensorPreHooks. The alternative is to have a
93     // check that actually guarantees this.
94     for (const auto& pair : fn.retains_grad_hooks()) {
95       if (auto pyhook =
96               dynamic_cast<PyFunctionTensorPreHook*>(pair.second.get())) {
97         Py_VISIT(pyhook->dict);
98       }
99     }
100     for (const auto& hook : fn.pre_hooks()) {
101       if (auto pyhook = dynamic_cast<PyFunctionPreHook*>(hook.get())) {
102         Py_VISIT(pyhook->dict);
103       }
104     }
105     for (const auto& hook : fn.post_hooks()) {
106       if (auto pyhook = dynamic_cast<PyFunctionPostHook*>(hook.get())) {
107         Py_VISIT(pyhook->dict);
108       }
109     }
110   }
111   return 0;
112 }
113 
THPCppFunction_clear(PyObject * self)114 int THPCppFunction_clear(PyObject* self) {
115   auto f = (THPCppFunction*)self;
116   // Remove the weak ref of the c++ object if it exist
117   if (f->cdata) {
118     f->cdata->set_pyobj(nullptr);
119   }
120   f->cdata.reset();
121   return 0;
122 }
123 
THPCppFunction_dealloc(PyObject * self)124 void THPCppFunction_dealloc(PyObject* self) {
125   PyObject_GC_UnTrack(self);
126   THPCppFunction_clear(self);
127   ((THPCppFunction*)self)->cdata.~shared_ptr();
128   Py_TYPE(self)->tp_free(self);
129 }
130 
131 } // namespace
132 
THPCppFunction_next_functions(PyObject * self,void * _unused)133 PyObject* THPCppFunction_next_functions(PyObject* self, void* _unused) {
134   auto cdata = reinterpret_cast<const THPCppFunction*>(self)->cdata;
135   const auto num_next = cdata->num_outputs();
136   THPObjectPtr py_functions(PyTuple_New(num_next));
137   if (!py_functions)
138     return nullptr;
139   for (const auto i : c10::irange(num_next)) {
140     auto& c_tuple = cdata->next_edge(i);
141     THPObjectPtr tuple(PyTuple_New(2));
142     if (!tuple)
143       return nullptr;
144     PyObject* py_fn = functionToPyObject(c_tuple.function);
145     if (!py_fn)
146       return nullptr;
147     PyTuple_SET_ITEM(tuple.get(), 0, py_fn);
148     PyObject* py_idx = THPUtils_packUInt32(c_tuple.input_nr);
149     if (!py_idx)
150       return nullptr;
151     PyTuple_SET_ITEM(tuple.get(), 1, py_idx);
152     PyTuple_SET_ITEM(py_functions.get(), i, tuple.release());
153   }
154   return py_functions.release();
155 }
156 
THPCppFunction_metadata(PyObject * self,void * _unused)157 PyObject* THPCppFunction_metadata(PyObject* self, void* _unused) {
158   auto* metadata =
159       static_cast<PyAnomalyMetadata*>(
160           reinterpret_cast<THPCppFunction*>(self)->cdata->metadata())
161           ->dict();
162 
163   Py_XINCREF(metadata);
164   return metadata;
165 }
166 
THPCppFunction_requires_grad(PyObject * self,void * unused)167 PyObject* THPCppFunction_requires_grad(PyObject* self, void* unused) {
168   Py_RETURN_TRUE;
169 }
170 
THPCppFunction_register_hook_dict(PyObject * self,PyObject * _var)171 PyObject* THPCppFunction_register_hook_dict(PyObject* self, PyObject* _var) {
172   if (!THPVariable_Check(_var)) {
173     return PyErr_Format(
174         PyExc_TypeError, "_register_hook_dict expected a variable");
175   }
176   auto var = (THPVariable*)_var;
177   auto& fn = *((THPCppFunction*)self)->cdata;
178   fn.add_tensor_pre_hook(std::make_unique<PyFunctionTensorPreHook>(
179       var->backward_hooks, THPVariable_Unpack(var).output_nr()));
180   Py_RETURN_NONE;
181 }
182 
THPCppFunction_register_hook(PyObject * self,PyObject * hook)183 PyObject* THPCppFunction_register_hook(PyObject* self, PyObject* hook) {
184   auto& fn = *((THPCppFunction*)self)->cdata;
185   return registerFunctionHook(fn, hook);
186 }
187 
THPCppFunction_register_prehook(PyObject * self,PyObject * hook)188 PyObject* THPCppFunction_register_prehook(PyObject* self, PyObject* hook) {
189   auto& fn = *((THPCppFunction*)self)->cdata;
190   return registerFunctionPreHook(fn, hook);
191 }
192 
THPCppFunction_name(PyObject * self,PyObject * noargs)193 PyObject* THPCppFunction_name(PyObject* self, PyObject* noargs) {
194   auto& fn = *((THPCppFunction*)self)->cdata;
195   return THPUtils_packString(fn.name());
196 }
197 
THPCppFunction_sequence_nr(PyObject * self,PyObject * noargs)198 PyObject* THPCppFunction_sequence_nr(PyObject* self, PyObject* noargs) {
199   auto& fn = *((THPCppFunction*)self)->cdata;
200   return THPUtils_packUInt64(fn.sequence_nr());
201 }
202 
THPCppFunction_set_sequence_nr(PyObject * self,PyObject * sequence_nr)203 PyObject* THPCppFunction_set_sequence_nr(
204     PyObject* self,
205     PyObject* sequence_nr) {
206   HANDLE_TH_ERRORS
207   auto& fn = *((THPCppFunction*)self)->cdata;
208   fn.set_sequence_nr(THPUtils_unpackUInt64(sequence_nr));
209   Py_RETURN_NONE;
210   END_HANDLE_TH_ERRORS
211 }
212 
THPCppFunction_input_metadata(PyObject * self,void * closure)213 PyObject* THPCppFunction_input_metadata(PyObject* self, void* closure) {
214   HANDLE_TH_ERRORS;
215   auto& fn = *((THPCppFunction*)self)->cdata;
216   const auto num_inputs =
217       fn.num_inputs(); // Assuming there's a method to get the number of inputs
218   THPObjectPtr list(PyTuple_New(num_inputs));
219   if (!list) {
220     return nullptr;
221   }
222   for (size_t i = 0; i < num_inputs; ++i) {
223     const auto& metadata = fn.input_metadata(i);
224     THPObjectPtr item(py::cast(metadata).release().ptr());
225     if (!item) {
226       return nullptr;
227     }
228     PyTuple_SET_ITEM(list.get(), i, item.release());
229   }
230   return list.release();
231   END_HANDLE_TH_ERRORS
232 }
233 
234 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays)
235 static struct PyMethodDef default_methods[] = {
236     THP_FUNCTION_DEFAULT_METHODS,
237     {nullptr}};
238 
239 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays)
240 static struct PyGetSetDef default_properties[] = {
241     THP_FUNCTION_DEFAULT_PROPERTIES,
242     {nullptr}};
243 
_initFunctionPyTypeObject(PyTypeObject & type,const char * name,PyGetSetDef * function_properties,PyMethodDef * function_methods)244 PyTypeObject* _initFunctionPyTypeObject(
245     PyTypeObject& type,
246     const char* name,
247     PyGetSetDef* function_properties,
248     PyMethodDef* function_methods) {
249   type.ob_base = {PyObject_HEAD_INIT(nullptr) 0};
250   // NOLINTNEXTLINE(misc-redundant-expression)
251   type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC;
252   type.tp_name = name;
253   type.tp_basicsize = sizeof(THPCppFunction);
254   type.tp_call = THPCppFunction_call;
255   type.tp_methods = function_methods ? function_methods : default_methods;
256   type.tp_getset =
257       function_properties ? function_properties : default_properties;
258   type.tp_dealloc = THPCppFunction_dealloc;
259   type.tp_traverse = THPCppFunction_traverse;
260   type.tp_clear = THPCppFunction_clear;
261   if (PyType_Ready(&type) < 0) {
262     auto msg = std::string("Unable to instantiate PyTypeObject for ") + name;
263     throw std::runtime_error(msg);
264   }
265   return &type;
266 }
267 
268 static std::unordered_map<std::type_index, THPObjectPtr> cpp_function_types_map;
269 static std::unordered_set<PyTypeObject*> cpp_function_types_set;
270 
271 struct DefaultFunctionType {
DefaultFunctionTypetorch::autograd::DefaultFunctionType272   DefaultFunctionType() : type() {
273     _initFunctionPyTypeObject(type, "CppFunction", nullptr, nullptr);
274   }
275 
276   PyTypeObject type;
277 };
278 
get_default_type()279 PyTypeObject* get_default_type() {
280   static DefaultFunctionType default_type;
281   return &(default_type.type);
282 }
283 
functionToPyObject(const std::shared_ptr<Node> & cdata)284 PyObject* functionToPyObject(const std::shared_ptr<Node>& cdata) {
285   if (!cdata) {
286     Py_RETURN_NONE;
287   }
288 
289   if (auto pfw = dynamic_cast<PyNode*>(cdata.get())) {
290     PyObject* obj = pfw->obj;
291     Py_INCREF(obj);
292     return obj;
293   }
294 
295   if (cdata->pyobj()) {
296     Py_INCREF(cdata->pyobj());
297   } else {
298     auto& fn = *cdata;
299     auto it = cpp_function_types_map.find(std::type_index(typeid(fn)));
300     PyTypeObject* type = nullptr;
301     if (it == cpp_function_types_map.end()) {
302       type = get_default_type();
303     } else {
304       type = (PyTypeObject*)it->second.get();
305     }
306 
307     THPObjectPtr obj(type->tp_alloc(type, 0));
308     if (!obj)
309       return nullptr;
310     THPCppFunction* f = (THPCppFunction*)obj.get();
311     new (&f->cdata) std::shared_ptr<Node>(cdata);
312 
313     // No INCREF here as we only have a weak reference
314     cdata->set_pyobj(obj.release());
315   }
316 
317   return cdata->pyobj();
318 }
319 
registerCppFunction(const std::type_info & type,PyTypeObject * pytype)320 void registerCppFunction(const std::type_info& type, PyTypeObject* pytype) {
321   Py_INCREF((PyObject*)pytype);
322   cpp_function_types_map[std::type_index(type)] =
323       THPObjectPtr((PyObject*)pytype);
324   cpp_function_types_set.insert(pytype);
325 }
326 
THPCppFunction_Check(PyObject * obj)327 bool THPCppFunction_Check(PyObject* obj) {
328   THPObjectPtr type = THPObjectPtr(PyObject_Type(obj));
329   if ((PyTypeObject*)type.get() == get_default_type()) {
330     return true;
331   }
332   if (cpp_function_types_set.find((PyTypeObject*)type.get()) ==
333       cpp_function_types_set.end()) {
334     return false;
335   } else {
336     return true;
337   }
338 }
339 
callRegisterFn(PyObject * dict,PyObject * hook)340 PyObject* callRegisterFn(PyObject* dict, PyObject* hook) {
341   THPObjectPtr register_fn(
342       PyObject_GetAttrString(THPFunctionClass, "_register_hook"));
343   if (!register_fn) {
344     return nullptr;
345   }
346   THPObjectPtr res(
347       PyObject_CallFunctionObjArgs(register_fn.get(), dict, hook, nullptr));
348   if (!res) {
349     return nullptr;
350   }
351   return res.release();
352 }
353 
registerFunctionHook(Node & fn,PyObject * hook)354 PyObject* registerFunctionHook(Node& fn, PyObject* hook) {
355   PyObject* dict = Py_None;
356   for (const auto& hook : fn.post_hooks()) {
357     if (auto pyhook = dynamic_cast<PyFunctionPostHook*>(hook.get())) {
358       dict = pyhook->dict;
359       break;
360     }
361   }
362   THPObjectPtr res{callRegisterFn(dict, hook)};
363   if (!res) {
364     return nullptr;
365   }
366   if (dict == Py_None) {
367     dict = PyTuple_GET_ITEM(res.get(), 0);
368     fn.add_post_hook(std::make_unique<PyFunctionPostHook>(dict));
369   }
370 
371   PyObject* handle = PyTuple_GET_ITEM(res.get(), 1);
372   Py_INCREF(handle);
373   return handle;
374 }
375 
376 // This is almost a copy of the function above except post -> pre
registerFunctionPreHook(Node & fn,PyObject * hook)377 PyObject* registerFunctionPreHook(Node& fn, PyObject* hook) {
378   PyObject* dict = Py_None;
379   for (const auto& hook : fn.pre_hooks()) {
380     if (auto pyhook = dynamic_cast<PyFunctionPreHook*>(hook.get())) {
381       dict = pyhook->dict;
382       break;
383     }
384   }
385   THPObjectPtr res{callRegisterFn(dict, hook)};
386   if (!res) {
387     return nullptr;
388   }
389   if (dict == Py_None) {
390     dict = PyTuple_GET_ITEM(res.get(), 0);
391     fn.add_pre_hook(std::make_unique<PyFunctionPreHook>(dict));
392   }
393 
394   PyObject* handle = PyTuple_GET_ITEM(res.get(), 1);
395   Py_INCREF(handle);
396   return handle;
397 }
398 
399 } // namespace torch::autograd
400