xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/peephole_list_idioms.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/core/jit_type.h>
2 #include <torch/csrc/jit/ir/alias_analysis.h>
3 #include <torch/csrc/jit/ir/ir_views.h>
4 #include <torch/csrc/jit/jit_log.h>
5 #include <torch/csrc/jit/passes/dead_code_elimination.h>
6 #include <torch/csrc/jit/passes/peephole.h>
7 #include <torch/csrc/jit/passes/peephole_list_idioms.h>
8 #include <torch/csrc/jit/passes/value_refinement_utils.h>
9 #include <torch/csrc/jit/runtime/graph_executor.h>
10 #include <torch/csrc/jit/runtime/slice_indices_adjust.h>
11 #include <limits>
12 #include <utility>
13 
14 namespace torch::jit {
15 
normalizeIndex(int64_t index,size_t len)16 static std::optional<size_t> normalizeIndex(int64_t index, size_t len) {
17   if (index < 0) {
18     index = index + len;
19   }
20   if (index >= 0 && index < static_cast<int64_t>(len)) {
21     return index;
22   } else {
23     return std::nullopt;
24   }
25 }
26 
27 // see [value refinement algorithm]
28 
29 struct ListLenRefiner {
ListLenRefinertorch::jit::ListLenRefiner30   ListLenRefiner(
31       std::shared_ptr<Graph> graph,
32       std::unordered_set<Value*>& mutated_lists)
33       : graph_(std::move(graph)), mutated_lists_(mutated_lists) {}
34 
runtorch::jit::ListLenRefiner35   bool run() {
36     std::unordered_set<Value*> li_with_len_use;
37     collectListsToRefine(graph_->block(), li_with_len_use);
38     if (lists_to_refine_.empty()) {
39       return false;
40     }
41     ListRefinement refinements;
42     RefineListLens(graph_->block(), std::move(refinements));
43     return changed_;
44   }
45 
46   // we only need to analyze lists that have multiple uses of len(), and we can
47   // only analyze lists that are not mutated
collectListsToRefinetorch::jit::ListLenRefiner48   void collectListsToRefine(
49       Block* b,
50       std::unordered_set<Value*>& li_with_len_use) {
51     for (Node* n : b->nodes()) {
52       for (Block* block : n->blocks()) {
53         collectListsToRefine(block, li_with_len_use);
54       }
55 
56       if (n->kind() != aten::len) {
57         continue;
58       }
59 
60       auto first_input = n->input(0);
61       if (first_input->type()->castRaw<ListType>() &&
62           !mutated_lists_.count(first_input)) {
63         if (!li_with_len_use.count(first_input)) {
64           li_with_len_use.insert(first_input);
65         } else {
66           lists_to_refine_.insert(first_input);
67         }
68       }
69     }
70   }
71 
RefineListLenstorch::jit::ListLenRefiner72   ListRefinement RefineListLens(Block* b, ListRefinement block_refinements) {
73     active_refinements_.push_back(&block_refinements);
74     for (Node* n : b->nodes()) {
75       if (n->matches("aten::eq(int a, int b) -> bool") ||
76           n->matches("aten::ne(int a, int b) -> bool")) {
77         // check for one input constant and the other coming from len(li)
78         for (size_t const_index : {0, 1}) {
79           auto ival = constant_as<int64_t>(n->input(const_index));
80           if (!ival) {
81             continue;
82           }
83           auto li_len = n->input(1 - const_index);
84           if (!li_len->node()->matches("aten::len.t(t[] a) -> int") ||
85               !lists_to_refine_.count(li_len->node()->input())) {
86             continue;
87           }
88           ListRefinement refine;
89           refine[li_len->node()->input()] = *ival;
90           boolean_value_refinements_[n->output()] = n->kind() == aten::eq
91               ? BooleanRefinementMapping::TrueRefinements(std::move(refine))
92               : BooleanRefinementMapping::FalseRefinements(std::move(refine));
93         }
94       } else if (n->kind() == aten::len) {
95         if (auto maybe_len = tryFindRefinement(n->input(0))) {
96           changed_ = true;
97           WithInsertPoint guard(n);
98           n->output()->replaceAllUsesWith(
99               graph_->insertConstant(static_cast<int64_t>(*maybe_len)));
100         }
101       } else if (n->kind() == prim::If) {
102         IfView if_n(n);
103         bool has_cond_ref = boolean_value_refinements_.count(if_n.cond()) != 0;
104         ListRefinement empty;
105         auto true_block_refinements = RefineListLens(
106             if_n.thenBlock(),
107             has_cond_ref ? boolean_value_refinements_[if_n.cond()].true_refine()
108                          : empty);
109         auto false_block_refinements = RefineListLens(
110             if_n.elseBlock(),
111             has_cond_ref
112                 ? boolean_value_refinements_[if_n.cond()].false_refine()
113                 : empty);
114 
115         joinIfRefinements(
116             n,
117             throwing_blocks_,
118             block_refinements,
119             true_block_refinements,
120             false_block_refinements,
121             boolean_value_refinements_);
122       } else {
123         handleCommonRefinentOperators(
124             n, throwing_blocks_, boolean_value_refinements_);
125       }
126     }
127     active_refinements_.pop_back();
128     return block_refinements;
129   };
130 
tryFindRefinementtorch::jit::ListLenRefiner131   std::optional<int64_t> tryFindRefinement(Value* v) {
132     for (const auto& ref : active_refinements_) {
133       auto maybe_refinement = ref->find(v);
134       if (maybe_refinement != ref->end()) {
135         return maybe_refinement->second;
136       }
137     }
138     return std::nullopt;
139   }
140 
141   std::shared_ptr<Graph> graph_;
142   std::unordered_set<Value*> mutated_lists_;
143   // candidate lists for optimizations
144   std::unordered_set<Value*> lists_to_refine_;
145   // A stack of active refinements, one for each block
146   std::vector<ListRefinement*> active_refinements_;
147   // A map from Boolean Value * -> associated refinements
148   std::unordered_map<Value*, BooleanRefinementMapping>
149       boolean_value_refinements_;
150   std::unordered_set<Block*> throwing_blocks_;
151   bool changed_ = false;
152 };
153 
154 // This pass only does optimizations on lists which aren't mutated,
155 // so we first use the Alias Db to collect the set of list values
156 // which we shouldn't optimize.
157 struct PeepholeOptimizeListIdiomsImpl {
PeepholeOptimizeListIdiomsImpltorch::jit::PeepholeOptimizeListIdiomsImpl158   PeepholeOptimizeListIdiomsImpl(
159       std::shared_ptr<Graph> graph,
160       bool refine_list_len)
161       : graph_(std::move(graph)),
162         aliasDb_(std::make_unique<AliasDb>(graph_)),
163         refine_list_len_(refine_list_len) {}
164 
runtorch::jit::PeepholeOptimizeListIdiomsImpl165   bool run() {
166     collectMutatedLists(graph_->block());
167     bool changed = runBlock(graph_->block());
168     if (refine_list_len_) {
169       changed |= ListLenRefiner(graph_, mutated_lists_).run();
170     }
171     return changed;
172   }
173 
174  private:
checkForMutatedListtorch::jit::PeepholeOptimizeListIdiomsImpl175   void checkForMutatedList(Value* v) {
176     if (v->type()->castRaw<ListType>() && aliasDb_->hasWriters(v)) {
177       mutated_lists_.insert(v);
178     }
179   }
180 
collectMutatedListstorch::jit::PeepholeOptimizeListIdiomsImpl181   void collectMutatedLists(Block* b) {
182     for (Value* v : b->inputs()) {
183       checkForMutatedList(v);
184     }
185     for (Node* n : b->nodes()) {
186       for (Value* v : n->outputs()) {
187         checkForMutatedList(v);
188       }
189       for (Block* block : n->blocks()) {
190         collectMutatedLists(block);
191       }
192     }
193   }
194 
optimizeSlicetorch::jit::PeepholeOptimizeListIdiomsImpl195   bool optimizeSlice(Node* slice_node, Node* list_construct_node) {
196     auto start_val = toIValue(slice_node->input(1));
197     auto end_val = toIValue(slice_node->input(2));
198     auto step_val = toIValue(slice_node->input(3));
199 
200     // All args must be constant to apply this optimization.
201     if (start_val == std::nullopt || end_val == std::nullopt ||
202         step_val == std::nullopt) {
203       return false;
204     }
205 
206     int64_t start = start_val->isInt() ? start_val->to<int64_t>()
207                                        : std::numeric_limits<int64_t>::max();
208     int64_t end = end_val->isInt() ? end_val->to<int64_t>()
209                                    : std::numeric_limits<int64_t>::max();
210     int64_t step = step_val->isInt() ? step_val->to<int64_t>() : 1;
211 
212     size_t list_size = list_construct_node->inputs().size();
213     size_t num_values = slice_indices_adjust(list_size, &start, &end, step);
214 
215     WithInsertPoint guard(slice_node);
216     auto slice_list_construct =
217         graph_->insertNode(graph_->create(prim::ListConstruct));
218     slice_list_construct->output()->setType(slice_node->output()->type());
219     for (size_t i = start, j = 0; j < num_values; ++j) {
220       slice_list_construct->addInput(list_construct_node->input(i));
221       i += step;
222     }
223 
224     slice_node->output()->replaceAllUsesWith(slice_list_construct->output());
225     if (mutated_lists_.count(slice_node->output())) {
226       mutated_lists_.insert(slice_list_construct->output());
227     }
228 
229     return true;
230   }
231 
runBlocktorch::jit::PeepholeOptimizeListIdiomsImpl232   bool runBlock(Block* block) {
233     bool changed = false;
234     for (Node* node : block->nodes()) {
235       for (Block* b : node->blocks()) {
236         changed |= runBlock(b);
237       }
238 
239       // only optimizing list ops
240       if (node->inputs().empty() ||
241           !node->input(0)->type()->castRaw<ListType>()) {
242         continue;
243       }
244 
245       auto first_input = node->input(0);
246 
247       // only optimizing ops with unmutated lists
248       if (mutated_lists_.count(first_input)) {
249         continue;
250       }
251 
252       auto list_creation_node = first_input->node();
253       if (list_creation_node->kind() != prim::ListConstruct) {
254         continue;
255       }
256 
257       if (node->kind() == aten::len) {
258         WithInsertPoint guard(node);
259         node->output()->replaceAllUsesWith(graph_->insertConstant(
260             static_cast<int64_t>(first_input->node()->inputs().size())));
261         changed = true;
262       } else if (node->kind() == aten::__getitem__) {
263         if (auto index = toIValue(node->input(1))) {
264           size_t list_size = list_creation_node->inputs().size();
265           if (auto norm_index = normalizeIndex(index->toInt(), list_size)) {
266             node->output()->replaceAllUsesWith(
267                 list_creation_node->input(*norm_index));
268             changed = true;
269           }
270         }
271       } else if (node->kind() == prim::ListUnpack) {
272         // if sizes are unequal it's a runtime error
273         if (list_creation_node->inputs().size() != node->outputs().size()) {
274           continue;
275         }
276         for (size_t i = 0; i < node->outputs().size(); ++i) {
277           node->output(i)->replaceAllUsesWith(list_creation_node->input(i));
278           changed = true;
279         }
280       } else if (node->kind() == aten::add) {
281         if (node->inputs().size() != 2) {
282           continue;
283         }
284         auto second_input = node->input(1);
285         // already checked first, need to check second
286         if (mutated_lists_.count(second_input)) {
287           continue;
288         }
289         if (second_input->node()->kind() != prim::ListConstruct) {
290           continue;
291         }
292         WithInsertPoint guard(node);
293         auto list_construct =
294             graph_->insertNode(graph_->create(prim::ListConstruct));
295         list_construct->output()->setType(node->output()->type());
296         for (Value* v : first_input->node()->inputs()) {
297           list_construct->addInput(v);
298         }
299         for (Value* v : second_input->node()->inputs()) {
300           list_construct->addInput(v);
301         }
302         node->output()->replaceAllUsesWith(list_construct->output());
303         if (mutated_lists_.count(node->output())) {
304           mutated_lists_.insert(list_construct->output());
305         }
306         changed = true;
307       } else if (node->kind() == aten::slice) {
308         changed |= optimizeSlice(node, first_input->node());
309       }
310     }
311     return changed;
312   }
313 
314   std::unordered_set<Value*> mutated_lists_;
315   std::shared_ptr<Graph> graph_;
316   std::unique_ptr<AliasDb> aliasDb_;
317   bool refine_list_len_;
318 };
319 
PeepholeOptimizeListIdioms(const std::shared_ptr<Graph> & graph,bool refine_list_len)320 bool PeepholeOptimizeListIdioms(
321     const std::shared_ptr<Graph>& graph,
322     bool refine_list_len) {
323   PeepholeOptimizeListIdiomsImpl opt(graph, refine_list_len);
324   return opt.run();
325 }
326 
327 } // namespace torch::jit
328