xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/fixup_trace_scope_blocks.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/passes/fixup_trace_scope_blocks.h>
2 
3 #include <c10/util/irange.h>
4 #include <torch/csrc/jit/frontend/schema_matching.h>
5 #include <torch/csrc/jit/passes/canonicalize.h>
6 #include <torch/csrc/jit/passes/dead_code_elimination.h>
7 #include <torch/csrc/jit/passes/inliner.h>
8 #include <torch/csrc/jit/passes/lower_tuples.h>
9 
10 #include <algorithm>
11 
12 namespace torch::jit {
13 
14 namespace {
15 
isEligibleNode(Node * n)16 bool isEligibleNode(Node* n) {
17   return n->kind() == prim::TracedModuleForward ||
18       n->kind() == prim::TracedFork;
19 }
20 
21 // This pass does several things:
22 // 1) It looks at TracedModuleForward nodes and resolves the type of `self`
23 //    for that (to-be) method call. It adds an input of that type to the
24 //    block, and adds the TracedAttr value corresponding to that `self`
25 //    value as a Node input. This ensures `self` is an explicit Use on
26 //    the node, a property we take advantage of downstream. Example:
27 // 2) Convert all references to prim::TracedAttr values to prim::GetAttr
28 //    calls in the tightest scope possible. Concretely, for each use of
29 //    a prim::TracedAttr value, we compare the scope of that attribute
30 //    to the scope of the Use. We emit GetAttr nodes for all atoms
31 //    that are not shared between the two. For example, if an
32 //    attribute `f.param` is referenced in scope `f`, we emit a
33 //    GetAttr[name="param"](%self) node in the `f` block, where
34 //    `self` is the previously-added `self` argument to the block.
35 // 3) Destroy all the prim::TracedAttr nodes, as they should have
36 //    no more uses.
37 //
38 // A quick example:
39 //
40 //
41 // Input graph:
42 //
43 //     graph(%self : ClassType<Module>,
44 //           %x : Float(3, 4)):
45 //       %1 : bool = prim::TracedAttr[scope="__module.training"]()
46 //       %2 : ClassType<Module> = prim::TracedAttr[scope="__module.f"]()
47 //       %3 : Float(4, 4) = prim::TracedAttr[scope="__module.f.param"]()
48 //       %4 : bool = prim::TracedAttr[scope="__module.f.training"]()
49 //       = prim::TracedModuleForward[scope="__module.f"](),
50 //         block0():
51 //           %6 : Float(3, 4) = aten::mm(%x, %3),
52 //           -> ()
53 //       return (%6)
54 //
55 // The diff after step (1)
56 //
57 //     -   = prim::TracedModuleForward[scope="__module.f"](),
58 //     -    block0():
59 //     +   = prim::TracedModuleForward[scope="__module.f"](%2),
60 //     +    block0(%self : ClassType<Module>):
61 //
62 // The diff after step (2)
63 //
64 //       graph(%self.1 : ClassType<Module>,
65 //             %x : Float(3, 4)):
66 //       +  %9 : ClassType<Module> = prim::GetAttr[name="f"](%self.1)
67 //         %1 : bool = prim::TracedAttr[scope="__module.training"]()
68 //           <....>
69 //         %4 : bool = prim::TracedAttr[scope="__module.f.training"]()
70 //       -   = prim::TracedModuleForward[scope="__module.f"](%2),
71 //       +   = prim::TracedModuleForward[scope="__module.f"](%9),
72 //           block0(%self : ClassType<Module>):
73 //       -      %6 : Float(3, 4) = aten::mm(%x, %3),
74 //       +      %8 : Tensor = prim::GetAttr[name="param"](%self)
75 //       +      %6 : Float(3, 4) = aten::mm(%x, %8),
76 //             -> ()
77 //         return (%6)
78 //
79 // The diff after step (3)
80 //
81 //       -  %1 : bool = prim::TracedAttr[scope="__module.training"]()
82 //       -  %2 : ClassType<Module> = prim::TracedAttr[scope="__module.f"]()
83 //       -  %3 : Float(4, 4) = prim::TracedAttr[scope="__module.f.param"]()
84 //       -  %4 : bool = prim::TracedAttr[scope="__module.f.training"]()
85 struct ConvertTracedAttrReferences {
runtorch::jit::__anon8e087a7d0111::ConvertTracedAttrReferences86   void run(const std::shared_ptr<Graph>& graph) {
87     // Build a table mapping--for each TracedAttr node--the
88     // qualified name of the attribute to the Value* output
89     // of the Node.
90     buildAttrMap(graph);
91     // Step 1
92     addSelfArgToTracedForwardNodes(graph->block());
93     // Step 2
94     convertAttrReferencesToLocalGetAttrs(
95         graph->block(), "__module", graph->inputs()[0]);
96     // Step 3
97     destroyTracedAttrNodes(graph);
98   }
99 
100  private:
buildAttrMaptorch::jit::__anon8e087a7d0111::ConvertTracedAttrReferences101   void buildAttrMap(const std::shared_ptr<Graph>& graph) {
102     for (Node* n : graph->nodes()) {
103       if (n->kind() == prim::TracedAttr) {
104         attr_qualname_to_value[n->s(attr::scope)] = n->output();
105       }
106     }
107   }
108 
addSelfArgToTracedForwardNodestorch::jit::__anon8e087a7d0111::ConvertTracedAttrReferences109   void addSelfArgToTracedForwardNodes(Block* b) {
110     for (Node* n : b->nodes()) {
111       if (n->kind() == prim::TracedModuleForward) {
112         n->addInput(attr_qualname_to_value.at(n->s(attr::scope)));
113         n->blocks()[0]->addInput("self")->setType(
114             attr_qualname_to_value.at(n->s(attr::scope))->type());
115         addSelfArgToTracedForwardNodes(n->blocks()[0]);
116       }
117       if (n->kind() == prim::TracedFork) {
118         addSelfArgToTracedForwardNodes(n->blocks()[0]);
119       }
120     }
121   }
122 
123   // This is a recursive function that descends down all blocks in the Graph
124   // (NB: not just TracedModuleForward blocks). Each descension has a
125   // corresponding `prefix`, i.e. the qualified name of the scope this
126   // Block represents (or the scope in which this block resides for
127   // non-TracedModuleForward nodes). We use this prefix to make decisions
128   // about whether to emit a GetAttr node for an attribute reference, or
129   // to defer that emission to the caller (in the case where an attribute
130   // reference does not reside in the `prefix` scope).
convertAttrReferencesToLocalGetAttrstorch::jit::__anon8e087a7d0111::ConvertTracedAttrReferences131   std::vector<Value*> convertAttrReferencesToLocalGetAttrs(
132       Block* b,
133       const c10::QualifiedName& prefix,
134       Value* self) {
135     // Store away Value*'s which are references to TracedAttr's which are
136     // not in the `prefix` scope. We pass this back to the caller, who
137     // should add these Values as explicit inputs as well as inductively
138     // make the same decision on those Values.
139     std::vector<Value*> unresolved_tracedattrs;
140     // To ensure we don't emit redundant GetAttr Nodes in a given scope,
141     // we maintain this map of original TracedAttr Value* to the Value*
142     // corresponding to the GetAttr for that attribute.
143     // We don't rely on CSE here because we currently can't reason about
144     // the correctness of CSE over GetAttr Nodes (i think)
145     std::unordered_map<Value*, Value*> local_remaps;
146 
147     for (Node* n : b->nodes()) {
148       // The only difference between these two branches is for
149       // TracedModuleForward we advance the scope, but for other
150       // Nodes with Blocks we don't
151       if (n->kind() == prim::TracedModuleForward) {
152         auto sub_unresolved = convertAttrReferencesToLocalGetAttrs(
153             n->blocks()[0], n->s(attr::scope), n->blocks()[0]->inputs()[0]);
154         for (Value* v : sub_unresolved) {
155           n->addInput(v);
156         }
157       } else if (!n->blocks().empty()) {
158         for (Block* sub_block : n->blocks()) {
159           auto sub_unresolved =
160               convertAttrReferencesToLocalGetAttrs(sub_block, prefix, self);
161           for (Value* v : sub_unresolved) {
162             n->addInput(v);
163           }
164         }
165       }
166 
167       for (size_t inp_idx = 0; inp_idx < n->inputs().size(); ++inp_idx) {
168         Value* inp = n->input(inp_idx);
169 
170         // Short circuit: if we've already emitted a new Value for this
171         // attribute, just use that.
172         if (local_remaps.count(inp)) {
173           n->replaceInput(inp_idx, local_remaps[inp]);
174           continue;
175         }
176 
177         WithInsertPoint guard(b->param_node()->next());
178         replaceTracedAttrInputOnNode(
179             n, inp_idx, prefix, self, local_remaps, unresolved_tracedattrs);
180       } // for (Value *inp : n->inputs())
181     } // for (Node *n : b->nodes())
182     return unresolved_tracedattrs;
183   }
184 
replaceTracedAttrInputOnNodetorch::jit::__anon8e087a7d0111::ConvertTracedAttrReferences185   void replaceTracedAttrInputOnNode(
186       Node* n,
187       size_t inp_idx,
188       const c10::QualifiedName& prefix,
189       Value* self,
190       std::unordered_map<Value*, Value*>& local_remaps,
191       std::vector<Value*>& unresolved_tracedattrs) {
192     auto inp = n->inputs()[inp_idx];
193     auto inp_node = inp->node();
194     auto prefix_atoms = prefix.atoms();
195     if (inp_node->kind() == prim::TracedAttr) {
196       auto attr_qualname = c10::QualifiedName(inp_node->s(attr::scope));
197       if (prefix.isPrefixOf(attr_qualname)) {
198         // Prefix case: the attribute resides in this scope or a
199         // sub-scope. Continually emit GetAttr nodes until we've reached
200         // the proper attribute.
201         auto attr_atoms = attr_qualname.atoms();
202         Value* replaced_value = self;
203         for (const auto i : c10::irange(attr_atoms.size())) {
204           if (i < prefix_atoms.size()) {
205             TORCH_INTERNAL_ASSERT(attr_atoms[i] == prefix_atoms[i]);
206           } else {
207             replaced_value = n->owningBlock()->owningGraph()->insertGetAttr(
208                 replaced_value, attr_atoms[i]);
209           } // if (i < prefix_atoms.size())
210         } // for(const auto i : c10::irange(attr_atoms.size()))
211         n->replaceInput(inp_idx, replaced_value);
212         local_remaps[inp] = replaced_value;
213       } else {
214         // Non-prefix case: this is a use of an attribute somewhere
215         // higher in the Module hierarchy. Add a captured input to
216         // the block for this attribute and add to the vector of
217         // Value*'s for the caller to handle.
218         Value* remapped = n->owningBlock()->addInput()->copyMetadata(inp);
219         n->replaceInput(inp_idx, remapped);
220         unresolved_tracedattrs.push_back(inp);
221         local_remaps[inp] = remapped;
222       } // if (prefix.isPrefixOf(attr_qualname))
223     } // if (inp_node->kind() == prim::TracedAttr)
224   }
225 
226   // The previous pass should have deleted all uses of TracedAttr
227   // nodes. Let's explicitly delete them here.
destroyTracedAttrNodestorch::jit::__anon8e087a7d0111::ConvertTracedAttrReferences228   void destroyTracedAttrNodes(const std::shared_ptr<Graph>& graph) {
229     for (auto& kv : attr_qualname_to_value) {
230       kv.second->node()->destroy();
231     }
232   }
233 
234   // For each prim::TracedAttr, record the `scope` value mapped
235   // to the Value* in the graph for that attribute.
236   std::unordered_map<std::string, Value*> attr_qualname_to_value;
237 };
238 
239 // Iterate through all the nodes in program order and--for each use--
240 // if the Value referenced is not in a scope that dominates the node,
241 // add block and Node outputs to lift it into a scope in which
242 // it dominates the Use.
243 struct MakeDefsDominateUses {
244   MakeDefsDominateUses() = default;
245 
runtorch::jit::__anon8e087a7d0111::MakeDefsDominateUses246   void run(Block* b) {
247     processNode(b->param_node(), b);
248     for (Node* n : b->nodes()) {
249       processNode(n, b);
250     }
251     processNode(b->return_node(), b);
252   }
253 
254  private:
processNodetorch::jit::__anon8e087a7d0111::MakeDefsDominateUses255   void processNode(Node* n, Block* b) {
256     for (size_t i = 0; i < n->inputs().size(); ++i) {
257       Value* inp = n->inputs()[i];
258 
259       // Already lifted to this level by a previously processed Use, switch to
260       // remapped value
261       Value* inp_remapped = inp;
262       if (remap.count(inp_remapped)) {
263         n->replaceInput(i, remap[inp_remapped]);
264         inp_remapped = remap[inp_remapped];
265       }
266 
267       // This conditional isn't strictly necessary, but saves a lot of
268       // computation in the common case that we're using a local value.
269       if (inp_remapped->node()->owningBlock() != b) {
270         // Find the common ancestor block between this node and the node that
271         // produced this input. For this input Use to be valid, the Value's
272         // def must be present in this common ancestor node.
273         Block* common_ancestor =
274             n->findCommonAncestorBlockWith(inp_remapped->node());
275 
276         Value* v_itr = inp_remapped;
277         Block* b_itr = inp_remapped->node()->owningBlock();
278 
279         // Starting from the initial def for this input, iterate to
280         // wider and wider blocks, adding Block outputs and Node outputs
281         // along the way. Then, log the lifted values in the remap table
282         // so we can make subsequent Uses refer to the lifted value, if
283         // the domination condition is met.
284         while (b_itr != common_ancestor) {
285           b_itr->registerOutput(v_itr);
286           Value* remapped =
287               b_itr->owningNode()->addOutput()->setType(v_itr->type());
288           v_itr = remapped;
289           b_itr = b_itr->owningNode()->owningBlock();
290         }
291         // From now on, references to `inp` will be replaced with
292         // references to `v_itr`, the lifted Value
293         remap[inp] = v_itr;
294         n->replaceInput(i, remap[inp]);
295       }
296     }
297 
298     if (isEligibleNode(n)) {
299       run(n->blocks()[0]);
300     }
301   }
302 
303   // This holds the mapping between a Value* we would see in a Use
304   // and the lifted value, if present. We use this to ensure that
305   // Uses refer to a Value* that is in a dominating scope.
306   using RemappingTable = std::unordered_map<Value*, Value*>;
307   RemappingTable remap;
308 };
309 
310 // For all blocks except graph->block(), convert multiple block
311 // returns to a TupleConstruct. This is required for turning the
312 // blocks into Methods. (and in the case that self is nullptr,
313 // it is required to properly inline the blocks).
convertReturnsToTuples(Block * b)314 void convertReturnsToTuples(Block* b) {
315   for (Node* n : b->nodes()) {
316     if (n->kind() == prim::TracedFork) {
317       convertReturnsToTuples(n->blocks()[0]);
318     } else if (n->kind() == prim::TracedModuleForward) {
319       TORCH_INTERNAL_ASSERT(n->blocks().size() == 1);
320       convertReturnsToTuples(n->blocks()[0]);
321 
322       Graph* g = b->owningGraph();
323       Block* sub_block = n->blocks()[0];
324       if (sub_block->outputs().size() > 1) {
325         {
326           // Make block returns go through a Tuple
327           WithInsertPoint guard(sub_block->return_node());
328           Node* return_tup =
329               g->insertNode(g->createTuple(sub_block->outputs()));
330           while (!sub_block->outputs().empty()) {
331             sub_block->eraseOutput(0);
332           }
333           sub_block->registerOutput(return_tup->output());
334         }
335 
336         // Make node outputs a single tuple;
337         std::vector<TypePtr> types;
338         for (size_t i = 0; i < n->outputs().size(); ++i) {
339           types.push_back(n->output(i)->type());
340         }
341         Value* tup_output = n->addOutput()->setType(TupleType::create(types));
342         Node* tup_unpack = g->createTupleUnpack(tup_output)->insertAfter(n);
343         for (size_t i = 0; i < tup_unpack->outputs().size(); ++i) {
344           auto rev_idx = tup_unpack->outputs().size() - i - 1;
345           n->output(rev_idx)->replaceAllUsesWith(tup_unpack->output(rev_idx));
346           n->eraseOutput(rev_idx);
347         }
348       } else if (sub_block->outputs().empty()) {
349         WithInsertPoint guard(sub_block->return_node());
350         sub_block->registerOutput(g->insertNode(g->createNone())->output());
351         n->addOutput()->setType(NoneType::get());
352       }
353     }
354   }
355 }
356 
357 // Lambda lift Values (i.e. add Graph inputs for the purpose of
358 // referencing values that dominate the block) and convert
359 // the block to a Graph. blocks()[0] on each TracedModuleForward then
360 // appears as a Graph attribute attr::Subgraph
lambdaLiftBlocksAndConvertToGraph(Block * b)361 void lambdaLiftBlocksAndConvertToGraph(Block* b) {
362   for (Node* n : b->nodes()) {
363     if (isEligibleNode(n)) {
364       lambdaLiftBlocksAndConvertToGraph(n->blocks()[0]);
365 
366       auto graph = std::make_shared<Graph>();
367       std::unordered_map<Value*, Value*> remaps;
368       graph->block()->cloneFrom(n->blocks()[0], [&](Value* v) {
369         if (!remaps.count(v)) {
370           remaps[v] = graph->addInput()->copyMetadata(v);
371           n->addInput(v);
372         }
373         return remaps[v];
374       });
375       LintGraph(graph);
376       n->g_(attr::Subgraph, graph);
377       n->eraseBlock(0);
378     }
379   }
380 }
381 
382 // Find a unique name to add this method as
383 // We try {method_name}, {method_name}1, {method_name}2, ...
mangleMethodName(const std::string & method_name,const ClassTypePtr & mod_type)384 std::string mangleMethodName(
385     const std::string& method_name,
386     const ClassTypePtr& mod_type) {
387   for (size_t method_idx = 0;; method_idx++) {
388     auto mangled = method_name;
389     if (method_idx != 0) {
390       mangled += std::to_string(method_idx);
391     }
392     bool found = false;
393     for (Function* fn : mod_type->methods()) {
394       if (fn->name() == mangled) {
395         found = true;
396         break;
397       }
398     }
399     if (!found) {
400       return mangled;
401     }
402   }
403   TORCH_INTERNAL_ASSERT(false);
404 }
405 
406 // Register the attr::Subgraph Graph values as Functions in the
407 // class compilation unit and register that Function as a method
408 // on the corresponding Module in the Module hierarchy. Note that we
409 // unique the methods by naming them forward, forward1, forward2...
createMethodCalls(const std::shared_ptr<Graph> & g)410 void createMethodCalls(const std::shared_ptr<Graph>& g) {
411   for (auto node_itr = g->nodes().begin(); node_itr != g->nodes().end();) {
412     Node* n = *node_itr++;
413     if (n->kind() == prim::TracedFork) {
414       createMethodCalls(n->g(attr::Subgraph));
415     } else if (n->kind() == prim::TracedModuleForward) {
416       WithInsertPoint ip(n);
417 
418       ClassTypePtr callee_mod_type = n->input(0)->type()->expect<ClassType>();
419 
420       createMethodCalls(n->g(attr::Subgraph));
421 
422       auto mangled_method_name = mangleMethodName("forward", callee_mod_type);
423       auto qualname = c10::QualifiedName(
424           callee_mod_type->name().value(), mangled_method_name);
425       Function* f = callee_mod_type->compilation_unit()->create_function(
426           qualname, n->g(attr::Subgraph));
427       callee_mod_type->addMethod(f);
428 
429       std::vector<NamedValue> nvs;
430       for (Value* i : n->inputs()) {
431         nvs.emplace_back(i->node()->sourceRange(), i);
432       }
433       auto schema = matchSchema(f->getSchema(), n->sourceRange(), *g, nvs, {});
434       Value* retval = g->insertMethodCall(f->qualname().name(), schema);
435       n->output()->replaceAllUsesWith(retval);
436       n->destroy();
437     }
438   }
439 }
440 
inlineScopeBlocks(Block * b)441 void inlineScopeBlocks(Block* b) {
442   for (auto n_itr = b->nodes().begin(); n_itr != b->nodes().end();) {
443     Node* n = *n_itr++;
444     for (Block* sub_b : n->blocks()) {
445       inlineScopeBlocks(sub_b);
446     }
447     if (n->kind() == prim::TracedModuleForward) {
448       // Convert the block to a graph so we can inline it
449       auto graph = std::make_shared<Graph>();
450       std::unordered_map<Value*, Value*> remaps;
451       graph->block()->cloneFrom(n->blocks()[0], [&](Value* v) {
452         remaps[v] = graph->block()->addInput()->copyMetadata(v);
453         n->addInput(v);
454         return remaps[v];
455       });
456 
457       WithInsertPoint insert_point(n);
458       AT_ASSERT(n->inputs().size() == graph->inputs().size());
459       auto new_outputs = insertGraph(*n->owningGraph(), *graph, n->inputs());
460       const auto& old_outputs = n->outputs();
461 
462       AT_ASSERT(new_outputs.size() == old_outputs.size());
463       for (const auto i : c10::irange(old_outputs.size())) {
464         old_outputs[i]->replaceAllUsesWith(new_outputs[i]);
465       }
466       n->destroy();
467     }
468   }
469 }
470 
convertTracedForksToRealForks(const std::shared_ptr<Graph> & g)471 void convertTracedForksToRealForks(const std::shared_ptr<Graph>& g) {
472   for (auto itr = g->nodes().begin(); itr != g->nodes().end();) {
473     Node* n = *itr++;
474     if (n->kind() == prim::TracedFork) {
475       WithInsertPoint guard(n);
476       Node* new_fork_node =
477           g->insertNode(g->create(prim::fork, n->outputs().size()))
478               ->copyAttributes(*n);
479       for (Value* i : n->inputs()) {
480         new_fork_node->addInput(i);
481       }
482       for (size_t i = 0; i < new_fork_node->outputs().size(); ++i) {
483         new_fork_node->outputs()[i]->copyMetadata(n->outputs()[i]);
484         n->outputs()[i]->replaceAllUsesWith(new_fork_node->outputs()[i]);
485       }
486       n->destroy();
487     }
488   }
489 }
490 
491 // Run a few clean-up passes to make the graph a bit cleaner.
runCleanupPasses(const std::shared_ptr<Graph> & g)492 void runCleanupPasses(const std::shared_ptr<Graph>& g) {
493   for (Node* n : g->nodes()) {
494     if (n->kind() == prim::TracedFork) {
495       auto subgraph = n->g(attr::Subgraph);
496       if (getInlineEverythingMode()) {
497         Inline(*subgraph);
498       }
499       convertTracedForksToRealForks(subgraph);
500       LowerSimpleTuples(subgraph);
501       EliminateDeadCode(subgraph);
502       LintGraph(subgraph);
503     }
504   }
505   if (getInlineEverythingMode()) {
506     Inline(*g);
507   }
508   convertTracedForksToRealForks(g);
509   LowerSimpleTuples(g);
510   EliminateDeadCode(g);
511   LintGraph(g);
512 }
513 
runCleanupPasses(Module * m)514 void runCleanupPasses(Module* m) {
515   auto methods = m->get_methods();
516   for (auto module : m->children()) {
517     runCleanupPasses(&module);
518   }
519   for (auto& method : methods) {
520     runCleanupPasses(method.graph());
521   }
522 }
523 
524 } // namespace
525 
FixupTraceScopeBlocks(std::shared_ptr<Graph> & graph,Module * self)526 void FixupTraceScopeBlocks(std::shared_ptr<Graph>& graph, Module* self) {
527   if (self) {
528     ConvertTracedAttrReferences().run(graph);
529   } else {
530     for (Node* n : graph->nodes()) {
531       TORCH_INTERNAL_ASSERT(n->kind() != prim::TracedAttr);
532     }
533   }
534   MakeDefsDominateUses().run(graph->block());
535   convertReturnsToTuples(graph->block());
536   if (!self) {
537     // We have no Module, so we're just going to inline everything.
538     // This should give us a totally flat graph.
539     inlineScopeBlocks(graph->block());
540     // For TracedFork nodes
541     lambdaLiftBlocksAndConvertToGraph(graph->block());
542     runCleanupPasses(graph);
543   } else {
544     lambdaLiftBlocksAndConvertToGraph(graph->block());
545     createMethodCalls(graph);
546     runCleanupPasses(self);
547     // `graph` isn't referenced in `self` yet, so we need to run
548     // this separately
549     runCleanupPasses(graph);
550   }
551 }
552 
553 } // namespace torch::jit
554