1 #include <functional>
2 #include <memory>
3 #include <string>
4
5 #include <torch/csrc/Export.h>
6 #include <torch/csrc/jit/frontend/inline_loop_condition.h>
7 #include <torch/csrc/jit/ir/ir.h>
8
9 namespace torch::jit {
10
InlineBlockBeforeNode(Node * before_node,Block * block)11 void InlineBlockBeforeNode(Node* before_node, Block* block) {
12 for (auto it = block->nodes().begin(); it != block->nodes().end();) {
13 auto block_node = *it++;
14 block_node->moveBefore(before_node);
15 }
16 }
17
18 // The loop node is initially emitted as:
19 // Loop(max_trip_count)
20 // block0(loop_counter) {
21 // <body>
22 // }
23 // block1 {
24 // <loop condition>
25 // -> (condition)
26 // }
27 // Here, we inline the loop condition and convert the loop to the form:
28 // Loop(max_trip_count, start_condition)
29 // block0(loop_counter, loop_carried_block*) {
30 // <body>
31 // BlockExit(continue_condition, loop_carried_block*)
32 // }
inlineLoopCondition(Node * n)33 static void inlineLoopCondition(Node* n) {
34 Block* body_block = n->blocks().at(0);
35
36 auto pre_header = n->blocks().at(1);
37 auto temp_block = n->addBlock();
38 temp_block->cloneFrom(pre_header, [](Value* v) { return v; });
39 InlineBlockBeforeNode(n, temp_block);
40 n->insertInput(/*start_condition_index*/ 1, temp_block->outputs().at(0));
41 n->eraseBlock(2);
42
43 InlineBlockBeforeNode(body_block->return_node(), pre_header);
44 body_block->return_node()->insertInput(0, pre_header->outputs().at(0));
45 n->eraseBlock(1);
46 }
47
inlineLoopCondition(Block * block)48 static void inlineLoopCondition(Block* block) {
49 for (Node* n : block->nodes()) {
50 for (Block* b : n->blocks()) {
51 inlineLoopCondition(b);
52 }
53 if (n->kind() == prim::Loop) {
54 inlineLoopCondition(n);
55 }
56 }
57 }
58
InlineLoopCondition(std::shared_ptr<Graph> & graph)59 void InlineLoopCondition(std::shared_ptr<Graph>& graph) {
60 inlineLoopCondition(graph->block());
61 }
62
63 } // namespace torch::jit
64