xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/utils/subgraph_utils.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/passes/utils/subgraph_utils.h>
2 
3 #include <torch/csrc/jit/passes/canonicalize.h>
4 
5 #include <ATen/core/symbol.h>
6 #include <c10/util/StringUtil.h>
7 #include <c10/util/irange.h>
8 #include <torch/csrc/jit/jit_log.h>
9 
10 #include <utility>
11 
12 namespace torch {
13 namespace jit {
14 namespace SubgraphUtils {
15 namespace {
16 
hasSubgraph(Node * n)17 bool hasSubgraph(Node* n) {
18   return n->hasAttribute(attr::Subgraph);
19 }
20 
gatherLastUses(at::ArrayRef<Value * > values)21 std::vector<std::optional<const Use>> gatherLastUses(
22     at::ArrayRef<Value*> values) {
23   return fmap(values, [&](Value* v) -> std::optional<const Use> {
24     return firstOrLastUse(v, /*find_first*/ false);
25   });
26 }
27 
28 // When merging a node into a subgraph, we wish to preserve all of the
29 // aliasing properties of the node's outputs. It is difficult to track
30 // the node or its contained nodes through all of the ir manipulation
31 // involved in merging; it is pretty easy to uniquely identify the value
32 // based on its uses. We can identify the value by its last use in the graph.
33 // Values which do not have uses or which do not have a last use
34 // outside of the subgraph to be merged into we do not need to track.
35 struct ValueMapper {
36   // `to_merge` is the node we're merginginto a subgraph, `existing_subgraph` is
37   // the subgraph node that we're merging into if it exists
ValueMappertorch::jit::SubgraphUtils::__anone4c04e190111::ValueMapper38   ValueMapper(
39       Node* to_merge,
40       AliasDb& db,
41       std::optional<Node*> existing_subgraph) {
42     last_uses_ = gatherLastUses(to_merge->outputs());
43     if (existing_subgraph) {
44       existing_last_uses_ = gatherLastUses((*existing_subgraph)->outputs());
45     }
46     WithInsertPoint guard(to_merge);
47     auto g = to_merge->owningGraph();
48     // temporary node to put the aliasing properties of the node before its
49     // merged and destroyed
50     placeholder_node_ = g->insertNode(g->create(prim::Uninitialized, 0));
51     for (size_t i = 0; i < to_merge->outputs().size(); ++i) {
52       Value* existing = to_merge->outputs().at(i);
53       Value* new_value = placeholder_node_->insertOutput(i)->copyMetadata(
54           to_merge->outputs().at(i));
55       db.replaceWithNewValue(existing, new_value);
56     }
57   }
58 
usesEqualtorch::jit::SubgraphUtils::__anone4c04e190111::ValueMapper59   bool usesEqual(const Use& a, const Use& b) {
60     return a.user == b.user && a.offset == b.offset;
61   }
62 
copyAliasingtorch::jit::SubgraphUtils::__anone4c04e190111::ValueMapper63   void copyAliasing(Node* merged_node, AliasDb& db) {
64     auto new_outputs = merged_node->outputs();
65     for (Value* v : new_outputs) {
66       auto maybe_last_use = firstOrLastUse(v, /*find_first*/ false);
67       // if it doesnt have a use it shouldnt have been added as output
68       TORCH_INTERNAL_ASSERT(maybe_last_use);
69       const Use last_use = *maybe_last_use;
70 
71       // existing outputs of the subgraph do not need to have alias db mappings
72       // updated
73       bool is_existing_value = false;
74       for (size_t i = 0; i < existing_last_uses_.size() && !is_existing_value;
75            ++i) {
76         is_existing_value = existing_last_uses_[i].has_value() &&
77             usesEqual(*existing_last_uses_[i], last_use);
78       }
79       if (is_existing_value) {
80         continue;
81       }
82 
83       size_t i = 0;
84       while (i < last_uses_.size() && last_uses_.at(i).has_value() &&
85              !usesEqual(*last_uses_.at(i), last_use)) {
86         ++i;
87       }
88       TORCH_INTERNAL_ASSERT(i != last_uses_.size());
89       db.replaceWithNewValue(placeholder_node_->outputs().at(i), v);
90     }
91     placeholder_node_->destroy();
92   }
93 
94   std::vector<std::optional<const Use>> last_uses_;
95   std::vector<std::optional<const Use>> existing_last_uses_;
96   Node* placeholder_node_;
97 };
98 
executeSubgraphMergeAndUpdateAliasing(Node * to_merge,std::optional<Node * > existing,AliasDb & db,const std::function<Node * (void)> & merge_fn)99 Node* executeSubgraphMergeAndUpdateAliasing(
100     Node* to_merge,
101     std::optional<Node*> existing,
102     AliasDb& db,
103     const std::function<Node*(void)>& merge_fn) {
104   // When we merge a node into a subgraph, the new subgraph outputs
105   // have the same aliasing properties as the original node's outputs.
106   // Here we create a placeholder node, transfer the aliasing properties
107   // to the placeholder, execute the merge, and transfer the aliasing
108   // properties to the appropriate fusion group outputs
109   ValueMapper vm(to_merge, db, existing);
110   Node* fusion_group = merge_fn();
111   vm.copyAliasing(fusion_group, db);
112   return fusion_group;
113 }
114 
115 // Combine the nodes in two subgraph together. The nodes will end up in
116 // `mergeTo`, and `mergeFrom` is destroyed.
mergeSubgraph(Node * mergeTo,Node * mergeFrom)117 void mergeSubgraph(Node* mergeTo, Node* mergeFrom) {
118   bool merge_from_is_after = mergeFrom->isAfter(mergeTo);
119   Node* nodeBeforeMergeFrom = mergeFrom->prev();
120   Node* nodeAfterMergeFrom = mergeFrom->next();
121 
122   unmergeSubgraph(mergeFrom);
123 
124   graph_node_list_iterator end_it;
125   graph_node_list_iterator it;
126 
127   if (merge_from_is_after) {
128     it = nodeBeforeMergeFrom->iterator();
129     end_it = nodeAfterMergeFrom->iterator();
130   } else {
131     end_it = nodeBeforeMergeFrom->reverseIterator();
132     it = nodeAfterMergeFrom->reverseIterator();
133   }
134   ++it;
135 
136   std::vector<Node*> merged_nodes;
137   while (it != end_it) {
138     Node* node = *it;
139     ++it;
140     mergeNodeIntoSubgraph(node, mergeTo);
141   }
142 }
143 
144 struct topo_cmp_value {
operator ()torch::jit::SubgraphUtils::__anone4c04e190111::topo_cmp_value145   bool operator()(Value* a, Value* b) const {
146     if (a->node() == b->node()) {
147       return a->unique() < b->unique();
148     }
149     return a->node()->isBefore(b->node());
150   }
151 };
152 
153 struct topo_cmp_node {
operator ()torch::jit::SubgraphUtils::__anone4c04e190111::topo_cmp_node154   bool operator()(Node* a, Node* b) const {
155     return a->isBefore(b);
156   }
157 };
158 
collectNodesToUnfuse(Node * start,std::set<Node *,topo_cmp_node> & s)159 void collectNodesToUnfuse(Node* start, std::set<Node*, topo_cmp_node>& s) {
160   if (start->kind() == prim::Return || start->kind() == prim::Param) {
161     GRAPH_DEBUG("reached the param or return node", getHeader(start));
162     return;
163   }
164 
165   if (s.count(start) != 0) {
166     // already visited, no need to visit descendants
167     return;
168   }
169 
170   GRAPH_DEBUG("collectNodesToUnfuse: inserting node ", getHeader(start));
171   s.insert(start);
172 
173   for (auto o : start->outputs()) {
174     for (auto use : o->uses()) {
175       collectNodesToUnfuse(use.user, s);
176     }
177   }
178 }
179 
buildAliasedSets(std::shared_ptr<Graph> subgraph)180 std::vector<std::set<Value*, topo_cmp_value>> buildAliasedSets(
181     std::shared_ptr<Graph> subgraph) {
182   auto outputs = subgraph->outputs();
183   AliasDb alias_db(std::move(subgraph));
184   TORCH_INTERNAL_ASSERT(outputs.size() > 1);
185   std::vector<std::set<Value*, topo_cmp_value>> res;
186   for (auto o : outputs) {
187     auto grouped = false;
188     for (auto& s : res) {
189       auto os = *s.begin();
190       auto aliased = alias_db.mayContainAlias(os, o);
191       GRAPH_DEBUG(
192           "comparing %",
193           o->debugName(),
194           " with %",
195           os->debugName(),
196           " result ",
197           aliased);
198       if (aliased) {
199         s.insert(o);
200         GRAPH_DEBUG("Grouping %", o->debugName(), " with %", os->debugName());
201         grouped = true;
202       }
203     }
204     if (!grouped) {
205       res.push_back({o});
206     }
207   }
208   return res;
209 }
210 
211 } // namespace
212 
getSubgraph(Node * n)213 std::shared_ptr<Graph> getSubgraph(Node* n) {
214   return n->g(attr::Subgraph);
215 }
216 
unmergeSubgraph(Node * subgraphNode)217 void unmergeSubgraph(Node* subgraphNode) {
218   // Inline the graph, replace uses of node outputs and destroy the node
219   auto outerGraph = subgraphNode->owningGraph();
220   WithInsertPoint guard(subgraphNode);
221   const auto subgraphOutputs = insertGraph(
222       *outerGraph, *getSubgraph(subgraphNode), subgraphNode->inputs());
223   AT_ASSERT(subgraphOutputs.size() >= subgraphNode->outputs().size());
224   for (size_t i = 0; i < subgraphNode->outputs().size(); ++i) {
225     subgraphNode->outputs()[i]->replaceAllUsesWith(subgraphOutputs[i]);
226   }
227   subgraphNode->destroy();
228 }
229 
collectNestedUses(std::unordered_set<Value * > & closed_over_values,std::unordered_set<Value * > & new_values,std::unordered_map<Value *,Value * > & externalValuesMap,Node * input_node)230 static void collectNestedUses(
231     std::unordered_set<Value*>& closed_over_values,
232     std::unordered_set<Value*>& new_values,
233     std::unordered_map<Value*, Value*>& externalValuesMap,
234     Node* input_node) {
235   for (auto input : input_node->inputs()) {
236     if (externalValuesMap.count(input) == 0 && new_values.count(input) == 0) {
237       closed_over_values.insert(input);
238     }
239   }
240   if (input_node->kind() == prim::If) {
241     for (Block* block : input_node->blocks()) {
242       for (Node* node : block->nodes()) {
243         collectNestedUses(
244             closed_over_values, new_values, externalValuesMap, node);
245       }
246       for (Value* v : block->outputs()) {
247         if (externalValuesMap.count(v) == 0 && new_values.count(v) == 0) {
248           closed_over_values.insert(v);
249         }
250       }
251     }
252   } else if (input_node->kind() == prim::Loop) {
253     for (Value* v : input_node->inputs()) {
254       if (externalValuesMap.count(v) == 0 && new_values.count(v) == 0) {
255         closed_over_values.insert(v);
256       }
257     }
258     Block* block = input_node->blocks().at(0);
259     for (Value* v : block->inputs()) {
260       new_values.insert(v);
261     }
262     for (Node* node : block->nodes()) {
263       collectNestedUses(
264           closed_over_values, new_values, externalValuesMap, node);
265     }
266   } else if (!input_node->blocks().empty()) {
267     TORCH_INTERNAL_ASSERT(false, input_node, " kind not handled yet");
268   }
269   for (Value* output : input_node->outputs()) {
270     new_values.insert(output);
271   }
272 }
273 
closedOverValues(Node * toMerge,std::unordered_map<Value *,Value * > & externalValuesMap)274 static std::unordered_set<Value*> closedOverValues(
275     Node* toMerge,
276     std::unordered_map<Value*, Value*>& externalValuesMap) {
277   std::unordered_set<Value*> closed_over_values;
278   std::unordered_set<Value*> new_values;
279   collectNestedUses(closed_over_values, new_values, externalValuesMap, toMerge);
280   return closed_over_values;
281 }
282 
mergeNodeIntoSubgraph(Node * toMerge,Node * subgraphNode,bool destroyNode)283 void mergeNodeIntoSubgraph(
284     Node* toMerge,
285     Node* subgraphNode,
286     bool destroyNode) {
287   AT_ASSERT(hasSubgraph(subgraphNode) && toMerge != subgraphNode);
288   if (hasSubgraph(toMerge)) {
289     return mergeSubgraph(subgraphNode, toMerge);
290   }
291 
292   auto subgraph = getSubgraph(subgraphNode);
293 
294   // Map from values in the surrounding graph to inputs/outputs in the subgraph
295   std::unordered_map<Value*, Value*> externalValuesMap;
296 
297   AT_ASSERT(subgraphNode->inputs().size() == subgraph->inputs().size());
298   size_t idx = 0;
299   for (auto input : subgraphNode->inputs()) {
300     externalValuesMap[input] = subgraph->inputs()[idx];
301     idx++;
302   }
303 
304   for (size_t i = 0; i < subgraphNode->outputs().size(); ++i) {
305     externalValuesMap[subgraphNode->outputs().at(i)] =
306         subgraph->outputs().at(i);
307   }
308 
309   // Add n's inputs to the group's input list if we don't already have them
310 
311   bool merging_node_after_subgraph = toMerge->isAfter(subgraphNode);
312   Node* guard_node = merging_node_after_subgraph ? *subgraph->nodes().end()
313                                                  : *subgraph->nodes().begin();
314   WithInsertPoint guard(guard_node);
315 
316   std::unordered_set<Value*> closedValues =
317       closedOverValues(toMerge, externalValuesMap);
318 
319   // There are currently downstream usage that relies on a fixed ordering
320   // of graph inputs. TODO: remove
321   std::vector<Value*> orderedClosedValues;
322   std::unordered_set<Value*> orderedSeenValues;
323   for (Value* input : toMerge->inputs()) {
324     orderedClosedValues.push_back(input);
325     orderedSeenValues.insert(input);
326   }
327   for (Value* closedValue : closedValues) {
328     if (!orderedSeenValues.count(closedValue)) {
329       orderedClosedValues.push_back(closedValue);
330       orderedSeenValues.insert(closedValue);
331     }
332   }
333 
334   for (auto input : orderedClosedValues) {
335     if (externalValuesMap.count(input) == 0) {
336       // Clone constants inside the subgraph instead of referencing them, to
337       // enable more optimizations
338       if (auto value = toIValue(input)) {
339         auto nv = subgraph->insertConstant(*value);
340         nv->copyMetadata(input);
341         externalValuesMap[input] = nv;
342       } else {
343         // The common case: this is a regular input, so just register it with
344         // the group node and inner subgraph
345         subgraphNode->addInput(input);
346         auto inputToGraph = subgraph->addInput();
347         inputToGraph->copyMetadata(input);
348         externalValuesMap[input] = inputToGraph;
349       }
350     }
351   }
352 
353   // Merge the node into the graph
354   auto mergedNode = subgraph->insertNode(subgraph->createClone(
355       toMerge, [&](Value* v) { return externalValuesMap[v]; }));
356 
357   if (!merging_node_after_subgraph) {
358     // If n's outputs were inputs to `group`, remove them since we just merged
359     // n in.
360     //
361     // i.e.,
362     // x = f(w); group(x, y, z) becomes group(w, y, z).
363     // x, y, z = f(w); group(x, y, z) becomes group(w).
364     auto inputs = subgraphNode->inputs();
365     for (size_t i = 0; i < toMerge->outputs().size(); ++i) {
366       auto it = std::find(inputs.begin(), inputs.end(), toMerge->outputs()[i]);
367       if (it != inputs.end()) {
368         size_t p = it - inputs.begin();
369         subgraphNode->removeInput(p);
370         subgraph->inputs()[p]->replaceAllUsesWith(mergedNode->outputs()[i]);
371         subgraph->eraseInput(p);
372       }
373     }
374   }
375 
376   // Add n's outputs to the group node and inner subgraph outputs.
377   for (const auto i : c10::irange(toMerge->outputs().size())) {
378     auto oldOutput = toMerge->outputs()[i];
379     auto newOutput = mergedNode->outputs()[i];
380     subgraph->registerOutput(newOutput);
381     auto groupOutput = subgraphNode->addOutput();
382     groupOutput->copyMetadata(oldOutput);
383     oldOutput->replaceAllUsesWith(groupOutput);
384   }
385   // Remove the original node now that the merge is complete
386   if (destroyNode) {
387     toMerge->destroy();
388   }
389 
390   // We wait till destroying `toMerge` before pruning subgraph outputs,
391   // since destroying `toMerge` could cause a subgraph output to no longer
392   // have any uses
393   const auto hasUsesOutsideSubgraph = [&](Value* v) {
394     return std::any_of(
395         v->uses().cbegin(), v->uses().cend(), [&](const Use& use) {
396           return use.user->isAfter(subgraphNode);
397         });
398   };
399 
400   for (int64_t i = subgraphNode->outputs().size() - 1; i >= 0; i--) {
401     if (!hasUsesOutsideSubgraph(subgraphNode->outputs().at(i))) {
402       subgraphNode->eraseOutput(i);
403       subgraph->eraseOutput(i);
404     }
405   }
406 }
407 
createSingletonSubgraph(Node * n,Symbol subgraphKind)408 Node* createSingletonSubgraph(Node* n, Symbol subgraphKind) {
409   auto graph = n->owningGraph();
410   auto subgraph = graph->create(subgraphKind, 0);
411   subgraph->g_(attr::Subgraph, std::make_shared<Graph>(graph->current_scope()));
412   subgraph->insertBefore(n);
413   mergeNodeIntoSubgraph(n, subgraph);
414   return subgraph;
415 }
416 
mergeNodeIntoSubgraphAndUpdateAliasing(Node * to_merge,Node * subgraphNode,AliasDb & db)417 void mergeNodeIntoSubgraphAndUpdateAliasing(
418     Node* to_merge,
419     Node* subgraphNode,
420     AliasDb& db) {
421   executeSubgraphMergeAndUpdateAliasing(to_merge, subgraphNode, db, [&]() {
422     mergeNodeIntoSubgraph(to_merge, subgraphNode);
423     return subgraphNode;
424   });
425 }
426 
createSingletonSubgraphAndUpdateAliasing(Node * to_merge,Symbol subgraphKind,AliasDb & db)427 Node* createSingletonSubgraphAndUpdateAliasing(
428     Node* to_merge,
429     Symbol subgraphKind,
430     AliasDb& db) {
431   return executeSubgraphMergeAndUpdateAliasing(
432       to_merge, std::nullopt, db, [&]() {
433         return createSingletonSubgraph(to_merge, subgraphKind);
434       });
435 }
436 
unmergeOutputsAlisingInputs(Node * subgraphNode)437 bool unmergeOutputsAlisingInputs(Node* subgraphNode) {
438   GRAPH_DEBUG("unfuseOutputsAlisingInputs on ", getHeader(subgraphNode));
439   auto subgraph = subgraphNode->g(attr::Subgraph);
440   AliasDb alias_db(subgraph);
441 
442   std::set<Node*, topo_cmp_node> nodes;
443   for (auto o : subgraph->outputs()) {
444     if (alias_db.mayContainAlias(o, subgraph->inputs())) {
445       collectNodesToUnfuse(o->node(), nodes);
446     }
447   }
448 
449   // unfuse in the reverse topo order
450   for (auto it = nodes.rbegin(); it != nodes.rend(); it++) {
451     SubgraphUtils::unmergeNode(*it, subgraphNode);
452   }
453 
454   return !nodes.empty();
455 }
456 
unmergeAliasedOutputs(Node * subgraphNode)457 bool unmergeAliasedOutputs(Node* subgraphNode) {
458   GRAPH_DEBUG("unfuseAliasedOutputs on ", getHeader(subgraphNode));
459   if (subgraphNode->outputs().size() < 2) {
460     return false;
461   }
462 
463   auto subgraph = subgraphNode->g(attr::Subgraph);
464   GRAPH_DUMP("unfuseAliasedOutputs Subgraph ", subgraph);
465   auto sets = buildAliasedSets(std::move(subgraph));
466   GRAPH_DEBUG("buildAliasedSets sets.size() = ", sets.size());
467 
468   std::set<Node*, topo_cmp_node> nodes;
469 
470   for (auto i : c10::irange(sets.size())) {
471     if (sets[i].size() <= 1) {
472       GRAPH_DEBUG(
473           "Set ",
474           i,
475           " with leader ",
476           (*(sets[i].begin()))->debugName(),
477           " size = ",
478           sets[i].size());
479       continue;
480     }
481 
482     // we have at least two aliased outputs
483     // we skip the earliest node w.r.t. the topo order
484     // NB. after some nodes are unfused, the outputs of some other nodes
485     // may become the outputs of the subgraph and alias the remaining ones
486     // so we have to re-run this function until there are no more changes
487     auto it = ++sets[i].begin();
488     while (it != sets[i].end()) {
489       GRAPH_DEBUG(
490           "root aliased value ", (*it)->debugName(), " node ", *(*it)->node());
491       collectNodesToUnfuse((*it)->node(), nodes);
492       it++;
493     }
494   }
495 
496   // unfuse in the reverse topo order
497   for (auto it = nodes.rbegin(); it != nodes.rend(); it++) {
498     unmergeNode(*it, subgraphNode);
499   }
500 
501   return !nodes.empty();
502 }
503 
unmergeNode(Node * n,Node * subgraphNode)504 void unmergeNode(Node* n, Node* subgraphNode) {
505   // collect output indices
506   GRAPH_DEBUG("unfuseNode node ", getHeader(n));
507   auto subgraph = n->owningGraph();
508 
509   std::set<Value*> node_outputs(n->outputs().begin(), n->outputs().end());
510   std::set<size_t> output_indices;
511   std::set<Value*> node_inputs(n->inputs().begin(), n->inputs().end());
512 
513   std::unordered_map<Value*, Value*> local_map;
514   auto env = [&](Value* v) {
515     auto it = local_map.find(v);
516     if (it != local_map.end()) {
517       return it->second;
518     }
519     TORCH_INTERNAL_ASSERT(
520         false,
521         "all inputs should've been mapped. Couldn't map %",
522         v->debugName());
523     return v;
524   };
525 
526   for (auto i : c10::irange(subgraph->outputs().size())) {
527     if (node_outputs.count(subgraph->outputs().at(i)) != 0) {
528       output_indices.insert(i);
529     }
530 
531     if (node_inputs.count(subgraph->outputs().at(i)) != 0) {
532       GRAPH_DEBUG(
533           "output %",
534           subgraph->outputs().at(i)->debugName(),
535           " is already subgraph's output");
536       GRAPH_DEBUG(
537           "Mapping %",
538           subgraph->outputs().at(i)->debugName(),
539           " to %",
540           subgraphNode->outputs().at(i)->debugName());
541       local_map[subgraph->outputs().at(i)] = subgraphNode->outputs().at(i);
542       node_inputs.erase(subgraph->outputs().at(i));
543     }
544   }
545 
546   WithInsertPoint wip(subgraphNode->next());
547 
548   // these node inputs need to be added to subgraph's outputs
549   // put them in vmap
550   for (auto ni : node_inputs) {
551     if (local_map.count(ni) != 0) {
552       // this could happen if `n` uses two or more outputs
553       // of a constant node and we already cloned the constant
554       // into the outer graph and mapped its outputs
555       continue;
556     }
557 
558     Value* sno = nullptr;
559     if (ni->node()->kind() == prim::Constant) {
560       auto copy = subgraphNode->owningGraph()->createClone(ni->node(), env);
561       subgraphNode->owningGraph()->insertNode(copy);
562       // in case we have a multi-output const, map the rest of the outputs
563       // so when we get to clone `n`, `n`'s clone will use the outputs of this
564       // constant clone
565       for (auto i : c10::irange(n->outputs().size())) {
566         GRAPH_DEBUG(
567             "Mapping %",
568             ni->node()->output(i)->debugName(),
569             " to %",
570             copy->output(i)->debugName());
571         local_map[ni->node()->output(i)] = copy->output(i);
572       }
573     } else {
574       subgraph->registerOutput(ni);
575       sno = subgraphNode->addOutput();
576       sno->setType(ni->type());
577       GRAPH_DEBUG("Mapping %", ni->debugName(), " to %", sno->debugName());
578       local_map[ni] = sno;
579     }
580   }
581 
582   auto copy = subgraphNode->owningGraph()->createClone(n, env);
583   GRAPH_DEBUG("copy ", *copy);
584 
585   for (auto i : c10::irange(n->outputs().size())) {
586     auto oo = n->outputs()[i];
587     auto no = copy->outputs()[i];
588     no->copyMetadata(oo);
589     GRAPH_DEBUG("Mapping %", oo->debugName(), " to %", no->debugName());
590     local_map[oo] = no;
591   }
592 
593   subgraphNode->owningGraph()->insertNode(copy);
594 
595   for (auto it = output_indices.rbegin(); it != output_indices.rend(); it++) {
596     auto replace_val = local_map[subgraph->outputs().at(*it)];
597     subgraphNode->outputs().at(*it)->replaceAllUsesWith(replace_val);
598     subgraphNode->eraseOutput(*it);
599     subgraph->eraseOutput(*it);
600   }
601 
602   n->destroy();
603 }
604 
truncateStrWithHash(const std::string & s,size_t maxlen)605 static std::string truncateStrWithHash(const std::string& s, size_t maxlen) {
606   if (s.size() <= maxlen) {
607     return s;
608   }
609   std::string hash_str = std::to_string(c10::hash<std::string>{}(s));
610   // If hash-string plus '_' can fit into maxlen, then truncate the original
611   // string correspondingly so that the final string with the hash included fits
612   // into maxlen. If that's not possible, at least truncate the original string
613   // to maxlen (and append the hash to it).
614   size_t trunc_len =
615       (maxlen > hash_str.size() + 1) ? (maxlen - hash_str.size() - 1) : maxlen;
616   std::stringstream truncated;
617   truncated << s.substr(0, trunc_len);
618   truncated << "_" << hash_str;
619   return truncated.str();
620 }
621 
generateNameForGraph(const std::shared_ptr<Graph> & graph,size_t maxlen,const std::string & prefix)622 std::string generateNameForGraph(
623     const std::shared_ptr<Graph>& graph,
624     size_t maxlen,
625     const std::string& prefix) {
626   std::stringstream graph_name;
627   graph_name << prefix;
628   for (Node* node : graph->nodes()) {
629     if (!node->kind().is_aten()) {
630       continue;
631     }
632     graph_name << "_" << node->kind().toUnqualString();
633   }
634   return truncateStrWithHash(graph_name.str(), maxlen);
635 }
636 
637 } // namespace SubgraphUtils
638 } // namespace jit
639 } // namespace torch
640