1 #include <functional>
2 #include <memory>
3 #include <string>
4
5 #include <torch/csrc/Export.h>
6 #include <torch/csrc/jit/frontend/canonicalize_modified_loop.h>
7 #include <torch/csrc/jit/ir/ir.h>
8 #include <torch/csrc/jit/ir/ir_views.h>
9
10 namespace torch::jit {
11
12 // Transforms a Loop that has both a trip count specified and a loop
13 // body condition so that the iter count is no longer specified
14 // and it is recognizable as a python while loop.
canonicalizeModifiedLoop(Node * n)15 static void canonicalizeModifiedLoop(Node* n) {
16 LoopView loop(n);
17 if (loop.loopType() != LoopView::ModifiedLoop) {
18 return;
19 }
20
21 auto g = n->owningGraph();
22 WithInsertPoint node_insert(n);
23 auto zero = g->insertConstant(0);
24 auto one = g->insertConstant(1);
25 auto max_trip_count = loop.maxTripCount();
26 auto condition = g->insert(aten::gt, {max_trip_count, zero});
27 loop.replaceMaxTripCount(
28 g->insertConstant(std::numeric_limits<int64_t>::max()));
29
30 auto inp_condition = toIValue(loop.inputCond());
31 if (inp_condition == std::nullopt || inp_condition->toBool() == false) {
32 condition = g->insert(aten::__and__, {condition, loop.inputCond()});
33 }
34 loop.replaceInputCondition(condition);
35 n->addOutput()->setType(IntType::get());
36 WithInsertPoint loop_insert(loop.bodyBlock());
37 n->addInput(zero);
38 auto new_iter = loop.bodyBlock()->addInput()->setType(IntType::get());
39 // unset unique name for jitter, its replacement does not have a name
40 loop.currentTripCount()->setDebugName("")->replaceAllUsesWith(new_iter);
41 auto inc_iter = g->insert(aten::add, {new_iter, one});
42 loop.bodyBlock()->registerOutput(inc_iter);
43 auto less_than_max_trip = g->insert(aten::lt, {inc_iter, max_trip_count});
44 auto loop_continue = loop.nextCond();
45 auto new_condition =
46 g->insert(aten::__and__, {less_than_max_trip, loop_continue});
47 loop.bodyBlock()->eraseOutput(0);
48 loop.bodyBlock()->insertOutput(0, new_condition);
49 }
50
canonicalizeModifiedLoops(Block * block)51 static void canonicalizeModifiedLoops(Block* block) {
52 for (Node* n : block->nodes()) {
53 for (Block* b : n->blocks()) {
54 canonicalizeModifiedLoops(b);
55 }
56 if (n->kind() == prim::Loop) {
57 canonicalizeModifiedLoop(n);
58 }
59 }
60 }
61
62 // Transforms loops so that they can be represented as python
63 // for or while loops
CanonicalizeModifiedLoops(std::shared_ptr<Graph> & graph)64 TORCH_API void CanonicalizeModifiedLoops(std::shared_ptr<Graph>& graph) {
65 canonicalizeModifiedLoops(graph->block());
66 }
67
68 } // namespace torch::jit
69