xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/tensorexpr/loopnest.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/tensorexpr/loopnest.h>
2 
3 #include <algorithm>
4 #include <iostream>
5 #include <stdexcept>
6 #include <unordered_map>
7 #include <unordered_set>
8 #include <utility>
9 #include <vector>
10 
11 #include <c10/util/Logging.h>
12 #include <c10/util/irange.h>
13 
14 #include <ATen/core/functional.h>
15 #include <torch/csrc/jit/jit_log.h>
16 #include <torch/csrc/jit/tensorexpr/analysis.h>
17 #include <torch/csrc/jit/tensorexpr/bounds_inference.h>
18 #include <torch/csrc/jit/tensorexpr/eval.h>
19 #include <torch/csrc/jit/tensorexpr/expr.h>
20 #include <torch/csrc/jit/tensorexpr/ir.h>
21 #include <torch/csrc/jit/tensorexpr/ir_cloner.h>
22 #include <torch/csrc/jit/tensorexpr/ir_mutator.h>
23 #include <torch/csrc/jit/tensorexpr/ir_printer.h>
24 #include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
25 #include <torch/csrc/jit/tensorexpr/ir_verifier.h>
26 #include <torch/csrc/jit/tensorexpr/tensor.h>
27 
28 #include <stdexcept>
29 #include <unordered_map>
30 #include <unordered_set>
31 #include <vector>
32 
33 namespace torch::jit::tensorexpr {
34 
LoopNest(const LoopNest & other)35 LoopNest::LoopNest(const LoopNest& other)
36     : root_stmt_(Stmt::clone(other.root_stmt_)),
37       output_bufs_(other.output_bufs_) {
38   GRAPH_DEBUG("Origin Stmt in LoopNest:\n", std::to_string(root_stmt_));
39   verify(root_stmt_);
40 }
41 
LoopNest(StmtPtr stmt,std::unordered_set<BufPtr> output_bufs)42 LoopNest::LoopNest(StmtPtr stmt, std::unordered_set<BufPtr> output_bufs)
43     : root_stmt_(std::move(stmt)), output_bufs_(std::move(output_bufs)) {
44   GRAPH_DEBUG("Origin Stmt in LoopNest:\n", std::to_string(root_stmt_));
45   verify(root_stmt_);
46 }
47 
LoopNest(const std::vector<Tensor> & output_tensors,const std::vector<Tensor> & tensors_to_compute)48 LoopNest::LoopNest(
49     const std::vector<Tensor>& output_tensors,
50     const std::vector<Tensor>& tensors_to_compute) {
51   initialize(output_tensors, tensors_to_compute);
52   GRAPH_DEBUG("Origin Stmt in LoopNest:\n", std::to_string(root_stmt_));
53   verify(root_stmt_);
54 }
55 
LoopNest(const std::vector<Tensor> & output_tensors)56 LoopNest::LoopNest(const std::vector<Tensor>& output_tensors) {
57   initialize(output_tensors, output_tensors);
58   GRAPH_DEBUG("Origin Stmt in LoopNest:\n", std::to_string(root_stmt_));
59   verify(root_stmt_);
60 }
61 
getIntermediateBufs() const62 std::vector<BufPtr> LoopNest::getIntermediateBufs() const {
63   std::vector<BufPtr> result;
64   std::unordered_set<BufPtr> result_set;
65   auto input_bufs = getInputBufs();
66   auto bufs = NodeFinder<Buf>::find(root_stmt_);
67   for (const auto& buf : bufs) {
68     if (!output_bufs_.count(buf) && !input_bufs.count(buf) &&
69         !result_set.count(buf)) {
70       result.push_back(buf);
71       result_set.insert(buf);
72     }
73   }
74   return result;
75 }
76 
getInputBufs() const77 const std::unordered_set<BufPtr> LoopNest::getInputBufs() const {
78   std::unordered_set<BufPtr> result;
79   auto buf_load_store_uses = findLoadOrStoreUses(root_stmt_);
80   for (auto& kv : buf_load_store_uses) {
81     bool has_store = false;
82     for (auto& use : kv.second) {
83       if (use.isStore) {
84         has_store = true;
85         break;
86       }
87     }
88     if (!has_store) {
89       result.insert(kv.first);
90     }
91   }
92   return result;
93 }
94 
95 class IndexFlattener : public IRMutator {
96  public:
flatten(const StmtPtr & s)97   StmtPtr flatten(const StmtPtr& s) {
98     return s->accept_mutator(this);
99   }
100 
mutate(const LoadPtr & v)101   ExprPtr mutate(const LoadPtr& v) override {
102     if (v->indices().size() == 1) {
103       return v;
104     }
105     return alloc<Load>(
106         v->dtype(),
107         v->buf(),
108         std::vector<ExprPtr>({flatten_index(
109             v->buf()->dims(), v->indices(), v->buf()->strides())}));
110   }
111 
mutate(const StorePtr & v)112   StmtPtr mutate(const StorePtr& v) override {
113     ExprPtr value = v->value();
114     ExprPtr new_value = value->accept_mutator(this);
115     if (v->indices().size() == 1 && value == new_value) {
116       return v;
117     }
118     std::vector<ExprPtr> indices = {
119         flatten_index(v->buf()->dims(), v->indices(), v->buf()->strides())};
120     v->set_indices(indices);
121     v->set_value(new_value);
122     return v;
123   }
124 };
125 
isValidIdentifierChar(char c,size_t pos)126 static bool isValidIdentifierChar(char c, size_t pos) {
127   return islower(c) || isupper(c) || c == '_' || (pos > 0 && isdigit(c));
128 }
129 
130 // replaces all invalid characters with underscore
sanitizeName(const std::string & input_name)131 std::string sanitizeName(const std::string& input_name) {
132   std::stringstream sanitized_name;
133   for (size_t i = 0; i < input_name.size(); ++i) {
134     if (isValidIdentifierChar(input_name[i], i)) {
135       sanitized_name << input_name[i];
136     } else {
137       if (i == 0) {
138         // Don't start names with underscore
139         sanitized_name << "v";
140       }
141       sanitized_name << "_";
142     }
143   }
144   return sanitized_name.str();
145 }
146 
147 class VarNameSanitizer : public IRMutator {
148  public:
mutate(const BufPtr & v)149   ExprPtr mutate(const BufPtr& v) override {
150     if (seen_bufs_.count(v)) {
151       return v;
152     }
153     const std::string& name = v->name_hint();
154     auto new_name = sanitizeName(name);
155     if (taken_names_.count(new_name)) {
156       new_name = getNextAvailableName(new_name);
157     }
158     v->set_name_hint(new_name);
159     taken_names_.insert(new_name);
160     seen_bufs_.insert(v);
161     return v;
162   }
163 
mutate(const VarPtr & v)164   ExprPtr mutate(const VarPtr& v) override {
165     if (seen_vars_.count(v)) {
166       return v;
167     }
168     const std::string& name = v->name_hint();
169     auto new_name = sanitizeName(name);
170     if (taken_names_.count(new_name)) {
171       new_name = getNextAvailableName(new_name);
172     }
173     v->set_name_hint(new_name);
174     taken_names_.insert(new_name);
175     seen_vars_.insert(v);
176     return v;
177   }
178 
mutate(const ForPtr & v)179   StmtPtr mutate(const ForPtr& v) override {
180     auto new_name = getNextAvailableName(getIndexVarNameAtLevel(level_));
181     if (seen_index_vars_.count(v->var())) {
182       auto new_var = alloc<Var>("", v->var()->dtype());
183       Substitute(v, {{v->var(), new_var}});
184     }
185     v->var()->set_name_hint(new_name);
186     seen_index_vars_.insert(v->var());
187     seen_vars_.insert(v->var());
188     taken_names_.insert(new_name);
189     level_++;
190     v->body()->accept_mutator(this);
191     level_--;
192     v->start()->accept_mutator(this);
193     v->stop()->accept_mutator(this);
194     return v;
195   }
196 
getIndexVarNameAtLevel(int level_)197   std::string getIndexVarNameAtLevel(int level_) {
198     auto names_num = index_var_names_.size();
199     auto counter = level_ / names_num;
200     if (counter == 0) {
201       return index_var_names_[level_ % names_num];
202     } else {
203       return index_var_names_[level_ % names_num] + std::to_string(counter);
204     }
205   }
getNextAvailableName(const std::string & base_name)206   std::string getNextAvailableName(const std::string& base_name) {
207     std::string name = base_name;
208     int counter = 0;
209     while (taken_names_.count(name)) {
210       counter++;
211       name = base_name + "_" + std::to_string(counter);
212     }
213     return name;
214   }
215 
216  private:
217   std::vector<std::string> index_var_names_ =
218       {"i", "j", "k", "l", "m", "n", "o", "p"};
219   std::unordered_set<std::string> taken_names_;
220   std::unordered_set<VarPtr> seen_index_vars_;
221   std::unordered_set<VarPtr> seen_vars_;
222   std::unordered_set<BufPtr> seen_bufs_;
223   int level_ = 0;
224 };
225 
sanitizeNames(StmtPtr s)226 StmtPtr LoopNest::sanitizeNames(StmtPtr s) {
227   VarNameSanitizer r;
228   s->accept_mutator(&r);
229   return s;
230 }
231 
232 class Vectorizer : public IRMutator {
233  public:
vectorize(ForPtr v)234   StmtPtr vectorize(ForPtr v) {
235     StmtPtr body = v->body();
236     VarPtr var = v->var();
237     ExprPtr start = v->start();
238     ExprPtr stop = v->stop();
239 
240     auto start_imm = intValue(start);
241     auto stop_imm = intValue(stop);
242     if (!start_imm) {
243       // Can't vectorize due to non-constant loop start!
244       success_ = false;
245       return v;
246     }
247 
248     if (!stop_imm) {
249       // Can't vectorize due to non-constant loop stop!
250       success_ = false;
251       return v;
252     }
253 
254     var_ = var;
255     start_ = immLike(start, *start_imm);
256     lanes_ = *stop_imm;
257 
258     StmtPtr new_body = body->accept_mutator(this);
259     if (new_body == body) {
260       // Vectorization failed!
261       success_ = false;
262       return v;
263     }
264 
265     return new_body;
266   }
267 
success() const268   bool success() const {
269     return success_;
270   }
271 
mutate(const AddPtr & v)272   ExprPtr mutate(const AddPtr& v) override {
273     std::vector<ExprPtr> inputs = {v->lhs(), v->rhs()};
274     return try_vectorize(v, inputs, [&]() {
275       return ExprHandle(inputs[0]) + ExprHandle(inputs[1]);
276     });
277   }
278 
mutate(const SubPtr & v)279   ExprPtr mutate(const SubPtr& v) override {
280     std::vector<ExprPtr> inputs = {v->lhs(), v->rhs()};
281     return try_vectorize(v, inputs, [&]() {
282       return ExprHandle(inputs[0]) - ExprHandle(inputs[1]);
283     });
284   }
285 
mutate(const MulPtr & v)286   ExprPtr mutate(const MulPtr& v) override {
287     std::vector<ExprPtr> inputs = {v->lhs(), v->rhs()};
288     return try_vectorize(v, inputs, [&]() {
289       return ExprHandle(inputs[0]) * ExprHandle(inputs[1]);
290     });
291   }
292 
mutate(const DivPtr & v)293   ExprPtr mutate(const DivPtr& v) override {
294     std::vector<ExprPtr> inputs = {v->lhs(), v->rhs()};
295     return try_vectorize(v, inputs, [&]() {
296       return ExprHandle(inputs[0]) / ExprHandle(inputs[1]);
297     });
298   }
299 
mutate(const ModPtr & v)300   ExprPtr mutate(const ModPtr& v) override {
301     std::vector<ExprPtr> inputs = {v->lhs(), v->rhs()};
302     return try_vectorize(v, inputs, [&]() {
303       return ExprHandle(inputs[0]) % ExprHandle(inputs[1]);
304     });
305   }
306 
mutate(const AndPtr & v)307   ExprPtr mutate(const AndPtr& v) override {
308     std::vector<ExprPtr> inputs = {v->lhs(), v->rhs()};
309     return try_vectorize(v, inputs, [&]() {
310       return ExprHandle(inputs[0]) & ExprHandle(inputs[1]);
311     });
312   }
313 
mutate(const OrPtr & v)314   ExprPtr mutate(const OrPtr& v) override {
315     std::vector<ExprPtr> inputs = {v->lhs(), v->rhs()};
316     return try_vectorize(v, inputs, [&]() {
317       return ExprHandle(inputs[0]) | ExprHandle(inputs[1]);
318     });
319   }
320 
mutate(const XorPtr & v)321   ExprPtr mutate(const XorPtr& v) override {
322     std::vector<ExprPtr> inputs = {v->lhs(), v->rhs()};
323     return try_vectorize(v, inputs, [&]() {
324       return ExprHandle(inputs[0]) ^ ExprHandle(inputs[1]);
325     });
326   }
327 
mutate(const LshiftPtr & v)328   ExprPtr mutate(const LshiftPtr& v) override {
329     std::vector<ExprPtr> inputs = {v->lhs(), v->rhs()};
330     return try_vectorize(v, inputs, [&]() {
331       return ExprHandle(inputs[0]) << ExprHandle(inputs[1]);
332     });
333   }
334 
mutate(const RshiftPtr & v)335   ExprPtr mutate(const RshiftPtr& v) override {
336     std::vector<ExprPtr> inputs = {v->lhs(), v->rhs()};
337     return try_vectorize(v, inputs, [&]() {
338       return ExprHandle(inputs[0]) >> ExprHandle(inputs[1]);
339     });
340   }
341 
mutate(const MaxPtr & v)342   ExprPtr mutate(const MaxPtr& v) override {
343     std::vector<ExprPtr> inputs = {v->lhs(), v->rhs()};
344     return try_vectorize(v, inputs, [&]() {
345       return Max::make(
346           ExprHandle(inputs[0]), ExprHandle(inputs[1]), v->propagate_nans());
347     });
348   }
349 
mutate(const MinPtr & v)350   ExprPtr mutate(const MinPtr& v) override {
351     std::vector<ExprPtr> inputs = {v->lhs(), v->rhs()};
352     return try_vectorize(v, inputs, [&]() {
353       return Min::make(
354           ExprHandle(inputs[0]), ExprHandle(inputs[1]), v->propagate_nans());
355     });
356   }
357 
mutate(const CompareSelectPtr & v)358   ExprPtr mutate(const CompareSelectPtr& v) override {
359     std::vector<ExprPtr> inputs = {
360         v->lhs(), v->rhs(), v->ret_val1(), v->ret_val2()};
361     return try_vectorize(v, inputs, [&]() {
362       return CompareSelect::make(
363           ExprHandle(inputs[0]),
364           ExprHandle(inputs[1]),
365           ExprHandle(inputs[2]),
366           ExprHandle(inputs[3]),
367           v->compare_select_op(),
368           v->bias());
369     });
370   }
371 
mutate(const BitCastPtr & v)372   ExprPtr mutate(const BitCastPtr& v) override {
373     std::vector<ExprPtr> inputs = {v->src_value()};
374     return try_vectorize(v, inputs, [&]() {
375       return BitCast::make(
376           Dtype(v->dtype().scalar_type(), lanes_), ExprHandle(inputs[0]));
377     });
378   }
379 
mutate(const CastPtr & v)380   ExprPtr mutate(const CastPtr& v) override {
381     std::vector<ExprPtr> inputs = {v->src_value()};
382     return try_vectorize(v, inputs, [&]() {
383       return Cast::make(
384           Dtype(v->dtype().scalar_type(), lanes_), ExprHandle(inputs[0]));
385     });
386   }
387 
mutate(const VarPtr & v)388   ExprPtr mutate(const VarPtr& v) override {
389     if (v == var_) {
390       return Ramp::make(
391                  ExprHandle(start_), ExprHandle(immLike(start_, 1)), lanes_)
392           .node();
393     }
394 
395     return v;
396   }
397 
mutate(const RampPtr & v)398   ExprPtr mutate(const RampPtr& v) override {
399     ExprPtr base = v->base();
400     ExprPtr stride = v->stride();
401 
402     ExprPtr base_new = base->accept_mutator(this);
403     ExprPtr stride_new = stride->accept_mutator(this);
404 
405     if (base_new == base && stride_new == stride) {
406       return v;
407     }
408 
409     // Can't vectorize a Ramp!
410     success_ = false;
411     return v;
412   }
413 
mutate(const LoadPtr & v)414   ExprPtr mutate(const LoadPtr& v) override {
415     Dtype dtype(v->dtype().scalar_type(), lanes_);
416     BufPtr buf = v->buf();
417     std::vector<ExprPtr> inputs = {v->flat_index()};
418     return try_vectorize(v, inputs, [&]() {
419       return Load::make(dtype, BufHandle(buf), {ExprHandle(inputs[0])});
420     });
421   }
422 
mutate(const ReduceOpPtr & v)423   ExprPtr mutate(const ReduceOpPtr& v) override {
424     Dtype dtype(v->dtype().scalar_type(), lanes_);
425 
426     std::vector<ExprPtr> inputs = {v->body()};
427 
428     auto out = try_vectorize(v, inputs, [&]() {
429       return ExprHandle(
430           alloc<ReduceOp>(inputs[0], v->reduce_args(), v->reducer()));
431     });
432     return out;
433   }
434 
mutate(const BroadcastPtr & v)435   ExprPtr mutate(const BroadcastPtr& v) override {
436     ExprPtr val = v->value();
437     ExprPtr new_val = val->accept_mutator(this);
438     if (new_val == val) {
439       return v;
440     }
441 
442     // Can't vectorize a Broadcast!
443     success_ = false;
444     return v;
445   }
446 
mutate(const IfThenElsePtr & v)447   ExprPtr mutate(const IfThenElsePtr& v) override {
448     ExprPtr condition = v->condition();
449     ExprPtr new_condition = condition->accept_mutator(this);
450     if (new_condition != condition) {
451       // Can't vectorize an IfThenElse condition!
452       success_ = false;
453       return v;
454     }
455 
456     std::vector<ExprPtr> inputs = {v->true_value(), v->false_value()};
457     return try_vectorize(v, inputs, [&]() {
458       return IfThenElse::make(
459           ExprHandle(condition), ExprHandle(inputs[0]), ExprHandle(inputs[1]));
460     });
461   }
462 
mutate(const IntrinsicsPtr & v)463   ExprPtr mutate(const IntrinsicsPtr& v) override {
464     std::vector<ExprPtr> inputs = v->params();
465     return try_vectorize(v, inputs, [&]() {
466       return ExprHandle(alloc<Intrinsics>(v->op_type(), inputs));
467     });
468   }
469 
mutate(const StorePtr & v)470   StmtPtr mutate(const StorePtr& v) override {
471     BufPtr buf = v->buf();
472     std::vector<ExprPtr> inputs = {v->flat_index(), v->value()};
473     return try_vectorize(v, inputs, [&]() {
474       return Store::make(
475           BufHandle(buf), {ExprHandle(inputs[0])}, ExprHandle(inputs[1]));
476     });
477   }
478 
mutate(const ForPtr & v)479   StmtPtr mutate(const ForPtr& v) override {
480     VarPtr var = v->var();
481     ExprPtr start = v->start();
482     ExprPtr stop = v->stop();
483     LoopOptions loop_options = v->loop_options();
484 
485     ExprPtr new_start = start->accept_mutator(this);
486     ExprPtr new_stop = stop->accept_mutator(this);
487 
488     if (new_start != start || new_stop != stop) {
489       // Can't vectorize nested For with dependent loop bounds!
490       success_ = false;
491       return v;
492     }
493 
494     StmtPtr body = v->body();
495     StmtPtr new_body = body->accept_mutator(this);
496 
497     if (new_body == body) {
498       return (ForPtr)v;
499     }
500 
501     return alloc<For>(var, new_start, new_stop, new_body, loop_options);
502   }
503 
mutate(const BlockPtr & v)504   StmtPtr mutate(const BlockPtr& v) override {
505     // IRMutator does in-place mutations. But the logic in vectorization checks
506     // for success by looking for a new stmt. So, we override the in-place
507     // mutations and create a clone here if any of its statements change.
508     // TODO: Can we change the logic of vectorizer so that we don't need this?
509     bool any_change = false;
510     std::vector<StmtPtr> stmts;
511     for (const StmtPtr& stmt : *v) {
512       StmtPtr stmt_new = stmt->accept_mutator(this);
513       if (stmt != stmt_new) {
514         any_change = true;
515       } else {
516         stmt_new = Stmt::clone(stmt);
517       }
518       if (stmt_new) {
519         stmts.push_back(stmt_new);
520       }
521     }
522     if (any_change) {
523       return alloc<Block>(stmts);
524     }
525     return v;
526   }
527 
528   template <typename T>
try_vectorize(ExprPtr e,std::vector<ExprPtr> & inputs,T && vec_ctor)529   ExprPtr try_vectorize(ExprPtr e, std::vector<ExprPtr>& inputs, T&& vec_ctor) {
530     bool vectorize = vectorize_inputs(inputs);
531     if (vectorize) {
532       return vec_ctor().node();
533     }
534 
535     return e;
536   }
537 
538   template <typename T>
try_vectorize(StmtPtr s,std::vector<ExprPtr> & inputs,T && vec_ctor)539   StmtPtr try_vectorize(StmtPtr s, std::vector<ExprPtr>& inputs, T&& vec_ctor) {
540     bool vectorize = vectorize_inputs(inputs);
541     if (vectorize) {
542       return vec_ctor();
543     }
544 
545     return s;
546   }
547 
vectorize_inputs(std::vector<ExprPtr> & inputs)548   bool vectorize_inputs(std::vector<ExprPtr>& inputs) {
549     bool any_vectorized = false;
550     std::vector<ExprPtr> new_inputs;
551 
552     // Attempt to vectorize each input.
553     for (ExprPtr& in : inputs) {
554       ExprPtr new_in = in->accept_mutator(this);
555       new_inputs.push_back(new_in);
556       if (new_in != in) {
557         any_vectorized = true;
558       }
559     }
560 
561     // If none of them vectorized, then don't vectorize this.
562     if (!any_vectorized) {
563       return false;
564     }
565 
566     // Insert broadcasts for any inputs that weren't vectorized.
567     for (size_t i = 0; i < inputs.size(); ++i) {
568       if (inputs[i] == new_inputs[i]) {
569         inputs[i] = Broadcast::make(ExprHandle(inputs[i]), lanes_).node();
570       } else {
571         inputs[i] = new_inputs[i];
572       }
573     }
574 
575     // And then vectorize this node.
576     return true;
577   }
578 
579   VarPtr var_ = nullptr;
580   int64_t lanes_ = 0;
581   ExprPtr start_ = nullptr;
582   bool success_ = true;
583 };
584 
vectorize(const ForPtr & f)585 bool LoopNest::vectorize(const ForPtr& f) {
586   BlockPtr b = to<Block>(f->get_parent());
587   if (!b) {
588     return false;
589   }
590 
591   // Can't vectorize reduction axes.
592   auto reductions = NodeFinder<ReduceOp>::find(f);
593   for (const auto& r : reductions) {
594     if (std::find(r->reduce_args().begin(), r->reduce_args().end(), f->var()) !=
595         r->reduce_args().end()) {
596       return false;
597     }
598   }
599 
600   Vectorizer v;
601   StmtPtr new_f = nullptr;
602   new_f = Stmt::clone(f);
603   normalize(to<For>(new_f));
604   new_f = FlattenIndexes(new_f);
605   new_f = v.vectorize(to<For>(new_f));
606   if (!v.success()) {
607     // We clone f before vectorizing. So, any partial vectorization will
608     // have modified the clone. In case of an exception, we can continue
609     // using f.
610     new_f = f;
611   }
612 
613   if (new_f != f) {
614     b->replace_stmt(f, IRSimplifier::simplify(new_f));
615     return true;
616   }
617 
618   // Vectorization was not successful.
619   return false;
620 }
621 
initialize(const std::vector<Tensor> & output_tensors,const std::vector<Tensor> & tensors_to_compute)622 void LoopNest::initialize(
623     const std::vector<Tensor>& output_tensors,
624     const std::vector<Tensor>& tensors_to_compute) {
625   for (const auto& t : output_tensors) {
626     output_bufs_.insert(t.buf());
627   }
628 
629   std::vector<StmtPtr> loops;
630   for (const Tensor& t : tensors_to_compute) {
631     StmtPtr loop = t.stmt();
632     if (loop->get_parent()) {
633       std::cerr << "Error: creating a loopnest from already used Tensors\n";
634       loops = {};
635       break;
636     }
637     // Flatten initializers.
638     if (BlockPtr block = to<Block>(loop)) {
639       for (const auto& s : block->stmts()) {
640         block->remove_stmt(s);
641         loops.push_back(s);
642       }
643     } else {
644       loops.push_back(loop);
645     }
646   }
647 
648   root_stmt_ = alloc<Block>(loops);
649 }
650 
651 class FunctionInliner : public IRMutator {
652  public:
FunctionInliner(StorePtr producer,std::unordered_set<BufPtr> outputs)653   FunctionInliner(StorePtr producer, std::unordered_set<BufPtr> outputs)
654       : buf_(producer->buf()),
655         producer_(std::move(producer)),
656         outputs_(std::move(outputs)) {
657     for (const auto& i : producer_->indices()) {
658       if (auto index_var = to<Var>(i)) {
659         index_vars_.insert(index_var);
660         producer_index_vars_.push_back(index_var);
661       } else {
662         // If the index can be a constant, then that dimension must have size 1
663         // (since we don't support in-place writes). Resolves issue 52581.
664         auto index_val = evalInt(i);
665         if (!index_val || *index_val != 0) {
666           success_ = false;
667           break;
668         }
669         producer_index_vars_.push_back(nullptr);
670       }
671     }
672   }
673 
success() const674   bool success() const {
675     return success_;
676   }
677 
678  private:
mutate_loads(const BufPtr & buf,std::vector<ExprPtr> dims)679   ExprPtr mutate_loads(const BufPtr& buf, std::vector<ExprPtr> dims) {
680     std::vector<VarPtr> index_vars;
681     if (buf->ndim() != producer_index_vars_.size()) {
682       // Dimensions of producer and consumer expressions do not match in inliner
683       // in the fuser
684       success_ = false;
685       return nullptr;
686     }
687     for (const auto i : c10::irange(buf->ndim())) {
688       VarPtr func_callee_arg = producer_index_vars_.at(i);
689       ExprPtr func_caller_param = dims.at(i);
690       if (func_callee_arg == nullptr) {
691         continue;
692       }
693       auto iter = inline_mapping_.find(func_callee_arg);
694       if (iter != inline_mapping_.end()) {
695         // Duplicated variables
696         success_ = false;
697         return nullptr;
698       }
699       // Add a mapping for each function parameter to it's source name.
700       inline_mapping_[func_callee_arg] = func_caller_param;
701       GRAPH_DEBUG(
702           "ComputeInline: Inline mapping: ",
703           std::to_string(func_callee_arg),
704           " -> ",
705           std::to_string(func_caller_param));
706       index_vars.push_back(func_callee_arg);
707     }
708 
709     // Call the actual replacement.
710     ExprPtr body = producer_->value();
711     GRAPH_DEBUG("ComputeInline: Before rewriting body: ", std::to_string(body));
712     ExprPtr result = Expr::clone(body)->accept_mutator(this);
713     GRAPH_DEBUG(
714         "ComputeInline: After rewriting body: ", std::to_string(result));
715 
716     // Remove the mappings we created for this function parameters.
717     for (const auto& v : index_vars) {
718       for (auto& pair : random_bindings_) {
719         if (pair.second.erase(v)) {
720           ExprPtr inlined = inline_mapping_[v];
721           for (const auto& nv : VarFinder::find(inlined)) {
722             pair.second.insert(nv);
723           }
724         }
725       }
726       GRAPH_DEBUG("ComputeInline: Inline mapping: erasing", std::to_string(v));
727       inline_mapping_.erase(v);
728     }
729     return result;
730   }
731 
mutate(const LoadPtr & v)732   ExprPtr mutate(const LoadPtr& v) override {
733     if (!success()) {
734       return v;
735     }
736     BufPtr buf = v->buf();
737     if (buf != buf_) {
738       return IRMutator::mutate(v);
739     }
740 
741     if (v->indices().size() != buf->ndim()) {
742       // Number of indices doesn't match buf rank in the fuser
743       success_ = false;
744       return v;
745     }
746     auto result = mutate_loads(buf, v->indices());
747     if (!result) {
748       // If we don't inline successfully return the given load.
749       success_ = false;
750       return v;
751     }
752     return result;
753   }
754 
755   // Replace the target variable with the caller expressions.
mutate(const VarPtr & v)756   ExprPtr mutate(const VarPtr& v) override {
757     if (!success()) {
758       return v;
759     }
760     auto iter = inline_mapping_.find(v);
761     if (iter == inline_mapping_.end()) {
762       return v;
763     } else {
764       ExprPtr expr = iter->second;
765       // Continue to transform the value from the lookup table.
766       return expr->accept_mutator(this);
767     }
768   }
769 
770   // Handle random intrinsics which should be cached.
mutate(const IntrinsicsPtr & v)771   ExprPtr mutate(const IntrinsicsPtr& v) override {
772     if (!success()) {
773       return v;
774     }
775     if (!in_producer_ || v->op_type() != kRand) {
776       return IRMutator::mutate(v);
777     }
778 
779     // Create a new Let Statement for the random variable, which we can refer
780     // to multiple times and resolve the same value (ie. store it in a scalar
781     // rather than the Tensor).
782     const std::string& name = buf_->name_hint();
783     VarPtr new_var = alloc<Var>(name, v->dtype());
784     random_bindings_[alloc<Let>(new_var, v)] = index_vars_;
785     GRAPH_DEBUG(
786         "ComputeInline: created random bindings for ", std::to_string(new_var));
787     return new_var;
788   }
789 
790   // Remove the buffer write from the inlined function.
mutate(const StorePtr & v)791   StmtPtr mutate(const StorePtr& v) override {
792     if (!success()) {
793       return v;
794     }
795     // If the buf_ is in the outputs set, keep its statement intact. Otherwise,
796     // remove it.
797     if (v == producer_ && !outputs_.count(buf_)) {
798       in_producer_ = true;
799       producer_ = to<Store>(IRMutator::mutate(v));
800       if (!producer_) {
801         // Producer statement for output buf should remain non-null in the fuser
802         success_ = false;
803         return v;
804       }
805       in_producer_ = false;
806       return nullptr;
807     } else {
808       return IRMutator::mutate(v);
809     }
810   }
811 
812   // Any Random Intrinsics that were turned into vars must be inserted here.
mutate(const BlockPtr & v)813   StmtPtr mutate(const BlockPtr& v) override {
814     if (!success()) {
815       return v;
816     }
817     std::vector<StmtPtr> stmts;
818     for (const StmtPtr& stmt : *v) {
819       StmtPtr stmt_new = stmt->accept_mutator(this);
820       if (!stmt_new) {
821         continue;
822       }
823 
824       if (stmt == stmt_new) {
825         stmt_new = Stmt::clone(stmt);
826       }
827 
828       stmts.push_back(stmt_new);
829     }
830 
831     return Block::make(stmts);
832   }
833 
mutate(const ForPtr & v)834   StmtPtr mutate(const ForPtr& v) override {
835     if (!success()) {
836       return v;
837     }
838     ForPtr res = to<For>(IRMutator::mutate(v));
839     if (!res) {
840       return nullptr;
841     }
842 
843     // Find any random bindings that should be defined in this loops body.
844     std::vector<LetPtr> bindings_this_loop;
845     VarPtr fv = v->var();
846     for (auto& pair : random_bindings_) {
847       auto& index_var = pair.second;
848       if (index_var.erase(fv)) {
849         bindings_this_loop.push_back(pair.first);
850       }
851     }
852 
853     for (const auto& l : bindings_this_loop) {
854       res->body()->prepend_stmt(l);
855       random_bindings_.erase(l);
856     }
857     return res;
858   }
859 
860  private:
861   BufPtr buf_;
862   StorePtr producer_;
863 
864   // Index Vars present in the producer.
865   std::unordered_set<VarPtr> index_vars_;
866   std::vector<VarPtr> producer_index_vars_;
867 
868   std::unordered_map<VarPtr, ExprPtr> inline_mapping_;
869 
870   // In the producer's scope - we need to bind any calls to rand().
871   bool in_producer_ = false;
872   std::unordered_map<LetPtr, std::unordered_set<VarPtr>> random_bindings_;
873   std::unordered_set<BufPtr> outputs_;
874   bool success_ = true;
875 };
876 
computeInlineImpl(const BufPtr & b,const StmtPtr & stmt,const std::unordered_set<BufPtr> & output_bufs)877 static StmtPtr computeInlineImpl(
878     const BufPtr& b,
879     const StmtPtr& stmt,
880     const std::unordered_set<BufPtr>& output_bufs) {
881   // If buf is used or defined in an ExternalCall, we cannot inline it
882   auto buf_load_store_uses = findLoadOrStoreUses(stmt);
883   if (!buf_load_store_uses.count(b)) {
884     return nullptr;
885   }
886   for (auto& use : buf_load_store_uses.at(b)) {
887     StmtPtr s = use.s;
888     if (to<ExternalCall>(s) || to<ExternalCallWithAlloc>(s)) {
889       return nullptr;
890     }
891   }
892 
893   // Find producers.
894   StorePtr relevant_store{nullptr};
895   auto stores = NodeFinder<Store>::find(stmt);
896   for (const auto& s : stores) {
897     if (s->buf() == b) {
898       auto reductions = NodeFinder<ReduceOp>::find(s);
899       if (!reductions.empty()) {
900         // Cannot inline a reduction computation
901         return nullptr;
902       }
903       if (relevant_store != nullptr) {
904         // Cannot inline Buf with multiple Tensors
905         return nullptr;
906       }
907       relevant_store = s;
908     }
909   }
910 
911   if (!relevant_store) {
912     // Cannot find a relevant store to inline a buf in the fuser
913     return nullptr;
914   }
915 
916   GRAPH_DEBUG("ComputeInline: Def: ", std::to_string(relevant_store));
917   FunctionInliner inliner(relevant_store, output_bufs);
918   auto result = stmt->accept_mutator(&inliner);
919   if (inliner.success()) {
920     return result;
921   }
922   return nullptr;
923 }
924 
computeInline(const BufPtr & b)925 bool LoopNest::computeInline(const BufPtr& b) {
926   // Inlining may not always be successful. Since all mutations now happen
927   // in-place, an unsuccessful inlining transformation might leave the IR
928   // in an invalid state. To get around this problem, we clone the root stmt,
929   // try inlining on the clone, and if it succeeds, we proceed to perform
930   // inlining on the actual root stmt. This way the root stmt will always be
931   // in a valid state.
932   auto stmt_copy = Stmt::clone(root_stmt_);
933   auto try_inline = computeInlineImpl(b, stmt_copy, output_bufs_);
934   if (!try_inline) {
935     return false;
936   }
937   root_stmt_ = computeInlineImpl(b, root_stmt_, output_bufs_);
938   return true;
939 }
940 
computeInline(const StmtPtr & s)941 bool LoopNest::computeInline(const StmtPtr& s) {
942   auto s_store = to<Store>(s);
943   if (s_store == nullptr) {
944     // Could not find buffer producer to inline
945     return false;
946   }
947   return computeInline(s_store->buf());
948 }
949 
950 // inlining buffers with multiple uses can create duplicated work, which can
951 // slow down cpu code generation but is enabled on gpu because it avoids
952 // difficult synchronization logic across blocks. Inlining trivial reads does
953 // not duplicate work
inlineIntermediateBufs(bool allow_duplicated_work)954 void LoopNest::inlineIntermediateBufs(bool allow_duplicated_work) {
955   std::unordered_set<BufPtr> bufs_to_inline;
956 
957   auto intermediate_bufs = getIntermediateBufs();
958   if (allow_duplicated_work) {
959     bufs_to_inline.insert(intermediate_bufs.begin(), intermediate_bufs.end());
960   } else {
961     auto buf_load_store_uses = findLoadOrStoreUses(root_stmt_);
962     auto input_bufs = getInputBufs();
963 
964     for (const auto& buf : intermediate_bufs) {
965       TORCH_INTERNAL_ASSERT(
966           buf_load_store_uses.count(buf),
967           buildErrorMessage(
968               "Could not find uses of buf '" + buf->name_hint() +
969               "' in the fuser."));
970       std::vector<BufLoadOrStoreUse>& uses = buf_load_store_uses[buf];
971       auto stores = c10::filter(
972           uses, [](const BufLoadOrStoreUse& use) { return use.isStore; });
973 
974       // if the intermediate is the buffer formed from reading in the input
975       // tensors, always inline, bc we are not duplicating any work
976       // and avoiding an intermediary buffer
977       if (stores.size() == 1) {
978         if (auto store = to<Store>(stores[0].s)) {
979           auto input_as_load = to<Load>(store->value());
980           if (input_as_load && input_bufs.count(input_as_load->buf())) {
981             bufs_to_inline.insert(buf);
982             continue;
983           }
984         } else {
985           // If S is not a store, it must be an ExternalCall.
986           TORCH_INTERNAL_ASSERT(
987               to<ExternalCall>(stores[0].s) ||
988                   to<ExternalCallWithAlloc>(stores[0].s),
989               buildErrorMessage(
990                   "Expected stmt: " + std::to_string(stores[0].s) +
991                   "\nto be either a Store or an ExternalCall in the fuser."));
992         }
993       }
994 
995       // all bufs will have at least one store (if they have > 1 they cant be
996       // inlined anyway)
997       size_t reads = uses.size() - 1;
998       // if only one read, we can inline it without duplicating work
999       if (reads <= 1) {
1000         bufs_to_inline.insert(buf);
1001       }
1002     }
1003   }
1004 
1005   if (allow_duplicated_work) {
1006     bufs_to_inline.insert(output_bufs_.begin(), output_bufs_.end());
1007   }
1008 
1009   for (const auto& b : bufs_to_inline) {
1010     computeInline(b);
1011   }
1012 }
1013 
1014 // TODO: Unify with DepTracker
1015 class LoadOrStoreUseFinder : public IRVisitor {
1016  public:
findUses(const StmtPtr & s)1017   std::unordered_map<BufPtr, std::vector<BufLoadOrStoreUse>> findUses(
1018       const StmtPtr& s) {
1019     uses_.clear();
1020     s->accept(this);
1021     return uses_;
1022   }
1023 
1024  private:
visit(const StorePtr & v)1025   void visit(const StorePtr& v) override {
1026     if (stores_[v->buf()].insert(last_stmt_).second) {
1027       uses_[v->buf()].push_back({(StmtPtr)v, true});
1028     }
1029     last_stmt_ = (StmtPtr)v;
1030     IRVisitor::visit(v);
1031   }
1032 
visit(const ExternalCallPtr & v)1033   void visit(const ExternalCallPtr& v) override {
1034     if (stores_[v->buf()].insert(last_stmt_).second) {
1035       uses_[v->buf()].push_back({(StmtPtr)v, true});
1036     }
1037     last_stmt_ = (StmtPtr)v;
1038 
1039     for (const BufPtr& input_buf : v->buf_args()) {
1040       if (loads_[input_buf].insert(last_stmt_).second) {
1041         uses_[input_buf].push_back({last_stmt_, false});
1042       }
1043     }
1044 
1045     IRVisitor::visit(v);
1046   }
1047 
visit(const ExternalCallWithAllocPtr & v)1048   void visit(const ExternalCallWithAllocPtr& v) override {
1049     for (const auto& out_buf : v->buf_out_args()) {
1050       if (stores_[out_buf].insert(last_stmt_).second) {
1051         uses_[out_buf].push_back({(StmtPtr)v, true});
1052       }
1053     }
1054     last_stmt_ = (StmtPtr)v;
1055 
1056     for (const auto& input_buf : v->buf_args()) {
1057       if (loads_[input_buf].insert(last_stmt_).second) {
1058         uses_[input_buf].push_back({last_stmt_, false});
1059       }
1060     }
1061 
1062     IRVisitor::visit(v);
1063   }
1064 
visit(const LoadPtr & v)1065   void visit(const LoadPtr& v) override {
1066     if (loads_[v->buf()].insert(last_stmt_).second) {
1067       uses_[v->buf()].push_back({last_stmt_, false});
1068     }
1069     IRVisitor::visit(v);
1070   }
1071 
1072   StmtPtr last_stmt_ = nullptr;
1073   std::unordered_map<BufPtr, std::vector<BufLoadOrStoreUse>> uses_;
1074 
1075   // Sets of loads and stores in order to keep the results unique
1076   std::unordered_map<BufPtr, std::unordered_set<StmtPtr>> loads_;
1077   std::unordered_map<BufPtr, std::unordered_set<StmtPtr>> stores_;
1078 };
1079 
findLoadOrStoreUses(const StmtPtr & s)1080 std::unordered_map<BufPtr, std::vector<BufLoadOrStoreUse>> findLoadOrStoreUses(
1081     const StmtPtr& s) {
1082   LoadOrStoreUseFinder uf;
1083   return uf.findUses(s);
1084 }
1085 
1086 class ContainedStmtsFinder : public IRVisitor {
1087  public:
1088   // Simply list all Stores and Block that are children of the given stmt
findContainedStmts(const StmtPtr & s)1089   const std::unordered_set<StmtPtr>& findContainedStmts(const StmtPtr& s) {
1090     contained_.clear();
1091     s->accept(this);
1092     return contained_;
1093   }
1094 
1095  private:
visit(const StorePtr & v)1096   void visit(const StorePtr& v) override {
1097     contained_.insert((StmtPtr)v);
1098     IRVisitor::visit(v);
1099   }
visit(const ExternalCallPtr & v)1100   void visit(const ExternalCallPtr& v) override {
1101     contained_.insert((StmtPtr)v);
1102     IRVisitor::visit(v);
1103   }
visit(const ExternalCallWithAllocPtr & v)1104   void visit(const ExternalCallWithAllocPtr& v) override {
1105     contained_.insert((StmtPtr)v);
1106     IRVisitor::visit(v);
1107   }
visit(const BlockPtr & v)1108   void visit(const BlockPtr& v) override {
1109     contained_.insert((StmtPtr)v);
1110     IRVisitor::visit(v);
1111   }
1112 
1113   std::unordered_set<StmtPtr> contained_;
1114 };
1115 
1116 class StmtDeleter : public IRMutator {
1117  public:
StmtDeleter(const std::unordered_set<StmtPtr> & targets)1118   StmtDeleter(const std::unordered_set<StmtPtr>& targets) : targets_(targets) {}
1119 
1120  private:
mutate(const BlockPtr & v)1121   StmtPtr mutate(const BlockPtr& v) override {
1122     std::vector<StmtPtr> stmts;
1123 
1124     for (const auto& s : v->stmts()) {
1125       if (targets_.count(s) == 0) {
1126         StmtPtr ns = s->accept_mutator(this);
1127         if (ns) {
1128           stmts.push_back(Stmt::clone(ns));
1129         }
1130       }
1131     }
1132 
1133     return Block::make(stmts);
1134   }
1135 
1136   const std::unordered_set<StmtPtr>& targets_;
1137 };
1138 
eliminateDeadStores()1139 void LoopNest::eliminateDeadStores() {
1140   using namespace analysis;
1141   MemDependencyChecker checker(getInputBufs(), getOutputBufs());
1142   root_stmt_->accept(&checker);
1143 
1144   std::unordered_set<StmtPtr> deadStores;
1145   std::vector<std::shared_ptr<AccessInfo>> outputAccesses;
1146   for (const auto& o : getOutputBufs()) {
1147     outputAccesses.push_back(checker.output(o));
1148   }
1149 
1150   for (auto& info : checker.getHistory()) {
1151     if (!info->isWrite()) {
1152       continue;
1153     }
1154     bool found = false;
1155 
1156     for (auto& output : outputAccesses) {
1157       if (checker.dependsIndirectly(output, info)) {
1158         found = true;
1159         break;
1160       }
1161     }
1162 
1163     if (!found) {
1164       deadStores.insert(info->stmt());
1165     }
1166   }
1167 
1168   StmtDeleter deleter(deadStores);
1169   root_stmt_ = root_stmt_->accept_mutator(&deleter);
1170 }
1171 
prepareForCodegen()1172 void LoopNest::prepareForCodegen() {
1173   // Expand reduction ops.
1174   ReductionExpander reduceExpander;
1175   root_stmt_ = reduceExpander.expand(root_stmt_);
1176 
1177   root_stmt_ = FlattenIndexes(root_stmt_);
1178 }
1179 
1180 namespace {
1181 
1182 // This is extended from IRCloner instead of IRMutator because we want all
1183 // the rest of the IR nodes (the ones not touched directly) to be cloned.
1184 class IfThenElseReplacer : public IRCloner {
1185  public:
IfThenElseReplacer(IfThenElsePtr to_replace,ExprPtr new_expr)1186   IfThenElseReplacer(IfThenElsePtr to_replace, ExprPtr new_expr)
1187       : to_replace_(std::move(to_replace)), new_expr_(std::move(new_expr)) {}
1188 
mutate(const IfThenElsePtr & i)1189   ExprPtr mutate(const IfThenElsePtr& i) override {
1190     if (i == to_replace_) {
1191       return new_expr_;
1192     }
1193     return IRCloner::mutate(i);
1194   }
1195 
1196  private:
1197   IfThenElsePtr to_replace_;
1198   ExprPtr new_expr_;
1199 };
1200 
1201 // Check if the given condition is optimizable.
1202 // Specifically, this function looks for the following pattern:
1203 //    "var < expr"
1204 //
1205 // If this pattern is found, then this function:
1206 //   * sets `cond_var` to `var`,
1207 //   * sets `compared_value` to `expr`, and
1208 //   * returns true.
isConditionOptimizable(const ExprPtr & condition,VarPtr * cond_var,ExprPtr * compared_value)1209 bool isConditionOptimizable(
1210     const ExprPtr& condition,
1211     VarPtr* cond_var,
1212     ExprPtr* compared_value) {
1213   auto cs = to<CompareSelect>(condition);
1214   if (cs && cs->compare_select_op() == kLT) {
1215     auto var = to<Var>(cs->lhs());
1216     if (var) {
1217       *cond_var = var;
1218       *compared_value = cs->rhs();
1219       return true;
1220     }
1221   }
1222   return false;
1223 }
1224 
1225 // Checks if the given if-then-else expression is a conditional that is
1226 // generated from `aten::cat`.
1227 //
1228 // The expected format of conditionals is:
1229 //     IfThenElse(var < val1? 1 : 0,
1230 //       IfThenElse (var < val2? 1 : 0,
1231 //         IfThenElse (var < val3? 1 : 0,
1232 //           sub-expr1,
1233 //           sub-expr2),
1234 //         sub-expr3),
1235 //       sub-expr4)
1236 //
1237 // If such a conditional is found, this function also sets:
1238 //   * cond_var to the condition variable found in this expression.
1239 //   * comp_values to the list of compared values in the condition expressions.
1240 //   * sub_exprs to the list of sub-expressions that are the result of this
1241 //     if-then-else expression.
isConditionalFromCat(const IfThenElsePtr & ite,VarPtr * cond_var,std::vector<ExprPtr> * comp_values,std::vector<ExprPtr> * sub_exprs)1242 bool isConditionalFromCat(
1243     const IfThenElsePtr& ite,
1244     VarPtr* cond_var,
1245     std::vector<ExprPtr>* comp_values,
1246     std::vector<ExprPtr>* sub_exprs) {
1247   VarPtr var = nullptr;
1248   ExprPtr comp_value;
1249   if (isConditionOptimizable(ite->condition(), &var, &comp_value)) {
1250     if (*cond_var == nullptr) {
1251       *cond_var = var;
1252     } else if (*cond_var != var) {
1253       // Different condition variables found in nested if-then-else
1254       // expressions. Can not optimize such cases.
1255       return false;
1256     }
1257     auto true_ite = to<IfThenElse>(ite->true_value());
1258     if (true_ite) {
1259       if (!isConditionalFromCat(true_ite, cond_var, comp_values, sub_exprs)) {
1260         return false;
1261       }
1262     } else {
1263       sub_exprs->push_back(ite->true_value());
1264     }
1265     auto false_ite = to<IfThenElse>(ite->false_value());
1266     if (false_ite) {
1267       return false;
1268     }
1269     comp_values->push_back(comp_value);
1270     sub_exprs->push_back(ite->false_value());
1271     return true;
1272   }
1273   return false;
1274 }
1275 
areConstantsAndSorted(const std::vector<ExprPtr> & comp_values)1276 bool areConstantsAndSorted(const std::vector<ExprPtr>& comp_values) {
1277   std::vector<int> comp_consts;
1278   comp_consts.reserve(comp_values.size());
1279   for (const auto& c : comp_values) {
1280     if (!c->isConstant()) {
1281       return false;
1282     }
1283     comp_consts.push_back(immediateAs<int>(c));
1284   }
1285   return std::is_sorted(comp_consts.begin(), comp_consts.end());
1286 }
1287 
1288 } // namespace
1289 
optimizeConditionals()1290 bool LoopNest::optimizeConditionals() {
1291   // Consider every store in the root_stmt_ and try to optimize the
1292   // conditionals in that store.
1293   auto stores = NodeFinder<Store>::find(root_stmt_);
1294   std::unordered_set<ForPtr> split_fors;
1295   for (const auto& store : stores) {
1296     VarPtr cond_var = nullptr;
1297     // `comp_values` represent the list of compared values that will be
1298     // collected as we check for the expected pattern. Since that will
1299     // only include the RHS of the conditions in the if-then-else expressions
1300     // we need to start with `0` which is the initial bound, given that we
1301     // only handle normalized loops (check for this is done below).
1302     std::vector<ExprPtr> comp_values;
1303     std::vector<ExprPtr> sub_exprs;
1304     auto ifthenelse_exprs = NodeFinder<IfThenElse>::find(store);
1305     if (ifthenelse_exprs.empty()) {
1306       continue;
1307     }
1308     // We only check if the first if-then-else expression in this store
1309     // corresponds to a conditional of the required format. If there are more
1310     // than one such conditional, optimizing them requires checking if the
1311     // conditions are exactly the same across them and handling all of them
1312     // together. Currently, this is not handled.
1313     if (!isConditionalFromCat(
1314             ifthenelse_exprs.front(), &cond_var, &comp_values, &sub_exprs)) {
1315       continue;
1316     }
1317     TORCH_INTERNAL_ASSERT(
1318         !comp_values.empty(),
1319         buildErrorMessage(
1320             "Expected at least one expression in optimizeConditional in the fuser."));
1321     comp_values.insert(comp_values.begin(), immLike(comp_values[0], 0));
1322 
1323     auto fors = getLoopStmtsFor(store);
1324     if (cond_var != fors.back()->var()) {
1325       // Currently, we only handle the case where the condition variable
1326       // is the same as the inner-most loop variable.
1327       // TODO: Handle all other cases here.
1328       //
1329       // In order to handle all other cases, the method `clone_and_replace`
1330       // called below to clone the body of the loop with a new store needs
1331       // to recursively handle cloning of the loops and other blocks it
1332       // contains.
1333       continue;
1334     }
1335 
1336     auto for_to_split = fors.back();
1337     if (!LoopNest::isNormalized(for_to_split)) {
1338       // Do not optimize this conditional since the condition variable
1339       // refers to a loop that is not normalized.
1340       continue;
1341     }
1342     if (split_fors.count(for_to_split)) {
1343       // This loop has already been split while optimizing conditionals
1344       // earlier.
1345       //
1346       // Optimizing multiple conditionals that require splitting the same loop
1347       // is tricky. It requires checking if the conditions are exactly the same
1348       // across them and handling all of them together by splitting the loop
1349       // exactly once.
1350       //
1351       // Currently, this case is not supported.
1352       continue;
1353     }
1354     split_fors.insert(for_to_split);
1355 
1356     // `comp_values` needs to include the end bound, which is `for_to_split`
1357     // stop value.
1358     comp_values.push_back(for_to_split->stop());
1359 
1360     // Check if all `comp_values` are constants and they are sorted.
1361     if (!areConstantsAndSorted(comp_values)) {
1362       continue;
1363     }
1364 
1365     // Remove all the if-then-else expressions from this store and create
1366     // one loop per sub-expression.
1367     std::vector<StmtPtr> split_loops;
1368     auto cond_to_replace = ifthenelse_exprs.front();
1369     for (size_t i = 0; i < sub_exprs.size(); ++i) {
1370       IfThenElseReplacer ifthenelseReplacer(cond_to_replace, sub_exprs[i]);
1371       auto new_store = store->accept_mutator(&ifthenelseReplacer);
1372       auto new_for_body =
1373           for_to_split->body()->clone_and_replace(store, new_store);
1374       auto new_for = alloc<For>(
1375           for_to_split->var(),
1376           comp_values[i],
1377           comp_values[i + 1],
1378           new_for_body);
1379       LoopNest::normalize(new_for);
1380       split_loops.push_back(new_for);
1381     }
1382     auto par = to<Block>(for_to_split->get_parent());
1383     par->replace_stmt(for_to_split, alloc<Block>(split_loops));
1384   }
1385   root_stmt_ = IRSimplifier::simplify(root_stmt_);
1386   return true;
1387 }
1388 
vectorizeInnerLoops()1389 void LoopNest::vectorizeInnerLoops() {
1390   std::vector<ForPtr> innerLoops;
1391   std::vector<ForPtr> worklist;
1392 
1393   // Find outer-most For loops
1394   if (ForPtr rootF = to<For>(root_stmt_)) {
1395     worklist.push_back(rootF);
1396   } else if (BlockPtr body = to<Block>(root_stmt_)) {
1397     std::vector<BlockPtr> blocks = {body};
1398     while (!blocks.empty()) {
1399       BlockPtr b = blocks.back();
1400       blocks.pop_back();
1401 
1402       for (const StmtPtr& s : *b) {
1403         if (const ForPtr& f = to<For>(s)) {
1404           worklist.push_back(f);
1405         } else if (BlockPtr b2 = to<Block>(s)) {
1406           blocks.push_back(b2);
1407         }
1408       }
1409     }
1410   }
1411 
1412   // Traverse the For loop nest find inner-most loops, which are
1413   // vectorization candidates.
1414   while (!worklist.empty()) {
1415     ForPtr f = worklist.back();
1416     worklist.pop_back();
1417 
1418     bool containsSubLoops = false;
1419     if (BlockPtr body = to<Block>(f->body())) {
1420       for (const StmtPtr& s2 : *body) {
1421         if (const ForPtr& f2 = to<For>(s2)) {
1422           containsSubLoops = true;
1423           worklist.push_back(f2);
1424         }
1425       }
1426     }
1427 
1428     if (!containsSubLoops) {
1429       innerLoops.push_back(f);
1430     }
1431   }
1432 
1433   // vectorize inner loops.
1434   for (const ForPtr& loop : innerLoops) {
1435     ForPtr split1;
1436     ForPtr tail1;
1437 
1438     static const int kBodyVectorWidth = 8;
1439     splitWithTail(loop, kBodyVectorWidth, &split1, &tail1);
1440     vectorize(split1);
1441 
1442     if (tail1) {
1443       ForPtr split2;
1444       ForPtr tail2;
1445       static const int kTailVectorWidth = 4;
1446       splitWithTail(tail1, kTailVectorWidth, &split2, &tail2);
1447       vectorize(split2);
1448     }
1449   }
1450 }
1451 
sliceHead(const ForPtr & f,int factor,ForPtr * head,ForPtr * tail)1452 void LoopNest::sliceHead(
1453     const ForPtr& f,
1454     int factor,
1455     ForPtr* head,
1456     ForPtr* tail) {
1457   if (intValue(f->start()) && intValue(f->stop())) {
1458     auto start_val = *intValue(f->start());
1459     auto stop_val = *intValue(f->stop());
1460     auto size_val = stop_val - start_val;
1461     if (factor >= size_val) {
1462       *head = f;
1463       *tail = nullptr;
1464       return;
1465     }
1466   }
1467 
1468   if (!f) {
1469     throw malformed_input("sliceHead attempted on null loop");
1470   }
1471 
1472   BlockPtr p = to<Block>(f->get_parent());
1473   if (!p) {
1474     throw malformed_input("sliceHead attempted on loop with no parent");
1475   }
1476 
1477   ExprPtr head_end = alloc<Min>(
1478       alloc<Add>(f->start(), immLike(f->stop(), factor)), f->stop(), true);
1479   *head = alloc<For>(f->var(), f->start(), head_end, Stmt::clone(f->body()));
1480   p->insert_stmt_before(*head, f);
1481 
1482   f->set_start(head_end);
1483   *tail = f;
1484 
1485   if (f->loop_options().is_gpu_block_index() ||
1486       f->loop_options().is_gpu_thread_index()) {
1487     LoopNest::normalize(*tail);
1488   }
1489 }
sliceHead(const ForPtr & f,int factor)1490 void LoopNest::sliceHead(const ForPtr& f, int factor) {
1491   ForPtr head, tail;
1492   sliceHead(f, factor, &head, &tail);
1493 }
1494 
sliceTail(const ForPtr & f,int factor,ForPtr * head,ForPtr * tail)1495 void LoopNest::sliceTail(
1496     const ForPtr& f,
1497     int factor,
1498     ForPtr* head,
1499     ForPtr* tail) {
1500   if (intValue(f->start()) && intValue(f->stop())) {
1501     auto start_val = *intValue(f->start());
1502     auto stop_val = *intValue(f->stop());
1503     auto size_val = stop_val - start_val;
1504     if (factor >= size_val) {
1505       *head = nullptr;
1506       *tail = f;
1507       return;
1508     }
1509   }
1510 
1511   if (!f) {
1512     throw malformed_input("sliceTail attempted on null loop");
1513   }
1514 
1515   BlockPtr p = to<Block>(f->get_parent());
1516   if (!p) {
1517     throw malformed_input("sliceTail attempted on loop with no parent");
1518   }
1519 
1520   ExprPtr tail_start = alloc<Max>(
1521       f->start(), alloc<Sub>(f->stop(), immLike(f->stop(), factor)), true);
1522   *tail = alloc<For>(f->var(), tail_start, f->stop(), Stmt::clone(f->body()));
1523   p->insert_stmt_after(*tail, f);
1524 
1525   f->set_stop(tail_start);
1526   *head = f;
1527 
1528   if (f->loop_options().is_gpu_block_index() ||
1529       f->loop_options().is_gpu_thread_index()) {
1530     LoopNest::normalize(*head);
1531   }
1532 }
sliceTail(const ForPtr & f,int factor)1533 void LoopNest::sliceTail(const ForPtr& f, int factor) {
1534   ForPtr head, tail;
1535   sliceTail(f, factor, &head, &tail);
1536 }
1537 
splitWithTail(const ForPtr & f,int factor)1538 void LoopNest::splitWithTail(const ForPtr& f, int factor) {
1539   ForPtr inner, tail;
1540   splitWithTail(f, factor, &inner, &tail);
1541 }
1542 
splitWithTail(const ForPtr & f,int factor,ForPtr * inner,ForPtr * tail)1543 void LoopNest::splitWithTail(
1544     const ForPtr& f,
1545     int factor,
1546     ForPtr* inner,
1547     ForPtr* tail) {
1548   if (!f) {
1549     throw malformed_input("splitWithTail attempted on null loop");
1550   }
1551 
1552   BlockPtr p = to<Block>(f->get_parent());
1553   if (!p) {
1554     throw malformed_input("splitWithTail attempted on loop with no parent");
1555   }
1556 
1557   // Normalize the loop to simplify start and stop bound computation
1558   normalize(f);
1559 
1560   bool tail_is_needed = true;
1561   if (intValue(f->start()) && intValue(f->stop())) {
1562     auto const start_val = *intValue(f->start());
1563     auto const stop_val = *intValue(f->stop());
1564     auto const size_val = stop_val - start_val;
1565     auto const tail_size = size_val % factor;
1566     if (tail_size == 0) {
1567       tail_is_needed = false;
1568     }
1569   }
1570 
1571   ExprPtr factor_expr = immLike(f->stop(), factor);
1572   ExprPtr size = alloc<Sub>(f->stop(), f->start());
1573   ExprPtr split_count = alloc<Div>(size, factor_expr);
1574   ExprPtr tail_size = alloc<Mod>(size, factor_expr);
1575 
1576   const std::string& loop_var_name = f->var()->name_hint();
1577   Dtype loop_var_dtype = f->var()->dtype();
1578 
1579   VarPtr i_inner = alloc<Var>(loop_var_name + "_inner", loop_var_dtype);
1580   VarPtr i_outer = alloc<Var>(loop_var_name + "_outer", loop_var_dtype);
1581 
1582   // x -> x.outer * inner.size + x.inner
1583   ExprPtr combined_index1 =
1584       alloc<Add>(alloc<Mul>(i_outer, factor_expr), i_inner);
1585 
1586   if (tail_is_needed) {
1587     VarPtr i_tail = alloc<Var>(loop_var_name + "_tail", loop_var_dtype);
1588     // x -> x.tail + outer.size * inner.size
1589     ExprPtr combined_index2 =
1590         alloc<Add>(i_tail, alloc<Mul>(split_count, factor_expr));
1591 
1592     StmtPtr body_tail =
1593         SubstituteInClone(f->body(), {{f->var(), combined_index2}});
1594     *tail = alloc<For>(i_tail, immLike(tail_size, 0), tail_size, body_tail);
1595 
1596     p->insert_stmt_after(*tail, f);
1597   } else {
1598     *tail = nullptr;
1599   }
1600 
1601   StmtPtr body_inner =
1602       Substitute(f->removeBody(), {{f->var(), combined_index1}});
1603 
1604   *inner =
1605       alloc<For>(i_inner, immLike(factor_expr, 0), factor_expr, body_inner);
1606   // The input loop `f` will be the outer loop after split.
1607   f->set_var(i_outer);
1608   f->set_start(immLike(split_count, 0));
1609   f->set_stop(split_count);
1610   f->set_body(*inner);
1611 }
1612 
splitWithMask(const ForPtr & f,int factor)1613 void LoopNest::splitWithMask(const ForPtr& f, int factor) {
1614   ForPtr inner;
1615   splitWithMask(f, factor, &inner);
1616 }
1617 
splitWithMask(const ForPtr & f,int factor,ForPtr * inner)1618 void LoopNest::splitWithMask(const ForPtr& f, int factor, ForPtr* inner) {
1619   BlockPtr p = to<Block>(f->get_parent());
1620   if (!p) {
1621     std::cerr << "Parent is not a Block!\n";
1622     return;
1623   }
1624 
1625   bool tail_is_needed = true;
1626   ExprPtr start = IRSimplifier::simplify(f->start());
1627   ExprPtr stop = IRSimplifier::simplify(f->stop());
1628   if (start->isConstant() && stop->isConstant()) {
1629     auto start_val = *intValue(start);
1630     auto stop_val = *intValue(stop);
1631     auto size_val = stop_val - start_val;
1632     auto tail_size = size_val % factor;
1633     if (tail_size == 0) {
1634       tail_is_needed = false;
1635     }
1636   }
1637 
1638   auto factor_expr = immLike(f->stop(), factor);
1639   ExprPtr size = alloc<Sub>(f->stop(), f->start());
1640   // split_count = (size + factor - 1) / factor
1641   ExprPtr split_count = alloc<Div>(
1642       alloc<Sub>(alloc<Add>(size, factor_expr), immLike(size, 1)), factor_expr);
1643 
1644   const std::string& loop_var_name = f->var()->name_hint();
1645   Dtype loop_var_dtype = f->var()->dtype();
1646 
1647   VarPtr i_inner = alloc<Var>(loop_var_name + "_inner", loop_var_dtype);
1648   VarPtr i_outer = alloc<Var>(loop_var_name + "_outer", loop_var_dtype);
1649 
1650   // x -> x.outer * inner.size + x.inner
1651   ExprPtr combined_index =
1652       alloc<Add>(alloc<Mul>(i_outer, factor_expr), i_inner);
1653 
1654   StmtPtr body_inner = f->removeBody();
1655   // TODO: is it ok that we're doing it eagerly? In the other implementation we
1656   // are only materializing predicates at the last, lowering, step.
1657   if (tail_is_needed) {
1658     auto start = intValue(f->start());
1659     if (!start || *start != 0) {
1660       throw unimplemented_lowering();
1661     }
1662 
1663     ExprPtr predicate =
1664         CompareSelect::make(ExprHandle(f->var()), ExprHandle(f->stop()), kLT)
1665             .node();
1666     body_inner = Cond::make(ExprHandle(predicate), body_inner, nullptr);
1667   }
1668   body_inner = Substitute(body_inner, {{f->var(), combined_index}});
1669 
1670   *inner =
1671       alloc<For>(i_inner, immLike(factor_expr, 0), factor_expr, body_inner);
1672   // The input loop `f` will be the outer loop after split.
1673   f->set_var(i_outer);
1674   f->set_start(immLike(split_count, 0));
1675   f->set_stop(split_count);
1676   f->set_body(*inner);
1677 }
1678 
distributeLoop(const ForPtr & loop,const std::unordered_set<StmtPtr> & pivots)1679 std::vector<ForPtr> LoopNest::distributeLoop(
1680     const ForPtr& loop,
1681     const std::unordered_set<StmtPtr>& pivots) {
1682   TORCH_INTERNAL_ASSERT(
1683       loop,
1684       buildErrorMessage(
1685           "Expected non-null loop in distributeLoop in the fuser."));
1686   auto root = loop->get_parent();
1687   if (root == nullptr) {
1688     throw malformed_input("Loop without parent: ", loop);
1689   }
1690   auto root_block = to<Block>(root);
1691   if (root_block == nullptr) {
1692     throw malformed_input(
1693         "Loop's parent must be a Block, instead found ", root);
1694   }
1695 
1696   // Extract bodies for all the loops after distribution.
1697   std::vector<BlockPtr> new_loop_bodies;
1698   auto new_loop_body = alloc<Block>(std::vector<StmtPtr>({}));
1699   while (!loop->body()->empty()) {
1700     auto s = loop->body()->front();
1701     loop->body()->remove_stmt(s);
1702     new_loop_body->append_stmt(s);
1703     if (pivots.count(s)) {
1704       new_loop_bodies.push_back(new_loop_body);
1705       new_loop_body = alloc<Block>(std::vector<StmtPtr>({}));
1706     }
1707   }
1708   if (!new_loop_body->empty()) {
1709     new_loop_bodies.push_back(new_loop_body);
1710   }
1711 
1712   // The first loop body has to be in the original loop.
1713   loop->body()->splice(loop->body()->begin(), new_loop_bodies.front());
1714   std::vector<ForPtr> new_loops = {loop};
1715 
1716   // Create loops for all the remaining blocks.
1717   // Add all the new loops to the parent block.
1718   for (size_t i = 1; i < new_loop_bodies.size(); ++i) {
1719     auto new_loop = loop->cloneWithNewBody(new_loop_bodies[i]);
1720     root_block->insert_stmt_after(new_loop, new_loops.back());
1721     new_loops.push_back(new_loop);
1722   }
1723 
1724   return new_loops;
1725 }
1726 
distributeLoop(const ForPtr & loop)1727 std::vector<ForPtr> LoopNest::distributeLoop(const ForPtr& loop) {
1728   std::unordered_set<StmtPtr> stmtsInBlock(
1729       loop->body()->begin(), loop->body()->end());
1730   return distributeLoop(loop, stmtsInBlock);
1731 }
1732 
distributeLoopAndParents(const ForPtr & loop)1733 std::vector<ForPtr> LoopNest::distributeLoopAndParents(const ForPtr& loop) {
1734   auto parentLoop = getParentLoop(loop);
1735   auto result = distributeLoop(loop);
1736   if (parentLoop) {
1737     return distributeLoopAndParents(parentLoop);
1738   }
1739   return result;
1740 }
1741 
distributeLoopOverInnerLoops(const ForPtr & loop)1742 std::vector<ForPtr> LoopNest::distributeLoopOverInnerLoops(const ForPtr& loop) {
1743   auto loops = NodeFinder<For>::find(loop);
1744   std::unordered_set<StmtPtr> loopsSet(loops.begin(), loops.end());
1745   return distributeLoop(loop, loopsSet);
1746 }
1747 
distributeLoopAndParentsOverInnerLoops(const ForPtr & loop)1748 std::vector<ForPtr> LoopNest::distributeLoopAndParentsOverInnerLoops(
1749     const ForPtr& loop) {
1750   auto parentLoop = getParentLoop(loop);
1751   auto result = distributeLoopOverInnerLoops(loop);
1752   if (parentLoop) {
1753     return distributeLoopAndParentsOverInnerLoops(parentLoop);
1754   }
1755   return result;
1756 }
1757 
areEqual(const ExprPtr & expr1,const ExprPtr & expr2)1758 static bool areEqual(const ExprPtr& expr1, const ExprPtr& expr2) {
1759   auto diff = IRSimplifier::simplify(alloc<Sub>(expr1, expr2));
1760   return diff->isConstant() && (immediateAs<int>(diff) == 0);
1761 };
1762 
doesExprContainAnyVar(const ExprPtr & expr,const std::unordered_set<VarPtr> & vars)1763 static bool doesExprContainAnyVar(
1764     const ExprPtr& expr,
1765     const std::unordered_set<VarPtr>& vars) {
1766   for (const auto& v : VarFinder::find(expr)) {
1767     if (vars.count(v)) {
1768       return true;
1769     }
1770   }
1771   return false;
1772 }
1773 
1774 // Returns true if the given list of indices refer to two accesses
1775 // that are loop-independent w.r.t. the given list of outer loop
1776 // variables.
areIndicesLoopIndependent(const std::vector<ExprPtr> & expr_list1,const std::vector<ExprPtr> & expr_list2,const std::unordered_set<VarPtr> & outer_loop_vars)1777 static bool areIndicesLoopIndependent(
1778     const std::vector<ExprPtr>& expr_list1,
1779     const std::vector<ExprPtr>& expr_list2,
1780     const std::unordered_set<VarPtr>& outer_loop_vars) {
1781   if (expr_list1.size() != expr_list2.size()) {
1782     return false;
1783   }
1784   for (size_t i = 0; i < expr_list1.size(); ++i) {
1785     const auto& expr1 = expr_list1[i];
1786     const auto& expr2 = expr_list2[i];
1787     if (doesExprContainAnyVar(expr1, outer_loop_vars) ||
1788         doesExprContainAnyVar(expr2, outer_loop_vars)) {
1789       if (!areEqual(expr1, expr2)) {
1790         return false;
1791       }
1792     }
1793   }
1794   return true;
1795 }
1796 
hasLoopCarriedDependence(const ForPtr & loop)1797 bool LoopNest::hasLoopCarriedDependence(const ForPtr& loop) {
1798   analysis::MemDependencyChecker analyzer;
1799   loop->accept(&analyzer);
1800 
1801   std::unordered_set<VarPtr> outer_loop_vars = {loop->var()};
1802   auto outer_loops = LoopNest::getEnclosingLoopNest(loop);
1803   for (const auto& l : outer_loops) {
1804     outer_loop_vars.insert(l->var());
1805   }
1806 
1807   // High-level algorithm to check if two accesses to a buffer, A and B, one of
1808   // which is a Store, result in a loop-carried dependence:
1809   //   1. For every pair of index expressions, Ai and Bi, that refer to a dim
1810   //      of A and B, if one of the following conditions are satisfied:
1811   //       a) Ai and Bi are equal (OR)
1812   //       b) Both Ai and Bi do not contain any outer-loop variables
1813   //      then, the dependence between A and B is a loop-independent
1814   //      dependence. This is because, in the case of b), those index
1815   //      expressions do not affect the ordering of accesses A and B.
1816   //   2. If condition 1) is not satisfied:
1817   //       a) if the bounds on the accesses overlap, then this is a
1818   //          loop-carried dependence.
1819   //       b) if the bounds on the accesses do not overlap, then there is no
1820   //          dependence.
1821   //
1822   // NOTE: Since we check for equality of index expressions whenever outer
1823   //     loop variables are involved, this may incorrectly report some cases as
1824   //     having a loop-carried dependence. It is impractical to handle all
1825   //     possible cases here, so, we are being conservative and allow for
1826   //     some false positives. While this will prevent some loop fusion
1827   //     opportunities, that should be a small fraction of the cases that are
1828   //     allowed.
1829   //
1830   // Implementation:
1831   //
1832   // For every pair of statements, S1 and S2, in the loop:
1833   //  * Get the loads and stores in S1 and S2.
1834   //  * For every store in S1 and load in S2 to the same buffer, if the index
1835   //    expressions are not equal and there is an overlap in accesses, return
1836   //    true to indicate a loop-carried dependence.
1837   //  * For every load in S1 and store in S2 to the same buffer, if the index
1838   //    expressions are not equal and there is an overlap in accesses, return
1839   //    true to indicate a loop-carried dependence.
1840   //  * For every store in S1 and store in S2 to the same buffer, if the index
1841   //    expressions are not equal and there is an overlap in accesses, return
1842   //    true to indicate a loop-carried dependence.
1843   for (auto it1 = loop->body()->begin(); it1 != loop->body()->end(); ++it1) {
1844     for (auto it2 = std::next(it1); it2 != loop->body()->end(); ++it2) {
1845       auto aStores = NodeFinder<Store>::find(*it1);
1846       auto aLoads = NodeFinder<Load>::find(*it1);
1847       auto bStores = NodeFinder<Store>::find(*it2);
1848       auto bLoads = NodeFinder<Load>::find(*it2);
1849       // ReadAfterWrite
1850       for (auto& aStore : aStores) {
1851         for (auto& bLoad : bLoads) {
1852           if (aStore->buf() == bLoad->buf()) {
1853             if (!areIndicesLoopIndependent(
1854                     aStore->indices(), bLoad->indices(), outer_loop_vars)) {
1855               if (isOverlapping(analyzer, aStore, bLoad)) {
1856                 return true;
1857               }
1858             }
1859           }
1860         }
1861       }
1862       // WriteAfterRead
1863       for (auto& bStore : bStores) {
1864         for (auto& aLoad : aLoads) {
1865           if (bStore->buf() == aLoad->buf()) {
1866             if (!areIndicesLoopIndependent(
1867                     bStore->indices(), aLoad->indices(), outer_loop_vars)) {
1868               if (isOverlapping(analyzer, bStore, aLoad)) {
1869                 return true;
1870               }
1871             }
1872           }
1873         }
1874       }
1875       // WriteAfterWrite
1876       for (auto& aStore : aStores) {
1877         for (auto& bStore : bStores) {
1878           if (aStore->buf() == bStore->buf()) {
1879             if (!areIndicesLoopIndependent(
1880                     aStore->indices(), bStore->indices(), outer_loop_vars)) {
1881               if (isOverlapping(analyzer, aStore, bStore)) {
1882                 return true;
1883               }
1884             }
1885           }
1886         }
1887       }
1888     }
1889   }
1890   return false;
1891 }
1892 
unsafeFuseLoops(const std::vector<ForPtr> & loops,ForPtr * fused)1893 bool LoopNest::unsafeFuseLoops(
1894     const std::vector<ForPtr>& loops,
1895     ForPtr* fused) {
1896   if (loops.empty()) {
1897     return false;
1898   }
1899   if (loops.size() == 1) {
1900     *fused = loops.front();
1901     return true;
1902   }
1903 
1904   // Check if all the loops have the same parent.
1905   auto root = loops.front()->get_parent();
1906   for (const auto& l : loops) {
1907     auto par = l->get_parent();
1908     if (par == nullptr) {
1909       return false;
1910     }
1911     if (par != root) {
1912       return false;
1913     }
1914   }
1915   auto root_block = to<Block>(root);
1916   if (root_block == nullptr) {
1917     return false;
1918   }
1919 
1920   // Currently, we only handle cases where there are no statements between
1921   // the given loops in their parents body. We can possibly relax this
1922   // constraint by allowing statements that do not affect the loops being
1923   // fused by performing some dependency analysis. TODO.
1924   auto it = root_block->begin();
1925   for (; it != root_block->end(); ++it) {
1926     if (*it == loops.front()) {
1927       break;
1928     }
1929   }
1930   TORCH_INTERNAL_ASSERT(
1931       it != root_block->end(),
1932       buildErrorMessage(
1933           "Could not find the given loop in the root stmt in unsafeFuseLoop the fuser."));
1934   for (const auto& l : loops) {
1935     if (*it != l) {
1936       return false;
1937     }
1938     ++it;
1939   }
1940 
1941   const auto& first_loop = loops.front();
1942   // Fuse the loops by taking all the statements from the second loops
1943   // onwards and moving them into the first loop's body.
1944   // This way the final fused loop will be the same as the first loop.
1945   for (size_t i = 1; i < loops.size(); ++i) {
1946     auto body = to<Block>(SubstituteInClone(
1947         loops[i]->body(), {{loops[i]->var(), first_loop->var()}}));
1948     first_loop->body()->splice(first_loop->body()->end(), body);
1949     root_block->remove_stmt(loops[i]);
1950   }
1951 
1952   *fused = loops.front();
1953   return true;
1954 }
1955 
fuseLoops(const std::vector<ForPtr> & loops,ForPtr * fused)1956 bool LoopNest::fuseLoops(const std::vector<ForPtr>& loops, ForPtr* fused) {
1957   if (loops.empty()) {
1958     return false;
1959   }
1960   if (loops.size() == 1) {
1961     *fused = loops.front();
1962     return true;
1963   }
1964 
1965   // Check if bounds are the same for all the loops.
1966   const auto& first_loop = loops.front();
1967   auto first_loop_start = IRSimplifier::simplify(first_loop->start());
1968   auto first_loop_stop = IRSimplifier::simplify(first_loop->stop());
1969   for (size_t i = 1; i < loops.size(); ++i) {
1970     const auto& curr_loop = loops[i];
1971     auto curr_loop_start = IRSimplifier::simplify(curr_loop->start());
1972     auto curr_loop_stop = IRSimplifier::simplify(curr_loop->stop());
1973     if (!areEqual(curr_loop_start, first_loop_start)) {
1974       return false;
1975     }
1976     if (!areEqual(curr_loop_stop, first_loop_stop)) {
1977       return false;
1978     }
1979   }
1980 
1981   // We need to check if fusing the loops results in a loop-carried dependence.
1982   // This check can be done only after the loops are fused into one. But if the
1983   // check is violated, we need to return the given loops in the original form.
1984   // So, we create a clone of all the loops, fuse them and check for this.
1985   std::vector<ForPtr> loops_copy;
1986   loops_copy.reserve(loops.size());
1987   BlockPtr parent = alloc<Block>(std::vector<StmtPtr>({}));
1988   for (auto& l : loops) {
1989     auto l_copy = Stmt::clone(l);
1990     loops_copy.push_back(to<For>(l_copy));
1991     parent->append_stmt(l_copy);
1992   }
1993   ForPtr fused_copy;
1994   bool ret = unsafeFuseLoops(loops_copy, &fused_copy);
1995   if (!ret || hasLoopCarriedDependence(fused_copy)) {
1996     return false;
1997   }
1998 
1999   // Now that all conditions are satisfied, we fuse the given loops.
2000   return unsafeFuseLoops(loops, fused);
2001 }
2002 
findOuterFor(ForPtr a,ForPtr b)2003 ForPtr LoopNest::findOuterFor(ForPtr a, ForPtr b) {
2004   StmtPtr s = b; // guess b is the latter.
2005   while (s != nullptr) {
2006     if (s == a) {
2007       // yes, b is after a.
2008       return a;
2009     }
2010     s = s->get_parent();
2011   }
2012 
2013   // check that the two are in the same loop nest.
2014   s = a;
2015   while (s != nullptr) {
2016     if (s == b) {
2017       // a is after b.
2018       return b;
2019     }
2020     s = s->get_parent();
2021   }
2022 
2023   // a and b have no relationship.
2024   return nullptr;
2025 }
2026 
reorderAxis(const ForPtr & a,const ForPtr & b)2027 void LoopNest::reorderAxis(const ForPtr& a, const ForPtr& b) {
2028   if (a == b) {
2029     // nothing to do.
2030     return;
2031   }
2032   // find inner and outer.
2033   ForPtr outer = findOuterFor(a, b);
2034   if (outer == nullptr) {
2035     throw std::runtime_error("Reordered a loop not in LoopNest");
2036   }
2037 
2038   ForPtr inner = a == outer ? b : a;
2039   std::deque<ForPtr> internal_axes;
2040 
2041   // Find relevant axes, store reversed.
2042   StmtPtr s = inner;
2043   while (s != outer) {
2044     if (const ForPtr& f = to<For>(s)) {
2045       internal_axes.push_back(f);
2046     }
2047 
2048     s = s->get_parent();
2049   }
2050 
2051   internal_axes.push_back(outer);
2052 
2053   BlockPtr root = to<Block>(outer->get_parent());
2054   CHECK(root);
2055 
2056   // Do a shallow copy of the inner blocks.
2057   BlockPtr body = alloc<Block>(std::vector<StmtPtr>({}));
2058   body->splice(body->end(), inner->body());
2059 
2060   const ForPtr& before{outer};
2061   ForPtr after{nullptr};
2062   ForPtr last = internal_axes.front();
2063   StmtPtr newInner = body;
2064 
2065   s = inner;
2066   while (s != outer) {
2067     if (auto cond = to<Cond>(s->get_parent())) {
2068       if (s == cond->true_stmt()) {
2069         newInner = cond->cloneWithNewBody(newInner);
2070       } else {
2071         // s is the false branch of Cond
2072         newInner = cond->cloneWithNewBodies(
2073             alloc<Block>(std::vector<StmtPtr>({})), newInner);
2074       }
2075     }
2076     s = s->get_parent();
2077   }
2078 
2079   // This is the major complexity in loop reordering: handling statements not in
2080   // the straight line of the reorder. To handle this we partition the tree into
2081   // the section before the critical path and after the critical path.
2082   //
2083   // An example of this pattern is:
2084   // for i in ..
2085   //   Statement A
2086   //   for j in ..
2087   //     Statement B
2088   //   Statement C
2089   //
2090   // When reordering loop i and j we need to ensure that Statement A and C are
2091   // still both executed with the loop extents of i, and that the three
2092   // statements are not reordered (as much as possible).
2093   for (const auto& loop : internal_axes) {
2094     // If the inner loop had a component after the loop we must wrap it in a For
2095     // loop matching this level of the tree.
2096     if (after != nullptr) {
2097       after = loop->cloneWithNewBody(after);
2098     }
2099 
2100     bool pastMidpoint = false;
2101     bool hadBeforeStmts = false;
2102     for (auto I = loop->body()->begin(), E = loop->body()->end(); I != E;) {
2103       // Be careful not to invalidate the iterator.
2104       StmtPtr s = *(I++);
2105       if (s == last) {
2106         // This is the midpoint.
2107         loop->body()->remove_stmt(s);
2108         if (!hadBeforeStmts) {
2109           // If there were no existing statements this loop does not need  to be
2110           // preserved and we can roll it into the above loop.
2111           last = loop;
2112         }
2113         pastMidpoint = true;
2114       } else if (pastMidpoint) {
2115         // Statements after the reordered path must be moved to a new tree after
2116         // the reordered statement has occurred to preserve ordering.
2117         loop->body()->remove_stmt(s);
2118         if (after == nullptr) {
2119           after = loop->cloneWithNewBody(s);
2120         } else {
2121           after->body()->append_stmt(s);
2122         }
2123       } else {
2124         // We can leave any statements before the reordered loop alone, so long
2125         // as we preserve the loop structure.
2126         hadBeforeStmts = true;
2127       }
2128     }
2129   }
2130 
2131   // now we can actually reorder the chosen axes.
2132   std::swap(internal_axes.front(), internal_axes.back());
2133 
2134   // Create the reordered internals:
2135   for (const auto& loop : internal_axes) {
2136     newInner = loop->cloneWithNewBody(newInner);
2137   }
2138 
2139   // Append the new statements to the root of the tree.
2140   if (before->body()->nstmts() == 0) {
2141     // If the top level is now empty, eliminate it.
2142     root->replace_stmt(before, newInner);
2143   } else {
2144     root->insert_stmt_after(newInner, before);
2145   }
2146 
2147   if (after) {
2148     root->insert_stmt_after(after, newInner);
2149   }
2150 }
2151 
isTrivialPermutation(const std::vector<size_t> & permutation)2152 static bool isTrivialPermutation(const std::vector<size_t>& permutation) {
2153   for (size_t i = 0; i < permutation.size(); ++i) {
2154     if (permutation[i] != i) {
2155       return false;
2156     }
2157   }
2158   return true;
2159 }
2160 
isValidPermutation(std::vector<size_t> permutation)2161 static bool isValidPermutation(std::vector<size_t> permutation) {
2162   std::sort(permutation.begin(), permutation.end());
2163   return isTrivialPermutation(permutation);
2164 }
2165 
reorder(const std::vector<ForPtr> & loops,const std::vector<size_t> & permutation)2166 std::vector<ForPtr> LoopNest::reorder(
2167     const std::vector<ForPtr>& loops,
2168     const std::vector<size_t>& permutation) {
2169   if (loops.size() != permutation.size()) {
2170     throw malformed_input("invalid permutation size");
2171   }
2172   if (isTrivialPermutation(permutation)) {
2173     return loops;
2174   }
2175   if (!isValidPermutation(permutation)) {
2176     throw malformed_input("invalid permutation for reorder");
2177   }
2178   if (loops.size() < 2) {
2179     return loops;
2180   }
2181   if (!areLoopsPerfectlyNested(loops)) {
2182     throw malformed_input("reorder is only allowed on perfectly nested loops");
2183   }
2184 
2185   auto parent = to<Block>(loops.front()->get_parent());
2186   if (parent == nullptr) {
2187     throw malformed_input("parent of the loops must be a Block");
2188   }
2189 
2190   // Reorder the loops according to the permutation.
2191   std::vector<ForPtr> result(loops.size());
2192   for (size_t i = 0; i < loops.size(); ++i) {
2193     result[i] = loops[permutation[i]];
2194   }
2195 
2196   // Remove the bodies from all the loops.
2197   auto innermost_body = loops.back()->removeBody();
2198   // We use an empty block statement to replace the outermost loop
2199   // so that we know the position where the outermost reordered loop
2200   // is to be inserted.
2201   auto empty_block = alloc<Block>(std::vector<StmtPtr>({}));
2202   parent->replace_stmt(loops.front(), empty_block);
2203   for (size_t i = 1; i < loops.size(); ++i) {
2204     auto block = to<Block>(loops[i]->get_parent());
2205     TORCH_INTERNAL_ASSERT(
2206         block,
2207         buildErrorMessage(
2208             "Expected parent stmt to be a non-null Block in reorder transformation the fuser."));
2209     block->remove_stmt(loops[i]);
2210   }
2211 
2212   // Set the new bodies after reorder for all the loops.
2213   for (size_t i = 0; i < result.size() - 1; ++i) {
2214     result[i]->set_body(result[i + 1]);
2215   }
2216   result.back()->set_body(innermost_body);
2217   parent->replace_stmt(empty_block, result.front());
2218   return result;
2219 }
2220 
getLoopAt(ForPtr root,const std::vector<int> & indices) const2221 ForPtr LoopNest::getLoopAt(ForPtr root, const std::vector<int>& indices) const {
2222   if (indices.empty()) {
2223     return root;
2224   }
2225   if (root == nullptr) {
2226     throw malformed_input("root loop is null");
2227   }
2228 
2229   ForPtr curr = std::move(root);
2230   for (auto i : indices) {
2231     if (i < 0 || curr->body()->nstmts() <= static_cast<size_t>(i)) {
2232       return nullptr;
2233     }
2234     std::list<StmtPtr>::iterator stmtp = curr->body()->begin();
2235     std::advance(stmtp, i);
2236     curr = to<For>(*stmtp);
2237     if (curr == nullptr) {
2238       return nullptr;
2239     }
2240   }
2241 
2242   return curr;
2243 }
2244 
tile(const ForPtr & x,const ForPtr & y,int x_factor,int y_factor)2245 ForPtr LoopNest::tile(
2246     const ForPtr& x,
2247     const ForPtr& y,
2248     int x_factor,
2249     int y_factor) {
2250   auto parent = to<Block>(x->get_parent());
2251   if (parent == nullptr) {
2252     throw malformed_input("parent of the loops must be a Block");
2253   }
2254   if (!areLoopsPerfectlyNested({x, y})) {
2255     throw malformed_input("two loops must be perfectly nested");
2256   }
2257 
2258   // Split x, y axes by x_factor and y_factor
2259   ForPtr yi, ytail;
2260   splitWithTail(y, y_factor, &yi, &ytail);
2261   ForPtr xi, xtail;
2262   splitWithTail(x, x_factor, &xi, &xtail);
2263 
2264   // Distribute xi over yo and ytail so we can manipulate the loop order of {xo,
2265   // xi, yo, yi}
2266   auto loops = distributeLoop(xi);
2267 
2268   // For {xi, yo, yi}, reorder the axes to be yo, xi, yi
2269   xi = loops.front();
2270   ForPtr yo = to<For>(xi->body()->stmts().front());
2271   CHECK(yo);
2272   reorder({xi, yo}, {1, 0});
2273 
2274   // For {xi, ytail}, reorder the axes to be ytail, xi
2275   if (loops.size() == 2) {
2276     xi = loops.back();
2277     ytail = to<For>(xi->body()->stmts().front());
2278     CHECK(ytail);
2279     reorder({xi, ytail}, {1, 0});
2280   }
2281 
2282   return xtail;
2283 }
2284 
areLoopsPerfectlyNested(const std::vector<ForPtr> & loops)2285 bool LoopNest::areLoopsPerfectlyNested(const std::vector<ForPtr>& loops) {
2286   if (loops.size() < 2) {
2287     return true;
2288   }
2289   for (size_t i = 0; i < loops.size() - 1; ++i) {
2290     auto loop_body = loops[i]->body();
2291     if (loop_body->nstmts() != 1 || loop_body->front() != loops[i + 1]) {
2292       return false;
2293     }
2294   }
2295   return true;
2296 }
2297 
fullUnroll(const ForPtr & f,StmtPtr * unrolled)2298 void LoopNest::fullUnroll(const ForPtr& f, StmtPtr* unrolled) {
2299   BlockPtr p = to<Block>(f->get_parent());
2300   if (!f) {
2301     throw malformed_input("unroll attempted on null loop");
2302   } else if (!p) {
2303     throw malformed_input("unroll attempted on loop with no parent");
2304   }
2305 
2306   auto start_expr = IRSimplifier::simplify(f->start());
2307   auto stop_expr = IRSimplifier::simplify(f->stop());
2308   if (!start_expr->isConstant()) {
2309     throw std::runtime_error("Can't unroll due to non-constant loop start!");
2310   }
2311   if (!stop_expr->isConstant()) {
2312     throw std::runtime_error("Can't unroll due to non-constant loop stop!");
2313   }
2314 
2315   std::vector<StmtPtr> unrolled_stmts;
2316   int start_val = immediateAs<int>(start_expr);
2317   int stop_val = immediateAs<int>(stop_expr);
2318   for (int current = start_val; current < stop_val; ++current) {
2319     for (const auto& stmt : f->body()->stmts()) {
2320       unrolled_stmts.push_back(SubstituteInClone(
2321           stmt, {{f->var(), getImmediateByType(f->var()->dtype(), current)}}));
2322     }
2323   }
2324   *unrolled = alloc<Block>(unrolled_stmts);
2325   *unrolled = IRSimplifier::simplify(*unrolled);
2326 
2327   p->replace_stmt(f, *unrolled);
2328 }
2329 
fullUnroll(const ForPtr & f)2330 void LoopNest::fullUnroll(const ForPtr& f) {
2331   StmtPtr unrolled;
2332   fullUnroll(f, &unrolled);
2333 }
2334 
unroll(const ForPtr & f,int factor,ForPtr * tail)2335 void LoopNest::unroll(const ForPtr& f, int factor, ForPtr* tail) {
2336   if (factor < 2) {
2337     return;
2338   }
2339   ForPtr inner;
2340   splitWithTail(f, factor, &inner, tail);
2341   fullUnroll(inner);
2342 }
2343 
unroll(const ForPtr & f,int factor)2344 void LoopNest::unroll(const ForPtr& f, int factor) {
2345   ForPtr tail;
2346   unroll(f, factor, &tail);
2347 }
2348 
isNormalized(const ForPtr & f)2349 bool LoopNest::isNormalized(const ForPtr& f) {
2350   if (f->start()->isConstant()) {
2351     return immediateAs<int>(f->start()) == 0;
2352   }
2353   return false;
2354 }
2355 
normalize(const ForPtr & f)2356 bool LoopNest::normalize(const ForPtr& f) {
2357   if (!f) {
2358     throw malformed_input("normalize attempted on null loop");
2359   }
2360 
2361   if (isNormalized(f)) {
2362     // No need to normalize anymore here.
2363     return false;
2364   }
2365 
2366   auto for_body_normalized = Substitute(
2367       f->body(),
2368       {{f->var(), (VarHandle(f->var()) + ExprHandle(f->start())).node()}});
2369   f->set_body(IRSimplifier::simplify(for_body_normalized));
2370   f->set_stop(IRSimplifier::simplify(alloc<Sub>(f->stop(), f->start())));
2371   f->set_start(immLike(f->stop(), 0));
2372   return true;
2373 }
2374 
2375 // This function expects that there are 'num' loops perfectly nested within
2376 // and including 'f'.
getLoopStmtsInLoopNest(const ForPtr & f,size_t num)2377 std::vector<ForPtr> LoopNest::getLoopStmtsInLoopNest(
2378     const ForPtr& f,
2379     size_t num) {
2380   std::vector<ForPtr> loops(num);
2381   ForPtr curr_for = f;
2382   loops[0] = curr_for;
2383   for (size_t i = 1; i < num; ++i) {
2384     TORCH_INTERNAL_ASSERT(
2385         curr_for->body()->nstmts() == 1,
2386         buildErrorMessage("Expected a single stmt in the loop body."));
2387     curr_for = to<For>(curr_for->body()->front());
2388     TORCH_INTERNAL_ASSERT(
2389         curr_for,
2390         buildErrorMessage("Expected the only child stmt to be a For loop."));
2391     loops[i] = curr_for;
2392   }
2393   return loops;
2394 }
2395 
flatten(const std::vector<ForPtr> & loops,ForPtr * flattened)2396 bool LoopNest::flatten(const std::vector<ForPtr>& loops, ForPtr* flattened) {
2397   if (loops.empty()) {
2398     throw malformed_input("flatten attempted on empty set of loops");
2399   }
2400   BlockPtr p = to<Block>(loops[0]->get_parent());
2401   if (!p) {
2402     throw malformed_input("flatten attempted on loops with no parent");
2403   }
2404 
2405   if (loops.size() == 1) {
2406     // This loop nest is already flattened.
2407     *flattened = loops[0];
2408     return false;
2409   }
2410 
2411   // Check if all the loops correspond to a perfect loopnest:
2412   //  * every loop except the inner-most should have only one stmt, the For.
2413   // Do not flatten, otherwise.
2414   // This check also ensures we do not flatten reduction loops.
2415   for (size_t i = 0; i < loops.size() - 1; ++i) {
2416     if ((loops[i]->body()->nstmts() != 1) ||
2417         (loops[i]->body()->front() != loops[i + 1])) {
2418       return false;
2419     }
2420   }
2421 
2422   // Normalize the loops before flattening.
2423   // We need to normalize them from inner-most to outer because once the outer
2424   // loop is normalized, the given pointers to inner loops point to old code.
2425   // For the same reason, we can't store the normalized inner loops until after
2426   // the outer-most loop is normalized.
2427   for (size_t i = 0; i < loops.size(); ++i) {
2428     size_t idx = loops.size() - i - 1;
2429     LoopNest::normalize(loops[idx]);
2430   }
2431 
2432   // 'normalized' points to the outer-most loop in the normalized loopnest.
2433   // Collect all the normalized loops.
2434   auto normalized_loops = getLoopStmtsInLoopNest(loops.front(), loops.size());
2435 
2436   auto flat_var = alloc<Var>(
2437       normalized_loops[0]->var()->name_hint() + "_flat",
2438       normalized_loops[0]->var()->dtype());
2439   VarMapping var_mapping;
2440   ExprPtr stop = immLike(flat_var, 1);
2441   for (size_t i = 0; i < normalized_loops.size(); ++i) {
2442     size_t idx = normalized_loops.size() - i - 1;
2443     auto curr_loop = normalized_loops[idx];
2444     ExprPtr div = alloc<Div>(flat_var, stop);
2445     ExprPtr sub_expr = idx == 0 ? div : alloc<Mod>(div, curr_loop->stop());
2446     var_mapping.emplace_back(curr_loop->var(), sub_expr);
2447     stop = alloc<Mul>(curr_loop->stop(), stop);
2448   }
2449   auto flattened_body =
2450       Substitute(normalized_loops.back()->removeBody(), var_mapping);
2451 
2452   normalized_loops.front()->set_var(flat_var);
2453   normalized_loops.front()->set_start(immLike(stop, 0));
2454   normalized_loops.front()->set_stop(stop);
2455   normalized_loops.front()->set_body(flattened_body);
2456   *flattened = normalized_loops.front();
2457   return true;
2458 }
2459 
flatten(const std::vector<ForPtr> & loops)2460 bool LoopNest::flatten(const std::vector<ForPtr>& loops) {
2461   ForPtr flattened;
2462   return flatten(loops, &flattened);
2463 }
2464 
compressBuffer(const BufPtr & buf,const StmtPtr & stmt)2465 void LoopNest::compressBuffer(const BufPtr& buf, const StmtPtr& stmt) {
2466   // Loop iterations in NNC IR do not follow sequential semantics by default.
2467   // In other words, the iterations of the loops could be executed in any
2468   // random order without affecting correctness. This constraint in turn
2469   // implies that there can’t be any *inter-iteration* dependences
2470   // (or *loop-carried* dependences) in NNC loops. So, any NNC IR with such
2471   // dependences is considered invalid.
2472   //
2473   // Given the constraint above, for any pair of accesses to a buffer (where
2474   // at least one of the access is a write), the accesses must be
2475   // loop-independent on the innermost loop containing the accesses as well as
2476   // all the loops above it. So, any dimension that uses only those loop
2477   // variables to access the given buffer could be optimized away.
2478   //
2479   // Algorithm:
2480   //   * Find all the accesses to the given buf. (A)
2481   //   * Find the parent common to all accesses in A. (P)
2482   //   * Collect all the loops above P. (L)
2483   //   * Collect all the loop variables corresponding to L. (LV)
2484   //   * For every access a in A:
2485   //      * For the index I in every dimension of a:
2486   //          * If the variables in I are all in LV, mark this dimension
2487   //            for deletion.
2488   //   * For every dimension that is marked for deletion in ALL accesses in A:
2489   //      * Update the buffer to set the size of that dimension to 1.
2490   //      * Update all accesses in A to set the index in that dimension to 0.
2491 
2492   auto writes = WritesToBuf::find(stmt, buf);
2493   auto reads = StmtsReadingBuf::find(stmt, buf);
2494 
2495   // Find the parent common to all the buffer accesses.
2496   BlockPtr parent = to<Block>(writes.front()->get_parent());
2497   TORCH_INTERNAL_ASSERT(
2498       parent,
2499       buildErrorMessage(
2500           "Expected parent stmt to be a non-null block in compressBuffer in the fuser."));
2501   for (const auto& w : writes) {
2502     parent = Block::getSharedParent(parent, w);
2503   }
2504   for (const auto& r : reads) {
2505     parent = Block::getSharedParent(parent, r);
2506   }
2507 
2508   // Collect all the loops that are above the common parent.
2509   auto loops = LoopNest::getEnclosingLoopNest(parent);
2510   std::unordered_set<VarPtr> loop_vars;
2511   for (const auto& l : loops) {
2512     loop_vars.insert(l->var());
2513   }
2514 
2515   // TODO: Need to handle other Stmts / Exprs that read / write buffers.
2516   auto stores = NodeFinder<Store>::find(stmt);
2517   auto loads = NodeFinder<Load>::find(stmt);
2518 
2519   // Vector to indicate which dimensions could be compressed away.
2520   std::vector<bool> dims(buf->dims().size(), true);
2521   auto check_indices = [&](const std::vector<ExprPtr>& indices) {
2522     TORCH_INTERNAL_ASSERT(
2523         indices.size() == dims.size(),
2524         buildErrorMessage(
2525             "Expected ranks to match in compressBuffer in the fuser."));
2526     for (size_t i = 0; i < indices.size(); ++i) {
2527       auto index_vars = NodeFinder<Var>::find(indices[i]);
2528       for (const auto& iv : index_vars) {
2529         if (loop_vars.count(iv) == 0) {
2530           // A variable in this index is not in loop_vars.
2531           // This implies that this dimension cannot be optimized away.
2532           dims[i] = false;
2533           break;
2534         }
2535       }
2536     }
2537   };
2538   for (const auto& s : stores) {
2539     if (s->buf() == buf) {
2540       check_indices(s->indices());
2541     }
2542   }
2543   for (const auto& l : loads) {
2544     if (l->buf() == buf) {
2545       check_indices(l->indices());
2546     }
2547   }
2548   bool any_dim_to_compress = false;
2549   for (auto d : dims) {
2550     any_dim_to_compress |= d;
2551   }
2552   if (!any_dim_to_compress) {
2553     return;
2554   }
2555 
2556   // Compress buffer by removing the marked dims.
2557   std::vector<ExprPtr> new_dims(buf->dims());
2558   for (size_t i = 0; i < dims.size(); ++i) {
2559     if (dims[i]) {
2560       new_dims[i] = immLike(buf->dims()[i], 1);
2561     }
2562   }
2563   buf->set_dims(new_dims);
2564 
2565   // Modify all access to reflect the removed dims.
2566   auto get_new_indices = [&](const std::vector<ExprPtr>& indices) {
2567     TORCH_INTERNAL_ASSERT(
2568         indices.size() == dims.size(),
2569         buildErrorMessage(
2570             "Expected ranks to match in compressBuffer in the fuser."));
2571     std::vector<ExprPtr> new_indices(indices);
2572     for (size_t i = 0; i < dims.size(); ++i) {
2573       if (dims[i]) {
2574         new_indices[i] = immLike(indices[i], 0);
2575       }
2576     }
2577     return new_indices;
2578   };
2579   for (const auto& s : stores) {
2580     if (s->buf() == buf) {
2581       s->set_indices(get_new_indices(s->indices()));
2582     }
2583   }
2584   for (const auto& l : loads) {
2585     if (l->buf() == buf) {
2586       l->set_indices(get_new_indices(l->indices()));
2587     }
2588   }
2589 }
2590 
compressAllBuffers(const StmtPtr & stmt)2591 void LoopNest::compressAllBuffers(const StmtPtr& stmt) {
2592   for (const auto& buf : BufFinder::find(stmt)) {
2593     compressBuffer(buf, stmt);
2594   }
2595 }
2596 
getLoopStmtsFor(const Tensor & t) const2597 std::vector<ForPtr> LoopNest::getLoopStmtsFor(const Tensor& t) const {
2598   StmtPtr cur_stmt = getLoopBodyFor(t);
2599   return getLoopStmtsFor(cur_stmt);
2600 }
2601 
getLoopStmtsFor(const BufPtr & buf) const2602 std::vector<ForPtr> LoopNest::getLoopStmtsFor(const BufPtr& buf) const {
2603   StmtPtr cur_stmt = getLoopBodyFor(buf);
2604   return getLoopStmtsFor(cur_stmt);
2605 }
2606 
getLoopStmtsFor(StmtPtr s) const2607 std::vector<ForPtr> LoopNest::getLoopStmtsFor(StmtPtr s) const {
2608   std::vector<ForPtr> result;
2609 
2610   while (s) {
2611     if (auto loop = to<For>(s)) {
2612       result.push_back(loop);
2613     }
2614     s = s->get_parent();
2615   }
2616   std::reverse(result.begin(), result.end());
2617   return result;
2618 }
2619 
getLoopBodyFor(const Tensor & t) const2620 StmtPtr LoopNest::getLoopBodyFor(const Tensor& t) const {
2621   return getLoopBodyFor(t.buf());
2622 }
2623 
getLoopBodyFor(BufPtr buf) const2624 StmtPtr LoopNest::getLoopBodyFor(BufPtr buf) const {
2625   auto writes = WritesToBuf::find(root_stmt_, std::move(buf));
2626 
2627   // special case for reduction Tensors, ignore the initializer if it's the only
2628   // op:
2629   if (writes.size() == 2) {
2630     if (StorePtr s = to<Store>(writes.back())) {
2631       if (ReduceOpPtr r = to<ReduceOp>(s->value())) {
2632         return (StmtPtr)s;
2633       }
2634     }
2635   }
2636 
2637   StmtPtr res = nullptr;
2638   for (const auto& s : writes) {
2639     if (!res) {
2640       res = s;
2641       continue;
2642     }
2643 
2644     res = Block::getSharedParent(res, s);
2645   }
2646 
2647   return (StmtPtr)res;
2648 }
2649 
getParentLoop(const StmtPtr & st)2650 ForPtr LoopNest::getParentLoop(const StmtPtr& st) {
2651   if (st == nullptr) {
2652     return nullptr;
2653   }
2654   auto par = st->get_parent();
2655   if (auto f = to<For>(par)) {
2656     return f;
2657   }
2658   return getParentLoop(par);
2659 }
2660 
getEnclosingLoopNest(const StmtPtr & st)2661 std::vector<ForPtr> LoopNest::getEnclosingLoopNest(const StmtPtr& st) {
2662   std::vector<ForPtr> loops;
2663   auto f = getParentLoop(st);
2664   while (f) {
2665     loops.push_back(f);
2666     f = getParentLoop(f);
2667   }
2668   std::reverse(loops.begin(), loops.end());
2669   return loops;
2670 }
2671 
getAllWritesToBuf(BufPtr buf) const2672 std::vector<StmtPtr> LoopNest::getAllWritesToBuf(BufPtr buf) const {
2673   return WritesToBuf::find(root_stmt_, std::move(buf));
2674 }
2675 
getAllInnermostLoopsWritingToBuf(BufPtr buf) const2676 std::vector<ForPtr> LoopNest::getAllInnermostLoopsWritingToBuf(
2677     BufPtr buf) const {
2678   auto writes = getAllWritesToBuf(std::move(buf));
2679   std::vector<ForPtr> innermost_loops;
2680   innermost_loops.reserve(writes.size());
2681   for (const auto& w : writes) {
2682     innermost_loops.push_back(LoopNest::getParentLoop(w));
2683   }
2684   return innermost_loops;
2685 }
2686 
getAllLoopNestsWritingToBuf(BufPtr buf) const2687 std::vector<std::vector<ForPtr>> LoopNest::getAllLoopNestsWritingToBuf(
2688     BufPtr buf) const {
2689   auto writes = getAllWritesToBuf(std::move(buf));
2690   std::vector<std::vector<ForPtr>> loopnests;
2691   loopnests.reserve(writes.size());
2692   for (const auto& w : writes) {
2693     loopnests.emplace_back(LoopNest::getEnclosingLoopNest(w));
2694   }
2695   return loopnests;
2696 }
2697 
simplify()2698 StmtPtr LoopNest::simplify() {
2699   root_stmt_ = IRSimplifier::simplify(root_stmt_);
2700   return root_stmt_;
2701 }
2702 
FlattenIndexes(const StmtPtr & s)2703 StmtPtr FlattenIndexes(const StmtPtr& s) {
2704   IndexFlattener idx_flattener;
2705   return idx_flattener.flatten(s);
2706 }
2707 
2708 // Auxiliary class for rewriting we're doing in `compute_at`. See
2709 // LoopNest::computeAt for more details.
2710 class LoopComputeAtRewriter : public IRMutator {
2711  public:
LoopComputeAtRewriter(BufPtr buf,BufPtr new_buf,std::vector<ExprPtr> offsets)2712   LoopComputeAtRewriter(
2713       BufPtr buf,
2714       BufPtr new_buf,
2715       std::vector<ExprPtr> offsets)
2716       : buf_(std::move(buf)),
2717         new_buf_(std::move(new_buf)),
2718         offsets_(std::move(offsets)) {}
2719 
2720  private:
2721   BufPtr buf_;
2722   BufPtr new_buf_;
2723   std::vector<ExprPtr> offsets_;
2724 
mutate(const LoadPtr & v)2725   ExprPtr mutate(const LoadPtr& v) override {
2726     if (v->buf() != buf_) {
2727       return v;
2728     }
2729     std::vector<ExprPtr> new_indices(v->indices().size());
2730     for (const auto i : c10::irange(v->indices().size())) {
2731       new_indices[i] =
2732           IRSimplifier::simplify(alloc<Sub>(v->indices()[i], offsets_[i]));
2733     }
2734     return alloc<Load>(v->dtype(), new_buf_, new_indices);
2735   }
2736 };
2737 
getStoreStmtOfProducer(const StmtPtr & s)2738 static StorePtr getStoreStmtOfProducer(const StmtPtr& s) {
2739   if (StorePtr st = to<Store>(s)) {
2740     return st;
2741   }
2742   if (BlockPtr b = to<Block>(s)) {
2743     for (const StmtPtr& ss : *b) {
2744       if (StorePtr st = to<Store>(ss)) {
2745         return st;
2746       }
2747     }
2748   }
2749   return nullptr;
2750 }
2751 
getOuterLoopIndexes(StmtPtr s)2752 static std::vector<VarPtr> getOuterLoopIndexes(StmtPtr s) {
2753   std::vector<VarPtr> res;
2754   StmtPtr cur = std::move(s);
2755   while (cur) {
2756     if (auto l = to<For>(cur)) {
2757       res.push_back(l->var());
2758     }
2759     cur = cur->get_parent();
2760   }
2761   return res;
2762 }
2763 
2764 class CacheReplacer : public IRMutator {
2765  public:
CacheReplacer(BufPtr buffer,BufPtr cache,std::vector<ExprPtr> & offsets)2766   CacheReplacer(BufPtr buffer, BufPtr cache, std::vector<ExprPtr>& offsets)
2767       : buf_(std::move(buffer)), cache_(std::move(cache)), offsets_(offsets) {}
2768 
2769  private:
mutate(const LoadPtr & v)2770   ExprPtr mutate(const LoadPtr& v) override {
2771     BufPtr buf = v->buf();
2772     if (buf != buf_) {
2773       return IRMutator::mutate(v);
2774     }
2775 
2776     // Map indices to call-parameters.
2777     std::vector<ExprPtr> newIndices;
2778     TORCH_INTERNAL_ASSERT(
2779         offsets_.size() == v->indices().size(),
2780         buildErrorMessage(
2781             "Expected ranks to match in CacheReplacer in the fuser."));
2782     for (size_t i = 0; i < v->indices().size(); ++i) {
2783       ExprPtr index = v->indices()[i]->accept_mutator(this);
2784       ExprPtr offset = offsets_[i];
2785       ExprPtr sub = IRSimplifier::simplify(alloc<Sub>(index, offset));
2786       newIndices.push_back(sub);
2787     }
2788     v->set_buf(cache_);
2789     v->set_indices(newIndices);
2790     return v;
2791   }
2792 
mutate(const StorePtr & v)2793   StmtPtr mutate(const StorePtr& v) override {
2794     BufPtr buf = v->buf();
2795     if (buf != buf_) {
2796       return IRMutator::mutate(v);
2797     }
2798 
2799     ExprPtr newValue = v->value()->accept_mutator(this);
2800 
2801     // Map indices to call-parameters.
2802     std::vector<ExprPtr> newIndices;
2803     TORCH_INTERNAL_ASSERT(
2804         offsets_.size() == v->indices().size(),
2805         buildErrorMessage(
2806             "Expected ranks to match in CacheReplacer in the fuser."));
2807     for (size_t i = 0; i < v->indices().size(); ++i) {
2808       ExprPtr index = v->indices()[i]->accept_mutator(this);
2809       ExprPtr offset = offsets_[i];
2810       ExprPtr sub = IRSimplifier::simplify(alloc<Sub>(index, offset));
2811       newIndices.push_back(sub);
2812     }
2813     v->set_buf(cache_);
2814     v->set_indices(newIndices);
2815     v->set_value(newValue);
2816     return v;
2817   }
2818 
2819   BufPtr buf_;
2820   BufPtr cache_;
2821   std::vector<ExprPtr>& offsets_;
2822 };
2823 
cacheAccesses(const BufPtr & producer,const std::string & name,const StmtPtr & consumer)2824 LoopNest::AccessResult LoopNest::cacheAccesses(
2825     const BufPtr& producer,
2826     const std::string& name,
2827     const StmtPtr& consumer) {
2828   ReduceOpPtr reduceOp{nullptr};
2829   auto stores = NodeFinder<Store>::find(consumer);
2830   for (const auto& store : stores) {
2831     if (auto ro = to<ReduceOp>(store->value())) {
2832       if (store->buf() != producer) {
2833         continue;
2834       }
2835 
2836       if (reduceOp) {
2837         throw std::runtime_error(
2838             "can only cache accesses used by at most a single reduceOp");
2839         return {nullptr, nullptr};
2840       }
2841 
2842       reduceOp = ro;
2843     }
2844   }
2845 
2846   // Check bounds but don't care about AccessKind.
2847   auto consumer_bounds_info = inferBounds(consumer, false);
2848   auto bounds_it = consumer_bounds_info.find(producer);
2849   if (bounds_it == consumer_bounds_info.end()) {
2850     throw std::runtime_error("consumer does not use the Tensor produced");
2851     return {nullptr, nullptr};
2852   }
2853 
2854   TORCH_INTERNAL_ASSERT(
2855       bounds_it->second.size() == 1,
2856       buildErrorMessage(
2857           "Unexpected number of bound info entries in cacheAccesses in the fuser."));
2858   TensorAccessBoundsInfo& info = bounds_it->second[0];
2859   bool hasReads = info.kind == kLoad || info.kind == kMutate;
2860   bool hasWrites = info.kind == kStore || info.kind == kMutate;
2861 
2862   std::vector<std::string> var_names = {"i", "j", "k", "l", "m", "n", "o", "p"};
2863   std::vector<ExprPtr> tmp_dims;
2864   std::vector<VarPtr> new_loop_vars;
2865   std::vector<ExprPtr> new_loop_vars_expr;
2866 
2867   // Determine the size of the cache, and create a loop var for each dimension.
2868   for (size_t i = 0; i < info.start.size(); ++i) {
2869     ExprPtr dim = IRSimplifier::simplify(alloc<Add>(
2870         alloc<Sub>(info.stop[i], info.start[i]), immLike(info.stop[i], 1)));
2871 
2872     tmp_dims.push_back(dim);
2873 
2874     new_loop_vars.push_back(
2875         alloc<Var>(var_names[i % var_names.size()], info.stop[i]->dtype()));
2876     new_loop_vars_expr.push_back(new_loop_vars[i]);
2877   }
2878 
2879   // Create the var.
2880   BufPtr tmp_buf =
2881       alloc<Buf>(alloc<Var>(name, kHandle), tmp_dims, producer->dtype());
2882 
2883   // determine the offsets for calls into the cache based off the loop start of
2884   // each axis.
2885   std::vector<ExprPtr> tmp_params;
2886   for (size_t i = 0; i < new_loop_vars.size(); ++i) {
2887     tmp_params.push_back(alloc<Add>(new_loop_vars[i], info.start[i]));
2888   }
2889 
2890   // Replace accesses to the producer in the consumer with the cache.
2891   CacheReplacer replacer(producer, tmp_buf, info.start);
2892   consumer->accept_mutator(&replacer);
2893 
2894   // replace the old consumer with the replaced consumer.
2895   BlockPtr consumer_block = to<Block>(consumer);
2896   BlockPtr parent_block = to<Block>(consumer->get_parent());
2897   // if the consumer is a block, we should mutate it in place.
2898   bool is_block = consumer_block != nullptr;
2899 
2900   // If there's a reduction and we are operating on the reduce axis, we need to
2901   // initialize the cache with 0s. Also, we can't just write the result straight
2902   // back to the original buffer, since after parallelism the writes will race.
2903   // Instead we need to create a new ReduceOp.
2904   bool on_reduce_axis = false;
2905   if (reduceOp) {
2906     std::set<VarPtr> reduce_args(
2907         reduceOp->reduce_args().begin(), reduceOp->reduce_args().end());
2908     std::set<VarPtr> enclosing_vars;
2909     for (const auto& enclosing_for_stmt : NodeFinder<For>::find(consumer)) {
2910       enclosing_vars.insert(enclosing_for_stmt->var());
2911     }
2912     for (const auto& reduce_arg : reduce_args) {
2913       if (enclosing_vars.find(reduce_arg) == enclosing_vars.end()) {
2914         on_reduce_axis = true;
2915       }
2916     }
2917   }
2918   if (reduceOp && on_reduce_axis) {
2919     // reduceOp means we had both loads and stores.
2920 
2921     // Init cache to 0.
2922     StmtPtr tmp_init = alloc<Store>(
2923         tmp_buf, new_loop_vars_expr, getImmediateByType(tmp_buf->dtype(), 0));
2924 
2925     for (int64_t i = new_loop_vars.size() - 1; i >= 0; --i) {
2926       tmp_init = alloc<For>(
2927           new_loop_vars[i], immLike(tmp_dims[i], 0), tmp_dims[i], tmp_init);
2928     }
2929 
2930     if (is_block) {
2931       consumer_block->prepend_stmt(tmp_init);
2932     } else {
2933       parent_block->insert_stmt_before(tmp_init, consumer);
2934     }
2935 
2936     // Reduce back to the original buffer:
2937     StmtPtr tmp_store = alloc<Store>(
2938         producer,
2939         tmp_params,
2940         reduceOp->reducer()(
2941             producer,
2942             alloc<Load>(tmp_buf, new_loop_vars_expr),
2943             tmp_params,
2944             {}));
2945 
2946     for (int64_t i = new_loop_vars.size() - 1; i >= 0; --i) {
2947       tmp_store = alloc<For>(
2948           new_loop_vars[i], immLike(tmp_dims[i], 0), tmp_dims[i], tmp_store);
2949     }
2950 
2951     if (is_block) {
2952       consumer_block->append_stmt(tmp_store);
2953     } else {
2954       parent_block->insert_stmt_after(tmp_store, consumer);
2955     }
2956 
2957     return std::make_pair(tmp_buf, consumer);
2958   }
2959 
2960   if (hasReads) {
2961     // Fill the cache with values from the consumer.
2962     StmtPtr tmp_store = alloc<Store>(
2963         tmp_buf, new_loop_vars_expr, alloc<Load>(producer, tmp_params));
2964 
2965     for (int64_t i = new_loop_vars.size() - 1; i >= 0; --i) {
2966       tmp_store = alloc<For>(
2967           new_loop_vars[i], immLike(tmp_dims[i], 0), tmp_dims[i], tmp_store);
2968     }
2969 
2970     if (is_block) {
2971       consumer_block->prepend_stmt(tmp_store);
2972     } else {
2973       parent_block->insert_stmt_before(tmp_store, consumer);
2974     }
2975   }
2976 
2977   if (hasWrites) {
2978     // sync the cache back to the producer buf.
2979     StmtPtr tmp_store = alloc<Store>(
2980         producer, tmp_params, alloc<Load>(tmp_buf, new_loop_vars_expr));
2981 
2982     for (int64_t i = static_cast<int64_t>(new_loop_vars.size()) - 1; i >= 0;
2983          --i) {
2984       tmp_store = alloc<For>(
2985           new_loop_vars[i], immLike(tmp_dims[i], 0), tmp_dims[i], tmp_store);
2986     }
2987 
2988     if (is_block) {
2989       consumer_block->append_stmt(tmp_store);
2990     } else {
2991       parent_block->insert_stmt_after(tmp_store, consumer);
2992     }
2993   }
2994 
2995   return std::make_pair(tmp_buf, consumer);
2996 }
2997 
2998 /*
2999  * WHAT COMPUTE_AT DOES
3000  * ====================
3001  *
3002  * Suppose we have two loops:
3003  *
3004  * for i in 0..100:
3005  *   for j in 0..200:
3006  *     A[i,j] = sin(i*j)
3007  * for i in 0..100:
3008  *   for j in 0..199:
3009  *     B[i,j] = A[i,j] + A[i, j+1]
3010  *
3011  * If we compute these loops as is, we would have to allocate two buffers:
3012  * 100x200 for A and 100x199 for B. To decrease the memory usage one can use
3013  * compute_inline primitive, which would result in the following:
3014  *
3015  * for i in 0..100:
3016  *   for j in 0..199:
3017  *     B[i,j] = sin(i*j) + sin(i*(j+1))
3018  *
3019  * We now need only one buffer - 100x199 for B. However, we're now doing some
3020  * redundant computations: we're calling `sin` twice as much as in the first
3021  * version.
3022  *
3023  * Ultimately, we nede to choose at what point we prefer to compute values of
3024  * A[i,j] - we can do it in the very beginning for the entire buffer A (the
3025  * first option) or compute it on the fly when we compute B (the second option).
3026  * There are also options in between those two: we can compute a part of B which
3027  * is required for a computation of part of B, e.g. for a single row of B. The
3028  * code would then look like:
3029  *
3030  * for i in 0..100:
3031  *   for j in 0..200:
3032  *     A[j] = sin(i*j)
3033  *   for j in 0..199:
3034  *     B[i,j] = A[j] + A[j+1]
3035  *
3036  * In this case we're only using 1x200 for A, and we're avoiding redundant
3037  * computations.
3038  *
3039  * The purpose of `compute_at` is to achieve exactly this transformation.
3040  *
3041  * compute_at requires to specify What to compute and Where to compute: in our
3042  * example we would call compute_at(What=`A[i,j] = sin(i*j)`, Where=`for i in
3043  * 0..100`).
3044  *
3045  * More info about compute_at could be found in Halide's tutorials:
3046  * https://halide-lang.org/tutorials/tutorial_lesson_08_scheduling_2.html
3047  *
3048  * HOW COMPUTE_AT WORKS
3049  * ====================
3050  *
3051  * The most important part of compute_at is bounds inference: we need to figure
3052  * out what part of the used tensors we need to compute when we move the
3053  * computation to a new scope. In the example above, we need bounds inference to
3054  * tell us that in order to compute A at each iteration of the outer loop, we
3055  * need to compute A within indices [i:i+1,0:200].
3056  *
3057  * This info allows us to conclude that we need a temp buffer of size 1x200.
3058  *
3059  * Once this is known we need to insert statements for allocation and freeing
3060  * the temporary buffer and copy the original computation to fill the temp
3061  * buffer with proper values. When we copy the computation we also must rewrite
3062  * indices used in it: old indices are referring to the old loop and are not
3063  * valid in the new loop.
3064  *
3065  * To easier follow the logic, let's examine an example. Suppose we start from
3066  * the following loop nest:
3067  *   for py in 0..100:
3068  *     for px in 0..100:
3069  *       producer[py,px] = py*px
3070  *   for cy in 0..100:
3071  *     for cx in 0..100:
3072  *       consumer[cy,cx] = producer[cy,cx]
3073  *
3074  * And then we're running `compute_at(producer, cy)`.
3075  *
3076  * What we would like to get is the following loop nest:
3077  *   for py in 0..100:
3078  *     for px in 0..100:
3079  *       producer[py,px] = py*px
3080  *   for cy in 0..100:
3081  *     Allocate(temp, {1, 100})
3082  *     for ty in 0..1:
3083  *       for tx in 0..100:
3084  *         temp[ty,tx] = (ty+cy)*(tx+0)
3085  *     for cx in 0..100:
3086  *       consumer[cy,cx] = temp[0,cx]
3087  *     Free(temp)
3088  *
3089  * NB: this loop nest can and should be simplified (e.g. the producer loop can
3090  * be removed since its result is no longer used), but this clean-up
3091  * optimization is performed separately (currently, not performed at all).
3092  *
3093  * If we examine the final loop nest, we can identify that the following steps
3094  * needs to be performed:
3095  *   - Bounds inference needs to tell us that we need a 1x100 buffer for temp.
3096  *   - Allocate and Free statements for this buffer need to be inserted to the
3097  *   loop.
3098  *   - A new loop-nest should be inserted to the loop CY for computing `temp`
3099  *   and it should replicate the loopnest of producer (PY,PX loops). The indices
3100  *   in the loop body need to be offset by (cy, 0) - the offsets come from
3101  *   bounds inference too.
3102  *   - The computation of `consumer` needs to be rewritten so that it uses
3103  *   `temp` instead of `producer`. The indices in the corresponding accesses
3104  *   also need to be offset.
3105  */
computeAt(const StmtPtr & s,const ForPtr & f)3106 void LoopNest::computeAt(const StmtPtr& s, const ForPtr& f) {
3107   StorePtr st = getStoreStmtOfProducer(s);
3108   if (!st) {
3109     return;
3110   }
3111 
3112   // Infer bounds info for all accesses that we make in the loop
3113   auto loop_bounds_info = inferBounds(f->body());
3114 
3115   // bounds_it holds bounds info for the store we're trying to move to
3116   // the loop. If its result isn't accessed in the loop at all - do nothing and
3117   // exit early.
3118   auto bounds_it = loop_bounds_info.find(st->buf());
3119   if (bounds_it == loop_bounds_info.end()) {
3120     return;
3121   }
3122 
3123   // Compute dimensions of the temp buffer we would need to allocate
3124   std::vector<ExprPtr> dims = getBoundExtents(bounds_it->second);
3125 
3126   // TODO: Use name-hint of the producer instead of "temp"
3127   BufPtr temp_buf = alloc<Buf>("temp", dims, st->value()->dtype());
3128 
3129   // Generate index variables for 'temp'
3130   std::vector<ExprPtr> temp_indices(dims.size());
3131   for (const auto i : c10::irange(dims.size())) {
3132     // TODO: Use name-hint of the producer indices instead of 'idx'
3133     temp_indices[i] =
3134         alloc<Var>(std::string("idx") + std::to_string(i), dims[i]->dtype());
3135   }
3136 
3137   // Prepare substitute rules for constructing the temp statement from the prod
3138   // statement
3139   // TODO: Instead of going up the loop nest we should go through the indices in
3140   // the original tensor expression. The loops in the nest might've been
3141   // modified (e.g. split or merged) so that the loop indices no longer
3142   // correspond to the indices of the original expression and even their number
3143   // might be different. In that case, the loop below would crash.
3144   std::vector<VarPtr> prod_indices = getOuterLoopIndexes(s);
3145   std::vector<std::pair<VarPtr, ExprPtr>> rewrite_indices_map;
3146   std::vector<ExprPtr> offsets;
3147   for (const TensorAccessBoundsInfo& p : bounds_it->second) {
3148     for (const auto i : c10::irange(p.start.size())) {
3149       if (offsets.size() <= i) {
3150         offsets.push_back(p.start[i]);
3151       } else {
3152         offsets[i] =
3153             IRSimplifier::simplify(alloc<Min>(offsets[i], p.start[i], true));
3154       }
3155     }
3156   }
3157 
3158   for (const auto i : c10::irange(prod_indices.size())) {
3159     rewrite_indices_map.emplace_back(
3160         prod_indices[i], alloc<Add>(temp_indices[i], offsets[i]));
3161   }
3162 
3163   // Construct the temp statement
3164   StmtPtr bd = alloc<Store>(
3165       temp_buf,
3166       temp_indices,
3167       SubstituteInClone(st->value(), rewrite_indices_map));
3168 
3169   // Construct the loop nest for the temp computation
3170   for (const auto i : c10::irange(dims.size())) {
3171     // We're creating loops from innermost to outermost, so we need to access
3172     // dimensions in reversed order.
3173     size_t dim_idx = dims.size() - 1 - i;
3174     bd = alloc<For>(
3175         to<Var>(temp_indices[dim_idx]),
3176         immLike(dims[dim_idx], 0),
3177         dims[dim_idx],
3178         bd);
3179   }
3180 
3181   // Add constructed stmts to the consumer loop
3182   f->body()->prepend_stmt(bd);
3183 
3184   // Rewrite accesses to producer in consumer with accesses to temp
3185   LoopComputeAtRewriter lr(st->buf(), temp_buf, offsets);
3186   StmtPtr new_f = f->accept_mutator(&lr);
3187   if (f != new_f) {
3188     BlockPtr bb = to<Block>(f->get_parent());
3189     bb->replace_stmt(f, new_f);
3190   }
3191 }
3192 
3193 class RfactorStoreRewriter : public IRMutator {
3194  public:
RfactorStoreRewriter(BufPtr old_buf,const std::vector<ExprPtr> & old_indices,BufPtr new_buf,VarPtr reduction_var)3195   RfactorStoreRewriter(
3196       BufPtr old_buf,
3197       const std::vector<ExprPtr>& old_indices,
3198       BufPtr new_buf,
3199       VarPtr reduction_var)
3200       : old_buf_(std::move(old_buf)),
3201         old_indices_(old_indices),
3202         new_buf_(std::move(new_buf)),
3203         reduction_var_(std::move(reduction_var)),
3204         new_indices_(old_indices) {
3205     new_indices_.push_back(reduction_var_);
3206   }
3207 
mutate(const LoadPtr & v)3208   ExprPtr mutate(const LoadPtr& v) override {
3209     if (v->buf() != old_buf_) {
3210       return IRMutator::mutate(v);
3211     }
3212 
3213     TORCH_INTERNAL_ASSERT(
3214         old_indices_.size() == v->indices().size(),
3215         buildErrorMessage(
3216             "Expected ranks to match in RfactorStoreRewriter in the fuser."));
3217 
3218     bool equal_indices = true;
3219     for (size_t i = 0; i < v->indices().size(); ++i) {
3220       if (!exprEquals(v->indices()[i], old_indices_[i])) {
3221         equal_indices = false;
3222         break;
3223       }
3224     }
3225     if (!equal_indices) {
3226       return IRMutator::mutate(v);
3227     }
3228 
3229     return alloc<Load>(new_buf_, new_indices_);
3230   }
3231 
mutate(const ReduceOpPtr & v)3232   ExprPtr mutate(const ReduceOpPtr& v) override {
3233     ExprPtr body_new = v->body()->accept_mutator(this);
3234 
3235     std::vector<VarPtr> new_reduce_args;
3236     for (const auto& r : v->reduce_args()) {
3237       if (r != reduction_var_) {
3238         new_reduce_args.push_back(r);
3239       }
3240     }
3241 
3242     return alloc<ReduceOp>(body_new, new_reduce_args, v->reducer());
3243   }
3244 
mutate(const StorePtr & v)3245   StmtPtr mutate(const StorePtr& v) override {
3246     if (v->buf() != old_buf_) {
3247       return IRMutator::mutate(v);
3248     }
3249 
3250     TORCH_INTERNAL_ASSERT(
3251         old_indices_.size() == v->indices().size(),
3252         buildErrorMessage(
3253             "Expected ranks to match in RfactorStoreRewriter in the fuser."));
3254 
3255     bool equal_indices = true;
3256     for (size_t i = 0; i < v->indices().size(); ++i) {
3257       if (!exprEquals(v->indices()[i], old_indices_[i])) {
3258         equal_indices = false;
3259         break;
3260       }
3261     }
3262     if (!equal_indices) {
3263       return IRMutator::mutate(v);
3264     }
3265 
3266     ExprPtr new_value = v->value()->accept_mutator(this);
3267     return alloc<Store>(new_buf_, new_indices_, new_value);
3268   }
3269 
3270  private:
3271   BufPtr old_buf_;
3272   const std::vector<ExprPtr>& old_indices_;
3273   BufPtr new_buf_;
3274   VarPtr reduction_var_;
3275   std::vector<ExprPtr> new_indices_;
3276 };
3277 
rfactor(const StmtPtr & st,const ForPtr & target_for)3278 bool LoopNest::rfactor(const StmtPtr& st, const ForPtr& target_for) {
3279   BufPtr tmp_buf = nullptr;
3280   return rfactor(st, target_for, &tmp_buf);
3281 }
3282 
rfactor(const StmtPtr & st,const ForPtr & outer_reduction_for,BufPtr * rfac_buf_ptr)3283 bool LoopNest::rfactor(
3284     const StmtPtr& st,
3285     const ForPtr& outer_reduction_for,
3286     BufPtr* rfac_buf_ptr) {
3287   StorePtr reduction_store = to<Store>(st);
3288   ReduceOpPtr reduce_op = to<ReduceOp>(reduction_store->value());
3289   if (!reduce_op) {
3290     // Not a reduction store
3291     return false;
3292   }
3293 
3294   auto orig_buf = reduction_store->buf();
3295   auto orig_buf_indices = reduction_store->indices();
3296   VarPtr reduction_var = outer_reduction_for->var();
3297 
3298   std::set<VarPtr> reduce_args = {
3299       reduce_op->reduce_args().begin(), reduce_op->reduce_args().end()};
3300 
3301   if (reduce_args.size() < 2) {
3302     // Not enough reduction axis to do rfactor
3303     return false;
3304   }
3305 
3306   // Verify that outer_reduction_for is a perfect loop nest with all loops being
3307   // reductions
3308   StmtPtr cur = outer_reduction_for;
3309   while (ForPtr cur_for = to<For>(cur)) {
3310     if (!reduce_args.count(cur_for->var())) {
3311       // output axis inside outer_reduction_for are not allowed
3312       return false;
3313     }
3314     reduce_args.erase(cur_for->var());
3315 
3316     BlockPtr b = cur_for->body();
3317     if (b->nstmts() != 1) {
3318       return false;
3319     }
3320     cur = b->stmts().front();
3321   }
3322   if (cur != st) {
3323     // The reduction store is not a single stmt in the innermost loop - bail in
3324     // that case
3325     return false;
3326   }
3327   if (!reduce_args.empty()) {
3328     // This is not the outermost reduction axis
3329     return false;
3330   }
3331 
3332   // assert: reduce_axis match loop vars from outer_reduction_for and inside
3333   // assert: no other stmts in outer_reduction_for or its child loops
3334 
3335   std::vector<ExprPtr> rfac_dims = orig_buf->dims();
3336   ExprPtr extra_dim = IRSimplifier::simplify(
3337       alloc<Sub>(outer_reduction_for->stop(), outer_reduction_for->start()));
3338   rfac_dims.push_back(extra_dim);
3339   ExprPtr rfac_init =
3340       alloc<Cast>(reduce_op->dtype(), reduce_op->reducer().initializer());
3341 
3342   *rfac_buf_ptr = alloc<Buf>(
3343       orig_buf->name_hint() + "_rfac",
3344       rfac_dims,
3345       reduce_op->dtype(),
3346       rfac_init);
3347   BufPtr rfac_buf = *rfac_buf_ptr;
3348 
3349   // Rewrite the original reduction store to use the temporary rfac buffer:
3350   //   1) X[*indexes] --> T[*indexes + {reduction_var}]
3351   //   2) reduce_axis -= {reduction_var}
3352   RfactorStoreRewriter rfac_rewriter(
3353       orig_buf, orig_buf_indices, rfac_buf, reduction_var);
3354   to<Block>(st->get_parent())
3355       ->replace_stmt(st, st->accept_mutator(&rfac_rewriter));
3356 
3357   // Insert a store for the final reduction over the temp buffer into the
3358   // original buffer:
3359   //   X[*indexes] = ReduceOp(X[*indexes] + T[*indexes + {reduction_var}],
3360   //                          reduce_axis={reduction_var})
3361   BlockPtr b = outer_reduction_for->body();
3362   TORCH_INTERNAL_ASSERT(
3363       b->nstmts() == 1,
3364       buildErrorMessage(
3365           "Expected to have a single stmt in the block in rfactor transformation in the fuser."));
3366   StmtPtr first_reduction_loop = b->stmts().front();
3367   auto rfac_buf_indices = orig_buf_indices;
3368   rfac_buf_indices.emplace_back(reduction_var);
3369 
3370   ExprPtr final_reduce_load = alloc<Load>(rfac_buf, rfac_buf_indices);
3371   outer_reduction_for->body()->insert_stmt_after(
3372       alloc<Store>(
3373           orig_buf,
3374           orig_buf_indices,
3375           reduce_op->reducer()(
3376               orig_buf, final_reduce_load, orig_buf_indices, {reduction_var})),
3377       first_reduction_loop);
3378 
3379   // Insert an initialization store for the temp buffer:
3380   //   T[a,b,c] = init
3381   outer_reduction_for->body()->insert_stmt_before(
3382       alloc<Store>(rfac_buf, rfac_buf_indices, rfac_init),
3383       first_reduction_loop);
3384   return true;
3385 }
3386 
3387 } // namespace torch::jit::tensorexpr
3388