1 #pragma once 2 3 #include <ATen/core/Tensor.h> 4 #include <c10/macros/Export.h> 5 6 // A little explanation about why this file exists at all. We have 7 // a few methods on Tensor class which require access to reified access to 8 // AutogradMeta. In open source, this isn't a big deal: we just access 9 // torch/csrc/autograd/variable.h from aten/src/ATen/core/Tensor.cpp and 10 // we can put the definitions inline. This is because everything gets balled 11 // into a single dynamic library in the end. 12 // 13 // However, inside our Facebook internal version of our build system, we 14 // have a split between aten and torch/csrc. So we cannot simply just 15 // cross this boundary. "Now wait," you might say, "Why don't we just 16 // merge the libraries inside Facebook". Well, the problem is that there 17 // are some downstream applications which are at binary size limit, and 18 // incorporating all of the extra code from libtorch would push them 19 // over (admarket/adreview/service:adreviewservice, see also 20 // https://github.com/pytorch/pytorch/pull/29299) So if you want to do that, 21 // we have to fix all of the services like this. 22 // 23 // I didn't want to block eliminating Tensor-Variable on this work, so I 24 // had to introduce another dynamic dispatch to get to the variable 25 // implementations (which live in torch/csrc/autograd/variable.cpp, FYI). 26 // 27 // I also considered using our existing dynamic dispatch mechanism, c10 28 // dispatcher, to do this. However, (1) some of the functions on Tensor 29 // have weird signatures that are not supported by autograd, and (2) 30 // see this bug https://github.com/pytorch/pytorch/issues/30102 31 32 namespace torch::autograd { 33 34 struct Node; 35 36 } // namespace torch::autograd 37 38 namespace at::impl { 39 40 struct TORCH_API VariableHooksInterface { 41 virtual ~VariableHooksInterface() = default; 42 virtual TensorBase tensor_data(const TensorBase&) const = 0; 43 virtual TensorBase variable_data(const TensorBase&) const = 0; 44 virtual const std::shared_ptr<torch::autograd::Node>& grad_fn( 45 const TensorBase&) const = 0; 46 virtual unsigned _register_hook( 47 const TensorBase&, 48 std::function<TensorBase(const TensorBase&)> hook) const = 0; 49 virtual void remove_hook(const TensorBase&, unsigned pos) const = 0; 50 virtual bool is_view(const TensorBase&) const = 0; 51 virtual const TensorBase& base(const TensorBase&) const = 0; 52 virtual const std::string& name(const TensorBase&) const = 0; 53 virtual bool is_leaf(const TensorBase&) const = 0; 54 virtual int64_t output_nr(const TensorBase&) const = 0; 55 virtual void set_data(const TensorBase&, const TensorBase&) const = 0; 56 virtual TensorBase data(const TensorBase&) const = 0; 57 virtual int64_t _version(const TensorBase&) const = 0; 58 virtual void retain_grad(const TensorBase&) const = 0; 59 virtual bool retains_grad(const TensorBase&) const = 0; 60 virtual void _backward( 61 const Tensor&, 62 TensorList, 63 const std::optional<Tensor>&, 64 std::optional<bool>, 65 bool) const = 0; 66 virtual void requires_grad_(const TensorBase&, bool) const = 0; 67 virtual void basic_autograd_not_implemented_fallback( 68 const c10::OperatorHandle& op, 69 c10::DispatchKeySet dispatch_keys, 70 torch::jit::Stack* stack) const = 0; 71 }; 72 73 TORCH_API void SetVariableHooks(VariableHooksInterface* hooks); 74 TORCH_API VariableHooksInterface* GetVariableHooks(); 75 TORCH_API bool HasVariableHooks(); 76 77 struct TORCH_API VariableHooksRegisterer { VariableHooksRegistererVariableHooksRegisterer78 explicit VariableHooksRegisterer(VariableHooksInterface* hooks) { 79 SetVariableHooks(hooks); 80 } 81 }; 82 83 } // namespace at::impl 84