xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/lift_closures.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/passes/lift_closures.h>
2 
3 #include <torch/csrc/jit/frontend/ir_emitter.h>
4 #include <torch/csrc/jit/ir/ir.h>
5 
6 #include <utility>
7 
8 namespace torch::jit {
9 
10 // Closures are initially emitted as prim::Closure nodes with a single block.
11 // Here, we convert the block to a subgraph, adding all closed over variables
12 // as a context tuple input to the closure node.
13 // At this point the closure has already undergone conversion to SSA,
14 // so closed over variables will just be value * that are not set in the
15 // closure block.
16 // Within the closure subgraph, the context tuple is unpacked and the unpacked
17 // values are used for closed over values.
liftClosure(Node * closure)18 static void liftClosure(Node* closure) {
19   auto block = closure->blocks().at(0);
20   auto subgraph = std::make_shared<Graph>();
21   // closures/forks can be nested, so use closure owning graph
22   auto g = closure->owningGraph();
23   Node* pack_context =
24       g->create(prim::TupleConstruct, {}, 1)->insertAfter(closure);
25   Value* context = subgraph->addInput("context");
26   // cannot use createTupleUnpack because the type is not known yet
27   Node* unpack_context =
28       subgraph->insertNode(subgraph->create(prim::TupleUnpack, {context}, 0));
29 
30   std::unordered_map<Value*, Value*> captures;
31   auto env = [&](Value* v) -> Value* {
32     auto it = captures.find(v);
33     if (it != captures.end()) {
34       return it->second;
35     }
36     pack_context->addInput(v);
37     Value* r = unpack_context->addOutput()->copyMetadata(v);
38     captures[v] = r;
39     return r;
40   };
41   subgraph->block()->cloneFrom(block, env);
42   auto context_type = TupleType::create(
43       fmap(pack_context->inputs(), [](Value* v) { return v->type(); }));
44   context->setType(context_type);
45   pack_context->output()->setType(context_type);
46   auto closure_tuple =
47       g->create(prim::TupleConstruct, {}, 1)->insertAfter(pack_context);
48   closure->output()->replaceAllUsesWith(closure_tuple->output());
49   closure_tuple->addInput(closure->output());
50   closure_tuple->addInput(pack_context->output());
51   closure_tuple->output()->setType(
52       TupleType::create({closure->output()->type(), std::move(context_type)}));
53   closure->eraseBlock(0);
54   closure->g_(attr::Subgraph, std::move(subgraph));
55   runCleanupPasses(closure->g(attr::Subgraph));
56 }
57 
liftClosures(Block * block)58 static void liftClosures(Block* block) {
59   for (auto it = block->nodes().begin(); it != block->nodes().end();) {
60     Node* n = *it;
61     it++;
62     switch (n->kind()) {
63       case prim::Closure: {
64         liftClosure(n);
65       } break;
66       default: {
67         for (Block* b : n->blocks()) {
68           liftClosures(b);
69         }
70       }
71     }
72   }
73 }
74 
liftClosures(const std::shared_ptr<Graph> & to_clean)75 void liftClosures(const std::shared_ptr<Graph>& to_clean) {
76   liftClosures(to_clean->block());
77 }
78 
79 } // namespace torch::jit
80