1 #include <torch/csrc/jit/tensorexpr/loopnest.h>
2
3 #include <algorithm>
4 #include <iostream>
5 #include <stdexcept>
6 #include <unordered_map>
7 #include <unordered_set>
8 #include <utility>
9 #include <vector>
10
11 #include <c10/util/Logging.h>
12 #include <c10/util/irange.h>
13
14 #include <ATen/core/functional.h>
15 #include <torch/csrc/jit/jit_log.h>
16 #include <torch/csrc/jit/tensorexpr/analysis.h>
17 #include <torch/csrc/jit/tensorexpr/bounds_inference.h>
18 #include <torch/csrc/jit/tensorexpr/eval.h>
19 #include <torch/csrc/jit/tensorexpr/expr.h>
20 #include <torch/csrc/jit/tensorexpr/ir.h>
21 #include <torch/csrc/jit/tensorexpr/ir_cloner.h>
22 #include <torch/csrc/jit/tensorexpr/ir_mutator.h>
23 #include <torch/csrc/jit/tensorexpr/ir_printer.h>
24 #include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
25 #include <torch/csrc/jit/tensorexpr/ir_verifier.h>
26 #include <torch/csrc/jit/tensorexpr/tensor.h>
27
28 #include <stdexcept>
29 #include <unordered_map>
30 #include <unordered_set>
31 #include <vector>
32
33 namespace torch::jit::tensorexpr {
34
LoopNest(const LoopNest & other)35 LoopNest::LoopNest(const LoopNest& other)
36 : root_stmt_(Stmt::clone(other.root_stmt_)),
37 output_bufs_(other.output_bufs_) {
38 GRAPH_DEBUG("Origin Stmt in LoopNest:\n", std::to_string(root_stmt_));
39 verify(root_stmt_);
40 }
41
LoopNest(StmtPtr stmt,std::unordered_set<BufPtr> output_bufs)42 LoopNest::LoopNest(StmtPtr stmt, std::unordered_set<BufPtr> output_bufs)
43 : root_stmt_(std::move(stmt)), output_bufs_(std::move(output_bufs)) {
44 GRAPH_DEBUG("Origin Stmt in LoopNest:\n", std::to_string(root_stmt_));
45 verify(root_stmt_);
46 }
47
LoopNest(const std::vector<Tensor> & output_tensors,const std::vector<Tensor> & tensors_to_compute)48 LoopNest::LoopNest(
49 const std::vector<Tensor>& output_tensors,
50 const std::vector<Tensor>& tensors_to_compute) {
51 initialize(output_tensors, tensors_to_compute);
52 GRAPH_DEBUG("Origin Stmt in LoopNest:\n", std::to_string(root_stmt_));
53 verify(root_stmt_);
54 }
55
LoopNest(const std::vector<Tensor> & output_tensors)56 LoopNest::LoopNest(const std::vector<Tensor>& output_tensors) {
57 initialize(output_tensors, output_tensors);
58 GRAPH_DEBUG("Origin Stmt in LoopNest:\n", std::to_string(root_stmt_));
59 verify(root_stmt_);
60 }
61
getIntermediateBufs() const62 std::vector<BufPtr> LoopNest::getIntermediateBufs() const {
63 std::vector<BufPtr> result;
64 std::unordered_set<BufPtr> result_set;
65 auto input_bufs = getInputBufs();
66 auto bufs = NodeFinder<Buf>::find(root_stmt_);
67 for (const auto& buf : bufs) {
68 if (!output_bufs_.count(buf) && !input_bufs.count(buf) &&
69 !result_set.count(buf)) {
70 result.push_back(buf);
71 result_set.insert(buf);
72 }
73 }
74 return result;
75 }
76
getInputBufs() const77 const std::unordered_set<BufPtr> LoopNest::getInputBufs() const {
78 std::unordered_set<BufPtr> result;
79 auto buf_load_store_uses = findLoadOrStoreUses(root_stmt_);
80 for (auto& kv : buf_load_store_uses) {
81 bool has_store = false;
82 for (auto& use : kv.second) {
83 if (use.isStore) {
84 has_store = true;
85 break;
86 }
87 }
88 if (!has_store) {
89 result.insert(kv.first);
90 }
91 }
92 return result;
93 }
94
95 class IndexFlattener : public IRMutator {
96 public:
flatten(const StmtPtr & s)97 StmtPtr flatten(const StmtPtr& s) {
98 return s->accept_mutator(this);
99 }
100
mutate(const LoadPtr & v)101 ExprPtr mutate(const LoadPtr& v) override {
102 if (v->indices().size() == 1) {
103 return v;
104 }
105 return alloc<Load>(
106 v->dtype(),
107 v->buf(),
108 std::vector<ExprPtr>({flatten_index(
109 v->buf()->dims(), v->indices(), v->buf()->strides())}));
110 }
111
mutate(const StorePtr & v)112 StmtPtr mutate(const StorePtr& v) override {
113 ExprPtr value = v->value();
114 ExprPtr new_value = value->accept_mutator(this);
115 if (v->indices().size() == 1 && value == new_value) {
116 return v;
117 }
118 std::vector<ExprPtr> indices = {
119 flatten_index(v->buf()->dims(), v->indices(), v->buf()->strides())};
120 v->set_indices(indices);
121 v->set_value(new_value);
122 return v;
123 }
124 };
125
isValidIdentifierChar(char c,size_t pos)126 static bool isValidIdentifierChar(char c, size_t pos) {
127 return islower(c) || isupper(c) || c == '_' || (pos > 0 && isdigit(c));
128 }
129
130 // replaces all invalid characters with underscore
sanitizeName(const std::string & input_name)131 std::string sanitizeName(const std::string& input_name) {
132 std::stringstream sanitized_name;
133 for (size_t i = 0; i < input_name.size(); ++i) {
134 if (isValidIdentifierChar(input_name[i], i)) {
135 sanitized_name << input_name[i];
136 } else {
137 if (i == 0) {
138 // Don't start names with underscore
139 sanitized_name << "v";
140 }
141 sanitized_name << "_";
142 }
143 }
144 return sanitized_name.str();
145 }
146
147 class VarNameSanitizer : public IRMutator {
148 public:
mutate(const BufPtr & v)149 ExprPtr mutate(const BufPtr& v) override {
150 if (seen_bufs_.count(v)) {
151 return v;
152 }
153 const std::string& name = v->name_hint();
154 auto new_name = sanitizeName(name);
155 if (taken_names_.count(new_name)) {
156 new_name = getNextAvailableName(new_name);
157 }
158 v->set_name_hint(new_name);
159 taken_names_.insert(new_name);
160 seen_bufs_.insert(v);
161 return v;
162 }
163
mutate(const VarPtr & v)164 ExprPtr mutate(const VarPtr& v) override {
165 if (seen_vars_.count(v)) {
166 return v;
167 }
168 const std::string& name = v->name_hint();
169 auto new_name = sanitizeName(name);
170 if (taken_names_.count(new_name)) {
171 new_name = getNextAvailableName(new_name);
172 }
173 v->set_name_hint(new_name);
174 taken_names_.insert(new_name);
175 seen_vars_.insert(v);
176 return v;
177 }
178
mutate(const ForPtr & v)179 StmtPtr mutate(const ForPtr& v) override {
180 auto new_name = getNextAvailableName(getIndexVarNameAtLevel(level_));
181 if (seen_index_vars_.count(v->var())) {
182 auto new_var = alloc<Var>("", v->var()->dtype());
183 Substitute(v, {{v->var(), new_var}});
184 }
185 v->var()->set_name_hint(new_name);
186 seen_index_vars_.insert(v->var());
187 seen_vars_.insert(v->var());
188 taken_names_.insert(new_name);
189 level_++;
190 v->body()->accept_mutator(this);
191 level_--;
192 v->start()->accept_mutator(this);
193 v->stop()->accept_mutator(this);
194 return v;
195 }
196
getIndexVarNameAtLevel(int level_)197 std::string getIndexVarNameAtLevel(int level_) {
198 auto names_num = index_var_names_.size();
199 auto counter = level_ / names_num;
200 if (counter == 0) {
201 return index_var_names_[level_ % names_num];
202 } else {
203 return index_var_names_[level_ % names_num] + std::to_string(counter);
204 }
205 }
getNextAvailableName(const std::string & base_name)206 std::string getNextAvailableName(const std::string& base_name) {
207 std::string name = base_name;
208 int counter = 0;
209 while (taken_names_.count(name)) {
210 counter++;
211 name = base_name + "_" + std::to_string(counter);
212 }
213 return name;
214 }
215
216 private:
217 std::vector<std::string> index_var_names_ =
218 {"i", "j", "k", "l", "m", "n", "o", "p"};
219 std::unordered_set<std::string> taken_names_;
220 std::unordered_set<VarPtr> seen_index_vars_;
221 std::unordered_set<VarPtr> seen_vars_;
222 std::unordered_set<BufPtr> seen_bufs_;
223 int level_ = 0;
224 };
225
sanitizeNames(StmtPtr s)226 StmtPtr LoopNest::sanitizeNames(StmtPtr s) {
227 VarNameSanitizer r;
228 s->accept_mutator(&r);
229 return s;
230 }
231
232 class Vectorizer : public IRMutator {
233 public:
vectorize(ForPtr v)234 StmtPtr vectorize(ForPtr v) {
235 StmtPtr body = v->body();
236 VarPtr var = v->var();
237 ExprPtr start = v->start();
238 ExprPtr stop = v->stop();
239
240 auto start_imm = intValue(start);
241 auto stop_imm = intValue(stop);
242 if (!start_imm) {
243 // Can't vectorize due to non-constant loop start!
244 success_ = false;
245 return v;
246 }
247
248 if (!stop_imm) {
249 // Can't vectorize due to non-constant loop stop!
250 success_ = false;
251 return v;
252 }
253
254 var_ = var;
255 start_ = immLike(start, *start_imm);
256 lanes_ = *stop_imm;
257
258 StmtPtr new_body = body->accept_mutator(this);
259 if (new_body == body) {
260 // Vectorization failed!
261 success_ = false;
262 return v;
263 }
264
265 return new_body;
266 }
267
success() const268 bool success() const {
269 return success_;
270 }
271
mutate(const AddPtr & v)272 ExprPtr mutate(const AddPtr& v) override {
273 std::vector<ExprPtr> inputs = {v->lhs(), v->rhs()};
274 return try_vectorize(v, inputs, [&]() {
275 return ExprHandle(inputs[0]) + ExprHandle(inputs[1]);
276 });
277 }
278
mutate(const SubPtr & v)279 ExprPtr mutate(const SubPtr& v) override {
280 std::vector<ExprPtr> inputs = {v->lhs(), v->rhs()};
281 return try_vectorize(v, inputs, [&]() {
282 return ExprHandle(inputs[0]) - ExprHandle(inputs[1]);
283 });
284 }
285
mutate(const MulPtr & v)286 ExprPtr mutate(const MulPtr& v) override {
287 std::vector<ExprPtr> inputs = {v->lhs(), v->rhs()};
288 return try_vectorize(v, inputs, [&]() {
289 return ExprHandle(inputs[0]) * ExprHandle(inputs[1]);
290 });
291 }
292
mutate(const DivPtr & v)293 ExprPtr mutate(const DivPtr& v) override {
294 std::vector<ExprPtr> inputs = {v->lhs(), v->rhs()};
295 return try_vectorize(v, inputs, [&]() {
296 return ExprHandle(inputs[0]) / ExprHandle(inputs[1]);
297 });
298 }
299
mutate(const ModPtr & v)300 ExprPtr mutate(const ModPtr& v) override {
301 std::vector<ExprPtr> inputs = {v->lhs(), v->rhs()};
302 return try_vectorize(v, inputs, [&]() {
303 return ExprHandle(inputs[0]) % ExprHandle(inputs[1]);
304 });
305 }
306
mutate(const AndPtr & v)307 ExprPtr mutate(const AndPtr& v) override {
308 std::vector<ExprPtr> inputs = {v->lhs(), v->rhs()};
309 return try_vectorize(v, inputs, [&]() {
310 return ExprHandle(inputs[0]) & ExprHandle(inputs[1]);
311 });
312 }
313
mutate(const OrPtr & v)314 ExprPtr mutate(const OrPtr& v) override {
315 std::vector<ExprPtr> inputs = {v->lhs(), v->rhs()};
316 return try_vectorize(v, inputs, [&]() {
317 return ExprHandle(inputs[0]) | ExprHandle(inputs[1]);
318 });
319 }
320
mutate(const XorPtr & v)321 ExprPtr mutate(const XorPtr& v) override {
322 std::vector<ExprPtr> inputs = {v->lhs(), v->rhs()};
323 return try_vectorize(v, inputs, [&]() {
324 return ExprHandle(inputs[0]) ^ ExprHandle(inputs[1]);
325 });
326 }
327
mutate(const LshiftPtr & v)328 ExprPtr mutate(const LshiftPtr& v) override {
329 std::vector<ExprPtr> inputs = {v->lhs(), v->rhs()};
330 return try_vectorize(v, inputs, [&]() {
331 return ExprHandle(inputs[0]) << ExprHandle(inputs[1]);
332 });
333 }
334
mutate(const RshiftPtr & v)335 ExprPtr mutate(const RshiftPtr& v) override {
336 std::vector<ExprPtr> inputs = {v->lhs(), v->rhs()};
337 return try_vectorize(v, inputs, [&]() {
338 return ExprHandle(inputs[0]) >> ExprHandle(inputs[1]);
339 });
340 }
341
mutate(const MaxPtr & v)342 ExprPtr mutate(const MaxPtr& v) override {
343 std::vector<ExprPtr> inputs = {v->lhs(), v->rhs()};
344 return try_vectorize(v, inputs, [&]() {
345 return Max::make(
346 ExprHandle(inputs[0]), ExprHandle(inputs[1]), v->propagate_nans());
347 });
348 }
349
mutate(const MinPtr & v)350 ExprPtr mutate(const MinPtr& v) override {
351 std::vector<ExprPtr> inputs = {v->lhs(), v->rhs()};
352 return try_vectorize(v, inputs, [&]() {
353 return Min::make(
354 ExprHandle(inputs[0]), ExprHandle(inputs[1]), v->propagate_nans());
355 });
356 }
357
mutate(const CompareSelectPtr & v)358 ExprPtr mutate(const CompareSelectPtr& v) override {
359 std::vector<ExprPtr> inputs = {
360 v->lhs(), v->rhs(), v->ret_val1(), v->ret_val2()};
361 return try_vectorize(v, inputs, [&]() {
362 return CompareSelect::make(
363 ExprHandle(inputs[0]),
364 ExprHandle(inputs[1]),
365 ExprHandle(inputs[2]),
366 ExprHandle(inputs[3]),
367 v->compare_select_op(),
368 v->bias());
369 });
370 }
371
mutate(const BitCastPtr & v)372 ExprPtr mutate(const BitCastPtr& v) override {
373 std::vector<ExprPtr> inputs = {v->src_value()};
374 return try_vectorize(v, inputs, [&]() {
375 return BitCast::make(
376 Dtype(v->dtype().scalar_type(), lanes_), ExprHandle(inputs[0]));
377 });
378 }
379
mutate(const CastPtr & v)380 ExprPtr mutate(const CastPtr& v) override {
381 std::vector<ExprPtr> inputs = {v->src_value()};
382 return try_vectorize(v, inputs, [&]() {
383 return Cast::make(
384 Dtype(v->dtype().scalar_type(), lanes_), ExprHandle(inputs[0]));
385 });
386 }
387
mutate(const VarPtr & v)388 ExprPtr mutate(const VarPtr& v) override {
389 if (v == var_) {
390 return Ramp::make(
391 ExprHandle(start_), ExprHandle(immLike(start_, 1)), lanes_)
392 .node();
393 }
394
395 return v;
396 }
397
mutate(const RampPtr & v)398 ExprPtr mutate(const RampPtr& v) override {
399 ExprPtr base = v->base();
400 ExprPtr stride = v->stride();
401
402 ExprPtr base_new = base->accept_mutator(this);
403 ExprPtr stride_new = stride->accept_mutator(this);
404
405 if (base_new == base && stride_new == stride) {
406 return v;
407 }
408
409 // Can't vectorize a Ramp!
410 success_ = false;
411 return v;
412 }
413
mutate(const LoadPtr & v)414 ExprPtr mutate(const LoadPtr& v) override {
415 Dtype dtype(v->dtype().scalar_type(), lanes_);
416 BufPtr buf = v->buf();
417 std::vector<ExprPtr> inputs = {v->flat_index()};
418 return try_vectorize(v, inputs, [&]() {
419 return Load::make(dtype, BufHandle(buf), {ExprHandle(inputs[0])});
420 });
421 }
422
mutate(const ReduceOpPtr & v)423 ExprPtr mutate(const ReduceOpPtr& v) override {
424 Dtype dtype(v->dtype().scalar_type(), lanes_);
425
426 std::vector<ExprPtr> inputs = {v->body()};
427
428 auto out = try_vectorize(v, inputs, [&]() {
429 return ExprHandle(
430 alloc<ReduceOp>(inputs[0], v->reduce_args(), v->reducer()));
431 });
432 return out;
433 }
434
mutate(const BroadcastPtr & v)435 ExprPtr mutate(const BroadcastPtr& v) override {
436 ExprPtr val = v->value();
437 ExprPtr new_val = val->accept_mutator(this);
438 if (new_val == val) {
439 return v;
440 }
441
442 // Can't vectorize a Broadcast!
443 success_ = false;
444 return v;
445 }
446
mutate(const IfThenElsePtr & v)447 ExprPtr mutate(const IfThenElsePtr& v) override {
448 ExprPtr condition = v->condition();
449 ExprPtr new_condition = condition->accept_mutator(this);
450 if (new_condition != condition) {
451 // Can't vectorize an IfThenElse condition!
452 success_ = false;
453 return v;
454 }
455
456 std::vector<ExprPtr> inputs = {v->true_value(), v->false_value()};
457 return try_vectorize(v, inputs, [&]() {
458 return IfThenElse::make(
459 ExprHandle(condition), ExprHandle(inputs[0]), ExprHandle(inputs[1]));
460 });
461 }
462
mutate(const IntrinsicsPtr & v)463 ExprPtr mutate(const IntrinsicsPtr& v) override {
464 std::vector<ExprPtr> inputs = v->params();
465 return try_vectorize(v, inputs, [&]() {
466 return ExprHandle(alloc<Intrinsics>(v->op_type(), inputs));
467 });
468 }
469
mutate(const StorePtr & v)470 StmtPtr mutate(const StorePtr& v) override {
471 BufPtr buf = v->buf();
472 std::vector<ExprPtr> inputs = {v->flat_index(), v->value()};
473 return try_vectorize(v, inputs, [&]() {
474 return Store::make(
475 BufHandle(buf), {ExprHandle(inputs[0])}, ExprHandle(inputs[1]));
476 });
477 }
478
mutate(const ForPtr & v)479 StmtPtr mutate(const ForPtr& v) override {
480 VarPtr var = v->var();
481 ExprPtr start = v->start();
482 ExprPtr stop = v->stop();
483 LoopOptions loop_options = v->loop_options();
484
485 ExprPtr new_start = start->accept_mutator(this);
486 ExprPtr new_stop = stop->accept_mutator(this);
487
488 if (new_start != start || new_stop != stop) {
489 // Can't vectorize nested For with dependent loop bounds!
490 success_ = false;
491 return v;
492 }
493
494 StmtPtr body = v->body();
495 StmtPtr new_body = body->accept_mutator(this);
496
497 if (new_body == body) {
498 return (ForPtr)v;
499 }
500
501 return alloc<For>(var, new_start, new_stop, new_body, loop_options);
502 }
503
mutate(const BlockPtr & v)504 StmtPtr mutate(const BlockPtr& v) override {
505 // IRMutator does in-place mutations. But the logic in vectorization checks
506 // for success by looking for a new stmt. So, we override the in-place
507 // mutations and create a clone here if any of its statements change.
508 // TODO: Can we change the logic of vectorizer so that we don't need this?
509 bool any_change = false;
510 std::vector<StmtPtr> stmts;
511 for (const StmtPtr& stmt : *v) {
512 StmtPtr stmt_new = stmt->accept_mutator(this);
513 if (stmt != stmt_new) {
514 any_change = true;
515 } else {
516 stmt_new = Stmt::clone(stmt);
517 }
518 if (stmt_new) {
519 stmts.push_back(stmt_new);
520 }
521 }
522 if (any_change) {
523 return alloc<Block>(stmts);
524 }
525 return v;
526 }
527
528 template <typename T>
try_vectorize(ExprPtr e,std::vector<ExprPtr> & inputs,T && vec_ctor)529 ExprPtr try_vectorize(ExprPtr e, std::vector<ExprPtr>& inputs, T&& vec_ctor) {
530 bool vectorize = vectorize_inputs(inputs);
531 if (vectorize) {
532 return vec_ctor().node();
533 }
534
535 return e;
536 }
537
538 template <typename T>
try_vectorize(StmtPtr s,std::vector<ExprPtr> & inputs,T && vec_ctor)539 StmtPtr try_vectorize(StmtPtr s, std::vector<ExprPtr>& inputs, T&& vec_ctor) {
540 bool vectorize = vectorize_inputs(inputs);
541 if (vectorize) {
542 return vec_ctor();
543 }
544
545 return s;
546 }
547
vectorize_inputs(std::vector<ExprPtr> & inputs)548 bool vectorize_inputs(std::vector<ExprPtr>& inputs) {
549 bool any_vectorized = false;
550 std::vector<ExprPtr> new_inputs;
551
552 // Attempt to vectorize each input.
553 for (ExprPtr& in : inputs) {
554 ExprPtr new_in = in->accept_mutator(this);
555 new_inputs.push_back(new_in);
556 if (new_in != in) {
557 any_vectorized = true;
558 }
559 }
560
561 // If none of them vectorized, then don't vectorize this.
562 if (!any_vectorized) {
563 return false;
564 }
565
566 // Insert broadcasts for any inputs that weren't vectorized.
567 for (size_t i = 0; i < inputs.size(); ++i) {
568 if (inputs[i] == new_inputs[i]) {
569 inputs[i] = Broadcast::make(ExprHandle(inputs[i]), lanes_).node();
570 } else {
571 inputs[i] = new_inputs[i];
572 }
573 }
574
575 // And then vectorize this node.
576 return true;
577 }
578
579 VarPtr var_ = nullptr;
580 int64_t lanes_ = 0;
581 ExprPtr start_ = nullptr;
582 bool success_ = true;
583 };
584
vectorize(const ForPtr & f)585 bool LoopNest::vectorize(const ForPtr& f) {
586 BlockPtr b = to<Block>(f->get_parent());
587 if (!b) {
588 return false;
589 }
590
591 // Can't vectorize reduction axes.
592 auto reductions = NodeFinder<ReduceOp>::find(f);
593 for (const auto& r : reductions) {
594 if (std::find(r->reduce_args().begin(), r->reduce_args().end(), f->var()) !=
595 r->reduce_args().end()) {
596 return false;
597 }
598 }
599
600 Vectorizer v;
601 StmtPtr new_f = nullptr;
602 new_f = Stmt::clone(f);
603 normalize(to<For>(new_f));
604 new_f = FlattenIndexes(new_f);
605 new_f = v.vectorize(to<For>(new_f));
606 if (!v.success()) {
607 // We clone f before vectorizing. So, any partial vectorization will
608 // have modified the clone. In case of an exception, we can continue
609 // using f.
610 new_f = f;
611 }
612
613 if (new_f != f) {
614 b->replace_stmt(f, IRSimplifier::simplify(new_f));
615 return true;
616 }
617
618 // Vectorization was not successful.
619 return false;
620 }
621
initialize(const std::vector<Tensor> & output_tensors,const std::vector<Tensor> & tensors_to_compute)622 void LoopNest::initialize(
623 const std::vector<Tensor>& output_tensors,
624 const std::vector<Tensor>& tensors_to_compute) {
625 for (const auto& t : output_tensors) {
626 output_bufs_.insert(t.buf());
627 }
628
629 std::vector<StmtPtr> loops;
630 for (const Tensor& t : tensors_to_compute) {
631 StmtPtr loop = t.stmt();
632 if (loop->get_parent()) {
633 std::cerr << "Error: creating a loopnest from already used Tensors\n";
634 loops = {};
635 break;
636 }
637 // Flatten initializers.
638 if (BlockPtr block = to<Block>(loop)) {
639 for (const auto& s : block->stmts()) {
640 block->remove_stmt(s);
641 loops.push_back(s);
642 }
643 } else {
644 loops.push_back(loop);
645 }
646 }
647
648 root_stmt_ = alloc<Block>(loops);
649 }
650
651 class FunctionInliner : public IRMutator {
652 public:
FunctionInliner(StorePtr producer,std::unordered_set<BufPtr> outputs)653 FunctionInliner(StorePtr producer, std::unordered_set<BufPtr> outputs)
654 : buf_(producer->buf()),
655 producer_(std::move(producer)),
656 outputs_(std::move(outputs)) {
657 for (const auto& i : producer_->indices()) {
658 if (auto index_var = to<Var>(i)) {
659 index_vars_.insert(index_var);
660 producer_index_vars_.push_back(index_var);
661 } else {
662 // If the index can be a constant, then that dimension must have size 1
663 // (since we don't support in-place writes). Resolves issue 52581.
664 auto index_val = evalInt(i);
665 if (!index_val || *index_val != 0) {
666 success_ = false;
667 break;
668 }
669 producer_index_vars_.push_back(nullptr);
670 }
671 }
672 }
673
success() const674 bool success() const {
675 return success_;
676 }
677
678 private:
mutate_loads(const BufPtr & buf,std::vector<ExprPtr> dims)679 ExprPtr mutate_loads(const BufPtr& buf, std::vector<ExprPtr> dims) {
680 std::vector<VarPtr> index_vars;
681 if (buf->ndim() != producer_index_vars_.size()) {
682 // Dimensions of producer and consumer expressions do not match in inliner
683 // in the fuser
684 success_ = false;
685 return nullptr;
686 }
687 for (const auto i : c10::irange(buf->ndim())) {
688 VarPtr func_callee_arg = producer_index_vars_.at(i);
689 ExprPtr func_caller_param = dims.at(i);
690 if (func_callee_arg == nullptr) {
691 continue;
692 }
693 auto iter = inline_mapping_.find(func_callee_arg);
694 if (iter != inline_mapping_.end()) {
695 // Duplicated variables
696 success_ = false;
697 return nullptr;
698 }
699 // Add a mapping for each function parameter to it's source name.
700 inline_mapping_[func_callee_arg] = func_caller_param;
701 GRAPH_DEBUG(
702 "ComputeInline: Inline mapping: ",
703 std::to_string(func_callee_arg),
704 " -> ",
705 std::to_string(func_caller_param));
706 index_vars.push_back(func_callee_arg);
707 }
708
709 // Call the actual replacement.
710 ExprPtr body = producer_->value();
711 GRAPH_DEBUG("ComputeInline: Before rewriting body: ", std::to_string(body));
712 ExprPtr result = Expr::clone(body)->accept_mutator(this);
713 GRAPH_DEBUG(
714 "ComputeInline: After rewriting body: ", std::to_string(result));
715
716 // Remove the mappings we created for this function parameters.
717 for (const auto& v : index_vars) {
718 for (auto& pair : random_bindings_) {
719 if (pair.second.erase(v)) {
720 ExprPtr inlined = inline_mapping_[v];
721 for (const auto& nv : VarFinder::find(inlined)) {
722 pair.second.insert(nv);
723 }
724 }
725 }
726 GRAPH_DEBUG("ComputeInline: Inline mapping: erasing", std::to_string(v));
727 inline_mapping_.erase(v);
728 }
729 return result;
730 }
731
mutate(const LoadPtr & v)732 ExprPtr mutate(const LoadPtr& v) override {
733 if (!success()) {
734 return v;
735 }
736 BufPtr buf = v->buf();
737 if (buf != buf_) {
738 return IRMutator::mutate(v);
739 }
740
741 if (v->indices().size() != buf->ndim()) {
742 // Number of indices doesn't match buf rank in the fuser
743 success_ = false;
744 return v;
745 }
746 auto result = mutate_loads(buf, v->indices());
747 if (!result) {
748 // If we don't inline successfully return the given load.
749 success_ = false;
750 return v;
751 }
752 return result;
753 }
754
755 // Replace the target variable with the caller expressions.
mutate(const VarPtr & v)756 ExprPtr mutate(const VarPtr& v) override {
757 if (!success()) {
758 return v;
759 }
760 auto iter = inline_mapping_.find(v);
761 if (iter == inline_mapping_.end()) {
762 return v;
763 } else {
764 ExprPtr expr = iter->second;
765 // Continue to transform the value from the lookup table.
766 return expr->accept_mutator(this);
767 }
768 }
769
770 // Handle random intrinsics which should be cached.
mutate(const IntrinsicsPtr & v)771 ExprPtr mutate(const IntrinsicsPtr& v) override {
772 if (!success()) {
773 return v;
774 }
775 if (!in_producer_ || v->op_type() != kRand) {
776 return IRMutator::mutate(v);
777 }
778
779 // Create a new Let Statement for the random variable, which we can refer
780 // to multiple times and resolve the same value (ie. store it in a scalar
781 // rather than the Tensor).
782 const std::string& name = buf_->name_hint();
783 VarPtr new_var = alloc<Var>(name, v->dtype());
784 random_bindings_[alloc<Let>(new_var, v)] = index_vars_;
785 GRAPH_DEBUG(
786 "ComputeInline: created random bindings for ", std::to_string(new_var));
787 return new_var;
788 }
789
790 // Remove the buffer write from the inlined function.
mutate(const StorePtr & v)791 StmtPtr mutate(const StorePtr& v) override {
792 if (!success()) {
793 return v;
794 }
795 // If the buf_ is in the outputs set, keep its statement intact. Otherwise,
796 // remove it.
797 if (v == producer_ && !outputs_.count(buf_)) {
798 in_producer_ = true;
799 producer_ = to<Store>(IRMutator::mutate(v));
800 if (!producer_) {
801 // Producer statement for output buf should remain non-null in the fuser
802 success_ = false;
803 return v;
804 }
805 in_producer_ = false;
806 return nullptr;
807 } else {
808 return IRMutator::mutate(v);
809 }
810 }
811
812 // Any Random Intrinsics that were turned into vars must be inserted here.
mutate(const BlockPtr & v)813 StmtPtr mutate(const BlockPtr& v) override {
814 if (!success()) {
815 return v;
816 }
817 std::vector<StmtPtr> stmts;
818 for (const StmtPtr& stmt : *v) {
819 StmtPtr stmt_new = stmt->accept_mutator(this);
820 if (!stmt_new) {
821 continue;
822 }
823
824 if (stmt == stmt_new) {
825 stmt_new = Stmt::clone(stmt);
826 }
827
828 stmts.push_back(stmt_new);
829 }
830
831 return Block::make(stmts);
832 }
833
mutate(const ForPtr & v)834 StmtPtr mutate(const ForPtr& v) override {
835 if (!success()) {
836 return v;
837 }
838 ForPtr res = to<For>(IRMutator::mutate(v));
839 if (!res) {
840 return nullptr;
841 }
842
843 // Find any random bindings that should be defined in this loops body.
844 std::vector<LetPtr> bindings_this_loop;
845 VarPtr fv = v->var();
846 for (auto& pair : random_bindings_) {
847 auto& index_var = pair.second;
848 if (index_var.erase(fv)) {
849 bindings_this_loop.push_back(pair.first);
850 }
851 }
852
853 for (const auto& l : bindings_this_loop) {
854 res->body()->prepend_stmt(l);
855 random_bindings_.erase(l);
856 }
857 return res;
858 }
859
860 private:
861 BufPtr buf_;
862 StorePtr producer_;
863
864 // Index Vars present in the producer.
865 std::unordered_set<VarPtr> index_vars_;
866 std::vector<VarPtr> producer_index_vars_;
867
868 std::unordered_map<VarPtr, ExprPtr> inline_mapping_;
869
870 // In the producer's scope - we need to bind any calls to rand().
871 bool in_producer_ = false;
872 std::unordered_map<LetPtr, std::unordered_set<VarPtr>> random_bindings_;
873 std::unordered_set<BufPtr> outputs_;
874 bool success_ = true;
875 };
876
computeInlineImpl(const BufPtr & b,const StmtPtr & stmt,const std::unordered_set<BufPtr> & output_bufs)877 static StmtPtr computeInlineImpl(
878 const BufPtr& b,
879 const StmtPtr& stmt,
880 const std::unordered_set<BufPtr>& output_bufs) {
881 // If buf is used or defined in an ExternalCall, we cannot inline it
882 auto buf_load_store_uses = findLoadOrStoreUses(stmt);
883 if (!buf_load_store_uses.count(b)) {
884 return nullptr;
885 }
886 for (auto& use : buf_load_store_uses.at(b)) {
887 StmtPtr s = use.s;
888 if (to<ExternalCall>(s) || to<ExternalCallWithAlloc>(s)) {
889 return nullptr;
890 }
891 }
892
893 // Find producers.
894 StorePtr relevant_store{nullptr};
895 auto stores = NodeFinder<Store>::find(stmt);
896 for (const auto& s : stores) {
897 if (s->buf() == b) {
898 auto reductions = NodeFinder<ReduceOp>::find(s);
899 if (!reductions.empty()) {
900 // Cannot inline a reduction computation
901 return nullptr;
902 }
903 if (relevant_store != nullptr) {
904 // Cannot inline Buf with multiple Tensors
905 return nullptr;
906 }
907 relevant_store = s;
908 }
909 }
910
911 if (!relevant_store) {
912 // Cannot find a relevant store to inline a buf in the fuser
913 return nullptr;
914 }
915
916 GRAPH_DEBUG("ComputeInline: Def: ", std::to_string(relevant_store));
917 FunctionInliner inliner(relevant_store, output_bufs);
918 auto result = stmt->accept_mutator(&inliner);
919 if (inliner.success()) {
920 return result;
921 }
922 return nullptr;
923 }
924
computeInline(const BufPtr & b)925 bool LoopNest::computeInline(const BufPtr& b) {
926 // Inlining may not always be successful. Since all mutations now happen
927 // in-place, an unsuccessful inlining transformation might leave the IR
928 // in an invalid state. To get around this problem, we clone the root stmt,
929 // try inlining on the clone, and if it succeeds, we proceed to perform
930 // inlining on the actual root stmt. This way the root stmt will always be
931 // in a valid state.
932 auto stmt_copy = Stmt::clone(root_stmt_);
933 auto try_inline = computeInlineImpl(b, stmt_copy, output_bufs_);
934 if (!try_inline) {
935 return false;
936 }
937 root_stmt_ = computeInlineImpl(b, root_stmt_, output_bufs_);
938 return true;
939 }
940
computeInline(const StmtPtr & s)941 bool LoopNest::computeInline(const StmtPtr& s) {
942 auto s_store = to<Store>(s);
943 if (s_store == nullptr) {
944 // Could not find buffer producer to inline
945 return false;
946 }
947 return computeInline(s_store->buf());
948 }
949
950 // inlining buffers with multiple uses can create duplicated work, which can
951 // slow down cpu code generation but is enabled on gpu because it avoids
952 // difficult synchronization logic across blocks. Inlining trivial reads does
953 // not duplicate work
inlineIntermediateBufs(bool allow_duplicated_work)954 void LoopNest::inlineIntermediateBufs(bool allow_duplicated_work) {
955 std::unordered_set<BufPtr> bufs_to_inline;
956
957 auto intermediate_bufs = getIntermediateBufs();
958 if (allow_duplicated_work) {
959 bufs_to_inline.insert(intermediate_bufs.begin(), intermediate_bufs.end());
960 } else {
961 auto buf_load_store_uses = findLoadOrStoreUses(root_stmt_);
962 auto input_bufs = getInputBufs();
963
964 for (const auto& buf : intermediate_bufs) {
965 TORCH_INTERNAL_ASSERT(
966 buf_load_store_uses.count(buf),
967 buildErrorMessage(
968 "Could not find uses of buf '" + buf->name_hint() +
969 "' in the fuser."));
970 std::vector<BufLoadOrStoreUse>& uses = buf_load_store_uses[buf];
971 auto stores = c10::filter(
972 uses, [](const BufLoadOrStoreUse& use) { return use.isStore; });
973
974 // if the intermediate is the buffer formed from reading in the input
975 // tensors, always inline, bc we are not duplicating any work
976 // and avoiding an intermediary buffer
977 if (stores.size() == 1) {
978 if (auto store = to<Store>(stores[0].s)) {
979 auto input_as_load = to<Load>(store->value());
980 if (input_as_load && input_bufs.count(input_as_load->buf())) {
981 bufs_to_inline.insert(buf);
982 continue;
983 }
984 } else {
985 // If S is not a store, it must be an ExternalCall.
986 TORCH_INTERNAL_ASSERT(
987 to<ExternalCall>(stores[0].s) ||
988 to<ExternalCallWithAlloc>(stores[0].s),
989 buildErrorMessage(
990 "Expected stmt: " + std::to_string(stores[0].s) +
991 "\nto be either a Store or an ExternalCall in the fuser."));
992 }
993 }
994
995 // all bufs will have at least one store (if they have > 1 they cant be
996 // inlined anyway)
997 size_t reads = uses.size() - 1;
998 // if only one read, we can inline it without duplicating work
999 if (reads <= 1) {
1000 bufs_to_inline.insert(buf);
1001 }
1002 }
1003 }
1004
1005 if (allow_duplicated_work) {
1006 bufs_to_inline.insert(output_bufs_.begin(), output_bufs_.end());
1007 }
1008
1009 for (const auto& b : bufs_to_inline) {
1010 computeInline(b);
1011 }
1012 }
1013
1014 // TODO: Unify with DepTracker
1015 class LoadOrStoreUseFinder : public IRVisitor {
1016 public:
findUses(const StmtPtr & s)1017 std::unordered_map<BufPtr, std::vector<BufLoadOrStoreUse>> findUses(
1018 const StmtPtr& s) {
1019 uses_.clear();
1020 s->accept(this);
1021 return uses_;
1022 }
1023
1024 private:
visit(const StorePtr & v)1025 void visit(const StorePtr& v) override {
1026 if (stores_[v->buf()].insert(last_stmt_).second) {
1027 uses_[v->buf()].push_back({(StmtPtr)v, true});
1028 }
1029 last_stmt_ = (StmtPtr)v;
1030 IRVisitor::visit(v);
1031 }
1032
visit(const ExternalCallPtr & v)1033 void visit(const ExternalCallPtr& v) override {
1034 if (stores_[v->buf()].insert(last_stmt_).second) {
1035 uses_[v->buf()].push_back({(StmtPtr)v, true});
1036 }
1037 last_stmt_ = (StmtPtr)v;
1038
1039 for (const BufPtr& input_buf : v->buf_args()) {
1040 if (loads_[input_buf].insert(last_stmt_).second) {
1041 uses_[input_buf].push_back({last_stmt_, false});
1042 }
1043 }
1044
1045 IRVisitor::visit(v);
1046 }
1047
visit(const ExternalCallWithAllocPtr & v)1048 void visit(const ExternalCallWithAllocPtr& v) override {
1049 for (const auto& out_buf : v->buf_out_args()) {
1050 if (stores_[out_buf].insert(last_stmt_).second) {
1051 uses_[out_buf].push_back({(StmtPtr)v, true});
1052 }
1053 }
1054 last_stmt_ = (StmtPtr)v;
1055
1056 for (const auto& input_buf : v->buf_args()) {
1057 if (loads_[input_buf].insert(last_stmt_).second) {
1058 uses_[input_buf].push_back({last_stmt_, false});
1059 }
1060 }
1061
1062 IRVisitor::visit(v);
1063 }
1064
visit(const LoadPtr & v)1065 void visit(const LoadPtr& v) override {
1066 if (loads_[v->buf()].insert(last_stmt_).second) {
1067 uses_[v->buf()].push_back({last_stmt_, false});
1068 }
1069 IRVisitor::visit(v);
1070 }
1071
1072 StmtPtr last_stmt_ = nullptr;
1073 std::unordered_map<BufPtr, std::vector<BufLoadOrStoreUse>> uses_;
1074
1075 // Sets of loads and stores in order to keep the results unique
1076 std::unordered_map<BufPtr, std::unordered_set<StmtPtr>> loads_;
1077 std::unordered_map<BufPtr, std::unordered_set<StmtPtr>> stores_;
1078 };
1079
findLoadOrStoreUses(const StmtPtr & s)1080 std::unordered_map<BufPtr, std::vector<BufLoadOrStoreUse>> findLoadOrStoreUses(
1081 const StmtPtr& s) {
1082 LoadOrStoreUseFinder uf;
1083 return uf.findUses(s);
1084 }
1085
1086 class ContainedStmtsFinder : public IRVisitor {
1087 public:
1088 // Simply list all Stores and Block that are children of the given stmt
findContainedStmts(const StmtPtr & s)1089 const std::unordered_set<StmtPtr>& findContainedStmts(const StmtPtr& s) {
1090 contained_.clear();
1091 s->accept(this);
1092 return contained_;
1093 }
1094
1095 private:
visit(const StorePtr & v)1096 void visit(const StorePtr& v) override {
1097 contained_.insert((StmtPtr)v);
1098 IRVisitor::visit(v);
1099 }
visit(const ExternalCallPtr & v)1100 void visit(const ExternalCallPtr& v) override {
1101 contained_.insert((StmtPtr)v);
1102 IRVisitor::visit(v);
1103 }
visit(const ExternalCallWithAllocPtr & v)1104 void visit(const ExternalCallWithAllocPtr& v) override {
1105 contained_.insert((StmtPtr)v);
1106 IRVisitor::visit(v);
1107 }
visit(const BlockPtr & v)1108 void visit(const BlockPtr& v) override {
1109 contained_.insert((StmtPtr)v);
1110 IRVisitor::visit(v);
1111 }
1112
1113 std::unordered_set<StmtPtr> contained_;
1114 };
1115
1116 class StmtDeleter : public IRMutator {
1117 public:
StmtDeleter(const std::unordered_set<StmtPtr> & targets)1118 StmtDeleter(const std::unordered_set<StmtPtr>& targets) : targets_(targets) {}
1119
1120 private:
mutate(const BlockPtr & v)1121 StmtPtr mutate(const BlockPtr& v) override {
1122 std::vector<StmtPtr> stmts;
1123
1124 for (const auto& s : v->stmts()) {
1125 if (targets_.count(s) == 0) {
1126 StmtPtr ns = s->accept_mutator(this);
1127 if (ns) {
1128 stmts.push_back(Stmt::clone(ns));
1129 }
1130 }
1131 }
1132
1133 return Block::make(stmts);
1134 }
1135
1136 const std::unordered_set<StmtPtr>& targets_;
1137 };
1138
eliminateDeadStores()1139 void LoopNest::eliminateDeadStores() {
1140 using namespace analysis;
1141 MemDependencyChecker checker(getInputBufs(), getOutputBufs());
1142 root_stmt_->accept(&checker);
1143
1144 std::unordered_set<StmtPtr> deadStores;
1145 std::vector<std::shared_ptr<AccessInfo>> outputAccesses;
1146 for (const auto& o : getOutputBufs()) {
1147 outputAccesses.push_back(checker.output(o));
1148 }
1149
1150 for (auto& info : checker.getHistory()) {
1151 if (!info->isWrite()) {
1152 continue;
1153 }
1154 bool found = false;
1155
1156 for (auto& output : outputAccesses) {
1157 if (checker.dependsIndirectly(output, info)) {
1158 found = true;
1159 break;
1160 }
1161 }
1162
1163 if (!found) {
1164 deadStores.insert(info->stmt());
1165 }
1166 }
1167
1168 StmtDeleter deleter(deadStores);
1169 root_stmt_ = root_stmt_->accept_mutator(&deleter);
1170 }
1171
prepareForCodegen()1172 void LoopNest::prepareForCodegen() {
1173 // Expand reduction ops.
1174 ReductionExpander reduceExpander;
1175 root_stmt_ = reduceExpander.expand(root_stmt_);
1176
1177 root_stmt_ = FlattenIndexes(root_stmt_);
1178 }
1179
1180 namespace {
1181
1182 // This is extended from IRCloner instead of IRMutator because we want all
1183 // the rest of the IR nodes (the ones not touched directly) to be cloned.
1184 class IfThenElseReplacer : public IRCloner {
1185 public:
IfThenElseReplacer(IfThenElsePtr to_replace,ExprPtr new_expr)1186 IfThenElseReplacer(IfThenElsePtr to_replace, ExprPtr new_expr)
1187 : to_replace_(std::move(to_replace)), new_expr_(std::move(new_expr)) {}
1188
mutate(const IfThenElsePtr & i)1189 ExprPtr mutate(const IfThenElsePtr& i) override {
1190 if (i == to_replace_) {
1191 return new_expr_;
1192 }
1193 return IRCloner::mutate(i);
1194 }
1195
1196 private:
1197 IfThenElsePtr to_replace_;
1198 ExprPtr new_expr_;
1199 };
1200
1201 // Check if the given condition is optimizable.
1202 // Specifically, this function looks for the following pattern:
1203 // "var < expr"
1204 //
1205 // If this pattern is found, then this function:
1206 // * sets `cond_var` to `var`,
1207 // * sets `compared_value` to `expr`, and
1208 // * returns true.
isConditionOptimizable(const ExprPtr & condition,VarPtr * cond_var,ExprPtr * compared_value)1209 bool isConditionOptimizable(
1210 const ExprPtr& condition,
1211 VarPtr* cond_var,
1212 ExprPtr* compared_value) {
1213 auto cs = to<CompareSelect>(condition);
1214 if (cs && cs->compare_select_op() == kLT) {
1215 auto var = to<Var>(cs->lhs());
1216 if (var) {
1217 *cond_var = var;
1218 *compared_value = cs->rhs();
1219 return true;
1220 }
1221 }
1222 return false;
1223 }
1224
1225 // Checks if the given if-then-else expression is a conditional that is
1226 // generated from `aten::cat`.
1227 //
1228 // The expected format of conditionals is:
1229 // IfThenElse(var < val1? 1 : 0,
1230 // IfThenElse (var < val2? 1 : 0,
1231 // IfThenElse (var < val3? 1 : 0,
1232 // sub-expr1,
1233 // sub-expr2),
1234 // sub-expr3),
1235 // sub-expr4)
1236 //
1237 // If such a conditional is found, this function also sets:
1238 // * cond_var to the condition variable found in this expression.
1239 // * comp_values to the list of compared values in the condition expressions.
1240 // * sub_exprs to the list of sub-expressions that are the result of this
1241 // if-then-else expression.
isConditionalFromCat(const IfThenElsePtr & ite,VarPtr * cond_var,std::vector<ExprPtr> * comp_values,std::vector<ExprPtr> * sub_exprs)1242 bool isConditionalFromCat(
1243 const IfThenElsePtr& ite,
1244 VarPtr* cond_var,
1245 std::vector<ExprPtr>* comp_values,
1246 std::vector<ExprPtr>* sub_exprs) {
1247 VarPtr var = nullptr;
1248 ExprPtr comp_value;
1249 if (isConditionOptimizable(ite->condition(), &var, &comp_value)) {
1250 if (*cond_var == nullptr) {
1251 *cond_var = var;
1252 } else if (*cond_var != var) {
1253 // Different condition variables found in nested if-then-else
1254 // expressions. Can not optimize such cases.
1255 return false;
1256 }
1257 auto true_ite = to<IfThenElse>(ite->true_value());
1258 if (true_ite) {
1259 if (!isConditionalFromCat(true_ite, cond_var, comp_values, sub_exprs)) {
1260 return false;
1261 }
1262 } else {
1263 sub_exprs->push_back(ite->true_value());
1264 }
1265 auto false_ite = to<IfThenElse>(ite->false_value());
1266 if (false_ite) {
1267 return false;
1268 }
1269 comp_values->push_back(comp_value);
1270 sub_exprs->push_back(ite->false_value());
1271 return true;
1272 }
1273 return false;
1274 }
1275
areConstantsAndSorted(const std::vector<ExprPtr> & comp_values)1276 bool areConstantsAndSorted(const std::vector<ExprPtr>& comp_values) {
1277 std::vector<int> comp_consts;
1278 comp_consts.reserve(comp_values.size());
1279 for (const auto& c : comp_values) {
1280 if (!c->isConstant()) {
1281 return false;
1282 }
1283 comp_consts.push_back(immediateAs<int>(c));
1284 }
1285 return std::is_sorted(comp_consts.begin(), comp_consts.end());
1286 }
1287
1288 } // namespace
1289
optimizeConditionals()1290 bool LoopNest::optimizeConditionals() {
1291 // Consider every store in the root_stmt_ and try to optimize the
1292 // conditionals in that store.
1293 auto stores = NodeFinder<Store>::find(root_stmt_);
1294 std::unordered_set<ForPtr> split_fors;
1295 for (const auto& store : stores) {
1296 VarPtr cond_var = nullptr;
1297 // `comp_values` represent the list of compared values that will be
1298 // collected as we check for the expected pattern. Since that will
1299 // only include the RHS of the conditions in the if-then-else expressions
1300 // we need to start with `0` which is the initial bound, given that we
1301 // only handle normalized loops (check for this is done below).
1302 std::vector<ExprPtr> comp_values;
1303 std::vector<ExprPtr> sub_exprs;
1304 auto ifthenelse_exprs = NodeFinder<IfThenElse>::find(store);
1305 if (ifthenelse_exprs.empty()) {
1306 continue;
1307 }
1308 // We only check if the first if-then-else expression in this store
1309 // corresponds to a conditional of the required format. If there are more
1310 // than one such conditional, optimizing them requires checking if the
1311 // conditions are exactly the same across them and handling all of them
1312 // together. Currently, this is not handled.
1313 if (!isConditionalFromCat(
1314 ifthenelse_exprs.front(), &cond_var, &comp_values, &sub_exprs)) {
1315 continue;
1316 }
1317 TORCH_INTERNAL_ASSERT(
1318 !comp_values.empty(),
1319 buildErrorMessage(
1320 "Expected at least one expression in optimizeConditional in the fuser."));
1321 comp_values.insert(comp_values.begin(), immLike(comp_values[0], 0));
1322
1323 auto fors = getLoopStmtsFor(store);
1324 if (cond_var != fors.back()->var()) {
1325 // Currently, we only handle the case where the condition variable
1326 // is the same as the inner-most loop variable.
1327 // TODO: Handle all other cases here.
1328 //
1329 // In order to handle all other cases, the method `clone_and_replace`
1330 // called below to clone the body of the loop with a new store needs
1331 // to recursively handle cloning of the loops and other blocks it
1332 // contains.
1333 continue;
1334 }
1335
1336 auto for_to_split = fors.back();
1337 if (!LoopNest::isNormalized(for_to_split)) {
1338 // Do not optimize this conditional since the condition variable
1339 // refers to a loop that is not normalized.
1340 continue;
1341 }
1342 if (split_fors.count(for_to_split)) {
1343 // This loop has already been split while optimizing conditionals
1344 // earlier.
1345 //
1346 // Optimizing multiple conditionals that require splitting the same loop
1347 // is tricky. It requires checking if the conditions are exactly the same
1348 // across them and handling all of them together by splitting the loop
1349 // exactly once.
1350 //
1351 // Currently, this case is not supported.
1352 continue;
1353 }
1354 split_fors.insert(for_to_split);
1355
1356 // `comp_values` needs to include the end bound, which is `for_to_split`
1357 // stop value.
1358 comp_values.push_back(for_to_split->stop());
1359
1360 // Check if all `comp_values` are constants and they are sorted.
1361 if (!areConstantsAndSorted(comp_values)) {
1362 continue;
1363 }
1364
1365 // Remove all the if-then-else expressions from this store and create
1366 // one loop per sub-expression.
1367 std::vector<StmtPtr> split_loops;
1368 auto cond_to_replace = ifthenelse_exprs.front();
1369 for (size_t i = 0; i < sub_exprs.size(); ++i) {
1370 IfThenElseReplacer ifthenelseReplacer(cond_to_replace, sub_exprs[i]);
1371 auto new_store = store->accept_mutator(&ifthenelseReplacer);
1372 auto new_for_body =
1373 for_to_split->body()->clone_and_replace(store, new_store);
1374 auto new_for = alloc<For>(
1375 for_to_split->var(),
1376 comp_values[i],
1377 comp_values[i + 1],
1378 new_for_body);
1379 LoopNest::normalize(new_for);
1380 split_loops.push_back(new_for);
1381 }
1382 auto par = to<Block>(for_to_split->get_parent());
1383 par->replace_stmt(for_to_split, alloc<Block>(split_loops));
1384 }
1385 root_stmt_ = IRSimplifier::simplify(root_stmt_);
1386 return true;
1387 }
1388
vectorizeInnerLoops()1389 void LoopNest::vectorizeInnerLoops() {
1390 std::vector<ForPtr> innerLoops;
1391 std::vector<ForPtr> worklist;
1392
1393 // Find outer-most For loops
1394 if (ForPtr rootF = to<For>(root_stmt_)) {
1395 worklist.push_back(rootF);
1396 } else if (BlockPtr body = to<Block>(root_stmt_)) {
1397 std::vector<BlockPtr> blocks = {body};
1398 while (!blocks.empty()) {
1399 BlockPtr b = blocks.back();
1400 blocks.pop_back();
1401
1402 for (const StmtPtr& s : *b) {
1403 if (const ForPtr& f = to<For>(s)) {
1404 worklist.push_back(f);
1405 } else if (BlockPtr b2 = to<Block>(s)) {
1406 blocks.push_back(b2);
1407 }
1408 }
1409 }
1410 }
1411
1412 // Traverse the For loop nest find inner-most loops, which are
1413 // vectorization candidates.
1414 while (!worklist.empty()) {
1415 ForPtr f = worklist.back();
1416 worklist.pop_back();
1417
1418 bool containsSubLoops = false;
1419 if (BlockPtr body = to<Block>(f->body())) {
1420 for (const StmtPtr& s2 : *body) {
1421 if (const ForPtr& f2 = to<For>(s2)) {
1422 containsSubLoops = true;
1423 worklist.push_back(f2);
1424 }
1425 }
1426 }
1427
1428 if (!containsSubLoops) {
1429 innerLoops.push_back(f);
1430 }
1431 }
1432
1433 // vectorize inner loops.
1434 for (const ForPtr& loop : innerLoops) {
1435 ForPtr split1;
1436 ForPtr tail1;
1437
1438 static const int kBodyVectorWidth = 8;
1439 splitWithTail(loop, kBodyVectorWidth, &split1, &tail1);
1440 vectorize(split1);
1441
1442 if (tail1) {
1443 ForPtr split2;
1444 ForPtr tail2;
1445 static const int kTailVectorWidth = 4;
1446 splitWithTail(tail1, kTailVectorWidth, &split2, &tail2);
1447 vectorize(split2);
1448 }
1449 }
1450 }
1451
sliceHead(const ForPtr & f,int factor,ForPtr * head,ForPtr * tail)1452 void LoopNest::sliceHead(
1453 const ForPtr& f,
1454 int factor,
1455 ForPtr* head,
1456 ForPtr* tail) {
1457 if (intValue(f->start()) && intValue(f->stop())) {
1458 auto start_val = *intValue(f->start());
1459 auto stop_val = *intValue(f->stop());
1460 auto size_val = stop_val - start_val;
1461 if (factor >= size_val) {
1462 *head = f;
1463 *tail = nullptr;
1464 return;
1465 }
1466 }
1467
1468 if (!f) {
1469 throw malformed_input("sliceHead attempted on null loop");
1470 }
1471
1472 BlockPtr p = to<Block>(f->get_parent());
1473 if (!p) {
1474 throw malformed_input("sliceHead attempted on loop with no parent");
1475 }
1476
1477 ExprPtr head_end = alloc<Min>(
1478 alloc<Add>(f->start(), immLike(f->stop(), factor)), f->stop(), true);
1479 *head = alloc<For>(f->var(), f->start(), head_end, Stmt::clone(f->body()));
1480 p->insert_stmt_before(*head, f);
1481
1482 f->set_start(head_end);
1483 *tail = f;
1484
1485 if (f->loop_options().is_gpu_block_index() ||
1486 f->loop_options().is_gpu_thread_index()) {
1487 LoopNest::normalize(*tail);
1488 }
1489 }
sliceHead(const ForPtr & f,int factor)1490 void LoopNest::sliceHead(const ForPtr& f, int factor) {
1491 ForPtr head, tail;
1492 sliceHead(f, factor, &head, &tail);
1493 }
1494
sliceTail(const ForPtr & f,int factor,ForPtr * head,ForPtr * tail)1495 void LoopNest::sliceTail(
1496 const ForPtr& f,
1497 int factor,
1498 ForPtr* head,
1499 ForPtr* tail) {
1500 if (intValue(f->start()) && intValue(f->stop())) {
1501 auto start_val = *intValue(f->start());
1502 auto stop_val = *intValue(f->stop());
1503 auto size_val = stop_val - start_val;
1504 if (factor >= size_val) {
1505 *head = nullptr;
1506 *tail = f;
1507 return;
1508 }
1509 }
1510
1511 if (!f) {
1512 throw malformed_input("sliceTail attempted on null loop");
1513 }
1514
1515 BlockPtr p = to<Block>(f->get_parent());
1516 if (!p) {
1517 throw malformed_input("sliceTail attempted on loop with no parent");
1518 }
1519
1520 ExprPtr tail_start = alloc<Max>(
1521 f->start(), alloc<Sub>(f->stop(), immLike(f->stop(), factor)), true);
1522 *tail = alloc<For>(f->var(), tail_start, f->stop(), Stmt::clone(f->body()));
1523 p->insert_stmt_after(*tail, f);
1524
1525 f->set_stop(tail_start);
1526 *head = f;
1527
1528 if (f->loop_options().is_gpu_block_index() ||
1529 f->loop_options().is_gpu_thread_index()) {
1530 LoopNest::normalize(*head);
1531 }
1532 }
sliceTail(const ForPtr & f,int factor)1533 void LoopNest::sliceTail(const ForPtr& f, int factor) {
1534 ForPtr head, tail;
1535 sliceTail(f, factor, &head, &tail);
1536 }
1537
splitWithTail(const ForPtr & f,int factor)1538 void LoopNest::splitWithTail(const ForPtr& f, int factor) {
1539 ForPtr inner, tail;
1540 splitWithTail(f, factor, &inner, &tail);
1541 }
1542
splitWithTail(const ForPtr & f,int factor,ForPtr * inner,ForPtr * tail)1543 void LoopNest::splitWithTail(
1544 const ForPtr& f,
1545 int factor,
1546 ForPtr* inner,
1547 ForPtr* tail) {
1548 if (!f) {
1549 throw malformed_input("splitWithTail attempted on null loop");
1550 }
1551
1552 BlockPtr p = to<Block>(f->get_parent());
1553 if (!p) {
1554 throw malformed_input("splitWithTail attempted on loop with no parent");
1555 }
1556
1557 // Normalize the loop to simplify start and stop bound computation
1558 normalize(f);
1559
1560 bool tail_is_needed = true;
1561 if (intValue(f->start()) && intValue(f->stop())) {
1562 auto const start_val = *intValue(f->start());
1563 auto const stop_val = *intValue(f->stop());
1564 auto const size_val = stop_val - start_val;
1565 auto const tail_size = size_val % factor;
1566 if (tail_size == 0) {
1567 tail_is_needed = false;
1568 }
1569 }
1570
1571 ExprPtr factor_expr = immLike(f->stop(), factor);
1572 ExprPtr size = alloc<Sub>(f->stop(), f->start());
1573 ExprPtr split_count = alloc<Div>(size, factor_expr);
1574 ExprPtr tail_size = alloc<Mod>(size, factor_expr);
1575
1576 const std::string& loop_var_name = f->var()->name_hint();
1577 Dtype loop_var_dtype = f->var()->dtype();
1578
1579 VarPtr i_inner = alloc<Var>(loop_var_name + "_inner", loop_var_dtype);
1580 VarPtr i_outer = alloc<Var>(loop_var_name + "_outer", loop_var_dtype);
1581
1582 // x -> x.outer * inner.size + x.inner
1583 ExprPtr combined_index1 =
1584 alloc<Add>(alloc<Mul>(i_outer, factor_expr), i_inner);
1585
1586 if (tail_is_needed) {
1587 VarPtr i_tail = alloc<Var>(loop_var_name + "_tail", loop_var_dtype);
1588 // x -> x.tail + outer.size * inner.size
1589 ExprPtr combined_index2 =
1590 alloc<Add>(i_tail, alloc<Mul>(split_count, factor_expr));
1591
1592 StmtPtr body_tail =
1593 SubstituteInClone(f->body(), {{f->var(), combined_index2}});
1594 *tail = alloc<For>(i_tail, immLike(tail_size, 0), tail_size, body_tail);
1595
1596 p->insert_stmt_after(*tail, f);
1597 } else {
1598 *tail = nullptr;
1599 }
1600
1601 StmtPtr body_inner =
1602 Substitute(f->removeBody(), {{f->var(), combined_index1}});
1603
1604 *inner =
1605 alloc<For>(i_inner, immLike(factor_expr, 0), factor_expr, body_inner);
1606 // The input loop `f` will be the outer loop after split.
1607 f->set_var(i_outer);
1608 f->set_start(immLike(split_count, 0));
1609 f->set_stop(split_count);
1610 f->set_body(*inner);
1611 }
1612
splitWithMask(const ForPtr & f,int factor)1613 void LoopNest::splitWithMask(const ForPtr& f, int factor) {
1614 ForPtr inner;
1615 splitWithMask(f, factor, &inner);
1616 }
1617
splitWithMask(const ForPtr & f,int factor,ForPtr * inner)1618 void LoopNest::splitWithMask(const ForPtr& f, int factor, ForPtr* inner) {
1619 BlockPtr p = to<Block>(f->get_parent());
1620 if (!p) {
1621 std::cerr << "Parent is not a Block!\n";
1622 return;
1623 }
1624
1625 bool tail_is_needed = true;
1626 ExprPtr start = IRSimplifier::simplify(f->start());
1627 ExprPtr stop = IRSimplifier::simplify(f->stop());
1628 if (start->isConstant() && stop->isConstant()) {
1629 auto start_val = *intValue(start);
1630 auto stop_val = *intValue(stop);
1631 auto size_val = stop_val - start_val;
1632 auto tail_size = size_val % factor;
1633 if (tail_size == 0) {
1634 tail_is_needed = false;
1635 }
1636 }
1637
1638 auto factor_expr = immLike(f->stop(), factor);
1639 ExprPtr size = alloc<Sub>(f->stop(), f->start());
1640 // split_count = (size + factor - 1) / factor
1641 ExprPtr split_count = alloc<Div>(
1642 alloc<Sub>(alloc<Add>(size, factor_expr), immLike(size, 1)), factor_expr);
1643
1644 const std::string& loop_var_name = f->var()->name_hint();
1645 Dtype loop_var_dtype = f->var()->dtype();
1646
1647 VarPtr i_inner = alloc<Var>(loop_var_name + "_inner", loop_var_dtype);
1648 VarPtr i_outer = alloc<Var>(loop_var_name + "_outer", loop_var_dtype);
1649
1650 // x -> x.outer * inner.size + x.inner
1651 ExprPtr combined_index =
1652 alloc<Add>(alloc<Mul>(i_outer, factor_expr), i_inner);
1653
1654 StmtPtr body_inner = f->removeBody();
1655 // TODO: is it ok that we're doing it eagerly? In the other implementation we
1656 // are only materializing predicates at the last, lowering, step.
1657 if (tail_is_needed) {
1658 auto start = intValue(f->start());
1659 if (!start || *start != 0) {
1660 throw unimplemented_lowering();
1661 }
1662
1663 ExprPtr predicate =
1664 CompareSelect::make(ExprHandle(f->var()), ExprHandle(f->stop()), kLT)
1665 .node();
1666 body_inner = Cond::make(ExprHandle(predicate), body_inner, nullptr);
1667 }
1668 body_inner = Substitute(body_inner, {{f->var(), combined_index}});
1669
1670 *inner =
1671 alloc<For>(i_inner, immLike(factor_expr, 0), factor_expr, body_inner);
1672 // The input loop `f` will be the outer loop after split.
1673 f->set_var(i_outer);
1674 f->set_start(immLike(split_count, 0));
1675 f->set_stop(split_count);
1676 f->set_body(*inner);
1677 }
1678
distributeLoop(const ForPtr & loop,const std::unordered_set<StmtPtr> & pivots)1679 std::vector<ForPtr> LoopNest::distributeLoop(
1680 const ForPtr& loop,
1681 const std::unordered_set<StmtPtr>& pivots) {
1682 TORCH_INTERNAL_ASSERT(
1683 loop,
1684 buildErrorMessage(
1685 "Expected non-null loop in distributeLoop in the fuser."));
1686 auto root = loop->get_parent();
1687 if (root == nullptr) {
1688 throw malformed_input("Loop without parent: ", loop);
1689 }
1690 auto root_block = to<Block>(root);
1691 if (root_block == nullptr) {
1692 throw malformed_input(
1693 "Loop's parent must be a Block, instead found ", root);
1694 }
1695
1696 // Extract bodies for all the loops after distribution.
1697 std::vector<BlockPtr> new_loop_bodies;
1698 auto new_loop_body = alloc<Block>(std::vector<StmtPtr>({}));
1699 while (!loop->body()->empty()) {
1700 auto s = loop->body()->front();
1701 loop->body()->remove_stmt(s);
1702 new_loop_body->append_stmt(s);
1703 if (pivots.count(s)) {
1704 new_loop_bodies.push_back(new_loop_body);
1705 new_loop_body = alloc<Block>(std::vector<StmtPtr>({}));
1706 }
1707 }
1708 if (!new_loop_body->empty()) {
1709 new_loop_bodies.push_back(new_loop_body);
1710 }
1711
1712 // The first loop body has to be in the original loop.
1713 loop->body()->splice(loop->body()->begin(), new_loop_bodies.front());
1714 std::vector<ForPtr> new_loops = {loop};
1715
1716 // Create loops for all the remaining blocks.
1717 // Add all the new loops to the parent block.
1718 for (size_t i = 1; i < new_loop_bodies.size(); ++i) {
1719 auto new_loop = loop->cloneWithNewBody(new_loop_bodies[i]);
1720 root_block->insert_stmt_after(new_loop, new_loops.back());
1721 new_loops.push_back(new_loop);
1722 }
1723
1724 return new_loops;
1725 }
1726
distributeLoop(const ForPtr & loop)1727 std::vector<ForPtr> LoopNest::distributeLoop(const ForPtr& loop) {
1728 std::unordered_set<StmtPtr> stmtsInBlock(
1729 loop->body()->begin(), loop->body()->end());
1730 return distributeLoop(loop, stmtsInBlock);
1731 }
1732
distributeLoopAndParents(const ForPtr & loop)1733 std::vector<ForPtr> LoopNest::distributeLoopAndParents(const ForPtr& loop) {
1734 auto parentLoop = getParentLoop(loop);
1735 auto result = distributeLoop(loop);
1736 if (parentLoop) {
1737 return distributeLoopAndParents(parentLoop);
1738 }
1739 return result;
1740 }
1741
distributeLoopOverInnerLoops(const ForPtr & loop)1742 std::vector<ForPtr> LoopNest::distributeLoopOverInnerLoops(const ForPtr& loop) {
1743 auto loops = NodeFinder<For>::find(loop);
1744 std::unordered_set<StmtPtr> loopsSet(loops.begin(), loops.end());
1745 return distributeLoop(loop, loopsSet);
1746 }
1747
distributeLoopAndParentsOverInnerLoops(const ForPtr & loop)1748 std::vector<ForPtr> LoopNest::distributeLoopAndParentsOverInnerLoops(
1749 const ForPtr& loop) {
1750 auto parentLoop = getParentLoop(loop);
1751 auto result = distributeLoopOverInnerLoops(loop);
1752 if (parentLoop) {
1753 return distributeLoopAndParentsOverInnerLoops(parentLoop);
1754 }
1755 return result;
1756 }
1757
areEqual(const ExprPtr & expr1,const ExprPtr & expr2)1758 static bool areEqual(const ExprPtr& expr1, const ExprPtr& expr2) {
1759 auto diff = IRSimplifier::simplify(alloc<Sub>(expr1, expr2));
1760 return diff->isConstant() && (immediateAs<int>(diff) == 0);
1761 };
1762
doesExprContainAnyVar(const ExprPtr & expr,const std::unordered_set<VarPtr> & vars)1763 static bool doesExprContainAnyVar(
1764 const ExprPtr& expr,
1765 const std::unordered_set<VarPtr>& vars) {
1766 for (const auto& v : VarFinder::find(expr)) {
1767 if (vars.count(v)) {
1768 return true;
1769 }
1770 }
1771 return false;
1772 }
1773
1774 // Returns true if the given list of indices refer to two accesses
1775 // that are loop-independent w.r.t. the given list of outer loop
1776 // variables.
areIndicesLoopIndependent(const std::vector<ExprPtr> & expr_list1,const std::vector<ExprPtr> & expr_list2,const std::unordered_set<VarPtr> & outer_loop_vars)1777 static bool areIndicesLoopIndependent(
1778 const std::vector<ExprPtr>& expr_list1,
1779 const std::vector<ExprPtr>& expr_list2,
1780 const std::unordered_set<VarPtr>& outer_loop_vars) {
1781 if (expr_list1.size() != expr_list2.size()) {
1782 return false;
1783 }
1784 for (size_t i = 0; i < expr_list1.size(); ++i) {
1785 const auto& expr1 = expr_list1[i];
1786 const auto& expr2 = expr_list2[i];
1787 if (doesExprContainAnyVar(expr1, outer_loop_vars) ||
1788 doesExprContainAnyVar(expr2, outer_loop_vars)) {
1789 if (!areEqual(expr1, expr2)) {
1790 return false;
1791 }
1792 }
1793 }
1794 return true;
1795 }
1796
hasLoopCarriedDependence(const ForPtr & loop)1797 bool LoopNest::hasLoopCarriedDependence(const ForPtr& loop) {
1798 analysis::MemDependencyChecker analyzer;
1799 loop->accept(&analyzer);
1800
1801 std::unordered_set<VarPtr> outer_loop_vars = {loop->var()};
1802 auto outer_loops = LoopNest::getEnclosingLoopNest(loop);
1803 for (const auto& l : outer_loops) {
1804 outer_loop_vars.insert(l->var());
1805 }
1806
1807 // High-level algorithm to check if two accesses to a buffer, A and B, one of
1808 // which is a Store, result in a loop-carried dependence:
1809 // 1. For every pair of index expressions, Ai and Bi, that refer to a dim
1810 // of A and B, if one of the following conditions are satisfied:
1811 // a) Ai and Bi are equal (OR)
1812 // b) Both Ai and Bi do not contain any outer-loop variables
1813 // then, the dependence between A and B is a loop-independent
1814 // dependence. This is because, in the case of b), those index
1815 // expressions do not affect the ordering of accesses A and B.
1816 // 2. If condition 1) is not satisfied:
1817 // a) if the bounds on the accesses overlap, then this is a
1818 // loop-carried dependence.
1819 // b) if the bounds on the accesses do not overlap, then there is no
1820 // dependence.
1821 //
1822 // NOTE: Since we check for equality of index expressions whenever outer
1823 // loop variables are involved, this may incorrectly report some cases as
1824 // having a loop-carried dependence. It is impractical to handle all
1825 // possible cases here, so, we are being conservative and allow for
1826 // some false positives. While this will prevent some loop fusion
1827 // opportunities, that should be a small fraction of the cases that are
1828 // allowed.
1829 //
1830 // Implementation:
1831 //
1832 // For every pair of statements, S1 and S2, in the loop:
1833 // * Get the loads and stores in S1 and S2.
1834 // * For every store in S1 and load in S2 to the same buffer, if the index
1835 // expressions are not equal and there is an overlap in accesses, return
1836 // true to indicate a loop-carried dependence.
1837 // * For every load in S1 and store in S2 to the same buffer, if the index
1838 // expressions are not equal and there is an overlap in accesses, return
1839 // true to indicate a loop-carried dependence.
1840 // * For every store in S1 and store in S2 to the same buffer, if the index
1841 // expressions are not equal and there is an overlap in accesses, return
1842 // true to indicate a loop-carried dependence.
1843 for (auto it1 = loop->body()->begin(); it1 != loop->body()->end(); ++it1) {
1844 for (auto it2 = std::next(it1); it2 != loop->body()->end(); ++it2) {
1845 auto aStores = NodeFinder<Store>::find(*it1);
1846 auto aLoads = NodeFinder<Load>::find(*it1);
1847 auto bStores = NodeFinder<Store>::find(*it2);
1848 auto bLoads = NodeFinder<Load>::find(*it2);
1849 // ReadAfterWrite
1850 for (auto& aStore : aStores) {
1851 for (auto& bLoad : bLoads) {
1852 if (aStore->buf() == bLoad->buf()) {
1853 if (!areIndicesLoopIndependent(
1854 aStore->indices(), bLoad->indices(), outer_loop_vars)) {
1855 if (isOverlapping(analyzer, aStore, bLoad)) {
1856 return true;
1857 }
1858 }
1859 }
1860 }
1861 }
1862 // WriteAfterRead
1863 for (auto& bStore : bStores) {
1864 for (auto& aLoad : aLoads) {
1865 if (bStore->buf() == aLoad->buf()) {
1866 if (!areIndicesLoopIndependent(
1867 bStore->indices(), aLoad->indices(), outer_loop_vars)) {
1868 if (isOverlapping(analyzer, bStore, aLoad)) {
1869 return true;
1870 }
1871 }
1872 }
1873 }
1874 }
1875 // WriteAfterWrite
1876 for (auto& aStore : aStores) {
1877 for (auto& bStore : bStores) {
1878 if (aStore->buf() == bStore->buf()) {
1879 if (!areIndicesLoopIndependent(
1880 aStore->indices(), bStore->indices(), outer_loop_vars)) {
1881 if (isOverlapping(analyzer, aStore, bStore)) {
1882 return true;
1883 }
1884 }
1885 }
1886 }
1887 }
1888 }
1889 }
1890 return false;
1891 }
1892
unsafeFuseLoops(const std::vector<ForPtr> & loops,ForPtr * fused)1893 bool LoopNest::unsafeFuseLoops(
1894 const std::vector<ForPtr>& loops,
1895 ForPtr* fused) {
1896 if (loops.empty()) {
1897 return false;
1898 }
1899 if (loops.size() == 1) {
1900 *fused = loops.front();
1901 return true;
1902 }
1903
1904 // Check if all the loops have the same parent.
1905 auto root = loops.front()->get_parent();
1906 for (const auto& l : loops) {
1907 auto par = l->get_parent();
1908 if (par == nullptr) {
1909 return false;
1910 }
1911 if (par != root) {
1912 return false;
1913 }
1914 }
1915 auto root_block = to<Block>(root);
1916 if (root_block == nullptr) {
1917 return false;
1918 }
1919
1920 // Currently, we only handle cases where there are no statements between
1921 // the given loops in their parents body. We can possibly relax this
1922 // constraint by allowing statements that do not affect the loops being
1923 // fused by performing some dependency analysis. TODO.
1924 auto it = root_block->begin();
1925 for (; it != root_block->end(); ++it) {
1926 if (*it == loops.front()) {
1927 break;
1928 }
1929 }
1930 TORCH_INTERNAL_ASSERT(
1931 it != root_block->end(),
1932 buildErrorMessage(
1933 "Could not find the given loop in the root stmt in unsafeFuseLoop the fuser."));
1934 for (const auto& l : loops) {
1935 if (*it != l) {
1936 return false;
1937 }
1938 ++it;
1939 }
1940
1941 const auto& first_loop = loops.front();
1942 // Fuse the loops by taking all the statements from the second loops
1943 // onwards and moving them into the first loop's body.
1944 // This way the final fused loop will be the same as the first loop.
1945 for (size_t i = 1; i < loops.size(); ++i) {
1946 auto body = to<Block>(SubstituteInClone(
1947 loops[i]->body(), {{loops[i]->var(), first_loop->var()}}));
1948 first_loop->body()->splice(first_loop->body()->end(), body);
1949 root_block->remove_stmt(loops[i]);
1950 }
1951
1952 *fused = loops.front();
1953 return true;
1954 }
1955
fuseLoops(const std::vector<ForPtr> & loops,ForPtr * fused)1956 bool LoopNest::fuseLoops(const std::vector<ForPtr>& loops, ForPtr* fused) {
1957 if (loops.empty()) {
1958 return false;
1959 }
1960 if (loops.size() == 1) {
1961 *fused = loops.front();
1962 return true;
1963 }
1964
1965 // Check if bounds are the same for all the loops.
1966 const auto& first_loop = loops.front();
1967 auto first_loop_start = IRSimplifier::simplify(first_loop->start());
1968 auto first_loop_stop = IRSimplifier::simplify(first_loop->stop());
1969 for (size_t i = 1; i < loops.size(); ++i) {
1970 const auto& curr_loop = loops[i];
1971 auto curr_loop_start = IRSimplifier::simplify(curr_loop->start());
1972 auto curr_loop_stop = IRSimplifier::simplify(curr_loop->stop());
1973 if (!areEqual(curr_loop_start, first_loop_start)) {
1974 return false;
1975 }
1976 if (!areEqual(curr_loop_stop, first_loop_stop)) {
1977 return false;
1978 }
1979 }
1980
1981 // We need to check if fusing the loops results in a loop-carried dependence.
1982 // This check can be done only after the loops are fused into one. But if the
1983 // check is violated, we need to return the given loops in the original form.
1984 // So, we create a clone of all the loops, fuse them and check for this.
1985 std::vector<ForPtr> loops_copy;
1986 loops_copy.reserve(loops.size());
1987 BlockPtr parent = alloc<Block>(std::vector<StmtPtr>({}));
1988 for (auto& l : loops) {
1989 auto l_copy = Stmt::clone(l);
1990 loops_copy.push_back(to<For>(l_copy));
1991 parent->append_stmt(l_copy);
1992 }
1993 ForPtr fused_copy;
1994 bool ret = unsafeFuseLoops(loops_copy, &fused_copy);
1995 if (!ret || hasLoopCarriedDependence(fused_copy)) {
1996 return false;
1997 }
1998
1999 // Now that all conditions are satisfied, we fuse the given loops.
2000 return unsafeFuseLoops(loops, fused);
2001 }
2002
findOuterFor(ForPtr a,ForPtr b)2003 ForPtr LoopNest::findOuterFor(ForPtr a, ForPtr b) {
2004 StmtPtr s = b; // guess b is the latter.
2005 while (s != nullptr) {
2006 if (s == a) {
2007 // yes, b is after a.
2008 return a;
2009 }
2010 s = s->get_parent();
2011 }
2012
2013 // check that the two are in the same loop nest.
2014 s = a;
2015 while (s != nullptr) {
2016 if (s == b) {
2017 // a is after b.
2018 return b;
2019 }
2020 s = s->get_parent();
2021 }
2022
2023 // a and b have no relationship.
2024 return nullptr;
2025 }
2026
reorderAxis(const ForPtr & a,const ForPtr & b)2027 void LoopNest::reorderAxis(const ForPtr& a, const ForPtr& b) {
2028 if (a == b) {
2029 // nothing to do.
2030 return;
2031 }
2032 // find inner and outer.
2033 ForPtr outer = findOuterFor(a, b);
2034 if (outer == nullptr) {
2035 throw std::runtime_error("Reordered a loop not in LoopNest");
2036 }
2037
2038 ForPtr inner = a == outer ? b : a;
2039 std::deque<ForPtr> internal_axes;
2040
2041 // Find relevant axes, store reversed.
2042 StmtPtr s = inner;
2043 while (s != outer) {
2044 if (const ForPtr& f = to<For>(s)) {
2045 internal_axes.push_back(f);
2046 }
2047
2048 s = s->get_parent();
2049 }
2050
2051 internal_axes.push_back(outer);
2052
2053 BlockPtr root = to<Block>(outer->get_parent());
2054 CHECK(root);
2055
2056 // Do a shallow copy of the inner blocks.
2057 BlockPtr body = alloc<Block>(std::vector<StmtPtr>({}));
2058 body->splice(body->end(), inner->body());
2059
2060 const ForPtr& before{outer};
2061 ForPtr after{nullptr};
2062 ForPtr last = internal_axes.front();
2063 StmtPtr newInner = body;
2064
2065 s = inner;
2066 while (s != outer) {
2067 if (auto cond = to<Cond>(s->get_parent())) {
2068 if (s == cond->true_stmt()) {
2069 newInner = cond->cloneWithNewBody(newInner);
2070 } else {
2071 // s is the false branch of Cond
2072 newInner = cond->cloneWithNewBodies(
2073 alloc<Block>(std::vector<StmtPtr>({})), newInner);
2074 }
2075 }
2076 s = s->get_parent();
2077 }
2078
2079 // This is the major complexity in loop reordering: handling statements not in
2080 // the straight line of the reorder. To handle this we partition the tree into
2081 // the section before the critical path and after the critical path.
2082 //
2083 // An example of this pattern is:
2084 // for i in ..
2085 // Statement A
2086 // for j in ..
2087 // Statement B
2088 // Statement C
2089 //
2090 // When reordering loop i and j we need to ensure that Statement A and C are
2091 // still both executed with the loop extents of i, and that the three
2092 // statements are not reordered (as much as possible).
2093 for (const auto& loop : internal_axes) {
2094 // If the inner loop had a component after the loop we must wrap it in a For
2095 // loop matching this level of the tree.
2096 if (after != nullptr) {
2097 after = loop->cloneWithNewBody(after);
2098 }
2099
2100 bool pastMidpoint = false;
2101 bool hadBeforeStmts = false;
2102 for (auto I = loop->body()->begin(), E = loop->body()->end(); I != E;) {
2103 // Be careful not to invalidate the iterator.
2104 StmtPtr s = *(I++);
2105 if (s == last) {
2106 // This is the midpoint.
2107 loop->body()->remove_stmt(s);
2108 if (!hadBeforeStmts) {
2109 // If there were no existing statements this loop does not need to be
2110 // preserved and we can roll it into the above loop.
2111 last = loop;
2112 }
2113 pastMidpoint = true;
2114 } else if (pastMidpoint) {
2115 // Statements after the reordered path must be moved to a new tree after
2116 // the reordered statement has occurred to preserve ordering.
2117 loop->body()->remove_stmt(s);
2118 if (after == nullptr) {
2119 after = loop->cloneWithNewBody(s);
2120 } else {
2121 after->body()->append_stmt(s);
2122 }
2123 } else {
2124 // We can leave any statements before the reordered loop alone, so long
2125 // as we preserve the loop structure.
2126 hadBeforeStmts = true;
2127 }
2128 }
2129 }
2130
2131 // now we can actually reorder the chosen axes.
2132 std::swap(internal_axes.front(), internal_axes.back());
2133
2134 // Create the reordered internals:
2135 for (const auto& loop : internal_axes) {
2136 newInner = loop->cloneWithNewBody(newInner);
2137 }
2138
2139 // Append the new statements to the root of the tree.
2140 if (before->body()->nstmts() == 0) {
2141 // If the top level is now empty, eliminate it.
2142 root->replace_stmt(before, newInner);
2143 } else {
2144 root->insert_stmt_after(newInner, before);
2145 }
2146
2147 if (after) {
2148 root->insert_stmt_after(after, newInner);
2149 }
2150 }
2151
isTrivialPermutation(const std::vector<size_t> & permutation)2152 static bool isTrivialPermutation(const std::vector<size_t>& permutation) {
2153 for (size_t i = 0; i < permutation.size(); ++i) {
2154 if (permutation[i] != i) {
2155 return false;
2156 }
2157 }
2158 return true;
2159 }
2160
isValidPermutation(std::vector<size_t> permutation)2161 static bool isValidPermutation(std::vector<size_t> permutation) {
2162 std::sort(permutation.begin(), permutation.end());
2163 return isTrivialPermutation(permutation);
2164 }
2165
reorder(const std::vector<ForPtr> & loops,const std::vector<size_t> & permutation)2166 std::vector<ForPtr> LoopNest::reorder(
2167 const std::vector<ForPtr>& loops,
2168 const std::vector<size_t>& permutation) {
2169 if (loops.size() != permutation.size()) {
2170 throw malformed_input("invalid permutation size");
2171 }
2172 if (isTrivialPermutation(permutation)) {
2173 return loops;
2174 }
2175 if (!isValidPermutation(permutation)) {
2176 throw malformed_input("invalid permutation for reorder");
2177 }
2178 if (loops.size() < 2) {
2179 return loops;
2180 }
2181 if (!areLoopsPerfectlyNested(loops)) {
2182 throw malformed_input("reorder is only allowed on perfectly nested loops");
2183 }
2184
2185 auto parent = to<Block>(loops.front()->get_parent());
2186 if (parent == nullptr) {
2187 throw malformed_input("parent of the loops must be a Block");
2188 }
2189
2190 // Reorder the loops according to the permutation.
2191 std::vector<ForPtr> result(loops.size());
2192 for (size_t i = 0; i < loops.size(); ++i) {
2193 result[i] = loops[permutation[i]];
2194 }
2195
2196 // Remove the bodies from all the loops.
2197 auto innermost_body = loops.back()->removeBody();
2198 // We use an empty block statement to replace the outermost loop
2199 // so that we know the position where the outermost reordered loop
2200 // is to be inserted.
2201 auto empty_block = alloc<Block>(std::vector<StmtPtr>({}));
2202 parent->replace_stmt(loops.front(), empty_block);
2203 for (size_t i = 1; i < loops.size(); ++i) {
2204 auto block = to<Block>(loops[i]->get_parent());
2205 TORCH_INTERNAL_ASSERT(
2206 block,
2207 buildErrorMessage(
2208 "Expected parent stmt to be a non-null Block in reorder transformation the fuser."));
2209 block->remove_stmt(loops[i]);
2210 }
2211
2212 // Set the new bodies after reorder for all the loops.
2213 for (size_t i = 0; i < result.size() - 1; ++i) {
2214 result[i]->set_body(result[i + 1]);
2215 }
2216 result.back()->set_body(innermost_body);
2217 parent->replace_stmt(empty_block, result.front());
2218 return result;
2219 }
2220
getLoopAt(ForPtr root,const std::vector<int> & indices) const2221 ForPtr LoopNest::getLoopAt(ForPtr root, const std::vector<int>& indices) const {
2222 if (indices.empty()) {
2223 return root;
2224 }
2225 if (root == nullptr) {
2226 throw malformed_input("root loop is null");
2227 }
2228
2229 ForPtr curr = std::move(root);
2230 for (auto i : indices) {
2231 if (i < 0 || curr->body()->nstmts() <= static_cast<size_t>(i)) {
2232 return nullptr;
2233 }
2234 std::list<StmtPtr>::iterator stmtp = curr->body()->begin();
2235 std::advance(stmtp, i);
2236 curr = to<For>(*stmtp);
2237 if (curr == nullptr) {
2238 return nullptr;
2239 }
2240 }
2241
2242 return curr;
2243 }
2244
tile(const ForPtr & x,const ForPtr & y,int x_factor,int y_factor)2245 ForPtr LoopNest::tile(
2246 const ForPtr& x,
2247 const ForPtr& y,
2248 int x_factor,
2249 int y_factor) {
2250 auto parent = to<Block>(x->get_parent());
2251 if (parent == nullptr) {
2252 throw malformed_input("parent of the loops must be a Block");
2253 }
2254 if (!areLoopsPerfectlyNested({x, y})) {
2255 throw malformed_input("two loops must be perfectly nested");
2256 }
2257
2258 // Split x, y axes by x_factor and y_factor
2259 ForPtr yi, ytail;
2260 splitWithTail(y, y_factor, &yi, &ytail);
2261 ForPtr xi, xtail;
2262 splitWithTail(x, x_factor, &xi, &xtail);
2263
2264 // Distribute xi over yo and ytail so we can manipulate the loop order of {xo,
2265 // xi, yo, yi}
2266 auto loops = distributeLoop(xi);
2267
2268 // For {xi, yo, yi}, reorder the axes to be yo, xi, yi
2269 xi = loops.front();
2270 ForPtr yo = to<For>(xi->body()->stmts().front());
2271 CHECK(yo);
2272 reorder({xi, yo}, {1, 0});
2273
2274 // For {xi, ytail}, reorder the axes to be ytail, xi
2275 if (loops.size() == 2) {
2276 xi = loops.back();
2277 ytail = to<For>(xi->body()->stmts().front());
2278 CHECK(ytail);
2279 reorder({xi, ytail}, {1, 0});
2280 }
2281
2282 return xtail;
2283 }
2284
areLoopsPerfectlyNested(const std::vector<ForPtr> & loops)2285 bool LoopNest::areLoopsPerfectlyNested(const std::vector<ForPtr>& loops) {
2286 if (loops.size() < 2) {
2287 return true;
2288 }
2289 for (size_t i = 0; i < loops.size() - 1; ++i) {
2290 auto loop_body = loops[i]->body();
2291 if (loop_body->nstmts() != 1 || loop_body->front() != loops[i + 1]) {
2292 return false;
2293 }
2294 }
2295 return true;
2296 }
2297
fullUnroll(const ForPtr & f,StmtPtr * unrolled)2298 void LoopNest::fullUnroll(const ForPtr& f, StmtPtr* unrolled) {
2299 BlockPtr p = to<Block>(f->get_parent());
2300 if (!f) {
2301 throw malformed_input("unroll attempted on null loop");
2302 } else if (!p) {
2303 throw malformed_input("unroll attempted on loop with no parent");
2304 }
2305
2306 auto start_expr = IRSimplifier::simplify(f->start());
2307 auto stop_expr = IRSimplifier::simplify(f->stop());
2308 if (!start_expr->isConstant()) {
2309 throw std::runtime_error("Can't unroll due to non-constant loop start!");
2310 }
2311 if (!stop_expr->isConstant()) {
2312 throw std::runtime_error("Can't unroll due to non-constant loop stop!");
2313 }
2314
2315 std::vector<StmtPtr> unrolled_stmts;
2316 int start_val = immediateAs<int>(start_expr);
2317 int stop_val = immediateAs<int>(stop_expr);
2318 for (int current = start_val; current < stop_val; ++current) {
2319 for (const auto& stmt : f->body()->stmts()) {
2320 unrolled_stmts.push_back(SubstituteInClone(
2321 stmt, {{f->var(), getImmediateByType(f->var()->dtype(), current)}}));
2322 }
2323 }
2324 *unrolled = alloc<Block>(unrolled_stmts);
2325 *unrolled = IRSimplifier::simplify(*unrolled);
2326
2327 p->replace_stmt(f, *unrolled);
2328 }
2329
fullUnroll(const ForPtr & f)2330 void LoopNest::fullUnroll(const ForPtr& f) {
2331 StmtPtr unrolled;
2332 fullUnroll(f, &unrolled);
2333 }
2334
unroll(const ForPtr & f,int factor,ForPtr * tail)2335 void LoopNest::unroll(const ForPtr& f, int factor, ForPtr* tail) {
2336 if (factor < 2) {
2337 return;
2338 }
2339 ForPtr inner;
2340 splitWithTail(f, factor, &inner, tail);
2341 fullUnroll(inner);
2342 }
2343
unroll(const ForPtr & f,int factor)2344 void LoopNest::unroll(const ForPtr& f, int factor) {
2345 ForPtr tail;
2346 unroll(f, factor, &tail);
2347 }
2348
isNormalized(const ForPtr & f)2349 bool LoopNest::isNormalized(const ForPtr& f) {
2350 if (f->start()->isConstant()) {
2351 return immediateAs<int>(f->start()) == 0;
2352 }
2353 return false;
2354 }
2355
normalize(const ForPtr & f)2356 bool LoopNest::normalize(const ForPtr& f) {
2357 if (!f) {
2358 throw malformed_input("normalize attempted on null loop");
2359 }
2360
2361 if (isNormalized(f)) {
2362 // No need to normalize anymore here.
2363 return false;
2364 }
2365
2366 auto for_body_normalized = Substitute(
2367 f->body(),
2368 {{f->var(), (VarHandle(f->var()) + ExprHandle(f->start())).node()}});
2369 f->set_body(IRSimplifier::simplify(for_body_normalized));
2370 f->set_stop(IRSimplifier::simplify(alloc<Sub>(f->stop(), f->start())));
2371 f->set_start(immLike(f->stop(), 0));
2372 return true;
2373 }
2374
2375 // This function expects that there are 'num' loops perfectly nested within
2376 // and including 'f'.
getLoopStmtsInLoopNest(const ForPtr & f,size_t num)2377 std::vector<ForPtr> LoopNest::getLoopStmtsInLoopNest(
2378 const ForPtr& f,
2379 size_t num) {
2380 std::vector<ForPtr> loops(num);
2381 ForPtr curr_for = f;
2382 loops[0] = curr_for;
2383 for (size_t i = 1; i < num; ++i) {
2384 TORCH_INTERNAL_ASSERT(
2385 curr_for->body()->nstmts() == 1,
2386 buildErrorMessage("Expected a single stmt in the loop body."));
2387 curr_for = to<For>(curr_for->body()->front());
2388 TORCH_INTERNAL_ASSERT(
2389 curr_for,
2390 buildErrorMessage("Expected the only child stmt to be a For loop."));
2391 loops[i] = curr_for;
2392 }
2393 return loops;
2394 }
2395
flatten(const std::vector<ForPtr> & loops,ForPtr * flattened)2396 bool LoopNest::flatten(const std::vector<ForPtr>& loops, ForPtr* flattened) {
2397 if (loops.empty()) {
2398 throw malformed_input("flatten attempted on empty set of loops");
2399 }
2400 BlockPtr p = to<Block>(loops[0]->get_parent());
2401 if (!p) {
2402 throw malformed_input("flatten attempted on loops with no parent");
2403 }
2404
2405 if (loops.size() == 1) {
2406 // This loop nest is already flattened.
2407 *flattened = loops[0];
2408 return false;
2409 }
2410
2411 // Check if all the loops correspond to a perfect loopnest:
2412 // * every loop except the inner-most should have only one stmt, the For.
2413 // Do not flatten, otherwise.
2414 // This check also ensures we do not flatten reduction loops.
2415 for (size_t i = 0; i < loops.size() - 1; ++i) {
2416 if ((loops[i]->body()->nstmts() != 1) ||
2417 (loops[i]->body()->front() != loops[i + 1])) {
2418 return false;
2419 }
2420 }
2421
2422 // Normalize the loops before flattening.
2423 // We need to normalize them from inner-most to outer because once the outer
2424 // loop is normalized, the given pointers to inner loops point to old code.
2425 // For the same reason, we can't store the normalized inner loops until after
2426 // the outer-most loop is normalized.
2427 for (size_t i = 0; i < loops.size(); ++i) {
2428 size_t idx = loops.size() - i - 1;
2429 LoopNest::normalize(loops[idx]);
2430 }
2431
2432 // 'normalized' points to the outer-most loop in the normalized loopnest.
2433 // Collect all the normalized loops.
2434 auto normalized_loops = getLoopStmtsInLoopNest(loops.front(), loops.size());
2435
2436 auto flat_var = alloc<Var>(
2437 normalized_loops[0]->var()->name_hint() + "_flat",
2438 normalized_loops[0]->var()->dtype());
2439 VarMapping var_mapping;
2440 ExprPtr stop = immLike(flat_var, 1);
2441 for (size_t i = 0; i < normalized_loops.size(); ++i) {
2442 size_t idx = normalized_loops.size() - i - 1;
2443 auto curr_loop = normalized_loops[idx];
2444 ExprPtr div = alloc<Div>(flat_var, stop);
2445 ExprPtr sub_expr = idx == 0 ? div : alloc<Mod>(div, curr_loop->stop());
2446 var_mapping.emplace_back(curr_loop->var(), sub_expr);
2447 stop = alloc<Mul>(curr_loop->stop(), stop);
2448 }
2449 auto flattened_body =
2450 Substitute(normalized_loops.back()->removeBody(), var_mapping);
2451
2452 normalized_loops.front()->set_var(flat_var);
2453 normalized_loops.front()->set_start(immLike(stop, 0));
2454 normalized_loops.front()->set_stop(stop);
2455 normalized_loops.front()->set_body(flattened_body);
2456 *flattened = normalized_loops.front();
2457 return true;
2458 }
2459
flatten(const std::vector<ForPtr> & loops)2460 bool LoopNest::flatten(const std::vector<ForPtr>& loops) {
2461 ForPtr flattened;
2462 return flatten(loops, &flattened);
2463 }
2464
compressBuffer(const BufPtr & buf,const StmtPtr & stmt)2465 void LoopNest::compressBuffer(const BufPtr& buf, const StmtPtr& stmt) {
2466 // Loop iterations in NNC IR do not follow sequential semantics by default.
2467 // In other words, the iterations of the loops could be executed in any
2468 // random order without affecting correctness. This constraint in turn
2469 // implies that there can’t be any *inter-iteration* dependences
2470 // (or *loop-carried* dependences) in NNC loops. So, any NNC IR with such
2471 // dependences is considered invalid.
2472 //
2473 // Given the constraint above, for any pair of accesses to a buffer (where
2474 // at least one of the access is a write), the accesses must be
2475 // loop-independent on the innermost loop containing the accesses as well as
2476 // all the loops above it. So, any dimension that uses only those loop
2477 // variables to access the given buffer could be optimized away.
2478 //
2479 // Algorithm:
2480 // * Find all the accesses to the given buf. (A)
2481 // * Find the parent common to all accesses in A. (P)
2482 // * Collect all the loops above P. (L)
2483 // * Collect all the loop variables corresponding to L. (LV)
2484 // * For every access a in A:
2485 // * For the index I in every dimension of a:
2486 // * If the variables in I are all in LV, mark this dimension
2487 // for deletion.
2488 // * For every dimension that is marked for deletion in ALL accesses in A:
2489 // * Update the buffer to set the size of that dimension to 1.
2490 // * Update all accesses in A to set the index in that dimension to 0.
2491
2492 auto writes = WritesToBuf::find(stmt, buf);
2493 auto reads = StmtsReadingBuf::find(stmt, buf);
2494
2495 // Find the parent common to all the buffer accesses.
2496 BlockPtr parent = to<Block>(writes.front()->get_parent());
2497 TORCH_INTERNAL_ASSERT(
2498 parent,
2499 buildErrorMessage(
2500 "Expected parent stmt to be a non-null block in compressBuffer in the fuser."));
2501 for (const auto& w : writes) {
2502 parent = Block::getSharedParent(parent, w);
2503 }
2504 for (const auto& r : reads) {
2505 parent = Block::getSharedParent(parent, r);
2506 }
2507
2508 // Collect all the loops that are above the common parent.
2509 auto loops = LoopNest::getEnclosingLoopNest(parent);
2510 std::unordered_set<VarPtr> loop_vars;
2511 for (const auto& l : loops) {
2512 loop_vars.insert(l->var());
2513 }
2514
2515 // TODO: Need to handle other Stmts / Exprs that read / write buffers.
2516 auto stores = NodeFinder<Store>::find(stmt);
2517 auto loads = NodeFinder<Load>::find(stmt);
2518
2519 // Vector to indicate which dimensions could be compressed away.
2520 std::vector<bool> dims(buf->dims().size(), true);
2521 auto check_indices = [&](const std::vector<ExprPtr>& indices) {
2522 TORCH_INTERNAL_ASSERT(
2523 indices.size() == dims.size(),
2524 buildErrorMessage(
2525 "Expected ranks to match in compressBuffer in the fuser."));
2526 for (size_t i = 0; i < indices.size(); ++i) {
2527 auto index_vars = NodeFinder<Var>::find(indices[i]);
2528 for (const auto& iv : index_vars) {
2529 if (loop_vars.count(iv) == 0) {
2530 // A variable in this index is not in loop_vars.
2531 // This implies that this dimension cannot be optimized away.
2532 dims[i] = false;
2533 break;
2534 }
2535 }
2536 }
2537 };
2538 for (const auto& s : stores) {
2539 if (s->buf() == buf) {
2540 check_indices(s->indices());
2541 }
2542 }
2543 for (const auto& l : loads) {
2544 if (l->buf() == buf) {
2545 check_indices(l->indices());
2546 }
2547 }
2548 bool any_dim_to_compress = false;
2549 for (auto d : dims) {
2550 any_dim_to_compress |= d;
2551 }
2552 if (!any_dim_to_compress) {
2553 return;
2554 }
2555
2556 // Compress buffer by removing the marked dims.
2557 std::vector<ExprPtr> new_dims(buf->dims());
2558 for (size_t i = 0; i < dims.size(); ++i) {
2559 if (dims[i]) {
2560 new_dims[i] = immLike(buf->dims()[i], 1);
2561 }
2562 }
2563 buf->set_dims(new_dims);
2564
2565 // Modify all access to reflect the removed dims.
2566 auto get_new_indices = [&](const std::vector<ExprPtr>& indices) {
2567 TORCH_INTERNAL_ASSERT(
2568 indices.size() == dims.size(),
2569 buildErrorMessage(
2570 "Expected ranks to match in compressBuffer in the fuser."));
2571 std::vector<ExprPtr> new_indices(indices);
2572 for (size_t i = 0; i < dims.size(); ++i) {
2573 if (dims[i]) {
2574 new_indices[i] = immLike(indices[i], 0);
2575 }
2576 }
2577 return new_indices;
2578 };
2579 for (const auto& s : stores) {
2580 if (s->buf() == buf) {
2581 s->set_indices(get_new_indices(s->indices()));
2582 }
2583 }
2584 for (const auto& l : loads) {
2585 if (l->buf() == buf) {
2586 l->set_indices(get_new_indices(l->indices()));
2587 }
2588 }
2589 }
2590
compressAllBuffers(const StmtPtr & stmt)2591 void LoopNest::compressAllBuffers(const StmtPtr& stmt) {
2592 for (const auto& buf : BufFinder::find(stmt)) {
2593 compressBuffer(buf, stmt);
2594 }
2595 }
2596
getLoopStmtsFor(const Tensor & t) const2597 std::vector<ForPtr> LoopNest::getLoopStmtsFor(const Tensor& t) const {
2598 StmtPtr cur_stmt = getLoopBodyFor(t);
2599 return getLoopStmtsFor(cur_stmt);
2600 }
2601
getLoopStmtsFor(const BufPtr & buf) const2602 std::vector<ForPtr> LoopNest::getLoopStmtsFor(const BufPtr& buf) const {
2603 StmtPtr cur_stmt = getLoopBodyFor(buf);
2604 return getLoopStmtsFor(cur_stmt);
2605 }
2606
getLoopStmtsFor(StmtPtr s) const2607 std::vector<ForPtr> LoopNest::getLoopStmtsFor(StmtPtr s) const {
2608 std::vector<ForPtr> result;
2609
2610 while (s) {
2611 if (auto loop = to<For>(s)) {
2612 result.push_back(loop);
2613 }
2614 s = s->get_parent();
2615 }
2616 std::reverse(result.begin(), result.end());
2617 return result;
2618 }
2619
getLoopBodyFor(const Tensor & t) const2620 StmtPtr LoopNest::getLoopBodyFor(const Tensor& t) const {
2621 return getLoopBodyFor(t.buf());
2622 }
2623
getLoopBodyFor(BufPtr buf) const2624 StmtPtr LoopNest::getLoopBodyFor(BufPtr buf) const {
2625 auto writes = WritesToBuf::find(root_stmt_, std::move(buf));
2626
2627 // special case for reduction Tensors, ignore the initializer if it's the only
2628 // op:
2629 if (writes.size() == 2) {
2630 if (StorePtr s = to<Store>(writes.back())) {
2631 if (ReduceOpPtr r = to<ReduceOp>(s->value())) {
2632 return (StmtPtr)s;
2633 }
2634 }
2635 }
2636
2637 StmtPtr res = nullptr;
2638 for (const auto& s : writes) {
2639 if (!res) {
2640 res = s;
2641 continue;
2642 }
2643
2644 res = Block::getSharedParent(res, s);
2645 }
2646
2647 return (StmtPtr)res;
2648 }
2649
getParentLoop(const StmtPtr & st)2650 ForPtr LoopNest::getParentLoop(const StmtPtr& st) {
2651 if (st == nullptr) {
2652 return nullptr;
2653 }
2654 auto par = st->get_parent();
2655 if (auto f = to<For>(par)) {
2656 return f;
2657 }
2658 return getParentLoop(par);
2659 }
2660
getEnclosingLoopNest(const StmtPtr & st)2661 std::vector<ForPtr> LoopNest::getEnclosingLoopNest(const StmtPtr& st) {
2662 std::vector<ForPtr> loops;
2663 auto f = getParentLoop(st);
2664 while (f) {
2665 loops.push_back(f);
2666 f = getParentLoop(f);
2667 }
2668 std::reverse(loops.begin(), loops.end());
2669 return loops;
2670 }
2671
getAllWritesToBuf(BufPtr buf) const2672 std::vector<StmtPtr> LoopNest::getAllWritesToBuf(BufPtr buf) const {
2673 return WritesToBuf::find(root_stmt_, std::move(buf));
2674 }
2675
getAllInnermostLoopsWritingToBuf(BufPtr buf) const2676 std::vector<ForPtr> LoopNest::getAllInnermostLoopsWritingToBuf(
2677 BufPtr buf) const {
2678 auto writes = getAllWritesToBuf(std::move(buf));
2679 std::vector<ForPtr> innermost_loops;
2680 innermost_loops.reserve(writes.size());
2681 for (const auto& w : writes) {
2682 innermost_loops.push_back(LoopNest::getParentLoop(w));
2683 }
2684 return innermost_loops;
2685 }
2686
getAllLoopNestsWritingToBuf(BufPtr buf) const2687 std::vector<std::vector<ForPtr>> LoopNest::getAllLoopNestsWritingToBuf(
2688 BufPtr buf) const {
2689 auto writes = getAllWritesToBuf(std::move(buf));
2690 std::vector<std::vector<ForPtr>> loopnests;
2691 loopnests.reserve(writes.size());
2692 for (const auto& w : writes) {
2693 loopnests.emplace_back(LoopNest::getEnclosingLoopNest(w));
2694 }
2695 return loopnests;
2696 }
2697
simplify()2698 StmtPtr LoopNest::simplify() {
2699 root_stmt_ = IRSimplifier::simplify(root_stmt_);
2700 return root_stmt_;
2701 }
2702
FlattenIndexes(const StmtPtr & s)2703 StmtPtr FlattenIndexes(const StmtPtr& s) {
2704 IndexFlattener idx_flattener;
2705 return idx_flattener.flatten(s);
2706 }
2707
2708 // Auxiliary class for rewriting we're doing in `compute_at`. See
2709 // LoopNest::computeAt for more details.
2710 class LoopComputeAtRewriter : public IRMutator {
2711 public:
LoopComputeAtRewriter(BufPtr buf,BufPtr new_buf,std::vector<ExprPtr> offsets)2712 LoopComputeAtRewriter(
2713 BufPtr buf,
2714 BufPtr new_buf,
2715 std::vector<ExprPtr> offsets)
2716 : buf_(std::move(buf)),
2717 new_buf_(std::move(new_buf)),
2718 offsets_(std::move(offsets)) {}
2719
2720 private:
2721 BufPtr buf_;
2722 BufPtr new_buf_;
2723 std::vector<ExprPtr> offsets_;
2724
mutate(const LoadPtr & v)2725 ExprPtr mutate(const LoadPtr& v) override {
2726 if (v->buf() != buf_) {
2727 return v;
2728 }
2729 std::vector<ExprPtr> new_indices(v->indices().size());
2730 for (const auto i : c10::irange(v->indices().size())) {
2731 new_indices[i] =
2732 IRSimplifier::simplify(alloc<Sub>(v->indices()[i], offsets_[i]));
2733 }
2734 return alloc<Load>(v->dtype(), new_buf_, new_indices);
2735 }
2736 };
2737
getStoreStmtOfProducer(const StmtPtr & s)2738 static StorePtr getStoreStmtOfProducer(const StmtPtr& s) {
2739 if (StorePtr st = to<Store>(s)) {
2740 return st;
2741 }
2742 if (BlockPtr b = to<Block>(s)) {
2743 for (const StmtPtr& ss : *b) {
2744 if (StorePtr st = to<Store>(ss)) {
2745 return st;
2746 }
2747 }
2748 }
2749 return nullptr;
2750 }
2751
getOuterLoopIndexes(StmtPtr s)2752 static std::vector<VarPtr> getOuterLoopIndexes(StmtPtr s) {
2753 std::vector<VarPtr> res;
2754 StmtPtr cur = std::move(s);
2755 while (cur) {
2756 if (auto l = to<For>(cur)) {
2757 res.push_back(l->var());
2758 }
2759 cur = cur->get_parent();
2760 }
2761 return res;
2762 }
2763
2764 class CacheReplacer : public IRMutator {
2765 public:
CacheReplacer(BufPtr buffer,BufPtr cache,std::vector<ExprPtr> & offsets)2766 CacheReplacer(BufPtr buffer, BufPtr cache, std::vector<ExprPtr>& offsets)
2767 : buf_(std::move(buffer)), cache_(std::move(cache)), offsets_(offsets) {}
2768
2769 private:
mutate(const LoadPtr & v)2770 ExprPtr mutate(const LoadPtr& v) override {
2771 BufPtr buf = v->buf();
2772 if (buf != buf_) {
2773 return IRMutator::mutate(v);
2774 }
2775
2776 // Map indices to call-parameters.
2777 std::vector<ExprPtr> newIndices;
2778 TORCH_INTERNAL_ASSERT(
2779 offsets_.size() == v->indices().size(),
2780 buildErrorMessage(
2781 "Expected ranks to match in CacheReplacer in the fuser."));
2782 for (size_t i = 0; i < v->indices().size(); ++i) {
2783 ExprPtr index = v->indices()[i]->accept_mutator(this);
2784 ExprPtr offset = offsets_[i];
2785 ExprPtr sub = IRSimplifier::simplify(alloc<Sub>(index, offset));
2786 newIndices.push_back(sub);
2787 }
2788 v->set_buf(cache_);
2789 v->set_indices(newIndices);
2790 return v;
2791 }
2792
mutate(const StorePtr & v)2793 StmtPtr mutate(const StorePtr& v) override {
2794 BufPtr buf = v->buf();
2795 if (buf != buf_) {
2796 return IRMutator::mutate(v);
2797 }
2798
2799 ExprPtr newValue = v->value()->accept_mutator(this);
2800
2801 // Map indices to call-parameters.
2802 std::vector<ExprPtr> newIndices;
2803 TORCH_INTERNAL_ASSERT(
2804 offsets_.size() == v->indices().size(),
2805 buildErrorMessage(
2806 "Expected ranks to match in CacheReplacer in the fuser."));
2807 for (size_t i = 0; i < v->indices().size(); ++i) {
2808 ExprPtr index = v->indices()[i]->accept_mutator(this);
2809 ExprPtr offset = offsets_[i];
2810 ExprPtr sub = IRSimplifier::simplify(alloc<Sub>(index, offset));
2811 newIndices.push_back(sub);
2812 }
2813 v->set_buf(cache_);
2814 v->set_indices(newIndices);
2815 v->set_value(newValue);
2816 return v;
2817 }
2818
2819 BufPtr buf_;
2820 BufPtr cache_;
2821 std::vector<ExprPtr>& offsets_;
2822 };
2823
cacheAccesses(const BufPtr & producer,const std::string & name,const StmtPtr & consumer)2824 LoopNest::AccessResult LoopNest::cacheAccesses(
2825 const BufPtr& producer,
2826 const std::string& name,
2827 const StmtPtr& consumer) {
2828 ReduceOpPtr reduceOp{nullptr};
2829 auto stores = NodeFinder<Store>::find(consumer);
2830 for (const auto& store : stores) {
2831 if (auto ro = to<ReduceOp>(store->value())) {
2832 if (store->buf() != producer) {
2833 continue;
2834 }
2835
2836 if (reduceOp) {
2837 throw std::runtime_error(
2838 "can only cache accesses used by at most a single reduceOp");
2839 return {nullptr, nullptr};
2840 }
2841
2842 reduceOp = ro;
2843 }
2844 }
2845
2846 // Check bounds but don't care about AccessKind.
2847 auto consumer_bounds_info = inferBounds(consumer, false);
2848 auto bounds_it = consumer_bounds_info.find(producer);
2849 if (bounds_it == consumer_bounds_info.end()) {
2850 throw std::runtime_error("consumer does not use the Tensor produced");
2851 return {nullptr, nullptr};
2852 }
2853
2854 TORCH_INTERNAL_ASSERT(
2855 bounds_it->second.size() == 1,
2856 buildErrorMessage(
2857 "Unexpected number of bound info entries in cacheAccesses in the fuser."));
2858 TensorAccessBoundsInfo& info = bounds_it->second[0];
2859 bool hasReads = info.kind == kLoad || info.kind == kMutate;
2860 bool hasWrites = info.kind == kStore || info.kind == kMutate;
2861
2862 std::vector<std::string> var_names = {"i", "j", "k", "l", "m", "n", "o", "p"};
2863 std::vector<ExprPtr> tmp_dims;
2864 std::vector<VarPtr> new_loop_vars;
2865 std::vector<ExprPtr> new_loop_vars_expr;
2866
2867 // Determine the size of the cache, and create a loop var for each dimension.
2868 for (size_t i = 0; i < info.start.size(); ++i) {
2869 ExprPtr dim = IRSimplifier::simplify(alloc<Add>(
2870 alloc<Sub>(info.stop[i], info.start[i]), immLike(info.stop[i], 1)));
2871
2872 tmp_dims.push_back(dim);
2873
2874 new_loop_vars.push_back(
2875 alloc<Var>(var_names[i % var_names.size()], info.stop[i]->dtype()));
2876 new_loop_vars_expr.push_back(new_loop_vars[i]);
2877 }
2878
2879 // Create the var.
2880 BufPtr tmp_buf =
2881 alloc<Buf>(alloc<Var>(name, kHandle), tmp_dims, producer->dtype());
2882
2883 // determine the offsets for calls into the cache based off the loop start of
2884 // each axis.
2885 std::vector<ExprPtr> tmp_params;
2886 for (size_t i = 0; i < new_loop_vars.size(); ++i) {
2887 tmp_params.push_back(alloc<Add>(new_loop_vars[i], info.start[i]));
2888 }
2889
2890 // Replace accesses to the producer in the consumer with the cache.
2891 CacheReplacer replacer(producer, tmp_buf, info.start);
2892 consumer->accept_mutator(&replacer);
2893
2894 // replace the old consumer with the replaced consumer.
2895 BlockPtr consumer_block = to<Block>(consumer);
2896 BlockPtr parent_block = to<Block>(consumer->get_parent());
2897 // if the consumer is a block, we should mutate it in place.
2898 bool is_block = consumer_block != nullptr;
2899
2900 // If there's a reduction and we are operating on the reduce axis, we need to
2901 // initialize the cache with 0s. Also, we can't just write the result straight
2902 // back to the original buffer, since after parallelism the writes will race.
2903 // Instead we need to create a new ReduceOp.
2904 bool on_reduce_axis = false;
2905 if (reduceOp) {
2906 std::set<VarPtr> reduce_args(
2907 reduceOp->reduce_args().begin(), reduceOp->reduce_args().end());
2908 std::set<VarPtr> enclosing_vars;
2909 for (const auto& enclosing_for_stmt : NodeFinder<For>::find(consumer)) {
2910 enclosing_vars.insert(enclosing_for_stmt->var());
2911 }
2912 for (const auto& reduce_arg : reduce_args) {
2913 if (enclosing_vars.find(reduce_arg) == enclosing_vars.end()) {
2914 on_reduce_axis = true;
2915 }
2916 }
2917 }
2918 if (reduceOp && on_reduce_axis) {
2919 // reduceOp means we had both loads and stores.
2920
2921 // Init cache to 0.
2922 StmtPtr tmp_init = alloc<Store>(
2923 tmp_buf, new_loop_vars_expr, getImmediateByType(tmp_buf->dtype(), 0));
2924
2925 for (int64_t i = new_loop_vars.size() - 1; i >= 0; --i) {
2926 tmp_init = alloc<For>(
2927 new_loop_vars[i], immLike(tmp_dims[i], 0), tmp_dims[i], tmp_init);
2928 }
2929
2930 if (is_block) {
2931 consumer_block->prepend_stmt(tmp_init);
2932 } else {
2933 parent_block->insert_stmt_before(tmp_init, consumer);
2934 }
2935
2936 // Reduce back to the original buffer:
2937 StmtPtr tmp_store = alloc<Store>(
2938 producer,
2939 tmp_params,
2940 reduceOp->reducer()(
2941 producer,
2942 alloc<Load>(tmp_buf, new_loop_vars_expr),
2943 tmp_params,
2944 {}));
2945
2946 for (int64_t i = new_loop_vars.size() - 1; i >= 0; --i) {
2947 tmp_store = alloc<For>(
2948 new_loop_vars[i], immLike(tmp_dims[i], 0), tmp_dims[i], tmp_store);
2949 }
2950
2951 if (is_block) {
2952 consumer_block->append_stmt(tmp_store);
2953 } else {
2954 parent_block->insert_stmt_after(tmp_store, consumer);
2955 }
2956
2957 return std::make_pair(tmp_buf, consumer);
2958 }
2959
2960 if (hasReads) {
2961 // Fill the cache with values from the consumer.
2962 StmtPtr tmp_store = alloc<Store>(
2963 tmp_buf, new_loop_vars_expr, alloc<Load>(producer, tmp_params));
2964
2965 for (int64_t i = new_loop_vars.size() - 1; i >= 0; --i) {
2966 tmp_store = alloc<For>(
2967 new_loop_vars[i], immLike(tmp_dims[i], 0), tmp_dims[i], tmp_store);
2968 }
2969
2970 if (is_block) {
2971 consumer_block->prepend_stmt(tmp_store);
2972 } else {
2973 parent_block->insert_stmt_before(tmp_store, consumer);
2974 }
2975 }
2976
2977 if (hasWrites) {
2978 // sync the cache back to the producer buf.
2979 StmtPtr tmp_store = alloc<Store>(
2980 producer, tmp_params, alloc<Load>(tmp_buf, new_loop_vars_expr));
2981
2982 for (int64_t i = static_cast<int64_t>(new_loop_vars.size()) - 1; i >= 0;
2983 --i) {
2984 tmp_store = alloc<For>(
2985 new_loop_vars[i], immLike(tmp_dims[i], 0), tmp_dims[i], tmp_store);
2986 }
2987
2988 if (is_block) {
2989 consumer_block->append_stmt(tmp_store);
2990 } else {
2991 parent_block->insert_stmt_after(tmp_store, consumer);
2992 }
2993 }
2994
2995 return std::make_pair(tmp_buf, consumer);
2996 }
2997
2998 /*
2999 * WHAT COMPUTE_AT DOES
3000 * ====================
3001 *
3002 * Suppose we have two loops:
3003 *
3004 * for i in 0..100:
3005 * for j in 0..200:
3006 * A[i,j] = sin(i*j)
3007 * for i in 0..100:
3008 * for j in 0..199:
3009 * B[i,j] = A[i,j] + A[i, j+1]
3010 *
3011 * If we compute these loops as is, we would have to allocate two buffers:
3012 * 100x200 for A and 100x199 for B. To decrease the memory usage one can use
3013 * compute_inline primitive, which would result in the following:
3014 *
3015 * for i in 0..100:
3016 * for j in 0..199:
3017 * B[i,j] = sin(i*j) + sin(i*(j+1))
3018 *
3019 * We now need only one buffer - 100x199 for B. However, we're now doing some
3020 * redundant computations: we're calling `sin` twice as much as in the first
3021 * version.
3022 *
3023 * Ultimately, we nede to choose at what point we prefer to compute values of
3024 * A[i,j] - we can do it in the very beginning for the entire buffer A (the
3025 * first option) or compute it on the fly when we compute B (the second option).
3026 * There are also options in between those two: we can compute a part of B which
3027 * is required for a computation of part of B, e.g. for a single row of B. The
3028 * code would then look like:
3029 *
3030 * for i in 0..100:
3031 * for j in 0..200:
3032 * A[j] = sin(i*j)
3033 * for j in 0..199:
3034 * B[i,j] = A[j] + A[j+1]
3035 *
3036 * In this case we're only using 1x200 for A, and we're avoiding redundant
3037 * computations.
3038 *
3039 * The purpose of `compute_at` is to achieve exactly this transformation.
3040 *
3041 * compute_at requires to specify What to compute and Where to compute: in our
3042 * example we would call compute_at(What=`A[i,j] = sin(i*j)`, Where=`for i in
3043 * 0..100`).
3044 *
3045 * More info about compute_at could be found in Halide's tutorials:
3046 * https://halide-lang.org/tutorials/tutorial_lesson_08_scheduling_2.html
3047 *
3048 * HOW COMPUTE_AT WORKS
3049 * ====================
3050 *
3051 * The most important part of compute_at is bounds inference: we need to figure
3052 * out what part of the used tensors we need to compute when we move the
3053 * computation to a new scope. In the example above, we need bounds inference to
3054 * tell us that in order to compute A at each iteration of the outer loop, we
3055 * need to compute A within indices [i:i+1,0:200].
3056 *
3057 * This info allows us to conclude that we need a temp buffer of size 1x200.
3058 *
3059 * Once this is known we need to insert statements for allocation and freeing
3060 * the temporary buffer and copy the original computation to fill the temp
3061 * buffer with proper values. When we copy the computation we also must rewrite
3062 * indices used in it: old indices are referring to the old loop and are not
3063 * valid in the new loop.
3064 *
3065 * To easier follow the logic, let's examine an example. Suppose we start from
3066 * the following loop nest:
3067 * for py in 0..100:
3068 * for px in 0..100:
3069 * producer[py,px] = py*px
3070 * for cy in 0..100:
3071 * for cx in 0..100:
3072 * consumer[cy,cx] = producer[cy,cx]
3073 *
3074 * And then we're running `compute_at(producer, cy)`.
3075 *
3076 * What we would like to get is the following loop nest:
3077 * for py in 0..100:
3078 * for px in 0..100:
3079 * producer[py,px] = py*px
3080 * for cy in 0..100:
3081 * Allocate(temp, {1, 100})
3082 * for ty in 0..1:
3083 * for tx in 0..100:
3084 * temp[ty,tx] = (ty+cy)*(tx+0)
3085 * for cx in 0..100:
3086 * consumer[cy,cx] = temp[0,cx]
3087 * Free(temp)
3088 *
3089 * NB: this loop nest can and should be simplified (e.g. the producer loop can
3090 * be removed since its result is no longer used), but this clean-up
3091 * optimization is performed separately (currently, not performed at all).
3092 *
3093 * If we examine the final loop nest, we can identify that the following steps
3094 * needs to be performed:
3095 * - Bounds inference needs to tell us that we need a 1x100 buffer for temp.
3096 * - Allocate and Free statements for this buffer need to be inserted to the
3097 * loop.
3098 * - A new loop-nest should be inserted to the loop CY for computing `temp`
3099 * and it should replicate the loopnest of producer (PY,PX loops). The indices
3100 * in the loop body need to be offset by (cy, 0) - the offsets come from
3101 * bounds inference too.
3102 * - The computation of `consumer` needs to be rewritten so that it uses
3103 * `temp` instead of `producer`. The indices in the corresponding accesses
3104 * also need to be offset.
3105 */
computeAt(const StmtPtr & s,const ForPtr & f)3106 void LoopNest::computeAt(const StmtPtr& s, const ForPtr& f) {
3107 StorePtr st = getStoreStmtOfProducer(s);
3108 if (!st) {
3109 return;
3110 }
3111
3112 // Infer bounds info for all accesses that we make in the loop
3113 auto loop_bounds_info = inferBounds(f->body());
3114
3115 // bounds_it holds bounds info for the store we're trying to move to
3116 // the loop. If its result isn't accessed in the loop at all - do nothing and
3117 // exit early.
3118 auto bounds_it = loop_bounds_info.find(st->buf());
3119 if (bounds_it == loop_bounds_info.end()) {
3120 return;
3121 }
3122
3123 // Compute dimensions of the temp buffer we would need to allocate
3124 std::vector<ExprPtr> dims = getBoundExtents(bounds_it->second);
3125
3126 // TODO: Use name-hint of the producer instead of "temp"
3127 BufPtr temp_buf = alloc<Buf>("temp", dims, st->value()->dtype());
3128
3129 // Generate index variables for 'temp'
3130 std::vector<ExprPtr> temp_indices(dims.size());
3131 for (const auto i : c10::irange(dims.size())) {
3132 // TODO: Use name-hint of the producer indices instead of 'idx'
3133 temp_indices[i] =
3134 alloc<Var>(std::string("idx") + std::to_string(i), dims[i]->dtype());
3135 }
3136
3137 // Prepare substitute rules for constructing the temp statement from the prod
3138 // statement
3139 // TODO: Instead of going up the loop nest we should go through the indices in
3140 // the original tensor expression. The loops in the nest might've been
3141 // modified (e.g. split or merged) so that the loop indices no longer
3142 // correspond to the indices of the original expression and even their number
3143 // might be different. In that case, the loop below would crash.
3144 std::vector<VarPtr> prod_indices = getOuterLoopIndexes(s);
3145 std::vector<std::pair<VarPtr, ExprPtr>> rewrite_indices_map;
3146 std::vector<ExprPtr> offsets;
3147 for (const TensorAccessBoundsInfo& p : bounds_it->second) {
3148 for (const auto i : c10::irange(p.start.size())) {
3149 if (offsets.size() <= i) {
3150 offsets.push_back(p.start[i]);
3151 } else {
3152 offsets[i] =
3153 IRSimplifier::simplify(alloc<Min>(offsets[i], p.start[i], true));
3154 }
3155 }
3156 }
3157
3158 for (const auto i : c10::irange(prod_indices.size())) {
3159 rewrite_indices_map.emplace_back(
3160 prod_indices[i], alloc<Add>(temp_indices[i], offsets[i]));
3161 }
3162
3163 // Construct the temp statement
3164 StmtPtr bd = alloc<Store>(
3165 temp_buf,
3166 temp_indices,
3167 SubstituteInClone(st->value(), rewrite_indices_map));
3168
3169 // Construct the loop nest for the temp computation
3170 for (const auto i : c10::irange(dims.size())) {
3171 // We're creating loops from innermost to outermost, so we need to access
3172 // dimensions in reversed order.
3173 size_t dim_idx = dims.size() - 1 - i;
3174 bd = alloc<For>(
3175 to<Var>(temp_indices[dim_idx]),
3176 immLike(dims[dim_idx], 0),
3177 dims[dim_idx],
3178 bd);
3179 }
3180
3181 // Add constructed stmts to the consumer loop
3182 f->body()->prepend_stmt(bd);
3183
3184 // Rewrite accesses to producer in consumer with accesses to temp
3185 LoopComputeAtRewriter lr(st->buf(), temp_buf, offsets);
3186 StmtPtr new_f = f->accept_mutator(&lr);
3187 if (f != new_f) {
3188 BlockPtr bb = to<Block>(f->get_parent());
3189 bb->replace_stmt(f, new_f);
3190 }
3191 }
3192
3193 class RfactorStoreRewriter : public IRMutator {
3194 public:
RfactorStoreRewriter(BufPtr old_buf,const std::vector<ExprPtr> & old_indices,BufPtr new_buf,VarPtr reduction_var)3195 RfactorStoreRewriter(
3196 BufPtr old_buf,
3197 const std::vector<ExprPtr>& old_indices,
3198 BufPtr new_buf,
3199 VarPtr reduction_var)
3200 : old_buf_(std::move(old_buf)),
3201 old_indices_(old_indices),
3202 new_buf_(std::move(new_buf)),
3203 reduction_var_(std::move(reduction_var)),
3204 new_indices_(old_indices) {
3205 new_indices_.push_back(reduction_var_);
3206 }
3207
mutate(const LoadPtr & v)3208 ExprPtr mutate(const LoadPtr& v) override {
3209 if (v->buf() != old_buf_) {
3210 return IRMutator::mutate(v);
3211 }
3212
3213 TORCH_INTERNAL_ASSERT(
3214 old_indices_.size() == v->indices().size(),
3215 buildErrorMessage(
3216 "Expected ranks to match in RfactorStoreRewriter in the fuser."));
3217
3218 bool equal_indices = true;
3219 for (size_t i = 0; i < v->indices().size(); ++i) {
3220 if (!exprEquals(v->indices()[i], old_indices_[i])) {
3221 equal_indices = false;
3222 break;
3223 }
3224 }
3225 if (!equal_indices) {
3226 return IRMutator::mutate(v);
3227 }
3228
3229 return alloc<Load>(new_buf_, new_indices_);
3230 }
3231
mutate(const ReduceOpPtr & v)3232 ExprPtr mutate(const ReduceOpPtr& v) override {
3233 ExprPtr body_new = v->body()->accept_mutator(this);
3234
3235 std::vector<VarPtr> new_reduce_args;
3236 for (const auto& r : v->reduce_args()) {
3237 if (r != reduction_var_) {
3238 new_reduce_args.push_back(r);
3239 }
3240 }
3241
3242 return alloc<ReduceOp>(body_new, new_reduce_args, v->reducer());
3243 }
3244
mutate(const StorePtr & v)3245 StmtPtr mutate(const StorePtr& v) override {
3246 if (v->buf() != old_buf_) {
3247 return IRMutator::mutate(v);
3248 }
3249
3250 TORCH_INTERNAL_ASSERT(
3251 old_indices_.size() == v->indices().size(),
3252 buildErrorMessage(
3253 "Expected ranks to match in RfactorStoreRewriter in the fuser."));
3254
3255 bool equal_indices = true;
3256 for (size_t i = 0; i < v->indices().size(); ++i) {
3257 if (!exprEquals(v->indices()[i], old_indices_[i])) {
3258 equal_indices = false;
3259 break;
3260 }
3261 }
3262 if (!equal_indices) {
3263 return IRMutator::mutate(v);
3264 }
3265
3266 ExprPtr new_value = v->value()->accept_mutator(this);
3267 return alloc<Store>(new_buf_, new_indices_, new_value);
3268 }
3269
3270 private:
3271 BufPtr old_buf_;
3272 const std::vector<ExprPtr>& old_indices_;
3273 BufPtr new_buf_;
3274 VarPtr reduction_var_;
3275 std::vector<ExprPtr> new_indices_;
3276 };
3277
rfactor(const StmtPtr & st,const ForPtr & target_for)3278 bool LoopNest::rfactor(const StmtPtr& st, const ForPtr& target_for) {
3279 BufPtr tmp_buf = nullptr;
3280 return rfactor(st, target_for, &tmp_buf);
3281 }
3282
rfactor(const StmtPtr & st,const ForPtr & outer_reduction_for,BufPtr * rfac_buf_ptr)3283 bool LoopNest::rfactor(
3284 const StmtPtr& st,
3285 const ForPtr& outer_reduction_for,
3286 BufPtr* rfac_buf_ptr) {
3287 StorePtr reduction_store = to<Store>(st);
3288 ReduceOpPtr reduce_op = to<ReduceOp>(reduction_store->value());
3289 if (!reduce_op) {
3290 // Not a reduction store
3291 return false;
3292 }
3293
3294 auto orig_buf = reduction_store->buf();
3295 auto orig_buf_indices = reduction_store->indices();
3296 VarPtr reduction_var = outer_reduction_for->var();
3297
3298 std::set<VarPtr> reduce_args = {
3299 reduce_op->reduce_args().begin(), reduce_op->reduce_args().end()};
3300
3301 if (reduce_args.size() < 2) {
3302 // Not enough reduction axis to do rfactor
3303 return false;
3304 }
3305
3306 // Verify that outer_reduction_for is a perfect loop nest with all loops being
3307 // reductions
3308 StmtPtr cur = outer_reduction_for;
3309 while (ForPtr cur_for = to<For>(cur)) {
3310 if (!reduce_args.count(cur_for->var())) {
3311 // output axis inside outer_reduction_for are not allowed
3312 return false;
3313 }
3314 reduce_args.erase(cur_for->var());
3315
3316 BlockPtr b = cur_for->body();
3317 if (b->nstmts() != 1) {
3318 return false;
3319 }
3320 cur = b->stmts().front();
3321 }
3322 if (cur != st) {
3323 // The reduction store is not a single stmt in the innermost loop - bail in
3324 // that case
3325 return false;
3326 }
3327 if (!reduce_args.empty()) {
3328 // This is not the outermost reduction axis
3329 return false;
3330 }
3331
3332 // assert: reduce_axis match loop vars from outer_reduction_for and inside
3333 // assert: no other stmts in outer_reduction_for or its child loops
3334
3335 std::vector<ExprPtr> rfac_dims = orig_buf->dims();
3336 ExprPtr extra_dim = IRSimplifier::simplify(
3337 alloc<Sub>(outer_reduction_for->stop(), outer_reduction_for->start()));
3338 rfac_dims.push_back(extra_dim);
3339 ExprPtr rfac_init =
3340 alloc<Cast>(reduce_op->dtype(), reduce_op->reducer().initializer());
3341
3342 *rfac_buf_ptr = alloc<Buf>(
3343 orig_buf->name_hint() + "_rfac",
3344 rfac_dims,
3345 reduce_op->dtype(),
3346 rfac_init);
3347 BufPtr rfac_buf = *rfac_buf_ptr;
3348
3349 // Rewrite the original reduction store to use the temporary rfac buffer:
3350 // 1) X[*indexes] --> T[*indexes + {reduction_var}]
3351 // 2) reduce_axis -= {reduction_var}
3352 RfactorStoreRewriter rfac_rewriter(
3353 orig_buf, orig_buf_indices, rfac_buf, reduction_var);
3354 to<Block>(st->get_parent())
3355 ->replace_stmt(st, st->accept_mutator(&rfac_rewriter));
3356
3357 // Insert a store for the final reduction over the temp buffer into the
3358 // original buffer:
3359 // X[*indexes] = ReduceOp(X[*indexes] + T[*indexes + {reduction_var}],
3360 // reduce_axis={reduction_var})
3361 BlockPtr b = outer_reduction_for->body();
3362 TORCH_INTERNAL_ASSERT(
3363 b->nstmts() == 1,
3364 buildErrorMessage(
3365 "Expected to have a single stmt in the block in rfactor transformation in the fuser."));
3366 StmtPtr first_reduction_loop = b->stmts().front();
3367 auto rfac_buf_indices = orig_buf_indices;
3368 rfac_buf_indices.emplace_back(reduction_var);
3369
3370 ExprPtr final_reduce_load = alloc<Load>(rfac_buf, rfac_buf_indices);
3371 outer_reduction_for->body()->insert_stmt_after(
3372 alloc<Store>(
3373 orig_buf,
3374 orig_buf_indices,
3375 reduce_op->reducer()(
3376 orig_buf, final_reduce_load, orig_buf_indices, {reduction_var})),
3377 first_reduction_loop);
3378
3379 // Insert an initialization store for the temp buffer:
3380 // T[a,b,c] = init
3381 outer_reduction_for->body()->insert_stmt_before(
3382 alloc<Store>(rfac_buf, rfac_buf_indices, rfac_init),
3383 first_reduction_loop);
3384 return true;
3385 }
3386
3387 } // namespace torch::jit::tensorexpr
3388