1 #include <torch/csrc/jit/passes/requires_grad_analysis.h>
2
3 #include <ATen/core/jit_type.h>
4 #include <c10/util/irange.h>
5 #include <torch/csrc/jit/ir/constants.h>
6 #include <torch/csrc/jit/ir/ir.h>
7
8 #include <vector>
9
10 namespace torch::jit {
11
12 namespace {
13
getRequiresGrad(Value * value)14 bool getRequiresGrad(Value* value) {
15 return value->requires_grad();
16 }
17
setRequiresGrad(Value * value,bool req_value)18 void setRequiresGrad(Value* value, bool req_value) {
19 if (auto type = value->type()->cast<TensorType>()) {
20 value->setType(type->withRequiresGrad(req_value));
21 }
22 }
23
setRequiresGrad(at::ArrayRef<Value * > outputs,const std::vector<bool> & values)24 void setRequiresGrad(
25 at::ArrayRef<Value*> outputs,
26 const std::vector<bool>& values) {
27 AT_ASSERT(outputs.size() == values.size());
28 for (const auto i : c10::irange(values.size())) {
29 setRequiresGrad(outputs[i], values[i]);
30 }
31 }
32
setRequiresGrad(Node * node,const std::vector<bool> & values)33 void setRequiresGrad(Node* node, const std::vector<bool>& values) {
34 setRequiresGrad(node->outputs(), values);
35 }
36
bitwiseOr(std::vector<bool> a,const std::vector<bool> & b)37 std::vector<bool> bitwiseOr(std::vector<bool> a, const std::vector<bool>& b) {
38 AT_ASSERT(a.size() == b.size());
39 for (const auto i : c10::irange(a.size())) {
40 a[i] = a[i] || b[i];
41 }
42 return a;
43 }
44
PropagateRequiresGradSimpleNode(Node * node)45 void PropagateRequiresGradSimpleNode(Node* node) {
46 static const OperatorSet comparison_ops = {
47 "aten::lt(Tensor self, Tensor other) -> Tensor",
48 "aten::le(Tensor self, Tensor other) -> Tensor",
49 "aten::gt(Tensor self, Tensor other) -> Tensor",
50 "aten::ge(Tensor self, Tensor other) -> Tensor",
51 "aten::eq(Tensor self, Tensor other) -> Tensor",
52 "aten::ne(Tensor self, Tensor other) -> Tensor",
53 "aten::lt(Tensor self, Scalar other) -> Tensor",
54 "aten::le(Tensor self, Scalar other) -> Tensor",
55 "aten::gt(Tensor self, Scalar other) -> Tensor",
56 "aten::ge(Tensor self, Scalar other) -> Tensor",
57 "aten::eq(Tensor self, Scalar other) -> Tensor",
58 "aten::ne(Tensor self, Scalar other) -> Tensor",
59 };
60
61 // NOLINTNEXTLINE(bugprone-branch-clone)
62 if (node->isMemberOf(comparison_ops)) {
63 return setRequiresGrad(node->output(), false);
64 } else if (node->matches(
65 "aten::type_as(Tensor self, Tensor other) -> Tensor")) {
66 return setRequiresGrad(node->output(), node->input(0)->requires_grad());
67 } else if (node->matches("aten::detach(Tensor(a) self) -> Tensor(a)")) {
68 return setRequiresGrad(node->output(), false);
69 } else if (node->kind() == aten::tensor) {
70 if (auto grad_index =
71 node->schema().argumentIndexWithName("requires_grad")) {
72 if (auto const_arg = constant_as<bool>(node->inputs().at(*grad_index))) {
73 return setRequiresGrad(node->output(), *const_arg);
74 }
75 }
76 if (auto type = node->output()->type()->cast<TensorType>()) {
77 if (type->scalarType()) {
78 setRequiresGrad(
79 node->output(),
80 autograd::isDifferentiableType(*type->scalarType()));
81 }
82 }
83 return;
84 }
85
86 auto inputs = node->inputs();
87 auto outputs = node->outputs();
88 bool should_require =
89 std::any_of(inputs.begin(), inputs.end(), getRequiresGrad);
90 for (Value* output : outputs) {
91 if (auto type = output->type()->cast<TensorType>()) {
92 if (type->scalarType()) {
93 setRequiresGrad(
94 output,
95 should_require &&
96 autograd::isDifferentiableType(*type->scalarType()));
97 }
98 }
99 }
100 }
101
102 void PropagateRequiresGrad(Block* block);
103
PropagateRequiresGrad(Node * node)104 void PropagateRequiresGrad(Node* node) {
105 if (node->kind() == prim::If) {
106 auto blocks = node->blocks();
107 auto true_block = blocks.at(0);
108 auto false_block = blocks.at(1);
109
110 PropagateRequiresGrad(true_block);
111 PropagateRequiresGrad(false_block);
112
113 auto outputs_require = bitwiseOr(
114 fmap(true_block->outputs(), getRequiresGrad),
115 fmap(false_block->outputs(), getRequiresGrad));
116 setRequiresGrad(node, outputs_require);
117 } else if (node->kind() == prim::Loop) {
118 auto body = node->blocks().at(0);
119 std::vector<bool> loop_inputs_require =
120 fmap(node->inputs().slice(2), getRequiresGrad);
121 std::vector<bool> body_inputs_require = loop_inputs_require;
122 std::vector<bool> body_outputs_require(node->outputs().size(), false);
123
124 std::vector<bool> new_body_inputs_require = body_inputs_require;
125 std::vector<bool> new_body_outputs_require = body_outputs_require;
126
127 // continue iterating until the results have converged
128 do {
129 body_inputs_require = new_body_inputs_require;
130 body_outputs_require = new_body_outputs_require;
131
132 new_body_inputs_require =
133 bitwiseOr(body_inputs_require, body_outputs_require);
134 setRequiresGrad(
135 body->param_node()->outputs().slice(1), new_body_inputs_require);
136 PropagateRequiresGrad(body);
137 new_body_outputs_require =
138 fmap(body->return_node()->inputs().slice(1), getRequiresGrad);
139 } while (new_body_inputs_require != body_inputs_require ||
140 new_body_outputs_require != body_outputs_require);
141
142 setRequiresGrad(node, bitwiseOr(body_outputs_require, loop_inputs_require));
143 } else {
144 PropagateRequiresGradSimpleNode(node);
145 }
146 }
147
PropagateRequiresGrad(Block * block)148 void PropagateRequiresGrad(Block* block) {
149 for (Node* node : block->nodes()) {
150 PropagateRequiresGrad(node);
151 }
152 }
153 } // anonymous namespace
154
PropagateRequiresGrad(std::shared_ptr<Graph> & graph)155 void PropagateRequiresGrad(std::shared_ptr<Graph>& graph) {
156 PropagateRequiresGrad(graph->block());
157 }
158 } // namespace torch::jit
159