xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/requires_grad_analysis.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/passes/requires_grad_analysis.h>
2 
3 #include <ATen/core/jit_type.h>
4 #include <c10/util/irange.h>
5 #include <torch/csrc/jit/ir/constants.h>
6 #include <torch/csrc/jit/ir/ir.h>
7 
8 #include <vector>
9 
10 namespace torch::jit {
11 
12 namespace {
13 
getRequiresGrad(Value * value)14 bool getRequiresGrad(Value* value) {
15   return value->requires_grad();
16 }
17 
setRequiresGrad(Value * value,bool req_value)18 void setRequiresGrad(Value* value, bool req_value) {
19   if (auto type = value->type()->cast<TensorType>()) {
20     value->setType(type->withRequiresGrad(req_value));
21   }
22 }
23 
setRequiresGrad(at::ArrayRef<Value * > outputs,const std::vector<bool> & values)24 void setRequiresGrad(
25     at::ArrayRef<Value*> outputs,
26     const std::vector<bool>& values) {
27   AT_ASSERT(outputs.size() == values.size());
28   for (const auto i : c10::irange(values.size())) {
29     setRequiresGrad(outputs[i], values[i]);
30   }
31 }
32 
setRequiresGrad(Node * node,const std::vector<bool> & values)33 void setRequiresGrad(Node* node, const std::vector<bool>& values) {
34   setRequiresGrad(node->outputs(), values);
35 }
36 
bitwiseOr(std::vector<bool> a,const std::vector<bool> & b)37 std::vector<bool> bitwiseOr(std::vector<bool> a, const std::vector<bool>& b) {
38   AT_ASSERT(a.size() == b.size());
39   for (const auto i : c10::irange(a.size())) {
40     a[i] = a[i] || b[i];
41   }
42   return a;
43 }
44 
PropagateRequiresGradSimpleNode(Node * node)45 void PropagateRequiresGradSimpleNode(Node* node) {
46   static const OperatorSet comparison_ops = {
47       "aten::lt(Tensor self, Tensor other) -> Tensor",
48       "aten::le(Tensor self, Tensor other) -> Tensor",
49       "aten::gt(Tensor self, Tensor other) -> Tensor",
50       "aten::ge(Tensor self, Tensor other) -> Tensor",
51       "aten::eq(Tensor self, Tensor other) -> Tensor",
52       "aten::ne(Tensor self, Tensor other) -> Tensor",
53       "aten::lt(Tensor self, Scalar other) -> Tensor",
54       "aten::le(Tensor self, Scalar other) -> Tensor",
55       "aten::gt(Tensor self, Scalar other) -> Tensor",
56       "aten::ge(Tensor self, Scalar other) -> Tensor",
57       "aten::eq(Tensor self, Scalar other) -> Tensor",
58       "aten::ne(Tensor self, Scalar other) -> Tensor",
59   };
60 
61   // NOLINTNEXTLINE(bugprone-branch-clone)
62   if (node->isMemberOf(comparison_ops)) {
63     return setRequiresGrad(node->output(), false);
64   } else if (node->matches(
65                  "aten::type_as(Tensor self, Tensor other) -> Tensor")) {
66     return setRequiresGrad(node->output(), node->input(0)->requires_grad());
67   } else if (node->matches("aten::detach(Tensor(a) self) -> Tensor(a)")) {
68     return setRequiresGrad(node->output(), false);
69   } else if (node->kind() == aten::tensor) {
70     if (auto grad_index =
71             node->schema().argumentIndexWithName("requires_grad")) {
72       if (auto const_arg = constant_as<bool>(node->inputs().at(*grad_index))) {
73         return setRequiresGrad(node->output(), *const_arg);
74       }
75     }
76     if (auto type = node->output()->type()->cast<TensorType>()) {
77       if (type->scalarType()) {
78         setRequiresGrad(
79             node->output(),
80             autograd::isDifferentiableType(*type->scalarType()));
81       }
82     }
83     return;
84   }
85 
86   auto inputs = node->inputs();
87   auto outputs = node->outputs();
88   bool should_require =
89       std::any_of(inputs.begin(), inputs.end(), getRequiresGrad);
90   for (Value* output : outputs) {
91     if (auto type = output->type()->cast<TensorType>()) {
92       if (type->scalarType()) {
93         setRequiresGrad(
94             output,
95             should_require &&
96                 autograd::isDifferentiableType(*type->scalarType()));
97       }
98     }
99   }
100 }
101 
102 void PropagateRequiresGrad(Block* block);
103 
PropagateRequiresGrad(Node * node)104 void PropagateRequiresGrad(Node* node) {
105   if (node->kind() == prim::If) {
106     auto blocks = node->blocks();
107     auto true_block = blocks.at(0);
108     auto false_block = blocks.at(1);
109 
110     PropagateRequiresGrad(true_block);
111     PropagateRequiresGrad(false_block);
112 
113     auto outputs_require = bitwiseOr(
114         fmap(true_block->outputs(), getRequiresGrad),
115         fmap(false_block->outputs(), getRequiresGrad));
116     setRequiresGrad(node, outputs_require);
117   } else if (node->kind() == prim::Loop) {
118     auto body = node->blocks().at(0);
119     std::vector<bool> loop_inputs_require =
120         fmap(node->inputs().slice(2), getRequiresGrad);
121     std::vector<bool> body_inputs_require = loop_inputs_require;
122     std::vector<bool> body_outputs_require(node->outputs().size(), false);
123 
124     std::vector<bool> new_body_inputs_require = body_inputs_require;
125     std::vector<bool> new_body_outputs_require = body_outputs_require;
126 
127     // continue iterating until the results have converged
128     do {
129       body_inputs_require = new_body_inputs_require;
130       body_outputs_require = new_body_outputs_require;
131 
132       new_body_inputs_require =
133           bitwiseOr(body_inputs_require, body_outputs_require);
134       setRequiresGrad(
135           body->param_node()->outputs().slice(1), new_body_inputs_require);
136       PropagateRequiresGrad(body);
137       new_body_outputs_require =
138           fmap(body->return_node()->inputs().slice(1), getRequiresGrad);
139     } while (new_body_inputs_require != body_inputs_require ||
140              new_body_outputs_require != body_outputs_require);
141 
142     setRequiresGrad(node, bitwiseOr(body_outputs_require, loop_inputs_require));
143   } else {
144     PropagateRequiresGradSimpleNode(node);
145   }
146 }
147 
PropagateRequiresGrad(Block * block)148 void PropagateRequiresGrad(Block* block) {
149   for (Node* node : block->nodes()) {
150     PropagateRequiresGrad(node);
151   }
152 }
153 } // anonymous namespace
154 
PropagateRequiresGrad(std::shared_ptr<Graph> & graph)155 void PropagateRequiresGrad(std::shared_ptr<Graph>& graph) {
156   PropagateRequiresGrad(graph->block());
157 }
158 } // namespace torch::jit
159