xref: /aosp_15_r20/external/pytorch/test/cpp/tensorexpr/test_registerizer.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 #include "test/cpp/tensorexpr/test_base.h"
3 
4 #include "test/cpp/tensorexpr/test_utils.h"
5 #include "torch/csrc/jit/tensorexpr/ir_simplifier.h"
6 #include "torch/csrc/jit/tensorexpr/registerizer.h"
7 
8 #include <iostream>
9 
10 namespace torch {
11 namespace jit {
12 using namespace torch::jit::tensorexpr;
13 
14 // Can replace a simple scalar access with a local variable.
TEST(Registerizer,RegisterizerSimple)15 TEST(Registerizer, RegisterizerSimple) {
16   BufHandle a("A", {1}, kInt);
17   VarHandle x("x", kInt);
18   StmtPtr stmt = Block::make(
19       {Store::make(a, {0}, 0),
20        For::make(
21            x,
22            0,
23            10,
24            Block::make(
25                {Store::make(a, {0}, Add::make(Load::make(a, {0}), x))}))});
26 
27   /*
28    * A[0] = 0;
29    * for (int x = 0; x < 10; x++) {
30    *   A[0] = (A[0]) + x;
31    * }
32    */
33 
34   stmt = registerize(stmt);
35 
36   /*
37    * int A_1 = 0;
38    * for (int x = 0; x < 10; x++) {
39    *   A_1 = x + A_1;
40    * }
41    * A[0] = A_1;
42    */
43 
44   std::ostringstream oss;
45   oss << *stmt;
46 
47   const std::string& verification_pattern =
48       R"IR(
49 # CHECK: int A_1 = 0;
50 # CHECK: for (int x = 0; x < 10; x++)
51 # CHECK-NOT: A[
52 # CHECK:   A_1 =
53 # CHECK: A[0] = A_1;)IR";
54 
55   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
56 }
57 
58 // Won't do replacement of a loop access.
TEST(Registerizer,RegisterizerLoop)59 TEST(Registerizer, RegisterizerLoop) {
60   BufHandle a("A", {10}, kInt);
61   VarHandle x("x", kInt);
62   StmtPtr stmt = Block::make(
63       {Store::make(a, {0}, 0),
64        For::make(
65            x,
66            0,
67            10,
68            Block::make(
69                {Store::make(a, {x}, Add::make(Load::make(a, {x}), x))}))});
70 
71   /*
72    * A[0] = 0;
73    * for (int x = 0; x < 10; x++) {
74    *   A[x] = (A[x]) + x;
75    * }
76    */
77 
78   // No change.
79   stmt = registerize(stmt);
80 
81   /*
82    * A[0] = 0;
83    * for (int x = 0; x < 10; x++) {
84    *   A[x] = (A[x]) + x;
85    * }
86    */
87 
88   std::ostringstream oss;
89   oss << *stmt;
90 
91   const std::string& verification_pattern =
92       R"IR(
93 # CHECK-NOT: int
94 # CHECK: A[0] = 0;
95 # CHECK: for (int x = 0; x < 10; x++)
96 # CHECK-NOT: A_
97 # CHECK:   A[x] =
98 # CHECK-NOT: A[0] = A_1;)IR";
99 
100   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
101 }
102 
103 // Won't replace even if the load is a fixed scalar, since the store could
104 // invalidate it.
TEST(Registerizer,RegisterizerLoopFixedLoad)105 TEST(Registerizer, RegisterizerLoopFixedLoad) {
106   BufHandle a("A", {1}, kInt);
107   VarHandle x("x", kInt);
108   StmtPtr stmt = Block::make(
109       {Store::make(a, {0}, 0),
110        For::make(
111            x,
112            0,
113            10,
114            Block::make(
115                {Store::make(a, {x}, Add::make(Load::make(a, {0}), x))}))});
116 
117   /*
118    * A[0] = 0;
119    * for (int x = 0; x < 10; x++) {
120    *   A[x] = (A[0]) + x;
121    * }
122    */
123 
124   // No change.
125   stmt = registerize(stmt);
126 
127   /*
128    * A[0] = 0;
129    * for (int x = 0; x < 10; x++) {
130    *   A[x] = (A[0]) + x;
131    * }
132    */
133 
134   std::ostringstream oss;
135   oss << *stmt;
136 
137   const std::string& verification_pattern =
138       R"IR(
139 # CHECK-NOT: int
140 # CHECK: A[0] = 0;
141 # CHECK: for (int x = 0; x < 10; x++)
142 # CHECK-NOT: A_
143 # CHECK:   A[x] =
144 # CHECK-NOT: A[0] = A_1;)IR";
145 
146   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
147 }
148 
149 // We can registerize accesses that occur entirely within inner scopes, even if
150 // they depend on the loop var.
TEST(Registerizer,RegisterizerLoopInternal)151 TEST(Registerizer, RegisterizerLoopInternal) {
152   BufHandle a("A", {1}, kInt);
153   VarHandle x("x", kInt);
154   StmtPtr stmt = Block::make({For::make(
155       x,
156       0,
157       10,
158       Block::make(
159           {Store::make(a, {x}, Add::make(Load::make(a, {x}), x)),
160            Store::make(a, {x}, Add::make(Load::make(a, {x}), x))}))});
161 
162   /*
163    * for (int x = 0; x < 10; x++) {
164    *   A[x] = (A[x]) + x;
165    *   A[x] = (A[x]) + x;
166    * }
167    */
168 
169   stmt = registerize(stmt);
170 
171   // TODO: the order of terms in addition changes and in general depends on
172   // some hash value. This results in unpredictable swaps of the operands from
173   // random changes, which is not great. Ideally, we should ensure some
174   // specific order (ideally, the original one).
175   /*
176    * for (int x = 0; x < 10; x++) {
177    *   int A_1 = A[x];
178    *   A_1 = x + A_1;
179    *   A_1 = x + A_1;
180    *   A[x] = A_1;
181    * }
182    */
183 
184   std::ostringstream oss;
185   oss << *stmt;
186 
187   const std::string& verification_pattern =
188       R"IR(
189 # CHECK: for (int x = 0; x < 10; x++)
190 # CHECK: int A_1 = A[x];
191 # CHECK:   A_1 = A_1 + x;
192 # CHECK:   A_1 = A_1 + x;
193 # CHECK:   A[x] = A_1;
194 # CHECK: })IR";
195 
196   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
197 }
198 
199 // An access can be overlapped by another read in the same Expr. In this case
200 // B[z] and B[y] overlap and prevent registerization of both accesses.
TEST(Registerizer,RegisterizerLoopInternalLoadOverlap)201 TEST(Registerizer, RegisterizerLoopInternalLoadOverlap) {
202   BufHandle a("A", {10}, kInt);
203   BufHandle b("B", {10}, kInt);
204   VarHandle x("x", kInt);
205   VarHandle y("y", kInt);
206   VarHandle z("z", kInt);
207   StmtPtr stmt = Block::make({For::make(
208       x,
209       0,
210       10,
211       Store::make(a, {x}, Add::make(Load::make(b, {y}), Load::make(b, {z}))))});
212   stmt = IRSimplifier::simplify(stmt);
213 
214   /*
215    * for (int x = 0; x < 10; x++) {
216    *   A[x] = (B[y]) + (B[z]);
217    * }
218    */
219 
220   std::ostringstream before;
221   before << *stmt;
222 
223   // No change.
224   stmt = registerize(stmt);
225 
226   std::ostringstream after;
227   after << *stmt;
228 
229   ASSERT_EQ(before.str(), after.str());
230 }
231 
TEST(Registerizer,RegisterizerLoopInternalRepeated)232 TEST(Registerizer, RegisterizerLoopInternalRepeated) {
233   BufHandle a("A", {1}, kInt);
234   VarHandle x("x", kInt);
235   StmtPtr stmt = Block::make(
236       {For::make(
237            x,
238            0,
239            10,
240            Block::make(
241                {Store::make(a, {0}, Add::make(Load::make(a, {1}), x)),
242                 Store::make(a, {0}, Add::make(Load::make(a, {1}), x))})),
243        For::make(
244            x,
245            0,
246            10,
247            Block::make(
248                {Store::make(a, {0}, Add::make(Load::make(a, {1}), x)),
249                 Store::make(a, {0}, Add::make(Load::make(a, {1}), x))}))
250 
251       });
252 
253   /*
254    * for (int x = 0; x < 10; x++) {
255    *   A[0] = x + (A[1]);
256    *   A[0] = x + (A[1]);
257    * }
258    * for (int x = 0; x < 10; x++) {
259    *   A[0] = x + (A[1]);
260    *   A[0] = x + (A[1]);
261    * }
262    */
263 
264   stmt = registerize(stmt);
265 
266   /*
267    * int A_1 = A[1];
268    * int A_2 = A[0];
269    * for (int x = 0; x < 10; x++) {
270    *   A_2 = A_1 + x;
271    *   A_2 = A_1 + x;
272    * }
273    * for (int x = 0; x < 10; x++) {
274    *   A_2 = A_1 + x;
275    *   A_2 = A_1 + x;
276    * }
277    * A[0] = A_2;
278    */
279 
280   std::ostringstream oss;
281   oss << *stmt;
282 
283   const std::string& verification_pattern =
284       R"IR(
285 # CHECK: int A_1 = A[1];
286 # CHECK: int A_2 = A[0];
287 # CHECK: for (int x = 0; x < 10; x++)
288 # CHECK:   A_2 = A_1 + x;
289 # CHECK:   A_2 = A_1 + x;
290 # CHECK: }
291 # CHECK: for (int x = 0; x < 10; x++)
292 # CHECK:   A_2 = A_1 + x;
293 # CHECK:   A_2 = A_1 + x;
294 # CHECK: }
295 # CHECK-NOT: A[1]
296 # CHECK: A[0] = A_2;
297 # CHECK-NOT: A[1]
298 # CHECK: })IR";
299 
300   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
301 }
302 
TEST(Registerizer,RegisterizerLoopInternalRepeatedOverlapLoopVar)303 TEST(Registerizer, RegisterizerLoopInternalRepeatedOverlapLoopVar) {
304   BufHandle a("A", {1}, kInt);
305   VarHandle x("x", kInt);
306   StmtPtr stmt = Block::make(
307       {For::make(
308            x,
309            0,
310            10,
311            Block::make(
312                {Store::make(a, {0}, Add::make(Load::make(a, {x}), x)),
313                 Store::make(a, {0}, Add::make(Load::make(a, {x}), x))})),
314        For::make(
315            x,
316            0,
317            10,
318            Block::make(
319                {Store::make(a, {0}, Add::make(Load::make(a, {x}), x)),
320                 Store::make(a, {0}, Add::make(Load::make(a, {x}), x))}))
321 
322       });
323   stmt = IRSimplifier::simplify(stmt);
324 
325   /*
326    * for (int x = 0; x < 10; x++) {
327    *   A[0] = (A[x]) + x;
328    *   A[0] = (A[x]) + x;
329    * }
330    * for (int x = 0; x < 10; x++) {
331    *   A[0] = (A[x]) + x;
332    *   A[0] = (A[x]) + x;
333    * }
334    */
335 
336   std::ostringstream before;
337   before << *stmt;
338 
339   // No change.
340   stmt = registerize(stmt);
341 
342   std::ostringstream after;
343   after << *stmt;
344 
345   ASSERT_EQ(before.str(), after.str());
346 }
347 
TEST(Registerizer,RegisterizerLoopInternalRepeatedOverlapOther)348 TEST(Registerizer, RegisterizerLoopInternalRepeatedOverlapOther) {
349   BufHandle a("A", {1}, kInt);
350   VarHandle x("x", kInt);
351   VarHandle y("y", kInt);
352   StmtPtr stmt = IRSimplifier::simplify(Block::make(
353       {For::make(
354            x,
355            0,
356            10,
357            Block::make(
358                {Store::make(a, {0}, Add::make(x, Load::make(a, {y}))),
359                 Store::make(a, {0}, Add::make(x, Load::make(a, {y})))})),
360        For::make(
361            x,
362            0,
363            10,
364            Block::make(
365                {Store::make(a, {0}, Add::make(x, Load::make(a, {y}))),
366                 Store::make(a, {0}, Add::make(x, Load::make(a, {y})))}))
367 
368       }));
369 
370   /*
371    * for (int x = 0; x < 10; x++) {
372    *   A[0] = (A[x]) + x;
373    *   A[0] = (A[x]) + x;
374    * }
375    * for (int x = 0; x < 10; x++) {
376    *   A[0] = (A[x]) + x;
377    *   A[0] = (A[x]) + x;
378    * }
379    */
380 
381   std::ostringstream before;
382   before << *stmt;
383 
384   // No change.
385   stmt = registerize(stmt);
386 
387   std::ostringstream after;
388   after << *stmt;
389 
390   ASSERT_EQ(before.str(), after.str());
391 }
392 
393 // Will registerize multiple accesses of different items of the same buffer.
TEST(Registerizer,RegisterizerMultiVar)394 TEST(Registerizer, RegisterizerMultiVar) {
395   BufHandle a("A", {2}, kInt);
396   VarHandle x("x", kInt);
397   StmtPtr stmt = Block::make({
398       Store::make(a, {0}, 0),
399       Store::make(a, {1}, 0),
400       For::make(
401           x,
402           0,
403           10,
404           Block::make(
405               {Store::make(a, {0}, Add::make(Load::make(a, {0}), x)),
406                Store::make(a, {1}, Sub::make(Load::make(a, {1}), x))})),
407   });
408 
409   /*
410    * A[0] = 0;
411    * A[1] = 0;
412    * for (int x = 0; x < 10; x++) {
413    *   A[0] = (A[0]) + x;
414    *   A[1] = (A[1]) - x;
415    * }
416    */
417 
418   stmt = registerize(stmt);
419 
420   /*
421    * int A_1 = 0;
422    * int A_2 = 0;
423    * for (int x = 0; x < 10; x++) {
424    *   A_2 = x + A_2;
425    *   A_1 = A_1 - x;
426    * }
427    * A[1] = A_2;
428    * A[0] = A_1;
429    */
430 
431   std::ostringstream oss;
432   oss << *stmt;
433 
434   const std::string& verification_pattern =
435       R"IR(
436 # CHECK: int A_1 = 0;
437 # CHECK: int A_2 = 0;
438 # CHECK: for (int x = 0; x < 10; x++)
439 # CHECK-NOT: A[
440 # CHECK:   A_1 =
441 # CHECK:   A_2 =
442 # CHECK: A[1] = A_2
443 # CHECK: A[0] = A_1;)IR";
444 
445   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
446 }
447 
448 // Will registerize the valid accesses while skipping invalid replacements.
TEST(Registerizer,RegisterizerVariableLoad)449 TEST(Registerizer, RegisterizerVariableLoad) {
450   BufHandle a("A", {1}, kInt);
451   BufHandle b("B", {10}, kInt);
452   VarHandle x("x", kInt);
453   VarHandle x2("x", kInt);
454   StmtPtr stmt = Block::make(
455       {Store::make(a, {0}, 0),
456        For::make(x, 0, 10, Store::make(b, {x}, x)),
457        For::make(
458            x2,
459            0,
460            10,
461            Block::make({Store::make(
462                a, {0}, Add::make(Load::make(a, {0}), Load::make(b, {x2})))}))});
463 
464   /*
465    * A[0] = 0;
466    * for (int x = 0; x < 10; x++) {
467    *   B[x] = x;
468    * }
469    * for (int x_1 = 0; x_1 < 10; x_1++) {
470    *   A[0] = (A[0]) + (B[x_1]);
471    * }
472    */
473 
474   stmt = registerize(stmt);
475 
476   /*
477    * int A_1 = 0;
478    * for (int x = 0; x < 10; x++) {
479    *   B[x] = x;
480    * }
481    * for (int x_1 = 0; x_1 < 10; x_1++) {
482    *   A_1 = A_1 + (B[x_1]);
483    * }
484    * A[0] = A_1;
485    */
486 
487   std::ostringstream oss;
488   oss << *stmt;
489 
490   const std::string& verification_pattern =
491       R"IR(
492 # CHECK: int A_1 = 0;
493 # CHECK: for (int x = 0; x < 10; x++)
494 # CHECK:   B[x] = x
495 # CHECK: for (int x_1 = 0; x_1 < 10; x_1++)
496 # CHECK-NOT: A[
497 # CHECK:   A_1 =
498 # CHECK: A[0] = A_1;)IR";
499 
500   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
501 }
502 
503 // Can registerize variable accesses so long as the variable does not change.
TEST(Registerizer,RegisterizerSymbolicIndices)504 TEST(Registerizer, RegisterizerSymbolicIndices) {
505   VarHandle i("i", kInt);
506   VarHandle N("N", kInt);
507   BufHandle a("A", {N}, kInt);
508   VarHandle x("x", kInt);
509   StmtPtr stmt = Block::make(
510       {Store::make(a, {i}, 0),
511        For::make(
512            x,
513            0,
514            10,
515            Block::make(
516                {Store::make(a, {i}, Add::make(Load::make(a, {i}), x))}))});
517 
518   /*
519    * A[i] = 0;
520    * for (int x = 0; x < 10; x++) {
521    *   A[i] = (A[i]) + x;
522    * }
523    */
524 
525   stmt = registerize(stmt);
526 
527   /*
528    * int A_1 = 0;
529    * for (int x = 0; x < 10; x++) {
530    *   A_1 = x + A_1;
531    * }
532    * A[i] = A_1;
533    */
534 
535   std::ostringstream oss;
536   oss << *stmt;
537 
538   const std::string& verification_pattern =
539       R"IR(
540 # CHECK: int A_1 = 0;
541 # CHECK: for (int x = 0; x < 10; x++)
542 # CHECK-NOT: A[
543 # CHECK:   A_1 =
544 # CHECK: A[i] = A_1;)IR";
545 
546   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
547 }
548 
549 // Can registerize accesses dependent on multiple loop vars.
TEST(Registerizer,RegisterizerMultiLoop)550 TEST(Registerizer, RegisterizerMultiLoop) {
551   BufHandle a("A", {1}, kInt);
552   VarHandle x("x", kInt);
553   VarHandle y("y", kInt);
554   StmtPtr stmt = Block::make(
555       {Store::make(a, {0}, 0),
556        For::make(
557            x,
558            0,
559            10,
560            For::make(
561                y,
562                0,
563                10,
564                Block::make({Store::make(
565                    a,
566                    {0},
567                    Mul::make(Add::make(Load::make(a, {0}), x), y))})))});
568 
569   /*
570    * A[0] = 0;
571    * for (int x = 0; x < 10; x++) {
572    *   for (int y = 0; y < 10; y++) {
573    *     A[0] = x * y + (A[0]) * y;
574    *   }
575    * }
576    */
577 
578   stmt = registerize(stmt);
579 
580   /*
581    * int A_1 = 0;
582    * for (int x = 0; x < 10; x++) {
583    *   for (int y = 0; y < 10; y++) {
584    *     A_1 = x * y + y * A_1;
585    *   }
586    * }
587    * A[0] = A_1;
588    */
589 
590   std::ostringstream oss;
591   oss << *stmt;
592 
593   const std::string& verification_pattern =
594       R"IR(
595 # CHECK: int A_1 = 0;
596 # CHECK: for (int x = 0; x < 10; x++)
597 # CHECK:   for (int y = 0; y < 10; y++)
598 # CHECK-NOT: A[
599 # CHECK:     A_1 =
600 # CHECK: A[0] = A_1;)IR";
601 
602   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
603 }
604 
605 // Can registerize correctly if scalars already exist in the program.
TEST(Registerizer,RegisterizerRepeated)606 TEST(Registerizer, RegisterizerRepeated) {
607   BufHandle a("A", {2}, kInt);
608   VarHandle x("x", kInt);
609   StmtPtr stmt = Block::make({
610       Store::make(a, {0}, 0),
611       Store::make(a, {1}, 0),
612       For::make(
613           x,
614           0,
615           10,
616           Block::make(
617               {Store::make(a, {0}, Add::make(Load::make(a, {0}), x)),
618                Store::make(a, {1}, Sub::make(Load::make(a, {1}), x))})),
619   });
620 
621   // Registerize manually to make sure we only replace a single target.
622   {
623     registerizer::RegisterizerAnalysis analysis;
624     stmt->accept(&analysis);
625     auto candidates = analysis.getCandidates();
626     ASSERT_EQ(candidates.size(), 2);
627 
628     candidates.pop_back();
629     registerizer::RegisterizerReplacer replacer(candidates);
630     stmt = stmt->accept_mutator(&replacer);
631   }
632 
633   // Re-analyze and replace the second target.
634   {
635     registerizer::RegisterizerAnalysis analysis;
636     stmt->accept(&analysis);
637     auto candidates = analysis.getCandidates();
638     ASSERT_EQ(candidates.size(), 1);
639 
640     registerizer::RegisterizerReplacer replacer(candidates);
641     stmt = stmt->accept_mutator(&replacer);
642   }
643 
644   std::ostringstream oss;
645   oss << *stmt;
646 
647   const std::string& verification_pattern =
648       R"IR(
649 # CHECK: int A_1 = 0;
650 # CHECK: int A_1_1 = 0;
651 # CHECK: for (int x = 0; x < 10; x++)
652 # CHECK-NOT: A[
653 # CHECK:   A_1 =
654 # CHECK:   A_1_1 =
655 # CHECK: A[1] = A_1_1;
656 # CHECK: A[0] = A_1;)IR";
657 
658   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
659 }
660 
661 // Can registerize the load of A.
TEST(Registerizer,RegisterizerNoLoads)662 TEST(Registerizer, RegisterizerNoLoads) {
663   BufHandle a("A", {1}, kInt);
664   VarHandle x("x", kInt);
665   StmtPtr stmt = Block::make(
666       {Store::make(a, {0}, 0),
667        For::make(
668            x, 0, 10, Block::make({Store::make(a, {0}, Add::make(x, 1))}))});
669 
670   /*
671    * A[0] = 0;
672    * for (int x = 0; x < 10; x++) {
673    *   A[0] = x + 1;
674    * }
675    */
676 
677   stmt = registerize(stmt);
678 
679   /*
680    * int A_1 = 0;
681    * for (int x = 0; x < 10; x++) {
682    *   A_1 = x + 1;
683    * }
684    * A[0] = A_1;
685    */
686 
687   std::ostringstream oss;
688   oss << *stmt;
689 
690   const std::string& verification_pattern =
691       R"IR(
692 # CHECK: int A_1 = 0;
693 # CHECK: for (int x = 0; x < 10; x++)
694 # CHECK-NOT: A[
695 # CHECK:   A_1 =
696 # CHECK: A[0] = A_1;)IR";
697 
698   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
699 }
700 
701 // Can registerize the load of A but not the store of B.
TEST(Registerizer,RegisterizerNoRepeatedStores)702 TEST(Registerizer, RegisterizerNoRepeatedStores) {
703   BufHandle a("A", {1}, kInt);
704   BufHandle b("B", {10}, kInt);
705   VarHandle x("x", kInt);
706   StmtPtr stmt = Block::make(
707       {Store::make(a, {0}, 0),
708        For::make(
709            x,
710            0,
711            10,
712            Block::make(
713                {Store::make(b, {x}, Add::make(Load::make(a, {0}), x))}))});
714 
715   /*
716    * A[0] = 0;
717    * for (int x = 0; x < 10; x++) {
718    *   B[x] = (A[0]) + x;
719    * }
720    */
721 
722   stmt = registerize(stmt);
723 
724   // TODO: its unnecessary to reorder the initializer of A[0], but it's not
725   // actually worse so lets not worry for now.
726 
727   /*
728    * int A_1 = 0;
729    * for (int x = 0; x < 10; x++) {
730    *   B[x] = x + A_1;
731    * }
732    * A[0] = A_1;
733    */
734 
735   std::ostringstream oss;
736   oss << *stmt;
737 
738   const std::string& verification_pattern =
739       R"IR(
740 # CHECK: int A_1 = 0;
741 # CHECK: for (int x = 0; x < 10; x++)
742 # CHECK-NOT: A_
743 # CHECK:   B[x] =
744 # CHECK: A[0] = A_1;)IR";
745 
746   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
747 }
748 
749 // Won't registerize if there are multiple accesses which may overlap.
TEST(Registerizer,RegisterizerMultiVarOverlap)750 TEST(Registerizer, RegisterizerMultiVarOverlap) {
751   BufHandle a("A", {2}, kInt);
752   VarHandle x("x", kInt);
753   StmtPtr stmt = Block::make({
754       Store::make(a, {0}, 0),
755       Store::make(a, {1}, 0),
756       For::make(
757           x,
758           0,
759           10,
760           Block::make(
761               {Store::make(a, {x}, Add::make(Load::make(a, {0}), x)),
762                Store::make(a, {x + 1}, Sub::make(Load::make(a, {1}), x))})),
763   });
764   stmt = IRSimplifier::simplify(stmt);
765 
766   std::ostringstream before;
767   before << *stmt;
768 
769   // No change.
770   stmt = registerize(stmt);
771 
772   std::ostringstream after;
773   after << *stmt;
774 
775   ASSERT_EQ(before.str(), after.str());
776 }
777 
TEST(Registerizer,RegisterizerAllocs)778 TEST(Registerizer, RegisterizerAllocs) {
779   BufHandle a("A", {2}, kInt);
780   BufHandle c("C", {1}, kInt);
781   VarHandle x("x", kInt);
782 
783   BufHandle b("B", {Load::make(c, {0})}, kInt);
784 
785   StmtPtr stmt = Block::make(
786       {Allocate::make(b),
787        Store::make(a, {0}, Load::make(c, {0})),
788        Store::make(b, {0}, 0),
789        For::make(
790            x,
791            0,
792            10,
793            Block::make(
794                {Store::make(b, {0}, Add::make(Load::make(b, {0}), x)),
795                 Store::make(a, {0}, Load::make(c, {0}))})),
796        Free::make(b)});
797 
798   /*
799    * Allocate(B, int, {C[0]});
800    * A[0] = C[0];
801    * B[0] = 0;
802    * for (int x = 0; x < 10; x++) {
803    *   B[0] = (B[0]) + x;
804    *   A[0] = C[0];
805    * }
806    * Free(B);
807    */
808 
809   stmt = registerize(stmt);
810 
811   /*
812    * int C_1 = C[0];
813    * Allocate(B, int, {C_});
814    * int A_1 = C_1;
815    * int B_1 = 0;
816    * for (int x = 0; x < 10; x++) {
817    *   B_1 = B_1 + x;
818    *   A_1 = C_1;
819    * }
820    * B[0] = B_1;
821    * A[0] = A_1;
822    * Free(B);
823    */
824 
825   std::ostringstream oss;
826   oss << *stmt;
827 
828   const std::string& verification_pattern =
829       R"IR(
830 # CHECK: int C_1 = C[0];
831 # CHECK: Allocate(B
832 # CHECK: int A_1 = C_1;
833 # CHECK: int B_1 = 0;
834 # CHECK: for (int x = 0; x < 10; x++)
835 # CHECK:   B_1 =
836 # CHECK:   A_1 = C_
837 # CHECK: B[0] = B_1;
838 # CHECK: A[0] = A_1;
839 # CHECK: Free(B)IR";
840 
841   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
842 }
843 
TEST(Registerizer,RegisterizerNoInitializer)844 TEST(Registerizer, RegisterizerNoInitializer) {
845   BufHandle a("A", {1}, kInt);
846   VarHandle x("x", kInt);
847   StmtPtr stmt = Block::make({For::make(
848       x,
849       0,
850       10,
851       Block::make({Store::make(a, {0}, Add::make(Load::make(a, {0}), x))}))});
852 
853   /*
854    * for (int x = 0; x < 10; x++) {
855    *   A[0] = (A[0]) + x;
856    * }
857    */
858 
859   stmt = registerize(stmt);
860 
861   /*
862    * int A_1 = A[0];
863    * for (int x = 0; x < 10; x++) {
864    *   A_1 = x + A_1;
865    * }
866    * A[0] = A_1;
867    */
868 
869   std::ostringstream oss;
870   oss << *stmt;
871 
872   const std::string& verification_pattern =
873       R"IR(
874 # CHECK: int A_1 = A[0];
875 # CHECK: for (int x = 0; x < 10; x++)
876 # CHECK-NOT: A[
877 # CHECK:   A_1 =
878 # CHECK: A[0] = A_1;)IR";
879 
880   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
881 }
882 
TEST(Registerizer,RegisterizerNoInitializerLoopVar)883 TEST(Registerizer, RegisterizerNoInitializerLoopVar) {
884   BufHandle a("A", {1}, kInt);
885   VarHandle x("x", kInt);
886   StmtPtr stmt = Block::make({For::make(
887       x,
888       0,
889       10,
890       Block::make({Store::make(a, {x}, Add::make(Load::make(a, {x}), x))}))});
891   stmt = IRSimplifier::simplify(stmt);
892 
893   /*
894    * for (int x = 0; x < 10; x++) {
895    *   A[x] = (A[x]) + x;
896    * }
897    */
898 
899   std::ostringstream before;
900   before << *stmt;
901 
902   // No change.
903   stmt = registerize(stmt);
904 
905   std::ostringstream after;
906   after << *stmt;
907 
908   ASSERT_EQ(before.str(), after.str());
909 }
910 
TEST(Registerizer,RegisterizerLoadThenStore)911 TEST(Registerizer, RegisterizerLoadThenStore) {
912   BufHandle a("A", {1}, kInt);
913   BufHandle b("B", {1}, kInt);
914   VarHandle x("x", kInt);
915   StmtPtr stmt = Block::make({For::make(
916       x,
917       0,
918       10,
919       Block::make(
920           {Store::make(b, {0}, Add::make(Load::make(a, {0}), x)),
921            Store::make(a, {0}, Load::make(b, {0}))}))});
922 
923   /*
924    * for (int x = 0; x < 10; x++) {
925    *   B[0] = (A[0]) + x;
926    *   A[0] = B[0];
927    * }
928    */
929 
930   stmt = registerize(stmt);
931 
932   /*
933    * int A_1 = A[0];
934    * int B_1 = B[0];
935    * for (int x = 0; x < 10; x++) {
936    *   B_1 = x + A_1;
937    *   A_1 = B_1;
938    * }
939    * B[0] = B_1;
940    * A[0] = A_1;
941    */
942 
943   std::ostringstream oss;
944   oss << *stmt;
945 
946   const std::string& verification_pattern =
947       R"IR(
948 # CHECK: int A_1 = A[0];
949 # CHECK: int B_1 = B[0];
950 # CHECK: for (int x = 0; x < 10; x++)
951 # CHECK-NOT: B[
952 # CHECK:   B_1 =
953 # CHECK-NOT: A[
954 # CHECK:   A_1 = B_
955 # CHECK: B[0] = B_
956 # CHECK: A[0] = A_1;)IR";
957 
958   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
959 }
960 
TEST(Registerizer,RegisterizerParallelized)961 TEST(Registerizer, RegisterizerParallelized) {
962   BufHandle a("A", {1}, kInt);
963   VarHandle x("x", kInt);
964   LoopOptions loopOpts;
965   loopOpts.set_gpu_block_index(0);
966   StmtPtr stmt = Block::make(
967       {Store::make(a, {0}, 0),
968        For::make(
969            x,
970            0,
971            10,
972            Block::make({Store::make(a, {0}, Add::make(Load::make(a, {0}), x))}),
973            loopOpts)});
974 
975   /*
976    * A[0] = 0;
977    * for (int x = 0; x < 10; x++) {
978    *   A[0] = (A[0]) + x;
979    * }
980    */
981 
982   ASSERT_THROWS_WITH(
983       registerize(stmt),
984       "Registerization must occur after parallelism flattening");
985 }
986 
987 // Should be able to registerize this since the scalar would exist before the
988 // branch.
TEST(Registerizer,RegisterizerConditionAfter)989 TEST(Registerizer, RegisterizerConditionAfter) {
990   BufHandle a("A", {5}, kInt);
991   BufHandle b("B", {5}, kInt);
992   BufHandle c("C", {5}, kInt);
993   VarHandle x("x", kInt);
994 
995   StmtPtr stmt = Block::make(
996       {Store::make(a, {x}, Load::make(b, {x})),
997        Store::make(c, {x}, Load::make(a, {x})),
998        Cond::make(
999            CompareSelect::make(x, 5, CompareSelectOperation::kLT),
1000            Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)),
1001            nullptr)});
1002 
1003   /*
1004    * A[x] = B[x];
1005    * C[x] = A[x];
1006    * if (x<5 ? 1 : 0) {
1007    *   A[x] = (A[x]) + 1;
1008    * }
1009    */
1010 
1011   stmt = registerize(stmt);
1012 
1013   /*
1014    * int A_1 = B[x];
1015    * C[x] = A_1;
1016    * if (x<5 ? 1 : 0) {
1017    *   A_1 = A_1 + 1;
1018    * }
1019    * A[x] = A_1;
1020    */
1021 
1022   std::ostringstream oss;
1023   oss << *stmt;
1024 
1025   const std::string& verification_pattern =
1026       R"IR(
1027 # CHECK: int A_1 = B[x];
1028 # CHECK: C[x] = A_1;
1029 # CHECK: if (
1030 # CHECK:   A_1 = A_1 + 1;
1031 # CHECK: A[x] = A_1;)IR";
1032 
1033   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
1034 }
1035 
1036 // Should be able to registerize this since the scalar exists in the same form
1037 // after the branch and there is no overlap.
TEST(Registerizer,RegisterizerConditionBefore)1038 TEST(Registerizer, RegisterizerConditionBefore) {
1039   BufHandle a("A", {5}, kInt);
1040   BufHandle b("B", {5}, kInt);
1041   BufHandle c("C", {5}, kInt);
1042   VarHandle x("x", kInt);
1043 
1044   StmtPtr stmt = Block::make(
1045       {Cond::make(
1046            CompareSelect::make(x, 5, CompareSelectOperation::kLT),
1047            Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)),
1048            nullptr),
1049        Store::make(a, {x}, Load::make(b, {x})),
1050        Store::make(c, {x}, Load::make(a, {x}))});
1051 
1052   /*
1053    * if (x<5 ? 1 : 0) {
1054    *   A[x] = (A[x]) + 1;
1055    * }
1056    * A[x] = B[x];
1057    * C[x] = A[x];
1058    */
1059 
1060   stmt = registerize(stmt);
1061 
1062   /*
1063    * int A_ 1 = A[x];
1064    * if (x<5 ? 1 : 0) {
1065    *   A_1 = A_1 + 1;
1066    * }
1067    * A_1 = B[x];
1068    * C[x] = A_1;
1069    * A[x] = A_1;
1070    */
1071 
1072   std::ostringstream oss;
1073   oss << *stmt;
1074 
1075   const std::string& verification_pattern =
1076       R"IR(
1077 # CHECK: int A_1 = A[x];
1078 # CHECK: if (
1079 # CHECK:   A_1 = A_1 + 1;
1080 # CHECK: }
1081 # CHECK: A_1 = B[x];
1082 # CHECK: C[x] = A_1;
1083 # CHECK: A[x] = A_1;)IR";
1084 
1085   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
1086 }
1087 
1088 // Should be able to registerize this as the combination of the two above rules.
TEST(Registerizer,RegisterizerConditionInside)1089 TEST(Registerizer, RegisterizerConditionInside) {
1090   BufHandle a("A", {5}, kInt);
1091   BufHandle b("B", {5}, kInt);
1092   BufHandle c("C", {5}, kInt);
1093   VarHandle x("x", kInt);
1094 
1095   StmtPtr stmt = Block::make(
1096       {Store::make(a, {x}, Load::make(b, {x})),
1097        Store::make(c, {x}, Load::make(a, {x})),
1098        Cond::make(
1099            CompareSelect::make(x, 5, CompareSelectOperation::kLT),
1100            Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)),
1101            nullptr),
1102        Store::make(b, {x}, Load::make(a, {x})),
1103        Store::make(a, {x}, Load::make(c, {x}))});
1104 
1105   /*
1106    * A[x] = B[x];
1107    * C[x] = A[x];
1108    * if (x<5 ? 1 : 0) {
1109    *   A[x] = (A[x]) + 1;
1110    * }
1111    * B[x] = A[x];
1112    * A[x] = C[x];
1113    */
1114 
1115   stmt = registerize(stmt);
1116 
1117   /*
1118    * int A_1 = B[x];
1119    * C[x] = A_1;
1120    * if (x<5 ? 1 : 0) {
1121    *   A_1 = A_1 + 1;
1122    * }
1123    * B[x] = A_1;
1124    * A_1 = C[x];
1125    * A[x] = A_1;
1126    */
1127 
1128   std::ostringstream oss;
1129   oss << *stmt;
1130 
1131   const std::string& verification_pattern =
1132       R"IR(
1133 # CHECK: int A_1 = B[x];
1134 # CHECK: C[x] = A_1;
1135 # CHECK: if (
1136 # CHECK:   A_1 = A_1 + 1;
1137 # CHECK: }
1138 # CHECK: B[x] = A_1;
1139 # CHECK: A_1 = C[x];
1140 # CHECK: A[x] = A_1;)IR";
1141 
1142   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
1143 }
1144 
1145 // An example where an access is cut by an overlapping access inside a
1146 // condition, and both sides are large enough to be registerized but cannot be
1147 // because there is no safe place to put the initializer or finalizer.
TEST(Registerizer,RegisterizerConditionInsideOverlap1)1148 TEST(Registerizer, RegisterizerConditionInsideOverlap1) {
1149   BufHandle a("A", {5}, kInt);
1150   BufHandle b("B", {5}, kInt);
1151   BufHandle c("C", {5}, kInt);
1152   VarHandle x("x", kInt);
1153   VarHandle y("y", kInt);
1154 
1155   StmtPtr stmt = Block::make(
1156       // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
1157       {Store::make(a, {x}, Load::make(b, {x})),
1158        Store::make(c, {x}, Load::make(a, {x})),
1159        Cond::make(
1160            CompareSelect::make(x, 5, CompareSelectOperation::kLT),
1161            Block::make({
1162                Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)),
1163                Store::make(a, {0}, 3),
1164                Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)),
1165            }),
1166            nullptr),
1167        Store::make(b, {x}, Load::make(a, {x})),
1168        Store::make(a, {x}, Load::make(c, {x}))});
1169 
1170   /*
1171    * A[x] = B[x];
1172    * C[x] = A[x];
1173    * if (x<5 ? 1 : 0) {
1174    *   A[x] = (A[x]) + 1;
1175    *   A[0] = 3;
1176    *   A[x] = (A[x]) + 1;
1177    * }
1178    * B[x] = A[x];
1179    * A[x] = C[x];
1180    */
1181 
1182   // The A[0] store overlaps, A[x] cutting the region that can be registerized
1183   // into two groups.
1184   // Each group has 2 loads and 2 stores however, so we could registerize it,
1185   // but the first group would need to be finalized inside the condition block,
1186   // the second would need to be initialized inside the condition block. There's
1187   // no safe place to put these that's visible to the other uses in the group
1188   // and so neither registerization is possible.
1189 
1190   std::ostringstream before;
1191   before << *stmt;
1192 
1193   // No change.
1194   stmt = registerize(stmt);
1195 
1196   std::ostringstream after;
1197   after << *stmt;
1198 
1199   ASSERT_EQ(before.str(), after.str());
1200 }
1201 
1202 // Same as the above, but the access group before the condition (and after the
1203 // condition) are large enough to be registerized without needing the access
1204 // from the loop. Registerization occurs but does not include any accesses in
1205 // the condition, and the first group must be finalized before the Cond, the
1206 // second initialized after it.
TEST(Registerizer,RegisterizerConditionInsideOverlap2)1207 TEST(Registerizer, RegisterizerConditionInsideOverlap2) {
1208   BufHandle a("A", {5}, kInt);
1209   BufHandle b("B", {5}, kInt);
1210   BufHandle c("C", {5}, kInt);
1211   VarHandle x("x", kInt);
1212   VarHandle y("y", kInt);
1213 
1214   StmtPtr stmt = Block::make(
1215       // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
1216       {Store::make(a, {x}, Load::make(b, {x})),
1217        Store::make(a, {x}, Load::make(b, {x + 1})),
1218        Store::make(c, {x}, Load::make(a, {x})),
1219        Cond::make(
1220            CompareSelect::make(x, 5, CompareSelectOperation::kLT),
1221            Block::make({
1222                Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)),
1223                Store::make(a, {0}, 3),
1224                Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)),
1225            }),
1226            nullptr),
1227        Store::make(b, {x}, Load::make(a, {x})),
1228        Store::make(b, {x + 1}, Load::make(a, {x})),
1229        Store::make(a, {x}, Load::make(c, {x}))});
1230 
1231   /*
1232    * A[x] = B[x];
1233    * A[x] = B[x + 1];
1234    * C[x] = A[x];
1235    * if (x<5 ? 1 : 0) {
1236    *   A[x] = (A[x]) + 1;
1237    *   A[0] = 3;
1238    *   A[x] = (A[x]) + 1;
1239    * }
1240    * B[x] = A[x];
1241    * B[x + 1] = A[x];
1242    * A[x] = C[x];
1243    */
1244 
1245   stmt = registerize(stmt);
1246 
1247   /*
1248    * int A_1 = B[x];              // A_1 initializer
1249    * A_1 = B[x + 1];              //
1250    * C[x] = A_1;                  //
1251    * A[x] = A_1;                  // A_1 finalizer
1252    * if (x<5 ? 1 : 0) {
1253    *   A[x] = (A[x]) + 1;
1254    *   A[0] = 3;
1255    *   A[x] = (A[x]) + 1;
1256    * }
1257    * int A_2 = A[x];              // A_2 initialier
1258    * B[x] = A_2;                  //
1259    * B[x + 1] = A_2;              //
1260    * A_2 = C[x];                  //
1261    * A[x] = A_2;                  // A_2 finalizer
1262    */
1263 
1264   std::ostringstream oss;
1265   oss << *stmt;
1266 
1267   const std::string& verification_pattern =
1268       R"IR(
1269 # CHECK: int A_1 = B[x];
1270 # CHECK: A_1 = B[x + 1];
1271 # CHECK: C[x] = A_1;
1272 # CHECK: A[x] = A_1;
1273 # CHECK: if (
1274 # CHECK-NOT:   A_1 = A_1 + 1;
1275 # CHECK:   A[x] = (A[x]
1276 # CHECK:   A[0] =
1277 # CHECK:   A[x] = (A[x]
1278 # CHECK: }
1279 # CHECK: int A_2 = A[x];
1280 # CHECK: B[x] = A_2;
1281 # CHECK: B[x + 1] = A_2;
1282 # CHECK: A_2 = C[x];
1283 # CHECK: A[x] = A_2;)IR";
1284 
1285   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
1286 }
1287 
1288 // When accesses are within conditional blocks they are not visible to the wider
1289 // program, because we don't know if the branch would be taken and if it isn't
1290 // the accesses in it don't need to be valid (think size checks on the index).
1291 // In this case the accesses cannot be registerized.
TEST(Registerizer,RegisterizerConditionHidden)1292 TEST(Registerizer, RegisterizerConditionHidden) {
1293   BufHandle a("A", {5}, kInt);
1294   BufHandle b("B", {5}, kInt);
1295   BufHandle c("C", {5}, kInt);
1296   VarHandle x("x", kInt);
1297 
1298   StmtPtr stmt = Block::make(
1299       {Cond::make(
1300            CompareSelect::make(x, 5, CompareSelectOperation::kLT),
1301            Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)),
1302            nullptr),
1303        Cond::make(
1304            CompareSelect::make(x, 5, CompareSelectOperation::kGT),
1305            Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)),
1306            nullptr)});
1307 
1308   /*
1309    * if (x<5 ? 1 : 0) {
1310    *   A[x] = (A[x]) + 1;
1311    * }
1312    * if (x>5 ? 1 : 0) {
1313    *   A[x] = (A[x]) + 1;
1314    * }
1315    */
1316 
1317   std::ostringstream before;
1318   before << *stmt;
1319 
1320   // No change.
1321   stmt = registerize(stmt);
1322 
1323   std::ostringstream after;
1324   after << *stmt;
1325 
1326   ASSERT_EQ(before.str(), after.str());
1327 }
1328 
1329 // But... if the same access is found in a non conditional scope, that means
1330 // that that access is valid in the higher scope (or at least if its not it's
1331 // the user's fault). It "unhides" the conditional accesses, allowing
1332 // registerization to occur.
TEST(Registerizer,RegisterizerConditionUnhidden)1333 TEST(Registerizer, RegisterizerConditionUnhidden) {
1334   BufHandle a("A", {5}, kInt);
1335   BufHandle b("B", {5}, kInt);
1336   BufHandle c("C", {5}, kInt);
1337   VarHandle x("x", kInt);
1338 
1339   StmtPtr stmt = Block::make(
1340       {Cond::make(
1341            CompareSelect::make(x, 5, CompareSelectOperation::kLT),
1342            Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)),
1343            nullptr),
1344        Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)),
1345        Cond::make(
1346            CompareSelect::make(x, 5, CompareSelectOperation::kGT),
1347            Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)),
1348            nullptr)});
1349 
1350   /*
1351    * if (x<5 ? 1 : 0) {
1352    *   A[x] = (A[x]) + 1;
1353    * }
1354    * A[x] = (A[x]) + 1;            <-- this is doing the unhiding.
1355    * if (x>5 ? 1 : 0) {
1356    *   A[x] = (A[x]) + 1;
1357    * }
1358    */
1359 
1360   stmt = registerize(stmt);
1361 
1362   /*
1363    * int A_1 = A[x];
1364    * if (x<5 ? 1 : 0) {
1365    *   A_1 = A_1 + 1;
1366    * }
1367    * A_1 = A_1 + 1;
1368    * if (x>5 ? 1 : 0) {
1369    *   A_1 = A_1 + 1;
1370    * }
1371    * A[x] = A_1;
1372    */
1373 
1374   std::ostringstream oss;
1375   oss << *stmt;
1376 
1377   const std::string& verification_pattern =
1378       R"IR(
1379 # CHECK: int A_1 = A[x];
1380 # CHECK: if (x<5
1381 # CHECK:   A_1 = A_1 + 1;
1382 # CHECK: }
1383 # CHECK: A_1 = A_1 + 1;
1384 # CHECK: if (x>5
1385 # CHECK:   A_1 = A_1 + 1;
1386 # CHECK: }
1387 # CHECK: A[x] = A_1;)IR";
1388 
1389   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
1390 }
1391 
1392 // Can registerize a load that occurs in the condition of a Cond.
TEST(Registerizer,RegisterizerCondCondition)1393 TEST(Registerizer, RegisterizerCondCondition) {
1394   BufHandle a("A", {5}, kInt);
1395   BufHandle b("B", {5}, kInt);
1396   BufHandle c("C", {5}, kInt);
1397   VarHandle x("x", kInt);
1398 
1399   StmtPtr stmt = Block::make(
1400       {Store::make(a, {x}, Load::make(b, {x})),
1401        Store::make(c, {x}, Load::make(a, {x})),
1402        Cond::make(
1403            CompareSelect::make(
1404                Load::make(a, {x}), 5, CompareSelectOperation::kLT),
1405            Store::make(c, {x}, Add::make(Load::make(c, {x}), 1)),
1406            nullptr)});
1407 
1408   /*
1409    * A[x] = B[x];
1410    * C[x] = A[x];
1411    * if ((A[x])<5 ? 1 : 0) {
1412    *   C[x] = (C[x]) + 1;
1413    * }
1414    */
1415 
1416   stmt = registerize(stmt);
1417 
1418   /*
1419    * int A_1 = B[x];
1420    * int C_1 = A_1;
1421    * if (A_1<5 ? 1 : 0) {
1422    *   C_1 = C_1 + 1;
1423    * }
1424    * C[x] = C_1;
1425    */
1426 
1427   std::ostringstream oss;
1428   oss << *stmt;
1429 
1430   const std::string& verification_pattern =
1431       R"IR(
1432 # CHECK: int A_1 = B[x];
1433 # CHECK: int C_1 = A_1;
1434 # CHECK: if (A_1<5
1435 # CHECK:   C_1 = C_1 + 1;
1436 # CHECK: C[x] = C_1;)IR";
1437 
1438   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
1439 }
1440 
1441 // Appearing in the condition of a Cond makes it visible to the enclosing scope,
1442 // and so we can registerize internal usages.
TEST(Registerizer,RegisterizerCondConditionUnhidden)1443 TEST(Registerizer, RegisterizerCondConditionUnhidden) {
1444   BufHandle a("A", {5}, kInt);
1445   BufHandle b("B", {5}, kInt);
1446   BufHandle c("C", {5}, kInt);
1447   VarHandle x("x", kInt);
1448 
1449   StmtPtr stmt = Block::make({Cond::make(
1450       CompareSelect::make(Load::make(a, {x}), 5, CompareSelectOperation::kLT),
1451       Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)),
1452       Store::make(a, {x}, Add::make(Load::make(a, {x}), 10)))});
1453 
1454   /*
1455    * if ((A[x])<5 ? 1 : 0) {
1456    *   A[x] = (A[x]) + 1;
1457    * } else {
1458    *   A[x] = (A[x]) + 10;
1459    * }
1460    */
1461 
1462   stmt = registerize(stmt);
1463 
1464   /*
1465    * int A_1 = A[x];
1466    * if (A_1<5 ? 1 : 0) {
1467    *   A_1 = A_1 + 1;
1468    * } else {
1469    *   A_1 = A_1 + 10;
1470    * }
1471    * A[x] = A_1;
1472    */
1473 
1474   std::ostringstream oss;
1475   oss << *stmt;
1476 
1477   const std::string& verification_pattern =
1478       R"IR(
1479 # CHECK: int A_1 = A[x];
1480 # CHECK: if (A_1<5
1481 # CHECK:   A_1 = A_1 + 1;
1482 # CHECK: } else {
1483 # CHECK:   A_1 = A_1 + 10;
1484 # CHECK: }
1485 # CHECK: A[x] = A_1;)IR";
1486 
1487   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
1488 }
1489 
1490 // Conditional hiding also works for IfThenElse exprs.
TEST(Registerizer,RegisterizerIfThenElseHidden)1491 TEST(Registerizer, RegisterizerIfThenElseHidden) {
1492   BufHandle a("A", {5}, kInt);
1493   BufHandle b("B", {5}, kInt);
1494   BufHandle c("C", {5}, kInt);
1495   VarHandle x("x", kInt);
1496   VarHandle y("y", kInt);
1497 
1498   StmtPtr stmt = Block::make(
1499       {Store::make(
1500            b,
1501            {y},
1502            IfThenElse::make(
1503                CompareSelect::make(x, 5, CompareSelectOperation::kLT),
1504                Add::make(Load::make(a, {x}), 1),
1505                Add::make(Load::make(a, {x + 1}), 2))),
1506        Store::make(
1507            b,
1508            {y + 1},
1509            IfThenElse::make(
1510                CompareSelect::make(x, 5, CompareSelectOperation::kLT),
1511                Add::make(Load::make(a, {x}), 1),
1512                Add::make(Load::make(a, {x + 1}), 2)))});
1513 
1514   /*
1515    * B[y] = IfThenElse(x<5 ? 1 : 0, (A[x]) + 1, (A[x + 1]) + 2);
1516    * B[y + 1] = IfThenElse(x<5 ? 1 : 0, (A[x]) + 1, (A[x + 1]) + 2);
1517    */
1518 
1519   std::ostringstream before;
1520   before << *stmt;
1521 
1522   // No change.
1523   stmt = registerize(stmt);
1524 
1525   std::ostringstream after;
1526   after << *stmt;
1527 
1528   ASSERT_EQ(before.str(), after.str());
1529 }
1530 
1531 // Conditional unhiding also works for IfThenElse exprs.
TEST(Registerizer,RegisterizerIfThenElseUnhidden)1532 TEST(Registerizer, RegisterizerIfThenElseUnhidden) {
1533   BufHandle a("A", {5}, kInt);
1534   BufHandle b("B", {5}, kInt);
1535   BufHandle c("C", {5}, kInt);
1536   VarHandle x("x", kInt);
1537   VarHandle y("y", kInt);
1538 
1539   StmtPtr stmt = Block::make({
1540       Store::make(a, {x}, 0),
1541       Store::make(
1542           b,
1543           {y},
1544           IfThenElse::make(
1545               CompareSelect::make(x, 5, CompareSelectOperation::kLT),
1546               Add::make(Load::make(a, {x}), 1),
1547               Add::make(Load::make(a, {x + 1}), 2))),
1548       Store::make(
1549           b,
1550           {y + 1},
1551           IfThenElse::make(
1552               CompareSelect::make(x, 5, CompareSelectOperation::kLT),
1553               Add::make(Load::make(a, {x}), 1),
1554               Add::make(Load::make(a, {x + 1}), 2))),
1555   });
1556 
1557   /*
1558    * A[x] = 0;
1559    * B[y] = IfThenElse(x<5 ? 1 : 0, (A[x]) + 1, (A[x + 1]) + 2);
1560    * B[y + 1] = IfThenElse(x<5 ? 1 : 0, (A[x]) + 1, (A[x + 1]) + 2);
1561    */
1562 
1563   stmt = registerize(stmt);
1564 
1565   /*
1566    * int A_1 = 0;
1567    * B[y] = IfThenElse(x<5 ? 1 : 0, A_1 + 1, (A[x + 1]) + 2);
1568    * B[y + 1] = IfThenElse(x<5 ? 1 : 0, A_1 + 1, (A[x + 1]) + 2);
1569    * A[x] = A_1;
1570    */
1571 
1572   std::ostringstream oss;
1573   oss << *stmt;
1574 
1575   const std::string& verification_pattern =
1576       R"IR(
1577 # CHECK: int A_1 = 0;
1578 # CHECK: B[y] = IfThenElse(x<5 ? 1 : 0, A_1 + 1, (A[x + 1]) + 2);
1579 # CHECK: B[y + 1] = IfThenElse(x<5 ? 1 : 0, A_1 + 1, (A[x + 1]) + 2);
1580 # CHECK: A[x] = A_1;)IR";
1581 
1582   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
1583 }
1584 
1585 // Nested IfThenElse exprs can't promote to higher level scopes.
TEST(Registerizer,RegisterizerIfThenElseNested)1586 TEST(Registerizer, RegisterizerIfThenElseNested) {
1587   BufHandle a("A", {5}, kInt);
1588   BufHandle b("B", {5}, kInt);
1589   BufHandle c("C", {5}, kInt);
1590   BufHandle d("D", {5}, kInt);
1591   VarHandle x("x", kInt);
1592 
1593   StmtPtr stmt = Block::make({Store::make(
1594       a,
1595       {x},
1596       IfThenElse::make(
1597           CompareSelect::make(x, 3, CompareSelectOperation::kLT),
1598           IfThenElse::make(
1599               CompareSelect::make(x, 2, CompareSelectOperation::kEQ),
1600               Load::make(d, {x}),
1601               Load::make(b, {x})),
1602           IfThenElse::make(
1603               CompareSelect::make(x, 5, CompareSelectOperation::kEQ),
1604               Load::make(c, {x}),
1605               Load::make(d, {x}))))});
1606 
1607   /*
1608    * A[x] = IfThenElse(x<3 ? 1 : 0,
1609    *          IfThenElse(x==2 ? 1 : 0, D[x], B[x]),
1610    *            IfThenElse(x==5 ? 1 : 0, C[x], D[x]));
1611    */
1612 
1613   std::ostringstream before;
1614   before << *stmt;
1615 
1616   // No change.
1617   stmt = registerize(stmt);
1618 
1619   std::ostringstream after;
1620   after << *stmt;
1621 
1622   ASSERT_EQ(before.str(), after.str());
1623 }
1624 
1625 // Cannot registerize an access completely contained within an IfThenElse
1626 // branch, since it is not a Stmt and cannot hold variable definitions. We need
1627 // to check that we don't promote the initializer/finalizer to the enclosing
1628 // Block.
TEST(Registerizer,RegisterizerIfThenElseInternal)1629 TEST(Registerizer, RegisterizerIfThenElseInternal) {
1630   // Making these floats so they don't get simplified to a single access.
1631   BufHandle a("A", {5}, kFloat);
1632   BufHandle b("B", {5}, kFloat);
1633   VarHandle x("x", kInt);
1634 
1635   StmtPtr stmt = Block::make({Store::make(
1636       a,
1637       {x},
1638       IfThenElse::make(
1639           CompareSelect::make(x, 3, CompareSelectOperation::kLT),
1640           Add::make(Load::make(b, {x}), Load::make(b, {x})),
1641           Load::make(b, {x})))});
1642 
1643   /*
1644    * A[x] = IfThenElse(x<3 ? 1 : 0, (B[x]) + (B[x]), B[x]);
1645    */
1646 
1647   std::ostringstream before;
1648   before << *stmt;
1649 
1650   // No change.
1651   stmt = registerize(stmt);
1652 
1653   std::ostringstream after;
1654   after << *stmt;
1655 
1656   ASSERT_EQ(before.str(), after.str());
1657 
1658   // If this was a Cond instead of an IfThenElse then we could registerize the
1659   // two accesses to B[x] in the True branch.
1660 
1661   // Actually lets verify that.
1662 
1663   stmt = Block::make({Cond::make(
1664       CompareSelect::make(x, 3, CompareSelectOperation::kLT),
1665       Store::make(a, {x}, Add::make(Load::make(b, {x}), Load::make(b, {x}))),
1666       Store::make(a, {x}, Load::make(b, {x})))});
1667 
1668   /*
1669    * if (x<3 ? 1 : 0) {
1670    *   A[x] = (B[x]) + (B[x]);
1671    * } else {
1672    *   A[x] = B[x];
1673    * }
1674    */
1675 
1676   stmt = registerize(stmt);
1677 
1678   /*
1679    * if (x<3 ? 1 : 0) {
1680    *   float B_1 = B[x];
1681    *   A[x] = B_1 + B_1;
1682    * } else {
1683    *   A[x] = B[x];
1684    * }
1685    */
1686 
1687   std::ostringstream oss;
1688   oss << *stmt;
1689 
1690   const std::string& verification_pattern =
1691       R"IR(
1692 # CHECK-NOT: int
1693 # CHECK-NOT: float
1694 # CHECK: if (x<3
1695 # CHECK:   float B_1 =
1696 # CHECK:   A[x] = B_1 + B_1
1697 # CHECK: } else {
1698 # CHECK:   A[x] = B[x]
1699 # CHECK: }
1700 # CHECK-NOT: A[x]
1701 # CHECK-NOT: B[x])IR";
1702 
1703   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
1704 }
1705 
1706 // Can registerize a load that occurs in the condition of an IfThenElse;
TEST(Registerizer,RegisterizerIfThenElseCondition)1707 TEST(Registerizer, RegisterizerIfThenElseCondition) {
1708   BufHandle a("A", {5}, kInt);
1709   BufHandle b("B", {5}, kInt);
1710   BufHandle c("C", {5}, kInt);
1711   VarHandle x("x", kInt);
1712 
1713   StmtPtr stmt = Block::make(
1714       {Store::make(a, {x}, Load::make(a, {x})),
1715        Store::make(
1716            a,
1717            {x},
1718            IfThenElse::make(
1719                CompareSelect::make(
1720                    Load::make(a, {x}), 5, CompareSelectOperation::kLT),
1721                Load::make(b, {0}),
1722                Load::make(c, {0})))});
1723 
1724   /*
1725    * A[x] = A[x];       <---- just here so there are enough accesses to combine.
1726    * A[x] = IfThenElse((A[x])<5 ? 1 : 0, B[0], C[0]);
1727    */
1728 
1729   stmt = registerize(stmt);
1730 
1731   /*
1732    * int A_1 = A[x];
1733    * A_1 = A_1;
1734    * A_1 = IfThenElse(A_1<5 ? 1 : 0, B[0], C[0]);
1735    * A[x] = A_1;
1736    */
1737 
1738   std::ostringstream oss;
1739   oss << *stmt;
1740 
1741   const std::string& verification_pattern =
1742       R"IR(
1743 # CHECK: int A_1 = A[x];
1744 # CHECK: A_1 = IfThenElse(A_1<5 ? 1 : 0, B[0], C[0]);
1745 # CHECK: A[x] = A_1;)IR";
1746 
1747   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
1748 }
1749 
1750 // Appearing in the condition of a Cond makes it visible to the enclosing scope,
1751 // and so we can registerize internal usages.
TEST(Registerizer,RegisterizerIfThenElseConditionUnhidden)1752 TEST(Registerizer, RegisterizerIfThenElseConditionUnhidden) {
1753   BufHandle a("A", {5}, kInt);
1754   BufHandle b("B", {5}, kInt);
1755   BufHandle c("C", {5}, kInt);
1756   VarHandle x("x", kInt);
1757 
1758   StmtPtr stmt = Block::make({Store::make(
1759       b,
1760       {x},
1761       IfThenElse::make(
1762           CompareSelect::make(
1763               Load::make(a, {x}), 5, CompareSelectOperation::kLT),
1764           Add::make(Load::make(a, {x}), 1),
1765           Add::make(Load::make(a, {x}), 10)))});
1766 
1767   /*
1768    * B[x] = IfThenElse((A[x])<5 ? 1 : 0, (A[x]) + 1, (A[x]) + 10);
1769    */
1770 
1771   stmt = registerize(stmt);
1772 
1773   /*
1774    * int A_1 = A[x];
1775    * B[x] = IfThenElse(A_1<5 ? 1 : 0, A_1 + 1, A_1 + 10);
1776    */
1777 
1778   std::ostringstream oss;
1779   oss << *stmt;
1780 
1781   const std::string& verification_pattern =
1782       R"IR(
1783 # CHECK: int A_1 = A[x];
1784 # CHECK: B[x] = IfThenElse(A_1<5 ? 1 : 0, A_1 + 1, A_1 + 10);)IR";
1785 
1786   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
1787 }
1788 
1789 // Cannot promote accesses internal to IfThenElse branches even if the enclosing
1790 // scope if conditional.
TEST(Registerizer,RegisterizerConditionBranchOnly)1791 TEST(Registerizer, RegisterizerConditionBranchOnly) {
1792   BufHandle a("A", {5}, kInt);
1793   VarHandle x("x", kInt);
1794   StmtPtr stmt = Block::make({For::make(
1795       x,
1796       0,
1797       10,
1798       Block::make({
1799           Cond::make(
1800               CompareSelect::make(x, 5, CompareSelectOperation::kLT),
1801               Store::make(
1802                   a,
1803                   {x},
1804                   IfThenElse::make(
1805                       CompareSelect::make(x, 5, CompareSelectOperation::kLT),
1806                       Add::make(Load::make(a, {x}), x),
1807                       Add::make(Load::make(a, {x - 5}), x))),
1808               Store::make(
1809                   a,
1810                   {x - 5},
1811                   IfThenElse::make(
1812                       CompareSelect::make(x, 5, CompareSelectOperation::kLT),
1813                       Add::make(Load::make(a, {x}), x),
1814                       Add::make(Load::make(a, {x - 5}), x)))),
1815       }))});
1816   stmt = IRSimplifier::simplify(stmt);
1817 
1818   std::ostringstream before;
1819   before << *stmt;
1820 
1821   /* for (int x = 0; x < 10; x++) {
1822    *   if (x<5 ? 1 : 0) {
1823    *     A[x] = IfThenElse(x<5 ? 1 : 0, (A[x]) + x, (A[x - 5]) + x);
1824    *   } else {
1825    *     A[x - 5] = IfThenElse(x<5 ? 1 : 0, (A[x]) + x, (A[x - 5]) + x);
1826    *   }
1827    * }
1828    */
1829 
1830   // No change.
1831   stmt = registerize(stmt);
1832 
1833   std::ostringstream after;
1834   after << *stmt;
1835 
1836   ASSERT_EQ(before.str(), after.str());
1837 }
1838 
1839 // We can registerize an IfThenElse that appears in the condition branch of a
1840 // Cond. This is a weird but valid thing to do.
TEST(Registerizer,RegisterizerCondIfThenElse)1841 TEST(Registerizer, RegisterizerCondIfThenElse) {
1842   BufHandle a("A", {5}, kInt);
1843   BufHandle b("B", {5}, kInt);
1844   BufHandle c("C", {5}, kInt);
1845   VarHandle x("x", kInt);
1846 
1847   StmtPtr stmt = Block::make({Cond::make(
1848       CompareSelect::make(
1849           IfThenElse::make(
1850               CompareSelect::make(
1851                   Load::make(a, {x}), 5, CompareSelectOperation::kLT),
1852               Load::make(a, {x}),
1853               Load::make(b, {x})),
1854           x,
1855           CompareSelectOperation::kEQ),
1856       Store::make(c, {x}, Add::make(Load::make(c, {x}), 1)),
1857       nullptr)});
1858 
1859   /*
1860    * if ((IfThenElse((A[x])<5 ? 1 : 0, A[x], B[x]))==x ? 1 : 0) {
1861    *   C[x] = (C[x]) + 1;
1862    * }
1863    */
1864 
1865   stmt = registerize(stmt);
1866 
1867   // access to A can be registerized, but not B or C
1868 
1869   /*
1870    * int A_1 = A[x];
1871    * if ((IfThenElse(A_1<5 ? 1 : 0, A_1, B[x]))==x ? 1 : 0) {
1872    *   C[x] = (C[x]) + 1;
1873    * }
1874    */
1875 
1876   std::ostringstream oss;
1877   oss << *stmt;
1878 
1879   const std::string& verification_pattern =
1880       R"IR(
1881 # CHECK: int A_1 = A[x];
1882 # CHECK: if ((IfThenElse(A_1<5 ? 1 : 0, A_1, B[x]
1883 # CHECK:   C[x] = (C[x]) + 1;)IR";
1884 
1885   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
1886 }
1887 
1888 // Can registerize a conditional access in the RHS of a store unhidden by it's
1889 // LHS, and hoist it out of a loop.
TEST(Registerizer,RegisterizerIfThenElseLoop)1890 TEST(Registerizer, RegisterizerIfThenElseLoop) {
1891   BufHandle a("A", {5}, kInt);
1892   BufHandle b("B", {5}, kInt);
1893   VarHandle x("x", kInt);
1894   VarHandle y("y", kInt);
1895 
1896   StmtPtr stmt = For::make(
1897       y,
1898       0,
1899       10,
1900       Store::make(
1901           a,
1902           {x},
1903           IfThenElse::make(
1904               CompareSelect::make(x, 3, CompareSelectOperation::kLT),
1905               Load::make(a, {x}),
1906               Load::make(b, {y}))));
1907 
1908   /*
1909    * for (int y = 0; y < 10; y++) {
1910    *   A[x] = IfThenElse(x<3 ? 1 : 0, A[x], B[y]);
1911    * }
1912    */
1913 
1914   stmt = registerize(stmt);
1915 
1916   /*
1917    * int A_1 = A[x];
1918    * for (int y = 0; y < 10; y++) {
1919    *   A_1 = IfThenElse(x<3 ? 1 : 0, A_1, B[y]);
1920    * }
1921    * A[x] = A_1;
1922    */
1923 
1924   std::ostringstream oss;
1925   oss << *stmt;
1926 
1927   const std::string& verification_pattern =
1928       R"IR(
1929 # CHECK: int A_1 = A[x];
1930 # CHECK: for (
1931 # CHECK:   A_1 = IfThenElse(x<3 ? 1 : 0, A_1, B[y]);
1932 # CHECK: }
1933 # CHECK: A[x] = A_1;)IR";
1934 
1935   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
1936 }
1937 
1938 // Cannot registerize if the RHS overlaps the access creating visibility.
TEST(Registerizer,RegisterizerIfThenElseLoopCut)1939 TEST(Registerizer, RegisterizerIfThenElseLoopCut) {
1940   BufHandle a("A", {5}, kInt);
1941   BufHandle b("B", {5}, kInt);
1942   VarHandle x("x", kInt);
1943   VarHandle y("y", kInt);
1944 
1945   StmtPtr stmt = Block::make({For::make(
1946       y,
1947       0,
1948       10,
1949       Store::make(
1950           a,
1951           {x},
1952           IfThenElse::make(
1953               CompareSelect::make(x, 3, CompareSelectOperation::kLT),
1954               Load::make(a, {x}),
1955               Load::make(a, {y}))))});
1956 
1957   /*
1958    * for (int y = 0; y < 10; y++) {
1959    *   A[x] = IfThenElse(x<3 ? 1 : 0, A[x], A[y]);
1960    * }
1961    */
1962 
1963   std::ostringstream before;
1964   before << *stmt;
1965 
1966   // No change.
1967   stmt = registerize(stmt);
1968 
1969   std::ostringstream after;
1970   after << *stmt;
1971 
1972   ASSERT_EQ(before.str(), after.str());
1973 }
1974 
1975 // Simple case where an access is cut by an overlapping access later in the
1976 // program, we can registerize up until the overlap.
TEST(Registerizer,RegisterizerPartialAfter)1977 TEST(Registerizer, RegisterizerPartialAfter) {
1978   BufHandle a("A", {1}, kInt);
1979   VarHandle x("x", kInt);
1980   StmtPtr stmt = Block::make(
1981       {Store::make(a, {0}, 0),
1982        For::make(
1983            x,
1984            0,
1985            10,
1986            Block::make(
1987                {Store::make(a, {0}, Add::make(Load::make(a, {0}), x))})),
1988        For::make(x, 1, 10, Store::make(a, {x}, Load::make(a, {x - 1})))});
1989 
1990   /*
1991    * A[0] = 0;
1992    * for (int x = 0; x < 10; x++) {
1993    *   A[0] = (A[0]) + x;
1994    * }
1995    * for (int x = 1; x < 10; x++) {
1996    *   A[x] = A[x - 1];
1997    * }
1998    */
1999 
2000   stmt = registerize(stmt);
2001 
2002   /*
2003    * int A_1 = 0;
2004    * for (int x = 0; x < 10; x++) {
2005    *   A_1 = A_1 + x;
2006    * }
2007    * A[0] = A_1;
2008    * for (int x = 1; x < 10; x++) {
2009    *   A[x] = A[x - 1];
2010    * }
2011    */
2012 
2013   std::ostringstream oss;
2014   oss << *stmt;
2015 
2016   const std::string& verification_pattern =
2017       R"IR(
2018 # CHECK: int A_1 = 0;
2019 # CHECK: for (
2020 # CHECK:   A_1 = A_1 + x;
2021 # CHECK: }
2022 # CHECK: A[0] = A_1;
2023 # CHECK: for (
2024 # CHECK:   A[x] = A[x - 1];
2025 # CHECK: }
2026 # CHECK-NOT: A)IR";
2027 
2028   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
2029 }
2030 
2031 // We can registerize an access which overlaps a previous access, the
2032 // initializer must be inserted after the previous access.
TEST(Registerizer,RegisterizerPartialBefore)2033 TEST(Registerizer, RegisterizerPartialBefore) {
2034   BufHandle a("A", {1}, kInt);
2035   VarHandle x("x", kInt);
2036   StmtPtr stmt = Block::make(
2037       {For::make(x, 1, 10, Store::make(a, {x}, Load::make(a, {x - 1}))),
2038        Store::make(a, {0}, 0),
2039        For::make(
2040            x,
2041            0,
2042            10,
2043            Block::make(
2044                {Store::make(a, {0}, Add::make(Load::make(a, {0}), x))}))});
2045 
2046   /*
2047    * for (int x = 1; x < 10; x++) {
2048    *   A[x] = A[x - 1];
2049    * }
2050    * A[0] = 0;
2051    * for (int x = 0; x < 10; x++) {
2052    *   A[0] = (A[0]) + x;
2053    * }
2054    */
2055 
2056   stmt = registerize(stmt);
2057 
2058   /*
2059    * for (int x = 1; x < 10; x++) {
2060    *   A[x] = A[x - 1];
2061    * }
2062    * int A_1 = 0;
2063    * for (int x = 0; x < 10; x++) {
2064    *   A_1 = A_1 + x;
2065    * }
2066    * A[0] = A_1;
2067    */
2068 
2069   std::ostringstream oss;
2070   oss << *stmt;
2071 
2072   const std::string& verification_pattern =
2073       R"IR(
2074 # CHECK-NOT: int
2075 # CHECK: for (
2076 # CHECK:   A[x] = A[x - 1];
2077 # CHECK: }
2078 # CHECK: int A_1 = 0;
2079 # CHECK: for (
2080 # CHECK:   A_1 = A_1 + x;
2081 # CHECK: }
2082 # CHECK: A[0] = A_1;)IR";
2083 
2084   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
2085 }
2086 
2087 // The combination of the previous two tests, an access is cut by an overlapping
2088 // access in both directions.
TEST(Registerizer,RegisterizerPartialInside)2089 TEST(Registerizer, RegisterizerPartialInside) {
2090   BufHandle a("A", {1}, kInt);
2091   VarHandle x1("x1", kInt);
2092   VarHandle x2("x2", kInt);
2093   VarHandle x3("x3", kInt);
2094   StmtPtr stmt = Block::make(
2095       {Store::make(a, {0}, 2),
2096        For::make(
2097            x1, 0, 10, Store::make(a, {0}, Add::make(Load::make(a, {0}), x1))),
2098        For::make(x2, 1, 10, Store::make(a, {x2}, Load::make(a, {x2 - 1}))),
2099        For::make(
2100            x3, 0, 10, Store::make(a, {0}, Add::make(Load::make(a, {0}), x3)))});
2101 
2102   /*
2103    * A[0] = 2;
2104    * for (int x1 = 0; x1 < 10; x1++) {
2105    *   A[0] = (A[0]) + x1;
2106    * }
2107    * for (int x2 = 1; x2 < 10; x2++) {
2108    *   A[x2] = A[x2 - 1];
2109    * }
2110    * for (int x3 = 0; x3 < 10; x3++) {
2111    *   A[0] = (A[0]) + x3;
2112    * }
2113    */
2114 
2115   stmt = registerize(stmt);
2116 
2117   /*
2118    * int A_1 = 2;
2119    * for (int x1 = 0; x1 < 10; x1++) {
2120    *   A_1 = A_1 + x1;
2121    * }
2122    * A[0] = A_1;
2123    * for (int x2 = 1; x2 < 10; x2++) {
2124    *   A[x2] = A[x2 - 1];
2125    * }
2126    * int A_2 = A[0];
2127    * for (int x3 = 0; x3 < 10; x3++) {
2128    *   A_2 = A_2 + x3;
2129    * }
2130    * A[0] = A_2;
2131    */
2132 
2133   std::ostringstream oss;
2134   oss << *stmt;
2135 
2136   const std::string& verification_pattern =
2137       R"IR(
2138 # CHECK: int A_1 = 2;
2139 # CHECK: for (
2140 # CHECK:   A_1 = A_1 + x1;
2141 # CHECK: }
2142 # CHECK: A[0] = A_1;
2143 # CHECK: for (
2144 # CHECK:   A[x2] =
2145 # CHECK: }
2146 # CHECK: int A_2 = A[0];
2147 # CHECK: for (
2148 # CHECK:   A_2 = A_2 + x3;
2149 # CHECK: }
2150 # CHECK: A[0] = A_2;)IR";
2151 
2152   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
2153 }
2154 
2155 // An element could be registerized program wide but is cut by a conditional
2156 // access, we should break this into two scalars and write back to the buffer
2157 // before the condition.
TEST(Registerizer,RegisterizerPartialCondition)2158 TEST(Registerizer, RegisterizerPartialCondition) {
2159   BufHandle a("A", {1}, kInt);
2160   VarHandle x("x", kInt);
2161   StmtPtr stmt = Block::make(
2162       {Store::make(a, {0}, 2),
2163        For::make(
2164            x, 0, 10, Store::make(a, {0}, Add::make(Load::make(a, {0}), x))),
2165        Cond::make(
2166            CompareSelect::make(x, 5, CompareSelectOperation::kLT),
2167            Store::make(a, {x}, Load::make(a, {x - 1})),
2168            nullptr),
2169        For::make(
2170            x, 0, 10, Store::make(a, {0}, Add::make(Load::make(a, {0}), x)))});
2171 
2172   /*
2173    * A[0] = 2;
2174    * for (int x = 0; x < 10; x++) {
2175    *   A[0] = (A[0]) + x;
2176    * }
2177    * if (x<5 ? 1 : 0) {
2178    *   A[x] = A[x - 1];
2179    * }
2180    * for (int x = 0; x < 10; x++) {
2181    *   A[0] = (A[0]) + x;
2182    * }
2183    */
2184 
2185   stmt = registerize(stmt);
2186 
2187   /*
2188    * int A_1 = 2;
2189    * for (int x = 0; x < 10; x++) {
2190    *   A_1 = A_1 + x;
2191    * }
2192    * A[0] = A_1;
2193    * if (x<5 ? 1 : 0) {
2194    *   A[x] = A[x - 1];
2195    * }
2196    * int A_2 = A[0];
2197    * for (int x = 0; x < 10; x++) {
2198    *   A_2 = A_2 + x;
2199    * }
2200    * A[0] = A_2;
2201    */
2202 
2203   std::ostringstream oss;
2204   oss << *stmt;
2205 
2206   const std::string& verification_pattern =
2207       R"IR(
2208 # CHECK: int A_1 = 2;
2209 # CHECK: for (
2210 # CHECK:   A_1 = A_1 + x;
2211 # CHECK: }
2212 # CHECK: A[0] = A_1;
2213 # CHECK: if (
2214 # CHECK:   A[x] =
2215 # CHECK: }
2216 # CHECK: int A_2 = A[0];
2217 # CHECK: for (
2218 # CHECK:   A_2 = A_2 + x;
2219 # CHECK: }
2220 # CHECK: A[0] = A_2;)IR";
2221 
2222   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
2223 }
2224 
2225 // Tests case where an access is cut by an internal conditional access which
2226 // itself is registerized.
TEST(Registerizer,RegisterizerPartialConditionInternalCut)2227 TEST(Registerizer, RegisterizerPartialConditionInternalCut) {
2228   BufHandle a("A", {1}, kInt);
2229   VarHandle x("x", kInt);
2230   StmtPtr stmt = Block::make(
2231       {Store::make(a, {0}, 1),
2232        Store::make(a, {0}, 3),
2233        Cond::make(
2234            CompareSelect::make(x, 5, CompareSelectOperation::kLT),
2235            Block::make({Store::make(a, {x}, 1), Store::make(a, {x}, 3)}),
2236            nullptr),
2237        Store::make(a, {0}, 4),
2238        Store::make(a, {0}, 6)});
2239 
2240   /*
2241    * A[0] = 1;
2242    * A[0] = 3;
2243    * if (x<5 ? 1 : 0) {
2244    *   A[x] = 1;
2245    *   A[x] = 3;
2246    * }
2247    * A[0] = 4;
2248    * A[0] = 6;
2249    */
2250 
2251   stmt = registerize(stmt);
2252 
2253   /*
2254    * int A_1 = 1;
2255    * A_1 = 3;
2256    * A[0] = A_1;
2257    * if (x<5 ? 1 : 0) {
2258    *   int A_2 = 1;
2259    *   A_2 = 3;
2260    *   A[x] = A_2;
2261    * }
2262    * int A_3 = 4;
2263    * A_3 = 6;
2264    * A[0] = A_3;
2265    */
2266 
2267   std::ostringstream oss;
2268   oss << *stmt;
2269 
2270   const std::string& verification_pattern =
2271       R"IR(
2272 # CHECK: int A_1 = 1;
2273 # CHECK: A_1 = 3
2274 # CHECK: A[0] = A_1;
2275 # CHECK: if (
2276 # CHECK:   int A_2 = 1;
2277 # CHECK:   A_2 = 3;
2278 # CHECK:   A[x] = A_2;
2279 # CHECK: }
2280 # CHECK: int A_3 = 4;
2281 # CHECK: A_3 = 6;
2282 # CHECK: A[0] = A_3;)IR";
2283 
2284   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
2285 }
2286 
2287 // First statement in condition closes outer access, but can be registerized
2288 // with later statements.
TEST(Registerizer,RegisterizerPartialConditionInternalStart)2289 TEST(Registerizer, RegisterizerPartialConditionInternalStart) {
2290   BufHandle a("A", {1}, kInt);
2291   VarHandle x("x", kInt);
2292   StmtPtr stmt = Block::make(
2293       {Store::make(a, {0}, 1),
2294        Store::make(a, {0}, 3),
2295        Cond::make(
2296            CompareSelect::make(x, 5, CompareSelectOperation::kLT),
2297            Block::make({Store::make(a, {x}, 1), Store::make(a, {x}, 3)}),
2298            nullptr),
2299        Store::make(a, {x}, 4),
2300        Store::make(a, {x}, 6)});
2301 
2302   /*
2303    * A[0] = 1;
2304    * A[0] = 3;
2305    * if (x<5 ? 1 : 0) {
2306    *   A[x] = 1;
2307    *   A[x] = 3;
2308    * }
2309    * A[x] = 4;
2310    * A[x] = 6;
2311    */
2312 
2313   stmt = registerize(stmt);
2314 
2315   /*
2316    * int A_1 = 1;
2317    * A_1 = 3;
2318    * A[0] = A_1;
2319    * int A_2 = A[x];    <--- must read from the input here.
2320    * if (x<5 ? 1 : 0) {
2321    *   A_2 = 1;
2322    *   A_2 = 3;
2323    * }
2324    * A_2 = 4;
2325    * A_2 = 6;
2326    * A[x] = A_2;
2327    */
2328 
2329   // TODO: I suppose we could refactor with a conditional initializer?
2330 
2331   std::ostringstream oss;
2332   oss << *stmt;
2333 
2334   const std::string& verification_pattern =
2335       R"IR(
2336 # CHECK: int A_1 = 1;
2337 # CHECK: A_1 = 3
2338 # CHECK: A[0] = A_1;
2339 # CHECK: int A_2 = A[x];
2340 # CHECK: if (
2341 # CHECK:   A_2 = 1;
2342 # CHECK:   A_2 = 3;
2343 # CHECK: }
2344 # CHECK: A_2 = 4;
2345 # CHECK: A_2 = 6;
2346 # CHECK: A[x] = A_2;)IR";
2347 
2348   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
2349 }
2350 
2351 // An access cuts two open overlaps and creates four scalar variables.
TEST(Registerizer,RegisterizerPartialOverlapsTwo)2352 TEST(Registerizer, RegisterizerPartialOverlapsTwo) {
2353   BufHandle a("A", {1}, kInt);
2354   VarHandle x("x", kInt);
2355   StmtPtr stmt = Block::make(
2356       {Store::make(a, {1}, Load::make(a, {0})),
2357        Store::make(a, {0}, Load::make(a, {1})),
2358        Store::make(a, {0}, Load::make(a, {1})),
2359        For::make(x, 1, 10, Store::make(a, {x}, x)),
2360        Store::make(a, {1}, Load::make(a, {0})),
2361        Store::make(a, {0}, Load::make(a, {1})),
2362        Store::make(a, {0}, Load::make(a, {1}))});
2363 
2364   /*
2365    * A[1] = A[0];
2366    * A[0] = A[1];
2367    * A[0] = A[1];
2368    * for (int x = 1; x < 10; x++) {
2369    *   A[x] = x;
2370    * }
2371    * A[1] = A[0];
2372    * A[0] = A[1];
2373    * A[0] = A[1];
2374    */
2375 
2376   stmt = registerize(stmt);
2377 
2378   /*
2379    * int A_1 = A[0];
2380    * int A_2 = A_1;
2381    * A_1 = A_2;
2382    * A_1 = A_2;
2383    * A[1] = A_2;
2384    * A[0] = A_1;
2385    * for (int x = 1; x < 10; x++) {
2386    *   A[x] = x;
2387    * }
2388    * int A_3 = A[0];
2389    * int A_4 = A_3;
2390    * A_3 = A_4;
2391    * A_3 = A_4;
2392    * A[1] = A_4;
2393    * A[0] = A_3;
2394    */
2395 
2396   std::ostringstream oss;
2397   oss << *stmt;
2398 
2399   const std::string& verification_pattern =
2400       R"IR(
2401 # CHECK: int A_1 = A[0];
2402 # CHECK: int A_2 = A_1;
2403 # CHECK: A_1 = A_2;
2404 # CHECK: A_1 = A_2;
2405 # CHECK: A[1] = A_2;
2406 # CHECK: A[0] = A_1;
2407 # CHECK: for (
2408 # CHECK:   A[x] = x;
2409 # CHECK: }
2410 # CHECK: int A_3 = A[0];
2411 # CHECK: int A_4 = A_3;
2412 # CHECK: A_3 = A_4;
2413 # CHECK: A_3 = A_4;
2414 # CHECK: A[1] = A_4;
2415 # CHECK: A[0] = A_3;)IR";
2416 
2417   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
2418 }
2419 
2420 // Nested blocks will automatically be flattened and do not provent
2421 // registerization of enclosed accesses.
TEST(Registerizer,RegisterizerNestedBlocks)2422 TEST(Registerizer, RegisterizerNestedBlocks) {
2423   BufHandle a("A", {1}, kInt);
2424   VarHandle x("x", kInt);
2425   StmtPtr stmt = Block::make(
2426       // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
2427       {Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)),
2428        Block::make({Store::make(a, {0}, Add::make(Load::make(a, {0}), 2))}),
2429        Block::make(
2430            {Store::make(a, {0}, Add::make(Load::make(a, {0}), 3)),
2431             Block::make(
2432                 {Store::make(a, {0}, Add::make(Load::make(a, {0}), 4))})})});
2433 
2434   /*
2435    * A[0] = (A[0]) + 1;
2436    * {
2437    *   A[0] = (A[0]) + 2;
2438    * }
2439    * {
2440    *   A[0] = (A[0]) + 3;
2441    *   {
2442    *     A[0] = (A[0]) + 4;
2443    *   }
2444    * }
2445    */
2446 
2447   stmt = registerize(stmt);
2448 
2449   /*
2450    * int A_1 = A[0];
2451    * A_1 = A_1 + 1;
2452    * A_1 = A_1 + 2;
2453    * A_1 = A_1 + 3;
2454    * A_1 = A_1 + 4;
2455    * A[0] = A_1;
2456    */
2457 
2458   std::ostringstream oss;
2459   oss << *stmt;
2460 
2461   const std::string& verification_pattern =
2462       R"IR(
2463 # CHECK: int A_1 = A[0];
2464 # CHECK: A_1 = A_1 + 1;
2465 # CHECK: A_1 = A_1 + 2;
2466 # CHECK: A_1 = A_1 + 3;
2467 # CHECK: A_1 = A_1 + 4;
2468 # CHECK: A[0] = A_1;)IR";
2469 
2470   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
2471 }
2472 
2473 // The access can be registerized internally to a condition, but must ensure
2474 // that both initializer and finalizer are within the same condition.
TEST(Registerizer,RegisterizerNestedConditions)2475 TEST(Registerizer, RegisterizerNestedConditions) {
2476   BufHandle a("A", {1}, kInt);
2477   VarHandle x("x", kInt);
2478   StmtPtr stmt = Block::make({Cond::make(
2479       CompareSelect::make(x, 5, CompareSelectOperation::kLT),
2480       Block::make(
2481           {Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)),
2482            Cond::make(
2483                CompareSelect::make(x, 2, CompareSelectOperation::kEQ),
2484                Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)),
2485                nullptr)}),
2486       nullptr)});
2487 
2488   /*
2489    * if (x<5 ? 1 : 0) {
2490    *   A[0] = (A[0]) + 1;
2491    *   if (x==2 ? 1 : 0) {
2492    *
2493    *     A[0] = (A[0]) + 1;
2494    *   }
2495    * }
2496    */
2497 
2498   stmt = registerize(stmt);
2499 
2500   /*
2501    * if (x<5 ? 1 : 0) {
2502    *   int A_1 = A[0];
2503    *   A_1 = A_1 + 1;
2504    *   if (x==2 ? 1 : 0) {
2505    *     A_1 = A_1 + 1;
2506    *   }
2507    * A[0] = A_1;
2508    * }
2509    */
2510 
2511   std::ostringstream oss;
2512   oss << *stmt;
2513 
2514   const std::string& verification_pattern =
2515       R"IR(
2516 # CHECK: if (x<5
2517 # CHECK:   int A_1 = A[0];
2518 # CHECK:   A_1 = A_1 + 1;
2519 # CHECK:   if (x==2
2520 # CHECK:     A_1 = A_1 + 1;
2521 # CHECK:   }
2522 # CHECK: A[0] = A_1;
2523 # CHECK: })IR";
2524 
2525   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
2526 }
2527 
2528 // If an access exists outside the scope of the condition then we can lift
2529 // nested conditional usages into the same scalar.
TEST(Registerizer,RegisterizerNestedConditionsUnhidden)2530 TEST(Registerizer, RegisterizerNestedConditionsUnhidden) {
2531   BufHandle a("A", {1}, kInt);
2532   VarHandle x("x", kInt);
2533   StmtPtr stmt = Block::make(
2534       {Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)),
2535        Cond::make(
2536            CompareSelect::make(x, 5, CompareSelectOperation::kLT),
2537            Block::make(
2538                {Store::make(a, {1}, 1),
2539                 Cond::make(
2540                     CompareSelect::make(x, 2, CompareSelectOperation::kEQ),
2541                     Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)),
2542                     nullptr)}),
2543            nullptr)});
2544 
2545   /*
2546    * A[0] = (A[0]) + 1;
2547    * if (x<5 ? 1 : 0) {
2548    *   A[1] = 1;
2549    *   if (x==2 ? 1 : 0) {
2550    *     A[0] = (A[0]) + 1;
2551    *   }
2552    * }
2553    */
2554 
2555   stmt = registerize(stmt);
2556 
2557   /*
2558    * int A_1 = A[0];
2559    * A_1 = A_1 + 1;
2560    * if (x<5 ? 1 : 0) {
2561    *   A[1] = 1;
2562    *   if (x==2 ? 1 : 0) {
2563    *     A_1 = A_1 + 1;
2564    *   }
2565    * }
2566    * A[0] = A_1;
2567    */
2568 
2569   std::ostringstream oss;
2570   oss << *stmt;
2571 
2572   const std::string& verification_pattern =
2573       R"IR(
2574 # CHECK: int A_1 = A[0];
2575 # CHECK: A_1 = A_1 + 1;
2576 # CHECK: if (x<5
2577 # CHECK:   A[1] = 1;
2578 # CHECK:   if (x==2
2579 # CHECK:     A_1 = A_1 + 1;
2580 # CHECK: A[0] = A_1;)IR";
2581 
2582   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
2583 }
2584 
TEST(Registerizer,RegisterizerNestedConditionsHiddenFirst)2585 TEST(Registerizer, RegisterizerNestedConditionsHiddenFirst) {
2586   BufHandle a("A", {1}, kInt);
2587   VarHandle x("x", kInt);
2588   StmtPtr stmt = Block::make(
2589       {Cond::make(
2590            CompareSelect::make(x, 2, CompareSelectOperation::kEQ),
2591            Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)),
2592            nullptr),
2593        Cond::make(
2594            CompareSelect::make(x, 5, CompareSelectOperation::kLT),
2595            Block::make({Cond::make(
2596                CompareSelect::make(x, 2, CompareSelectOperation::kEQ),
2597                Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)),
2598                nullptr)}),
2599            nullptr)});
2600 
2601   /*
2602    * if (x==2 ? 1 : 0) {
2603    *   A[0] = (A[0]) + 1;
2604    * }
2605    * if (x<5 ? 1 : 0) {
2606    *   if (x==2 ? 1 : 0) {
2607    *     A[0] = (A[0]) + 1;
2608    *   }
2609    * }
2610    */
2611 
2612   std::ostringstream before;
2613   before << *stmt;
2614 
2615   // No change.
2616   stmt = registerize(stmt);
2617 
2618   std::ostringstream after;
2619   after << *stmt;
2620 
2621   ASSERT_EQ(before.str(), after.str());
2622 
2623   // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
2624   stmt = registerize(stmt);
2625 }
2626 
TEST(Registerizer,RegisterizerNestedConditionsHiddenSecond)2627 TEST(Registerizer, RegisterizerNestedConditionsHiddenSecond) {
2628   BufHandle a("A", {1}, kInt);
2629   VarHandle x("x", kInt);
2630   StmtPtr stmt = Block::make(
2631       {Cond::make(
2632            CompareSelect::make(x, 5, CompareSelectOperation::kLT),
2633            Block::make({Cond::make(
2634                CompareSelect::make(x, 2, CompareSelectOperation::kEQ),
2635                Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)),
2636                nullptr)}),
2637            nullptr),
2638        Cond::make(
2639            CompareSelect::make(x, 2, CompareSelectOperation::kEQ),
2640            Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)),
2641            nullptr)});
2642 
2643   /*
2644    * if (x<5 ? 1 : 0) {
2645    *   if (x==2 ? 1 : 0) {
2646    *     A[0] = (A[0]) + 1;
2647    *   }
2648    * }
2649    * if (x==2 ? 1 : 0) {
2650    *   A[0] = (A[0]) + 1;
2651    * }
2652    */
2653 
2654   std::ostringstream before;
2655   before << *stmt;
2656 
2657   // No change.
2658   stmt = registerize(stmt);
2659 
2660   std::ostringstream after;
2661   after << *stmt;
2662 
2663   ASSERT_EQ(before.str(), after.str());
2664 
2665   // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
2666   stmt = registerize(stmt);
2667 }
2668 
2669 // If an access is cut by another access internal to a condition block, it still
2670 // cuts the access.
TEST(Registerizer,RegisterizerNestedConditionsCut)2671 TEST(Registerizer, RegisterizerNestedConditionsCut) {
2672   BufHandle a("A", {1}, kInt);
2673   VarHandle x("x", kInt);
2674   StmtPtr stmt = Block::make(
2675       {Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)),
2676        Cond::make(
2677            CompareSelect::make(x, 5, CompareSelectOperation::kLT),
2678            Block::make(
2679                {Store::make(a, {x}, 1),
2680                 Cond::make(
2681                     CompareSelect::make(x, 2, CompareSelectOperation::kEQ),
2682                     Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)),
2683                     nullptr)}),
2684            nullptr)});
2685 
2686   /*
2687    * A[0] = (A[0]) + 1;
2688    * if (x<5 ? 1 : 0) {
2689    *   A[x] = 1;
2690    *   if (x==2 ? 1 : 0) {
2691    *
2692    *     A[0] = (A[0]) + 1;
2693    *   }
2694    * }
2695    */
2696 
2697   std::ostringstream before;
2698   before << *stmt;
2699 
2700   // No change.
2701   stmt = registerize(stmt);
2702 
2703   std::ostringstream after;
2704   after << *stmt;
2705 
2706   ASSERT_EQ(before.str(), after.str());
2707 }
2708 
TEST(Registerizer,RegisterizerNestedConditionLoopHidden)2709 TEST(Registerizer, RegisterizerNestedConditionLoopHidden) {
2710   BufHandle a("A", {10}, kInt);
2711   BufHandle b("B", {10}, kInt);
2712   VarHandle x("x", kInt);
2713   StmtPtr stmt = Block::make(
2714       {Cond::make(
2715            CompareSelect::make(x, 2, CompareSelectOperation::kEQ),
2716            Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)),
2717            nullptr),
2718        For::make(
2719            x,
2720            0,
2721            10,
2722            Block::make(
2723                {Store::make(b, {x}, 0),
2724                 Cond::make(
2725                     CompareSelect::make(x, 2, CompareSelectOperation::kEQ),
2726                     Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)),
2727                     nullptr)}))});
2728 
2729   /*
2730    * if (x==2 ? 1 : 0) {
2731    *   A[0] = (A[0]) + 1;
2732    * }
2733    * for (int x = 0; x < 10; x++) {
2734    *   B[x] = 0;     <-- this is only here to prevent Loop/Cond reordering.
2735    *   if (x==2 ? 1 : 0) {
2736    *     A[0] = (A[0]) + 1;
2737    *   }
2738    * }
2739    */
2740 
2741   std::ostringstream before;
2742   before << *stmt;
2743 
2744   // No change.
2745   stmt = registerize(stmt);
2746 
2747   std::ostringstream after;
2748   after << *stmt;
2749 
2750   ASSERT_EQ(before.str(), after.str());
2751 }
2752 
2753 // Three loops and four element regions, three of which should be registerized
2754 // at different levels of the IR.
TEST(Registerizer,RegisterizerNestedConditionThreeDeep)2755 TEST(Registerizer, RegisterizerNestedConditionThreeDeep) {
2756   BufHandle a("A", {10}, kInt);
2757   BufHandle b("B", {10}, kInt);
2758   VarHandle x("x", kInt);
2759   StmtPtr stmt = Block::make(
2760       {Store::make(a, {4}, 0),
2761        Cond::make(
2762            CompareSelect::make(x, 2, CompareSelectOperation::kGT),
2763            Cond::make(
2764                CompareSelect::make(x, 3, CompareSelectOperation::kGT),
2765                Block::make({
2766                    Cond::make(
2767                        CompareSelect::make(x, 4, CompareSelectOperation::kGT),
2768                        Block::make({
2769                            Store::make(
2770                                a, {1}, Add::make(Load::make(a, {1}), 1)),
2771                            Store::make(
2772                                a, {2}, Add::make(Load::make(a, {2}), 1)),
2773                            Store::make(
2774                                a, {3}, Add::make(Load::make(a, {3}), 1)),
2775                            Store::make(
2776                                a, {4}, Add::make(Load::make(a, {4}), 1)),
2777                            Store::make(
2778                                a, {1}, Add::make(Load::make(a, {1}), 1)),
2779                        }),
2780                        nullptr),
2781                    Store::make(a, {2}, Add::make(Load::make(a, {2}), 1)),
2782                }),
2783                nullptr),
2784            nullptr)});
2785 
2786   /*
2787    * A[4] = 0;
2788    * if (x>2 ? 1 : 0) {
2789    *   if (x>3 ? 1 : 0) {
2790    *     if (x>4 ? 1 : 0) {
2791    *       A[1] = (A[1]) + 1;
2792    *       A[2] = (A[2]) + 1;
2793    *       A[3] = (A[3]) + 1;
2794    *       A[4] = (A[4]) + 1;
2795    *       A[1] = (A[1]) + 1;
2796    *     }
2797    *     A[2] = (A[2]) + 1;
2798    *   }
2799    * }
2800    */
2801 
2802   stmt = registerize(stmt);
2803 
2804   /*
2805    * int A_1 = 0;
2806    * if (x>2 ? 1 : 0) {
2807    *   if (x>3 ? 1 : 0) {
2808    *     int A_3 = A[2];
2809    *     if (x>4 ? 1 : 0) {
2810    *       int A_2 = A[1];
2811    *       A_2 = A_2 + 1;
2812    *       A_3 = A_3 + 1;
2813    *       A[3] = (A[3]) + 1;
2814    *       A_1 = A_1 + 1;
2815    *       A_2 = A_2 + 1;
2816    *       A[1] = A_2;
2817    *     }
2818    *     A_3 = A_3 + 1;
2819    *     A[2] = A_3;
2820    *   }
2821    * }
2822    * A[4] = A_1;
2823    */
2824 
2825   std::ostringstream oss;
2826   oss << *stmt;
2827 
2828   const std::string& verification_pattern =
2829       R"IR(
2830 # CHECK: int A_1 = 0;
2831 # CHECK: if (x>2 ? 1 : 0) {
2832 # CHECK:   if (x>3 ? 1 : 0) {
2833 # CHECK:     int A_3 = A[2];
2834 # CHECK:     if (x>4 ? 1 : 0) {
2835 # CHECK:       int A_2 = A[1];
2836 # CHECK:       A_2 = A_2 + 1;
2837 # CHECK:       A_3 = A_3 + 1;
2838 # CHECK:       A[3] = (A[3]) + 1;
2839 # CHECK:       A_1 = A_1 + 1;
2840 # CHECK:       A_2 = A_2 + 1;
2841 # CHECK:       A[1] = A_2;
2842 # CHECK:     }
2843 # CHECK:     A_3 = A_3 + 1;
2844 # CHECK:     A[2] = A_3;
2845 # CHECK:   }
2846 # CHECK: }
2847 # CHECK: A[4] = A_1;)IR";
2848 
2849   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
2850 }
2851 
2852 // Can replace a simple scalar access with a local variable even when that
2853 // variable is an outer loop var.
TEST(Registerizer,RegisterizerNestedLoopSimple)2854 TEST(Registerizer, RegisterizerNestedLoopSimple) {
2855   BufHandle a("A", {1}, kInt);
2856   VarHandle x("x", kInt);
2857   VarHandle y("y", kInt);
2858   StmtPtr stmt = Block::make({For::make(
2859       y,
2860       0,
2861       10,
2862       For::make(
2863           x,
2864           0,
2865           10,
2866           Block::make(
2867               {Store::make(a, {y}, Add::make(Load::make(a, {y}), x))})))});
2868 
2869   /*
2870    * for (int y = 0; y < 10; y++) {
2871    *   for (int x = 0; x < 10; x++) {
2872    *     A[y] = (A[y]) + x;
2873    *   }
2874    * }
2875    */
2876 
2877   stmt = registerize(stmt);
2878 
2879   /*
2880    * for (int y = 0; y < 10; y++) {
2881    *   int A_1 = A[y];
2882    *   for (int x = 0; x < 10; x++) {
2883    *     A_1 = A_1 + x;
2884    *   }
2885    * A[y] = A_1;
2886    * }
2887    */
2888 
2889   std::ostringstream oss;
2890   oss << *stmt;
2891 
2892   const std::string& verification_pattern =
2893       R"IR(
2894 # CHECK: for (int y
2895 # CHECK:   int A_1 = A[y];
2896 # CHECK:   for (int x
2897 # CHECK:     A_1 = A_1 + x;
2898 # CHECK:   }
2899 # CHECK:   A[y] = A_1;
2900 # CHECK: })IR";
2901 
2902   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
2903 }
2904 
2905 // Test the positive case of the hiddenAccess split, where an internal
2906 // conditional access can be hoisted up through a loop to match an existing
2907 // access in a higher scope and the two can be registerized.
TEST(Registerizer,RegisterizerHiddenAccessYes)2908 TEST(Registerizer, RegisterizerHiddenAccessYes) {
2909   BufHandle a("A", {10}, kInt);
2910   BufHandle b("B", {10}, kInt);
2911   VarHandle x("x", kInt);
2912   VarHandle y("y", kInt);
2913   StmtPtr stmt = Block::make({Cond::make(
2914       CompareSelect::make(x, 2, CompareSelectOperation::kEQ),
2915       Block::make(
2916           {Store::make(a, {0}, 0),
2917            For::make(
2918                x,
2919                0,
2920                10,
2921                Block::make(
2922                    {Store::make(b, {x}, 0),
2923                     // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
2924                     Cond::make(
2925                         CompareSelect::make(x, 3, CompareSelectOperation::kEQ),
2926                         For::make(
2927                             y,
2928                             0,
2929                             10,
2930                             Store::make(
2931                                 a, {0}, Add::make(Load::make(a, {0}), 1))),
2932                         nullptr)}))}),
2933       nullptr)});
2934 
2935   /*
2936    * if (x==2 ? 1 : 0) {
2937    *   A[0] = 0;
2938    *   for (int x = 0; x < 10; x++) {
2939    *     B[x] = 0;
2940    *     if (x==3 ? 1 : 0) {
2941    *       for (int y = 0; y < 10; y++) {
2942    *         A[0] = (A[0]) + 1;
2943    *       }
2944    *     }
2945    *   }
2946    * }
2947    */
2948 
2949   stmt = registerize(stmt);
2950 
2951   /*
2952    * if (x==2 ? 1 : 0) {
2953    *   int A_1 = 0;
2954    *   for (int x = 0; x < 10; x++) {
2955    *     B[x] = 0;
2956    *     if (x==3 ? 1 : 0) {
2957    *       for (int y = 0; y < 10; y++) {
2958    *         A_1 = A_1 + 1;
2959    *       }
2960    *     }
2961    *   }
2962    *   A[0] = A_1;
2963    * }
2964    */
2965 
2966   std::ostringstream oss;
2967   oss << *stmt;
2968 
2969   const std::string& verification_pattern =
2970       R"IR(
2971 # CHECK: if (x==2
2972 # CHECK:   int A_1 = 0;
2973 # CHECK:   for (int x
2974 # CHECK:     B[x] = 0;
2975 # CHECK:     if (x==3
2976 # CHECK:       for (int y
2977 # CHECK:         A_1 = A_1 + 1;
2978 # CHECK:       }
2979 # CHECK:     }
2980 # CHECK:   }
2981 # CHECK:  A[0] = A_1;
2982 # CHECK: })IR";
2983 
2984   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
2985 }
2986 
2987 // Test the negative case of the hiddenAccess split, where the hoisted access is
2988 // never unhidden at a higher scope and registerization occurs at the lower
2989 // scope.
TEST(Registerizer,RegisterizerHiddenAccessNo)2990 TEST(Registerizer, RegisterizerHiddenAccessNo) {
2991   BufHandle a("A", {10}, kInt);
2992   BufHandle b("B", {10}, kInt);
2993   VarHandle x("x", kInt);
2994   VarHandle y("y", kInt);
2995   StmtPtr stmt = Block::make({Cond::make(
2996       CompareSelect::make(x, 2, CompareSelectOperation::kEQ),
2997       Block::make({For::make(
2998           x,
2999           0,
3000           10,
3001           Block::make(
3002               {Store::make(b, {x}, 0),
3003                // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
3004                Cond::make(
3005                    CompareSelect::make(x, 3, CompareSelectOperation::kEQ),
3006                    For::make(
3007                        y,
3008                        0,
3009                        10,
3010                        Store::make(a, {0}, Add::make(Load::make(a, {0}), 1))),
3011                    nullptr)}))}),
3012       nullptr)});
3013 
3014   /*
3015    * if (x==2 ? 1 : 0) {
3016    *   A[0] = 0;
3017    *   for (int x = 0; x < 10; x++) {
3018    *     B[x] = 0;
3019    *     if (x==3 ? 1 : 0) {
3020    *       for (int y = 0; y < 10; y++) {
3021    *         A[0] = (A[0]) + 1;
3022    *       }
3023    *     }
3024    *   }
3025    * }
3026    */
3027 
3028   stmt = registerize(stmt);
3029 
3030   /*
3031    * if (x==2 ? 1 : 0) {
3032    *   for (int x = 0; x < 10; x++) {
3033    *     B[x] = 0;
3034    *     if (x==3 ? 1 : 0) {
3035    *       int A_1 = A[0];
3036    *       for (int y = 0; y < 10; y++) {
3037    *         A_1 = A_1 + 1;
3038    *       }
3039    *       A[0] = A_1;
3040    *     }
3041    *   }
3042    * }
3043    */
3044 
3045   std::ostringstream oss;
3046   oss << *stmt;
3047 
3048   const std::string& verification_pattern =
3049       R"IR(
3050 # CHECK: if (x==2
3051 # CHECK:   for (int x
3052 # CHECK:     B[x] = 0;
3053 # CHECK:     if (x==3
3054 # CHECK:       int A_1 = A[0];
3055 # CHECK:       for (int y
3056 # CHECK:         A_1 = A_1 + 1;
3057 # CHECK:       }
3058 # CHECK:       A[0] = A_1;
3059 # CHECK:     }
3060 # CHECK:   }
3061 # CHECK: })IR";
3062 
3063   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
3064 }
3065 
3066 // In this case the conditional access must be hoisted by two loops, there are
3067 // two accesses here one is unhidden and the other isnt. A[0] can be
3068 // registerized but B[0] cannot.
TEST(Registerizer,RegisterizerHiddenAccessMultiLoop)3069 TEST(Registerizer, RegisterizerHiddenAccessMultiLoop) {
3070   BufHandle a("A", {10}, kInt);
3071   BufHandle b("B", {10}, kInt);
3072   VarHandle x("x", kInt);
3073   VarHandle y("y", kInt);
3074   StmtPtr stmt = Block::make({Cond::make(
3075       CompareSelect::make(x, 2, CompareSelectOperation::kEQ),
3076       Block::make(
3077           {Store::make(a, {0}, 0),
3078            For::make(
3079                x,
3080                0,
3081                10,
3082                For::make(
3083                    y,
3084                    0,
3085                    10,
3086                    Block::make({Cond::make(
3087                        CompareSelect::make(y, 3, CompareSelectOperation::kEQ),
3088                        Block::make(
3089                            {Store::make(
3090                                 a, {0}, Add::make(Load::make(a, {0}), 1)),
3091                             Store::make(
3092                                 b, {0}, Add::make(Load::make(b, {0}), 1))}),
3093                        nullptr)})))}),
3094       nullptr)});
3095 
3096   /*
3097    * if (x==2 ? 1 : 0) {
3098    *   A[0] = 0;
3099    *   for (int x = 0; x < 10; x++) {
3100    *     for (int y = 0; y < 10; y++) {
3101    *       if (y==3 ? 1 : 0) {
3102    *         A[0] = (A[0]) + 1;
3103    *         B[0] = (B[0]) + 1;
3104    *       }
3105    *     }
3106    *   }
3107    * }
3108    */
3109 
3110   stmt = registerize(stmt);
3111 
3112   /*
3113    * if (x==2 ? 1 : 0) {
3114    *   int A_1 = 0;
3115    *   for (int x = 0; x < 10; x++) {
3116    *     for (int y = 0; y < 10; y++) {
3117    *       if (y==3 ? 1 : 0) {
3118    *         A_1 = A_1 + 1;
3119    *         B[0] = (B[0]) + 1;
3120    *       }
3121    *     }
3122    *   }
3123    *   A[0] = A_1;
3124    * }
3125    */
3126 
3127   std::ostringstream oss;
3128   oss << *stmt;
3129 
3130   const std::string& verification_pattern =
3131       R"IR(
3132 # CHECK: if (x==2
3133 # CHECK:   int A_1 = 0;
3134 # CHECK:   for (int x
3135 # CHECK:     for (int y
3136 # CHECK:       if (y==3
3137 # CHECK:         A_1 = A_1 + 1;
3138 # CHECK:         B[0] = (B[0]) + 1;
3139 # CHECK:       }
3140 # CHECK:     }
3141 # CHECK:   }
3142 # CHECK:  A[0] = A_1;
3143 # CHECK: })IR";
3144 
3145   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
3146 }
3147 
3148 // Accesses are registerized inside two conditions, but the immediate parent is
3149 // not a condition.
TEST(Registerizer,RegisterizerTwoConditionalLoops)3150 TEST(Registerizer, RegisterizerTwoConditionalLoops) {
3151   BufHandle a("A", {1}, kInt);
3152   VarHandle x("x", kInt);
3153   StmtPtr stmt = Block::make(
3154       {Cond::make(
3155            CompareSelect::make(x, 5, CompareSelectOperation::kLT),
3156            For::make(
3157                x, 0, 10, Store::make(a, {0}, Add::make(Load::make(a, {0}), 1))),
3158            nullptr),
3159        Cond::make(
3160            CompareSelect::make(x, 5, CompareSelectOperation::kGT),
3161            For::make(
3162                x, 0, 10, Store::make(a, {0}, Add::make(Load::make(a, {0}), 1))),
3163            nullptr)});
3164 
3165   /*
3166    * if (x<5 ? 1 : 0) {
3167    *   for (int x = 0; x < 10; x++) {
3168    *     A[0] = (A[0]) + 1;
3169    *   }
3170    * }
3171    * if (x>5 ? 1 : 0) {
3172    *   for (int x = 0; x < 10; x++) {
3173    *     A[0] = (A[0]) + 1;
3174    *   }
3175    * }
3176    */
3177 
3178   stmt = registerize(stmt);
3179 
3180   /*
3181    * if (x<5 ? 1 : 0) {
3182    *   int A_1 = A[0];
3183    *   for (int x = 0; x < 10; x++) {
3184    *     A_1 = A_1 + 1;
3185    *   }
3186    *   A[0] = A_1;
3187    * }
3188    * if (x>5 ? 1 : 0) {
3189    *   int A_2 = A[0];
3190    *   for (int x = 0; x < 10; x++) {
3191    *     A_2 = A_2 + 1;
3192    *   }
3193    *   A[0] = A_2;
3194    * }
3195    */
3196 
3197   std::ostringstream oss;
3198   oss << *stmt;
3199 
3200   const std::string& verification_pattern =
3201       R"IR(
3202 # CHECK: if (x<5
3203 # CHECK:   int A_1 = A[0];
3204 # CHECK:   for (int x
3205 # CHECK:     A_1 = A_1 + 1;
3206 # CHECK:   }
3207 # CHECK:   A[0] = A_1;
3208 # CHECK: }
3209 # CHECK: if (x>5
3210 # CHECK:   int A_2 = A[0];
3211 # CHECK:   for (int x
3212 # CHECK:     A_2 = A_2 + 1;
3213 # CHECK:   }
3214 # CHECK:   A[0] = A_2;
3215 # CHECK: })IR";
3216 
3217   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
3218 }
3219 
3220 // Accesses are registerized inside two conditions, cut in the middle.
TEST(Registerizer,RegisterizerTwoConditionalLoopsCut)3221 TEST(Registerizer, RegisterizerTwoConditionalLoopsCut) {
3222   BufHandle a("A", {1}, kInt);
3223   VarHandle x("x", kInt);
3224   StmtPtr stmt = Block::make(
3225       {Cond::make(
3226            CompareSelect::make(x, 5, CompareSelectOperation::kLT),
3227            For::make(
3228                x, 0, 10, Store::make(a, {0}, Add::make(Load::make(a, {0}), 1))),
3229            nullptr),
3230        For::make(x, 0, 10, Store::make(a, {x}, 1)),
3231        Cond::make(
3232            CompareSelect::make(x, 5, CompareSelectOperation::kGT),
3233            For::make(
3234                x, 0, 10, Store::make(a, {0}, Add::make(Load::make(a, {0}), 1))),
3235            nullptr)});
3236 
3237   /*
3238    * if (x<5 ? 1 : 0) {
3239    *   for (int x = 0; x < 10; x++) {
3240    *     A[0] = (A[0]) + 1;
3241    *   }
3242    * }
3243    * for (int x = 0; x < 10; x++) {
3244    *   A[x] = 1;
3245    * }
3246    * if (x>5 ? 1 : 0) {
3247    *   for (int x = 0; x < 10; x++) {
3248    *     A[0] = (A[0]) + 1;
3249    *   }
3250    * }
3251    */
3252 
3253   stmt = registerize(stmt);
3254 
3255   /*
3256    * if (x<5 ? 1 : 0) {
3257    *   int A_1 = A[0];
3258    *   for (int x = 0; x < 10; x++) {
3259    *     A_1 = A_1 + 1;
3260    *   }
3261    *   A[0] = A_1;
3262    * }
3263    * for (int x = 0; x < 10; x++) {
3264    *   A[x] = 1;
3265    * }
3266    * if (x>5 ? 1 : 0) {
3267    *   int A_2 = A[0];
3268    *   for (int x = 0; x < 10; x++) {
3269    *     A_2 = A_2 + 1;
3270    *   }
3271    *   A[0] = A_2;
3272    * }
3273    */
3274 
3275   std::ostringstream oss;
3276   oss << *stmt;
3277 
3278   const std::string& verification_pattern =
3279       R"IR(
3280 # CHECK: if (x<5
3281 # CHECK:   int A_1 = A[0];
3282 # CHECK:   for (int x
3283 # CHECK:     A_1 = A_1 + 1;
3284 # CHECK:   }
3285 # CHECK:   A[0] = A_1;
3286 # CHECK: }
3287 # CHECK: for (int x
3288 # CHECK:  A[x] = 1;
3289 # CHECK: if (x>5
3290 # CHECK:   int A_2 = A[0];
3291 # CHECK:   for (int x
3292 # CHECK:     A_2 = A_2 + 1;
3293 # CHECK:   }
3294 # CHECK:   A[0] = A_2;
3295 # CHECK: })IR";
3296 
3297   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
3298 }
3299 
3300 // references a Let var in a local scope which cannot be hoisted out of the
3301 // loop.
TEST(Registerizer,RegisterizerLoopLetVar)3302 TEST(Registerizer, RegisterizerLoopLetVar) {
3303   BufHandle a("A", {10}, kInt);
3304   VarHandle x("x", kInt);
3305   VarHandle y("y", kInt);
3306   StmtPtr stmt = IRSimplifier::simplify(Block::make({For::make(
3307       x,
3308       0,
3309       10,
3310       Block::make(
3311           {Let::make(y, 30),
3312            Store::make(a, {y}, Add::make(x, Load::make(a, {y})))}))}));
3313 
3314   /*
3315    * for (int x = 0; x < 10; x++) {
3316    *   int y = 30;
3317    *   A[y] = x + (A[y]);
3318    * }
3319    */
3320 
3321   std::ostringstream before;
3322   before << *stmt;
3323 
3324   // No change.
3325   stmt = registerize(stmt);
3326 
3327   std::ostringstream after;
3328   after << *stmt;
3329 
3330   ASSERT_EQ(before.str(), after.str());
3331 }
3332 
3333 // references a Let var in an outer scope that does not prevent hoisting the
3334 // initializer.
TEST(Registerizer,RegisterizerLoopLetVarOuter)3335 TEST(Registerizer, RegisterizerLoopLetVarOuter) {
3336   BufHandle a("A", {10}, kInt);
3337   VarHandle x("x", kInt);
3338   VarHandle y("y", kInt);
3339   StmtPtr stmt = Block::make(
3340       {Let::make(y, 30),
3341        For::make(
3342            x,
3343            0,
3344            10,
3345            Block::make(
3346                {Store::make(a, {y}, Add::make(x, Load::make(a, {y})))}))});
3347 
3348   /*
3349    * int y = 30;
3350    * for (int x = 0; x < 10; x++) {
3351    *   A[y] = x + (A[y]);
3352    * }
3353    */
3354 
3355   stmt = registerize(stmt);
3356 
3357   /*
3358    * int y = 30;
3359    * int A_1 = A[y];
3360    * for (int x = 0; x < 10; x++) {
3361    *   A_1 = A_1 + x;
3362    * }
3363    * A[y] = A_1;
3364    */
3365 
3366   std::ostringstream oss;
3367   oss << *stmt;
3368 
3369   const std::string& verification_pattern =
3370       R"IR(
3371 # CHECK: int y = 30;
3372 # CHECK: int A_1 = A[y];
3373 # CHECK: for (int x
3374 # CHECK:   A_1 = A_1 + x;
3375 # CHECK: A[y] = A_1;)IR";
3376 
3377   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
3378 }
3379 
3380 // Okay so the registerizer generally goes after index flattening, but just in
3381 // case. Test multi index registerization.
TEST(Registerizer,RegisterizerMultiDim)3382 TEST(Registerizer, RegisterizerMultiDim) {
3383   BufHandle a("A", {3, 4, 5}, kInt);
3384   VarHandle x("x", kInt);
3385   StmtPtr stmt = Block::make(
3386       {Store::make(a, {0, 1, 2}, 0),
3387        For::make(
3388            x,
3389            0,
3390            10,
3391            Block::make({Store::make(
3392                a, {0, 1, 2}, Add::make(Load::make(a, {0, 1, 2}), x))}))});
3393 
3394   /*
3395    * A[0, 1, 2] = 0;
3396    * for (int x = 0; x < 10; x++) {
3397    *   A[0, 1, 2] = (A[0, 1, 2]) + x;
3398    * }
3399    */
3400 
3401   stmt = registerize(stmt);
3402 
3403   /*
3404    * int A_1 = 0;
3405    * for (int x = 0; x < 10; x++) {
3406    *   A_1 = x + A_1;
3407    * }
3408    * A[0, 1, 2] = A_1;
3409    */
3410 
3411   std::ostringstream oss;
3412   oss << *stmt;
3413 
3414   const std::string& verification_pattern =
3415       R"IR(
3416 # CHECK: int A_1 = 0;
3417 # CHECK: for (int x = 0; x < 10; x++)
3418 # CHECK-NOT: A[
3419 # CHECK:   A_1 =
3420 # CHECK: A[0, 1, 2] = A_1;)IR";
3421 
3422   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
3423 }
3424 
3425 // Wont registerize if only some dims match, but will still registerize distinct
3426 // elements.
TEST(Registerizer,RegisterizerMultiDimPartial)3427 TEST(Registerizer, RegisterizerMultiDimPartial) {
3428   BufHandle a("A", {3, 4, 5}, kInt);
3429   VarHandle x("x", kInt);
3430   StmtPtr stmt = Block::make(
3431       {Store::make(a, {0, 1, 2}, 0),
3432        For::make(
3433            x,
3434            0,
3435            10,
3436            Block::make({Store::make(
3437                a, {0, 2, 2}, Add::make(Load::make(a, {0, 1, 4}), x))}))});
3438 
3439   /*
3440    * A[0, 1, 2] = 0;
3441    * for (int x = 0; x < 10; x++) {
3442    *   A[0, 2, 2] = (A[0, 1, 4]) + x;
3443    * }
3444    */
3445 
3446   stmt = registerize(stmt);
3447 
3448   /*
3449    * A[0, 1, 2] = 0;
3450    * int A_1 = A[0, 1, 4];
3451    * int A_2 = A[0, 2, 2];
3452    * for (int x = 0; x < 10; x++) {
3453    *   A_2 = A_1 + x;
3454    * }
3455    * A[0, 2, 2] = A_2;
3456    */
3457 
3458   std::ostringstream oss;
3459   oss << *stmt;
3460 
3461   const std::string& verification_pattern =
3462       R"IR(
3463 # CHECK: A[0, 1, 2] = 0;
3464 # CHECK: int A_1 = A[0, 1, 4];
3465 # CHECK: int A_2 = A[0, 2, 2];
3466 # CHECK: for (
3467 # CHECK:   A_2 = A_1 + x;
3468 # CHECK: A[0, 2, 2] = A_2;)IR";
3469 
3470   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
3471 }
3472 
3473 // If they could overlap across all dimensions we cannot registerize.
TEST(Registerizer,RegisterizerMultiDimOverlap)3474 TEST(Registerizer, RegisterizerMultiDimOverlap) {
3475   BufHandle a("A", {3, 4, 5}, kInt);
3476   VarHandle x("x", kInt);
3477   VarHandle y("y", kInt);
3478   StmtPtr stmt = Block::make(
3479       {Store::make(a, {0, 1, 2}, 0),
3480        For::make(
3481            x,
3482            0,
3483            10,
3484            Block::make({Store::make(
3485                a, {0, x, 2}, Add::make(Load::make(a, {y, 2, 2}), x))}))});
3486   stmt = IRSimplifier::simplify(stmt);
3487 
3488   /*
3489    * A[0, 1, 2] = 0;
3490    * for (int x = 0; x < 10; x++) {
3491    *   A[0, x, 2] = (A[y, 2, 2]) + x;
3492    * }
3493    */
3494 
3495   std::ostringstream before;
3496   before << *stmt;
3497 
3498   // No change.
3499   stmt = registerize(stmt);
3500 
3501   std::ostringstream after;
3502   after << *stmt;
3503 
3504   ASSERT_EQ(before.str(), after.str());
3505 }
3506 
3507 // But, if one dimension is known to be distinct they do not overlap.
TEST(Registerizer,RegisterizerMultiDimPartialOverlap)3508 TEST(Registerizer, RegisterizerMultiDimPartialOverlap) {
3509   BufHandle a("A", {3, 4, 5}, kInt);
3510   VarHandle x("x", kInt);
3511   VarHandle y("y", kInt);
3512   StmtPtr stmt = Block::make(
3513       {Store::make(a, {0, 1, 2}, 0),
3514        For::make(
3515            x,
3516            0,
3517            10,
3518            Block::make({Store::make(
3519                a, {0, x, 2}, Add::make(Load::make(a, {y, 2, 4}), x))}))});
3520 
3521   /*
3522    * A[0, 1, 2] = 0;                          <---- 2nd dim overlaps with store.
3523    * for (int x = 0; x < 10; x++) {
3524    *   A[0, x, 2] = (A[y, 2, 4]) + x;           <---- 3rd dim has constant diff.
3525    * }
3526    */
3527 
3528   stmt = registerize(stmt);
3529 
3530   /*
3531    * A[0, 1, 2] = 0;
3532    * int A_1 = A[y, 2, 4];
3533    * for (int x = 0; x < 10; x++) {
3534    *   A[0, x, 2] = A_1 + x;
3535    * }
3536    */
3537 
3538   std::ostringstream oss;
3539   oss << *stmt;
3540 
3541   const std::string& verification_pattern =
3542       R"IR(
3543 # CHECK: A[0, 1, 2] = 0;
3544 # CHECK: int A_1 = A[y, 2, 4];
3545 # CHECK: for (
3546 # CHECK:   A[0, x, 2] = A_1 + x;
3547 # CHECK: })IR";
3548 
3549   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
3550 }
3551 
3552 // A 3D reduction with different input dimensionality.
TEST(Registerizer,RegisterizerMultiDim3DReduction1)3553 TEST(Registerizer, RegisterizerMultiDim3DReduction1) {
3554   BufHandle a("A", {10}, kInt);
3555   BufHandle b("B", {10, 10}, kInt);
3556   BufHandle c("C", {10, 10, 10}, kInt);
3557   VarHandle x("x", kInt);
3558   VarHandle y("y", kInt);
3559   VarHandle z("z", kInt);
3560   StmtPtr stmt = For::make(
3561       x,
3562       0,
3563       10,
3564       For::make(
3565           y,
3566           0,
3567           10,
3568           For::make(
3569               z,
3570               0,
3571               10,
3572               Store::make(
3573                   c,
3574                   {x, y, z},
3575                   Add::make(
3576                       Load::make(c, {x, y, z}),
3577                       Mul::make(Load::make(b, {x, y}), Load::make(a, {x})))))));
3578 
3579   /*
3580    * for (int x = 0; x < 10; x++) {
3581    *   for (int y = 0; y < 10; y++) {
3582    *     for (int z = 0; z < 10; z++) {
3583    *       C[x, y, z] = (C[x, y, z]) + (B[x, y]) * (A[x]);
3584    *     }
3585    *   }
3586    * }
3587    */
3588 
3589   // We can registerize the A and B access since they can be hoisted before
3590   // hitting a dependent loop var.
3591 
3592   stmt = registerize(stmt);
3593 
3594   /*
3595    * for (int x = 0; x < 10; x++) {
3596    *   int A_1 = A[x];
3597    *   for (int y = 0; y < 10; y++) {
3598    *     int B_1 = B[x, y];
3599    *     for (int z = 0; z < 10; z++) {
3600    *       C[x, y, z] = A_1 * B_1 + (C[x, y, z]);
3601    *     }
3602    *   }
3603    * }
3604    */
3605 
3606   std::ostringstream oss;
3607   oss << *stmt;
3608 
3609   const std::string& verification_pattern =
3610       R"IR(
3611 # CHECK: for (int x
3612 # CHECK:   int A_1 = A[x];
3613 # CHECK:   for (int y
3614 # CHECK:     int B_1 = B[x, y];
3615 # CHECK:       for (int z
3616 # CHECK:         C[x, y, z] = A_1 * B_1 + (C[x, y, z]);
3617 # CHECK: })IR";
3618 
3619   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
3620 }
3621 
3622 // A 3D reduction with the same smaller dimensionality using different loop
3623 // vars.
TEST(Registerizer,RegisterizerMultiDim3DReduction2)3624 TEST(Registerizer, RegisterizerMultiDim3DReduction2) {
3625   BufHandle a("A", {10}, kInt);
3626   BufHandle b("B", {10}, kInt);
3627   BufHandle c("C", {10}, kInt);
3628   VarHandle x("x", kInt);
3629   VarHandle y("y", kInt);
3630   VarHandle z("z", kInt);
3631   StmtPtr stmt = For::make(
3632       x,
3633       0,
3634       10,
3635       // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
3636       For::make(
3637           y,
3638           0,
3639           10,
3640           For::make(
3641               z,
3642               0,
3643               10,
3644               Store::make(
3645                   c,
3646                   {x},
3647                   Add::make(
3648                       Load::make(c, {x}),
3649                       Mul::make(Load::make(b, {y}), Load::make(a, {x})))))));
3650 
3651   /*
3652    * for (int x = 0; x < 10; x++) {
3653    *   for (int y = 0; y < 10; y++) {
3654    *     for (int z = 0; z < 10; z++) {
3655    *       C[x] = (C[x]) + (B[y]) * (A[x]);
3656    *     }
3657    *   }
3658    * }
3659    */
3660 
3661   // We can registerize all accesses, the A and C access can be hoisted to the
3662   // outer loop since they depend only on it's loop var while the B can only be
3663   // raised to the loop of y.
3664 
3665   stmt = registerize(stmt);
3666 
3667   /*
3668    * for (int x = 0; x < 10; x++) {
3669    *   int A_1 = A[x];
3670    *   int C_1 = C[x];
3671    *   for (int y = 0; y < 10; y++) {
3672    *     int B_1 = B[y];
3673    *     for (int z = 0; z < 10; z++) {
3674    *       C_1 = A_1 * B_1 + C_1;
3675    *     }
3676    *   }
3677    *   C[x] = C_1;
3678    * }
3679    */
3680 
3681   std::ostringstream oss;
3682   oss << *stmt;
3683 
3684   const std::string& verification_pattern =
3685       R"IR(
3686 # CHECK: for (int x
3687 # CHECK:   int A_1 = A[x];
3688 # CHECK:   int C_1 = C[x];
3689 # CHECK:   for (int y
3690 # CHECK:     int B_1 = B[y];
3691 # CHECK:       for (int z
3692 # CHECK:         C_1 = A_1 * B_1 + C_1;
3693 # CHECK:       }
3694 # CHECK:     }
3695 # CHECK:   C[x] = C_1;
3696 # CHECK: })IR";
3697 
3698   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
3699 }
3700 
3701 } // namespace jit
3702 } // namespace torch
3703