1 #include <ATen/core/jit_type.h>
2 #include <torch/csrc/jit/ir/ir.h>
3 #include <torch/csrc/jit/jit_log.h>
4 #include <torch/csrc/jit/passes/integer_value_refinement.h>
5 #include <torch/csrc/jit/passes/value_refinement_utils.h>
6
7 #include <utility>
8
9 namespace torch::jit {
10
11 using IntegerRefinement = std::unordered_map<Value*, int64_t>;
12
13 // see [value refinement algorithm] for full explanation.
14 // When a comparison like `cond = x == 4` or `cond = x != 4` is made,
15 // `cond` value carries information (refinements) about the value of `x`.
16 // in an example like:
17 // if x == 1:
18 // ...
19 // we can substitute all uses of x dominated by the true block
20 // with 1.
21
22 struct IntegerValueRefiner {
IntegerValueRefinertorch::jit::IntegerValueRefiner23 IntegerValueRefiner(std::shared_ptr<Graph> graph)
24 : graph_(std::move(graph)) {}
25
runtorch::jit::IntegerValueRefiner26 bool run() {
27 if (!blockHasIntComparisons(graph_->block())) {
28 return false;
29 }
30 IntegerRefinement refinements;
31 RefineIntegerValues(graph_->block(), std::move(refinements));
32 return changed_;
33 }
34
blockHasIntComparisonstorch::jit::IntegerValueRefiner35 bool blockHasIntComparisons(Block* b) {
36 for (Node* n : b->nodes()) {
37 if (n->matches("aten::eq(int a, int b) -> bool") ||
38 n->matches("aten::ne(int a, int b) -> bool")) {
39 for (size_t const_index : {0, 1}) {
40 auto non_const_index = 1 - const_index;
41 if (n->inputs().at(const_index)->node()->kind() == prim::Constant &&
42 n->inputs().at(non_const_index)->uses().size() > 1) {
43 return true;
44 }
45 }
46 }
47 for (Block* block : n->blocks()) {
48 if (blockHasIntComparisons(block)) {
49 return true;
50 }
51 }
52 }
53 return false;
54 }
55
removeIfNodeOutputsWithRefinementstorch::jit::IntegerValueRefiner56 void removeIfNodeOutputsWithRefinements(
57 Node* if_node,
58 IntegerRefinement& true_block_refinements,
59 IntegerRefinement& false_block_refinements) {
60 // we are looking for cases where we can replace both block outputs with the
61 // same value, which opens up further optimization opportunities. The pass
62 // will already handle if both outputs are refined to the same constant.
63 // Here, we look for cases where one block output has been refined in the
64 // other block to be equal to the same constant value as the other other
65 // block output:
66 // graph(%y.1 : int):
67 // %one_constant : int = prim::Constant[value=1]()
68 // %3 : bool = aten::eq(%y.1, %one_constant)
69 // %15 : int = prim::If(%3)
70 // block0():
71 // -> (%one_constant)
72 // block1():
73 // -> (%y.1)
74 // return (%15)
75 // %15 can always be safely replaced with %y.1
76 // this is an important case for symbolic shape analysis
77 for (size_t block_index : {0, 1}) {
78 Block* if_block = if_node->blocks().at(block_index);
79 Block* other_if_block = if_node->blocks().at(1 - block_index);
80 for (size_t i = 0; i < if_node->outputs().size(); ++i) {
81 Value* block_output = if_block->outputs().at(i);
82 if (!block_output->type()->cast<IntType>()) {
83 continue;
84 }
85 // Value must be in scope for both blocks
86 // in example above, %y.1 cannot be defined in block1
87 if (!if_node->isDominatedBy(block_output->node())) {
88 continue;
89 }
90 // one constant value one not - we are looking for the pattern
91 // where y.1 is refined to the existing block output %one_constant
92 auto other_output = other_if_block->outputs().at(i);
93 auto other_const_value = other_output->type()->cast<IntType>()
94 ? constant_as<int64_t>(other_output)
95 : std::nullopt;
96 if (!other_const_value ||
97 block_output->node()->kind() == prim::Constant) {
98 continue;
99 }
100 // here, we are looking in refinements in the other block of our
101 // current output. in the example, we are looking for refinements of
102 // %y.1 in `block0`, and we are checking that %y.1 is refined
103 // to the constant value of %one_constant
104 const auto& other_block_refinements =
105 block_index == 0 ? false_block_refinements : true_block_refinements;
106 if (!other_block_refinements.count(block_output)) {
107 continue;
108 }
109 if (other_block_refinements.at(block_output) == *other_const_value) {
110 if_node->outputs().at(i)->replaceAllUsesWith(block_output);
111 changed_ = true;
112 }
113 }
114 }
115 }
116
117 // iteratively look through the block `b` for refinements or Value uses that
118 // can be refined, `block_refinements` are the refinements present starting at
119 // this block (and for all blocks dominated by this block).
RefineIntegerValuestorch::jit::IntegerValueRefiner120 IntegerRefinement RefineIntegerValues(
121 Block* b,
122 IntegerRefinement block_refinements) {
123 active_refinements_.push_back(&block_refinements);
124 for (Node* n : b->nodes()) {
125 if (n->matches("aten::eq(int a, int b) -> bool") ||
126 n->matches("aten::ne(int a, int b) -> bool")) {
127 for (size_t const_index : {0, 1}) {
128 if (auto ival = constant_as<int64_t>(n->inputs().at(const_index))) {
129 IntegerRefinement refine;
130 refine[n->inputs().at(1 - const_index)] = *ival;
131 info_[n->output()] = n->kind() == aten::eq
132 ? BooleanRefinementMapping::TrueRefinements(std::move(refine))
133 : BooleanRefinementMapping::FalseRefinements(std::move(refine));
134 }
135 }
136 }
137 for (size_t input = 0; input < n->inputs().size(); ++input) {
138 Value* input_v = n->inputs().at(input);
139 if (!input_v->type()->cast<IntType>()) {
140 continue;
141 }
142
143 if (auto refine = tryFindRefinement(input_v)) {
144 WithInsertPoint guard(n);
145 auto refine_constant =
146 graph_->insertConstant(static_cast<int64_t>(*refine));
147 n->replaceInputWith(input_v, refine_constant);
148 changed_ = true;
149 }
150 }
151
152 if (n->kind() == prim::If) {
153 IfView if_n(n);
154 bool has_cond_ref = info_.count(if_n.cond()) != 0;
155 IntegerRefinement empty;
156 auto true_block_refinements = RefineIntegerValues(
157 if_n.thenBlock(),
158 has_cond_ref ? info_[if_n.cond()].true_refine() : empty);
159 auto false_block_refinements = RefineIntegerValues(
160 if_n.elseBlock(),
161 has_cond_ref ? info_[if_n.cond()].false_refine() : empty);
162
163 removeIfNodeOutputsWithRefinements(
164 n, true_block_refinements, false_block_refinements);
165
166 joinIfRefinements(
167 n,
168 throwing_blocks_,
169 block_refinements,
170 true_block_refinements,
171 false_block_refinements,
172 info_);
173 } else {
174 handleCommonRefinentOperators(n, throwing_blocks_, info_);
175 }
176 }
177
178 // iterating over all nodes in the block will not iterate over
179 // block outputs, so we need to add handling of them.
180 // %3 : int = prim::Constant[value=3]()
181 // %4 : bool = aten::eq(%y.1, %3)
182 // %a : int = prim::If(%4)
183 // block0():
184 // -> (%y.1)
185 // Here, we can replace y.1 with 3
186
187 for (size_t i = 0; i < b->outputs().size(); ++i) {
188 Value* output_v = b->outputs().at(i);
189 if (!output_v->type()->cast<IntType>()) {
190 continue;
191 }
192
193 if (auto refine = tryFindRefinement(output_v)) {
194 WithInsertPoint guard(b);
195 auto refine_constant =
196 graph_->insertConstant(static_cast<int64_t>(*refine));
197 b->replaceOutput(i, refine_constant);
198 changed_ = true;
199 }
200 }
201
202 active_refinements_.pop_back();
203 return block_refinements;
204 };
205
tryFindRefinementtorch::jit::IntegerValueRefiner206 std::optional<int64_t> tryFindRefinement(Value* v) {
207 for (const auto& ref : active_refinements_) {
208 auto maybe_refinement = ref->find(v);
209 if (maybe_refinement != ref->end()) {
210 return maybe_refinement->second;
211 }
212 }
213 return std::nullopt;
214 }
215
216 std::shared_ptr<Graph> graph_;
217 // A stack of active refinements, one for each block
218 std::vector<IntegerRefinement*> active_refinements_;
219 // A map from Boolean Value * -> associated refinements
220 std::unordered_map<Value*, BooleanRefinementMapping> info_;
221 std::unordered_set<Block*> throwing_blocks_;
222 bool changed_ = false;
223 };
224
RefineIntegerValues(const std::shared_ptr<Graph> & graph)225 bool RefineIntegerValues(const std::shared_ptr<Graph>& graph) {
226 return IntegerValueRefiner(graph).run();
227 }
228
229 } // namespace torch::jit
230