1 #include <gtest/gtest.h>
2
3 #include <limits>
4 #include <memory>
5 #include <sstream>
6 #include <stdexcept>
7 #include <unordered_map>
8
9 #include <test/cpp/tensorexpr/test_base.h>
10
11 #include <c10/util/irange.h>
12 #include <test/cpp/tensorexpr/padded_buffer.h>
13 #include <torch/csrc/jit/tensorexpr/analysis.h>
14 #include <torch/csrc/jit/tensorexpr/eval.h>
15 #include <torch/csrc/jit/tensorexpr/ir.h>
16 #include <torch/csrc/jit/tensorexpr/ir_printer.h>
17 #include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
18 #include <torch/csrc/jit/tensorexpr/loopnest.h>
19 #include <torch/csrc/jit/tensorexpr/tensor.h>
20 #include <torch/csrc/jit/testing/file_check.h>
21
22 namespace torch {
23 namespace jit {
24
25 using namespace torch::jit::tensorexpr;
26
TEST(Reductions,ReduceSum0D_1)27 TEST(Reductions, ReduceSum0D_1) {
28 const int M = 10;
29
30 BufHandle b("b", {M}, kFloat);
31 std::vector<float> in(M);
32 for (const auto j : c10::irange(M)) {
33 in[j] = j;
34 }
35
36 std::vector<float> out(M, -1.f);
37
38 Tensor c = Reduce("sum", {M}, Sum(), b, {});
39 LoopNest loop({c});
40 loop.prepareForCodegen();
41 StmtPtr s = loop.root_stmt();
42 s = IRSimplifier::simplify(s);
43
44 SimpleIREvaluator cg(s, {b, c});
45
46 cg.call({in, out});
47 for (const auto i : c10::irange(M)) {
48 ASSERT_EQ(out[i], in[i]);
49 }
50 }
51
TEST(Reductions,ReduceSum0D_2)52 TEST(Reductions, ReduceSum0D_2) {
53 BufHandle b("b", {}, kFloat);
54 std::vector<float> in(1);
55 in[0] = 77.7;
56
57 std::vector<float> out(1, -1.f);
58
59 Tensor c = Reduce("sum", {}, Sum(), b, {});
60 LoopNest loop({c});
61 loop.prepareForCodegen();
62 StmtPtr s = loop.root_stmt();
63 s = IRSimplifier::simplify(s);
64
65 SimpleIREvaluator cg(s, {b, c});
66
67 cg.call({in, out});
68 ASSERT_EQ(out[0], in[0]);
69 }
70
71 // Sum an array to a single value.
TEST(Reductions,ReduceSum1D)72 TEST(Reductions, ReduceSum1D) {
73 BufHandle b("b", {10}, kFloat);
74 std::vector<float> in(10);
75 for (const auto j : c10::irange(10)) {
76 in[j] = j;
77 }
78
79 std::vector<float> out(1, -1.f);
80
81 Tensor c = Reduce("sum", {}, Sum(), b, {10});
82 LoopNest loop({c});
83 loop.prepareForCodegen();
84 StmtPtr s = loop.root_stmt();
85 s = IRSimplifier::simplify(s);
86
87 SimpleIREvaluator cg(s, {b, c});
88
89 cg.call({in, out});
90 ASSERT_EQ(out[0], 45);
91 }
92 // Sum a 2D tensor to a 1D tensor with dynamic shapes.
TEST(Reductions,ReduceSum2D)93 TEST(Reductions, ReduceSum2D) {
94 const int M = 3;
95 const int N = 7;
96
97 VarHandle m("m", kInt);
98 VarHandle n("n", kInt);
99
100 BufHandle b("b", {m, n}, kFloat);
101 std::vector<float> in(M * N);
102 for (const auto i : c10::irange(M)) {
103 for (const auto j : c10::irange(N)) {
104 in[i * N + j] = j;
105 }
106 }
107
108 std::vector<float> out(M, -1.f);
109
110 Tensor c = Reduce("sum", {M}, Sum(), b, {N});
111 LoopNest loop({c});
112 loop.prepareForCodegen();
113 StmtPtr s = loop.root_stmt();
114 s = IRSimplifier::simplify(s);
115
116 SimpleIREvaluator cg(s, {b, c, n, m});
117
118 cg.call({in, out, 5, 7});
119
120 float expected = 0;
121 for (const auto i : c10::irange(N)) {
122 // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
123 expected += i;
124 }
125
126 for (const auto i : c10::irange(M)) {
127 ASSERT_EQ(out[i], expected);
128 }
129 }
130
131 // Sum a 3D tensor to both a 2D and 1D tensor, then reduce the 2D tensor flat to
132 // check our work.
TEST(Reductions,ReduceSum3D)133 TEST(Reductions, ReduceSum3D) {
134 const int M = 10;
135 VarHandle m("m", kInt);
136
137 BufHandle b("b", {2, 3, m}, kFloat);
138
139 Tensor c = Reduce("sum", {2, 3}, Sum(), b, {m});
140 LoopNest loop({c});
141 loop.prepareForCodegen();
142 StmtPtr s = loop.root_stmt();
143 s = IRSimplifier::simplify(s);
144
145 SimpleIREvaluator cg(s, {b, c, m});
146
147 std::vector<float> bData(2 * 3 * M, 0);
148 std::vector<float> cData(2 * 3, 6.0f);
149 std::vector<float> dData(2, 1.0f);
150 std::vector<float> eData(2, 1.0f);
151
152 for (int i = 0; i < 2 * 3; ++i) {
153 for (const auto j : c10::irange(M)) {
154 bData[i * M + j] = j;
155 }
156 }
157
158 cg.call({bData, cData, M});
159 float expected = 0;
160 for (const auto i : c10::irange(M)) {
161 // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
162 expected += i;
163 }
164
165 for (int i = 0; i < 2 * 3; ++i) {
166 ASSERT_EQ(cData[i], expected);
167 }
168
169 Tensor d = Reduce("sum2", {2}, Sum(), b, {3, m});
170 LoopNest loop2({d});
171 loop2.prepareForCodegen();
172 StmtPtr s2 = loop2.root_stmt();
173 s2 = IRSimplifier::simplify(s2);
174
175 SimpleIREvaluator cg2(s2, {b, d, m});
176 cg2.call({bData, dData, M});
177
178 // We're combining an additional dimension of 3, so the sum is 3x.
179 expected = expected * 3;
180
181 for (const auto i : c10::irange(2)) {
182 ASSERT_EQ(dData[i], expected);
183 }
184
185 // This is the same as just reducing the original result across that axis.
186 BufHandle c_buf(c.buf());
187 Tensor e = Reduce("sum3", {2}, Sum(), c_buf, {3});
188 LoopNest loop3({e});
189 loop3.prepareForCodegen();
190 StmtPtr s3 = loop3.root_stmt();
191 s3 = IRSimplifier::simplify(s3);
192
193 SimpleIREvaluator cg3(s3, {c, e});
194 cg3.call({cData, eData});
195
196 for (const auto i : c10::irange(2)) {
197 ASSERT_EQ(eData[i], expected);
198 }
199 }
200
201 // Sum a large (10 D) Tensor 5 dimensions in.
TEST(Reductions,ReduceSum10D)202 TEST(Reductions, ReduceSum10D) {
203 BufHandle in_("in_", {2, 3, 2, 3, 2, 3, 2, 3, 2, 3}, kFloat);
204 const int InputSize = 2 * 3 * 2 * 3 * 2 * 3 * 2 * 3 * 2 * 3;
205 BufHandle out_("out_", {2, 3, 2, 3, 2}, kFloat);
206 const int OutputSize = 2 * 3 * 2 * 3 * 2;
207
208 std::vector<float> in(InputSize, 1.f);
209 std::vector<float> out(OutputSize, -1.f);
210
211 Tensor c = Reduce("sum", {2, 3, 2, 3, 2}, Sum(), in_, {3, 2, 3, 2, 3});
212 LoopNest loop({c});
213 loop.prepareForCodegen();
214 StmtPtr s = loop.root_stmt();
215 s = IRSimplifier::simplify(s);
216
217 SimpleIREvaluator cg(s, {in_, c});
218
219 cg.call({in, out});
220
221 // NOLINTNEXTLINE(bugprone-integer-division)
222 float expected = InputSize / OutputSize;
223 for (const auto i : c10::irange(OutputSize)) {
224 ASSERT_EQ(out[i], expected);
225 }
226 }
227
228 // Reduce via Mul rather than Add using a custom Reducer.
TEST(Reductions,ReduceProduct)229 TEST(Reductions, ReduceProduct) {
230 const int M = 4;
231 const int N = 4;
232
233 BufHandle b("b", {M, N}, kFloat);
234 std::vector<float> in(M * N);
235 for (const auto i : c10::irange(M)) {
236 for (const auto j : c10::irange(N)) {
237 in[i * N + j] = 2 + j;
238 }
239 }
240
241 std::vector<float> out(M, -1.f);
242
243 Reducer product(
244 ExprHandle(1.f), [](ExprHandle a, ExprHandle b) { return a * b; });
245
246 Tensor c = Reduce("product", {M}, product, b, {N});
247 LoopNest loop({c});
248 loop.prepareForCodegen();
249 StmtPtr s = loop.root_stmt();
250 s = IRSimplifier::simplify(s);
251
252 SimpleIREvaluator cg(s, {b, c});
253
254 cg.call({in, out});
255
256 float expected = 1;
257 for (const auto i : c10::irange(N)) {
258 // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
259 expected *= 2 + i;
260 }
261
262 for (const auto i : c10::irange(M)) {
263 ASSERT_EQ(out[i], expected);
264 }
265 }
266
267 // Maximum reductions.
TEST(Reductions,ReduceMax)268 TEST(Reductions, ReduceMax) {
269 BufHandle in_("b", {10}, kFloat);
270
271 std::vector<float> in(10);
272 std::vector<float> out(1, -1.f);
273 for (const auto j : c10::irange(10)) {
274 in[j] = j;
275 }
276
277 Tensor dm1 = Reduce("max", {}, Maximum(kFloat), in_, {10});
278
279 LoopNest loop({dm1});
280 loop.prepareForCodegen();
281 StmtPtr s = loop.root_stmt();
282 s = IRSimplifier::simplify(s);
283 SimpleIREvaluator cg(s, {in_, dm1});
284
285 cg.call({in, out});
286
287 ASSERT_EQ(out[0], 9);
288
289 BufHandle in2_("b", {2, 5}, kFloat);
290 std::vector<float> out2(2, -1.f);
291
292 Tensor m2d = Reduce("max", {2}, Maximum(kFloat), in2_, {5});
293
294 LoopNest loop2({m2d});
295 loop2.prepareForCodegen();
296 s = loop2.root_stmt();
297 s = IRSimplifier::simplify(s);
298
299 SimpleIREvaluator cg2(s, {in2_, m2d});
300 cg2.call({in, out2});
301
302 ASSERT_EQ(out2[0], 4);
303 ASSERT_EQ(out2[1], 9);
304 }
305
306 // Minimum reduction, with custom initialization.
TEST(Reductions,ReduceMinCustomInitializer)307 TEST(Reductions, ReduceMinCustomInitializer) {
308 VarHandle minInit("minInit", kFloat);
309 BufHandle in_("b", {10}, kFloat);
310
311 std::vector<float> in(10);
312 std::vector<float> out(1, -1.f);
313 for (const auto j : c10::irange(10)) {
314 in[j] = 10 + j;
315 }
316
317 Tensor min = Reduce(
318 "min",
319 {},
320 Minimum(ExprHandle(minInit)),
321 [&](ParameterList& v) { return in_.load(v); },
322 {10});
323
324 LoopNest loop({min});
325 loop.prepareForCodegen();
326 StmtPtr s = loop.root_stmt();
327 s = IRSimplifier::simplify(s);
328
329 SimpleIREvaluator cg(s, {in_, min, minInit});
330
331 // Works normally (note that out data starts lower than the correct
332 // minimum).
333 cg.call({in, out, std::numeric_limits<float>::max()});
334 ASSERT_EQ(out[0], 10);
335
336 // With an initalizer lower than the min, that's the min.
337 cg.call({in, out, 5.f});
338 ASSERT_EQ(out[0], 5);
339 }
340
341 // Example implementation of Any/All.
342 // TODO: this is very awkward without logical And/Or operators.
TEST(Reductions,ReduceAnyAll)343 TEST(Reductions, ReduceAnyAll) {
344 VarHandle searchValue("searchValue", kInt);
345 BufHandle b("b", {4, 10}, kInt);
346
347 Reducer anyEqSV(ExprHandle(0), [](ExprHandle a, ExprHandle b) {
348 return CompareSelect::make(a, 1, 1, b, kEQ);
349 });
350
351 Tensor any = Reduce(
352 "anyEqual",
353 {4},
354 anyEqSV,
355 [&](const auto& i, const auto& j) {
356 return CompareSelect::make(b.load(i, j), searchValue, kEQ);
357 },
358 {10});
359
360 LoopNest loop({any});
361 loop.prepareForCodegen();
362 StmtPtr s = loop.root_stmt();
363 s = IRSimplifier::simplify(s);
364
365 SimpleIREvaluator cg(s, {b, any, searchValue});
366
367 std::vector<int> in(40, 0);
368 std::vector<int> out(4, 0);
369
370 // input has 0-39 in 4 rows.
371 for (const auto i : c10::irange(40)) {
372 in[i] = i;
373 }
374 cg.call({in, out, 1});
375
376 // only the first row has 1
377 ASSERT_EQ(out[0], 1);
378 ASSERT_EQ(out[1], 0);
379 ASSERT_EQ(out[2], 0);
380 ASSERT_EQ(out[3], 0);
381
382 cg.call({in, out, 15});
383
384 // 15 in the 3rd row
385 ASSERT_EQ(out[0], 0);
386 ASSERT_EQ(out[1], 1);
387 ASSERT_EQ(out[2], 0);
388 ASSERT_EQ(out[3], 0);
389
390 Reducer allGTSV(ExprHandle(1), [](ExprHandle a, ExprHandle b) {
391 return CompareSelect::make(a, 0, 0, b, kEQ);
392 });
393
394 Tensor allGreaterThan = Reduce(
395 "allGreaterThan",
396 {4},
397 allGTSV,
398 [&](const auto& i, const auto& j) {
399 return CompareSelect::make(b.load(i, j), searchValue, kGT);
400 },
401 {10});
402
403 LoopNest loop2({allGreaterThan});
404 loop2.prepareForCodegen();
405 s = loop2.root_stmt();
406 s = IRSimplifier::simplify(s);
407
408 SimpleIREvaluator cg2(s, {b, allGreaterThan, searchValue});
409
410 cg2.call({in, out, 11});
411
412 // 11 is in row 2.
413 ASSERT_EQ(out[0], 0);
414 ASSERT_EQ(out[1], 0);
415 ASSERT_EQ(out[2], 1);
416 ASSERT_EQ(out[3], 1);
417
418 cg2.call({in, out, -3});
419
420 // All are positive.
421 ASSERT_EQ(out[0], 1);
422 ASSERT_EQ(out[1], 1);
423 ASSERT_EQ(out[2], 1);
424 ASSERT_EQ(out[3], 1);
425 }
426
TEST(Reductions,ReduceMatmul2D)427 TEST(Reductions, ReduceMatmul2D) {
428 BufHandle tA("tA", {3, 2}, kFloat);
429 BufHandle tB("tB", {2, 3}, kFloat);
430
431 std::vector<float> tA_(6);
432 std::vector<float> tB_(6);
433
434 std::vector<float> out(9, -1.f);
435 for (const auto i : c10::irange(3)) {
436 for (const auto j : c10::irange(2)) {
437 tA_[i * 2 + j] = i * 2 + j;
438 tB_[j * 3 + i] = i * 2 + j;
439 }
440 }
441
442 Tensor mm = Reduce(
443 "mm",
444 {3, 3},
445 Sum(),
446 [&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& k) {
447 return tA.load(m, k) * tB.load(k, n);
448 },
449 {2});
450
451 LoopNest loop({mm});
452 loop.prepareForCodegen();
453 StmtPtr s = loop.root_stmt();
454 s = IRSimplifier::simplify(s);
455
456 SimpleIREvaluator cg(s, {tA, tB, mm});
457 cg.call({tA_, tB_, out});
458
459 std::vector<float> expected(
460 {1.f, 3.f, 5.f, 3.f, 13.f, 23.f, 5.f, 23.f, 41.f});
461
462 for (const auto i : c10::irange(9)) {
463 ASSERT_EQ(out[i], expected[i]);
464 }
465 }
466
TEST(Reductions,ReduceRfactorLike)467 TEST(Reductions, ReduceRfactorLike) {
468 BufHandle in("in", {10, 10}, kFloat);
469 std::vector<float> in_(100);
470 for (const auto i : c10::irange(100)) {
471 in_[i] = i;
472 }
473 std::vector<float> in_rf_(10, -2.f);
474 std::vector<float> out(1, -1.f);
475
476 Tensor l1 = Reduce("l1", {10}, Sum(), in, {10});
477 BufHandle in_rf(l1.buf());
478
479 Tensor l2 = Reduce("l2", {}, Sum(), in_rf, {10});
480
481 LoopNest loop({l1, l2});
482 loop.prepareForCodegen();
483 StmtPtr s = loop.root_stmt();
484 s = IRSimplifier::simplify(s);
485
486 SimpleIREvaluator cg(s, {in, l1, l2});
487 cg.call({in_, in_rf_, out});
488
489 ASSERT_EQ(out[0], 99 * 50);
490 }
491
TEST(Reductions,ReduceAsProducer)492 TEST(Reductions, ReduceAsProducer) {
493 const int M = 10;
494 VarHandle m("m", kInt);
495
496 BufHandle a("a", {2, 3}, kFloat);
497 BufHandle b("b", {2, 3, m}, kFloat);
498
499 Tensor c = Reduce("sum", {2, 3}, Sum(), b, {m});
500 Tensor d =
501 Compute("scale", {2, 3}, [&](const VarHandle& l, const VarHandle& n) {
502 return c.load(l, n) * a.load(l, n);
503 });
504 LoopNest loop({d}, {c, d});
505 loop.prepareForCodegen();
506 StmtPtr s = loop.root_stmt();
507 s = IRSimplifier::simplify(s);
508
509 SimpleIREvaluator cg(s, {a, b, d, m});
510
511 std::vector<float> aData(2 * 3, 0);
512 std::vector<float> bData(2 * 3 * M, 0);
513 std::vector<float> dData(2 * 3, 6.0f);
514
515 for (int i = 0; i < 2 * 3; ++i) {
516 aData[i] = 6 - i;
517 for (const auto j : c10::irange(M)) {
518 bData[i * M + j] = j;
519 }
520 }
521
522 cg.call({aData, bData, dData, M});
523 float expected = 0;
524 for (const auto i : c10::irange(M)) {
525 // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
526 expected += i;
527 }
528 for (int i = 0; i < 2 * 3; ++i) {
529 ASSERT_EQ(dData[i], expected * (6 - i));
530 }
531 }
532
TEST(Reductions,ReduceAsConsumer)533 TEST(Reductions, ReduceAsConsumer) {
534 const int M = 10;
535 VarHandle m("m", kInt);
536
537 BufHandle a("a", {2, 3, m}, kFloat);
538 BufHandle b("b", {2, 3, m}, kFloat);
539
540 Tensor c = Compute(
541 "scale",
542 {2, 3, m},
543 [&](const VarHandle& l, const VarHandle& n, const VarHandle& m) {
544 return b.load(l, n, m) * a.load(l, n, m);
545 });
546 Tensor d = Reduce("sum", {2}, Sum(), c, {3, m});
547 LoopNest loop({d}, {c, d});
548 loop.prepareForCodegen();
549 StmtPtr s = loop.root_stmt();
550 s = IRSimplifier::simplify(s);
551
552 SimpleIREvaluator cg(s, {a, b, d, m});
553
554 std::vector<float> aData(2 * 3 * M, 0);
555 std::vector<float> bData(2 * 3 * M, 0);
556 std::vector<float> dData(2, 6.0f);
557
558 for (int i = 0; i < 2 * 3; ++i) {
559 for (const auto j : c10::irange(M)) {
560 bData[i * M + j] = j + 1;
561 aData[i * M + j] = 6 - i;
562 }
563 }
564
565 cg.call({aData, bData, dData, M});
566 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
567 float expected[2] = {0, 0};
568 for (const auto i : c10::irange(2)) {
569 for (const auto j : c10::irange(3)) {
570 for (const auto k : c10::irange(M)) {
571 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
572 expected[i] += (k + 1) * (6 - (i * 3 + j));
573 }
574 }
575 }
576
577 for (const auto i : c10::irange(2)) {
578 ASSERT_EQ(dData[i], expected[i]);
579 }
580 }
581
TEST(Reductions,SplitReduceAxis)582 TEST(Reductions, SplitReduceAxis) {
583 BufHandle in("in", {16, 8}, kFloat);
584
585 std::vector<float> in_(16 * 8);
586 for (const auto i : c10::irange(16)) {
587 for (const auto j : c10::irange(8)) {
588 in_[i * 8 + j] = i;
589 }
590 }
591 std::vector<float> out(16, -1.f);
592
593 Tensor tensor = Reduce("sum", {16}, Sum(), in, {8});
594 LoopNest l({tensor});
595 std::vector<ForPtr> loops = l.getLoopStmtsFor(tensor);
596 LoopNest::splitWithTail(loops[1], 2);
597
598 l.prepareForCodegen();
599
600 StmtPtr s = l.root_stmt();
601 s = IRSimplifier::simplify(s);
602
603 SimpleIREvaluator cg(s, {in, tensor});
604 cg.call({in_, out});
605
606 for (const auto i : c10::irange(16)) {
607 ASSERT_EQ(out[i], i * 8);
608 }
609 }
610
TEST(Reductions,SplitNonReduceAxis)611 TEST(Reductions, SplitNonReduceAxis) {
612 BufHandle in("in", {16, 8}, kFloat);
613
614 std::vector<float> in_(16 * 8);
615 for (const auto i : c10::irange(16)) {
616 for (const auto j : c10::irange(8)) {
617 in_[i * 8 + j] = i;
618 }
619 }
620 std::vector<float> out(16, -1.f);
621 Tensor tensor = Reduce("sum", {16}, Sum(), in, {8});
622 LoopNest l({tensor});
623 std::vector<ForPtr> loops = l.getLoopStmtsFor(tensor);
624 LoopNest::splitWithTail(loops[0], 2);
625 LoopNest::splitWithTail(loops[0], 2);
626
627 l.prepareForCodegen();
628
629 StmtPtr s = l.root_stmt();
630 s = IRSimplifier::simplify(s);
631
632 SimpleIREvaluator cg(s, {in, tensor});
633 cg.call({in_, out});
634
635 for (const auto i : c10::irange(16)) {
636 ASSERT_EQ(out[i], i * 8);
637 }
638 }
639
TEST(Reductions,ReorderedReductionInitializer)640 TEST(Reductions, ReorderedReductionInitializer) {
641 /* From the quip:
642 for k in 0..1: // blockIdx
643 for m in 0..128:
644 for n in 0..64: // threadIdx
645 SumOp(c(k, n), 0, a(k, m, n), {m})
646 */
647
648 BufHandle in("in", {1, 12, 6}, kFloat);
649 std::vector<float> in_(12 * 6, 1.f);
650
651 Tensor tensor_ = Reduce("sum", {1, 12}, Sum(), in, {6});
652 LoopNest l_({tensor_});
653
654 l_.prepareForCodegen();
655 StmtPtr s_ = Stmt::clone(l_.root_stmt());
656 s_ = IRSimplifier::simplify(s_);
657
658 Tensor tensor = Reduce("sum", {1, 12}, Sum(), in, {6});
659 LoopNest l({tensor});
660
661 auto loops = l.getLoopStmtsFor(tensor);
662 loops[0]->set_gpu_block_index(0);
663 loops[1]->set_gpu_thread_index(0);
664
665 LoopNest::reorderAxis(loops[1], loops[2]);
666
667 StmtPtr s = l.root_stmt();
668 // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
669 s = IRSimplifier::simplify(s);
670
671 l.prepareForCodegen();
672
673 s = l.root_stmt();
674 s = IRSimplifier::simplify(s);
675
676 std::vector<float> out1(16, -1.f);
677 SimpleIREvaluator cg(s_, {in, tensor_});
678 cg.call({in_, out1});
679
680 std::vector<float> out2(16, -1.f);
681 SimpleIREvaluator cg2(s, {in, tensor});
682 cg2.call({in_, out2});
683
684 for (const auto i : c10::irange(16)) {
685 ASSERT_EQ(out1[i], out2[i]);
686 }
687 }
688
TEST(Reductions,ReduceRfactor)689 TEST(Reductions, ReduceRfactor) {
690 const int M = 10;
691 const int N = 10;
692 VarHandle m("m", kInt);
693 VarHandle n("n", kInt);
694
695 BufHandle b("b", {m, n}, kFloat);
696 std::vector<float> in(M * N);
697 for (int j = 0; j < M * N; ++j) {
698 in[j] = j;
699 }
700
701 std::vector<float> out(1, -1.f);
702
703 Tensor c = Reduce("sum", {}, Sum(), b, {m, n});
704 LoopNest loop({c});
705 std::vector<ForPtr> loops = loop.getLoopStmtsFor(c);
706 auto c_body = loop.getAllWritesToBuf(c.buf())[1];
707 ASSERT_TRUE(loop.rfactor(c_body, loops.at(0)));
708 auto rc = NodeFinder<ReduceOp>::find(loop.root_stmt());
709 ASSERT_EQ(rc.size(), 2);
710 loop.prepareForCodegen();
711 StmtPtr s = loop.root_stmt();
712 s = IRSimplifier::simplify(s);
713
714 SimpleIREvaluator cg(s, {b, c, m, n});
715
716 cg.call({in, out, M, N});
717 ASSERT_EQ(out[0], 4950);
718 }
719
TEST(Reductions,Reduce3DRfactorInner)720 TEST(Reductions, Reduce3DRfactorInner) {
721 const int M = 10;
722 const int N = 10;
723 const int K = 10;
724 VarHandle m("m", kInt);
725 VarHandle n("n", kInt);
726 VarHandle k("k", kInt);
727
728 BufHandle b("b", {m, n, k}, kFloat);
729 std::vector<float> in(M * N * K);
730 for (int j = 0; j < M * N * K; ++j) {
731 in[j] = j;
732 }
733
734 std::vector<float> out(1, -1.f);
735
736 Tensor c = Reduce("sum", {}, Sum(), b, {m, n, k});
737 LoopNest loop({c});
738 std::vector<ForPtr> loops = loop.getLoopStmtsFor(c);
739 auto c_body = loop.getAllWritesToBuf(c.buf())[1];
740 ASSERT_FALSE(loop.rfactor(c_body, loops.at(2)));
741 auto rc = NodeFinder<ReduceOp>::find(loop.root_stmt());
742 ASSERT_EQ(rc.size(), 1);
743 loop.prepareForCodegen();
744 StmtPtr s = loop.root_stmt();
745 s = IRSimplifier::simplify(s);
746
747 SimpleIREvaluator cg(s, {b, c, m, n, k});
748
749 cg.call({in, out, M, N, K});
750 ASSERT_EQ(out[0], 499500);
751 }
752
TEST(Reductions,Reduce3DRfactorOuter)753 TEST(Reductions, Reduce3DRfactorOuter) {
754 const int M = 10;
755 const int N = 10;
756 const int K = 10;
757 VarHandle m("m", kInt);
758 VarHandle n("n", kInt);
759 VarHandle k("k", kInt);
760
761 BufHandle b("b", {m, n, k}, kFloat);
762 std::vector<float> in(M * N * K);
763 for (int j = 0; j < M * N * K; ++j) {
764 in[j] = j;
765 }
766
767 std::vector<float> out(1, -1.f);
768
769 Tensor c = Reduce("sum", {}, Sum(), b, {m, n, k});
770 LoopNest loop({c});
771 std::vector<ForPtr> loops = loop.getLoopStmtsFor(c);
772 auto c_body = loop.getAllWritesToBuf(c.buf())[1];
773 ASSERT_TRUE(loop.rfactor(c_body, loops.at(0)));
774 auto rc = NodeFinder<ReduceOp>::find(loop.root_stmt());
775 ASSERT_EQ(rc.size(), 2);
776 loop.prepareForCodegen();
777 StmtPtr s = loop.root_stmt();
778 s = IRSimplifier::simplify(s);
779
780 SimpleIREvaluator cg(s, {b, c, m, n, k});
781 cg.call({in, out, M, N, K});
782 ASSERT_EQ(out[0], 499500);
783 }
784
TEST(Reductions,ReduceRepeatedInternalRfactor)785 TEST(Reductions, ReduceRepeatedInternalRfactor) {
786 BufHandle in_("in_", {2, 3, 4, 5, 6}, kFloat);
787 const int InputSize = 2 * 3 * 4 * 5 * 6;
788
789 std::vector<float> in(InputSize, 1.f);
790 std::vector<float> out(1, -1.f);
791 std::vector<float> ref(1, -1.f);
792
793 Tensor c = Reduce("sum", {}, Sum(), in_, {2, 3, 4, 5, 6});
794 LoopNest orig_loop({c});
795
796 // Try rfactoring N outer loops
797 for (const auto rfac_number : c10::irange(1, 5)) {
798 LoopNest refloop(orig_loop);
799 LoopNest loop(orig_loop);
800 refloop.prepareForCodegen();
801 SimpleIREvaluator ref_cg(
802 IRSimplifier::simplify(refloop.root_stmt()), {in_, c});
803 ref_cg.call({in, ref});
804
805 BufPtr tmp_buf = c.buf();
806
807 for (const auto idx : c10::irange(rfac_number)) {
808 auto reduce = loop.getAllWritesToBuf(tmp_buf)[1];
809 ASSERT_TRUE(loop.rfactor(
810 reduce, loop.getLoopStmtsFor(tmp_buf).at(idx), &tmp_buf));
811 }
812
813 loop.prepareForCodegen();
814 StmtPtr s = loop.root_stmt();
815 s = IRSimplifier::simplify(s);
816
817 SimpleIREvaluator cg(s, {in_, c});
818 cg.call({in, out});
819
820 ASSERT_EQ(ref[0], out[0]);
821 }
822 }
823
824 // Split a reduction axis with a tail loop.
TEST(Reductions,ReduceSplitTail)825 TEST(Reductions, ReduceSplitTail) {
826 const int M = 10;
827 const int N = 10;
828 const int K = 10;
829
830 BufHandle b("b", {M, N, K}, kFloat);
831 std::vector<float> in(M * N * K);
832 for (int j = 0; j < M * N * K; ++j) {
833 in[j] = j;
834 }
835
836 for (const auto i : c10::irange(3)) {
837 std::vector<float> out(M, -1.f);
838
839 Tensor c = Reduce("sum", {M}, Sum(), b, {N, K});
840 LoopNest loop({c});
841 std::vector<ForPtr> loops = loop.getLoopStmtsFor(c);
842 LoopNest::splitWithTail(loops[i], 8);
843
844 loop.prepareForCodegen();
845 StmtPtr s = loop.root_stmt();
846 s = IRSimplifier::simplify(s);
847
848 SimpleIREvaluator cg(s, {b, c});
849
850 cg.call({in, out});
851 ASSERT_EQ(out[0], 4950);
852 }
853 }
854
855 // Split a reduction axis cleanly so there is no tail loop.
TEST(Reductions,ReduceSplitNoTail)856 TEST(Reductions, ReduceSplitNoTail) {
857 const int M = 10;
858 const int N = 10;
859 const int K = 10;
860 BufHandle b("b", {M, N, K}, kFloat);
861 std::vector<float> in(M * N * K);
862 for (int j = 0; j < M * N * K; ++j) {
863 in[j] = j;
864 }
865
866 for (const auto i : c10::irange(3)) {
867 std::vector<float> out(M, -1.f);
868
869 Tensor c = Reduce("sum", {M}, Sum(), b, {N, K});
870 LoopNest loop({c});
871 std::vector<ForPtr> loops = loop.getLoopStmtsFor(c);
872 LoopNest::splitWithTail(loops[i], 5);
873
874 loop.prepareForCodegen();
875 StmtPtr s = loop.root_stmt();
876 s = IRSimplifier::simplify(s);
877
878 SimpleIREvaluator cg(s, {b, c});
879
880 cg.call({in, out});
881 ASSERT_EQ(out[0], 4950);
882 }
883 }
884
885 // Split a reduction axis with only a tail loop (the split loop will be size 0
886 // and eliminated out).
TEST(Reductions,ReduceOverSplitTail)887 TEST(Reductions, ReduceOverSplitTail) {
888 const int M = 10;
889 const int N = 10;
890 const int K = 10;
891
892 BufHandle b("b", {M, N, K}, kFloat);
893 std::vector<float> in(M * N * K);
894 for (int j = 0; j < M * N * K; ++j) {
895 in[j] = j;
896 }
897
898 for (const auto i : c10::irange(3)) {
899 std::vector<float> out(M, -1.f);
900
901 Tensor c = Reduce("sum", {M}, Sum(), b, {N, K});
902 LoopNest loop({c});
903 std::vector<ForPtr> loops = loop.getLoopStmtsFor(c);
904 LoopNest::splitWithTail(loops[i], 16);
905
906 loop.prepareForCodegen();
907 StmtPtr s = loop.root_stmt();
908 s = IRSimplifier::simplify(s);
909
910 SimpleIREvaluator cg(s, {b, c});
911
912 cg.call({in, out});
913 ASSERT_EQ(out[0], 4950);
914 }
915 }
916
917 // Split a reduction axis with a mask.
TEST(Reductions,ReduceSplitMask)918 TEST(Reductions, ReduceSplitMask) {
919 const int M = 10;
920 const int N = 10;
921 const int K = 10;
922
923 BufHandle b("b", {M, N, K}, kFloat);
924 std::vector<float> in(M * N * K);
925 for (int j = 0; j < M * N * K; ++j) {
926 in[j] = j;
927 }
928
929 for (const auto i : c10::irange(3)) {
930 std::vector<float> out(M, -1.f);
931
932 Tensor c = Reduce("sum", {M}, Sum(), b, {N, K});
933 LoopNest loop({c});
934 std::vector<ForPtr> loops = loop.getLoopStmtsFor(c);
935 LoopNest::splitWithMask(loops[i], 8);
936
937 loop.prepareForCodegen();
938 StmtPtr s = loop.root_stmt();
939 s = IRSimplifier::simplify(s);
940
941 SimpleIREvaluator cg(s, {b, c});
942
943 cg.call({in, out});
944 ASSERT_EQ(out[0], 4950);
945 }
946 }
947
948 // Split a reduction axis cleanly not requiring a mask.
TEST(Reductions,ReduceSplitNoMask)949 TEST(Reductions, ReduceSplitNoMask) {
950 const int M = 10;
951 const int N = 10;
952 const int K = 10;
953 BufHandle b("b", {M, N, K}, kFloat);
954 std::vector<float> in(M * N * K);
955 for (int j = 0; j < M * N * K; ++j) {
956 in[j] = j;
957 }
958
959 for (const auto i : c10::irange(3)) {
960 std::vector<float> out(M, -1.f);
961
962 Tensor c = Reduce("sum", {M}, Sum(), b, {N, K});
963 LoopNest loop({c});
964 std::vector<ForPtr> loops = loop.getLoopStmtsFor(c);
965 LoopNest::splitWithMask(loops[i], 5);
966
967 loop.prepareForCodegen();
968 StmtPtr s = loop.root_stmt();
969 s = IRSimplifier::simplify(s);
970
971 SimpleIREvaluator cg(s, {b, c});
972
973 cg.call({in, out});
974 ASSERT_EQ(out[0], 4950);
975 }
976 }
977
978 // Split a reduction axis with all logic in the mask.
TEST(Reductions,ReduceOverSplitMask)979 TEST(Reductions, ReduceOverSplitMask) {
980 const int M = 10;
981 const int N = 10;
982 const int K = 10;
983
984 BufHandle b("b", {M, N, K}, kFloat);
985 std::vector<float> in(M * N * K);
986 for (int j = 0; j < M * N * K; ++j) {
987 in[j] = j;
988 }
989
990 for (const auto i : c10::irange(3)) {
991 std::vector<float> out(M, -1.f);
992
993 Tensor c = Reduce("sum", {M}, Sum(), b, {N, K});
994 LoopNest loop({c});
995 std::vector<ForPtr> loops = loop.getLoopStmtsFor(c);
996 LoopNest::splitWithMask(loops[i], 16);
997
998 loop.prepareForCodegen();
999 StmtPtr s = loop.root_stmt();
1000 s = IRSimplifier::simplify(s);
1001
1002 SimpleIREvaluator cg(s, {b, c});
1003
1004 cg.call({in, out});
1005 ASSERT_EQ(out[0], 4950);
1006 }
1007 }
1008
1009 // Test an rfactor when there are two ReduceOps in the graph due to a
1010 // splitWithTail.
TEST(Reductions,ReduceSplitRfactor)1011 TEST(Reductions, ReduceSplitRfactor) {
1012 const int M = 2;
1013 const int N = 10;
1014 const int K = 10;
1015 const int SPLIT_FACTOR = 4;
1016
1017 BufHandle b("b", {M, N, K}, kFloat);
1018 std::vector<float> in(M * N * K);
1019 for (const auto m : c10::irange(M)) {
1020 for (int j = 0; j < N * K; ++j) {
1021 in[m * N * K + j] = j;
1022 }
1023 }
1024
1025 std::vector<float> out(M, -1.f);
1026
1027 Tensor c = Reduce("sum", {M}, Sum(), b, {N, K});
1028 LoopNest loop({c});
1029 std::vector<ForPtr> loops = loop.getLoopStmtsFor(c);
1030 LoopNest::splitWithTail(loops[2], SPLIT_FACTOR);
1031
1032 auto c_body = loop.getAllWritesToBuf(c.buf())[2];
1033 auto all_loops = loop.getAllLoopNestsWritingToBuf(c.buf());
1034 ASSERT_TRUE(all_loops.size() == 3 && all_loops.at(2).size() == 3);
1035 LoopNest::reorderAxis(all_loops[2][1], all_loops[2][2]);
1036 all_loops = loop.getAllLoopNestsWritingToBuf(c.buf());
1037 ASSERT_TRUE(all_loops.size() == 3 && all_loops.at(2).size() == 3);
1038 ASSERT_TRUE(loop.rfactor(c_body, all_loops[2][1]));
1039 loop.prepareForCodegen();
1040 loop.simplify();
1041 StmtPtr s = loop.root_stmt();
1042
1043 SimpleIREvaluator cg(s, {b, c});
1044
1045 cg.call({in, out});
1046 for (const auto i : c10::irange(M)) {
1047 (void)i; // Suppress unused variable warning
1048 ASSERT_EQ(out[0], 4950);
1049 }
1050 }
1051
1052 // Test an rfactor which ends up being eliminated since the total loop size is
1053 // smaller than the split factor.
TEST(Reductions,ReduceOverSplitRfactor)1054 TEST(Reductions, ReduceOverSplitRfactor) {
1055 const int N = 10;
1056 const int K = 10;
1057 const int SPLIT_FACTOR = 16;
1058
1059 BufHandle b("b", {N, K}, kFloat);
1060 std::vector<float> in(N * K);
1061 for (int j = 0; j < N * K; ++j) {
1062 in[j] = j;
1063 }
1064
1065 std::vector<float> out(1, -1.f);
1066
1067 Tensor c = Reduce("sum", {}, Sum(), b, {N, K});
1068 LoopNest loop({c});
1069 std::vector<ForPtr> loops = loop.getLoopStmtsFor(c);
1070 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
1071 ForPtr i, t;
1072 LoopNest::splitWithTail(loops[1], SPLIT_FACTOR, &i, &t);
1073 LoopNest::reorderAxis(loops[0], i);
1074
1075 auto all_loops = loop.getAllLoopNestsWritingToBuf(c.buf());
1076 ASSERT_TRUE(all_loops.size() == 3 && all_loops.at(1).size() == 3);
1077 auto c_body = loop.getAllWritesToBuf(c.buf())[1];
1078 ASSERT_TRUE(loop.rfactor(c_body, all_loops[1][0]));
1079 LoopNest::reorderAxis(all_loops[1][0], all_loops[1][2]);
1080
1081 loop.prepareForCodegen();
1082 loop.simplify();
1083 StmtPtr s = loop.root_stmt();
1084
1085 SimpleIREvaluator cg(s, {b, c});
1086
1087 cg.call({in, out});
1088 ASSERT_EQ(out[0], 4950);
1089
1090 std::ostringstream oss;
1091 oss << *cg.stmt();
1092
1093 // Check the IR to verify the rfactored reduce is eliminated.
1094 // TODO: The alloc free should be eliminated here since it is size 0.
1095 /*
1096 const std::string& verification_pattern =
1097 R"IR(
1098 # CHECK: Allocate(tmp_buf); // dtype=float, dims=[0]
1099 # CHECK: sum[0] = 0.f;
1100 # CHECK: for (int n = 0; n < 10; n++) {
1101 # CHECK: for (int k_tail = 0; k_tail < 10; k_tail++) {
1102 # CHECK: sum[0] = (sum[0]) + (b[k_tail + 10 * n]);
1103 # CHECK: }
1104 # CHECK: }
1105 # CHECK: Free(tmp_buf);)IR";
1106 */
1107 // TODO: rfactor output is not consistent yet, will fix (@nickg).
1108 // torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
1109 }
1110
TEST(Reductions,ReduceInlineReduction)1111 TEST(Reductions, ReduceInlineReduction) {
1112 const int M = 4;
1113 const int N = 5;
1114 const int K = 6;
1115
1116 BufHandle a_buf("a", {M}, kFloat);
1117 BufHandle b_buf("b", {M, N, K}, kFloat);
1118
1119 Tensor x = Reduce("x", {M}, Sum(), b_buf, {N, K});
1120 Tensor y = Compute(
1121 "y", {M}, [&](const VarHandle& m) { return a_buf.load(m) + x.load(m); });
1122
1123 PaddedBuffer<float> a_v(M);
1124 PaddedBuffer<float> b_v(M, N, K);
1125
1126 for (const auto i : c10::irange(M)) {
1127 a_v(i) = i * i;
1128 }
1129 for (const auto i : c10::irange(M)) {
1130 for (const auto j : c10::irange(N)) {
1131 for (const auto k : c10::irange(K)) {
1132 b_v(i, j, k) = j * j * k;
1133 }
1134 }
1135 }
1136
1137 LoopNest l1({y}, {x, y});
1138 // Cannot inline a reduction computation
1139 ASSERT_FALSE(l1.computeInline(x.buf()));
1140 }
1141
TEST(Reductions,ReduceInlineConsumer)1142 TEST(Reductions, ReduceInlineConsumer) {
1143 const int M = 4;
1144 const int N = 5;
1145 const int K = 6;
1146
1147 BufHandle a_buf("a", {M, N, K}, kFloat);
1148 BufHandle b_buf("b", {M, N, K}, kFloat);
1149
1150 Tensor x = Compute(
1151 "x",
1152 {M, N, K},
1153 [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
1154 return a_buf.load(m, n, k) + b_buf.load(m, n, k);
1155 });
1156 Tensor y = Reduce("y", {M}, Sum(), x, {N, K});
1157
1158 PaddedBuffer<float> a_v(M, N, K);
1159 PaddedBuffer<float> b_v(M, N, K);
1160
1161 for (const auto i : c10::irange(M)) {
1162 for (const auto j : c10::irange(N)) {
1163 for (const auto k : c10::irange(K)) {
1164 a_v(i, j, k) = i * i + k;
1165 b_v(i, j, k) = j * j + k;
1166 }
1167 }
1168 }
1169
1170 LoopNest l1({y}, {x, y});
1171 LoopNest l2(l1);
1172 l2.computeInline(x.buf());
1173
1174 l1.prepareForCodegen();
1175 l2.prepareForCodegen();
1176
1177 StmtPtr stmt1 = IRSimplifier::simplify(l1.root_stmt());
1178 StmtPtr stmt2 = IRSimplifier::simplify(l2.root_stmt());
1179
1180 SimpleIREvaluator eval1(stmt1, {a_buf, b_buf, y});
1181 SimpleIREvaluator eval2(stmt2, {a_buf, b_buf, y});
1182
1183 PaddedBuffer<float> y_1(M);
1184 PaddedBuffer<float> y_2(M);
1185
1186 eval1(a_v, b_v, y_1);
1187 eval2(a_v, b_v, y_2);
1188 ExpectAllNear(y_1, y_2, 1e-5);
1189 std::ostringstream oss1, oss2;
1190 oss1 << *stmt1;
1191 oss2 << *stmt2;
1192 ASSERT_GT(oss1.str().size(), oss2.str().size());
1193 }
1194
TEST(Reductions,ReduceInlineReducerInternal)1195 TEST(Reductions, ReduceInlineReducerInternal) {
1196 const int M = 4;
1197 const int N = 5;
1198 const int K = 6;
1199
1200 BufHandle a_buf("a", {M, N, K}, kFloat);
1201 BufHandle b_buf("b", {M, N, K}, kFloat);
1202
1203 Tensor x = Compute(
1204 "x",
1205 {M, N, K},
1206 [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
1207 return a_buf.load(m, n, k) + b_buf.load(m, n, k);
1208 });
1209
1210 Reducer minimum(ExprHandle(0.f), [&](ExprHandle a, ExprHandle b) {
1211 return Add::make(ExprHandle(1.f), Min::make(a, b, false));
1212 });
1213 Tensor y = Reduce("y", {M}, minimum, x, {N, K});
1214
1215 PaddedBuffer<float> a_v(M, N, K);
1216 PaddedBuffer<float> b_v(M, N, K);
1217
1218 for (const auto i : c10::irange(M)) {
1219 for (const auto j : c10::irange(N)) {
1220 for (const auto k : c10::irange(K)) {
1221 a_v(i, j, k) = i * i + k;
1222 b_v(i, j, k) = j * j + k;
1223 }
1224 }
1225 }
1226
1227 LoopNest l1({y}, {x, y});
1228 LoopNest l2(l1);
1229 l2.computeInline(x.buf());
1230
1231 l1.prepareForCodegen();
1232 l2.prepareForCodegen();
1233
1234 StmtPtr stmt1 = IRSimplifier::simplify(l1.root_stmt());
1235 StmtPtr stmt2 = IRSimplifier::simplify(l2.root_stmt());
1236
1237 SimpleIREvaluator eval1(stmt1, {a_buf, b_buf, y});
1238 SimpleIREvaluator eval2(stmt2, {a_buf, b_buf, y});
1239
1240 PaddedBuffer<float> y_1(M);
1241 PaddedBuffer<float> y_2(M);
1242
1243 eval1(a_v, b_v, y_1);
1244 eval2(a_v, b_v, y_2);
1245 ExpectAllNear(y_1, y_2, 1e-5);
1246 std::ostringstream oss1, oss2;
1247 oss1 << *stmt1;
1248 oss2 << *stmt2;
1249 ASSERT_GT(oss1.str().size(), oss2.str().size());
1250 }
1251
TEST(Reductions,ReductionCacheAccessesOperatorAxis)1252 TEST(Reductions, ReductionCacheAccessesOperatorAxis) {
1253 int L = 4;
1254 int N = 3;
1255 int M = 2;
1256
1257 BufHandle a("a", {L, N, M}, kFloat);
1258 BufHandle b("b", {L, N, M}, kFloat);
1259
1260 Tensor c = Compute(
1261 "scale",
1262 {L, N, M},
1263 [&](const VarHandle& l, const VarHandle& n, const VarHandle& m) {
1264 return b.load(l, n, m) * a.load(l, n, m);
1265 });
1266 Tensor d = Reduce("sum", {L}, Sum(), c, {N, M});
1267
1268 Tensor e = Compute("scale", {L}, [&](const VarHandle& l) {
1269 return b.load(0, 0, l) * d.load(l);
1270 });
1271
1272 LoopNest l({e}, {c, d, e});
1273 LoopNest l_before(l);
1274 l_before.prepareForCodegen();
1275 SimpleIREvaluator cg_before(
1276 LoopNest::sanitizeNames(l_before.root_stmt()), {a, b, e});
1277
1278 StmtPtr d_loop = l.getLoopStmtsFor(d)[0];
1279 l.cacheAccesses(d.buf(), "d_local", d_loop);
1280 l.prepareForCodegen();
1281
1282 StmtPtr result =
1283 LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt()));
1284 SimpleIREvaluator cg_after(result, {a, b, e});
1285
1286 std::ostringstream oss;
1287 oss << *cg_after.stmt();
1288 const std::string& expected_ir =
1289 R"IR(
1290 #CHECK: Allocate(d_local); // dtype=float, dims=[4]
1291 #CHECK: for (int i_2
1292 #CHECK: d_local[i_2] = 0.f
1293 #CHECK: for (int
1294 #CHECK: for (int
1295 #CHECK: d_local[i_2] = (d_local[i_2]) + (scale[
1296 #CHECK: }
1297 #CHECK: }
1298 #CHECK: }
1299 #CHECK: for (int i_3
1300 #CHECK: sum[i_3] = d_local[i_3]
1301 #CHECK: Free(d_local);
1302 #CHECK-NOT: d_local
1303 )IR";
1304 torch::jit::testing::FileCheck().run(expected_ir, oss.str());
1305
1306 PaddedBuffer<float> a_v(L, M, N, "a");
1307 PaddedBuffer<float> b_v(L, M, N, "b");
1308 PaddedBuffer<float> c_v(L, M, N, "c");
1309 PaddedBuffer<float> d_v(L, "d");
1310 PaddedBuffer<float> e_before(L, "e_before");
1311 PaddedBuffer<float> e_after(L, "e_after");
1312
1313 for (const auto l : c10::irange(L)) {
1314 for (const auto m : c10::irange(M)) {
1315 for (const auto n : c10::irange(N)) {
1316 a_v(l, m, n) = at::randn({1}).item().to<float>();
1317 b_v(l, m, n) = at::randn({1}).item().to<float>();
1318 }
1319 }
1320 }
1321
1322 cg_before.call({a_v, b_v, e_before});
1323 cg_after.call({a_v, b_v, e_after});
1324
1325 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
1326 ExpectAllNear(e_before, e_after, 1e-5);
1327 }
1328
TEST(Reductions,ReductionCacheAccessesOuterReduceAxis)1329 TEST(Reductions, ReductionCacheAccessesOuterReduceAxis) {
1330 int L = 4;
1331 int N = 3;
1332 int M = 2;
1333
1334 BufHandle a("a", {L, N, M}, kFloat);
1335 BufHandle b("b", {L, N, M}, kFloat);
1336
1337 Tensor c = Compute(
1338 "scale",
1339 {L, N, M},
1340 [&](const VarHandle& l, const VarHandle& n, const VarHandle& m) {
1341 return b.load(l, n, m) * a.load(l, n, m);
1342 });
1343 Tensor d = Reduce("sum", {L}, Sum(), c, {N, M});
1344
1345 Tensor e = Compute("scale", {L}, [&](const VarHandle& l) {
1346 return b.load(0, 0, l) * d.load(l);
1347 });
1348
1349 LoopNest l({e}, {c, d, e});
1350 LoopNest l_before(l);
1351 l_before.prepareForCodegen();
1352 SimpleIREvaluator cg_before(l_before.root_stmt(), {a, b, e});
1353
1354 StmtPtr d_loop = l.getLoopStmtsFor(d)[1];
1355 l.cacheAccesses(d.buf(), "d_local", d_loop);
1356 l.prepareForCodegen();
1357
1358 StmtPtr result =
1359 LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt()));
1360 SimpleIREvaluator cg_after(result, {a, b, e});
1361
1362 std::ostringstream oss;
1363 oss << *cg_after.stmt();
1364 const std::string& expected_ir =
1365 R"IR(
1366 #CHECK: Allocate(d_local); // dtype=float, dims=[1]
1367 #CHECK: sum[i_1] = 0
1368 #CHECK: d_local[0] = sum[i_1]
1369 #CHECK: for (int j_1
1370 #CHECK: for (int k_1
1371 #CHECK: d_local[0] = (d_local[0]) + (scale[
1372 #CHECK: }
1373 #CHECK: }
1374 #CHECK: sum[i_1] = d_local[0]
1375 #CHECK: Free(d_local);
1376 #CHECK-NOT: d_local
1377 )IR";
1378 torch::jit::testing::FileCheck().run(expected_ir, oss.str());
1379
1380 PaddedBuffer<float> a_v(L, M, N, "a");
1381 PaddedBuffer<float> b_v(L, M, N, "b");
1382 PaddedBuffer<float> c_v(L, M, N, "c");
1383 PaddedBuffer<float> d_v(L, "d");
1384 PaddedBuffer<float> e_before(L, "e_before");
1385 PaddedBuffer<float> e_after(L, "e_after");
1386
1387 for (const auto l : c10::irange(L)) {
1388 for (const auto m : c10::irange(M)) {
1389 for (const auto n : c10::irange(N)) {
1390 a_v(l, m, n) = at::randn({1}).item().to<float>();
1391 b_v(l, m, n) = at::randn({1}).item().to<float>();
1392 }
1393 }
1394 }
1395
1396 cg_before.call({a_v, b_v, e_before});
1397 cg_after.call({a_v, b_v, e_after});
1398
1399 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
1400 ExpectAllNear(e_before, e_after, 1e-5);
1401 }
1402
TEST(Reductions,ReductionCacheAccessesInnerReduceAxis)1403 TEST(Reductions, ReductionCacheAccessesInnerReduceAxis) {
1404 int L = 4;
1405 int N = 3;
1406 int M = 2;
1407
1408 BufHandle a("a", {L, N, M}, kFloat);
1409 BufHandle b("b", {L, N, M}, kFloat);
1410
1411 Tensor c = Compute(
1412 "scale",
1413 {L, N, M},
1414 [&](const VarHandle& l, const VarHandle& n, const VarHandle& m) {
1415 return b.load(l, n, m) * a.load(l, n, m);
1416 });
1417 Tensor d = Reduce("sum", {L}, Sum(), c, {N, M});
1418
1419 Tensor e = Compute("scale", {L}, [&](const VarHandle& l) {
1420 return b.load(0, 0, l) * d.load(l);
1421 });
1422
1423 LoopNest l({e}, {c, d, e});
1424 LoopNest l_before(l);
1425 l_before.prepareForCodegen();
1426 SimpleIREvaluator cg_before(l_before.root_stmt(), {a, b, e});
1427
1428 StmtPtr d_loop = l.getLoopStmtsFor(d)[2];
1429 l.cacheAccesses(d.buf(), "d_local", d_loop);
1430 l.prepareForCodegen();
1431
1432 StmtPtr result =
1433 LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt()));
1434 SimpleIREvaluator cg_after(result, {a, b, e});
1435
1436 std::ostringstream oss;
1437 oss << *cg_after.stmt();
1438 const std::string& expected_ir =
1439 R"IR(
1440 #CHECK: Allocate(d_local); // dtype=float, dims=[1]
1441 #CHECK: sum[i_1] = 0
1442 #CHECK: for (int
1443 #CHECK: d_local[0] = 0
1444 #CHECK: for (int
1445 #CHECK: d_local[0] = (d_local[0]) + (scale[
1446 #CHECK: }
1447 #CHECK: sum[i_1] = (sum[i_1]) + (d_local[0])
1448 #CHECK: }
1449 #CHECK: Free(d_local);
1450 #CHECK-NOT: d_local
1451 )IR";
1452 torch::jit::testing::FileCheck().run(expected_ir, oss.str());
1453
1454 PaddedBuffer<float> a_v(L, M, N, "a");
1455 PaddedBuffer<float> b_v(L, M, N, "b");
1456 PaddedBuffer<float> c_v(L, M, N, "c");
1457 PaddedBuffer<float> d_v(L, "d");
1458 PaddedBuffer<float> e_before(L, "e_before");
1459 PaddedBuffer<float> e_after(L, "e_after");
1460
1461 for (const auto l : c10::irange(L)) {
1462 for (const auto m : c10::irange(M)) {
1463 for (const auto n : c10::irange(N)) {
1464 a_v(l, m, n) = at::randn({1}).item().to<float>();
1465 b_v(l, m, n) = at::randn({1}).item().to<float>();
1466 }
1467 }
1468 }
1469
1470 cg_before.call({a_v, b_v, e_before});
1471 cg_after.call({a_v, b_v, e_after});
1472
1473 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
1474 ExpectAllNear(e_before, e_after, 1e-5);
1475 }
1476
TEST(Reductions,ReductionCacheBodyAccess)1477 TEST(Reductions, ReductionCacheBodyAccess) {
1478 BufHandle a("a", {24, 32, 12}, kFloat);
1479 BufHandle b("b", {24, 32, 12}, kFloat);
1480
1481 Tensor c = Compute(
1482 "scale",
1483 {24, 32, 12},
1484 [&](const VarHandle& l, const VarHandle& n, const VarHandle& m) {
1485 return b.load(l, n, m) * a.load(l, n, m);
1486 });
1487 Tensor d = Reduce("sum", {24}, Sum(), c, {32, 12});
1488
1489 Tensor e = Compute("scale", {24}, [&](const VarHandle& l) {
1490 return b.load(0, 0, l) * d.load(l);
1491 });
1492
1493 LoopNest l({e}, {c, d, e});
1494
1495 StmtPtr d_loop = l.getLoopStmtsFor(d)[1];
1496 l.cacheAccesses(c.buf(), "scale_local", d_loop);
1497
1498 l.prepareForCodegen();
1499 StmtPtr result =
1500 LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt()));
1501 SimpleIREvaluator cg(result, {a, b, e});
1502
1503 std::ostringstream oss;
1504 oss << *cg.stmt();
1505 const std::string& expected_ir =
1506 R"IR(
1507 #CHECK: Allocate(scale_local); // dtype=float, dims=[1, 32, 12]
1508 #CHECK: for (int j_1 = 0; j_1 < 32; j_1++) {
1509 #CHECK: for (int k_1 = 0; k_1 < 12; k_1++) {
1510 #CHECK: scale_local[k_1 + 12 * j_1] = scale[(k_1 + 12 * j_1) + 384 * i_1];
1511 #CHECK: sum[i_1] = (sum[i_1]) + (scale_local[k_2 + 12 * j_2]);
1512 #CHECK: scale_1[i_2] = (b[i_2]) * (sum[i_2]);
1513 #CHECK: Free(scale_local);
1514 )IR";
1515 torch::jit::testing::FileCheck().run(expected_ir, oss.str());
1516 }
1517
TEST(Reductions,ReductionCacheConsumerAccess)1518 TEST(Reductions, ReductionCacheConsumerAccess) {
1519 BufHandle a("a", {24, 32, 12}, kFloat);
1520 BufHandle b("b", {24, 32, 12}, kFloat);
1521
1522 Tensor c = Compute(
1523 "scale",
1524 {24, 32, 12},
1525 [&](const VarHandle& l, const VarHandle& n, const VarHandle& m) {
1526 return b.load(l, n, m) * a.load(l, n, m);
1527 });
1528 Tensor d = Reduce("sum", {24}, Sum(), c, {32, 12});
1529
1530 Tensor e = Compute("scale", {24}, [&](const VarHandle& l) {
1531 return b.load(0, 0, l) * d.load(l);
1532 });
1533
1534 LoopNest l({e}, {c, d, e});
1535
1536 LoopNest::splitWithMask(l.getLoopStmtsFor(e)[0], 4);
1537
1538 StmtPtr e_loop = l.getLoopStmtsFor(e)[1];
1539 l.cacheAccesses(d.buf(), "sum_local", e_loop);
1540 l.prepareForCodegen();
1541
1542 StmtPtr result =
1543 LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt()));
1544 SimpleIREvaluator cg(result, {a, b, e});
1545
1546 std::ostringstream oss;
1547 oss << *cg.stmt();
1548 const std::string& expected_ir =
1549 R"IR(
1550 #CHECK: Alias(sum_local,scale);
1551 #CHECK: sum[i_1] = (sum[i_1]) + (scale[
1552 #CHECK: for (int j_2 = 0; j_2 < 4
1553 #CHECK: sum_local[j_2] = sum[j_2 + 4 * i_2];
1554 #CHECK: scale_1[j_3 + 4 * i_2] = (b[j_3 + 4 * i_2]) * (sum_local[j_3]);
1555 )IR";
1556 torch::jit::testing::FileCheck().run(expected_ir, oss.str());
1557 }
1558
TEST(Reductions,ReductionSplitCacheConsumerAccess)1559 TEST(Reductions, ReductionSplitCacheConsumerAccess) {
1560 BufHandle a("a", {24, 32, 12}, kFloat);
1561 BufHandle b("b", {24, 32, 12}, kFloat);
1562
1563 Tensor c = Compute(
1564 "scale",
1565 {24, 32, 12},
1566 [&](const VarHandle& l, const VarHandle& n, const VarHandle& m) {
1567 return b.load(l, n, m) * a.load(l, n, m);
1568 });
1569 Tensor d = Reduce("sum", {24}, Sum(), c, {32, 12});
1570
1571 Tensor e = Compute("scale", {24}, [&](const VarHandle& l) {
1572 return b.load(0, 0, l) * d.load(l);
1573 });
1574
1575 LoopNest l({e}, {c, d, e});
1576
1577 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
1578 ForPtr inner;
1579
1580 // Split outer reduction axis.
1581 LoopNest::splitWithMask(l.getLoopStmtsFor(d)[0], 4, &inner);
1582
1583 // Split reduction consumer.
1584 LoopNest::splitWithMask(l.getLoopStmtsFor(e)[0], 4, &inner);
1585
1586 l.cacheAccesses(d.buf(), "sum_local", inner);
1587 l.prepareForCodegen();
1588
1589 StmtPtr result =
1590 LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt()));
1591 SimpleIREvaluator cg(result, {a, b, e});
1592
1593 // reduction changes but cache does not.
1594 std::ostringstream oss;
1595 oss << *cg.stmt();
1596 const std::string& expected_ir =
1597 R"IR(
1598 #CHECK: Alias(sum_local,scale);
1599 #CHECK: sum[j_1 + 4 * i_1] = (sum[j_1 + 4 * i_1]) + (scale[((l + 12 * k_1) + 1536 * i_1) + 384 * j_1]);
1600 #CHECK: for (int i_2 = 0; i_2 < 6
1601 #CHECK: for (int j_2 = 0; j_2 < 4
1602 #CHECK: sum_local[j_2] = sum[j_2 + 4 * i_2];
1603 #CHECK: for (int j_3 = 0; j_3 < 4
1604 #CHECK: scale_1[j_3 + 4 * i_2] = (b[j_3 + 4 * i_2]) * (sum_local[j_3]);
1605 )IR";
1606 torch::jit::testing::FileCheck().run(expected_ir, oss.str());
1607 }
1608
TEST(Reductions,ReductionReorderCacheConsumerAccess)1609 TEST(Reductions, ReductionReorderCacheConsumerAccess) {
1610 BufHandle a("a", {24, 32, 12}, kFloat);
1611 BufHandle b("b", {24, 32, 12}, kFloat);
1612
1613 Tensor c = Compute(
1614 "scale",
1615 {24, 32, 12},
1616 [&](const VarHandle& l, const VarHandle& n, const VarHandle& m) {
1617 return b.load(l, n, m) * a.load(l, n, m);
1618 });
1619 Tensor d = Reduce("sum", {24}, Sum(), c, {32, 12});
1620
1621 Tensor e = Compute("scale", {24}, [&](const VarHandle& l) {
1622 return b.load(0, 0, l) * d.load(l);
1623 });
1624
1625 LoopNest l({e}, {c, d, e});
1626
1627 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
1628 ForPtr inner;
1629
1630 // reorder outer reduction axes.
1631 auto loops = l.getLoopStmtsFor(d);
1632 LoopNest::reorderAxis(loops[0], loops[1]);
1633
1634 // Split reduction consumer.
1635 LoopNest::splitWithMask(l.getLoopStmtsFor(e)[0], 4, &inner);
1636
1637 l.cacheAccesses(d.buf(), "sum_local", inner);
1638 l.prepareForCodegen();
1639
1640 StmtPtr result =
1641 LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt()));
1642 SimpleIREvaluator cg(result, {a, b, e});
1643
1644 // neither reduction body not cache changes.
1645 std::ostringstream oss;
1646 oss << *cg.stmt();
1647 const std::string& expected_ir =
1648 R"IR(
1649 #CHECK: sum[j_1] = (sum[j_1]) + (scale[(k_1 + 12 * i_2) + 384 * j_1]);
1650 #CHECK: for (int i_3 = 0; i_3 < 6;
1651 #CHECK: for (int j_2 = 0; j_2 < 4;
1652 #CHECK: sum_local[j_2] = sum[j_2 + 4 * i_3];
1653 #CHECK: for (int j_3 = 0; j_3 < 4;
1654 #CHECK: scale_1[j_3 + 4 * i_3] = (b[j_3 + 4 * i_3]) * (sum_local[j_3]);
1655 )IR";
1656 torch::jit::testing::FileCheck().run(expected_ir, oss.str());
1657 }
1658
TEST(Reductions,ReductionRfactorCacheTempOuter)1659 TEST(Reductions, ReductionRfactorCacheTempOuter) {
1660 const int M = 10;
1661 const int N = 10;
1662 const int K = 10;
1663 VarHandle m("m", kInt);
1664 VarHandle n("n", kInt);
1665 VarHandle k("k", kInt);
1666
1667 BufHandle b("B", {m, n, k}, kFloat);
1668 std::vector<float> in(M * N * K);
1669 for (int j = 0; j < M * N * K; ++j) {
1670 in[j] = j;
1671 }
1672
1673 std::vector<float> out(1, -1.f);
1674
1675 Tensor c = Reduce("sum", {}, Sum(), b, {m, n, k});
1676 LoopNest loop({c});
1677
1678 std::vector<ForPtr> loops = loop.getLoopStmtsFor(c);
1679 LoopNest::reorderAxis(loops.at(0), loops.at(1));
1680 loops = loop.getLoopStmtsFor(c);
1681 auto c_body = loop.getAllWritesToBuf(c.buf())[1];
1682 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
1683 BufPtr rfac_buf;
1684 ASSERT_TRUE(loop.rfactor(c_body, loops.at(0), &rfac_buf));
1685 loop.distributeLoop(loops.at(0));
1686
1687 auto all_loops = loop.getAllLoopNestsWritingToBuf(rfac_buf);
1688 ASSERT_TRUE(all_loops.size() == 2 && all_loops.at(1).size() == 3);
1689 LoopNest::reorderAxis(all_loops[1][0], all_loops[1][1]);
1690
1691 all_loops = loop.getAllLoopNestsWritingToBuf(rfac_buf);
1692 LoopNest::cacheAccesses(rfac_buf, "tmp", all_loops[1][1]);
1693 loop.simplify();
1694 loop.prepareForCodegen();
1695 StmtPtr s = LoopNest::sanitizeNames(loop.root_stmt());
1696 SimpleIREvaluator cg(s, {b, c, m, n, k});
1697
1698 std::ostringstream oss;
1699 oss << *cg.stmt();
1700 const std::string& expected_ir =
1701 R"IR(
1702 #CHECK: Allocate(sum_rfac); // dtype=float, dims=[n]
1703 #CHECK: Allocate(tmp); // dtype=float, dims=[n]
1704 #CHECK: for (int i_1 = 0; i_1 < m
1705 #CHECK: for (int j = 0; j < n
1706 #CHECK: tmp[j] = 0
1707 #CHECK: }
1708 #CHECK: for (int j_1 = 0; j_1 < n
1709 #CHECK: for (int k
1710 #CHECK: tmp[j_1] = (tmp[j_1]) + (B[
1711 #CHECK: }
1712 #CHECK: }
1713 #CHECK: for (int j_2 = 0; j_2 < n
1714 #CHECK: sum_rfac[j_2] = (sum_rfac[j_2]) + (tmp[j_2]);
1715 #CHECK: }
1716 #CHECK: Free(tmp);
1717 #CHECK-NOT: tmp
1718 )IR";
1719 torch::jit::testing::FileCheck().run(expected_ir, oss.str());
1720
1721 cg.call({in, out, M, N, K});
1722 ASSERT_EQ(out[0], 499500);
1723 }
1724
TEST(Reductions,ReductionRfactorCacheTempInner)1725 TEST(Reductions, ReductionRfactorCacheTempInner) {
1726 const int M = 10;
1727 const int N = 10;
1728 const int K = 10;
1729 VarHandle m("m", kInt);
1730 VarHandle n("n", kInt);
1731 VarHandle k("k", kInt);
1732
1733 BufHandle b("B", {m, n, k}, kFloat);
1734 std::vector<float> in(M * N * K);
1735 for (int j = 0; j < M * N * K; ++j) {
1736 in[j] = j;
1737 }
1738
1739 std::vector<float> out(1, -1.f);
1740
1741 Tensor c = Reduce("sum", {}, Sum(), b, {m, n, k});
1742 LoopNest loop({c});
1743 std::vector<ForPtr> loops = loop.getLoopStmtsFor(c);
1744 auto c_body = loop.getAllWritesToBuf(c.buf())[1];
1745
1746 LoopNest::reorderAxis(loops.at(0), loops.at(1));
1747 loops = loop.getLoopStmtsFor(c);
1748 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
1749 BufPtr rfac_buf;
1750 ASSERT_TRUE(loop.rfactor(c_body, loops.at(0), &rfac_buf));
1751 loop.distributeLoop(loops.at(0));
1752 auto all_loops = loop.getAllLoopNestsWritingToBuf(rfac_buf);
1753 ASSERT_TRUE(all_loops.size() == 2 && all_loops.at(1).size() == 3);
1754 LoopNest::reorderAxis(all_loops[1][0], all_loops[1][1]);
1755
1756 all_loops = loop.getAllLoopNestsWritingToBuf(rfac_buf);
1757 ASSERT_TRUE(all_loops.size() == 2 && all_loops.at(1).size() == 3);
1758 LoopNest::cacheAccesses(rfac_buf, "tmp", all_loops[1][2]);
1759 loop.prepareForCodegen();
1760 loop.simplify();
1761 StmtPtr s = LoopNest::sanitizeNames(loop.root_stmt());
1762 SimpleIREvaluator cg(s, {b, c, m, n, k});
1763
1764 std::ostringstream oss;
1765 oss << *cg.stmt();
1766 const std::string& expected_ir =
1767 R"IR(
1768 #CHECK: Allocate(sum_rfac); // dtype=float, dims=[n]
1769 #CHECK: Allocate(tmp); // dtype=float, dims=[1]
1770 #CHECK: for (int i_1 = 0; i_1 < m
1771 #CHECK: for (int j = 0; j < n
1772 #CHECK: tmp[0] = 0
1773 #CHECK: for (int k
1774 #CHECK: tmp[0] = (tmp[0]) + (B[
1775 #CHECK: }
1776 #CHECK: sum_rfac[j] = (sum_rfac[j]) + (tmp[0]);
1777 #CHECK: Free(tmp);
1778 #CHECK-NOT: tmp
1779 )IR";
1780 torch::jit::testing::FileCheck().run(expected_ir, oss.str());
1781
1782 cg.call({in, out, M, N, K});
1783 ASSERT_EQ(out[0], 499500);
1784 }
1785
TEST(Reductions,ReductionVectorize)1786 TEST(Reductions, ReductionVectorize) {
1787 std::vector<float> in_(8 * 8);
1788 for (const auto i : c10::irange(8)) {
1789 for (const auto j : c10::irange(8)) {
1790 in_[i * 8 + j] = i;
1791 }
1792 }
1793 std::vector<float> out_before(8, -1.f);
1794 std::vector<float> out_after(8, -1.f);
1795
1796 BufHandle in("in", {8, 8}, kFloat);
1797
1798 Tensor tensor = Reduce("sum", {8}, Sum(), in, {8});
1799 LoopNest l_before({tensor});
1800 LoopNest l(l_before);
1801 l_before.prepareForCodegen();
1802 SimpleIREvaluator cg_before(l_before.root_stmt(), {in, tensor});
1803 cg_before.call({in_, out_before});
1804
1805 ASSERT_TRUE(LoopNest::vectorize(l.getLoopStmtsFor(tensor)[0]));
1806
1807 StmtPtr s = l.root_stmt();
1808 s = LoopNest::sanitizeNames(IRSimplifier::simplify(s));
1809
1810 std::ostringstream oss;
1811 oss << *s;
1812 const std::string& expected_ir =
1813 R"IR(
1814 #CHECK: sum[Ramp(0, 1, 8)] = Broadcast(0.f, 8);
1815 #CHECK: for (int i = 0; i < 8; i++) {
1816 #CHECK: sum[Ramp(0, 1, 8)] = ReduceOp((sum[Ramp(0, 1, 8)]) + (in[Ramp(i, 8, 8)]), reduce_args={i});
1817 #CHECK: }
1818 )IR";
1819 torch::jit::testing::FileCheck().run(expected_ir, oss.str());
1820
1821 // Vectorizing should not change result.
1822 l.prepareForCodegen();
1823 s = IRSimplifier::simplify(l.root_stmt());
1824 SimpleIREvaluator cg_after(s, {in, tensor});
1825 cg_after.call({in_, out_after});
1826 for (const auto i : c10::irange(8)) {
1827 ASSERT_EQ(out_before[i], out_after[i]);
1828 }
1829 }
1830
TEST(Reductions,ReductionVectorizeInner)1831 TEST(Reductions, ReductionVectorizeInner) {
1832 BufHandle in("in", {8, 8}, kFloat);
1833
1834 Tensor tensor = Reduce("sum", {8}, Sum(), in, {8});
1835 LoopNest l({tensor});
1836
1837 ASSERT_FALSE(LoopNest::vectorize(l.getLoopStmtsFor(tensor)[1]));
1838 }
1839
TEST(Reductions,ReductionVectorizeRfactor)1840 TEST(Reductions, ReductionVectorizeRfactor) {
1841 std::vector<float> in_(8 * 8);
1842 for (const auto i : c10::irange(8)) {
1843 for (const auto j : c10::irange(8)) {
1844 in_[i * 8 + j] = i;
1845 }
1846 }
1847 std::vector<float> out_before(1, -1.f);
1848 std::vector<float> out_after(1, -1.f);
1849
1850 BufHandle in("in", {8, 8}, kFloat);
1851
1852 Tensor tensor = Reduce("sum", {}, Sum(), in, {8, 8});
1853
1854 LoopNest l_before({tensor});
1855 LoopNest l(l_before);
1856 l_before.prepareForCodegen();
1857 SimpleIREvaluator cg_before(l_before.root_stmt(), {in, tensor});
1858 cg_before.call({in_, out_before});
1859
1860 ASSERT_FALSE(LoopNest::vectorize(l.getLoopStmtsFor(tensor)[1]));
1861
1862 // But if we rfactor this so it's not a reduce axis we can vectorize that
1863 // loop.
1864 std::vector<ForPtr> loops = l.getLoopStmtsFor(tensor);
1865 LoopNest::reorderAxis(loops[0], loops[1]);
1866 loops = l.getLoopStmtsFor(tensor);
1867 auto tensor_body = l.getAllWritesToBuf(tensor.buf())[1];
1868 BufPtr rfac_buf = nullptr;
1869 ASSERT_TRUE(LoopNest::rfactor(tensor_body, loops.at(0), &rfac_buf));
1870
1871 LoopNest::distributeLoop(loops.at(0));
1872 auto rfac_loops = l.getAllLoopNestsWritingToBuf(rfac_buf);
1873
1874 ASSERT_TRUE(LoopNest::vectorize(rfac_loops[1][0]));
1875 l.simplify();
1876
1877 StmtPtr s = LoopNest::sanitizeNames(l.root_stmt());
1878
1879 std::ostringstream oss;
1880 oss << *s;
1881 const std::string& expected_ir =
1882 R"IR(
1883 #CHECK: sum = 0.f;
1884 #CHECK: for (int i = 0; i < 8; i++) {
1885 #CHECK: sum_rfac[i] = 0.f;
1886 #CHECK: }
1887 #CHECK: for (int i_1 = 0; i_1 < 8; i_1++) {
1888 #CHECK: sum_rfac[Ramp(0, 1, 8)] = ReduceOp((sum_rfac[Ramp(0, 1, 8)]) + (in[Ramp(8 * i_1, 1, 8)]), reduce_args={i_1});
1889 #CHECK: }
1890 #CHECK: for (int i_2 = 0; i_2 < 8; i_2++) {
1891 #CHECK: sum = ReduceOp((sum) + (sum_rfac[i_2]), reduce_args={i_2});
1892 #CHECK: }
1893 )IR";
1894 torch::jit::testing::FileCheck().run(expected_ir, oss.str());
1895
1896 // Vectorizing should not change result.
1897 l.prepareForCodegen();
1898 s = IRSimplifier::simplify(l.root_stmt());
1899 SimpleIREvaluator cg_after(s, {in, tensor});
1900 cg_after.call({in_, out_after});
1901
1902 ASSERT_EQ(out_before[0], out_after[0]);
1903 }
1904
TEST(Reductions,InitFunction)1905 TEST(Reductions, InitFunction) {
1906 constexpr int M = 32;
1907 constexpr int N = 16;
1908 BufHandle A("A", {M, N}, kFloat);
1909 BufHandle B("B", {N}, kFloat);
1910 Tensor C = Reduce(
1911 "C",
1912 {N},
1913 Sum(),
1914 [&](const std::vector<VarHandle>& v) { return B.load(v[0]); },
1915 [&](const std::vector<VarHandle>& v) { return A.load(v[1], v[0]); },
1916 {M});
1917 LoopNest nest({C});
1918 nest.prepareForCodegen();
1919 StmtPtr s = LoopNest::sanitizeNames(IRSimplifier::simplify(nest.root_stmt()));
1920 std::ostringstream oss;
1921 oss << *s << "\n";
1922 const std::string& expected_ir =
1923 R"IR(
1924 #CHECK: for (int i = 0; i < 16; i++) {
1925 #CHECK: C[i] = B[i];
1926 #CHECK: for (int j = 0; j < 32; j++) {
1927 #CHECK: C[i] = (C[i]) + (A[i + 16 * j]);
1928 #CHECK: }
1929 #CHECK: }
1930 )IR";
1931 torch::jit::testing::FileCheck().run(expected_ir, oss.str());
1932 }
1933 } // namespace jit
1934 } // namespace torch
1935