xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/inline_forked_closures.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/passes/inline_forked_closures.h>
2 
3 #include <torch/csrc/jit/frontend/ir_emitter.h>
4 
5 namespace torch::jit {
6 
7 // Closure nodes are emitted as a tuple of (function %, context tuple %)
8 // Inside the closure the closure is then unpacked so that all closed over
9 // values are set. A function closing over a and b would look like:
10 // def foo(context):
11 //  a, b = context
12 //
13 // To fork the closure, we need to set each value in the context tuple
14 // as an explicit input to the fork node, and then within the closure
15 // subgraph, replace the context unpacking value with the new graph input.
16 // fork(foo) ->
17 // def foo(a, b):
inlineForkedClosure(Node * fork_closure,NodeKind genKind)18 static void inlineForkedClosure(Node* fork_closure, NodeKind genKind) {
19   Node* function_context_node = fork_closure->input()->node();
20 
21   if (function_context_node->inputs().size() != 2 ||
22       function_context_node->inputs().at(0)->node()->kind() != prim::Closure ||
23       function_context_node->inputs().at(1)->node()->kind() !=
24           prim::TupleConstruct) {
25     throw ErrorReport(fork_closure->sourceRange()) << "Cannot fork this value";
26   }
27 
28   Node* function = function_context_node->inputs().at(0)->node();
29   Node* context = function_context_node->inputs().at(1)->node();
30   auto fork_graph = function->g(attr::Subgraph)->copy();
31   auto g = fork_closure->owningGraph();
32   Node* fork_node = g->create(genKind, 1)
33                         ->insertAfter(fork_closure)
34                         ->setSourceRange(fork_closure->sourceRange());
35 
36   if (fork_graph->inputs().size() != 1 ||
37       !fork_graph->inputs().at(0)->type()->cast<TupleType>()) {
38     throw ErrorReport(fork_node->sourceRange())
39         << "Cannot fork lambda with parameters";
40   }
41   auto fork_graph_context = fork_graph->inputs().at(0);
42   AT_ASSERT(fork_graph_context->uses().size() == 1);
43   auto fork_graph_unpack = fork_graph_context->uses().at(0).user;
44 
45   for (size_t i = 0; i < context->inputs().size(); ++i) {
46     auto cont_input = context->inputs().at(i);
47     fork_node->addInput(cont_input);
48     auto inp = fork_graph->insertInput(i)->copyMetadata(cont_input);
49     fork_graph_unpack->outputs().at(i)->replaceAllUsesWith(inp);
50   }
51   fork_graph_unpack->destroy();
52   fork_graph->eraseInput(fork_graph->inputs().size() - 1);
53   fork_node->output()->copyMetadata(fork_closure->output());
54   fork_closure->output()->replaceAllUsesWith(fork_node->output());
55   fork_closure->destroy();
56   fork_node->g_(attr::Subgraph, fork_graph);
57   runCleanupPasses(fork_graph);
58 }
59 
inlineForkedClosures(Block * block)60 static void inlineForkedClosures(Block* block) {
61   for (auto it = block->nodes().begin(); it != block->nodes().end();) {
62     Node* n = *it;
63     it++;
64     switch (n->kind()) {
65       case prim::forkClosure: {
66         inlineForkedClosure(n, prim::fork);
67       } break;
68       case prim::awaitableClosure: {
69         inlineForkedClosure(n, prim::awaitable);
70       } break;
71       default: {
72         for (Block* b : n->blocks()) {
73           inlineForkedClosures(b);
74         }
75       } break;
76     }
77   }
78 }
79 
inlineForkedClosures(std::shared_ptr<Graph> & to_clean)80 void inlineForkedClosures(std::shared_ptr<Graph>& to_clean) {
81   inlineForkedClosures(to_clean->block());
82 }
83 
84 } // namespace torch::jit
85