1 #include <algorithm>
2 #include <iostream>
3 #include <random>
4 #include <stdexcept>
5 #include <typeinfo>
6 #include <unordered_map>
7 #include <unordered_set>
8 #include <vector>
9
10 #include <torch/csrc/jit/jit_log.h>
11 #include <torch/csrc/jit/jit_opt_limit.h>
12 #include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
13 #include <torch/csrc/jit/tensorexpr/loopnest.h>
14 #include <torch/csrc/jit/tensorexpr/loopnest_randomization.h>
15
16 namespace torch::jit::tensorexpr {
17
18 namespace randomization_helper {
19
max_transformations(int n_max_transforms)20 static int64_t max_transformations(int n_max_transforms) {
21 // Reuse the env variable PYTORCH_JIT_OPT_LIMIT to control the max number of
22 // transformations. Example - set the env variable
23 // PYTORCH_JIT_OPT_LIMIT="loopnest_randomization=10" to set max
24 // transformations to 10. This can be helpful in gradually reducing the
25 // number of transformations when we see an error.
26 if (!JIT_OPT_ALLOWED) {
27 return n_max_transforms;
28 }
29 int max_transforms = 1;
30 while (JIT_OPT_ALLOWED && max_transforms < n_max_transforms) {
31 max_transforms++;
32 }
33 return max_transforms;
34 }
35
GetAllPerfectlyNestedLoopNests(std::vector<ForPtr> loops)36 static std::vector<std::vector<ForPtr>> GetAllPerfectlyNestedLoopNests(
37 std::vector<ForPtr> loops) {
38 // Find the first set of loops that can be reordered
39 std::vector<std::vector<ForPtr>> all_nested_loops;
40 std::vector<ForPtr> nested_loops;
41 if (loops.empty()) {
42 return all_nested_loops;
43 }
44 nested_loops.push_back(loops[0]);
45 for (size_t i = 1; i < loops.size(); i++) {
46 auto last_loop = nested_loops.back();
47 auto next_loop = loops[i];
48 if (last_loop->body()->nstmts() == 1 &&
49 last_loop->body()->front() == next_loop) {
50 nested_loops.push_back(next_loop);
51 } else {
52 if (nested_loops.size() > 1) {
53 all_nested_loops.push_back(nested_loops);
54 }
55 nested_loops.clear();
56 nested_loops.push_back(next_loop);
57 }
58 }
59 return all_nested_loops;
60 }
61
62 template <typename T>
select_n_randomly(std::vector<T> & objects,int n,std::default_random_engine & random_engine)63 std::tuple<std::vector<T>, std::vector<int>> select_n_randomly(
64 std::vector<T>& objects,
65 int n,
66 std::default_random_engine& random_engine) {
67 std::vector<int> indices(objects.size());
68 std::iota(indices.begin(), indices.end(), 0);
69 std::shuffle(indices.begin(), indices.end(), random_engine);
70
71 std::vector<T> selected_objects;
72 std::vector<int> selected_indices;
73 if (static_cast<int>(indices.size()) < n) {
74 return std::make_tuple(selected_objects, selected_indices);
75 }
76 for (int i = 0; i < n; i++) {
77 int index = indices[i];
78 selected_indices.push_back(index);
79 selected_objects.push_back(objects[index]);
80 }
81 return std::make_tuple(selected_objects, selected_indices);
82 }
83
find_factor(const ForPtr & loop)84 static int find_factor(const ForPtr& loop) {
85 // Find valid factors
86 ExprPtr loop_stop = loop->stop();
87 auto loop_imm = intValue(loop_stop);
88 if (loop_imm) {
89 int loop_bound = *loop_imm;
90 int factor = rand() % (loop_bound - 1) + 1;
91 return factor;
92 }
93 return -1;
94 }
95
printHistory(int index,std::string message)96 static void printHistory(int index, std::string message) {
97 message = "Random Transform Sequence - Transformations[" +
98 std::to_string(index) + "] = " + message;
99 GRAPH_DEBUG(message);
100 }
101
102 template <typename T>
join(std::vector<T> indices,char sep=',')103 std::string join(std::vector<T> indices, char sep = ',') {
104 std::string s = "";
105 for (const auto& index : indices) {
106 s += std::to_string(index) + sep;
107 }
108 return s;
109 }
110
join(const std::vector<std::string> & indices,char sep=',')111 static std::string join(
112 const std::vector<std::string>& indices,
113 char sep = ',') {
114 std::string s = "";
115 for (const auto& index : indices) {
116 s += index + sep;
117 }
118 return s;
119 }
120 template <typename T>
indexOf(const std::vector<T> & objects,const T & object)121 std::string indexOf(const std::vector<T>& objects, const T& object) {
122 return std::to_string(std::distance(
123 objects.begin(), std::find(objects.begin(), objects.end(), object)));
124 }
125
126 } // namespace randomization_helper
127
loopnestRandomization(int64_t seed,LoopNest & l)128 void loopnestRandomization(int64_t seed, LoopNest& l) {
129 // This is to help with deterministic testing of randomized infrastructure.
130 // When seed value is 1, we perform preset loop transformations. This allows
131 // testing of interface.
132 if (seed == 1) {
133 l.simplify();
134 return;
135 }
136
137 std::default_random_engine random_engine(seed);
138 std::srand(seed);
139 // Set the maximum allowed number of transformations beyond which it is hard
140 // to track and debug. Arbitrarily choosing 20 as maximum number.
141 int max_allowed_transformations = 20;
142 int n_transforms = randomization_helper::max_transformations(
143 std::rand() % max_allowed_transformations);
144 std::string message = "";
145 // clang-format off
146 // Transformations list:
147 //
148 // StmtPtr simplify();
149 // bool computeInline(BufPtr b);
150 // void inlineIntermediateBufs(bool allow_duplicated_work);
151 // bool optimizeConditionals();
152 // static void splitWithTail(ForPtr f, int factor);
153 // static void splitWithMask(ForPtr f, int factor);
154 // static std::vector<ForPtr> distributeLoop(ForPtr loop, const std::unordered_set<StmtPtr>& pivots);
155 // static std::vector<ForPtr> distributeLoop(ForPtr loop);
156 // static std::vector<ForPtr> distributeLoopAndParents(ForPtr loop);
157 // static std::vector<ForPtr> distributeLoopOverInnerLoops(ForPtr loop);
158 // static std::vector<ForPtr> distributeLoopAndParentsOverInnerLoops(ForPtr loop);
159 // static bool fuseLoops(const std::vector<ForPtr>& loops, ForPtr* fused);
160 // static void reorderAxis(ForPtr a, ForPtr b);
161 // static std::vector<ForPtr> reorder(const std::vector<ForPtr>& loops, const std::vector<size_t>& permutation);
162 // ForPtr tile(ForPtr x, ForPtr y, int x_factor, int y_factor);
163 // static void fullUnroll(ForPtr f);
164 // static bool normalize(ForPtr f);
165 // static bool flatten(const std::vector<ForPtr>& f, ForPtr* flattened);
166 // static void compressBuffer(BufPtr buf, StmtPtr stmt);
167 // static void compressAllBuffers(StmtPtr stmt);
168 // static void sliceHead(ForPtr f, int factor, ForPtr* head, ForPtr* tail);
169 // static void sliceHead(ForPtr f, int factor);
170 // static void sliceTail(ForPtr f, int factor, ForPtr* head, ForPtr* tail);
171 // static void sliceTail(ForPtr f, int factor);
172 // static AccessResult cacheAccesses(BufPtr producer, const std::string& name, StmtPtr consumer);
173 // static void computeAt(StmtPtr s, ForPtr at);
174 // static bool rfactor(StmtPtr s, ForPtr outer_reduction_for);
175 // static bool vectorize(ForPtr);
176 // void vectorizeInnerLoops();
177 // void eliminateDeadStores();
178 // void prepareForCodegen();
179 // clang-format on
180 enum TransformKind {
181 SIMPLIFY = 0,
182 COMPUTE_INLINE,
183 INLINE_ALL,
184 OPT_COND,
185 SPLIT_TAIL,
186 SPLIT_MASK,
187 DIST1,
188 DIST2,
189 DIST3,
190 DIST4,
191 DIST5,
192 FUSE_LOOPS,
193 REORDER_AXIS,
194 REORDER,
195 TILE,
196 FULL_UNROLL,
197 NORMALIZE,
198 FLATTEN,
199 COMPRESS_BUFFER,
200 COMPRESS_ALL_BUFFERS,
201 SLICE_HEAD,
202 SLICE_TAIL,
203 CACHE_ACCESSES,
204 COMPUTE_AT,
205 RFACTOR,
206 VECTORIZE,
207 VECTORIZE_INNER_LOOPS,
208 ELIMINATE_DEAD_STORES,
209 MAX_TRANSFORM,
210 };
211 bool can_inline = true;
212 try {
213 for (int n_transform = 0; n_transform < n_transforms; n_transform++) {
214 int transform = std::rand() % MAX_TRANSFORM;
215 switch (transform) {
216 case SIMPLIFY: {
217 message = "simplify();\n";
218 randomization_helper::printHistory(n_transform, message);
219 l.simplify();
220 break;
221 }
222 case COMPUTE_INLINE: {
223 if (can_inline) {
224 auto bufs = NodeFinder<Buf>::find(l.root_stmt());
225 if (!bufs.empty()) {
226 int buf_number = std::rand() % (int)bufs.size();
227 message =
228 "computeInline(" + bufs[buf_number]->name_hint() + ");\n";
229 randomization_helper::printHistory(n_transform, message);
230 l.computeInline(bufs[buf_number]);
231 }
232 }
233 break;
234 }
235 case INLINE_ALL: {
236 if (can_inline) {
237 int allow_dup = std::rand() % 2;
238 message =
239 "inlineIntermediateBufs(" + std::to_string(allow_dup) + ");\n";
240 randomization_helper::printHistory(n_transform, message);
241 l.inlineIntermediateBufs(allow_dup);
242 can_inline = false;
243 }
244 break;
245 }
246 case OPT_COND: {
247 message = "optimizeConditionals();\n";
248 randomization_helper::printHistory(n_transform, message);
249 l.optimizeConditionals();
250 break;
251 }
252 case SPLIT_TAIL: {
253 auto loops = NodeFinder<For>::find(l.root_stmt());
254 if (loops.empty()) {
255 break;
256 }
257 int loop_n = std::rand() % (int)loops.size();
258 auto loop = loops[loop_n];
259 int factor = (std::rand() % 20) + 1;
260 message = "splitWithTail(loops[" + std::to_string(loop_n) + "], " +
261 std::to_string(factor) + ");\n";
262 randomization_helper::printHistory(n_transform, message);
263 l.splitWithTail(loop, factor);
264 break;
265 }
266 case SPLIT_MASK: {
267 auto loops = NodeFinder<For>::find(l.root_stmt());
268 if (loops.empty()) {
269 break;
270 }
271 int loop_n = std::rand() % (int)loops.size();
272 auto loop = loops[loop_n];
273 int factor = (std::rand() % 20) + 1;
274 message = "splitWithMask(loops[" + std::to_string(loop_n) + "], " +
275 std::to_string(factor) + ")\n";
276 randomization_helper::printHistory(n_transform, message);
277 l.splitWithMask(loop, factor);
278 break;
279 }
280 case DIST1: {
281 auto loops = NodeFinder<For>::find(l.root_stmt());
282 if (loops.empty()) {
283 break;
284 }
285 int loop_n = std::rand() % (int)loops.size();
286 auto loop = loops[loop_n];
287 std::vector<StmtPtr> stmts(
288 loop->body()->begin(), loop->body()->end());
289 if (stmts.empty()) {
290 break;
291 }
292 int n_pivots = (std::rand() % (int)stmts.size()) + 1;
293 auto [pivots, chosen_indices] =
294 randomization_helper::select_n_randomly<StmtPtr>(
295 stmts, n_pivots, random_engine);
296 std::unordered_set<StmtPtr> pivots_set(pivots.begin(), pivots.end());
297 message = "distributeLoop(loops[" + std::to_string(loop_n) +
298 "], pivots=stmts(" + randomization_helper::join(chosen_indices) +
299 "))\n";
300 randomization_helper::printHistory(n_transform, message);
301 l.distributeLoop(loop, pivots_set);
302 break;
303 }
304 case DIST2: {
305 auto loops = NodeFinder<For>::find(l.root_stmt());
306
307 if (loops.empty()) {
308 break;
309 }
310 int loop_n = std::rand() % (int)loops.size();
311 auto loop = loops[loop_n];
312
313 message = "distributeLoop(loops[" + std::to_string(loop_n) + "])\n";
314 randomization_helper::printHistory(n_transform, message);
315 l.distributeLoop(loop);
316 break;
317 }
318 case DIST3: {
319 auto loops = NodeFinder<For>::find(l.root_stmt());
320
321 if (loops.empty()) {
322 break;
323 }
324 int loop_n = std::rand() % (int)loops.size();
325 auto loop = loops[loop_n];
326
327 message = "distributeLoopAndParents(loops[" + std::to_string(loop_n) +
328 "])\n";
329 randomization_helper::printHistory(n_transform, message);
330 l.distributeLoopAndParents(loop);
331 break;
332 }
333 case DIST4: {
334 auto loops = NodeFinder<For>::find(l.root_stmt());
335
336 if (loops.empty()) {
337 break;
338 }
339 int loop_n = std::rand() % (int)loops.size();
340 auto loop = loops[loop_n];
341
342 message = "distributeLoopOverInnerLoops(loops[" +
343 std::to_string(loop_n) + "])\n";
344 randomization_helper::printHistory(n_transform, message);
345 l.distributeLoopOverInnerLoops(loop);
346 break;
347 }
348 case DIST5: {
349 auto loops = NodeFinder<For>::find(l.root_stmt());
350
351 if (loops.empty()) {
352 break;
353 }
354 int loop_n = std::rand() % (int)loops.size();
355 auto loop = loops[loop_n];
356
357 message = "distributeLoopAndParentsOverInnerLoops(loops[" +
358 std::to_string(loop_n) + "])\n";
359 randomization_helper::printHistory(n_transform, message);
360 l.distributeLoopAndParentsOverInnerLoops(loop);
361 break;
362 }
363 case FUSE_LOOPS: {
364 // Get all the loops
365 auto loops = NodeFinder<For>::find(l.root_stmt());
366 if (loops.size() <= 1) {
367 break;
368 }
369
370 // Find a random number of loops to fuse
371 int num_loops_to_fuse =
372 std::max(2, (int)(std::rand() % (int)loops.size()));
373
374 auto [loops_to_fuse, chosen_indices] =
375 randomization_helper::select_n_randomly<ForPtr>(
376 loops, num_loops_to_fuse, random_engine);
377
378 message = "fuseLoops(loops[" +
379 randomization_helper::join(chosen_indices) + "], &fused_loop);\n";
380 randomization_helper::printHistory(n_transform, message);
381 // Fuse the loops
382 ForPtr fused_loop;
383 l.fuseLoops(loops_to_fuse, &fused_loop);
384 break;
385 }
386
387 case REORDER_AXIS: {
388 // Get all the loops
389 auto loops = NodeFinder<For>::find(l.root_stmt());
390 if (loops.size() <= 1) {
391 break;
392 }
393
394 // Find pairs of axes that can be reordered
395 std::vector<std::pair<ForPtr, ForPtr>> valid_pairs;
396 for (const auto i : c10::irange(loops.size())) {
397 for (const auto j : c10::irange(i + 1, loops.size())) {
398 if (LoopNest::findOuterFor(loops[i], loops[j])) {
399 valid_pairs.emplace_back(loops[i], loops[j]);
400 }
401 }
402 }
403
404 // Choose a pair randomly
405 if (valid_pairs.empty()) {
406 break;
407 }
408 int valid_pair_n = std::rand() % (int)valid_pairs.size();
409 auto loop_pair = valid_pairs.at(valid_pair_n);
410 auto first_loop = std::get<0>(loop_pair);
411 auto second_loop = std::get<1>(loop_pair);
412
413 std::string first_index =
414 randomization_helper::indexOf(loops, first_loop);
415 std::string second_index =
416 randomization_helper::indexOf(loops, second_loop);
417 message = "reorderAxis(loops[";
418 message += first_index;
419 message += "], loops[";
420 message += second_index + "]);\n";
421 randomization_helper::printHistory(n_transform, message);
422 // reorder the axis
423 l.reorderAxis(first_loop, second_loop);
424 break;
425 }
426
427 case REORDER: {
428 // Get all the loops
429 auto loops = NodeFinder<For>::find(l.root_stmt());
430 if (loops.size() <= 1) {
431 break;
432 }
433
434 // Find all perfectly nested loop nests
435 auto all_nested_loops =
436 randomization_helper::GetAllPerfectlyNestedLoopNests(loops);
437 if (all_nested_loops.empty()) {
438 break;
439 }
440
441 // Randomly pick a set of consecutive loops to reorder
442 int index = rand() % (int)all_nested_loops.size();
443 auto nested_loops = all_nested_loops.at(index);
444
445 // Create a random permutation for reordering
446 std::vector<size_t> permutation(nested_loops.size());
447 std::iota(permutation.begin(), permutation.end(), 0);
448 std::shuffle(permutation.begin(), permutation.end(), random_engine);
449
450 // Generate a good history message
451 std::vector<std::string> indices;
452 indices.reserve(nested_loops.size());
453 for (const auto& l : nested_loops) {
454 indices.push_back(randomization_helper::indexOf(loops, l));
455 }
456 message = "reorder(loops[" + randomization_helper::join(indices) +
457 "], permutation=[" + randomization_helper::join(permutation) +
458 "]);\n";
459 randomization_helper::printHistory(n_transform, message);
460 // reorder
461 l.reorder(nested_loops, permutation);
462 break;
463 }
464
465 case TILE: {
466 // Get all the loops
467 auto loops = NodeFinder<For>::find(l.root_stmt());
468 if (loops.size() <= 1) {
469 break;
470 }
471
472 // Tile needs two perfectly nested loops. To find such loops, we find
473 // all perfectly nested loop nests, randomly pick one of them, and
474 // randomly pick 2 consecutive loops in that loop nest.
475 // Find all perfectly nested loop nests
476 auto all_nested_loops =
477 randomization_helper::GetAllPerfectlyNestedLoopNests(loops);
478 if (all_nested_loops.empty()) {
479 break;
480 }
481
482 int index = rand() % (int)all_nested_loops.size();
483 auto nested_loops = all_nested_loops.at(index);
484 if (nested_loops.size() < 2) {
485 break;
486 }
487 int loop_number = rand() % ((int)nested_loops.size() - 1);
488 auto x_loop = nested_loops.at(loop_number);
489 auto y_loop = nested_loops.at(loop_number + 1);
490
491 int x_factor = randomization_helper::find_factor(x_loop);
492 int y_factor = randomization_helper::find_factor(y_loop);
493 if (x_factor == -1 || y_factor == -1) {
494 break;
495 }
496
497 std::string x_loop_index =
498 randomization_helper::indexOf(loops, x_loop);
499 std::string y_loop_index =
500 randomization_helper::indexOf(loops, y_loop);
501 message = "tile(loops[";
502 message += x_loop_index;
503 message += "], loops[";
504 message += y_loop_index + "], ";
505 message += std::to_string(x_factor);
506 message += ", " + std::to_string(y_factor) + ");\n";
507 randomization_helper::printHistory(n_transform, message);
508 // tile
509 l.tile(x_loop, y_loop, x_factor, y_factor);
510 break;
511 }
512
513 case FULL_UNROLL: {
514 auto loops = NodeFinder<For>::find(l.root_stmt());
515 if (loops.empty()) {
516 break;
517 }
518 int loop_n = std::rand() % (int)loops.size();
519 auto loop = loops[loop_n];
520
521 message = "fullUnroll(loops[" + std::to_string(loop_n) + "]);\n";
522 randomization_helper::printHistory(n_transform, message);
523 LoopNest::fullUnroll(loop);
524 break;
525 }
526
527 case NORMALIZE: {
528 auto loops = NodeFinder<For>::find(l.root_stmt());
529 if (loops.empty()) {
530 break;
531 }
532 int loop_n = std::rand() % (int)loops.size();
533 auto loop = loops[loop_n];
534
535 message = "normalize(loops[" + std::to_string(loop_n) + "]);\n";
536 randomization_helper::printHistory(n_transform, message);
537 l.normalize(loop);
538 break;
539 }
540
541 case FLATTEN: {
542 // Get all the loops
543 auto loops = NodeFinder<For>::find(l.root_stmt());
544 if (loops.size() <= 1) {
545 break;
546 }
547
548 // Find all perfectly nested loop nests
549 auto all_nested_loops =
550 randomization_helper::GetAllPerfectlyNestedLoopNests(loops);
551 if (all_nested_loops.empty()) {
552 break;
553 }
554
555 // Randomly pick a set of consecutive loops to flatten
556 int index = rand() % (int)all_nested_loops.size();
557 auto nested_loops = all_nested_loops.at(index);
558
559 // Generate a good history message
560 std::vector<std::string> indices;
561 indices.reserve(nested_loops.size());
562 for (const auto& l : nested_loops) {
563 indices.push_back(randomization_helper::indexOf(loops, l));
564 }
565 message =
566 "flatten(loops[" + randomization_helper::join(indices) + "]);\n";
567 randomization_helper::printHistory(n_transform, message);
568 // flatten
569 l.flatten(nested_loops);
570 break;
571 }
572
573 case COMPRESS_BUFFER: {
574 auto buffers = NodeFinder<Buf>::find(l.root_stmt());
575 int buffer_n = std::rand() % (int)buffers.size();
576 auto buffer = buffers[buffer_n];
577
578 message = "compressBuffer(buffers[" + std::to_string(buffer_n) +
579 "], l.root_stmt());\n";
580 randomization_helper::printHistory(n_transform, message);
581 l.compressBuffer(buffer, l.root_stmt());
582 break;
583 }
584
585 case COMPRESS_ALL_BUFFERS: {
586 message = "compressAllBuffers(l.root_stmt());\n";
587 randomization_helper::printHistory(n_transform, message);
588 l.compressAllBuffers(l.root_stmt());
589 break;
590 }
591
592 case SLICE_HEAD: {
593 // Get all the loops
594 auto loops = NodeFinder<For>::find(l.root_stmt());
595 if (loops.empty()) {
596 break;
597 }
598 int loop_n = std::rand() % (int)loops.size();
599 auto loop = loops[loop_n];
600
601 int factor = randomization_helper::find_factor(loop);
602 if (factor == -1) {
603 break;
604 }
605 message = "sliceHead(loops[" + std::to_string(loop_n) + "]);\n";
606 randomization_helper::printHistory(n_transform, message);
607 l.sliceHead(loop, factor);
608 break;
609 }
610
611 case SLICE_TAIL: {
612 // Get all the loops
613 auto loops = NodeFinder<For>::find(l.root_stmt());
614 if (loops.empty()) {
615 break;
616 }
617 int loop_n = std::rand() % (int)loops.size();
618 auto loop = loops[loop_n];
619
620 int factor = randomization_helper::find_factor(loop);
621 if (factor == -1) {
622 break;
623 }
624 message = "sliceTail(loops[" + std::to_string(loop_n) + "]);\n";
625 randomization_helper::printHistory(n_transform, message);
626 l.sliceTail(loop, factor);
627 break;
628 }
629
630 case CACHE_ACCESSES: {
631 // TODO - Implement cache_access
632 break;
633 }
634
635 case COMPUTE_AT: {
636 // To find valid compute at pairs, we need to collect the producer
637 // consumer pairs. For now, we do not collect all such pairs for
638 // simplicity. For now, we collect producer and the immediate parent
639 // loop of the consumer. We could collect all the consumer enclosing
640 // loops, but then we will have to clean up the ones that are shared
641 // with the producer encloser loop. Currently, we only test on the
642 // immediate parent loop.
643 auto buffers = BufFinder::find(l.root_stmt());
644 std::vector<std::pair<StmtPtr, ForPtr>> producer_consumer_pairs;
645
646 for (const auto& buffer : buffers) {
647 auto producers = l.getAllWritesToBuf(buffer);
648 auto consumers = StmtsReadingBuf::find(l.root_stmt(), buffer);
649 if (producers.size() != 1 || consumers.empty()) {
650 continue;
651 }
652
653 for (const auto& producer : producers) {
654 for (const auto& consumer : consumers) {
655 auto parent_loop = LoopNest::getParentLoop(consumer);
656 auto pc_pair = std::make_pair(producer, parent_loop);
657 producer_consumer_pairs.push_back(pc_pair);
658 }
659 }
660 }
661
662 if (producer_consumer_pairs.empty()) {
663 break;
664 }
665
666 // Choose a random pair
667 int pair_n = std::rand() % (int)producer_consumer_pairs.size();
668 auto pc_pair = producer_consumer_pairs.at(pair_n);
669 auto store = std::get<0>(pc_pair);
670 auto for_ptr = std::get<1>(pc_pair);
671
672 // TODO - come up with better message
673 message = "computeAt(....);\n";
674 randomization_helper::printHistory(n_transform, message);
675 l.computeAt(store, for_ptr);
676 break;
677 }
678
679 case RFACTOR: {
680 // TODO - Implement rfactor
681 break;
682 }
683
684 case VECTORIZE: {
685 auto loops = NodeFinder<For>::find(l.root_stmt());
686 std::vector<ForPtr> innermost_loops;
687
688 for (const auto& loop : loops) {
689 bool containsSubLoops = false;
690 if (BlockPtr body = to<Block>(loop->body())) {
691 for (const StmtPtr& stmt : *body) {
692 if (ForPtr f2 = to<For>(stmt)) {
693 containsSubLoops = true;
694 }
695 }
696 }
697
698 if (!containsSubLoops) {
699 innermost_loops.push_back(loop);
700 }
701 }
702
703 if (innermost_loops.empty()) {
704 break;
705 }
706 int loop_n = std::rand() % (int)innermost_loops.size();
707 auto loop = innermost_loops[loop_n];
708
709 message = "vectorize(loops[" + std::to_string(loop_n) + "]);\n";
710 randomization_helper::printHistory(n_transform, message);
711 l.vectorize(loop);
712 break;
713 }
714
715 case VECTORIZE_INNER_LOOPS: {
716 message = "vectorizeInnerLoops();\n";
717 randomization_helper::printHistory(n_transform, message);
718 l.vectorizeInnerLoops();
719 break;
720 }
721
722 case ELIMINATE_DEAD_STORES: {
723 message = "eliminateDeadStores();\n";
724 randomization_helper::printHistory(n_transform, message);
725 l.eliminateDeadStores();
726 break;
727 }
728
729 // TODO: Add remaining transforms
730 default:
731 break;
732 }
733 }
734 } catch (...) {
735 std::cout << "EXCEPTION THROWN!\n";
736 std::cout << "SEED: " << seed << "\n";
737 throw std::runtime_error("Random test failed");
738 }
739 message = "End of transformations;\n";
740 randomization_helper::printHistory(n_transforms, message);
741 return;
742 }
743
744 } // namespace torch::jit::tensorexpr
745