1 #include <torch/csrc/autograd/python_hook.h>
2
3 #include <c10/util/irange.h>
4 #include <pybind11/pybind11.h>
5 #include <torch/csrc/Exceptions.h>
6 #include <torch/csrc/PyInterpreter.h>
7 #include <torch/csrc/THP.h>
8 #include <torch/csrc/autograd/python_variable.h>
9 #include <torch/csrc/dynamo/compiled_autograd.h>
10 #include <torch/csrc/utils/object_ptr.h>
11 #include <torch/csrc/utils/pybind.h>
12 #include <torch/csrc/utils/python_strings.h>
13
14 #include <iostream>
15 #include <sstream>
16
17 using torch::autograd::Variable;
18 using torch::autograd::variable_list;
19
20 static PyObject* wrap_variables(const variable_list& c_variables);
21 static variable_list unwrap_variables(PyObject* py_variables);
22 static std::string hook_name(PyObject* hook);
23 static void check_result(PyObject* original, PyObject* result, PyObject* hook);
24 static void check_single_result(
25 PyObject* original,
26 PyObject* result,
27 PyObject* hook);
28
29 namespace torch::autograd {
30
31 namespace {
32
33 // This function is called in 4 different cases:
34 // 1) TensorPreHook
35 // 2) PreHook
36 // 3) PostHook
37 // 4) TensorPostAccGradHook
38 //
39 // Depending on the case, args and res can hold different types of objects:
40 //
41 // args:
42 // TensorPreHook (Tensor,)
43 // PreHook ((Tensor, ...),) (grad_outputs,)
44 // PostHook ((Tensor, ...), (Tensor, ...)) (grad_inputs, grad_outputs)
45 // TensorPostAccGradHook ((Tensor), ()) (tensor,)
46 //
47 // res:
48 // TensorPreHook Tensor
49 // PreHook ((Tensor, ...),) (grad_outputs,)
50 // PostHook ((Tensor, ...),) (grad_inputs,)
51 // TensorPostAccGradHook None
52 //
53 // This function returns True if any hook returned non-None value, and False
54 // otherwise.
_call_hooks(PyObject * dict,PyObject * args)55 bool _call_hooks(PyObject* dict, PyObject* args) {
56 // Note: [Extend Hook Lifetime]
57 // Hold a reference to hooks till we iterate over them.
58 // This is to handle the case when hook calls `handle.remove` inside it
59 // and it's refcount goes to `0`, Python is free to GC it.
60 // We hold onto a stale pointer and subsequent call to
61 // `check_single_result`, which tries to fetch the `hook`'s name segfaults.
62 // So, we use `PyDict_Values` which returns a new reference to the values
63 // i.e. we hold the reference to the hooks till we have iterated over them.
64 // Reference: https://github.com/pytorch/pytorch/issues/58354
65 auto hooks = THPObjectPtr{PyDict_Values(dict)};
66 bool is_modified = false;
67 const auto len = PyList_Size(hooks);
68 for (Py_ssize_t idx = 0; idx < len; ++idx) {
69 const auto hook = PyList_GetItem(hooks, idx);
70
71 THPObjectPtr res(PyObject_CallObject(hook, args));
72 if (!res)
73 throw python_error();
74 if (res == Py_None)
75 continue;
76
77 PyObject* args0 = PyTuple_GetItem(args, 0);
78 if (res == args0)
79 continue;
80
81 if (PyTuple_CheckExact(args0)) {
82 check_result(args0, res, hook);
83 } else {
84 check_single_result(args0, res, hook);
85 }
86 PyTuple_SetItem(args, 0, res.release());
87
88 is_modified = true;
89 }
90 return is_modified;
91 }
92
93 } // namespace
94
PyFunctionTensorPreHook(PyObject * dict,size_t value_idx)95 PyFunctionTensorPreHook::PyFunctionTensorPreHook(
96 PyObject* dict,
97 size_t value_idx)
98 : dict(dict), value_idx(value_idx) {
99 Py_INCREF(dict);
100 }
101
102 // NOLINTNEXTLINE(bugprone-exception-escape)
~PyFunctionTensorPreHook()103 PyFunctionTensorPreHook::~PyFunctionTensorPreHook() {
104 // If python is already dead, leak the wrapped python objects
105 if (Py_IsInitialized()) {
106 pybind11::gil_scoped_acquire gil;
107 Py_DECREF(dict);
108 }
109 }
110
operator ()(const variable_list & values)111 auto PyFunctionTensorPreHook::operator()(const variable_list& values)
112 -> variable_list {
113 pybind11::gil_scoped_acquire gil;
114 THPObjectPtr value(THPVariable_Wrap(values.at(value_idx)));
115 if (!value)
116 throw python_error();
117 THPObjectPtr tup(PyTuple_New(1));
118 PyTuple_SET_ITEM(tup.get(), 0, value.release());
119 bool is_tup_modified = _call_hooks(dict, tup.get());
120 variable_list results(values);
121 if (is_tup_modified) {
122 results[value_idx] = THPVariable_Unpack(PyTuple_GetItem(tup.get(), 0));
123 }
124 return results;
125 }
126
PyFunctionPreHook(PyObject * dict)127 PyFunctionPreHook::PyFunctionPreHook(PyObject* dict) : dict(dict) {
128 Py_INCREF(dict);
129 }
130
131 // NOLINTNEXTLINE(bugprone-exception-escape)
~PyFunctionPreHook()132 PyFunctionPreHook::~PyFunctionPreHook() {
133 // If python is already dead, leak the wrapped python objects
134 if (Py_IsInitialized()) {
135 pybind11::gil_scoped_acquire gil;
136 Py_DECREF(dict);
137 }
138 }
139
operator ()(const variable_list & grad_outputs_)140 auto PyFunctionPreHook::operator()(const variable_list& grad_outputs_)
141 -> variable_list {
142 pybind11::gil_scoped_acquire gil;
143 THPObjectPtr grad_outputs(wrap_variables(grad_outputs_));
144 THPObjectPtr tup(PyTuple_New(1));
145 PyTuple_SET_ITEM(tup.get(), 0, grad_outputs.release());
146 _call_hooks(dict, tup.get());
147 return unwrap_variables(PyTuple_GetItem(tup.get(), 0));
148 }
149
PyFunctionPostHook(PyObject * dict)150 PyFunctionPostHook::PyFunctionPostHook(PyObject* dict) : dict(dict) {
151 Py_INCREF(dict);
152 }
153
154 // NOLINTNEXTLINE(bugprone-exception-escape)
~PyFunctionPostHook()155 PyFunctionPostHook::~PyFunctionPostHook() {
156 // If python is already dead, leak the wrapped python objects
157 if (Py_IsInitialized()) {
158 pybind11::gil_scoped_acquire gil;
159 Py_DECREF(dict);
160 }
161 }
162
operator ()(const variable_list & _outputs,const variable_list & _inputs)163 auto PyFunctionPostHook::operator()(
164 const variable_list& _outputs, /* grad_inputs */
165 const variable_list& _inputs /* grad_outputs */) -> variable_list {
166 pybind11::gil_scoped_acquire gil;
167 THPObjectPtr grad_inputs(wrap_variables(_outputs));
168 THPObjectPtr grad_outputs(wrap_variables(_inputs));
169 THPObjectPtr tup(PyTuple_New(2));
170 PyTuple_SET_ITEM(tup.get(), 0, grad_inputs.release());
171 PyTuple_SET_ITEM(tup.get(), 1, grad_outputs.release());
172 _call_hooks(dict, tup.get());
173 return unwrap_variables(PyTuple_GetItem(tup.get(), 0));
174 }
175
compiled_args(CompiledNodeArgs & args)176 void PyFunctionTensorPreHook::compiled_args(CompiledNodeArgs& args) {
177 PyObject *key = nullptr, *value = nullptr;
178 Py_ssize_t pos = 0;
179 while (PyDict_Next(dict, &pos, &key, &value)) {
180 Py_INCREF(value);
181 args.add_tensor_pre_hook(
182 c10::SafePyObject(value, getPyInterpreter()),
183 static_cast<int>(value_idx));
184 }
185 }
186
compiled_args(CompiledNodeArgs & args)187 void PyFunctionPreHook::compiled_args(CompiledNodeArgs& args) {
188 PyObject *key = nullptr, *value = nullptr;
189 Py_ssize_t pos = 0;
190 while (PyDict_Next(dict, &pos, &key, &value)) {
191 Py_INCREF(value);
192 args.add_pre_hook(c10::SafePyObject(value, getPyInterpreter()));
193 }
194 }
195
compiled_args(CompiledNodeArgs & args)196 void PyFunctionPostHook::compiled_args(CompiledNodeArgs& args) {
197 PyObject *key = nullptr, *value = nullptr;
198 Py_ssize_t pos = 0;
199 while (PyDict_Next(dict, &pos, &key, &value)) {
200 Py_INCREF(value);
201 args.add_post_hook(c10::SafePyObject(value, getPyInterpreter()));
202 }
203 }
204
PyFunctionTensorPostAccGradHooks(PyObject * dict)205 PyFunctionTensorPostAccGradHooks::PyFunctionTensorPostAccGradHooks(
206 PyObject* dict)
207 : dict(dict) {
208 Py_INCREF(dict);
209 }
210
211 // NOLINTNEXTLINE(bugprone-exception-escape)
~PyFunctionTensorPostAccGradHooks()212 PyFunctionTensorPostAccGradHooks::~PyFunctionTensorPostAccGradHooks() {
213 // If python is already dead, leak the wrapped python objects
214 if (Py_IsInitialized()) {
215 pybind11::gil_scoped_acquire gil;
216 Py_DECREF(dict);
217 }
218 }
219
operator ()(const Variable & tensor)220 auto PyFunctionTensorPostAccGradHooks::operator()(const Variable& tensor)
221 -> void {
222 pybind11::gil_scoped_acquire gil;
223 THPObjectPtr tup(PyTuple_New(1));
224 PyTuple_SET_ITEM(tup.get(), 0, THPVariable_Wrap(tensor));
225 bool returned_none = !_call_hooks(dict, tup.get());
226 TORCH_CHECK(
227 returned_none, "Tensor post accumulate grad hooks should return None.");
228 }
229
compiled_args(torch::dynamo::autograd::CompiledNodeArgs & args)230 void PyFunctionTensorPostAccGradHooks::compiled_args(
231 torch::dynamo::autograd::CompiledNodeArgs& args) {
232 PyObject *key = nullptr, *value = nullptr;
233 Py_ssize_t pos = 0;
234 while (PyDict_Next(dict, &pos, &key, &value)) {
235 Py_INCREF(value);
236 c10::SafePyObject hook_obj(value, getPyInterpreter());
237 args.add_post_acc_grad_hook(std::move(hook_obj));
238 }
239 }
240
apply_with_saved(Variable & tensor,torch::dynamo::autograd::SwapSavedVariables & saved)241 void PyFunctionTensorPostAccGradHooks::apply_with_saved(
242 Variable& tensor,
243 torch::dynamo::autograd::SwapSavedVariables& saved) {
244 for (const auto hook : saved.get_curr_node_call().post_acc_grad_hooks) {
245 THPObjectPtr py_var(THPVariable_Wrap(tensor));
246 PyObject_CallMethod(
247 saved.get_py_compiler(),
248 "post_acc_grad_hook",
249 "Oi",
250 py_var.get(),
251 hook);
252 }
253 }
254
255 } // namespace torch::autograd
256
wrap_variables(const variable_list & c_variables)257 static PyObject* wrap_variables(const variable_list& c_variables) {
258 size_t num_vars = c_variables.size();
259 THPObjectPtr tuple(PyTuple_New(static_cast<Py_ssize_t>(num_vars)));
260 if (!tuple)
261 throw python_error();
262 for (const auto i : c10::irange(num_vars)) {
263 THPObjectPtr var(THPVariable_Wrap(c_variables[i]));
264 if (!var)
265 throw python_error();
266 PyTuple_SET_ITEM(tuple.get(), i, var.release());
267 }
268 return tuple.release();
269 }
270
unwrap_variables(PyObject * py_variables)271 static variable_list unwrap_variables(PyObject* py_variables) {
272 variable_list results(PyTuple_GET_SIZE(py_variables));
273 for (const auto i : c10::irange(results.size())) {
274 PyObject* item = PyTuple_GET_ITEM(py_variables, i);
275 if (item == Py_None) {
276 continue;
277 } else if (THPVariable_Check(item)) {
278 results[i] = THPVariable_Unpack(item);
279 } else {
280 // this should never happen, but just in case...
281 std::stringstream ss;
282 ss << "expected variable but got " << Py_TYPE(item)->tp_name;
283 throw std::runtime_error(ss.str());
284 }
285 }
286 return results;
287 }
288
check_result(PyObject * prev,PyObject * result,PyObject * hook)289 static void check_result(PyObject* prev, PyObject* result, PyObject* hook) {
290 if (!PyTuple_Check(result)) {
291 PyErr_Format(
292 PyExc_TypeError,
293 "expected tuple, but hook returned '%s'",
294 THPUtils_typename(result));
295 throw python_error();
296 }
297
298 auto prev_size = PyTuple_GET_SIZE(prev);
299 auto result_size = PyTuple_GET_SIZE(result);
300 if (prev_size != result_size) {
301 std::stringstream ss;
302 auto name = hook_name(hook);
303 ss << "hook '" << name << "' has returned an incorrect number ";
304 ss << "of values (got " << result_size << ", but expected ";
305 ss << prev_size << ")";
306 throw std::runtime_error(ss.str());
307 }
308
309 for (const auto i : c10::irange(prev_size)) {
310 check_single_result(
311 PyTuple_GET_ITEM(prev, i), PyTuple_GET_ITEM(result, i), hook);
312 }
313 }
314
check_single_result(PyObject * _original,PyObject * _result,PyObject * hook)315 static void check_single_result(
316 PyObject* _original,
317 PyObject* _result,
318 PyObject* hook) {
319 if (_result == Py_None)
320 return;
321
322 if (_original == Py_None) {
323 throw std::runtime_error(
324 "can't replace a None gradient with a non-None value");
325 }
326
327 if (!PyObject_IsInstance(_result, THPVariableClass)) {
328 PyErr_Format(
329 PyExc_TypeError,
330 "expected Variable, but hook returned '%s'",
331 THPUtils_typename(_result));
332 throw python_error();
333 }
334
335 const auto& original = THPVariable_Unpack(_original);
336 const auto& result = THPVariable_Unpack(_result);
337
338 torch::autograd::check_variable_result(original, result, hook_name(hook));
339 }
340
hook_name(PyObject * hook)341 static std::string hook_name(PyObject* hook) {
342 if (PyObject_HasAttrString(hook, "__name__")) {
343 THPObjectPtr name(PyObject_GetAttrString(hook, "__name__"));
344 if (!name)
345 throw python_error();
346
347 if (name && THPUtils_checkString(name.get())) {
348 return THPUtils_unpackString(name.get());
349 }
350 }
351 return "<unknown>";
352 }
353