xref: /aosp_15_r20/external/pytorch/torch/csrc/autograd/custom_function.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <c10/util/irange.h>
2 #include <torch/csrc/autograd/autograd.h>
3 #include <torch/csrc/autograd/custom_function.h>
4 #include <torch/csrc/autograd/functions/accumulate_grad.h>
5 
6 #include <utility>
7 
8 namespace torch::autograd {
9 
10 // This function has two main goals:
11 //  1) Use the user-provided jvp function to populate the outputs' forward
12 //  gradient 2) Perform error checking to ensure that view and inplace ops are
13 //  properly handled
14 //
15 // For 1) we have to:
16 //  - Create a variable_list of grad_inputs based on the function inputs
17 //  - Call the user jvp function with these to get the grad_outputs
18 //  - Set the forward grad field on each output based on these grad_outputs
19 //
20 // For 2) we want to check the following:
21 //  - If an output is a view, then the generated forward grad must be a view as
22 //  well and
23 //    the output's base's forward grad must be the output's forward grad's base.
24 //  - If an input was modified inplace (it must be an output as well) we make
25 //  sure that its
26 //    forward grad was also modified inplace and already present on the
27 //    corresponding output.
_process_forward_mode_AD(const variable_list & inputs,std::unordered_map<at::TensorImpl *,size_t> inputs_mapping,const at::ArrayRef<std::optional<Variable>> raw_outputs,const optional_variable_list & outputs,const std::unordered_set<at::TensorImpl * > & non_differentiable,const std::unordered_set<at::TensorImpl * > & dirty_inputs,const _jvp_fn_t & jvp_user_function)28 static void _process_forward_mode_AD(
29     const variable_list& inputs,
30     std::unordered_map<at::TensorImpl*, size_t> inputs_mapping,
31     const at::ArrayRef<std::optional<Variable>> raw_outputs,
32     const optional_variable_list& outputs,
33     const std::unordered_set<at::TensorImpl*>& non_differentiable,
34     const std::unordered_set<at::TensorImpl*>& dirty_inputs,
35     const _jvp_fn_t& jvp_user_function) {
36   // TODO handle multiple levels here
37   uint64_t level = 0;
38 
39   const auto num_inputs = inputs.size();
40   const auto num_outputs = outputs.size();
41 
42   // The tracking info below are used to perform the view and inplace checks.
43   // They are lazily initialized to reduce the cost of this function in the
44   // common case where the user is not using forward mode AD.
45   variable_list input_grads;
46   std::vector<int64_t> grad_versions;
47   std::vector<at::TensorImpl*> grad_impls;
48   std::unordered_map<at::TensorImpl*, size_t> inputs_bases;
49 
50   auto init_tracked_info = [&]() {
51     input_grads.resize(num_inputs);
52     grad_versions.resize(num_inputs);
53     grad_impls.resize(num_inputs);
54 
55     for (const auto i : c10::irange(num_inputs)) {
56       const auto& inp = inputs[i];
57       if (inp.is_view() && impl::get_view_autograd_meta(inp)->has_fw_view()) {
58         inputs_bases.emplace(
59             impl::get_view_autograd_meta(inp)
60                 ->get_forward_view()
61                 .base_.unsafeGetTensorImpl(),
62             i);
63       } else {
64         inputs_bases.emplace(inp.unsafeGetTensorImpl(), i);
65       }
66     }
67   };
68 
69   bool any_input_has_grad = false;
70   // Extract the input's forward gradients and record any info we will need
71   // later
72   for (const auto i : c10::irange(num_inputs)) {
73     const auto& inp = inputs[i];
74     if (!inp.defined()) {
75       continue;
76     }
77     const auto& fw_grad = inp._fw_grad(level);
78     if (fw_grad.defined()) {
79       if (!any_input_has_grad) {
80         any_input_has_grad = true;
81         init_tracked_info();
82       }
83       input_grads[i] = fw_grad;
84       grad_versions[i] = fw_grad._version();
85       grad_impls[i] = fw_grad.unsafeGetTensorImpl();
86     }
87   }
88 
89   // If no input has forward grad, nothing to do here
90   if (!any_input_has_grad) {
91     return;
92   }
93 
94   torch::autograd::variable_list forward_grads;
95   {
96     at::AutoFwGradMode fw_grad_mode(false);
97     forward_grads = jvp_user_function(inputs, std::move(input_grads));
98   }
99 
100   const auto num_forward_grads = forward_grads.size();
101   // contrary to backward mode, we don't allow returning too many gradients
102   TORCH_CHECK(
103       num_forward_grads == num_outputs,
104       "Function's jvp returned "
105       "an invalid number of forward gradients (expected ",
106       num_outputs,
107       " but got ",
108       num_forward_grads,
109       ")");
110 
111   for (const auto i : c10::irange(num_outputs)) {
112     if (!raw_outputs[i].has_value()) {
113       continue;
114     }
115     const auto& out =
116         outputs[i].has_value() ? outputs[i].value() : at::Tensor();
117     auto out_tensor_impl = raw_outputs[i].value().unsafeGetTensorImpl();
118     bool is_differentiable =
119         (non_differentiable.count(out_tensor_impl) == 0 &&
120          isDifferentiableType(raw_outputs[i].value().scalar_type()));
121     const auto& out_grad = forward_grads[i];
122     if (!out.defined() || !is_differentiable) {
123       TORCH_CHECK(
124           !out_grad.defined(),
125           "Function's jvp returned a gradient at position ",
126           i,
127           ", but "
128           " the corresponding forward output is not a differentiable Tensor."
129           "You should return None at that position instead.");
130       continue;
131     }
132 
133     bool is_input = inputs_mapping.count(out_tensor_impl) > 0;
134     bool is_modified = dirty_inputs.count(out_tensor_impl) > 0;
135 
136     if (is_modified) {
137       TORCH_CHECK(
138           is_input,
139           "Only input Tensors should be given to ctx.mark_dirty(). If a Tensor is not an input, there"
140           " is no need to pass it to mark_dirty().");
141       auto inp_idx = inputs_mapping[out_tensor_impl];
142       if (grad_impls[inp_idx]) {
143         // If there was already a forward grad for that input
144         // Just make sure that it is modified inplace and returned as-is
145         TORCH_CHECK(
146             out_grad._version() != grad_versions[inp_idx],
147             "An inplace custom Function is not modifying the "
148             "forward mode gradients inplace. If the forward is modifying an input inplace, then the jvp "
149             "function must modify the corresponding gradient inplace.")
150         TORCH_CHECK(
151             out_grad.unsafeGetTensorImpl() == grad_impls[inp_idx],
152             "An inplace custom Function is not returning the "
153             "forward mode gradients as-is. If the forward is modifying an input inplace, then the jvp "
154             "function must modify the gradient inplace and return it as-is.")
155       } else {
156         // If that Tensor didn't had gradients already, set the newly returned
157         // one We could also use inputs[inp_idx] here as it is the same as out
158         out._set_fw_grad(out_grad, level, /* is_inplace_op */ true);
159       }
160     } else {
161       // At this point, outputs[i] cannot be one of the input (raw_outputs[i]
162       // might be but was changed by the backward code)
163       TORCH_INTERNAL_ASSERT(
164           inputs_mapping.count(out.unsafeGetTensorImpl()) == 0);
165 
166       if (out.is_view() && impl::get_view_autograd_meta(out)->has_fw_view()) {
167         // If the output is a view
168         const auto& out_view_info =
169             impl::get_view_autograd_meta(out)->get_forward_view();
170         if (inputs_bases.count(out_view_info.base_.unsafeGetTensorImpl())) {
171           // And it is a view of an input (either that input is its base or they
172           // have a common base)
173           const auto matching_input_idx =
174               inputs_bases[out_view_info.base_.unsafeGetTensorImpl()];
175           const auto& matching_input = inputs[matching_input_idx];
176 
177           const auto& matching_input_grad = matching_input._fw_grad(level);
178 
179           // If the matching input has a forward grad, the user should have
180           // returned a view of that Tensor
181           if (matching_input_grad.defined()) {
182             TORCH_CHECK(
183                 out_grad.is_view() &&
184                     impl::get_view_autograd_meta(out_grad)->has_fw_view(),
185                 "A custom Function's forward is returning a view (or an input as-is) but the jvp is not "
186                 "returning a view.");
187             const auto& out_grad_base = impl::get_view_autograd_meta(out_grad)
188                                             ->get_forward_view()
189                                             .base_;
190             if (matching_input_grad.is_view() &&
191                 impl::get_view_autograd_meta(matching_input_grad)
192                     ->has_fw_view()) {
193               // If the matching input's grad is a view, ensure that the
194               // out_grad is a view of the same base
195               const auto& matching_input_grad_base =
196                   impl::get_view_autograd_meta(matching_input_grad)
197                       ->get_forward_view()
198                       .base_;
199               TORCH_CHECK(
200                   matching_input_grad_base.unsafeGetTensorImpl() ==
201                       out_grad_base.unsafeGetTensorImpl(),
202                   "A custom Function is returning a view but the jvp is not returning a view of the same base as "
203                   "the given grad input.");
204             } else {
205               // If the matching input's grad is not a view, then it must be the
206               // output gradient's base
207               TORCH_CHECK(
208                   matching_input_grad.unsafeGetTensorImpl() ==
209                       out_grad_base.unsafeGetTensorImpl(),
210                   "A custom Function is returning a view but the jvp is not returning a view of the given grad input.");
211             }
212           } else {
213             // We have a view op where the input didn't have a forward grad but
214             // the user returned one for the output To ensure that we maintain
215             // the view/inplace constraints, we consider this as an inplace op
216             // This case CANNOT happen in codegen as all view ops are mapping
217             // from one Tensor to one Tensor and so the output of the view
218             // cannot have a forward grad if the base does not.
219             out._set_fw_grad(out_grad, level, /* is_inplace_op */ true);
220             return;
221           }
222         }
223       }
224 
225       out._set_fw_grad(out_grad, level, /* is_inplace_op */ false);
226     }
227   }
228 }
229 
_view_as_self_with_no_grad(const at::Tensor & self,const _view_as_self_fn_t & view_as_self_fn)230 static at::Tensor _view_as_self_with_no_grad(
231     const at::Tensor& self,
232     const _view_as_self_fn_t& view_as_self_fn) {
233   // This is called below in _process_backward_mode_ad in two places:
234   //
235   // (1) An input has been returned, but it wasn't modified. Return it as a view
236   // so that we can attach a new grad_fn to the Variable.
237   // Run in no_grad mode to mimic the behavior of the forward.
238   //
239   // (2) Though it is not necessary for the purposes of attaching grad_fn, we
240   // also call this function when an output is non-differentiable (and does not
241   // require grad). to help custom forward AD UX more consistent. We'd like to
242   // uniformly say that returning an input as-is is treated as if
243   // `self.view_as(self)` were returned for that output.
244   //
245   // Alternatively, we could have not disabled forward grad while performing
246   // this view, but it would mean that the user defined jvp may be silently
247   // ignored.
248   at::AutoFwGradMode fw_grad_mode(false);
249   AutoGradMode grad_mode(false);
250   // We thread through this view_as_self_fn lambda so that in the case we are a
251   // Python custom function (rather than a cpp one), we can properly call the
252   // view_as from python so that torch function logic can still trigger.
253   return view_as_self_fn(self);
254 }
255 
_process_backward_mode_ad(const std::unordered_map<at::TensorImpl *,size_t> & inputs_mapping,const std::unordered_set<at::TensorImpl * > & non_differentiable,const std::unordered_set<at::TensorImpl * > & dirty_inputs,const at::ArrayRef<std::optional<Variable>> raw_outputs,const std::shared_ptr<Node> & cdata,const std::unordered_set<at::TensorImpl * > & to_save_if_setup_context,const _view_as_self_fn_t & view_as_self_fn)256 static optional_variable_list _process_backward_mode_ad(
257     const std::unordered_map<at::TensorImpl*, size_t>& inputs_mapping,
258     const std::unordered_set<at::TensorImpl*>& non_differentiable,
259     const std::unordered_set<at::TensorImpl*>& dirty_inputs,
260     const at::ArrayRef<std::optional<Variable>> raw_outputs,
261     const std::shared_ptr<Node>& cdata,
262     const std::unordered_set<at::TensorImpl*>& to_save_if_setup_context,
263     const _view_as_self_fn_t& view_as_self_fn) {
264   auto num_outputs = raw_outputs.size();
265 
266 #ifndef STRIP_ERROR_MESSAGES
267   const char* error_msg_input_returned_as_is =
268       "A input that has been returned as-is as output is being saved for backward. "
269       "This is not supported if you override setup_context. You should return and "
270       "save a view of the input instead, e.g. with x.view_as(x) or setup ctx inside "
271       "the forward function itself.";
272 #endif
273 
274   // Sets the grad_fn and output_nr of an output Variable.
275   auto set_history = [&](Variable& var,
276                          uint32_t output_nr,
277                          bool is_input,
278                          bool is_modified,
279                          bool is_differentiable,
280                          bool is_saved_and_setup_context) {
281     if (!is_differentiable) {
282       if (!var.requires_grad()) {
283         if (is_input && !is_modified) {
284           TORCH_CHECK(
285               !is_saved_and_setup_context, error_msg_input_returned_as_is)
286           var = _view_as_self_with_no_grad(var, view_as_self_fn);
287         }
288         return;
289       }
290       // Return detached aliases of inputs, instead of changing their
291       // requires_grad property.
292       if (is_input) {
293         var = var.detach();
294       } else if (!var.is_view()) {
295         var.detach_();
296       }
297       // If var is a view of one of the inputs of the custom autograd Function,
298       // we don't detach it in a no_grad block. This is so that we can mimic the
299       // behavior of returning a view from a no_grad block:
300       //   x = torch.randn(3, requires_grad=True)
301       //   with torch.no_grad():
302       //       y = x.view(-1)
303       // Here, `y` requires_grad (!).
304     } else if (is_modified) {
305       if (var.is_leaf() && var.requires_grad()) {
306         TORCH_CHECK(
307             false,
308             "a leaf Variable that requires grad has been used in an in-place operation.");
309       }
310       // No need to mark as modified Tensors that are not inputs.
311       if (!is_input) {
312         TORCH_WARN(
313             "Only input Tensors should be given to ctx.mark_dirty(). If a Tensor is not an input, there"
314             " is no need to pass it to mark_dirty().");
315       }
316       // If the input is a view, the rebase will need to rewrite the graph and
317       // this only works if we have a single output to this Function.
318       TORCH_CHECK(
319           !(var.is_view() && num_outputs > 1),
320           "If your Function modifies inplace an input that is a view"
321           " of another Tensor, your Function cannot return more than one Tensor. This is not supported"
322           " by the current autograd engine. You should either make sure the input is not a view (using"
323           " .clone() for example) or make your Function only return one Tensor (potentially splitting"
324           " it into two Functions: one doing the inplace that returns a single Tensor and a second one"
325           " that does the other operations). You can ask on the forum https://discuss.pytorch.org/ if"
326           " you need help to do this change.");
327 
328       // If the input was modified, transplant the grad_fn in the graph:
329       // grad_fn <- variable <- self  ==>  grad_fn <- self <- variable
330       var.mutable_grad().reset();
331       impl::clear_hooks(var);
332       if (auto grad_acc_fn = impl::try_get_grad_accumulator(var)) {
333         auto& grad_acc = dynamic_cast<AccumulateGrad&>(*grad_acc_fn);
334         grad_acc.variable.reset();
335       }
336       if (cdata) {
337         impl::rebase_history(var, {cdata, output_nr});
338       }
339     } else if (is_input) {
340       TORCH_CHECK(!is_saved_and_setup_context, error_msg_input_returned_as_is)
341       var = _view_as_self_with_no_grad(var, view_as_self_fn);
342       impl::set_gradient_edge(var, {cdata, output_nr});
343     } else if (cdata) {
344       impl::set_gradient_edge(var, {cdata, output_nr});
345     }
346   };
347 
348   optional_variable_list outputs;
349   std::unordered_set<at::TensorImpl*> outputs_impl; // For dirty_inputs check
350   outputs.reserve(num_outputs);
351   int num_diff_outputs = 0;
352 
353   for (const auto i : c10::irange(num_outputs)) {
354     // We put a undefined_input placeholder for outputs that are not tensor and
355     // for when the output tensor is not differentiable (see below)
356     if (!raw_outputs[i].has_value()) {
357       if (cdata) {
358         auto output_nr = cdata->add_input_metadata(Node::undefined_input());
359         AT_ASSERT(i == output_nr);
360       }
361       outputs.emplace_back();
362       continue;
363     }
364 
365     // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
366     Variable var = raw_outputs[i].value();
367 
368     auto out_tensor_impl = var.unsafeGetTensorImpl();
369     bool is_input = inputs_mapping.count(out_tensor_impl) > 0;
370     bool is_modified = dirty_inputs.count(out_tensor_impl) > 0;
371     bool is_differentiable = cdata &&
372         non_differentiable.count(out_tensor_impl) == 0 &&
373         isDifferentiableType(var.scalar_type());
374     bool is_saved_and_setup_context =
375         to_save_if_setup_context.count(out_tensor_impl) > 0;
376 
377     if (cdata) {
378       uint32_t output_nr = 0;
379       if (!is_differentiable) {
380         output_nr = cdata->add_input_metadata(Node::undefined_input());
381       } else {
382         output_nr = cdata->add_input_metadata(var);
383       }
384       AT_ASSERT(i == output_nr);
385     }
386     set_history(
387         var,
388         i,
389         is_input,
390         is_modified,
391         is_differentiable,
392         is_saved_and_setup_context);
393 
394     // For deprecation cycle. Can be removed after 1.6. In the case where we
395     // detected a view in no grad mode during the forward, only warn the user
396     // (do not change the flag if we return and input that is a view as is). See
397     // NOTE [ View + Inplace detection ] for why we replace everything by a
398     // warning.
399     if (!(is_input && is_modified) && var.is_view()) {
400       // is_view() => diff_view_meta
401       auto diff_view_meta = impl::get_view_autograd_meta(var);
402       diff_view_meta->set_creation_meta(CreationMeta::IN_CUSTOM_FUNCTION);
403     }
404 
405     if (is_differentiable) {
406       ++num_diff_outputs;
407     }
408 
409     outputs_impl.insert(out_tensor_impl);
410     outputs.emplace_back(var);
411   }
412 
413   // If multiple differentiable outputs are returned, we do not allow views to
414   // be modified inplace See NOTE [ View + Inplace detection ] for more details
415   if (num_diff_outputs > 1) {
416     for (auto& var : outputs) {
417       if (var.has_value()) {
418         auto diff_view_meta = impl::get_view_autograd_meta(var.value());
419         if (diff_view_meta && diff_view_meta->has_bw_view()) {
420           diff_view_meta->set_creation_meta(CreationMeta::MULTI_OUTPUT_NODE);
421         }
422       }
423     }
424   }
425 
426   // All the modified Tensors must be returned as is for the rewrite to be
427   // valid.
428   for (auto& dirty_input : dirty_inputs) {
429     TORCH_CHECK(
430         outputs_impl.count(dirty_input) > 0,
431         "Some elements marked as dirty during the forward method were not returned as output. The"
432         " inputs that are modified inplace must all be outputs of the Function.");
433   }
434 
435   return outputs;
436 }
437 
_wrap_outputs(const variable_list & input_vars,const std::unordered_set<at::TensorImpl * > & non_differentiable,const std::unordered_set<at::TensorImpl * > & dirty_inputs,const at::ArrayRef<std::optional<Variable>> raw_outputs,const std::shared_ptr<Node> & cdata,const _jvp_fn_t & jvp_user_function,const std::unordered_set<at::TensorImpl * > & to_save_if_setup_context,const _view_as_self_fn_t & view_as_self_fn)438 optional_variable_list _wrap_outputs(
439     const variable_list& input_vars,
440     const std::unordered_set<at::TensorImpl*>& non_differentiable,
441     const std::unordered_set<at::TensorImpl*>& dirty_inputs,
442     const at::ArrayRef<std::optional<Variable>> raw_outputs,
443     const std::shared_ptr<Node>& cdata,
444     const _jvp_fn_t& jvp_user_function,
445     const std::unordered_set<at::TensorImpl*>& to_save_if_setup_context,
446     const _view_as_self_fn_t& view_as_self_fn) {
447   std::unordered_map<at::TensorImpl*, size_t> inputs_mapping;
448   inputs_mapping.reserve(input_vars.size());
449   for (const auto i : c10::irange(input_vars.size())) {
450     inputs_mapping.emplace(input_vars[i].unsafeGetTensorImpl(), i);
451   }
452 
453   auto outputs = _process_backward_mode_ad(
454       inputs_mapping,
455       non_differentiable,
456       dirty_inputs,
457       raw_outputs,
458       cdata,
459       to_save_if_setup_context,
460       view_as_self_fn);
461 
462   // This must happen after the backward processing as we expect the
463   // computations happening here to track backward mode gradients.
464   _process_forward_mode_AD(
465       input_vars,
466       std::move(inputs_mapping),
467       raw_outputs,
468       outputs,
469       non_differentiable,
470       dirty_inputs,
471       jvp_user_function);
472 
473   return outputs;
474 }
475 
check_variable_result(const at::TensorBase & original,const at::TensorBase & result,const std::string & hook_name)476 void check_variable_result(
477     const at::TensorBase& original,
478     const at::TensorBase& result,
479     const std::string& hook_name) {
480   if (!original.options().type_equal(result.options())) {
481     std::stringstream ss;
482     ss << "hook '" << hook_name << "' has changed the type of value (";
483     ss << "was " << original.toString() << " got ";
484     ss << result.toString() << ")";
485     throw std::runtime_error(ss.str());
486   }
487 
488   if (original.is_cuda() != result.is_cuda()) {
489     std::stringstream ss;
490     ss << "hook '" << hook_name << "' has changed the type of value";
491     if (original.is_cuda()) {
492       ss << " (was CUDA tensor got CPU tensor)";
493     } else {
494       ss << " (was CPU tensor got CUDA tensor)";
495     }
496     throw std::runtime_error(ss.str());
497   }
498 
499   if (original.sym_sizes().vec() != result.sym_sizes().vec()) {
500     std::stringstream ss;
501     ss << "hook '" << hook_name << "' has changed the size of value";
502     throw std::runtime_error(ss.str());
503   }
504 }
505 
save_for_backward(variable_list to_save)506 void AutogradContext::save_for_backward(variable_list to_save) {
507   to_save_ = std::move(to_save);
508 }
509 
510 // The logic for handling saved variables here is the same as
511 // python_function.cpp See _save_variables() and unpack_saved_variables()
save_variables()512 void AutogradContext::save_variables() {
513   saved_variables_.clear();
514   auto ptr = grad_fn_.lock();
515 
516   for (const auto& var : to_save_) {
517     // Allow empty variables to be saved
518     if (var.defined()) {
519       bool is_output = var.grad_fn().get() == ptr.get();
520       saved_variables_.emplace_back(var, is_output);
521     } else {
522       saved_variables_.emplace_back();
523     }
524   }
525   to_save_.clear();
526 }
527 
get_saved_variables() const528 variable_list AutogradContext::get_saved_variables() const {
529   TORCH_CHECK(!has_freed_buffers_, ERR_BACKWARD_TWICE);
530   variable_list saved;
531   saved.reserve(saved_variables_.size());
532   auto ptr = grad_fn_.lock();
533   TORCH_INTERNAL_ASSERT(ptr);
534   for (auto& var : saved_variables_) {
535     saved.push_back(var.unpack(ptr));
536   }
537   return saved;
538 }
539 
needs_input_grad(size_t output_edge_index) const540 bool AutogradContext::needs_input_grad(size_t output_edge_index) const {
541   auto ptr = grad_fn_.lock();
542   TORCH_INTERNAL_ASSERT(ptr);
543   return ptr->task_should_compute_output(output_edge_index);
544 }
545 
needs_input_grad(std::initializer_list<IndexRange> idxs) const546 bool AutogradContext::needs_input_grad(
547     std::initializer_list<IndexRange> idxs) const {
548   auto ptr = grad_fn_.lock();
549   TORCH_INTERNAL_ASSERT(ptr);
550   return ptr->task_should_compute_output(idxs);
551 }
552 
mark_dirty(const variable_list & inputs)553 void AutogradContext::mark_dirty(const variable_list& inputs) {
554   dirty_inputs_.clear();
555   dirty_inputs_.reserve(inputs.size());
556   for (auto& var : inputs) {
557     dirty_inputs_.insert(var.unsafeGetTensorImpl());
558   }
559 }
560 
mark_non_differentiable(const variable_list & outputs)561 void AutogradContext::mark_non_differentiable(const variable_list& outputs) {
562   non_differentiable_.clear();
563   non_differentiable_.reserve(outputs.size());
564   for (auto& var : outputs) {
565     non_differentiable_.insert(var.unsafeGetTensorImpl());
566   }
567 }
568 
set_materialize_grads(bool value)569 void AutogradContext::set_materialize_grads(bool value) {
570   materialize_grads_ = value;
571 }
572 
get_and_bump_dirty() const573 const std::unordered_set<at::TensorImpl*>& AutogradContext::get_and_bump_dirty()
574     const {
575   for (auto& var : dirty_inputs_) {
576     var->bump_version();
577   }
578   return dirty_inputs_;
579 }
580 
581 const std::unordered_set<at::TensorImpl*>& AutogradContext::
get_non_differentiable() const582     get_non_differentiable() const {
583   return non_differentiable_;
584 }
585 } // namespace torch::autograd
586