1 #include <torch/csrc/jit/passes/add_if_then_else.h> 2 #include <torch/csrc/jit/runtime/graph_iterator.h> 3 4 namespace torch::jit { 5 6 namespace { 7 hasNoNodes(Block * block)8bool hasNoNodes(Block* block) { 9 auto nodes = block->nodes(); 10 return nodes.begin() == nodes.end(); 11 } 12 hasTrivialSubBlocks(Node * node)13bool hasTrivialSubBlocks(Node* node) { 14 const auto blocks = node->blocks(); 15 TORCH_DCHECK_EQ(blocks.size(), 2); 16 17 return hasNoNodes(blocks[0]) && hasNoNodes(blocks[1]); 18 } 19 20 } // namespace 21 AddIfThenElseOp(std::shared_ptr<Graph> & graph)22bool AddIfThenElseOp(std::shared_ptr<Graph>& graph) { 23 std::vector<Node*> to_replace; 24 DepthFirstGraphNodeIterator graph_it(graph); 25 for (auto* node = graph_it.next(); node != nullptr; node = graph_it.next()) { 26 if (node->kind() != prim::If) { 27 continue; 28 } 29 if (node->outputs().size() != 1) { 30 continue; 31 } 32 if (hasTrivialSubBlocks(node)) { 33 to_replace.push_back(node); 34 } 35 } 36 37 for (auto* node : to_replace) { 38 auto* if_then_else_node = graph->create(prim::IfThenElse, 1); 39 if_then_else_node->addInput(node->input()); 40 auto blocks = node->blocks(); 41 if_then_else_node->addInput(blocks[0]->return_node()->input()); 42 if_then_else_node->addInput(blocks[1]->return_node()->input()); 43 44 if_then_else_node->insertBefore(node); 45 if_then_else_node->output()->copyMetadata(node->output()); 46 47 node->output()->replaceAllUsesWith(if_then_else_node->output()); 48 node->destroy(); 49 } 50 return !to_replace.empty(); 51 } 52 53 } // namespace torch::jit 54