1 #include <torch/csrc/jit/jit_log.h>
2 #include <torch/csrc/jit/tensorexpr/bounds_overlap.h>
3 #include <torch/csrc/jit/tensorexpr/ir_printer.h>
4 #include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
5
6 #include <utility>
7
8 namespace torch::jit::tensorexpr {
9
10 // Creates a new Expr of the given type with the provided lhs and rhs.
newBinaryOpOfType(IRNodeType expr_type,const ExprPtr & lhs,const ExprPtr & rhs,bool option)11 inline ExprPtr newBinaryOpOfType(
12 IRNodeType expr_type,
13 const ExprPtr& lhs,
14 const ExprPtr& rhs,
15 bool option) {
16 switch (expr_type) {
17 case IRNodeType::kAdd:
18 return alloc<Add>(lhs, rhs);
19 case IRNodeType::kSub:
20 return alloc<Sub>(lhs, rhs);
21 case IRNodeType::kMul:
22 return alloc<Mul>(lhs, rhs);
23 case IRNodeType::kDiv:
24 return alloc<Div>(lhs, rhs);
25 case IRNodeType::kMod:
26 return alloc<Mod>(lhs, rhs);
27 case IRNodeType::kMax:
28 return alloc<Max>(lhs, rhs, option);
29 case IRNodeType::kMin:
30 return alloc<Min>(lhs, rhs, option);
31 case IRNodeType::kAnd:
32 return alloc<And>(lhs, rhs);
33 case IRNodeType::kXor:
34 return alloc<Xor>(lhs, rhs);
35 case IRNodeType::kLshift:
36 return alloc<Lshift>(lhs, rhs);
37 case IRNodeType::kRshift:
38 return alloc<Rshift>(lhs, rhs);
39 default:
40 LOG(FATAL) << "unsupported expr_type: " << static_cast<int>(expr_type);
41 return nullptr;
42 }
43 }
44
45 template <
46 typename Op,
47 std::enable_if_t<std::is_same_v<
48 decltype(detail::bin_op_deducer(std::declval<Op>())),
49 void>>* = nullptr>
mutateBinaryOp(NodePtr<Op> v,IRMutator * mutator,bool option=false)50 static ExprPtr mutateBinaryOp(
51 NodePtr<Op> v,
52 IRMutator* mutator,
53 bool option = false) {
54 ExprPtr lhs = v->lhs();
55 ExprPtr rhs = v->rhs();
56 ExprPtr lhs_new = lhs->accept_mutator(mutator);
57 ExprPtr rhs_new = rhs->accept_mutator(mutator);
58
59 ExprPtr node = v;
60
61 if (lhs != lhs_new || rhs != rhs_new) {
62 node = newBinaryOpOfType(v->expr_type(), lhs_new, rhs_new, option);
63 }
64
65 // Can only fold if both sides are constant.
66 if (!lhs_new->isConstant() || !rhs_new->isConstant()) {
67 return node;
68 }
69
70 return evaluateOp(node);
71 }
72
73 // Simple recursive GCD.
74 template <typename T>
gcd(T a,T b)75 T gcd(T a, T b) {
76 if (b == 0) {
77 return a;
78 }
79 return gcd(b, a % b);
80 }
81
82 // Helper for determining if an Expr is a multi-lane primitive (e.g. Broadcast
83 // or Ramp).
isMultilanePrimitive(const ExprPtr & e)84 static bool isMultilanePrimitive(const ExprPtr& e) {
85 return to<Broadcast>(e) || to<Ramp>(e);
86 }
87
hashVars() const88 SimplifierHashType Term::hashVars() const {
89 SimplifierHashType hash;
90 for (const auto& v : variables_) {
91 hash = hasher_.hash_combine(hash, hasher_.hash(v));
92 }
93
94 return hash;
95 }
96
sort()97 void Term::sort() {
98 // order of ops important for float
99 if (dtype().is_floating_point()) {
100 throw std::logic_error("reordering FP ops");
101 }
102 std::unordered_map<ExprPtr, std::string> str_repr_cache;
103 std::sort(
104 variables_.begin(),
105 variables_.end(),
106 [&](const ExprPtr& a, const ExprPtr& b) {
107 if (!str_repr_cache.count(a)) {
108 str_repr_cache[a] = std::to_string(a);
109 }
110 if (!str_repr_cache.count(b)) {
111 str_repr_cache[b] = std::to_string(b);
112 }
113 return str_repr_cache.at(a) < str_repr_cache.at(b);
114 });
115 }
116
hashVars() const117 SimplifierHashType Polynomial::hashVars() const {
118 SimplifierHashType hash;
119 for (const auto& v : variables_) {
120 hash = hasher_.hash_combine(hash, hasher_.hash(v));
121 }
122 return hash;
123 }
124
sort()125 void Polynomial::sort() {
126 if (dtype().is_floating_point()) {
127 throw std::logic_error("reordering FP ops");
128 }
129 std::unordered_map<ExprPtr, std::string> str_repr_cache;
130 std::sort(
131 variables_.begin(),
132 variables_.end(),
133 [&](const ExprPtr& a, const ExprPtr& b) {
134 if (!str_repr_cache.count(a)) {
135 str_repr_cache[a] = std::to_string(a);
136 }
137 if (!str_repr_cache.count(b)) {
138 str_repr_cache[b] = std::to_string(b);
139 }
140 return str_repr_cache.at(a) < str_repr_cache.at(b);
141 });
142 }
143
uniquefy()144 void MaxTerm::uniquefy() {
145 std::sort(
146 variables_.begin(),
147 variables_.end(),
148 [&](const ExprPtr& a, const ExprPtr& b) {
149 return hasher_.hash(a) < hasher_.hash(b);
150 });
151 auto it = std::unique(
152 variables_.begin(),
153 variables_.end(),
154 [&](const ExprPtr& a, const ExprPtr& b) {
155 return hasher_.hash(a) == hasher_.hash(b);
156 });
157 variables_.resize(std::distance(variables_.begin(), it));
158
159 // Once we removed duplicates, sort terms alphabetically for stability.
160 std::unordered_map<ExprPtr, std::string> str_repr_cache;
161 std::sort(
162 variables_.begin(),
163 variables_.end(),
164 [&](const ExprPtr& a, const ExprPtr& b) {
165 if (!str_repr_cache.count(a)) {
166 str_repr_cache[a] = std::to_string(a);
167 }
168 if (!str_repr_cache.count(b)) {
169 str_repr_cache[b] = std::to_string(b);
170 }
171 return str_repr_cache.at(a) < str_repr_cache.at(b);
172 });
173 }
174
uniquefy()175 void MinTerm::uniquefy() {
176 std::sort(
177 variables_.begin(),
178 variables_.end(),
179 [&](const ExprPtr& a, const ExprPtr& b) {
180 return hasher_.hash(a) < hasher_.hash(b);
181 });
182 auto it = std::unique(
183 variables_.begin(),
184 variables_.end(),
185 [&](const ExprPtr& a, const ExprPtr& b) {
186 return hasher_.hash(a) == hasher_.hash(b);
187 });
188 variables_.resize(std::distance(variables_.begin(), it));
189
190 // Once we removed duplicates, sort terms alphabetically for stability.
191 std::unordered_map<ExprPtr, std::string> str_repr_cache;
192 std::sort(
193 variables_.begin(),
194 variables_.end(),
195 [&](const ExprPtr& a, const ExprPtr& b) {
196 if (!str_repr_cache.count(a)) {
197 str_repr_cache[a] = std::to_string(a);
198 }
199 if (!str_repr_cache.count(b)) {
200 str_repr_cache[b] = std::to_string(b);
201 }
202 return str_repr_cache.at(a) < str_repr_cache.at(b);
203 });
204 }
205
206 // Handles optimization cases for Broadcast/Ramp +/- Broadcast/Ramp
207 template <class Op>
combineMultilane(const ExprPtr & lhs,const ExprPtr & rhs)208 ExprPtr combineMultilane(const ExprPtr& lhs, const ExprPtr& rhs) {
209 if (BroadcastPtr bc = to<Broadcast>(lhs)) {
210 if (BroadcastPtr bcother = to<Broadcast>(rhs)) {
211 if (bc->lanes() != bcother->lanes()) {
212 throw malformed_input("multilane lane mismatch");
213 }
214
215 ExprPtr ret = alloc<Broadcast>(
216 alloc<Op>(bc->value(), bcother->value()), bc->lanes());
217 return ret;
218 }
219
220 if (RampPtr r = to<Ramp>(rhs)) {
221 if (bc->lanes() != r->lanes()) {
222 throw malformed_input("multilane lane mismatch");
223 }
224
225 ExprPtr ret = alloc<Ramp>(
226 alloc<Op>(bc->value(), r->base()), r->stride(), r->lanes());
227 return ret;
228 }
229 } else if (RampPtr ramp = to<Ramp>(lhs)) {
230 if (RampPtr rother = to<Ramp>(rhs)) {
231 if (ramp->lanes() != rother->lanes()) {
232 throw malformed_input("multilane lane mismatch");
233 }
234
235 ExprPtr ret = alloc<Ramp>(
236 alloc<Op>(ramp->base(), rother->base()),
237 alloc<Op>(ramp->stride(), rother->stride()),
238 ramp->lanes());
239 return ret;
240 }
241
242 if (BroadcastPtr bc = to<Broadcast>(rhs)) {
243 if (ramp->lanes() != bc->lanes()) {
244 throw malformed_input("multilane lane mismatch");
245 }
246 ExprPtr ret = alloc<Ramp>(
247 alloc<Op>(ramp->base(), bc->value()), ramp->stride(), ramp->lanes());
248 return ret;
249 }
250 }
251
252 return nullptr;
253 }
254
255 // Handles optimization cases for Broadcast/Ramp * Broadcast/Ramp
mulMultilane(const ExprPtr & lhs,const ExprPtr & rhs)256 static ExprPtr mulMultilane(const ExprPtr& lhs, const ExprPtr& rhs) {
257 if (BroadcastPtr bc = to<Broadcast>(lhs)) {
258 if (BroadcastPtr bcother = to<Broadcast>(rhs)) {
259 if (bc->lanes() != bcother->lanes()) {
260 throw malformed_input("multilane lane mismatch");
261 }
262
263 ExprPtr ret = alloc<Broadcast>(
264 alloc<Mul>(bc->value(), bcother->value()), bc->lanes());
265 return ret;
266 }
267
268 if (RampPtr r = to<Ramp>(rhs)) {
269 if (bc->lanes() != r->lanes()) {
270 throw malformed_input("multilane lane mismatch");
271 }
272
273 ExprPtr ret = alloc<Ramp>(
274 alloc<Mul>(bc->value(), r->base()),
275 alloc<Mul>(bc->value(), r->stride()),
276 r->lanes());
277 return ret;
278 }
279 } else if (RampPtr ramp = to<Ramp>(lhs)) {
280 if (RampPtr r = to<Ramp>(rhs)) {
281 if (ramp->lanes() != r->lanes()) {
282 throw malformed_input("multilane lane mismatch");
283 }
284
285 ExprPtr ret = alloc<Ramp>(
286 alloc<Mul>(ramp->base(), r->base()),
287 alloc<Mul>(ramp->stride(), r->stride()),
288 r->lanes());
289 return ret;
290 }
291
292 if (BroadcastPtr bc = to<Broadcast>(rhs)) {
293 if (ramp->lanes() != bc->lanes()) {
294 throw malformed_input("multilane lane mismatch");
295 }
296
297 ExprPtr ret = alloc<Ramp>(
298 alloc<Mul>(bc->value(), ramp->base()),
299 alloc<Mul>(bc->value(), ramp->stride()),
300 ramp->lanes());
301 return ret;
302 }
303 }
304
305 return nullptr;
306 }
307
addOrUpdateTerm(std::unordered_map<SimplifierHashType,TermPtr> & varmap,const TermPtr & term)308 void PolynomialTransformer::addOrUpdateTerm(
309 std::unordered_map<SimplifierHashType, TermPtr>& varmap,
310 const TermPtr& term) {
311 SimplifierHashType hash = term->hashVars();
312 auto insertRes = varmap.emplace(hash, term);
313 if (insertRes.second == false) {
314 TermPtr lt = insertRes.first->second;
315 ExprPtr termScalar = evaluateOp(alloc<Add>(lt->scalar(), term->scalar()));
316
317 // If the term is canceled out, remove from the map.
318 if (immediateEquals(termScalar, 0)) {
319 varmap.erase(hash);
320 return;
321 }
322
323 varmap[hash] = alloc<Term>(hasher_, termScalar, lt->variables());
324 }
325 }
326
addPolynomials(const PolynomialPtr & lhs,const PolynomialPtr & rhs)327 ExprPtr PolynomialTransformer::addPolynomials(
328 const PolynomialPtr& lhs,
329 const PolynomialPtr& rhs) {
330 // simplify common components
331 // The key here is the variable hash, not the term's hash since we do want
332 // to combine terms that have the same vars but different scalar components.
333 std::unordered_map<SimplifierHashType, TermPtr> varmap;
334
335 for (const auto& lt : lhs->variables()) {
336 addOrUpdateTerm(varmap, lt);
337 }
338 for (const auto& rt : rhs->variables()) {
339 addOrUpdateTerm(varmap, rt);
340 }
341
342 ExprPtr newScalar = evaluateOp(alloc<Add>(lhs->scalar(), rhs->scalar()));
343 return alloc<Polynomial>(hasher_, newScalar, varmap);
344 }
345
346 // Insert a new Term into the provided polynomial. If the new term has common
347 // variables to an existing term it is combined.
insertTerm(const PolynomialPtr & poly,const TermPtr & term)348 ExprPtr PolynomialTransformer::insertTerm(
349 const PolynomialPtr& poly,
350 const TermPtr& term) {
351 SimplifierHashType tHash = term->hashVars();
352 std::vector<TermPtr> newVars;
353
354 bool found = false;
355 for (const auto& v : poly->variables()) {
356 if (v->hashVars() == tHash) {
357 ExprPtr newScalar = evaluateOp(alloc<Add>(term->scalar(), v->scalar()));
358 found = true;
359 // Skip this term if we cancelled it out.
360 if (immediateEquals(newScalar, 0)) {
361 continue;
362 }
363 auto term = alloc<Term>(hasher_, newScalar, v->variables());
364 newVars.push_back(term);
365 } else {
366 newVars.push_back(v);
367 }
368 }
369
370 if (!found) {
371 newVars.push_back(term);
372 }
373
374 if (newVars.empty()) {
375 return poly->scalar();
376 }
377
378 auto Poly = alloc<Polynomial>(hasher_, poly->scalar(), newVars);
379 return Poly;
380 }
381
mutate(const AddPtr & v)382 ExprPtr PolynomialTransformer::mutate(const AddPtr& v) {
383 ExprPtr lhs_new = v->lhs()->accept_mutator(this);
384 ExprPtr rhs_new = v->rhs()->accept_mutator(this);
385
386 // Constant Folding.
387 if (lhs_new->isConstant() && rhs_new->isConstant()) {
388 ExprPtr result = evaluateOp(alloc<Add>(lhs_new, rhs_new));
389 return result;
390 }
391
392 // Multilane folding.
393 if (isMultilanePrimitive(lhs_new)) {
394 if (auto ret = combineMultilane<Add>(lhs_new, rhs_new)) {
395 return ret->accept_mutator(this);
396 }
397 }
398
399 ExprPtr scalar = nullptr;
400 ExprPtr variable = nullptr;
401 if (lhs_new->isConstant()) {
402 scalar = evaluateOp(lhs_new);
403 variable = rhs_new;
404 } else if (rhs_new->isConstant()) {
405 scalar = evaluateOp(rhs_new);
406 variable = lhs_new;
407 }
408
409 // If there is a scalar, and it's zero: short circuit and return the other
410 // side.
411 if (scalar && immediateEquals(scalar, 0)) {
412 auto c = alloc<Cast>(v->dtype(), variable);
413 return c->accept_mutator(this);
414 }
415
416 // If this is a floating point Add then order of operations is important, we
417 // dont want to combine ops.
418 if (lhs_new->dtype().is_floating_point() ||
419 rhs_new->dtype().is_floating_point()) {
420 return alloc<Add>(lhs_new, rhs_new);
421 }
422
423 PolynomialPtr lhsPoly = to<Polynomial>(lhs_new);
424 PolynomialPtr rhsPoly = to<Polynomial>(rhs_new);
425
426 if (lhsPoly && rhsPoly) {
427 return addPolynomials(lhsPoly, rhsPoly);
428 }
429
430 TermPtr lhsTerm = to<Term>(lhs_new);
431 TermPtr rhsTerm = to<Term>(rhs_new);
432
433 if (lhsPoly && rhsTerm) {
434 return insertTerm(lhsPoly, rhsTerm);
435 }
436
437 if (rhsPoly && lhsTerm) {
438 return insertTerm(rhsPoly, lhsTerm);
439 }
440
441 if (lhsTerm && rhsTerm) {
442 // If the terms refer to the same variables: combine them.
443 if (lhsTerm->hashVars() == rhsTerm->hashVars()) {
444 ExprPtr newScalar =
445 evaluateOp(alloc<Add>(lhsTerm->scalar(), rhsTerm->scalar()));
446
447 // If the terms cancelled out, return zero.
448 if (immediateEquals(newScalar, 0)) {
449 return newScalar->accept_mutator(this);
450 }
451
452 return alloc<Term>(hasher_, newScalar, lhsTerm->variables());
453 }
454
455 // Otherwise this is a new polynomial with no scalar and two variable
456 // terms.
457 return alloc<Polynomial>(hasher_, immLike(v, 0), lhsTerm, rhsTerm);
458 }
459
460 // Adds are commutative.
461 PolynomialPtr poly = lhsPoly ? lhsPoly : rhsPoly;
462
463 // Add to Polynomial->scalar().
464 if (scalar && poly) {
465 ExprPtr newScalar = evaluateOp(alloc<Add>(scalar, poly->scalar()));
466 return alloc<Polynomial>(hasher_, newScalar, poly->variables());
467 }
468
469 // Simple Polynomial with a scalar and Term.
470 TermPtr term = lhsTerm ? lhsTerm : rhsTerm;
471 if (scalar && term) {
472 return alloc<Polynomial>(hasher_, scalar, term);
473 }
474
475 // Simple Term with a scalar and variable type.
476 if (scalar) {
477 return alloc<Polynomial>(
478 hasher_, scalar, alloc<Term>(hasher_, immLike(v, 1), variable));
479 }
480
481 // If LHS is neither Term not Polynomial, wrap it in a Term.
482 if (!lhsTerm && !lhsPoly) {
483 lhsTerm = alloc<Term>(hasher_, immLike(v, 1), lhs_new);
484 }
485
486 // Same for RHS.
487 if (!rhsTerm && !rhsPoly) {
488 rhsTerm = alloc<Term>(hasher_, immLike(v, 1), rhs_new);
489 }
490
491 // If we now have a poly and a term, we can insert.
492 if (poly) {
493 return insertTerm(poly, lhsTerm ? lhsTerm : rhsTerm);
494 }
495
496 if (lhsTerm->hashVars() == rhsTerm->hashVars()) {
497 return alloc<Term>(
498 hasher_,
499 evaluateOp(alloc<Add>(lhsTerm->scalar(), rhsTerm->scalar())),
500 lhsTerm->variables());
501 }
502
503 // If all else fails we have a new Polynomial with two new variable Terms.
504 return alloc<Polynomial>(hasher_, immLike(v, 0), lhsTerm, rhsTerm);
505 }
506
subTerms(const TermPtr & lhs,TermPtr rhs,bool negated)507 ExprPtr PolynomialTransformer::subTerms(
508 const TermPtr& lhs,
509 TermPtr rhs,
510 bool negated) {
511 // If RHS not already negated, negate it.
512 if (!negated) {
513 ExprPtr minusOne = immLike(rhs, -1);
514 ExprPtr negateScalar = evaluateOp(alloc<Mul>(minusOne, rhs->scalar()));
515 rhs = alloc<Term>(hasher_, negateScalar, rhs->variables());
516 }
517
518 if (lhs->hashVars() == rhs->hashVars()) {
519 ExprPtr newScalar = evaluateOp(alloc<Add>(lhs->scalar(), rhs->scalar()));
520
521 // If the terms cancel out, return zero.
522 if (immediateEquals(newScalar, 0)) {
523 return newScalar;
524 }
525
526 return alloc<Term>(hasher_, newScalar, lhs->variables());
527 }
528
529 return alloc<Polynomial>(
530 hasher_,
531 getImmediateByType(promoteTypes(lhs->dtype(), rhs->dtype()), 0),
532 lhs,
533 rhs);
534 }
535
536 // Subtract the RHS Polynomial from the LHS Polynomial, cancelling out where
537 // possible.
subPolynomials(const PolynomialPtr & lhs,const PolynomialPtr & rhs)538 ExprPtr PolynomialTransformer::subPolynomials(
539 const PolynomialPtr& lhs,
540 const PolynomialPtr& rhs) {
541 // simplify common components
542 // The key here is the variable hash, not the term's hash since we do want
543 // to combine terms that have the same vars but different scalar components.
544 std::unordered_map<SimplifierHashType, TermPtr> varmap;
545
546 for (const auto& lt : lhs->variables()) {
547 addOrUpdateTerm(varmap, lt);
548 }
549
550 for (const auto& rt : rhs->variables()) {
551 // Polynomials add their terms, so negate the RHS's Terms.
552 ExprPtr negated = evaluateOp(alloc<Mul>(immLike(rt, -1), rt->scalar()));
553 TermPtr newRHS = alloc<Term>(hasher_, negated, rt->variables());
554 addOrUpdateTerm(varmap, newRHS);
555 }
556
557 ExprPtr newScalar = evaluateOp(alloc<Sub>(lhs->scalar(), rhs->scalar()));
558
559 // No vars means this cancelled out to a scalar, return it unwrapped.
560 if (varmap.empty()) {
561 return newScalar;
562 }
563
564 // If there is no scalar and zero or one terms, don't wrap.
565 if (immediateEquals(newScalar, 0)) {
566 if (varmap.empty()) {
567 return nullptr;
568 }
569 if (varmap.size() == 1) {
570 return varmap.begin()->second;
571 }
572 }
573
574 // Wrap new variables in a Polynomial.
575 return alloc<Polynomial>(hasher_, newScalar, varmap);
576 }
577
mutate(const SubPtr & v)578 ExprPtr PolynomialTransformer::mutate(const SubPtr& v) {
579 ExprPtr lhs_new = v->lhs()->accept_mutator(this);
580 ExprPtr rhs_new = v->rhs()->accept_mutator(this);
581
582 // Constant Folding.
583 if (lhs_new->isConstant() && rhs_new->isConstant()) {
584 ExprPtr result = evaluateOp(alloc<Sub>(lhs_new, rhs_new));
585 return result;
586 }
587
588 // Multilane folding.
589 if (isMultilanePrimitive(lhs_new)) {
590 if (auto ret = combineMultilane<Sub>(lhs_new, rhs_new)) {
591 return ret->accept_mutator(this);
592 }
593 }
594
595 if (rhs_new->isConstant() && immediateEquals(rhs_new, 0)) {
596 auto c = alloc<Cast>(v->dtype(), lhs_new);
597 return c->accept_mutator(this);
598 }
599
600 // If this is a floating point Sub then order of operations is important, we
601 // dont want to combine ops.
602 if (lhs_new->dtype().is_floating_point() ||
603 rhs_new->dtype().is_floating_point()) {
604 return alloc<Sub>(lhs_new, rhs_new);
605 }
606
607 PolynomialPtr lhsPoly = to<Polynomial>(lhs_new);
608 PolynomialPtr rhsPoly = to<Polynomial>(rhs_new);
609
610 if (lhsPoly && rhsPoly) {
611 auto ret = subPolynomials(lhsPoly, rhsPoly);
612 if (!ret) {
613 // Cancelled out completely.
614 return immLike(v, 0);
615 }
616 return ret;
617 }
618
619 TermPtr lhsTerm = to<Term>(lhs_new);
620 TermPtr rhsTerm = to<Term>(rhs_new);
621
622 // Polynomial - Term.
623 if (lhsPoly && rhsTerm) {
624 // Negate the term.
625 ExprPtr negate =
626 evaluateOp(alloc<Mul>(immLike(rhsTerm, -1), rhsTerm->scalar()));
627 TermPtr newTerm = alloc<Term>(hasher_, negate, rhsTerm->variables());
628 return insertTerm(lhsPoly, newTerm);
629 }
630
631 // Term - Polynomial.
632 if (rhsPoly && lhsTerm) {
633 // Negate every part of the Polynomial.
634 ExprPtr minusOne = immLike(lhsTerm, -1);
635 ExprPtr negateScalar = evaluateOp(alloc<Mul>(minusOne, rhsPoly->scalar()));
636
637 std::vector<TermPtr> variables;
638 for (const auto& t : rhsPoly->variables()) {
639 ExprPtr negate = evaluateOp(alloc<Mul>(minusOne, t->scalar()));
640 variables.push_back(alloc<Term>(hasher_, negate, t->variables()));
641 }
642
643 PolynomialPtr newPoly = alloc<Polynomial>(hasher_, negateScalar, variables);
644 return insertTerm(newPoly, lhsTerm);
645 }
646
647 if (lhsTerm && rhsTerm) {
648 return subTerms(lhsTerm, rhsTerm, false);
649 }
650
651 bool lhsScalar = lhs_new->isConstant();
652 bool rhsScalar = rhs_new->isConstant();
653
654 if (lhsPoly && rhsScalar) {
655 // Easy path, just sub the scalar component.
656 ExprPtr newScalar = evaluateOp(alloc<Sub>(lhsPoly->scalar(), rhs_new));
657 return alloc<Polynomial>(hasher_, newScalar, lhsPoly->variables());
658 }
659
660 if (lhsScalar && rhsPoly) {
661 // Sub the scalar component.
662 ExprPtr newScalar = evaluateOp(alloc<Sub>(lhs_new, rhsPoly->scalar()));
663
664 // Negate each term in the Polynomial RHS.
665 ExprPtr minusOne = immLike(rhsPoly, -1);
666 std::vector<TermPtr> variables;
667 for (const auto& t : rhsPoly->variables()) {
668 ExprPtr negate = evaluateOp(alloc<Mul>(minusOne, t->scalar()));
669 variables.push_back(alloc<Term>(hasher_, negate, t->variables()));
670 }
671
672 return alloc<Polynomial>(hasher_, newScalar, variables);
673 }
674
675 if (lhsTerm && rhsScalar) {
676 // Negate the constant.
677 ExprPtr negate = evaluateOp(alloc<Mul>(immLike(rhs_new, -1), rhs_new));
678 return alloc<Polynomial>(hasher_, negate, lhsTerm);
679 }
680
681 if (lhsScalar && rhsTerm) {
682 // Negate the RHS Term.
683 ExprPtr negate = evaluateOp(
684 alloc<Mul>(immLike(rhsTerm->scalar(), -1), rhsTerm->scalar()));
685
686 return alloc<Polynomial>(
687 hasher_, lhs_new, alloc<Term>(hasher_, negate, rhsTerm->variables()));
688 }
689
690 // simple term with a scalar and variable type.
691 if (lhsScalar) {
692 // Create a negated term.
693 return alloc<Polynomial>(
694 hasher_, lhs_new, alloc<Term>(hasher_, immLike(v, -1), rhs_new));
695 }
696
697 if (rhsScalar) {
698 // Negate the scalar.
699 ExprPtr negate = evaluateOp(alloc<Mul>(immLike(rhs_new, -1), rhs_new));
700 return alloc<Polynomial>(
701 hasher_, negate, alloc<Term>(hasher_, immLike(v, 1), lhs_new));
702 }
703
704 // no scalar...
705 if (!lhsTerm && !lhsPoly) {
706 lhsTerm = alloc<Term>(hasher_, immLike(v, 1), lhs_new);
707 }
708
709 bool createdRHSnegated = false;
710 if (!rhsTerm && !rhsPoly) {
711 rhsTerm = alloc<Term>(hasher_, immLike(v, -1), rhs_new);
712 createdRHSnegated = true;
713 }
714
715 if (lhsTerm && rhsTerm) {
716 return subTerms(lhsTerm, rhsTerm, createdRHSnegated);
717 }
718
719 // Insert wrapped Term into LHS Polynomial.
720 if (lhsPoly) {
721 CHECK(rhsTerm);
722 return insertTerm(lhsPoly, rhsTerm);
723 }
724
725 // Insert wrapper Term into negated RHS Poly.
726 if (rhsPoly) {
727 CHECK(lhsTerm);
728 ExprPtr minusOne = immLike(rhsPoly, -1);
729 ExprPtr newScalar = evaluateOp(alloc<Mul>(minusOne, rhsPoly->scalar()));
730
731 // Negate each term in the Polynomial RHS.
732 std::vector<TermPtr> variables;
733 for (const auto& t : rhsPoly->variables()) {
734 ExprPtr negate = evaluateOp(alloc<Mul>(minusOne, t->scalar()));
735 variables.push_back(alloc<Term>(hasher_, negate, t->variables()));
736 }
737
738 auto poly = alloc<Polynomial>(hasher_, newScalar, variables);
739 return insertTerm(poly, lhsTerm);
740 }
741
742 return alloc<Polynomial>(hasher_, immLike(v, 0), lhsTerm, rhsTerm);
743 }
744
745 // Multiply two terms together, usually creating a new term with the variable
746 // lists concatenated.
mulTerms(const TermPtr & lhs,const TermPtr & rhs)747 TermPtr PolynomialTransformer::mulTerms(
748 const TermPtr& lhs,
749 const TermPtr& rhs) {
750 ExprPtr scalar = evaluateOp(alloc<Mul>(lhs->scalar(), rhs->scalar()));
751 if (immediateEquals(scalar, 0)) {
752 return nullptr;
753 }
754
755 // Can reorder here since floating point ops don't get put into Terms.
756 std::vector<ExprPtr> variables;
757 std::vector<ExprPtr> multilaneVariables;
758 // For now don't handle exponents.
759 for (const auto& c : lhs->variables()) {
760 if (isMultilanePrimitive(c)) {
761 multilaneVariables.push_back(c);
762 } else {
763 variables.push_back(c);
764 }
765 }
766 for (const auto& c : rhs->variables()) {
767 if (isMultilanePrimitive(c)) {
768 multilaneVariables.push_back(c);
769 } else {
770 variables.push_back(c);
771 }
772 }
773
774 // Merge all the multilane vars:
775 ExprPtr lastNode{nullptr};
776 for (const auto& node : multilaneVariables) {
777 if (lastNode == nullptr) {
778 lastNode = node;
779 } else {
780 if (auto next = mulMultilane(lastNode, node)) {
781 lastNode = next->accept_mutator(this);
782 } else {
783 variables.push_back(lastNode);
784 lastNode = node;
785 }
786 }
787 }
788 if (lastNode) {
789 variables.push_back(lastNode);
790 }
791
792 return alloc<Term>(hasher_, scalar, variables);
793 }
794
795 // Multiply a Polynomial by a Term.
polyByTerm(const PolynomialPtr & poly,const TermPtr & term)796 ExprPtr PolynomialTransformer::polyByTerm(
797 const PolynomialPtr& poly,
798 const TermPtr& term) {
799 // poly * term
800 // = (poly_terms + poly_scalar) * term
801 // = poly_terms * term + poly_scalar * term
802
803 // First, multiply all variables (terms) in the polynomial by the input
804 // term.
805 std::vector<TermPtr> newTerms;
806 for (const auto& var : poly->variables()) {
807 TermPtr newTerm = mulTerms(var, term);
808 if (newTerm) {
809 newTerms.push_back(newTerm);
810 }
811 }
812
813 // If the scalar in poly is not 0, it must be multiplied by term.
814 // If there are no variables in term, this becomes the scalar in the result
815 // polynomial. If there are variables in term, this becomes a new term in
816 // the result polynomial.
817 if (!immediateEquals(poly->scalar(), 0)) {
818 ExprPtr scalar = evaluateOp(alloc<Mul>(poly->scalar(), term->scalar()));
819 if (term->variables().empty()) {
820 return alloc<Polynomial>(hasher_, scalar, newTerms);
821 }
822 newTerms.push_back(alloc<Term>(hasher_, scalar, term->variables()));
823 }
824
825 // The only case when the result polynomial has a scalar is when the input
826 // term does not have any variables and the input polynomial has a non-zero
827 // scalar. That case is handled above. So, at this point, we do not have any
828 // scalars in the result polynomial.
829 return alloc<Polynomial>(hasher_, std::move(newTerms));
830 }
831
832 // Does multiplying these two expressions make a Rounding Off operation.
833 // e.g. LHS = (x/y), RHS = y => (x / y) * y => RoundOff(x, y).
isRoundOff(const ExprPtr & lhs,const ExprPtr & rhs)834 ExprPtr PolynomialTransformer::isRoundOff(
835 const ExprPtr& lhs,
836 const ExprPtr& rhs) {
837 DivPtr div{nullptr};
838 ExprPtr other{nullptr};
839
840 if ((div = to<Div>(lhs))) {
841 other = rhs;
842 } else if ((div = to<Div>(rhs))) {
843 other = lhs;
844 } else {
845 return nullptr;
846 }
847
848 ExprPtr denom = div->rhs();
849
850 if (TermPtr denomTerm = to<Term>(denom)) {
851 if (immediateEquals(denomTerm->scalar(), 1) &&
852 denomTerm->variables().size() == 1) {
853 denom = denomTerm->variables()[0];
854 }
855 }
856
857 if (hasher_.hash(denom) == hasher_.hash(other)) {
858 // If the denominator is equal to the other, then yes it's a RoundOff.
859 return alloc<RoundOff>(div->lhs(), div->rhs());
860 }
861
862 if (denom->isConstant() && other->isConstant()) {
863 if (immediateEquals(denom, 0) || immediateEquals(other, 0)) {
864 return nullptr;
865 }
866 // If they are both scalar we may be able to find a common factor.
867 if (immediateEquals(evaluateOp(alloc<Mod>(other, denom)), 0)) {
868 ExprPtr scalar = evaluateOp(alloc<Div>(other, denom));
869 ExprPtr newDenom = evaluateOp(alloc<Div>(other, scalar));
870 return alloc<Term>(
871 hasher_, scalar, alloc<RoundOff>(div->lhs(), newDenom));
872 }
873 }
874
875 return nullptr;
876 }
877
878 // Inserts a new component into a term, looking for opportunities to simplify.
insertIntoTerm(const TermPtr & term,const ExprPtr & expr)879 ExprPtr PolynomialTransformer::insertIntoTerm(
880 const TermPtr& term,
881 const ExprPtr& expr) {
882 std::vector<ExprPtr> vars;
883
884 // Search for RoundOffs.
885 bool merged{false};
886 for (const auto& component : term->variables()) {
887 if (auto roundoff = isRoundOff(component, expr)) {
888 vars.push_back(roundoff);
889 merged = true;
890 } else {
891 vars.push_back(component);
892 }
893 }
894
895 if (!merged) {
896 vars.push_back(expr);
897 }
898
899 if (vars.size() == 1 && immediateEquals(term->scalar(), 1)) {
900 return vars[0];
901 }
902
903 return alloc<Term>(hasher_, term->scalar(), vars);
904 }
905
mutate(const MulPtr & v)906 ExprPtr PolynomialTransformer::mutate(const MulPtr& v) {
907 ExprPtr lhs_new = v->lhs()->accept_mutator(this);
908 ExprPtr rhs_new = v->rhs()->accept_mutator(this);
909
910 // Constant Folding.
911 if (lhs_new->isConstant() && rhs_new->isConstant()) {
912 return evaluateOp(alloc<Mul>(lhs_new, rhs_new));
913 }
914
915 // Multilane folding.
916 if (isMultilanePrimitive(lhs_new)) {
917 if (auto ret = mulMultilane(lhs_new, rhs_new)) {
918 return ret->accept_mutator(this);
919 }
920 }
921
922 // Order doesn't matter.
923 ExprPtr scalar = nullptr;
924 ExprPtr variable = nullptr;
925 if (lhs_new->isConstant()) {
926 scalar = lhs_new;
927 variable = rhs_new;
928 } else if (rhs_new->isConstant()) {
929 scalar = rhs_new;
930 variable = lhs_new;
931 }
932
933 // Handle special case mul by 1 since thats safe for floating point, even if
934 // it's Nan/Inf.
935 if (scalar && immediateEquals(scalar, 1)) {
936 auto c = alloc<Cast>(v->dtype(), variable);
937 return c->accept_mutator(this);
938 }
939
940 // If this is a floating point Mul then order of operations is important, we
941 // dont want to combine ops.
942 if (lhs_new->dtype().is_floating_point() ||
943 rhs_new->dtype().is_floating_point()) {
944 return alloc<Mul>(lhs_new, rhs_new);
945 }
946
947 // Handle special case mul by 0.
948 if (scalar && immediateEquals(scalar, 0)) {
949 return immLike(v, 0);
950 }
951
952 // Catch cases of rounding (Div(A/B) * B).
953 if (auto ret = isRoundOff(lhs_new, rhs_new)) {
954 return ret;
955 } else if (auto ret = isRoundOff(v->lhs(), v->rhs())) {
956 // We can break the Round + Mod pattern via factorization of the Div, so
957 // check whether it would have worked on the unsimplified tree. If so, we
958 // need to simplify again.
959 return ret->accept_mutator(this);
960 }
961
962 PolynomialPtr lhsPoly = to<Polynomial>(lhs_new);
963 PolynomialPtr rhsPoly = to<Polynomial>(rhs_new);
964
965 if (lhsPoly && rhsPoly) {
966 // This expands to more terms that we can't generally fix without variable
967 // factorization, it's more efficient to just leave these as Muls.
968 return alloc<Mul>(lhsPoly, rhsPoly);
969 }
970
971 TermPtr lhsTerm = to<Term>(lhs_new);
972 TermPtr rhsTerm = to<Term>(rhs_new);
973
974 if (lhsPoly && rhsTerm) {
975 return polyByTerm(lhsPoly, rhsTerm);
976 }
977
978 if (rhsPoly && lhsTerm) {
979 return polyByTerm(rhsPoly, lhsTerm);
980 }
981
982 if (lhsTerm && rhsTerm) {
983 return mulTerms(lhsTerm, rhsTerm);
984 }
985
986 if (scalar && lhsTerm) {
987 ExprPtr newScalar = evaluateOp(alloc<Mul>(scalar, lhsTerm->scalar()));
988 return alloc<Term>(hasher_, newScalar, lhsTerm->variables());
989 }
990
991 if (scalar && rhsTerm) {
992 ExprPtr newScalar = evaluateOp(alloc<Mul>(scalar, rhsTerm->scalar()));
993 return alloc<Term>(hasher_, newScalar, rhsTerm->variables());
994 }
995
996 // If this is a scalar * a Polynomial, push the scalar term down.
997 // We can wrap the scalar with a Term and use polyByTerm.
998 if (scalar && lhsPoly) {
999 return polyByTerm(lhsPoly, alloc<Term>(hasher_, scalar));
1000 }
1001 if (scalar && rhsPoly) {
1002 return polyByTerm(rhsPoly, alloc<Term>(hasher_, scalar));
1003 }
1004
1005 // simple term with a scalar and variable type.
1006 if (scalar) {
1007 return alloc<Term>(hasher_, scalar, variable);
1008 }
1009
1010 // Multiplying Polynomial by variable can be wrapped in a term and handled
1011 // by polyByTerm also.
1012 if (lhsPoly) {
1013 auto term = alloc<Term>(hasher_, immLike(rhs_new, 1), rhs_new);
1014 return polyByTerm(lhsPoly, term);
1015 }
1016 if (rhsPoly) {
1017 auto term = alloc<Term>(hasher_, immLike(lhs_new, 1), lhs_new);
1018 return polyByTerm(rhsPoly, term);
1019 }
1020
1021 // Multiplying Term by a variable is equivalent to adding the variable to
1022 // the term's list of vars.
1023 if (lhsTerm) {
1024 return insertIntoTerm(lhsTerm, rhs_new);
1025 }
1026 if (rhsTerm) {
1027 return insertIntoTerm(rhsTerm, lhs_new);
1028 }
1029
1030 // Two variables, create a new Term.
1031 return alloc<Term>(hasher_, immLike(v, 1), lhs_new, rhs_new);
1032 }
1033
factorizeDivision(ExprPtr lhs_new,ExprPtr rhs_new)1034 static ExprPtr factorizeDivision(ExprPtr lhs_new, ExprPtr rhs_new) {
1035 if (!lhs_new || !rhs_new) {
1036 return nullptr;
1037 }
1038
1039 ExprPtr leftScalar = lhs_new->isConstant() ? lhs_new : nullptr;
1040 ExprPtr rightScalar = rhs_new->isConstant() ? rhs_new : nullptr;
1041
1042 auto lhsTerm = to<Term>(lhs_new);
1043 auto rhsTerm = to<Term>(rhs_new);
1044 if (lhsTerm) {
1045 leftScalar = lhsTerm->scalar();
1046 }
1047
1048 if (rhsTerm) {
1049 rightScalar = rhsTerm->scalar();
1050 }
1051
1052 if (!leftScalar || !rightScalar) {
1053 return nullptr;
1054 }
1055
1056 long left = immediateAs<long>(leftScalar);
1057 long right = immediateAs<long>(rightScalar);
1058
1059 long GCD = gcd<long>(left, right);
1060 if (GCD <= 1) {
1061 return nullptr;
1062 }
1063
1064 leftScalar = evaluateOp(alloc<Div>(leftScalar, immLike(leftScalar, GCD)));
1065 rightScalar = evaluateOp(alloc<Div>(rightScalar, immLike(rightScalar, GCD)));
1066
1067 if (lhsTerm) {
1068 lhs_new = alloc<Term>(lhsTerm->hasher(), leftScalar, lhsTerm->variables());
1069 } else {
1070 lhs_new = leftScalar;
1071 }
1072
1073 if (rhsTerm) {
1074 rhs_new = alloc<Term>(rhsTerm->hasher(), rightScalar, rhsTerm->variables());
1075 } else {
1076 rhs_new = rightScalar;
1077 }
1078
1079 return alloc<Div>(lhs_new, rhs_new);
1080 }
1081
mutate(const DivPtr & v)1082 ExprPtr PolynomialTransformer::mutate(const DivPtr& v) {
1083 ExprPtr lhs_new = v->lhs()->accept_mutator(this);
1084 ExprPtr rhs_new = v->rhs()->accept_mutator(this);
1085
1086 // Constant Folding.
1087 if (lhs_new->isConstant() && rhs_new->isConstant()) {
1088 return evaluateOp(alloc<Div>(lhs_new, rhs_new));
1089 }
1090
1091 // If this is a floating point Div then order of operations is important, we
1092 // dont want to combine ops.
1093 if (lhs_new->dtype().is_floating_point() ||
1094 rhs_new->dtype().is_floating_point()) {
1095 return alloc<Div>(lhs_new, rhs_new);
1096 }
1097
1098 // If the numerator is zero, so is the result.
1099 if (lhs_new->isConstant() && immediateEquals(lhs_new, 0)) {
1100 return lhs_new;
1101 }
1102
1103 // If the denominator is one, return numerator.
1104 if (rhs_new->isConstant() && immediateEquals(rhs_new, 1)) {
1105 return lhs_new;
1106 }
1107
1108 // If numberator and denominator are equal the result is 1.
1109 // Unless the demoninator could be zero.
1110 // if (hasher_.hash(lhs_new) == hasher_.hash(rhs_new)) {
1111 // return getImmediateByType(v->dtype(), 1);
1112 // }
1113
1114 if (auto ret = factorizeDivision(lhs_new, rhs_new)) {
1115 return ret->accept_mutator(this);
1116 }
1117
1118 return alloc<Div>(lhs_new, rhs_new);
1119 }
1120
mutate(const ModPtr & v)1121 ExprPtr PolynomialTransformer::mutate(const ModPtr& v) {
1122 ExprPtr lhs_new = v->lhs()->accept_mutator(this);
1123 ExprPtr rhs_new = v->rhs()->accept_mutator(this);
1124
1125 // Constant Folding.
1126 if (lhs_new->isConstant() && rhs_new->isConstant()) {
1127 return evaluateOp(alloc<Mod>(lhs_new, rhs_new));
1128 }
1129
1130 // 0 % x => 0.
1131 if (lhs_new->isConstant() && immediateEquals(lhs_new, 0)) {
1132 return lhs_new;
1133 }
1134
1135 // x % 1 == 0.
1136 if (rhs_new->isConstant() && immediateEquals(rhs_new, 1)) {
1137 return immLike(v, 0);
1138 }
1139
1140 // x % x => 0.
1141 if (hasher_.hash(lhs_new) == hasher_.hash(rhs_new)) {
1142 return immLike(v, 0);
1143 }
1144
1145 TermPtr lhsTerm = to<Term>(lhs_new);
1146 if (!lhsTerm) {
1147 PolynomialPtr lhsPoly = to<Polynomial>(lhs_new);
1148 if (lhsPoly) {
1149 // Can still optimize this out if we can factorize the polynomial.
1150 lhsTerm = factorizePolynomial(lhsPoly);
1151 }
1152 }
1153
1154 if (lhsTerm) {
1155 // ((C1 * C2) * x) % C1 => 0.
1156 if (rhs_new->isConstant() &&
1157 immediateEquals(
1158 evaluateOp(alloc<Mod>(lhsTerm->scalar(), rhs_new)), 0)) {
1159 return immLike(v, 0);
1160 }
1161
1162 // (x * y * z) % x => 0.
1163 for (const auto& component : lhsTerm->variables()) {
1164 if (hasher_.hash(component) == hasher_.hash(rhs_new)) {
1165 return immLike(v, 0);
1166 }
1167 }
1168
1169 // (6 * x * y) % (3 * x * y) => 0.
1170 // also, (x * y * z) % (z * y) => 0.
1171 // This requires all variable terms found in the RHS to be present in the
1172 // LHS.
1173 TermPtr rhsTerm = to<Term>(rhs_new);
1174 if (rhsTerm) {
1175 auto& lVars = lhsTerm->variables();
1176 auto& rVars = rhsTerm->variables();
1177 size_t rLeft = rVars.size();
1178
1179 auto rIt = rVars.begin();
1180
1181 for (auto lIt = lVars.begin(); lIt != lVars.end() && !rVars.empty();
1182 ++lIt) {
1183 auto lHash = hasher_.hash(*lIt);
1184 for (; rIt != rVars.end(); ++rIt) {
1185 auto rHash = hasher_.hash(*rIt);
1186 if (lHash == rHash) {
1187 --rLeft;
1188 break;
1189 } else if (lHash < rHash) {
1190 break;
1191 }
1192 }
1193 }
1194
1195 if (rLeft == 0 &&
1196 immediateEquals(
1197 evaluateOp(alloc<Mod>(lhsTerm->scalar(), rhsTerm->scalar())),
1198 0)) {
1199 return immLike(v, 0);
1200 }
1201 }
1202 }
1203
1204 return alloc<Mod>(lhs_new, rhs_new);
1205 }
1206
1207 namespace {
1208
1209 // Combines two MinTerm / MaxTerm expressions into one.
1210 // The first type on the template refers to the op, as in Min or Max and the
1211 // second type refers to the corresponding term, as in MinTerm or MaxTerm.
1212 template <class Op, class OpTerm>
combineMinMaxTerms(ExprPtr lhs,ExprPtr rhs,bool propagate_nans,HashProvider & hasher)1213 ExprPtr combineMinMaxTerms(
1214 ExprPtr lhs,
1215 ExprPtr rhs,
1216 bool propagate_nans,
1217 HashProvider& hasher) {
1218 auto combine_scalars = [&](ExprPtr c1, ExprPtr c2) -> ExprPtr {
1219 if (c1 && c2) {
1220 return evaluateOp(alloc<Op>(c1, c2, propagate_nans));
1221 }
1222 if (c1) {
1223 return c1;
1224 }
1225 return c2;
1226 };
1227
1228 auto combine_opterms = [&](NodePtr<OpTerm> m1, NodePtr<OpTerm> m2) {
1229 ExprPtr scalar = combine_scalars(m1->scalar(), m2->scalar());
1230 std::vector<ExprPtr> variables;
1231 for (const auto& v : m1->variables()) {
1232 variables.push_back(v);
1233 }
1234 for (const auto& v : m2->variables()) {
1235 variables.push_back(v);
1236 }
1237 return alloc<OpTerm>(hasher, scalar, propagate_nans, std::move(variables));
1238 };
1239
1240 auto add_expr_to_opterm = [&](ExprPtr expr, NodePtr<OpTerm> opterm) {
1241 ExprPtr scalar = nullptr;
1242 std::vector<ExprPtr> variables;
1243 if (opterm) {
1244 scalar = opterm->scalar();
1245 variables = opterm->variables();
1246 }
1247 if (expr->isConstant()) {
1248 scalar = combine_scalars(scalar, expr);
1249 } else {
1250 variables.push_back(expr);
1251 }
1252 return alloc<OpTerm>(hasher, scalar, propagate_nans, std::move(variables));
1253 };
1254
1255 auto lhs_opterm = to<OpTerm>(lhs);
1256 auto rhs_opterm = to<OpTerm>(rhs);
1257 if (lhs_opterm && lhs_opterm->propagate_nans() != propagate_nans) {
1258 return alloc<Op>(lhs, rhs, propagate_nans);
1259 }
1260 if (rhs_opterm && rhs_opterm->propagate_nans() != propagate_nans) {
1261 return alloc<Op>(lhs, rhs, propagate_nans);
1262 }
1263
1264 if (lhs_opterm && rhs_opterm) {
1265 return combine_opterms(lhs_opterm, rhs_opterm);
1266 } else if (lhs_opterm) {
1267 return add_expr_to_opterm(rhs, lhs_opterm);
1268 } else if (rhs_opterm) {
1269 return add_expr_to_opterm(lhs, rhs_opterm);
1270 }
1271 return add_expr_to_opterm(rhs, add_expr_to_opterm(lhs, nullptr));
1272 }
1273
1274 // Returns true if op is one of the 2 operands in opterm and also returns
1275 // the other op of opterm in other_op.
1276 template <class OpTerm>
isOperandInMinMaxTerm(NodePtr<OpTerm> opterm,ExprPtr op,HashProvider & hasher,ExprPtr * other_op)1277 bool isOperandInMinMaxTerm(
1278 NodePtr<OpTerm> opterm,
1279 ExprPtr op,
1280 HashProvider& hasher,
1281 ExprPtr* other_op) {
1282 if (opterm->variables().size() != 2) {
1283 return false;
1284 }
1285 auto lhs = opterm->variables()[0];
1286 auto rhs = opterm->variables()[1];
1287 auto op_hash = hasher.hash(std::move(op));
1288 if (hasher.hash(lhs) == op_hash) {
1289 *other_op = rhs;
1290 return true;
1291 } else if (hasher.hash(rhs) == op_hash) {
1292 *other_op = lhs;
1293 return true;
1294 }
1295 return false;
1296 };
1297
1298 // Simplifies the nested min-max pattern like:
1299 // * Max(Min(x, y), Min(x, z)) => Min(x, Max(y, z))
1300 // * Min(Max(x, y), Max(x, z)) => Max(x, Min(y, z))
1301 // This function is called while processing the outer Min / Max ops.
1302 // At that point the inner Min / Max ops would have been converted to
1303 // MinTerm / MaxTerm as appropriate. So, this function checks for those
1304 // term expressions in the given lhs and rhs.
1305 //
1306 // The first type of the template must be the term type corresponding to the
1307 // outer op (e.g. MaxTerm) and the second type of the template must be the term
1308 // type corresponding to the expected inner op (e.g. MinTerm).
1309 template <class OpTerm, class OtherOpTerm>
simplifyNestedMinMax(ExprPtr lhs,ExprPtr rhs,bool propagate_nans,HashProvider & hasher,ExprPtr * new_op)1310 bool simplifyNestedMinMax(
1311 ExprPtr lhs,
1312 ExprPtr rhs,
1313 bool propagate_nans,
1314 HashProvider& hasher,
1315 ExprPtr* new_op) {
1316 auto lhs_opterm = to<OtherOpTerm>(lhs);
1317 auto rhs_opterm = to<OtherOpTerm>(rhs);
1318 if (lhs_opterm && rhs_opterm &&
1319 lhs_opterm->propagate_nans() == propagate_nans &&
1320 rhs_opterm->propagate_nans() == propagate_nans) {
1321 if (!lhs_opterm->scalar() && !rhs_opterm->scalar()) {
1322 if (lhs_opterm->variables().size() == 2 &&
1323 rhs_opterm->variables().size() == 2) {
1324 auto rhs_v1 = rhs_opterm->variables()[0];
1325 auto rhs_v2 = rhs_opterm->variables()[1];
1326 ExprPtr new_op_lhs;
1327 if (isOperandInMinMaxTerm<OtherOpTerm>(
1328 lhs_opterm, rhs_v1, hasher, &new_op_lhs)) {
1329 auto inner_op = alloc<OpTerm>(
1330 hasher, nullptr, propagate_nans, new_op_lhs, rhs_v2);
1331 *new_op = alloc<OtherOpTerm>(
1332 hasher, nullptr, propagate_nans, rhs_v1, inner_op);
1333 return true;
1334 }
1335 if (isOperandInMinMaxTerm<OtherOpTerm>(
1336 lhs_opterm, rhs_v2, hasher, &new_op_lhs)) {
1337 auto inner_op = alloc<OpTerm>(
1338 hasher, nullptr, propagate_nans, new_op_lhs, rhs_v1);
1339 *new_op = alloc<OtherOpTerm>(
1340 hasher, nullptr, propagate_nans, rhs_v2, inner_op);
1341 return true;
1342 }
1343 }
1344 }
1345 }
1346 return false;
1347 }
1348
1349 } // namespace
1350
mutate(const MaxPtr & v)1351 ExprPtr PolynomialTransformer::mutate(const MaxPtr& v) {
1352 ExprPtr lhs_new = v->lhs()->accept_mutator(this);
1353 ExprPtr rhs_new = v->rhs()->accept_mutator(this);
1354
1355 // Constant Folding.
1356 if (lhs_new->isConstant() && rhs_new->isConstant()) {
1357 return evaluateOp(alloc<Max>(lhs_new, rhs_new, v->propagate_nans()));
1358 }
1359
1360 // If diff is constant, return the appropriate operand.
1361 ExprPtr diff = alloc<Sub>(lhs_new, rhs_new);
1362 diff = diff->accept_mutator(this);
1363 if (diff->isConstant()) {
1364 if (immediateAs<int>(diff) > 0) {
1365 return lhs_new;
1366 }
1367 return rhs_new;
1368 }
1369
1370 // Max(Min(x, y), Min(x, z)) => Min(x, Max(y, z))
1371 ExprPtr new_op;
1372 if (simplifyNestedMinMax<MaxTerm, MinTerm>(
1373 lhs_new, rhs_new, v->propagate_nans(), hasher_, &new_op)) {
1374 return new_op;
1375 }
1376
1377 return combineMinMaxTerms<Max, MaxTerm>(
1378 lhs_new, rhs_new, v->propagate_nans(), hasher_);
1379 }
1380
mutate(const MinPtr & v)1381 ExprPtr PolynomialTransformer::mutate(const MinPtr& v) {
1382 ExprPtr lhs_new = v->lhs()->accept_mutator(this);
1383 ExprPtr rhs_new = v->rhs()->accept_mutator(this);
1384
1385 // Constant Folding.
1386 if (lhs_new->isConstant() && rhs_new->isConstant()) {
1387 return evaluateOp(alloc<Min>(lhs_new, rhs_new, v->propagate_nans()));
1388 }
1389
1390 // If diff is constant, return the appropriate operand.
1391 ExprPtr diff = alloc<Sub>(lhs_new, rhs_new);
1392 diff = diff->accept_mutator(this);
1393 if (diff->isConstant()) {
1394 if (immediateAs<int>(diff) < 0) {
1395 return lhs_new;
1396 }
1397 return rhs_new;
1398 }
1399
1400 // Min(Max(x, y), Max(x, z)) => Max(x, Min(y, z))
1401 ExprPtr new_op;
1402 if (simplifyNestedMinMax<MinTerm, MaxTerm>(
1403 lhs_new, rhs_new, v->propagate_nans(), hasher_, &new_op)) {
1404 return new_op;
1405 }
1406
1407 return combineMinMaxTerms<Min, MinTerm>(
1408 lhs_new, rhs_new, v->propagate_nans(), hasher_);
1409 }
1410
mutate(const CompareSelectPtr & v)1411 ExprPtr PolynomialTransformer::mutate(const CompareSelectPtr& v) {
1412 ExprPtr lhs_new = v->lhs()->accept_mutator(this);
1413 ExprPtr rhs_new = v->rhs()->accept_mutator(this);
1414 ExprPtr true_branch = v->ret_val1()->accept_mutator(this);
1415 ExprPtr false_branch = v->ret_val2()->accept_mutator(this);
1416
1417 // Constant Folding.
1418 if (lhs_new->isConstant() && rhs_new->isConstant() &&
1419 true_branch->isConstant() && false_branch->isConstant()) {
1420 ExprPtr v_new = alloc<CompareSelect>(
1421 lhs_new,
1422 rhs_new,
1423 true_branch,
1424 false_branch,
1425 v->compare_select_op(),
1426 v->bias());
1427 return evaluateOp(v_new);
1428 }
1429
1430 // If the comparison is done in float, don't attempt diff simplification,
1431 // since we can't correctly handle NaN.
1432 if (lhs_new->dtype().is_floating_point() ||
1433 rhs_new->dtype().is_floating_point()) {
1434 return alloc<CompareSelect>(
1435 lhs_new,
1436 rhs_new,
1437 true_branch,
1438 false_branch,
1439 v->compare_select_op(),
1440 v->bias());
1441 }
1442
1443 // If diff is constant, we can determine it.
1444 ExprPtr diff = alloc<Sub>(rhs_new, lhs_new);
1445 diff = diff->accept_mutator(this);
1446
1447 if (!diff->isConstant()) {
1448 return alloc<CompareSelect>(
1449 lhs_new,
1450 rhs_new,
1451 true_branch,
1452 false_branch,
1453 v->compare_select_op(),
1454 v->bias());
1455 }
1456
1457 bool equal = immediateEquals(diff, 0);
1458 bool lhsSmaller = !equal && !immediateIsNegative(diff);
1459
1460 switch (v->compare_select_op()) {
1461 case CompareSelectOperation::kEQ:
1462 return equal ? true_branch : false_branch;
1463 case CompareSelectOperation::kGT:
1464 return (lhsSmaller || equal) ? false_branch : true_branch;
1465 case CompareSelectOperation::kGE:
1466 return lhsSmaller ? false_branch : true_branch;
1467 case CompareSelectOperation::kLT:
1468 return lhsSmaller ? true_branch : false_branch;
1469 case CompareSelectOperation::kLE:
1470 return (lhsSmaller || equal) ? true_branch : false_branch;
1471 case CompareSelectOperation::kNE:
1472 return equal ? false_branch : true_branch;
1473 }
1474
1475 // should not be possible but just in case.
1476 return alloc<CompareSelect>(
1477 lhs_new,
1478 rhs_new,
1479 true_branch,
1480 false_branch,
1481 v->compare_select_op(),
1482 v->bias());
1483 }
1484
mutate(const IntrinsicsPtr & v)1485 ExprPtr PolynomialTransformer::mutate(const IntrinsicsPtr& v) {
1486 std::vector<ExprPtr> new_params;
1487 bool changed = false;
1488 bool allConstant = true;
1489 for (const auto& p : v->params()) {
1490 ExprPtr new_child = p->accept_mutator(this);
1491 new_params.push_back(new_child);
1492
1493 changed |= p != new_child;
1494 allConstant &= new_child->isConstant();
1495 }
1496
1497 ExprPtr node = v;
1498 if (changed) {
1499 node = alloc<Intrinsics>(v->op_type(), new_params);
1500 }
1501
1502 if (!allConstant || !v->isPure()) {
1503 return node;
1504 }
1505
1506 // we're evaluating, but the evaluator only supports float intrinsics.
1507 std::vector<ExprPtr> const_params;
1508 changed = false;
1509 for (const auto& p : new_params) {
1510 if (p->dtype().scalar_type() == ScalarType::Float) {
1511 const_params.push_back(p);
1512 } else {
1513 const_params.push_back(
1514 alloc<Cast>(Dtype(ScalarType::Float, p->dtype().lanes()), p));
1515 changed = true;
1516 }
1517 }
1518
1519 if (changed) {
1520 node = alloc<Intrinsics>(v->op_type(), const_params);
1521 }
1522 return evaluateOp(node);
1523 }
1524
mutate(const CastPtr & v)1525 ExprPtr PolynomialTransformer::mutate(const CastPtr& v) {
1526 ExprPtr node = v->src_value()->accept_mutator(this);
1527 if (node->isConstant()) {
1528 return evaluateOp(alloc<Cast>(v->dtype(), node));
1529 }
1530
1531 if (v->dtype() == node->dtype()) {
1532 return node;
1533 }
1534
1535 return alloc<Cast>(v->dtype(), node);
1536 }
1537
mutate(const IfThenElsePtr & v)1538 ExprPtr PolynomialTransformer::mutate(const IfThenElsePtr& v) {
1539 ExprPtr condition = v->condition();
1540 ExprPtr true_value = v->true_value();
1541 ExprPtr false_value = v->false_value();
1542 ExprPtr condition_new = condition->accept_mutator(this);
1543 ExprPtr true_value_new = true_value->accept_mutator(this);
1544 ExprPtr false_value_new = false_value->accept_mutator(this);
1545
1546 // If the condition is constant then we can choose the right branch now.
1547 if (condition_new->isConstant()) {
1548 if (!immediateEquals(condition_new, 0)) {
1549 return true_value_new;
1550 } else {
1551 return false_value_new;
1552 }
1553 }
1554
1555 // If both branches are the same then don't do the condition.
1556 if (hasher_.hash(true_value_new) == hasher_.hash(false_value_new)) {
1557 return true_value_new;
1558 }
1559
1560 if (condition == condition_new && true_value == true_value_new &&
1561 false_value == false_value_new) {
1562 return v;
1563 }
1564
1565 return alloc<IfThenElse>(condition_new, true_value_new, false_value_new);
1566 }
1567
mutate(const AndPtr & v)1568 ExprPtr PolynomialTransformer::mutate(const AndPtr& v) {
1569 return mutateBinaryOp(v, this);
1570 }
1571
mutate(const XorPtr & v)1572 ExprPtr PolynomialTransformer::mutate(const XorPtr& v) {
1573 return mutateBinaryOp(v, this);
1574 }
1575
mutate(const LshiftPtr & v)1576 ExprPtr PolynomialTransformer::mutate(const LshiftPtr& v) {
1577 return mutateBinaryOp(v, this);
1578 }
1579
mutate(const RshiftPtr & v)1580 ExprPtr PolynomialTransformer::mutate(const RshiftPtr& v) {
1581 return mutateBinaryOp(v, this);
1582 }
1583
mutate(const CondPtr & v)1584 StmtPtr PolynomialBase::mutate(const CondPtr& v) {
1585 ExprPtr cond_old = v->condition();
1586 StmtPtr true_old = v->true_stmt();
1587 StmtPtr false_old = v->false_stmt();
1588
1589 ExprPtr cond_new = cond_old->accept_mutator(this);
1590 StmtPtr true_new = true_old ? true_old->accept_mutator(this) : true_old;
1591 StmtPtr false_new = false_old ? false_old->accept_mutator(this) : false_old;
1592
1593 // If the condition is constant then we can choose the right branch now.
1594 if (cond_new->isConstant()) {
1595 if (!immediateEquals(cond_new, 0)) {
1596 return true_new;
1597 } else {
1598 return false_new;
1599 }
1600 }
1601
1602 // If both branches are the same then don't do the condition.
1603 if (true_new && false_new &&
1604 hasher_.hash(true_new) == hasher_.hash(false_new)) {
1605 return true_new;
1606 }
1607
1608 BlockPtr true_block = to<Block>(true_new);
1609 BlockPtr false_block = to<Block>(false_new);
1610 bool true_empty = !true_new || (true_block && true_block->nstmts() == 0);
1611 bool false_empty = !false_new || (false_block && false_block->nstmts() == 0);
1612
1613 if (true_empty && false_empty) {
1614 return alloc<Block>(std::vector<StmtPtr>({}));
1615 }
1616 if (cond_old != cond_new) {
1617 v->set_condition(cond_new);
1618 }
1619 if (true_old != true_new) {
1620 v->set_true_stmt(true_new);
1621 }
1622 if (false_old != false_new) {
1623 v->set_false_stmt(false_new);
1624 }
1625 return v;
1626 }
1627
handleForCondReordering(const ForPtr & loop,const CondPtr & cond)1628 static StmtPtr handleForCondReordering(
1629 const ForPtr& loop,
1630 const CondPtr& cond) {
1631 if (cond->false_stmt()) {
1632 return nullptr;
1633 }
1634
1635 auto condition_vars = VarFinder::find(cond->condition());
1636 for (const auto& v : condition_vars) {
1637 // If the condition depends on a Var that is modified in the loop body, it
1638 // may not be safe to reorder.
1639 if (ModifiesVarChecker::check(loop, v)) {
1640 return nullptr;
1641 }
1642 }
1643
1644 ForPtr new_f = loop->cloneWithNewBody(Stmt::clone(cond->true_stmt()));
1645 return cond->cloneWithNewBody(new_f);
1646 }
1647
mutate(const ForPtr & v)1648 StmtPtr PolynomialBase::mutate(const ForPtr& v) {
1649 ExprPtr var = v->var();
1650 ExprPtr start = v->start();
1651 ExprPtr stop = v->stop();
1652 StmtPtr body = v->body();
1653 LoopOptions loop_options = v->loop_options();
1654 ExprPtr var_new_expr = var->accept_mutator(this);
1655 VarPtr var_new = to<Var>(var_new_expr);
1656 ExprPtr start_new = start->accept_mutator(this);
1657 ExprPtr stop_new = stop->accept_mutator(this);
1658 StmtPtr body_new = body;
1659
1660 ExprPtr loops = alloc<Sub>(stop_new, start_new);
1661 loops = loops->accept_mutator(this);
1662 if (loop_options.isDefault() && loops->isConstant()) {
1663 if (immediateEquals(loops, 0)) {
1664 return alloc<Block>(std::vector<StmtPtr>({}));
1665 } else if (immediateEquals(loops, 1)) {
1666 body_new = Substitute(body, {{var_new, start_new}});
1667 body_new = body_new->accept_mutator(this);
1668 return body_new;
1669 }
1670 }
1671
1672 body_new = body_new->accept_mutator(this);
1673 if (!body_new) {
1674 return alloc<Block>(std::vector<StmtPtr>({}));
1675 }
1676
1677 if (auto block = to<Block>(body_new)) {
1678 if (block->nstmts() == 0) {
1679 return alloc<Block>(std::vector<StmtPtr>({}));
1680 }
1681
1682 if (block->nstmts() == 1) {
1683 if (auto cond = to<Cond>(block->front())) {
1684 StmtPtr reordered = handleForCondReordering(v, cond);
1685 if (reordered) {
1686 return reordered->accept_mutator(this);
1687 }
1688 }
1689 }
1690 }
1691
1692 if (var != var_new) {
1693 v->set_var(var_new);
1694 }
1695 if (start != start_new) {
1696 v->set_start(start_new);
1697 }
1698 if (stop != stop_new) {
1699 v->set_stop(stop_new);
1700 }
1701 if (body != body_new) {
1702 v->set_body(body_new);
1703 }
1704 return v;
1705 }
1706
mutate(const BlockPtr & v)1707 StmtPtr PolynomialBase::mutate(const BlockPtr& v) {
1708 std::vector<StmtPtr> stmts;
1709 // Flatten sub-blocks:
1710 bool stmts_changed = false;
1711 for (const StmtPtr& stmt : *v) {
1712 StmtPtr stmt_new = stmt->accept_mutator(this);
1713 stmts_changed |= stmt != stmt_new;
1714 if (stmt_new == nullptr) {
1715 continue;
1716 }
1717
1718 if (auto subBlock = to<Block>(stmt_new)) {
1719 for (Block::iterator I = subBlock->begin(), E = subBlock->end();
1720 I != E;) {
1721 // Be careful to avoid invalidating the iterator.
1722 StmtPtr s = *(I++);
1723 subBlock->remove_stmt(s);
1724 stmts.push_back(s);
1725 }
1726 stmts_changed = true;
1727 } else {
1728 stmts.push_back(stmt_new);
1729 }
1730 }
1731 if (stmts_changed) {
1732 v->set_stmts(stmts);
1733 }
1734 return v;
1735 }
1736
1737 // TermExpander
1738
mutate(const TermPtr & v)1739 ExprPtr TermExpander::mutate(const TermPtr& v) {
1740 ExprPtr newScalar = v->scalar()->accept_mutator(this);
1741 if (immediateEquals(newScalar, 0)) {
1742 return newScalar;
1743 }
1744
1745 std::vector<ExprPtr> vars;
1746 std::vector<ExprPtr> multilaneVars;
1747
1748 // Assume we can reorder here because we wont merge floating terms.
1749 ExprPtr lastNode{nullptr};
1750 for (const auto& var : v->variables()) {
1751 ExprPtr node = var->accept_mutator(this);
1752 if (MulPtr mul = to<Mul>(node)) {
1753 // If the sub-Expr resolved to a multiplication, lift it into this
1754 // term.
1755 if (isMultilanePrimitive(mul->lhs())) {
1756 multilaneVars.push_back(mul->lhs());
1757 } else {
1758 vars.push_back(mul->lhs());
1759 }
1760
1761 if (isMultilanePrimitive(mul->rhs())) {
1762 multilaneVars.push_back(mul->rhs());
1763 } else {
1764 vars.push_back(mul->rhs());
1765 }
1766 } else {
1767 if (isMultilanePrimitive(node)) {
1768 multilaneVars.push_back(node);
1769 } else {
1770 vars.push_back(node);
1771 }
1772 }
1773 }
1774
1775 for (const auto& node : multilaneVars) {
1776 if (lastNode == nullptr) {
1777 lastNode = node;
1778 } else {
1779 lastNode = mulMultilane(lastNode, node);
1780 // simplify first, then re-expand.
1781 lastNode = lastNode->accept_mutator(simplifier_);
1782 lastNode = lastNode->accept_mutator(this);
1783 }
1784 }
1785
1786 for (const auto& node : vars) {
1787 if (lastNode == nullptr) {
1788 lastNode = node;
1789 } else {
1790 lastNode = alloc<Mul>(lastNode, node);
1791 }
1792 }
1793
1794 if (!immediateEquals(newScalar, 1)) {
1795 if (lastNode) {
1796 // We want to avoid a leaving a CastNode on the scalar, so handle that
1797 // now.
1798 auto termDtype = v->scalar()->dtype();
1799 auto lastNodeDtype = lastNode->dtype();
1800 if (termDtype != lastNodeDtype) {
1801 ExprPtr castV = v->scalar();
1802 // Take care of lane mismatch first.
1803 if (termDtype.lanes() != lastNodeDtype.lanes()) {
1804 castV = alloc<Broadcast>(v->scalar(), lastNodeDtype.lanes());
1805 }
1806 // Now take care of scalar type as well.
1807 if (termDtype.scalar_type() != lastNodeDtype.scalar_type()) {
1808 castV = alloc<Cast>(lastNode->dtype(), castV);
1809 // For scalars, we can simplify the cast further.
1810 if (lastNodeDtype.lanes() == 1) {
1811 castV = evaluateOp(castV);
1812 }
1813 }
1814 lastNode = alloc<Mul>(castV, lastNode);
1815 } else {
1816 lastNode = alloc<Mul>(v->scalar(), lastNode);
1817 }
1818 } else {
1819 lastNode = v->scalar();
1820 }
1821 }
1822
1823 return lastNode;
1824 }
1825
1826 // Returns an immediate containing the greatest common divisor of all terms
1827 // (inc. the scalar term) in the polynomial. If the GCD is uninteresting
1828 // (e.g. 1) then returns nullptr.
polyGCD(const PolynomialPtr & poly)1829 static ExprPtr polyGCD(const PolynomialPtr& poly) {
1830 ExprPtr scalar = poly->scalar();
1831 const std::vector<TermPtr>& variables = poly->variables();
1832
1833 // We ony want to factorize if we're saving complete operations, i.e. no
1834 // value in factorizing 6x + 4y into 2 * (3x + 2y) since we don't save work.
1835 int opsSaved = 1; // default to saving the scalar.
1836 long GCD = std::abs(immediateAs<long>(scalar));
1837 for (const auto& t : variables) {
1838 long termScalar = std::abs(immediateAs<long>(t->scalar()));
1839 long newGCD = gcd(std::max(GCD, termScalar), std::min(GCD, termScalar));
1840 if (newGCD == 1) {
1841 return nullptr;
1842 }
1843
1844 if (GCD != newGCD) {
1845 opsSaved = 0;
1846 GCD = newGCD;
1847 }
1848
1849 if (GCD == termScalar) {
1850 opsSaved++;
1851 }
1852 }
1853
1854 if (opsSaved == 0) {
1855 return nullptr;
1856 }
1857
1858 if (GCD == 0) {
1859 return nullptr;
1860 }
1861
1862 // Not worth, can be a Sub.
1863 if (GCD == -1 && opsSaved == 1) {
1864 return nullptr;
1865 }
1866
1867 return immLike(poly, GCD);
1868 }
1869
1870 // A ModRound is a div-mod-mul in which the divisor in div and multiplier in mul
1871 // are identical and not equal to 1.
1872 // In a ModRound x/y%z*y*c (c is constant), 'scalar' denotes c, 'denominator'
1873 // denotes x, 'divisor' denotes y and 'mod_divisor' denotes z.
1874 class ModRound {
1875 public:
ModRound(ExprPtr scalar,ExprPtr denom,ExprPtr divisor,ExprPtr mod_divisor)1876 ModRound(ExprPtr scalar, ExprPtr denom, ExprPtr divisor, ExprPtr mod_divisor)
1877 : scalar(std::move(scalar)),
1878 denom(std::move(denom)),
1879 divisor(std::move(divisor)),
1880 mod_divisor(std::move(mod_divisor)) {}
1881 ExprPtr scalar;
1882 ExprPtr denom;
1883 ExprPtr divisor;
1884 ExprPtr mod_divisor;
1885 };
1886
isModRound(const TermPtr & e)1887 static std::optional<class ModRound> isModRound(const TermPtr& e) {
1888 DivPtr div{nullptr};
1889 ModPtr mod{nullptr};
1890 ExprPtr denom{nullptr};
1891 ExprPtr divisor{nullptr};
1892 ExprPtr mod_divisor{nullptr};
1893 ExprPtr multiplier = e->scalar();
1894 ExprPtr scalar{nullptr};
1895 ExprPtr other{nullptr};
1896
1897 for (const auto& m : e->variables()) {
1898 if (m->expr_type() == IRNodeType::kMod) {
1899 // TODO: currently only identify terms with one variable being mod; it is
1900 // possible to extend this if we have to handle terms like (t/(x%2 * y) %
1901 // z) * (x%2 *y).
1902 if (!mod) {
1903 mod = to<Mod>(m);
1904 } else {
1905 return std::nullopt;
1906 }
1907 } else {
1908 // Take care of special cases before multiplying the scalar and variable.
1909 if (multiplier->isConstant()) {
1910 // Take care of lane mismatch first.
1911 if (multiplier->dtype().lanes() != m->dtype().lanes()) {
1912 multiplier = alloc<Broadcast>(multiplier, m->dtype().lanes());
1913 }
1914 // Take care of scalar type mismatch.
1915 if (multiplier->dtype().scalar_type() != m->dtype().scalar_type()) {
1916 multiplier = alloc<Cast>(m->dtype(), multiplier);
1917 if (m->dtype().lanes() == 1) {
1918 multiplier = evaluateOp(multiplier);
1919 }
1920 }
1921 }
1922
1923 // All non-mod variables are considered as part of the multiplier.
1924 multiplier = alloc<Mul>(multiplier, m);
1925 }
1926 }
1927 multiplier = IRSimplifier::simplify(multiplier);
1928
1929 if (!mod) {
1930 return std::nullopt;
1931 }
1932
1933 mod_divisor = IRSimplifier::simplify(mod->rhs());
1934 other = mod->lhs();
1935
1936 if (!(div = to<Div>(other))) {
1937 return std::nullopt;
1938 }
1939
1940 divisor = IRSimplifier::simplify(div->rhs());
1941 other = div->lhs();
1942
1943 denom = IRSimplifier::simplify(other);
1944
1945 // Deny cases in which divisor!=multiplier.
1946 HashProvider& hasher = e->hasher();
1947 if (hasher.hash(divisor) != hasher.hash(multiplier)) {
1948 // TODO: currently we do not extract a common factor if divisor and
1949 // multiplier are not constants. The extraction is not supported (e.g.,
1950 // x*2/x -> 2) in IRSimplifier.simplify because x could be 0. As future
1951 // work, we can extend division to 2 versions: 1) division for customers
1952 // that has to be strictly simplified and 2) division we introduced in our
1953 // transformations which can be simplified without considering 0s, e.g.,
1954 // Div_nonzero. The second division will be only used to facilitate our
1955 // transformations.
1956 if (divisor->isConstant() && multiplier->isConstant()) {
1957 // If both are scalar we may be able to find a common factor.
1958 if (immediateEquals(evaluateOp(alloc<Mod>(multiplier, divisor)), 0)) {
1959 // The common factor becomes 'scalar' of the term, e.g.,in t/3%7*6,
1960 // divisor=multiplier=3, scalar=2.
1961 ExprPtr c = evaluateOp(alloc<Div>(multiplier, divisor));
1962 scalar = c;
1963 } else if (immediateEquals(
1964 evaluateOp(alloc<Mod>(divisor, multiplier)), 0)) {
1965 // The common factor becomes part of 'denom', e.g., in t/14%7*2,
1966 // divisor=multiplier=2, denom=t/7.
1967 ExprPtr c = evaluateOp(alloc<Div>(divisor, multiplier));
1968 divisor = multiplier;
1969 denom = IRSimplifier::simplify(alloc<Div>(other, c));
1970 } else {
1971 return std::nullopt;
1972 }
1973 } else {
1974 return std::nullopt;
1975 }
1976 }
1977
1978 // Deny cases in which divisor=1. Such cases are considered as Mods.
1979 if (divisor->isConstant() && immediateEquals(divisor, 1)) {
1980 return std::nullopt;
1981 }
1982
1983 if (!scalar) {
1984 scalar = immLike(multiplier, 1);
1985 }
1986
1987 return ModRound(scalar, denom, divisor, mod_divisor);
1988 }
1989
1990 // Search the polynomial for Terms that can be merged in
1991 // (1) Round + Mod pattern: (x/y) * y + x % y => RoundOff(x,y) + Mod(x, y) => x
1992 // (2) Mod round + Mod pattern: (x/y % z)*y + x%y => ModRound(x, y, z) + Mod(x,
1993 // y) => x % (y*z)
simplifyRoundModPattern(const PolynomialPtr & poly)1994 static ExprPtr simplifyRoundModPattern(const PolynomialPtr& poly) {
1995 std::vector<TermPtr> rounds;
1996 std::vector<TermPtr> mods;
1997 std::vector<TermPtr> mod_rounds;
1998 std::vector<TermPtr> others;
1999
2000 // Split out the Mod, ModRounds and RoundOffs operations so we can inspect.
2001 for (const auto& c : poly->variables()) {
2002 if (c->variables().size() > 1) {
2003 if (auto a = isModRound(c)) {
2004 mod_rounds.push_back(c);
2005 } else {
2006 others.push_back(c);
2007 }
2008 continue;
2009 }
2010
2011 ExprPtr e = c->variables()[0];
2012
2013 if (to<RoundOff>(e)) {
2014 rounds.push_back(c);
2015 } else if (e->expr_type() == IRNodeType::kMod) {
2016 if (auto a = isModRound(c)) {
2017 mod_rounds.push_back(c);
2018 } else {
2019 mods.push_back(c);
2020 }
2021 } else {
2022 others.push_back(c);
2023 }
2024 }
2025
2026 // Can't continue without at least one RoundOff/ModRound and one Mod.
2027 if ((rounds.empty() && mod_rounds.empty()) || mods.empty()) {
2028 return nullptr;
2029 }
2030
2031 HashProvider& hasher = poly->hasher();
2032 bool didAnything = false;
2033 std::vector<TermPtr> mods_merged;
2034 bool repeat = true;
2035 // Repeat merging terms till there are no Mods or the terms cannot be merged
2036 // any further.
2037 while (!mods.empty() && repeat) {
2038 repeat = false;
2039 for (int64_t i = static_cast<int64_t>(mods.size()) - 1; i >= 0; i--) {
2040 TermPtr m = mods[i];
2041 ModPtr mod = to<Mod>(m->variables()[0]);
2042 CHECK(mod);
2043 ExprPtr mod_lhs = IRSimplifier::simplify(mod->lhs());
2044 ExprPtr mod_rhs = IRSimplifier::simplify(mod->rhs());
2045 bool merged = false;
2046 for (int64_t j = static_cast<int64_t>(mod_rounds.size()) - 1; j >= 0;
2047 j--) {
2048 TermPtr mr = mod_rounds[j];
2049 auto a = isModRound(mr);
2050 CHECK(a);
2051 ModRound& mod_round = *a;
2052
2053 // TODO: for now don't attempt partial factorization of this
2054 // optimization. E.g. it's possible to do: 2 * (x/y%z) * y + (x%y) =>
2055 // x%(y*z) + (x/y%z) * y
2056 if (!immediateEquals(
2057 evaluateOp(alloc<Sub>(mod_round.scalar, m->scalar())), 0)) {
2058 continue;
2059 }
2060 // Valid optimization if mod LHS matches denom and mod RHS matches
2061 // divisor.
2062 if (hasher.hash(mod_round.denom) == hasher.hash(mod_lhs) &&
2063 hasher.hash(mod_round.divisor) == hasher.hash(mod_rhs)) {
2064 TermPtr merged_m = alloc<Term>(
2065 hasher,
2066 mod_round.scalar,
2067 IRSimplifier::simplify(alloc<Mod>(
2068 mod_round.denom,
2069 alloc<Mul>(mod_round.divisor, mod_round.mod_divisor))));
2070 mods_merged.push_back(merged_m);
2071 merged = true;
2072 repeat = true;
2073 didAnything = true;
2074 mods.erase(mods.begin() + i);
2075 mod_rounds.erase(mod_rounds.begin() + j);
2076 break;
2077 }
2078 }
2079
2080 if (merged) {
2081 continue;
2082 }
2083
2084 for (int64_t k = static_cast<int64_t>(rounds.size()) - 1; k >= 0; k--) {
2085 TermPtr r = rounds[k];
2086 RoundOffPtr roundoff = to<RoundOff>(r->variables()[0]);
2087 CHECK(roundoff);
2088
2089 // TODO: for now don't attempt partial factorization of this
2090 // optimization. E.g. it's possible to do: 2 * (x/y) * y + (x%y) => x +
2091 // (x/y) * y but unsure thats actually much better, particularly with
2092 // CSE.
2093 if (!immediateEquals(
2094 evaluateOp(alloc<Sub>(r->scalar(), m->scalar())), 0)) {
2095 continue;
2096 }
2097 ExprPtr round_lhs = IRSimplifier::simplify(roundoff->lhs());
2098 ExprPtr round_rhs = IRSimplifier::simplify(roundoff->rhs());
2099 // Valid optimization if LHS and RHS are equal for both.
2100 if (hasher.hash(round_lhs) == hasher.hash(mod_lhs) &&
2101 hasher.hash(round_rhs) == hasher.hash(mod_rhs)) {
2102 TermPtr merged_r = alloc<Term>(hasher, r->scalar(), round_lhs);
2103 others.push_back(merged_r);
2104 merged = true;
2105 didAnything = true;
2106 mods.erase(mods.begin() + i);
2107 rounds.erase(rounds.begin() + k);
2108 break;
2109 }
2110 }
2111
2112 // If we didn't merge, move out the Mod.
2113 if (!merged) {
2114 others.push_back(m);
2115 mods.erase(mods.begin() + i);
2116 }
2117
2118 } // end of for-loop
2119
2120 // Add newly generated Mods for merging opportunities in the next iteration.
2121 if (!mods_merged.empty()) {
2122 mods.insert(mods.end(), mods_merged.begin(), mods_merged.end());
2123 mods_merged.clear();
2124 }
2125
2126 } // end of while-loop
2127
2128 // If we made no changes, just exit.
2129 if (!didAnything) {
2130 return nullptr;
2131 }
2132
2133 // Keep remaining ModRounds and RoundOffs.
2134 if (!mod_rounds.empty()) {
2135 others.insert(others.end(), mod_rounds.begin(), mod_rounds.end());
2136 }
2137
2138 if (!rounds.empty()) {
2139 others.insert(others.end(), rounds.begin(), rounds.end());
2140 }
2141
2142 return alloc<Polynomial>(hasher, poly->scalar(), others);
2143 }
2144
2145 // Trivially factorize terms by GCD of scalar components.
factorizePolynomial(const PolynomialPtr & poly)2146 TermPtr PolynomialBase::factorizePolynomial(const PolynomialPtr& poly) {
2147 ExprPtr scalar = poly->scalar();
2148 const std::vector<TermPtr>& variables = poly->variables();
2149
2150 // Compute the GCD of terms.
2151 ExprPtr GCD = polyGCD(poly);
2152
2153 // No GCD means 0 or 1 and can't be factored.
2154 if (!GCD) {
2155 return nullptr;
2156 }
2157
2158 // Create new structure.
2159 std::vector<TermPtr> newPolyTerms;
2160 newPolyTerms.reserve(variables.size());
2161 for (const auto& t : variables) {
2162 // New term with the scalar divided by the GCD.
2163 newPolyTerms.push_back(alloc<Term>(
2164 poly->hasher(),
2165 evaluateOp(alloc<Div>(t->scalar(), GCD)),
2166 t->variables()));
2167 }
2168
2169 PolynomialPtr newPoly = alloc<Polynomial>(
2170 poly->hasher(), evaluateOp(alloc<Div>(scalar, GCD)), newPolyTerms);
2171
2172 return alloc<Term>(poly->hasher(), GCD, newPoly);
2173 }
2174
mutate(const PolynomialPtr & v)2175 ExprPtr TermExpander::mutate(const PolynomialPtr& v) {
2176 if (v->variables().empty()) {
2177 return v->scalar();
2178 }
2179
2180 // If this Polynomial can be factorized: do it, then expand the result.
2181 if (ExprPtr simplified = simplifyRoundModPattern(v)) {
2182 return simplified->accept_mutator(this);
2183 }
2184
2185 // If this Polynomial can be factorized: do it, then expand the result.
2186 if (ExprPtr factorized = factorizePolynomial(v)) {
2187 return factorized->accept_mutator(this);
2188 }
2189
2190 std::vector<TermPtr> addTerms;
2191 std::vector<TermPtr> subTerms;
2192
2193 auto vars = v->variables();
2194 std::unordered_map<ExprPtr, std::string> str_repr_cache;
2195 std::sort(vars.begin(), vars.end(), [&](const ExprPtr& a, const ExprPtr& b) {
2196 if (!str_repr_cache.count(a)) {
2197 str_repr_cache[a] = std::to_string(a);
2198 }
2199 if (!str_repr_cache.count(b)) {
2200 str_repr_cache[b] = std::to_string(b);
2201 }
2202 return str_repr_cache.at(a) < str_repr_cache.at(b);
2203 });
2204
2205 // partition the terms into a list to add and list to subtract.
2206 for (const auto& node : vars) {
2207 if (immediateIsNegative(node->scalar())) {
2208 subTerms.push_back(node);
2209 } else if (!immediateEquals(node->scalar(), 0)) {
2210 addTerms.push_back(node);
2211 }
2212 // Skip terms with a scalar of zero.
2213 }
2214
2215 // The last node constructed.
2216 ExprPtr lastNode{nullptr};
2217
2218 for (const auto& node : addTerms) {
2219 ExprPtr simpleNode = node->accept_mutator(this);
2220
2221 if (lastNode == nullptr) {
2222 lastNode = simpleNode;
2223 continue;
2224 }
2225
2226 if (isMultilanePrimitive(simpleNode)) {
2227 auto ret = combineMultilane<Add>(lastNode, simpleNode);
2228 if (ret) {
2229 // simplify result first, then expand.
2230 lastNode = ret->accept_mutator(simplifier_);
2231 lastNode = lastNode->accept_mutator(this);
2232 continue;
2233 }
2234 }
2235
2236 lastNode = alloc<Add>(lastNode, simpleNode);
2237 }
2238
2239 // If we have no add terms the scalar should go first.
2240 // E.g. 1 - x.
2241 bool scalarWritten = false;
2242 if (lastNode == nullptr) {
2243 auto scalarNode = v->scalar()->accept_mutator(simplifier_);
2244
2245 if (!immediateEquals(scalarNode, 0)) {
2246 lastNode = scalarNode;
2247 scalarWritten = true;
2248 }
2249 }
2250
2251 for (const auto& node : subTerms) {
2252 // Can still be first node if scalarVal is 0.
2253 if (lastNode == nullptr) {
2254 lastNode = node->accept_mutator(this);
2255 continue;
2256 }
2257
2258 // Negate the term back to positive since we'll be subtracting it.
2259 ExprPtr negated =
2260 evaluateOp(alloc<Mul>(immLike(node->scalar(), -1), node->scalar()));
2261 TermPtr newRHS = alloc<Term>(node->hasher(), negated, node->variables());
2262 lastNode = alloc<Sub>(lastNode, newRHS->accept_mutator(this));
2263 }
2264
2265 if (scalarWritten || immediateEquals(v->scalar(), 0)) {
2266 if (!lastNode) {
2267 return immLike(v, 0);
2268 }
2269 return lastNode;
2270 }
2271
2272 if (immediateIsNegative(v->scalar())) {
2273 // Negate the scalar and subtract.
2274 ExprPtr negated =
2275 evaluateOp(alloc<Mul>(immLike(lastNode, -1), v->scalar()));
2276 lastNode = alloc<Sub>(lastNode, evaluateOp(negated));
2277 } else {
2278 // we want to avoid a cast to the scalar if it would happen.
2279 if (v->scalar()->dtype() != lastNode->dtype()) {
2280 lastNode = alloc<Add>(
2281 lastNode, evaluateOp(alloc<Cast>(lastNode->dtype(), v->scalar())));
2282 } else {
2283 lastNode = alloc<Add>(lastNode, v->scalar());
2284 }
2285 }
2286
2287 return lastNode;
2288 }
2289
mutate(const MaxTermPtr & v)2290 ExprPtr TermExpander::mutate(const MaxTermPtr& v) {
2291 auto& variables = v->variables();
2292 if (variables.empty()) {
2293 if (!v->scalar()) {
2294 // This case should never happen because MaxTerm will be created only
2295 // on valid Max expressions.
2296 throw std::logic_error("empty maxterm op");
2297 }
2298 return v->scalar();
2299 }
2300 ExprPtr max;
2301 if (v->scalar()) {
2302 max = alloc<Max>(variables[0], v->scalar(), v->propagate_nans());
2303 } else {
2304 max = variables[0];
2305 }
2306 for (size_t i = 1; i < variables.size(); i++) {
2307 max = alloc<Max>(max, variables[i], v->propagate_nans());
2308 }
2309 return max->accept_mutator(this);
2310 }
2311
mutate(const MinTermPtr & v)2312 ExprPtr TermExpander::mutate(const MinTermPtr& v) {
2313 auto& variables = v->variables();
2314 if (variables.empty()) {
2315 if (!v->scalar()) {
2316 // This case should never happen because MinTerm will be created only
2317 // on valid Min expressions.
2318 throw std::logic_error("empty minterm op");
2319 }
2320 return v->scalar();
2321 }
2322 ExprPtr min;
2323 if (v->scalar()) {
2324 min = alloc<Min>(variables[0], v->scalar(), v->propagate_nans());
2325 } else {
2326 min = variables[0];
2327 }
2328 for (size_t i = 1; i < variables.size(); i++) {
2329 min = alloc<Min>(min, variables[i], v->propagate_nans());
2330 }
2331 return min->accept_mutator(this);
2332 }
2333
2334 // Expands RoundOff(x, y) => Term(1, Div(x, y), y), which will later be expanded
2335 // to Mul(Div(x, y), y).
mutate(const RoundOffPtr & v)2336 ExprPtr TermExpander::mutate(const RoundOffPtr& v) {
2337 TermPtr term = alloc<Term>(
2338 simplifier_->hasher(),
2339 immLike(v, 1),
2340 alloc<Div>(v->lhs(), v->rhs()),
2341 v->rhs());
2342 return term->accept_mutator(this);
2343 }
2344
buf_flat_size(const BufPtr & v)2345 ExprPtr buf_flat_size(const BufPtr& v) {
2346 std::vector<ExprPtr> dims = v->dims();
2347 if (dims.empty()) {
2348 return alloc<LongImm>(1);
2349 }
2350 ExprPtr flattened = immLike(dims[0], 1);
2351 for (auto& dim : dims) {
2352 flattened = alloc<Mul>(flattened, dim);
2353 }
2354 flattened = IRSimplifier::simplify(flattened);
2355
2356 return flattened;
2357 }
2358
mutate(const AllocatePtr & v)2359 StmtPtr TermExpander::mutate(const AllocatePtr& v) {
2360 BufPtr buf = v->buf();
2361 BufPtr buf_new = to<Buf>(v->buf()->accept_mutator(this));
2362 TORCH_INTERNAL_ASSERT(
2363 buf_new,
2364 buildErrorMessage("TermExpander mutation produced null for Buf."));
2365 ExprPtr flattened = buf_flat_size(buf_new);
2366
2367 if (flattened->isConstant() && immediateEquals(flattened, 0)) {
2368 eliminated_allocations_.insert(buf_new->base_handle());
2369 return nullptr;
2370 }
2371
2372 if (buf != buf_new) {
2373 v->set_buf(buf_new);
2374 }
2375 return v;
2376 }
2377
mutate(const FreePtr & v)2378 StmtPtr TermExpander::mutate(const FreePtr& v) {
2379 BufPtr buf = v->buf();
2380 BufPtr buf_new = to<Buf>(v->buf()->accept_mutator(this));
2381 TORCH_INTERNAL_ASSERT(
2382 buf_new,
2383 buildErrorMessage("TermExpander mutation produced null for Buf."));
2384
2385 if (eliminated_allocations_.count(buf_new->base_handle())) {
2386 eliminated_allocations_.erase(buf_new->base_handle());
2387 return nullptr;
2388 }
2389
2390 if (buf != buf_new) {
2391 v->set_buf(buf_new);
2392 }
2393 return v;
2394 }
2395
2396 // Combines adjacent Cond nodes with identical conditions.
fuseConditions(BlockPtr v)2397 BlockPtr TermExpander::fuseConditions(BlockPtr v) {
2398 std::vector<StmtPtr> stmts;
2399 bool did_anything = false;
2400 CondPtr prev_cond = nullptr;
2401
2402 for (const auto& s : *v) {
2403 CondPtr cond = to<Cond>(s);
2404 if (!cond) {
2405 prev_cond = nullptr;
2406 stmts.push_back(s);
2407 continue;
2408 }
2409
2410 // If the previous statement is a Cond and the conditions are identical,
2411 // then we fuse.
2412 if (!prev_cond ||
2413 hasher_.hash(prev_cond->condition()) !=
2414 hasher_.hash(cond->condition())) {
2415 prev_cond = cond;
2416 stmts.push_back(s);
2417 continue;
2418 }
2419
2420 // Fuse the two Conds by appending the bodies of the second Cond to the
2421 // first.
2422 BlockPtr true_block = alloc<Block>(std::vector<StmtPtr>({}));
2423 BlockPtr false_block = alloc<Block>(std::vector<StmtPtr>({}));
2424
2425 if (prev_cond->true_stmt()) {
2426 true_block->splice(true_block->end(), prev_cond->true_stmt());
2427 }
2428
2429 if (cond->true_stmt()) {
2430 true_block->splice(true_block->end(), cond->true_stmt());
2431 }
2432
2433 if (prev_cond->false_stmt()) {
2434 false_block->splice(false_block->end(), prev_cond->false_stmt());
2435 }
2436
2437 if (cond->false_stmt()) {
2438 false_block->splice(false_block->end(), cond->false_stmt());
2439 }
2440
2441 // avoid unflattening this Cond if we can.
2442 if (true_block->empty()) {
2443 true_block = nullptr;
2444 }
2445
2446 if (false_block->empty()) {
2447 false_block = nullptr;
2448 }
2449
2450 StmtPtr new_cond = prev_cond->cloneWithNewBodies(true_block, false_block)
2451 ->accept_mutator(this);
2452 prev_cond = to<Cond>(new_cond);
2453
2454 // erase, which shortens the list.
2455 stmts.pop_back();
2456 stmts.push_back(new_cond);
2457 did_anything = true;
2458 }
2459
2460 if (!did_anything) {
2461 return v;
2462 }
2463
2464 // clean up parents.
2465 for (const auto& s : stmts) {
2466 if (s->get_parent() == v) {
2467 v->remove_stmt(s);
2468 }
2469 }
2470
2471 return alloc<Block>(stmts);
2472 }
2473
fuseSyncThreads(BlockPtr block)2474 StmtPtr TermExpander::fuseSyncThreads(BlockPtr block) {
2475 // only really first if highest level Block.
2476 bool first = block->get_parent() == nullptr;
2477 SyncThreadsPtr last = nullptr;
2478 std::vector<StmtPtr> stmts;
2479 bool did_anything = false;
2480
2481 for (const auto& s : *block) {
2482 SyncThreadsPtr sync = to<SyncThreads>(s);
2483 if (!sync) {
2484 first = false;
2485 last = nullptr;
2486 stmts.push_back(s);
2487 continue;
2488 }
2489
2490 if (first || last) {
2491 did_anything = true;
2492 continue;
2493 }
2494
2495 last = sync;
2496 first = false;
2497 stmts.push_back(s);
2498 }
2499
2500 if (last) {
2501 stmts.pop_back();
2502 did_anything = true;
2503 }
2504
2505 if (!did_anything) {
2506 return block;
2507 }
2508
2509 // clean up parents.
2510 for (const auto& s : stmts) {
2511 if (s->get_parent() == block) {
2512 block->remove_stmt(s);
2513 }
2514 }
2515
2516 return alloc<Block>(std::vector<StmtPtr>({stmts}));
2517 }
2518
mutate(const BlockPtr & v)2519 StmtPtr TermExpander::mutate(const BlockPtr& v) {
2520 StmtPtr new_stmt = PolynomialBase::mutate(v);
2521 BlockPtr new_block = to<Block>(new_stmt);
2522 if (!new_block) {
2523 return new_stmt;
2524 }
2525
2526 // fuseConditions will return the original block if it cannot fuse.
2527 new_block = fuseConditions(new_block);
2528 /// fuseSyncThreads too.
2529 return fuseSyncThreads(new_block);
2530 }
2531
2532 // SimplifierUnderContext
2533 //
2534 // This function records the bounds(range) info of the index var in a for-stmt.
2535 // The bounds info will be used later when simplifying expressions with the
2536 // index var.
mutate(const ForPtr & v)2537 StmtPtr SimplifierUnderContext::mutate(const ForPtr& v) {
2538 ExprPtr var = v->var();
2539 ExprPtr start = v->start();
2540 ExprPtr stop = v->stop();
2541 StmtPtr body = v->body();
2542 LoopOptions loop_options = v->loop_options();
2543 ExprPtr var_new_expr = var->accept_mutator(this);
2544 VarPtr var_new = to<Var>(var_new_expr);
2545 ExprPtr start_new = start->accept_mutator(this);
2546 ExprPtr stop_new = stop->accept_mutator(this);
2547 StmtPtr body_new = body;
2548
2549 // save bounds info before this for-stmt
2550 //
2551 // The same variable could have appeared in a if-stmt which the for-stmt is
2552 // nested inside, and we need to restore its bounds info after the for-stmt.
2553 //
2554 // An example,
2555 // if (i>=0 && i<5) {
2556 // for (i=0; i<3; i++){
2557 // A[i] = ...
2558 // }
2559 // x = (i+20) / 5;
2560 //}
2561 // Inside the if stmt, i is in the range of [0, 5); and if we can restore this
2562 // bound info after the for stmt, we can use it to simplify the assignment
2563 // stmt x = (i+20)/5 to x = 4.
2564 bool has_bounds = false;
2565 analysis::Bound bound_old;
2566 VarPtr var_key = to<Var>(var);
2567 auto got = var_bound_info_.find(var_key);
2568 if (got != var_bound_info_.end()) {
2569 has_bounds = true;
2570 bound_old = got->second;
2571 }
2572 // set bounds info for index var
2573 const analysis::Bound bound_new(start_new, stop_new);
2574 var_bound_info_[var_key] = bound_new;
2575
2576 ExprPtr iters = alloc<Sub>(stop_new, start_new);
2577 iters = iters->accept_mutator(this);
2578 if (loop_options.isDefault() && iters->isConstant()) {
2579 if (immediateEquals(iters, 0)) {
2580 return alloc<Block>(std::vector<StmtPtr>({}));
2581 } else if (immediateEquals(iters, 1)) {
2582 body_new = Substitute(body, {{var_new, start_new}});
2583 body_new = body_new->accept_mutator(this);
2584
2585 // erase index var bounds info or restore old bounds info
2586 if (has_bounds) {
2587 var_bound_info_[var_key] = bound_old;
2588 } else {
2589 var_bound_info_.erase(var_key);
2590 }
2591
2592 return body_new;
2593 }
2594 }
2595
2596 body_new = body_new->accept_mutator(this);
2597
2598 // erase index var bounds info or restore old bounds info
2599 if (has_bounds) {
2600 var_bound_info_[var_key] = bound_old;
2601 } else {
2602 var_bound_info_.erase(var_key);
2603 }
2604
2605 if (!body_new) {
2606 return alloc<Block>(std::vector<StmtPtr>({}));
2607 }
2608
2609 if (auto block = to<Block>(body_new)) {
2610 if (block->nstmts() == 0) {
2611 return alloc<Block>(std::vector<StmtPtr>({}));
2612 }
2613
2614 if (block->nstmts() == 1) {
2615 // if the stmt in the loop body is a if-stmt, try to move the branching
2616 // out of the loop
2617 if (auto cond = to<Cond>(block->front())) {
2618 StmtPtr reordered = handleForCondReordering(v, cond);
2619 if (reordered) {
2620 return reordered->accept_mutator(this);
2621 }
2622 }
2623 }
2624 }
2625
2626 if (var != var_new) {
2627 v->set_var(var_new);
2628 }
2629 if (start != start_new) {
2630 v->set_start(start_new);
2631 }
2632 if (stop != stop_new) {
2633 v->set_stop(stop_new);
2634 }
2635 if (body != body_new) {
2636 v->set_body(body_new);
2637 }
2638 return v;
2639 }
2640
2641 // Simplify division using distributive laws for the following cases:
2642 // 1) (i + x) / n => x/n, if
2643 // a) n is a positive integer constant;
2644 // b) i is the index var of a for-stmt and the range of i is
2645 // a subset of [0, n);
2646 // c) x is a constant and the end value of i's range is less than n - x%n;
2647 // TODO: remove d) from the requirements because the simplification formula
2648 // still holds when x is a negative integer. In integer division, the result
2649 // of the division is converted to an integer using `floor` function which
2650 // returns the largest integer that is not greater than X. For example, -1/6
2651 // returns -1. But currently, both Pytorch and NNC are performing an incorrect
2652 // integer division: (-1)/6 = 0. With the current implementation of integer
2653 // division, x has to be not negative. d) x is not negative
2654 //
2655 // 2) (i + j*n) / n => j, if
2656 // a) n is a positive integer constant;
2657 // b) i is the index var of a for-stmt and the range of i is
2658 // a subset of [0, n);
2659 // c) j is an integer variable;
2660 // TODO: remove d) from the requirements because the simplification formula
2661 // still holds when j is a negative integer. In integer division, the result
2662 // of the division is converted to an integer using `floor` function which
2663 // returns the largest integer that is not greater than X. For example, -1/6
2664 // returns -1. But currently, both Pytorch and NNC are performing an incorrect
2665 // integer division: (-1)/6 = 0. With the current implementation of integer
2666 // division, x has to be not negative. d) j is not negative
distributeDiv(const ExprPtr & lhs,const ExprPtr & rhs,VarBoundInfo var_bound_info)2667 static ExprPtr distributeDiv(
2668 const ExprPtr& lhs,
2669 const ExprPtr& rhs,
2670 VarBoundInfo var_bound_info) {
2671 if (!lhs || !rhs) {
2672 return nullptr;
2673 }
2674 // return if not integer division
2675 if (lhs->dtype().is_floating_point() || rhs->dtype().is_floating_point()) {
2676 return nullptr;
2677 }
2678
2679 // identify n: a positive integer constant
2680 ExprPtr rhsScalar = rhs->isConstant() ? rhs : nullptr;
2681 if (!rhsScalar) {
2682 return nullptr;
2683 }
2684 ExprPtr check_n_value = IRSimplifier::simplify(
2685 alloc<CompareSelect>(rhsScalar, immLike(rhsScalar, 0), kGT));
2686 if (!immediateEquals(check_n_value, 1)) {
2687 return nullptr;
2688 }
2689
2690 auto lhsAdd = to<Add>(lhs);
2691 if (!lhsAdd) {
2692 return nullptr;
2693 }
2694 ExprPtr lhsAdd1 = lhsAdd->lhs();
2695 ExprPtr lhsAdd2 = lhsAdd->rhs();
2696
2697 // identify index var 'i'
2698 VarPtr var_key = to<Var>(lhsAdd1);
2699 ExprPtr main = lhsAdd2;
2700 if (var_key == nullptr) {
2701 var_key = to<Var>(lhsAdd2);
2702 main = lhsAdd1;
2703 }
2704
2705 if (var_key == nullptr) {
2706 return nullptr;
2707 }
2708
2709 auto got = var_bound_info.find(var_key);
2710 if (got == var_bound_info.end()) {
2711 return nullptr;
2712 }
2713
2714 // check the bounds of 'i'
2715 auto start = got->second.start;
2716 // open upper bound, i.e., end is one more than the maximum value in the
2717 // range
2718 auto end = got->second.end;
2719 ExprPtr check_start = IRSimplifier::simplify(
2720 alloc<CompareSelect>(start, immLike(start, 0), kGE));
2721 ExprPtr check_end =
2722 IRSimplifier::simplify(alloc<CompareSelect>(end, rhsScalar, kLE));
2723 if (!check_start->isConstant() || !check_end->isConstant() ||
2724 !immediateEquals(check_start, 1) || !immediateEquals(check_end, 1)) {
2725 return nullptr;
2726 }
2727
2728 ExprPtr ret = IRSimplifier::simplify(alloc<Div>(main, rhsScalar));
2729
2730 // simplify type 1) exprs: '(i+x)/n' => 'x/n'
2731 ExprPtr sign_check =
2732 IRSimplifier::simplify(alloc<CompareSelect>(main, immLike(main, 0), kGE));
2733 ExprPtr main_mod = IRSimplifier::simplify(alloc<Mod>(main, rhsScalar));
2734 ExprPtr mod_check = IRSimplifier::simplify(
2735 alloc<CompareSelect>(alloc<Add>(main_mod, end), rhsScalar, kLE));
2736 if (sign_check->isConstant() && immediateEquals(sign_check, 1) &&
2737 mod_check->isConstant() && immediateEquals(mod_check, 1)) {
2738 return ret;
2739 }
2740
2741 // simplify type 2 exprs: '(i+j*n)/n' => 'j'
2742 auto ret_var = to<Var>(ret);
2743 // FIXME: Allow any integral type.
2744 if (ret_var && ret_var->dtype() == kInt) {
2745 // retrieve j's range info
2746 auto got = var_bound_info.find(ret_var);
2747 if (got == var_bound_info.end()) {
2748 return nullptr;
2749 }
2750
2751 // check if j is not negative
2752 sign_check = IRSimplifier::simplify(alloc<CompareSelect>(
2753 got->second.start, immLike(got->second.start, 0), kGE));
2754 if (sign_check->isConstant() && immediateEquals(sign_check, 1)) {
2755 return ret_var;
2756 }
2757 }
2758
2759 return nullptr;
2760 }
2761
2762 // Simplify mod using distributive laws for the following cases:
2763 // 1) (i + x) % n => i + x%n if
2764 // a) n is a positive integer constant;
2765 // b) i is the index var of a for-stmt and the range of i is
2766 // a subset of [0, n);
2767 // c) x is a constant and the end value of i's range is less than n - x%n;
2768 // TODO: remove d) from the requirements because the simplification formula
2769 // still holds when x is a negative integer. In integer division, the result
2770 // of the division is converted to an integer using `floor` function which
2771 // returns the largest integer that is not greater than X. For example, -1/6
2772 // returns -1. But currently, both Pytorch and NNC are performing an incorrect
2773 // integer division: (-1)/6 = 0. With the current implementation of integer
2774 // division, x has to be not negative. d) x is not negative
2775 //
2776 // 2) (i + j*n) % n => i if
2777 // a) n is a positive integer constant;
2778 // b) i is the index var of a for-stmt and the range of i is
2779 // a subset of [0, n);
2780 // c) j is an integer variable;
2781 // TODO: remove d) from the requirements because the simplification formula
2782 // still holds when j is a negative integer. In integer division, the result
2783 // of the division is converted to an integer using `floor` function which
2784 // returns the largest integer that is not greater than X. For example, -1/6
2785 // returns -1. But currently, both Pytorch and NNC are performing an incorrect
2786 // integer division: (-1)/6 = 0. With the current implementation of integer
2787 // division, j has to be not negative. d) j is not negative
distributeMod(const ExprPtr & lhs,const ExprPtr & rhs,VarBoundInfo var_bound_info)2788 static ExprPtr distributeMod(
2789 const ExprPtr& lhs,
2790 const ExprPtr& rhs,
2791 VarBoundInfo var_bound_info) {
2792 if (!lhs || !rhs) {
2793 return nullptr;
2794 }
2795 // return if not integer mod
2796 if (lhs->dtype().is_floating_point() || rhs->dtype().is_floating_point()) {
2797 return nullptr;
2798 }
2799
2800 // identify n: a positive integer constant
2801 ExprPtr rhsScalar = rhs->isConstant() ? rhs : nullptr;
2802 if (!rhsScalar) {
2803 return nullptr;
2804 }
2805 ExprPtr check_n_value = IRSimplifier::simplify(
2806 alloc<CompareSelect>(rhsScalar, immLike(rhsScalar, 0), kGT));
2807 if (!immediateEquals(check_n_value, 1)) {
2808 return nullptr;
2809 }
2810
2811 auto lhsAdd = to<Add>(lhs);
2812 if (!lhsAdd) {
2813 return nullptr;
2814 }
2815 if (!lhsAdd || !rhsScalar) {
2816 return nullptr;
2817 }
2818 ExprPtr lhsAdd1 = lhsAdd->lhs();
2819 ExprPtr lhsAdd2 = lhsAdd->rhs();
2820
2821 // identify index var 'i'
2822 VarPtr var_key = to<Var>(lhsAdd1);
2823 ExprPtr main = lhsAdd2;
2824 if (var_key == nullptr) {
2825 var_key = to<Var>(lhsAdd2);
2826 main = lhsAdd1;
2827 }
2828 if (var_key == nullptr) {
2829 return nullptr;
2830 }
2831
2832 auto got = var_bound_info.find(var_key);
2833 if (got == var_bound_info.end()) {
2834 return nullptr;
2835 }
2836
2837 // check the bounds of 'i'
2838 auto start = got->second.start;
2839 // open upper bound, i.e., end is one more than the maximum value in the
2840 // range
2841 auto end = got->second.end;
2842 ExprPtr check_start = IRSimplifier::simplify(
2843 alloc<CompareSelect>(start, immLike(start, 0), kGE));
2844 ExprPtr check_end =
2845 IRSimplifier::simplify(alloc<CompareSelect>(end, rhsScalar, kLE));
2846 if (!check_start->isConstant() || !check_end->isConstant() ||
2847 !immediateEquals(check_start, 1) || !immediateEquals(check_end, 1)) {
2848 return nullptr;
2849 }
2850
2851 // simplify type 1) exprs: '(i+x)%n' => 'i+x%n'
2852 ExprPtr sign_check =
2853 IRSimplifier::simplify(alloc<CompareSelect>(main, immLike(main, 0), kGE));
2854 ExprPtr main_mod = IRSimplifier::simplify(alloc<Mod>(main, rhsScalar));
2855 ExprPtr mod_check = IRSimplifier::simplify(
2856 alloc<CompareSelect>(alloc<Add>(main_mod, end), rhsScalar, kLE));
2857 if (sign_check->isConstant() && immediateEquals(sign_check, 1) &&
2858 mod_check->isConstant() && immediateEquals(mod_check, 1)) {
2859 return alloc<Add>(var_key, main_mod);
2860 }
2861
2862 // simplify type 2) exprs: '(i+j*n)%n' => 'i'
2863 ExprPtr main_div = IRSimplifier::simplify(alloc<Div>(main, rhsScalar));
2864 auto j_var = to<Var>(main_div);
2865 // FIXME: Allow any integral type.
2866 if (j_var && j_var->dtype() == kInt) {
2867 // retrieve j's range info
2868 auto got = var_bound_info.find(j_var);
2869 if (got == var_bound_info.end()) {
2870 return nullptr;
2871 }
2872
2873 // check if j is not negative
2874 sign_check = IRSimplifier::simplify(alloc<CompareSelect>(
2875 got->second.start, immLike(got->second.start, 0), kGE));
2876 if (sign_check->isConstant() && immediateEquals(sign_check, 1)) {
2877 return var_key;
2878 }
2879 }
2880
2881 return nullptr;
2882 }
2883
mutate(const DivPtr & v)2884 ExprPtr SimplifierUnderContext::mutate(const DivPtr& v) {
2885 ExprPtr lhs = v->lhs();
2886 ExprPtr rhs = v->rhs();
2887
2888 std::ostringstream oss;
2889 if (auto ret = distributeDiv(lhs, rhs, var_bound_info_)) {
2890 GRAPH_DEBUG("SimplifierUnderContext: ", *v, " => ", *ret);
2891 return ret->accept_mutator(this);
2892 }
2893
2894 // i / N -> 0 if the range of i's values is a subset of [0, N)
2895 // where N is an integer constant
2896 auto lhsVar = to<Var>(lhs);
2897 ExprPtr rhsScalar = rhs->isConstant() ? rhs : nullptr;
2898 if (lhsVar && rhsScalar && !rhsScalar->dtype().is_floating_point()) {
2899 auto got = var_bound_info_.find(lhsVar);
2900 if (got != var_bound_info_.end()) {
2901 auto start = got->second.start;
2902 auto end = got->second.end;
2903 ExprPtr check_start = IRSimplifier::simplify(
2904 alloc<CompareSelect>(start, immLike(start, 0), kGE));
2905 ExprPtr check_end =
2906 IRSimplifier::simplify(alloc<CompareSelect>(end, rhsScalar, kLE));
2907 if (check_start->isConstant() && check_end->isConstant() &&
2908 immediateEquals(check_start, 1) && immediateEquals(check_end, 1)) {
2909 GRAPH_DEBUG(
2910 "SimplifierUnderContext: ", *v, " => ", *immLike(lhsVar, 0));
2911 return immLike(lhsVar, 0);
2912 }
2913 }
2914 }
2915
2916 ExprPtr lhs_new = lhs->accept_mutator(this);
2917 ExprPtr rhs_new = rhs->accept_mutator(this);
2918 if (lhs == lhs_new && rhs == rhs_new) {
2919 return v;
2920 }
2921 return alloc<Div>(lhs_new, rhs_new);
2922 }
2923
mutate(const IfThenElsePtr & v)2924 ExprPtr SimplifierUnderContext::mutate(const IfThenElsePtr& v) {
2925 ExprPtr condition = v->condition();
2926 ExprPtr true_val = v->true_value();
2927 ExprPtr false_val = v->false_value();
2928
2929 auto simplified_condition =
2930 IRSimplifier::simplify(condition->accept_mutator(this));
2931 auto simplified_true_val =
2932 IRSimplifier::simplify(true_val->accept_mutator(this));
2933 auto simplified_false_val =
2934 IRSimplifier::simplify(false_val->accept_mutator(this));
2935 if (simplified_condition->isConstant()) {
2936 return immediateAs<int>(simplified_condition) ? simplified_true_val
2937 : simplified_false_val;
2938 }
2939
2940 bool nothing_changed = (simplified_condition == condition) &&
2941 (simplified_true_val == true_val) && (simplified_false_val == false_val);
2942 return nothing_changed
2943 ? v
2944 : alloc<IfThenElse>(
2945 simplified_condition, simplified_true_val, simplified_false_val);
2946 }
2947
mutate(const CompareSelectPtr & v)2948 ExprPtr SimplifierUnderContext::mutate(const CompareSelectPtr& v) {
2949 GRAPH_DEBUG("(SimplifierUnderContext) Original: ", std::to_string(v));
2950
2951 ExprPtr lhs = v->lhs();
2952 ExprPtr rhs = v->rhs();
2953 ExprPtr ret1 = v->ret_val1();
2954 ExprPtr ret2 = v->ret_val2();
2955
2956 auto simplified_lhs = IRSimplifier::simplify(lhs->accept_mutator(this));
2957 auto simplified_rhs = IRSimplifier::simplify(rhs->accept_mutator(this));
2958 auto simplified_ret1 = IRSimplifier::simplify(ret1->accept_mutator(this));
2959 auto simplified_ret2 = IRSimplifier::simplify(ret2->accept_mutator(this));
2960
2961 ExprPtr simplified_cmp_select_expr = nullptr;
2962 if ((simplified_lhs == lhs) && (simplified_rhs == rhs) &&
2963 (simplified_ret1 == ret1) && (simplified_ret2 == ret2)) {
2964 simplified_cmp_select_expr = v;
2965 } else {
2966 simplified_cmp_select_expr = alloc<CompareSelect>(
2967 simplified_lhs,
2968 simplified_rhs,
2969 simplified_ret1,
2970 simplified_ret2,
2971 v->compare_select_op(),
2972 v->bias());
2973 }
2974
2975 GRAPH_DEBUG(
2976 "(SimplifierUnderContext) after simplify: ",
2977 std::to_string(simplified_cmp_select_expr));
2978
2979 analysis::Bound lhs_bound;
2980 analysis::Bound rhs_bound;
2981 auto lhs_has_bound = getLoopBoundInfo(simplified_lhs, &lhs_bound);
2982 auto rhs_has_bound = getLoopBoundInfo(simplified_rhs, &rhs_bound);
2983 if (!lhs_has_bound || !rhs_has_bound) {
2984 GRAPH_DEBUG(
2985 "(SimplifierUnderContext) Final: ",
2986 std::to_string(simplified_cmp_select_expr));
2987 return simplified_cmp_select_expr;
2988 }
2989
2990 analysis::CmpEvalResult cmp_res =
2991 analysis::compareBound(lhs_bound, rhs_bound, v->compare_select_op());
2992
2993 // Return the simplified ret1/ret2 if the compare result is deterministic.
2994 // Otherwise, return the simplified CompareSelect directly.
2995 auto ret_expr = (cmp_res == analysis::CmpEvalResult::True)
2996 ? simplified_ret1
2997 : ((cmp_res == analysis::CmpEvalResult::False)
2998 ? simplified_ret2
2999 : simplified_cmp_select_expr);
3000 GRAPH_DEBUG("(SimplifierUnderContext) Final: ", std::to_string(ret_expr));
3001 return ret_expr;
3002 }
3003
mutate(const ModPtr & v)3004 ExprPtr SimplifierUnderContext::mutate(const ModPtr& v) {
3005 ExprPtr lhs = v->lhs();
3006 ExprPtr rhs = v->rhs();
3007
3008 std::ostringstream oss;
3009 if (auto ret = distributeMod(lhs, rhs, var_bound_info_)) {
3010 GRAPH_DEBUG("SimplifierUnderContext: ", *v, " => ", *ret);
3011 return ret->accept_mutator(this);
3012 }
3013
3014 // i % N -> i if the range of i's values is a subset of [0, N)
3015 // where N is an integer constant
3016 auto lhsVar = to<Var>(lhs);
3017 ExprPtr rhsScalar = rhs->isConstant() ? rhs : nullptr;
3018 if (lhsVar && rhsScalar && !rhsScalar->dtype().is_floating_point()) {
3019 auto got = var_bound_info_.find(lhsVar);
3020 if (got != var_bound_info_.end()) {
3021 auto start = got->second.start;
3022 auto end = got->second.end;
3023 ExprPtr check_start = IRSimplifier::simplify(
3024 alloc<CompareSelect>(start, immLike(start, 0), kGE));
3025 ExprPtr check_end =
3026 IRSimplifier::simplify(alloc<CompareSelect>(end, rhsScalar, kLE));
3027 if (check_start->isConstant() && check_end->isConstant() &&
3028 immediateEquals(check_start, 1) && immediateEquals(check_end, 1)) {
3029 GRAPH_DEBUG("SimplifierUnderContext: ", *v, " => ", *lhsVar);
3030 return lhsVar;
3031 }
3032 }
3033 }
3034
3035 ExprPtr lhs_new = lhs->accept_mutator(this);
3036 ExprPtr rhs_new = rhs->accept_mutator(this);
3037 if (lhs == lhs_new && rhs == rhs_new) {
3038 return v;
3039 }
3040 return alloc<Mod>(lhs_new, rhs_new);
3041 }
3042
getLoopBoundInfo(const ExprPtr & expr,analysis::Bound * loop_bound_info)3043 bool SimplifierUnderContext::getLoopBoundInfo(
3044 const ExprPtr& expr,
3045 analysis::Bound* loop_bound_info) {
3046 if (expr == nullptr)
3047 return false;
3048
3049 if (expr->isConstant()) {
3050 loop_bound_info->start = expr;
3051 loop_bound_info->end = expr;
3052 return true;
3053 }
3054
3055 VarPtr var_key = to<Var>(expr);
3056 if (var_key == nullptr) {
3057 return false;
3058 }
3059
3060 auto got = var_bound_info_.find(var_key);
3061 if (got == var_bound_info_.end()) {
3062 return false;
3063 }
3064
3065 loop_bound_info->start = got->second.start;
3066 // TODO: Need to add the boundary information(close/open) of a range to
3067 // Bound. Currently, the VarBoundInfo comes from for-loop statement while
3068 // the end of the boundary is open. But we assume the start and end of a
3069 // range are always close. Hence, we explicitly convert the open boundary to
3070 // close.
3071 // [for-start, for-stop) => [for-start, for-stop -1]
3072 loop_bound_info->end = IRSimplifier::simplify(
3073 alloc<Sub>(got->second.end, immLike(got->second.end, 1)));
3074 return true;
3075 }
3076
exprEquals(const ExprPtr & A,const ExprPtr & B)3077 bool exprEquals(const ExprPtr& A, const ExprPtr& B) {
3078 try {
3079 ExprPtr diff = IRSimplifier::simplify(alloc<Sub>(A, B));
3080 if (!diff->isConstant()) {
3081 return false;
3082 }
3083 return immediateEquals(diff, 0);
3084 } catch (std::exception& e) {
3085 return false;
3086 }
3087 }
3088
simplify(ExprPtr e)3089 ExprPtr IRSimplifier::simplify(ExprPtr e) {
3090 GRAPH_DEBUG("(Simplifier) Original: ", std::to_string(e));
3091 SimplifierUnderContext ctxsimplifier;
3092 e = e->accept_mutator(&ctxsimplifier);
3093
3094 PolynomialTransformer simplifier;
3095 e = e->accept_mutator(&simplifier);
3096
3097 // There may be terms left in the IR, expand them.
3098 TermExpander expander(&simplifier);
3099 e = e->accept_mutator(&expander);
3100 if (!expander.check_safe()) {
3101 throw malformed_input("eliminated null Allocation without free");
3102 }
3103
3104 GRAPH_DEBUG("(Simplifier) Simplified: ", std::to_string(e));
3105 return e;
3106 }
3107
simplify(StmtPtr s)3108 StmtPtr IRSimplifier::simplify(StmtPtr s) {
3109 GRAPH_DEBUG("(Simplifier) Original: ", std::to_string(s));
3110 SimplifierUnderContext ctxsimplifier;
3111 s = s->accept_mutator(&ctxsimplifier);
3112
3113 PolynomialTransformer simplifier;
3114 s = s->accept_mutator(&simplifier);
3115 if (s == nullptr) {
3116 GRAPH_DEBUG("(Simplifier) Simplified: NULL");
3117 return nullptr;
3118 }
3119
3120 // There may be terms left in the IR, expand them.
3121 TermExpander expander(&simplifier);
3122 s = s->accept_mutator(&expander);
3123 if (!expander.check_safe()) {
3124 throw malformed_input("eliminated null Allocation without free");
3125 }
3126
3127 GRAPH_DEBUG("(Simplifier) Simplified: ", std::to_string(s));
3128 return s;
3129 }
3130
3131 } // namespace torch::jit::tensorexpr
3132