1 #include <ATen/core/jit_type.h>
2 #include <torch/csrc/jit/ir/alias_analysis.h>
3 #include <torch/csrc/jit/ir/ir_views.h>
4 #include <torch/csrc/jit/jit_log.h>
5 #include <torch/csrc/jit/passes/dead_code_elimination.h>
6 #include <torch/csrc/jit/passes/peephole.h>
7 #include <torch/csrc/jit/passes/peephole_list_idioms.h>
8 #include <torch/csrc/jit/passes/value_refinement_utils.h>
9 #include <torch/csrc/jit/runtime/graph_executor.h>
10 #include <torch/csrc/jit/runtime/slice_indices_adjust.h>
11 #include <limits>
12 #include <utility>
13
14 namespace torch::jit {
15
normalizeIndex(int64_t index,size_t len)16 static std::optional<size_t> normalizeIndex(int64_t index, size_t len) {
17 if (index < 0) {
18 index = index + len;
19 }
20 if (index >= 0 && index < static_cast<int64_t>(len)) {
21 return index;
22 } else {
23 return std::nullopt;
24 }
25 }
26
27 // see [value refinement algorithm]
28
29 struct ListLenRefiner {
ListLenRefinertorch::jit::ListLenRefiner30 ListLenRefiner(
31 std::shared_ptr<Graph> graph,
32 std::unordered_set<Value*>& mutated_lists)
33 : graph_(std::move(graph)), mutated_lists_(mutated_lists) {}
34
runtorch::jit::ListLenRefiner35 bool run() {
36 std::unordered_set<Value*> li_with_len_use;
37 collectListsToRefine(graph_->block(), li_with_len_use);
38 if (lists_to_refine_.empty()) {
39 return false;
40 }
41 ListRefinement refinements;
42 RefineListLens(graph_->block(), std::move(refinements));
43 return changed_;
44 }
45
46 // we only need to analyze lists that have multiple uses of len(), and we can
47 // only analyze lists that are not mutated
collectListsToRefinetorch::jit::ListLenRefiner48 void collectListsToRefine(
49 Block* b,
50 std::unordered_set<Value*>& li_with_len_use) {
51 for (Node* n : b->nodes()) {
52 for (Block* block : n->blocks()) {
53 collectListsToRefine(block, li_with_len_use);
54 }
55
56 if (n->kind() != aten::len) {
57 continue;
58 }
59
60 auto first_input = n->input(0);
61 if (first_input->type()->castRaw<ListType>() &&
62 !mutated_lists_.count(first_input)) {
63 if (!li_with_len_use.count(first_input)) {
64 li_with_len_use.insert(first_input);
65 } else {
66 lists_to_refine_.insert(first_input);
67 }
68 }
69 }
70 }
71
RefineListLenstorch::jit::ListLenRefiner72 ListRefinement RefineListLens(Block* b, ListRefinement block_refinements) {
73 active_refinements_.push_back(&block_refinements);
74 for (Node* n : b->nodes()) {
75 if (n->matches("aten::eq(int a, int b) -> bool") ||
76 n->matches("aten::ne(int a, int b) -> bool")) {
77 // check for one input constant and the other coming from len(li)
78 for (size_t const_index : {0, 1}) {
79 auto ival = constant_as<int64_t>(n->input(const_index));
80 if (!ival) {
81 continue;
82 }
83 auto li_len = n->input(1 - const_index);
84 if (!li_len->node()->matches("aten::len.t(t[] a) -> int") ||
85 !lists_to_refine_.count(li_len->node()->input())) {
86 continue;
87 }
88 ListRefinement refine;
89 refine[li_len->node()->input()] = *ival;
90 boolean_value_refinements_[n->output()] = n->kind() == aten::eq
91 ? BooleanRefinementMapping::TrueRefinements(std::move(refine))
92 : BooleanRefinementMapping::FalseRefinements(std::move(refine));
93 }
94 } else if (n->kind() == aten::len) {
95 if (auto maybe_len = tryFindRefinement(n->input(0))) {
96 changed_ = true;
97 WithInsertPoint guard(n);
98 n->output()->replaceAllUsesWith(
99 graph_->insertConstant(static_cast<int64_t>(*maybe_len)));
100 }
101 } else if (n->kind() == prim::If) {
102 IfView if_n(n);
103 bool has_cond_ref = boolean_value_refinements_.count(if_n.cond()) != 0;
104 ListRefinement empty;
105 auto true_block_refinements = RefineListLens(
106 if_n.thenBlock(),
107 has_cond_ref ? boolean_value_refinements_[if_n.cond()].true_refine()
108 : empty);
109 auto false_block_refinements = RefineListLens(
110 if_n.elseBlock(),
111 has_cond_ref
112 ? boolean_value_refinements_[if_n.cond()].false_refine()
113 : empty);
114
115 joinIfRefinements(
116 n,
117 throwing_blocks_,
118 block_refinements,
119 true_block_refinements,
120 false_block_refinements,
121 boolean_value_refinements_);
122 } else {
123 handleCommonRefinentOperators(
124 n, throwing_blocks_, boolean_value_refinements_);
125 }
126 }
127 active_refinements_.pop_back();
128 return block_refinements;
129 };
130
tryFindRefinementtorch::jit::ListLenRefiner131 std::optional<int64_t> tryFindRefinement(Value* v) {
132 for (const auto& ref : active_refinements_) {
133 auto maybe_refinement = ref->find(v);
134 if (maybe_refinement != ref->end()) {
135 return maybe_refinement->second;
136 }
137 }
138 return std::nullopt;
139 }
140
141 std::shared_ptr<Graph> graph_;
142 std::unordered_set<Value*> mutated_lists_;
143 // candidate lists for optimizations
144 std::unordered_set<Value*> lists_to_refine_;
145 // A stack of active refinements, one for each block
146 std::vector<ListRefinement*> active_refinements_;
147 // A map from Boolean Value * -> associated refinements
148 std::unordered_map<Value*, BooleanRefinementMapping>
149 boolean_value_refinements_;
150 std::unordered_set<Block*> throwing_blocks_;
151 bool changed_ = false;
152 };
153
154 // This pass only does optimizations on lists which aren't mutated,
155 // so we first use the Alias Db to collect the set of list values
156 // which we shouldn't optimize.
157 struct PeepholeOptimizeListIdiomsImpl {
PeepholeOptimizeListIdiomsImpltorch::jit::PeepholeOptimizeListIdiomsImpl158 PeepholeOptimizeListIdiomsImpl(
159 std::shared_ptr<Graph> graph,
160 bool refine_list_len)
161 : graph_(std::move(graph)),
162 aliasDb_(std::make_unique<AliasDb>(graph_)),
163 refine_list_len_(refine_list_len) {}
164
runtorch::jit::PeepholeOptimizeListIdiomsImpl165 bool run() {
166 collectMutatedLists(graph_->block());
167 bool changed = runBlock(graph_->block());
168 if (refine_list_len_) {
169 changed |= ListLenRefiner(graph_, mutated_lists_).run();
170 }
171 return changed;
172 }
173
174 private:
checkForMutatedListtorch::jit::PeepholeOptimizeListIdiomsImpl175 void checkForMutatedList(Value* v) {
176 if (v->type()->castRaw<ListType>() && aliasDb_->hasWriters(v)) {
177 mutated_lists_.insert(v);
178 }
179 }
180
collectMutatedListstorch::jit::PeepholeOptimizeListIdiomsImpl181 void collectMutatedLists(Block* b) {
182 for (Value* v : b->inputs()) {
183 checkForMutatedList(v);
184 }
185 for (Node* n : b->nodes()) {
186 for (Value* v : n->outputs()) {
187 checkForMutatedList(v);
188 }
189 for (Block* block : n->blocks()) {
190 collectMutatedLists(block);
191 }
192 }
193 }
194
optimizeSlicetorch::jit::PeepholeOptimizeListIdiomsImpl195 bool optimizeSlice(Node* slice_node, Node* list_construct_node) {
196 auto start_val = toIValue(slice_node->input(1));
197 auto end_val = toIValue(slice_node->input(2));
198 auto step_val = toIValue(slice_node->input(3));
199
200 // All args must be constant to apply this optimization.
201 if (start_val == std::nullopt || end_val == std::nullopt ||
202 step_val == std::nullopt) {
203 return false;
204 }
205
206 int64_t start = start_val->isInt() ? start_val->to<int64_t>()
207 : std::numeric_limits<int64_t>::max();
208 int64_t end = end_val->isInt() ? end_val->to<int64_t>()
209 : std::numeric_limits<int64_t>::max();
210 int64_t step = step_val->isInt() ? step_val->to<int64_t>() : 1;
211
212 size_t list_size = list_construct_node->inputs().size();
213 size_t num_values = slice_indices_adjust(list_size, &start, &end, step);
214
215 WithInsertPoint guard(slice_node);
216 auto slice_list_construct =
217 graph_->insertNode(graph_->create(prim::ListConstruct));
218 slice_list_construct->output()->setType(slice_node->output()->type());
219 for (size_t i = start, j = 0; j < num_values; ++j) {
220 slice_list_construct->addInput(list_construct_node->input(i));
221 i += step;
222 }
223
224 slice_node->output()->replaceAllUsesWith(slice_list_construct->output());
225 if (mutated_lists_.count(slice_node->output())) {
226 mutated_lists_.insert(slice_list_construct->output());
227 }
228
229 return true;
230 }
231
runBlocktorch::jit::PeepholeOptimizeListIdiomsImpl232 bool runBlock(Block* block) {
233 bool changed = false;
234 for (Node* node : block->nodes()) {
235 for (Block* b : node->blocks()) {
236 changed |= runBlock(b);
237 }
238
239 // only optimizing list ops
240 if (node->inputs().empty() ||
241 !node->input(0)->type()->castRaw<ListType>()) {
242 continue;
243 }
244
245 auto first_input = node->input(0);
246
247 // only optimizing ops with unmutated lists
248 if (mutated_lists_.count(first_input)) {
249 continue;
250 }
251
252 auto list_creation_node = first_input->node();
253 if (list_creation_node->kind() != prim::ListConstruct) {
254 continue;
255 }
256
257 if (node->kind() == aten::len) {
258 WithInsertPoint guard(node);
259 node->output()->replaceAllUsesWith(graph_->insertConstant(
260 static_cast<int64_t>(first_input->node()->inputs().size())));
261 changed = true;
262 } else if (node->kind() == aten::__getitem__) {
263 if (auto index = toIValue(node->input(1))) {
264 size_t list_size = list_creation_node->inputs().size();
265 if (auto norm_index = normalizeIndex(index->toInt(), list_size)) {
266 node->output()->replaceAllUsesWith(
267 list_creation_node->input(*norm_index));
268 changed = true;
269 }
270 }
271 } else if (node->kind() == prim::ListUnpack) {
272 // if sizes are unequal it's a runtime error
273 if (list_creation_node->inputs().size() != node->outputs().size()) {
274 continue;
275 }
276 for (size_t i = 0; i < node->outputs().size(); ++i) {
277 node->output(i)->replaceAllUsesWith(list_creation_node->input(i));
278 changed = true;
279 }
280 } else if (node->kind() == aten::add) {
281 if (node->inputs().size() != 2) {
282 continue;
283 }
284 auto second_input = node->input(1);
285 // already checked first, need to check second
286 if (mutated_lists_.count(second_input)) {
287 continue;
288 }
289 if (second_input->node()->kind() != prim::ListConstruct) {
290 continue;
291 }
292 WithInsertPoint guard(node);
293 auto list_construct =
294 graph_->insertNode(graph_->create(prim::ListConstruct));
295 list_construct->output()->setType(node->output()->type());
296 for (Value* v : first_input->node()->inputs()) {
297 list_construct->addInput(v);
298 }
299 for (Value* v : second_input->node()->inputs()) {
300 list_construct->addInput(v);
301 }
302 node->output()->replaceAllUsesWith(list_construct->output());
303 if (mutated_lists_.count(node->output())) {
304 mutated_lists_.insert(list_construct->output());
305 }
306 changed = true;
307 } else if (node->kind() == aten::slice) {
308 changed |= optimizeSlice(node, first_input->node());
309 }
310 }
311 return changed;
312 }
313
314 std::unordered_set<Value*> mutated_lists_;
315 std::shared_ptr<Graph> graph_;
316 std::unique_ptr<AliasDb> aliasDb_;
317 bool refine_list_len_;
318 };
319
PeepholeOptimizeListIdioms(const std::shared_ptr<Graph> & graph,bool refine_list_len)320 bool PeepholeOptimizeListIdioms(
321 const std::shared_ptr<Graph>& graph,
322 bool refine_list_len) {
323 PeepholeOptimizeListIdiomsImpl opt(graph, refine_list_len);
324 return opt.run();
325 }
326
327 } // namespace torch::jit
328