1 #include <torch/csrc/jit/jit_log.h>
2 #include <torch/csrc/jit/passes/inline_fork_wait.h>
3
4 namespace torch::jit {
5
InlineForkWait(Block * b,std::unordered_map<Value *,Value * > & future_remap)6 static void InlineForkWait(
7 Block* b,
8 std::unordered_map<Value*, Value*>& future_remap) {
9 auto nodes = b->nodes();
10
11 // Track the futures returned by prim::fork.
12 for (auto it = nodes.begin(); it != nodes.end(); it++) {
13 auto node = *it;
14 if (node->kind() != prim::fork) {
15 continue;
16 }
17 WithInsertPoint insert_guard(node);
18 auto graph = b->owningGraph();
19 auto subgraph = node->g(attr::Subgraph);
20
21 auto output = insertGraph(*graph, *subgraph, node->inputs());
22
23 future_remap[node->output()] = output.at(0);
24 }
25
26 // Remove aten::wait if its input future is returned by prim::fork.
27 auto reversed = b->nodes().reverse();
28 for (auto it = reversed.begin(); it != reversed.end(); it++) {
29 auto node = *it;
30 if (node->kind() == prim::fork) {
31 // Account for the case where the aten::wait call isn't present in
32 // the current graph.
33 node->output()->replaceAllUsesWith(future_remap.at(node->output()));
34 it.destroyCurrent();
35 } else if (node->kind() == aten::wait) {
36 AT_ASSERT(node->inputs().size() == 1);
37 AT_ASSERT(node->outputs().size() == 1);
38 // If the future does not map to a prim::fork, it could be
39 // returned from prim::rpc_async, which has side effect, so it shouldn't
40 // be dead code eliminated.
41 if (future_remap.count(node->input())) {
42 node->output()->replaceAllUsesWith(future_remap.at(node->input()));
43 it.destroyCurrent();
44 }
45 }
46 }
47
48 // Recursively inline fork/wait.
49 for (auto it = nodes.begin(); it != nodes.end(); it++) {
50 auto node = *it;
51 for (auto sub_b : node->blocks()) {
52 InlineForkWait(sub_b, future_remap);
53 }
54 }
55 }
56
InlineForkWait(const std::shared_ptr<Graph> & graph)57 void InlineForkWait(const std::shared_ptr<Graph>& graph) {
58 std::unordered_map<Value*, Value*> future_remap;
59 InlineForkWait(graph->block(), future_remap);
60 GRAPH_DUMP("After InlineForkWait: ", graph);
61 }
62
63 } // namespace torch::jit
64