xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/frontend/canonicalize_modified_loop.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <functional>
2 #include <memory>
3 #include <string>
4 
5 #include <torch/csrc/Export.h>
6 #include <torch/csrc/jit/frontend/canonicalize_modified_loop.h>
7 #include <torch/csrc/jit/ir/ir.h>
8 #include <torch/csrc/jit/ir/ir_views.h>
9 
10 namespace torch::jit {
11 
12 // Transforms a Loop that has both a trip count specified and a loop
13 // body condition so that the iter count is no longer specified
14 // and it is recognizable as a python while loop.
canonicalizeModifiedLoop(Node * n)15 static void canonicalizeModifiedLoop(Node* n) {
16   LoopView loop(n);
17   if (loop.loopType() != LoopView::ModifiedLoop) {
18     return;
19   }
20 
21   auto g = n->owningGraph();
22   WithInsertPoint node_insert(n);
23   auto zero = g->insertConstant(0);
24   auto one = g->insertConstant(1);
25   auto max_trip_count = loop.maxTripCount();
26   auto condition = g->insert(aten::gt, {max_trip_count, zero});
27   loop.replaceMaxTripCount(
28       g->insertConstant(std::numeric_limits<int64_t>::max()));
29 
30   auto inp_condition = toIValue(loop.inputCond());
31   if (inp_condition == std::nullopt || inp_condition->toBool() == false) {
32     condition = g->insert(aten::__and__, {condition, loop.inputCond()});
33   }
34   loop.replaceInputCondition(condition);
35   n->addOutput()->setType(IntType::get());
36   WithInsertPoint loop_insert(loop.bodyBlock());
37   n->addInput(zero);
38   auto new_iter = loop.bodyBlock()->addInput()->setType(IntType::get());
39   // unset unique name for jitter, its replacement does not have a name
40   loop.currentTripCount()->setDebugName("")->replaceAllUsesWith(new_iter);
41   auto inc_iter = g->insert(aten::add, {new_iter, one});
42   loop.bodyBlock()->registerOutput(inc_iter);
43   auto less_than_max_trip = g->insert(aten::lt, {inc_iter, max_trip_count});
44   auto loop_continue = loop.nextCond();
45   auto new_condition =
46       g->insert(aten::__and__, {less_than_max_trip, loop_continue});
47   loop.bodyBlock()->eraseOutput(0);
48   loop.bodyBlock()->insertOutput(0, new_condition);
49 }
50 
canonicalizeModifiedLoops(Block * block)51 static void canonicalizeModifiedLoops(Block* block) {
52   for (Node* n : block->nodes()) {
53     for (Block* b : n->blocks()) {
54       canonicalizeModifiedLoops(b);
55     }
56     if (n->kind() == prim::Loop) {
57       canonicalizeModifiedLoop(n);
58     }
59   }
60 }
61 
62 // Transforms loops so that they can be represented as python
63 // for or while loops
CanonicalizeModifiedLoops(std::shared_ptr<Graph> & graph)64 TORCH_API void CanonicalizeModifiedLoops(std::shared_ptr<Graph>& graph) {
65   canonicalizeModifiedLoops(graph->block());
66 }
67 
68 } // namespace torch::jit
69