1 #include <torch/csrc/jit/jit_log.h>
2 #include <torch/csrc/jit/passes/onnx/function_extraction.h>
3 #include <torch/csrc/jit/passes/onnx/naming.h>
4
5 namespace torch::jit::onnx {
6
7 namespace {
8
9 using scope_list = std::vector<ScopePtr>;
10
11 // Annotated attributes retrieved from module by inspecting module annotations.
12 // These attributes are not used inside the subgraph of ONNX local function
13 // because they are not created by PyTorch JIT tracing, but they may be used by
14 // consumers to determine whether or not to replace the function with a
15 // particular fused kernel.
16 static std::unordered_map<ScopePtr, Node*> scope_attr_map_;
17 static std::shared_ptr<Graph> scope_attr_graph_ = std::make_shared<Graph>();
18
19 static bool HasSameAttribute(
20 const Node* a,
21 const Node* b,
22 const c10::Symbol& attr);
23
24 struct FunctionExtractor {
25 public:
FunctionExtractortorch::jit::onnx::__anon8b66f1e50111::FunctionExtractor26 FunctionExtractor(
27 std::shared_ptr<Graph>& graph,
28 const std::unordered_set<std::string>& module_names,
29 const std::vector<std::string>& param_names)
30 : graph_(graph),
31 module_names_(module_names.begin(), module_names.end()),
32 param_names_(param_names.begin(), param_names.end()) {}
33 NodeAttrNameMap run();
34
35 private:
36 struct ScopeContext {
37 std::unordered_set<ScopePtr> children_;
38 ScopePtr scope_;
39 node_list nlist_;
40 value_list inputs_;
41 value_list outputs_;
42 std::unordered_map<Value*, Value*> env_to_subgraph_;
43
44 void PopulateInputsOutputs(
45 const std::unordered_set<std::string>& param_names);
46 bool IsIdenticalFuncion(const ScopeContext& other_ctx) const;
47 };
48
49 using ScopeCtxPtr = ScopeContext*;
50 using scope_ctx_map = std::unordered_map<ScopePtr, ScopeCtxPtr>;
51
52 struct FunctionContext {
53 FunctionContext(
54 ScopePtr key,
55 const scope_list& scopes,
56 scope_ctx_map& scope_ctxs);
57 void DebugPrint() const;
58 void SetAttrName(Node* ref_n, Symbol attr, const std::string& name);
59 std::optional<std::string> FindAttrName(Node* ref_n, Symbol attr);
60 std::optional<std::string> FindAttrName(Node* ref_const_n);
61
62 ScopePtr scope_key_;
63 scope_ctx_map scope_ctxs_;
64 std::unordered_map<
65 Node*,
66 std::unordered_map<Symbol, std::unordered_set<Node*>>>
67 attribute_map_;
68
69 // Passed later to serialization.
70 NodeAttrNameMap node_attr_to_name_;
71 };
72
73 using FunctionCtxPtr = FunctionContext*;
74 using func_ctx_map = std::unordered_map<ScopePtr, FunctionCtxPtr>;
75
76 static bool IsValidScope(const ScopePtr& s);
77 static std::optional<ScopePtr> InferScope(Node* n);
78 static bool IsAncestor(const ScopePtr& parent, ScopePtr child);
79 static std::optional<ScopePtr> FindCommonAncestor(ScopePtr a, ScopePtr b);
80 static std::optional<ScopePtr> FindCommonAncestor(const scope_list& scopes);
81 std::shared_ptr<Graph> ConstructFuncGraph(FunctionContext& ctx);
82
83 void ConvertScopeToFunction(
84 const ScopePtr& scope_key,
85 const scope_list& scope_list,
86 scope_ctx_map& scope_ctxs,
87 const std::shared_ptr<Graph>& graph);
88
89 static void HandleNoScopeNodes(
90 scope_ctx_map&,
91 const node_list& no_scope_nlist);
92 std::tuple<scope_ctx_map, node_list> PartitionNodesByScope(Block* b);
93 scope_ctx_map PartitionNodesByScope(const std::shared_ptr<Graph>& graph);
94 static std::unordered_map<ScopePtr, scope_list> PartitionIdenticalScopes(
95 scope_ctx_map& scope_ctxs);
96 static scope_list SortScopesByMaxDepth(
97 std::unordered_map<ScopePtr, scope_list>&);
98 Node* CreateFunctionDefNode(
99 FunctionContext& func_ctx,
100 const std::shared_ptr<Graph>& graph,
101 const std::string& domain_name,
102 const std::string& func_name);
103 Node* CreateFunctionNode(
104 FunctionContext& func_ctx,
105 ScopeContext& scope_ctx,
106 const std::shared_ptr<Graph>& graph,
107 const std::string& domain_name,
108 const std::string& func_name);
109
110 static void DebugPrintScopeContexts(const scope_ctx_map&);
111 static void DebugPrintGraphWithFunction(const std::shared_ptr<Graph>& g);
112 static void DebugPrintConstantDiff(const FunctionContext&);
113
114 std::shared_ptr<Graph> graph_;
115 std::unordered_set<std::string> module_names_;
116 std::unordered_set<std::string> param_names_;
117 // Track modules with same module name that are exported as different onnx
118 // local functions.
119 std::unordered_map<std::string, int> module_variant_count_;
120 func_ctx_map func_ctxs_;
121 };
122
FunctionContext(ScopePtr key,const scope_list & scopes,scope_ctx_map & scope_ctxs)123 FunctionExtractor::FunctionContext::FunctionContext(
124 ScopePtr key,
125 const scope_list& scopes,
126 scope_ctx_map& scope_ctxs)
127 : scope_key_(std::move(key)) {
128 GRAPH_UPDATE(
129 "Process function context for scope ",
130 scope_key_->name().toDisplayString());
131 TORCH_INTERNAL_ASSERT(!scopes.empty());
132 const auto& ref_ctx = scope_ctxs[scope_key_];
133 // NOTE: Function scopes must have same number and order of nodes.
134 GRAPH_DEBUG(
135 "Initialized function context for scope ",
136 scope_key_->name().toDisplayString());
137
138 for (const auto& scope : scopes) {
139 GRAPH_DEBUG(
140 "Process function context for scope ", scope->name().toDisplayString());
141 TORCH_INTERNAL_ASSERT(scope_ctxs.find(scope) != scope_ctxs.end());
142 scope_ctxs_[scope] = scope_ctxs[scope];
143 if (scope_key_ == scope) {
144 continue;
145 }
146 auto& scope_ctx = scope_ctxs[scope];
147
148 const auto& ns_a = ref_ctx->nlist_;
149 const auto& ns_b = scope_ctx->nlist_;
150 TORCH_INTERNAL_ASSERT(ns_a.size() == ns_b.size());
151
152 GRAPH_DEBUG("Process nodes of scope ", scope->name().toDisplayString());
153 for (const auto i : c10::irange(ns_a.size())) {
154 TORCH_INTERNAL_ASSERT(ns_a[i]->kind() == ns_b[i]->kind());
155 auto n_a = ns_a[i];
156 auto n_b = ns_b[i];
157 std::vector<c10::Symbol> diff_attrs;
158 std::vector<c10::Symbol> same_attrs;
159 auto n_a_attr_names = n_a->attributeNames();
160 auto n_b_attr_names = n_b->attributeNames();
161 std::sort(n_a_attr_names.begin(), n_a_attr_names.end());
162 std::sort(n_b_attr_names.begin(), n_b_attr_names.end());
163 std::set_difference(
164 n_a_attr_names.begin(),
165 n_a_attr_names.end(),
166 n_b_attr_names.begin(),
167 n_b_attr_names.end(),
168 std::inserter(diff_attrs, diff_attrs.begin()));
169 std::set_intersection(
170 n_a_attr_names.begin(),
171 n_a_attr_names.end(),
172 n_b_attr_names.begin(),
173 n_b_attr_names.end(),
174 std::inserter(same_attrs, same_attrs.begin()));
175 for (auto attr_name : diff_attrs) {
176 attribute_map_[n_a][attr_name].insert(n_b);
177 }
178
179 for (auto attr_name : same_attrs) {
180 if (!HasSameAttribute(n_a, n_b, attr_name)) {
181 attribute_map_[n_a][attr_name].insert(n_b);
182 }
183 }
184 }
185 GRAPH_DEBUG("Process scope complete. ", scope->name().toDisplayString());
186 }
187
188 GRAPH_DEBUG(
189 "Process function context complete. ",
190 scope_key_->name().toDisplayString());
191 DebugPrint();
192 }
193
DebugPrint() const194 void FunctionExtractor::FunctionContext::DebugPrint() const {
195 GRAPH_DEBUG("Scope name: ", scope_key_->name().toDisplayString());
196
197 for (const auto& it : attribute_map_) {
198 for (const auto& attr_it : it.second) {
199 GRAPH_DEBUG(
200 "Attribute value difference for attribute ",
201 attr_it.first.toDisplayString());
202 GRAPH_DEBUG(*it.first);
203 for (auto n : attr_it.second) {
204 GRAPH_DEBUG(*n);
205 }
206 }
207 }
208 }
209
SetAttrName(Node * ref_n,Symbol attr,const std::string & name)210 void FunctionExtractor::FunctionContext::SetAttrName(
211 Node* ref_n,
212 Symbol attr,
213 const std::string& name) {
214 auto v_it =
215 scope_ctxs_[scope_key_]->env_to_subgraph_.find(ref_n->outputs().at(0));
216 TORCH_INTERNAL_ASSERT(
217 v_it != scope_ctxs_[scope_key_]->env_to_subgraph_.end());
218 auto* n_in_def = v_it->second->node();
219 auto n_attr_it = node_attr_to_name_[n_in_def][attr.toUnqualString()] = name;
220 }
221
FindAttrName(Node * ref_n,Symbol attr)222 std::optional<std::string> FunctionExtractor::FunctionContext::FindAttrName(
223 Node* ref_n,
224 Symbol attr) {
225 auto v_it =
226 scope_ctxs_[scope_key_]->env_to_subgraph_.find(ref_n->outputs().at(0));
227 if (v_it == scope_ctxs_[scope_key_]->env_to_subgraph_.end()) {
228 return std::nullopt;
229 }
230 auto* n_in_def = v_it->second->node();
231 auto n_attr_it = node_attr_to_name_.find(n_in_def);
232 if (n_attr_it == node_attr_to_name_.end()) {
233 return std::nullopt;
234 }
235 auto name_it = n_attr_it->second.find(attr.toUnqualString());
236 if (name_it == n_attr_it->second.end()) {
237 return std::nullopt;
238 }
239 return name_it->second;
240 }
241
DebugPrintScopeContexts(const scope_ctx_map & scope_ctxs)242 void FunctionExtractor::DebugPrintScopeContexts(
243 const scope_ctx_map& scope_ctxs) {
244 for (auto& it : scope_ctxs) {
245 GRAPH_UPDATE(
246 "Scope name: ",
247 it.first->namesFromRoot(),
248 " ",
249 it.first->name().toDisplayString());
250 GRAPH_UPDATE("Children scopes: ", [&]() {
251 std::stringstream ss;
252 for (const auto& child_scope : it.second->children_) {
253 ss << child_scope->name().toDisplayString() << " ";
254 }
255 return ss.str();
256 }());
257 GRAPH_UPDATE("Node types: \n", [&]() {
258 std::stringstream ss;
259 for (auto n : it.second->nlist_) {
260 ss << " " << *n;
261 }
262 return ss.str();
263 }());
264 GRAPH_UPDATE("Node count: ", it.second->nlist_.size());
265 }
266 }
267
DebugPrintGraphWithFunction(const std::shared_ptr<Graph> & g)268 void FunctionExtractor::DebugPrintGraphWithFunction(
269 const std::shared_ptr<Graph>& g) {
270 GRAPH_UPDATE("Local function definitions:");
271 for (auto* n : g->nodes()) {
272 if (n->kind() == Symbol::onnx("LocalFunctionDef")) {
273 GRAPH_UPDATE(
274 n->s(attr::name),
275 " graph: ",
276 n->g(Symbol::attr("graph"))->toString());
277 }
278 }
279 GRAPH_UPDATE("Main graph: ", g->toString());
280 }
281
IsValidScope(const ScopePtr & s)282 bool FunctionExtractor::IsValidScope(const ScopePtr& s) {
283 return !s->isRoot() && !s->isBlank();
284 }
285
IsAncestor(const ScopePtr & parent,ScopePtr child)286 bool FunctionExtractor::IsAncestor(const ScopePtr& parent, ScopePtr child) {
287 if (!IsValidScope(parent) || !IsValidScope(child) ||
288 parent->getDepth() >= child->getDepth()) {
289 return false;
290 }
291 do {
292 child = child->parent();
293 if (parent == child) {
294 return true;
295 }
296 } while (IsValidScope(child));
297 return false;
298 }
299
FindCommonAncestor(ScopePtr a,ScopePtr b)300 std::optional<ScopePtr> FunctionExtractor::FindCommonAncestor(
301 ScopePtr a,
302 ScopePtr b) {
303 if (!IsValidScope(a) || !IsValidScope(b)) {
304 return std::nullopt;
305 }
306
307 auto diff =
308 static_cast<int64_t>(a->getDepth()) - static_cast<int64_t>(b->getDepth());
309 if (diff != 0) {
310 auto deeper_scope = diff > 0 ? a : b;
311 auto other_scope = diff > 0 ? b : a;
312 diff = std::abs(diff);
313 while (diff > 0) {
314 deeper_scope = deeper_scope->parent();
315 diff--;
316 }
317 a = deeper_scope;
318 b = other_scope;
319 }
320
321 while (IsValidScope(a) && IsValidScope(b)) {
322 if (a == b) {
323 return a;
324 } else {
325 a = a->parent();
326 b = b->parent();
327 }
328 }
329
330 return std::nullopt;
331 }
332
FindCommonAncestor(const scope_list & scopes)333 std::optional<ScopePtr> FunctionExtractor::FindCommonAncestor(
334 const scope_list& scopes) {
335 if (scopes.empty()) {
336 return std::nullopt;
337 }
338
339 std::optional<ScopePtr> common_ancestor = scopes.at(0);
340 for (const auto& scope : scopes) {
341 common_ancestor = FindCommonAncestor(common_ancestor.value(), scope);
342 if (!common_ancestor.has_value()) {
343 return std::nullopt;
344 }
345 }
346
347 return common_ancestor;
348 }
349
InferScope(Node * n)350 std::optional<ScopePtr> FunctionExtractor::InferScope(Node* n) {
351 // The scope of node n is assigned based on the following rules.
352 // 1. If all uses of outputs of n belongs to the same scope,
353 // assign that scope, otherwise
354 // 2. If all nodes of inputs of n belongs to the same scope,
355 // assign that scope, otherwise
356 // 3. Find common ancestor of the scopes of uses of outputs of n,
357 // and the scopes of nodes of inputs of n.
358 scope_list input_scopes;
359 scope_list output_scopes;
360 for (auto input : n->inputs()) {
361 input_scopes.emplace_back(input->node()->scope());
362 }
363 for (auto output : n->outputs()) {
364 for (auto use : output->uses()) {
365 if (!IsValidScope(use.user->scope())) {
366 auto inferred_output_scope = InferScope(use.user);
367 if (inferred_output_scope.has_value() &&
368 IsValidScope(inferred_output_scope.value())) {
369 use.user->setScope(inferred_output_scope.value());
370 }
371 }
372 output_scopes.emplace_back(use.user->scope());
373 }
374 }
375 if (!output_scopes.empty() &&
376 std::all_of(
377 output_scopes.begin(),
378 output_scopes.end(),
379 [&output_scopes](const ScopePtr& scope) -> bool {
380 return IsValidScope(scope) && scope == output_scopes.at(0);
381 })) {
382 return output_scopes.at(0);
383 } else if (
384 !input_scopes.empty() &&
385 std::all_of(
386 input_scopes.begin(),
387 input_scopes.end(),
388 [&input_scopes](const ScopePtr& scope) -> bool {
389 return IsValidScope(scope) && scope == input_scopes.at(0);
390 })) {
391 return input_scopes.at(0);
392 } else {
393 scope_list scopes;
394 std::copy_if(
395 input_scopes.begin(),
396 input_scopes.end(),
397 std::back_inserter(scopes),
398 IsValidScope);
399 std::copy_if(
400 output_scopes.begin(),
401 output_scopes.end(),
402 std::back_inserter(scopes),
403 IsValidScope);
404 if (!scopes.empty()) {
405 auto common_ancestor = FindCommonAncestor(scopes);
406 if (common_ancestor.has_value() &&
407 IsValidScope(common_ancestor.value())) {
408 return common_ancestor.value();
409 }
410 }
411 }
412
413 return std::nullopt;
414 }
415
ConstructFuncGraph(FunctionContext & func_ctx)416 std::shared_ptr<Graph> FunctionExtractor::ConstructFuncGraph(
417 FunctionContext& func_ctx) {
418 auto& ctx = *func_ctx.scope_ctxs_[func_ctx.scope_key_];
419 const auto& nlist = ctx.nlist_;
420 const auto& scope = ctx.scope_;
421 auto& env = ctx.env_to_subgraph_;
422
423 auto g = std::make_shared<Graph>();
424 GRAPH_DEBUG("Constructing graph for ", scope->namesFromRoot());
425
426 // TODO: Update input names of function to match those in Module source code
427 // signature.
428 // This requires mapping between function node inputs and Module inputs.
429 // Due to the lack of such mapping, currently debugName is used as input
430 // names.
431 ctx.PopulateInputsOutputs(param_names_);
432 for (auto* v : ctx.inputs_) {
433 env[v] = g->addInput()->copyMetadata(v);
434 GRAPH_DEBUG(
435 "Add input value ",
436 env[v]->debugName(),
437 " for outer scope value ",
438 v->debugName(),
439 " from ",
440 *v->node());
441 }
442
443 for (auto* n : nlist) {
444 auto clone_n = g->createClone(n, [&](Value* v) {
445 TORCH_INTERNAL_ASSERT(env.find(v) != env.end());
446 return env[v];
447 });
448 for (const auto i : c10::irange(clone_n->outputs().size())) {
449 env[n->output(i)] = clone_n->output(i);
450 }
451 g->insertNode(clone_n);
452 }
453
454 // If values are used outside of this graph, set as graph output.
455 for (auto* v : ctx.outputs_) {
456 TORCH_INTERNAL_ASSERT(env.find(v) != env.end());
457 g->registerOutput(env[v]);
458 }
459
460 GRAPH_DEBUG(g->toString());
461 return g;
462 }
463
CreateFunctionDefNode(FunctionContext & func_ctx,const std::shared_ptr<Graph> & graph,const std::string & domain_name,const std::string & func_name)464 Node* FunctionExtractor::CreateFunctionDefNode(
465 FunctionContext& func_ctx,
466 const std::shared_ptr<Graph>& graph,
467 const std::string& domain_name,
468 const std::string& func_name) {
469 const auto func_def_nk = Symbol::onnx("LocalFunctionDef");
470 const auto func_g_attr = Symbol::attr("graph");
471 const auto func_name_attr = attr::name;
472 const auto func_domain_attr = Symbol::attr("domain");
473
474 auto func_graph = ConstructFuncGraph(func_ctx);
475
476 // create and insert local function definition node
477 auto func_def_n = graph->create(func_def_nk, 0);
478 func_def_n->g_(func_g_attr, func_graph);
479 func_def_n->s_(func_name_attr, func_name);
480 func_def_n->s_(func_domain_attr, domain_name);
481 graph->prependNode(func_def_n);
482
483 // set constants and attributes of different values as function attributes.
484 std::unordered_map<std::string, int> base_attr_name_count;
485 std::vector<std::string> final_attr_names;
486
487 auto adjust_attr_name = [&](std::string attr_name) {
488 if (base_attr_name_count.find(attr_name) != base_attr_name_count.end()) {
489 attr_name =
490 attr_name + "." + std::to_string(base_attr_name_count[attr_name]++);
491 } else {
492 base_attr_name_count[attr_name] = 1;
493 }
494 return attr_name;
495 };
496
497 for (const auto& n_it : func_ctx.attribute_map_) {
498 auto* n = n_it.first;
499 for (const auto& attr_it : n_it.second) {
500 const auto& attr = attr_it.first;
501 // Add prefix "inferred::" to name of inferred attribute.
502 // This is to differentiate from annotated attributes picked up
503 // from python module annotation.
504 auto attr_name = "inferred::" + std::string(n->kind().toUnqualString()) +
505 '_' + attr.toUnqualString();
506 auto final_attr_name = adjust_attr_name(attr_name);
507 final_attr_names.emplace_back(final_attr_name);
508 func_ctx.SetAttrName(n, attr, final_attr_name);
509 }
510 }
511
512 // Set annotated attributes
513 std::unordered_set<Symbol> annotated_attr_names;
514 bool first_iteration = true;
515 for (const auto& it : func_ctx.scope_ctxs_) {
516 auto scope = it.first;
517 auto annotated_attr_node = scope_attr_map_.find(scope);
518 if (annotated_attr_node != scope_attr_map_.end()) {
519 auto names = annotated_attr_node->second->attributeNames();
520 if (first_iteration) {
521 std::copy(
522 names.begin(),
523 names.end(),
524 std::inserter(annotated_attr_names, annotated_attr_names.end()));
525 first_iteration = false;
526 } else {
527 auto unseen_attr_name = std::find_if(
528 names.begin(),
529 names.end(),
530 [&annotated_attr_names](const Symbol& name) {
531 return annotated_attr_names.find(name) ==
532 annotated_attr_names.end();
533 });
534 TORCH_CHECK(
535 unseen_attr_name == names.end(),
536 "Found outstanding annotated attribute ",
537 *unseen_attr_name,
538 " from module ",
539 scope->name(),
540 ". Please ensure module instances of the same class have the same set of annotated attributes.");
541 }
542 }
543 }
544 for (auto attr_name : annotated_attr_names) {
545 final_attr_names.emplace_back(attr_name.toUnqualString());
546 }
547
548 func_def_n->ss_(Symbol::attr("attributes"), final_attr_names);
549
550 return func_def_n;
551 }
552
CreateFunctionNode(FunctionContext & func_ctx,ScopeContext & scope_ctx,const std::shared_ptr<Graph> & graph,const std::string & domain_name,const std::string & func_name)553 Node* FunctionExtractor::CreateFunctionNode(
554 FunctionContext& func_ctx,
555 ScopeContext& scope_ctx,
556 const std::shared_ptr<Graph>& graph,
557 const std::string& domain_name,
558 const std::string& func_name) {
559 const auto& func_scope = func_ctx.scope_key_;
560 GRAPH_DEBUG(
561 "Create and insert local function for scope: ",
562 func_scope->namesFromRoot());
563 scope_ctx.PopulateInputsOutputs(param_names_);
564 auto last_n = *scope_ctx.nlist_.rbegin();
565 auto func_n = graph->create(
566 Symbol::fromQualString(domain_name + "::" + func_name),
567 scope_ctx.outputs_.size());
568 func_n->copyMetadata(last_n);
569 for (auto* v : scope_ctx.inputs_) {
570 func_n->addInput(v);
571 }
572 for (const auto i : c10::irange(scope_ctx.outputs_.size())) {
573 func_n->output(i)->setType(scope_ctx.outputs_[i]->type());
574 scope_ctx.outputs_[i]->replaceAllUsesWith(func_n->output(i));
575 }
576
577 // set attributes of different values as function attributes.
578 auto copy_attr =
579 [](Node* a, Node* b, Symbol attr, const std::string& new_name) {
580 #define COPY_ATTR(kind) \
581 case AttributeKind::kind: { \
582 b->kind##_(Symbol::attr(new_name), a->kind(attr)); \
583 break; \
584 }
585 switch (a->kindOf(attr)) {
586 COPY_ATTR(f)
587 COPY_ATTR(fs)
588 COPY_ATTR(i)
589 COPY_ATTR(is)
590 COPY_ATTR(s)
591 COPY_ATTR(ss)
592 COPY_ATTR(t)
593 COPY_ATTR(ts)
594 #undef COPY_ATTR
595 case AttributeKind::ival:
596 case AttributeKind::g:
597 case AttributeKind::gs:
598 case AttributeKind::ty:
599 case AttributeKind::tys:
600 case AttributeKind::c:
601 default:
602 TORCH_INTERNAL_ASSERT(
603 false,
604 "Unexpected attribute type ",
605 static_cast<int>(a->kindOf(attr)),
606 " from node ",
607 *a);
608 break;
609 }
610 };
611
612 for (const auto& it : func_ctx.attribute_map_) {
613 auto* ref_n = it.first;
614 for (const auto& attr_it : it.second) {
615 const auto& attr = attr_it.first;
616 auto attr_name = func_ctx.FindAttrName(ref_n, attr).value();
617 copy_attr(ref_n, func_n, attr, attr_name);
618 for (auto* n : scope_ctx.nlist_) {
619 if (attr_it.second.find(n) != attr_it.second.end()) {
620 copy_attr(n, func_n, attr, attr_name);
621 break;
622 }
623 }
624 }
625 }
626
627 // annotated attributes
628 auto scope = scope_ctx.scope_;
629 auto annotated_attr_node = scope_attr_map_.find(scope);
630 if (annotated_attr_node != scope_attr_map_.end()) {
631 auto node = annotated_attr_node->second;
632 for (auto attr : node->attributeNames()) {
633 copy_attr(node, func_n, attr, attr.toUnqualString());
634 }
635 }
636
637 func_n->insertAfter(last_n);
638 return func_n;
639 }
640
ConvertScopeToFunction(const ScopePtr & scope_key,const scope_list & scope_list,scope_ctx_map & scope_ctxs,const std::shared_ptr<Graph> & graph)641 void FunctionExtractor::ConvertScopeToFunction(
642 const ScopePtr& scope_key,
643 const scope_list& scope_list,
644 scope_ctx_map& scope_ctxs,
645 const std::shared_ptr<Graph>& graph) {
646 // This function needs to be called always on inner most scopes.
647 // 1. Generate function context, this identifies different constants and
648 // attributes.
649 // 2. Create function definition node, and insert to main graph.
650 // 3. Create function node for each call, and replace subgraph nodes in parent
651 // functions.
652
653 func_ctxs_.insert(std::make_pair(
654 scope_key, new FunctionContext(scope_key, scope_list, scope_ctxs)));
655 auto& func_ctx = *func_ctxs_[scope_key];
656
657 const std::string module_class_name(
658 ONNXScopeName::className(func_ctx.scope_key_));
659 auto pos = module_class_name.rfind('.');
660 TORCH_INTERNAL_ASSERT(pos != std::string::npos);
661
662 auto construct_unique_module_name = [&](std::string module_name) {
663 auto module_name_variant = module_variant_count_.find(module_name);
664 if (module_name_variant != module_variant_count_.end()) {
665 module_variant_count_[module_name]++;
666 module_name += ("." + std::to_string(module_name_variant->second));
667 } else {
668 module_variant_count_[module_name] = 0;
669 }
670 return module_name;
671 };
672
673 const auto domain_name = module_class_name.substr(0, pos);
674 const auto func_name =
675 construct_unique_module_name(module_class_name.substr(pos + 1));
676
677 CreateFunctionDefNode(func_ctx, graph, domain_name, func_name);
678
679 // create and insert local function node to graph.
680 for (const auto& it : func_ctx.scope_ctxs_) {
681 auto scope = it.first;
682 auto& scope_ctx = *it.second;
683 auto func_n =
684 CreateFunctionNode(func_ctx, scope_ctx, graph, domain_name, func_name);
685
686 std::unordered_set<Node*> old_nodes(
687 scope_ctx.nlist_.begin(), scope_ctx.nlist_.end());
688
689 auto last_n = *scope_ctx.nlist_.rbegin();
690 // replace function body nodes in parent scopes with local function node.
691 for (auto& it : scope_ctxs) {
692 const auto& parent_scope = it.first;
693 auto& parent_ctx = *it.second;
694
695 if (!IsAncestor(parent_scope, scope)) {
696 continue;
697 }
698
699 auto& ctx_nlist = parent_ctx.nlist_;
700 GRAPH_DEBUG(
701 "Replace local function node in parent scope: ",
702 it.first->namesFromRoot(),
703 " nodes to remove: ",
704 old_nodes.size(),
705 " parent total nodes: ",
706 ctx_nlist.size());
707
708 // insert local function node
709 auto last_n_it = std::find(ctx_nlist.begin(), ctx_nlist.end(), last_n);
710 ctx_nlist.insert(last_n_it, func_n);
711
712 // remove replaced nodes from list
713 ctx_nlist.erase(
714 std::remove_if(
715 ctx_nlist.begin(),
716 ctx_nlist.end(),
717 [&old_nodes](Node* n) {
718 return old_nodes.find(n) != old_nodes.end();
719 }),
720 ctx_nlist.end());
721
722 GRAPH_DEBUG("Parent total nodes after remove: ", ctx_nlist.size());
723
724 // refresh inputs/outputs.
725 parent_ctx.PopulateInputsOutputs(param_names_);
726 }
727 }
728
729 for (const auto& it : func_ctx.scope_ctxs_) {
730 auto& scope_ctx = *it.second;
731 // delete replaced nodes in graph.
732 for (auto it = scope_ctx.nlist_.rbegin(); it != scope_ctx.nlist_.rend();) {
733 auto* n = *it;
734 it++;
735 GRAPH_DEBUG("Destroying node ", *n);
736 n->destroy();
737 }
738 }
739 }
740
IsIdenticalFuncion(const ScopeContext & other_ctx) const741 bool FunctionExtractor::ScopeContext::IsIdenticalFuncion(
742 const ScopeContext& other_ctx) const {
743 // Differentiate same function under different inputs.
744 // When constants are passed in place of inputs, it leads to different
745 // input count and node count. Likewise, due to different uses, output
746 // count can be different as well.
747 // For now export them as different functions.
748 // Covered by `test_local_function_overloads` in
749 // `test/onnx/test_utility_funs.py`.
750 if (&other_ctx == this) {
751 return true;
752 }
753 if (ONNXScopeName::className(this->scope_) !=
754 ONNXScopeName::className(other_ctx.scope_)) {
755 return false;
756 }
757 if (this->inputs_.size() != other_ctx.inputs_.size() ||
758 this->outputs_.size() != other_ctx.outputs_.size()) {
759 return false;
760 }
761 const auto& ns_a = this->nlist_;
762 const auto& ns_b = other_ctx.nlist_;
763 if (ns_a.size() != ns_b.size()) {
764 return false;
765 }
766 for (const auto i : c10::irange(ns_a.size())) {
767 if (ns_a[i]->kind() != ns_b[i]->kind()) {
768 return false;
769 }
770 }
771
772 return true;
773 }
774
PopulateInputsOutputs(const std::unordered_set<std::string> & param_names)775 void FunctionExtractor::ScopeContext::PopulateInputsOutputs(
776 const std::unordered_set<std::string>& param_names) {
777 inputs_.clear();
778 outputs_.clear();
779 const auto& nlist = this->nlist_;
780 std::unordered_set<Value*> v_set;
781 std::unordered_set<Node*> n_set;
782
783 value_list input_list;
784 value_list initializer_list;
785
786 // Add initializers after inputs.
787 for (auto* n : nlist) {
788 for (auto* v : n->inputs()) {
789 if (v_set.find(v) == v_set.end()) {
790 if (param_names.find(v->debugName()) != param_names.end()) {
791 initializer_list.emplace_back(v);
792 } else {
793 input_list.emplace_back(v);
794 }
795 v_set.insert(v);
796 }
797 }
798 for (auto* v : n->outputs()) {
799 v_set.insert(v);
800 }
801 n_set.insert(n);
802 }
803 for (auto* v : input_list) {
804 inputs_.emplace_back(v);
805 }
806 for (auto* v : initializer_list) {
807 inputs_.emplace_back(v);
808 }
809
810 for (auto* n : nlist) {
811 for (auto* v : n->outputs()) {
812 bool used_outside = false;
813 for (auto use : v->uses()) {
814 used_outside |= (n_set.find(use.user) == n_set.end());
815 }
816 if (used_outside) {
817 outputs_.emplace_back(v);
818 }
819 }
820 }
821 }
822
HandleNoScopeNodes(scope_ctx_map & scope_ctxs,const node_list & no_scope_nlist)823 void FunctionExtractor::HandleNoScopeNodes(
824 scope_ctx_map& scope_ctxs,
825 const node_list& no_scope_nlist) {
826 GRAPH_UPDATE("No scope node count: ", no_scope_nlist.size());
827 for (auto n : no_scope_nlist) {
828 TORCH_WARN(
829 "ONNX function extraction cannot determine the scope for node: ", *n);
830 }
831 TORCH_INTERNAL_ASSERT(
832 no_scope_nlist.empty(),
833 "ONNX function extraction cannot determine the scope for the above nodes.");
834 }
835
836 std::tuple<FunctionExtractor::scope_ctx_map, node_list> FunctionExtractor::
PartitionNodesByScope(Block * b)837 PartitionNodesByScope(Block* b) {
838 scope_ctx_map scope_ctxs = {};
839 node_list no_scope_nlist;
840
841 auto find_or_create_scope_ctx = [](scope_ctx_map& scope_ctxs,
842 const ScopePtr& scope) {
843 if (scope_ctxs.find(scope) == scope_ctxs.end()) {
844 scope_ctxs.insert(std::make_pair(scope, new ScopeContext()));
845 }
846 return scope_ctxs[scope];
847 };
848
849 auto record_node_scope = [&scope_ctxs, &find_or_create_scope_ctx](Node* n) {
850 const auto& scope = n->scope();
851 find_or_create_scope_ctx(scope_ctxs, scope)->scope_ = scope;
852 auto tmp_scope = scope;
853 while (IsValidScope(tmp_scope)) {
854 find_or_create_scope_ctx(scope_ctxs, tmp_scope)->nlist_.emplace_back(n);
855 if (IsValidScope(tmp_scope->parent())) {
856 find_or_create_scope_ctx(scope_ctxs, tmp_scope->parent())
857 ->children_.insert(tmp_scope);
858 }
859 tmp_scope = tmp_scope->parent();
860 }
861 };
862
863 for (auto* n : b->nodes()) {
864 auto scope = n->scope();
865 if (scope && IsValidScope(scope)) {
866 record_node_scope(n);
867 } else {
868 auto inferred_scope = InferScope(n);
869
870 if (inferred_scope.has_value() && IsValidScope(inferred_scope.value())) {
871 n->setScope(inferred_scope.value());
872 record_node_scope(n);
873 } else {
874 GRAPH_UPDATE("Cannot infer proper scope for node: ", *n);
875 no_scope_nlist.emplace_back(n);
876 }
877 }
878
879 for (auto* sub_b : n->blocks()) {
880 auto [subblock_scope_ctxs, subblock_no_scope_nlist] =
881 PartitionNodesByScope(sub_b);
882
883 for (auto& it : subblock_scope_ctxs) {
884 if (scope_ctxs.find(it.first) == scope_ctxs.end()) {
885 scope_ctxs.insert(std::make_pair(it.first, it.second));
886 } else {
887 for (auto* s_n : it.second->nlist_) {
888 scope_ctxs[it.first]->nlist_.emplace_back(s_n);
889 }
890 for (const auto& s_child_scope : it.second->children_) {
891 scope_ctxs[it.first]->children_.insert(s_child_scope);
892 }
893 }
894 }
895
896 no_scope_nlist.insert(
897 no_scope_nlist.end(),
898 subblock_no_scope_nlist.begin(),
899 subblock_no_scope_nlist.end());
900 }
901 }
902
903 for (auto& it : scope_ctxs) {
904 it.second->scope_ = it.first;
905 it.second->PopulateInputsOutputs(param_names_);
906 }
907
908 return std::tie(scope_ctxs, no_scope_nlist);
909 }
910
PartitionNodesByScope(const std::shared_ptr<Graph> & graph)911 FunctionExtractor::scope_ctx_map FunctionExtractor::PartitionNodesByScope(
912 const std::shared_ptr<Graph>& graph) {
913 scope_ctx_map scope_ctxs;
914 node_list no_scope_nlist;
915 std::tie(scope_ctxs, no_scope_nlist) = PartitionNodesByScope(graph->block());
916
917 HandleNoScopeNodes(scope_ctxs, no_scope_nlist);
918
919 return scope_ctxs;
920 }
921
922 std::unordered_map<ScopePtr, scope_list> FunctionExtractor::
PartitionIdenticalScopes(FunctionExtractor::scope_ctx_map & scope_ctxs)923 PartitionIdenticalScopes(FunctionExtractor::scope_ctx_map& scope_ctxs) {
924 std::unordered_map<ScopePtr, scope_list> identical_scope_map;
925
926 for (auto& it : scope_ctxs) {
927 auto scope = it.first;
928 const auto& scope_ctx = it.second;
929 bool unique = true;
930 for (auto& kv_it : identical_scope_map) {
931 auto key_scope = kv_it.first;
932 const auto& key_scope_ctx = scope_ctxs[key_scope];
933 auto& key_scope_vec = kv_it.second;
934 if (key_scope_ctx->IsIdenticalFuncion(*scope_ctx)) {
935 key_scope_vec.emplace_back(scope);
936 unique = false;
937 break;
938 }
939 }
940 if (unique) {
941 identical_scope_map[scope].emplace_back(scope);
942 }
943 }
944
945 return identical_scope_map;
946 }
947
HasSameAttribute(const Node * a,const Node * b,const c10::Symbol & attr)948 static bool HasSameAttribute(
949 const Node* a,
950 const Node* b,
951 const c10::Symbol& attr) {
952 if (!a->hasAttribute(attr) && !b->hasAttribute(attr)) {
953 return true;
954 }
955 if (!a->hasAttribute(attr) || !b->hasAttribute(attr)) {
956 return false;
957 }
958 auto a_kind = a->kindOf(attr);
959 auto b_kind = b->kindOf(attr);
960 if (a_kind != b_kind) {
961 return false;
962 }
963
964 #define COMP_ATTR(kind) \
965 case AttributeKind::kind: { \
966 const auto& a_v = a->kind(attr); \
967 const auto& b_v = b->kind(attr); \
968 return a_v == b_v; \
969 }
970
971 switch (a_kind) {
972 COMP_ATTR(f)
973 COMP_ATTR(fs)
974 COMP_ATTR(i)
975 COMP_ATTR(is)
976 COMP_ATTR(s)
977 COMP_ATTR(ss)
978 #undef COMP_ATTR
979 case AttributeKind::t: {
980 const auto& a_v = a->t(attr);
981 const auto& b_v = b->t(attr);
982 return a_v.equal(b_v);
983 }
984 case AttributeKind::ts: {
985 const auto& a_v = a->ts(attr);
986 const auto& b_v = b->ts(attr);
987 return std::equal(
988 a_v.begin(),
989 a_v.end(),
990 b_v.begin(),
991 b_v.end(),
992 [](const at::Tensor& a_t, const at::Tensor& b_t) {
993 return a_t.equal(b_t);
994 });
995 }
996 case AttributeKind::ival:
997 case AttributeKind::g:
998 case AttributeKind::gs:
999 case AttributeKind::ty:
1000 case AttributeKind::tys:
1001 case AttributeKind::c:
1002 default:
1003 TORCH_INTERNAL_ASSERT(
1004 false,
1005 "Unexpected attribute type ",
1006 static_cast<int>(a_kind),
1007 " from node ",
1008 *a);
1009 break;
1010 }
1011
1012 return true;
1013 }
1014
SortScopesByMaxDepth(std::unordered_map<ScopePtr,scope_list> & identical_scope_map)1015 scope_list FunctionExtractor::SortScopesByMaxDepth(
1016 std::unordered_map<ScopePtr, scope_list>& identical_scope_map) {
1017 std::unordered_map<ScopePtr, size_t> scope_max_depth;
1018 for (const auto& it : identical_scope_map) {
1019 const auto& scopes = it.second;
1020 size_t max_depth = 0;
1021 for (const auto& scope : scopes) {
1022 if (scope->getDepth() > max_depth) {
1023 max_depth = scope->getDepth();
1024 }
1025 }
1026 scope_max_depth[it.first] = max_depth;
1027 }
1028
1029 scope_list sorted_scopes;
1030 sorted_scopes.reserve(scope_max_depth.size());
1031 for (const auto& it : scope_max_depth) {
1032 sorted_scopes.emplace_back(it.first);
1033 }
1034 std::sort(
1035 sorted_scopes.begin(),
1036 sorted_scopes.end(),
1037 [&scope_max_depth](const ScopePtr& a, const ScopePtr& b) -> bool {
1038 return scope_max_depth[a] >= scope_max_depth[b];
1039 });
1040 return sorted_scopes;
1041 }
1042
run()1043 NodeAttrNameMap FunctionExtractor::run() {
1044 auto scope_ctxs = PartitionNodesByScope(graph_);
1045 DebugPrintScopeContexts(scope_ctxs);
1046 auto identical_scope_map = PartitionIdenticalScopes(scope_ctxs);
1047 // Deepest scope comes first, guaranteeing no other scope can be its child.
1048 auto sorted_scope_keys = SortScopesByMaxDepth(identical_scope_map);
1049 for (const auto& scope_key : sorted_scope_keys) {
1050 if (module_names_.find(ONNXScopeName::className(scope_key)) !=
1051 module_names_.end()) {
1052 ConvertScopeToFunction(
1053 scope_key, identical_scope_map[scope_key], scope_ctxs, graph_);
1054 }
1055 GRAPH_DEBUG("Main graph afterwards: ", graph_->toString());
1056 }
1057 DebugPrintGraphWithFunction(graph_);
1058
1059 // Construct return mappings
1060 NodeAttrNameMap node_attr_to_name;
1061
1062 for (const auto& it : func_ctxs_) {
1063 auto func_ref_map = it.second->node_attr_to_name_;
1064 node_attr_to_name.insert(func_ref_map.begin(), func_ref_map.end());
1065 }
1066
1067 // Clear
1068 for (auto& it : scope_ctxs) {
1069 delete it.second;
1070 }
1071 scope_ctxs.clear();
1072 for (auto& it : func_ctxs_) {
1073 delete it.second;
1074 }
1075 func_ctxs_.clear();
1076
1077 return node_attr_to_name;
1078 }
1079
1080 // Retrieves the node representing the most recent
1081 // ScopePtr. This function should only be invoked from module forward hook. At
1082 // this point, module forward call is completed, and the most recent ScopePtr
1083 // is popped from TracingState.
1084 // This function inspects the node, and its subblock, to find
1085 // the node associated with the most recent ScopePtr.
NodeOfMostRecentScope(Node * forward_node)1086 Node* NodeOfMostRecentScope(Node* forward_node) {
1087 TORCH_INTERNAL_ASSERT(
1088 forward_node->kind() == prim::TracedModuleForward,
1089 "forward_node got kind: ",
1090 forward_node->kind().toDisplayString());
1091 auto* block = forward_node->blocks()[0];
1092 for (auto* node : block->nodes().reverse()) {
1093 if (node->kind() == prim::TracedModuleForward) {
1094 Node* target_node = NodeOfMostRecentScope(node);
1095 if (scope_attr_map_.find(node->scope()) == scope_attr_map_.end()) {
1096 return target_node;
1097 }
1098 }
1099 }
1100 return forward_node;
1101 }
1102
1103 } // namespace
1104
1105 // FunctionExtractor runs in the following steps. Updates are made inplace to
1106 // the graph argument.
1107 // 1. Partition nodes into groups based on their scope information.
1108 // Each scope represents an individual nn.Module call. A ScopeContext object
1109 // is created for each group.
1110 // 2. Compare and find groups with the same subgraph pattern from step 1.
1111 // 3. Scopes are nested. Starting from the deepest scope, extract the
1112 // subgraph pattern, and define as local function node. Replace subgraph
1113 // pattern with a single node of the new local function node type. A
1114 // FunctionContext object is created for each function.
1115 // 4. Construct NodeAttrNameMap tracking mapping from attribute name of
1116 // IR Node inside function subgraph, to function attribute name.
ONNXFunctionExtraction(std::shared_ptr<Graph> & graph,const std::unordered_set<std::string> & module_names,const std::vector<std::string> & param_names)1117 NodeAttrNameMap ONNXFunctionExtraction(
1118 std::shared_ptr<Graph>& graph,
1119 const std::unordered_set<std::string>& module_names,
1120 const std::vector<std::string>& param_names) {
1121 GRAPH_UPDATE(
1122 "Export these module forward calls as functions: ",
1123 std::vector<std::string>{module_names.begin(), module_names.end()});
1124 FunctionExtractor fe(graph, module_names, param_names);
1125 return fe.run();
1126 }
1127
ONNXGetPreviousScope(std::shared_ptr<Graph> & graph)1128 Node* ONNXGetPreviousScope(std::shared_ptr<Graph>& graph) {
1129 auto* last_node = graph->nodes().back()->prev();
1130 auto* scope_node = NodeOfMostRecentScope(last_node);
1131 auto* attr_node = scope_attr_graph_->create(prim::TracedModuleForward);
1132 attr_node->setScope(scope_node->scope());
1133 TORCH_INTERNAL_ASSERT(
1134 scope_attr_map_.find(scope_node->scope()) == scope_attr_map_.end(),
1135 "Found duplicated scope. Scope ",
1136 scope_node->scope()->namesFromRoot(),
1137 " already processed.");
1138 scope_attr_map_[scope_node->scope()] = attr_node;
1139 return attr_node;
1140 }
1141
ONNXClearScopeRecords()1142 void ONNXClearScopeRecords() {
1143 scope_attr_map_.clear();
1144 scope_attr_graph_ = std::make_shared<Graph>();
1145 }
1146
ONNXTrackScopeAttributes(std::shared_ptr<Graph> & graph,std::map<std::string,IValue> & attributes)1147 void ONNXTrackScopeAttributes(
1148 std::shared_ptr<Graph>& graph,
1149 std::map<std::string, IValue>& attributes) {
1150 // Skip the "real" last node which is `return_node`.
1151 auto* last_node = graph->nodes().back()->prev();
1152 auto* scope_node = NodeOfMostRecentScope(last_node);
1153 auto* attr_node = scope_attr_graph_->create(prim::TracedModuleForward);
1154 attr_node->setScope(scope_node->scope());
1155 TORCH_INTERNAL_ASSERT(
1156 scope_attr_map_.find(scope_node->scope()) == scope_attr_map_.end());
1157 scope_attr_map_[scope_node->scope()] = attr_node;
1158
1159 for (const auto& it : attributes) {
1160 auto k = Symbol::attr(it.first);
1161 auto v = it.second;
1162 if (v.isTensor()) {
1163 attr_node->t_(k, v.toTensor());
1164 } else if (v.isInt()) {
1165 attr_node->i_(k, v.toInt());
1166 } else if (v.isDouble()) {
1167 attr_node->f_(k, v.toDouble());
1168 } else if (v.isBool()) {
1169 attr_node->i_(k, v.toBool());
1170 } else if (v.isString()) {
1171 attr_node->s_(k, v.toStringRef());
1172 } else if (v.isIntList()) {
1173 attr_node->is_(k, v.toIntList().vec());
1174 } else if (v.isBoolList()) {
1175 auto bool_list = v.toBoolList();
1176 attr_node->is_(
1177 k, std::vector<int64_t>(bool_list.begin(), bool_list.end()));
1178 } else if (v.isDoubleList()) {
1179 attr_node->fs_(k, v.toDoubleList().vec());
1180 }
1181 }
1182 }
1183
1184 } // namespace torch::jit::onnx
1185