xref: /aosp_15_r20/external/pytorch/torch/csrc/autograd/functions/accumulate_grad.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/autograd/functions/accumulate_grad.h>
2 
3 #include <ATen/core/dispatch/Dispatcher.h>
4 #include <torch/csrc/autograd/functions/basic_ops.h>
5 #include <torch/csrc/autograd/functions/tensor.h>
6 #include <torch/csrc/autograd/functions/utils.h>
7 #include <torch/csrc/autograd/grad_mode.h>
8 #include <torch/csrc/autograd/variable.h>
9 #include <torch/csrc/dynamo/compiled_autograd.h>
10 
11 #include <cstdint>
12 #include <stdexcept>
13 #include <utility>
14 
15 namespace torch::autograd {
16 
17 // AccumulateGrad sets sequence_nr to the max value so it's always called
18 // ASAP during backwards.
AccumulateGrad(Variable variable_)19 AccumulateGrad::AccumulateGrad(Variable variable_)
20     : Node(/*sequence_nr=*/UINT64_MAX), variable(std::move(variable_)) {
21   add_input_metadata(variable);
22 }
23 
apply(variable_list && grads)24 auto AccumulateGrad::apply(variable_list&& grads) -> variable_list {
25   check_input_variables("AccumulateGrad", grads, 1, 0);
26 
27   if (!grads[0].defined())
28     return {};
29   if (variable.grad_fn())
30     throw std::logic_error(
31         "leaf variable has been moved into the graph interior");
32   if (!variable.requires_grad())
33     return {};
34 
35   // std::move(grads[0]) to avoid bumping up refcount
36   at::Tensor new_grad = std::move(grads[0]);
37 
38   // Acquire lock to here protect thread safety on variable, this ensures
39   // AccumulateGrad does not race to shared variable from different threads
40   // when updating the gradients. We don't ensure thread safety on hooks
41   // and rely on user to provide thread safe hooks
42   // see Note [Thread Safety on Autograd Node]
43   std::lock_guard<std::mutex> lock(mutex_);
44 
45   at::Tensor& grad = variable.mutable_grad();
46 
47   // If the function has post hooks (for example, a DDP allreduce hook),
48   // call_function in Engine.cpp will temporarily bump the expected refcount
49   // by one, hence the addition of !post_hooks().empty() for 'num_expected_refs'
50   // in addition to the one reference that we're holding.
51   // 'num_expected_refs' is used to determine whether or not we should clone
52   // the grad or can steal the grad.
53   accumulateGrad(
54       variable,
55       grad,
56       new_grad,
57       1 + !post_hooks().empty() /* num_expected_refs */,
58       [&grad](at::Tensor&& grad_update) { grad = std::move(grad_update); });
59 
60   auto& hook = tensor_post_acc_grad_hooks();
61   if (hook != nullptr) {
62     (*hook)(variable);
63   }
64 
65   return variable_list();
66 }
67 
compiled_args(CompiledNodeArgs & args)68 void AccumulateGrad::compiled_args(CompiledNodeArgs& args) {
69   if (args.cond(variable.defined() && variable.requires_grad())) {
70     args.collect(variable);
71     args.collect(variable.grad());
72   }
73   auto& hook = tensor_post_acc_grad_hooks();
74   if (hook != nullptr) {
75     hook->compiled_args(args);
76   }
77 }
apply_with_saved(const variable_list & grads,SwapSavedVariables & saved)78 variable_list AccumulateGrad::apply_with_saved(
79     const variable_list& grads,
80     SwapSavedVariables& saved) {
81   if (!(variable.defined() && variable.requires_grad()) ||
82       !grads[0].defined()) {
83     return variable_list();
84   }
85   TORCH_INTERNAL_ASSERT(!variable.grad_fn() && grads.size() == 1);
86   at::Tensor variable_copy = variable;
87   at::Tensor grad_copy = variable.grad();
88   saved.before(variable_copy);
89   saved.before(grad_copy);
90   variable_copy.mutable_grad() = grad_copy;
91   // op is intentionally static
92   static auto op = c10::Dispatcher::singleton()
93                        .findSchemaOrThrow("inductor::accumulate_grad_", "")
94                        .typed<void(const at::Tensor&, const at::Tensor&)>();
95   op.call(variable_copy, grads[0]);
96   auto& hook = tensor_post_acc_grad_hooks();
97   if (hook != nullptr) {
98     hook->apply_with_saved(variable_copy, saved);
99   }
100   saved.after(variable_copy);
101   saved.after(grad_copy);
102 
103   return variable_list();
104 }
105 
106 } // namespace torch::autograd
107