1 #pragma once 2 3 #include <torch/csrc/Export.h> 4 #include <torch/csrc/autograd/forward_grad.h> 5 #include <torch/csrc/autograd/saved_variable_hooks.h> 6 7 #include <ATen/core/Tensor.h> 8 9 #include <cstdint> 10 #include <memory> 11 12 namespace torch::autograd { 13 14 using Variable = at::Tensor; 15 struct Node; 16 17 TORCH_API extern const char* ERR_BACKWARD_TWICE; 18 19 /// A snapshot of a variable at a certain version. A `SavedVariable` stores 20 /// enough information to reconstruct a variable from a certain point in time. 21 class TORCH_API SavedVariable { 22 public: 23 SavedVariable() = default; 24 SavedVariable( 25 const Variable& variable, 26 bool is_output, 27 bool is_inplace_on_view = false); 28 SavedVariable( 29 const std::optional<Variable>& variable, 30 bool is_output, 31 bool is_inplace_on_view = false); 32 SavedVariable(SavedVariable&&) = default; 33 SavedVariable& operator=(SavedVariable&&) = default; ~SavedVariable()34 ~SavedVariable() { 35 if (fw_grad_) { 36 // See note [ Using ForwardGrad ] 37 fw_grad_->clear(); 38 } 39 } 40 41 /// Reconstructs the saved variable. Pass `saved_for` as the gradient 42 /// function if constructing the `SavedVariable` with it would have caused a 43 /// circular reference. 44 Variable unpack(std::shared_ptr<Node> saved_for = nullptr) const; 45 46 void register_hooks(std::unique_ptr<SavedVariableHooks>&& hooks); 47 48 void reset_data(); 49 has_hooks()50 bool has_hooks() const { 51 return (bool)hooks_; 52 } 53 54 private: 55 // This field contains either: 56 // 1. the variable to save 57 // 2. or its tensor_data. 58 // If storing the variable itself would create a circular reference, 59 // we fall into the second case and its metadata is also saved separately. 60 // In that case, the grad_fn must be passed in to the unpack function when 61 // reconstructing the Variable (except when we are doing an inplace operation 62 // on a view, see below). The field saved_original_ below reflects the two 63 // cases: its value is true in the first case and false in the second case. 64 // The value data_.defined() can be false in three cases: 65 // 1. SavedVariable was constructed without a Tensor (the value to save is 66 // None), in that case was_default_constructed_ will be kept at true 67 // 2. The saved variable has been released by calling 68 // SavedVariable::reset_data(), typically during the backward pass 69 // 3. Hooks have been registered. In that case, hooks_ will be defined 70 // instead. Note that the value of saved_original_ only reflects what happened 71 // during the construction of the SavedVariable. If saved_original_ is true, 72 // we saved the original tensor in data_, but if the user registers hooks, we 73 // will no longer have it (despite the saved_original_ still being true) 74 at::Tensor data_; 75 76 // This field is used to store the forward AD gradients associated with 77 // the saved Tensor. Note that this shared_ptr must never be shared with 78 // either the saved Tensor or the unpacked Tensor. See note [ Using 79 // ForwardGrad ] 80 std::shared_ptr<ForwardGrad> fw_grad_; 81 82 // Weak version of grad_fn_ that prevents leaks in rebase_history() for 83 // inplace views. 84 // This variable is used when the user chooses to create a SavedVariable with 85 // is_inplace_on_view = true. 86 // In that case, the grad_fn passed in to the unpack function at unwrapping 87 // time is unused. 88 std::weak_ptr<Node> weak_grad_fn_; 89 90 uint32_t saved_version_ = 0; 91 uint32_t output_nr_ = 0; 92 bool was_default_constructed_ = true; 93 bool is_inplace_on_view_ = false; 94 bool saved_original_ = false; 95 bool is_leaf_ = false; 96 bool is_output_ = false; 97 98 // Hooks are a pair of functions pack_hook/unpack_hook that provides 99 // fine-grained control over how the SavedVariable should save its data. 100 // pack_hook is called upon registration, while unpack_hook is called when 101 // unpacking. 102 std::unique_ptr<SavedVariableHooks> hooks_; 103 // Fields grad_fn_, grad_accumulator_, and requires_grad_ are only used if 104 // hooks are defined. They are set before pack_hook is called and used after 105 // unpack_hook is called. 106 std::shared_ptr<Node> grad_fn_; 107 // For the usual case where leaf tensors are the input, we expect its 108 // grad_acc to be kept alive by the graph. The reason SavedVariable holds 109 // a owning reference is to support the case where a custom autograd Function 110 // saves an intermediate. 111 std::shared_ptr<Node> grad_accumulator_; 112 bool requires_grad_ = false; 113 114 void save_metadata(const Variable& data); 115 static std::unique_ptr<SavedVariableHooks> get_default_hooks(); 116 void set_hooks_and_pack_data( 117 std::unique_ptr<SavedVariableHooks>&& hooks, 118 const Variable& data); 119 }; 120 } // namespace torch::autograd 121