xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/frontend/inline_loop_condition.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <functional>
2 #include <memory>
3 #include <string>
4 
5 #include <torch/csrc/Export.h>
6 #include <torch/csrc/jit/frontend/inline_loop_condition.h>
7 #include <torch/csrc/jit/ir/ir.h>
8 
9 namespace torch::jit {
10 
InlineBlockBeforeNode(Node * before_node,Block * block)11 void InlineBlockBeforeNode(Node* before_node, Block* block) {
12   for (auto it = block->nodes().begin(); it != block->nodes().end();) {
13     auto block_node = *it++;
14     block_node->moveBefore(before_node);
15   }
16 }
17 
18 // The loop node is initially emitted as:
19 // Loop(max_trip_count)
20 //    block0(loop_counter) {
21 //      <body>
22 //    }
23 //    block1 {
24 //      <loop condition>
25 //      -> (condition)
26 //    }
27 // Here, we inline the loop condition and convert the loop to the form:
28 // Loop(max_trip_count, start_condition)
29 //    block0(loop_counter, loop_carried_block*) {
30 //      <body>
31 //       BlockExit(continue_condition, loop_carried_block*)
32 //    }
inlineLoopCondition(Node * n)33 static void inlineLoopCondition(Node* n) {
34   Block* body_block = n->blocks().at(0);
35 
36   auto pre_header = n->blocks().at(1);
37   auto temp_block = n->addBlock();
38   temp_block->cloneFrom(pre_header, [](Value* v) { return v; });
39   InlineBlockBeforeNode(n, temp_block);
40   n->insertInput(/*start_condition_index*/ 1, temp_block->outputs().at(0));
41   n->eraseBlock(2);
42 
43   InlineBlockBeforeNode(body_block->return_node(), pre_header);
44   body_block->return_node()->insertInput(0, pre_header->outputs().at(0));
45   n->eraseBlock(1);
46 }
47 
inlineLoopCondition(Block * block)48 static void inlineLoopCondition(Block* block) {
49   for (Node* n : block->nodes()) {
50     for (Block* b : n->blocks()) {
51       inlineLoopCondition(b);
52     }
53     if (n->kind() == prim::Loop) {
54       inlineLoopCondition(n);
55     }
56   }
57 }
58 
InlineLoopCondition(std::shared_ptr<Graph> & graph)59 void InlineLoopCondition(std::shared_ptr<Graph>& graph) {
60   inlineLoopCondition(graph->block());
61 }
62 
63 } // namespace torch::jit
64