xref: /aosp_15_r20/external/pytorch/torch/csrc/autograd/python_saved_variable_hooks.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/SavedTensorHooks.h>
2 #include <torch/csrc/autograd/python_saved_variable_hooks.h>
3 
4 #include <c10/core/SafePyObject.h>
5 #include <torch/csrc/PyInterpreter.h>
6 #include <torch/csrc/THP.h>
7 
8 namespace py = pybind11;
9 
10 namespace torch::autograd {
PySavedVariableHooks(py::function & pack_hook,py::function & unpack_hook)11 PySavedVariableHooks::PySavedVariableHooks(
12     py::function& pack_hook,
13     py::function& unpack_hook)
14     : // steals the reference (we will decref ourselves)
15       pack_hook_(pack_hook.release().ptr()),
16       unpack_hook_(unpack_hook.release().ptr()) {}
17 
18 // We don't use pybind for call_pack_hook and call_unpack_hook to avoid
19 // https://github.com/pytorch/pytorch/issues/34172
call_pack_hook(const at::Tensor & tensor)20 void PySavedVariableHooks::call_pack_hook(const at::Tensor& tensor) {
21   py::gil_scoped_acquire acquire;
22   THPObjectPtr obj(THPVariable_Wrap(tensor));
23   THPObjectPtr packed(
24       PyObject_CallFunctionObjArgs(pack_hook_, obj.get(), nullptr));
25   if (!packed) {
26     throw python_error();
27   }
28   data_ = packed.release();
29   // obj is decrefed on exit, packed has their references stolen
30   // pack_hook_ and data_ will be manually decrefed when the saved variable is
31   // released
32 }
33 
call_unpack_hook()34 at::Tensor PySavedVariableHooks::call_unpack_hook() {
35   py::gil_scoped_acquire acquire;
36   THPObjectPtr res(PyObject_CallFunctionObjArgs(unpack_hook_, data_, nullptr));
37   if (!res) {
38     throw python_error();
39   }
40   TORCH_CHECK_TYPE(
41       THPVariable_Check(res),
42       "Output of saved tensor unpack_hook expected to be a Tensor but got result of type ",
43       THPUtils_typename(res));
44   return THPVariable_Unpack(res);
45   // res is decrefed on exit
46   // unpack_hook_ will be manually decrefed when the saved variable is released
47 }
48 
49 // NOLINTNEXTLINE(bugprone-exception-escape)
~PySavedVariableHooks()50 PySavedVariableHooks::~PySavedVariableHooks() {
51   // If python is already dead, leak the wrapped python objects
52   if (Py_IsInitialized()) {
53     py::gil_scoped_acquire gil;
54     Py_XDECREF(pack_hook_);
55     Py_XDECREF(unpack_hook_);
56     Py_XDECREF(data_);
57   }
58 }
59 
push_hooks(py::function & pack_hook,py::function & unpack_hook)60 void PyDefaultSavedVariableHooks::push_hooks(
61     py::function& pack_hook,
62     py::function& unpack_hook) {
63   at::SavedTensorDefaultHooks::lazy_initialize();
64   at::SavedTensorDefaultHooks::push_hooks(
65       c10::SafePyObject(pack_hook.release().ptr(), getPyInterpreter()),
66       c10::SafePyObject(unpack_hook.release().ptr(), getPyInterpreter()));
67 }
68 
pop_hooks()69 void PyDefaultSavedVariableHooks::pop_hooks() {
70   auto [pack_hook, unpack_hook] = at::SavedTensorDefaultHooks::pop_hooks();
71   TORCH_INTERNAL_ASSERT(
72       pack_hook.ptr(getPyInterpreter()) != nullptr &&
73       unpack_hook.ptr(getPyInterpreter()) != nullptr);
74 }
75 
get_hooks()76 std::unique_ptr<SavedVariableHooks> PyDefaultSavedVariableHooks::get_hooks() {
77   auto out = at::SavedTensorDefaultHooks::get_hooks();
78   if (!out.has_value()) {
79     return nullptr;
80   }
81   auto [pack_hook, unpack_hook] = *out;
82   py::gil_scoped_acquire gil;
83   py::function pack_hook_ =
84       py::reinterpret_steal<py::function>(pack_hook.release());
85   py::function unpack_hook_ =
86       py::reinterpret_steal<py::function>(unpack_hook.release());
87   return std::make_unique<PySavedVariableHooks>(pack_hook_, unpack_hook_);
88 }
89 
90 } // namespace torch::autograd
91