xref: /aosp_15_r20/external/pytorch/test/cpp/tensorexpr/test_reductions.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 
3 #include <limits>
4 #include <memory>
5 #include <sstream>
6 #include <stdexcept>
7 #include <unordered_map>
8 
9 #include <test/cpp/tensorexpr/test_base.h>
10 
11 #include <c10/util/irange.h>
12 #include <test/cpp/tensorexpr/padded_buffer.h>
13 #include <torch/csrc/jit/tensorexpr/analysis.h>
14 #include <torch/csrc/jit/tensorexpr/eval.h>
15 #include <torch/csrc/jit/tensorexpr/ir.h>
16 #include <torch/csrc/jit/tensorexpr/ir_printer.h>
17 #include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
18 #include <torch/csrc/jit/tensorexpr/loopnest.h>
19 #include <torch/csrc/jit/tensorexpr/tensor.h>
20 #include <torch/csrc/jit/testing/file_check.h>
21 
22 namespace torch {
23 namespace jit {
24 
25 using namespace torch::jit::tensorexpr;
26 
TEST(Reductions,ReduceSum0D_1)27 TEST(Reductions, ReduceSum0D_1) {
28   const int M = 10;
29 
30   BufHandle b("b", {M}, kFloat);
31   std::vector<float> in(M);
32   for (const auto j : c10::irange(M)) {
33     in[j] = j;
34   }
35 
36   std::vector<float> out(M, -1.f);
37 
38   Tensor c = Reduce("sum", {M}, Sum(), b, {});
39   LoopNest loop({c});
40   loop.prepareForCodegen();
41   StmtPtr s = loop.root_stmt();
42   s = IRSimplifier::simplify(s);
43 
44   SimpleIREvaluator cg(s, {b, c});
45 
46   cg.call({in, out});
47   for (const auto i : c10::irange(M)) {
48     ASSERT_EQ(out[i], in[i]);
49   }
50 }
51 
TEST(Reductions,ReduceSum0D_2)52 TEST(Reductions, ReduceSum0D_2) {
53   BufHandle b("b", {}, kFloat);
54   std::vector<float> in(1);
55   in[0] = 77.7;
56 
57   std::vector<float> out(1, -1.f);
58 
59   Tensor c = Reduce("sum", {}, Sum(), b, {});
60   LoopNest loop({c});
61   loop.prepareForCodegen();
62   StmtPtr s = loop.root_stmt();
63   s = IRSimplifier::simplify(s);
64 
65   SimpleIREvaluator cg(s, {b, c});
66 
67   cg.call({in, out});
68   ASSERT_EQ(out[0], in[0]);
69 }
70 
71 // Sum an array to a single value.
TEST(Reductions,ReduceSum1D)72 TEST(Reductions, ReduceSum1D) {
73   BufHandle b("b", {10}, kFloat);
74   std::vector<float> in(10);
75   for (const auto j : c10::irange(10)) {
76     in[j] = j;
77   }
78 
79   std::vector<float> out(1, -1.f);
80 
81   Tensor c = Reduce("sum", {}, Sum(), b, {10});
82   LoopNest loop({c});
83   loop.prepareForCodegen();
84   StmtPtr s = loop.root_stmt();
85   s = IRSimplifier::simplify(s);
86 
87   SimpleIREvaluator cg(s, {b, c});
88 
89   cg.call({in, out});
90   ASSERT_EQ(out[0], 45);
91 }
92 // Sum a 2D tensor to a 1D tensor with dynamic shapes.
TEST(Reductions,ReduceSum2D)93 TEST(Reductions, ReduceSum2D) {
94   const int M = 3;
95   const int N = 7;
96 
97   VarHandle m("m", kInt);
98   VarHandle n("n", kInt);
99 
100   BufHandle b("b", {m, n}, kFloat);
101   std::vector<float> in(M * N);
102   for (const auto i : c10::irange(M)) {
103     for (const auto j : c10::irange(N)) {
104       in[i * N + j] = j;
105     }
106   }
107 
108   std::vector<float> out(M, -1.f);
109 
110   Tensor c = Reduce("sum", {M}, Sum(), b, {N});
111   LoopNest loop({c});
112   loop.prepareForCodegen();
113   StmtPtr s = loop.root_stmt();
114   s = IRSimplifier::simplify(s);
115 
116   SimpleIREvaluator cg(s, {b, c, n, m});
117 
118   cg.call({in, out, 5, 7});
119 
120   float expected = 0;
121   for (const auto i : c10::irange(N)) {
122     // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
123     expected += i;
124   }
125 
126   for (const auto i : c10::irange(M)) {
127     ASSERT_EQ(out[i], expected);
128   }
129 }
130 
131 // Sum a 3D tensor to both a 2D and 1D tensor, then reduce the 2D tensor flat to
132 // check our work.
TEST(Reductions,ReduceSum3D)133 TEST(Reductions, ReduceSum3D) {
134   const int M = 10;
135   VarHandle m("m", kInt);
136 
137   BufHandle b("b", {2, 3, m}, kFloat);
138 
139   Tensor c = Reduce("sum", {2, 3}, Sum(), b, {m});
140   LoopNest loop({c});
141   loop.prepareForCodegen();
142   StmtPtr s = loop.root_stmt();
143   s = IRSimplifier::simplify(s);
144 
145   SimpleIREvaluator cg(s, {b, c, m});
146 
147   std::vector<float> bData(2 * 3 * M, 0);
148   std::vector<float> cData(2 * 3, 6.0f);
149   std::vector<float> dData(2, 1.0f);
150   std::vector<float> eData(2, 1.0f);
151 
152   for (int i = 0; i < 2 * 3; ++i) {
153     for (const auto j : c10::irange(M)) {
154       bData[i * M + j] = j;
155     }
156   }
157 
158   cg.call({bData, cData, M});
159   float expected = 0;
160   for (const auto i : c10::irange(M)) {
161     // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
162     expected += i;
163   }
164 
165   for (int i = 0; i < 2 * 3; ++i) {
166     ASSERT_EQ(cData[i], expected);
167   }
168 
169   Tensor d = Reduce("sum2", {2}, Sum(), b, {3, m});
170   LoopNest loop2({d});
171   loop2.prepareForCodegen();
172   StmtPtr s2 = loop2.root_stmt();
173   s2 = IRSimplifier::simplify(s2);
174 
175   SimpleIREvaluator cg2(s2, {b, d, m});
176   cg2.call({bData, dData, M});
177 
178   // We're combining an additional dimension of 3, so the sum is 3x.
179   expected = expected * 3;
180 
181   for (const auto i : c10::irange(2)) {
182     ASSERT_EQ(dData[i], expected);
183   }
184 
185   // This is the same as just reducing the original result across that axis.
186   BufHandle c_buf(c.buf());
187   Tensor e = Reduce("sum3", {2}, Sum(), c_buf, {3});
188   LoopNest loop3({e});
189   loop3.prepareForCodegen();
190   StmtPtr s3 = loop3.root_stmt();
191   s3 = IRSimplifier::simplify(s3);
192 
193   SimpleIREvaluator cg3(s3, {c, e});
194   cg3.call({cData, eData});
195 
196   for (const auto i : c10::irange(2)) {
197     ASSERT_EQ(eData[i], expected);
198   }
199 }
200 
201 // Sum a large (10 D) Tensor 5 dimensions in.
TEST(Reductions,ReduceSum10D)202 TEST(Reductions, ReduceSum10D) {
203   BufHandle in_("in_", {2, 3, 2, 3, 2, 3, 2, 3, 2, 3}, kFloat);
204   const int InputSize = 2 * 3 * 2 * 3 * 2 * 3 * 2 * 3 * 2 * 3;
205   BufHandle out_("out_", {2, 3, 2, 3, 2}, kFloat);
206   const int OutputSize = 2 * 3 * 2 * 3 * 2;
207 
208   std::vector<float> in(InputSize, 1.f);
209   std::vector<float> out(OutputSize, -1.f);
210 
211   Tensor c = Reduce("sum", {2, 3, 2, 3, 2}, Sum(), in_, {3, 2, 3, 2, 3});
212   LoopNest loop({c});
213   loop.prepareForCodegen();
214   StmtPtr s = loop.root_stmt();
215   s = IRSimplifier::simplify(s);
216 
217   SimpleIREvaluator cg(s, {in_, c});
218 
219   cg.call({in, out});
220 
221   // NOLINTNEXTLINE(bugprone-integer-division)
222   float expected = InputSize / OutputSize;
223   for (const auto i : c10::irange(OutputSize)) {
224     ASSERT_EQ(out[i], expected);
225   }
226 }
227 
228 // Reduce via Mul rather than Add using a custom Reducer.
TEST(Reductions,ReduceProduct)229 TEST(Reductions, ReduceProduct) {
230   const int M = 4;
231   const int N = 4;
232 
233   BufHandle b("b", {M, N}, kFloat);
234   std::vector<float> in(M * N);
235   for (const auto i : c10::irange(M)) {
236     for (const auto j : c10::irange(N)) {
237       in[i * N + j] = 2 + j;
238     }
239   }
240 
241   std::vector<float> out(M, -1.f);
242 
243   Reducer product(
244       ExprHandle(1.f), [](ExprHandle a, ExprHandle b) { return a * b; });
245 
246   Tensor c = Reduce("product", {M}, product, b, {N});
247   LoopNest loop({c});
248   loop.prepareForCodegen();
249   StmtPtr s = loop.root_stmt();
250   s = IRSimplifier::simplify(s);
251 
252   SimpleIREvaluator cg(s, {b, c});
253 
254   cg.call({in, out});
255 
256   float expected = 1;
257   for (const auto i : c10::irange(N)) {
258     // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
259     expected *= 2 + i;
260   }
261 
262   for (const auto i : c10::irange(M)) {
263     ASSERT_EQ(out[i], expected);
264   }
265 }
266 
267 // Maximum reductions.
TEST(Reductions,ReduceMax)268 TEST(Reductions, ReduceMax) {
269   BufHandle in_("b", {10}, kFloat);
270 
271   std::vector<float> in(10);
272   std::vector<float> out(1, -1.f);
273   for (const auto j : c10::irange(10)) {
274     in[j] = j;
275   }
276 
277   Tensor dm1 = Reduce("max", {}, Maximum(kFloat), in_, {10});
278 
279   LoopNest loop({dm1});
280   loop.prepareForCodegen();
281   StmtPtr s = loop.root_stmt();
282   s = IRSimplifier::simplify(s);
283   SimpleIREvaluator cg(s, {in_, dm1});
284 
285   cg.call({in, out});
286 
287   ASSERT_EQ(out[0], 9);
288 
289   BufHandle in2_("b", {2, 5}, kFloat);
290   std::vector<float> out2(2, -1.f);
291 
292   Tensor m2d = Reduce("max", {2}, Maximum(kFloat), in2_, {5});
293 
294   LoopNest loop2({m2d});
295   loop2.prepareForCodegen();
296   s = loop2.root_stmt();
297   s = IRSimplifier::simplify(s);
298 
299   SimpleIREvaluator cg2(s, {in2_, m2d});
300   cg2.call({in, out2});
301 
302   ASSERT_EQ(out2[0], 4);
303   ASSERT_EQ(out2[1], 9);
304 }
305 
306 // Minimum reduction, with custom initialization.
TEST(Reductions,ReduceMinCustomInitializer)307 TEST(Reductions, ReduceMinCustomInitializer) {
308   VarHandle minInit("minInit", kFloat);
309   BufHandle in_("b", {10}, kFloat);
310 
311   std::vector<float> in(10);
312   std::vector<float> out(1, -1.f);
313   for (const auto j : c10::irange(10)) {
314     in[j] = 10 + j;
315   }
316 
317   Tensor min = Reduce(
318       "min",
319       {},
320       Minimum(ExprHandle(minInit)),
321       [&](ParameterList& v) { return in_.load(v); },
322       {10});
323 
324   LoopNest loop({min});
325   loop.prepareForCodegen();
326   StmtPtr s = loop.root_stmt();
327   s = IRSimplifier::simplify(s);
328 
329   SimpleIREvaluator cg(s, {in_, min, minInit});
330 
331   // Works normally (note that out data starts lower than the correct
332   // minimum).
333   cg.call({in, out, std::numeric_limits<float>::max()});
334   ASSERT_EQ(out[0], 10);
335 
336   // With an initalizer lower than the min, that's the min.
337   cg.call({in, out, 5.f});
338   ASSERT_EQ(out[0], 5);
339 }
340 
341 // Example implementation of Any/All.
342 // TODO: this is very awkward without logical And/Or operators.
TEST(Reductions,ReduceAnyAll)343 TEST(Reductions, ReduceAnyAll) {
344   VarHandle searchValue("searchValue", kInt);
345   BufHandle b("b", {4, 10}, kInt);
346 
347   Reducer anyEqSV(ExprHandle(0), [](ExprHandle a, ExprHandle b) {
348     return CompareSelect::make(a, 1, 1, b, kEQ);
349   });
350 
351   Tensor any = Reduce(
352       "anyEqual",
353       {4},
354       anyEqSV,
355       [&](const auto& i, const auto& j) {
356         return CompareSelect::make(b.load(i, j), searchValue, kEQ);
357       },
358       {10});
359 
360   LoopNest loop({any});
361   loop.prepareForCodegen();
362   StmtPtr s = loop.root_stmt();
363   s = IRSimplifier::simplify(s);
364 
365   SimpleIREvaluator cg(s, {b, any, searchValue});
366 
367   std::vector<int> in(40, 0);
368   std::vector<int> out(4, 0);
369 
370   // input has 0-39 in 4 rows.
371   for (const auto i : c10::irange(40)) {
372     in[i] = i;
373   }
374   cg.call({in, out, 1});
375 
376   // only the first row has 1
377   ASSERT_EQ(out[0], 1);
378   ASSERT_EQ(out[1], 0);
379   ASSERT_EQ(out[2], 0);
380   ASSERT_EQ(out[3], 0);
381 
382   cg.call({in, out, 15});
383 
384   // 15 in the 3rd row
385   ASSERT_EQ(out[0], 0);
386   ASSERT_EQ(out[1], 1);
387   ASSERT_EQ(out[2], 0);
388   ASSERT_EQ(out[3], 0);
389 
390   Reducer allGTSV(ExprHandle(1), [](ExprHandle a, ExprHandle b) {
391     return CompareSelect::make(a, 0, 0, b, kEQ);
392   });
393 
394   Tensor allGreaterThan = Reduce(
395       "allGreaterThan",
396       {4},
397       allGTSV,
398       [&](const auto& i, const auto& j) {
399         return CompareSelect::make(b.load(i, j), searchValue, kGT);
400       },
401       {10});
402 
403   LoopNest loop2({allGreaterThan});
404   loop2.prepareForCodegen();
405   s = loop2.root_stmt();
406   s = IRSimplifier::simplify(s);
407 
408   SimpleIREvaluator cg2(s, {b, allGreaterThan, searchValue});
409 
410   cg2.call({in, out, 11});
411 
412   // 11 is in row 2.
413   ASSERT_EQ(out[0], 0);
414   ASSERT_EQ(out[1], 0);
415   ASSERT_EQ(out[2], 1);
416   ASSERT_EQ(out[3], 1);
417 
418   cg2.call({in, out, -3});
419 
420   // All are positive.
421   ASSERT_EQ(out[0], 1);
422   ASSERT_EQ(out[1], 1);
423   ASSERT_EQ(out[2], 1);
424   ASSERT_EQ(out[3], 1);
425 }
426 
TEST(Reductions,ReduceMatmul2D)427 TEST(Reductions, ReduceMatmul2D) {
428   BufHandle tA("tA", {3, 2}, kFloat);
429   BufHandle tB("tB", {2, 3}, kFloat);
430 
431   std::vector<float> tA_(6);
432   std::vector<float> tB_(6);
433 
434   std::vector<float> out(9, -1.f);
435   for (const auto i : c10::irange(3)) {
436     for (const auto j : c10::irange(2)) {
437       tA_[i * 2 + j] = i * 2 + j;
438       tB_[j * 3 + i] = i * 2 + j;
439     }
440   }
441 
442   Tensor mm = Reduce(
443       "mm",
444       {3, 3},
445       Sum(),
446       [&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& k) {
447         return tA.load(m, k) * tB.load(k, n);
448       },
449       {2});
450 
451   LoopNest loop({mm});
452   loop.prepareForCodegen();
453   StmtPtr s = loop.root_stmt();
454   s = IRSimplifier::simplify(s);
455 
456   SimpleIREvaluator cg(s, {tA, tB, mm});
457   cg.call({tA_, tB_, out});
458 
459   std::vector<float> expected(
460       {1.f, 3.f, 5.f, 3.f, 13.f, 23.f, 5.f, 23.f, 41.f});
461 
462   for (const auto i : c10::irange(9)) {
463     ASSERT_EQ(out[i], expected[i]);
464   }
465 }
466 
TEST(Reductions,ReduceRfactorLike)467 TEST(Reductions, ReduceRfactorLike) {
468   BufHandle in("in", {10, 10}, kFloat);
469   std::vector<float> in_(100);
470   for (const auto i : c10::irange(100)) {
471     in_[i] = i;
472   }
473   std::vector<float> in_rf_(10, -2.f);
474   std::vector<float> out(1, -1.f);
475 
476   Tensor l1 = Reduce("l1", {10}, Sum(), in, {10});
477   BufHandle in_rf(l1.buf());
478 
479   Tensor l2 = Reduce("l2", {}, Sum(), in_rf, {10});
480 
481   LoopNest loop({l1, l2});
482   loop.prepareForCodegen();
483   StmtPtr s = loop.root_stmt();
484   s = IRSimplifier::simplify(s);
485 
486   SimpleIREvaluator cg(s, {in, l1, l2});
487   cg.call({in_, in_rf_, out});
488 
489   ASSERT_EQ(out[0], 99 * 50);
490 }
491 
TEST(Reductions,ReduceAsProducer)492 TEST(Reductions, ReduceAsProducer) {
493   const int M = 10;
494   VarHandle m("m", kInt);
495 
496   BufHandle a("a", {2, 3}, kFloat);
497   BufHandle b("b", {2, 3, m}, kFloat);
498 
499   Tensor c = Reduce("sum", {2, 3}, Sum(), b, {m});
500   Tensor d =
501       Compute("scale", {2, 3}, [&](const VarHandle& l, const VarHandle& n) {
502         return c.load(l, n) * a.load(l, n);
503       });
504   LoopNest loop({d}, {c, d});
505   loop.prepareForCodegen();
506   StmtPtr s = loop.root_stmt();
507   s = IRSimplifier::simplify(s);
508 
509   SimpleIREvaluator cg(s, {a, b, d, m});
510 
511   std::vector<float> aData(2 * 3, 0);
512   std::vector<float> bData(2 * 3 * M, 0);
513   std::vector<float> dData(2 * 3, 6.0f);
514 
515   for (int i = 0; i < 2 * 3; ++i) {
516     aData[i] = 6 - i;
517     for (const auto j : c10::irange(M)) {
518       bData[i * M + j] = j;
519     }
520   }
521 
522   cg.call({aData, bData, dData, M});
523   float expected = 0;
524   for (const auto i : c10::irange(M)) {
525     // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
526     expected += i;
527   }
528   for (int i = 0; i < 2 * 3; ++i) {
529     ASSERT_EQ(dData[i], expected * (6 - i));
530   }
531 }
532 
TEST(Reductions,ReduceAsConsumer)533 TEST(Reductions, ReduceAsConsumer) {
534   const int M = 10;
535   VarHandle m("m", kInt);
536 
537   BufHandle a("a", {2, 3, m}, kFloat);
538   BufHandle b("b", {2, 3, m}, kFloat);
539 
540   Tensor c = Compute(
541       "scale",
542       {2, 3, m},
543       [&](const VarHandle& l, const VarHandle& n, const VarHandle& m) {
544         return b.load(l, n, m) * a.load(l, n, m);
545       });
546   Tensor d = Reduce("sum", {2}, Sum(), c, {3, m});
547   LoopNest loop({d}, {c, d});
548   loop.prepareForCodegen();
549   StmtPtr s = loop.root_stmt();
550   s = IRSimplifier::simplify(s);
551 
552   SimpleIREvaluator cg(s, {a, b, d, m});
553 
554   std::vector<float> aData(2 * 3 * M, 0);
555   std::vector<float> bData(2 * 3 * M, 0);
556   std::vector<float> dData(2, 6.0f);
557 
558   for (int i = 0; i < 2 * 3; ++i) {
559     for (const auto j : c10::irange(M)) {
560       bData[i * M + j] = j + 1;
561       aData[i * M + j] = 6 - i;
562     }
563   }
564 
565   cg.call({aData, bData, dData, M});
566   // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
567   float expected[2] = {0, 0};
568   for (const auto i : c10::irange(2)) {
569     for (const auto j : c10::irange(3)) {
570       for (const auto k : c10::irange(M)) {
571         // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
572         expected[i] += (k + 1) * (6 - (i * 3 + j));
573       }
574     }
575   }
576 
577   for (const auto i : c10::irange(2)) {
578     ASSERT_EQ(dData[i], expected[i]);
579   }
580 }
581 
TEST(Reductions,SplitReduceAxis)582 TEST(Reductions, SplitReduceAxis) {
583   BufHandle in("in", {16, 8}, kFloat);
584 
585   std::vector<float> in_(16 * 8);
586   for (const auto i : c10::irange(16)) {
587     for (const auto j : c10::irange(8)) {
588       in_[i * 8 + j] = i;
589     }
590   }
591   std::vector<float> out(16, -1.f);
592 
593   Tensor tensor = Reduce("sum", {16}, Sum(), in, {8});
594   LoopNest l({tensor});
595   std::vector<ForPtr> loops = l.getLoopStmtsFor(tensor);
596   LoopNest::splitWithTail(loops[1], 2);
597 
598   l.prepareForCodegen();
599 
600   StmtPtr s = l.root_stmt();
601   s = IRSimplifier::simplify(s);
602 
603   SimpleIREvaluator cg(s, {in, tensor});
604   cg.call({in_, out});
605 
606   for (const auto i : c10::irange(16)) {
607     ASSERT_EQ(out[i], i * 8);
608   }
609 }
610 
TEST(Reductions,SplitNonReduceAxis)611 TEST(Reductions, SplitNonReduceAxis) {
612   BufHandle in("in", {16, 8}, kFloat);
613 
614   std::vector<float> in_(16 * 8);
615   for (const auto i : c10::irange(16)) {
616     for (const auto j : c10::irange(8)) {
617       in_[i * 8 + j] = i;
618     }
619   }
620   std::vector<float> out(16, -1.f);
621   Tensor tensor = Reduce("sum", {16}, Sum(), in, {8});
622   LoopNest l({tensor});
623   std::vector<ForPtr> loops = l.getLoopStmtsFor(tensor);
624   LoopNest::splitWithTail(loops[0], 2);
625   LoopNest::splitWithTail(loops[0], 2);
626 
627   l.prepareForCodegen();
628 
629   StmtPtr s = l.root_stmt();
630   s = IRSimplifier::simplify(s);
631 
632   SimpleIREvaluator cg(s, {in, tensor});
633   cg.call({in_, out});
634 
635   for (const auto i : c10::irange(16)) {
636     ASSERT_EQ(out[i], i * 8);
637   }
638 }
639 
TEST(Reductions,ReorderedReductionInitializer)640 TEST(Reductions, ReorderedReductionInitializer) {
641   /* From the quip:
642   for k in 0..1:  // blockIdx
643     for m in 0..128:
644       for n in 0..64: // threadIdx
645         SumOp(c(k, n), 0, a(k, m, n), {m})
646   */
647 
648   BufHandle in("in", {1, 12, 6}, kFloat);
649   std::vector<float> in_(12 * 6, 1.f);
650 
651   Tensor tensor_ = Reduce("sum", {1, 12}, Sum(), in, {6});
652   LoopNest l_({tensor_});
653 
654   l_.prepareForCodegen();
655   StmtPtr s_ = Stmt::clone(l_.root_stmt());
656   s_ = IRSimplifier::simplify(s_);
657 
658   Tensor tensor = Reduce("sum", {1, 12}, Sum(), in, {6});
659   LoopNest l({tensor});
660 
661   auto loops = l.getLoopStmtsFor(tensor);
662   loops[0]->set_gpu_block_index(0);
663   loops[1]->set_gpu_thread_index(0);
664 
665   LoopNest::reorderAxis(loops[1], loops[2]);
666 
667   StmtPtr s = l.root_stmt();
668   // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
669   s = IRSimplifier::simplify(s);
670 
671   l.prepareForCodegen();
672 
673   s = l.root_stmt();
674   s = IRSimplifier::simplify(s);
675 
676   std::vector<float> out1(16, -1.f);
677   SimpleIREvaluator cg(s_, {in, tensor_});
678   cg.call({in_, out1});
679 
680   std::vector<float> out2(16, -1.f);
681   SimpleIREvaluator cg2(s, {in, tensor});
682   cg2.call({in_, out2});
683 
684   for (const auto i : c10::irange(16)) {
685     ASSERT_EQ(out1[i], out2[i]);
686   }
687 }
688 
TEST(Reductions,ReduceRfactor)689 TEST(Reductions, ReduceRfactor) {
690   const int M = 10;
691   const int N = 10;
692   VarHandle m("m", kInt);
693   VarHandle n("n", kInt);
694 
695   BufHandle b("b", {m, n}, kFloat);
696   std::vector<float> in(M * N);
697   for (int j = 0; j < M * N; ++j) {
698     in[j] = j;
699   }
700 
701   std::vector<float> out(1, -1.f);
702 
703   Tensor c = Reduce("sum", {}, Sum(), b, {m, n});
704   LoopNest loop({c});
705   std::vector<ForPtr> loops = loop.getLoopStmtsFor(c);
706   auto c_body = loop.getAllWritesToBuf(c.buf())[1];
707   ASSERT_TRUE(loop.rfactor(c_body, loops.at(0)));
708   auto rc = NodeFinder<ReduceOp>::find(loop.root_stmt());
709   ASSERT_EQ(rc.size(), 2);
710   loop.prepareForCodegen();
711   StmtPtr s = loop.root_stmt();
712   s = IRSimplifier::simplify(s);
713 
714   SimpleIREvaluator cg(s, {b, c, m, n});
715 
716   cg.call({in, out, M, N});
717   ASSERT_EQ(out[0], 4950);
718 }
719 
TEST(Reductions,Reduce3DRfactorInner)720 TEST(Reductions, Reduce3DRfactorInner) {
721   const int M = 10;
722   const int N = 10;
723   const int K = 10;
724   VarHandle m("m", kInt);
725   VarHandle n("n", kInt);
726   VarHandle k("k", kInt);
727 
728   BufHandle b("b", {m, n, k}, kFloat);
729   std::vector<float> in(M * N * K);
730   for (int j = 0; j < M * N * K; ++j) {
731     in[j] = j;
732   }
733 
734   std::vector<float> out(1, -1.f);
735 
736   Tensor c = Reduce("sum", {}, Sum(), b, {m, n, k});
737   LoopNest loop({c});
738   std::vector<ForPtr> loops = loop.getLoopStmtsFor(c);
739   auto c_body = loop.getAllWritesToBuf(c.buf())[1];
740   ASSERT_FALSE(loop.rfactor(c_body, loops.at(2)));
741   auto rc = NodeFinder<ReduceOp>::find(loop.root_stmt());
742   ASSERT_EQ(rc.size(), 1);
743   loop.prepareForCodegen();
744   StmtPtr s = loop.root_stmt();
745   s = IRSimplifier::simplify(s);
746 
747   SimpleIREvaluator cg(s, {b, c, m, n, k});
748 
749   cg.call({in, out, M, N, K});
750   ASSERT_EQ(out[0], 499500);
751 }
752 
TEST(Reductions,Reduce3DRfactorOuter)753 TEST(Reductions, Reduce3DRfactorOuter) {
754   const int M = 10;
755   const int N = 10;
756   const int K = 10;
757   VarHandle m("m", kInt);
758   VarHandle n("n", kInt);
759   VarHandle k("k", kInt);
760 
761   BufHandle b("b", {m, n, k}, kFloat);
762   std::vector<float> in(M * N * K);
763   for (int j = 0; j < M * N * K; ++j) {
764     in[j] = j;
765   }
766 
767   std::vector<float> out(1, -1.f);
768 
769   Tensor c = Reduce("sum", {}, Sum(), b, {m, n, k});
770   LoopNest loop({c});
771   std::vector<ForPtr> loops = loop.getLoopStmtsFor(c);
772   auto c_body = loop.getAllWritesToBuf(c.buf())[1];
773   ASSERT_TRUE(loop.rfactor(c_body, loops.at(0)));
774   auto rc = NodeFinder<ReduceOp>::find(loop.root_stmt());
775   ASSERT_EQ(rc.size(), 2);
776   loop.prepareForCodegen();
777   StmtPtr s = loop.root_stmt();
778   s = IRSimplifier::simplify(s);
779 
780   SimpleIREvaluator cg(s, {b, c, m, n, k});
781   cg.call({in, out, M, N, K});
782   ASSERT_EQ(out[0], 499500);
783 }
784 
TEST(Reductions,ReduceRepeatedInternalRfactor)785 TEST(Reductions, ReduceRepeatedInternalRfactor) {
786   BufHandle in_("in_", {2, 3, 4, 5, 6}, kFloat);
787   const int InputSize = 2 * 3 * 4 * 5 * 6;
788 
789   std::vector<float> in(InputSize, 1.f);
790   std::vector<float> out(1, -1.f);
791   std::vector<float> ref(1, -1.f);
792 
793   Tensor c = Reduce("sum", {}, Sum(), in_, {2, 3, 4, 5, 6});
794   LoopNest orig_loop({c});
795 
796   // Try rfactoring N outer loops
797   for (const auto rfac_number : c10::irange(1, 5)) {
798     LoopNest refloop(orig_loop);
799     LoopNest loop(orig_loop);
800     refloop.prepareForCodegen();
801     SimpleIREvaluator ref_cg(
802         IRSimplifier::simplify(refloop.root_stmt()), {in_, c});
803     ref_cg.call({in, ref});
804 
805     BufPtr tmp_buf = c.buf();
806 
807     for (const auto idx : c10::irange(rfac_number)) {
808       auto reduce = loop.getAllWritesToBuf(tmp_buf)[1];
809       ASSERT_TRUE(loop.rfactor(
810           reduce, loop.getLoopStmtsFor(tmp_buf).at(idx), &tmp_buf));
811     }
812 
813     loop.prepareForCodegen();
814     StmtPtr s = loop.root_stmt();
815     s = IRSimplifier::simplify(s);
816 
817     SimpleIREvaluator cg(s, {in_, c});
818     cg.call({in, out});
819 
820     ASSERT_EQ(ref[0], out[0]);
821   }
822 }
823 
824 // Split a reduction axis with a tail loop.
TEST(Reductions,ReduceSplitTail)825 TEST(Reductions, ReduceSplitTail) {
826   const int M = 10;
827   const int N = 10;
828   const int K = 10;
829 
830   BufHandle b("b", {M, N, K}, kFloat);
831   std::vector<float> in(M * N * K);
832   for (int j = 0; j < M * N * K; ++j) {
833     in[j] = j;
834   }
835 
836   for (const auto i : c10::irange(3)) {
837     std::vector<float> out(M, -1.f);
838 
839     Tensor c = Reduce("sum", {M}, Sum(), b, {N, K});
840     LoopNest loop({c});
841     std::vector<ForPtr> loops = loop.getLoopStmtsFor(c);
842     LoopNest::splitWithTail(loops[i], 8);
843 
844     loop.prepareForCodegen();
845     StmtPtr s = loop.root_stmt();
846     s = IRSimplifier::simplify(s);
847 
848     SimpleIREvaluator cg(s, {b, c});
849 
850     cg.call({in, out});
851     ASSERT_EQ(out[0], 4950);
852   }
853 }
854 
855 // Split a reduction axis cleanly so there is no tail loop.
TEST(Reductions,ReduceSplitNoTail)856 TEST(Reductions, ReduceSplitNoTail) {
857   const int M = 10;
858   const int N = 10;
859   const int K = 10;
860   BufHandle b("b", {M, N, K}, kFloat);
861   std::vector<float> in(M * N * K);
862   for (int j = 0; j < M * N * K; ++j) {
863     in[j] = j;
864   }
865 
866   for (const auto i : c10::irange(3)) {
867     std::vector<float> out(M, -1.f);
868 
869     Tensor c = Reduce("sum", {M}, Sum(), b, {N, K});
870     LoopNest loop({c});
871     std::vector<ForPtr> loops = loop.getLoopStmtsFor(c);
872     LoopNest::splitWithTail(loops[i], 5);
873 
874     loop.prepareForCodegen();
875     StmtPtr s = loop.root_stmt();
876     s = IRSimplifier::simplify(s);
877 
878     SimpleIREvaluator cg(s, {b, c});
879 
880     cg.call({in, out});
881     ASSERT_EQ(out[0], 4950);
882   }
883 }
884 
885 // Split a reduction axis with only a tail loop (the split loop will be size 0
886 // and eliminated out).
TEST(Reductions,ReduceOverSplitTail)887 TEST(Reductions, ReduceOverSplitTail) {
888   const int M = 10;
889   const int N = 10;
890   const int K = 10;
891 
892   BufHandle b("b", {M, N, K}, kFloat);
893   std::vector<float> in(M * N * K);
894   for (int j = 0; j < M * N * K; ++j) {
895     in[j] = j;
896   }
897 
898   for (const auto i : c10::irange(3)) {
899     std::vector<float> out(M, -1.f);
900 
901     Tensor c = Reduce("sum", {M}, Sum(), b, {N, K});
902     LoopNest loop({c});
903     std::vector<ForPtr> loops = loop.getLoopStmtsFor(c);
904     LoopNest::splitWithTail(loops[i], 16);
905 
906     loop.prepareForCodegen();
907     StmtPtr s = loop.root_stmt();
908     s = IRSimplifier::simplify(s);
909 
910     SimpleIREvaluator cg(s, {b, c});
911 
912     cg.call({in, out});
913     ASSERT_EQ(out[0], 4950);
914   }
915 }
916 
917 // Split a reduction axis with a mask.
TEST(Reductions,ReduceSplitMask)918 TEST(Reductions, ReduceSplitMask) {
919   const int M = 10;
920   const int N = 10;
921   const int K = 10;
922 
923   BufHandle b("b", {M, N, K}, kFloat);
924   std::vector<float> in(M * N * K);
925   for (int j = 0; j < M * N * K; ++j) {
926     in[j] = j;
927   }
928 
929   for (const auto i : c10::irange(3)) {
930     std::vector<float> out(M, -1.f);
931 
932     Tensor c = Reduce("sum", {M}, Sum(), b, {N, K});
933     LoopNest loop({c});
934     std::vector<ForPtr> loops = loop.getLoopStmtsFor(c);
935     LoopNest::splitWithMask(loops[i], 8);
936 
937     loop.prepareForCodegen();
938     StmtPtr s = loop.root_stmt();
939     s = IRSimplifier::simplify(s);
940 
941     SimpleIREvaluator cg(s, {b, c});
942 
943     cg.call({in, out});
944     ASSERT_EQ(out[0], 4950);
945   }
946 }
947 
948 // Split a reduction axis cleanly not requiring a mask.
TEST(Reductions,ReduceSplitNoMask)949 TEST(Reductions, ReduceSplitNoMask) {
950   const int M = 10;
951   const int N = 10;
952   const int K = 10;
953   BufHandle b("b", {M, N, K}, kFloat);
954   std::vector<float> in(M * N * K);
955   for (int j = 0; j < M * N * K; ++j) {
956     in[j] = j;
957   }
958 
959   for (const auto i : c10::irange(3)) {
960     std::vector<float> out(M, -1.f);
961 
962     Tensor c = Reduce("sum", {M}, Sum(), b, {N, K});
963     LoopNest loop({c});
964     std::vector<ForPtr> loops = loop.getLoopStmtsFor(c);
965     LoopNest::splitWithMask(loops[i], 5);
966 
967     loop.prepareForCodegen();
968     StmtPtr s = loop.root_stmt();
969     s = IRSimplifier::simplify(s);
970 
971     SimpleIREvaluator cg(s, {b, c});
972 
973     cg.call({in, out});
974     ASSERT_EQ(out[0], 4950);
975   }
976 }
977 
978 // Split a reduction axis with all logic in the mask.
TEST(Reductions,ReduceOverSplitMask)979 TEST(Reductions, ReduceOverSplitMask) {
980   const int M = 10;
981   const int N = 10;
982   const int K = 10;
983 
984   BufHandle b("b", {M, N, K}, kFloat);
985   std::vector<float> in(M * N * K);
986   for (int j = 0; j < M * N * K; ++j) {
987     in[j] = j;
988   }
989 
990   for (const auto i : c10::irange(3)) {
991     std::vector<float> out(M, -1.f);
992 
993     Tensor c = Reduce("sum", {M}, Sum(), b, {N, K});
994     LoopNest loop({c});
995     std::vector<ForPtr> loops = loop.getLoopStmtsFor(c);
996     LoopNest::splitWithMask(loops[i], 16);
997 
998     loop.prepareForCodegen();
999     StmtPtr s = loop.root_stmt();
1000     s = IRSimplifier::simplify(s);
1001 
1002     SimpleIREvaluator cg(s, {b, c});
1003 
1004     cg.call({in, out});
1005     ASSERT_EQ(out[0], 4950);
1006   }
1007 }
1008 
1009 // Test an rfactor when there are two ReduceOps in the graph due to a
1010 // splitWithTail.
TEST(Reductions,ReduceSplitRfactor)1011 TEST(Reductions, ReduceSplitRfactor) {
1012   const int M = 2;
1013   const int N = 10;
1014   const int K = 10;
1015   const int SPLIT_FACTOR = 4;
1016 
1017   BufHandle b("b", {M, N, K}, kFloat);
1018   std::vector<float> in(M * N * K);
1019   for (const auto m : c10::irange(M)) {
1020     for (int j = 0; j < N * K; ++j) {
1021       in[m * N * K + j] = j;
1022     }
1023   }
1024 
1025   std::vector<float> out(M, -1.f);
1026 
1027   Tensor c = Reduce("sum", {M}, Sum(), b, {N, K});
1028   LoopNest loop({c});
1029   std::vector<ForPtr> loops = loop.getLoopStmtsFor(c);
1030   LoopNest::splitWithTail(loops[2], SPLIT_FACTOR);
1031 
1032   auto c_body = loop.getAllWritesToBuf(c.buf())[2];
1033   auto all_loops = loop.getAllLoopNestsWritingToBuf(c.buf());
1034   ASSERT_TRUE(all_loops.size() == 3 && all_loops.at(2).size() == 3);
1035   LoopNest::reorderAxis(all_loops[2][1], all_loops[2][2]);
1036   all_loops = loop.getAllLoopNestsWritingToBuf(c.buf());
1037   ASSERT_TRUE(all_loops.size() == 3 && all_loops.at(2).size() == 3);
1038   ASSERT_TRUE(loop.rfactor(c_body, all_loops[2][1]));
1039   loop.prepareForCodegen();
1040   loop.simplify();
1041   StmtPtr s = loop.root_stmt();
1042 
1043   SimpleIREvaluator cg(s, {b, c});
1044 
1045   cg.call({in, out});
1046   for (const auto i : c10::irange(M)) {
1047     (void)i; // Suppress unused variable warning
1048     ASSERT_EQ(out[0], 4950);
1049   }
1050 }
1051 
1052 // Test an rfactor which ends up being eliminated since the total loop size is
1053 // smaller than the split factor.
TEST(Reductions,ReduceOverSplitRfactor)1054 TEST(Reductions, ReduceOverSplitRfactor) {
1055   const int N = 10;
1056   const int K = 10;
1057   const int SPLIT_FACTOR = 16;
1058 
1059   BufHandle b("b", {N, K}, kFloat);
1060   std::vector<float> in(N * K);
1061   for (int j = 0; j < N * K; ++j) {
1062     in[j] = j;
1063   }
1064 
1065   std::vector<float> out(1, -1.f);
1066 
1067   Tensor c = Reduce("sum", {}, Sum(), b, {N, K});
1068   LoopNest loop({c});
1069   std::vector<ForPtr> loops = loop.getLoopStmtsFor(c);
1070   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
1071   ForPtr i, t;
1072   LoopNest::splitWithTail(loops[1], SPLIT_FACTOR, &i, &t);
1073   LoopNest::reorderAxis(loops[0], i);
1074 
1075   auto all_loops = loop.getAllLoopNestsWritingToBuf(c.buf());
1076   ASSERT_TRUE(all_loops.size() == 3 && all_loops.at(1).size() == 3);
1077   auto c_body = loop.getAllWritesToBuf(c.buf())[1];
1078   ASSERT_TRUE(loop.rfactor(c_body, all_loops[1][0]));
1079   LoopNest::reorderAxis(all_loops[1][0], all_loops[1][2]);
1080 
1081   loop.prepareForCodegen();
1082   loop.simplify();
1083   StmtPtr s = loop.root_stmt();
1084 
1085   SimpleIREvaluator cg(s, {b, c});
1086 
1087   cg.call({in, out});
1088   ASSERT_EQ(out[0], 4950);
1089 
1090   std::ostringstream oss;
1091   oss << *cg.stmt();
1092 
1093   // Check the IR to verify the rfactored reduce is eliminated.
1094   // TODO: The alloc free should be eliminated here since it is size 0.
1095   /*
1096   const std::string& verification_pattern =
1097       R"IR(
1098 # CHECK: Allocate(tmp_buf); // dtype=float, dims=[0]
1099 # CHECK: sum[0] = 0.f;
1100 # CHECK: for (int n = 0; n < 10; n++) {
1101 # CHECK:   for (int k_tail = 0; k_tail < 10; k_tail++) {
1102 # CHECK:     sum[0] = (sum[0]) + (b[k_tail + 10 * n]);
1103 # CHECK:   }
1104 # CHECK: }
1105 # CHECK: Free(tmp_buf);)IR";
1106   */
1107   // TODO: rfactor output is not consistent yet, will fix (@nickg).
1108   // torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
1109 }
1110 
TEST(Reductions,ReduceInlineReduction)1111 TEST(Reductions, ReduceInlineReduction) {
1112   const int M = 4;
1113   const int N = 5;
1114   const int K = 6;
1115 
1116   BufHandle a_buf("a", {M}, kFloat);
1117   BufHandle b_buf("b", {M, N, K}, kFloat);
1118 
1119   Tensor x = Reduce("x", {M}, Sum(), b_buf, {N, K});
1120   Tensor y = Compute(
1121       "y", {M}, [&](const VarHandle& m) { return a_buf.load(m) + x.load(m); });
1122 
1123   PaddedBuffer<float> a_v(M);
1124   PaddedBuffer<float> b_v(M, N, K);
1125 
1126   for (const auto i : c10::irange(M)) {
1127     a_v(i) = i * i;
1128   }
1129   for (const auto i : c10::irange(M)) {
1130     for (const auto j : c10::irange(N)) {
1131       for (const auto k : c10::irange(K)) {
1132         b_v(i, j, k) = j * j * k;
1133       }
1134     }
1135   }
1136 
1137   LoopNest l1({y}, {x, y});
1138   // Cannot inline a reduction computation
1139   ASSERT_FALSE(l1.computeInline(x.buf()));
1140 }
1141 
TEST(Reductions,ReduceInlineConsumer)1142 TEST(Reductions, ReduceInlineConsumer) {
1143   const int M = 4;
1144   const int N = 5;
1145   const int K = 6;
1146 
1147   BufHandle a_buf("a", {M, N, K}, kFloat);
1148   BufHandle b_buf("b", {M, N, K}, kFloat);
1149 
1150   Tensor x = Compute(
1151       "x",
1152       {M, N, K},
1153       [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
1154         return a_buf.load(m, n, k) + b_buf.load(m, n, k);
1155       });
1156   Tensor y = Reduce("y", {M}, Sum(), x, {N, K});
1157 
1158   PaddedBuffer<float> a_v(M, N, K);
1159   PaddedBuffer<float> b_v(M, N, K);
1160 
1161   for (const auto i : c10::irange(M)) {
1162     for (const auto j : c10::irange(N)) {
1163       for (const auto k : c10::irange(K)) {
1164         a_v(i, j, k) = i * i + k;
1165         b_v(i, j, k) = j * j + k;
1166       }
1167     }
1168   }
1169 
1170   LoopNest l1({y}, {x, y});
1171   LoopNest l2(l1);
1172   l2.computeInline(x.buf());
1173 
1174   l1.prepareForCodegen();
1175   l2.prepareForCodegen();
1176 
1177   StmtPtr stmt1 = IRSimplifier::simplify(l1.root_stmt());
1178   StmtPtr stmt2 = IRSimplifier::simplify(l2.root_stmt());
1179 
1180   SimpleIREvaluator eval1(stmt1, {a_buf, b_buf, y});
1181   SimpleIREvaluator eval2(stmt2, {a_buf, b_buf, y});
1182 
1183   PaddedBuffer<float> y_1(M);
1184   PaddedBuffer<float> y_2(M);
1185 
1186   eval1(a_v, b_v, y_1);
1187   eval2(a_v, b_v, y_2);
1188   ExpectAllNear(y_1, y_2, 1e-5);
1189   std::ostringstream oss1, oss2;
1190   oss1 << *stmt1;
1191   oss2 << *stmt2;
1192   ASSERT_GT(oss1.str().size(), oss2.str().size());
1193 }
1194 
TEST(Reductions,ReduceInlineReducerInternal)1195 TEST(Reductions, ReduceInlineReducerInternal) {
1196   const int M = 4;
1197   const int N = 5;
1198   const int K = 6;
1199 
1200   BufHandle a_buf("a", {M, N, K}, kFloat);
1201   BufHandle b_buf("b", {M, N, K}, kFloat);
1202 
1203   Tensor x = Compute(
1204       "x",
1205       {M, N, K},
1206       [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
1207         return a_buf.load(m, n, k) + b_buf.load(m, n, k);
1208       });
1209 
1210   Reducer minimum(ExprHandle(0.f), [&](ExprHandle a, ExprHandle b) {
1211     return Add::make(ExprHandle(1.f), Min::make(a, b, false));
1212   });
1213   Tensor y = Reduce("y", {M}, minimum, x, {N, K});
1214 
1215   PaddedBuffer<float> a_v(M, N, K);
1216   PaddedBuffer<float> b_v(M, N, K);
1217 
1218   for (const auto i : c10::irange(M)) {
1219     for (const auto j : c10::irange(N)) {
1220       for (const auto k : c10::irange(K)) {
1221         a_v(i, j, k) = i * i + k;
1222         b_v(i, j, k) = j * j + k;
1223       }
1224     }
1225   }
1226 
1227   LoopNest l1({y}, {x, y});
1228   LoopNest l2(l1);
1229   l2.computeInline(x.buf());
1230 
1231   l1.prepareForCodegen();
1232   l2.prepareForCodegen();
1233 
1234   StmtPtr stmt1 = IRSimplifier::simplify(l1.root_stmt());
1235   StmtPtr stmt2 = IRSimplifier::simplify(l2.root_stmt());
1236 
1237   SimpleIREvaluator eval1(stmt1, {a_buf, b_buf, y});
1238   SimpleIREvaluator eval2(stmt2, {a_buf, b_buf, y});
1239 
1240   PaddedBuffer<float> y_1(M);
1241   PaddedBuffer<float> y_2(M);
1242 
1243   eval1(a_v, b_v, y_1);
1244   eval2(a_v, b_v, y_2);
1245   ExpectAllNear(y_1, y_2, 1e-5);
1246   std::ostringstream oss1, oss2;
1247   oss1 << *stmt1;
1248   oss2 << *stmt2;
1249   ASSERT_GT(oss1.str().size(), oss2.str().size());
1250 }
1251 
TEST(Reductions,ReductionCacheAccessesOperatorAxis)1252 TEST(Reductions, ReductionCacheAccessesOperatorAxis) {
1253   int L = 4;
1254   int N = 3;
1255   int M = 2;
1256 
1257   BufHandle a("a", {L, N, M}, kFloat);
1258   BufHandle b("b", {L, N, M}, kFloat);
1259 
1260   Tensor c = Compute(
1261       "scale",
1262       {L, N, M},
1263       [&](const VarHandle& l, const VarHandle& n, const VarHandle& m) {
1264         return b.load(l, n, m) * a.load(l, n, m);
1265       });
1266   Tensor d = Reduce("sum", {L}, Sum(), c, {N, M});
1267 
1268   Tensor e = Compute("scale", {L}, [&](const VarHandle& l) {
1269     return b.load(0, 0, l) * d.load(l);
1270   });
1271 
1272   LoopNest l({e}, {c, d, e});
1273   LoopNest l_before(l);
1274   l_before.prepareForCodegen();
1275   SimpleIREvaluator cg_before(
1276       LoopNest::sanitizeNames(l_before.root_stmt()), {a, b, e});
1277 
1278   StmtPtr d_loop = l.getLoopStmtsFor(d)[0];
1279   l.cacheAccesses(d.buf(), "d_local", d_loop);
1280   l.prepareForCodegen();
1281 
1282   StmtPtr result =
1283       LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt()));
1284   SimpleIREvaluator cg_after(result, {a, b, e});
1285 
1286   std::ostringstream oss;
1287   oss << *cg_after.stmt();
1288   const std::string& expected_ir =
1289       R"IR(
1290 #CHECK: Allocate(d_local); // dtype=float, dims=[4]
1291 #CHECK: for (int i_2
1292 #CHECK:   d_local[i_2] = 0.f
1293 #CHECK:   for (int
1294 #CHECK:     for (int
1295 #CHECK:       d_local[i_2] = (d_local[i_2]) + (scale[
1296 #CHECK:     }
1297 #CHECK:   }
1298 #CHECK: }
1299 #CHECK: for (int i_3
1300 #CHECK:   sum[i_3] = d_local[i_3]
1301 #CHECK: Free(d_local);
1302 #CHECK-NOT: d_local
1303       )IR";
1304   torch::jit::testing::FileCheck().run(expected_ir, oss.str());
1305 
1306   PaddedBuffer<float> a_v(L, M, N, "a");
1307   PaddedBuffer<float> b_v(L, M, N, "b");
1308   PaddedBuffer<float> c_v(L, M, N, "c");
1309   PaddedBuffer<float> d_v(L, "d");
1310   PaddedBuffer<float> e_before(L, "e_before");
1311   PaddedBuffer<float> e_after(L, "e_after");
1312 
1313   for (const auto l : c10::irange(L)) {
1314     for (const auto m : c10::irange(M)) {
1315       for (const auto n : c10::irange(N)) {
1316         a_v(l, m, n) = at::randn({1}).item().to<float>();
1317         b_v(l, m, n) = at::randn({1}).item().to<float>();
1318       }
1319     }
1320   }
1321 
1322   cg_before.call({a_v, b_v, e_before});
1323   cg_after.call({a_v, b_v, e_after});
1324 
1325   // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
1326   ExpectAllNear(e_before, e_after, 1e-5);
1327 }
1328 
TEST(Reductions,ReductionCacheAccessesOuterReduceAxis)1329 TEST(Reductions, ReductionCacheAccessesOuterReduceAxis) {
1330   int L = 4;
1331   int N = 3;
1332   int M = 2;
1333 
1334   BufHandle a("a", {L, N, M}, kFloat);
1335   BufHandle b("b", {L, N, M}, kFloat);
1336 
1337   Tensor c = Compute(
1338       "scale",
1339       {L, N, M},
1340       [&](const VarHandle& l, const VarHandle& n, const VarHandle& m) {
1341         return b.load(l, n, m) * a.load(l, n, m);
1342       });
1343   Tensor d = Reduce("sum", {L}, Sum(), c, {N, M});
1344 
1345   Tensor e = Compute("scale", {L}, [&](const VarHandle& l) {
1346     return b.load(0, 0, l) * d.load(l);
1347   });
1348 
1349   LoopNest l({e}, {c, d, e});
1350   LoopNest l_before(l);
1351   l_before.prepareForCodegen();
1352   SimpleIREvaluator cg_before(l_before.root_stmt(), {a, b, e});
1353 
1354   StmtPtr d_loop = l.getLoopStmtsFor(d)[1];
1355   l.cacheAccesses(d.buf(), "d_local", d_loop);
1356   l.prepareForCodegen();
1357 
1358   StmtPtr result =
1359       LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt()));
1360   SimpleIREvaluator cg_after(result, {a, b, e});
1361 
1362   std::ostringstream oss;
1363   oss << *cg_after.stmt();
1364   const std::string& expected_ir =
1365       R"IR(
1366 #CHECK: Allocate(d_local); // dtype=float, dims=[1]
1367 #CHECK: sum[i_1] = 0
1368 #CHECK: d_local[0] = sum[i_1]
1369 #CHECK: for (int j_1
1370 #CHECK:   for (int k_1
1371 #CHECK: d_local[0] = (d_local[0]) + (scale[
1372 #CHECK:   }
1373 #CHECK: }
1374 #CHECK: sum[i_1] = d_local[0]
1375 #CHECK: Free(d_local);
1376 #CHECK-NOT: d_local
1377       )IR";
1378   torch::jit::testing::FileCheck().run(expected_ir, oss.str());
1379 
1380   PaddedBuffer<float> a_v(L, M, N, "a");
1381   PaddedBuffer<float> b_v(L, M, N, "b");
1382   PaddedBuffer<float> c_v(L, M, N, "c");
1383   PaddedBuffer<float> d_v(L, "d");
1384   PaddedBuffer<float> e_before(L, "e_before");
1385   PaddedBuffer<float> e_after(L, "e_after");
1386 
1387   for (const auto l : c10::irange(L)) {
1388     for (const auto m : c10::irange(M)) {
1389       for (const auto n : c10::irange(N)) {
1390         a_v(l, m, n) = at::randn({1}).item().to<float>();
1391         b_v(l, m, n) = at::randn({1}).item().to<float>();
1392       }
1393     }
1394   }
1395 
1396   cg_before.call({a_v, b_v, e_before});
1397   cg_after.call({a_v, b_v, e_after});
1398 
1399   // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
1400   ExpectAllNear(e_before, e_after, 1e-5);
1401 }
1402 
TEST(Reductions,ReductionCacheAccessesInnerReduceAxis)1403 TEST(Reductions, ReductionCacheAccessesInnerReduceAxis) {
1404   int L = 4;
1405   int N = 3;
1406   int M = 2;
1407 
1408   BufHandle a("a", {L, N, M}, kFloat);
1409   BufHandle b("b", {L, N, M}, kFloat);
1410 
1411   Tensor c = Compute(
1412       "scale",
1413       {L, N, M},
1414       [&](const VarHandle& l, const VarHandle& n, const VarHandle& m) {
1415         return b.load(l, n, m) * a.load(l, n, m);
1416       });
1417   Tensor d = Reduce("sum", {L}, Sum(), c, {N, M});
1418 
1419   Tensor e = Compute("scale", {L}, [&](const VarHandle& l) {
1420     return b.load(0, 0, l) * d.load(l);
1421   });
1422 
1423   LoopNest l({e}, {c, d, e});
1424   LoopNest l_before(l);
1425   l_before.prepareForCodegen();
1426   SimpleIREvaluator cg_before(l_before.root_stmt(), {a, b, e});
1427 
1428   StmtPtr d_loop = l.getLoopStmtsFor(d)[2];
1429   l.cacheAccesses(d.buf(), "d_local", d_loop);
1430   l.prepareForCodegen();
1431 
1432   StmtPtr result =
1433       LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt()));
1434   SimpleIREvaluator cg_after(result, {a, b, e});
1435 
1436   std::ostringstream oss;
1437   oss << *cg_after.stmt();
1438   const std::string& expected_ir =
1439       R"IR(
1440 #CHECK: Allocate(d_local); // dtype=float, dims=[1]
1441 #CHECK: sum[i_1] = 0
1442 #CHECK: for (int
1443 #CHECK:   d_local[0] = 0
1444 #CHECK:   for (int
1445 #CHECK:     d_local[0] = (d_local[0]) + (scale[
1446 #CHECK:   }
1447 #CHECK:   sum[i_1] = (sum[i_1]) + (d_local[0])
1448 #CHECK: }
1449 #CHECK: Free(d_local);
1450 #CHECK-NOT: d_local
1451       )IR";
1452   torch::jit::testing::FileCheck().run(expected_ir, oss.str());
1453 
1454   PaddedBuffer<float> a_v(L, M, N, "a");
1455   PaddedBuffer<float> b_v(L, M, N, "b");
1456   PaddedBuffer<float> c_v(L, M, N, "c");
1457   PaddedBuffer<float> d_v(L, "d");
1458   PaddedBuffer<float> e_before(L, "e_before");
1459   PaddedBuffer<float> e_after(L, "e_after");
1460 
1461   for (const auto l : c10::irange(L)) {
1462     for (const auto m : c10::irange(M)) {
1463       for (const auto n : c10::irange(N)) {
1464         a_v(l, m, n) = at::randn({1}).item().to<float>();
1465         b_v(l, m, n) = at::randn({1}).item().to<float>();
1466       }
1467     }
1468   }
1469 
1470   cg_before.call({a_v, b_v, e_before});
1471   cg_after.call({a_v, b_v, e_after});
1472 
1473   // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
1474   ExpectAllNear(e_before, e_after, 1e-5);
1475 }
1476 
TEST(Reductions,ReductionCacheBodyAccess)1477 TEST(Reductions, ReductionCacheBodyAccess) {
1478   BufHandle a("a", {24, 32, 12}, kFloat);
1479   BufHandle b("b", {24, 32, 12}, kFloat);
1480 
1481   Tensor c = Compute(
1482       "scale",
1483       {24, 32, 12},
1484       [&](const VarHandle& l, const VarHandle& n, const VarHandle& m) {
1485         return b.load(l, n, m) * a.load(l, n, m);
1486       });
1487   Tensor d = Reduce("sum", {24}, Sum(), c, {32, 12});
1488 
1489   Tensor e = Compute("scale", {24}, [&](const VarHandle& l) {
1490     return b.load(0, 0, l) * d.load(l);
1491   });
1492 
1493   LoopNest l({e}, {c, d, e});
1494 
1495   StmtPtr d_loop = l.getLoopStmtsFor(d)[1];
1496   l.cacheAccesses(c.buf(), "scale_local", d_loop);
1497 
1498   l.prepareForCodegen();
1499   StmtPtr result =
1500       LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt()));
1501   SimpleIREvaluator cg(result, {a, b, e});
1502 
1503   std::ostringstream oss;
1504   oss << *cg.stmt();
1505   const std::string& expected_ir =
1506       R"IR(
1507 #CHECK: Allocate(scale_local); // dtype=float, dims=[1, 32, 12]
1508 #CHECK: for (int j_1 = 0; j_1 < 32; j_1++) {
1509 #CHECK:   for (int k_1 = 0; k_1 < 12; k_1++) {
1510 #CHECK:     scale_local[k_1 + 12 * j_1] = scale[(k_1 + 12 * j_1) + 384 * i_1];
1511 #CHECK: sum[i_1] = (sum[i_1]) + (scale_local[k_2 + 12 * j_2]);
1512 #CHECK: scale_1[i_2] = (b[i_2]) * (sum[i_2]);
1513 #CHECK: Free(scale_local);
1514       )IR";
1515   torch::jit::testing::FileCheck().run(expected_ir, oss.str());
1516 }
1517 
TEST(Reductions,ReductionCacheConsumerAccess)1518 TEST(Reductions, ReductionCacheConsumerAccess) {
1519   BufHandle a("a", {24, 32, 12}, kFloat);
1520   BufHandle b("b", {24, 32, 12}, kFloat);
1521 
1522   Tensor c = Compute(
1523       "scale",
1524       {24, 32, 12},
1525       [&](const VarHandle& l, const VarHandle& n, const VarHandle& m) {
1526         return b.load(l, n, m) * a.load(l, n, m);
1527       });
1528   Tensor d = Reduce("sum", {24}, Sum(), c, {32, 12});
1529 
1530   Tensor e = Compute("scale", {24}, [&](const VarHandle& l) {
1531     return b.load(0, 0, l) * d.load(l);
1532   });
1533 
1534   LoopNest l({e}, {c, d, e});
1535 
1536   LoopNest::splitWithMask(l.getLoopStmtsFor(e)[0], 4);
1537 
1538   StmtPtr e_loop = l.getLoopStmtsFor(e)[1];
1539   l.cacheAccesses(d.buf(), "sum_local", e_loop);
1540   l.prepareForCodegen();
1541 
1542   StmtPtr result =
1543       LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt()));
1544   SimpleIREvaluator cg(result, {a, b, e});
1545 
1546   std::ostringstream oss;
1547   oss << *cg.stmt();
1548   const std::string& expected_ir =
1549       R"IR(
1550 #CHECK: Alias(sum_local,scale);
1551 #CHECK: sum[i_1] = (sum[i_1]) + (scale[
1552 #CHECK: for (int j_2 = 0; j_2 < 4
1553 #CHECK:   sum_local[j_2] = sum[j_2 + 4 * i_2];
1554 #CHECK:   scale_1[j_3 + 4 * i_2] = (b[j_3 + 4 * i_2]) * (sum_local[j_3]);
1555       )IR";
1556   torch::jit::testing::FileCheck().run(expected_ir, oss.str());
1557 }
1558 
TEST(Reductions,ReductionSplitCacheConsumerAccess)1559 TEST(Reductions, ReductionSplitCacheConsumerAccess) {
1560   BufHandle a("a", {24, 32, 12}, kFloat);
1561   BufHandle b("b", {24, 32, 12}, kFloat);
1562 
1563   Tensor c = Compute(
1564       "scale",
1565       {24, 32, 12},
1566       [&](const VarHandle& l, const VarHandle& n, const VarHandle& m) {
1567         return b.load(l, n, m) * a.load(l, n, m);
1568       });
1569   Tensor d = Reduce("sum", {24}, Sum(), c, {32, 12});
1570 
1571   Tensor e = Compute("scale", {24}, [&](const VarHandle& l) {
1572     return b.load(0, 0, l) * d.load(l);
1573   });
1574 
1575   LoopNest l({e}, {c, d, e});
1576 
1577   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
1578   ForPtr inner;
1579 
1580   // Split outer reduction axis.
1581   LoopNest::splitWithMask(l.getLoopStmtsFor(d)[0], 4, &inner);
1582 
1583   // Split reduction consumer.
1584   LoopNest::splitWithMask(l.getLoopStmtsFor(e)[0], 4, &inner);
1585 
1586   l.cacheAccesses(d.buf(), "sum_local", inner);
1587   l.prepareForCodegen();
1588 
1589   StmtPtr result =
1590       LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt()));
1591   SimpleIREvaluator cg(result, {a, b, e});
1592 
1593   // reduction changes but cache does not.
1594   std::ostringstream oss;
1595   oss << *cg.stmt();
1596   const std::string& expected_ir =
1597       R"IR(
1598 #CHECK: Alias(sum_local,scale);
1599 #CHECK:         sum[j_1 + 4 * i_1] = (sum[j_1 + 4 * i_1]) + (scale[((l + 12 * k_1) + 1536 * i_1) + 384 * j_1]);
1600 #CHECK: for (int i_2 = 0; i_2 < 6
1601 #CHECK:   for (int j_2 = 0; j_2 < 4
1602 #CHECK:     sum_local[j_2] = sum[j_2 + 4 * i_2];
1603 #CHECK:   for (int j_3 = 0; j_3 < 4
1604 #CHECK:     scale_1[j_3 + 4 * i_2] = (b[j_3 + 4 * i_2]) * (sum_local[j_3]);
1605       )IR";
1606   torch::jit::testing::FileCheck().run(expected_ir, oss.str());
1607 }
1608 
TEST(Reductions,ReductionReorderCacheConsumerAccess)1609 TEST(Reductions, ReductionReorderCacheConsumerAccess) {
1610   BufHandle a("a", {24, 32, 12}, kFloat);
1611   BufHandle b("b", {24, 32, 12}, kFloat);
1612 
1613   Tensor c = Compute(
1614       "scale",
1615       {24, 32, 12},
1616       [&](const VarHandle& l, const VarHandle& n, const VarHandle& m) {
1617         return b.load(l, n, m) * a.load(l, n, m);
1618       });
1619   Tensor d = Reduce("sum", {24}, Sum(), c, {32, 12});
1620 
1621   Tensor e = Compute("scale", {24}, [&](const VarHandle& l) {
1622     return b.load(0, 0, l) * d.load(l);
1623   });
1624 
1625   LoopNest l({e}, {c, d, e});
1626 
1627   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
1628   ForPtr inner;
1629 
1630   // reorder outer reduction axes.
1631   auto loops = l.getLoopStmtsFor(d);
1632   LoopNest::reorderAxis(loops[0], loops[1]);
1633 
1634   // Split reduction consumer.
1635   LoopNest::splitWithMask(l.getLoopStmtsFor(e)[0], 4, &inner);
1636 
1637   l.cacheAccesses(d.buf(), "sum_local", inner);
1638   l.prepareForCodegen();
1639 
1640   StmtPtr result =
1641       LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt()));
1642   SimpleIREvaluator cg(result, {a, b, e});
1643 
1644   // neither reduction body not cache changes.
1645   std::ostringstream oss;
1646   oss << *cg.stmt();
1647   const std::string& expected_ir =
1648       R"IR(
1649 #CHECK:        sum[j_1] = (sum[j_1]) + (scale[(k_1 + 12 * i_2) + 384 * j_1]);
1650 #CHECK:  for (int i_3 = 0; i_3 < 6;
1651 #CHECK:    for (int j_2 = 0; j_2 < 4;
1652 #CHECK:      sum_local[j_2] = sum[j_2 + 4 * i_3];
1653 #CHECK:    for (int j_3 = 0; j_3 < 4;
1654 #CHECK:      scale_1[j_3 + 4 * i_3] = (b[j_3 + 4 * i_3]) * (sum_local[j_3]);
1655       )IR";
1656   torch::jit::testing::FileCheck().run(expected_ir, oss.str());
1657 }
1658 
TEST(Reductions,ReductionRfactorCacheTempOuter)1659 TEST(Reductions, ReductionRfactorCacheTempOuter) {
1660   const int M = 10;
1661   const int N = 10;
1662   const int K = 10;
1663   VarHandle m("m", kInt);
1664   VarHandle n("n", kInt);
1665   VarHandle k("k", kInt);
1666 
1667   BufHandle b("B", {m, n, k}, kFloat);
1668   std::vector<float> in(M * N * K);
1669   for (int j = 0; j < M * N * K; ++j) {
1670     in[j] = j;
1671   }
1672 
1673   std::vector<float> out(1, -1.f);
1674 
1675   Tensor c = Reduce("sum", {}, Sum(), b, {m, n, k});
1676   LoopNest loop({c});
1677 
1678   std::vector<ForPtr> loops = loop.getLoopStmtsFor(c);
1679   LoopNest::reorderAxis(loops.at(0), loops.at(1));
1680   loops = loop.getLoopStmtsFor(c);
1681   auto c_body = loop.getAllWritesToBuf(c.buf())[1];
1682   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
1683   BufPtr rfac_buf;
1684   ASSERT_TRUE(loop.rfactor(c_body, loops.at(0), &rfac_buf));
1685   loop.distributeLoop(loops.at(0));
1686 
1687   auto all_loops = loop.getAllLoopNestsWritingToBuf(rfac_buf);
1688   ASSERT_TRUE(all_loops.size() == 2 && all_loops.at(1).size() == 3);
1689   LoopNest::reorderAxis(all_loops[1][0], all_loops[1][1]);
1690 
1691   all_loops = loop.getAllLoopNestsWritingToBuf(rfac_buf);
1692   LoopNest::cacheAccesses(rfac_buf, "tmp", all_loops[1][1]);
1693   loop.simplify();
1694   loop.prepareForCodegen();
1695   StmtPtr s = LoopNest::sanitizeNames(loop.root_stmt());
1696   SimpleIREvaluator cg(s, {b, c, m, n, k});
1697 
1698   std::ostringstream oss;
1699   oss << *cg.stmt();
1700   const std::string& expected_ir =
1701       R"IR(
1702 #CHECK: Allocate(sum_rfac); // dtype=float, dims=[n]
1703 #CHECK: Allocate(tmp); // dtype=float, dims=[n]
1704 #CHECK: for (int i_1 = 0; i_1 < m
1705 #CHECK:   for (int j = 0; j < n
1706 #CHECK:     tmp[j] = 0
1707 #CHECK:   }
1708 #CHECK:   for (int j_1 = 0; j_1 < n
1709 #CHECK:     for (int k
1710 #CHECK:       tmp[j_1] = (tmp[j_1]) + (B[
1711 #CHECK:     }
1712 #CHECK:   }
1713 #CHECK:   for (int j_2 = 0; j_2 < n
1714 #CHECK:     sum_rfac[j_2] = (sum_rfac[j_2]) + (tmp[j_2]);
1715 #CHECK:   }
1716 #CHECK:   Free(tmp);
1717 #CHECK-NOT: tmp
1718       )IR";
1719   torch::jit::testing::FileCheck().run(expected_ir, oss.str());
1720 
1721   cg.call({in, out, M, N, K});
1722   ASSERT_EQ(out[0], 499500);
1723 }
1724 
TEST(Reductions,ReductionRfactorCacheTempInner)1725 TEST(Reductions, ReductionRfactorCacheTempInner) {
1726   const int M = 10;
1727   const int N = 10;
1728   const int K = 10;
1729   VarHandle m("m", kInt);
1730   VarHandle n("n", kInt);
1731   VarHandle k("k", kInt);
1732 
1733   BufHandle b("B", {m, n, k}, kFloat);
1734   std::vector<float> in(M * N * K);
1735   for (int j = 0; j < M * N * K; ++j) {
1736     in[j] = j;
1737   }
1738 
1739   std::vector<float> out(1, -1.f);
1740 
1741   Tensor c = Reduce("sum", {}, Sum(), b, {m, n, k});
1742   LoopNest loop({c});
1743   std::vector<ForPtr> loops = loop.getLoopStmtsFor(c);
1744   auto c_body = loop.getAllWritesToBuf(c.buf())[1];
1745 
1746   LoopNest::reorderAxis(loops.at(0), loops.at(1));
1747   loops = loop.getLoopStmtsFor(c);
1748   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
1749   BufPtr rfac_buf;
1750   ASSERT_TRUE(loop.rfactor(c_body, loops.at(0), &rfac_buf));
1751   loop.distributeLoop(loops.at(0));
1752   auto all_loops = loop.getAllLoopNestsWritingToBuf(rfac_buf);
1753   ASSERT_TRUE(all_loops.size() == 2 && all_loops.at(1).size() == 3);
1754   LoopNest::reorderAxis(all_loops[1][0], all_loops[1][1]);
1755 
1756   all_loops = loop.getAllLoopNestsWritingToBuf(rfac_buf);
1757   ASSERT_TRUE(all_loops.size() == 2 && all_loops.at(1).size() == 3);
1758   LoopNest::cacheAccesses(rfac_buf, "tmp", all_loops[1][2]);
1759   loop.prepareForCodegen();
1760   loop.simplify();
1761   StmtPtr s = LoopNest::sanitizeNames(loop.root_stmt());
1762   SimpleIREvaluator cg(s, {b, c, m, n, k});
1763 
1764   std::ostringstream oss;
1765   oss << *cg.stmt();
1766   const std::string& expected_ir =
1767       R"IR(
1768 #CHECK: Allocate(sum_rfac); // dtype=float, dims=[n]
1769 #CHECK: Allocate(tmp); // dtype=float, dims=[1]
1770 #CHECK: for (int i_1 = 0; i_1 < m
1771 #CHECK:   for (int j = 0; j < n
1772 #CHECK:     tmp[0] = 0
1773 #CHECK:     for (int k
1774 #CHECK:       tmp[0] = (tmp[0]) + (B[
1775 #CHECK:     }
1776 #CHECK:   sum_rfac[j] = (sum_rfac[j]) + (tmp[0]);
1777 #CHECK:   Free(tmp);
1778 #CHECK-NOT: tmp
1779       )IR";
1780   torch::jit::testing::FileCheck().run(expected_ir, oss.str());
1781 
1782   cg.call({in, out, M, N, K});
1783   ASSERT_EQ(out[0], 499500);
1784 }
1785 
TEST(Reductions,ReductionVectorize)1786 TEST(Reductions, ReductionVectorize) {
1787   std::vector<float> in_(8 * 8);
1788   for (const auto i : c10::irange(8)) {
1789     for (const auto j : c10::irange(8)) {
1790       in_[i * 8 + j] = i;
1791     }
1792   }
1793   std::vector<float> out_before(8, -1.f);
1794   std::vector<float> out_after(8, -1.f);
1795 
1796   BufHandle in("in", {8, 8}, kFloat);
1797 
1798   Tensor tensor = Reduce("sum", {8}, Sum(), in, {8});
1799   LoopNest l_before({tensor});
1800   LoopNest l(l_before);
1801   l_before.prepareForCodegen();
1802   SimpleIREvaluator cg_before(l_before.root_stmt(), {in, tensor});
1803   cg_before.call({in_, out_before});
1804 
1805   ASSERT_TRUE(LoopNest::vectorize(l.getLoopStmtsFor(tensor)[0]));
1806 
1807   StmtPtr s = l.root_stmt();
1808   s = LoopNest::sanitizeNames(IRSimplifier::simplify(s));
1809 
1810   std::ostringstream oss;
1811   oss << *s;
1812   const std::string& expected_ir =
1813       R"IR(
1814 #CHECK: sum[Ramp(0, 1, 8)] = Broadcast(0.f, 8);
1815 #CHECK: for (int i = 0; i < 8; i++) {
1816 #CHECK: sum[Ramp(0, 1, 8)] = ReduceOp((sum[Ramp(0, 1, 8)]) + (in[Ramp(i, 8, 8)]), reduce_args={i});
1817 #CHECK: }
1818       )IR";
1819   torch::jit::testing::FileCheck().run(expected_ir, oss.str());
1820 
1821   // Vectorizing should not change result.
1822   l.prepareForCodegen();
1823   s = IRSimplifier::simplify(l.root_stmt());
1824   SimpleIREvaluator cg_after(s, {in, tensor});
1825   cg_after.call({in_, out_after});
1826   for (const auto i : c10::irange(8)) {
1827     ASSERT_EQ(out_before[i], out_after[i]);
1828   }
1829 }
1830 
TEST(Reductions,ReductionVectorizeInner)1831 TEST(Reductions, ReductionVectorizeInner) {
1832   BufHandle in("in", {8, 8}, kFloat);
1833 
1834   Tensor tensor = Reduce("sum", {8}, Sum(), in, {8});
1835   LoopNest l({tensor});
1836 
1837   ASSERT_FALSE(LoopNest::vectorize(l.getLoopStmtsFor(tensor)[1]));
1838 }
1839 
TEST(Reductions,ReductionVectorizeRfactor)1840 TEST(Reductions, ReductionVectorizeRfactor) {
1841   std::vector<float> in_(8 * 8);
1842   for (const auto i : c10::irange(8)) {
1843     for (const auto j : c10::irange(8)) {
1844       in_[i * 8 + j] = i;
1845     }
1846   }
1847   std::vector<float> out_before(1, -1.f);
1848   std::vector<float> out_after(1, -1.f);
1849 
1850   BufHandle in("in", {8, 8}, kFloat);
1851 
1852   Tensor tensor = Reduce("sum", {}, Sum(), in, {8, 8});
1853 
1854   LoopNest l_before({tensor});
1855   LoopNest l(l_before);
1856   l_before.prepareForCodegen();
1857   SimpleIREvaluator cg_before(l_before.root_stmt(), {in, tensor});
1858   cg_before.call({in_, out_before});
1859 
1860   ASSERT_FALSE(LoopNest::vectorize(l.getLoopStmtsFor(tensor)[1]));
1861 
1862   // But if we rfactor this so it's not a reduce axis we can vectorize that
1863   // loop.
1864   std::vector<ForPtr> loops = l.getLoopStmtsFor(tensor);
1865   LoopNest::reorderAxis(loops[0], loops[1]);
1866   loops = l.getLoopStmtsFor(tensor);
1867   auto tensor_body = l.getAllWritesToBuf(tensor.buf())[1];
1868   BufPtr rfac_buf = nullptr;
1869   ASSERT_TRUE(LoopNest::rfactor(tensor_body, loops.at(0), &rfac_buf));
1870 
1871   LoopNest::distributeLoop(loops.at(0));
1872   auto rfac_loops = l.getAllLoopNestsWritingToBuf(rfac_buf);
1873 
1874   ASSERT_TRUE(LoopNest::vectorize(rfac_loops[1][0]));
1875   l.simplify();
1876 
1877   StmtPtr s = LoopNest::sanitizeNames(l.root_stmt());
1878 
1879   std::ostringstream oss;
1880   oss << *s;
1881   const std::string& expected_ir =
1882       R"IR(
1883 #CHECK: sum = 0.f;
1884 #CHECK: for (int i = 0; i < 8; i++) {
1885 #CHECK:   sum_rfac[i] = 0.f;
1886 #CHECK: }
1887 #CHECK: for (int i_1 = 0; i_1 < 8; i_1++) {
1888 #CHECK:   sum_rfac[Ramp(0, 1, 8)] = ReduceOp((sum_rfac[Ramp(0, 1, 8)]) + (in[Ramp(8 * i_1, 1, 8)]), reduce_args={i_1});
1889 #CHECK: }
1890 #CHECK: for (int i_2 = 0; i_2 < 8; i_2++) {
1891 #CHECK:   sum = ReduceOp((sum) + (sum_rfac[i_2]), reduce_args={i_2});
1892 #CHECK: }
1893       )IR";
1894   torch::jit::testing::FileCheck().run(expected_ir, oss.str());
1895 
1896   // Vectorizing should not change result.
1897   l.prepareForCodegen();
1898   s = IRSimplifier::simplify(l.root_stmt());
1899   SimpleIREvaluator cg_after(s, {in, tensor});
1900   cg_after.call({in_, out_after});
1901 
1902   ASSERT_EQ(out_before[0], out_after[0]);
1903 }
1904 
TEST(Reductions,InitFunction)1905 TEST(Reductions, InitFunction) {
1906   constexpr int M = 32;
1907   constexpr int N = 16;
1908   BufHandle A("A", {M, N}, kFloat);
1909   BufHandle B("B", {N}, kFloat);
1910   Tensor C = Reduce(
1911       "C",
1912       {N},
1913       Sum(),
1914       [&](const std::vector<VarHandle>& v) { return B.load(v[0]); },
1915       [&](const std::vector<VarHandle>& v) { return A.load(v[1], v[0]); },
1916       {M});
1917   LoopNest nest({C});
1918   nest.prepareForCodegen();
1919   StmtPtr s = LoopNest::sanitizeNames(IRSimplifier::simplify(nest.root_stmt()));
1920   std::ostringstream oss;
1921   oss << *s << "\n";
1922   const std::string& expected_ir =
1923       R"IR(
1924 #CHECK:  for (int i = 0; i < 16; i++) {
1925 #CHECK:    C[i] = B[i];
1926 #CHECK:    for (int j = 0; j < 32; j++) {
1927 #CHECK:      C[i] = (C[i]) + (A[i + 16 * j]);
1928 #CHECK:    }
1929 #CHECK:  }
1930       )IR";
1931   torch::jit::testing::FileCheck().run(expected_ir, oss.str());
1932 }
1933 } // namespace jit
1934 } // namespace torch
1935