xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/add_if_then_else.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/passes/add_if_then_else.h>
2 #include <torch/csrc/jit/runtime/graph_iterator.h>
3 
4 namespace torch::jit {
5 
6 namespace {
7 
hasNoNodes(Block * block)8 bool hasNoNodes(Block* block) {
9   auto nodes = block->nodes();
10   return nodes.begin() == nodes.end();
11 }
12 
hasTrivialSubBlocks(Node * node)13 bool hasTrivialSubBlocks(Node* node) {
14   const auto blocks = node->blocks();
15   TORCH_DCHECK_EQ(blocks.size(), 2);
16 
17   return hasNoNodes(blocks[0]) && hasNoNodes(blocks[1]);
18 }
19 
20 } // namespace
21 
AddIfThenElseOp(std::shared_ptr<Graph> & graph)22 bool AddIfThenElseOp(std::shared_ptr<Graph>& graph) {
23   std::vector<Node*> to_replace;
24   DepthFirstGraphNodeIterator graph_it(graph);
25   for (auto* node = graph_it.next(); node != nullptr; node = graph_it.next()) {
26     if (node->kind() != prim::If) {
27       continue;
28     }
29     if (node->outputs().size() != 1) {
30       continue;
31     }
32     if (hasTrivialSubBlocks(node)) {
33       to_replace.push_back(node);
34     }
35   }
36 
37   for (auto* node : to_replace) {
38     auto* if_then_else_node = graph->create(prim::IfThenElse, 1);
39     if_then_else_node->addInput(node->input());
40     auto blocks = node->blocks();
41     if_then_else_node->addInput(blocks[0]->return_node()->input());
42     if_then_else_node->addInput(blocks[1]->return_node()->input());
43 
44     if_then_else_node->insertBefore(node);
45     if_then_else_node->output()->copyMetadata(node->output());
46 
47     node->output()->replaceAllUsesWith(if_then_else_node->output());
48     node->destroy();
49   }
50   return !to_replace.empty();
51 }
52 
53 } // namespace torch::jit
54