xref: /aosp_15_r20/external/pytorch/torch/csrc/autograd/saved_variable.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/autograd/saved_variable.h>
2 
3 #include <torch/csrc/autograd/anomaly_mode.h>
4 #include <torch/csrc/autograd/edge.h>
5 #include <torch/csrc/autograd/engine.h>
6 #include <torch/csrc/autograd/function.h>
7 #include <torch/csrc/autograd/grad_mode.h>
8 #include <torch/csrc/autograd/variable.h>
9 
10 #include <ATen/Tensor.h>
11 
12 #include <memory>
13 #include <sstream>
14 
15 namespace torch::autograd {
16 
SavedVariable(const Variable & variable,bool is_output,bool is_inplace_on_view)17 SavedVariable::SavedVariable(
18     const Variable& variable,
19     bool is_output,
20     bool is_inplace_on_view) {
21   if (variable.defined()) {
22     // Note [Inference tensor cannot be saved for backward]
23     // Invariant:
24     //   You can't save an inference tensor for backwards.
25     // If an inference tensor was saved for backward in an autograd session and
26     // then you reenter inference mode and make an inplace update to the tensor
27     // without bumping version_counter, it'll lead to silent wrong result when
28     // you do backward() for the previous autograd session.  Technically we
29     // don't have to check here since it'll fail when querying `current_version`
30     // on the inference tensor, but we can give a much better error message
31     // here.
32     //
33     // Note in the documentation we say "inference tensor cannot participate
34     // in autograd" which is more restrictive than the invariant.  In practice
35     // the check is more permissive and only error out when an inference tensor
36     // is saved for backward.  Whether a tensor is saved for backward is
37     // determined by derivative formula and thus varies op by op, so by saying
38     // "no inference tensor in autograd" it's easier for users to understand and
39     // follow.
40     TORCH_CHECK(
41         !variable.is_inference(),
42         "Inference tensors cannot be saved for backward. To work around "
43         "you can make a clone to get a normal tensor and use it in autograd.")
44 
45     was_default_constructed_ = false;
46     saved_version_ = variable._version();
47     is_leaf_ = variable.is_leaf();
48     is_output_ = is_output;
49     is_inplace_on_view_ = is_inplace_on_view;
50 
51     if (is_inplace_on_view) {
52       TORCH_INTERNAL_ASSERT(!is_leaf_ && is_output);
53       weak_grad_fn_ = variable.grad_fn();
54     }
55 
56     auto maybe_hooks = get_default_hooks();
57 
58     // Avoid wrapped numbers from being leaked to the user
59     if (maybe_hooks && !variable.unsafeGetTensorImpl()->is_wrapped_number()) {
60       save_metadata(variable);
61       set_hooks_and_pack_data(std::move(maybe_hooks), variable);
62       return;
63     }
64 
65     // If the variable is a leaf or is not an output, we can safely save the
66     // original variable without running the risk of reference cycles.
67     // 1. If the variable is not an output, its grad_fn has already been fully
68     // created and in particular will be a different Node than the one
69     // we are currently constructing (the one that owns this SavedVariable).
70     // 2. If the variable is a leaf, it only has weak reference to the
71     // grad_accumulator which cannot create a cycle. In those cases, we save the
72     // original variable and don't need further processing.
73     if (!is_output || is_leaf_) {
74       saved_original_ = true;
75       data_ = variable;
76       return;
77     }
78 
79     save_metadata(variable);
80 
81     // Only do this if we actually need to.
82     data_ = variable.tensor_data();
83   }
84 }
85 
save_metadata(const Variable & data)86 void SavedVariable::save_metadata(const Variable& data) {
87   // Save output number, version counter and fw_grad if needed
88 
89   output_nr_ = data.output_nr();
90 
91   if (is_leaf_) {
92     grad_accumulator_ = impl::grad_accumulator(data);
93     requires_grad_ = data.requires_grad();
94   } else if (!is_output_) {
95     grad_fn_ = data.grad_fn();
96   }
97 
98   // TODO(albanD) This needs to be updated when moving to multiple levels
99   const auto& fw_grad = data._fw_grad(/* level */ 0);
100   if (fw_grad.defined()) {
101     fw_grad_ = std::make_shared<ForwardGrad>();
102     fw_grad_->set_value(fw_grad, /* level */ 0);
103   }
104 }
105 
get_default_hooks()106 std::unique_ptr<SavedVariableHooks> SavedVariable::get_default_hooks() {
107   return Engine::get_default_engine().get_default_saved_variable_hooks();
108 }
109 
reset_data()110 void SavedVariable::reset_data() {
111   hooks_.reset();
112   grad_fn_.reset();
113   data_.reset();
114 }
115 
SavedVariable(const std::optional<Variable> & variable,bool is_output,bool is_inplace_on_view)116 SavedVariable::SavedVariable(
117     const std::optional<Variable>& variable,
118     bool is_output,
119     bool is_inplace_on_view)
120     : SavedVariable(
121           variable.has_value() ? *variable : Variable(),
122           is_output,
123           is_inplace_on_view) {}
124 
unpack(std::shared_ptr<Node> saved_for) const125 Variable SavedVariable::unpack(std::shared_ptr<Node> saved_for) const {
126   if (was_default_constructed_) {
127     return Variable();
128   }
129 
130   if (!data_.defined()) {
131     TORCH_CHECK(hooks_, ERR_BACKWARD_TWICE);
132   }
133 
134   // We want grad_fn here to provide the most helpful debug message to the user
135   // if versions don't match
136 
137   auto grad_fn = is_inplace_on_view_ ? weak_grad_fn_.lock()
138       : !hooks_ ? saved_original_ ? data_.grad_fn() : nullptr
139                 : grad_fn_;
140 
141   if (!is_leaf_ && !grad_fn) {
142     // This issue was introduced when we added logic to save the original
143     // because now we rely on data_.grad_fn(), but can be unreliable if the
144     // autograd_meta of that saved tensor is cleared with an in-place detach.
145     // As a simple fix, we choose to disallow that behavior here even though
146     // it makes behavior inconsistent depending on whether you are saving
147     // input or output.
148     TORCH_CHECK(
149         saved_for,
150         "Trying to use a saved tensor that has been detached in-place, i.e. with .detach_()."
151         "This is not supported, please use out-of-place `.detach()` instead");
152     grad_fn = std::move(saved_for);
153   }
154 
155   // Only check version counter in the case without hooks
156   // If user provides hooks, we can't track versions through the hooks
157   if (!hooks_) {
158     auto current_version = impl::version_counter(data_).current_version();
159 
160     if (saved_version_ != current_version) {
161       std::stringstream message;
162       message
163           << "one of the variables needed for gradient computation has been "
164              "modified by an inplace operation: ["
165           << data_.toString() << " ";
166       if (data_.is_nested()) {
167         message << data_._nested_tensor_size() << "]";
168       } else {
169         message << data_.sizes() << "]";
170       }
171       if (grad_fn) {
172         message << ", which is output " << output_nr_ << " of "
173                 << grad_fn->name() << ",";
174       }
175       message << " is at version " << current_version << "; expected version "
176               << saved_version_ << " instead.";
177       if (!AnomalyMode::is_enabled()) {
178         message << " Hint: enable anomaly detection to find the operation "
179                    "that failed to compute its gradient, with torch.autograd."
180                    "set_detect_anomaly(True).";
181       } else {
182         message
183             << " Hint: the backtrace further above shows the operation "
184                "that failed to compute its gradient. The variable in question "
185                "was changed in there or anywhere later. Good luck!";
186       }
187       TORCH_CHECK(false, message.str());
188     }
189   }
190 
191   // The version counter is correct.
192   // Additionally, if we deal with a non-leaf variable, we have its correct
193   // grad_fn.
194 
195   // If we have the original variable, we simply return it
196   if (!hooks_ && saved_original_) {
197     return data_;
198   }
199 
200   auto data = hooks_ ? hooks_->call_unpack_hook() : data_;
201 
202   if (!grad_fn && !requires_grad_ && !data.requires_grad() &&
203       !(fw_grad_ && !fw_grad_->empty())) {
204     // Avoid detaching if we don't need to.
205     return data;
206   }
207 
208   // NB: saved views are unpacked as normal Variables (not views) even though
209   // they still share the same storage. This works only because we never call
210   // in-place functions on unpacked variables.
211   Variable var;
212   if (grad_fn) {
213     var = make_variable(data, Edge(std::move(grad_fn), output_nr_));
214   } else {
215     var = make_variable(data, requires_grad_);
216   }
217 
218   impl::set_grad_accumulator(var, grad_accumulator_);
219   impl::set_version_counter(var, impl::version_counter(data));
220 
221   // NB: var here is never a view so there is no need to make anything special
222   // for the case where the saved Tensor was a view. This whole argument relies
223   // on the fact that the Tensor returned by this function is never
224   // modified in-place.
225   if (fw_grad_ && !fw_grad_->empty()) {
226     // TODO(albanD) This needs to be updated when moving to multiple levels
227     auto new_fw_grad = fw_grad_->value(/* level */ 0);
228     var._set_fw_grad(new_fw_grad, /* level */ 0, /* is_inplace_op */ false);
229   }
230 
231   return var;
232 }
233 
set_hooks_and_pack_data(std::unique_ptr<SavedVariableHooks> && hooks,const Variable & data)234 void SavedVariable::set_hooks_and_pack_data(
235     std::unique_ptr<SavedVariableHooks>&& hooks,
236     const Variable& data) {
237   hooks_ = std::move(hooks);
238   at::NoGradGuard guard;
239   const auto version = impl::version_counter(data).current_version();
240   hooks_->call_pack_hook(saved_original_ ? data.detach() : data);
241   TORCH_CHECK(
242       version == impl::version_counter(data).current_version(),
243       "A saved tensor pack hook is modifying its input in place. "
244       "Tensors provided as input to pack hook can not be modified by "
245       "in-place operations as this can lead to unexpected side-effects. "
246       "Please open an issue if you need to perform in-place operations on "
247       "the input to a pack hook.");
248 }
249 
register_hooks(std::unique_ptr<SavedVariableHooks> && hooks)250 void SavedVariable::register_hooks(
251     std::unique_ptr<SavedVariableHooks>&& hooks) {
252   TORCH_INTERNAL_ASSERT(hooks);
253   TORCH_CHECK(
254       !hooks_,
255       "Calling register_hooks on a saved tensor whose hooks have already been set. "
256       "Hint: only one pair of hooks is allowed at a time.");
257   if (!data_.defined()) {
258     if (!was_default_constructed_) {
259       TORCH_CHECK(
260           false,
261           "Calling register_hooks on a saved tensor after it has been freed. "
262           "Saved intermediate values of the graph are freed when you call "
263           ".backward() or autograd.grad(). Specify retain_graph=True if you "
264           "need to backward through the graph a second time or if you need to "
265           "access saved variables after calling backward.");
266     } else {
267       TORCH_CHECK(
268           false,
269           "Calling register_hooks on a saved tensor with value None is forbidden");
270     }
271   }
272   // If we didn't save the original variable, we already saved metadata
273   if (saved_original_) {
274     save_metadata(data_);
275   }
276   set_hooks_and_pack_data(std::move(hooks), data_);
277   data_.reset();
278 }
279 
280 const char* ERR_BACKWARD_TWICE =
281     "Trying to backward through the graph a second time (or directly access saved "
282     "tensors after they have already been freed). Saved intermediate values "
283     "of the graph are freed when you call .backward() or autograd.grad(). Specify "
284     "retain_graph=True if you need to backward through the graph a second time or "
285     "if you need to access saved tensors after calling backward.";
286 
287 } // namespace torch::autograd
288