1 #include <torch/csrc/autograd/functions/accumulate_grad.h>
2
3 #include <ATen/core/dispatch/Dispatcher.h>
4 #include <torch/csrc/autograd/functions/basic_ops.h>
5 #include <torch/csrc/autograd/functions/tensor.h>
6 #include <torch/csrc/autograd/functions/utils.h>
7 #include <torch/csrc/autograd/grad_mode.h>
8 #include <torch/csrc/autograd/variable.h>
9 #include <torch/csrc/dynamo/compiled_autograd.h>
10
11 #include <cstdint>
12 #include <stdexcept>
13 #include <utility>
14
15 namespace torch::autograd {
16
17 // AccumulateGrad sets sequence_nr to the max value so it's always called
18 // ASAP during backwards.
AccumulateGrad(Variable variable_)19 AccumulateGrad::AccumulateGrad(Variable variable_)
20 : Node(/*sequence_nr=*/UINT64_MAX), variable(std::move(variable_)) {
21 add_input_metadata(variable);
22 }
23
apply(variable_list && grads)24 auto AccumulateGrad::apply(variable_list&& grads) -> variable_list {
25 check_input_variables("AccumulateGrad", grads, 1, 0);
26
27 if (!grads[0].defined())
28 return {};
29 if (variable.grad_fn())
30 throw std::logic_error(
31 "leaf variable has been moved into the graph interior");
32 if (!variable.requires_grad())
33 return {};
34
35 // std::move(grads[0]) to avoid bumping up refcount
36 at::Tensor new_grad = std::move(grads[0]);
37
38 // Acquire lock to here protect thread safety on variable, this ensures
39 // AccumulateGrad does not race to shared variable from different threads
40 // when updating the gradients. We don't ensure thread safety on hooks
41 // and rely on user to provide thread safe hooks
42 // see Note [Thread Safety on Autograd Node]
43 std::lock_guard<std::mutex> lock(mutex_);
44
45 at::Tensor& grad = variable.mutable_grad();
46
47 // If the function has post hooks (for example, a DDP allreduce hook),
48 // call_function in Engine.cpp will temporarily bump the expected refcount
49 // by one, hence the addition of !post_hooks().empty() for 'num_expected_refs'
50 // in addition to the one reference that we're holding.
51 // 'num_expected_refs' is used to determine whether or not we should clone
52 // the grad or can steal the grad.
53 accumulateGrad(
54 variable,
55 grad,
56 new_grad,
57 1 + !post_hooks().empty() /* num_expected_refs */,
58 [&grad](at::Tensor&& grad_update) { grad = std::move(grad_update); });
59
60 auto& hook = tensor_post_acc_grad_hooks();
61 if (hook != nullptr) {
62 (*hook)(variable);
63 }
64
65 return variable_list();
66 }
67
compiled_args(CompiledNodeArgs & args)68 void AccumulateGrad::compiled_args(CompiledNodeArgs& args) {
69 if (args.cond(variable.defined() && variable.requires_grad())) {
70 args.collect(variable);
71 args.collect(variable.grad());
72 }
73 auto& hook = tensor_post_acc_grad_hooks();
74 if (hook != nullptr) {
75 hook->compiled_args(args);
76 }
77 }
apply_with_saved(const variable_list & grads,SwapSavedVariables & saved)78 variable_list AccumulateGrad::apply_with_saved(
79 const variable_list& grads,
80 SwapSavedVariables& saved) {
81 if (!(variable.defined() && variable.requires_grad()) ||
82 !grads[0].defined()) {
83 return variable_list();
84 }
85 TORCH_INTERNAL_ASSERT(!variable.grad_fn() && grads.size() == 1);
86 at::Tensor variable_copy = variable;
87 at::Tensor grad_copy = variable.grad();
88 saved.before(variable_copy);
89 saved.before(grad_copy);
90 variable_copy.mutable_grad() = grad_copy;
91 // op is intentionally static
92 static auto op = c10::Dispatcher::singleton()
93 .findSchemaOrThrow("inductor::accumulate_grad_", "")
94 .typed<void(const at::Tensor&, const at::Tensor&)>();
95 op.call(variable_copy, grads[0]);
96 auto& hook = tensor_post_acc_grad_hooks();
97 if (hook != nullptr) {
98 hook->apply_with_saved(variable_copy, saved);
99 }
100 saved.after(variable_copy);
101 saved.after(grad_copy);
102
103 return variable_list();
104 }
105
106 } // namespace torch::autograd
107