xref: /aosp_15_r20/external/pytorch/torch/csrc/autograd/python_hook.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/autograd/python_hook.h>
2 
3 #include <c10/util/irange.h>
4 #include <pybind11/pybind11.h>
5 #include <torch/csrc/Exceptions.h>
6 #include <torch/csrc/PyInterpreter.h>
7 #include <torch/csrc/THP.h>
8 #include <torch/csrc/autograd/python_variable.h>
9 #include <torch/csrc/dynamo/compiled_autograd.h>
10 #include <torch/csrc/utils/object_ptr.h>
11 #include <torch/csrc/utils/pybind.h>
12 #include <torch/csrc/utils/python_strings.h>
13 
14 #include <iostream>
15 #include <sstream>
16 
17 using torch::autograd::Variable;
18 using torch::autograd::variable_list;
19 
20 static PyObject* wrap_variables(const variable_list& c_variables);
21 static variable_list unwrap_variables(PyObject* py_variables);
22 static std::string hook_name(PyObject* hook);
23 static void check_result(PyObject* original, PyObject* result, PyObject* hook);
24 static void check_single_result(
25     PyObject* original,
26     PyObject* result,
27     PyObject* hook);
28 
29 namespace torch::autograd {
30 
31 namespace {
32 
33 // This function is called in 4 different cases:
34 //   1) TensorPreHook
35 //   2) PreHook
36 //   3) PostHook
37 //   4) TensorPostAccGradHook
38 //
39 // Depending on the case, args and res can hold different types of objects:
40 //
41 // args:
42 // TensorPreHook   (Tensor,)
43 // PreHook         ((Tensor, ...),)                (grad_outputs,)
44 // PostHook        ((Tensor, ...), (Tensor, ...))  (grad_inputs, grad_outputs)
45 // TensorPostAccGradHook  ((Tensor), ())                  (tensor,)
46 //
47 // res:
48 // TensorPreHook          Tensor
49 // PreHook                ((Tensor, ...),)                (grad_outputs,)
50 // PostHook               ((Tensor, ...),)                (grad_inputs,)
51 // TensorPostAccGradHook  None
52 //
53 // This function returns True if any hook returned non-None value, and False
54 // otherwise.
_call_hooks(PyObject * dict,PyObject * args)55 bool _call_hooks(PyObject* dict, PyObject* args) {
56   // Note: [Extend Hook Lifetime]
57   // Hold a reference to hooks till we iterate over them.
58   // This is to handle the case when hook calls `handle.remove` inside it
59   // and it's refcount goes to `0`, Python is free to GC it.
60   // We hold onto a stale pointer and subsequent call to
61   // `check_single_result`, which tries to fetch the `hook`'s name segfaults.
62   // So, we use `PyDict_Values` which returns a new reference to the values
63   // i.e. we hold the reference to the hooks till we have iterated over them.
64   // Reference: https://github.com/pytorch/pytorch/issues/58354
65   auto hooks = THPObjectPtr{PyDict_Values(dict)};
66   bool is_modified = false;
67   const auto len = PyList_Size(hooks);
68   for (Py_ssize_t idx = 0; idx < len; ++idx) {
69     const auto hook = PyList_GetItem(hooks, idx);
70 
71     THPObjectPtr res(PyObject_CallObject(hook, args));
72     if (!res)
73       throw python_error();
74     if (res == Py_None)
75       continue;
76 
77     PyObject* args0 = PyTuple_GetItem(args, 0);
78     if (res == args0)
79       continue;
80 
81     if (PyTuple_CheckExact(args0)) {
82       check_result(args0, res, hook);
83     } else {
84       check_single_result(args0, res, hook);
85     }
86     PyTuple_SetItem(args, 0, res.release());
87 
88     is_modified = true;
89   }
90   return is_modified;
91 }
92 
93 } // namespace
94 
PyFunctionTensorPreHook(PyObject * dict,size_t value_idx)95 PyFunctionTensorPreHook::PyFunctionTensorPreHook(
96     PyObject* dict,
97     size_t value_idx)
98     : dict(dict), value_idx(value_idx) {
99   Py_INCREF(dict);
100 }
101 
102 // NOLINTNEXTLINE(bugprone-exception-escape)
~PyFunctionTensorPreHook()103 PyFunctionTensorPreHook::~PyFunctionTensorPreHook() {
104   // If python is already dead, leak the wrapped python objects
105   if (Py_IsInitialized()) {
106     pybind11::gil_scoped_acquire gil;
107     Py_DECREF(dict);
108   }
109 }
110 
operator ()(const variable_list & values)111 auto PyFunctionTensorPreHook::operator()(const variable_list& values)
112     -> variable_list {
113   pybind11::gil_scoped_acquire gil;
114   THPObjectPtr value(THPVariable_Wrap(values.at(value_idx)));
115   if (!value)
116     throw python_error();
117   THPObjectPtr tup(PyTuple_New(1));
118   PyTuple_SET_ITEM(tup.get(), 0, value.release());
119   bool is_tup_modified = _call_hooks(dict, tup.get());
120   variable_list results(values);
121   if (is_tup_modified) {
122     results[value_idx] = THPVariable_Unpack(PyTuple_GetItem(tup.get(), 0));
123   }
124   return results;
125 }
126 
PyFunctionPreHook(PyObject * dict)127 PyFunctionPreHook::PyFunctionPreHook(PyObject* dict) : dict(dict) {
128   Py_INCREF(dict);
129 }
130 
131 // NOLINTNEXTLINE(bugprone-exception-escape)
~PyFunctionPreHook()132 PyFunctionPreHook::~PyFunctionPreHook() {
133   // If python is already dead, leak the wrapped python objects
134   if (Py_IsInitialized()) {
135     pybind11::gil_scoped_acquire gil;
136     Py_DECREF(dict);
137   }
138 }
139 
operator ()(const variable_list & grad_outputs_)140 auto PyFunctionPreHook::operator()(const variable_list& grad_outputs_)
141     -> variable_list {
142   pybind11::gil_scoped_acquire gil;
143   THPObjectPtr grad_outputs(wrap_variables(grad_outputs_));
144   THPObjectPtr tup(PyTuple_New(1));
145   PyTuple_SET_ITEM(tup.get(), 0, grad_outputs.release());
146   _call_hooks(dict, tup.get());
147   return unwrap_variables(PyTuple_GetItem(tup.get(), 0));
148 }
149 
PyFunctionPostHook(PyObject * dict)150 PyFunctionPostHook::PyFunctionPostHook(PyObject* dict) : dict(dict) {
151   Py_INCREF(dict);
152 }
153 
154 // NOLINTNEXTLINE(bugprone-exception-escape)
~PyFunctionPostHook()155 PyFunctionPostHook::~PyFunctionPostHook() {
156   // If python is already dead, leak the wrapped python objects
157   if (Py_IsInitialized()) {
158     pybind11::gil_scoped_acquire gil;
159     Py_DECREF(dict);
160   }
161 }
162 
operator ()(const variable_list & _outputs,const variable_list & _inputs)163 auto PyFunctionPostHook::operator()(
164     const variable_list& _outputs, /* grad_inputs */
165     const variable_list& _inputs /* grad_outputs */) -> variable_list {
166   pybind11::gil_scoped_acquire gil;
167   THPObjectPtr grad_inputs(wrap_variables(_outputs));
168   THPObjectPtr grad_outputs(wrap_variables(_inputs));
169   THPObjectPtr tup(PyTuple_New(2));
170   PyTuple_SET_ITEM(tup.get(), 0, grad_inputs.release());
171   PyTuple_SET_ITEM(tup.get(), 1, grad_outputs.release());
172   _call_hooks(dict, tup.get());
173   return unwrap_variables(PyTuple_GetItem(tup.get(), 0));
174 }
175 
compiled_args(CompiledNodeArgs & args)176 void PyFunctionTensorPreHook::compiled_args(CompiledNodeArgs& args) {
177   PyObject *key = nullptr, *value = nullptr;
178   Py_ssize_t pos = 0;
179   while (PyDict_Next(dict, &pos, &key, &value)) {
180     Py_INCREF(value);
181     args.add_tensor_pre_hook(
182         c10::SafePyObject(value, getPyInterpreter()),
183         static_cast<int>(value_idx));
184   }
185 }
186 
compiled_args(CompiledNodeArgs & args)187 void PyFunctionPreHook::compiled_args(CompiledNodeArgs& args) {
188   PyObject *key = nullptr, *value = nullptr;
189   Py_ssize_t pos = 0;
190   while (PyDict_Next(dict, &pos, &key, &value)) {
191     Py_INCREF(value);
192     args.add_pre_hook(c10::SafePyObject(value, getPyInterpreter()));
193   }
194 }
195 
compiled_args(CompiledNodeArgs & args)196 void PyFunctionPostHook::compiled_args(CompiledNodeArgs& args) {
197   PyObject *key = nullptr, *value = nullptr;
198   Py_ssize_t pos = 0;
199   while (PyDict_Next(dict, &pos, &key, &value)) {
200     Py_INCREF(value);
201     args.add_post_hook(c10::SafePyObject(value, getPyInterpreter()));
202   }
203 }
204 
PyFunctionTensorPostAccGradHooks(PyObject * dict)205 PyFunctionTensorPostAccGradHooks::PyFunctionTensorPostAccGradHooks(
206     PyObject* dict)
207     : dict(dict) {
208   Py_INCREF(dict);
209 }
210 
211 // NOLINTNEXTLINE(bugprone-exception-escape)
~PyFunctionTensorPostAccGradHooks()212 PyFunctionTensorPostAccGradHooks::~PyFunctionTensorPostAccGradHooks() {
213   // If python is already dead, leak the wrapped python objects
214   if (Py_IsInitialized()) {
215     pybind11::gil_scoped_acquire gil;
216     Py_DECREF(dict);
217   }
218 }
219 
operator ()(const Variable & tensor)220 auto PyFunctionTensorPostAccGradHooks::operator()(const Variable& tensor)
221     -> void {
222   pybind11::gil_scoped_acquire gil;
223   THPObjectPtr tup(PyTuple_New(1));
224   PyTuple_SET_ITEM(tup.get(), 0, THPVariable_Wrap(tensor));
225   bool returned_none = !_call_hooks(dict, tup.get());
226   TORCH_CHECK(
227       returned_none, "Tensor post accumulate grad hooks should return None.");
228 }
229 
compiled_args(torch::dynamo::autograd::CompiledNodeArgs & args)230 void PyFunctionTensorPostAccGradHooks::compiled_args(
231     torch::dynamo::autograd::CompiledNodeArgs& args) {
232   PyObject *key = nullptr, *value = nullptr;
233   Py_ssize_t pos = 0;
234   while (PyDict_Next(dict, &pos, &key, &value)) {
235     Py_INCREF(value);
236     c10::SafePyObject hook_obj(value, getPyInterpreter());
237     args.add_post_acc_grad_hook(std::move(hook_obj));
238   }
239 }
240 
apply_with_saved(Variable & tensor,torch::dynamo::autograd::SwapSavedVariables & saved)241 void PyFunctionTensorPostAccGradHooks::apply_with_saved(
242     Variable& tensor,
243     torch::dynamo::autograd::SwapSavedVariables& saved) {
244   for (const auto hook : saved.get_curr_node_call().post_acc_grad_hooks) {
245     THPObjectPtr py_var(THPVariable_Wrap(tensor));
246     PyObject_CallMethod(
247         saved.get_py_compiler(),
248         "post_acc_grad_hook",
249         "Oi",
250         py_var.get(),
251         hook);
252   }
253 }
254 
255 } // namespace torch::autograd
256 
wrap_variables(const variable_list & c_variables)257 static PyObject* wrap_variables(const variable_list& c_variables) {
258   size_t num_vars = c_variables.size();
259   THPObjectPtr tuple(PyTuple_New(static_cast<Py_ssize_t>(num_vars)));
260   if (!tuple)
261     throw python_error();
262   for (const auto i : c10::irange(num_vars)) {
263     THPObjectPtr var(THPVariable_Wrap(c_variables[i]));
264     if (!var)
265       throw python_error();
266     PyTuple_SET_ITEM(tuple.get(), i, var.release());
267   }
268   return tuple.release();
269 }
270 
unwrap_variables(PyObject * py_variables)271 static variable_list unwrap_variables(PyObject* py_variables) {
272   variable_list results(PyTuple_GET_SIZE(py_variables));
273   for (const auto i : c10::irange(results.size())) {
274     PyObject* item = PyTuple_GET_ITEM(py_variables, i);
275     if (item == Py_None) {
276       continue;
277     } else if (THPVariable_Check(item)) {
278       results[i] = THPVariable_Unpack(item);
279     } else {
280       // this should never happen, but just in case...
281       std::stringstream ss;
282       ss << "expected variable but got " << Py_TYPE(item)->tp_name;
283       throw std::runtime_error(ss.str());
284     }
285   }
286   return results;
287 }
288 
check_result(PyObject * prev,PyObject * result,PyObject * hook)289 static void check_result(PyObject* prev, PyObject* result, PyObject* hook) {
290   if (!PyTuple_Check(result)) {
291     PyErr_Format(
292         PyExc_TypeError,
293         "expected tuple, but hook returned '%s'",
294         THPUtils_typename(result));
295     throw python_error();
296   }
297 
298   auto prev_size = PyTuple_GET_SIZE(prev);
299   auto result_size = PyTuple_GET_SIZE(result);
300   if (prev_size != result_size) {
301     std::stringstream ss;
302     auto name = hook_name(hook);
303     ss << "hook '" << name << "' has returned an incorrect number ";
304     ss << "of values (got " << result_size << ", but expected ";
305     ss << prev_size << ")";
306     throw std::runtime_error(ss.str());
307   }
308 
309   for (const auto i : c10::irange(prev_size)) {
310     check_single_result(
311         PyTuple_GET_ITEM(prev, i), PyTuple_GET_ITEM(result, i), hook);
312   }
313 }
314 
check_single_result(PyObject * _original,PyObject * _result,PyObject * hook)315 static void check_single_result(
316     PyObject* _original,
317     PyObject* _result,
318     PyObject* hook) {
319   if (_result == Py_None)
320     return;
321 
322   if (_original == Py_None) {
323     throw std::runtime_error(
324         "can't replace a None gradient with a non-None value");
325   }
326 
327   if (!PyObject_IsInstance(_result, THPVariableClass)) {
328     PyErr_Format(
329         PyExc_TypeError,
330         "expected Variable, but hook returned '%s'",
331         THPUtils_typename(_result));
332     throw python_error();
333   }
334 
335   const auto& original = THPVariable_Unpack(_original);
336   const auto& result = THPVariable_Unpack(_result);
337 
338   torch::autograd::check_variable_result(original, result, hook_name(hook));
339 }
340 
hook_name(PyObject * hook)341 static std::string hook_name(PyObject* hook) {
342   if (PyObject_HasAttrString(hook, "__name__")) {
343     THPObjectPtr name(PyObject_GetAttrString(hook, "__name__"));
344     if (!name)
345       throw python_error();
346 
347     if (name && THPUtils_checkString(name.get())) {
348       return THPUtils_unpackString(name.get());
349     }
350   }
351   return "<unknown>";
352 }
353