xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/tensorexpr/ir_simplifier.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/jit_log.h>
2 #include <torch/csrc/jit/tensorexpr/bounds_overlap.h>
3 #include <torch/csrc/jit/tensorexpr/ir_printer.h>
4 #include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
5 
6 #include <utility>
7 
8 namespace torch::jit::tensorexpr {
9 
10 // Creates a new Expr of the given type with the provided lhs and rhs.
newBinaryOpOfType(IRNodeType expr_type,const ExprPtr & lhs,const ExprPtr & rhs,bool option)11 inline ExprPtr newBinaryOpOfType(
12     IRNodeType expr_type,
13     const ExprPtr& lhs,
14     const ExprPtr& rhs,
15     bool option) {
16   switch (expr_type) {
17     case IRNodeType::kAdd:
18       return alloc<Add>(lhs, rhs);
19     case IRNodeType::kSub:
20       return alloc<Sub>(lhs, rhs);
21     case IRNodeType::kMul:
22       return alloc<Mul>(lhs, rhs);
23     case IRNodeType::kDiv:
24       return alloc<Div>(lhs, rhs);
25     case IRNodeType::kMod:
26       return alloc<Mod>(lhs, rhs);
27     case IRNodeType::kMax:
28       return alloc<Max>(lhs, rhs, option);
29     case IRNodeType::kMin:
30       return alloc<Min>(lhs, rhs, option);
31     case IRNodeType::kAnd:
32       return alloc<And>(lhs, rhs);
33     case IRNodeType::kXor:
34       return alloc<Xor>(lhs, rhs);
35     case IRNodeType::kLshift:
36       return alloc<Lshift>(lhs, rhs);
37     case IRNodeType::kRshift:
38       return alloc<Rshift>(lhs, rhs);
39     default:
40       LOG(FATAL) << "unsupported expr_type: " << static_cast<int>(expr_type);
41       return nullptr;
42   }
43 }
44 
45 template <
46     typename Op,
47     std::enable_if_t<std::is_same_v<
48         decltype(detail::bin_op_deducer(std::declval<Op>())),
49         void>>* = nullptr>
mutateBinaryOp(NodePtr<Op> v,IRMutator * mutator,bool option=false)50 static ExprPtr mutateBinaryOp(
51     NodePtr<Op> v,
52     IRMutator* mutator,
53     bool option = false) {
54   ExprPtr lhs = v->lhs();
55   ExprPtr rhs = v->rhs();
56   ExprPtr lhs_new = lhs->accept_mutator(mutator);
57   ExprPtr rhs_new = rhs->accept_mutator(mutator);
58 
59   ExprPtr node = v;
60 
61   if (lhs != lhs_new || rhs != rhs_new) {
62     node = newBinaryOpOfType(v->expr_type(), lhs_new, rhs_new, option);
63   }
64 
65   // Can only fold if both sides are constant.
66   if (!lhs_new->isConstant() || !rhs_new->isConstant()) {
67     return node;
68   }
69 
70   return evaluateOp(node);
71 }
72 
73 // Simple recursive GCD.
74 template <typename T>
gcd(T a,T b)75 T gcd(T a, T b) {
76   if (b == 0) {
77     return a;
78   }
79   return gcd(b, a % b);
80 }
81 
82 // Helper for determining if an Expr is a multi-lane primitive (e.g. Broadcast
83 // or Ramp).
isMultilanePrimitive(const ExprPtr & e)84 static bool isMultilanePrimitive(const ExprPtr& e) {
85   return to<Broadcast>(e) || to<Ramp>(e);
86 }
87 
hashVars() const88 SimplifierHashType Term::hashVars() const {
89   SimplifierHashType hash;
90   for (const auto& v : variables_) {
91     hash = hasher_.hash_combine(hash, hasher_.hash(v));
92   }
93 
94   return hash;
95 }
96 
sort()97 void Term::sort() {
98   // order of ops important for float
99   if (dtype().is_floating_point()) {
100     throw std::logic_error("reordering FP ops");
101   }
102   std::unordered_map<ExprPtr, std::string> str_repr_cache;
103   std::sort(
104       variables_.begin(),
105       variables_.end(),
106       [&](const ExprPtr& a, const ExprPtr& b) {
107         if (!str_repr_cache.count(a)) {
108           str_repr_cache[a] = std::to_string(a);
109         }
110         if (!str_repr_cache.count(b)) {
111           str_repr_cache[b] = std::to_string(b);
112         }
113         return str_repr_cache.at(a) < str_repr_cache.at(b);
114       });
115 }
116 
hashVars() const117 SimplifierHashType Polynomial::hashVars() const {
118   SimplifierHashType hash;
119   for (const auto& v : variables_) {
120     hash = hasher_.hash_combine(hash, hasher_.hash(v));
121   }
122   return hash;
123 }
124 
sort()125 void Polynomial::sort() {
126   if (dtype().is_floating_point()) {
127     throw std::logic_error("reordering FP ops");
128   }
129   std::unordered_map<ExprPtr, std::string> str_repr_cache;
130   std::sort(
131       variables_.begin(),
132       variables_.end(),
133       [&](const ExprPtr& a, const ExprPtr& b) {
134         if (!str_repr_cache.count(a)) {
135           str_repr_cache[a] = std::to_string(a);
136         }
137         if (!str_repr_cache.count(b)) {
138           str_repr_cache[b] = std::to_string(b);
139         }
140         return str_repr_cache.at(a) < str_repr_cache.at(b);
141       });
142 }
143 
uniquefy()144 void MaxTerm::uniquefy() {
145   std::sort(
146       variables_.begin(),
147       variables_.end(),
148       [&](const ExprPtr& a, const ExprPtr& b) {
149         return hasher_.hash(a) < hasher_.hash(b);
150       });
151   auto it = std::unique(
152       variables_.begin(),
153       variables_.end(),
154       [&](const ExprPtr& a, const ExprPtr& b) {
155         return hasher_.hash(a) == hasher_.hash(b);
156       });
157   variables_.resize(std::distance(variables_.begin(), it));
158 
159   // Once we removed duplicates, sort terms alphabetically for stability.
160   std::unordered_map<ExprPtr, std::string> str_repr_cache;
161   std::sort(
162       variables_.begin(),
163       variables_.end(),
164       [&](const ExprPtr& a, const ExprPtr& b) {
165         if (!str_repr_cache.count(a)) {
166           str_repr_cache[a] = std::to_string(a);
167         }
168         if (!str_repr_cache.count(b)) {
169           str_repr_cache[b] = std::to_string(b);
170         }
171         return str_repr_cache.at(a) < str_repr_cache.at(b);
172       });
173 }
174 
uniquefy()175 void MinTerm::uniquefy() {
176   std::sort(
177       variables_.begin(),
178       variables_.end(),
179       [&](const ExprPtr& a, const ExprPtr& b) {
180         return hasher_.hash(a) < hasher_.hash(b);
181       });
182   auto it = std::unique(
183       variables_.begin(),
184       variables_.end(),
185       [&](const ExprPtr& a, const ExprPtr& b) {
186         return hasher_.hash(a) == hasher_.hash(b);
187       });
188   variables_.resize(std::distance(variables_.begin(), it));
189 
190   // Once we removed duplicates, sort terms alphabetically for stability.
191   std::unordered_map<ExprPtr, std::string> str_repr_cache;
192   std::sort(
193       variables_.begin(),
194       variables_.end(),
195       [&](const ExprPtr& a, const ExprPtr& b) {
196         if (!str_repr_cache.count(a)) {
197           str_repr_cache[a] = std::to_string(a);
198         }
199         if (!str_repr_cache.count(b)) {
200           str_repr_cache[b] = std::to_string(b);
201         }
202         return str_repr_cache.at(a) < str_repr_cache.at(b);
203       });
204 }
205 
206 // Handles optimization cases for Broadcast/Ramp +/- Broadcast/Ramp
207 template <class Op>
combineMultilane(const ExprPtr & lhs,const ExprPtr & rhs)208 ExprPtr combineMultilane(const ExprPtr& lhs, const ExprPtr& rhs) {
209   if (BroadcastPtr bc = to<Broadcast>(lhs)) {
210     if (BroadcastPtr bcother = to<Broadcast>(rhs)) {
211       if (bc->lanes() != bcother->lanes()) {
212         throw malformed_input("multilane lane mismatch");
213       }
214 
215       ExprPtr ret = alloc<Broadcast>(
216           alloc<Op>(bc->value(), bcother->value()), bc->lanes());
217       return ret;
218     }
219 
220     if (RampPtr r = to<Ramp>(rhs)) {
221       if (bc->lanes() != r->lanes()) {
222         throw malformed_input("multilane lane mismatch");
223       }
224 
225       ExprPtr ret = alloc<Ramp>(
226           alloc<Op>(bc->value(), r->base()), r->stride(), r->lanes());
227       return ret;
228     }
229   } else if (RampPtr ramp = to<Ramp>(lhs)) {
230     if (RampPtr rother = to<Ramp>(rhs)) {
231       if (ramp->lanes() != rother->lanes()) {
232         throw malformed_input("multilane lane mismatch");
233       }
234 
235       ExprPtr ret = alloc<Ramp>(
236           alloc<Op>(ramp->base(), rother->base()),
237           alloc<Op>(ramp->stride(), rother->stride()),
238           ramp->lanes());
239       return ret;
240     }
241 
242     if (BroadcastPtr bc = to<Broadcast>(rhs)) {
243       if (ramp->lanes() != bc->lanes()) {
244         throw malformed_input("multilane lane mismatch");
245       }
246       ExprPtr ret = alloc<Ramp>(
247           alloc<Op>(ramp->base(), bc->value()), ramp->stride(), ramp->lanes());
248       return ret;
249     }
250   }
251 
252   return nullptr;
253 }
254 
255 // Handles optimization cases for Broadcast/Ramp * Broadcast/Ramp
mulMultilane(const ExprPtr & lhs,const ExprPtr & rhs)256 static ExprPtr mulMultilane(const ExprPtr& lhs, const ExprPtr& rhs) {
257   if (BroadcastPtr bc = to<Broadcast>(lhs)) {
258     if (BroadcastPtr bcother = to<Broadcast>(rhs)) {
259       if (bc->lanes() != bcother->lanes()) {
260         throw malformed_input("multilane lane mismatch");
261       }
262 
263       ExprPtr ret = alloc<Broadcast>(
264           alloc<Mul>(bc->value(), bcother->value()), bc->lanes());
265       return ret;
266     }
267 
268     if (RampPtr r = to<Ramp>(rhs)) {
269       if (bc->lanes() != r->lanes()) {
270         throw malformed_input("multilane lane mismatch");
271       }
272 
273       ExprPtr ret = alloc<Ramp>(
274           alloc<Mul>(bc->value(), r->base()),
275           alloc<Mul>(bc->value(), r->stride()),
276           r->lanes());
277       return ret;
278     }
279   } else if (RampPtr ramp = to<Ramp>(lhs)) {
280     if (RampPtr r = to<Ramp>(rhs)) {
281       if (ramp->lanes() != r->lanes()) {
282         throw malformed_input("multilane lane mismatch");
283       }
284 
285       ExprPtr ret = alloc<Ramp>(
286           alloc<Mul>(ramp->base(), r->base()),
287           alloc<Mul>(ramp->stride(), r->stride()),
288           r->lanes());
289       return ret;
290     }
291 
292     if (BroadcastPtr bc = to<Broadcast>(rhs)) {
293       if (ramp->lanes() != bc->lanes()) {
294         throw malformed_input("multilane lane mismatch");
295       }
296 
297       ExprPtr ret = alloc<Ramp>(
298           alloc<Mul>(bc->value(), ramp->base()),
299           alloc<Mul>(bc->value(), ramp->stride()),
300           ramp->lanes());
301       return ret;
302     }
303   }
304 
305   return nullptr;
306 }
307 
addOrUpdateTerm(std::unordered_map<SimplifierHashType,TermPtr> & varmap,const TermPtr & term)308 void PolynomialTransformer::addOrUpdateTerm(
309     std::unordered_map<SimplifierHashType, TermPtr>& varmap,
310     const TermPtr& term) {
311   SimplifierHashType hash = term->hashVars();
312   auto insertRes = varmap.emplace(hash, term);
313   if (insertRes.second == false) {
314     TermPtr lt = insertRes.first->second;
315     ExprPtr termScalar = evaluateOp(alloc<Add>(lt->scalar(), term->scalar()));
316 
317     // If the term is canceled out, remove from the map.
318     if (immediateEquals(termScalar, 0)) {
319       varmap.erase(hash);
320       return;
321     }
322 
323     varmap[hash] = alloc<Term>(hasher_, termScalar, lt->variables());
324   }
325 }
326 
addPolynomials(const PolynomialPtr & lhs,const PolynomialPtr & rhs)327 ExprPtr PolynomialTransformer::addPolynomials(
328     const PolynomialPtr& lhs,
329     const PolynomialPtr& rhs) {
330   // simplify common components
331   // The key here is the variable hash, not the term's hash since we do want
332   // to combine terms that have the same vars but different scalar components.
333   std::unordered_map<SimplifierHashType, TermPtr> varmap;
334 
335   for (const auto& lt : lhs->variables()) {
336     addOrUpdateTerm(varmap, lt);
337   }
338   for (const auto& rt : rhs->variables()) {
339     addOrUpdateTerm(varmap, rt);
340   }
341 
342   ExprPtr newScalar = evaluateOp(alloc<Add>(lhs->scalar(), rhs->scalar()));
343   return alloc<Polynomial>(hasher_, newScalar, varmap);
344 }
345 
346 // Insert a new Term into the provided polynomial. If the new term has common
347 // variables to an existing term it is combined.
insertTerm(const PolynomialPtr & poly,const TermPtr & term)348 ExprPtr PolynomialTransformer::insertTerm(
349     const PolynomialPtr& poly,
350     const TermPtr& term) {
351   SimplifierHashType tHash = term->hashVars();
352   std::vector<TermPtr> newVars;
353 
354   bool found = false;
355   for (const auto& v : poly->variables()) {
356     if (v->hashVars() == tHash) {
357       ExprPtr newScalar = evaluateOp(alloc<Add>(term->scalar(), v->scalar()));
358       found = true;
359       // Skip this term if we cancelled it out.
360       if (immediateEquals(newScalar, 0)) {
361         continue;
362       }
363       auto term = alloc<Term>(hasher_, newScalar, v->variables());
364       newVars.push_back(term);
365     } else {
366       newVars.push_back(v);
367     }
368   }
369 
370   if (!found) {
371     newVars.push_back(term);
372   }
373 
374   if (newVars.empty()) {
375     return poly->scalar();
376   }
377 
378   auto Poly = alloc<Polynomial>(hasher_, poly->scalar(), newVars);
379   return Poly;
380 }
381 
mutate(const AddPtr & v)382 ExprPtr PolynomialTransformer::mutate(const AddPtr& v) {
383   ExprPtr lhs_new = v->lhs()->accept_mutator(this);
384   ExprPtr rhs_new = v->rhs()->accept_mutator(this);
385 
386   // Constant Folding.
387   if (lhs_new->isConstant() && rhs_new->isConstant()) {
388     ExprPtr result = evaluateOp(alloc<Add>(lhs_new, rhs_new));
389     return result;
390   }
391 
392   // Multilane folding.
393   if (isMultilanePrimitive(lhs_new)) {
394     if (auto ret = combineMultilane<Add>(lhs_new, rhs_new)) {
395       return ret->accept_mutator(this);
396     }
397   }
398 
399   ExprPtr scalar = nullptr;
400   ExprPtr variable = nullptr;
401   if (lhs_new->isConstant()) {
402     scalar = evaluateOp(lhs_new);
403     variable = rhs_new;
404   } else if (rhs_new->isConstant()) {
405     scalar = evaluateOp(rhs_new);
406     variable = lhs_new;
407   }
408 
409   // If there is a scalar, and it's zero: short circuit and return the other
410   // side.
411   if (scalar && immediateEquals(scalar, 0)) {
412     auto c = alloc<Cast>(v->dtype(), variable);
413     return c->accept_mutator(this);
414   }
415 
416   // If this is a floating point Add then order of operations is important, we
417   // dont want to combine ops.
418   if (lhs_new->dtype().is_floating_point() ||
419       rhs_new->dtype().is_floating_point()) {
420     return alloc<Add>(lhs_new, rhs_new);
421   }
422 
423   PolynomialPtr lhsPoly = to<Polynomial>(lhs_new);
424   PolynomialPtr rhsPoly = to<Polynomial>(rhs_new);
425 
426   if (lhsPoly && rhsPoly) {
427     return addPolynomials(lhsPoly, rhsPoly);
428   }
429 
430   TermPtr lhsTerm = to<Term>(lhs_new);
431   TermPtr rhsTerm = to<Term>(rhs_new);
432 
433   if (lhsPoly && rhsTerm) {
434     return insertTerm(lhsPoly, rhsTerm);
435   }
436 
437   if (rhsPoly && lhsTerm) {
438     return insertTerm(rhsPoly, lhsTerm);
439   }
440 
441   if (lhsTerm && rhsTerm) {
442     // If the terms refer to the same variables: combine them.
443     if (lhsTerm->hashVars() == rhsTerm->hashVars()) {
444       ExprPtr newScalar =
445           evaluateOp(alloc<Add>(lhsTerm->scalar(), rhsTerm->scalar()));
446 
447       // If the terms cancelled out, return zero.
448       if (immediateEquals(newScalar, 0)) {
449         return newScalar->accept_mutator(this);
450       }
451 
452       return alloc<Term>(hasher_, newScalar, lhsTerm->variables());
453     }
454 
455     // Otherwise this is a new polynomial with no scalar and two variable
456     // terms.
457     return alloc<Polynomial>(hasher_, immLike(v, 0), lhsTerm, rhsTerm);
458   }
459 
460   // Adds are commutative.
461   PolynomialPtr poly = lhsPoly ? lhsPoly : rhsPoly;
462 
463   // Add to Polynomial->scalar().
464   if (scalar && poly) {
465     ExprPtr newScalar = evaluateOp(alloc<Add>(scalar, poly->scalar()));
466     return alloc<Polynomial>(hasher_, newScalar, poly->variables());
467   }
468 
469   // Simple Polynomial with a scalar and Term.
470   TermPtr term = lhsTerm ? lhsTerm : rhsTerm;
471   if (scalar && term) {
472     return alloc<Polynomial>(hasher_, scalar, term);
473   }
474 
475   // Simple Term with a scalar and variable type.
476   if (scalar) {
477     return alloc<Polynomial>(
478         hasher_, scalar, alloc<Term>(hasher_, immLike(v, 1), variable));
479   }
480 
481   // If LHS is neither Term not Polynomial, wrap it in a Term.
482   if (!lhsTerm && !lhsPoly) {
483     lhsTerm = alloc<Term>(hasher_, immLike(v, 1), lhs_new);
484   }
485 
486   // Same for RHS.
487   if (!rhsTerm && !rhsPoly) {
488     rhsTerm = alloc<Term>(hasher_, immLike(v, 1), rhs_new);
489   }
490 
491   // If we now have a poly and a term, we can insert.
492   if (poly) {
493     return insertTerm(poly, lhsTerm ? lhsTerm : rhsTerm);
494   }
495 
496   if (lhsTerm->hashVars() == rhsTerm->hashVars()) {
497     return alloc<Term>(
498         hasher_,
499         evaluateOp(alloc<Add>(lhsTerm->scalar(), rhsTerm->scalar())),
500         lhsTerm->variables());
501   }
502 
503   // If all else fails we have a new Polynomial with two new variable Terms.
504   return alloc<Polynomial>(hasher_, immLike(v, 0), lhsTerm, rhsTerm);
505 }
506 
subTerms(const TermPtr & lhs,TermPtr rhs,bool negated)507 ExprPtr PolynomialTransformer::subTerms(
508     const TermPtr& lhs,
509     TermPtr rhs,
510     bool negated) {
511   // If RHS not already negated, negate it.
512   if (!negated) {
513     ExprPtr minusOne = immLike(rhs, -1);
514     ExprPtr negateScalar = evaluateOp(alloc<Mul>(minusOne, rhs->scalar()));
515     rhs = alloc<Term>(hasher_, negateScalar, rhs->variables());
516   }
517 
518   if (lhs->hashVars() == rhs->hashVars()) {
519     ExprPtr newScalar = evaluateOp(alloc<Add>(lhs->scalar(), rhs->scalar()));
520 
521     // If the terms cancel out, return zero.
522     if (immediateEquals(newScalar, 0)) {
523       return newScalar;
524     }
525 
526     return alloc<Term>(hasher_, newScalar, lhs->variables());
527   }
528 
529   return alloc<Polynomial>(
530       hasher_,
531       getImmediateByType(promoteTypes(lhs->dtype(), rhs->dtype()), 0),
532       lhs,
533       rhs);
534 }
535 
536 // Subtract the RHS Polynomial from the LHS Polynomial, cancelling out where
537 // possible.
subPolynomials(const PolynomialPtr & lhs,const PolynomialPtr & rhs)538 ExprPtr PolynomialTransformer::subPolynomials(
539     const PolynomialPtr& lhs,
540     const PolynomialPtr& rhs) {
541   // simplify common components
542   // The key here is the variable hash, not the term's hash since we do want
543   // to combine terms that have the same vars but different scalar components.
544   std::unordered_map<SimplifierHashType, TermPtr> varmap;
545 
546   for (const auto& lt : lhs->variables()) {
547     addOrUpdateTerm(varmap, lt);
548   }
549 
550   for (const auto& rt : rhs->variables()) {
551     // Polynomials add their terms, so negate the RHS's Terms.
552     ExprPtr negated = evaluateOp(alloc<Mul>(immLike(rt, -1), rt->scalar()));
553     TermPtr newRHS = alloc<Term>(hasher_, negated, rt->variables());
554     addOrUpdateTerm(varmap, newRHS);
555   }
556 
557   ExprPtr newScalar = evaluateOp(alloc<Sub>(lhs->scalar(), rhs->scalar()));
558 
559   // No vars means this cancelled out to a scalar, return it unwrapped.
560   if (varmap.empty()) {
561     return newScalar;
562   }
563 
564   // If there is no scalar and zero or one terms, don't wrap.
565   if (immediateEquals(newScalar, 0)) {
566     if (varmap.empty()) {
567       return nullptr;
568     }
569     if (varmap.size() == 1) {
570       return varmap.begin()->second;
571     }
572   }
573 
574   // Wrap new variables in a Polynomial.
575   return alloc<Polynomial>(hasher_, newScalar, varmap);
576 }
577 
mutate(const SubPtr & v)578 ExprPtr PolynomialTransformer::mutate(const SubPtr& v) {
579   ExprPtr lhs_new = v->lhs()->accept_mutator(this);
580   ExprPtr rhs_new = v->rhs()->accept_mutator(this);
581 
582   // Constant Folding.
583   if (lhs_new->isConstant() && rhs_new->isConstant()) {
584     ExprPtr result = evaluateOp(alloc<Sub>(lhs_new, rhs_new));
585     return result;
586   }
587 
588   // Multilane folding.
589   if (isMultilanePrimitive(lhs_new)) {
590     if (auto ret = combineMultilane<Sub>(lhs_new, rhs_new)) {
591       return ret->accept_mutator(this);
592     }
593   }
594 
595   if (rhs_new->isConstant() && immediateEquals(rhs_new, 0)) {
596     auto c = alloc<Cast>(v->dtype(), lhs_new);
597     return c->accept_mutator(this);
598   }
599 
600   // If this is a floating point Sub then order of operations is important, we
601   // dont want to combine ops.
602   if (lhs_new->dtype().is_floating_point() ||
603       rhs_new->dtype().is_floating_point()) {
604     return alloc<Sub>(lhs_new, rhs_new);
605   }
606 
607   PolynomialPtr lhsPoly = to<Polynomial>(lhs_new);
608   PolynomialPtr rhsPoly = to<Polynomial>(rhs_new);
609 
610   if (lhsPoly && rhsPoly) {
611     auto ret = subPolynomials(lhsPoly, rhsPoly);
612     if (!ret) {
613       // Cancelled out completely.
614       return immLike(v, 0);
615     }
616     return ret;
617   }
618 
619   TermPtr lhsTerm = to<Term>(lhs_new);
620   TermPtr rhsTerm = to<Term>(rhs_new);
621 
622   // Polynomial - Term.
623   if (lhsPoly && rhsTerm) {
624     // Negate the term.
625     ExprPtr negate =
626         evaluateOp(alloc<Mul>(immLike(rhsTerm, -1), rhsTerm->scalar()));
627     TermPtr newTerm = alloc<Term>(hasher_, negate, rhsTerm->variables());
628     return insertTerm(lhsPoly, newTerm);
629   }
630 
631   // Term - Polynomial.
632   if (rhsPoly && lhsTerm) {
633     // Negate every part of the Polynomial.
634     ExprPtr minusOne = immLike(lhsTerm, -1);
635     ExprPtr negateScalar = evaluateOp(alloc<Mul>(minusOne, rhsPoly->scalar()));
636 
637     std::vector<TermPtr> variables;
638     for (const auto& t : rhsPoly->variables()) {
639       ExprPtr negate = evaluateOp(alloc<Mul>(minusOne, t->scalar()));
640       variables.push_back(alloc<Term>(hasher_, negate, t->variables()));
641     }
642 
643     PolynomialPtr newPoly = alloc<Polynomial>(hasher_, negateScalar, variables);
644     return insertTerm(newPoly, lhsTerm);
645   }
646 
647   if (lhsTerm && rhsTerm) {
648     return subTerms(lhsTerm, rhsTerm, false);
649   }
650 
651   bool lhsScalar = lhs_new->isConstant();
652   bool rhsScalar = rhs_new->isConstant();
653 
654   if (lhsPoly && rhsScalar) {
655     // Easy path, just sub the scalar component.
656     ExprPtr newScalar = evaluateOp(alloc<Sub>(lhsPoly->scalar(), rhs_new));
657     return alloc<Polynomial>(hasher_, newScalar, lhsPoly->variables());
658   }
659 
660   if (lhsScalar && rhsPoly) {
661     // Sub the scalar component.
662     ExprPtr newScalar = evaluateOp(alloc<Sub>(lhs_new, rhsPoly->scalar()));
663 
664     // Negate each term in the Polynomial RHS.
665     ExprPtr minusOne = immLike(rhsPoly, -1);
666     std::vector<TermPtr> variables;
667     for (const auto& t : rhsPoly->variables()) {
668       ExprPtr negate = evaluateOp(alloc<Mul>(minusOne, t->scalar()));
669       variables.push_back(alloc<Term>(hasher_, negate, t->variables()));
670     }
671 
672     return alloc<Polynomial>(hasher_, newScalar, variables);
673   }
674 
675   if (lhsTerm && rhsScalar) {
676     // Negate the constant.
677     ExprPtr negate = evaluateOp(alloc<Mul>(immLike(rhs_new, -1), rhs_new));
678     return alloc<Polynomial>(hasher_, negate, lhsTerm);
679   }
680 
681   if (lhsScalar && rhsTerm) {
682     // Negate the RHS Term.
683     ExprPtr negate = evaluateOp(
684         alloc<Mul>(immLike(rhsTerm->scalar(), -1), rhsTerm->scalar()));
685 
686     return alloc<Polynomial>(
687         hasher_, lhs_new, alloc<Term>(hasher_, negate, rhsTerm->variables()));
688   }
689 
690   // simple term with a scalar and variable type.
691   if (lhsScalar) {
692     // Create a negated term.
693     return alloc<Polynomial>(
694         hasher_, lhs_new, alloc<Term>(hasher_, immLike(v, -1), rhs_new));
695   }
696 
697   if (rhsScalar) {
698     // Negate the scalar.
699     ExprPtr negate = evaluateOp(alloc<Mul>(immLike(rhs_new, -1), rhs_new));
700     return alloc<Polynomial>(
701         hasher_, negate, alloc<Term>(hasher_, immLike(v, 1), lhs_new));
702   }
703 
704   // no scalar...
705   if (!lhsTerm && !lhsPoly) {
706     lhsTerm = alloc<Term>(hasher_, immLike(v, 1), lhs_new);
707   }
708 
709   bool createdRHSnegated = false;
710   if (!rhsTerm && !rhsPoly) {
711     rhsTerm = alloc<Term>(hasher_, immLike(v, -1), rhs_new);
712     createdRHSnegated = true;
713   }
714 
715   if (lhsTerm && rhsTerm) {
716     return subTerms(lhsTerm, rhsTerm, createdRHSnegated);
717   }
718 
719   // Insert wrapped Term into LHS Polynomial.
720   if (lhsPoly) {
721     CHECK(rhsTerm);
722     return insertTerm(lhsPoly, rhsTerm);
723   }
724 
725   // Insert wrapper Term into negated RHS Poly.
726   if (rhsPoly) {
727     CHECK(lhsTerm);
728     ExprPtr minusOne = immLike(rhsPoly, -1);
729     ExprPtr newScalar = evaluateOp(alloc<Mul>(minusOne, rhsPoly->scalar()));
730 
731     // Negate each term in the Polynomial RHS.
732     std::vector<TermPtr> variables;
733     for (const auto& t : rhsPoly->variables()) {
734       ExprPtr negate = evaluateOp(alloc<Mul>(minusOne, t->scalar()));
735       variables.push_back(alloc<Term>(hasher_, negate, t->variables()));
736     }
737 
738     auto poly = alloc<Polynomial>(hasher_, newScalar, variables);
739     return insertTerm(poly, lhsTerm);
740   }
741 
742   return alloc<Polynomial>(hasher_, immLike(v, 0), lhsTerm, rhsTerm);
743 }
744 
745 // Multiply two terms together, usually creating a new term with the variable
746 // lists concatenated.
mulTerms(const TermPtr & lhs,const TermPtr & rhs)747 TermPtr PolynomialTransformer::mulTerms(
748     const TermPtr& lhs,
749     const TermPtr& rhs) {
750   ExprPtr scalar = evaluateOp(alloc<Mul>(lhs->scalar(), rhs->scalar()));
751   if (immediateEquals(scalar, 0)) {
752     return nullptr;
753   }
754 
755   // Can reorder here since floating point ops don't get put into Terms.
756   std::vector<ExprPtr> variables;
757   std::vector<ExprPtr> multilaneVariables;
758   // For now don't handle exponents.
759   for (const auto& c : lhs->variables()) {
760     if (isMultilanePrimitive(c)) {
761       multilaneVariables.push_back(c);
762     } else {
763       variables.push_back(c);
764     }
765   }
766   for (const auto& c : rhs->variables()) {
767     if (isMultilanePrimitive(c)) {
768       multilaneVariables.push_back(c);
769     } else {
770       variables.push_back(c);
771     }
772   }
773 
774   // Merge all the multilane vars:
775   ExprPtr lastNode{nullptr};
776   for (const auto& node : multilaneVariables) {
777     if (lastNode == nullptr) {
778       lastNode = node;
779     } else {
780       if (auto next = mulMultilane(lastNode, node)) {
781         lastNode = next->accept_mutator(this);
782       } else {
783         variables.push_back(lastNode);
784         lastNode = node;
785       }
786     }
787   }
788   if (lastNode) {
789     variables.push_back(lastNode);
790   }
791 
792   return alloc<Term>(hasher_, scalar, variables);
793 }
794 
795 // Multiply a Polynomial by a Term.
polyByTerm(const PolynomialPtr & poly,const TermPtr & term)796 ExprPtr PolynomialTransformer::polyByTerm(
797     const PolynomialPtr& poly,
798     const TermPtr& term) {
799   // poly * term
800   //    = (poly_terms + poly_scalar) * term
801   //    = poly_terms * term + poly_scalar * term
802 
803   // First, multiply all variables (terms) in the polynomial by the input
804   // term.
805   std::vector<TermPtr> newTerms;
806   for (const auto& var : poly->variables()) {
807     TermPtr newTerm = mulTerms(var, term);
808     if (newTerm) {
809       newTerms.push_back(newTerm);
810     }
811   }
812 
813   // If the scalar in poly is not 0, it must be multiplied by term.
814   // If there are no variables in term, this becomes the scalar in the result
815   // polynomial. If there are variables in term, this becomes a new term in
816   // the result polynomial.
817   if (!immediateEquals(poly->scalar(), 0)) {
818     ExprPtr scalar = evaluateOp(alloc<Mul>(poly->scalar(), term->scalar()));
819     if (term->variables().empty()) {
820       return alloc<Polynomial>(hasher_, scalar, newTerms);
821     }
822     newTerms.push_back(alloc<Term>(hasher_, scalar, term->variables()));
823   }
824 
825   // The only case when the result polynomial has a scalar is when the input
826   // term does not have any variables and the input polynomial has a non-zero
827   // scalar. That case is handled above. So, at this point, we do not have any
828   // scalars in the result polynomial.
829   return alloc<Polynomial>(hasher_, std::move(newTerms));
830 }
831 
832 // Does multiplying these two expressions make a Rounding Off operation.
833 // e.g. LHS = (x/y),  RHS = y => (x / y) * y => RoundOff(x, y).
isRoundOff(const ExprPtr & lhs,const ExprPtr & rhs)834 ExprPtr PolynomialTransformer::isRoundOff(
835     const ExprPtr& lhs,
836     const ExprPtr& rhs) {
837   DivPtr div{nullptr};
838   ExprPtr other{nullptr};
839 
840   if ((div = to<Div>(lhs))) {
841     other = rhs;
842   } else if ((div = to<Div>(rhs))) {
843     other = lhs;
844   } else {
845     return nullptr;
846   }
847 
848   ExprPtr denom = div->rhs();
849 
850   if (TermPtr denomTerm = to<Term>(denom)) {
851     if (immediateEquals(denomTerm->scalar(), 1) &&
852         denomTerm->variables().size() == 1) {
853       denom = denomTerm->variables()[0];
854     }
855   }
856 
857   if (hasher_.hash(denom) == hasher_.hash(other)) {
858     // If the denominator is equal to the other, then yes it's a RoundOff.
859     return alloc<RoundOff>(div->lhs(), div->rhs());
860   }
861 
862   if (denom->isConstant() && other->isConstant()) {
863     if (immediateEquals(denom, 0) || immediateEquals(other, 0)) {
864       return nullptr;
865     }
866     // If they are both scalar we may be able to find a common factor.
867     if (immediateEquals(evaluateOp(alloc<Mod>(other, denom)), 0)) {
868       ExprPtr scalar = evaluateOp(alloc<Div>(other, denom));
869       ExprPtr newDenom = evaluateOp(alloc<Div>(other, scalar));
870       return alloc<Term>(
871           hasher_, scalar, alloc<RoundOff>(div->lhs(), newDenom));
872     }
873   }
874 
875   return nullptr;
876 }
877 
878 // Inserts a new component into a term, looking for opportunities to simplify.
insertIntoTerm(const TermPtr & term,const ExprPtr & expr)879 ExprPtr PolynomialTransformer::insertIntoTerm(
880     const TermPtr& term,
881     const ExprPtr& expr) {
882   std::vector<ExprPtr> vars;
883 
884   // Search for RoundOffs.
885   bool merged{false};
886   for (const auto& component : term->variables()) {
887     if (auto roundoff = isRoundOff(component, expr)) {
888       vars.push_back(roundoff);
889       merged = true;
890     } else {
891       vars.push_back(component);
892     }
893   }
894 
895   if (!merged) {
896     vars.push_back(expr);
897   }
898 
899   if (vars.size() == 1 && immediateEquals(term->scalar(), 1)) {
900     return vars[0];
901   }
902 
903   return alloc<Term>(hasher_, term->scalar(), vars);
904 }
905 
mutate(const MulPtr & v)906 ExprPtr PolynomialTransformer::mutate(const MulPtr& v) {
907   ExprPtr lhs_new = v->lhs()->accept_mutator(this);
908   ExprPtr rhs_new = v->rhs()->accept_mutator(this);
909 
910   // Constant Folding.
911   if (lhs_new->isConstant() && rhs_new->isConstant()) {
912     return evaluateOp(alloc<Mul>(lhs_new, rhs_new));
913   }
914 
915   // Multilane folding.
916   if (isMultilanePrimitive(lhs_new)) {
917     if (auto ret = mulMultilane(lhs_new, rhs_new)) {
918       return ret->accept_mutator(this);
919     }
920   }
921 
922   // Order doesn't matter.
923   ExprPtr scalar = nullptr;
924   ExprPtr variable = nullptr;
925   if (lhs_new->isConstant()) {
926     scalar = lhs_new;
927     variable = rhs_new;
928   } else if (rhs_new->isConstant()) {
929     scalar = rhs_new;
930     variable = lhs_new;
931   }
932 
933   // Handle special case mul by 1 since thats safe for floating point, even if
934   // it's Nan/Inf.
935   if (scalar && immediateEquals(scalar, 1)) {
936     auto c = alloc<Cast>(v->dtype(), variable);
937     return c->accept_mutator(this);
938   }
939 
940   // If this is a floating point Mul then order of operations is important, we
941   // dont want to combine ops.
942   if (lhs_new->dtype().is_floating_point() ||
943       rhs_new->dtype().is_floating_point()) {
944     return alloc<Mul>(lhs_new, rhs_new);
945   }
946 
947   // Handle special case mul by 0.
948   if (scalar && immediateEquals(scalar, 0)) {
949     return immLike(v, 0);
950   }
951 
952   // Catch cases of rounding (Div(A/B) * B).
953   if (auto ret = isRoundOff(lhs_new, rhs_new)) {
954     return ret;
955   } else if (auto ret = isRoundOff(v->lhs(), v->rhs())) {
956     // We can break the Round + Mod pattern via factorization of the Div, so
957     // check whether it would have worked on the unsimplified tree. If so, we
958     // need to simplify again.
959     return ret->accept_mutator(this);
960   }
961 
962   PolynomialPtr lhsPoly = to<Polynomial>(lhs_new);
963   PolynomialPtr rhsPoly = to<Polynomial>(rhs_new);
964 
965   if (lhsPoly && rhsPoly) {
966     // This expands to more terms that we can't generally fix without variable
967     // factorization, it's more efficient to just leave these as Muls.
968     return alloc<Mul>(lhsPoly, rhsPoly);
969   }
970 
971   TermPtr lhsTerm = to<Term>(lhs_new);
972   TermPtr rhsTerm = to<Term>(rhs_new);
973 
974   if (lhsPoly && rhsTerm) {
975     return polyByTerm(lhsPoly, rhsTerm);
976   }
977 
978   if (rhsPoly && lhsTerm) {
979     return polyByTerm(rhsPoly, lhsTerm);
980   }
981 
982   if (lhsTerm && rhsTerm) {
983     return mulTerms(lhsTerm, rhsTerm);
984   }
985 
986   if (scalar && lhsTerm) {
987     ExprPtr newScalar = evaluateOp(alloc<Mul>(scalar, lhsTerm->scalar()));
988     return alloc<Term>(hasher_, newScalar, lhsTerm->variables());
989   }
990 
991   if (scalar && rhsTerm) {
992     ExprPtr newScalar = evaluateOp(alloc<Mul>(scalar, rhsTerm->scalar()));
993     return alloc<Term>(hasher_, newScalar, rhsTerm->variables());
994   }
995 
996   // If this is a scalar * a Polynomial, push the scalar term down.
997   // We can wrap the scalar with a Term and use polyByTerm.
998   if (scalar && lhsPoly) {
999     return polyByTerm(lhsPoly, alloc<Term>(hasher_, scalar));
1000   }
1001   if (scalar && rhsPoly) {
1002     return polyByTerm(rhsPoly, alloc<Term>(hasher_, scalar));
1003   }
1004 
1005   // simple term with a scalar and variable type.
1006   if (scalar) {
1007     return alloc<Term>(hasher_, scalar, variable);
1008   }
1009 
1010   // Multiplying Polynomial by variable can be wrapped in a term and handled
1011   // by polyByTerm also.
1012   if (lhsPoly) {
1013     auto term = alloc<Term>(hasher_, immLike(rhs_new, 1), rhs_new);
1014     return polyByTerm(lhsPoly, term);
1015   }
1016   if (rhsPoly) {
1017     auto term = alloc<Term>(hasher_, immLike(lhs_new, 1), lhs_new);
1018     return polyByTerm(rhsPoly, term);
1019   }
1020 
1021   // Multiplying Term by a variable is equivalent to adding the variable to
1022   // the term's list of vars.
1023   if (lhsTerm) {
1024     return insertIntoTerm(lhsTerm, rhs_new);
1025   }
1026   if (rhsTerm) {
1027     return insertIntoTerm(rhsTerm, lhs_new);
1028   }
1029 
1030   // Two variables, create a new Term.
1031   return alloc<Term>(hasher_, immLike(v, 1), lhs_new, rhs_new);
1032 }
1033 
factorizeDivision(ExprPtr lhs_new,ExprPtr rhs_new)1034 static ExprPtr factorizeDivision(ExprPtr lhs_new, ExprPtr rhs_new) {
1035   if (!lhs_new || !rhs_new) {
1036     return nullptr;
1037   }
1038 
1039   ExprPtr leftScalar = lhs_new->isConstant() ? lhs_new : nullptr;
1040   ExprPtr rightScalar = rhs_new->isConstant() ? rhs_new : nullptr;
1041 
1042   auto lhsTerm = to<Term>(lhs_new);
1043   auto rhsTerm = to<Term>(rhs_new);
1044   if (lhsTerm) {
1045     leftScalar = lhsTerm->scalar();
1046   }
1047 
1048   if (rhsTerm) {
1049     rightScalar = rhsTerm->scalar();
1050   }
1051 
1052   if (!leftScalar || !rightScalar) {
1053     return nullptr;
1054   }
1055 
1056   long left = immediateAs<long>(leftScalar);
1057   long right = immediateAs<long>(rightScalar);
1058 
1059   long GCD = gcd<long>(left, right);
1060   if (GCD <= 1) {
1061     return nullptr;
1062   }
1063 
1064   leftScalar = evaluateOp(alloc<Div>(leftScalar, immLike(leftScalar, GCD)));
1065   rightScalar = evaluateOp(alloc<Div>(rightScalar, immLike(rightScalar, GCD)));
1066 
1067   if (lhsTerm) {
1068     lhs_new = alloc<Term>(lhsTerm->hasher(), leftScalar, lhsTerm->variables());
1069   } else {
1070     lhs_new = leftScalar;
1071   }
1072 
1073   if (rhsTerm) {
1074     rhs_new = alloc<Term>(rhsTerm->hasher(), rightScalar, rhsTerm->variables());
1075   } else {
1076     rhs_new = rightScalar;
1077   }
1078 
1079   return alloc<Div>(lhs_new, rhs_new);
1080 }
1081 
mutate(const DivPtr & v)1082 ExprPtr PolynomialTransformer::mutate(const DivPtr& v) {
1083   ExprPtr lhs_new = v->lhs()->accept_mutator(this);
1084   ExprPtr rhs_new = v->rhs()->accept_mutator(this);
1085 
1086   // Constant Folding.
1087   if (lhs_new->isConstant() && rhs_new->isConstant()) {
1088     return evaluateOp(alloc<Div>(lhs_new, rhs_new));
1089   }
1090 
1091   // If this is a floating point Div then order of operations is important, we
1092   // dont want to combine ops.
1093   if (lhs_new->dtype().is_floating_point() ||
1094       rhs_new->dtype().is_floating_point()) {
1095     return alloc<Div>(lhs_new, rhs_new);
1096   }
1097 
1098   // If the numerator is zero, so is the result.
1099   if (lhs_new->isConstant() && immediateEquals(lhs_new, 0)) {
1100     return lhs_new;
1101   }
1102 
1103   // If the denominator is one, return numerator.
1104   if (rhs_new->isConstant() && immediateEquals(rhs_new, 1)) {
1105     return lhs_new;
1106   }
1107 
1108   // If numberator and denominator are equal the result is 1.
1109   // Unless the demoninator could be zero.
1110   // if (hasher_.hash(lhs_new) == hasher_.hash(rhs_new)) {
1111   //   return getImmediateByType(v->dtype(), 1);
1112   // }
1113 
1114   if (auto ret = factorizeDivision(lhs_new, rhs_new)) {
1115     return ret->accept_mutator(this);
1116   }
1117 
1118   return alloc<Div>(lhs_new, rhs_new);
1119 }
1120 
mutate(const ModPtr & v)1121 ExprPtr PolynomialTransformer::mutate(const ModPtr& v) {
1122   ExprPtr lhs_new = v->lhs()->accept_mutator(this);
1123   ExprPtr rhs_new = v->rhs()->accept_mutator(this);
1124 
1125   // Constant Folding.
1126   if (lhs_new->isConstant() && rhs_new->isConstant()) {
1127     return evaluateOp(alloc<Mod>(lhs_new, rhs_new));
1128   }
1129 
1130   // 0 % x => 0.
1131   if (lhs_new->isConstant() && immediateEquals(lhs_new, 0)) {
1132     return lhs_new;
1133   }
1134 
1135   // x % 1 == 0.
1136   if (rhs_new->isConstant() && immediateEquals(rhs_new, 1)) {
1137     return immLike(v, 0);
1138   }
1139 
1140   // x % x => 0.
1141   if (hasher_.hash(lhs_new) == hasher_.hash(rhs_new)) {
1142     return immLike(v, 0);
1143   }
1144 
1145   TermPtr lhsTerm = to<Term>(lhs_new);
1146   if (!lhsTerm) {
1147     PolynomialPtr lhsPoly = to<Polynomial>(lhs_new);
1148     if (lhsPoly) {
1149       // Can still optimize this out if we can factorize the polynomial.
1150       lhsTerm = factorizePolynomial(lhsPoly);
1151     }
1152   }
1153 
1154   if (lhsTerm) {
1155     // ((C1 * C2) * x) % C1 => 0.
1156     if (rhs_new->isConstant() &&
1157         immediateEquals(
1158             evaluateOp(alloc<Mod>(lhsTerm->scalar(), rhs_new)), 0)) {
1159       return immLike(v, 0);
1160     }
1161 
1162     // (x * y * z) % x => 0.
1163     for (const auto& component : lhsTerm->variables()) {
1164       if (hasher_.hash(component) == hasher_.hash(rhs_new)) {
1165         return immLike(v, 0);
1166       }
1167     }
1168 
1169     // (6 * x * y) % (3 * x * y) => 0.
1170     // also, (x * y * z) % (z * y) => 0.
1171     // This requires all variable terms found in the RHS to be present in the
1172     // LHS.
1173     TermPtr rhsTerm = to<Term>(rhs_new);
1174     if (rhsTerm) {
1175       auto& lVars = lhsTerm->variables();
1176       auto& rVars = rhsTerm->variables();
1177       size_t rLeft = rVars.size();
1178 
1179       auto rIt = rVars.begin();
1180 
1181       for (auto lIt = lVars.begin(); lIt != lVars.end() && !rVars.empty();
1182            ++lIt) {
1183         auto lHash = hasher_.hash(*lIt);
1184         for (; rIt != rVars.end(); ++rIt) {
1185           auto rHash = hasher_.hash(*rIt);
1186           if (lHash == rHash) {
1187             --rLeft;
1188             break;
1189           } else if (lHash < rHash) {
1190             break;
1191           }
1192         }
1193       }
1194 
1195       if (rLeft == 0 &&
1196           immediateEquals(
1197               evaluateOp(alloc<Mod>(lhsTerm->scalar(), rhsTerm->scalar())),
1198               0)) {
1199         return immLike(v, 0);
1200       }
1201     }
1202   }
1203 
1204   return alloc<Mod>(lhs_new, rhs_new);
1205 }
1206 
1207 namespace {
1208 
1209 // Combines two MinTerm / MaxTerm expressions into one.
1210 // The first type on the template refers to the op, as in Min or Max and the
1211 // second type refers to the corresponding term, as in MinTerm or MaxTerm.
1212 template <class Op, class OpTerm>
combineMinMaxTerms(ExprPtr lhs,ExprPtr rhs,bool propagate_nans,HashProvider & hasher)1213 ExprPtr combineMinMaxTerms(
1214     ExprPtr lhs,
1215     ExprPtr rhs,
1216     bool propagate_nans,
1217     HashProvider& hasher) {
1218   auto combine_scalars = [&](ExprPtr c1, ExprPtr c2) -> ExprPtr {
1219     if (c1 && c2) {
1220       return evaluateOp(alloc<Op>(c1, c2, propagate_nans));
1221     }
1222     if (c1) {
1223       return c1;
1224     }
1225     return c2;
1226   };
1227 
1228   auto combine_opterms = [&](NodePtr<OpTerm> m1, NodePtr<OpTerm> m2) {
1229     ExprPtr scalar = combine_scalars(m1->scalar(), m2->scalar());
1230     std::vector<ExprPtr> variables;
1231     for (const auto& v : m1->variables()) {
1232       variables.push_back(v);
1233     }
1234     for (const auto& v : m2->variables()) {
1235       variables.push_back(v);
1236     }
1237     return alloc<OpTerm>(hasher, scalar, propagate_nans, std::move(variables));
1238   };
1239 
1240   auto add_expr_to_opterm = [&](ExprPtr expr, NodePtr<OpTerm> opterm) {
1241     ExprPtr scalar = nullptr;
1242     std::vector<ExprPtr> variables;
1243     if (opterm) {
1244       scalar = opterm->scalar();
1245       variables = opterm->variables();
1246     }
1247     if (expr->isConstant()) {
1248       scalar = combine_scalars(scalar, expr);
1249     } else {
1250       variables.push_back(expr);
1251     }
1252     return alloc<OpTerm>(hasher, scalar, propagate_nans, std::move(variables));
1253   };
1254 
1255   auto lhs_opterm = to<OpTerm>(lhs);
1256   auto rhs_opterm = to<OpTerm>(rhs);
1257   if (lhs_opterm && lhs_opterm->propagate_nans() != propagate_nans) {
1258     return alloc<Op>(lhs, rhs, propagate_nans);
1259   }
1260   if (rhs_opterm && rhs_opterm->propagate_nans() != propagate_nans) {
1261     return alloc<Op>(lhs, rhs, propagate_nans);
1262   }
1263 
1264   if (lhs_opterm && rhs_opterm) {
1265     return combine_opterms(lhs_opterm, rhs_opterm);
1266   } else if (lhs_opterm) {
1267     return add_expr_to_opterm(rhs, lhs_opterm);
1268   } else if (rhs_opterm) {
1269     return add_expr_to_opterm(lhs, rhs_opterm);
1270   }
1271   return add_expr_to_opterm(rhs, add_expr_to_opterm(lhs, nullptr));
1272 }
1273 
1274 // Returns true if op is one of the 2 operands in opterm and also returns
1275 // the other op of opterm in other_op.
1276 template <class OpTerm>
isOperandInMinMaxTerm(NodePtr<OpTerm> opterm,ExprPtr op,HashProvider & hasher,ExprPtr * other_op)1277 bool isOperandInMinMaxTerm(
1278     NodePtr<OpTerm> opterm,
1279     ExprPtr op,
1280     HashProvider& hasher,
1281     ExprPtr* other_op) {
1282   if (opterm->variables().size() != 2) {
1283     return false;
1284   }
1285   auto lhs = opterm->variables()[0];
1286   auto rhs = opterm->variables()[1];
1287   auto op_hash = hasher.hash(std::move(op));
1288   if (hasher.hash(lhs) == op_hash) {
1289     *other_op = rhs;
1290     return true;
1291   } else if (hasher.hash(rhs) == op_hash) {
1292     *other_op = lhs;
1293     return true;
1294   }
1295   return false;
1296 };
1297 
1298 // Simplifies the nested min-max pattern like:
1299 //   * Max(Min(x, y), Min(x, z)) => Min(x, Max(y, z))
1300 //   * Min(Max(x, y), Max(x, z)) => Max(x, Min(y, z))
1301 // This function is called while processing the outer Min / Max ops.
1302 // At that point the inner Min / Max ops would have been converted to
1303 // MinTerm / MaxTerm as appropriate. So, this function checks for those
1304 // term expressions in the given lhs and rhs.
1305 //
1306 // The first type of the template must be the term type corresponding to the
1307 // outer op (e.g. MaxTerm) and the second type of the template must be the term
1308 // type corresponding to the expected inner op (e.g. MinTerm).
1309 template <class OpTerm, class OtherOpTerm>
simplifyNestedMinMax(ExprPtr lhs,ExprPtr rhs,bool propagate_nans,HashProvider & hasher,ExprPtr * new_op)1310 bool simplifyNestedMinMax(
1311     ExprPtr lhs,
1312     ExprPtr rhs,
1313     bool propagate_nans,
1314     HashProvider& hasher,
1315     ExprPtr* new_op) {
1316   auto lhs_opterm = to<OtherOpTerm>(lhs);
1317   auto rhs_opterm = to<OtherOpTerm>(rhs);
1318   if (lhs_opterm && rhs_opterm &&
1319       lhs_opterm->propagate_nans() == propagate_nans &&
1320       rhs_opterm->propagate_nans() == propagate_nans) {
1321     if (!lhs_opterm->scalar() && !rhs_opterm->scalar()) {
1322       if (lhs_opterm->variables().size() == 2 &&
1323           rhs_opterm->variables().size() == 2) {
1324         auto rhs_v1 = rhs_opterm->variables()[0];
1325         auto rhs_v2 = rhs_opterm->variables()[1];
1326         ExprPtr new_op_lhs;
1327         if (isOperandInMinMaxTerm<OtherOpTerm>(
1328                 lhs_opterm, rhs_v1, hasher, &new_op_lhs)) {
1329           auto inner_op = alloc<OpTerm>(
1330               hasher, nullptr, propagate_nans, new_op_lhs, rhs_v2);
1331           *new_op = alloc<OtherOpTerm>(
1332               hasher, nullptr, propagate_nans, rhs_v1, inner_op);
1333           return true;
1334         }
1335         if (isOperandInMinMaxTerm<OtherOpTerm>(
1336                 lhs_opterm, rhs_v2, hasher, &new_op_lhs)) {
1337           auto inner_op = alloc<OpTerm>(
1338               hasher, nullptr, propagate_nans, new_op_lhs, rhs_v1);
1339           *new_op = alloc<OtherOpTerm>(
1340               hasher, nullptr, propagate_nans, rhs_v2, inner_op);
1341           return true;
1342         }
1343       }
1344     }
1345   }
1346   return false;
1347 }
1348 
1349 } // namespace
1350 
mutate(const MaxPtr & v)1351 ExprPtr PolynomialTransformer::mutate(const MaxPtr& v) {
1352   ExprPtr lhs_new = v->lhs()->accept_mutator(this);
1353   ExprPtr rhs_new = v->rhs()->accept_mutator(this);
1354 
1355   // Constant Folding.
1356   if (lhs_new->isConstant() && rhs_new->isConstant()) {
1357     return evaluateOp(alloc<Max>(lhs_new, rhs_new, v->propagate_nans()));
1358   }
1359 
1360   // If diff is constant, return the appropriate operand.
1361   ExprPtr diff = alloc<Sub>(lhs_new, rhs_new);
1362   diff = diff->accept_mutator(this);
1363   if (diff->isConstant()) {
1364     if (immediateAs<int>(diff) > 0) {
1365       return lhs_new;
1366     }
1367     return rhs_new;
1368   }
1369 
1370   // Max(Min(x, y), Min(x, z)) => Min(x, Max(y, z))
1371   ExprPtr new_op;
1372   if (simplifyNestedMinMax<MaxTerm, MinTerm>(
1373           lhs_new, rhs_new, v->propagate_nans(), hasher_, &new_op)) {
1374     return new_op;
1375   }
1376 
1377   return combineMinMaxTerms<Max, MaxTerm>(
1378       lhs_new, rhs_new, v->propagate_nans(), hasher_);
1379 }
1380 
mutate(const MinPtr & v)1381 ExprPtr PolynomialTransformer::mutate(const MinPtr& v) {
1382   ExprPtr lhs_new = v->lhs()->accept_mutator(this);
1383   ExprPtr rhs_new = v->rhs()->accept_mutator(this);
1384 
1385   // Constant Folding.
1386   if (lhs_new->isConstant() && rhs_new->isConstant()) {
1387     return evaluateOp(alloc<Min>(lhs_new, rhs_new, v->propagate_nans()));
1388   }
1389 
1390   // If diff is constant, return the appropriate operand.
1391   ExprPtr diff = alloc<Sub>(lhs_new, rhs_new);
1392   diff = diff->accept_mutator(this);
1393   if (diff->isConstant()) {
1394     if (immediateAs<int>(diff) < 0) {
1395       return lhs_new;
1396     }
1397     return rhs_new;
1398   }
1399 
1400   // Min(Max(x, y), Max(x, z)) => Max(x, Min(y, z))
1401   ExprPtr new_op;
1402   if (simplifyNestedMinMax<MinTerm, MaxTerm>(
1403           lhs_new, rhs_new, v->propagate_nans(), hasher_, &new_op)) {
1404     return new_op;
1405   }
1406 
1407   return combineMinMaxTerms<Min, MinTerm>(
1408       lhs_new, rhs_new, v->propagate_nans(), hasher_);
1409 }
1410 
mutate(const CompareSelectPtr & v)1411 ExprPtr PolynomialTransformer::mutate(const CompareSelectPtr& v) {
1412   ExprPtr lhs_new = v->lhs()->accept_mutator(this);
1413   ExprPtr rhs_new = v->rhs()->accept_mutator(this);
1414   ExprPtr true_branch = v->ret_val1()->accept_mutator(this);
1415   ExprPtr false_branch = v->ret_val2()->accept_mutator(this);
1416 
1417   // Constant Folding.
1418   if (lhs_new->isConstant() && rhs_new->isConstant() &&
1419       true_branch->isConstant() && false_branch->isConstant()) {
1420     ExprPtr v_new = alloc<CompareSelect>(
1421         lhs_new,
1422         rhs_new,
1423         true_branch,
1424         false_branch,
1425         v->compare_select_op(),
1426         v->bias());
1427     return evaluateOp(v_new);
1428   }
1429 
1430   // If the comparison is done in float, don't attempt diff simplification,
1431   // since we can't correctly handle NaN.
1432   if (lhs_new->dtype().is_floating_point() ||
1433       rhs_new->dtype().is_floating_point()) {
1434     return alloc<CompareSelect>(
1435         lhs_new,
1436         rhs_new,
1437         true_branch,
1438         false_branch,
1439         v->compare_select_op(),
1440         v->bias());
1441   }
1442 
1443   // If diff is constant, we can determine it.
1444   ExprPtr diff = alloc<Sub>(rhs_new, lhs_new);
1445   diff = diff->accept_mutator(this);
1446 
1447   if (!diff->isConstant()) {
1448     return alloc<CompareSelect>(
1449         lhs_new,
1450         rhs_new,
1451         true_branch,
1452         false_branch,
1453         v->compare_select_op(),
1454         v->bias());
1455   }
1456 
1457   bool equal = immediateEquals(diff, 0);
1458   bool lhsSmaller = !equal && !immediateIsNegative(diff);
1459 
1460   switch (v->compare_select_op()) {
1461     case CompareSelectOperation::kEQ:
1462       return equal ? true_branch : false_branch;
1463     case CompareSelectOperation::kGT:
1464       return (lhsSmaller || equal) ? false_branch : true_branch;
1465     case CompareSelectOperation::kGE:
1466       return lhsSmaller ? false_branch : true_branch;
1467     case CompareSelectOperation::kLT:
1468       return lhsSmaller ? true_branch : false_branch;
1469     case CompareSelectOperation::kLE:
1470       return (lhsSmaller || equal) ? true_branch : false_branch;
1471     case CompareSelectOperation::kNE:
1472       return equal ? false_branch : true_branch;
1473   }
1474 
1475   // should not be possible but just in case.
1476   return alloc<CompareSelect>(
1477       lhs_new,
1478       rhs_new,
1479       true_branch,
1480       false_branch,
1481       v->compare_select_op(),
1482       v->bias());
1483 }
1484 
mutate(const IntrinsicsPtr & v)1485 ExprPtr PolynomialTransformer::mutate(const IntrinsicsPtr& v) {
1486   std::vector<ExprPtr> new_params;
1487   bool changed = false;
1488   bool allConstant = true;
1489   for (const auto& p : v->params()) {
1490     ExprPtr new_child = p->accept_mutator(this);
1491     new_params.push_back(new_child);
1492 
1493     changed |= p != new_child;
1494     allConstant &= new_child->isConstant();
1495   }
1496 
1497   ExprPtr node = v;
1498   if (changed) {
1499     node = alloc<Intrinsics>(v->op_type(), new_params);
1500   }
1501 
1502   if (!allConstant || !v->isPure()) {
1503     return node;
1504   }
1505 
1506   // we're evaluating, but the evaluator only supports float intrinsics.
1507   std::vector<ExprPtr> const_params;
1508   changed = false;
1509   for (const auto& p : new_params) {
1510     if (p->dtype().scalar_type() == ScalarType::Float) {
1511       const_params.push_back(p);
1512     } else {
1513       const_params.push_back(
1514           alloc<Cast>(Dtype(ScalarType::Float, p->dtype().lanes()), p));
1515       changed = true;
1516     }
1517   }
1518 
1519   if (changed) {
1520     node = alloc<Intrinsics>(v->op_type(), const_params);
1521   }
1522   return evaluateOp(node);
1523 }
1524 
mutate(const CastPtr & v)1525 ExprPtr PolynomialTransformer::mutate(const CastPtr& v) {
1526   ExprPtr node = v->src_value()->accept_mutator(this);
1527   if (node->isConstant()) {
1528     return evaluateOp(alloc<Cast>(v->dtype(), node));
1529   }
1530 
1531   if (v->dtype() == node->dtype()) {
1532     return node;
1533   }
1534 
1535   return alloc<Cast>(v->dtype(), node);
1536 }
1537 
mutate(const IfThenElsePtr & v)1538 ExprPtr PolynomialTransformer::mutate(const IfThenElsePtr& v) {
1539   ExprPtr condition = v->condition();
1540   ExprPtr true_value = v->true_value();
1541   ExprPtr false_value = v->false_value();
1542   ExprPtr condition_new = condition->accept_mutator(this);
1543   ExprPtr true_value_new = true_value->accept_mutator(this);
1544   ExprPtr false_value_new = false_value->accept_mutator(this);
1545 
1546   // If the condition is constant then we can choose the right branch now.
1547   if (condition_new->isConstant()) {
1548     if (!immediateEquals(condition_new, 0)) {
1549       return true_value_new;
1550     } else {
1551       return false_value_new;
1552     }
1553   }
1554 
1555   // If both branches are the same then don't do the condition.
1556   if (hasher_.hash(true_value_new) == hasher_.hash(false_value_new)) {
1557     return true_value_new;
1558   }
1559 
1560   if (condition == condition_new && true_value == true_value_new &&
1561       false_value == false_value_new) {
1562     return v;
1563   }
1564 
1565   return alloc<IfThenElse>(condition_new, true_value_new, false_value_new);
1566 }
1567 
mutate(const AndPtr & v)1568 ExprPtr PolynomialTransformer::mutate(const AndPtr& v) {
1569   return mutateBinaryOp(v, this);
1570 }
1571 
mutate(const XorPtr & v)1572 ExprPtr PolynomialTransformer::mutate(const XorPtr& v) {
1573   return mutateBinaryOp(v, this);
1574 }
1575 
mutate(const LshiftPtr & v)1576 ExprPtr PolynomialTransformer::mutate(const LshiftPtr& v) {
1577   return mutateBinaryOp(v, this);
1578 }
1579 
mutate(const RshiftPtr & v)1580 ExprPtr PolynomialTransformer::mutate(const RshiftPtr& v) {
1581   return mutateBinaryOp(v, this);
1582 }
1583 
mutate(const CondPtr & v)1584 StmtPtr PolynomialBase::mutate(const CondPtr& v) {
1585   ExprPtr cond_old = v->condition();
1586   StmtPtr true_old = v->true_stmt();
1587   StmtPtr false_old = v->false_stmt();
1588 
1589   ExprPtr cond_new = cond_old->accept_mutator(this);
1590   StmtPtr true_new = true_old ? true_old->accept_mutator(this) : true_old;
1591   StmtPtr false_new = false_old ? false_old->accept_mutator(this) : false_old;
1592 
1593   // If the condition is constant then we can choose the right branch now.
1594   if (cond_new->isConstant()) {
1595     if (!immediateEquals(cond_new, 0)) {
1596       return true_new;
1597     } else {
1598       return false_new;
1599     }
1600   }
1601 
1602   // If both branches are the same then don't do the condition.
1603   if (true_new && false_new &&
1604       hasher_.hash(true_new) == hasher_.hash(false_new)) {
1605     return true_new;
1606   }
1607 
1608   BlockPtr true_block = to<Block>(true_new);
1609   BlockPtr false_block = to<Block>(false_new);
1610   bool true_empty = !true_new || (true_block && true_block->nstmts() == 0);
1611   bool false_empty = !false_new || (false_block && false_block->nstmts() == 0);
1612 
1613   if (true_empty && false_empty) {
1614     return alloc<Block>(std::vector<StmtPtr>({}));
1615   }
1616   if (cond_old != cond_new) {
1617     v->set_condition(cond_new);
1618   }
1619   if (true_old != true_new) {
1620     v->set_true_stmt(true_new);
1621   }
1622   if (false_old != false_new) {
1623     v->set_false_stmt(false_new);
1624   }
1625   return v;
1626 }
1627 
handleForCondReordering(const ForPtr & loop,const CondPtr & cond)1628 static StmtPtr handleForCondReordering(
1629     const ForPtr& loop,
1630     const CondPtr& cond) {
1631   if (cond->false_stmt()) {
1632     return nullptr;
1633   }
1634 
1635   auto condition_vars = VarFinder::find(cond->condition());
1636   for (const auto& v : condition_vars) {
1637     // If the condition depends on a Var that is modified in the loop body, it
1638     // may not be safe to reorder.
1639     if (ModifiesVarChecker::check(loop, v)) {
1640       return nullptr;
1641     }
1642   }
1643 
1644   ForPtr new_f = loop->cloneWithNewBody(Stmt::clone(cond->true_stmt()));
1645   return cond->cloneWithNewBody(new_f);
1646 }
1647 
mutate(const ForPtr & v)1648 StmtPtr PolynomialBase::mutate(const ForPtr& v) {
1649   ExprPtr var = v->var();
1650   ExprPtr start = v->start();
1651   ExprPtr stop = v->stop();
1652   StmtPtr body = v->body();
1653   LoopOptions loop_options = v->loop_options();
1654   ExprPtr var_new_expr = var->accept_mutator(this);
1655   VarPtr var_new = to<Var>(var_new_expr);
1656   ExprPtr start_new = start->accept_mutator(this);
1657   ExprPtr stop_new = stop->accept_mutator(this);
1658   StmtPtr body_new = body;
1659 
1660   ExprPtr loops = alloc<Sub>(stop_new, start_new);
1661   loops = loops->accept_mutator(this);
1662   if (loop_options.isDefault() && loops->isConstant()) {
1663     if (immediateEquals(loops, 0)) {
1664       return alloc<Block>(std::vector<StmtPtr>({}));
1665     } else if (immediateEquals(loops, 1)) {
1666       body_new = Substitute(body, {{var_new, start_new}});
1667       body_new = body_new->accept_mutator(this);
1668       return body_new;
1669     }
1670   }
1671 
1672   body_new = body_new->accept_mutator(this);
1673   if (!body_new) {
1674     return alloc<Block>(std::vector<StmtPtr>({}));
1675   }
1676 
1677   if (auto block = to<Block>(body_new)) {
1678     if (block->nstmts() == 0) {
1679       return alloc<Block>(std::vector<StmtPtr>({}));
1680     }
1681 
1682     if (block->nstmts() == 1) {
1683       if (auto cond = to<Cond>(block->front())) {
1684         StmtPtr reordered = handleForCondReordering(v, cond);
1685         if (reordered) {
1686           return reordered->accept_mutator(this);
1687         }
1688       }
1689     }
1690   }
1691 
1692   if (var != var_new) {
1693     v->set_var(var_new);
1694   }
1695   if (start != start_new) {
1696     v->set_start(start_new);
1697   }
1698   if (stop != stop_new) {
1699     v->set_stop(stop_new);
1700   }
1701   if (body != body_new) {
1702     v->set_body(body_new);
1703   }
1704   return v;
1705 }
1706 
mutate(const BlockPtr & v)1707 StmtPtr PolynomialBase::mutate(const BlockPtr& v) {
1708   std::vector<StmtPtr> stmts;
1709   // Flatten sub-blocks:
1710   bool stmts_changed = false;
1711   for (const StmtPtr& stmt : *v) {
1712     StmtPtr stmt_new = stmt->accept_mutator(this);
1713     stmts_changed |= stmt != stmt_new;
1714     if (stmt_new == nullptr) {
1715       continue;
1716     }
1717 
1718     if (auto subBlock = to<Block>(stmt_new)) {
1719       for (Block::iterator I = subBlock->begin(), E = subBlock->end();
1720            I != E;) {
1721         // Be careful to avoid invalidating the iterator.
1722         StmtPtr s = *(I++);
1723         subBlock->remove_stmt(s);
1724         stmts.push_back(s);
1725       }
1726       stmts_changed = true;
1727     } else {
1728       stmts.push_back(stmt_new);
1729     }
1730   }
1731   if (stmts_changed) {
1732     v->set_stmts(stmts);
1733   }
1734   return v;
1735 }
1736 
1737 // TermExpander
1738 
mutate(const TermPtr & v)1739 ExprPtr TermExpander::mutate(const TermPtr& v) {
1740   ExprPtr newScalar = v->scalar()->accept_mutator(this);
1741   if (immediateEquals(newScalar, 0)) {
1742     return newScalar;
1743   }
1744 
1745   std::vector<ExprPtr> vars;
1746   std::vector<ExprPtr> multilaneVars;
1747 
1748   // Assume we can reorder here because we wont merge floating terms.
1749   ExprPtr lastNode{nullptr};
1750   for (const auto& var : v->variables()) {
1751     ExprPtr node = var->accept_mutator(this);
1752     if (MulPtr mul = to<Mul>(node)) {
1753       // If the sub-Expr resolved to a multiplication, lift it into this
1754       // term.
1755       if (isMultilanePrimitive(mul->lhs())) {
1756         multilaneVars.push_back(mul->lhs());
1757       } else {
1758         vars.push_back(mul->lhs());
1759       }
1760 
1761       if (isMultilanePrimitive(mul->rhs())) {
1762         multilaneVars.push_back(mul->rhs());
1763       } else {
1764         vars.push_back(mul->rhs());
1765       }
1766     } else {
1767       if (isMultilanePrimitive(node)) {
1768         multilaneVars.push_back(node);
1769       } else {
1770         vars.push_back(node);
1771       }
1772     }
1773   }
1774 
1775   for (const auto& node : multilaneVars) {
1776     if (lastNode == nullptr) {
1777       lastNode = node;
1778     } else {
1779       lastNode = mulMultilane(lastNode, node);
1780       // simplify first, then re-expand.
1781       lastNode = lastNode->accept_mutator(simplifier_);
1782       lastNode = lastNode->accept_mutator(this);
1783     }
1784   }
1785 
1786   for (const auto& node : vars) {
1787     if (lastNode == nullptr) {
1788       lastNode = node;
1789     } else {
1790       lastNode = alloc<Mul>(lastNode, node);
1791     }
1792   }
1793 
1794   if (!immediateEquals(newScalar, 1)) {
1795     if (lastNode) {
1796       // We want to avoid a leaving a CastNode on the scalar, so handle that
1797       // now.
1798       auto termDtype = v->scalar()->dtype();
1799       auto lastNodeDtype = lastNode->dtype();
1800       if (termDtype != lastNodeDtype) {
1801         ExprPtr castV = v->scalar();
1802         // Take care of lane mismatch first.
1803         if (termDtype.lanes() != lastNodeDtype.lanes()) {
1804           castV = alloc<Broadcast>(v->scalar(), lastNodeDtype.lanes());
1805         }
1806         // Now take care of scalar type as well.
1807         if (termDtype.scalar_type() != lastNodeDtype.scalar_type()) {
1808           castV = alloc<Cast>(lastNode->dtype(), castV);
1809           // For scalars, we can simplify the cast further.
1810           if (lastNodeDtype.lanes() == 1) {
1811             castV = evaluateOp(castV);
1812           }
1813         }
1814         lastNode = alloc<Mul>(castV, lastNode);
1815       } else {
1816         lastNode = alloc<Mul>(v->scalar(), lastNode);
1817       }
1818     } else {
1819       lastNode = v->scalar();
1820     }
1821   }
1822 
1823   return lastNode;
1824 }
1825 
1826 // Returns an immediate containing the greatest common divisor of all terms
1827 // (inc. the scalar term) in the polynomial. If the GCD is uninteresting
1828 // (e.g. 1) then returns nullptr.
polyGCD(const PolynomialPtr & poly)1829 static ExprPtr polyGCD(const PolynomialPtr& poly) {
1830   ExprPtr scalar = poly->scalar();
1831   const std::vector<TermPtr>& variables = poly->variables();
1832 
1833   // We ony want to factorize if we're saving complete operations, i.e. no
1834   // value in factorizing 6x + 4y into 2 * (3x + 2y) since we don't save work.
1835   int opsSaved = 1; // default to saving the scalar.
1836   long GCD = std::abs(immediateAs<long>(scalar));
1837   for (const auto& t : variables) {
1838     long termScalar = std::abs(immediateAs<long>(t->scalar()));
1839     long newGCD = gcd(std::max(GCD, termScalar), std::min(GCD, termScalar));
1840     if (newGCD == 1) {
1841       return nullptr;
1842     }
1843 
1844     if (GCD != newGCD) {
1845       opsSaved = 0;
1846       GCD = newGCD;
1847     }
1848 
1849     if (GCD == termScalar) {
1850       opsSaved++;
1851     }
1852   }
1853 
1854   if (opsSaved == 0) {
1855     return nullptr;
1856   }
1857 
1858   if (GCD == 0) {
1859     return nullptr;
1860   }
1861 
1862   // Not worth, can be a Sub.
1863   if (GCD == -1 && opsSaved == 1) {
1864     return nullptr;
1865   }
1866 
1867   return immLike(poly, GCD);
1868 }
1869 
1870 // A ModRound is a div-mod-mul in which the divisor in div and multiplier in mul
1871 // are identical and not equal to 1.
1872 // In a ModRound x/y%z*y*c (c is constant), 'scalar' denotes c, 'denominator'
1873 // denotes x, 'divisor' denotes y and 'mod_divisor' denotes z.
1874 class ModRound {
1875  public:
ModRound(ExprPtr scalar,ExprPtr denom,ExprPtr divisor,ExprPtr mod_divisor)1876   ModRound(ExprPtr scalar, ExprPtr denom, ExprPtr divisor, ExprPtr mod_divisor)
1877       : scalar(std::move(scalar)),
1878         denom(std::move(denom)),
1879         divisor(std::move(divisor)),
1880         mod_divisor(std::move(mod_divisor)) {}
1881   ExprPtr scalar;
1882   ExprPtr denom;
1883   ExprPtr divisor;
1884   ExprPtr mod_divisor;
1885 };
1886 
isModRound(const TermPtr & e)1887 static std::optional<class ModRound> isModRound(const TermPtr& e) {
1888   DivPtr div{nullptr};
1889   ModPtr mod{nullptr};
1890   ExprPtr denom{nullptr};
1891   ExprPtr divisor{nullptr};
1892   ExprPtr mod_divisor{nullptr};
1893   ExprPtr multiplier = e->scalar();
1894   ExprPtr scalar{nullptr};
1895   ExprPtr other{nullptr};
1896 
1897   for (const auto& m : e->variables()) {
1898     if (m->expr_type() == IRNodeType::kMod) {
1899       // TODO: currently only identify terms with one variable being mod; it is
1900       // possible to extend this if we have to handle terms like (t/(x%2 * y) %
1901       // z) * (x%2 *y).
1902       if (!mod) {
1903         mod = to<Mod>(m);
1904       } else {
1905         return std::nullopt;
1906       }
1907     } else {
1908       // Take care of special cases before multiplying the scalar and variable.
1909       if (multiplier->isConstant()) {
1910         // Take care of lane mismatch first.
1911         if (multiplier->dtype().lanes() != m->dtype().lanes()) {
1912           multiplier = alloc<Broadcast>(multiplier, m->dtype().lanes());
1913         }
1914         // Take care of scalar type mismatch.
1915         if (multiplier->dtype().scalar_type() != m->dtype().scalar_type()) {
1916           multiplier = alloc<Cast>(m->dtype(), multiplier);
1917           if (m->dtype().lanes() == 1) {
1918             multiplier = evaluateOp(multiplier);
1919           }
1920         }
1921       }
1922 
1923       // All non-mod variables are considered as part of the multiplier.
1924       multiplier = alloc<Mul>(multiplier, m);
1925     }
1926   }
1927   multiplier = IRSimplifier::simplify(multiplier);
1928 
1929   if (!mod) {
1930     return std::nullopt;
1931   }
1932 
1933   mod_divisor = IRSimplifier::simplify(mod->rhs());
1934   other = mod->lhs();
1935 
1936   if (!(div = to<Div>(other))) {
1937     return std::nullopt;
1938   }
1939 
1940   divisor = IRSimplifier::simplify(div->rhs());
1941   other = div->lhs();
1942 
1943   denom = IRSimplifier::simplify(other);
1944 
1945   // Deny cases in which divisor!=multiplier.
1946   HashProvider& hasher = e->hasher();
1947   if (hasher.hash(divisor) != hasher.hash(multiplier)) {
1948     // TODO: currently we do not extract a common factor if divisor and
1949     // multiplier are not constants. The extraction is not supported (e.g.,
1950     // x*2/x -> 2) in IRSimplifier.simplify because x could be 0. As future
1951     // work, we can extend division to 2 versions: 1) division for customers
1952     // that has to be strictly simplified and 2) division we introduced in our
1953     // transformations which can be simplified without considering 0s, e.g.,
1954     // Div_nonzero. The second division will be only used to facilitate our
1955     // transformations.
1956     if (divisor->isConstant() && multiplier->isConstant()) {
1957       // If both are scalar we may be able to find a common factor.
1958       if (immediateEquals(evaluateOp(alloc<Mod>(multiplier, divisor)), 0)) {
1959         // The common factor becomes 'scalar' of the term, e.g.,in t/3%7*6,
1960         // divisor=multiplier=3, scalar=2.
1961         ExprPtr c = evaluateOp(alloc<Div>(multiplier, divisor));
1962         scalar = c;
1963       } else if (immediateEquals(
1964                      evaluateOp(alloc<Mod>(divisor, multiplier)), 0)) {
1965         // The common factor becomes part of 'denom', e.g., in t/14%7*2,
1966         // divisor=multiplier=2, denom=t/7.
1967         ExprPtr c = evaluateOp(alloc<Div>(divisor, multiplier));
1968         divisor = multiplier;
1969         denom = IRSimplifier::simplify(alloc<Div>(other, c));
1970       } else {
1971         return std::nullopt;
1972       }
1973     } else {
1974       return std::nullopt;
1975     }
1976   }
1977 
1978   // Deny cases in which divisor=1. Such cases are considered as Mods.
1979   if (divisor->isConstant() && immediateEquals(divisor, 1)) {
1980     return std::nullopt;
1981   }
1982 
1983   if (!scalar) {
1984     scalar = immLike(multiplier, 1);
1985   }
1986 
1987   return ModRound(scalar, denom, divisor, mod_divisor);
1988 }
1989 
1990 // Search the polynomial for Terms that can be merged in
1991 // (1) Round + Mod pattern: (x/y) * y + x % y => RoundOff(x,y) + Mod(x, y) => x
1992 // (2) Mod round + Mod pattern: (x/y % z)*y + x%y => ModRound(x, y, z) + Mod(x,
1993 // y) => x % (y*z)
simplifyRoundModPattern(const PolynomialPtr & poly)1994 static ExprPtr simplifyRoundModPattern(const PolynomialPtr& poly) {
1995   std::vector<TermPtr> rounds;
1996   std::vector<TermPtr> mods;
1997   std::vector<TermPtr> mod_rounds;
1998   std::vector<TermPtr> others;
1999 
2000   // Split out the Mod, ModRounds and RoundOffs operations so we can inspect.
2001   for (const auto& c : poly->variables()) {
2002     if (c->variables().size() > 1) {
2003       if (auto a = isModRound(c)) {
2004         mod_rounds.push_back(c);
2005       } else {
2006         others.push_back(c);
2007       }
2008       continue;
2009     }
2010 
2011     ExprPtr e = c->variables()[0];
2012 
2013     if (to<RoundOff>(e)) {
2014       rounds.push_back(c);
2015     } else if (e->expr_type() == IRNodeType::kMod) {
2016       if (auto a = isModRound(c)) {
2017         mod_rounds.push_back(c);
2018       } else {
2019         mods.push_back(c);
2020       }
2021     } else {
2022       others.push_back(c);
2023     }
2024   }
2025 
2026   // Can't continue without at least one RoundOff/ModRound and one Mod.
2027   if ((rounds.empty() && mod_rounds.empty()) || mods.empty()) {
2028     return nullptr;
2029   }
2030 
2031   HashProvider& hasher = poly->hasher();
2032   bool didAnything = false;
2033   std::vector<TermPtr> mods_merged;
2034   bool repeat = true;
2035   // Repeat merging terms till there are no Mods or the terms cannot be merged
2036   // any further.
2037   while (!mods.empty() && repeat) {
2038     repeat = false;
2039     for (int64_t i = static_cast<int64_t>(mods.size()) - 1; i >= 0; i--) {
2040       TermPtr m = mods[i];
2041       ModPtr mod = to<Mod>(m->variables()[0]);
2042       CHECK(mod);
2043       ExprPtr mod_lhs = IRSimplifier::simplify(mod->lhs());
2044       ExprPtr mod_rhs = IRSimplifier::simplify(mod->rhs());
2045       bool merged = false;
2046       for (int64_t j = static_cast<int64_t>(mod_rounds.size()) - 1; j >= 0;
2047            j--) {
2048         TermPtr mr = mod_rounds[j];
2049         auto a = isModRound(mr);
2050         CHECK(a);
2051         ModRound& mod_round = *a;
2052 
2053         // TODO: for now don't attempt partial factorization of this
2054         // optimization. E.g. it's possible to do: 2 * (x/y%z) * y + (x%y) =>
2055         // x%(y*z) + (x/y%z) * y
2056         if (!immediateEquals(
2057                 evaluateOp(alloc<Sub>(mod_round.scalar, m->scalar())), 0)) {
2058           continue;
2059         }
2060         // Valid optimization if mod LHS matches denom and mod RHS matches
2061         // divisor.
2062         if (hasher.hash(mod_round.denom) == hasher.hash(mod_lhs) &&
2063             hasher.hash(mod_round.divisor) == hasher.hash(mod_rhs)) {
2064           TermPtr merged_m = alloc<Term>(
2065               hasher,
2066               mod_round.scalar,
2067               IRSimplifier::simplify(alloc<Mod>(
2068                   mod_round.denom,
2069                   alloc<Mul>(mod_round.divisor, mod_round.mod_divisor))));
2070           mods_merged.push_back(merged_m);
2071           merged = true;
2072           repeat = true;
2073           didAnything = true;
2074           mods.erase(mods.begin() + i);
2075           mod_rounds.erase(mod_rounds.begin() + j);
2076           break;
2077         }
2078       }
2079 
2080       if (merged) {
2081         continue;
2082       }
2083 
2084       for (int64_t k = static_cast<int64_t>(rounds.size()) - 1; k >= 0; k--) {
2085         TermPtr r = rounds[k];
2086         RoundOffPtr roundoff = to<RoundOff>(r->variables()[0]);
2087         CHECK(roundoff);
2088 
2089         // TODO: for now don't attempt partial factorization of this
2090         // optimization. E.g. it's possible to do: 2 * (x/y) * y + (x%y) => x +
2091         // (x/y) * y but unsure thats actually much better, particularly with
2092         // CSE.
2093         if (!immediateEquals(
2094                 evaluateOp(alloc<Sub>(r->scalar(), m->scalar())), 0)) {
2095           continue;
2096         }
2097         ExprPtr round_lhs = IRSimplifier::simplify(roundoff->lhs());
2098         ExprPtr round_rhs = IRSimplifier::simplify(roundoff->rhs());
2099         // Valid optimization if LHS and RHS are equal for both.
2100         if (hasher.hash(round_lhs) == hasher.hash(mod_lhs) &&
2101             hasher.hash(round_rhs) == hasher.hash(mod_rhs)) {
2102           TermPtr merged_r = alloc<Term>(hasher, r->scalar(), round_lhs);
2103           others.push_back(merged_r);
2104           merged = true;
2105           didAnything = true;
2106           mods.erase(mods.begin() + i);
2107           rounds.erase(rounds.begin() + k);
2108           break;
2109         }
2110       }
2111 
2112       // If we didn't merge, move out the Mod.
2113       if (!merged) {
2114         others.push_back(m);
2115         mods.erase(mods.begin() + i);
2116       }
2117 
2118     } // end of for-loop
2119 
2120     // Add newly generated Mods for merging opportunities in the next iteration.
2121     if (!mods_merged.empty()) {
2122       mods.insert(mods.end(), mods_merged.begin(), mods_merged.end());
2123       mods_merged.clear();
2124     }
2125 
2126   } // end of while-loop
2127 
2128   // If we made no changes, just exit.
2129   if (!didAnything) {
2130     return nullptr;
2131   }
2132 
2133   // Keep remaining ModRounds and RoundOffs.
2134   if (!mod_rounds.empty()) {
2135     others.insert(others.end(), mod_rounds.begin(), mod_rounds.end());
2136   }
2137 
2138   if (!rounds.empty()) {
2139     others.insert(others.end(), rounds.begin(), rounds.end());
2140   }
2141 
2142   return alloc<Polynomial>(hasher, poly->scalar(), others);
2143 }
2144 
2145 // Trivially factorize terms by GCD of scalar components.
factorizePolynomial(const PolynomialPtr & poly)2146 TermPtr PolynomialBase::factorizePolynomial(const PolynomialPtr& poly) {
2147   ExprPtr scalar = poly->scalar();
2148   const std::vector<TermPtr>& variables = poly->variables();
2149 
2150   // Compute the GCD of terms.
2151   ExprPtr GCD = polyGCD(poly);
2152 
2153   // No GCD means 0 or 1 and can't be factored.
2154   if (!GCD) {
2155     return nullptr;
2156   }
2157 
2158   // Create new structure.
2159   std::vector<TermPtr> newPolyTerms;
2160   newPolyTerms.reserve(variables.size());
2161   for (const auto& t : variables) {
2162     // New term with the scalar divided by the GCD.
2163     newPolyTerms.push_back(alloc<Term>(
2164         poly->hasher(),
2165         evaluateOp(alloc<Div>(t->scalar(), GCD)),
2166         t->variables()));
2167   }
2168 
2169   PolynomialPtr newPoly = alloc<Polynomial>(
2170       poly->hasher(), evaluateOp(alloc<Div>(scalar, GCD)), newPolyTerms);
2171 
2172   return alloc<Term>(poly->hasher(), GCD, newPoly);
2173 }
2174 
mutate(const PolynomialPtr & v)2175 ExprPtr TermExpander::mutate(const PolynomialPtr& v) {
2176   if (v->variables().empty()) {
2177     return v->scalar();
2178   }
2179 
2180   // If this Polynomial can be factorized: do it, then expand the result.
2181   if (ExprPtr simplified = simplifyRoundModPattern(v)) {
2182     return simplified->accept_mutator(this);
2183   }
2184 
2185   // If this Polynomial can be factorized: do it, then expand the result.
2186   if (ExprPtr factorized = factorizePolynomial(v)) {
2187     return factorized->accept_mutator(this);
2188   }
2189 
2190   std::vector<TermPtr> addTerms;
2191   std::vector<TermPtr> subTerms;
2192 
2193   auto vars = v->variables();
2194   std::unordered_map<ExprPtr, std::string> str_repr_cache;
2195   std::sort(vars.begin(), vars.end(), [&](const ExprPtr& a, const ExprPtr& b) {
2196     if (!str_repr_cache.count(a)) {
2197       str_repr_cache[a] = std::to_string(a);
2198     }
2199     if (!str_repr_cache.count(b)) {
2200       str_repr_cache[b] = std::to_string(b);
2201     }
2202     return str_repr_cache.at(a) < str_repr_cache.at(b);
2203   });
2204 
2205   // partition the terms into a list to add and list to subtract.
2206   for (const auto& node : vars) {
2207     if (immediateIsNegative(node->scalar())) {
2208       subTerms.push_back(node);
2209     } else if (!immediateEquals(node->scalar(), 0)) {
2210       addTerms.push_back(node);
2211     }
2212     // Skip terms with a scalar of zero.
2213   }
2214 
2215   // The last node constructed.
2216   ExprPtr lastNode{nullptr};
2217 
2218   for (const auto& node : addTerms) {
2219     ExprPtr simpleNode = node->accept_mutator(this);
2220 
2221     if (lastNode == nullptr) {
2222       lastNode = simpleNode;
2223       continue;
2224     }
2225 
2226     if (isMultilanePrimitive(simpleNode)) {
2227       auto ret = combineMultilane<Add>(lastNode, simpleNode);
2228       if (ret) {
2229         // simplify result first, then expand.
2230         lastNode = ret->accept_mutator(simplifier_);
2231         lastNode = lastNode->accept_mutator(this);
2232         continue;
2233       }
2234     }
2235 
2236     lastNode = alloc<Add>(lastNode, simpleNode);
2237   }
2238 
2239   // If we have no add terms the scalar should go first.
2240   // E.g. 1 - x.
2241   bool scalarWritten = false;
2242   if (lastNode == nullptr) {
2243     auto scalarNode = v->scalar()->accept_mutator(simplifier_);
2244 
2245     if (!immediateEquals(scalarNode, 0)) {
2246       lastNode = scalarNode;
2247       scalarWritten = true;
2248     }
2249   }
2250 
2251   for (const auto& node : subTerms) {
2252     // Can still be first node if scalarVal is 0.
2253     if (lastNode == nullptr) {
2254       lastNode = node->accept_mutator(this);
2255       continue;
2256     }
2257 
2258     // Negate the term back to positive since we'll be subtracting it.
2259     ExprPtr negated =
2260         evaluateOp(alloc<Mul>(immLike(node->scalar(), -1), node->scalar()));
2261     TermPtr newRHS = alloc<Term>(node->hasher(), negated, node->variables());
2262     lastNode = alloc<Sub>(lastNode, newRHS->accept_mutator(this));
2263   }
2264 
2265   if (scalarWritten || immediateEquals(v->scalar(), 0)) {
2266     if (!lastNode) {
2267       return immLike(v, 0);
2268     }
2269     return lastNode;
2270   }
2271 
2272   if (immediateIsNegative(v->scalar())) {
2273     // Negate the scalar and subtract.
2274     ExprPtr negated =
2275         evaluateOp(alloc<Mul>(immLike(lastNode, -1), v->scalar()));
2276     lastNode = alloc<Sub>(lastNode, evaluateOp(negated));
2277   } else {
2278     // we want to avoid a cast to the scalar if it would happen.
2279     if (v->scalar()->dtype() != lastNode->dtype()) {
2280       lastNode = alloc<Add>(
2281           lastNode, evaluateOp(alloc<Cast>(lastNode->dtype(), v->scalar())));
2282     } else {
2283       lastNode = alloc<Add>(lastNode, v->scalar());
2284     }
2285   }
2286 
2287   return lastNode;
2288 }
2289 
mutate(const MaxTermPtr & v)2290 ExprPtr TermExpander::mutate(const MaxTermPtr& v) {
2291   auto& variables = v->variables();
2292   if (variables.empty()) {
2293     if (!v->scalar()) {
2294       // This case should never happen because MaxTerm will be created only
2295       // on valid Max expressions.
2296       throw std::logic_error("empty maxterm op");
2297     }
2298     return v->scalar();
2299   }
2300   ExprPtr max;
2301   if (v->scalar()) {
2302     max = alloc<Max>(variables[0], v->scalar(), v->propagate_nans());
2303   } else {
2304     max = variables[0];
2305   }
2306   for (size_t i = 1; i < variables.size(); i++) {
2307     max = alloc<Max>(max, variables[i], v->propagate_nans());
2308   }
2309   return max->accept_mutator(this);
2310 }
2311 
mutate(const MinTermPtr & v)2312 ExprPtr TermExpander::mutate(const MinTermPtr& v) {
2313   auto& variables = v->variables();
2314   if (variables.empty()) {
2315     if (!v->scalar()) {
2316       // This case should never happen because MinTerm will be created only
2317       // on valid Min expressions.
2318       throw std::logic_error("empty minterm op");
2319     }
2320     return v->scalar();
2321   }
2322   ExprPtr min;
2323   if (v->scalar()) {
2324     min = alloc<Min>(variables[0], v->scalar(), v->propagate_nans());
2325   } else {
2326     min = variables[0];
2327   }
2328   for (size_t i = 1; i < variables.size(); i++) {
2329     min = alloc<Min>(min, variables[i], v->propagate_nans());
2330   }
2331   return min->accept_mutator(this);
2332 }
2333 
2334 // Expands RoundOff(x, y) => Term(1, Div(x, y), y), which will later be expanded
2335 // to Mul(Div(x, y), y).
mutate(const RoundOffPtr & v)2336 ExprPtr TermExpander::mutate(const RoundOffPtr& v) {
2337   TermPtr term = alloc<Term>(
2338       simplifier_->hasher(),
2339       immLike(v, 1),
2340       alloc<Div>(v->lhs(), v->rhs()),
2341       v->rhs());
2342   return term->accept_mutator(this);
2343 }
2344 
buf_flat_size(const BufPtr & v)2345 ExprPtr buf_flat_size(const BufPtr& v) {
2346   std::vector<ExprPtr> dims = v->dims();
2347   if (dims.empty()) {
2348     return alloc<LongImm>(1);
2349   }
2350   ExprPtr flattened = immLike(dims[0], 1);
2351   for (auto& dim : dims) {
2352     flattened = alloc<Mul>(flattened, dim);
2353   }
2354   flattened = IRSimplifier::simplify(flattened);
2355 
2356   return flattened;
2357 }
2358 
mutate(const AllocatePtr & v)2359 StmtPtr TermExpander::mutate(const AllocatePtr& v) {
2360   BufPtr buf = v->buf();
2361   BufPtr buf_new = to<Buf>(v->buf()->accept_mutator(this));
2362   TORCH_INTERNAL_ASSERT(
2363       buf_new,
2364       buildErrorMessage("TermExpander mutation produced null for Buf."));
2365   ExprPtr flattened = buf_flat_size(buf_new);
2366 
2367   if (flattened->isConstant() && immediateEquals(flattened, 0)) {
2368     eliminated_allocations_.insert(buf_new->base_handle());
2369     return nullptr;
2370   }
2371 
2372   if (buf != buf_new) {
2373     v->set_buf(buf_new);
2374   }
2375   return v;
2376 }
2377 
mutate(const FreePtr & v)2378 StmtPtr TermExpander::mutate(const FreePtr& v) {
2379   BufPtr buf = v->buf();
2380   BufPtr buf_new = to<Buf>(v->buf()->accept_mutator(this));
2381   TORCH_INTERNAL_ASSERT(
2382       buf_new,
2383       buildErrorMessage("TermExpander mutation produced null for Buf."));
2384 
2385   if (eliminated_allocations_.count(buf_new->base_handle())) {
2386     eliminated_allocations_.erase(buf_new->base_handle());
2387     return nullptr;
2388   }
2389 
2390   if (buf != buf_new) {
2391     v->set_buf(buf_new);
2392   }
2393   return v;
2394 }
2395 
2396 // Combines adjacent Cond nodes with identical conditions.
fuseConditions(BlockPtr v)2397 BlockPtr TermExpander::fuseConditions(BlockPtr v) {
2398   std::vector<StmtPtr> stmts;
2399   bool did_anything = false;
2400   CondPtr prev_cond = nullptr;
2401 
2402   for (const auto& s : *v) {
2403     CondPtr cond = to<Cond>(s);
2404     if (!cond) {
2405       prev_cond = nullptr;
2406       stmts.push_back(s);
2407       continue;
2408     }
2409 
2410     // If the previous statement is a Cond and the conditions are identical,
2411     // then we fuse.
2412     if (!prev_cond ||
2413         hasher_.hash(prev_cond->condition()) !=
2414             hasher_.hash(cond->condition())) {
2415       prev_cond = cond;
2416       stmts.push_back(s);
2417       continue;
2418     }
2419 
2420     // Fuse the two Conds by appending the bodies of the second Cond to the
2421     // first.
2422     BlockPtr true_block = alloc<Block>(std::vector<StmtPtr>({}));
2423     BlockPtr false_block = alloc<Block>(std::vector<StmtPtr>({}));
2424 
2425     if (prev_cond->true_stmt()) {
2426       true_block->splice(true_block->end(), prev_cond->true_stmt());
2427     }
2428 
2429     if (cond->true_stmt()) {
2430       true_block->splice(true_block->end(), cond->true_stmt());
2431     }
2432 
2433     if (prev_cond->false_stmt()) {
2434       false_block->splice(false_block->end(), prev_cond->false_stmt());
2435     }
2436 
2437     if (cond->false_stmt()) {
2438       false_block->splice(false_block->end(), cond->false_stmt());
2439     }
2440 
2441     // avoid unflattening this Cond if we can.
2442     if (true_block->empty()) {
2443       true_block = nullptr;
2444     }
2445 
2446     if (false_block->empty()) {
2447       false_block = nullptr;
2448     }
2449 
2450     StmtPtr new_cond = prev_cond->cloneWithNewBodies(true_block, false_block)
2451                            ->accept_mutator(this);
2452     prev_cond = to<Cond>(new_cond);
2453 
2454     // erase, which shortens the list.
2455     stmts.pop_back();
2456     stmts.push_back(new_cond);
2457     did_anything = true;
2458   }
2459 
2460   if (!did_anything) {
2461     return v;
2462   }
2463 
2464   // clean up parents.
2465   for (const auto& s : stmts) {
2466     if (s->get_parent() == v) {
2467       v->remove_stmt(s);
2468     }
2469   }
2470 
2471   return alloc<Block>(stmts);
2472 }
2473 
fuseSyncThreads(BlockPtr block)2474 StmtPtr TermExpander::fuseSyncThreads(BlockPtr block) {
2475   // only really first if highest level Block.
2476   bool first = block->get_parent() == nullptr;
2477   SyncThreadsPtr last = nullptr;
2478   std::vector<StmtPtr> stmts;
2479   bool did_anything = false;
2480 
2481   for (const auto& s : *block) {
2482     SyncThreadsPtr sync = to<SyncThreads>(s);
2483     if (!sync) {
2484       first = false;
2485       last = nullptr;
2486       stmts.push_back(s);
2487       continue;
2488     }
2489 
2490     if (first || last) {
2491       did_anything = true;
2492       continue;
2493     }
2494 
2495     last = sync;
2496     first = false;
2497     stmts.push_back(s);
2498   }
2499 
2500   if (last) {
2501     stmts.pop_back();
2502     did_anything = true;
2503   }
2504 
2505   if (!did_anything) {
2506     return block;
2507   }
2508 
2509   // clean up parents.
2510   for (const auto& s : stmts) {
2511     if (s->get_parent() == block) {
2512       block->remove_stmt(s);
2513     }
2514   }
2515 
2516   return alloc<Block>(std::vector<StmtPtr>({stmts}));
2517 }
2518 
mutate(const BlockPtr & v)2519 StmtPtr TermExpander::mutate(const BlockPtr& v) {
2520   StmtPtr new_stmt = PolynomialBase::mutate(v);
2521   BlockPtr new_block = to<Block>(new_stmt);
2522   if (!new_block) {
2523     return new_stmt;
2524   }
2525 
2526   // fuseConditions will return the original block if it cannot fuse.
2527   new_block = fuseConditions(new_block);
2528   /// fuseSyncThreads too.
2529   return fuseSyncThreads(new_block);
2530 }
2531 
2532 // SimplifierUnderContext
2533 //
2534 // This function records the bounds(range) info of the index var in a for-stmt.
2535 // The bounds info will be used later when simplifying expressions with the
2536 // index var.
mutate(const ForPtr & v)2537 StmtPtr SimplifierUnderContext::mutate(const ForPtr& v) {
2538   ExprPtr var = v->var();
2539   ExprPtr start = v->start();
2540   ExprPtr stop = v->stop();
2541   StmtPtr body = v->body();
2542   LoopOptions loop_options = v->loop_options();
2543   ExprPtr var_new_expr = var->accept_mutator(this);
2544   VarPtr var_new = to<Var>(var_new_expr);
2545   ExprPtr start_new = start->accept_mutator(this);
2546   ExprPtr stop_new = stop->accept_mutator(this);
2547   StmtPtr body_new = body;
2548 
2549   // save bounds info before this for-stmt
2550   //
2551   // The same variable could have appeared in a if-stmt which the for-stmt is
2552   // nested inside, and we need to restore its bounds info after the for-stmt.
2553   //
2554   // An example,
2555   // if (i>=0 && i<5) {
2556   //   for (i=0; i<3; i++){
2557   //     A[i] = ...
2558   //   }
2559   //   x = (i+20) / 5;
2560   //}
2561   // Inside the if stmt, i is in the range of [0, 5); and if we can restore this
2562   // bound info after the for stmt, we can use it to simplify the assignment
2563   // stmt x = (i+20)/5 to x = 4.
2564   bool has_bounds = false;
2565   analysis::Bound bound_old;
2566   VarPtr var_key = to<Var>(var);
2567   auto got = var_bound_info_.find(var_key);
2568   if (got != var_bound_info_.end()) {
2569     has_bounds = true;
2570     bound_old = got->second;
2571   }
2572   // set bounds info for index var
2573   const analysis::Bound bound_new(start_new, stop_new);
2574   var_bound_info_[var_key] = bound_new;
2575 
2576   ExprPtr iters = alloc<Sub>(stop_new, start_new);
2577   iters = iters->accept_mutator(this);
2578   if (loop_options.isDefault() && iters->isConstant()) {
2579     if (immediateEquals(iters, 0)) {
2580       return alloc<Block>(std::vector<StmtPtr>({}));
2581     } else if (immediateEquals(iters, 1)) {
2582       body_new = Substitute(body, {{var_new, start_new}});
2583       body_new = body_new->accept_mutator(this);
2584 
2585       // erase index var bounds info or restore old bounds info
2586       if (has_bounds) {
2587         var_bound_info_[var_key] = bound_old;
2588       } else {
2589         var_bound_info_.erase(var_key);
2590       }
2591 
2592       return body_new;
2593     }
2594   }
2595 
2596   body_new = body_new->accept_mutator(this);
2597 
2598   // erase index var bounds info or restore old bounds info
2599   if (has_bounds) {
2600     var_bound_info_[var_key] = bound_old;
2601   } else {
2602     var_bound_info_.erase(var_key);
2603   }
2604 
2605   if (!body_new) {
2606     return alloc<Block>(std::vector<StmtPtr>({}));
2607   }
2608 
2609   if (auto block = to<Block>(body_new)) {
2610     if (block->nstmts() == 0) {
2611       return alloc<Block>(std::vector<StmtPtr>({}));
2612     }
2613 
2614     if (block->nstmts() == 1) {
2615       // if the stmt in the loop body is a if-stmt, try to move the branching
2616       // out of the loop
2617       if (auto cond = to<Cond>(block->front())) {
2618         StmtPtr reordered = handleForCondReordering(v, cond);
2619         if (reordered) {
2620           return reordered->accept_mutator(this);
2621         }
2622       }
2623     }
2624   }
2625 
2626   if (var != var_new) {
2627     v->set_var(var_new);
2628   }
2629   if (start != start_new) {
2630     v->set_start(start_new);
2631   }
2632   if (stop != stop_new) {
2633     v->set_stop(stop_new);
2634   }
2635   if (body != body_new) {
2636     v->set_body(body_new);
2637   }
2638   return v;
2639 }
2640 
2641 // Simplify division using distributive laws for the following cases:
2642 // 1) (i + x) / n => x/n, if
2643 //   a) n is a positive integer constant;
2644 //   b) i is the index var of a for-stmt and the range of i is
2645 // a subset of [0, n);
2646 //   c) x is a constant and the end value of i's range is less than n - x%n;
2647 //   TODO: remove d) from the requirements because the simplification formula
2648 //   still holds when x is a negative integer. In integer division, the result
2649 //   of the division is converted to an integer using `floor` function which
2650 //   returns the largest integer that is not greater than X. For example, -1/6
2651 //   returns -1. But currently, both Pytorch and NNC are performing an incorrect
2652 //   integer division: (-1)/6 = 0. With the current implementation of integer
2653 //   division, x has to be not negative. d) x is not negative
2654 //
2655 // 2) (i + j*n) / n => j, if
2656 //   a) n is a positive integer constant;
2657 //   b) i is the index var of a for-stmt and the range of i is
2658 // a subset of [0, n);
2659 //   c) j is an integer variable;
2660 //   TODO: remove d) from the requirements because the simplification formula
2661 //   still holds when j is a negative integer. In integer division, the result
2662 //   of the division is converted to an integer using `floor` function which
2663 //   returns the largest integer that is not greater than X. For example, -1/6
2664 //   returns -1. But currently, both Pytorch and NNC are performing an incorrect
2665 //   integer division: (-1)/6 = 0. With the current implementation of integer
2666 //   division, x has to be not negative. d) j is not negative
distributeDiv(const ExprPtr & lhs,const ExprPtr & rhs,VarBoundInfo var_bound_info)2667 static ExprPtr distributeDiv(
2668     const ExprPtr& lhs,
2669     const ExprPtr& rhs,
2670     VarBoundInfo var_bound_info) {
2671   if (!lhs || !rhs) {
2672     return nullptr;
2673   }
2674   // return if not integer division
2675   if (lhs->dtype().is_floating_point() || rhs->dtype().is_floating_point()) {
2676     return nullptr;
2677   }
2678 
2679   // identify n: a positive integer constant
2680   ExprPtr rhsScalar = rhs->isConstant() ? rhs : nullptr;
2681   if (!rhsScalar) {
2682     return nullptr;
2683   }
2684   ExprPtr check_n_value = IRSimplifier::simplify(
2685       alloc<CompareSelect>(rhsScalar, immLike(rhsScalar, 0), kGT));
2686   if (!immediateEquals(check_n_value, 1)) {
2687     return nullptr;
2688   }
2689 
2690   auto lhsAdd = to<Add>(lhs);
2691   if (!lhsAdd) {
2692     return nullptr;
2693   }
2694   ExprPtr lhsAdd1 = lhsAdd->lhs();
2695   ExprPtr lhsAdd2 = lhsAdd->rhs();
2696 
2697   // identify index var 'i'
2698   VarPtr var_key = to<Var>(lhsAdd1);
2699   ExprPtr main = lhsAdd2;
2700   if (var_key == nullptr) {
2701     var_key = to<Var>(lhsAdd2);
2702     main = lhsAdd1;
2703   }
2704 
2705   if (var_key == nullptr) {
2706     return nullptr;
2707   }
2708 
2709   auto got = var_bound_info.find(var_key);
2710   if (got == var_bound_info.end()) {
2711     return nullptr;
2712   }
2713 
2714   // check the bounds of 'i'
2715   auto start = got->second.start;
2716   // open upper bound, i.e.,  end is one more than the maximum value in the
2717   // range
2718   auto end = got->second.end;
2719   ExprPtr check_start = IRSimplifier::simplify(
2720       alloc<CompareSelect>(start, immLike(start, 0), kGE));
2721   ExprPtr check_end =
2722       IRSimplifier::simplify(alloc<CompareSelect>(end, rhsScalar, kLE));
2723   if (!check_start->isConstant() || !check_end->isConstant() ||
2724       !immediateEquals(check_start, 1) || !immediateEquals(check_end, 1)) {
2725     return nullptr;
2726   }
2727 
2728   ExprPtr ret = IRSimplifier::simplify(alloc<Div>(main, rhsScalar));
2729 
2730   // simplify type 1) exprs: '(i+x)/n' => 'x/n'
2731   ExprPtr sign_check =
2732       IRSimplifier::simplify(alloc<CompareSelect>(main, immLike(main, 0), kGE));
2733   ExprPtr main_mod = IRSimplifier::simplify(alloc<Mod>(main, rhsScalar));
2734   ExprPtr mod_check = IRSimplifier::simplify(
2735       alloc<CompareSelect>(alloc<Add>(main_mod, end), rhsScalar, kLE));
2736   if (sign_check->isConstant() && immediateEquals(sign_check, 1) &&
2737       mod_check->isConstant() && immediateEquals(mod_check, 1)) {
2738     return ret;
2739   }
2740 
2741   // simplify type 2 exprs: '(i+j*n)/n' => 'j'
2742   auto ret_var = to<Var>(ret);
2743   // FIXME: Allow any integral type.
2744   if (ret_var && ret_var->dtype() == kInt) {
2745     // retrieve j's range info
2746     auto got = var_bound_info.find(ret_var);
2747     if (got == var_bound_info.end()) {
2748       return nullptr;
2749     }
2750 
2751     // check if j is not negative
2752     sign_check = IRSimplifier::simplify(alloc<CompareSelect>(
2753         got->second.start, immLike(got->second.start, 0), kGE));
2754     if (sign_check->isConstant() && immediateEquals(sign_check, 1)) {
2755       return ret_var;
2756     }
2757   }
2758 
2759   return nullptr;
2760 }
2761 
2762 // Simplify mod using distributive laws for the following cases:
2763 // 1) (i + x) % n => i + x%n if
2764 //   a) n is a positive integer constant;
2765 //   b) i is the index var of a for-stmt and the range of i is
2766 // a subset of [0, n);
2767 //   c) x is a constant and the end value of i's range is less than n - x%n;
2768 //   TODO: remove d) from the requirements because the simplification formula
2769 //   still holds when x is a negative integer. In integer division, the result
2770 //   of the division is converted to an integer using `floor` function which
2771 //   returns the largest integer that is not greater than X. For example, -1/6
2772 //   returns -1. But currently, both Pytorch and NNC are performing an incorrect
2773 //   integer division: (-1)/6 = 0. With the current implementation of integer
2774 //   division, x has to be not negative. d) x is not negative
2775 //
2776 // 2) (i + j*n) % n => i if
2777 //   a) n is a positive integer constant;
2778 //   b) i is the index var of a for-stmt and the range of i is
2779 // a subset of [0, n);
2780 //   c) j is an integer variable;
2781 //   TODO: remove d) from the requirements because the simplification formula
2782 //   still holds when j is a negative integer. In integer division, the result
2783 //   of the division is converted to an integer using `floor` function which
2784 //   returns the largest integer that is not greater than X. For example, -1/6
2785 //   returns -1. But currently, both Pytorch and NNC are performing an incorrect
2786 //   integer division: (-1)/6 = 0. With the current implementation of integer
2787 //   division, j has to be not negative. d) j is not negative
distributeMod(const ExprPtr & lhs,const ExprPtr & rhs,VarBoundInfo var_bound_info)2788 static ExprPtr distributeMod(
2789     const ExprPtr& lhs,
2790     const ExprPtr& rhs,
2791     VarBoundInfo var_bound_info) {
2792   if (!lhs || !rhs) {
2793     return nullptr;
2794   }
2795   // return if not integer mod
2796   if (lhs->dtype().is_floating_point() || rhs->dtype().is_floating_point()) {
2797     return nullptr;
2798   }
2799 
2800   // identify n: a positive integer constant
2801   ExprPtr rhsScalar = rhs->isConstant() ? rhs : nullptr;
2802   if (!rhsScalar) {
2803     return nullptr;
2804   }
2805   ExprPtr check_n_value = IRSimplifier::simplify(
2806       alloc<CompareSelect>(rhsScalar, immLike(rhsScalar, 0), kGT));
2807   if (!immediateEquals(check_n_value, 1)) {
2808     return nullptr;
2809   }
2810 
2811   auto lhsAdd = to<Add>(lhs);
2812   if (!lhsAdd) {
2813     return nullptr;
2814   }
2815   if (!lhsAdd || !rhsScalar) {
2816     return nullptr;
2817   }
2818   ExprPtr lhsAdd1 = lhsAdd->lhs();
2819   ExprPtr lhsAdd2 = lhsAdd->rhs();
2820 
2821   // identify index var 'i'
2822   VarPtr var_key = to<Var>(lhsAdd1);
2823   ExprPtr main = lhsAdd2;
2824   if (var_key == nullptr) {
2825     var_key = to<Var>(lhsAdd2);
2826     main = lhsAdd1;
2827   }
2828   if (var_key == nullptr) {
2829     return nullptr;
2830   }
2831 
2832   auto got = var_bound_info.find(var_key);
2833   if (got == var_bound_info.end()) {
2834     return nullptr;
2835   }
2836 
2837   // check the bounds of 'i'
2838   auto start = got->second.start;
2839   // open upper bound, i.e.,  end is one more than the maximum value in the
2840   // range
2841   auto end = got->second.end;
2842   ExprPtr check_start = IRSimplifier::simplify(
2843       alloc<CompareSelect>(start, immLike(start, 0), kGE));
2844   ExprPtr check_end =
2845       IRSimplifier::simplify(alloc<CompareSelect>(end, rhsScalar, kLE));
2846   if (!check_start->isConstant() || !check_end->isConstant() ||
2847       !immediateEquals(check_start, 1) || !immediateEquals(check_end, 1)) {
2848     return nullptr;
2849   }
2850 
2851   // simplify type 1) exprs: '(i+x)%n' => 'i+x%n'
2852   ExprPtr sign_check =
2853       IRSimplifier::simplify(alloc<CompareSelect>(main, immLike(main, 0), kGE));
2854   ExprPtr main_mod = IRSimplifier::simplify(alloc<Mod>(main, rhsScalar));
2855   ExprPtr mod_check = IRSimplifier::simplify(
2856       alloc<CompareSelect>(alloc<Add>(main_mod, end), rhsScalar, kLE));
2857   if (sign_check->isConstant() && immediateEquals(sign_check, 1) &&
2858       mod_check->isConstant() && immediateEquals(mod_check, 1)) {
2859     return alloc<Add>(var_key, main_mod);
2860   }
2861 
2862   // simplify type 2) exprs: '(i+j*n)%n' => 'i'
2863   ExprPtr main_div = IRSimplifier::simplify(alloc<Div>(main, rhsScalar));
2864   auto j_var = to<Var>(main_div);
2865   // FIXME: Allow any integral type.
2866   if (j_var && j_var->dtype() == kInt) {
2867     // retrieve j's range info
2868     auto got = var_bound_info.find(j_var);
2869     if (got == var_bound_info.end()) {
2870       return nullptr;
2871     }
2872 
2873     // check if j is not negative
2874     sign_check = IRSimplifier::simplify(alloc<CompareSelect>(
2875         got->second.start, immLike(got->second.start, 0), kGE));
2876     if (sign_check->isConstant() && immediateEquals(sign_check, 1)) {
2877       return var_key;
2878     }
2879   }
2880 
2881   return nullptr;
2882 }
2883 
mutate(const DivPtr & v)2884 ExprPtr SimplifierUnderContext::mutate(const DivPtr& v) {
2885   ExprPtr lhs = v->lhs();
2886   ExprPtr rhs = v->rhs();
2887 
2888   std::ostringstream oss;
2889   if (auto ret = distributeDiv(lhs, rhs, var_bound_info_)) {
2890     GRAPH_DEBUG("SimplifierUnderContext: ", *v, " => ", *ret);
2891     return ret->accept_mutator(this);
2892   }
2893 
2894   // i / N -> 0 if the range of i's values is a subset of [0, N)
2895   // where N is an integer constant
2896   auto lhsVar = to<Var>(lhs);
2897   ExprPtr rhsScalar = rhs->isConstant() ? rhs : nullptr;
2898   if (lhsVar && rhsScalar && !rhsScalar->dtype().is_floating_point()) {
2899     auto got = var_bound_info_.find(lhsVar);
2900     if (got != var_bound_info_.end()) {
2901       auto start = got->second.start;
2902       auto end = got->second.end;
2903       ExprPtr check_start = IRSimplifier::simplify(
2904           alloc<CompareSelect>(start, immLike(start, 0), kGE));
2905       ExprPtr check_end =
2906           IRSimplifier::simplify(alloc<CompareSelect>(end, rhsScalar, kLE));
2907       if (check_start->isConstant() && check_end->isConstant() &&
2908           immediateEquals(check_start, 1) && immediateEquals(check_end, 1)) {
2909         GRAPH_DEBUG(
2910             "SimplifierUnderContext: ", *v, " => ", *immLike(lhsVar, 0));
2911         return immLike(lhsVar, 0);
2912       }
2913     }
2914   }
2915 
2916   ExprPtr lhs_new = lhs->accept_mutator(this);
2917   ExprPtr rhs_new = rhs->accept_mutator(this);
2918   if (lhs == lhs_new && rhs == rhs_new) {
2919     return v;
2920   }
2921   return alloc<Div>(lhs_new, rhs_new);
2922 }
2923 
mutate(const IfThenElsePtr & v)2924 ExprPtr SimplifierUnderContext::mutate(const IfThenElsePtr& v) {
2925   ExprPtr condition = v->condition();
2926   ExprPtr true_val = v->true_value();
2927   ExprPtr false_val = v->false_value();
2928 
2929   auto simplified_condition =
2930       IRSimplifier::simplify(condition->accept_mutator(this));
2931   auto simplified_true_val =
2932       IRSimplifier::simplify(true_val->accept_mutator(this));
2933   auto simplified_false_val =
2934       IRSimplifier::simplify(false_val->accept_mutator(this));
2935   if (simplified_condition->isConstant()) {
2936     return immediateAs<int>(simplified_condition) ? simplified_true_val
2937                                                   : simplified_false_val;
2938   }
2939 
2940   bool nothing_changed = (simplified_condition == condition) &&
2941       (simplified_true_val == true_val) && (simplified_false_val == false_val);
2942   return nothing_changed
2943       ? v
2944       : alloc<IfThenElse>(
2945             simplified_condition, simplified_true_val, simplified_false_val);
2946 }
2947 
mutate(const CompareSelectPtr & v)2948 ExprPtr SimplifierUnderContext::mutate(const CompareSelectPtr& v) {
2949   GRAPH_DEBUG("(SimplifierUnderContext) Original: ", std::to_string(v));
2950 
2951   ExprPtr lhs = v->lhs();
2952   ExprPtr rhs = v->rhs();
2953   ExprPtr ret1 = v->ret_val1();
2954   ExprPtr ret2 = v->ret_val2();
2955 
2956   auto simplified_lhs = IRSimplifier::simplify(lhs->accept_mutator(this));
2957   auto simplified_rhs = IRSimplifier::simplify(rhs->accept_mutator(this));
2958   auto simplified_ret1 = IRSimplifier::simplify(ret1->accept_mutator(this));
2959   auto simplified_ret2 = IRSimplifier::simplify(ret2->accept_mutator(this));
2960 
2961   ExprPtr simplified_cmp_select_expr = nullptr;
2962   if ((simplified_lhs == lhs) && (simplified_rhs == rhs) &&
2963       (simplified_ret1 == ret1) && (simplified_ret2 == ret2)) {
2964     simplified_cmp_select_expr = v;
2965   } else {
2966     simplified_cmp_select_expr = alloc<CompareSelect>(
2967         simplified_lhs,
2968         simplified_rhs,
2969         simplified_ret1,
2970         simplified_ret2,
2971         v->compare_select_op(),
2972         v->bias());
2973   }
2974 
2975   GRAPH_DEBUG(
2976       "(SimplifierUnderContext) after simplify: ",
2977       std::to_string(simplified_cmp_select_expr));
2978 
2979   analysis::Bound lhs_bound;
2980   analysis::Bound rhs_bound;
2981   auto lhs_has_bound = getLoopBoundInfo(simplified_lhs, &lhs_bound);
2982   auto rhs_has_bound = getLoopBoundInfo(simplified_rhs, &rhs_bound);
2983   if (!lhs_has_bound || !rhs_has_bound) {
2984     GRAPH_DEBUG(
2985         "(SimplifierUnderContext) Final: ",
2986         std::to_string(simplified_cmp_select_expr));
2987     return simplified_cmp_select_expr;
2988   }
2989 
2990   analysis::CmpEvalResult cmp_res =
2991       analysis::compareBound(lhs_bound, rhs_bound, v->compare_select_op());
2992 
2993   // Return the simplified ret1/ret2 if the compare result is deterministic.
2994   // Otherwise, return the simplified CompareSelect directly.
2995   auto ret_expr = (cmp_res == analysis::CmpEvalResult::True)
2996       ? simplified_ret1
2997       : ((cmp_res == analysis::CmpEvalResult::False)
2998              ? simplified_ret2
2999              : simplified_cmp_select_expr);
3000   GRAPH_DEBUG("(SimplifierUnderContext) Final: ", std::to_string(ret_expr));
3001   return ret_expr;
3002 }
3003 
mutate(const ModPtr & v)3004 ExprPtr SimplifierUnderContext::mutate(const ModPtr& v) {
3005   ExprPtr lhs = v->lhs();
3006   ExprPtr rhs = v->rhs();
3007 
3008   std::ostringstream oss;
3009   if (auto ret = distributeMod(lhs, rhs, var_bound_info_)) {
3010     GRAPH_DEBUG("SimplifierUnderContext: ", *v, " => ", *ret);
3011     return ret->accept_mutator(this);
3012   }
3013 
3014   // i % N -> i if the range of i's values is a subset of [0, N)
3015   // where N is an integer constant
3016   auto lhsVar = to<Var>(lhs);
3017   ExprPtr rhsScalar = rhs->isConstant() ? rhs : nullptr;
3018   if (lhsVar && rhsScalar && !rhsScalar->dtype().is_floating_point()) {
3019     auto got = var_bound_info_.find(lhsVar);
3020     if (got != var_bound_info_.end()) {
3021       auto start = got->second.start;
3022       auto end = got->second.end;
3023       ExprPtr check_start = IRSimplifier::simplify(
3024           alloc<CompareSelect>(start, immLike(start, 0), kGE));
3025       ExprPtr check_end =
3026           IRSimplifier::simplify(alloc<CompareSelect>(end, rhsScalar, kLE));
3027       if (check_start->isConstant() && check_end->isConstant() &&
3028           immediateEquals(check_start, 1) && immediateEquals(check_end, 1)) {
3029         GRAPH_DEBUG("SimplifierUnderContext: ", *v, " => ", *lhsVar);
3030         return lhsVar;
3031       }
3032     }
3033   }
3034 
3035   ExprPtr lhs_new = lhs->accept_mutator(this);
3036   ExprPtr rhs_new = rhs->accept_mutator(this);
3037   if (lhs == lhs_new && rhs == rhs_new) {
3038     return v;
3039   }
3040   return alloc<Mod>(lhs_new, rhs_new);
3041 }
3042 
getLoopBoundInfo(const ExprPtr & expr,analysis::Bound * loop_bound_info)3043 bool SimplifierUnderContext::getLoopBoundInfo(
3044     const ExprPtr& expr,
3045     analysis::Bound* loop_bound_info) {
3046   if (expr == nullptr)
3047     return false;
3048 
3049   if (expr->isConstant()) {
3050     loop_bound_info->start = expr;
3051     loop_bound_info->end = expr;
3052     return true;
3053   }
3054 
3055   VarPtr var_key = to<Var>(expr);
3056   if (var_key == nullptr) {
3057     return false;
3058   }
3059 
3060   auto got = var_bound_info_.find(var_key);
3061   if (got == var_bound_info_.end()) {
3062     return false;
3063   }
3064 
3065   loop_bound_info->start = got->second.start;
3066   // TODO: Need to add the boundary information(close/open) of a range to
3067   // Bound. Currently, the VarBoundInfo comes from for-loop statement while
3068   // the end of the boundary is open. But we assume the start and end of a
3069   // range are always close. Hence, we explicitly convert the open boundary to
3070   // close.
3071   //   [for-start, for-stop) => [for-start, for-stop -1]
3072   loop_bound_info->end = IRSimplifier::simplify(
3073       alloc<Sub>(got->second.end, immLike(got->second.end, 1)));
3074   return true;
3075 }
3076 
exprEquals(const ExprPtr & A,const ExprPtr & B)3077 bool exprEquals(const ExprPtr& A, const ExprPtr& B) {
3078   try {
3079     ExprPtr diff = IRSimplifier::simplify(alloc<Sub>(A, B));
3080     if (!diff->isConstant()) {
3081       return false;
3082     }
3083     return immediateEquals(diff, 0);
3084   } catch (std::exception& e) {
3085     return false;
3086   }
3087 }
3088 
simplify(ExprPtr e)3089 ExprPtr IRSimplifier::simplify(ExprPtr e) {
3090   GRAPH_DEBUG("(Simplifier) Original: ", std::to_string(e));
3091   SimplifierUnderContext ctxsimplifier;
3092   e = e->accept_mutator(&ctxsimplifier);
3093 
3094   PolynomialTransformer simplifier;
3095   e = e->accept_mutator(&simplifier);
3096 
3097   // There may be terms left in the IR, expand them.
3098   TermExpander expander(&simplifier);
3099   e = e->accept_mutator(&expander);
3100   if (!expander.check_safe()) {
3101     throw malformed_input("eliminated null Allocation without free");
3102   }
3103 
3104   GRAPH_DEBUG("(Simplifier) Simplified: ", std::to_string(e));
3105   return e;
3106 }
3107 
simplify(StmtPtr s)3108 StmtPtr IRSimplifier::simplify(StmtPtr s) {
3109   GRAPH_DEBUG("(Simplifier) Original: ", std::to_string(s));
3110   SimplifierUnderContext ctxsimplifier;
3111   s = s->accept_mutator(&ctxsimplifier);
3112 
3113   PolynomialTransformer simplifier;
3114   s = s->accept_mutator(&simplifier);
3115   if (s == nullptr) {
3116     GRAPH_DEBUG("(Simplifier) Simplified: NULL");
3117     return nullptr;
3118   }
3119 
3120   // There may be terms left in the IR, expand them.
3121   TermExpander expander(&simplifier);
3122   s = s->accept_mutator(&expander);
3123   if (!expander.check_safe()) {
3124     throw malformed_input("eliminated null Allocation without free");
3125   }
3126 
3127   GRAPH_DEBUG("(Simplifier) Simplified: ", std::to_string(s));
3128   return s;
3129 }
3130 
3131 } // namespace torch::jit::tensorexpr
3132