1 #include <torch/csrc/jit/passes/inline_forked_closures.h>
2
3 #include <torch/csrc/jit/frontend/ir_emitter.h>
4
5 namespace torch::jit {
6
7 // Closure nodes are emitted as a tuple of (function %, context tuple %)
8 // Inside the closure the closure is then unpacked so that all closed over
9 // values are set. A function closing over a and b would look like:
10 // def foo(context):
11 // a, b = context
12 //
13 // To fork the closure, we need to set each value in the context tuple
14 // as an explicit input to the fork node, and then within the closure
15 // subgraph, replace the context unpacking value with the new graph input.
16 // fork(foo) ->
17 // def foo(a, b):
inlineForkedClosure(Node * fork_closure,NodeKind genKind)18 static void inlineForkedClosure(Node* fork_closure, NodeKind genKind) {
19 Node* function_context_node = fork_closure->input()->node();
20
21 if (function_context_node->inputs().size() != 2 ||
22 function_context_node->inputs().at(0)->node()->kind() != prim::Closure ||
23 function_context_node->inputs().at(1)->node()->kind() !=
24 prim::TupleConstruct) {
25 throw ErrorReport(fork_closure->sourceRange()) << "Cannot fork this value";
26 }
27
28 Node* function = function_context_node->inputs().at(0)->node();
29 Node* context = function_context_node->inputs().at(1)->node();
30 auto fork_graph = function->g(attr::Subgraph)->copy();
31 auto g = fork_closure->owningGraph();
32 Node* fork_node = g->create(genKind, 1)
33 ->insertAfter(fork_closure)
34 ->setSourceRange(fork_closure->sourceRange());
35
36 if (fork_graph->inputs().size() != 1 ||
37 !fork_graph->inputs().at(0)->type()->cast<TupleType>()) {
38 throw ErrorReport(fork_node->sourceRange())
39 << "Cannot fork lambda with parameters";
40 }
41 auto fork_graph_context = fork_graph->inputs().at(0);
42 AT_ASSERT(fork_graph_context->uses().size() == 1);
43 auto fork_graph_unpack = fork_graph_context->uses().at(0).user;
44
45 for (size_t i = 0; i < context->inputs().size(); ++i) {
46 auto cont_input = context->inputs().at(i);
47 fork_node->addInput(cont_input);
48 auto inp = fork_graph->insertInput(i)->copyMetadata(cont_input);
49 fork_graph_unpack->outputs().at(i)->replaceAllUsesWith(inp);
50 }
51 fork_graph_unpack->destroy();
52 fork_graph->eraseInput(fork_graph->inputs().size() - 1);
53 fork_node->output()->copyMetadata(fork_closure->output());
54 fork_closure->output()->replaceAllUsesWith(fork_node->output());
55 fork_closure->destroy();
56 fork_node->g_(attr::Subgraph, fork_graph);
57 runCleanupPasses(fork_graph);
58 }
59
inlineForkedClosures(Block * block)60 static void inlineForkedClosures(Block* block) {
61 for (auto it = block->nodes().begin(); it != block->nodes().end();) {
62 Node* n = *it;
63 it++;
64 switch (n->kind()) {
65 case prim::forkClosure: {
66 inlineForkedClosure(n, prim::fork);
67 } break;
68 case prim::awaitableClosure: {
69 inlineForkedClosure(n, prim::awaitable);
70 } break;
71 default: {
72 for (Block* b : n->blocks()) {
73 inlineForkedClosures(b);
74 }
75 } break;
76 }
77 }
78 }
79
inlineForkedClosures(std::shared_ptr<Graph> & to_clean)80 void inlineForkedClosures(std::shared_ptr<Graph>& to_clean) {
81 inlineForkedClosures(to_clean->block());
82 }
83
84 } // namespace torch::jit
85