xref: /aosp_15_r20/external/pytorch/test/cpp/tensorexpr/test_simplify.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 #include <test/cpp/tensorexpr/test_base.h>
3 
4 #include <c10/util/irange.h>
5 #include <test/cpp/tensorexpr/test_utils.h>
6 #include <torch/csrc/jit/tensorexpr/hash_provider.h>
7 #include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
8 #include <torch/csrc/jit/tensorexpr/loopnest.h>
9 
10 #include <cmath>
11 
12 namespace torch {
13 namespace jit {
14 using namespace torch::jit::tensorexpr;
15 using SimpleIRExprEval = ExprEval<SimpleIREvaluator>;
16 
TEST(Simplify,ConstantFoldSimple)17 TEST(Simplify, ConstantFoldSimple) {
18   ExprHandle a(2.0f);
19   ExprHandle b(3.0f);
20   ExprHandle f = (a + b);
21 
22   ExprHandle newF = IRSimplifier::simplify(f);
23   ASSERT_NE(newF.AsNode<FloatImm>(), nullptr);
24   ASSERT_EQ(newF.AsNode<FloatImm>()->value(), 5);
25 
26   SimpleIRExprEval eval(newF);
27   ASSERT_EQ(eval.value<float>(), 5.f);
28 }
29 
TEST(Simplify,ConstantFoldTwoLayer)30 TEST(Simplify, ConstantFoldTwoLayer) {
31   ExprHandle a(2.0f);
32   ExprHandle b(3.0f);
33   ExprHandle c(4.0f);
34   ExprHandle d(5.0f);
35   ExprHandle f = (a + b) - (c + d);
36 
37   ExprHandle newF = IRSimplifier::simplify(f);
38   ASSERT_NE(newF.AsNode<FloatImm>(), nullptr);
39   ASSERT_EQ(newF.AsNode<FloatImm>()->value(), -4);
40 
41   SimpleIRExprEval eval(newF);
42   ASSERT_EQ(eval.value<float>(), -4.f);
43 }
44 
TEST(Simplify,ConstantFoldShifts)45 TEST(Simplify, ConstantFoldShifts) {
46   ExprHandle a(7);
47   ExprHandle b(2);
48   ExprHandle c(3);
49   ExprHandle f = ((a << b) << b) >> c;
50 
51   ExprHandle newF = IRSimplifier::simplify(f);
52   ASSERT_NE(newF.AsNode<IntImm>(), nullptr);
53   ASSERT_EQ(newF.AsNode<IntImm>()->value(), 14);
54 
55   SimpleIRExprEval eval(newF);
56   ASSERT_EQ(eval.value<int>(), 7 << (4 - 3));
57 }
58 
TEST(Simplify,ConstantFoldBitwise)59 TEST(Simplify, ConstantFoldBitwise) {
60   ExprHandle a(59);
61   ExprHandle b(22);
62   ExprHandle c(101);
63   ExprHandle f = (a ^ b) & c;
64 
65   ExprHandle newF = IRSimplifier::simplify(f);
66   ASSERT_NE(newF.AsNode<IntImm>(), nullptr);
67   ASSERT_EQ(newF.AsNode<IntImm>()->value(), 37);
68 
69   SimpleIRExprEval eval(newF);
70   ASSERT_EQ(eval.value<int>(), (59 ^ 22) & 101);
71 }
72 
TEST(Simplify,ConstantFoldMultiOp)73 TEST(Simplify, ConstantFoldMultiOp) {
74   ExprHandle a(2.0f);
75   ExprHandle b(3.0f);
76   ExprHandle c(4.0f);
77   ExprHandle d(5.0f);
78   ExprHandle e(6.0f);
79   ExprHandle f(7.0f);
80   ExprHandle fn = ((a / e) - (c + d)) * (f / b);
81 
82   ExprHandle newF = IRSimplifier::simplify(fn);
83   ASSERT_NE(newF.AsNode<FloatImm>(), nullptr);
84 
85   SimpleIRExprEval eval(newF);
86   SimpleIRExprEval ref(fn);
87 
88   ASSERT_EQ(eval.value<float>(), ref.value<float>());
89 }
90 
TEST(Simplify,ConstantFoldMinMax)91 TEST(Simplify, ConstantFoldMinMax) {
92   ExprHandle a(12.0f);
93   ExprHandle b(15.0f);
94   ExprHandle c(17.0f);
95 
96   // x = max(12, min(15, 17)).
97   ExprHandle minHandle = Min::make(b, c, true);
98   ExprHandle fn = Max::make(a, minHandle, false);
99 
100   // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
101   ASSERT_EQ(fn.dtype().scalar_type(), ScalarType::Float);
102 
103   ExprHandle newF = IRSimplifier::simplify(fn);
104   ASSERT_NE(newF.AsNode<FloatImm>(), nullptr);
105 
106   SimpleIRExprEval eval(newF);
107   ASSERT_EQ(eval.value<float>(), 15.f);
108 }
109 
TEST(Simplify,ConstantFoldIntrinsics)110 TEST(Simplify, ConstantFoldIntrinsics) {
111   ExprHandle a(2.0f);
112   ExprHandle b(3.0f);
113   ExprHandle c(4.0f);
114   ExprHandle powHandle = Intrinsics::make(kPow, a, b);
115   ExprHandle sinHandle = Intrinsics::make(kSin, powHandle);
116   ExprHandle modHandle = Intrinsics::make(kFmod, c, sinHandle);
117   ExprHandle logHandle = Intrinsics::make(kLog10, modHandle);
118   ExprHandle rndHandle = Intrinsics::make(kRound, logHandle);
119   ExprHandle fn = Intrinsics::make(kAbs, rndHandle);
120 
121   ExprHandle newF = IRSimplifier::simplify(fn);
122   ASSERT_NE(newF.AsNode<FloatImm>(), nullptr);
123   ASSERT_EQ(newF.AsNode<FloatImm>()->value(), 1);
124 
125   SimpleIRExprEval eval(newF);
126   SimpleIRExprEval ref(fn);
127 
128   ASSERT_EQ(eval.value<float>(), ref.value<float>());
129 }
130 
TEST(Simplify,ConstantFoldCastToBool)131 TEST(Simplify, ConstantFoldCastToBool) {
132   ExprHandle f = Cast::make(kBool, IntImm::make(0));
133   ExprHandle newF = IRSimplifier::simplify(f);
134   SimpleIRExprEval eval(newF);
135   ASSERT_EQ(eval.value<bool>(), false);
136 }
137 
TEST(Simplify,ConstantFoldWithVar)138 TEST(Simplify, ConstantFoldWithVar) {
139   {
140     VarHandle x("x", kInt);
141     ExprHandle body = x * (ExprHandle(2) + ExprHandle(4));
142 
143     ExprHandle newF = IRSimplifier::simplify(body);
144     MulPtr root = newF.AsNode<Mul>();
145     ASSERT_NE(root, nullptr);
146     ASSERT_NE(to<IntImm>(root->lhs()), nullptr);
147 
148     SimpleIRExprEval eval(newF);
149     eval.bindVar(x, ExprHandle(3));
150     ASSERT_EQ(eval.value<int>(), 3 * (2 + 4));
151   }
152 
153   {
154     VarHandle x("x", kFloat);
155     ExprHandle body = x * (ExprHandle(2.f) + ExprHandle(4.f));
156 
157     ExprHandle newF = IRSimplifier::simplify(body);
158     MulPtr root = newF.AsNode<Mul>();
159     ASSERT_NE(root, nullptr);
160     ASSERT_NE(to<FloatImm>(root->rhs()), nullptr);
161 
162     SimpleIRExprEval eval(newF);
163     eval.bindVar(x, ExprHandle(3.f));
164     ASSERT_EQ(eval.value<float>(), 3 * (2 + 4));
165   }
166 }
167 
TEST(Simplify,ConditionalSelectFoldSimple)168 TEST(Simplify, ConditionalSelectFoldSimple) {
169   ExprHandle a(3.0f);
170   ExprHandle b(4.0f);
171   ExprHandle c(3.0f);
172   {
173     ExprHandle f = (a > b);
174 
175     ExprHandle newF = IRSimplifier::simplify(f);
176     ASSERT_NE(newF.AsNode<IntImm>(), nullptr);
177     ASSERT_EQ(newF.AsNode<IntImm>()->value(), 0);
178 
179     SimpleIRExprEval eval(newF);
180     ASSERT_EQ(eval.value<int>(), 0);
181   }
182   {
183     ExprHandle f = (a < b);
184 
185     ExprHandle newF = IRSimplifier::simplify(f);
186     ASSERT_NE(newF.AsNode<IntImm>(), nullptr);
187     ASSERT_EQ(newF.AsNode<IntImm>()->value(), 1);
188 
189     SimpleIRExprEval eval(newF);
190     ASSERT_EQ(eval.value<int>(), 1);
191   }
192   {
193     ExprHandle f = (a == c);
194 
195     ExprHandle newF = IRSimplifier::simplify(f);
196     ASSERT_NE(newF.AsNode<IntImm>(), nullptr);
197     ASSERT_EQ(newF.AsNode<IntImm>()->value(), 1);
198 
199     SimpleIRExprEval eval(newF);
200     ASSERT_EQ(eval.value<int>(), 1);
201   }
202   {
203     ExprHandle f = (a != c);
204 
205     ExprHandle newF = IRSimplifier::simplify(f);
206     ASSERT_NE(newF.AsNode<IntImm>(), nullptr);
207     ASSERT_EQ(newF.AsNode<IntImm>()->value(), 0);
208 
209     SimpleIRExprEval eval(newF);
210     ASSERT_EQ(eval.value<int>(), 0);
211   }
212 }
213 
TEST(Simplify,ConditionalSelectFoldTwoLayer)214 TEST(Simplify, ConditionalSelectFoldTwoLayer) {
215   ExprHandle a(3.0f);
216   ExprHandle b(2.0f);
217   ExprHandle c(2.0f);
218   ExprHandle d(1.0f);
219   {
220     ExprHandle f = (a + b < c + d);
221 
222     ExprHandle newF = IRSimplifier::simplify(f);
223     ASSERT_NE(newF.AsNode<IntImm>(), nullptr);
224     ASSERT_EQ(newF.AsNode<IntImm>()->value(), 0);
225 
226     SimpleIRExprEval eval(newF);
227     ASSERT_EQ(eval.value<int>(), 0);
228   }
229   {
230     ExprHandle f = (a + b > c + d);
231 
232     ExprHandle newF = IRSimplifier::simplify(f);
233     ASSERT_NE(newF.AsNode<IntImm>(), nullptr);
234     ASSERT_EQ(newF.AsNode<IntImm>()->value(), 1);
235 
236     SimpleIRExprEval eval(newF);
237     ASSERT_EQ(eval.value<int>(), 1);
238   }
239   {
240     ExprHandle f = (a + d == b + c);
241 
242     ExprHandle newF = IRSimplifier::simplify(f);
243     ASSERT_NE(newF.AsNode<IntImm>(), nullptr);
244     ASSERT_EQ(newF.AsNode<IntImm>()->value(), 1);
245 
246     SimpleIRExprEval eval(newF);
247     ASSERT_EQ(eval.value<int>(), 1);
248   }
249   {
250     ExprHandle f = (a + d != b + c);
251 
252     ExprHandle newF = IRSimplifier::simplify(f);
253     ASSERT_NE(newF.AsNode<IntImm>(), nullptr);
254     ASSERT_EQ(newF.AsNode<IntImm>()->value(), 0);
255 
256     SimpleIRExprEval eval(newF);
257     ASSERT_EQ(eval.value<int>(), 0);
258   }
259 }
260 
TEST(Simplify,ConditionalSelectFoldWithVar)261 TEST(Simplify, ConditionalSelectFoldWithVar) {
262   VarHandle x("x", kFloat);
263   ExprHandle f = x < 4.f;
264 
265   ExprHandle newF = IRSimplifier::simplify(f);
266   IntImmPtr folded = newF.AsNode<IntImm>();
267   ASSERT_EQ(folded, nullptr);
268 
269   {
270     SimpleIRExprEval eval(newF);
271     eval.bindVar(x, ExprHandle(3.f));
272     ASSERT_EQ(eval.value<int>(), 1);
273   }
274   {
275     SimpleIRExprEval eval(newF);
276     eval.bindVar(x, ExprHandle(5.f));
277     ASSERT_EQ(eval.value<int>(), 0);
278   }
279 }
280 
TEST(Simplify,UnFoldableExpr)281 TEST(Simplify, UnFoldableExpr) {
282   VarHandle x("x", kFloat);
283   VarHandle y("y", kFloat);
284   ExprHandle body = (ExprHandle(3) * x) + (ExprHandle(5) * y);
285 
286   ExprHandle newF = IRSimplifier::simplify(body);
287   AddPtr root = newF.AsNode<Add>();
288   ASSERT_NE(root, nullptr);
289   ASSERT_EQ(to<FloatImm>(root->lhs()), nullptr);
290   ASSERT_EQ(to<FloatImm>(root->rhs()), nullptr);
291 
292   SimpleIRExprEval eval(newF);
293   eval.bindVar(x, ExprHandle(3.f));
294   eval.bindVar(y, ExprHandle(2.f));
295   ASSERT_EQ(eval.value<float>(), 9 + 10);
296 }
297 
TEST(Simplify,HashSimple)298 TEST(Simplify, HashSimple) {
299   VarHandle x("x", kFloat);
300   ExprHandle a(2.0f);
301   ExprHandle b(3.0f);
302   ExprHandle f = a + b * x;
303 
304   HashProvider hasher;
305 
306   auto hash_x = hasher.hash(x.node());
307   auto hash_a = hasher.hash(a.node());
308   auto hash_f = hasher.hash(f.node());
309 
310   ASSERT_NE(hash_x, (size_t)0);
311   ASSERT_NE(hash_a, (size_t)0);
312   ASSERT_NE(hash_f, (size_t)0);
313   ASSERT_NE(hash_x, hash_a);
314   ASSERT_NE(hash_x, hash_f);
315   ASSERT_NE(hash_a, hash_f);
316 }
317 
TEST(Simplify,HashEquivalence)318 TEST(Simplify, HashEquivalence) {
319   VarHandle x("x", kFloat);
320   VarHandle y("y", kFloat);
321   ExprHandle f = (x * y) + (x * y);
322 
323   AddPtr root = f.AsNode<Add>();
324   ASSERT_NE(root, nullptr);
325 
326   HashProvider hasher;
327   auto hash_f = hasher.hash(f.node());
328   auto hash_l = hasher.hash(root->lhs());
329   auto hash_r = hasher.hash(root->rhs());
330 
331   // Root not equal to either branch.
332   ASSERT_NE(hash_f, hash_l);
333   ASSERT_NE(hash_f, hash_r);
334   // but branches are equal.
335   ASSERT_EQ(hash_l, hash_r);
336 
337   // Still equivalent if separate.
338   ExprHandle a(2);
339   ExprHandle f2 = x + a / y;
340   ExprHandle b(2);
341   ExprHandle f3 = x + b / y;
342   ASSERT_EQ(hasher.hash(f2.node()), hasher.hash(f3.node()));
343 
344   // Not equivalent if different vars (even with same name).
345   VarHandle z("x", kFloat);
346   ExprHandle f4 = z + b / y;
347   ASSERT_NE(hasher.hash(f2.node()), hasher.hash(f4.node()));
348 
349   // Intrinsics sanity check.
350   ExprHandle f5 = Intrinsics::make(kSin, x) * Intrinsics::make(kCos, x);
351   ASSERT_NE(hasher.hash(f5.node()), (size_t)0);
352 }
353 
TEST(Simplify,HashEquivalenceRand)354 TEST(Simplify, HashEquivalenceRand) {
355   ExprHandle f =
356       Intrinsics::make(kRand, kFloat) + Intrinsics::make(kRand, kInt);
357 
358   AddPtr root = f.AsNode<Add>();
359   ASSERT_NE(root, nullptr);
360 
361   HashProvider hasher;
362   auto hash_f = hasher.hash(f.node());
363   auto hash_l = hasher.hash(root->lhs());
364   auto hash_r = hasher.hash(root->rhs());
365 
366   // Root not equal to either branch.
367   ASSERT_NE(hash_f, hash_l);
368   ASSERT_NE(hash_f, hash_r);
369   // and branches are NOT equal.
370   ASSERT_NE(hash_l, hash_r);
371 }
372 
TEST(Simplify,HashEquivalenceAfterFolding)373 TEST(Simplify, HashEquivalenceAfterFolding) {
374   VarHandle x("x", kFloat);
375   ExprHandle a(2.0f);
376   ExprHandle b(3.0f);
377   ExprHandle c(5.0f);
378 
379   ExprHandle f1 = ((a + b) * x);
380   ExprHandle f2 = (c * x);
381 
382   HashProvider hasher;
383   auto hash_l = hasher.hash(f1.node());
384   auto hash_r = hasher.hash(f2.node());
385 
386   // Root not equal to either branch, and branches not equal.
387   ASSERT_NE(hash_l, hash_r);
388 
389   ExprHandle ff1 = IRSimplifier::simplify(f1);
390   ExprHandle ff2 = IRSimplifier::simplify(f2);
391 
392   auto hash_l_n = hasher.hash(ff1.node());
393   auto hash_r_n = hasher.hash(ff2.node());
394   // but branches are now equal.
395   ASSERT_EQ(hash_l_n, hash_r_n);
396 }
397 
TEST(Simplify,HashDifferenceTypes)398 TEST(Simplify, HashDifferenceTypes) {
399   HashProvider hasher;
400   std::vector<ExprPtr> immediates;
401 
402   immediates.push_back(alloc<DoubleImm>(1));
403   immediates.push_back(alloc<FloatImm>(1));
404   immediates.push_back(alloc<HalfImm>(1));
405   // NOLINTNEXTLINE(modernize-use-bool-literals)
406   immediates.push_back(alloc<BoolImm>(1));
407   immediates.push_back(alloc<CharImm>(1));
408   immediates.push_back(alloc<ByteImm>(1));
409   immediates.push_back(alloc<ShortImm>(1));
410   immediates.push_back(alloc<IntImm>(1));
411   immediates.push_back(alloc<LongImm>(1));
412 
413   // Immediates of different types are not equal.
414   for (unsigned int i = 0; i < immediates.size(); ++i) {
415     for (unsigned int j = i + 1; j < immediates.size(); ++j) {
416       ASSERT_NE(hasher.hash(immediates[i]), hasher.hash(immediates[j]));
417     }
418   }
419 
420   // But coerced immediates are if they are the same type:
421   ExprHandle f1 = ExprHandle(2.f) + CharImm::make(1);
422   ExprHandle f2 = Cast::make(kFloat, IntImm::make(3));
423 
424   ExprHandle ff1 = IRSimplifier::simplify(f1);
425   ExprHandle ff2 = IRSimplifier::simplify(f2);
426 
427   ASSERT_EQ(hasher.hash(ff1.node()), hasher.hash(ff2.node()));
428 }
429 
TEST(Simplify,HashLargeExpression)430 TEST(Simplify, HashLargeExpression) {
431   constexpr int N = 1024;
432   BufHandle a("A", {N}, kInt);
433   BufHandle b("B", {N}, kInt);
434   BufHandle c("C", {N}, kInt);
435   VarHandle i("i", kInt);
436   auto memcpy_stmt = For::make(
437       i,
438       0,
439       N,
440       Store::make(
441           c,
442           {i},
443           CompareSelect::make(
444               Load::make(a, {i}),
445               Load::make(b, {i}),
446               CompareSelectOperation::kEQ)));
447 
448   BufHandle d("D", {1}, kInt);
449   BufHandle e("E", {1}, kInt);
450   auto store_ramp_stmt = Store::make(
451       e, {Ramp::make(0, 1, 4)}, Load::make(d, {Ramp::make(0, 1, 4)}));
452 
453   auto if_stmt = Cond::make(
454       CompareSelect::make(
455           Load::make(a, {i}), Load::make(b, {i}), CompareSelectOperation::kGE),
456       memcpy_stmt,
457       store_ramp_stmt);
458 
459   HashProvider hasher;
460   auto hash_r = hasher.hash(if_stmt);
461   // We should not have to do any more work.
462   ASSERT_TRUE(hasher.cachedHash(memcpy_stmt));
463   auto hash_t = hasher.hash(memcpy_stmt);
464   ASSERT_TRUE(hasher.cachedHash(store_ramp_stmt));
465   auto hash_f = hasher.hash(store_ramp_stmt);
466 
467   // Root not equal to either branch, and branches not equal.
468   ASSERT_NE(hash_r, hash_t);
469   ASSERT_NE(hash_r, hash_f);
470   ASSERT_NE(hash_t, hash_f);
471 }
472 
TEST(Simplify,HashForLoopOptions)473 TEST(Simplify, HashForLoopOptions) {
474   constexpr int N = 1024;
475   BufHandle a("A", {N}, kInt);
476   BufHandle b("B", {N}, kInt);
477   BufHandle c("C", {N}, kInt);
478   VarHandle i("i", kInt);
479   auto for_stmt = For::make(
480       i,
481       0,
482       N,
483       Store::make(
484           c,
485           {i},
486           CompareSelect::make(
487               Load::make(a, {i}),
488               Load::make(b, {i}),
489               CompareSelectOperation::kEQ)));
490 
491   HashProvider hasher;
492   auto hash_before = hasher.hash(for_stmt);
493   hasher.clearCache();
494 
495   for_stmt->set_gpu_block_index(LoopOptions::IDX_X);
496   auto hash_block_idx = hasher.hash(for_stmt);
497   hasher.clearCache();
498 
499   ASSERT_NE(hash_before, hash_block_idx);
500 
501   for_stmt->set_gpu_block_index(LoopOptions::IDX_UNSET);
502   auto hash_reset = hasher.hash(for_stmt);
503   hasher.clearCache();
504 
505   ASSERT_EQ(hash_before, hash_reset);
506   for_stmt->set_gpu_thread_index(LoopOptions::IDX_X);
507   auto hash_thread_idx = hasher.hash(for_stmt);
508 
509   ASSERT_NE(hash_before, hash_thread_idx);
510   ASSERT_NE(hash_block_idx, hash_thread_idx);
511 }
512 
513 /// (2 + x) + 4 => x + 6
TEST(Simplify,SimplifyAdd)514 TEST(Simplify, SimplifyAdd) {
515   VarHandle x("x", kInt);
516   VarHandle y("y", kInt);
517 
518   // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
519   VarHandle m("m", kInt);
520   // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
521   VarHandle n("n", kInt);
522   // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
523   VarHandle n_1("n_1", kInt);
524   // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
525   ExprHandle body = (ExprHandle(2) + x) + ExprHandle(4);
526 
527   ExprHandle simplified = IRSimplifier::simplify(body);
528   AddPtr root = simplified.AsNode<Add>();
529   ASSERT_NE(root, nullptr);
530   VarPtr lhs = to<Var>(root->lhs());
531   ASSERT_NE(lhs, nullptr);
532   ASSERT_EQ(lhs->name_hint(), "x");
533   IntImmPtr rhs = to<IntImm>(root->rhs());
534   ASSERT_NE(rhs, nullptr);
535   ASSERT_EQ(rhs->value(), 6.f);
536 }
537 
538 /// (2 - x) - 4 => -2 - x
TEST(Simplify,SimplifySub)539 TEST(Simplify, SimplifySub) {
540   VarHandle x("x", kInt);
541   ExprHandle body = (ExprHandle(2) - x) - ExprHandle(4);
542 
543   ExprHandle simplified = IRSimplifier::simplify(body);
544   SubPtr root = simplified.AsNode<Sub>();
545   ASSERT_NE(root, nullptr);
546   IntImmPtr lhs = to<IntImm>(root->lhs());
547   ASSERT_NE(lhs, nullptr);
548   ASSERT_EQ(lhs->value(), -2.f);
549   VarPtr rhs = to<Var>(root->rhs());
550   ASSERT_NE(rhs, nullptr);
551   ASSERT_EQ(rhs->name_hint(), "x");
552 }
553 
554 /// 2 * (1 - x) - 4 => 2 * (-3 - x)
TEST(Simplify,SimplifyMultiLayer)555 TEST(Simplify, SimplifyMultiLayer) {
556   VarHandle x("x", kInt);
557   ExprHandle body = ExprHandle(2) * ((ExprHandle(1) - x) - ExprHandle(4));
558   ExprHandle simplified = IRSimplifier::simplify(body);
559   IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
560   IS_IMM_WITH_VAL(Int, mul->lhs(), 2);
561   IS_NODE_WITH_NAME(Sub, mul->rhs(), sub);
562   IS_IMM_WITH_VAL(Int, sub->lhs(), -3);
563   IS_VAR_WITH_NAME(sub->rhs(), "x");
564 }
565 
566 /// 2 * (3 * x) - (x * 4) => 2 * x
TEST(Simplify,SimplifyMultiTerm)567 TEST(Simplify, SimplifyMultiTerm) {
568   VarHandle x("x", kInt);
569   ExprHandle body =
570       (ExprHandle(2) * ((ExprHandle(3) * x)) - (x * ExprHandle(4)));
571 
572   ExprHandle simplified = IRSimplifier::simplify(body);
573   MulPtr root = simplified.AsNode<Mul>();
574   ASSERT_NE(root, nullptr);
575   IntImmPtr lhs = to<IntImm>(root->lhs());
576   ASSERT_NE(lhs, nullptr);
577   ASSERT_EQ(lhs->value(), 2);
578   VarPtr rhs = to<Var>(root->rhs());
579   ASSERT_NE(rhs, nullptr);
580   ASSERT_EQ(rhs->name_hint(), "x");
581 }
582 
583 /// 2 * (3 * (long)x) - (x * 4) => 2 * x
TEST(Simplify,SimplifyCasts)584 TEST(Simplify, SimplifyCasts) {
585   VarHandle x("x", kLong);
586   ExprHandle body =
587       (ExprHandle(2) * ((ExprHandle(3) * x)) - (x * ExprHandle(4)));
588 
589   ExprHandle simplified = IRSimplifier::simplify(body);
590   MulPtr root = simplified.AsNode<Mul>();
591   ASSERT_NE(root, nullptr);
592   LongImmPtr lhs = to<LongImm>(root->lhs());
593   ASSERT_NE(lhs, nullptr);
594   ASSERT_EQ(lhs->value(), 2);
595   VarPtr rhs = to<Var>(root->rhs());
596   ASSERT_NE(rhs, nullptr);
597   ASSERT_EQ(rhs->name_hint(), "x");
598 }
599 
600 /// (x + 0) * 1 => x
TEST(Simplify,SimplifyEliminatesNoOps)601 TEST(Simplify, SimplifyEliminatesNoOps) {
602   VarHandle x("x", kInt);
603   ExprHandle body = (x + ExprHandle(0)) * 1;
604 
605   ExprHandle simplified = IRSimplifier::simplify(body);
606   VarPtr root = simplified.AsNode<Var>();
607   ASSERT_NE(root, nullptr);
608   ASSERT_EQ(root->name_hint(), "x");
609 }
610 
611 /// Cannot simplify this.
TEST(Simplify,SimplifyMultiVar)612 TEST(Simplify, SimplifyMultiVar) {
613   VarHandle x("x", kInt);
614   VarHandle y("y", kInt);
615   ExprHandle body = x * 24 + y * 34;
616 
617   ExprHandle simplified = IRSimplifier::simplify(body);
618 
619   AddPtr root = simplified.AsNode<Add>();
620   ASSERT_NE(root, nullptr);
621   MulPtr lhs = to<Mul>(root->lhs());
622   ASSERT_NE(lhs, nullptr);
623   VarPtr varX = to<Var>(lhs->rhs());
624   ASSERT_NE(varX, nullptr);
625   ASSERT_EQ(varX->name_hint(), "x");
626   MulPtr rhs = to<Mul>(root->rhs());
627   ASSERT_NE(rhs, nullptr);
628   VarPtr varY = to<Var>(rhs->rhs());
629   ASSERT_NE(varY, nullptr);
630   ASSERT_EQ(varY->name_hint(), "y");
631 }
632 
633 // x + 2 + y => x + y + 2
TEST(Simplify,DISABLED_SimplifyReorderings)634 TEST(Simplify, DISABLED_SimplifyReorderings) {
635   VarHandle x("x", kInt);
636   VarHandle y("y", kInt);
637   ExprHandle body = x + 2 + y;
638   ExprHandle simplified = IRSimplifier::simplify(body);
639 
640   AddPtr root = simplified.AsNode<Add>();
641   ASSERT_NE(root, nullptr);
642 
643   IS_NODE_WITH_NAME(Add, root->lhs(), rhs);
644   IS_VAR_WITH_NAME(rhs->lhs(), "x");
645   IS_VAR_WITH_NAME(rhs->rhs(), "y");
646   IS_IMM_WITH_VAL(Int, root->rhs(), 2);
647 }
648 
649 /// y + x * 0 => y
TEST(Simplify,SimplifyEliminatesVar)650 TEST(Simplify, SimplifyEliminatesVar) {
651   VarHandle x("x", kInt);
652   VarHandle y("y", kInt);
653   ExprHandle body = y + x * ExprHandle(0);
654 
655   ExprHandle simplified = IRSimplifier::simplify(body);
656   IS_VAR_WITH_NAME(simplified.node(), "y");
657 }
658 
TEST(Simplify,SimplifyAdds)659 TEST(Simplify, SimplifyAdds) {
660   VarHandle x("x", kInt);
661   VarHandle y("y", kInt);
662 
663   {
664     // (x + y) + (x + y) => 2 * (x + y)
665     ExprHandle body = (x + y) + (x + y);
666     ExprHandle simplified = IRSimplifier::simplify(body);
667 
668     IS_NODE_WITH_NAME(Mul, simplified.node(), root);
669     IS_IMM_WITH_VAL(Int, root->lhs(), 2);
670     IS_NODE_WITH_NAME(Add, root->rhs(), add);
671     IS_VAR_WITH_NAME(add->lhs(), "x");
672     IS_VAR_WITH_NAME(add->rhs(), "y");
673   }
674 
675   {
676     // (x * y) + (x * y) => 2 * (x * y)
677     ExprHandle body = (x * y) + (x * y);
678     ExprHandle simplified = IRSimplifier::simplify(body);
679 
680     IS_NODE_WITH_NAME(Mul, simplified.node(), root);
681     IS_IMM_WITH_VAL(Int, root->lhs(), 2);
682     IS_NODE_WITH_NAME(Mul, root->rhs(), mul);
683     IS_VAR_WITH_NAME(mul->lhs(), "x");
684     IS_VAR_WITH_NAME(mul->rhs(), "y");
685   }
686 
687   {
688     // (x - y) + (x - y) => 2 * (x - y)
689     ExprHandle body = (x - y) + (x - y);
690     ExprHandle simplified = IRSimplifier::simplify(body);
691 
692     IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
693     IS_IMM_WITH_VAL(Int, mul->lhs(), 2);
694 
695     IS_NODE_WITH_NAME(Sub, mul->rhs(), rhs);
696     IS_VAR_WITH_NAME(rhs->lhs(), "x");
697     IS_VAR_WITH_NAME(rhs->rhs(), "y");
698   }
699 
700   {
701     // (x + x + x + x) => 4 * x
702     ExprHandle body = (x + x + x + x);
703     ExprHandle simplified = IRSimplifier::simplify(body);
704 
705     IS_NODE_WITH_NAME(Mul, simplified.node(), root);
706     IS_IMM_WITH_VAL(Int, root->lhs(), 4);
707     IS_VAR_WITH_NAME(root->rhs(), "x");
708   }
709 
710   {
711     // (x + 0) => x.
712     ExprHandle body = x + 0;
713     ExprHandle simplified = IRSimplifier::simplify(body);
714 
715     IS_VAR_WITH_NAME(simplified.node(), "x");
716   }
717 
718   {
719     // (x + 0.f) => float(x).
720     ExprHandle body = x + 0.f;
721     ExprHandle simplified = IRSimplifier::simplify(body);
722 
723     IS_NODE_WITH_NAME(Cast, simplified.node(), cast);
724     ASSERT_EQ(cast->dtype().scalar_type(), ScalarType::Float);
725     IS_VAR_WITH_NAME(cast->src_value(), "x");
726   }
727 }
728 
TEST(Simplify,SimplifyMuls)729 TEST(Simplify, SimplifyMuls) {
730   VarHandle x("x", kInt);
731   VarHandle y("y", kInt);
732 
733   {
734     // (x + y) * (x + y) => (x + y) * (x + y)
735     // We don't attempt to simplify multiplication of polynomials since the
736     // result is only very rarely more efficient.
737     ExprHandle body = (x + y) * (x + y);
738     ExprHandle simplified = IRSimplifier::simplify(body);
739 
740     IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
741     IS_NODE_WITH_NAME(Add, mul->lhs(), lhs);
742     IS_VAR_WITH_NAME(lhs->lhs(), "x");
743     IS_VAR_WITH_NAME(lhs->rhs(), "y");
744     IS_NODE_WITH_NAME(Add, mul->rhs(), rhs);
745     IS_VAR_WITH_NAME(rhs->lhs(), "x");
746     IS_VAR_WITH_NAME(rhs->rhs(), "y");
747   }
748 
749   {
750     // x * y * x * y => x * x * y * y
751     // These get reordered only.
752     ExprHandle body = x * y * x * y;
753     ExprHandle simplified = IRSimplifier::simplify(body);
754 
755     IS_NODE_WITH_NAME(Mul, simplified.node(), mul1);
756     IS_NODE_WITH_NAME(Mul, mul1->lhs(), mul2);
757     IS_NODE_WITH_NAME(Mul, mul2->lhs(), mul3);
758     IS_VAR_WITH_NAME(mul1->rhs(), "y");
759     IS_VAR_WITH_NAME(mul2->rhs(), "y");
760     IS_VAR_WITH_NAME(mul3->lhs(), "x");
761     IS_VAR_WITH_NAME(mul3->rhs(), "x");
762   }
763 
764   {
765     // 1 * (x * 1) => x
766     // Ones cancel cleanly.
767     ExprHandle body = ExprHandle(1) * (x * ExprHandle(1));
768     ExprHandle simplified = IRSimplifier::simplify(body);
769 
770     IS_VAR_WITH_NAME(simplified.node(), "x");
771   }
772 
773   {
774     // 1.f * (x * 1.f) => x
775     // Even float ones cancel cleanly, but carry their type.
776     ExprHandle body = ExprHandle(1.f) * (x * ExprHandle(1.f));
777     ExprHandle simplified = IRSimplifier::simplify(body);
778 
779     IS_NODE_WITH_NAME(Cast, simplified.node(), cast);
780     ASSERT_EQ(cast->dtype().scalar_type(), ScalarType::Float);
781     IS_VAR_WITH_NAME(cast->src_value(), "x");
782   }
783 
784   {
785     // 1 * (x * 1.f) => x
786     // One float is enough to cast the expr.
787     ExprHandle body = ExprHandle(1) * (x * ExprHandle(1.f));
788     ExprHandle simplified = IRSimplifier::simplify(body);
789 
790     IS_NODE_WITH_NAME(Cast, simplified.node(), cast);
791     ASSERT_EQ(cast->dtype().scalar_type(), ScalarType::Float);
792     IS_VAR_WITH_NAME(cast->src_value(), "x");
793   }
794 
795   {
796     // 1 * (x * 0) => 0
797     // Zeroes are eliminated.
798     ExprHandle body = ExprHandle(1) * (x * ExprHandle(0));
799     ExprHandle simplified = IRSimplifier::simplify(body);
800 
801     IS_IMM_WITH_VAL(Int, simplified.node(), 0);
802   }
803 
804   {
805     // 1 * (x * 0) => 0
806     // But not for Float since nan * 0 = nan.
807     ExprHandle body = ExprHandle(1.f) * (x * ExprHandle(0.f));
808     ExprHandle simplified = IRSimplifier::simplify(body);
809 
810     IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
811     IS_NODE_WITH_NAME(Cast, mul->lhs(), cast);
812     ASSERT_EQ(cast->dtype().scalar_type(), ScalarType::Float);
813     IS_VAR_WITH_NAME(cast->src_value(), "x");
814     IS_IMM_WITH_VAL(Float, mul->rhs(), 0.0);
815   }
816 
817   {
818     // (x - y) * (x - y) => (x - y) * (x - y)
819     // As with Add we don't attempt simplification of this.
820     ExprHandle body = (x - y) * (x - y);
821     ExprHandle simplified = IRSimplifier::simplify(body);
822 
823     IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
824     IS_NODE_WITH_NAME(Sub, mul->lhs(), lhs);
825     IS_VAR_WITH_NAME(lhs->lhs(), "x");
826     IS_VAR_WITH_NAME(lhs->rhs(), "y");
827     IS_NODE_WITH_NAME(Sub, mul->rhs(), rhs);
828     IS_VAR_WITH_NAME(rhs->lhs(), "x");
829     IS_VAR_WITH_NAME(rhs->rhs(), "y");
830   }
831 
832   {
833     // (x + y) * (x - y) => (x + y) * (x - y)
834     // Don't simplify with different ops on each side.
835     ExprHandle body = (x + y) * (x - y);
836     ExprHandle simplified = IRSimplifier::simplify(body);
837     IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
838     IS_NODE_WITH_NAME(Add, mul->lhs(), lhs);
839     IS_VAR_WITH_NAME(lhs->lhs(), "x");
840     IS_VAR_WITH_NAME(lhs->rhs(), "y");
841     IS_NODE_WITH_NAME(Sub, mul->rhs(), rhs);
842     IS_VAR_WITH_NAME(rhs->lhs(), "x");
843     IS_VAR_WITH_NAME(rhs->rhs(), "y");
844   }
845 
846   {
847     // Multiply a polynomial by a term.
848     //   - term with no scalar, poly with non-identity scalar.
849     // x * (y + 1) => x + x * y
850     ExprHandle body = x * (y + ExprHandle(1));
851     ExprHandle simplified = IRSimplifier::simplify(body);
852 
853     IS_NODE_WITH_NAME(Add, simplified.node(), add);
854     IS_VAR_WITH_NAME(add->lhs(), "x");
855     IS_NODE_WITH_NAME(Mul, add->rhs(), mul);
856     IS_VAR_WITH_NAME(mul->lhs(), "x");
857     IS_VAR_WITH_NAME(mul->rhs(), "y");
858   }
859 
860   {
861     // Multiply a polynomial by a term.
862     //   - term with identity scalar, poly with non-identity scalar.
863     // (x * 1) * (y + 1) => x + x * y
864     ExprHandle body = (x * ExprHandle(1)) * (y + ExprHandle(1));
865     ExprHandle simplified = IRSimplifier::simplify(body);
866 
867     IS_NODE_WITH_NAME(Add, simplified.node(), add);
868     IS_VAR_WITH_NAME(add->lhs(), "x");
869     IS_NODE_WITH_NAME(Mul, add->rhs(), mul);
870     IS_VAR_WITH_NAME(mul->lhs(), "x");
871     IS_VAR_WITH_NAME(mul->rhs(), "y");
872   }
873 
874   {
875     // Multiply a polynomial by a term.
876     //   - term with non-identity scalar, poly with non-identity scalar.
877     // (x * 2) * (y + 1) => 2 * (x + x * y)
878     ExprHandle body = (x * ExprHandle(2)) * (y + ExprHandle(1));
879     ExprHandle simplified = IRSimplifier::simplify(body);
880 
881     IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
882     IS_IMM_WITH_VAL(Int, mul->lhs(), 2);
883     IS_NODE_WITH_NAME(Add, mul->rhs(), add);
884     IS_VAR_WITH_NAME(add->lhs(), "x");
885     IS_NODE_WITH_NAME(Mul, add->rhs(), mul2);
886     IS_VAR_WITH_NAME(mul2->lhs(), "x");
887     IS_VAR_WITH_NAME(mul2->rhs(), "y");
888   }
889 
890   {
891     // Multiply a polynomial by a term.
892     //   - term with non-identity scalar, poly with identity scalar.
893     // (x * 2) * (y + 0) => 2 * (x * y)
894     ExprHandle body = (x * ExprHandle(2)) * (y + ExprHandle(0));
895     ExprHandle simplified = IRSimplifier::simplify(body);
896 
897     IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
898     IS_IMM_WITH_VAL(Int, mul->lhs(), 2);
899     IS_NODE_WITH_NAME(Mul, mul->rhs(), mul2);
900     IS_VAR_WITH_NAME(mul2->lhs(), "x");
901     IS_VAR_WITH_NAME(mul2->rhs(), "y");
902   }
903 
904   {
905     // Multiply a polynomial by a term.
906     //   - term with identity scalar, poly with identity scalar.
907     // (x * 1) * (y + 0) => x * y
908     ExprHandle body = (x * ExprHandle(1)) * (y + ExprHandle(0));
909     ExprHandle simplified = IRSimplifier::simplify(body);
910 
911     IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
912     IS_VAR_WITH_NAME(mul->lhs(), "x");
913     IS_VAR_WITH_NAME(mul->rhs(), "y");
914   }
915 
916   {
917     // Multiply a polynomial by a term.
918     //   - term with no scalar, poly with identity scalar.
919     // x * (y + 0) => x * y
920     ExprHandle body = x * (y + ExprHandle(0));
921     ExprHandle simplified = IRSimplifier::simplify(body);
922 
923     IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
924     IS_VAR_WITH_NAME(mul->lhs(), "x");
925     IS_VAR_WITH_NAME(mul->rhs(), "y");
926   }
927 }
928 
929 // Sub an expr from itself will result in zero.
TEST(Simplify,SimplifySubs)930 TEST(Simplify, SimplifySubs) {
931   VarHandle x("x", kInt);
932   VarHandle y("y", kInt);
933 
934   {
935     // (x + y) - (x + y) => 0
936     ExprHandle body = (x + y) - (x + y);
937     ExprHandle simplified = IRSimplifier::simplify(body);
938     IS_IMM_WITH_VAL(Int, simplified.node(), 0);
939   }
940 
941   {
942     // (x * y) - (x * y) => 0
943     ExprHandle body = (x * y) - (x * y);
944     ExprHandle simplified = IRSimplifier::simplify(body);
945     IS_IMM_WITH_VAL(Int, simplified.node(), 0);
946   }
947 
948   {
949     // (x - y) - (x - y) => 0
950     ExprHandle body = (x - y) - (x - y);
951     ExprHandle simplified = IRSimplifier::simplify(body);
952     IS_IMM_WITH_VAL(Int, simplified.node(), 0);
953   }
954 
955   {
956     // (x + y) - 2 * (x + y) => -1 * x - y
957     ExprHandle body = (x + y) - ExprHandle(2) * (x + y);
958     ExprHandle simplified = IRSimplifier::simplify(body);
959 
960     IS_NODE_WITH_NAME(Sub, simplified.node(), sub);
961     IS_NODE_WITH_NAME(Mul, sub->lhs(), mul);
962     IS_IMM_WITH_VAL(Int, mul->lhs(), -1);
963     IS_VAR_WITH_NAME(mul->rhs(), "x");
964     IS_VAR_WITH_NAME(sub->rhs(), "y");
965   }
966 
967   {
968     // (x + y) - y => x
969     ExprHandle body = (x + y) - y;
970     ExprHandle simplified = IRSimplifier::simplify(body);
971     IS_VAR_WITH_NAME(simplified.node(), "x");
972   }
973 
974   {
975     // (x - 0) => x.
976     ExprHandle body = x - 0;
977     ExprHandle simplified = IRSimplifier::simplify(body);
978     IS_VAR_WITH_NAME(simplified.node(), "x");
979   }
980 
981   {
982     // (x - 0.f) => x.
983     // Simple enough to cancel in float.
984     ExprHandle body = x - ExprHandle(0.f);
985     ExprHandle simplified = IRSimplifier::simplify(body);
986     IS_NODE_WITH_NAME(Cast, simplified.node(), cast);
987     ASSERT_EQ(cast->dtype().scalar_type(), ScalarType::Float);
988     IS_VAR_WITH_NAME(cast->src_value(), "x");
989   }
990 
991   {
992     // (x - (float)(y - y)) => x.
993     ExprHandle body = x - Cast::make(kFloat, y - y);
994     ExprHandle simplified = IRSimplifier::simplify(body);
995     IS_NODE_WITH_NAME(Cast, simplified.node(), cast);
996     ASSERT_EQ(cast->dtype().scalar_type(), ScalarType::Float);
997     IS_VAR_WITH_NAME(cast->src_value(), "x");
998   }
999 
1000   {
1001     // (x - y) - y => x - 2 * y
1002     ExprHandle body = (x - y) - y;
1003     ExprHandle simplified = IRSimplifier::simplify(body);
1004 
1005     IS_NODE_WITH_NAME(Sub, simplified.node(), sub);
1006     IS_VAR_WITH_NAME(sub->lhs(), "x");
1007     IS_NODE_WITH_NAME(Mul, sub->rhs(), mul);
1008     IS_IMM_WITH_VAL(Int, mul->lhs(), 2);
1009     IS_VAR_WITH_NAME(mul->rhs(), "y");
1010   }
1011 
1012   {
1013     // 2 * x - x => x
1014     ExprHandle body = (ExprHandle(2) * x) - x;
1015     ExprHandle simplified = IRSimplifier::simplify(body);
1016     IS_VAR_WITH_NAME(simplified.node(), "x");
1017   }
1018 
1019   {
1020     // x - 2 * x = -1 * x
1021     // We don't have a unary negate, but this could be 0 -x I guess?
1022     ExprHandle body = x - (ExprHandle(2) * x);
1023     ExprHandle simplified = IRSimplifier::simplify(body);
1024     IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
1025 
1026     IS_IMM_WITH_VAL(Int, mul->lhs(), -1);
1027     IS_VAR_WITH_NAME(mul->rhs(), "x");
1028   }
1029 
1030   {
1031     // (x + y + 5) * (x - x) => 0
1032     // Cancelling out one side of Mul cancels both.
1033     ExprHandle body = (x + y + 5) * (x - x);
1034     ExprHandle simplified = IRSimplifier::simplify(body);
1035 
1036     IS_IMM_WITH_VAL(Int, simplified.node(), 0);
1037   }
1038 
1039   {
1040     // Cancel out opaque modulus.
1041     ExprHandle body = (x % y + 2) - (x % y);
1042     ExprHandle simplified = IRSimplifier::simplify(body);
1043     IS_IMM_WITH_VAL(Int, simplified.node(), 2);
1044   }
1045 
1046   {
1047     // Cancel out opaque modulus with a bit more going on.
1048     ExprHandle body = (x % y + (x * 2 - x - y * 0) - x + 2) - (x % y);
1049     ExprHandle simplified = IRSimplifier::simplify(body);
1050     IS_IMM_WITH_VAL(Int, simplified.node(), 2);
1051   }
1052 
1053   {
1054     // Sub where result is negative.
1055     ExprHandle body = x - (x + 1);
1056     ExprHandle simplified = IRSimplifier::simplify(body);
1057     IS_IMM_WITH_VAL(Int, simplified.node(), -1);
1058   }
1059 
1060   {
1061     // Sub where result is positive due to negative scalar on RHS.
1062     ExprHandle body = x - (x - 1);
1063     ExprHandle simplified = IRSimplifier::simplify(body);
1064     IS_IMM_WITH_VAL(Int, simplified.node(), 1);
1065   }
1066 
1067   {
1068     // Term - Polynomial sub where RHS must be negated.
1069     ExprHandle body = (x * 2) - (x * 2 + 1);
1070     ExprHandle simplified = IRSimplifier::simplify(body);
1071     IS_IMM_WITH_VAL(Int, simplified.node(), -1);
1072   }
1073 
1074   {
1075     // Term - Polynomial sub where the result is a Term.
1076     ExprHandle body = (y * x * 2) - (x * y);
1077     ExprHandle simplified = IRSimplifier::simplify(body);
1078     IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
1079 
1080     IS_VAR_WITH_NAME(mul->lhs(), "x");
1081     IS_VAR_WITH_NAME(mul->rhs(), "y");
1082   }
1083 
1084   {
1085     // Term - Polynomial sub where the result is a Polynomial.
1086     ExprHandle body = (x * 2) - (x + 1);
1087     ExprHandle simplified = IRSimplifier::simplify(body);
1088     IS_NODE_WITH_NAME(Sub, simplified.node(), sub);
1089 
1090     IS_VAR_WITH_NAME(sub->lhs(), "x");
1091     IS_IMM_WITH_VAL(Int, sub->rhs(), 1);
1092   }
1093 }
1094 
TEST(Simplify,SimplifyDiv)1095 TEST(Simplify, SimplifyDiv) {
1096   VarHandle x("x", kInt);
1097 
1098   {
1099     ExprHandle body = ExprHandle(0) / x;
1100     ExprHandle simplified = IRSimplifier::simplify(body);
1101 
1102     IS_IMM_WITH_VAL(Int, simplified.node(), 0);
1103   }
1104 
1105   {
1106     ExprHandle body = x / 1;
1107     ExprHandle simplified = IRSimplifier::simplify(body);
1108 
1109     IS_VAR_WITH_NAME(simplified.node(), "x");
1110   }
1111 }
1112 
TEST(Simplify,SimplifyDivWithLoopContext0)1113 TEST(Simplify, SimplifyDivWithLoopContext0) {
1114   // Stmt to simplify:
1115   // for (int i = 0; i < 100; i++) {
1116   //  A[i] = i / 100;
1117   //}
1118   VarHandle i("i", kInt);
1119   BufHandle a_buf("A", {100}, kInt);
1120   auto for_stmt = For::make(i, 0, 100, Store::make(a_buf, {i}, (i / 100)));
1121 
1122   const StmtPtr simplified = IRSimplifier::simplify(for_stmt);
1123 
1124   std::ostringstream oss;
1125   oss << *(simplified);
1126   const std::string& verification_pattern =
1127       R"IR(
1128 # CHECK: for (int i
1129 # CHECK-NEXT:   A[i] = 0;
1130       )IR";
1131   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
1132 }
1133 
TEST(Simplify,SimplifyDivWithLoopContext1)1134 TEST(Simplify, SimplifyDivWithLoopContext1) {
1135   // Stmt to simplify:
1136   // for (const auto i : c10::irange(6)) {
1137   //  A[i] = (i + 24) / 6;
1138   //}
1139   VarHandle i("i", kInt);
1140   BufHandle a_buf("A", {6}, kInt);
1141   auto for_stmt = For::make(i, 0, 6, Store::make(a_buf, {i}, (i + 24) / 6));
1142 
1143   const StmtPtr simplified = IRSimplifier::simplify(for_stmt);
1144 
1145   std::ostringstream oss;
1146   oss << *(simplified);
1147   const std::string& verification_pattern =
1148       R"IR(
1149 # CHECK: for (int i
1150 # CHECK-NEXT:   A[i] = 4;
1151       )IR";
1152   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
1153 }
1154 
TEST(Simplify,SimplifyDivWithLoopContext2)1155 TEST(Simplify, SimplifyDivWithLoopContext2) {
1156   // Stmt to simplify:
1157   // for (const auto i : c10::irange(5)) {
1158   //  A[i] = (i + 25) / 6;
1159   //}
1160   VarHandle i("i", kInt);
1161   BufHandle a_buf("A", {5}, kInt);
1162   auto for_stmt = For::make(i, 0, 5, Store::make(a_buf, {i}, (i + 25) / 6));
1163 
1164   const StmtPtr simplified = IRSimplifier::simplify(for_stmt);
1165 
1166   std::ostringstream oss;
1167   oss << *(simplified);
1168   const std::string& verification_pattern =
1169       R"IR(
1170 # CHECK: for (int i
1171 # CHECK-NEXT:   A[i] = 4;
1172       )IR";
1173   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
1174 }
1175 
TEST(Simplify,SimplifyDivWithLoopContext3)1176 TEST(Simplify, SimplifyDivWithLoopContext3) {
1177   // Stmt to simplify:
1178   // for (const auto i : c10::irange(6)) {
1179   //  A[i] = (i + 24) / (-6);
1180   //}
1181   VarHandle i("i", kInt);
1182   BufHandle a_buf("A", {6}, kInt);
1183   auto for_stmt = For::make(i, 0, 6, Store::make(a_buf, {i}, (i + 24) / (-6)));
1184 
1185   const StmtPtr simplified = IRSimplifier::simplify(for_stmt);
1186 
1187   std::ostringstream oss;
1188   oss << *(simplified);
1189   const std::string& verification_pattern =
1190       R"IR(
1191 # CHECK: for (int i
1192 # CHECK-NOT:   A[i] = -4;
1193       )IR";
1194   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
1195 }
1196 
TEST(Simplify,SimplifyDivWithLoopContext4)1197 TEST(Simplify, SimplifyDivWithLoopContext4) {
1198   // Stmt to simplify:
1199   // for (const auto i : c10::irange(5)) {
1200   //  A[i] = (i - 5) / 6;
1201   //}
1202   VarHandle i("i", kInt);
1203   BufHandle a_buf("A", {5}, kInt);
1204   auto for_stmt = For::make(i, 0, 5, Store::make(a_buf, {i}, (i + (-5)) / 6));
1205 
1206   const StmtPtr simplified = IRSimplifier::simplify(for_stmt);
1207 
1208   std::ostringstream oss;
1209   oss << *(simplified);
1210   const std::string& verification_pattern =
1211       R"IR(
1212 # CHECK: for (int i
1213 # CHECK-NOT:   A[i] = 0;
1214       )IR";
1215   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
1216 }
1217 
TEST(Simplify,SimplifyDivWithLoopContext5)1218 TEST(Simplify, SimplifyDivWithLoopContext5) {
1219   // Stmt to simplify:
1220   // for (const auto i : c10::irange(6)) {
1221   //  for (const auto j : c10::irange(10)) {
1222   //    A[i, j] = (i + 6*j) / 6;
1223   //  }
1224   //}
1225   VarHandle i("i", kInt);
1226   VarHandle j("j", kInt);
1227   BufHandle a_buf("A", {6, 10}, kInt);
1228   auto for_j = For::make(j, 0, 10, Store::make(a_buf, {i, j}, (i + j * 6) / 6));
1229   auto for_i = For::make(i, 0, 6, for_j);
1230 
1231   const StmtPtr simplified = IRSimplifier::simplify(for_i);
1232 
1233   std::ostringstream oss;
1234   oss << *(simplified);
1235   const std::string& verification_pattern =
1236       R"IR(
1237 # CHECK: for (int i
1238 # CHECK:   for (int j
1239 # CHECK-NEXT:   A[i, j] = j;
1240       )IR";
1241   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
1242 }
1243 
TEST(Simplify,SimplifyDivWithLoopContext6)1244 TEST(Simplify, SimplifyDivWithLoopContext6) {
1245   // Stmt to simplify:
1246   // for (const auto i : c10::irange(6)) {
1247   //  for (int j = -1; j < 9; j++) {
1248   //    A[i, j+1] = (i + 6*j) / 6;
1249   //  }
1250   //}
1251   VarHandle i("i", kInt);
1252   VarHandle j("j", kInt);
1253   BufHandle a_buf("A", {6, 10}, kInt);
1254   auto for_j =
1255       For::make(j, -1, 9, Store::make(a_buf, {i, j + 1}, (i + j * 6) / 6));
1256   auto for_i = For::make(i, 0, 6, for_j);
1257 
1258   const StmtPtr simplified = IRSimplifier::simplify(for_i);
1259 
1260   std::ostringstream oss;
1261   oss << *(simplified);
1262   const std::string& verification_pattern =
1263       R"IR(
1264 # CHECK: for (int i
1265 # CHECK:   for (int j
1266 # CHECK-NOT:   A[i, j] = j;
1267       )IR";
1268   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
1269 }
1270 
TEST(Simplify,SimplifyDivWithLoopContext7)1271 TEST(Simplify, SimplifyDivWithLoopContext7) {
1272   // Stmt to simplify:
1273   // for (const auto i : c10::irange(6)) {
1274   //  for (const auto j : c10::irange(10)) {
1275   //    A[i, j] = (i + 6*j) / (-6);
1276   //  }
1277   //}
1278   VarHandle i("i", kInt);
1279   VarHandle j("j", kInt);
1280   BufHandle a_buf("A", {6, 10}, kInt);
1281   auto for_j =
1282       For::make(j, 0, 10, Store::make(a_buf, {i, j}, (i + j * 6) / (-6)));
1283   auto for_i = For::make(i, 0, 6, for_j);
1284 
1285   const StmtPtr simplified = IRSimplifier::simplify(for_i);
1286 
1287   std::ostringstream oss;
1288   oss << *(simplified);
1289   const std::string& verification_pattern =
1290       R"IR(
1291 # CHECK: for (int i
1292 # CHECK:   for (int j
1293 # CHECK-NOT:   A[i, j] = -j;
1294       )IR";
1295   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
1296 }
1297 
TEST(Simplify,SimplifyModWithLoopContext0)1298 TEST(Simplify, SimplifyModWithLoopContext0) {
1299   // Stmt to simplify:
1300   // for (const auto i : c10::irange(100)) {
1301   //  A[i] = i % 100;
1302   //}
1303   VarHandle i("i", kInt);
1304   BufHandle a_buf("A", {100}, kInt);
1305   auto for_stmt = For::make(i, 0, 100, Store::make(a_buf, {i}, (i % 100)));
1306 
1307   const StmtPtr simplified = IRSimplifier::simplify(for_stmt);
1308 
1309   std::ostringstream oss;
1310   oss << *(simplified);
1311   const std::string& verification_pattern =
1312       R"IR(
1313 # CHECK: for (int i
1314 # CHECK-NEXT:   A[i] = i;
1315       )IR";
1316   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
1317 }
1318 
TEST(Simplify,SimplifyModWithLoopContext1)1319 TEST(Simplify, SimplifyModWithLoopContext1) {
1320   // Stmt to simplify:
1321   // for (const auto i : c10::irange(6)) {
1322   //  A[i] = (i + 24) % 6;
1323   //}
1324   VarHandle i("i", kInt);
1325   BufHandle a_buf("A", {6}, kInt);
1326   auto for_stmt = For::make(i, 0, 6, Store::make(a_buf, {i}, (i + 24) % 6));
1327 
1328   const StmtPtr simplified = IRSimplifier::simplify(for_stmt);
1329 
1330   std::ostringstream oss;
1331   oss << *(simplified);
1332   const std::string& verification_pattern =
1333       R"IR(
1334 # CHECK: for (int i
1335 # CHECK-NEXT:   A[i] = i;
1336       )IR";
1337   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
1338 }
1339 
TEST(Simplify,SimplifyModWithLoopContext2)1340 TEST(Simplify, SimplifyModWithLoopContext2) {
1341   // Stmt to simplify:
1342   // for (const auto i : c10::irange(5)) {
1343   //  A[i] = (i + 25) % 6;
1344   //}
1345   VarHandle i("i", kInt);
1346   BufHandle a_buf("A", {5}, kInt);
1347   auto for_stmt = For::make(i, 0, 5, Store::make(a_buf, {i}, (i + 25) % 6));
1348 
1349   const StmtPtr simplified = IRSimplifier::simplify(for_stmt);
1350 
1351   std::ostringstream oss;
1352   oss << *(simplified);
1353   const std::string& verification_pattern =
1354       R"IR(
1355 # CHECK: for (int i
1356 # CHECK-NEXT:   A[i] = i + 1;
1357       )IR";
1358   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
1359 }
1360 
TEST(Simplify,SimplifyModWithLoopContext3)1361 TEST(Simplify, SimplifyModWithLoopContext3) {
1362   // Stmt to simplify:
1363   // for (const auto i : c10::irange(6)) {
1364   //  A[i] = (i + 24) % (-6);
1365   //}
1366   VarHandle i("i", kInt);
1367   BufHandle a_buf("A", {6}, kInt);
1368   auto for_stmt = For::make(i, 0, 6, Store::make(a_buf, {i}, (i + 24) % (-6)));
1369 
1370   const StmtPtr simplified = IRSimplifier::simplify(for_stmt);
1371 
1372   std::ostringstream oss;
1373   oss << *(simplified);
1374   const std::string& verification_pattern =
1375       R"IR(
1376 # CHECK: for (int i
1377 # CHECK-NOT:   A[i] = i;
1378       )IR";
1379   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
1380 }
1381 
TEST(Simplify,SimplifyModWithLoopContext4)1382 TEST(Simplify, SimplifyModWithLoopContext4) {
1383   // Stmt to simplify:
1384   // for (const auto i : c10::irange(5)) {
1385   //  A[i] = (i - 5) % 6;
1386   //}
1387   VarHandle i("i", kInt);
1388   BufHandle a_buf("A", {5}, kInt);
1389   auto for_stmt = For::make(i, 0, 5, Store::make(a_buf, {i}, (i + (-5)) % 6));
1390 
1391   const StmtPtr simplified = IRSimplifier::simplify(for_stmt);
1392 
1393   std::ostringstream oss;
1394   oss << *(simplified);
1395   const std::string& verification_pattern =
1396       R"IR(
1397 # CHECK: for (int i
1398 # CHECK-NOT:   A[i] = i - 5;
1399       )IR";
1400   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
1401 }
1402 
TEST(Simplify,SimplifyModWithLoopContext5)1403 TEST(Simplify, SimplifyModWithLoopContext5) {
1404   // Stmt to simplify:
1405   // for (const auto i : c10::irange(6)) {
1406   //  for (const auto j : c10::irange(10)) {
1407   //    A[i, j] = (i + 6*j) % 6;
1408   //  }
1409   //}
1410   VarHandle i("i", kInt);
1411   VarHandle j("j", kInt);
1412   BufHandle a_buf("A", {6, 10}, kInt);
1413   auto for_j = For::make(j, 0, 10, Store::make(a_buf, {i, j}, (i + j * 6) % 6));
1414   auto for_i = For::make(i, 0, 6, for_j);
1415 
1416   const StmtPtr simplified = IRSimplifier::simplify(for_i);
1417 
1418   std::ostringstream oss;
1419   oss << *(simplified);
1420   const std::string& verification_pattern =
1421       R"IR(
1422 # CHECK: for (int i
1423 # CHECK:   for (int j
1424 # CHECK-NEXT:   A[i, j] = i;
1425       )IR";
1426   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
1427 }
1428 
TEST(Simplify,SimplifyModWithLoopContext6)1429 TEST(Simplify, SimplifyModWithLoopContext6) {
1430   // Stmt to simplify:
1431   // for (const auto i : c10::irange(6)) {
1432   //  for (int j = -1; j < 9; j++) {
1433   //    A[i, j+1] = (i + 6*j) % 6;
1434   //  }
1435   //}
1436   VarHandle i("i", kInt);
1437   VarHandle j("j", kInt);
1438   BufHandle a_buf("A", {6, 10}, kInt);
1439   auto for_j =
1440       For::make(j, -1, 9, Store::make(a_buf, {i, j + 1}, (i + j * 6) % 6));
1441   auto for_i = For::make(i, 0, 6, for_j);
1442 
1443   const StmtPtr simplified = IRSimplifier::simplify(for_i);
1444 
1445   std::ostringstream oss;
1446   oss << *(simplified);
1447   const std::string& verification_pattern =
1448       R"IR(
1449 # CHECK: for (int i
1450 # CHECK:   for (int j
1451 # CHECK-NOT:   A[i, j] = i;
1452       )IR";
1453   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
1454 }
1455 
TEST(Simplify,SimplifyModWithLoopContext7)1456 TEST(Simplify, SimplifyModWithLoopContext7) {
1457   // Stmt to simplify:
1458   // for (const auto i : c10::irange(6)) {
1459   //  for (const auto j : c10::irange(10)) {
1460   //    A[i, j] = (i + 6*j) % (-6);
1461   //  }
1462   //}
1463   VarHandle i("i", kInt);
1464   VarHandle j("j", kInt);
1465   BufHandle a_buf("A", {6, 10}, kInt);
1466   auto for_j =
1467       For::make(j, 0, 10, Store::make(a_buf, {i, j}, (i + j * 6) % (-6)));
1468   auto for_i = For::make(i, 0, 6, for_j);
1469 
1470   const StmtPtr simplified = IRSimplifier::simplify(for_i);
1471 
1472   std::ostringstream oss;
1473   oss << *(simplified);
1474   const std::string& verification_pattern =
1475       R"IR(
1476 # CHECK: for (int i
1477 # CHECK:   for (int j
1478 # CHECK-NOT:   A[i, j] = i;
1479       )IR";
1480   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
1481 }
1482 
TEST(Simplify,SimplifyMod)1483 TEST(Simplify, SimplifyMod) {
1484   VarHandle x("x", kInt);
1485   VarHandle y("y", kInt);
1486   VarHandle z("z", kInt);
1487 
1488   {
1489     // Constant folding works.
1490     ExprHandle body = ExprHandle(10) % 8;
1491     ExprHandle simplified = IRSimplifier::simplify(body);
1492     // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
1493     IS_IMM_WITH_VAL(Int, simplified.node(), 2);
1494   }
1495 
1496   {
1497     // x % x => 0
1498     ExprHandle body = x % x;
1499     ExprHandle simplified = IRSimplifier::simplify(body);
1500     IS_IMM_WITH_VAL(Int, simplified.node(), 0);
1501   }
1502 
1503   {
1504     // 0 % x => 0
1505     ExprHandle body = ExprHandle(0) % x;
1506     ExprHandle simplified = IRSimplifier::simplify(body);
1507     IS_IMM_WITH_VAL(Int, simplified.node(), 0);
1508   }
1509 
1510   {
1511     // x % 1 => 0
1512     ExprHandle body = x % 1;
1513     ExprHandle simplified = IRSimplifier::simplify(body);
1514     IS_IMM_WITH_VAL(Int, simplified.node(), 0);
1515   }
1516 
1517   {
1518     // Doesn't change unknown mods.
1519     // x % y => x % y
1520     ExprHandle body = x % y;
1521     ExprHandle simplified = IRSimplifier::simplify(body);
1522     IS_NODE_WITH_NAME(Mod, simplified.node(), mod);
1523     IS_VAR_WITH_NAME(mod->lhs(), "x");
1524     IS_VAR_WITH_NAME(mod->rhs(), "y");
1525   }
1526 
1527   {
1528     // don't touch if RHS is unknown.
1529     // 4 % x => 4 % x
1530     ExprHandle body = ExprHandle(4) % x;
1531     ExprHandle simplified = IRSimplifier::simplify(body);
1532     IS_NODE_WITH_NAME(Mod, simplified.node(), mod);
1533     IS_IMM_WITH_VAL(Int, mod->lhs(), 4);
1534     IS_VAR_WITH_NAME(mod->rhs(), "x");
1535   }
1536 
1537   {
1538     // don't touch if LHS is unknown.
1539     // x % 4 => x % 4
1540     ExprHandle body = x % 4;
1541     ExprHandle simplified = IRSimplifier::simplify(body);
1542     IS_NODE_WITH_NAME(Mod, simplified.node(), mod);
1543     IS_VAR_WITH_NAME(mod->lhs(), "x");
1544     IS_IMM_WITH_VAL(Int, mod->rhs(), 4);
1545   }
1546 
1547   {
1548     // if LHS is a multiple of RHS, mod is zero.
1549     // 2 * x % x => 0
1550     ExprHandle body = (x * 2) % x;
1551     ExprHandle simplified = IRSimplifier::simplify(body);
1552     IS_IMM_WITH_VAL(Int, simplified.node(), 0);
1553   }
1554 
1555   {
1556     // true even if the multiple is not constant.
1557     // x * y % x => 0
1558     ExprHandle body = (x * y) % x;
1559     ExprHandle simplified = IRSimplifier::simplify(body);
1560     IS_IMM_WITH_VAL(Int, simplified.node(), 0);
1561   }
1562 
1563   {
1564     // true with multiple unknown values in LHS.
1565     // x * y * z % x => 0
1566     ExprHandle body = (x * y * z) % x;
1567     ExprHandle simplified = IRSimplifier::simplify(body);
1568     IS_IMM_WITH_VAL(Int, simplified.node(), 0);
1569   }
1570 
1571   {
1572     // true if the denom is compound.
1573     // x * y * z % y * z => 0
1574     ExprHandle body = (x * y * z) % (y * z);
1575     ExprHandle simplified = IRSimplifier::simplify(body);
1576     IS_IMM_WITH_VAL(Int, simplified.node(), 0);
1577   }
1578 
1579   {
1580     // Sanity check true with scalars that are multiples.
1581     // 12 * x % 4 => 0
1582     ExprHandle body = (x * 12) % 4;
1583     ExprHandle simplified = IRSimplifier::simplify(body);
1584     IS_IMM_WITH_VAL(Int, simplified.node(), 0);
1585   }
1586 
1587   {
1588     // Sanity check not true if the smaller scalar is on LHS.
1589     // 4 * x % 12 => 4 * x % 12
1590     ExprHandle body = (x * 4) % 12;
1591     ExprHandle simplified = IRSimplifier::simplify(body);
1592     IS_NODE_WITH_NAME(Mod, simplified.node(), mod);
1593     IS_NODE_WITH_NAME(Mul, mod->lhs(), mul);
1594     IS_IMM_WITH_VAL(Int, mul->lhs(), 4);
1595     IS_VAR_WITH_NAME(mul->rhs(), "x");
1596     IS_IMM_WITH_VAL(Int, mod->rhs(), 12);
1597   }
1598 
1599   {
1600     // Both scalar and symbolic in multiple.
1601     // (6 * x * y) % (3 * x * y) => 0
1602     ExprHandle body = (ExprHandle(6) * x * y) % (x * y * 3);
1603     ExprHandle simplified = IRSimplifier::simplify(body);
1604     IS_IMM_WITH_VAL(Int, simplified.node(), 0);
1605   }
1606 }
1607 
1608 // Test that mixing ops together simplifies as expected.
TEST(Simplify,SimplifyMultiOp)1609 TEST(Simplify, SimplifyMultiOp) {
1610   VarHandle x("x", kInt);
1611   VarHandle y("y", kInt);
1612 
1613   {
1614     // (x * y) + (x - y) => (x + x * y) - y
1615     ExprHandle body = (x * y) + (x - y);
1616     ExprHandle simplified = IRSimplifier::simplify(body);
1617 
1618     IS_NODE_WITH_NAME(Sub, simplified.node(), sub);
1619     IS_NODE_WITH_NAME(Add, sub->lhs(), add);
1620     IS_VAR_WITH_NAME(add->lhs(), "x");
1621     IS_NODE_WITH_NAME(Mul, add->rhs(), mul);
1622     IS_VAR_WITH_NAME(mul->lhs(), "x");
1623     IS_VAR_WITH_NAME(mul->rhs(), "y");
1624     IS_VAR_WITH_NAME(sub->rhs(), "y");
1625   }
1626 
1627   {
1628     // (x + y) - x * y => (x + y) - x * y
1629     ExprHandle body = (x + y) - x * y;
1630     ExprHandle simplified = IRSimplifier::simplify(body);
1631     IS_NODE_WITH_NAME(Sub, simplified.node(), sub);
1632     IS_NODE_WITH_NAME(Add, sub->lhs(), add);
1633     IS_NODE_WITH_NAME(Mul, sub->rhs(), mul);
1634     IS_VAR_WITH_NAME(add->lhs(), "x");
1635     IS_VAR_WITH_NAME(add->rhs(), "y");
1636     IS_VAR_WITH_NAME(mul->lhs(), "x");
1637     IS_VAR_WITH_NAME(mul->rhs(), "y");
1638   }
1639 
1640   {
1641     // (x - y) - (x + y) => -2 * y
1642     ExprHandle body = (x - y) - (x + y);
1643     ExprHandle simplified = IRSimplifier::simplify(body);
1644 
1645     IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
1646     IS_IMM_WITH_VAL(Int, mul->lhs(), -2);
1647     IS_VAR_WITH_NAME(mul->rhs(), "y");
1648   }
1649 
1650   {
1651     // (x - 0) + (x * 1) - (x + 0) => x
1652     ExprHandle body = (x - 0) + (x * 1) - (x + 0);
1653     ExprHandle simplified = IRSimplifier::simplify(body);
1654 
1655     IS_VAR_WITH_NAME(simplified.node(), "x");
1656   }
1657 
1658   {
1659     // (x - 0.f) + (x * 1.f) - (x + 0.f) => float(x) + float(x) - float(x)
1660     // Even in Float simple terms cancel out, but the variable ones cannot.
1661     ExprHandle body =
1662         (x - ExprHandle(0.f)) + (x * ExprHandle(1.f)) - (x + ExprHandle(0.f));
1663     ExprHandle simplified = IRSimplifier::simplify(body);
1664 
1665     IS_NODE_WITH_NAME(Sub, simplified.node(), sub);
1666     IS_NODE_WITH_NAME(Add, sub->lhs(), add);
1667     IS_NODE_WITH_NAME(Cast, add->lhs(), cast1);
1668     IS_VAR_WITH_NAME(cast1->src_value(), "x");
1669     IS_NODE_WITH_NAME(Cast, add->rhs(), cast2);
1670     IS_VAR_WITH_NAME(cast2->src_value(), "x");
1671     IS_NODE_WITH_NAME(Cast, sub->rhs(), cast3);
1672     IS_VAR_WITH_NAME(cast3->src_value(), "x");
1673   }
1674 }
1675 
1676 // Test that chaining many ops together works as expected.
TEST(Simplify,SimplifyManyOps)1677 TEST(Simplify, SimplifyManyOps) {
1678   VarHandle x("x", kInt);
1679   VarHandle y("y", kInt);
1680 
1681   {
1682     // x + y + x + x + y + y + x + y + x = 4 * y + 5 * x
1683     ExprHandle body = x + y + x + x + y + y + x + y + x;
1684     ExprHandle simplified = IRSimplifier::simplify(body);
1685 
1686     IS_NODE_WITH_NAME(Add, simplified.node(), add);
1687 
1688     IS_NODE_WITH_NAME(Mul, add->lhs(), lhs);
1689     IS_IMM_WITH_VAL(Int, lhs->lhs(), 4);
1690     IS_VAR_WITH_NAME(lhs->rhs(), "y");
1691 
1692     IS_NODE_WITH_NAME(Mul, add->rhs(), rhs);
1693     IS_IMM_WITH_VAL(Int, rhs->lhs(), 5);
1694     IS_VAR_WITH_NAME(rhs->rhs(), "x");
1695   }
1696 
1697   {
1698     // x - y + x + x - y - y + x - y + x = 5 * x - 4 * y
1699     ExprHandle body = x - y + x + x - y - y + x - y + x;
1700     ExprHandle simplified = IRSimplifier::simplify(body);
1701 
1702     IS_NODE_WITH_NAME(Sub, simplified.node(), add);
1703 
1704     IS_NODE_WITH_NAME(Mul, add->lhs(), lhs);
1705     IS_IMM_WITH_VAL(Int, lhs->lhs(), 5);
1706     IS_VAR_WITH_NAME(lhs->rhs(), "x");
1707 
1708     IS_NODE_WITH_NAME(Mul, add->rhs(), rhs);
1709     IS_IMM_WITH_VAL(Int, rhs->lhs(), 4);
1710     IS_VAR_WITH_NAME(rhs->rhs(), "y");
1711   }
1712 
1713   {
1714     // x + y + x - x - y - y + x + y + x = 3 * x
1715     ExprHandle body = x + y + x - x - y - y + x + y + x;
1716     ExprHandle simplified = IRSimplifier::simplify(body);
1717 
1718     IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
1719     IS_IMM_WITH_VAL(Int, mul->lhs(), 3);
1720     IS_VAR_WITH_NAME(mul->rhs(), "x");
1721   }
1722 }
1723 
TEST(Simplify,SimplifyFactorization)1724 TEST(Simplify, SimplifyFactorization) {
1725   VarHandle x("x", kInt);
1726   VarHandle y("y", kInt);
1727 
1728   {
1729     // (2 * x) + (2 * y) => 2 * (x + y)
1730     ExprHandle body = (ExprHandle(2) * x + ExprHandle(2) * y);
1731     ExprHandle simplified = IRSimplifier::simplify(body);
1732 
1733     IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
1734     IS_IMM_WITH_VAL(Int, mul->lhs(), 2);
1735 
1736     IS_NODE_WITH_NAME(Add, mul->rhs(), add);
1737     IS_VAR_WITH_NAME(add->lhs(), "x");
1738     IS_VAR_WITH_NAME(add->rhs(), "y");
1739   }
1740 
1741   {
1742     // Factorization when scalars have common divider.
1743     // (2 * x) + (4 * y) => 2 * (2 * y + x)
1744     ExprHandle body = (ExprHandle(2) * x + ExprHandle(4) * y);
1745     ExprHandle simplified = IRSimplifier::simplify(body);
1746 
1747     IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
1748     IS_IMM_WITH_VAL(Int, mul->lhs(), 2);
1749 
1750     IS_NODE_WITH_NAME(Add, mul->rhs(), add);
1751     IS_VAR_WITH_NAME(add->lhs(), "x");
1752     IS_NODE_WITH_NAME(Mul, add->rhs(), mul2);
1753     IS_IMM_WITH_VAL(Int, mul2->lhs(), 2);
1754     IS_VAR_WITH_NAME(mul2->rhs(), "y");
1755   }
1756 
1757   {
1758     // Factorization attempt without a common divider.
1759     // (2 * x) + (5 * y) =>  (5 * y) + (2 * x)
1760     ExprHandle body = (ExprHandle(2) * x + ExprHandle(5) * y);
1761     ExprHandle simplified = IRSimplifier::simplify(body);
1762 
1763     IS_NODE_WITH_NAME(Add, simplified.node(), add);
1764 
1765     IS_NODE_WITH_NAME(Mul, add->lhs(), lhs);
1766     IS_IMM_WITH_VAL(Int, lhs->lhs(), 2);
1767     IS_VAR_WITH_NAME(lhs->rhs(), "x");
1768 
1769     IS_NODE_WITH_NAME(Mul, add->rhs(), rhs);
1770     IS_IMM_WITH_VAL(Int, rhs->lhs(), 5);
1771     IS_VAR_WITH_NAME(rhs->rhs(), "y");
1772   }
1773 
1774   {
1775     // Factorization after merging.
1776     // (2 * x) + (4 * y) + (8 * x + 6 * y) => 10 * (x + y)
1777     ExprHandle body = (ExprHandle(2) * x + ExprHandle(4) * y) +
1778         (ExprHandle(8) * x + ExprHandle(6) * y);
1779     ExprHandle simplified = IRSimplifier::simplify(body);
1780 
1781     IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
1782     IS_IMM_WITH_VAL(Int, mul->lhs(), 10);
1783 
1784     IS_NODE_WITH_NAME(Add, mul->rhs(), add);
1785     IS_VAR_WITH_NAME(add->lhs(), "x");
1786     IS_VAR_WITH_NAME(add->rhs(), "y");
1787   }
1788 
1789   {
1790     // Factorization with common divider but different signs.
1791     // (2 * x) + (-4 * y) => 2 * (x - 2 * y)
1792     ExprHandle body = (ExprHandle(2) * x + ExprHandle(-4) * y);
1793     ExprHandle simplified = IRSimplifier::simplify(body);
1794 
1795     IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
1796     IS_IMM_WITH_VAL(Int, mul->lhs(), 2);
1797 
1798     IS_NODE_WITH_NAME(Sub, mul->rhs(), sub);
1799     IS_VAR_WITH_NAME(sub->lhs(), "x");
1800     IS_NODE_WITH_NAME(Mul, sub->rhs(), mul2);
1801     IS_IMM_WITH_VAL(Int, mul2->lhs(), 2);
1802     IS_VAR_WITH_NAME(mul2->rhs(), "y");
1803   }
1804 
1805   {
1806     // Factorization with all negative numbers.
1807     // (-2 * x) + (-4 * y) => 2 * (-1 * x - 2 * y)
1808     ExprHandle body = ExprHandle(-2) * x + ExprHandle(-4) * y;
1809     ExprHandle simplified = IRSimplifier::simplify(body);
1810 
1811     IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
1812     IS_IMM_WITH_VAL(Int, mul->lhs(), 2);
1813 
1814     IS_NODE_WITH_NAME(Sub, mul->rhs(), sub);
1815     IS_NODE_WITH_NAME(Mul, sub->lhs(), mul2);
1816     IS_IMM_WITH_VAL(Int, mul2->lhs(), -1);
1817     IS_VAR_WITH_NAME(mul2->rhs(), "x");
1818     IS_NODE_WITH_NAME(Mul, sub->rhs(), mul3);
1819     IS_IMM_WITH_VAL(Int, mul3->lhs(), 2);
1820     IS_VAR_WITH_NAME(mul3->rhs(), "y");
1821   }
1822 
1823   {
1824     // The following test ensures that there in no infinite recursion during
1825     // factorization when negative numbers are involved.
1826     VarHandle a("a", kInt);
1827     VarHandle b("b", kInt);
1828     VarHandle c("c", kInt);
1829     VarHandle d("d", kInt);
1830     VarHandle e("e", kInt);
1831     VarHandle f("f", kInt);
1832     VarHandle g("g", kInt);
1833     VarHandle h("h", kInt);
1834 
1835     ExprHandle body = a * 1024 + 0 + b * (-1) + c * (-1) + d * 1 + e * 1 +
1836         f * 32 + g * (-1024) + h * (-32);
1837     ExprHandle simplified = IRSimplifier::simplify(body);
1838     checkExprIR(
1839         simplified,
1840         "((((((d + e) + 1024 * a) + 32 * f) - b) - c) - 1024 * g) - 32 * h");
1841   }
1842 }
1843 
1844 // (4 * x + y + z * 2) + (4 * x + y + z * 4) => 2 * (y + 3 * z + 4 * x)
TEST(Simplify,SimplifyFactorizeUneven)1845 TEST(Simplify, SimplifyFactorizeUneven) {
1846   VarHandle x("x", kInt);
1847   VarHandle y("y", kInt);
1848   VarHandle z("z", kInt);
1849   ExprHandle body =
1850       (ExprHandle(4) * x + y + z * 2) + (ExprHandle(4) * x + y + z * 4);
1851   ExprHandle simplified = IRSimplifier::simplify(body);
1852 
1853   IS_NODE_WITH_NAME(Mul, simplified.node(), root);
1854   IS_IMM_WITH_VAL(Int, root->lhs(), 2);
1855   IS_NODE_WITH_NAME(Add, root->rhs(), add1);
1856   IS_NODE_WITH_NAME(Add, add1->lhs(), add2);
1857 
1858   IS_VAR_WITH_NAME(add2->lhs(), "y");
1859   IS_NODE_WITH_NAME(Mul, add2->rhs(), zmul);
1860   IS_NODE_WITH_NAME(Mul, add1->rhs(), xmul);
1861 
1862   IS_IMM_WITH_VAL(Int, xmul->lhs(), 4);
1863   IS_VAR_WITH_NAME(xmul->rhs(), "x");
1864 
1865   IS_IMM_WITH_VAL(Int, zmul->lhs(), 3);
1866   IS_VAR_WITH_NAME(zmul->rhs(), "z");
1867 }
1868 
1869 // (x * y) + (2 * x) * (x + y) => 2 * (x * x) + 3 * (x * y)
1870 // This is kind of a placeholder test for variable factorization.
TEST(Simplify,SimplifyDeeperTerms)1871 TEST(Simplify, SimplifyDeeperTerms) {
1872   VarHandle x("x", kInt);
1873   VarHandle y("y", kInt);
1874   ExprHandle body = (x * y) + (ExprHandle(2) * x) * (x + y);
1875   ExprHandle simplified = IRSimplifier::simplify(body);
1876 
1877   IS_NODE_WITH_NAME(Add, simplified.node(), add);
1878 
1879   IS_NODE_WITH_NAME(Mul, add->lhs(), lhs);
1880   IS_IMM_WITH_VAL(Int, lhs->lhs(), 2);
1881   IS_NODE_WITH_NAME(Mul, lhs->rhs(), xxTerm);
1882   IS_VAR_WITH_NAME(xxTerm->lhs(), "x");
1883   IS_VAR_WITH_NAME(xxTerm->rhs(), "x");
1884 
1885   IS_NODE_WITH_NAME(Mul, add->rhs(), rhs);
1886   IS_IMM_WITH_VAL(Int, rhs->lhs(), 3);
1887   IS_NODE_WITH_NAME(Mul, rhs->rhs(), xyTerm);
1888   IS_VAR_WITH_NAME(xyTerm->lhs(), "x");
1889   IS_VAR_WITH_NAME(xyTerm->rhs(), "y");
1890 }
1891 
1892 // Tests the difference between two less trivial expressions.
1893 // (m * (1 * n_1) + (n  + 1)) - (m *  (1 * n_1) + n) => 1
TEST(Simplify,SimplifyDeeperDifference)1894 TEST(Simplify, SimplifyDeeperDifference) {
1895   VarHandle n("n", kInt);
1896   VarHandle n_1("n_1", kInt);
1897   VarHandle m("m", kInt);
1898   ExprHandle body =
1899       (m * (ExprHandle(1) * n_1) + (n + 1)) - (m * (ExprHandle(1) * n_1) + n);
1900   ExprHandle simplified = IRSimplifier::simplify(body);
1901 
1902   IS_IMM_WITH_VAL(Int, simplified.node(), 1);
1903 }
1904 
1905 // Test constant folding into the difference between expressions.
1906 // 2 + char((m * (1 * n_1) + (n  + 1)) - (m *  (1 * n_1) + n)) => 3
TEST(Simplify,SimplifyFoldComplexDifference)1907 TEST(Simplify, SimplifyFoldComplexDifference) {
1908   VarHandle n("n", kInt);
1909   VarHandle n_1("n_1", kInt);
1910   VarHandle m("m", kInt);
1911   ExprHandle body =
1912       (IntImm::make(2) +
1913        (Cast::make(
1914            kChar,
1915            (m * (ExprHandle(1) * n_1) + (n + 1)) -
1916                (m * (ExprHandle(1) * n_1) + n))));
1917   ExprHandle simplified = IRSimplifier::simplify(body);
1918   IS_IMM_WITH_VAL(Int, simplified.node(), 3);
1919 }
1920 
TEST(Simplify,SimplifyIfComponents)1921 TEST(Simplify, SimplifyIfComponents) {
1922   VarHandle x("x", kInt);
1923   VarHandle y("y", kInt);
1924   ExprHandle body = IfThenElse::make(
1925       ((ExprHandle(5) - ExprHandle(4)) * x) > y,
1926       ExprHandle(2) * x - x,
1927       ExprHandle(2) * y - y);
1928 
1929   ExprHandle simplified = IRSimplifier::simplify(body);
1930 
1931   IS_NODE_WITH_NAME(IfThenElse, simplified.node(), ifexpr);
1932 
1933   IS_NODE_WITH_NAME(CompareSelect, ifexpr->condition(), cmp);
1934   ASSERT_EQ(cmp->compare_select_op(), kGT);
1935   IS_VAR_WITH_NAME(cmp->lhs(), "x");
1936   IS_VAR_WITH_NAME(cmp->rhs(), "y");
1937 
1938   IS_VAR_WITH_NAME(ifexpr->true_value(), "x");
1939   IS_VAR_WITH_NAME(ifexpr->false_value(), "y");
1940 }
1941 
TEST(Simplify,SimplifyOpaqueTerms)1942 TEST(Simplify, SimplifyOpaqueTerms) {
1943   VarHandle x("x", kInt);
1944   VarHandle y("y", kInt);
1945 
1946   {
1947     // 2 * x/y * y - x/y * y => x/y * y
1948     ExprHandle body = ((ExprHandle(2)) * (x / y) * y) - ((x / y) * y);
1949     ExprHandle simplified = IRSimplifier::simplify(body);
1950 
1951     IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
1952     IS_NODE_WITH_NAME(Div, mul->lhs(), div);
1953     IS_VAR_WITH_NAME(div->lhs(), "x");
1954     IS_VAR_WITH_NAME(div->rhs(), "y");
1955     IS_VAR_WITH_NAME(mul->rhs(), "y");
1956   }
1957 
1958   {
1959     // x%y - (x%y - 1) => 1
1960     ExprHandle body = (x % y) - ((x % y) - 1);
1961     ExprHandle simplified = IRSimplifier::simplify(body);
1962 
1963     IS_IMM_WITH_VAL(Int, simplified.node(), 1);
1964   }
1965 }
1966 
TEST(Simplify,SimplifySymbolicMinMax)1967 TEST(Simplify, SimplifySymbolicMinMax) {
1968   {
1969     // Minimum with constant difference between terms.
1970     VarHandle x("x", kInt);
1971     ExprHandle body = Min::make(x + 3, x + 7, true);
1972     ExprHandle simplified = IRSimplifier::simplify(body);
1973 
1974     IS_NODE_WITH_NAME(Add, simplified.node(), add);
1975     IS_VAR_WITH_NAME(add->lhs(), "x");
1976     IS_IMM_WITH_VAL(Int, add->rhs(), 3);
1977   }
1978 
1979   {
1980     // Maximum with constant difference between terms.
1981     VarHandle x("x", kInt);
1982     ExprHandle body = Max::make(x + 3, x + 7, true);
1983     ExprHandle simplified = IRSimplifier::simplify(body);
1984 
1985     IS_NODE_WITH_NAME(Add, simplified.node(), add);
1986     IS_VAR_WITH_NAME(add->lhs(), "x");
1987     IS_IMM_WITH_VAL(Int, add->rhs(), 7);
1988   }
1989 
1990   {
1991     // Can't simplify multiples because of signedness of variable component.
1992     // TODO: maybe we could for unsigned types?
1993     VarHandle x("x", kInt);
1994     ExprHandle body = Max::make(x * 3, x * 7, true);
1995     ExprHandle simplified = IRSimplifier::simplify(body);
1996 
1997     IS_NODE(Max, simplified.node());
1998   }
1999 }
2000 
TEST(Simplify,SimplifyNestedMax)2001 TEST(Simplify, SimplifyNestedMax) {
2002   VarHandle x("x", kInt);
2003   VarHandle y("y", kInt);
2004   VarHandle z("z", kInt);
2005 
2006   {
2007     // Max(x + y, x + y) => x + y
2008     ExprHandle body = Max::make(x + y, x + y, true);
2009     ExprHandle simplified = IRSimplifier::simplify(body);
2010 
2011     // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
2012     IS_BINOP_W_VARS(Add, simplified.node(), add, "x", "y");
2013   }
2014 
2015   {
2016     // Max(x + y, Max(x + y, z)) => Max(x + y, z)
2017     ExprHandle body = Max::make(x + y, Max::make(x + y, z, true), true);
2018     ExprHandle simplified = IRSimplifier::simplify(body);
2019 
2020     IS_NODE_WITH_NAME(Max, simplified.node(), max);
2021     IS_BINOP_W_VARS(Add, max->lhs(), add, "x", "y");
2022     IS_VAR_WITH_NAME(max->rhs(), "z");
2023   }
2024 
2025   {
2026     // Max(x + y, Max(z, x + y)) => Max(x + y, z)
2027     ExprHandle body = Max::make(x + y, Max::make(z, x + y, true), true);
2028     ExprHandle simplified = IRSimplifier::simplify(body);
2029 
2030     IS_NODE_WITH_NAME(Max, simplified.node(), max);
2031     IS_BINOP_W_VARS(Add, max->lhs(), add, "x", "y");
2032     IS_VAR_WITH_NAME(max->rhs(), "z");
2033   }
2034 
2035   {
2036     // Max(Max(x + y, z), x + y) => Max(x + y, z)
2037     ExprHandle body = Max::make(Max::make(x + y, z, true), x + y, true);
2038     ExprHandle simplified = IRSimplifier::simplify(body);
2039 
2040     IS_NODE_WITH_NAME(Max, simplified.node(), max);
2041     IS_BINOP_W_VARS(Add, max->lhs(), add, "x", "y");
2042     IS_VAR_WITH_NAME(max->rhs(), "z");
2043   }
2044 
2045   {
2046     // Max(Max(z, x + y), x + y) => Max(x + y, z)
2047     ExprHandle body = Max::make(Max::make(z, x + y, true), x + y, true);
2048     ExprHandle simplified = IRSimplifier::simplify(body);
2049 
2050     IS_NODE_WITH_NAME(Max, simplified.node(), max);
2051     IS_BINOP_W_VARS(Add, max->lhs(), add, "x", "y");
2052     IS_VAR_WITH_NAME(max->rhs(), "z");
2053   }
2054 
2055   {
2056     // Max(Max(x, y), x) => Max(Max(x, y), x)
2057     // Nested Max ops with different propagate_nans should not be simplified.
2058     ExprHandle body = Max::make(Max::make(x, y, true), x, false);
2059     ExprHandle simplified = IRSimplifier::simplify(body);
2060 
2061     IS_NODE_WITH_NAME(Max, simplified.node(), max);
2062     IS_BINOP_W_VARS(Max, max->lhs(), max1, "x", "y");
2063     ASSERT_TRUE(max1->propagate_nans());
2064     IS_VAR_WITH_NAME(max->rhs(), "x");
2065     ASSERT_FALSE(max->propagate_nans());
2066   }
2067 
2068   {
2069     // Max(Min(x, y), Min(x, z)) => Min(Max(y, z), x)
2070     ExprHandle body =
2071         Max::make(Min::make(x, y, true), Min::make(x, z, true), true);
2072     ExprHandle simplified = IRSimplifier::simplify(body);
2073     checkExprIR(simplified, "Min(Max(y, z, 1), x, 1)");
2074   }
2075 
2076   {
2077     // Max(Min(x, y), Min(z, x)) => Min(Max(y, z), x)
2078     ExprHandle body =
2079         Max::make(Min::make(x, y, true), Min::make(z, x, true), true);
2080     ExprHandle simplified = IRSimplifier::simplify(body);
2081     checkExprIR(simplified, "Min(Max(y, z, 1), x, 1)");
2082   }
2083 
2084   {
2085     // Max(Min(y, x), Min(x, z)) => Min(Max(y, z), x)
2086     ExprHandle body =
2087         Max::make(Min::make(y, x, true), Min::make(x, z, true), true);
2088     ExprHandle simplified = IRSimplifier::simplify(body);
2089     checkExprIR(simplified, "Min(Max(y, z, 1), x, 1)");
2090   }
2091 
2092   {
2093     // Max(Min(y, x), Min(z, x)) => Min(Max(y, z), x)
2094     ExprHandle body =
2095         Max::make(Min::make(y, x, true), Min::make(z, x, true), true);
2096     ExprHandle simplified = IRSimplifier::simplify(body);
2097     checkExprIR(simplified, "Min(Max(y, z, 1), x, 1)");
2098   }
2099 
2100   {
2101     // Max(Min(y, x), Min(z, x)) => Max(Min(x, y), Min(x, z))
2102     // When all the ops in the pattern do not have the same propagate_nans,
2103     // it should not be simplified.
2104     ExprHandle body =
2105         Max::make(Min::make(y, x, true), Min::make(z, x, false), true);
2106     ExprHandle simplified = IRSimplifier::simplify(body);
2107 
2108     IS_NODE_WITH_NAME(Max, simplified.node(), max);
2109     IS_BINOP_W_VARS(Min, max->lhs(), min1, "x", "y");
2110     ASSERT_TRUE(min1->propagate_nans());
2111     IS_BINOP_W_VARS(Min, max->rhs(), min2, "x", "z");
2112     ASSERT_FALSE(min2->propagate_nans());
2113     ASSERT_TRUE(max->propagate_nans());
2114   }
2115 
2116   {
2117     // Max(5, Max(x, 8)) => Max(x, 8)
2118     ExprHandle body = Max::make(5, Max::make(x, 8, true), true);
2119     ExprHandle simplified = IRSimplifier::simplify(body);
2120 
2121     IS_BINOP_W_CONST(Max, simplified.node(), max, "x", 8);
2122     ASSERT_TRUE(max->propagate_nans());
2123   }
2124 
2125   {
2126     // Max(8, Max(x, 5)) => Max(x, 8)
2127     ExprHandle body = Max::make(8, Max::make(x, 5, true), true);
2128     ExprHandle simplified = IRSimplifier::simplify(body);
2129 
2130     IS_BINOP_W_CONST(Max, simplified.node(), max, "x", 8);
2131     ASSERT_TRUE(max->propagate_nans());
2132   }
2133 
2134   {
2135     // Max(Max(x, 8), 5) => Max(x, 8)
2136     ExprHandle body = Max::make(Max::make(x, 8, true), 5, true);
2137     ExprHandle simplified = IRSimplifier::simplify(body);
2138 
2139     IS_BINOP_W_CONST(Max, simplified.node(), max, "x", 8);
2140     ASSERT_TRUE(max->propagate_nans());
2141   }
2142 
2143   {
2144     // Max(Max(x, 5), 8) => Max(x, 8)
2145     ExprHandle body = Max::make(Max::make(x, 5, true), 8, true);
2146     ExprHandle simplified = IRSimplifier::simplify(body);
2147 
2148     IS_BINOP_W_CONST(Max, simplified.node(), max, "x", 8);
2149     ASSERT_TRUE(max->propagate_nans());
2150   }
2151 
2152   {
2153     // Max(5, Max(x, Max(y, Max(z, 8)))) => Max(Max(Max(x, 8), y), z)
2154     ExprHandle body = Max::make(
2155         5, Max::make(x, Max::make(y, Max::make(z, 8, true), true), true), true);
2156     ExprHandle simplified = IRSimplifier::simplify(body);
2157 
2158     IS_NODE_WITH_NAME(Max, simplified.node(), max1);
2159     IS_NODE_WITH_NAME(Max, max1->lhs(), max2);
2160     IS_BINOP_W_CONST(Max, max2->lhs(), max3, "x", 8);
2161     ASSERT_TRUE(max3->propagate_nans());
2162     IS_VAR_WITH_NAME(max2->rhs(), "y");
2163     IS_VAR_WITH_NAME(max1->rhs(), "z");
2164   }
2165 
2166   {
2167     // Max(8, Max(Max(y, Max(z, 5)), x)) => Max(Max(Max(x, 8), y), z)
2168     ExprHandle body = Max::make(
2169         8, Max::make(Max::make(y, Max::make(z, 5, true), true), x, true), true);
2170     ExprHandle simplified = IRSimplifier::simplify(body);
2171 
2172     IS_NODE_WITH_NAME(Max, simplified.node(), max1);
2173     IS_NODE_WITH_NAME(Max, max1->lhs(), max2);
2174     IS_BINOP_W_CONST(Max, max2->lhs(), max3, "x", 8);
2175     ASSERT_TRUE(max3->propagate_nans());
2176     IS_VAR_WITH_NAME(max2->rhs(), "y");
2177     IS_VAR_WITH_NAME(max1->rhs(), "z");
2178   }
2179 
2180   {
2181     // Max(5, Max(Max(Max(z, 8), y), x)) => Max(Max(Max(x, 8), y), z)
2182     ExprHandle body = Max::make(
2183         5, Max::make(Max::make(Max::make(z, 8, true), y, true), x, true), true);
2184     ExprHandle simplified = IRSimplifier::simplify(body);
2185 
2186     IS_NODE_WITH_NAME(Max, simplified.node(), max1);
2187     IS_NODE_WITH_NAME(Max, max1->lhs(), max2);
2188     IS_BINOP_W_CONST(Max, max2->lhs(), max3, "x", 8);
2189     ASSERT_TRUE(max3->propagate_nans());
2190     IS_VAR_WITH_NAME(max2->rhs(), "y");
2191     IS_VAR_WITH_NAME(max1->rhs(), "z");
2192   }
2193 
2194   {
2195     // Max(Max(x, Max(y, Max(5, z))), 8) => Max(Max(Max(x, 8), y), z)
2196     ExprHandle body = Max::make(
2197         Max::make(x, Max::make(y, Max::make(5, z, true), true), true), 8, true);
2198     ExprHandle simplified = IRSimplifier::simplify(body);
2199 
2200     IS_NODE_WITH_NAME(Max, simplified.node(), max1);
2201     IS_NODE_WITH_NAME(Max, max1->lhs(), max2);
2202     IS_BINOP_W_CONST(Max, max2->lhs(), max3, "x", 8);
2203     ASSERT_TRUE(max3->propagate_nans());
2204     IS_VAR_WITH_NAME(max2->rhs(), "y");
2205     IS_VAR_WITH_NAME(max1->rhs(), "z");
2206   }
2207 
2208   {
2209     // Max(Max(Max(y, Max(8, z)), x), 5) => Max(Max(Max(x, 8), y), z)
2210     ExprHandle body = Max::make(
2211         Max::make(Max::make(y, Max::make(z, 8, true), true), x, true), 5, true);
2212     ExprHandle simplified = IRSimplifier::simplify(body);
2213 
2214     IS_NODE_WITH_NAME(Max, simplified.node(), max1);
2215     IS_NODE_WITH_NAME(Max, max1->lhs(), max2);
2216     IS_BINOP_W_CONST(Max, max2->lhs(), max3, "x", 8);
2217     ASSERT_TRUE(max3->propagate_nans());
2218     IS_VAR_WITH_NAME(max2->rhs(), "y");
2219     IS_VAR_WITH_NAME(max1->rhs(), "z");
2220   }
2221 
2222   {
2223     // Max(Max(Max(Max(5, z), y), x), 8) => Max(Max(Max(x, 8), y), z)
2224     ExprHandle body = Max::make(
2225         Max::make(Max::make(Max::make(z, 5, true), y, true), x, true), 8, true);
2226     ExprHandle simplified = IRSimplifier::simplify(body);
2227 
2228     IS_NODE_WITH_NAME(Max, simplified.node(), max1);
2229     IS_NODE_WITH_NAME(Max, max1->lhs(), max2);
2230     IS_BINOP_W_CONST(Max, max2->lhs(), max3, "x", 8);
2231     ASSERT_TRUE(max3->propagate_nans());
2232     IS_VAR_WITH_NAME(max2->rhs(), "y");
2233     IS_VAR_WITH_NAME(max1->rhs(), "z");
2234   }
2235 
2236   {
2237     // Max(Max(Max(Max(z, 5), y), x), 8) => Max(Max(x, Max(Max(z, 5), y)), 8)
2238     // Do not simplify when all the Max ops do not have the same
2239     // propagate_nans.
2240     ExprHandle body = Max::make(
2241         Max::make(Max::make(Max::make(z, 5, true), y, false), x, true),
2242         8,
2243         false);
2244     ExprHandle simplified = IRSimplifier::simplify(body);
2245     checkExprIR(simplified, "Max(Max(Max(Max(z, 5, 1), y, 0), x, 1), 8, 0)");
2246   }
2247 
2248   {
2249     // Max(8, Max(Max(x, 5), Max(y, z))) => Max(Max(Max(x, 8), y), z)
2250     ExprHandle body = Max::make(
2251         8, Max::make(Max::make(x, 5, true), Max::make(y, z, true), true), true);
2252     ExprHandle simplified = IRSimplifier::simplify(body);
2253 
2254     IS_NODE_WITH_NAME(Max, simplified.node(), max1);
2255     IS_NODE_WITH_NAME(Max, max1->lhs(), max2);
2256     IS_BINOP_W_CONST(Max, max2->lhs(), max3, "x", 8);
2257     ASSERT_TRUE(max3->propagate_nans());
2258     IS_VAR_WITH_NAME(max2->rhs(), "y");
2259     IS_VAR_WITH_NAME(max1->rhs(), "z");
2260   }
2261 
2262   {
2263     // Max(Max(Max(x, 5), Max(y, z)), 8) => Max(Max(Max(x, 8), y), z)
2264     ExprHandle body = Max::make(
2265         Max::make(Max::make(x, 5, true), Max::make(y, z, true), true), 8, true);
2266     ExprHandle simplified = IRSimplifier::simplify(body);
2267 
2268     IS_NODE_WITH_NAME(Max, simplified.node(), max1);
2269     IS_NODE_WITH_NAME(Max, max1->lhs(), max2);
2270     IS_BINOP_W_CONST(Max, max2->lhs(), max3, "x", 8);
2271     ASSERT_TRUE(max3->propagate_nans());
2272     IS_VAR_WITH_NAME(max2->rhs(), "y");
2273     IS_VAR_WITH_NAME(max1->rhs(), "z");
2274   }
2275 }
2276 
TEST(Simplify,SimplifyNestedMin)2277 TEST(Simplify, SimplifyNestedMin) {
2278   VarHandle x("x", kInt);
2279   VarHandle y("y", kInt);
2280   VarHandle z("z", kInt);
2281 
2282   {
2283     // Min(x + y, x + y) => x + y
2284     ExprHandle body = Min::make(x + y, x + y, true);
2285     ExprHandle simplified = IRSimplifier::simplify(body);
2286 
2287     // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
2288     IS_BINOP_W_VARS(Add, simplified.node(), add, "x", "y");
2289   }
2290 
2291   {
2292     // Min(x + y, Min(x + y, z)) => Min(x + y, z)
2293     ExprHandle body = Min::make(x + y, Min::make(x + y, z, true), true);
2294     ExprHandle simplified = IRSimplifier::simplify(body);
2295 
2296     IS_NODE_WITH_NAME(Min, simplified.node(), min);
2297     IS_BINOP_W_VARS(Add, min->lhs(), add, "x", "y");
2298     IS_VAR_WITH_NAME(min->rhs(), "z");
2299   }
2300 
2301   {
2302     // Min(x + y, Min(z, x + y)) => Min(x + y, z)
2303     ExprHandle body = Min::make(x + y, Min::make(z, x + y, true), true);
2304     ExprHandle simplified = IRSimplifier::simplify(body);
2305 
2306     IS_NODE_WITH_NAME(Min, simplified.node(), min);
2307     IS_BINOP_W_VARS(Add, min->lhs(), add, "x", "y");
2308     IS_VAR_WITH_NAME(min->rhs(), "z");
2309   }
2310 
2311   {
2312     // Min(Min(x + y, z), x + y) => Min(x + y, z)
2313     ExprHandle body = Min::make(Min::make(x + y, z, true), x + y, true);
2314     ExprHandle simplified = IRSimplifier::simplify(body);
2315 
2316     IS_NODE_WITH_NAME(Min, simplified.node(), min);
2317     IS_BINOP_W_VARS(Add, min->lhs(), add, "x", "y");
2318     IS_VAR_WITH_NAME(min->rhs(), "z");
2319   }
2320 
2321   {
2322     // Min(Min(z, x + y), x + y) => Min(x + y, z)
2323     ExprHandle body = Min::make(Min::make(z, x + y, true), x + y, true);
2324     ExprHandle simplified = IRSimplifier::simplify(body);
2325 
2326     IS_NODE_WITH_NAME(Min, simplified.node(), min);
2327     IS_BINOP_W_VARS(Add, min->lhs(), add, "x", "y");
2328     IS_VAR_WITH_NAME(min->rhs(), "z");
2329   }
2330 
2331   {
2332     // Min(Min(x, y), x) => Min(Min(x, y), x)
2333     // Nested Min ops with different propagate_nans should not be simplified.
2334     ExprHandle body = Min::make(Min::make(x, y, true), x, false);
2335     ExprHandle simplified = IRSimplifier::simplify(body);
2336 
2337     IS_NODE_WITH_NAME(Min, simplified.node(), min1);
2338     IS_BINOP_W_VARS(Min, min1->lhs(), min2, "x", "y");
2339     ASSERT_TRUE(min2->propagate_nans());
2340     IS_VAR_WITH_NAME(min1->rhs(), "x");
2341     ASSERT_FALSE(min1->propagate_nans());
2342   }
2343 
2344   {
2345     // Min(Max(x, y), Max(x, z)) => Max(Min(y, z), x)
2346     ExprHandle body =
2347         Min::make(Max::make(x, y, true), Max::make(x, z, true), true);
2348     ExprHandle simplified = IRSimplifier::simplify(body);
2349     checkExprIR(simplified, "Max(Min(y, z, 1), x, 1)");
2350   }
2351 
2352   {
2353     // Min(Max(x, y), Max(z, x)) => Max(Min(y, z), x)
2354     ExprHandle body =
2355         Min::make(Max::make(x, y, true), Max::make(z, x, true), true);
2356     ExprHandle simplified = IRSimplifier::simplify(body);
2357     checkExprIR(simplified, "Max(Min(y, z, 1), x, 1)");
2358   }
2359 
2360   {
2361     // Min(Max(y, x), Max(x, z)) => Max(Min(y, z), x)
2362     ExprHandle body =
2363         Min::make(Max::make(y, x, true), Max::make(x, z, true), true);
2364     ExprHandle simplified = IRSimplifier::simplify(body);
2365     checkExprIR(simplified, "Max(Min(y, z, 1), x, 1)");
2366   }
2367 
2368   {
2369     // Min(Max(y, x), Max(z, x)) => Max(Min(y, z), x)
2370     ExprHandle body =
2371         Min::make(Max::make(y, x, true), Max::make(z, x, true), true);
2372     ExprHandle simplified = IRSimplifier::simplify(body);
2373     checkExprIR(simplified, "Max(Min(y, z, 1), x, 1)");
2374   }
2375 
2376   {
2377     // Min(Max(y, x), Max(z, x)) => Min(Max(x, y), Max(x, z))
2378     // When all the ops in the pattern do not have the same propagate_nans,
2379     // it should not be simplified.
2380     ExprHandle body =
2381         Min::make(Max::make(y, x, true), Max::make(z, x, false), true);
2382     ExprHandle simplified = IRSimplifier::simplify(body);
2383 
2384     IS_NODE_WITH_NAME(Min, simplified.node(), min);
2385     IS_BINOP_W_VARS(Max, min->lhs(), max1, "x", "y");
2386     ASSERT_TRUE(max1->propagate_nans());
2387     IS_BINOP_W_VARS(Max, min->rhs(), max2, "x", "z");
2388     ASSERT_FALSE(max2->propagate_nans());
2389     ASSERT_TRUE(min->propagate_nans());
2390   }
2391 
2392   {
2393     // Min(5, Min(x, 8)) => Min(x, 8)
2394     ExprHandle body = Min::make(5, Min::make(x, 8, true), true);
2395     ExprHandle simplified = IRSimplifier::simplify(body);
2396 
2397     IS_BINOP_W_CONST(Min, simplified.node(), min, "x", 5);
2398     ASSERT_TRUE(min->propagate_nans());
2399   }
2400 
2401   {
2402     // Min(8, Min(x, 5)) => Min(x, 8)
2403     ExprHandle body = Min::make(8, Min::make(x, 5, true), true);
2404     ExprHandle simplified = IRSimplifier::simplify(body);
2405 
2406     IS_BINOP_W_CONST(Min, simplified.node(), min, "x", 5);
2407     ASSERT_TRUE(min->propagate_nans());
2408   }
2409 
2410   {
2411     // Min(Min(x, 8), 5) => Min(x, 8)
2412     ExprHandle body = Min::make(Min::make(x, 8, true), 5, true);
2413     ExprHandle simplified = IRSimplifier::simplify(body);
2414 
2415     IS_BINOP_W_CONST(Min, simplified.node(), min, "x", 5);
2416     ASSERT_TRUE(min->propagate_nans());
2417   }
2418 
2419   {
2420     // Min(Min(x, 5), 8) => Min(x, 8)
2421     ExprHandle body = Min::make(Min::make(x, 5, true), 8, true);
2422     ExprHandle simplified = IRSimplifier::simplify(body);
2423 
2424     IS_BINOP_W_CONST(Min, simplified.node(), min, "x", 5);
2425     ASSERT_TRUE(min->propagate_nans());
2426   }
2427 
2428   {
2429     // Min(5, Min(x, Min(y, Min(z, 8)))) => Min(Min(Min(x, 5), y), z)
2430     ExprHandle body = Min::make(
2431         5, Min::make(x, Min::make(y, Min::make(z, 8, true), true), true), true);
2432     ExprHandle simplified = IRSimplifier::simplify(body);
2433 
2434     IS_NODE_WITH_NAME(Min, simplified.node(), min1);
2435     IS_NODE_WITH_NAME(Min, min1->lhs(), min2);
2436     IS_BINOP_W_CONST(Min, min2->lhs(), min3, "x", 5);
2437     ASSERT_TRUE(min3->propagate_nans());
2438     IS_VAR_WITH_NAME(min2->rhs(), "y");
2439     IS_VAR_WITH_NAME(min1->rhs(), "z");
2440   }
2441 
2442   {
2443     // Min(5, Min(Min(y, Min(z, 8)), x)) => Min(Min(Min(x, 5), y), z)
2444     ExprHandle body = Min::make(
2445         5, Min::make(Min::make(y, Min::make(z, 8, true), true), x, true), true);
2446     ExprHandle simplified = IRSimplifier::simplify(body);
2447 
2448     IS_NODE_WITH_NAME(Min, simplified.node(), min1);
2449     IS_NODE_WITH_NAME(Min, min1->lhs(), min2);
2450     IS_BINOP_W_CONST(Min, min2->lhs(), min3, "x", 5);
2451     ASSERT_TRUE(min3->propagate_nans());
2452     IS_VAR_WITH_NAME(min2->rhs(), "y");
2453     IS_VAR_WITH_NAME(min1->rhs(), "z");
2454   }
2455 
2456   {
2457     // Min(5, Min(Min(Min(z, 8), y), x)) => Min(Min(Min(x, 5), y), z)
2458     ExprHandle body = Min::make(
2459         5, Min::make(Min::make(Min::make(z, 8, true), y, true), x, true), true);
2460     ExprHandle simplified = IRSimplifier::simplify(body);
2461 
2462     IS_NODE_WITH_NAME(Min, simplified.node(), min1);
2463     IS_NODE_WITH_NAME(Min, min1->lhs(), min2);
2464     IS_BINOP_W_CONST(Min, min2->lhs(), min3, "x", 5);
2465     ASSERT_TRUE(min3->propagate_nans());
2466     IS_VAR_WITH_NAME(min2->rhs(), "y");
2467     IS_VAR_WITH_NAME(min1->rhs(), "z");
2468   }
2469 
2470   {
2471     // Min(Min(x, Min(y, Min(8, z))), 5) => Min(Min(Min(x, 5), y), z)
2472     ExprHandle body = Min::make(
2473         Min::make(x, Min::make(y, Min::make(8, z, true), true), true), 5, true);
2474     ExprHandle simplified = IRSimplifier::simplify(body);
2475 
2476     IS_NODE_WITH_NAME(Min, simplified.node(), min1);
2477     IS_NODE_WITH_NAME(Min, min1->lhs(), min2);
2478     IS_BINOP_W_CONST(Min, min2->lhs(), min3, "x", 5);
2479     ASSERT_TRUE(min3->propagate_nans());
2480     IS_VAR_WITH_NAME(min2->rhs(), "y");
2481     IS_VAR_WITH_NAME(min1->rhs(), "z");
2482   }
2483 
2484   {
2485     // Min(Min(Min(y, Min(8, z)), x), 5) => Min(Min(Min(x, 5), y), z)
2486     ExprHandle body = Min::make(
2487         Min::make(Min::make(y, Min::make(z, 8, true), true), x, true), 5, true);
2488     ExprHandle simplified = IRSimplifier::simplify(body);
2489 
2490     IS_NODE_WITH_NAME(Min, simplified.node(), min1);
2491     IS_NODE_WITH_NAME(Min, min1->lhs(), min2);
2492     IS_BINOP_W_CONST(Min, min2->lhs(), min3, "x", 5);
2493     ASSERT_TRUE(min3->propagate_nans());
2494     IS_VAR_WITH_NAME(min2->rhs(), "y");
2495     IS_VAR_WITH_NAME(min1->rhs(), "z");
2496   }
2497 
2498   {
2499     // Min(Min(Min(Min(8, z), y), x), 5) => Min(Min(Min(x, 5), y), z)
2500     ExprHandle body = Min::make(
2501         Min::make(Min::make(Min::make(z, 8, true), y, true), x, true), 5, true);
2502     ExprHandle simplified = IRSimplifier::simplify(body);
2503 
2504     IS_NODE_WITH_NAME(Min, simplified.node(), min1);
2505     IS_NODE_WITH_NAME(Min, min1->lhs(), min2);
2506     IS_BINOP_W_CONST(Min, min2->lhs(), min3, "x", 5);
2507     ASSERT_TRUE(min3->propagate_nans());
2508     IS_VAR_WITH_NAME(min2->rhs(), "y");
2509     IS_VAR_WITH_NAME(min1->rhs(), "z");
2510   }
2511 
2512   {
2513     // Min(Min(Min(Min(z, 5), y), x), 8) => Min(Min(Min(Min(z, 5), y), x), 8)
2514     // Do not simplify when all the Min ops do not have the same
2515     // propagate_nans.
2516     ExprHandle body = Min::make(
2517         Min::make(Min::make(Min::make(z, 5, true), y, false), x, true),
2518         8,
2519         false);
2520     ExprHandle simplified = IRSimplifier::simplify(body);
2521     checkExprIR(simplified, "Min(Min(Min(Min(z, 5, 1), y, 0), x, 1), 8, 0)");
2522   }
2523 
2524   {
2525     // Min(8, Min(Min(x, 5), Min(y, z))) => Min(Min(Min(x, 5), y), z)
2526     ExprHandle body = Min::make(
2527         8, Min::make(Min::make(x, 5, true), Min::make(y, z, true), true), true);
2528     ExprHandle simplified = IRSimplifier::simplify(body);
2529 
2530     IS_NODE_WITH_NAME(Min, simplified.node(), min1);
2531     IS_NODE_WITH_NAME(Min, min1->lhs(), min2);
2532     IS_BINOP_W_CONST(Min, min2->lhs(), min3, "x", 5);
2533     ASSERT_TRUE(min3->propagate_nans());
2534     IS_VAR_WITH_NAME(min2->rhs(), "y");
2535     IS_VAR_WITH_NAME(min1->rhs(), "z");
2536   }
2537 
2538   {
2539     // Min(Min(Min(x, 5), Min(y, z)), 8) => Min(Min(Min(x, 5), y), z)
2540     ExprHandle body = Min::make(
2541         Min::make(Min::make(x, 5, true), Min::make(y, z, true), true), 8, true);
2542     ExprHandle simplified = IRSimplifier::simplify(body);
2543 
2544     IS_NODE_WITH_NAME(Min, simplified.node(), min1);
2545     IS_NODE_WITH_NAME(Min, min1->lhs(), min2);
2546     IS_BINOP_W_CONST(Min, min2->lhs(), min3, "x", 5);
2547     ASSERT_TRUE(min3->propagate_nans());
2548     IS_VAR_WITH_NAME(min2->rhs(), "y");
2549     IS_VAR_WITH_NAME(min1->rhs(), "z");
2550   }
2551 }
2552 
TEST(Simplify,SimplifyWontReorderFloat)2553 TEST(Simplify, SimplifyWontReorderFloat) {
2554   {
2555     // 3 * (3 * x) - 3 * (3 * y) => 9 * (x - y)
2556     // This is an expression we can simplify.
2557     VarHandle x("x", kInt);
2558     VarHandle y("y", kInt);
2559 
2560     ExprHandle body = ExprHandle(3) * (ExprHandle(3) * x) -
2561         ExprHandle(3) * (ExprHandle(3) * y);
2562     ExprHandle simplified = IRSimplifier::simplify(body);
2563 
2564     IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
2565     IS_IMM_WITH_VAL(Int, mul->lhs(), 9);
2566     IS_NODE_WITH_NAME(Sub, mul->rhs(), sub);
2567     IS_VAR_WITH_NAME(sub->lhs(), "x");
2568     IS_VAR_WITH_NAME(sub->rhs(), "y");
2569   }
2570 
2571   {
2572     // 3 * (3 * x) - 3 * (3 * y) => 3 * (3 * x) - 3 * (3 * y).
2573     // If the vars are floating point, ops are not associative and we can't
2574     // reorder.
2575     VarHandle x("x", kFloat);
2576     VarHandle y("y", kFloat);
2577 
2578     ExprHandle body = ExprHandle(3) * (ExprHandle(3) * x) -
2579         ExprHandle(3) * (ExprHandle(3) * y);
2580     ExprHandle simplified = IRSimplifier::simplify(body);
2581 
2582     IS_NODE_WITH_NAME(Sub, simplified.node(), sub);
2583     IS_NODE_WITH_NAME(Mul, sub->lhs(), lhsMul);
2584     IS_IMM_WITH_VAL(Float, lhsMul->lhs(), 3);
2585     IS_NODE_WITH_NAME(Mul, lhsMul->rhs(), lhsVarMul);
2586     IS_IMM_WITH_VAL(Float, lhsVarMul->lhs(), 3);
2587     IS_VAR_WITH_NAME(lhsVarMul->rhs(), "x");
2588 
2589     IS_NODE_WITH_NAME(Mul, sub->rhs(), rhsMul);
2590     IS_IMM_WITH_VAL(Float, rhsMul->lhs(), 3);
2591     IS_NODE_WITH_NAME(Mul, rhsMul->rhs(), rhsVarMul);
2592     IS_IMM_WITH_VAL(Float, rhsVarMul->lhs(), 3);
2593     IS_VAR_WITH_NAME(rhsVarMul->rhs(), "y");
2594   }
2595 
2596   {
2597     // 3 * (3 * x) - 3 * (3 * y) => 3 * (3 * x) - (9 * y).
2598     // We will simplify subexprs if they dont reorder floating point ops.
2599     VarHandle x("x", kDouble);
2600     VarHandle y("y", kInt);
2601 
2602     ExprHandle body = ExprHandle(3) * (ExprHandle(3) * x) -
2603         ExprHandle(3) * (ExprHandle(3) * y);
2604     ExprHandle simplified = IRSimplifier::simplify(body);
2605 
2606     IS_NODE_WITH_NAME(Sub, simplified.node(), sub);
2607     IS_NODE_WITH_NAME(Mul, sub->lhs(), lhsMul);
2608     IS_IMM_WITH_VAL(Double, lhsMul->lhs(), 3);
2609     IS_NODE_WITH_NAME(Mul, lhsMul->rhs(), lhsVarMul);
2610     IS_IMM_WITH_VAL(Double, lhsVarMul->lhs(), 3);
2611     IS_VAR_WITH_NAME(lhsVarMul->rhs(), "x");
2612 
2613     IS_NODE_WITH_NAME_AND_CAST(Mul, sub->rhs(), rhsMul, Double);
2614     IS_IMM_WITH_VAL(Int, rhsMul->lhs(), 9);
2615     IS_VAR_WITH_NAME(rhsMul->rhs(), "y");
2616   }
2617 
2618   {
2619     // Prevent reordering if FP propagated from dtypes.
2620     VarHandle x("x", kInt);
2621     VarHandle y("y", kInt);
2622 
2623     ExprHandle body = ExprHandle(3.f) * (ExprHandle(3) * x) -
2624         ExprHandle(3) * (ExprHandle(3.f) * y);
2625     ExprHandle simplified = IRSimplifier::simplify(body);
2626 
2627     IS_NODE_WITH_NAME(Sub, simplified.node(), sub);
2628     IS_NODE_WITH_NAME(Mul, sub->lhs(), lhsMul);
2629     IS_IMM_WITH_VAL(Float, lhsMul->lhs(), 3);
2630     IS_NODE_WITH_NAME_AND_CAST(Mul, lhsMul->rhs(), lhsVarMul, Float);
2631     IS_IMM_WITH_VAL(Int, lhsVarMul->lhs(), 3);
2632     IS_VAR_WITH_NAME(lhsVarMul->rhs(), "x");
2633 
2634     IS_NODE_WITH_NAME(Mul, sub->rhs(), rhsMul);
2635     IS_IMM_WITH_VAL(Float, rhsMul->lhs(), 3);
2636     IS_NODE_WITH_NAME(Mul, rhsMul->rhs(), rhsVarMul);
2637     IS_IMM_WITH_VAL(Float, rhsVarMul->lhs(), 3);
2638     IS_NODE_WITH_NAME(Cast, rhsVarMul->rhs(), yCast);
2639     IS_VAR_WITH_NAME(yCast->src_value(), "y");
2640   }
2641 
2642   {
2643     VarHandle x("x", kFloat);
2644     VarHandle y("y", kFloat);
2645     // x%y - (x%y - 1) => x%y - (x%y - 1).
2646     // We wont reorder opaque ops if they are FP.
2647     ExprHandle body = (x % y) - ((x % y) - 1);
2648     ExprHandle simplified = IRSimplifier::simplify(body);
2649 
2650     IS_NODE_WITH_NAME(Sub, simplified.node(), sub);
2651     IS_NODE_WITH_NAME(Mod, sub->lhs(), lhsMod);
2652     IS_VAR_WITH_NAME(lhsMod->lhs(), "x");
2653     IS_VAR_WITH_NAME(lhsMod->rhs(), "y");
2654 
2655     IS_NODE_WITH_NAME(Sub, sub->rhs(), rhsSub);
2656     IS_NODE_WITH_NAME(Mod, rhsSub->lhs(), rhsMod);
2657     IS_VAR_WITH_NAME(rhsMod->lhs(), "x");
2658     IS_VAR_WITH_NAME(rhsMod->rhs(), "y");
2659     IS_IMM_WITH_VAL(Float, rhsSub->rhs(), 1);
2660   }
2661 }
2662 
TEST(Simplify,SimplifyRoundModPattern)2663 TEST(Simplify, SimplifyRoundModPattern) {
2664   {
2665     // (x/y)*y + x%y => x.
2666     VarHandle x("x", kInt);
2667     VarHandle y("y", kInt);
2668     ExprHandle body = ((x / y) * y) + (x % y);
2669     ExprHandle simplified = IRSimplifier::simplify(body);
2670     IS_VAR_WITH_NAME(simplified.node(), "x");
2671   }
2672 
2673   {
2674     // Reverse order.
2675     // x%y + (x/y)*y => x.
2676     VarHandle x("x", kInt);
2677     VarHandle y("y", kInt);
2678     ExprHandle body = (x % y) + ((x / y) * y);
2679     ExprHandle simplified = IRSimplifier::simplify(body);
2680     IS_VAR_WITH_NAME(simplified.node(), "x");
2681   }
2682 
2683   {
2684     // Non opaque denominator.
2685     // (x / (4+y)) * (4+y)) + (x % (y + 4)) => x.
2686     VarHandle x("x", kInt);
2687     VarHandle y("y", kInt);
2688     ExprHandle body = ((x / (ExprHandle(4) + y)) * (ExprHandle(4) + y)) +
2689         (x % (y + ExprHandle(4)));
2690     ExprHandle simplified = IRSimplifier::simplify(body);
2691     IS_VAR_WITH_NAME(simplified.node(), "x");
2692   }
2693 
2694   {
2695     // Reverse order.
2696     // (x % (y + 4)) + (x / (4+y)) * (4+y)) => x.
2697     VarHandle x("x", kInt);
2698     VarHandle y("y", kInt);
2699     ExprHandle body = (x % (y + ExprHandle(4))) +
2700         ((x / (ExprHandle(4) + y)) * (ExprHandle(4) + y));
2701     ExprHandle simplified = IRSimplifier::simplify(body);
2702     IS_VAR_WITH_NAME(simplified.node(), "x");
2703   }
2704 
2705   {
2706     // Opaque denominator.
2707     // (x / (2/y)) * (2/y)) + (x % (2/y)) => x.
2708     VarHandle x("x", kInt);
2709     VarHandle y("y", kInt);
2710     ExprHandle body = ((x / (ExprHandle(2) / y)) * (ExprHandle(2) / y)) +
2711         (x % (ExprHandle(2) / y));
2712     ExprHandle simplified = IRSimplifier::simplify(body);
2713     IS_VAR_WITH_NAME(simplified.node(), "x");
2714   }
2715 
2716   {
2717     // Non opaque numerator
2718     // ((2*x)/y * y) + ((2*x) % y) => 2 * x.
2719     VarHandle x("x", kInt);
2720     VarHandle y("y", kInt);
2721     ExprHandle body =
2722         (((ExprHandle(2) * x) / y) * y) + ((ExprHandle(2) * x) % y);
2723     ExprHandle simplified = IRSimplifier::simplify(body);
2724     IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
2725     IS_IMM_WITH_VAL(Int, mul->lhs(), 2);
2726     IS_VAR_WITH_NAME(mul->rhs(), "x");
2727   }
2728 
2729   {
2730     // Opaque numerator.
2731     // ((x/2) / y * y) + (x/2 % y) => x / 2.
2732     VarHandle x("x", kInt);
2733     VarHandle y("y", kInt);
2734     ExprHandle body =
2735         (((x / ExprHandle(2)) / y) * y) + ((x / ExprHandle(2)) % y);
2736     ExprHandle simplified = IRSimplifier::simplify(body);
2737 
2738     IS_NODE_WITH_NAME(Div, simplified.node(), div);
2739     IS_VAR_WITH_NAME(div->lhs(), "x");
2740     IS_IMM_WITH_VAL(Int, div->rhs(), 2);
2741   }
2742 
2743   {
2744     // Numerator and denominator.
2745     // ((2*x)/(2*y) * (2*y)) + ((2*x) % (2*y)) => 2 * x.
2746     VarHandle x("x", kInt);
2747     VarHandle y("y", kInt);
2748     ExprHandle body =
2749         (((ExprHandle(2) * x) / (ExprHandle(2) * y)) * (ExprHandle(2) * y)) +
2750         ((ExprHandle(2) * x) % (ExprHandle(2) * y));
2751     ExprHandle simplified = IRSimplifier::simplify(body);
2752 
2753     IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
2754     IS_IMM_WITH_VAL(Int, mul->lhs(), 2);
2755     IS_VAR_WITH_NAME(mul->rhs(), "x");
2756   }
2757 
2758   {
2759     // Reverse order.
2760     // ((2*x) % (2*y)) + ((2*x)/(2*y) * (2*y)) => 2 * x.
2761     VarHandle x("x", kInt);
2762     VarHandle y("y", kInt);
2763     ExprHandle body = ((ExprHandle(2) * x) % (ExprHandle(2) * y)) +
2764         (((ExprHandle(2) * x) / (ExprHandle(2) * y)) * (ExprHandle(2) * y));
2765     ExprHandle simplified = IRSimplifier::simplify(body);
2766 
2767     IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
2768     IS_IMM_WITH_VAL(Int, mul->lhs(), 2);
2769     IS_VAR_WITH_NAME(mul->rhs(), "x");
2770   }
2771 
2772   {
2773     // Negated Subtraction of Round Mod.
2774     // (x/y) * y - (0 - x%y) => x.
2775     VarHandle x("x", kInt);
2776     VarHandle y("y", kInt);
2777     ExprHandle body = ((x / y) * y) - (ExprHandle(0) - (x % y));
2778     ExprHandle simplified = IRSimplifier::simplify(body);
2779     IS_VAR_WITH_NAME(simplified.node(), "x");
2780   }
2781 
2782   {
2783     // Other terms are preserved.
2784     // (x/y)*y + x%y + (y * x) => x + (y * x).
2785     VarHandle x("x", kInt);
2786     VarHandle y("y", kInt);
2787     ExprHandle body = ((x / y) * y) + (x % y) + (y * x);
2788     ExprHandle simplified = IRSimplifier::simplify(body);
2789     IS_NODE_WITH_NAME(Add, simplified.node(), add);
2790     IS_VAR_WITH_NAME(add->lhs(), "x");
2791     IS_NODE_WITH_NAME(Mul, add->rhs(), mul);
2792     IS_VAR_WITH_NAME(mul->lhs(), "x");
2793     IS_VAR_WITH_NAME(mul->rhs(), "y");
2794   }
2795 
2796   {
2797     // Sanity checking we wont do the optimization on floats.
2798     VarHandle x("x", kFloat);
2799     VarHandle y("y", kFloat);
2800     ExprHandle body = ((x / y) * y) + (x % y);
2801     ExprHandle simplified = IRSimplifier::simplify(body);
2802     IS_NODE_WITH_NAME(Add, simplified.node(), add);
2803     IS_NODE_WITH_NAME(Mul, add->lhs(), roundMul);
2804     IS_NODE_WITH_NAME(Div, roundMul->lhs(), roundDiv);
2805     IS_VAR_WITH_NAME(roundDiv->lhs(), "x");
2806     IS_VAR_WITH_NAME(roundDiv->rhs(), "y");
2807     IS_VAR_WITH_NAME(roundMul->rhs(), "y");
2808     IS_NODE_WITH_NAME(Mod, add->rhs(), mod);
2809     IS_VAR_WITH_NAME(mod->lhs(), "x");
2810     IS_VAR_WITH_NAME(mod->rhs(), "y");
2811   }
2812 
2813   {
2814     // Sanity check we wont do it if the mod term doesn't match.
2815     VarHandle x("x", kInt);
2816     VarHandle y("y", kInt);
2817     VarHandle z("z", kInt);
2818     ExprHandle body = ((x / y) * y) + (x % z);
2819     ExprHandle simplified = IRSimplifier::simplify(body);
2820     checkExprIR(simplified, "(x / y) * y + x % z");
2821   }
2822 
2823   {
2824     // Sanity check we wont do it if the div term doesn't match.
2825     VarHandle x("x", kInt);
2826     VarHandle y("y", kInt);
2827     VarHandle z("z", kInt);
2828     ExprHandle body = (y * (x / z)) + (x % y);
2829     ExprHandle simplified = IRSimplifier::simplify(body);
2830     checkExprIR(simplified, "x % y + (x / z) * y");
2831   }
2832 
2833   {
2834     // Sanity check we wont do it if the mul term doesn't match.
2835     VarHandle x("x", kInt);
2836     VarHandle y("y", kInt);
2837     VarHandle z("z", kInt);
2838     ExprHandle body = ((x / y) * z) + (x % y);
2839     ExprHandle simplified = IRSimplifier::simplify(body);
2840     checkExprIR(simplified, "x % y + (x / y) * z");
2841   }
2842 }
2843 
TEST(Simplify,SimplifyRoundModPatternFactorization)2844 TEST(Simplify, SimplifyRoundModPatternFactorization) {
2845   {
2846     // Full factorization.
2847     // 2 * (x/y * y) + 2 * (x%y) => 2 * x.
2848     VarHandle x("x", kInt);
2849     VarHandle y("y", kInt);
2850     ExprHandle body = ExprHandle(2) * ((x / y) * y) + ExprHandle(2) * (x % y);
2851     ExprHandle simplified = IRSimplifier::simplify(body);
2852     IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
2853     IS_IMM_WITH_VAL(Int, mul->lhs(), 2);
2854     IS_VAR_WITH_NAME(mul->rhs(), "x");
2855   }
2856 
2857   {
2858     // Partial Factorization.
2859     // 32 * (x/8) + 4 * (x % 8) => 4 * x.
2860     VarHandle x("x", kInt);
2861     VarHandle y("y", kInt);
2862     // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks,cppcoreguidelines-avoid-magic-numbers)
2863     ExprHandle body = ExprHandle(32) * (x / 8) + ExprHandle(4) * (x % 8);
2864     ExprHandle simplified = IRSimplifier::simplify(body);
2865 
2866     IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
2867     IS_IMM_WITH_VAL(Int, mul->lhs(), 4);
2868     IS_VAR_WITH_NAME(mul->rhs(), "x");
2869   }
2870 
2871   {
2872     // Factorization requiring constant folding.
2873     // 20 * (x  / (16 / 2)) * 2 + (11 % 6) * (x % (7+1)) => 5 * x.
2874     VarHandle x("x", kInt);
2875     ExprHandle body = ExprHandle(40) * (x / (ExprHandle(16) / 2)) +
2876         (ExprHandle(11) % 6) * (x % (ExprHandle(7) + 1));
2877     ExprHandle simplified = IRSimplifier::simplify(body);
2878     IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
2879     IS_IMM_WITH_VAL(Int, mul->lhs(), 5);
2880     IS_VAR_WITH_NAME(mul->rhs(), "x");
2881   }
2882 
2883   {
2884     VarHandle x("x", kInt);
2885     ExprHandle body = (x / 5) * 10 + ExprHandle(2) * (x % 5);
2886     ExprHandle simplified = IRSimplifier::simplify(body);
2887     IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
2888     IS_IMM_WITH_VAL(Int, mul->lhs(), 2);
2889     IS_VAR_WITH_NAME(mul->rhs(), "x");
2890   }
2891 
2892   {
2893     VarHandle x("x", kInt);
2894     ExprHandle body = (x / 10) * 0 + x % 5;
2895     ExprHandle simplified = IRSimplifier::simplify(body);
2896     IS_NODE_WITH_NAME(Mod, simplified.node(), mod);
2897     IS_VAR_WITH_NAME(mod->lhs(), "x");
2898     IS_IMM_WITH_VAL(Int, mod->rhs(), 5);
2899   }
2900 }
2901 
TEST(Simplify,SimplifyRoundModPatternMultivar)2902 TEST(Simplify, SimplifyRoundModPatternMultivar) {
2903   {
2904     // Multivar.
2905     // (x/8) * 8 + (y/5)*5 + x%8 + y%5 => x + y.
2906     VarHandle x("x", kInt);
2907     VarHandle y("y", kInt);
2908     ExprHandle body = (x / ExprHandle(8) * ExprHandle(8)) +
2909         (y / ExprHandle(5) * ExprHandle(5)) + (x % 8) + (y % 5);
2910     ExprHandle simplified = IRSimplifier::simplify(body);
2911     IS_NODE_WITH_NAME(Add, simplified.node(), add);
2912     IS_VAR_WITH_NAME(add->lhs(), "x");
2913     IS_VAR_WITH_NAME(add->rhs(), "y");
2914   }
2915 
2916   {
2917     // Find the right var.
2918     // (y/8) * 8  x%8 + y%8 + z%8 => x%8 + y + z%8
2919     VarHandle x("x", kInt);
2920     VarHandle y("y", kInt);
2921     VarHandle z("z", kInt);
2922     ExprHandle body =
2923         (y / ExprHandle(8) * ExprHandle(8)) + (x % 8) + (y % 8) + (z % 8);
2924     ExprHandle simplified = IRSimplifier::simplify(body);
2925     IS_NODE_WITH_NAME(Add, simplified.node(), add);
2926     IS_NODE_WITH_NAME(Add, add->lhs(), add2);
2927     IS_NODE_WITH_NAME(Mod, add2->lhs(), xMod);
2928     IS_VAR_WITH_NAME(xMod->lhs(), "x");
2929     IS_IMM_WITH_VAL(Int, xMod->rhs(), 8);
2930     IS_VAR_WITH_NAME(add2->rhs(), "y");
2931     IS_NODE_WITH_NAME(Mod, add->rhs(), zMod);
2932     IS_VAR_WITH_NAME(zMod->lhs(), "z");
2933     IS_IMM_WITH_VAL(Int, zMod->rhs(), 8);
2934   }
2935 
2936   {
2937     // Compound.
2938     // (x + (z + 512 * y) % 16) + 16 * ((z + 512 * y) / 16)
2939     // => (z + 512 * y) + x
2940     VarHandle x("x", kInt);
2941     VarHandle y("y", kInt);
2942     VarHandle z("z", kInt);
2943 
2944     ExprHandle body = x + (z + y * 512) % 16 + ((z + y * 512) / 16 * 16);
2945     ExprHandle simplified = IRSimplifier::simplify(body);
2946     checkExprIR(simplified, "x + (z + 512 * y)");
2947   }
2948 }
2949 
TEST(Simplify,SimplifyModRoundModPattern)2950 TEST(Simplify, SimplifyModRoundModPattern) {
2951   {
2952     // t/7 % 9 * 7 + t % 7 => t%63
2953     VarHandle t("t", kInt);
2954     ExprHandle body = (t / 7 % 9) * 7 + t % 7;
2955     ExprHandle simplified = IRSimplifier::simplify(body);
2956     IS_NODE_WITH_NAME(Mod, simplified.node(), mod);
2957     IS_VAR_WITH_NAME(mod->lhs(), "t");
2958     IS_IMM_WITH_VAL(Int, mod->rhs(), 63);
2959   }
2960 
2961   {
2962     // 2*t/7 % 9 * 7 + 2*t % 7 => 2*t % 63
2963     VarHandle t("t", kInt);
2964     ExprHandle body = (ExprHandle(2) * t / 7 % 9) * 7 + ExprHandle(2) * t % 7;
2965     ExprHandle simplified = IRSimplifier::simplify(body);
2966     IS_NODE_WITH_NAME(Mod, simplified.node(), mod);
2967     IS_NODE_WITH_NAME(Mul, mod->lhs(), mul);
2968     IS_IMM_WITH_VAL(Int, mul->lhs(), 2);
2969     IS_VAR_WITH_NAME(mul->rhs(), "t");
2970     IS_IMM_WITH_VAL(Int, mod->rhs(), 63);
2971   }
2972 
2973   {
2974     // t/x % y * x + t % x => t%(x*y)
2975     VarHandle t("t", kInt);
2976     VarHandle x("x", kInt);
2977     VarHandle y("y", kInt);
2978     ExprHandle body = (t / x % y) * x + t % x;
2979     ExprHandle simplified = IRSimplifier::simplify(body);
2980     IS_NODE_WITH_NAME(Mod, simplified.node(), mod);
2981     IS_VAR_WITH_NAME(mod->lhs(), "t");
2982     IS_NODE_WITH_NAME(Mul, mod->rhs(), mul);
2983     IS_VAR_WITH_NAME(mul->lhs(), "x");
2984     IS_VAR_WITH_NAME(mul->rhs(), "y");
2985   }
2986 
2987   {
2988     // k*t/x % y * x + k*t % x => k*t%(x*y)
2989     VarHandle t("t", kInt);
2990     VarHandle x("x", kInt);
2991     VarHandle y("y", kInt);
2992     VarHandle k("k", kInt);
2993     ExprHandle body = (k * t / x % y) * x + k * t % x;
2994     ExprHandle simplified = IRSimplifier::simplify(body);
2995     checkExprIR(simplified, "(k * t) % (x * y)");
2996   }
2997 
2998   {
2999     // t/k/x % y * x + t/k % x => t/k%(x*y)
3000     VarHandle t("t", kInt);
3001     VarHandle x("x", kInt);
3002     VarHandle y("y", kInt);
3003     VarHandle k("k", kInt);
3004     ExprHandle body = (t / k / x % y) * x + t / k % x;
3005     ExprHandle simplified = IRSimplifier::simplify(body);
3006     IS_NODE_WITH_NAME(Mod, simplified.node(), mod);
3007     IS_NODE_WITH_NAME(Div, mod->lhs(), div);
3008     IS_VAR_WITH_NAME(div->lhs(), "t");
3009     IS_VAR_WITH_NAME(div->rhs(), "k");
3010     IS_NODE_WITH_NAME(Mul, mod->rhs(), mul);
3011     IS_VAR_WITH_NAME(mul->lhs(), "x");
3012     IS_VAR_WITH_NAME(mul->rhs(), "y");
3013   }
3014 
3015   {
3016     // Sanity checking we wont do the optimization on floats.
3017     VarHandle x("x", kFloat);
3018     VarHandle y("y", kFloat);
3019     VarHandle z("z", kFloat);
3020     ExprHandle body = ((x / y % z) * y) + (x % y);
3021     ExprHandle simplified = IRSimplifier::simplify(body);
3022     IS_NODE_WITH_NAME(Add, simplified.node(), add);
3023     IS_NODE_WITH_NAME(Mul, add->lhs(), mul);
3024     IS_NODE_WITH_NAME(Mod, mul->lhs(), mod);
3025     IS_NODE_WITH_NAME(Div, mod->lhs(), div);
3026     IS_VAR_WITH_NAME(div->lhs(), "x");
3027     IS_VAR_WITH_NAME(div->rhs(), "y");
3028     IS_VAR_WITH_NAME(mod->rhs(), "z");
3029     IS_VAR_WITH_NAME(mul->rhs(), "y");
3030     IS_NODE_WITH_NAME(Mod, add->rhs(), mod2);
3031     IS_VAR_WITH_NAME(mod2->lhs(), "x");
3032     IS_VAR_WITH_NAME(mod2->rhs(), "y");
3033   }
3034 }
3035 
TEST(Simplify,SimplifyModRoundModPatternFactorization)3036 TEST(Simplify, SimplifyModRoundModPatternFactorization) {
3037   {
3038     // 2 * (t /7 % 9 * 7) + 2 * (t % 7) => 2 * (t % 63)
3039     VarHandle t("t", kInt);
3040     ExprHandle body =
3041         ExprHandle(2) * ((t / 7 % 9) * 7) + ExprHandle(2) * (t % 7);
3042     ExprHandle simplified = IRSimplifier::simplify(body);
3043     IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
3044     IS_IMM_WITH_VAL(Int, mul->lhs(), 2);
3045     IS_NODE_WITH_NAME(Mod, mul->rhs(), mod);
3046     IS_VAR_WITH_NAME(mod->lhs(), "t");
3047     IS_IMM_WITH_VAL(Int, mod->rhs(), 63);
3048   }
3049 
3050   {
3051     // t /7 % 9 * 14 + 2* (t % 7) => 2* (t % 63)
3052     VarHandle t("t", kInt);
3053     ExprHandle body = (t / 7 % 9) * 14 + ExprHandle(2) * (t % 7);
3054     ExprHandle simplified = IRSimplifier::simplify(body);
3055     IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
3056     IS_IMM_WITH_VAL(Int, mul->lhs(), 2);
3057     IS_NODE_WITH_NAME(Mod, mul->rhs(), mod);
3058     IS_VAR_WITH_NAME(mod->lhs(), "t");
3059     IS_IMM_WITH_VAL(Int, mod->rhs(), 63);
3060   }
3061 
3062   {
3063     // t/14 % 9 * 7 + t/2 % 7 => t/2 % 63
3064     VarHandle t("t", kInt);
3065     ExprHandle body = (t / 14 % 9) * 7 + t / 2 % 7;
3066     ExprHandle simplified = IRSimplifier::simplify(body);
3067     IS_NODE_WITH_NAME(Mod, simplified.node(), mod);
3068     IS_NODE_WITH_NAME(Div, mod->lhs(), div);
3069     IS_VAR_WITH_NAME(div->lhs(), "t");
3070     IS_IMM_WITH_VAL(Int, div->rhs(), 2);
3071     IS_IMM_WITH_VAL(Int, mod->rhs(), 63);
3072   }
3073 
3074   {
3075     // t/(7*3) % 9 * 7*3 + t % (7*3) => t % 189
3076     VarHandle t("t", kInt);
3077     ExprHandle body = (t / (ExprHandle(7) * ExprHandle(3)) % 9) * 7 * 3 +
3078         t % (ExprHandle(7) * ExprHandle(3));
3079     ExprHandle simplified = IRSimplifier::simplify(body);
3080     IS_NODE_WITH_NAME(Mod, simplified.node(), mod);
3081     IS_VAR_WITH_NAME(mod->lhs(), "t");
3082     IS_IMM_WITH_VAL(Int, mod->rhs(), 189);
3083   }
3084 
3085   {
3086     // 2*(t/x % y * x) + 2*(t % x) => 2*(t%(x*y))
3087     VarHandle t("t", kInt);
3088     VarHandle x("x", kInt);
3089     VarHandle y("y", kInt);
3090     ExprHandle body =
3091         ExprHandle(2) * ((t / x % y) * x) + ExprHandle(2) * (t % x);
3092     ExprHandle simplified = IRSimplifier::simplify(body);
3093     IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
3094     IS_IMM_WITH_VAL(Int, mul->lhs(), 2);
3095     IS_NODE_WITH_NAME(Mod, mul->rhs(), mod);
3096     IS_VAR_WITH_NAME(mod->lhs(), "t");
3097     IS_NODE_WITH_NAME(Mul, mod->rhs(), mul2);
3098     IS_VAR_WITH_NAME(mul2->lhs(), "x");
3099     IS_VAR_WITH_NAME(mul2->rhs(), "y");
3100   }
3101 }
3102 
TEST(Simplify,SimplifyModRoundModPatternMultivar)3103 TEST(Simplify, SimplifyModRoundModPatternMultivar) {
3104   {
3105     // t/7 % 9 * 7 + t % 7 + t => t % 63 + t
3106     VarHandle t("t", kInt);
3107     ExprHandle body = (t / 7 % 9) * 7 + t % 7 + t;
3108     ExprHandle simplified = IRSimplifier::simplify(body);
3109     checkExprIR(simplified, "t % 63 + t");
3110   }
3111 
3112   {
3113     // t/7 % 9 * 7 + t/8 % 9 * 8 + t % 7 + t % 8  => t % 63 + t % 72
3114     VarHandle t("t", kInt);
3115     ExprHandle body = (t / 7 % 9) * 7 + (t / 8 % 9) * 8 + t % 7 + t % 8;
3116     ExprHandle simplified = IRSimplifier::simplify(body);
3117     IS_NODE_WITH_NAME(Add, simplified.node(), add);
3118     IS_NODE_WITH_NAME(Mod, add->lhs(), mod1);
3119     IS_VAR_WITH_NAME(mod1->lhs(), "t");
3120     IS_IMM_WITH_VAL(Int, mod1->rhs(), 63);
3121     IS_NODE_WITH_NAME(Mod, add->rhs(), mod2);
3122     IS_VAR_WITH_NAME(mod2->lhs(), "t");
3123     IS_IMM_WITH_VAL(Int, mod2->rhs(), 72);
3124   }
3125 
3126   {
3127     // k + t/x % y * x + t % x => k + t%(x*y)
3128     VarHandle t("t", kInt);
3129     VarHandle x("x", kInt);
3130     VarHandle y("y", kInt);
3131     VarHandle k("k", kInt);
3132     ExprHandle body = k + (t / x % y) * x + t % x;
3133     ExprHandle simplified = IRSimplifier::simplify(body);
3134     IS_NODE_WITH_NAME(Add, simplified.node(), add);
3135     IS_VAR_WITH_NAME(add->lhs(), "k");
3136     IS_NODE_WITH_NAME(Mod, add->rhs(), mod);
3137     IS_VAR_WITH_NAME(mod->lhs(), "t");
3138     IS_NODE_WITH_NAME(Mul, mod->rhs(), mul);
3139     IS_VAR_WITH_NAME(mul->lhs(), "x");
3140     IS_VAR_WITH_NAME(mul->rhs(), "y");
3141   }
3142 
3143   {
3144     // t/x % y * x + t % x + (t/k / x % y) * x + t/k % x
3145     // => t%(x*y) + t/k % (x*y)
3146     VarHandle t("t", kInt);
3147     VarHandle x("x", kInt);
3148     VarHandle y("y", kInt);
3149     VarHandle k("k", kInt);
3150     ExprHandle body = (t / x % y) * x + t % x + (t / k / x % y) * x + t / k % x;
3151     ExprHandle simplified = IRSimplifier::simplify(body);
3152     checkExprIR(simplified, "(t / k) % (x * y) + t % (x * y)");
3153   }
3154 
3155   {
3156     // 3D: (7 * ((i0_flat / 7) % 9) + i0_flat % 7) + 63 * (i0_flat / 63)
3157     // => io_flat
3158     VarHandle t("io_flat", kInt);
3159     ExprHandle body =
3160         ExprHandle(7) * (t / 7 % 9) + t % 7 + ExprHandle(63) * (t / 63);
3161     ExprHandle simplified = IRSimplifier::simplify(body);
3162     IS_VAR_WITH_NAME(simplified.node(), "io_flat");
3163   }
3164 
3165   { // 5D: i0_flat / (11 * 10 * 9 * 7)  * (7 * 9 * 10 * 11) +
3166     // (i0_flat / (10 * 9 * 7) % 11)  * 7 * 9 * 10 +
3167     // (i0_flat / (9 * 7) % 10) * 7 * 9 +
3168     // (i0_flat / 7 % 9)  * 7 +
3169     // i0_flat % 7 => io_flat
3170     VarHandle t("io_flat", kInt);
3171     ExprHandle body = (t / (ExprHandle(11) * 10 * 9 * 7)) * (7 * 9 * 10 * 11) +
3172         (t / (ExprHandle(10) * 9 * 7) % 11) * 7 * 9 * 10 +
3173         (t / (ExprHandle(9) * 7) % 10) * 7 * 9 + (t / 7 % 9) * 7 + t % 7;
3174     ExprHandle simplified = IRSimplifier::simplify(body);
3175     IS_VAR_WITH_NAME(simplified.node(), "io_flat");
3176   }
3177 
3178   {
3179     // 3D: (m * ((i0_flat / m) % n) + i0_flat % m) + (m * n) *
3180     // (i0_flat / (m * n)) => io_flat
3181     VarHandle t("io_flat", kInt);
3182     VarHandle m("m", kInt);
3183     VarHandle n("n", kInt);
3184     ExprHandle body = m * (t / m % n) + t % m + (m * n) * (t / (m * n));
3185     ExprHandle simplified = IRSimplifier::simplify(body);
3186     IS_VAR_WITH_NAME(simplified.node(), "io_flat");
3187   }
3188 
3189   { // 5D: i0_flat / (k * l * n * m)  * (m * n * l * k) +
3190     // (i0_flat / (l * n * m) % k)  * m * n * l +
3191     // (i0_flat / (n * m) % l) * m * n +
3192     // (i0_flat / m % n)  * m +
3193     // i0_flat % m => io_flat
3194     VarHandle t("io_flat", kInt);
3195     VarHandle m("m", kInt);
3196     VarHandle n("n", kInt);
3197     VarHandle l("l", kInt);
3198     VarHandle k("k", kInt);
3199     ExprHandle body = (t / (k * l * n * m)) * (m * n * l * k) +
3200         (t / (l * n * m) % k) * m * n * l + (t / (n * m) % l) * m * n +
3201         (t / m % n) * m + t % m;
3202     ExprHandle simplified = IRSimplifier::simplify(body);
3203     IS_VAR_WITH_NAME(simplified.node(), "io_flat");
3204   }
3205 }
3206 
TEST(Simplify,SimplifyDivisionScalarFactorization)3207 TEST(Simplify, SimplifyDivisionScalarFactorization) {
3208   {
3209     // Simple factorization of numerator and denominator.
3210     // 8x / 4y => 2x / y.
3211     VarHandle x("x", kInt);
3212     VarHandle y("y", kInt);
3213     ExprHandle body = (x * 8) / (y * 4);
3214     ExprHandle simplified = IRSimplifier::simplify(body);
3215     IS_NODE_WITH_NAME(Div, simplified.node(), div);
3216     IS_NODE_WITH_NAME(Mul, div->lhs(), lhs);
3217     IS_IMM_WITH_VAL(Int, lhs->lhs(), 2);
3218     IS_VAR_WITH_NAME(lhs->rhs(), "x");
3219     IS_VAR_WITH_NAME(div->rhs(), "y");
3220   }
3221 
3222   {
3223     // Don't change anything if we can't factorize.
3224     VarHandle x("x", kInt);
3225     VarHandle y("y", kInt);
3226     ExprHandle body = (x * 7) / (y * 4);
3227     ExprHandle simplified = IRSimplifier::simplify(body);
3228     IS_NODE_WITH_NAME(Div, simplified.node(), div);
3229     IS_NODE_WITH_NAME(Mul, div->lhs(), lhs);
3230     IS_IMM_WITH_VAL(Int, lhs->lhs(), 7);
3231     IS_VAR_WITH_NAME(lhs->rhs(), "x");
3232     IS_NODE_WITH_NAME(Mul, div->rhs(), rhs);
3233     IS_IMM_WITH_VAL(Int, rhs->lhs(), 4);
3234     IS_VAR_WITH_NAME(rhs->rhs(), "y");
3235   }
3236 
3237   {
3238     // Don't reorder floats.
3239     VarHandle x("x", kFloat);
3240     VarHandle y("y", kFloat);
3241     ExprHandle body = (x * 8) / (y * 4);
3242     ExprHandle simplified = IRSimplifier::simplify(body);
3243     IS_NODE_WITH_NAME(Div, simplified.node(), div);
3244     IS_NODE_WITH_NAME(Mul, div->lhs(), lhs);
3245     IS_VAR_WITH_NAME(lhs->lhs(), "x");
3246     IS_IMM_WITH_VAL(Float, lhs->rhs(), 8.f);
3247     IS_NODE_WITH_NAME(Mul, div->rhs(), rhs);
3248     IS_VAR_WITH_NAME(rhs->lhs(), "y");
3249     IS_IMM_WITH_VAL(Float, rhs->rhs(), 4.f);
3250   }
3251 
3252   {
3253     // Sanity check we do nothing if there are only scalar parts.
3254     VarHandle x("x", kInt);
3255     VarHandle y("y", kInt);
3256     ExprHandle body = (x * 1) / (y * 1);
3257     ExprHandle simplified = IRSimplifier::simplify(body);
3258     IS_NODE_WITH_NAME(Div, simplified.node(), div);
3259     IS_VAR_WITH_NAME(div->lhs(), "x");
3260     IS_VAR_WITH_NAME(div->rhs(), "y");
3261   }
3262 
3263   {
3264     // Can factorize amounts of variables.
3265     VarHandle x("x", kInt);
3266     VarHandle y("y", kInt);
3267     ExprHandle body = (x + x + x + x) / (y + y);
3268     ExprHandle simplified = IRSimplifier::simplify(body);
3269     IS_NODE_WITH_NAME(Div, simplified.node(), div);
3270     IS_NODE_WITH_NAME(Mul, div->lhs(), lhs);
3271     IS_IMM_WITH_VAL(Int, lhs->lhs(), 2);
3272     IS_VAR_WITH_NAME(lhs->rhs(), "x");
3273     IS_VAR_WITH_NAME(div->rhs(), "y");
3274   }
3275 }
3276 
TEST(Simplify,SimplifyConstantBranches)3277 TEST(Simplify, SimplifyConstantBranches) {
3278   {
3279     // If the condition is constant true then take the true_value.
3280     // 1 ? x : y => x
3281     VarHandle x("x", kInt);
3282     VarHandle y("y", kInt);
3283     ExprHandle t(1);
3284     ExprHandle body = IfThenElse::make(t, x, y);
3285     ExprHandle simplified = IRSimplifier::simplify(body);
3286     IS_VAR_WITH_NAME(simplified.node(), "x");
3287   }
3288 
3289   {
3290     // If the condition is constant false then take the false_value.
3291     // 0 ? x : y => y
3292     VarHandle x("x", kInt);
3293     VarHandle y("y", kInt);
3294     ExprHandle t(0);
3295     ExprHandle body = IfThenElse::make(t, x, y);
3296     ExprHandle simplified = IRSimplifier::simplify(body);
3297     IS_VAR_WITH_NAME(simplified.node(), "y");
3298   }
3299 
3300   {
3301     // condition is simplified before checking.
3302     // (x-x) ? x : y => y
3303     VarHandle x("x", kInt);
3304     VarHandle y("y", kInt);
3305     ExprHandle body = IfThenElse::make(x - x, x, y);
3306     ExprHandle simplified = IRSimplifier::simplify(body);
3307     IS_VAR_WITH_NAME(simplified.node(), "y");
3308   }
3309 
3310   {
3311     // If both branches are the same then don't do the condition.
3312     // y ? x : x => x
3313     VarHandle x("x", kInt);
3314     VarHandle y("y", kInt);
3315     ExprHandle body = IfThenElse::make(y, x, x);
3316     ExprHandle simplified = IRSimplifier::simplify(body);
3317     IS_VAR_WITH_NAME(simplified.node(), "x");
3318   }
3319 
3320   {
3321     // If both branches simplify to the same thing it still works.
3322     // y ? (x + x) : (2 * x) => x
3323     VarHandle x("x", kInt);
3324     VarHandle y("y", kInt);
3325     ExprHandle body = IfThenElse::make(y, x + x, ExprHandle(2) * x);
3326     ExprHandle simplified = IRSimplifier::simplify(body);
3327     IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
3328     IS_IMM_WITH_VAL(Int, mul->lhs(), 2);
3329     IS_VAR_WITH_NAME(mul->rhs(), "x");
3330   }
3331 }
3332 
TEST(Simplify,SimplifyConstantCond)3333 TEST(Simplify, SimplifyConstantCond) {
3334   {
3335     // If the condition is constant true then take the true_value.
3336     // 1 ? A[0] = 1 : B[0] = 1 => A[0] = 1
3337     BufHandle a("A", {1}, kInt);
3338     BufHandle b("B", {1}, kInt);
3339     ExprHandle condition(1);
3340     StmtPtr true_val = Store::make(a, {0}, 1);
3341     StmtPtr false_val = Store::make(b, {0}, 1);
3342 
3343     CondPtr body = alloc<Cond>(condition.node(), true_val, false_val);
3344     StmtPtr simplified = IRSimplifier::simplify(body);
3345     BlockPtr block = to<Block>(simplified);
3346     IS_NODE_WITH_NAME(Store, block->front(), store);
3347     IS_VAR_WITH_NAME(store->base_handle(), "A");
3348   }
3349 
3350   {
3351     // If the condition is constant false then take the false_value.
3352     // 0 ? A[0] = 1 : B[0] = 1 => B[0] = 1
3353     BufHandle a("A", {1}, kInt);
3354     BufHandle b("B", {1}, kInt);
3355     ExprHandle condition(0);
3356     StmtPtr true_val = Store::make(a, {0}, 1);
3357     StmtPtr false_val = Store::make(b, {0}, 1);
3358 
3359     StmtPtr body = alloc<Cond>(condition.node(), true_val, false_val);
3360     StmtPtr simplified = IRSimplifier::simplify(body);
3361     BlockPtr block = to<Block>(simplified);
3362     IS_NODE_WITH_NAME(Store, block->front(), store);
3363     IS_VAR_WITH_NAME(store->base_handle(), "B");
3364   }
3365 
3366   {
3367     // condition is simplified before checking.
3368     // (x-x) ? A[0] = 1 : B[0] = 1 => B[0] = 1
3369     VarHandle x("x", kInt);
3370     BufHandle a("A", {1}, kInt);
3371     BufHandle b("B", {1}, kInt);
3372     ExprHandle condition(x - x);
3373     StmtPtr true_val = Store::make(a, {0}, 1);
3374     StmtPtr false_val = Store::make(b, {0}, 1);
3375 
3376     StmtPtr body = alloc<Cond>(condition.node(), true_val, false_val);
3377     StmtPtr simplified = IRSimplifier::simplify(body);
3378     BlockPtr block = to<Block>(simplified);
3379     IS_NODE_WITH_NAME(Store, block->front(), store);
3380     IS_VAR_WITH_NAME(store->base_handle(), "B");
3381   }
3382 
3383   {
3384     // If both branches are the same then don't do the condition.
3385     // x ? A[0] = x : A[0] = x => A[0] = x
3386     VarHandle x("x", kInt);
3387     BufHandle a("A", {1}, kInt);
3388     ExprHandle condition(x - x);
3389     StmtPtr true_val = Store::make(a, {0}, x);
3390     StmtPtr false_val = Store::make(a, {0}, x);
3391 
3392     StmtPtr body = alloc<Cond>(condition.node(), true_val, false_val);
3393     StmtPtr simplified = IRSimplifier::simplify(body);
3394     BlockPtr block = to<Block>(simplified);
3395     IS_NODE_WITH_NAME(Store, block->front(), store);
3396     IS_VAR_WITH_NAME(store->base_handle(), "A");
3397   }
3398 
3399   {
3400     // If both branches simplify to the same thing it still works.
3401     // x ? (x + x) : (2 * x) => x
3402     VarHandle x("x", kInt);
3403     BufHandle a("A", {1}, kInt);
3404     ExprHandle condition(x - x);
3405     StmtPtr true_val = Store::make(a, {0}, ExprHandle(2) * x);
3406     StmtPtr false_val = Store::make(a, {0}, x + x);
3407 
3408     StmtPtr body = alloc<Cond>(condition.node(), true_val, false_val);
3409     StmtPtr simplified = IRSimplifier::simplify(body);
3410     BlockPtr block = to<Block>(simplified);
3411     IS_NODE_WITH_NAME(Store, block->front(), store);
3412     IS_VAR_WITH_NAME(store->base_handle(), "A");
3413   }
3414 
3415   {
3416     // But not if they dont
3417     // x ? x : (2 * x) => x ? x : (2 * x)
3418     VarHandle x("x", kInt);
3419     BufHandle a("A", {1}, kInt);
3420     ExprHandle condition(x);
3421     StmtPtr true_val = Store::make(a, {0}, x);
3422     StmtPtr false_val = Store::make(a, {0}, ExprHandle(2) * x);
3423 
3424     StmtPtr body = alloc<Cond>(condition.node(), true_val, false_val);
3425     StmtPtr simplified = IRSimplifier::simplify(body);
3426     BlockPtr block = to<Block>(simplified);
3427     ASSERT_EQ(block, nullptr);
3428   }
3429 
3430   {
3431     StmtPtr cond = alloc<Cond>(
3432         ExprHandle(false).node(),
3433         alloc<Block>(std::vector<StmtPtr>({})),
3434         nullptr);
3435     StmtPtr simplified = IRSimplifier::simplify(cond);
3436     ASSERT_EQ(simplified, nullptr);
3437   }
3438 
3439   {
3440     StmtPtr cond = alloc<Cond>(
3441         ExprHandle(true).node(),
3442         nullptr,
3443         alloc<Block>(std::vector<StmtPtr>({})));
3444     StmtPtr simplified = IRSimplifier::simplify(cond);
3445     ASSERT_EQ(simplified, nullptr);
3446   }
3447 }
3448 
TEST(Simplify,SimplifyEliminateEmptyCond)3449 TEST(Simplify, SimplifyEliminateEmptyCond) {
3450   // If the branches are empty in different ways, eliminate.
3451   {
3452     VarHandle x("x", kInt);
3453     ExprHandle condition(x);
3454     StmtPtr true_val = alloc<Block>(std::vector<StmtPtr>({}));
3455 
3456     StmtPtr body = alloc<Cond>(condition.node(), true_val, nullptr);
3457     StmtPtr simplified = IRSimplifier::simplify(body);
3458     BlockPtr block = to<Block>(simplified);
3459     ASSERT_NE(block, nullptr);
3460     ASSERT_EQ(block->nstmts(), 0);
3461   }
3462 
3463   {
3464     VarHandle x("x", kInt);
3465     ExprHandle condition(x);
3466     StmtPtr false_val = alloc<Block>(std::vector<StmtPtr>({}));
3467 
3468     StmtPtr body = alloc<Cond>(condition.node(), nullptr, false_val);
3469     StmtPtr simplified = IRSimplifier::simplify(body);
3470     BlockPtr block = to<Block>(simplified);
3471     ASSERT_NE(block, nullptr);
3472     ASSERT_EQ(block->nstmts(), 0);
3473   }
3474 }
3475 
TEST(Simplify,SimplifyConstantComparisons)3476 TEST(Simplify, SimplifyConstantComparisons) {
3477   auto ComparisonTest =
3478       [](ExprHandle a, ExprHandle b, CompareSelectOperation op, int result) {
3479         ExprHandle body = CompareSelect::make(a, b, op);
3480         ExprHandle simplified = IRSimplifier::simplify(body);
3481         IS_IMM_WITH_VAL(Int, simplified.node(), result);
3482       };
3483 
3484   // Equals.
3485   ComparisonTest(2, 2, kEQ, 1);
3486   ComparisonTest(1, 2, kEQ, 0);
3487   ComparisonTest(2, 1, kEQ, 0);
3488 
3489   // Greater than.
3490   ComparisonTest(2, 2, kGT, 0);
3491   ComparisonTest(1, 2, kGT, 0);
3492   ComparisonTest(2, 1, kGT, 1);
3493 
3494   // Greater or Equal.
3495   ComparisonTest(2, 2, kGE, 1);
3496   ComparisonTest(1, 2, kGE, 0);
3497   ComparisonTest(2, 1, kGE, 1);
3498 
3499   // Less Than.
3500   ComparisonTest(2, 2, kLT, 0);
3501   ComparisonTest(1, 2, kLT, 1);
3502   ComparisonTest(2, 1, kLT, 0);
3503 
3504   // Less or Equal.
3505   ComparisonTest(2, 2, kLE, 1);
3506   ComparisonTest(1, 2, kLE, 1);
3507   ComparisonTest(2, 1, kLE, 0);
3508 
3509   // Not equal.
3510   ComparisonTest(2, 2, kNE, 0);
3511   ComparisonTest(1, 2, kNE, 1);
3512   ComparisonTest(2, 1, kNE, 1);
3513 
3514   // With specified results:
3515   ExprHandle body = CompareSelect::make(2, 2, 5, 42, kNE);
3516   ExprHandle simplified = IRSimplifier::simplify(body);
3517   IS_IMM_WITH_VAL(Int, simplified.node(), 42);
3518 }
3519 
TEST(Simplify,SimplifySymbolicComparisons)3520 TEST(Simplify, SimplifySymbolicComparisons) {
3521   VarHandle x("x", kInt);
3522   VarHandle y("y", kInt);
3523 
3524   auto TookTrueBranch = [](ExprHandle a) { IS_IMM_WITH_VAL(Int, a.node(), 1); };
3525   auto TookFalseBranch = [](ExprHandle a) {
3526     IS_IMM_WITH_VAL(Int, a.node(), 0);
3527   };
3528 
3529   // EQ
3530 
3531   // x == x => 1
3532   ExprHandle body = CompareSelect::make(x, x, kEQ);
3533   TookTrueBranch(IRSimplifier::simplify(body));
3534 
3535   // x == x+1 => 0
3536   body = CompareSelect::make(x, x + 1, kEQ);
3537   TookFalseBranch(IRSimplifier::simplify(body));
3538 
3539   // x == x * 2 cannot simplify since we don't know x is nonzero.
3540   body = CompareSelect::make(x, x * 2, kEQ);
3541   // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
3542   IS_NODE(CompareSelect, IRSimplifier::simplify(body).node());
3543 
3544   // x == x * 1 => 1
3545   body = CompareSelect::make(x, x * 1, kEQ);
3546   TookTrueBranch(IRSimplifier::simplify(body));
3547 
3548   {
3549     // x == y => x == y
3550     body = CompareSelect::make(x, y, kEQ);
3551     ExprHandle simplified = IRSimplifier::simplify(body);
3552     IS_NODE_WITH_NAME(CompareSelect, simplified.node(), cmp);
3553     ASSERT_EQ(cmp->compare_select_op(), kEQ);
3554     IS_VAR_WITH_NAME(cmp->lhs(), "x");
3555     IS_VAR_WITH_NAME(cmp->rhs(), "y");
3556   }
3557 
3558   {
3559     // x == 5 => x == 5
3560     body = CompareSelect::make(x, 5, kEQ);
3561     ExprHandle simplified = IRSimplifier::simplify(body);
3562     IS_NODE_WITH_NAME(CompareSelect, simplified.node(), cmp);
3563     ASSERT_EQ(cmp->compare_select_op(), kEQ);
3564     IS_VAR_WITH_NAME(cmp->lhs(), "x");
3565     IS_IMM_WITH_VAL(Int, cmp->rhs(), 5);
3566   }
3567 
3568   // GT
3569 
3570   // x+1 > x => 1
3571   body = CompareSelect::make(x + 1, x, kGT);
3572   TookTrueBranch(IRSimplifier::simplify(body));
3573 
3574   // x > x + 1 => 0
3575   body = CompareSelect::make(x, x + 1, kGT);
3576   TookFalseBranch(IRSimplifier::simplify(body));
3577 
3578   // x > x - 1 => 1
3579   body = CompareSelect::make(x, x - 1, kGT);
3580   TookTrueBranch(IRSimplifier::simplify(body));
3581 
3582   // x - 1 > x => 0
3583   body = CompareSelect::make(x - 1, x, kGT);
3584   TookFalseBranch(IRSimplifier::simplify(body));
3585 
3586   // x > x => 0
3587   body = CompareSelect::make(x, x, kGT);
3588   TookFalseBranch(IRSimplifier::simplify(body));
3589 
3590   // x * 2 > x => x * 2 > x
3591   // since we don't know the sign of x.
3592   body = CompareSelect::make(x * 2, x, kGT);
3593   IS_NODE(CompareSelect, IRSimplifier::simplify(body).node());
3594 
3595   // GE
3596 
3597   // x+1 >= x => 1
3598   body = CompareSelect::make(x + 1, x, kGE);
3599   TookTrueBranch(IRSimplifier::simplify(body));
3600 
3601   // x >= x + 1 => 0
3602   body = CompareSelect::make(x, x + 1, kGE);
3603   TookFalseBranch(IRSimplifier::simplify(body));
3604 
3605   // x >= x => 1
3606   body = CompareSelect::make(x, x, kGE);
3607   TookTrueBranch(IRSimplifier::simplify(body));
3608 
3609   // x * 2 >= x => x * 2 >= x
3610   // since we don't know the sign of x.
3611   body = CompareSelect::make(x * 2, x, kGE);
3612   IS_NODE(CompareSelect, IRSimplifier::simplify(body).node());
3613 
3614   // LT
3615 
3616   // x+1 < x => 0
3617   body = CompareSelect::make(x + 1, x, kLT);
3618   TookFalseBranch(IRSimplifier::simplify(body));
3619 
3620   // x < x + 1 => 1
3621   body = CompareSelect::make(x, x + 1, kLT);
3622   TookTrueBranch(IRSimplifier::simplify(body));
3623 
3624   // x < x => 0
3625   body = CompareSelect::make(x, x, kLT);
3626   TookFalseBranch(IRSimplifier::simplify(body));
3627 
3628   // LE
3629 
3630   // x+1 <= x => 0
3631   body = CompareSelect::make(x + 1, x, kLE);
3632   TookFalseBranch(IRSimplifier::simplify(body));
3633 
3634   // x <= x + 1 => 1
3635   body = CompareSelect::make(x, x + 1, kLE);
3636   TookTrueBranch(IRSimplifier::simplify(body));
3637 
3638   // x <= x => 1
3639   body = CompareSelect::make(x, x, kLE);
3640   TookTrueBranch(IRSimplifier::simplify(body));
3641 
3642   // NE
3643 
3644   // x+1 != x => 1
3645   body = CompareSelect::make(x + 1, x, kNE);
3646   TookTrueBranch(IRSimplifier::simplify(body));
3647 
3648   // x != x + 1 => 1
3649   body = CompareSelect::make(x, x + 1, kNE);
3650   TookTrueBranch(IRSimplifier::simplify(body));
3651 
3652   // x != x => 0
3653   body = CompareSelect::make(x, x, kNE);
3654   TookFalseBranch(IRSimplifier::simplify(body));
3655 }
3656 
TEST(Simplify,SimplifyEliminateZeroLengthFor)3657 TEST(Simplify, SimplifyEliminateZeroLengthFor) {
3658   {
3659     // Will eliminate zero loop For.
3660     BufHandle a("A", {4}, kInt);
3661     BufHandle c("C", {4}, kInt);
3662     VarHandle i("i", kInt);
3663     auto body = For::make(i, 0, 0, Store::make(c, {i}, Load::make(a, {i})));
3664     StmtPtr simplified = IRSimplifier::simplify(body);
3665     BlockPtr block = to<Block>(simplified);
3666     ASSERT_EQ(block->nstmts(), 0);
3667   }
3668 
3669   {
3670     // still works if start is not zero.
3671     BufHandle a("A", {4}, kInt);
3672     BufHandle c("C", {4}, kInt);
3673     VarHandle i("i", kInt);
3674     auto body = For::make(i, 2, 2, Store::make(c, {i}, Load::make(a, {i})));
3675     StmtPtr simplified = IRSimplifier::simplify(body);
3676     BlockPtr block = to<Block>(simplified);
3677     ASSERT_EQ(block->nstmts(), 0);
3678   }
3679 
3680   {
3681     // works if both terms are variable.
3682     VarHandle x("x", kInt);
3683     BufHandle a("A", {4}, kInt);
3684     BufHandle c("C", {4}, kInt);
3685     VarHandle i("i", kInt);
3686     auto body = For::make(i, x, x, Store::make(c, {i}, Load::make(a, {i})));
3687     StmtPtr simplified = IRSimplifier::simplify(body);
3688     BlockPtr block = to<Block>(simplified);
3689     ASSERT_EQ(block->nstmts(), 0);
3690   }
3691 
3692   {
3693     // works if one term simplifies down.
3694     VarHandle x("x", kInt);
3695     BufHandle a("A", {4}, kInt);
3696     BufHandle c("C", {4}, kInt);
3697     VarHandle i("i", kInt);
3698     auto body = For::make(i, 0, x - x, Store::make(c, {i}, Load::make(a, {i})));
3699     StmtPtr simplified = IRSimplifier::simplify(body);
3700     BlockPtr block = to<Block>(simplified);
3701     ASSERT_EQ(block->nstmts(), 0);
3702   }
3703 
3704   {
3705     // Sanity check does nothing if the condition is not met.
3706     BufHandle a("A", {4}, kInt);
3707     BufHandle c("C", {4}, kInt);
3708     VarHandle i("i", kInt);
3709     auto body = For::make(i, 0, 3, Store::make(c, {i}, Load::make(a, {i})));
3710     StmtPtr simplified = IRSimplifier::simplify(body);
3711     IS_NODE(For, simplified);
3712   }
3713 }
3714 
TEST(Simplify,SimplifyOneLoopFor)3715 TEST(Simplify, SimplifyOneLoopFor) {
3716   {
3717     // Will remove the loop if the body is run once.
3718     BufHandle a("A", {4}, kInt);
3719     BufHandle c("C", {4}, kInt);
3720     VarHandle i("i", kInt);
3721     auto body = For::make(i, 0, 1, Store::make(c, {i}, Load::make(a, {i})));
3722     StmtPtr simplified = IRSimplifier::simplify(body);
3723     BlockPtr block = to<Block>(simplified);
3724     IS_NODE_WITH_NAME(Store, block->front(), store);
3725     IS_VAR_WITH_NAME(store->base_handle(), "C");
3726     IS_IMM_WITH_VAL(Int, store->flat_index(), 0);
3727   }
3728 
3729   {
3730     // still works if start is not zero.
3731     BufHandle a("A", {4}, kInt);
3732     BufHandle c("C", {4}, kInt);
3733     VarHandle i("i", kInt);
3734     auto body = For::make(i, 2, 3, Store::make(c, {i}, Load::make(a, {i})));
3735     StmtPtr simplified = IRSimplifier::simplify(body);
3736     BlockPtr block = to<Block>(simplified);
3737     IS_NODE_WITH_NAME(Store, block->front(), store);
3738     IS_VAR_WITH_NAME(store->base_handle(), "C");
3739     IS_IMM_WITH_VAL(Int, store->flat_index(), 2);
3740   }
3741 
3742   {
3743     // works if both terms are variable.
3744     VarHandle x("x", kInt);
3745     BufHandle a("A", {4}, kInt);
3746     BufHandle c("C", {4}, kInt);
3747     VarHandle i("i", kInt);
3748     auto body = For::make(i, x, x + 1, Store::make(c, {i}, Load::make(a, {i})));
3749     StmtPtr simplified = IRSimplifier::simplify(body);
3750     BlockPtr block = to<Block>(simplified);
3751     IS_NODE_WITH_NAME(Store, block->front(), store);
3752     IS_VAR_WITH_NAME(store->base_handle(), "C");
3753     IS_VAR_WITH_NAME(store->flat_index(), "x");
3754   }
3755 
3756   {
3757     // works if one term simplifies down.
3758     VarHandle x("x", kInt);
3759     BufHandle a("A", {4}, kInt);
3760     BufHandle c("C", {4}, kInt);
3761     VarHandle i("i", kInt);
3762     auto body =
3763         For::make(i, 0, x - x + 1, Store::make(c, {i}, Load::make(a, {i})));
3764     StmtPtr simplified = IRSimplifier::simplify(body);
3765     BlockPtr block = to<Block>(simplified);
3766     IS_NODE_WITH_NAME(Store, block->front(), store);
3767     IS_VAR_WITH_NAME(store->base_handle(), "C");
3768     IS_IMM_WITH_VAL(Int, store->flat_index(), 0);
3769   }
3770 
3771   {
3772     // Sanity check does nothing if the condition is not met.
3773     BufHandle a("A", {4}, kInt);
3774     BufHandle c("C", {4}, kInt);
3775     VarHandle i("i", kInt);
3776     auto body = For::make(i, 0, 3, Store::make(c, {i}, Load::make(a, {i})));
3777     StmtPtr simplified = IRSimplifier::simplify(body);
3778     IS_NODE(For, simplified);
3779   }
3780 }
3781 
TEST(Simplify,SimplifyForWontLoseLoopOptions)3782 TEST(Simplify, SimplifyForWontLoseLoopOptions) {
3783   {
3784     // Sanity check does nothing if the condition is not met.
3785     BufHandle a("A", {4}, kInt);
3786     BufHandle c("C", {4}, kInt);
3787     VarHandle i("i", kInt);
3788     LoopOptions options;
3789     options.set_gpu_block_index(LoopOptions::IDX_W);
3790     auto body =
3791         For::make(i, 0, 1, Store::make(c, {i}, Load::make(a, {i})), options);
3792     StmtPtr simplified = IRSimplifier::simplify(body);
3793     IS_NODE_WITH_NAME(For, simplified, for_);
3794     LoopOptions options2 = for_->loop_options();
3795     ASSERT_EQ(options.gpu_block_index(), options2.gpu_block_index());
3796   }
3797 }
3798 
TEST(Simplify,SimplifyMultilevelFor)3799 TEST(Simplify, SimplifyMultilevelFor) {
3800   {
3801     // Multiple layers of For will be simplified out.
3802     BufHandle a("A", {4}, kInt);
3803     BufHandle c("C", {4}, kInt);
3804     VarHandle i("i", kInt);
3805     VarHandle j("j", kInt);
3806     auto body = For::make(i, 0, 1, Store::make(c, {i}, Load::make(a, {i})));
3807     auto outer = For::make(j, 0, 1, body);
3808     StmtPtr simplified = IRSimplifier::simplify(outer);
3809     BlockPtr block = to<Block>(simplified);
3810     IS_NODE_WITH_NAME(Store, block->front(), store);
3811     IS_VAR_WITH_NAME(store->base_handle(), "C");
3812     IS_IMM_WITH_VAL(Int, store->flat_index(), 0);
3813   }
3814 
3815   {
3816     // Will maintain an outer loop if the inner loop is eliminated.
3817     BufHandle a("A", {4}, kInt);
3818     BufHandle c("C", {4}, kInt);
3819     VarHandle i("i", kInt);
3820     VarHandle j("j", kInt);
3821     auto body = For::make(i, 0, 1, Store::make(c, {i}, Load::make(a, {i})));
3822     auto outer = For::make(j, 0, 2, body);
3823     StmtPtr simplified = IRSimplifier::simplify(outer);
3824     ForPtr for__ = static_to<For>(simplified);
3825     IS_NODE_WITH_NAME(For, for__, for_);
3826     IS_VAR_WITH_NAME(for_->var(), "j");
3827     IS_IMM_WITH_VAL(Int, for_->start(), 0);
3828     IS_IMM_WITH_VAL(Int, for_->stop(), 2);
3829     BlockPtr block = to<Block>(for_->body());
3830     ASSERT_NE(block, nullptr);
3831     IS_NODE_WITH_NAME(Store, block->front(), store);
3832     IS_VAR_WITH_NAME(store->base_handle(), "C");
3833     IS_IMM_WITH_VAL(Int, store->flat_index(), 0);
3834   }
3835 
3836   {
3837     // Will maintain inner loop if outer loops is eliminated.
3838     BufHandle a("A", {4}, kInt);
3839     BufHandle c("C", {4}, kInt);
3840     VarHandle i("i", kInt);
3841     VarHandle j("j", kInt);
3842     auto body = For::make(i, 0, 2, Store::make(c, {i}, Load::make(a, {i})));
3843     auto outer = For::make(j, 0, 1, body);
3844     StmtPtr simplified = IRSimplifier::simplify(outer);
3845     BlockPtr block = to<Block>(simplified);
3846     IS_NODE_WITH_NAME(For, block->front(), for_);
3847     IS_VAR_WITH_NAME(for_->var(), "i");
3848     IS_IMM_WITH_VAL(Int, for_->start(), 0);
3849     IS_IMM_WITH_VAL(Int, for_->stop(), 2);
3850     IS_NODE_WITH_NAME(Store, for_->body()->front(), store);
3851     IS_VAR_WITH_NAME(store->base_handle(), "C");
3852     IS_VAR_WITH_NAME(store->flat_index(), "i");
3853   }
3854 }
3855 
TEST(Simplify,SimplifyForCleansUp)3856 TEST(Simplify, SimplifyForCleansUp) {
3857   {
3858     BufHandle a("a", {1, 12, 1}, kFloat);
3859     VarHandle x("x", kInt);
3860     Tensor b = Compute(
3861         "x",
3862         {1, 12, 1},
3863         [](const VarHandle& i, const VarHandle& m, const VarHandle& n) {
3864           return i + m + n;
3865         });
3866     LoopNest l({b});
3867     l.prepareForCodegen();
3868 
3869     StmtPtr body = LoopNest::sanitizeNames(l.root_stmt());
3870     StmtPtr simplified = IRSimplifier::simplify(body);
3871 
3872     BlockPtr block = to<Block>(simplified);
3873     IS_NODE_WITH_NAME(For, block->front(), for_);
3874     // for is over "m".
3875     IS_VAR_WITH_NAME(for_->var(), "j");
3876     // x[m] = m;
3877     IS_NODE_WITH_NAME(Store, for_->body()->front(), store);
3878     IS_VAR_WITH_NAME(store->flat_index(), "j");
3879     IS_VAR_WITH_NAME(store->value(), "j");
3880   }
3881 }
3882 
TEST(Simplify,SimplifyEliminateEmptyFor)3883 TEST(Simplify, SimplifyEliminateEmptyFor) {
3884   {
3885     // Flatten many layers around an empty block to an empty block.
3886     StmtPtr last = alloc<Block>(std::vector<StmtPtr>({}));
3887     for (const auto i : c10::irange(11)) {
3888       (void)i; // Suppress unused variable warning
3889       VarHandle loopVar("loopVar", kInt);
3890       last = For::make(loopVar, 0, 10, last);
3891     }
3892 
3893     StmtPtr simplified = IRSimplifier::simplify(last);
3894     IS_NODE_WITH_NAME(Block, simplified, block);
3895     ASSERT_EQ(block->nstmts(), 0);
3896   }
3897 }
3898 
TEST(Simplify,SimplifyFlattenBlock)3899 TEST(Simplify, SimplifyFlattenBlock) {
3900   {
3901     // Flatten multiple blocks down to one.
3902     // { { { stmt1, stmt2 } } } =>  { stmt1, stmt2 }
3903     BufHandle a("A", {1}, kInt);
3904     StorePtr store1 = Store::make(a, {0}, 1);
3905     StorePtr store2 = Store::make(a, {0}, 0);
3906 
3907     BlockPtr block1 = alloc<Block>(std::vector<StmtPtr>({store1, store2}));
3908     BlockPtr block2 = alloc<Block>(std::vector<StmtPtr>({block1}));
3909 
3910     BlockPtr enclosing = alloc<Block>(std::vector<StmtPtr>({block2}));
3911     StmtPtr simplified = IRSimplifier::simplify(enclosing);
3912 
3913     IS_NODE_WITH_NAME(Block, simplified, block);
3914     ASSERT_EQ(block->nstmts(), 2);
3915 
3916     IS_NODE_WITH_NAME(Store, block->front(), store1_);
3917     IS_NODE_WITH_NAME(Store, block->back(), store2_);
3918 
3919     ASSERT_EQ(store1->value(), store1_->value());
3920     ASSERT_EQ(store2->value(), store2_->value());
3921   }
3922 
3923   {
3924     // Flatten multiple sub blocks containing statements.
3925     // { { stmt1 }, { stmt2 } } =>  { stmt1, stmt2 }
3926     BufHandle a("A", {1}, kInt);
3927     StorePtr store1 = Store::make(a, {0}, 1);
3928     StorePtr store2 = Store::make(a, {0}, 0);
3929 
3930     BlockPtr block1 = alloc<Block>(std::vector<StmtPtr>({store1}));
3931     BlockPtr block2 = alloc<Block>(std::vector<StmtPtr>({store2}));
3932 
3933     BlockPtr enclosing = alloc<Block>(std::vector<StmtPtr>({block1, block2}));
3934     StmtPtr simplified = IRSimplifier::simplify(enclosing);
3935 
3936     IS_NODE_WITH_NAME(Block, simplified, block);
3937     ASSERT_EQ(block->nstmts(), 2);
3938 
3939     IS_NODE_WITH_NAME(Store, block->front(), store1_);
3940     IS_NODE_WITH_NAME(Store, block->back(), store2_);
3941 
3942     ASSERT_EQ(store1->value(), store1_->value());
3943     ASSERT_EQ(store2->value(), store2_->value());
3944   }
3945 
3946   {
3947     // Flatten sub blocks with different depths.
3948     // { stmt1 , { { stmt2 } } } =>  { stmt1, stmt2 }
3949     BufHandle a("A", {1}, kInt);
3950     StorePtr store1 = Store::make(a, {0}, 1);
3951     StorePtr store2 = Store::make(a, {0}, 0);
3952 
3953     BlockPtr block1 = alloc<Block>(std::vector<StmtPtr>({store2}));
3954     BlockPtr block2 = alloc<Block>(std::vector<StmtPtr>({block1}));
3955 
3956     BlockPtr enclosing = alloc<Block>(std::vector<StmtPtr>({store1, block2}));
3957     StmtPtr simplified = IRSimplifier::simplify(enclosing);
3958 
3959     IS_NODE_WITH_NAME(Block, simplified, block);
3960     ASSERT_EQ(block->nstmts(), 2);
3961 
3962     IS_NODE_WITH_NAME(Store, block->front(), store1_);
3963     IS_NODE_WITH_NAME(Store, block->back(), store2_);
3964 
3965     ASSERT_EQ(store1->value(), store1_->value());
3966     ASSERT_EQ(store2->value(), store2_->value());
3967   }
3968 
3969   {
3970     // Flatten many layers around an empty block to an empty block.
3971     StmtPtr last = alloc<Block>(std::vector<StmtPtr>({}));
3972     for (const auto i : c10::irange(11)) {
3973       (void)i; // Suppress unused variable warning
3974       last = alloc<Block>(std::vector<StmtPtr>({last}));
3975     }
3976 
3977     StmtPtr simplified = IRSimplifier::simplify(last);
3978     IS_NODE_WITH_NAME(Block, simplified, block);
3979     ASSERT_EQ(block->nstmts(), 0);
3980   }
3981 }
3982 
TEST(Simplify,SimplifyEliminateZeroLengthAlloc)3983 TEST(Simplify, SimplifyEliminateZeroLengthAlloc) {
3984   {
3985     // Simple positive case.
3986     BufHandle b("x", {0}, kInt);
3987 
3988     AllocatePtr alloc_ = Allocate::make(b);
3989     FreePtr free_ = Free::make(b);
3990 
3991     BlockPtr block1 = alloc<Block>(std::vector<StmtPtr>({alloc_, free_}));
3992     ASSERT_EQ(block1->nstmts(), 2);
3993 
3994     StmtPtr simplified = IRSimplifier::simplify(block1);
3995     IS_NODE_WITH_NAME(Block, simplified, block2);
3996     ASSERT_EQ(block2->nstmts(), 0);
3997   }
3998 
3999   {
4000     // Simple negative case.
4001     BufHandle b("x", {2}, kInt);
4002 
4003     AllocatePtr alloc_ = Allocate::make(b);
4004     FreePtr free_ = Free::make(b);
4005 
4006     BlockPtr block1 = alloc<Block>(std::vector<StmtPtr>({alloc_, free_}));
4007     ASSERT_EQ(block1->nstmts(), 2);
4008 
4009     StmtPtr simplified = IRSimplifier::simplify(block1);
4010     IS_NODE_WITH_NAME(Block, simplified, block2);
4011     ASSERT_EQ(block2->nstmts(), 2);
4012   }
4013 
4014   {
4015     // Finds right Alloc/Free.
4016     BufHandle b1("x", {0}, kInt);
4017     BufHandle b2("y", {2}, kInt);
4018 
4019     AllocatePtr alloc1 = Allocate::make(b1);
4020     AllocatePtr alloc2 = Allocate::make(b2);
4021     FreePtr free2_ = Free::make(b2);
4022     FreePtr free1_ = Free::make(b1);
4023 
4024     BlockPtr block1 =
4025         alloc<Block>(std::vector<StmtPtr>({alloc1, alloc2, free2_, free1_}));
4026     ASSERT_EQ(block1->nstmts(), 4);
4027 
4028     StmtPtr simplified = IRSimplifier::simplify(block1);
4029     IS_NODE_WITH_NAME(Block, simplified, block2);
4030     ASSERT_EQ(block2->nstmts(), 2);
4031     IS_NODE_WITH_NAME(Allocate, block2->stmts().front(), simplified_alloc);
4032     IS_VAR_WITH_NAME(simplified_alloc->buffer_var(), "y");
4033     IS_NODE_WITH_NAME(Free, block2->stmts().back(), simplified_free);
4034     ASSERT_EQ(simplified_alloc->buffer_var(), simplified_free->buffer_var());
4035   }
4036 
4037   {
4038     // Dynamic shape.
4039     VarHandle z("z", kInt);
4040     BufHandle b1("x", {0}, kInt);
4041     BufHandle b2("y", {z}, kInt);
4042 
4043     AllocatePtr alloc1 = Allocate::make(b1);
4044     AllocatePtr alloc2 = Allocate::make(b2);
4045     FreePtr free2_ = Free::make(b2);
4046     FreePtr free1_ = Free::make(b1);
4047 
4048     BlockPtr block1 =
4049         alloc<Block>(std::vector<StmtPtr>({alloc1, alloc2, free2_, free1_}));
4050     ASSERT_EQ(block1->nstmts(), 4);
4051     StmtPtr simplified = IRSimplifier::simplify(block1);
4052     IS_NODE_WITH_NAME(Block, simplified, block2);
4053     ASSERT_EQ(block2->nstmts(), 2);
4054   }
4055 }
4056 
TEST(Simplify,DontSimplifyRand)4057 TEST(Simplify, DontSimplifyRand) {
4058   {
4059     // rand() + rand() = rand() + rand() NOT 2 * rand().
4060     ExprHandle body =
4061         Intrinsics::make(kRand, kInt) + Intrinsics::make(kRand, kInt);
4062     ExprHandle simplified = IRSimplifier::simplify(body);
4063     IS_NODE_WITH_NAME(Add, simplified.node(), add);
4064     IS_RAND(add->lhs());
4065     IS_RAND(add->rhs());
4066   }
4067 
4068   {
4069     // rand() - rand() = rand() - rand() NOT 0.
4070     ExprHandle body =
4071         Intrinsics::make(kRand, kFloat) - Intrinsics::make(kRand, kFloat);
4072     ExprHandle simplified = IRSimplifier::simplify(body);
4073     IS_NODE_WITH_NAME(Sub, simplified.node(), sub);
4074     IS_RAND(sub->lhs());
4075     IS_RAND(sub->rhs());
4076   }
4077 
4078   {
4079     // rand() * rand() = rand() * rand().
4080     ExprHandle body =
4081         Intrinsics::make(kRand, kInt) * Intrinsics::make(kRand, kInt);
4082     ExprHandle simplified = IRSimplifier::simplify(body);
4083     IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
4084     IS_RAND(mul->lhs());
4085     IS_RAND(mul->rhs());
4086   }
4087 }
4088 
TEST(Simplify,SimplifyReorderForCond)4089 TEST(Simplify, SimplifyReorderForCond) {
4090   BufHandle a("A", {4}, kInt);
4091   BufHandle b("B", {1}, kInt);
4092   BufHandle c("C", {4}, kInt);
4093   VarHandle i("i", kInt);
4094   VarHandle j("j", kInt);
4095 
4096   {
4097     // for ( if ( ... ) ) => if ( for ( ... ) ).
4098     auto body = For::make(
4099         i,
4100         0,
4101         4,
4102         Cond::make(
4103             CompareSelect::make(j, 10, CompareSelectOperation::kLT),
4104             Store::make(c, {i}, Load::make(a, {i})),
4105             nullptr));
4106 
4107     StmtPtr simplified = IRSimplifier::simplify(body);
4108     IS_NODE_WITH_NAME(Cond, simplified, cond);
4109     IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_block);
4110     IS_NODE_WITH_NAME(For, true_block->front(), loop);
4111   }
4112 
4113   {
4114     // Can't reorder if condition is dependent on the loop var.
4115     auto body = For::make(
4116         i,
4117         0,
4118         4,
4119         Cond::make(
4120             CompareSelect::make(i, 2, CompareSelectOperation::kEQ),
4121             Store::make(c, {i}, Load::make(a, {i})),
4122             nullptr));
4123 
4124     StmtPtr simplified = IRSimplifier::simplify(body);
4125     IS_NODE_WITH_NAME(For, simplified, loop);
4126     IS_NODE_WITH_NAME(Cond, loop->body()->front(), cond);
4127   }
4128 
4129   {
4130     // Can't reorder if condition is dependent on a var that is modified inside
4131     // the loop.
4132     auto body = For::make(
4133         i,
4134         0,
4135         4,
4136         Cond::make(
4137             CompareSelect::make(
4138                 Load::make(c, {0}), 10, CompareSelectOperation::kLT),
4139             Store::make(c, {0}, Load::make(a, {i})),
4140             nullptr));
4141 
4142     StmtPtr simplified = IRSimplifier::simplify(body);
4143     IS_NODE_WITH_NAME(For, simplified, loop);
4144     IS_NODE_WITH_NAME(Cond, loop->body()->front(), cond);
4145   }
4146 
4147   {
4148     // Condition based on buffer not referenced in body. Can reorder here.
4149     auto body = For::make(
4150         i,
4151         0,
4152         4,
4153         Cond::make(
4154             CompareSelect::make(
4155                 Load::make(b, {0}), 10, CompareSelectOperation::kLT),
4156             Store::make(c, {0}, Load::make(a, {i})),
4157             nullptr));
4158 
4159     StmtPtr simplified = IRSimplifier::simplify(body);
4160     IS_NODE_WITH_NAME(Cond, simplified, cond);
4161     IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_block);
4162     IS_NODE_WITH_NAME(For, true_block->front(), loop);
4163   }
4164 
4165   {
4166     // Condition based on buffer read only in body. Can reorder here.
4167     auto body = For::make(
4168         i,
4169         0,
4170         4,
4171         Cond::make(
4172             CompareSelect::make(
4173                 Load::make(a, {0}), 10, CompareSelectOperation::kLT),
4174             Store::make(c, {0}, Load::make(a, {i})),
4175             nullptr));
4176 
4177     StmtPtr simplified = IRSimplifier::simplify(body);
4178     IS_NODE_WITH_NAME(Cond, simplified, cond);
4179     IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_block);
4180     IS_NODE_WITH_NAME(For, true_block->front(), loop);
4181   }
4182 
4183   {
4184     // Condition depends on Let in the loop. Cannot reorder.
4185     auto body = For::make(
4186         i,
4187         0,
4188         4,
4189         Block::make(
4190             {Let::make(j, 3),
4191              Cond::make(
4192                  CompareSelect::make(j, 10, CompareSelectOperation::kLT),
4193                  Store::make(c, {0}, Load::make(a, {i})),
4194                  nullptr)}));
4195 
4196     StmtPtr simplified = IRSimplifier::simplify(body);
4197     IS_NODE_WITH_NAME(For, simplified, loop);
4198     IS_NODE_WITH_NAME(Let, loop->body()->front(), let);
4199     IS_NODE_WITH_NAME(Cond, loop->body()->back(), cond);
4200   }
4201 
4202   {
4203     // Multi level Ifs where all conditions are distinct. Move BOTH Cond
4204     // statements outside the loop.
4205     auto body = For::make(
4206         i,
4207         0,
4208         4,
4209         Cond::make(
4210             CompareSelect::make(
4211                 Load::make(a, {0}), 10, CompareSelectOperation::kLT),
4212             Cond::make(
4213                 CompareSelect::make(j, 10, CompareSelectOperation::kEQ),
4214                 Store::make(c, {0}, Load::make(a, {i})),
4215                 nullptr),
4216             nullptr));
4217 
4218     StmtPtr simplified = IRSimplifier::simplify(body);
4219     IS_NODE_WITH_NAME(Cond, simplified, cond);
4220     IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_block);
4221     IS_NODE_WITH_NAME(Cond, true_block->front(), cond2);
4222     IS_NODE_WITH_NAME(Block, cond2->true_stmt(), true_block2);
4223     IS_NODE_WITH_NAME(For, true_block2->front(), loop);
4224   }
4225 
4226   {
4227     // Multi level Ifs where the inner condition does depend on a loop var,
4228     // reorder only the first Cond.
4229     auto body = For::make(
4230         i,
4231         0,
4232         4,
4233         Cond::make(
4234             CompareSelect::make(
4235                 Load::make(a, {0}), 10, CompareSelectOperation::kLT),
4236             Cond::make(
4237                 CompareSelect::make(i, 3, CompareSelectOperation::kEQ),
4238                 Store::make(c, {0}, Load::make(a, {i})),
4239                 nullptr),
4240             nullptr));
4241 
4242     StmtPtr simplified = IRSimplifier::simplify(body);
4243     IS_NODE_WITH_NAME(Cond, simplified, cond);
4244     IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_block);
4245     IS_NODE_WITH_NAME(For, true_block->front(), loop);
4246     IS_NODE_WITH_NAME(Block, loop->body(), loop_body);
4247     IS_NODE_WITH_NAME(Cond, loop_body->front(), cond2);
4248   }
4249 
4250   {
4251     // Don't reorder if there's an else block of the Cond.
4252     // We could, but is it much better?
4253     auto body = For::make(
4254         i,
4255         0,
4256         4,
4257         Cond::make(
4258             CompareSelect::make(j, 10, CompareSelectOperation::kLT),
4259             Store::make(c, {0}, Load::make(a, {i})),
4260             Store::make(c, {0}, 0)));
4261 
4262     StmtPtr simplified = IRSimplifier::simplify(body);
4263     IS_NODE_WITH_NAME(For, simplified, loop);
4264     IS_NODE_WITH_NAME(Cond, loop->body()->front(), cond);
4265   }
4266 
4267   {
4268     // Condition uses distinct region of Tensor.
4269     // We could reorder here wih better analysis, but we don't. Included for
4270     // completeness.
4271     auto body = For::make(
4272         i,
4273         0,
4274         4,
4275         Cond::make(
4276             CompareSelect::make(
4277                 Load::make(c, {0}), 10, CompareSelectOperation::kLT),
4278             Store::make(c, {1}, Load::make(a, {i})),
4279             nullptr));
4280 
4281     StmtPtr simplified = IRSimplifier::simplify(body);
4282     IS_NODE_WITH_NAME(For, simplified, loop);
4283     IS_NODE_WITH_NAME(Cond, loop->body()->front(), cond);
4284   }
4285 }
4286 
TEST(Simplify,SimplifyFuseConditions)4287 TEST(Simplify, SimplifyFuseConditions) {
4288   BufHandle a("A", {2}, kInt);
4289   BufHandle b("B", {2}, kInt);
4290   VarHandle i("i", kInt);
4291   VarHandle j("j", kInt);
4292 
4293   {
4294     // Can fuse since the conditions are identical.
4295     // if (A) { X }; if (A) { Y }; => if (A) { X; Y }
4296     auto body = Block::make(
4297         {Cond::make(
4298              CompareSelect::make(i, 10, CompareSelectOperation::kLT),
4299              Store::make(a, {0}, i),
4300              nullptr),
4301          Cond::make(
4302              CompareSelect::make(i, 10, CompareSelectOperation::kLT),
4303              Store::make(a, {1}, i),
4304              nullptr)});
4305 
4306     StmtPtr simplified = IRSimplifier::simplify(body);
4307     // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
4308     IS_NODE_WITH_NAME(Block, simplified, block);
4309     ASSERT_EQ(block->nstmts(), 1);
4310     IS_NODE_WITH_NAME(Cond, block->front(), cond);
4311     IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_stmt);
4312     ASSERT_EQ(true_stmt->nstmts(), 2);
4313     ASSERT_EQ(cond->false_stmt(), nullptr);
4314   }
4315 
4316   {
4317     // Can't fuse, conditions are not identical in lhs (i != j).
4318     auto body = Block::make(
4319         {Cond::make(
4320              CompareSelect::make(i, 10, CompareSelectOperation::kLT),
4321              Store::make(a, {0}, i),
4322              nullptr),
4323          Cond::make(
4324              CompareSelect::make(j, 10, CompareSelectOperation::kLT),
4325              Store::make(a, {1}, i),
4326              nullptr)});
4327 
4328     StmtPtr simplified = IRSimplifier::simplify(body);
4329     IS_NODE_WITH_NAME(Block, simplified, block);
4330     ASSERT_EQ(block->nstmts(), 2);
4331     IS_NODE_WITH_NAME(Cond, block->front(), cond1);
4332     IS_NODE_WITH_NAME(Cond, block->back(), cond2);
4333 
4334     IS_NODE_WITH_NAME(Block, cond1->true_stmt(), true_stmt1);
4335     IS_NODE_WITH_NAME(Block, cond2->true_stmt(), true_stmt2);
4336     ASSERT_EQ(true_stmt1->nstmts(), 1);
4337     ASSERT_EQ(true_stmt2->nstmts(), 1);
4338 
4339     ASSERT_EQ(cond1->false_stmt(), nullptr);
4340     ASSERT_EQ(cond2->false_stmt(), nullptr);
4341   }
4342   {
4343     // Can't fuse, conditions are not identical in rhs (10 != 11).
4344     auto body = Block::make(
4345         {Cond::make(
4346              CompareSelect::make(i, 10, CompareSelectOperation::kLT),
4347              Store::make(a, {0}, i),
4348              nullptr),
4349          Cond::make(
4350              CompareSelect::make(i, 11, CompareSelectOperation::kLT),
4351              Store::make(a, {1}, i),
4352              nullptr)});
4353 
4354     StmtPtr simplified = IRSimplifier::simplify(body);
4355     IS_NODE_WITH_NAME(Block, simplified, block);
4356     ASSERT_EQ(block->nstmts(), 2);
4357     IS_NODE_WITH_NAME(Cond, block->front(), cond1);
4358     IS_NODE_WITH_NAME(Cond, block->back(), cond2);
4359 
4360     IS_NODE_WITH_NAME(Block, cond1->true_stmt(), true_stmt1);
4361     IS_NODE_WITH_NAME(Block, cond2->true_stmt(), true_stmt2);
4362     ASSERT_EQ(true_stmt1->nstmts(), 1);
4363     ASSERT_EQ(true_stmt2->nstmts(), 1);
4364 
4365     ASSERT_EQ(cond1->false_stmt(), nullptr);
4366     ASSERT_EQ(cond2->false_stmt(), nullptr);
4367   }
4368 
4369   {
4370     // Can't fuse, conditions are not identical in operation (LT vs GT).
4371     auto body = Block::make(
4372         {Cond::make(
4373              CompareSelect::make(i, 10, CompareSelectOperation::kLT),
4374              Store::make(a, {0}, i),
4375              nullptr),
4376          Cond::make(
4377              CompareSelect::make(i, 10, CompareSelectOperation::kGT),
4378              Store::make(a, {1}, i),
4379              nullptr)});
4380 
4381     StmtPtr simplified = IRSimplifier::simplify(body);
4382     IS_NODE_WITH_NAME(Block, simplified, block);
4383     ASSERT_EQ(block->nstmts(), 2);
4384     IS_NODE_WITH_NAME(Cond, block->front(), cond1);
4385     IS_NODE_WITH_NAME(Cond, block->back(), cond2);
4386 
4387     IS_NODE_WITH_NAME(Block, cond1->true_stmt(), true_stmt1);
4388     IS_NODE_WITH_NAME(Block, cond2->true_stmt(), true_stmt2);
4389     ASSERT_EQ(true_stmt1->nstmts(), 1);
4390     ASSERT_EQ(true_stmt2->nstmts(), 1);
4391 
4392     ASSERT_EQ(cond1->false_stmt(), nullptr);
4393     ASSERT_EQ(cond2->false_stmt(), nullptr);
4394   }
4395 
4396   {
4397     // Can't fuse, CompareSelect results are different.
4398     // Actually we totally could if we normalized CompareSelect results, but
4399     // TODO for later.
4400     auto body = Block::make(
4401         {Cond::make(
4402              CompareSelect::make(i, 10, 1, 0, CompareSelectOperation::kLT),
4403              Store::make(a, {0}, i),
4404              nullptr),
4405          Cond::make(
4406              CompareSelect::make(j, 10, 2, 0, CompareSelectOperation::kLT),
4407              Store::make(a, {1}, i),
4408              nullptr)});
4409 
4410     StmtPtr simplified = IRSimplifier::simplify(body);
4411     IS_NODE_WITH_NAME(Block, simplified, block);
4412     ASSERT_EQ(block->nstmts(), 2);
4413     IS_NODE_WITH_NAME(Cond, block->front(), cond1);
4414     IS_NODE_WITH_NAME(Cond, block->back(), cond2);
4415 
4416     IS_NODE_WITH_NAME(Block, cond1->true_stmt(), true_stmt1);
4417     IS_NODE_WITH_NAME(Block, cond2->true_stmt(), true_stmt2);
4418     ASSERT_EQ(true_stmt1->nstmts(), 1);
4419     ASSERT_EQ(true_stmt2->nstmts(), 1);
4420 
4421     ASSERT_EQ(cond1->false_stmt(), nullptr);
4422     ASSERT_EQ(cond2->false_stmt(), nullptr);
4423   }
4424 
4425   {
4426     // Can fuse with false stmt only.
4427     auto body = Block::make(
4428         {Cond::make(
4429              CompareSelect::make(i, 10, CompareSelectOperation::kLT),
4430              nullptr,
4431              Store::make(a, {0}, i)),
4432          Cond::make(
4433              CompareSelect::make(i, 10, CompareSelectOperation::kLT),
4434              nullptr,
4435              Store::make(a, {1}, i))});
4436 
4437     StmtPtr simplified = IRSimplifier::simplify(body);
4438     IS_NODE_WITH_NAME(Block, simplified, block);
4439     ASSERT_EQ(block->nstmts(), 1);
4440     IS_NODE_WITH_NAME(Cond, block->front(), cond);
4441     IS_NODE_WITH_NAME(Block, cond->false_stmt(), false_stmt);
4442     ASSERT_EQ(false_stmt->nstmts(), 2);
4443     ASSERT_EQ(cond->true_stmt(), nullptr);
4444   }
4445 
4446   {
4447     // Can fuse with both true and false stmt.
4448     auto body = Block::make(
4449         {Cond::make(
4450              CompareSelect::make(i, 10, CompareSelectOperation::kLT),
4451              Store::make(a, {0}, i),
4452              Store::make(b, {0}, i)),
4453          Cond::make(
4454              CompareSelect::make(i, 10, CompareSelectOperation::kLT),
4455              Store::make(a, {1}, i),
4456              Store::make(b, {1}, i))});
4457 
4458     StmtPtr simplified = IRSimplifier::simplify(body);
4459     IS_NODE_WITH_NAME(Block, simplified, block);
4460     ASSERT_EQ(block->nstmts(), 1);
4461     IS_NODE_WITH_NAME(Cond, block->front(), cond);
4462     IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_stmt);
4463     ASSERT_EQ(true_stmt->nstmts(), 2);
4464     IS_NODE_WITH_NAME(Block, cond->true_stmt(), false_stmt);
4465     ASSERT_EQ(false_stmt->nstmts(), 2);
4466   }
4467 
4468   {
4469     // Can fuse with mismatched true / false stmt existing
4470     auto body = Block::make(
4471         {Cond::make(
4472              CompareSelect::make(i, 10, CompareSelectOperation::kLT),
4473              Store::make(a, {0}, i),
4474              nullptr),
4475          Cond::make(
4476              CompareSelect::make(i, 10, CompareSelectOperation::kLT),
4477              nullptr,
4478              Store::make(b, {1}, i))});
4479 
4480     StmtPtr simplified = IRSimplifier::simplify(body);
4481     IS_NODE_WITH_NAME(Block, simplified, block);
4482     ASSERT_EQ(block->nstmts(), 1);
4483     IS_NODE_WITH_NAME(Cond, block->front(), cond);
4484     IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_stmt);
4485     ASSERT_EQ(true_stmt->nstmts(), 1);
4486     IS_NODE_WITH_NAME(Block, cond->true_stmt(), false_stmt);
4487     ASSERT_EQ(false_stmt->nstmts(), 1);
4488   }
4489 
4490   {
4491     // Can fuse partial block contents, ie when there are non fused stmts before
4492     // and after.
4493     // before:
4494     // if (j < 10) { A[0] = j; }
4495     // if (i < 10) { A[0] = i; }
4496     // if (i < 10) { A[1] = i; }
4497     // if (i < 11) { A[1] = j; }
4498     //
4499     // after:
4500     //
4501     // if (j < 10) { A[0] = j; }
4502     // if (i < 10) {
4503     //   A[0] = i;
4504     //   A[1] = i;
4505     // }
4506     // if (i < 11) { A[1] = j; }
4507 
4508     auto body = Block::make({
4509         Cond::make(
4510             CompareSelect::make(j, 10, CompareSelectOperation::kLT),
4511             Store::make(a, {0}, j),
4512             nullptr),
4513         Cond::make(
4514             CompareSelect::make(i, 10, CompareSelectOperation::kLT),
4515             Store::make(a, {0}, i),
4516             nullptr),
4517         Cond::make(
4518             CompareSelect::make(i, 10, CompareSelectOperation::kLT),
4519             Store::make(a, {1}, i),
4520             nullptr),
4521         Cond::make(
4522             CompareSelect::make(i, 11, CompareSelectOperation::kLT),
4523             Store::make(a, {1}, j),
4524             nullptr),
4525     });
4526     StmtPtr simplified = IRSimplifier::simplify(body);
4527     IS_NODE_WITH_NAME(Block, simplified, block);
4528     ASSERT_EQ(block->nstmts(), 3);
4529     auto it = block->begin();
4530     it++;
4531     IS_NODE_WITH_NAME(Cond, *it, cond);
4532     IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_stmt);
4533     ASSERT_EQ(true_stmt->nstmts(), 2);
4534     ASSERT_EQ(cond->false_stmt(), nullptr);
4535   }
4536 
4537   {
4538     // Can fuse longer sequences of identical conditions.
4539     auto body = Block::make({
4540         Cond::make(
4541             CompareSelect::make(i, 10, CompareSelectOperation::kLT),
4542             Store::make(a, {0}, j),
4543             nullptr),
4544         Cond::make(
4545             CompareSelect::make(i, 10, CompareSelectOperation::kLT),
4546             Store::make(a, {0}, i),
4547             nullptr),
4548         Cond::make(
4549             CompareSelect::make(i, 10, CompareSelectOperation::kLT),
4550             Store::make(a, {1}, i),
4551             nullptr),
4552         Cond::make(
4553             CompareSelect::make(i, 10, CompareSelectOperation::kLT),
4554             Store::make(a, {1}, j),
4555             nullptr),
4556     });
4557     StmtPtr simplified = IRSimplifier::simplify(body);
4558     IS_NODE_WITH_NAME(Block, simplified, block);
4559     ASSERT_EQ(block->nstmts(), 1);
4560     IS_NODE_WITH_NAME(Cond, block->front(), cond);
4561     IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_stmt);
4562     ASSERT_EQ(true_stmt->nstmts(), 4);
4563     ASSERT_EQ(cond->false_stmt(), nullptr);
4564   }
4565 
4566   {
4567     // Can't fuse through a non condition.
4568     auto body = Block::make({
4569         Cond::make(
4570             CompareSelect::make(i, 10, CompareSelectOperation::kLT),
4571             Store::make(a, {0}, j),
4572             nullptr),
4573         Cond::make(
4574             CompareSelect::make(i, 10, CompareSelectOperation::kLT),
4575             Store::make(a, {0}, i),
4576             nullptr),
4577         Store::make(b, {1}, i + j),
4578         Cond::make(
4579             CompareSelect::make(i, 10, CompareSelectOperation::kLT),
4580             Store::make(a, {1}, i),
4581             nullptr),
4582         Cond::make(
4583             CompareSelect::make(i, 10, CompareSelectOperation::kLT),
4584             Store::make(a, {1}, j),
4585             nullptr),
4586     });
4587     StmtPtr simplified = IRSimplifier::simplify(body);
4588     IS_NODE_WITH_NAME(Block, simplified, block);
4589     ASSERT_EQ(block->nstmts(), 3);
4590     IS_NODE_WITH_NAME(Cond, block->front(), cond);
4591     IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_stmt);
4592     ASSERT_EQ(true_stmt->nstmts(), 2);
4593     ASSERT_EQ(cond->false_stmt(), nullptr);
4594 
4595     IS_NODE_WITH_NAME(Cond, block->back(), cond2);
4596     IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_stmt2);
4597     ASSERT_EQ(true_stmt2->nstmts(), 2);
4598     ASSERT_EQ(cond2->false_stmt(), nullptr);
4599 
4600     auto it = block->begin();
4601     it++;
4602     IS_NODE_WITH_NAME(Store, *it, middle);
4603   }
4604 
4605   {
4606     // Can fuse if the conditions simplify to the same thing.
4607     auto body = Block::make(
4608         {Cond::make(
4609              CompareSelect::make(
4610                  i * 2,
4611                  ExprHandle(87) % ExprHandle(11),
4612                  CompareSelectOperation::kLT),
4613              Store::make(a, {0}, i),
4614              nullptr),
4615          Cond::make(
4616              CompareSelect::make(
4617                  i * 2,
4618                  ExprHandle(300) / ExprHandle(30),
4619                  CompareSelectOperation::kLT),
4620              Store::make(a, {1}, i),
4621              nullptr)});
4622     StmtPtr simplified = IRSimplifier::simplify(body);
4623     IS_NODE_WITH_NAME(Block, simplified, block);
4624     ASSERT_EQ(block->nstmts(), 1);
4625     IS_NODE_WITH_NAME(Cond, block->front(), cond);
4626     IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_stmt);
4627     ASSERT_EQ(true_stmt->nstmts(), 2);
4628     ASSERT_EQ(cond->false_stmt(), nullptr);
4629   }
4630 
4631   {
4632     // Can fuse non-CompareSelects.
4633     // if (i) { X } if (i) { Y } => if (i) { X; Y }
4634     auto body = Block::make(
4635         {Cond::make(i, Store::make(a, {0}, i), nullptr),
4636          Cond::make(i, Store::make(a, {1}, i), nullptr)});
4637 
4638     StmtPtr simplified = IRSimplifier::simplify(body);
4639     IS_NODE_WITH_NAME(Block, simplified, block);
4640     ASSERT_EQ(block->nstmts(), 1);
4641     IS_NODE_WITH_NAME(Cond, block->front(), cond);
4642     IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_stmt);
4643     ASSERT_EQ(true_stmt->nstmts(), 2);
4644     ASSERT_EQ(cond->false_stmt(), nullptr);
4645   }
4646 
4647   {
4648     // Sanity check wont fuse different non-CompareSelects.
4649     auto body = Block::make(
4650         {Cond::make(i, Store::make(a, {0}, i), nullptr),
4651          Cond::make(j, Store::make(a, {1}, i), nullptr)});
4652 
4653     StmtPtr simplified = IRSimplifier::simplify(body);
4654     IS_NODE_WITH_NAME(Block, simplified, block);
4655     ASSERT_EQ(block->nstmts(), 2);
4656     IS_NODE_WITH_NAME(Cond, block->front(), cond1);
4657     IS_NODE_WITH_NAME(Cond, block->back(), cond2);
4658   }
4659 
4660   {
4661     // Sanity check constant condition elimination still occurs when merging is
4662     // possible.
4663     auto body = Block::make(
4664         {Cond::make(1, Store::make(a, {0}, i), nullptr),
4665          Cond::make(1, Store::make(a, {1}, i), nullptr)});
4666     StmtPtr simplified = IRSimplifier::simplify(body);
4667     IS_NODE_WITH_NAME(Block, simplified, block);
4668     ASSERT_EQ(block->nstmts(), 2);
4669     IS_NODE_WITH_NAME(Store, block->front(), store1);
4670     IS_NODE_WITH_NAME(Store, block->back(), store2);
4671   }
4672 
4673   {
4674     // Sanity check for-cond reordering occurs after fusing.
4675     auto body = For::make(
4676         i,
4677         0,
4678         4,
4679         Block::make(
4680             {Cond::make(
4681                  CompareSelect::make(j, 10, CompareSelectOperation::kLT),
4682                  Store::make(a, {1}, Load::make(b, {0})),
4683                  nullptr),
4684              Cond::make(
4685                  CompareSelect::make(j, 10, CompareSelectOperation::kLT),
4686                  Store::make(a, {2}, Load::make(b, {0})),
4687                  nullptr)}));
4688 
4689     StmtPtr simplified = IRSimplifier::simplify(body);
4690     IS_NODE_WITH_NAME(Cond, simplified, cond);
4691     IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_block);
4692     IS_NODE_WITH_NAME(For, true_block->front(), loop);
4693   }
4694 }
4695 
TEST(Simplify,SimplifySyncThreads)4696 TEST(Simplify, SimplifySyncThreads) {
4697   BufHandle a("A", {4}, kInt);
4698   VarHandle i("i", kInt);
4699 
4700   {
4701     // Merge two inner SyncThreads.
4702     auto body = Block::make(
4703         // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
4704         {Store::make(a, {0}, 1),
4705          alloc<SyncThreads>(),
4706          alloc<SyncThreads>(),
4707          Store::make(a, {1}, 0)});
4708     StmtPtr simplified = IRSimplifier::simplify(body);
4709     IS_NODE_WITH_NAME(Block, simplified, block);
4710     ASSERT_EQ(block->nstmts(), 3);
4711     auto it = block->begin();
4712     IS_NODE(Store, *it++);
4713     IS_NODE(SyncThreads, *it++);
4714     IS_NODE(Store, *it++);
4715   }
4716 
4717   {
4718     // Eliminate outer SyncThreads.
4719     auto body = Block::make(
4720         {alloc<SyncThreads>(), Store::make(a, {1}, 0), alloc<SyncThreads>()});
4721 
4722     StmtPtr simplified = IRSimplifier::simplify(body);
4723     IS_NODE_WITH_NAME(Block, simplified, block);
4724     ASSERT_EQ(block->nstmts(), 1);
4725     auto it = block->begin();
4726     IS_NODE(Store, *it);
4727   }
4728 
4729   {
4730     // Merge many inner SyncThreads.
4731     auto body = Block::make(
4732         {Store::make(a, {0}, 1),
4733          alloc<SyncThreads>(),
4734          alloc<SyncThreads>(),
4735          alloc<SyncThreads>(),
4736          alloc<SyncThreads>(),
4737          alloc<SyncThreads>(),
4738          Store::make(a, {1}, 0)});
4739 
4740     StmtPtr simplified = IRSimplifier::simplify(body);
4741     IS_NODE_WITH_NAME(Block, simplified, block);
4742     ASSERT_EQ(block->nstmts(), 3);
4743     auto it = block->begin();
4744     IS_NODE(Store, *it++);
4745     IS_NODE(SyncThreads, *it++);
4746     IS_NODE(Store, *it++);
4747   }
4748 
4749   {
4750     // Merge multiple outer SyncThreads.
4751     auto body = Block::make(
4752         {alloc<SyncThreads>(),
4753          alloc<SyncThreads>(),
4754          Store::make(a, {1}, 0),
4755          alloc<SyncThreads>(),
4756          alloc<SyncThreads>(),
4757          alloc<SyncThreads>(),
4758          alloc<SyncThreads>()});
4759 
4760     StmtPtr simplified = IRSimplifier::simplify(body);
4761     IS_NODE_WITH_NAME(Block, simplified, block);
4762     ASSERT_EQ(block->nstmts(), 1);
4763     auto it = block->begin();
4764     IS_NODE(Store, *it);
4765   }
4766 
4767   {
4768     // Merge multiple sections;
4769     auto body = Block::make(
4770         {Store::make(a, {0}, 1),
4771          alloc<SyncThreads>(),
4772          alloc<SyncThreads>(),
4773          Store::make(a, {1}, 0),
4774          Store::make(a, {2}, 0),
4775          alloc<SyncThreads>(),
4776          alloc<SyncThreads>(),
4777          alloc<SyncThreads>(),
4778          Store::make(a, {3}, 0)});
4779 
4780     StmtPtr simplified = IRSimplifier::simplify(body);
4781     IS_NODE_WITH_NAME(Block, simplified, block);
4782     ASSERT_EQ(block->nstmts(), 6);
4783     auto it = block->begin();
4784     IS_NODE(Store, *it++);
4785     IS_NODE(SyncThreads, *it++);
4786     IS_NODE(Store, *it++);
4787     IS_NODE(Store, *it++);
4788     IS_NODE(SyncThreads, *it++);
4789     IS_NODE(Store, *it++);
4790   }
4791 }
4792 
TEST(Simplify,SimplifyRampSubBroadcast)4793 TEST(Simplify, SimplifyRampSubBroadcast) {
4794   int num_lanes = 4;
4795   ExprHandle ramp = Ramp::make(ExprHandle(0), ExprHandle(6), num_lanes);
4796   ExprHandle broadcast = Broadcast::make(ExprHandle(-5), num_lanes);
4797   ExprHandle simplified = IRSimplifier::simplify(ramp - broadcast);
4798   RampPtr newRamp = simplified.AsNode<Ramp>();
4799   IS_NODE_WITH_NAME(IntImm, newRamp->base(), base);
4800   ASSERT_EQ(base->value(), 5);
4801   IS_NODE_WITH_NAME(IntImm, newRamp->stride(), stride);
4802   ASSERT_EQ(stride->value(), 6);
4803   ASSERT_EQ(newRamp->lanes(), num_lanes);
4804 }
4805 
TEST(Simplify,SimplifyBroadcastTermExpander)4806 TEST(Simplify, SimplifyBroadcastTermExpander) {
4807   int num_lanes = 8;
4808   ExprHandle bc0 = Broadcast::make(ExprHandle(0), num_lanes);
4809   ExprHandle bc1 = Broadcast::make(ExprHandle(1), num_lanes);
4810   ExprHandle bc2 = Broadcast::make(ExprHandle(2), num_lanes);
4811   // NB: We need a term in the middle which isn't simplified to trigger the
4812   // relevant path in TermExpander::mutate. The two bc1 terms are brought
4813   // together and simplified to 2 * bc1, which then needs to make 2 multi-lane.
4814   ExprHandle simplified = IRSimplifier::simplify(bc1 + (bc0 / bc2) + bc1);
4815   BufHandle buf("buf", {num_lanes}, kInt);
4816   // The result isn't fully simplified currently and thus would be brittle to
4817   // match. Observe its value instead.
4818   auto store = Store::make(buf, {Ramp::make(0, 1, num_lanes)}, simplified);
4819   SimpleIREvaluator eval(store, {buf});
4820   std::vector<int> output(num_lanes);
4821   eval(output);
4822   for (const auto i : c10::irange(num_lanes)) {
4823     ASSERT_EQ(output[i], 2);
4824   }
4825 }
4826 
TEST(Simplify,CompareSelectLoopBounds)4827 TEST(Simplify, CompareSelectLoopBounds) {
4828   constexpr int N = 8;
4829   BufHandle b("b", {N}, kFloat);
4830   VarHandle n("n", kInt);
4831   VarHandle m("m", kInt);
4832   VarHandle var_N("var_N", kInt);
4833   VarHandle var_M("var_M", kInt);
4834 
4835   auto test_case_fn = [](const VarHandle& n,
4836                          const BufHandle& b,
4837                          const ExprHandle& start,
4838                          const ExprHandle& stop,
4839                          const int& cmp_val,
4840                          const CompareSelectOperation& cmp_op,
4841                          const std::string& check_string) {
4842     StmtPtr s = For::make(
4843         n,
4844         start,
4845         stop,
4846         b.store({n}, CompareSelect::make(n, cmp_val, 0.f, 1.0f, cmp_op)));
4847     s = IRSimplifier::simplify(s);
4848     std::ostringstream oss;
4849     oss << *s;
4850     std::string target_string = "# CHECK: ";
4851     target_string += check_string;
4852     torch::jit::testing::FileCheck().run(target_string, oss.str());
4853   };
4854 
4855   auto test_case_nest_loops_fn = [](const VarHandle& n,
4856                                     const VarHandle& m,
4857                                     const BufHandle& b,
4858                                     const ExprHandle& n_start,
4859                                     const ExprHandle& n_stop,
4860                                     const ExprHandle& m_start,
4861                                     const ExprHandle& m_stop,
4862                                     const CompareSelectOperation& cmp_op,
4863                                     const std::string& check_string) {
4864     StmtPtr s = For::make(
4865         m,
4866         m_start,
4867         m_stop,
4868         b.store({n, m}, CompareSelect::make(n, m, 0.f, 1.0f, cmp_op)));
4869     StmtPtr root_s = For::make(n, n_start, n_stop, s);
4870     root_s = IRSimplifier::simplify(root_s);
4871     std::ostringstream oss;
4872     oss << *root_s;
4873     std::string target_string = "# CHECK: ";
4874     target_string += check_string;
4875     torch::jit::testing::FileCheck().run(target_string, oss.str());
4876   };
4877 
4878   // Before:
4879   //   for (const auto n : c10::irange(1, N)) {
4880   //     b[n] = n < 1 ? 0.f : 1.f;
4881   //   }
4882   // After:
4883   //   for (const auto n : c10::irange(1, N)) {
4884   //     b[n] = 1.f;
4885   //   }
4886   test_case_fn(n, b, 1, N, 1, kLT, "b[n] = 1.f;");
4887 
4888   // Before:
4889   //   for (const auto n : c10::irange(1, N)) {
4890   //     b[n] = n <= 1 ? 0.f : 1.f;
4891   //   }
4892   // After:
4893   //   for (const auto n : c10::irange(1, N)) {
4894   //     b[n] = n <= 1 ? 0.f : 1.f;
4895   //   }
4896   test_case_fn(n, b, 1, N, 1, kLE, "b[n] = n<=1 ? 0.f : 1.f;");
4897 
4898   // Before:
4899   //   for (const auto n : c10::irange(1, N)) {
4900   //     b[n] = n <= 0 ? 0.f : 1.f;
4901   //   }
4902   // After:
4903   //   for (const auto n : c10::irange(1, N)) {
4904   //     b[n] = 1.f;
4905   //   }
4906   test_case_fn(n, b, 1, N, 0, kLE, "b[n] = 1.f;");
4907 
4908   // Before:
4909   //   for (const auto n : c10::irange(1, N)) {
4910   //     b[n] = n < 0 ? 0.f : 1.f;
4911   //   }
4912   // After:
4913   //   for (const auto n : c10::irange(1, N)) {
4914   //     b[n] = 1.f;
4915   //   }
4916   test_case_fn(n, b, 1, N, 0, kLT, "b[n] = 1.f;");
4917 
4918   // Before:
4919   //   for (const auto n : c10::irange(1, N)) {
4920   //     b[n] = n < 8 ? 0.f : 1.f;
4921   //   }
4922   // After:
4923   //   for (const auto n : c10::irange(1, N)) {
4924   //     b[n] = 0.f;
4925   //   }
4926   test_case_fn(n, b, 1, N, N, kLT, "b[n] = 0.f;");
4927 
4928   // Before:
4929   //   for (const auto n : c10::irange(1, N)) {
4930   //     b[n] = n <= 7 ? 0.f : 1.f;
4931   //   }
4932   // After:
4933   //   for (const auto n : c10::irange(1, N)) {
4934   //     b[n] = 0.f;
4935   //   }
4936   test_case_fn(n, b, 1, N, N - 1, kLE, "b[n] = 0.f;");
4937 
4938   // Before:
4939   //   for (const auto n : c10::irange(1, N)) {
4940   //     b[n] = n <= 8 ? 0.f : 1.f;
4941   //   }
4942   // After:
4943   //   for (const auto n : c10::irange(1, N)) {
4944   //     b[n] = 0.f;
4945   //   }
4946   test_case_fn(n, b, 1, N, N, kLE, "b[n] = 0.f;");
4947 
4948   // Before:
4949   //   for (const auto n : c10::irange(1, N)) {
4950   //     b[n] = n < 7 ? 0.f : 1.f;
4951   //   }
4952   // After:
4953   //   for (const auto n : c10::irange(1, N)) {
4954   //     b[n] = n < 7 ? 0.f : 1.f;
4955   //   }
4956   test_case_fn(n, b, 1, N, N - 1, kLT, "b[n] = n<7 ? 0.f : 1.f;");
4957 
4958   // Before:
4959   //   for (const auto n : c10::irange(1, N)) {
4960   //     b[n] = n > 0 ? 0.f : 1.f;
4961   //   }
4962   // After:
4963   //   for (const auto n : c10::irange(1, N)) {
4964   //     b[n] = 0.f;
4965   //   }
4966   test_case_fn(n, b, 1, N, 0, kGT, "b[n] = 0.f;");
4967 
4968   // Before:
4969   //   for (const auto n : c10::irange(1, N)) {
4970   //     b[n] = n > 1 ? 0.f : 1.f;
4971   //   }
4972   // After:
4973   //   for (const auto n : c10::irange(1, N)) {
4974   //     b[n] = n > 1 ? 0.f : 1.f;
4975   //   }
4976   test_case_fn(n, b, 1, N, 1, kGT, "b[n] = n>1 ? 0.f : 1.f;");
4977 
4978   // Before:
4979   //   for (const auto n : c10::irange(1, N)) {
4980   //     b[n] = n >= 1 ? 0.f : 1.f;
4981   //   }
4982   // After:
4983   //   for (const auto n : c10::irange(1, N)) {
4984   //     b[n] = 0.f;
4985   //   }
4986   test_case_fn(n, b, 1, N, 1, kGE, "b[n] = 0.f;");
4987 
4988   // Before:
4989   //   for (const auto n : c10::irange(1, N)) {
4990   //     b[n] = n > 7 ? 0.f : 1.f;
4991   //   }
4992   // After:
4993   //   for (const auto n : c10::irange(1, N)) {
4994   //     b[n] = 1.f;
4995   //   }
4996   test_case_fn(n, b, 1, N, N - 1, kGT, "b[n] = 1.f;");
4997 
4998   // Before:
4999   //   for (const auto n : c10::irange(1, N)) {
5000   //     b[n] = n >= 7 ? 0.f : 1.f;
5001   //   }
5002   // After:
5003   //   for (const auto n : c10::irange(1, N)) {
5004   //     b[n] = n >= 7 ? 0.f : 1.f;
5005   //   }
5006   test_case_fn(n, b, 1, N, N - 1, kGE, "b[n] = n>=7 ? 0.f : 1.f;");
5007 
5008   // Before:
5009   //   for (const auto n : c10::irange(1, N)) {
5010   //     b[n] = n > 5 ? 0.f : 1.f;
5011   //   }
5012   // After:
5013   //   for (const auto n : c10::irange(1, N)) {
5014   //     b[n] = n > 5 ? 0.f : 1.f;
5015   //   }
5016   test_case_fn(n, b, 1, N, 5, kGT, "b[n] = n>5 ? 0.f : 1.f;");
5017 
5018   // Before:
5019   //   for (const auto n : c10::irange(1, N)) {
5020   //     b[n] = n >= 5 ? 0.f : 1.f;
5021   //   }
5022   // After:
5023   //   for (const auto n : c10::irange(1, N)) {
5024   //     b[n] = n >= 5 ? 0.f : 1.f;
5025   //   }
5026   test_case_fn(n, b, 1, N, 5, kGE, "b[n] = n>=5 ? 0.f : 1.f;");
5027 
5028   // Before:
5029   //   for (const auto n : c10::irange(1, N)) {
5030   //     b[n] = n > 8 ? 0.f : 1.f;
5031   //   }
5032   // After:
5033   //   for (const auto n : c10::irange(1, N)) {
5034   //     b[n] = 1.f;
5035   //   }
5036   test_case_fn(n, b, 1, N, N, kGT, "b[n] = 1.f;");
5037 
5038   // Before:
5039   //   for (const auto n : c10::irange(1, N)) {
5040   //     b[n] = n >= 8 ? 0.f : 1.f;
5041   //   }
5042   // After:
5043   //   for (const auto n : c10::irange(1, N)) {
5044   //     b[n] = 1.f;
5045   //   }
5046   test_case_fn(n, b, 1, N, N, kGE, "b[n] = 1.f;");
5047 
5048   // Before:
5049   //   for (const auto n : c10::irange(1, 2)) {
5050   //     b[n] = n == 1 ? 0.f : 1.f;
5051   //   }
5052   // After:
5053   //   for (const auto n : c10::irange(1, 2)) {
5054   //     b[1] = 0.f;
5055   //   }
5056   test_case_fn(n, b, 1, 2, 1, kEQ, "b[1] = 0.f;");
5057 
5058   // Before:
5059   //   for (const auto n : c10::irange(1, N)) {
5060   //     b[n] = n == 1 ? 0.f : 1.f;
5061   //   }
5062   // After:
5063   //   for (const auto n : c10::irange(1, N)) {
5064   //     b[n] = n == 1 ? 0.f : 1.f;
5065   //   }
5066   test_case_fn(n, b, 1, N, 1, kEQ, "b[n] = n==1 ? 0.f : 1.f;");
5067 
5068   // Before:
5069   //   for (const auto n : c10::irange(1, N)) {
5070   //     b[n] = n == 0 ? 0.f : 1.f;
5071   //   }
5072   // After:
5073   //   for (const auto n : c10::irange(1, N)) {
5074   //     b[n] = 1.f;
5075   //   }
5076   test_case_fn(n, b, 1, N, 0, kEQ, "b[n] = 1.f;");
5077 
5078   // Before:
5079   //   for (const auto n : c10::irange(1, N)) {
5080   //     b[n] = n == 7 ? 0.f : 1.f;
5081   //   }
5082   // After:
5083   //   for (const auto n : c10::irange(1, N)) {
5084   //     b[n] = n == 7 ? 0.f : 1.f;
5085   //   }
5086   test_case_fn(n, b, 1, N, N - 1, kEQ, "b[n] = n==7 ? 0.f : 1.f;");
5087 
5088   // Before:
5089   //   for (const auto n : c10::irange(1, N)) {
5090   //     b[n] = n == 8 ? 0.f : 1.f;
5091   //   }
5092   // After:
5093   //   for (const auto n : c10::irange(1, N)) {
5094   //     b[n] = 1.f;
5095   //   }
5096   test_case_fn(n, b, 1, N, N, kEQ, "b[n] = 1.f;");
5097 
5098   // Before:
5099   //   for (const auto n : c10::irange(1, N)) {
5100   //     b[n] = n != 1 ? 0.f : 1.f;
5101   //   }
5102   // After:
5103   //   for (const auto n : c10::irange(1, N)) {
5104   //     b[n] = n != 1 ? 0.f : 1.f;
5105   //   }
5106   test_case_fn(n, b, 1, N, 1, kNE, "b[n] = n!=1 ? 0.f : 1.f;");
5107 
5108   // Before:
5109   //   for (const auto n : c10::irange(1, N)) {
5110   //     b[n] = n != 7 ? 0.f : 1.f;
5111   //   }
5112   // After:
5113   //   for (const auto n : c10::irange(1, N)) {
5114   //     b[n] = n != 7 ? 0.f : 1.f;
5115   //   }
5116   test_case_fn(n, b, 1, N, N - 1, kNE, "b[n] = n!=7 ? 0.f : 1.f;");
5117 
5118   // Before:
5119   //   for (const auto n : c10::irange(1, N)) {
5120   //     b[n] = n != 5 ? 0.f : 1.f;
5121   //   }
5122   // After:
5123   //   for (const auto n : c10::irange(1, N)) {
5124   //     b[n] = n != 5 ? 0.f : 1.f;
5125   //   }
5126   test_case_fn(n, b, 1, N, 5, kNE, "b[n] = n!=5 ? 0.f : 1.f;");
5127 
5128   // Before:
5129   //   for (const auto n : c10::irange(1, N)) {
5130   //     b[n] = n != 0 ? 0.f : 1.f;
5131   //   }
5132   // After:
5133   //   for (const auto n : c10::irange(1, N)) {
5134   //     b[n] = 0.f;
5135   //   }
5136   test_case_fn(n, b, 1, N, 0, kNE, "b[n] = 0.f;");
5137 
5138   // Before:
5139   //   for (const auto n : c10::irange(1, N)) {
5140   //     b[n] = n != 8 ? 0.f : 1.f;
5141   //   }
5142   // After:
5143   //   for (const auto n : c10::irange(1, N)) {
5144   //     b[n] = 0.f;
5145   //   }
5146   test_case_fn(n, b, 1, N, N, kNE, "b[n] = 0.f;");
5147 
5148   // Before:
5149   //   for (const auto n : c10::irange(10, 20)) {
5150   //     for(const auto m : c10::irange(30, 40)) {
5151   //       b[n, m] = (n != m) ? 0.f : 1.f;
5152   //     }
5153   //   }
5154   // After:
5155   //   for (const auto n : c10::irange(10, 20)) {
5156   //     for(const auto m : c10::irange(30, 40)) {
5157   //       b[n, m] = 0.f;
5158   //     }
5159   //   }
5160   test_case_nest_loops_fn(n, m, b, 10, 20, 30, 40, kNE, "b[n, m] = 0.f;");
5161   test_case_nest_loops_fn(
5162       n,
5163       m,
5164       b,
5165       var_N + 10,
5166       var_N + 20,
5167       var_N + 30,
5168       var_N + 40,
5169       kNE,
5170       "b[n, m] = 0.f;");
5171   test_case_nest_loops_fn(
5172       n,
5173       m,
5174       b,
5175       var_N + 10,
5176       var_N + 20,
5177       var_M + 30,
5178       var_M + 40,
5179       kNE,
5180       "b[n, m] = n!=m ? 0.f : 1.f;");
5181 
5182   // Before:
5183   //   for (const auto n : c10::irange(30, 40)) {
5184   //     for(const auto m : c10::irange(10, 20)) {
5185   //       b[n, m] = (n != m) ? 0.f : 1.f;
5186   //     }
5187   //   }
5188   // After:
5189   //   for (const auto n : c10::irange(30, 40)) {
5190   //     for(const auto m : c10::irange(10, 20)) {
5191   //       b[n, m] = 0.f;
5192   //     }
5193   //   }
5194   test_case_nest_loops_fn(n, m, b, 30, 40, 10, 20, kNE, "b[n, m] = 0.f;");
5195   test_case_nest_loops_fn(
5196       n,
5197       m,
5198       b,
5199       var_N + 30,
5200       var_N + 40,
5201       var_N + 10,
5202       var_N + 20,
5203       kNE,
5204       "b[n, m] = 0.f;");
5205   test_case_nest_loops_fn(
5206       n,
5207       m,
5208       b,
5209       var_N + 30,
5210       var_N + 40,
5211       var_M + 10,
5212       var_M + 20,
5213       kNE,
5214       "b[n, m] = n!=m ? 0.f : 1.f;");
5215 
5216   // Before:
5217   //   for (const auto n : c10::irange(30, 40)) {
5218   //     for(const auto m : c10::irange(10, 31)) {
5219   //       b[n, m] = (n != m) ? 0.f : 1.f;
5220   //     }
5221   //   }
5222   // After:
5223   //   for (const auto n : c10::irange(30, 40)) {
5224   //     for(const auto m : c10::irange(10, 31)) {
5225   //       b[n, m] = (n != m) ? 0.f : 1.f;
5226   //     }
5227   //   }
5228   test_case_nest_loops_fn(
5229       n, m, b, 30, 40, 10, 31, kNE, "b[n, m] = n!=m ? 0.f : 1.f;");
5230   test_case_nest_loops_fn(
5231       n,
5232       m,
5233       b,
5234       var_N + 30,
5235       var_N + 40,
5236       var_N + 10,
5237       var_N + 31,
5238       kNE,
5239       "b[n, m] = n!=m ? 0.f : 1.f;");
5240   test_case_nest_loops_fn(
5241       n,
5242       m,
5243       b,
5244       var_N + 30,
5245       var_N + 40,
5246       var_M + 10,
5247       var_M + 31,
5248       kNE,
5249       "b[n, m] = n!=m ? 0.f : 1.f;");
5250 
5251   // Before:
5252   //   for (const auto n : c10::irange(10, 31)) {
5253   //     for(const auto m : c10::irange(30, 40)) {
5254   //       b[n, m] = (n != m) ? 0.f : 1.f;
5255   //     }
5256   //   }
5257   // After:
5258   //   for (const auto n : c10::irange(10, 31)) {
5259   //     for(const auto m : c10::irange(30, 40)) {
5260   //       b[n, m] = (n != m) ? 0.f : 1.f;
5261   //     }
5262   //   }
5263   test_case_nest_loops_fn(
5264       n, m, b, 10, 31, 30, 40, kNE, "b[n, m] = n!=m ? 0.f : 1.f;");
5265   test_case_nest_loops_fn(
5266       n,
5267       m,
5268       b,
5269       var_N + 10,
5270       var_N + 31,
5271       var_N + 30,
5272       var_N + 40,
5273       kNE,
5274       "b[n, m] = n!=m ? 0.f : 1.f;");
5275   test_case_nest_loops_fn(
5276       n,
5277       m,
5278       b,
5279       var_N + 10,
5280       var_N + 31,
5281       var_M + 30,
5282       var_M + 40,
5283       kNE,
5284       "b[n, m] = n!=m ? 0.f : 1.f;");
5285 
5286   // Before:
5287   //   for (const auto n : c10::irange(10, 20)) {
5288   //     for(const auto m : c10::irange(30, 40)) {
5289   //       b[n, m] = (n < m) ? 0.f : 1.f;
5290   //     }
5291   //   }
5292   // After:
5293   //   for (const auto n : c10::irange(10, 20)) {
5294   //     for(const auto m : c10::irange(30, 40)) {
5295   //       b[n, m] = 0.f;
5296   //     }
5297   //   }
5298   test_case_nest_loops_fn(n, m, b, 10, 20, 30, 40, kLT, "b[n, m] = 0.f;");
5299   test_case_nest_loops_fn(
5300       n,
5301       m,
5302       b,
5303       var_N + 10,
5304       var_N + 20,
5305       var_N + 30,
5306       var_N + 40,
5307       kLT,
5308       "b[n, m] = 0.f;");
5309   test_case_nest_loops_fn(
5310       n,
5311       m,
5312       b,
5313       var_N + 10,
5314       var_N + 20,
5315       var_M + 30,
5316       var_M + 40,
5317       kLT,
5318       "b[n, m] = n<m ? 0.f : 1.f;");
5319 
5320   // Before:
5321   //   for (const auto n : c10::irange(30, 40)) {
5322   //     for(const auto m : c10::irange(10, 31)) {
5323   //       b[n, m] = (n < m) ? 0.f : 1.f;
5324   //     }
5325   //   }
5326   // After:
5327   //   for (const auto n : c10::irange(30, 40)) {
5328   //     for(const auto m : c10::irange(10, 31)) {
5329   //       b[n, m] = 1.f;
5330   //     }
5331   //   }
5332   test_case_nest_loops_fn(n, m, b, 30, 40, 10, 31, kLT, "b[n, m] = 1.f;");
5333   test_case_nest_loops_fn(
5334       n,
5335       m,
5336       b,
5337       var_N + 30,
5338       var_N + 40,
5339       var_N + 10,
5340       var_N + 31,
5341       kLT,
5342       "b[n, m] = 1.f;");
5343   test_case_nest_loops_fn(
5344       n,
5345       m,
5346       b,
5347       var_N + 30,
5348       var_N + 40,
5349       var_M + 10,
5350       var_M + 31,
5351       kLT,
5352       "b[n, m] = n<m ? 0.f : 1.f;");
5353 
5354   // Before:
5355   //   for (const auto n : c10::irange(30, 40)) {
5356   //     for(const auto m : c10::irange(10, 20)) {
5357   //       b[n, m] = (n > m) ? 0.f : 1.f;
5358   //     }
5359   //   }
5360   // After:
5361   //   for (const auto n : c10::irange(30, 40)) {
5362   //     for(const auto m : c10::irange(10, 20)) {
5363   //       b[n, m] = 0.f;
5364   //     }
5365   //   }
5366   test_case_nest_loops_fn(n, m, b, 30, 40, 10, 20, kGT, "b[n, m] = 0.f;");
5367   test_case_nest_loops_fn(
5368       n,
5369       m,
5370       b,
5371       var_N + 30,
5372       var_N + 40,
5373       var_N + 10,
5374       var_N + 20,
5375       kGT,
5376       "b[n, m] = 0.f;");
5377   test_case_nest_loops_fn(
5378       n,
5379       m,
5380       b,
5381       var_N + 30,
5382       var_N + 40,
5383       var_M + 10,
5384       var_M + 20,
5385       kGT,
5386       "b[n, m] = n>m ? 0.f : 1.f;");
5387 
5388   // Before:
5389   //   for (const auto n : c10::irange(10, 31)) {
5390   //     for(const auto m : c10::irange(30, 40)) {
5391   //       b[n, m] = (n > m) ? 0.f : 1.f;
5392   //     }
5393   //   }
5394   // After:
5395   //   for (const auto n : c10::irange(10, 31)) {
5396   //     for(const auto m : c10::irange(30, 40)) {
5397   //       b[n, m] = 1.f;
5398   //     }
5399   //   }
5400   test_case_nest_loops_fn(n, m, b, 10, 31, 30, 40, kGT, "b[n, m] = 1.f;");
5401   test_case_nest_loops_fn(
5402       n,
5403       m,
5404       b,
5405       var_N + 10,
5406       var_N + 31,
5407       var_N + 30,
5408       var_N + 40,
5409       kGT,
5410       "b[n, m] = 1.f;");
5411   test_case_nest_loops_fn(
5412       n,
5413       m,
5414       b,
5415       var_N + 10,
5416       var_N + 31,
5417       var_M + 30,
5418       var_M + 40,
5419       kGT,
5420       "b[n, m] = n>m ? 0.f : 1.f;");
5421 
5422   // Before:
5423   //   for (const auto n : c10::irange(30, 40)) {
5424   //     for(const auto m : c10::irange(10, 31)) {
5425   //       b[n, m] = (n >= m) ? 0.f : 1.f;
5426   //     }
5427   //   }
5428   // After:
5429   //   for (const auto n : c10::irange(30, 40)) {
5430   //     for(const auto m : c10::irange(10, 31)) {
5431   //       b[n, m] = 0.f;
5432   //     }
5433   //   }
5434   test_case_nest_loops_fn(n, m, b, 30, 40, 10, 31, kGE, "b[n, m] = 0.f;");
5435   test_case_nest_loops_fn(
5436       n,
5437       m,
5438       b,
5439       var_N + 30,
5440       var_N + 40,
5441       var_N + 10,
5442       var_N + 31,
5443       kGE,
5444       "b[n, m] = 0.f;");
5445   test_case_nest_loops_fn(
5446       n,
5447       m,
5448       b,
5449       var_N + 30,
5450       var_N + 40,
5451       var_M + 10,
5452       var_M + 31,
5453       kGE,
5454       "b[n, m] = n>=m ? 0.f : 1.f;");
5455 
5456   // Before:
5457   //   for (const auto n : c10::irange(10, 20)) {
5458   //     for(const auto m : c10::irange(30, 40)) {
5459   //       b[n, m] = (n >= m) ? 0.f : 1.f;
5460   //     }
5461   //   }
5462   // After:
5463   //   for (const auto n : c10::irange(10, 20)) {
5464   //     for(const auto m : c10::irange(30, 40)) {
5465   //       b[n, m] = 1.f;
5466   //     }
5467   //   }
5468   test_case_nest_loops_fn(n, m, b, 10, 20, 30, 40, kGE, "b[n, m] = 1.f;");
5469   test_case_nest_loops_fn(
5470       n,
5471       m,
5472       b,
5473       var_N + 10,
5474       var_N + 20,
5475       var_N + 30,
5476       var_N + 40,
5477       kGE,
5478       "b[n, m] = 1.f;");
5479   test_case_nest_loops_fn(
5480       n,
5481       m,
5482       b,
5483       var_N + 10,
5484       var_N + 20,
5485       var_M + 30,
5486       var_M + 40,
5487       kGE,
5488       "b[n, m] = n>=m ? 0.f : 1.f;");
5489 
5490   // Before:
5491   //   for (const auto n : c10::irange(10, 31)) {
5492   //     for(const auto m : c10::irange(30, 40)) {
5493   //       b[n, m] = (n <= m) ? 0.f : 1.f;
5494   //     }
5495   //   }
5496   // After:
5497   //   for (const auto n : c10::irange(10, 31)) {
5498   //     for(const auto m : c10::irange(30, 40)) {
5499   //       b[n, m] = 0.f;
5500   //     }
5501   //   }
5502   test_case_nest_loops_fn(n, m, b, 10, 31, 30, 40, kLE, "b[n, m] = 0.f;");
5503   test_case_nest_loops_fn(
5504       n,
5505       m,
5506       b,
5507       var_N + 10,
5508       var_N + 31,
5509       var_N + 30,
5510       var_N + 40,
5511       kLE,
5512       "b[n, m] = 0.f;");
5513   test_case_nest_loops_fn(
5514       n,
5515       m,
5516       b,
5517       var_N + 10,
5518       var_N + 31,
5519       var_M + 30,
5520       var_M + 40,
5521       kLE,
5522       "b[n, m] = n<=m ? 0.f : 1.f;");
5523 
5524   // Before:
5525   //   for (const auto n : c10::irange(30, 40)) {
5526   //     for(const auto m : c10::irange(10, 20)) {
5527   //       b[n, m] = (n <= m) ? 0.f : 1.f;
5528   //     }
5529   //   }
5530   // After:
5531   //   for (const auto n : c10::irange(30, 40)) {
5532   //     for(const auto m : c10::irange(10, 20)) {
5533   //       b[n, m] = 0.f;
5534   //     }
5535   //   }
5536   test_case_nest_loops_fn(n, m, b, 30, 40, 10, 20, kLE, "b[n, m] = 1.f;");
5537   test_case_nest_loops_fn(
5538       n,
5539       m,
5540       b,
5541       var_N + 30,
5542       var_N + 40,
5543       var_N + 10,
5544       var_N + 20,
5545       kLE,
5546       "b[n, m] = 1.f;");
5547   test_case_nest_loops_fn(
5548       n,
5549       m,
5550       b,
5551       var_N + 30,
5552       var_N + 40,
5553       var_M + 10,
5554       var_M + 20,
5555       kLE,
5556       "b[n, m] = n<=m ? 0.f : 1.f;");
5557 }
5558 
TEST(Simplify,CompareSelectCondAlwaysInLoopBounds)5559 TEST(Simplify, CompareSelectCondAlwaysInLoopBounds) {
5560   // Before:
5561   //   for (const auto n : c10::irange(1, N)) {
5562   //     b[n] = n < 1 ? 0.f : 1.f;
5563   //   }
5564   // After:
5565   //   for (const auto n : c10::irange(1, N)) {
5566   //     b[n] = 1.f;
5567   //   }
5568   constexpr int N = 8;
5569   BufHandle b("b", {N}, kFloat);
5570   VarHandle n("n", kInt);
5571   StmtPtr s = For::make(
5572       n, 1, N, b.store({n}, CompareSelect::make(n, 1, 0.f, 1.0f, kLT)));
5573   s = IRSimplifier::simplify(s);
5574   std::ostringstream oss;
5575   oss << *s;
5576   torch::jit::testing::FileCheck().run(
5577       R"IR(
5578 # CHECK: b[n] = 1.f;
5579 )IR",
5580       oss.str());
5581 }
5582 
TEST(Simplify,IfThenCondAlwaysInLoopBounds)5583 TEST(Simplify, IfThenCondAlwaysInLoopBounds) {
5584   // Before:
5585   //   for (const auto n : c10::irange(1, N)) {
5586   //     b[n] = IfThenElse(n < 1 ? 1 : 0, 0.f, 1.f);
5587   //   }
5588   // After:
5589   //   for (const auto n : c10::irange(1, N)) {
5590   //     b[n] = 1.f;
5591   //   }
5592   constexpr int N = 8;
5593   BufHandle b("b", {N}, kFloat);
5594   VarHandle n("n", kInt);
5595   StmtPtr s =
5596       For::make(n, 1, N, b.store({n}, IfThenElse::make(n < 1, 0.f, 1.0f)));
5597   s = IRSimplifier::simplify(s);
5598   std::ostringstream oss;
5599   oss << *s;
5600   torch::jit::testing::FileCheck().run(
5601       R"IR(
5602 # CHECK: b[n] = 1.f;
5603 )IR",
5604       oss.str());
5605 }
5606 
TEST(Simplify,MultiClauseCondAlwaysInLoopBounds)5607 TEST(Simplify, MultiClauseCondAlwaysInLoopBounds) {
5608   // This test mimics the unpadded region of a conv2d.  We want to remove any
5609   // conditional that is provably satisfied (or unsatisfied) by the entire loop
5610   // range.
5611   // Before:
5612   //   for (const auto i : c10::irange(1, 7)) {
5613   //     for (const auto j : c10::irange(1, 7)) {
5614   //       b[i, j] = IfThenElse(
5615   //         j>=7 ? 1 : (i>=7 ? 1 : (j<1 ? 1 : (i<1 ? 1 : 0))), 0.f, 1.f);
5616   // After:
5617   //   for (const auto i : c10::irange(1, 7)) {
5618   //     for (const auto j : c10::irange(1, 7)) {
5619   //       b[i, j] = 1.f;
5620   constexpr int N = 8;
5621   BufHandle b("b", {N, N}, kFloat);
5622   VarHandle i("i", kInt);
5623   VarHandle j("j", kInt);
5624   auto csel = CompareSelect::make(i, 1, kLT);
5625   csel = CompareSelect::make(j, 1, 1, csel, kLT);
5626   csel = CompareSelect::make(i, N - 1, 1, csel, kGE);
5627   csel = CompareSelect::make(j, N - 1, 1, csel, kGE);
5628   StmtPtr s = b.store({i, j}, IfThenElse::make(csel, 0.f, 1.0f));
5629   s = For::make(j, 1, N - 1, s);
5630   s = For::make(i, 1, N - 1, s);
5631   s = IRSimplifier::simplify(s);
5632   std::ostringstream oss;
5633   oss << *s;
5634   torch::jit::testing::FileCheck().run(
5635       R"IR(
5636 # CHECK: b[i, j] = 1.f;
5637 )IR",
5638       oss.str());
5639 }
5640 
TEST(Simplify,DISABLED_SimplifyLoopBounds)5641 TEST(Simplify, DISABLED_SimplifyLoopBounds) {
5642   // This test mimics the padded region of a conv2d.  We want to adjust the
5643   // loop bounds such that the condition will be always met.  Note that this
5644   // could be solved by peeling, and applying the range-based conditional
5645   // simplification in the previous tests.
5646   // Before:
5647   //   for (const auto i : c10::irange(3)) {
5648   //     for (const auto j : c10::irange(3)) {
5649   //       b[i, j] = (b[i, j]) + (IfThenElse(
5650   //         j>=7 ? 1 : (i>=7 ? 1 : (j<1 ? 1 : (i<1 ? 1 : 0))), 0.f, a[i, j]));
5651   // After:
5652   //   for (const auto i : c10::irange(1, 3)) {
5653   //     for (const auto j : c10::irange(1, 3)) {
5654   //       b[i, j] = (b[i, j]) + 1.f;
5655   constexpr int N = 8;
5656   constexpr int K = 3;
5657   BufHandle a("a", {N, N}, kFloat);
5658   BufHandle b("b", {N, N}, kFloat);
5659   VarHandle i("i", kInt);
5660   VarHandle j("j", kInt);
5661   auto csel = CompareSelect::make(i, 1, kLT);
5662   csel = CompareSelect::make(j, 1, 1, csel, kLT);
5663   csel = CompareSelect::make(i, N - 1, 1, csel, kGE);
5664   csel = CompareSelect::make(j, N - 1, 1, csel, kGE);
5665   StmtPtr s = b.store(
5666       {i, j}, b.load({i, j}) + IfThenElse::make(csel, 0.f, a.load({i, j})));
5667   s = For::make(j, 0, K, s);
5668   s = For::make(i, 0, K, s);
5669   s = IRSimplifier::simplify(s);
5670   std::ostringstream oss;
5671   oss << *s;
5672   torch::jit::testing::FileCheck().run(
5673       R"IR(
5674 # CHECK: for (const auto i : c10::irange(1, 3)) {
5675 # CHECK: for (const auto j : c10::irange(1, 3)) {
5676 # CHECK-NOT: IfThenElse
5677 )IR",
5678       oss.str());
5679 }
5680 
5681 } // namespace jit
5682 } // namespace torch
5683