xref: /aosp_15_r20/external/pytorch/aten/src/ATen/core/VariableHooksInterface.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/core/Tensor.h>
4 #include <c10/macros/Export.h>
5 
6 // A little explanation about why this file exists at all.  We have
7 // a few methods on Tensor class which require access to reified access to
8 // AutogradMeta.  In open source, this isn't a big deal: we just access
9 // torch/csrc/autograd/variable.h from aten/src/ATen/core/Tensor.cpp and
10 // we can put the definitions inline.  This is because everything gets balled
11 // into a single dynamic library in the end.
12 //
13 // However, inside our Facebook internal version of our build system, we
14 // have a split between aten and torch/csrc.  So we cannot simply just
15 // cross this boundary.  "Now wait," you might say, "Why don't we just
16 // merge the libraries inside Facebook".  Well, the problem is that there
17 // are some downstream applications which are at binary size limit, and
18 // incorporating all of the extra code from libtorch would push them
19 // over (admarket/adreview/service:adreviewservice, see also
20 // https://github.com/pytorch/pytorch/pull/29299)  So if you want to do that,
21 // we have to fix all of the services like this.
22 //
23 // I didn't want to block eliminating Tensor-Variable on this work, so I
24 // had to introduce another dynamic dispatch to get to the variable
25 // implementations (which live in torch/csrc/autograd/variable.cpp, FYI).
26 //
27 // I also considered using our existing dynamic dispatch mechanism, c10
28 // dispatcher, to do this.  However, (1) some of the functions on Tensor
29 // have weird signatures that are not supported by autograd, and (2)
30 // see this bug https://github.com/pytorch/pytorch/issues/30102
31 
32 namespace torch::autograd {
33 
34 struct Node;
35 
36 } // namespace torch::autograd
37 
38 namespace at::impl {
39 
40 struct TORCH_API VariableHooksInterface {
41   virtual ~VariableHooksInterface() = default;
42   virtual TensorBase tensor_data(const TensorBase&) const = 0;
43   virtual TensorBase variable_data(const TensorBase&) const = 0;
44   virtual const std::shared_ptr<torch::autograd::Node>& grad_fn(
45       const TensorBase&) const = 0;
46   virtual unsigned _register_hook(
47       const TensorBase&,
48       std::function<TensorBase(const TensorBase&)> hook) const = 0;
49   virtual void remove_hook(const TensorBase&, unsigned pos) const = 0;
50   virtual bool is_view(const TensorBase&) const = 0;
51   virtual const TensorBase& base(const TensorBase&) const = 0;
52   virtual const std::string& name(const TensorBase&) const = 0;
53   virtual bool is_leaf(const TensorBase&) const = 0;
54   virtual int64_t output_nr(const TensorBase&) const = 0;
55   virtual void set_data(const TensorBase&, const TensorBase&) const = 0;
56   virtual TensorBase data(const TensorBase&) const = 0;
57   virtual int64_t _version(const TensorBase&) const = 0;
58   virtual void retain_grad(const TensorBase&) const = 0;
59   virtual bool retains_grad(const TensorBase&) const = 0;
60   virtual void _backward(
61       const Tensor&,
62       TensorList,
63       const std::optional<Tensor>&,
64       std::optional<bool>,
65       bool) const = 0;
66   virtual void requires_grad_(const TensorBase&, bool) const = 0;
67   virtual void basic_autograd_not_implemented_fallback(
68       const c10::OperatorHandle& op,
69       c10::DispatchKeySet dispatch_keys,
70       torch::jit::Stack* stack) const = 0;
71 };
72 
73 TORCH_API void SetVariableHooks(VariableHooksInterface* hooks);
74 TORCH_API VariableHooksInterface* GetVariableHooks();
75 TORCH_API bool HasVariableHooks();
76 
77 struct TORCH_API VariableHooksRegisterer {
VariableHooksRegistererVariableHooksRegisterer78   explicit VariableHooksRegisterer(VariableHooksInterface* hooks) {
79     SetVariableHooks(hooks);
80   }
81 };
82 
83 } // namespace at::impl
84