xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/integer_value_refinement.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/core/jit_type.h>
2 #include <torch/csrc/jit/ir/ir.h>
3 #include <torch/csrc/jit/jit_log.h>
4 #include <torch/csrc/jit/passes/integer_value_refinement.h>
5 #include <torch/csrc/jit/passes/value_refinement_utils.h>
6 
7 #include <utility>
8 
9 namespace torch::jit {
10 
11 using IntegerRefinement = std::unordered_map<Value*, int64_t>;
12 
13 // see [value refinement algorithm] for full explanation.
14 // When a comparison like `cond = x == 4` or `cond = x != 4` is made,
15 // `cond` value carries information (refinements) about the value of `x`.
16 // in an example like:
17 // if x == 1:
18 //    ...
19 // we can substitute all uses of x dominated by the true block
20 // with 1.
21 
22 struct IntegerValueRefiner {
IntegerValueRefinertorch::jit::IntegerValueRefiner23   IntegerValueRefiner(std::shared_ptr<Graph> graph)
24       : graph_(std::move(graph)) {}
25 
runtorch::jit::IntegerValueRefiner26   bool run() {
27     if (!blockHasIntComparisons(graph_->block())) {
28       return false;
29     }
30     IntegerRefinement refinements;
31     RefineIntegerValues(graph_->block(), std::move(refinements));
32     return changed_;
33   }
34 
blockHasIntComparisonstorch::jit::IntegerValueRefiner35   bool blockHasIntComparisons(Block* b) {
36     for (Node* n : b->nodes()) {
37       if (n->matches("aten::eq(int a, int b) -> bool") ||
38           n->matches("aten::ne(int a, int b) -> bool")) {
39         for (size_t const_index : {0, 1}) {
40           auto non_const_index = 1 - const_index;
41           if (n->inputs().at(const_index)->node()->kind() == prim::Constant &&
42               n->inputs().at(non_const_index)->uses().size() > 1) {
43             return true;
44           }
45         }
46       }
47       for (Block* block : n->blocks()) {
48         if (blockHasIntComparisons(block)) {
49           return true;
50         }
51       }
52     }
53     return false;
54   }
55 
removeIfNodeOutputsWithRefinementstorch::jit::IntegerValueRefiner56   void removeIfNodeOutputsWithRefinements(
57       Node* if_node,
58       IntegerRefinement& true_block_refinements,
59       IntegerRefinement& false_block_refinements) {
60     // we are looking for cases where we can replace both block outputs with the
61     // same value, which opens up further optimization opportunities. The pass
62     // will already handle if both outputs are refined to the same constant.
63     // Here, we look for cases where one block output has been refined in the
64     // other block to be equal to the same constant value as the other other
65     // block output:
66     //  graph(%y.1 : int):
67     //   %one_constant : int = prim::Constant[value=1]()
68     //   %3 : bool = aten::eq(%y.1, %one_constant)
69     //   %15 : int = prim::If(%3)
70     //     block0():
71     //       -> (%one_constant)
72     //     block1():
73     //       -> (%y.1)
74     //   return (%15)
75     // %15 can always be safely replaced with %y.1
76     // this is an important case for symbolic shape analysis
77     for (size_t block_index : {0, 1}) {
78       Block* if_block = if_node->blocks().at(block_index);
79       Block* other_if_block = if_node->blocks().at(1 - block_index);
80       for (size_t i = 0; i < if_node->outputs().size(); ++i) {
81         Value* block_output = if_block->outputs().at(i);
82         if (!block_output->type()->cast<IntType>()) {
83           continue;
84         }
85         // Value must be in scope for both blocks
86         // in example above, %y.1 cannot be defined in block1
87         if (!if_node->isDominatedBy(block_output->node())) {
88           continue;
89         }
90         // one constant value one not - we are looking for the pattern
91         // where y.1 is refined to the existing block output %one_constant
92         auto other_output = other_if_block->outputs().at(i);
93         auto other_const_value = other_output->type()->cast<IntType>()
94             ? constant_as<int64_t>(other_output)
95             : std::nullopt;
96         if (!other_const_value ||
97             block_output->node()->kind() == prim::Constant) {
98           continue;
99         }
100         // here, we are looking in refinements in the other block of our
101         // current output. in the example, we are looking for refinements of
102         // %y.1 in `block0`, and we are checking that %y.1 is refined
103         // to the constant value of %one_constant
104         const auto& other_block_refinements =
105             block_index == 0 ? false_block_refinements : true_block_refinements;
106         if (!other_block_refinements.count(block_output)) {
107           continue;
108         }
109         if (other_block_refinements.at(block_output) == *other_const_value) {
110           if_node->outputs().at(i)->replaceAllUsesWith(block_output);
111           changed_ = true;
112         }
113       }
114     }
115   }
116 
117   // iteratively look through the block `b` for refinements or Value uses that
118   // can be refined, `block_refinements` are the refinements present starting at
119   // this block (and for all blocks dominated by this block).
RefineIntegerValuestorch::jit::IntegerValueRefiner120   IntegerRefinement RefineIntegerValues(
121       Block* b,
122       IntegerRefinement block_refinements) {
123     active_refinements_.push_back(&block_refinements);
124     for (Node* n : b->nodes()) {
125       if (n->matches("aten::eq(int a, int b) -> bool") ||
126           n->matches("aten::ne(int a, int b) -> bool")) {
127         for (size_t const_index : {0, 1}) {
128           if (auto ival = constant_as<int64_t>(n->inputs().at(const_index))) {
129             IntegerRefinement refine;
130             refine[n->inputs().at(1 - const_index)] = *ival;
131             info_[n->output()] = n->kind() == aten::eq
132                 ? BooleanRefinementMapping::TrueRefinements(std::move(refine))
133                 : BooleanRefinementMapping::FalseRefinements(std::move(refine));
134           }
135         }
136       }
137       for (size_t input = 0; input < n->inputs().size(); ++input) {
138         Value* input_v = n->inputs().at(input);
139         if (!input_v->type()->cast<IntType>()) {
140           continue;
141         }
142 
143         if (auto refine = tryFindRefinement(input_v)) {
144           WithInsertPoint guard(n);
145           auto refine_constant =
146               graph_->insertConstant(static_cast<int64_t>(*refine));
147           n->replaceInputWith(input_v, refine_constant);
148           changed_ = true;
149         }
150       }
151 
152       if (n->kind() == prim::If) {
153         IfView if_n(n);
154         bool has_cond_ref = info_.count(if_n.cond()) != 0;
155         IntegerRefinement empty;
156         auto true_block_refinements = RefineIntegerValues(
157             if_n.thenBlock(),
158             has_cond_ref ? info_[if_n.cond()].true_refine() : empty);
159         auto false_block_refinements = RefineIntegerValues(
160             if_n.elseBlock(),
161             has_cond_ref ? info_[if_n.cond()].false_refine() : empty);
162 
163         removeIfNodeOutputsWithRefinements(
164             n, true_block_refinements, false_block_refinements);
165 
166         joinIfRefinements(
167             n,
168             throwing_blocks_,
169             block_refinements,
170             true_block_refinements,
171             false_block_refinements,
172             info_);
173       } else {
174         handleCommonRefinentOperators(n, throwing_blocks_, info_);
175       }
176     }
177 
178     // iterating over all nodes in the block will not iterate over
179     // block outputs, so we need to add handling of them.
180     // %3 : int = prim::Constant[value=3]()
181     // %4 : bool = aten::eq(%y.1, %3)
182     // %a : int = prim::If(%4)
183     //   block0():
184     //     -> (%y.1)
185     // Here, we can replace y.1 with 3
186 
187     for (size_t i = 0; i < b->outputs().size(); ++i) {
188       Value* output_v = b->outputs().at(i);
189       if (!output_v->type()->cast<IntType>()) {
190         continue;
191       }
192 
193       if (auto refine = tryFindRefinement(output_v)) {
194         WithInsertPoint guard(b);
195         auto refine_constant =
196             graph_->insertConstant(static_cast<int64_t>(*refine));
197         b->replaceOutput(i, refine_constant);
198         changed_ = true;
199       }
200     }
201 
202     active_refinements_.pop_back();
203     return block_refinements;
204   };
205 
tryFindRefinementtorch::jit::IntegerValueRefiner206   std::optional<int64_t> tryFindRefinement(Value* v) {
207     for (const auto& ref : active_refinements_) {
208       auto maybe_refinement = ref->find(v);
209       if (maybe_refinement != ref->end()) {
210         return maybe_refinement->second;
211       }
212     }
213     return std::nullopt;
214   }
215 
216   std::shared_ptr<Graph> graph_;
217   // A stack of active refinements, one for each block
218   std::vector<IntegerRefinement*> active_refinements_;
219   // A map from Boolean Value * -> associated refinements
220   std::unordered_map<Value*, BooleanRefinementMapping> info_;
221   std::unordered_set<Block*> throwing_blocks_;
222   bool changed_ = false;
223 };
224 
RefineIntegerValues(const std::shared_ptr<Graph> & graph)225 bool RefineIntegerValues(const std::shared_ptr<Graph>& graph) {
226   return IntegerValueRefiner(graph).run();
227 }
228 
229 } // namespace torch::jit
230