xref: /aosp_15_r20/external/pytorch/torch/csrc/autograd/function_hook.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/Tensor.h>
4 #include <torch/csrc/Export.h>
5 #include <string>
6 #include <vector>
7 
8 namespace torch::dynamo::autograd {
9 class CompiledNodeArgs;
10 class SwapSavedVariables;
11 } // namespace torch::dynamo::autograd
12 
13 // A hook that's called on gradients
14 
15 namespace torch::autograd {
16 
17 using Variable = at::Tensor;
18 using variable_list = std::vector<Variable>;
19 
20 struct TORCH_API FunctionPreHook {
21   virtual ~FunctionPreHook() = default;
22   virtual variable_list operator()(const variable_list& grads) = 0;
23   // only implemented for python hooks, registers hook with compiled autograd
compiled_argsFunctionPreHook24   virtual void compiled_args(torch::dynamo::autograd::CompiledNodeArgs& args) {
25     throw std::runtime_error(
26         std::string("compiled_args nyi, see [Note: Compiled Autograd] ") +
27         typeid(*this).name());
28   }
29 };
30 
31 struct TORCH_API FunctionPostHook {
32   virtual ~FunctionPostHook() = default;
33   virtual variable_list operator()(
34       const variable_list& outputs /* grad_inputs */,
35       const variable_list& inputs /* grad_outputs */) = 0;
36   // only implemented for python hooks, registers hook with compiled autograd
compiled_argsFunctionPostHook37   virtual void compiled_args(torch::dynamo::autograd::CompiledNodeArgs& args) {
38     throw std::runtime_error(
39         std::string("compiled_args nyi, see [Note: Compiled Autograd] ") +
40         typeid(*this).name());
41   }
42 };
43 
44 struct TORCH_API PostAccumulateGradHook {
45   virtual ~PostAccumulateGradHook() = default;
46   virtual void operator()(const Variable& tensor) = 0;
47   // only implemented for python hooks on nodes, registers hook with compiled
48   // autograd
compiled_argsPostAccumulateGradHook49   virtual void compiled_args(torch::dynamo::autograd::CompiledNodeArgs& args) {
50     throw std::runtime_error(
51         std::string("not yet implemented for compiled autograd: ") +
52         typeid(*this).name());
53   }
54 
apply_with_savedPostAccumulateGradHook55   virtual void apply_with_saved(
56       Variable&,
57       torch::dynamo::autograd::SwapSavedVariables&) {
58     throw std::runtime_error(
59         std::string("not yet implemented for compiled autograd: ") +
60         typeid(*this).name());
61   }
62 };
63 
64 } // namespace torch::autograd
65