xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/onnx/function_extraction.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/jit_log.h>
2 #include <torch/csrc/jit/passes/onnx/function_extraction.h>
3 #include <torch/csrc/jit/passes/onnx/naming.h>
4 
5 namespace torch::jit::onnx {
6 
7 namespace {
8 
9 using scope_list = std::vector<ScopePtr>;
10 
11 // Annotated attributes retrieved from module by inspecting module annotations.
12 // These attributes are not used inside the subgraph of ONNX local function
13 // because they are not created by PyTorch JIT tracing, but they may be used by
14 // consumers to determine whether or not to replace the function with a
15 // particular fused kernel.
16 static std::unordered_map<ScopePtr, Node*> scope_attr_map_;
17 static std::shared_ptr<Graph> scope_attr_graph_ = std::make_shared<Graph>();
18 
19 static bool HasSameAttribute(
20     const Node* a,
21     const Node* b,
22     const c10::Symbol& attr);
23 
24 struct FunctionExtractor {
25  public:
FunctionExtractortorch::jit::onnx::__anon8b66f1e50111::FunctionExtractor26   FunctionExtractor(
27       std::shared_ptr<Graph>& graph,
28       const std::unordered_set<std::string>& module_names,
29       const std::vector<std::string>& param_names)
30       : graph_(graph),
31         module_names_(module_names.begin(), module_names.end()),
32         param_names_(param_names.begin(), param_names.end()) {}
33   NodeAttrNameMap run();
34 
35  private:
36   struct ScopeContext {
37     std::unordered_set<ScopePtr> children_;
38     ScopePtr scope_;
39     node_list nlist_;
40     value_list inputs_;
41     value_list outputs_;
42     std::unordered_map<Value*, Value*> env_to_subgraph_;
43 
44     void PopulateInputsOutputs(
45         const std::unordered_set<std::string>& param_names);
46     bool IsIdenticalFuncion(const ScopeContext& other_ctx) const;
47   };
48 
49   using ScopeCtxPtr = ScopeContext*;
50   using scope_ctx_map = std::unordered_map<ScopePtr, ScopeCtxPtr>;
51 
52   struct FunctionContext {
53     FunctionContext(
54         ScopePtr key,
55         const scope_list& scopes,
56         scope_ctx_map& scope_ctxs);
57     void DebugPrint() const;
58     void SetAttrName(Node* ref_n, Symbol attr, const std::string& name);
59     std::optional<std::string> FindAttrName(Node* ref_n, Symbol attr);
60     std::optional<std::string> FindAttrName(Node* ref_const_n);
61 
62     ScopePtr scope_key_;
63     scope_ctx_map scope_ctxs_;
64     std::unordered_map<
65         Node*,
66         std::unordered_map<Symbol, std::unordered_set<Node*>>>
67         attribute_map_;
68 
69     // Passed later to serialization.
70     NodeAttrNameMap node_attr_to_name_;
71   };
72 
73   using FunctionCtxPtr = FunctionContext*;
74   using func_ctx_map = std::unordered_map<ScopePtr, FunctionCtxPtr>;
75 
76   static bool IsValidScope(const ScopePtr& s);
77   static std::optional<ScopePtr> InferScope(Node* n);
78   static bool IsAncestor(const ScopePtr& parent, ScopePtr child);
79   static std::optional<ScopePtr> FindCommonAncestor(ScopePtr a, ScopePtr b);
80   static std::optional<ScopePtr> FindCommonAncestor(const scope_list& scopes);
81   std::shared_ptr<Graph> ConstructFuncGraph(FunctionContext& ctx);
82 
83   void ConvertScopeToFunction(
84       const ScopePtr& scope_key,
85       const scope_list& scope_list,
86       scope_ctx_map& scope_ctxs,
87       const std::shared_ptr<Graph>& graph);
88 
89   static void HandleNoScopeNodes(
90       scope_ctx_map&,
91       const node_list& no_scope_nlist);
92   std::tuple<scope_ctx_map, node_list> PartitionNodesByScope(Block* b);
93   scope_ctx_map PartitionNodesByScope(const std::shared_ptr<Graph>& graph);
94   static std::unordered_map<ScopePtr, scope_list> PartitionIdenticalScopes(
95       scope_ctx_map& scope_ctxs);
96   static scope_list SortScopesByMaxDepth(
97       std::unordered_map<ScopePtr, scope_list>&);
98   Node* CreateFunctionDefNode(
99       FunctionContext& func_ctx,
100       const std::shared_ptr<Graph>& graph,
101       const std::string& domain_name,
102       const std::string& func_name);
103   Node* CreateFunctionNode(
104       FunctionContext& func_ctx,
105       ScopeContext& scope_ctx,
106       const std::shared_ptr<Graph>& graph,
107       const std::string& domain_name,
108       const std::string& func_name);
109 
110   static void DebugPrintScopeContexts(const scope_ctx_map&);
111   static void DebugPrintGraphWithFunction(const std::shared_ptr<Graph>& g);
112   static void DebugPrintConstantDiff(const FunctionContext&);
113 
114   std::shared_ptr<Graph> graph_;
115   std::unordered_set<std::string> module_names_;
116   std::unordered_set<std::string> param_names_;
117   // Track modules with same module name that are exported as different onnx
118   // local functions.
119   std::unordered_map<std::string, int> module_variant_count_;
120   func_ctx_map func_ctxs_;
121 };
122 
FunctionContext(ScopePtr key,const scope_list & scopes,scope_ctx_map & scope_ctxs)123 FunctionExtractor::FunctionContext::FunctionContext(
124     ScopePtr key,
125     const scope_list& scopes,
126     scope_ctx_map& scope_ctxs)
127     : scope_key_(std::move(key)) {
128   GRAPH_UPDATE(
129       "Process function context for scope ",
130       scope_key_->name().toDisplayString());
131   TORCH_INTERNAL_ASSERT(!scopes.empty());
132   const auto& ref_ctx = scope_ctxs[scope_key_];
133   // NOTE: Function scopes must have same number and order of nodes.
134   GRAPH_DEBUG(
135       "Initialized function context for scope ",
136       scope_key_->name().toDisplayString());
137 
138   for (const auto& scope : scopes) {
139     GRAPH_DEBUG(
140         "Process function context for scope ", scope->name().toDisplayString());
141     TORCH_INTERNAL_ASSERT(scope_ctxs.find(scope) != scope_ctxs.end());
142     scope_ctxs_[scope] = scope_ctxs[scope];
143     if (scope_key_ == scope) {
144       continue;
145     }
146     auto& scope_ctx = scope_ctxs[scope];
147 
148     const auto& ns_a = ref_ctx->nlist_;
149     const auto& ns_b = scope_ctx->nlist_;
150     TORCH_INTERNAL_ASSERT(ns_a.size() == ns_b.size());
151 
152     GRAPH_DEBUG("Process nodes of scope ", scope->name().toDisplayString());
153     for (const auto i : c10::irange(ns_a.size())) {
154       TORCH_INTERNAL_ASSERT(ns_a[i]->kind() == ns_b[i]->kind());
155       auto n_a = ns_a[i];
156       auto n_b = ns_b[i];
157       std::vector<c10::Symbol> diff_attrs;
158       std::vector<c10::Symbol> same_attrs;
159       auto n_a_attr_names = n_a->attributeNames();
160       auto n_b_attr_names = n_b->attributeNames();
161       std::sort(n_a_attr_names.begin(), n_a_attr_names.end());
162       std::sort(n_b_attr_names.begin(), n_b_attr_names.end());
163       std::set_difference(
164           n_a_attr_names.begin(),
165           n_a_attr_names.end(),
166           n_b_attr_names.begin(),
167           n_b_attr_names.end(),
168           std::inserter(diff_attrs, diff_attrs.begin()));
169       std::set_intersection(
170           n_a_attr_names.begin(),
171           n_a_attr_names.end(),
172           n_b_attr_names.begin(),
173           n_b_attr_names.end(),
174           std::inserter(same_attrs, same_attrs.begin()));
175       for (auto attr_name : diff_attrs) {
176         attribute_map_[n_a][attr_name].insert(n_b);
177       }
178 
179       for (auto attr_name : same_attrs) {
180         if (!HasSameAttribute(n_a, n_b, attr_name)) {
181           attribute_map_[n_a][attr_name].insert(n_b);
182         }
183       }
184     }
185     GRAPH_DEBUG("Process scope complete. ", scope->name().toDisplayString());
186   }
187 
188   GRAPH_DEBUG(
189       "Process function context complete. ",
190       scope_key_->name().toDisplayString());
191   DebugPrint();
192 }
193 
DebugPrint() const194 void FunctionExtractor::FunctionContext::DebugPrint() const {
195   GRAPH_DEBUG("Scope name: ", scope_key_->name().toDisplayString());
196 
197   for (const auto& it : attribute_map_) {
198     for (const auto& attr_it : it.second) {
199       GRAPH_DEBUG(
200           "Attribute value difference for attribute ",
201           attr_it.first.toDisplayString());
202       GRAPH_DEBUG(*it.first);
203       for (auto n : attr_it.second) {
204         GRAPH_DEBUG(*n);
205       }
206     }
207   }
208 }
209 
SetAttrName(Node * ref_n,Symbol attr,const std::string & name)210 void FunctionExtractor::FunctionContext::SetAttrName(
211     Node* ref_n,
212     Symbol attr,
213     const std::string& name) {
214   auto v_it =
215       scope_ctxs_[scope_key_]->env_to_subgraph_.find(ref_n->outputs().at(0));
216   TORCH_INTERNAL_ASSERT(
217       v_it != scope_ctxs_[scope_key_]->env_to_subgraph_.end());
218   auto* n_in_def = v_it->second->node();
219   auto n_attr_it = node_attr_to_name_[n_in_def][attr.toUnqualString()] = name;
220 }
221 
FindAttrName(Node * ref_n,Symbol attr)222 std::optional<std::string> FunctionExtractor::FunctionContext::FindAttrName(
223     Node* ref_n,
224     Symbol attr) {
225   auto v_it =
226       scope_ctxs_[scope_key_]->env_to_subgraph_.find(ref_n->outputs().at(0));
227   if (v_it == scope_ctxs_[scope_key_]->env_to_subgraph_.end()) {
228     return std::nullopt;
229   }
230   auto* n_in_def = v_it->second->node();
231   auto n_attr_it = node_attr_to_name_.find(n_in_def);
232   if (n_attr_it == node_attr_to_name_.end()) {
233     return std::nullopt;
234   }
235   auto name_it = n_attr_it->second.find(attr.toUnqualString());
236   if (name_it == n_attr_it->second.end()) {
237     return std::nullopt;
238   }
239   return name_it->second;
240 }
241 
DebugPrintScopeContexts(const scope_ctx_map & scope_ctxs)242 void FunctionExtractor::DebugPrintScopeContexts(
243     const scope_ctx_map& scope_ctxs) {
244   for (auto& it : scope_ctxs) {
245     GRAPH_UPDATE(
246         "Scope name: ",
247         it.first->namesFromRoot(),
248         " ",
249         it.first->name().toDisplayString());
250     GRAPH_UPDATE("Children scopes: ", [&]() {
251       std::stringstream ss;
252       for (const auto& child_scope : it.second->children_) {
253         ss << child_scope->name().toDisplayString() << " ";
254       }
255       return ss.str();
256     }());
257     GRAPH_UPDATE("Node types: \n", [&]() {
258       std::stringstream ss;
259       for (auto n : it.second->nlist_) {
260         ss << "  " << *n;
261       }
262       return ss.str();
263     }());
264     GRAPH_UPDATE("Node count: ", it.second->nlist_.size());
265   }
266 }
267 
DebugPrintGraphWithFunction(const std::shared_ptr<Graph> & g)268 void FunctionExtractor::DebugPrintGraphWithFunction(
269     const std::shared_ptr<Graph>& g) {
270   GRAPH_UPDATE("Local function definitions:");
271   for (auto* n : g->nodes()) {
272     if (n->kind() == Symbol::onnx("LocalFunctionDef")) {
273       GRAPH_UPDATE(
274           n->s(attr::name),
275           " graph: ",
276           n->g(Symbol::attr("graph"))->toString());
277     }
278   }
279   GRAPH_UPDATE("Main graph: ", g->toString());
280 }
281 
IsValidScope(const ScopePtr & s)282 bool FunctionExtractor::IsValidScope(const ScopePtr& s) {
283   return !s->isRoot() && !s->isBlank();
284 }
285 
IsAncestor(const ScopePtr & parent,ScopePtr child)286 bool FunctionExtractor::IsAncestor(const ScopePtr& parent, ScopePtr child) {
287   if (!IsValidScope(parent) || !IsValidScope(child) ||
288       parent->getDepth() >= child->getDepth()) {
289     return false;
290   }
291   do {
292     child = child->parent();
293     if (parent == child) {
294       return true;
295     }
296   } while (IsValidScope(child));
297   return false;
298 }
299 
FindCommonAncestor(ScopePtr a,ScopePtr b)300 std::optional<ScopePtr> FunctionExtractor::FindCommonAncestor(
301     ScopePtr a,
302     ScopePtr b) {
303   if (!IsValidScope(a) || !IsValidScope(b)) {
304     return std::nullopt;
305   }
306 
307   auto diff =
308       static_cast<int64_t>(a->getDepth()) - static_cast<int64_t>(b->getDepth());
309   if (diff != 0) {
310     auto deeper_scope = diff > 0 ? a : b;
311     auto other_scope = diff > 0 ? b : a;
312     diff = std::abs(diff);
313     while (diff > 0) {
314       deeper_scope = deeper_scope->parent();
315       diff--;
316     }
317     a = deeper_scope;
318     b = other_scope;
319   }
320 
321   while (IsValidScope(a) && IsValidScope(b)) {
322     if (a == b) {
323       return a;
324     } else {
325       a = a->parent();
326       b = b->parent();
327     }
328   }
329 
330   return std::nullopt;
331 }
332 
FindCommonAncestor(const scope_list & scopes)333 std::optional<ScopePtr> FunctionExtractor::FindCommonAncestor(
334     const scope_list& scopes) {
335   if (scopes.empty()) {
336     return std::nullopt;
337   }
338 
339   std::optional<ScopePtr> common_ancestor = scopes.at(0);
340   for (const auto& scope : scopes) {
341     common_ancestor = FindCommonAncestor(common_ancestor.value(), scope);
342     if (!common_ancestor.has_value()) {
343       return std::nullopt;
344     }
345   }
346 
347   return common_ancestor;
348 }
349 
InferScope(Node * n)350 std::optional<ScopePtr> FunctionExtractor::InferScope(Node* n) {
351   // The scope of node n is assigned based on the following rules.
352   // 1. If all uses of outputs of n belongs to the same scope,
353   //    assign that scope, otherwise
354   // 2. If all nodes of inputs of n belongs to the same scope,
355   //    assign that scope, otherwise
356   // 3. Find common ancestor of the scopes of uses of outputs of n,
357   //    and the scopes of nodes of inputs of n.
358   scope_list input_scopes;
359   scope_list output_scopes;
360   for (auto input : n->inputs()) {
361     input_scopes.emplace_back(input->node()->scope());
362   }
363   for (auto output : n->outputs()) {
364     for (auto use : output->uses()) {
365       if (!IsValidScope(use.user->scope())) {
366         auto inferred_output_scope = InferScope(use.user);
367         if (inferred_output_scope.has_value() &&
368             IsValidScope(inferred_output_scope.value())) {
369           use.user->setScope(inferred_output_scope.value());
370         }
371       }
372       output_scopes.emplace_back(use.user->scope());
373     }
374   }
375   if (!output_scopes.empty() &&
376       std::all_of(
377           output_scopes.begin(),
378           output_scopes.end(),
379           [&output_scopes](const ScopePtr& scope) -> bool {
380             return IsValidScope(scope) && scope == output_scopes.at(0);
381           })) {
382     return output_scopes.at(0);
383   } else if (
384       !input_scopes.empty() &&
385       std::all_of(
386           input_scopes.begin(),
387           input_scopes.end(),
388           [&input_scopes](const ScopePtr& scope) -> bool {
389             return IsValidScope(scope) && scope == input_scopes.at(0);
390           })) {
391     return input_scopes.at(0);
392   } else {
393     scope_list scopes;
394     std::copy_if(
395         input_scopes.begin(),
396         input_scopes.end(),
397         std::back_inserter(scopes),
398         IsValidScope);
399     std::copy_if(
400         output_scopes.begin(),
401         output_scopes.end(),
402         std::back_inserter(scopes),
403         IsValidScope);
404     if (!scopes.empty()) {
405       auto common_ancestor = FindCommonAncestor(scopes);
406       if (common_ancestor.has_value() &&
407           IsValidScope(common_ancestor.value())) {
408         return common_ancestor.value();
409       }
410     }
411   }
412 
413   return std::nullopt;
414 }
415 
ConstructFuncGraph(FunctionContext & func_ctx)416 std::shared_ptr<Graph> FunctionExtractor::ConstructFuncGraph(
417     FunctionContext& func_ctx) {
418   auto& ctx = *func_ctx.scope_ctxs_[func_ctx.scope_key_];
419   const auto& nlist = ctx.nlist_;
420   const auto& scope = ctx.scope_;
421   auto& env = ctx.env_to_subgraph_;
422 
423   auto g = std::make_shared<Graph>();
424   GRAPH_DEBUG("Constructing graph for ", scope->namesFromRoot());
425 
426   // TODO: Update input names of function to match those in Module source code
427   // signature.
428   // This requires mapping between function node inputs and Module inputs.
429   // Due to the lack of such mapping, currently debugName is used as input
430   // names.
431   ctx.PopulateInputsOutputs(param_names_);
432   for (auto* v : ctx.inputs_) {
433     env[v] = g->addInput()->copyMetadata(v);
434     GRAPH_DEBUG(
435         "Add input value ",
436         env[v]->debugName(),
437         " for outer scope value ",
438         v->debugName(),
439         " from ",
440         *v->node());
441   }
442 
443   for (auto* n : nlist) {
444     auto clone_n = g->createClone(n, [&](Value* v) {
445       TORCH_INTERNAL_ASSERT(env.find(v) != env.end());
446       return env[v];
447     });
448     for (const auto i : c10::irange(clone_n->outputs().size())) {
449       env[n->output(i)] = clone_n->output(i);
450     }
451     g->insertNode(clone_n);
452   }
453 
454   // If values are used outside of this graph, set as graph output.
455   for (auto* v : ctx.outputs_) {
456     TORCH_INTERNAL_ASSERT(env.find(v) != env.end());
457     g->registerOutput(env[v]);
458   }
459 
460   GRAPH_DEBUG(g->toString());
461   return g;
462 }
463 
CreateFunctionDefNode(FunctionContext & func_ctx,const std::shared_ptr<Graph> & graph,const std::string & domain_name,const std::string & func_name)464 Node* FunctionExtractor::CreateFunctionDefNode(
465     FunctionContext& func_ctx,
466     const std::shared_ptr<Graph>& graph,
467     const std::string& domain_name,
468     const std::string& func_name) {
469   const auto func_def_nk = Symbol::onnx("LocalFunctionDef");
470   const auto func_g_attr = Symbol::attr("graph");
471   const auto func_name_attr = attr::name;
472   const auto func_domain_attr = Symbol::attr("domain");
473 
474   auto func_graph = ConstructFuncGraph(func_ctx);
475 
476   // create and insert local function definition node
477   auto func_def_n = graph->create(func_def_nk, 0);
478   func_def_n->g_(func_g_attr, func_graph);
479   func_def_n->s_(func_name_attr, func_name);
480   func_def_n->s_(func_domain_attr, domain_name);
481   graph->prependNode(func_def_n);
482 
483   // set constants and attributes of different values as function attributes.
484   std::unordered_map<std::string, int> base_attr_name_count;
485   std::vector<std::string> final_attr_names;
486 
487   auto adjust_attr_name = [&](std::string attr_name) {
488     if (base_attr_name_count.find(attr_name) != base_attr_name_count.end()) {
489       attr_name =
490           attr_name + "." + std::to_string(base_attr_name_count[attr_name]++);
491     } else {
492       base_attr_name_count[attr_name] = 1;
493     }
494     return attr_name;
495   };
496 
497   for (const auto& n_it : func_ctx.attribute_map_) {
498     auto* n = n_it.first;
499     for (const auto& attr_it : n_it.second) {
500       const auto& attr = attr_it.first;
501       // Add prefix "inferred::" to name of inferred attribute.
502       // This is to differentiate from annotated attributes picked up
503       // from python module annotation.
504       auto attr_name = "inferred::" + std::string(n->kind().toUnqualString()) +
505           '_' + attr.toUnqualString();
506       auto final_attr_name = adjust_attr_name(attr_name);
507       final_attr_names.emplace_back(final_attr_name);
508       func_ctx.SetAttrName(n, attr, final_attr_name);
509     }
510   }
511 
512   // Set annotated attributes
513   std::unordered_set<Symbol> annotated_attr_names;
514   bool first_iteration = true;
515   for (const auto& it : func_ctx.scope_ctxs_) {
516     auto scope = it.first;
517     auto annotated_attr_node = scope_attr_map_.find(scope);
518     if (annotated_attr_node != scope_attr_map_.end()) {
519       auto names = annotated_attr_node->second->attributeNames();
520       if (first_iteration) {
521         std::copy(
522             names.begin(),
523             names.end(),
524             std::inserter(annotated_attr_names, annotated_attr_names.end()));
525         first_iteration = false;
526       } else {
527         auto unseen_attr_name = std::find_if(
528             names.begin(),
529             names.end(),
530             [&annotated_attr_names](const Symbol& name) {
531               return annotated_attr_names.find(name) ==
532                   annotated_attr_names.end();
533             });
534         TORCH_CHECK(
535             unseen_attr_name == names.end(),
536             "Found outstanding annotated attribute ",
537             *unseen_attr_name,
538             " from module ",
539             scope->name(),
540             ". Please ensure module instances of the same class have the same set of annotated attributes.");
541       }
542     }
543   }
544   for (auto attr_name : annotated_attr_names) {
545     final_attr_names.emplace_back(attr_name.toUnqualString());
546   }
547 
548   func_def_n->ss_(Symbol::attr("attributes"), final_attr_names);
549 
550   return func_def_n;
551 }
552 
CreateFunctionNode(FunctionContext & func_ctx,ScopeContext & scope_ctx,const std::shared_ptr<Graph> & graph,const std::string & domain_name,const std::string & func_name)553 Node* FunctionExtractor::CreateFunctionNode(
554     FunctionContext& func_ctx,
555     ScopeContext& scope_ctx,
556     const std::shared_ptr<Graph>& graph,
557     const std::string& domain_name,
558     const std::string& func_name) {
559   const auto& func_scope = func_ctx.scope_key_;
560   GRAPH_DEBUG(
561       "Create and insert local function for scope: ",
562       func_scope->namesFromRoot());
563   scope_ctx.PopulateInputsOutputs(param_names_);
564   auto last_n = *scope_ctx.nlist_.rbegin();
565   auto func_n = graph->create(
566       Symbol::fromQualString(domain_name + "::" + func_name),
567       scope_ctx.outputs_.size());
568   func_n->copyMetadata(last_n);
569   for (auto* v : scope_ctx.inputs_) {
570     func_n->addInput(v);
571   }
572   for (const auto i : c10::irange(scope_ctx.outputs_.size())) {
573     func_n->output(i)->setType(scope_ctx.outputs_[i]->type());
574     scope_ctx.outputs_[i]->replaceAllUsesWith(func_n->output(i));
575   }
576 
577   // set attributes of different values as function attributes.
578   auto copy_attr =
579       [](Node* a, Node* b, Symbol attr, const std::string& new_name) {
580 #define COPY_ATTR(kind)                                \
581   case AttributeKind::kind: {                          \
582     b->kind##_(Symbol::attr(new_name), a->kind(attr)); \
583     break;                                             \
584   }
585         switch (a->kindOf(attr)) {
586           COPY_ATTR(f)
587           COPY_ATTR(fs)
588           COPY_ATTR(i)
589           COPY_ATTR(is)
590           COPY_ATTR(s)
591           COPY_ATTR(ss)
592           COPY_ATTR(t)
593           COPY_ATTR(ts)
594 #undef COPY_ATTR
595           case AttributeKind::ival:
596           case AttributeKind::g:
597           case AttributeKind::gs:
598           case AttributeKind::ty:
599           case AttributeKind::tys:
600           case AttributeKind::c:
601           default:
602             TORCH_INTERNAL_ASSERT(
603                 false,
604                 "Unexpected attribute type ",
605                 static_cast<int>(a->kindOf(attr)),
606                 " from node ",
607                 *a);
608             break;
609         }
610       };
611 
612   for (const auto& it : func_ctx.attribute_map_) {
613     auto* ref_n = it.first;
614     for (const auto& attr_it : it.second) {
615       const auto& attr = attr_it.first;
616       auto attr_name = func_ctx.FindAttrName(ref_n, attr).value();
617       copy_attr(ref_n, func_n, attr, attr_name);
618       for (auto* n : scope_ctx.nlist_) {
619         if (attr_it.second.find(n) != attr_it.second.end()) {
620           copy_attr(n, func_n, attr, attr_name);
621           break;
622         }
623       }
624     }
625   }
626 
627   // annotated attributes
628   auto scope = scope_ctx.scope_;
629   auto annotated_attr_node = scope_attr_map_.find(scope);
630   if (annotated_attr_node != scope_attr_map_.end()) {
631     auto node = annotated_attr_node->second;
632     for (auto attr : node->attributeNames()) {
633       copy_attr(node, func_n, attr, attr.toUnqualString());
634     }
635   }
636 
637   func_n->insertAfter(last_n);
638   return func_n;
639 }
640 
ConvertScopeToFunction(const ScopePtr & scope_key,const scope_list & scope_list,scope_ctx_map & scope_ctxs,const std::shared_ptr<Graph> & graph)641 void FunctionExtractor::ConvertScopeToFunction(
642     const ScopePtr& scope_key,
643     const scope_list& scope_list,
644     scope_ctx_map& scope_ctxs,
645     const std::shared_ptr<Graph>& graph) {
646   // This function needs to be called always on inner most scopes.
647   // 1. Generate function context, this identifies different constants and
648   // attributes.
649   // 2. Create function definition node, and insert to main graph.
650   // 3. Create function node for each call, and replace subgraph nodes in parent
651   // functions.
652 
653   func_ctxs_.insert(std::make_pair(
654       scope_key, new FunctionContext(scope_key, scope_list, scope_ctxs)));
655   auto& func_ctx = *func_ctxs_[scope_key];
656 
657   const std::string module_class_name(
658       ONNXScopeName::className(func_ctx.scope_key_));
659   auto pos = module_class_name.rfind('.');
660   TORCH_INTERNAL_ASSERT(pos != std::string::npos);
661 
662   auto construct_unique_module_name = [&](std::string module_name) {
663     auto module_name_variant = module_variant_count_.find(module_name);
664     if (module_name_variant != module_variant_count_.end()) {
665       module_variant_count_[module_name]++;
666       module_name += ("." + std::to_string(module_name_variant->second));
667     } else {
668       module_variant_count_[module_name] = 0;
669     }
670     return module_name;
671   };
672 
673   const auto domain_name = module_class_name.substr(0, pos);
674   const auto func_name =
675       construct_unique_module_name(module_class_name.substr(pos + 1));
676 
677   CreateFunctionDefNode(func_ctx, graph, domain_name, func_name);
678 
679   // create and insert local function node to graph.
680   for (const auto& it : func_ctx.scope_ctxs_) {
681     auto scope = it.first;
682     auto& scope_ctx = *it.second;
683     auto func_n =
684         CreateFunctionNode(func_ctx, scope_ctx, graph, domain_name, func_name);
685 
686     std::unordered_set<Node*> old_nodes(
687         scope_ctx.nlist_.begin(), scope_ctx.nlist_.end());
688 
689     auto last_n = *scope_ctx.nlist_.rbegin();
690     // replace function body nodes in parent scopes with local function node.
691     for (auto& it : scope_ctxs) {
692       const auto& parent_scope = it.first;
693       auto& parent_ctx = *it.second;
694 
695       if (!IsAncestor(parent_scope, scope)) {
696         continue;
697       }
698 
699       auto& ctx_nlist = parent_ctx.nlist_;
700       GRAPH_DEBUG(
701           "Replace local function node in parent scope: ",
702           it.first->namesFromRoot(),
703           " nodes to remove: ",
704           old_nodes.size(),
705           " parent total nodes: ",
706           ctx_nlist.size());
707 
708       // insert local function node
709       auto last_n_it = std::find(ctx_nlist.begin(), ctx_nlist.end(), last_n);
710       ctx_nlist.insert(last_n_it, func_n);
711 
712       // remove replaced nodes from list
713       ctx_nlist.erase(
714           std::remove_if(
715               ctx_nlist.begin(),
716               ctx_nlist.end(),
717               [&old_nodes](Node* n) {
718                 return old_nodes.find(n) != old_nodes.end();
719               }),
720           ctx_nlist.end());
721 
722       GRAPH_DEBUG("Parent total nodes after remove: ", ctx_nlist.size());
723 
724       // refresh inputs/outputs.
725       parent_ctx.PopulateInputsOutputs(param_names_);
726     }
727   }
728 
729   for (const auto& it : func_ctx.scope_ctxs_) {
730     auto& scope_ctx = *it.second;
731     // delete replaced nodes in graph.
732     for (auto it = scope_ctx.nlist_.rbegin(); it != scope_ctx.nlist_.rend();) {
733       auto* n = *it;
734       it++;
735       GRAPH_DEBUG("Destroying node ", *n);
736       n->destroy();
737     }
738   }
739 }
740 
IsIdenticalFuncion(const ScopeContext & other_ctx) const741 bool FunctionExtractor::ScopeContext::IsIdenticalFuncion(
742     const ScopeContext& other_ctx) const {
743   // Differentiate same function under different inputs.
744   // When constants are passed in place of inputs, it leads to different
745   // input count and node count. Likewise, due to different uses, output
746   // count can be different as well.
747   // For now export them as different functions.
748   // Covered by `test_local_function_overloads` in
749   // `test/onnx/test_utility_funs.py`.
750   if (&other_ctx == this) {
751     return true;
752   }
753   if (ONNXScopeName::className(this->scope_) !=
754       ONNXScopeName::className(other_ctx.scope_)) {
755     return false;
756   }
757   if (this->inputs_.size() != other_ctx.inputs_.size() ||
758       this->outputs_.size() != other_ctx.outputs_.size()) {
759     return false;
760   }
761   const auto& ns_a = this->nlist_;
762   const auto& ns_b = other_ctx.nlist_;
763   if (ns_a.size() != ns_b.size()) {
764     return false;
765   }
766   for (const auto i : c10::irange(ns_a.size())) {
767     if (ns_a[i]->kind() != ns_b[i]->kind()) {
768       return false;
769     }
770   }
771 
772   return true;
773 }
774 
PopulateInputsOutputs(const std::unordered_set<std::string> & param_names)775 void FunctionExtractor::ScopeContext::PopulateInputsOutputs(
776     const std::unordered_set<std::string>& param_names) {
777   inputs_.clear();
778   outputs_.clear();
779   const auto& nlist = this->nlist_;
780   std::unordered_set<Value*> v_set;
781   std::unordered_set<Node*> n_set;
782 
783   value_list input_list;
784   value_list initializer_list;
785 
786   // Add initializers after inputs.
787   for (auto* n : nlist) {
788     for (auto* v : n->inputs()) {
789       if (v_set.find(v) == v_set.end()) {
790         if (param_names.find(v->debugName()) != param_names.end()) {
791           initializer_list.emplace_back(v);
792         } else {
793           input_list.emplace_back(v);
794         }
795         v_set.insert(v);
796       }
797     }
798     for (auto* v : n->outputs()) {
799       v_set.insert(v);
800     }
801     n_set.insert(n);
802   }
803   for (auto* v : input_list) {
804     inputs_.emplace_back(v);
805   }
806   for (auto* v : initializer_list) {
807     inputs_.emplace_back(v);
808   }
809 
810   for (auto* n : nlist) {
811     for (auto* v : n->outputs()) {
812       bool used_outside = false;
813       for (auto use : v->uses()) {
814         used_outside |= (n_set.find(use.user) == n_set.end());
815       }
816       if (used_outside) {
817         outputs_.emplace_back(v);
818       }
819     }
820   }
821 }
822 
HandleNoScopeNodes(scope_ctx_map & scope_ctxs,const node_list & no_scope_nlist)823 void FunctionExtractor::HandleNoScopeNodes(
824     scope_ctx_map& scope_ctxs,
825     const node_list& no_scope_nlist) {
826   GRAPH_UPDATE("No scope node count: ", no_scope_nlist.size());
827   for (auto n : no_scope_nlist) {
828     TORCH_WARN(
829         "ONNX function extraction cannot determine the scope for node: ", *n);
830   }
831   TORCH_INTERNAL_ASSERT(
832       no_scope_nlist.empty(),
833       "ONNX function extraction cannot determine the scope for the above nodes.");
834 }
835 
836 std::tuple<FunctionExtractor::scope_ctx_map, node_list> FunctionExtractor::
PartitionNodesByScope(Block * b)837     PartitionNodesByScope(Block* b) {
838   scope_ctx_map scope_ctxs = {};
839   node_list no_scope_nlist;
840 
841   auto find_or_create_scope_ctx = [](scope_ctx_map& scope_ctxs,
842                                      const ScopePtr& scope) {
843     if (scope_ctxs.find(scope) == scope_ctxs.end()) {
844       scope_ctxs.insert(std::make_pair(scope, new ScopeContext()));
845     }
846     return scope_ctxs[scope];
847   };
848 
849   auto record_node_scope = [&scope_ctxs, &find_or_create_scope_ctx](Node* n) {
850     const auto& scope = n->scope();
851     find_or_create_scope_ctx(scope_ctxs, scope)->scope_ = scope;
852     auto tmp_scope = scope;
853     while (IsValidScope(tmp_scope)) {
854       find_or_create_scope_ctx(scope_ctxs, tmp_scope)->nlist_.emplace_back(n);
855       if (IsValidScope(tmp_scope->parent())) {
856         find_or_create_scope_ctx(scope_ctxs, tmp_scope->parent())
857             ->children_.insert(tmp_scope);
858       }
859       tmp_scope = tmp_scope->parent();
860     }
861   };
862 
863   for (auto* n : b->nodes()) {
864     auto scope = n->scope();
865     if (scope && IsValidScope(scope)) {
866       record_node_scope(n);
867     } else {
868       auto inferred_scope = InferScope(n);
869 
870       if (inferred_scope.has_value() && IsValidScope(inferred_scope.value())) {
871         n->setScope(inferred_scope.value());
872         record_node_scope(n);
873       } else {
874         GRAPH_UPDATE("Cannot infer proper scope for node: ", *n);
875         no_scope_nlist.emplace_back(n);
876       }
877     }
878 
879     for (auto* sub_b : n->blocks()) {
880       auto [subblock_scope_ctxs, subblock_no_scope_nlist] =
881           PartitionNodesByScope(sub_b);
882 
883       for (auto& it : subblock_scope_ctxs) {
884         if (scope_ctxs.find(it.first) == scope_ctxs.end()) {
885           scope_ctxs.insert(std::make_pair(it.first, it.second));
886         } else {
887           for (auto* s_n : it.second->nlist_) {
888             scope_ctxs[it.first]->nlist_.emplace_back(s_n);
889           }
890           for (const auto& s_child_scope : it.second->children_) {
891             scope_ctxs[it.first]->children_.insert(s_child_scope);
892           }
893         }
894       }
895 
896       no_scope_nlist.insert(
897           no_scope_nlist.end(),
898           subblock_no_scope_nlist.begin(),
899           subblock_no_scope_nlist.end());
900     }
901   }
902 
903   for (auto& it : scope_ctxs) {
904     it.second->scope_ = it.first;
905     it.second->PopulateInputsOutputs(param_names_);
906   }
907 
908   return std::tie(scope_ctxs, no_scope_nlist);
909 }
910 
PartitionNodesByScope(const std::shared_ptr<Graph> & graph)911 FunctionExtractor::scope_ctx_map FunctionExtractor::PartitionNodesByScope(
912     const std::shared_ptr<Graph>& graph) {
913   scope_ctx_map scope_ctxs;
914   node_list no_scope_nlist;
915   std::tie(scope_ctxs, no_scope_nlist) = PartitionNodesByScope(graph->block());
916 
917   HandleNoScopeNodes(scope_ctxs, no_scope_nlist);
918 
919   return scope_ctxs;
920 }
921 
922 std::unordered_map<ScopePtr, scope_list> FunctionExtractor::
PartitionIdenticalScopes(FunctionExtractor::scope_ctx_map & scope_ctxs)923     PartitionIdenticalScopes(FunctionExtractor::scope_ctx_map& scope_ctxs) {
924   std::unordered_map<ScopePtr, scope_list> identical_scope_map;
925 
926   for (auto& it : scope_ctxs) {
927     auto scope = it.first;
928     const auto& scope_ctx = it.second;
929     bool unique = true;
930     for (auto& kv_it : identical_scope_map) {
931       auto key_scope = kv_it.first;
932       const auto& key_scope_ctx = scope_ctxs[key_scope];
933       auto& key_scope_vec = kv_it.second;
934       if (key_scope_ctx->IsIdenticalFuncion(*scope_ctx)) {
935         key_scope_vec.emplace_back(scope);
936         unique = false;
937         break;
938       }
939     }
940     if (unique) {
941       identical_scope_map[scope].emplace_back(scope);
942     }
943   }
944 
945   return identical_scope_map;
946 }
947 
HasSameAttribute(const Node * a,const Node * b,const c10::Symbol & attr)948 static bool HasSameAttribute(
949     const Node* a,
950     const Node* b,
951     const c10::Symbol& attr) {
952   if (!a->hasAttribute(attr) && !b->hasAttribute(attr)) {
953     return true;
954   }
955   if (!a->hasAttribute(attr) || !b->hasAttribute(attr)) {
956     return false;
957   }
958   auto a_kind = a->kindOf(attr);
959   auto b_kind = b->kindOf(attr);
960   if (a_kind != b_kind) {
961     return false;
962   }
963 
964 #define COMP_ATTR(kind)              \
965   case AttributeKind::kind: {        \
966     const auto& a_v = a->kind(attr); \
967     const auto& b_v = b->kind(attr); \
968     return a_v == b_v;               \
969   }
970 
971   switch (a_kind) {
972     COMP_ATTR(f)
973     COMP_ATTR(fs)
974     COMP_ATTR(i)
975     COMP_ATTR(is)
976     COMP_ATTR(s)
977     COMP_ATTR(ss)
978 #undef COMP_ATTR
979     case AttributeKind::t: {
980       const auto& a_v = a->t(attr);
981       const auto& b_v = b->t(attr);
982       return a_v.equal(b_v);
983     }
984     case AttributeKind::ts: {
985       const auto& a_v = a->ts(attr);
986       const auto& b_v = b->ts(attr);
987       return std::equal(
988           a_v.begin(),
989           a_v.end(),
990           b_v.begin(),
991           b_v.end(),
992           [](const at::Tensor& a_t, const at::Tensor& b_t) {
993             return a_t.equal(b_t);
994           });
995     }
996     case AttributeKind::ival:
997     case AttributeKind::g:
998     case AttributeKind::gs:
999     case AttributeKind::ty:
1000     case AttributeKind::tys:
1001     case AttributeKind::c:
1002     default:
1003       TORCH_INTERNAL_ASSERT(
1004           false,
1005           "Unexpected attribute type ",
1006           static_cast<int>(a_kind),
1007           " from node ",
1008           *a);
1009       break;
1010   }
1011 
1012   return true;
1013 }
1014 
SortScopesByMaxDepth(std::unordered_map<ScopePtr,scope_list> & identical_scope_map)1015 scope_list FunctionExtractor::SortScopesByMaxDepth(
1016     std::unordered_map<ScopePtr, scope_list>& identical_scope_map) {
1017   std::unordered_map<ScopePtr, size_t> scope_max_depth;
1018   for (const auto& it : identical_scope_map) {
1019     const auto& scopes = it.second;
1020     size_t max_depth = 0;
1021     for (const auto& scope : scopes) {
1022       if (scope->getDepth() > max_depth) {
1023         max_depth = scope->getDepth();
1024       }
1025     }
1026     scope_max_depth[it.first] = max_depth;
1027   }
1028 
1029   scope_list sorted_scopes;
1030   sorted_scopes.reserve(scope_max_depth.size());
1031   for (const auto& it : scope_max_depth) {
1032     sorted_scopes.emplace_back(it.first);
1033   }
1034   std::sort(
1035       sorted_scopes.begin(),
1036       sorted_scopes.end(),
1037       [&scope_max_depth](const ScopePtr& a, const ScopePtr& b) -> bool {
1038         return scope_max_depth[a] >= scope_max_depth[b];
1039       });
1040   return sorted_scopes;
1041 }
1042 
run()1043 NodeAttrNameMap FunctionExtractor::run() {
1044   auto scope_ctxs = PartitionNodesByScope(graph_);
1045   DebugPrintScopeContexts(scope_ctxs);
1046   auto identical_scope_map = PartitionIdenticalScopes(scope_ctxs);
1047   // Deepest scope comes first, guaranteeing no other scope can be its child.
1048   auto sorted_scope_keys = SortScopesByMaxDepth(identical_scope_map);
1049   for (const auto& scope_key : sorted_scope_keys) {
1050     if (module_names_.find(ONNXScopeName::className(scope_key)) !=
1051         module_names_.end()) {
1052       ConvertScopeToFunction(
1053           scope_key, identical_scope_map[scope_key], scope_ctxs, graph_);
1054     }
1055     GRAPH_DEBUG("Main graph afterwards: ", graph_->toString());
1056   }
1057   DebugPrintGraphWithFunction(graph_);
1058 
1059   // Construct return mappings
1060   NodeAttrNameMap node_attr_to_name;
1061 
1062   for (const auto& it : func_ctxs_) {
1063     auto func_ref_map = it.second->node_attr_to_name_;
1064     node_attr_to_name.insert(func_ref_map.begin(), func_ref_map.end());
1065   }
1066 
1067   // Clear
1068   for (auto& it : scope_ctxs) {
1069     delete it.second;
1070   }
1071   scope_ctxs.clear();
1072   for (auto& it : func_ctxs_) {
1073     delete it.second;
1074   }
1075   func_ctxs_.clear();
1076 
1077   return node_attr_to_name;
1078 }
1079 
1080 // Retrieves the node representing the most recent
1081 // ScopePtr. This function should only be invoked from module forward hook. At
1082 // this point, module forward call is completed, and the most recent ScopePtr
1083 // is popped from TracingState.
1084 // This function inspects the node, and its subblock, to find
1085 // the node associated with the most recent ScopePtr.
NodeOfMostRecentScope(Node * forward_node)1086 Node* NodeOfMostRecentScope(Node* forward_node) {
1087   TORCH_INTERNAL_ASSERT(
1088       forward_node->kind() == prim::TracedModuleForward,
1089       "forward_node got kind: ",
1090       forward_node->kind().toDisplayString());
1091   auto* block = forward_node->blocks()[0];
1092   for (auto* node : block->nodes().reverse()) {
1093     if (node->kind() == prim::TracedModuleForward) {
1094       Node* target_node = NodeOfMostRecentScope(node);
1095       if (scope_attr_map_.find(node->scope()) == scope_attr_map_.end()) {
1096         return target_node;
1097       }
1098     }
1099   }
1100   return forward_node;
1101 }
1102 
1103 } // namespace
1104 
1105 // FunctionExtractor runs in the following steps. Updates are made inplace to
1106 // the graph argument.
1107 //    1. Partition nodes into groups based on their scope information.
1108 //    Each scope represents an individual nn.Module call. A ScopeContext object
1109 //    is created for each group.
1110 //    2. Compare and find groups with the same subgraph pattern from step 1.
1111 //    3. Scopes are nested. Starting from the deepest scope, extract the
1112 //    subgraph pattern, and define as local function node. Replace subgraph
1113 //    pattern with a single node of the new local function node type. A
1114 //    FunctionContext object is created for each function.
1115 //    4. Construct NodeAttrNameMap tracking mapping from attribute name of
1116 //    IR Node inside function subgraph, to function attribute name.
ONNXFunctionExtraction(std::shared_ptr<Graph> & graph,const std::unordered_set<std::string> & module_names,const std::vector<std::string> & param_names)1117 NodeAttrNameMap ONNXFunctionExtraction(
1118     std::shared_ptr<Graph>& graph,
1119     const std::unordered_set<std::string>& module_names,
1120     const std::vector<std::string>& param_names) {
1121   GRAPH_UPDATE(
1122       "Export these module forward calls as functions: ",
1123       std::vector<std::string>{module_names.begin(), module_names.end()});
1124   FunctionExtractor fe(graph, module_names, param_names);
1125   return fe.run();
1126 }
1127 
ONNXGetPreviousScope(std::shared_ptr<Graph> & graph)1128 Node* ONNXGetPreviousScope(std::shared_ptr<Graph>& graph) {
1129   auto* last_node = graph->nodes().back()->prev();
1130   auto* scope_node = NodeOfMostRecentScope(last_node);
1131   auto* attr_node = scope_attr_graph_->create(prim::TracedModuleForward);
1132   attr_node->setScope(scope_node->scope());
1133   TORCH_INTERNAL_ASSERT(
1134       scope_attr_map_.find(scope_node->scope()) == scope_attr_map_.end(),
1135       "Found duplicated scope. Scope ",
1136       scope_node->scope()->namesFromRoot(),
1137       " already processed.");
1138   scope_attr_map_[scope_node->scope()] = attr_node;
1139   return attr_node;
1140 }
1141 
ONNXClearScopeRecords()1142 void ONNXClearScopeRecords() {
1143   scope_attr_map_.clear();
1144   scope_attr_graph_ = std::make_shared<Graph>();
1145 }
1146 
ONNXTrackScopeAttributes(std::shared_ptr<Graph> & graph,std::map<std::string,IValue> & attributes)1147 void ONNXTrackScopeAttributes(
1148     std::shared_ptr<Graph>& graph,
1149     std::map<std::string, IValue>& attributes) {
1150   // Skip the "real" last node which is `return_node`.
1151   auto* last_node = graph->nodes().back()->prev();
1152   auto* scope_node = NodeOfMostRecentScope(last_node);
1153   auto* attr_node = scope_attr_graph_->create(prim::TracedModuleForward);
1154   attr_node->setScope(scope_node->scope());
1155   TORCH_INTERNAL_ASSERT(
1156       scope_attr_map_.find(scope_node->scope()) == scope_attr_map_.end());
1157   scope_attr_map_[scope_node->scope()] = attr_node;
1158 
1159   for (const auto& it : attributes) {
1160     auto k = Symbol::attr(it.first);
1161     auto v = it.second;
1162     if (v.isTensor()) {
1163       attr_node->t_(k, v.toTensor());
1164     } else if (v.isInt()) {
1165       attr_node->i_(k, v.toInt());
1166     } else if (v.isDouble()) {
1167       attr_node->f_(k, v.toDouble());
1168     } else if (v.isBool()) {
1169       attr_node->i_(k, v.toBool());
1170     } else if (v.isString()) {
1171       attr_node->s_(k, v.toStringRef());
1172     } else if (v.isIntList()) {
1173       attr_node->is_(k, v.toIntList().vec());
1174     } else if (v.isBoolList()) {
1175       auto bool_list = v.toBoolList();
1176       attr_node->is_(
1177           k, std::vector<int64_t>(bool_list.begin(), bool_list.end()));
1178     } else if (v.isDoubleList()) {
1179       attr_node->fs_(k, v.toDoubleList().vec());
1180     }
1181   }
1182 }
1183 
1184 } // namespace torch::jit::onnx
1185