1 #pragma once 2 3 #include <ATen/Tensor.h> 4 #include <torch/csrc/Export.h> 5 #include <string> 6 #include <vector> 7 8 namespace torch::dynamo::autograd { 9 class CompiledNodeArgs; 10 class SwapSavedVariables; 11 } // namespace torch::dynamo::autograd 12 13 // A hook that's called on gradients 14 15 namespace torch::autograd { 16 17 using Variable = at::Tensor; 18 using variable_list = std::vector<Variable>; 19 20 struct TORCH_API FunctionPreHook { 21 virtual ~FunctionPreHook() = default; 22 virtual variable_list operator()(const variable_list& grads) = 0; 23 // only implemented for python hooks, registers hook with compiled autograd compiled_argsFunctionPreHook24 virtual void compiled_args(torch::dynamo::autograd::CompiledNodeArgs& args) { 25 throw std::runtime_error( 26 std::string("compiled_args nyi, see [Note: Compiled Autograd] ") + 27 typeid(*this).name()); 28 } 29 }; 30 31 struct TORCH_API FunctionPostHook { 32 virtual ~FunctionPostHook() = default; 33 virtual variable_list operator()( 34 const variable_list& outputs /* grad_inputs */, 35 const variable_list& inputs /* grad_outputs */) = 0; 36 // only implemented for python hooks, registers hook with compiled autograd compiled_argsFunctionPostHook37 virtual void compiled_args(torch::dynamo::autograd::CompiledNodeArgs& args) { 38 throw std::runtime_error( 39 std::string("compiled_args nyi, see [Note: Compiled Autograd] ") + 40 typeid(*this).name()); 41 } 42 }; 43 44 struct TORCH_API PostAccumulateGradHook { 45 virtual ~PostAccumulateGradHook() = default; 46 virtual void operator()(const Variable& tensor) = 0; 47 // only implemented for python hooks on nodes, registers hook with compiled 48 // autograd compiled_argsPostAccumulateGradHook49 virtual void compiled_args(torch::dynamo::autograd::CompiledNodeArgs& args) { 50 throw std::runtime_error( 51 std::string("not yet implemented for compiled autograd: ") + 52 typeid(*this).name()); 53 } 54 apply_with_savedPostAccumulateGradHook55 virtual void apply_with_saved( 56 Variable&, 57 torch::dynamo::autograd::SwapSavedVariables&) { 58 throw std::runtime_error( 59 std::string("not yet implemented for compiled autograd: ") + 60 typeid(*this).name()); 61 } 62 }; 63 64 } // namespace torch::autograd 65