1 #include <torch/csrc/jit/passes/fixup_trace_scope_blocks.h>
2
3 #include <c10/util/irange.h>
4 #include <torch/csrc/jit/frontend/schema_matching.h>
5 #include <torch/csrc/jit/passes/canonicalize.h>
6 #include <torch/csrc/jit/passes/dead_code_elimination.h>
7 #include <torch/csrc/jit/passes/inliner.h>
8 #include <torch/csrc/jit/passes/lower_tuples.h>
9
10 #include <algorithm>
11
12 namespace torch::jit {
13
14 namespace {
15
isEligibleNode(Node * n)16 bool isEligibleNode(Node* n) {
17 return n->kind() == prim::TracedModuleForward ||
18 n->kind() == prim::TracedFork;
19 }
20
21 // This pass does several things:
22 // 1) It looks at TracedModuleForward nodes and resolves the type of `self`
23 // for that (to-be) method call. It adds an input of that type to the
24 // block, and adds the TracedAttr value corresponding to that `self`
25 // value as a Node input. This ensures `self` is an explicit Use on
26 // the node, a property we take advantage of downstream. Example:
27 // 2) Convert all references to prim::TracedAttr values to prim::GetAttr
28 // calls in the tightest scope possible. Concretely, for each use of
29 // a prim::TracedAttr value, we compare the scope of that attribute
30 // to the scope of the Use. We emit GetAttr nodes for all atoms
31 // that are not shared between the two. For example, if an
32 // attribute `f.param` is referenced in scope `f`, we emit a
33 // GetAttr[name="param"](%self) node in the `f` block, where
34 // `self` is the previously-added `self` argument to the block.
35 // 3) Destroy all the prim::TracedAttr nodes, as they should have
36 // no more uses.
37 //
38 // A quick example:
39 //
40 //
41 // Input graph:
42 //
43 // graph(%self : ClassType<Module>,
44 // %x : Float(3, 4)):
45 // %1 : bool = prim::TracedAttr[scope="__module.training"]()
46 // %2 : ClassType<Module> = prim::TracedAttr[scope="__module.f"]()
47 // %3 : Float(4, 4) = prim::TracedAttr[scope="__module.f.param"]()
48 // %4 : bool = prim::TracedAttr[scope="__module.f.training"]()
49 // = prim::TracedModuleForward[scope="__module.f"](),
50 // block0():
51 // %6 : Float(3, 4) = aten::mm(%x, %3),
52 // -> ()
53 // return (%6)
54 //
55 // The diff after step (1)
56 //
57 // - = prim::TracedModuleForward[scope="__module.f"](),
58 // - block0():
59 // + = prim::TracedModuleForward[scope="__module.f"](%2),
60 // + block0(%self : ClassType<Module>):
61 //
62 // The diff after step (2)
63 //
64 // graph(%self.1 : ClassType<Module>,
65 // %x : Float(3, 4)):
66 // + %9 : ClassType<Module> = prim::GetAttr[name="f"](%self.1)
67 // %1 : bool = prim::TracedAttr[scope="__module.training"]()
68 // <....>
69 // %4 : bool = prim::TracedAttr[scope="__module.f.training"]()
70 // - = prim::TracedModuleForward[scope="__module.f"](%2),
71 // + = prim::TracedModuleForward[scope="__module.f"](%9),
72 // block0(%self : ClassType<Module>):
73 // - %6 : Float(3, 4) = aten::mm(%x, %3),
74 // + %8 : Tensor = prim::GetAttr[name="param"](%self)
75 // + %6 : Float(3, 4) = aten::mm(%x, %8),
76 // -> ()
77 // return (%6)
78 //
79 // The diff after step (3)
80 //
81 // - %1 : bool = prim::TracedAttr[scope="__module.training"]()
82 // - %2 : ClassType<Module> = prim::TracedAttr[scope="__module.f"]()
83 // - %3 : Float(4, 4) = prim::TracedAttr[scope="__module.f.param"]()
84 // - %4 : bool = prim::TracedAttr[scope="__module.f.training"]()
85 struct ConvertTracedAttrReferences {
runtorch::jit::__anon8e087a7d0111::ConvertTracedAttrReferences86 void run(const std::shared_ptr<Graph>& graph) {
87 // Build a table mapping--for each TracedAttr node--the
88 // qualified name of the attribute to the Value* output
89 // of the Node.
90 buildAttrMap(graph);
91 // Step 1
92 addSelfArgToTracedForwardNodes(graph->block());
93 // Step 2
94 convertAttrReferencesToLocalGetAttrs(
95 graph->block(), "__module", graph->inputs()[0]);
96 // Step 3
97 destroyTracedAttrNodes(graph);
98 }
99
100 private:
buildAttrMaptorch::jit::__anon8e087a7d0111::ConvertTracedAttrReferences101 void buildAttrMap(const std::shared_ptr<Graph>& graph) {
102 for (Node* n : graph->nodes()) {
103 if (n->kind() == prim::TracedAttr) {
104 attr_qualname_to_value[n->s(attr::scope)] = n->output();
105 }
106 }
107 }
108
addSelfArgToTracedForwardNodestorch::jit::__anon8e087a7d0111::ConvertTracedAttrReferences109 void addSelfArgToTracedForwardNodes(Block* b) {
110 for (Node* n : b->nodes()) {
111 if (n->kind() == prim::TracedModuleForward) {
112 n->addInput(attr_qualname_to_value.at(n->s(attr::scope)));
113 n->blocks()[0]->addInput("self")->setType(
114 attr_qualname_to_value.at(n->s(attr::scope))->type());
115 addSelfArgToTracedForwardNodes(n->blocks()[0]);
116 }
117 if (n->kind() == prim::TracedFork) {
118 addSelfArgToTracedForwardNodes(n->blocks()[0]);
119 }
120 }
121 }
122
123 // This is a recursive function that descends down all blocks in the Graph
124 // (NB: not just TracedModuleForward blocks). Each descension has a
125 // corresponding `prefix`, i.e. the qualified name of the scope this
126 // Block represents (or the scope in which this block resides for
127 // non-TracedModuleForward nodes). We use this prefix to make decisions
128 // about whether to emit a GetAttr node for an attribute reference, or
129 // to defer that emission to the caller (in the case where an attribute
130 // reference does not reside in the `prefix` scope).
convertAttrReferencesToLocalGetAttrstorch::jit::__anon8e087a7d0111::ConvertTracedAttrReferences131 std::vector<Value*> convertAttrReferencesToLocalGetAttrs(
132 Block* b,
133 const c10::QualifiedName& prefix,
134 Value* self) {
135 // Store away Value*'s which are references to TracedAttr's which are
136 // not in the `prefix` scope. We pass this back to the caller, who
137 // should add these Values as explicit inputs as well as inductively
138 // make the same decision on those Values.
139 std::vector<Value*> unresolved_tracedattrs;
140 // To ensure we don't emit redundant GetAttr Nodes in a given scope,
141 // we maintain this map of original TracedAttr Value* to the Value*
142 // corresponding to the GetAttr for that attribute.
143 // We don't rely on CSE here because we currently can't reason about
144 // the correctness of CSE over GetAttr Nodes (i think)
145 std::unordered_map<Value*, Value*> local_remaps;
146
147 for (Node* n : b->nodes()) {
148 // The only difference between these two branches is for
149 // TracedModuleForward we advance the scope, but for other
150 // Nodes with Blocks we don't
151 if (n->kind() == prim::TracedModuleForward) {
152 auto sub_unresolved = convertAttrReferencesToLocalGetAttrs(
153 n->blocks()[0], n->s(attr::scope), n->blocks()[0]->inputs()[0]);
154 for (Value* v : sub_unresolved) {
155 n->addInput(v);
156 }
157 } else if (!n->blocks().empty()) {
158 for (Block* sub_block : n->blocks()) {
159 auto sub_unresolved =
160 convertAttrReferencesToLocalGetAttrs(sub_block, prefix, self);
161 for (Value* v : sub_unresolved) {
162 n->addInput(v);
163 }
164 }
165 }
166
167 for (size_t inp_idx = 0; inp_idx < n->inputs().size(); ++inp_idx) {
168 Value* inp = n->input(inp_idx);
169
170 // Short circuit: if we've already emitted a new Value for this
171 // attribute, just use that.
172 if (local_remaps.count(inp)) {
173 n->replaceInput(inp_idx, local_remaps[inp]);
174 continue;
175 }
176
177 WithInsertPoint guard(b->param_node()->next());
178 replaceTracedAttrInputOnNode(
179 n, inp_idx, prefix, self, local_remaps, unresolved_tracedattrs);
180 } // for (Value *inp : n->inputs())
181 } // for (Node *n : b->nodes())
182 return unresolved_tracedattrs;
183 }
184
replaceTracedAttrInputOnNodetorch::jit::__anon8e087a7d0111::ConvertTracedAttrReferences185 void replaceTracedAttrInputOnNode(
186 Node* n,
187 size_t inp_idx,
188 const c10::QualifiedName& prefix,
189 Value* self,
190 std::unordered_map<Value*, Value*>& local_remaps,
191 std::vector<Value*>& unresolved_tracedattrs) {
192 auto inp = n->inputs()[inp_idx];
193 auto inp_node = inp->node();
194 auto prefix_atoms = prefix.atoms();
195 if (inp_node->kind() == prim::TracedAttr) {
196 auto attr_qualname = c10::QualifiedName(inp_node->s(attr::scope));
197 if (prefix.isPrefixOf(attr_qualname)) {
198 // Prefix case: the attribute resides in this scope or a
199 // sub-scope. Continually emit GetAttr nodes until we've reached
200 // the proper attribute.
201 auto attr_atoms = attr_qualname.atoms();
202 Value* replaced_value = self;
203 for (const auto i : c10::irange(attr_atoms.size())) {
204 if (i < prefix_atoms.size()) {
205 TORCH_INTERNAL_ASSERT(attr_atoms[i] == prefix_atoms[i]);
206 } else {
207 replaced_value = n->owningBlock()->owningGraph()->insertGetAttr(
208 replaced_value, attr_atoms[i]);
209 } // if (i < prefix_atoms.size())
210 } // for(const auto i : c10::irange(attr_atoms.size()))
211 n->replaceInput(inp_idx, replaced_value);
212 local_remaps[inp] = replaced_value;
213 } else {
214 // Non-prefix case: this is a use of an attribute somewhere
215 // higher in the Module hierarchy. Add a captured input to
216 // the block for this attribute and add to the vector of
217 // Value*'s for the caller to handle.
218 Value* remapped = n->owningBlock()->addInput()->copyMetadata(inp);
219 n->replaceInput(inp_idx, remapped);
220 unresolved_tracedattrs.push_back(inp);
221 local_remaps[inp] = remapped;
222 } // if (prefix.isPrefixOf(attr_qualname))
223 } // if (inp_node->kind() == prim::TracedAttr)
224 }
225
226 // The previous pass should have deleted all uses of TracedAttr
227 // nodes. Let's explicitly delete them here.
destroyTracedAttrNodestorch::jit::__anon8e087a7d0111::ConvertTracedAttrReferences228 void destroyTracedAttrNodes(const std::shared_ptr<Graph>& graph) {
229 for (auto& kv : attr_qualname_to_value) {
230 kv.second->node()->destroy();
231 }
232 }
233
234 // For each prim::TracedAttr, record the `scope` value mapped
235 // to the Value* in the graph for that attribute.
236 std::unordered_map<std::string, Value*> attr_qualname_to_value;
237 };
238
239 // Iterate through all the nodes in program order and--for each use--
240 // if the Value referenced is not in a scope that dominates the node,
241 // add block and Node outputs to lift it into a scope in which
242 // it dominates the Use.
243 struct MakeDefsDominateUses {
244 MakeDefsDominateUses() = default;
245
runtorch::jit::__anon8e087a7d0111::MakeDefsDominateUses246 void run(Block* b) {
247 processNode(b->param_node(), b);
248 for (Node* n : b->nodes()) {
249 processNode(n, b);
250 }
251 processNode(b->return_node(), b);
252 }
253
254 private:
processNodetorch::jit::__anon8e087a7d0111::MakeDefsDominateUses255 void processNode(Node* n, Block* b) {
256 for (size_t i = 0; i < n->inputs().size(); ++i) {
257 Value* inp = n->inputs()[i];
258
259 // Already lifted to this level by a previously processed Use, switch to
260 // remapped value
261 Value* inp_remapped = inp;
262 if (remap.count(inp_remapped)) {
263 n->replaceInput(i, remap[inp_remapped]);
264 inp_remapped = remap[inp_remapped];
265 }
266
267 // This conditional isn't strictly necessary, but saves a lot of
268 // computation in the common case that we're using a local value.
269 if (inp_remapped->node()->owningBlock() != b) {
270 // Find the common ancestor block between this node and the node that
271 // produced this input. For this input Use to be valid, the Value's
272 // def must be present in this common ancestor node.
273 Block* common_ancestor =
274 n->findCommonAncestorBlockWith(inp_remapped->node());
275
276 Value* v_itr = inp_remapped;
277 Block* b_itr = inp_remapped->node()->owningBlock();
278
279 // Starting from the initial def for this input, iterate to
280 // wider and wider blocks, adding Block outputs and Node outputs
281 // along the way. Then, log the lifted values in the remap table
282 // so we can make subsequent Uses refer to the lifted value, if
283 // the domination condition is met.
284 while (b_itr != common_ancestor) {
285 b_itr->registerOutput(v_itr);
286 Value* remapped =
287 b_itr->owningNode()->addOutput()->setType(v_itr->type());
288 v_itr = remapped;
289 b_itr = b_itr->owningNode()->owningBlock();
290 }
291 // From now on, references to `inp` will be replaced with
292 // references to `v_itr`, the lifted Value
293 remap[inp] = v_itr;
294 n->replaceInput(i, remap[inp]);
295 }
296 }
297
298 if (isEligibleNode(n)) {
299 run(n->blocks()[0]);
300 }
301 }
302
303 // This holds the mapping between a Value* we would see in a Use
304 // and the lifted value, if present. We use this to ensure that
305 // Uses refer to a Value* that is in a dominating scope.
306 using RemappingTable = std::unordered_map<Value*, Value*>;
307 RemappingTable remap;
308 };
309
310 // For all blocks except graph->block(), convert multiple block
311 // returns to a TupleConstruct. This is required for turning the
312 // blocks into Methods. (and in the case that self is nullptr,
313 // it is required to properly inline the blocks).
convertReturnsToTuples(Block * b)314 void convertReturnsToTuples(Block* b) {
315 for (Node* n : b->nodes()) {
316 if (n->kind() == prim::TracedFork) {
317 convertReturnsToTuples(n->blocks()[0]);
318 } else if (n->kind() == prim::TracedModuleForward) {
319 TORCH_INTERNAL_ASSERT(n->blocks().size() == 1);
320 convertReturnsToTuples(n->blocks()[0]);
321
322 Graph* g = b->owningGraph();
323 Block* sub_block = n->blocks()[0];
324 if (sub_block->outputs().size() > 1) {
325 {
326 // Make block returns go through a Tuple
327 WithInsertPoint guard(sub_block->return_node());
328 Node* return_tup =
329 g->insertNode(g->createTuple(sub_block->outputs()));
330 while (!sub_block->outputs().empty()) {
331 sub_block->eraseOutput(0);
332 }
333 sub_block->registerOutput(return_tup->output());
334 }
335
336 // Make node outputs a single tuple;
337 std::vector<TypePtr> types;
338 for (size_t i = 0; i < n->outputs().size(); ++i) {
339 types.push_back(n->output(i)->type());
340 }
341 Value* tup_output = n->addOutput()->setType(TupleType::create(types));
342 Node* tup_unpack = g->createTupleUnpack(tup_output)->insertAfter(n);
343 for (size_t i = 0; i < tup_unpack->outputs().size(); ++i) {
344 auto rev_idx = tup_unpack->outputs().size() - i - 1;
345 n->output(rev_idx)->replaceAllUsesWith(tup_unpack->output(rev_idx));
346 n->eraseOutput(rev_idx);
347 }
348 } else if (sub_block->outputs().empty()) {
349 WithInsertPoint guard(sub_block->return_node());
350 sub_block->registerOutput(g->insertNode(g->createNone())->output());
351 n->addOutput()->setType(NoneType::get());
352 }
353 }
354 }
355 }
356
357 // Lambda lift Values (i.e. add Graph inputs for the purpose of
358 // referencing values that dominate the block) and convert
359 // the block to a Graph. blocks()[0] on each TracedModuleForward then
360 // appears as a Graph attribute attr::Subgraph
lambdaLiftBlocksAndConvertToGraph(Block * b)361 void lambdaLiftBlocksAndConvertToGraph(Block* b) {
362 for (Node* n : b->nodes()) {
363 if (isEligibleNode(n)) {
364 lambdaLiftBlocksAndConvertToGraph(n->blocks()[0]);
365
366 auto graph = std::make_shared<Graph>();
367 std::unordered_map<Value*, Value*> remaps;
368 graph->block()->cloneFrom(n->blocks()[0], [&](Value* v) {
369 if (!remaps.count(v)) {
370 remaps[v] = graph->addInput()->copyMetadata(v);
371 n->addInput(v);
372 }
373 return remaps[v];
374 });
375 LintGraph(graph);
376 n->g_(attr::Subgraph, graph);
377 n->eraseBlock(0);
378 }
379 }
380 }
381
382 // Find a unique name to add this method as
383 // We try {method_name}, {method_name}1, {method_name}2, ...
mangleMethodName(const std::string & method_name,const ClassTypePtr & mod_type)384 std::string mangleMethodName(
385 const std::string& method_name,
386 const ClassTypePtr& mod_type) {
387 for (size_t method_idx = 0;; method_idx++) {
388 auto mangled = method_name;
389 if (method_idx != 0) {
390 mangled += std::to_string(method_idx);
391 }
392 bool found = false;
393 for (Function* fn : mod_type->methods()) {
394 if (fn->name() == mangled) {
395 found = true;
396 break;
397 }
398 }
399 if (!found) {
400 return mangled;
401 }
402 }
403 TORCH_INTERNAL_ASSERT(false);
404 }
405
406 // Register the attr::Subgraph Graph values as Functions in the
407 // class compilation unit and register that Function as a method
408 // on the corresponding Module in the Module hierarchy. Note that we
409 // unique the methods by naming them forward, forward1, forward2...
createMethodCalls(const std::shared_ptr<Graph> & g)410 void createMethodCalls(const std::shared_ptr<Graph>& g) {
411 for (auto node_itr = g->nodes().begin(); node_itr != g->nodes().end();) {
412 Node* n = *node_itr++;
413 if (n->kind() == prim::TracedFork) {
414 createMethodCalls(n->g(attr::Subgraph));
415 } else if (n->kind() == prim::TracedModuleForward) {
416 WithInsertPoint ip(n);
417
418 ClassTypePtr callee_mod_type = n->input(0)->type()->expect<ClassType>();
419
420 createMethodCalls(n->g(attr::Subgraph));
421
422 auto mangled_method_name = mangleMethodName("forward", callee_mod_type);
423 auto qualname = c10::QualifiedName(
424 callee_mod_type->name().value(), mangled_method_name);
425 Function* f = callee_mod_type->compilation_unit()->create_function(
426 qualname, n->g(attr::Subgraph));
427 callee_mod_type->addMethod(f);
428
429 std::vector<NamedValue> nvs;
430 for (Value* i : n->inputs()) {
431 nvs.emplace_back(i->node()->sourceRange(), i);
432 }
433 auto schema = matchSchema(f->getSchema(), n->sourceRange(), *g, nvs, {});
434 Value* retval = g->insertMethodCall(f->qualname().name(), schema);
435 n->output()->replaceAllUsesWith(retval);
436 n->destroy();
437 }
438 }
439 }
440
inlineScopeBlocks(Block * b)441 void inlineScopeBlocks(Block* b) {
442 for (auto n_itr = b->nodes().begin(); n_itr != b->nodes().end();) {
443 Node* n = *n_itr++;
444 for (Block* sub_b : n->blocks()) {
445 inlineScopeBlocks(sub_b);
446 }
447 if (n->kind() == prim::TracedModuleForward) {
448 // Convert the block to a graph so we can inline it
449 auto graph = std::make_shared<Graph>();
450 std::unordered_map<Value*, Value*> remaps;
451 graph->block()->cloneFrom(n->blocks()[0], [&](Value* v) {
452 remaps[v] = graph->block()->addInput()->copyMetadata(v);
453 n->addInput(v);
454 return remaps[v];
455 });
456
457 WithInsertPoint insert_point(n);
458 AT_ASSERT(n->inputs().size() == graph->inputs().size());
459 auto new_outputs = insertGraph(*n->owningGraph(), *graph, n->inputs());
460 const auto& old_outputs = n->outputs();
461
462 AT_ASSERT(new_outputs.size() == old_outputs.size());
463 for (const auto i : c10::irange(old_outputs.size())) {
464 old_outputs[i]->replaceAllUsesWith(new_outputs[i]);
465 }
466 n->destroy();
467 }
468 }
469 }
470
convertTracedForksToRealForks(const std::shared_ptr<Graph> & g)471 void convertTracedForksToRealForks(const std::shared_ptr<Graph>& g) {
472 for (auto itr = g->nodes().begin(); itr != g->nodes().end();) {
473 Node* n = *itr++;
474 if (n->kind() == prim::TracedFork) {
475 WithInsertPoint guard(n);
476 Node* new_fork_node =
477 g->insertNode(g->create(prim::fork, n->outputs().size()))
478 ->copyAttributes(*n);
479 for (Value* i : n->inputs()) {
480 new_fork_node->addInput(i);
481 }
482 for (size_t i = 0; i < new_fork_node->outputs().size(); ++i) {
483 new_fork_node->outputs()[i]->copyMetadata(n->outputs()[i]);
484 n->outputs()[i]->replaceAllUsesWith(new_fork_node->outputs()[i]);
485 }
486 n->destroy();
487 }
488 }
489 }
490
491 // Run a few clean-up passes to make the graph a bit cleaner.
runCleanupPasses(const std::shared_ptr<Graph> & g)492 void runCleanupPasses(const std::shared_ptr<Graph>& g) {
493 for (Node* n : g->nodes()) {
494 if (n->kind() == prim::TracedFork) {
495 auto subgraph = n->g(attr::Subgraph);
496 if (getInlineEverythingMode()) {
497 Inline(*subgraph);
498 }
499 convertTracedForksToRealForks(subgraph);
500 LowerSimpleTuples(subgraph);
501 EliminateDeadCode(subgraph);
502 LintGraph(subgraph);
503 }
504 }
505 if (getInlineEverythingMode()) {
506 Inline(*g);
507 }
508 convertTracedForksToRealForks(g);
509 LowerSimpleTuples(g);
510 EliminateDeadCode(g);
511 LintGraph(g);
512 }
513
runCleanupPasses(Module * m)514 void runCleanupPasses(Module* m) {
515 auto methods = m->get_methods();
516 for (auto module : m->children()) {
517 runCleanupPasses(&module);
518 }
519 for (auto& method : methods) {
520 runCleanupPasses(method.graph());
521 }
522 }
523
524 } // namespace
525
FixupTraceScopeBlocks(std::shared_ptr<Graph> & graph,Module * self)526 void FixupTraceScopeBlocks(std::shared_ptr<Graph>& graph, Module* self) {
527 if (self) {
528 ConvertTracedAttrReferences().run(graph);
529 } else {
530 for (Node* n : graph->nodes()) {
531 TORCH_INTERNAL_ASSERT(n->kind() != prim::TracedAttr);
532 }
533 }
534 MakeDefsDominateUses().run(graph->block());
535 convertReturnsToTuples(graph->block());
536 if (!self) {
537 // We have no Module, so we're just going to inline everything.
538 // This should give us a totally flat graph.
539 inlineScopeBlocks(graph->block());
540 // For TracedFork nodes
541 lambdaLiftBlocksAndConvertToGraph(graph->block());
542 runCleanupPasses(graph);
543 } else {
544 lambdaLiftBlocksAndConvertToGraph(graph->block());
545 createMethodCalls(graph);
546 runCleanupPasses(self);
547 // `graph` isn't referenced in `self` yet, so we need to run
548 // this separately
549 runCleanupPasses(graph);
550 }
551 }
552
553 } // namespace torch::jit
554