1 #pragma once 2 3 #include <ATen/ATen.h> 4 #include <pybind11/pybind11.h> 5 #include <torch/csrc/Export.h> 6 #include <torch/csrc/autograd/python_variable.h> 7 #include <torch/csrc/autograd/saved_variable_hooks.h> 8 #include <torch/csrc/python_headers.h> 9 #include <torch/csrc/utils/pybind.h> 10 11 namespace py = pybind11; 12 13 namespace torch::autograd { 14 15 struct PySavedVariableHooks : public SavedVariableHooks { 16 PySavedVariableHooks(py::function& pack_hook, py::function& unpack_hook); 17 void call_pack_hook(const at::Tensor& tensor) override; 18 at::Tensor call_unpack_hook() override; 19 ~PySavedVariableHooks() override; 20 21 private: 22 PyObject* pack_hook_; 23 PyObject* unpack_hook_; 24 PyObject* data_ = nullptr; 25 }; 26 27 struct PyDefaultSavedVariableHooks { 28 static void push_hooks(py::function& pack_hook, py::function& unpack_hook); 29 static void pop_hooks(); 30 static std::unique_ptr<SavedVariableHooks> get_hooks(); 31 }; 32 33 } // namespace torch::autograd 34