xref: /aosp_15_r20/external/pytorch/torch/csrc/autograd/python_saved_variable_hooks.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/ATen.h>
4 #include <pybind11/pybind11.h>
5 #include <torch/csrc/Export.h>
6 #include <torch/csrc/autograd/python_variable.h>
7 #include <torch/csrc/autograd/saved_variable_hooks.h>
8 #include <torch/csrc/python_headers.h>
9 #include <torch/csrc/utils/pybind.h>
10 
11 namespace py = pybind11;
12 
13 namespace torch::autograd {
14 
15 struct PySavedVariableHooks : public SavedVariableHooks {
16   PySavedVariableHooks(py::function& pack_hook, py::function& unpack_hook);
17   void call_pack_hook(const at::Tensor& tensor) override;
18   at::Tensor call_unpack_hook() override;
19   ~PySavedVariableHooks() override;
20 
21  private:
22   PyObject* pack_hook_;
23   PyObject* unpack_hook_;
24   PyObject* data_ = nullptr;
25 };
26 
27 struct PyDefaultSavedVariableHooks {
28   static void push_hooks(py::function& pack_hook, py::function& unpack_hook);
29   static void pop_hooks();
30   static std::unique_ptr<SavedVariableHooks> get_hooks();
31 };
32 
33 } // namespace torch::autograd
34