xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/tensorexpr/loopnest_randomization.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <algorithm>
2 #include <iostream>
3 #include <random>
4 #include <stdexcept>
5 #include <typeinfo>
6 #include <unordered_map>
7 #include <unordered_set>
8 #include <vector>
9 
10 #include <torch/csrc/jit/jit_log.h>
11 #include <torch/csrc/jit/jit_opt_limit.h>
12 #include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
13 #include <torch/csrc/jit/tensorexpr/loopnest.h>
14 #include <torch/csrc/jit/tensorexpr/loopnest_randomization.h>
15 
16 namespace torch::jit::tensorexpr {
17 
18 namespace randomization_helper {
19 
max_transformations(int n_max_transforms)20 static int64_t max_transformations(int n_max_transforms) {
21   // Reuse the env variable PYTORCH_JIT_OPT_LIMIT to control the max number of
22   // transformations.  Example - set the env variable
23   // PYTORCH_JIT_OPT_LIMIT="loopnest_randomization=10" to set max
24   // transformations to 10.  This can be helpful in gradually reducing the
25   // number of transformations when we see an error.
26   if (!JIT_OPT_ALLOWED) {
27     return n_max_transforms;
28   }
29   int max_transforms = 1;
30   while (JIT_OPT_ALLOWED && max_transforms < n_max_transforms) {
31     max_transforms++;
32   }
33   return max_transforms;
34 }
35 
GetAllPerfectlyNestedLoopNests(std::vector<ForPtr> loops)36 static std::vector<std::vector<ForPtr>> GetAllPerfectlyNestedLoopNests(
37     std::vector<ForPtr> loops) {
38   // Find the first set of loops that can be reordered
39   std::vector<std::vector<ForPtr>> all_nested_loops;
40   std::vector<ForPtr> nested_loops;
41   if (loops.empty()) {
42     return all_nested_loops;
43   }
44   nested_loops.push_back(loops[0]);
45   for (size_t i = 1; i < loops.size(); i++) {
46     auto last_loop = nested_loops.back();
47     auto next_loop = loops[i];
48     if (last_loop->body()->nstmts() == 1 &&
49         last_loop->body()->front() == next_loop) {
50       nested_loops.push_back(next_loop);
51     } else {
52       if (nested_loops.size() > 1) {
53         all_nested_loops.push_back(nested_loops);
54       }
55       nested_loops.clear();
56       nested_loops.push_back(next_loop);
57     }
58   }
59   return all_nested_loops;
60 }
61 
62 template <typename T>
select_n_randomly(std::vector<T> & objects,int n,std::default_random_engine & random_engine)63 std::tuple<std::vector<T>, std::vector<int>> select_n_randomly(
64     std::vector<T>& objects,
65     int n,
66     std::default_random_engine& random_engine) {
67   std::vector<int> indices(objects.size());
68   std::iota(indices.begin(), indices.end(), 0);
69   std::shuffle(indices.begin(), indices.end(), random_engine);
70 
71   std::vector<T> selected_objects;
72   std::vector<int> selected_indices;
73   if (static_cast<int>(indices.size()) < n) {
74     return std::make_tuple(selected_objects, selected_indices);
75   }
76   for (int i = 0; i < n; i++) {
77     int index = indices[i];
78     selected_indices.push_back(index);
79     selected_objects.push_back(objects[index]);
80   }
81   return std::make_tuple(selected_objects, selected_indices);
82 }
83 
find_factor(const ForPtr & loop)84 static int find_factor(const ForPtr& loop) {
85   // Find valid factors
86   ExprPtr loop_stop = loop->stop();
87   auto loop_imm = intValue(loop_stop);
88   if (loop_imm) {
89     int loop_bound = *loop_imm;
90     int factor = rand() % (loop_bound - 1) + 1;
91     return factor;
92   }
93   return -1;
94 }
95 
printHistory(int index,std::string message)96 static void printHistory(int index, std::string message) {
97   message = "Random Transform Sequence - Transformations[" +
98       std::to_string(index) + "] = " + message;
99   GRAPH_DEBUG(message);
100 }
101 
102 template <typename T>
join(std::vector<T> indices,char sep=',')103 std::string join(std::vector<T> indices, char sep = ',') {
104   std::string s = "";
105   for (const auto& index : indices) {
106     s += std::to_string(index) + sep;
107   }
108   return s;
109 }
110 
join(const std::vector<std::string> & indices,char sep=',')111 static std::string join(
112     const std::vector<std::string>& indices,
113     char sep = ',') {
114   std::string s = "";
115   for (const auto& index : indices) {
116     s += index + sep;
117   }
118   return s;
119 }
120 template <typename T>
indexOf(const std::vector<T> & objects,const T & object)121 std::string indexOf(const std::vector<T>& objects, const T& object) {
122   return std::to_string(std::distance(
123       objects.begin(), std::find(objects.begin(), objects.end(), object)));
124 }
125 
126 } // namespace randomization_helper
127 
loopnestRandomization(int64_t seed,LoopNest & l)128 void loopnestRandomization(int64_t seed, LoopNest& l) {
129   // This is to help with deterministic testing of randomized infrastructure.
130   // When seed value is 1, we perform preset loop transformations. This allows
131   // testing of interface.
132   if (seed == 1) {
133     l.simplify();
134     return;
135   }
136 
137   std::default_random_engine random_engine(seed);
138   std::srand(seed);
139   // Set the maximum allowed number of transformations beyond which it is hard
140   // to track and debug. Arbitrarily choosing 20 as maximum number.
141   int max_allowed_transformations = 20;
142   int n_transforms = randomization_helper::max_transformations(
143       std::rand() % max_allowed_transformations);
144   std::string message = "";
145   // clang-format off
146   //   Transformations list:
147   //
148   //       StmtPtr simplify();
149   //       bool computeInline(BufPtr b);
150   //       void inlineIntermediateBufs(bool allow_duplicated_work);
151   //       bool optimizeConditionals();
152   //       static void splitWithTail(ForPtr f, int factor);
153   //       static void splitWithMask(ForPtr f, int factor);
154   //       static std::vector<ForPtr> distributeLoop(ForPtr loop, const std::unordered_set<StmtPtr>& pivots);
155   //       static std::vector<ForPtr> distributeLoop(ForPtr loop);
156   //       static std::vector<ForPtr> distributeLoopAndParents(ForPtr loop);
157   //       static std::vector<ForPtr> distributeLoopOverInnerLoops(ForPtr loop);
158   //       static std::vector<ForPtr> distributeLoopAndParentsOverInnerLoops(ForPtr loop);
159   //       static bool fuseLoops(const std::vector<ForPtr>& loops, ForPtr* fused);
160   //       static void reorderAxis(ForPtr a, ForPtr b);
161   //       static std::vector<ForPtr> reorder(const std::vector<ForPtr>& loops, const std::vector<size_t>& permutation);
162   //       ForPtr tile(ForPtr x, ForPtr y, int x_factor, int y_factor);
163   //       static void fullUnroll(ForPtr f);
164   //       static bool normalize(ForPtr f);
165   //       static bool flatten(const std::vector<ForPtr>& f, ForPtr* flattened);
166   //       static void compressBuffer(BufPtr buf, StmtPtr stmt);
167   //       static void compressAllBuffers(StmtPtr stmt);
168   //       static void sliceHead(ForPtr f, int factor, ForPtr* head, ForPtr* tail);
169   //       static void sliceHead(ForPtr f, int factor);
170   //       static void sliceTail(ForPtr f, int factor, ForPtr* head, ForPtr* tail);
171   //       static void sliceTail(ForPtr f, int factor);
172   //       static AccessResult cacheAccesses(BufPtr producer, const std::string& name, StmtPtr consumer);
173   //       static void computeAt(StmtPtr s, ForPtr at);
174   //       static bool rfactor(StmtPtr s, ForPtr outer_reduction_for);
175   //       static bool vectorize(ForPtr);
176   //       void vectorizeInnerLoops();
177   //       void eliminateDeadStores();
178   //       void prepareForCodegen();
179   // clang-format on
180   enum TransformKind {
181     SIMPLIFY = 0,
182     COMPUTE_INLINE,
183     INLINE_ALL,
184     OPT_COND,
185     SPLIT_TAIL,
186     SPLIT_MASK,
187     DIST1,
188     DIST2,
189     DIST3,
190     DIST4,
191     DIST5,
192     FUSE_LOOPS,
193     REORDER_AXIS,
194     REORDER,
195     TILE,
196     FULL_UNROLL,
197     NORMALIZE,
198     FLATTEN,
199     COMPRESS_BUFFER,
200     COMPRESS_ALL_BUFFERS,
201     SLICE_HEAD,
202     SLICE_TAIL,
203     CACHE_ACCESSES,
204     COMPUTE_AT,
205     RFACTOR,
206     VECTORIZE,
207     VECTORIZE_INNER_LOOPS,
208     ELIMINATE_DEAD_STORES,
209     MAX_TRANSFORM,
210   };
211   bool can_inline = true;
212   try {
213     for (int n_transform = 0; n_transform < n_transforms; n_transform++) {
214       int transform = std::rand() % MAX_TRANSFORM;
215       switch (transform) {
216         case SIMPLIFY: {
217           message = "simplify();\n";
218           randomization_helper::printHistory(n_transform, message);
219           l.simplify();
220           break;
221         }
222         case COMPUTE_INLINE: {
223           if (can_inline) {
224             auto bufs = NodeFinder<Buf>::find(l.root_stmt());
225             if (!bufs.empty()) {
226               int buf_number = std::rand() % (int)bufs.size();
227               message =
228                   "computeInline(" + bufs[buf_number]->name_hint() + ");\n";
229               randomization_helper::printHistory(n_transform, message);
230               l.computeInline(bufs[buf_number]);
231             }
232           }
233           break;
234         }
235         case INLINE_ALL: {
236           if (can_inline) {
237             int allow_dup = std::rand() % 2;
238             message =
239                 "inlineIntermediateBufs(" + std::to_string(allow_dup) + ");\n";
240             randomization_helper::printHistory(n_transform, message);
241             l.inlineIntermediateBufs(allow_dup);
242             can_inline = false;
243           }
244           break;
245         }
246         case OPT_COND: {
247           message = "optimizeConditionals();\n";
248           randomization_helper::printHistory(n_transform, message);
249           l.optimizeConditionals();
250           break;
251         }
252         case SPLIT_TAIL: {
253           auto loops = NodeFinder<For>::find(l.root_stmt());
254           if (loops.empty()) {
255             break;
256           }
257           int loop_n = std::rand() % (int)loops.size();
258           auto loop = loops[loop_n];
259           int factor = (std::rand() % 20) + 1;
260           message = "splitWithTail(loops[" + std::to_string(loop_n) + "], " +
261               std::to_string(factor) + ");\n";
262           randomization_helper::printHistory(n_transform, message);
263           l.splitWithTail(loop, factor);
264           break;
265         }
266         case SPLIT_MASK: {
267           auto loops = NodeFinder<For>::find(l.root_stmt());
268           if (loops.empty()) {
269             break;
270           }
271           int loop_n = std::rand() % (int)loops.size();
272           auto loop = loops[loop_n];
273           int factor = (std::rand() % 20) + 1;
274           message = "splitWithMask(loops[" + std::to_string(loop_n) + "], " +
275               std::to_string(factor) + ")\n";
276           randomization_helper::printHistory(n_transform, message);
277           l.splitWithMask(loop, factor);
278           break;
279         }
280         case DIST1: {
281           auto loops = NodeFinder<For>::find(l.root_stmt());
282           if (loops.empty()) {
283             break;
284           }
285           int loop_n = std::rand() % (int)loops.size();
286           auto loop = loops[loop_n];
287           std::vector<StmtPtr> stmts(
288               loop->body()->begin(), loop->body()->end());
289           if (stmts.empty()) {
290             break;
291           }
292           int n_pivots = (std::rand() % (int)stmts.size()) + 1;
293           auto [pivots, chosen_indices] =
294               randomization_helper::select_n_randomly<StmtPtr>(
295                   stmts, n_pivots, random_engine);
296           std::unordered_set<StmtPtr> pivots_set(pivots.begin(), pivots.end());
297           message = "distributeLoop(loops[" + std::to_string(loop_n) +
298               "], pivots=stmts(" + randomization_helper::join(chosen_indices) +
299               "))\n";
300           randomization_helper::printHistory(n_transform, message);
301           l.distributeLoop(loop, pivots_set);
302           break;
303         }
304         case DIST2: {
305           auto loops = NodeFinder<For>::find(l.root_stmt());
306 
307           if (loops.empty()) {
308             break;
309           }
310           int loop_n = std::rand() % (int)loops.size();
311           auto loop = loops[loop_n];
312 
313           message = "distributeLoop(loops[" + std::to_string(loop_n) + "])\n";
314           randomization_helper::printHistory(n_transform, message);
315           l.distributeLoop(loop);
316           break;
317         }
318         case DIST3: {
319           auto loops = NodeFinder<For>::find(l.root_stmt());
320 
321           if (loops.empty()) {
322             break;
323           }
324           int loop_n = std::rand() % (int)loops.size();
325           auto loop = loops[loop_n];
326 
327           message = "distributeLoopAndParents(loops[" + std::to_string(loop_n) +
328               "])\n";
329           randomization_helper::printHistory(n_transform, message);
330           l.distributeLoopAndParents(loop);
331           break;
332         }
333         case DIST4: {
334           auto loops = NodeFinder<For>::find(l.root_stmt());
335 
336           if (loops.empty()) {
337             break;
338           }
339           int loop_n = std::rand() % (int)loops.size();
340           auto loop = loops[loop_n];
341 
342           message = "distributeLoopOverInnerLoops(loops[" +
343               std::to_string(loop_n) + "])\n";
344           randomization_helper::printHistory(n_transform, message);
345           l.distributeLoopOverInnerLoops(loop);
346           break;
347         }
348         case DIST5: {
349           auto loops = NodeFinder<For>::find(l.root_stmt());
350 
351           if (loops.empty()) {
352             break;
353           }
354           int loop_n = std::rand() % (int)loops.size();
355           auto loop = loops[loop_n];
356 
357           message = "distributeLoopAndParentsOverInnerLoops(loops[" +
358               std::to_string(loop_n) + "])\n";
359           randomization_helper::printHistory(n_transform, message);
360           l.distributeLoopAndParentsOverInnerLoops(loop);
361           break;
362         }
363         case FUSE_LOOPS: {
364           // Get all the loops
365           auto loops = NodeFinder<For>::find(l.root_stmt());
366           if (loops.size() <= 1) {
367             break;
368           }
369 
370           // Find a random number of loops to fuse
371           int num_loops_to_fuse =
372               std::max(2, (int)(std::rand() % (int)loops.size()));
373 
374           auto [loops_to_fuse, chosen_indices] =
375               randomization_helper::select_n_randomly<ForPtr>(
376                   loops, num_loops_to_fuse, random_engine);
377 
378           message = "fuseLoops(loops[" +
379               randomization_helper::join(chosen_indices) + "], &fused_loop);\n";
380           randomization_helper::printHistory(n_transform, message);
381           // Fuse the loops
382           ForPtr fused_loop;
383           l.fuseLoops(loops_to_fuse, &fused_loop);
384           break;
385         }
386 
387         case REORDER_AXIS: {
388           // Get all the loops
389           auto loops = NodeFinder<For>::find(l.root_stmt());
390           if (loops.size() <= 1) {
391             break;
392           }
393 
394           // Find pairs of axes that can be reordered
395           std::vector<std::pair<ForPtr, ForPtr>> valid_pairs;
396           for (const auto i : c10::irange(loops.size())) {
397             for (const auto j : c10::irange(i + 1, loops.size())) {
398               if (LoopNest::findOuterFor(loops[i], loops[j])) {
399                 valid_pairs.emplace_back(loops[i], loops[j]);
400               }
401             }
402           }
403 
404           // Choose a pair randomly
405           if (valid_pairs.empty()) {
406             break;
407           }
408           int valid_pair_n = std::rand() % (int)valid_pairs.size();
409           auto loop_pair = valid_pairs.at(valid_pair_n);
410           auto first_loop = std::get<0>(loop_pair);
411           auto second_loop = std::get<1>(loop_pair);
412 
413           std::string first_index =
414               randomization_helper::indexOf(loops, first_loop);
415           std::string second_index =
416               randomization_helper::indexOf(loops, second_loop);
417           message = "reorderAxis(loops[";
418           message += first_index;
419           message += "], loops[";
420           message += second_index + "]);\n";
421           randomization_helper::printHistory(n_transform, message);
422           // reorder the axis
423           l.reorderAxis(first_loop, second_loop);
424           break;
425         }
426 
427         case REORDER: {
428           // Get all the loops
429           auto loops = NodeFinder<For>::find(l.root_stmt());
430           if (loops.size() <= 1) {
431             break;
432           }
433 
434           // Find all perfectly nested loop nests
435           auto all_nested_loops =
436               randomization_helper::GetAllPerfectlyNestedLoopNests(loops);
437           if (all_nested_loops.empty()) {
438             break;
439           }
440 
441           // Randomly pick a set of consecutive loops to reorder
442           int index = rand() % (int)all_nested_loops.size();
443           auto nested_loops = all_nested_loops.at(index);
444 
445           // Create a random permutation for reordering
446           std::vector<size_t> permutation(nested_loops.size());
447           std::iota(permutation.begin(), permutation.end(), 0);
448           std::shuffle(permutation.begin(), permutation.end(), random_engine);
449 
450           // Generate a good history message
451           std::vector<std::string> indices;
452           indices.reserve(nested_loops.size());
453           for (const auto& l : nested_loops) {
454             indices.push_back(randomization_helper::indexOf(loops, l));
455           }
456           message = "reorder(loops[" + randomization_helper::join(indices) +
457               "], permutation=[" + randomization_helper::join(permutation) +
458               "]);\n";
459           randomization_helper::printHistory(n_transform, message);
460           // reorder
461           l.reorder(nested_loops, permutation);
462           break;
463         }
464 
465         case TILE: {
466           // Get all the loops
467           auto loops = NodeFinder<For>::find(l.root_stmt());
468           if (loops.size() <= 1) {
469             break;
470           }
471 
472           // Tile needs two perfectly nested loops. To find such loops, we find
473           // all perfectly nested loop nests, randomly pick one of them, and
474           // randomly pick 2 consecutive loops in that loop nest.
475           // Find all perfectly nested loop nests
476           auto all_nested_loops =
477               randomization_helper::GetAllPerfectlyNestedLoopNests(loops);
478           if (all_nested_loops.empty()) {
479             break;
480           }
481 
482           int index = rand() % (int)all_nested_loops.size();
483           auto nested_loops = all_nested_loops.at(index);
484           if (nested_loops.size() < 2) {
485             break;
486           }
487           int loop_number = rand() % ((int)nested_loops.size() - 1);
488           auto x_loop = nested_loops.at(loop_number);
489           auto y_loop = nested_loops.at(loop_number + 1);
490 
491           int x_factor = randomization_helper::find_factor(x_loop);
492           int y_factor = randomization_helper::find_factor(y_loop);
493           if (x_factor == -1 || y_factor == -1) {
494             break;
495           }
496 
497           std::string x_loop_index =
498               randomization_helper::indexOf(loops, x_loop);
499           std::string y_loop_index =
500               randomization_helper::indexOf(loops, y_loop);
501           message = "tile(loops[";
502           message += x_loop_index;
503           message += "], loops[";
504           message += y_loop_index + "], ";
505           message += std::to_string(x_factor);
506           message += ", " + std::to_string(y_factor) + ");\n";
507           randomization_helper::printHistory(n_transform, message);
508           // tile
509           l.tile(x_loop, y_loop, x_factor, y_factor);
510           break;
511         }
512 
513         case FULL_UNROLL: {
514           auto loops = NodeFinder<For>::find(l.root_stmt());
515           if (loops.empty()) {
516             break;
517           }
518           int loop_n = std::rand() % (int)loops.size();
519           auto loop = loops[loop_n];
520 
521           message = "fullUnroll(loops[" + std::to_string(loop_n) + "]);\n";
522           randomization_helper::printHistory(n_transform, message);
523           LoopNest::fullUnroll(loop);
524           break;
525         }
526 
527         case NORMALIZE: {
528           auto loops = NodeFinder<For>::find(l.root_stmt());
529           if (loops.empty()) {
530             break;
531           }
532           int loop_n = std::rand() % (int)loops.size();
533           auto loop = loops[loop_n];
534 
535           message = "normalize(loops[" + std::to_string(loop_n) + "]);\n";
536           randomization_helper::printHistory(n_transform, message);
537           l.normalize(loop);
538           break;
539         }
540 
541         case FLATTEN: {
542           // Get all the loops
543           auto loops = NodeFinder<For>::find(l.root_stmt());
544           if (loops.size() <= 1) {
545             break;
546           }
547 
548           // Find all perfectly nested loop nests
549           auto all_nested_loops =
550               randomization_helper::GetAllPerfectlyNestedLoopNests(loops);
551           if (all_nested_loops.empty()) {
552             break;
553           }
554 
555           // Randomly pick a set of consecutive loops to flatten
556           int index = rand() % (int)all_nested_loops.size();
557           auto nested_loops = all_nested_loops.at(index);
558 
559           // Generate a good history message
560           std::vector<std::string> indices;
561           indices.reserve(nested_loops.size());
562           for (const auto& l : nested_loops) {
563             indices.push_back(randomization_helper::indexOf(loops, l));
564           }
565           message =
566               "flatten(loops[" + randomization_helper::join(indices) + "]);\n";
567           randomization_helper::printHistory(n_transform, message);
568           // flatten
569           l.flatten(nested_loops);
570           break;
571         }
572 
573         case COMPRESS_BUFFER: {
574           auto buffers = NodeFinder<Buf>::find(l.root_stmt());
575           int buffer_n = std::rand() % (int)buffers.size();
576           auto buffer = buffers[buffer_n];
577 
578           message = "compressBuffer(buffers[" + std::to_string(buffer_n) +
579               "], l.root_stmt());\n";
580           randomization_helper::printHistory(n_transform, message);
581           l.compressBuffer(buffer, l.root_stmt());
582           break;
583         }
584 
585         case COMPRESS_ALL_BUFFERS: {
586           message = "compressAllBuffers(l.root_stmt());\n";
587           randomization_helper::printHistory(n_transform, message);
588           l.compressAllBuffers(l.root_stmt());
589           break;
590         }
591 
592         case SLICE_HEAD: {
593           // Get all the loops
594           auto loops = NodeFinder<For>::find(l.root_stmt());
595           if (loops.empty()) {
596             break;
597           }
598           int loop_n = std::rand() % (int)loops.size();
599           auto loop = loops[loop_n];
600 
601           int factor = randomization_helper::find_factor(loop);
602           if (factor == -1) {
603             break;
604           }
605           message = "sliceHead(loops[" + std::to_string(loop_n) + "]);\n";
606           randomization_helper::printHistory(n_transform, message);
607           l.sliceHead(loop, factor);
608           break;
609         }
610 
611         case SLICE_TAIL: {
612           // Get all the loops
613           auto loops = NodeFinder<For>::find(l.root_stmt());
614           if (loops.empty()) {
615             break;
616           }
617           int loop_n = std::rand() % (int)loops.size();
618           auto loop = loops[loop_n];
619 
620           int factor = randomization_helper::find_factor(loop);
621           if (factor == -1) {
622             break;
623           }
624           message = "sliceTail(loops[" + std::to_string(loop_n) + "]);\n";
625           randomization_helper::printHistory(n_transform, message);
626           l.sliceTail(loop, factor);
627           break;
628         }
629 
630         case CACHE_ACCESSES: {
631           // TODO - Implement cache_access
632           break;
633         }
634 
635         case COMPUTE_AT: {
636           // To find valid compute at pairs, we need to collect the producer
637           // consumer pairs. For now, we do not collect all such pairs for
638           // simplicity. For now, we collect producer and the immediate parent
639           // loop of the consumer. We could collect all the consumer enclosing
640           // loops, but then we will have to clean up the ones that are shared
641           // with the producer encloser loop. Currently, we only test on the
642           // immediate parent loop.
643           auto buffers = BufFinder::find(l.root_stmt());
644           std::vector<std::pair<StmtPtr, ForPtr>> producer_consumer_pairs;
645 
646           for (const auto& buffer : buffers) {
647             auto producers = l.getAllWritesToBuf(buffer);
648             auto consumers = StmtsReadingBuf::find(l.root_stmt(), buffer);
649             if (producers.size() != 1 || consumers.empty()) {
650               continue;
651             }
652 
653             for (const auto& producer : producers) {
654               for (const auto& consumer : consumers) {
655                 auto parent_loop = LoopNest::getParentLoop(consumer);
656                 auto pc_pair = std::make_pair(producer, parent_loop);
657                 producer_consumer_pairs.push_back(pc_pair);
658               }
659             }
660           }
661 
662           if (producer_consumer_pairs.empty()) {
663             break;
664           }
665 
666           // Choose a random pair
667           int pair_n = std::rand() % (int)producer_consumer_pairs.size();
668           auto pc_pair = producer_consumer_pairs.at(pair_n);
669           auto store = std::get<0>(pc_pair);
670           auto for_ptr = std::get<1>(pc_pair);
671 
672           // TODO - come up with better message
673           message = "computeAt(....);\n";
674           randomization_helper::printHistory(n_transform, message);
675           l.computeAt(store, for_ptr);
676           break;
677         }
678 
679         case RFACTOR: {
680           // TODO - Implement rfactor
681           break;
682         }
683 
684         case VECTORIZE: {
685           auto loops = NodeFinder<For>::find(l.root_stmt());
686           std::vector<ForPtr> innermost_loops;
687 
688           for (const auto& loop : loops) {
689             bool containsSubLoops = false;
690             if (BlockPtr body = to<Block>(loop->body())) {
691               for (const StmtPtr& stmt : *body) {
692                 if (ForPtr f2 = to<For>(stmt)) {
693                   containsSubLoops = true;
694                 }
695               }
696             }
697 
698             if (!containsSubLoops) {
699               innermost_loops.push_back(loop);
700             }
701           }
702 
703           if (innermost_loops.empty()) {
704             break;
705           }
706           int loop_n = std::rand() % (int)innermost_loops.size();
707           auto loop = innermost_loops[loop_n];
708 
709           message = "vectorize(loops[" + std::to_string(loop_n) + "]);\n";
710           randomization_helper::printHistory(n_transform, message);
711           l.vectorize(loop);
712           break;
713         }
714 
715         case VECTORIZE_INNER_LOOPS: {
716           message = "vectorizeInnerLoops();\n";
717           randomization_helper::printHistory(n_transform, message);
718           l.vectorizeInnerLoops();
719           break;
720         }
721 
722         case ELIMINATE_DEAD_STORES: {
723           message = "eliminateDeadStores();\n";
724           randomization_helper::printHistory(n_transform, message);
725           l.eliminateDeadStores();
726           break;
727         }
728 
729         // TODO: Add remaining transforms
730         default:
731           break;
732       }
733     }
734   } catch (...) {
735     std::cout << "EXCEPTION THROWN!\n";
736     std::cout << "SEED: " << seed << "\n";
737     throw std::runtime_error("Random test failed");
738   }
739   message = "End of transformations;\n";
740   randomization_helper::printHistory(n_transforms, message);
741   return;
742 }
743 
744 } // namespace torch::jit::tensorexpr
745