xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/inline_fork_wait.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/jit_log.h>
2 #include <torch/csrc/jit/passes/inline_fork_wait.h>
3 
4 namespace torch::jit {
5 
InlineForkWait(Block * b,std::unordered_map<Value *,Value * > & future_remap)6 static void InlineForkWait(
7     Block* b,
8     std::unordered_map<Value*, Value*>& future_remap) {
9   auto nodes = b->nodes();
10 
11   // Track the futures returned by prim::fork.
12   for (auto it = nodes.begin(); it != nodes.end(); it++) {
13     auto node = *it;
14     if (node->kind() != prim::fork) {
15       continue;
16     }
17     WithInsertPoint insert_guard(node);
18     auto graph = b->owningGraph();
19     auto subgraph = node->g(attr::Subgraph);
20 
21     auto output = insertGraph(*graph, *subgraph, node->inputs());
22 
23     future_remap[node->output()] = output.at(0);
24   }
25 
26   // Remove aten::wait if its input future is returned by prim::fork.
27   auto reversed = b->nodes().reverse();
28   for (auto it = reversed.begin(); it != reversed.end(); it++) {
29     auto node = *it;
30     if (node->kind() == prim::fork) {
31       // Account for the case where the aten::wait call isn't present in
32       // the current graph.
33       node->output()->replaceAllUsesWith(future_remap.at(node->output()));
34       it.destroyCurrent();
35     } else if (node->kind() == aten::wait) {
36       AT_ASSERT(node->inputs().size() == 1);
37       AT_ASSERT(node->outputs().size() == 1);
38       // If the future does not map to a prim::fork, it could be
39       // returned from prim::rpc_async, which has side effect, so it shouldn't
40       // be dead code eliminated.
41       if (future_remap.count(node->input())) {
42         node->output()->replaceAllUsesWith(future_remap.at(node->input()));
43         it.destroyCurrent();
44       }
45     }
46   }
47 
48   // Recursively inline fork/wait.
49   for (auto it = nodes.begin(); it != nodes.end(); it++) {
50     auto node = *it;
51     for (auto sub_b : node->blocks()) {
52       InlineForkWait(sub_b, future_remap);
53     }
54   }
55 }
56 
InlineForkWait(const std::shared_ptr<Graph> & graph)57 void InlineForkWait(const std::shared_ptr<Graph>& graph) {
58   std::unordered_map<Value*, Value*> future_remap;
59   InlineForkWait(graph->block(), future_remap);
60   GRAPH_DUMP("After InlineForkWait: ", graph);
61 }
62 
63 } // namespace torch::jit
64