xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/onnx/remove_inplace_ops_for_onnx.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/passes/onnx/helper.h>
2 #include <torch/csrc/jit/passes/onnx/remove_inplace_ops_for_onnx.h>
3 #include <torch/csrc/jit/passes/remove_inplace_ops.h>
4 #include <torch/csrc/jit/passes/remove_mutation.h>
5 
6 #include <torch/csrc/jit/frontend/error_report.h>
7 #include <torch/csrc/jit/jit_log.h>
8 #include <torch/csrc/jit/passes/dead_code_elimination.h>
9 #include <torch/csrc/jit/passes/onnx/helper.h>
10 #include <torch/csrc/jit/passes/onnx/pattern_conversion/pattern_encapsulation.h>
11 
12 #include <c10/util/irange.h>
13 
14 #include <limits>
15 
16 namespace torch::jit {
17 
18 namespace {
19 
20 const std::set<c10::Symbol> inplace_ops =
21     {aten::append, aten::index_put_, aten::pop, aten::insert, aten::Delete};
22 
23 // InplaceConverter defines a set of functions that together enables the
24 // conversion from prim::GetAttr, prim::SetAttr, and ATen in-place operators to
25 // ONNX out-place operators.
26 struct InplaceConverter {
InplaceConvertertorch::jit::__anone66fbc010111::InplaceConverter27   InplaceConverter(
28       std::shared_ptr<Graph> graph,
29       MutationRemover* mr,
30       Module* model = nullptr)
31       : graph_(std::move(graph)), mr_(mr), module_(model) {}
32 
33   void convertMutationForONNX();
34 
35  private:
36   void gatherAttrNameInitialValueMap(
37       Block* block,
38       std::unordered_map<std::string, Value*>& attr_name_value_map,
39       std::unordered_map<Node*, std::string>& attr_node_fullname_map);
40   void replaceAttrWithInplaceOps(
41       Block* block,
42       const std::unordered_map<std::string, Value*>& attr_name_value_map,
43       const std::unordered_map<Node*, std::string>& attr_node_fullname_map);
44 
45   void convertInplaceOpsAndTrackAlias();
46   void convertInplaceOpsAndTrackAlias(Block* block);
47 
48   void correctAliasReferences();
49   void correctAliasReferences(Block* block);
50   void correctAliasReferences(Node* n);
51 
52   void convertGetSetAttrToInplaceOps(Block* block);
53 
54   // ValueTracker provides apis to record aliases for a single value,
55   // and to retrieve the correct alias of any given value based on the location
56   // in the graph it is used.
57   struct ValueTracker {
ValueTrackertorch::jit::__anone66fbc010111::InplaceConverter::ValueTracker58     ValueTracker() : graph_(nullptr) {}
59 
60     void init(const std::shared_ptr<Graph>& graph);
61     void recordSetValue(Value* old_v, Value* new_v);
62     Value* findAliasForValueAtNode(Value* v, const Node* n) const;
63 
64     std::string toString() const;
65 
66    private:
67     std::shared_ptr<Graph> graph_;
68 
69     // Map from aliases to root value.
70     // A single value can have multiple aliases throughout the graph,
71     // created by inplace operators, and preserved through loop carried
72     // input/output. For each such value, its first occurrence will be set as
73     // root value.
74     std::unordered_map<Value*, Value*> alias_to_value_;
75 
76     // Sort the alias based on their order in graph.
77     // A tie can happen when two distinct aliases belong to different blocks,
78     // while having the same ancestor node. The unique id is used as tie
79     // breaker, otherwise the two aliases will be considered equal to each
80     // other. aliasComp must satisfy strict weak ordering.
81     struct aliasComp {
operator ()torch::jit::__anone66fbc010111::InplaceConverter::ValueTracker::aliasComp82       bool operator()(const Value* a, const Value* b) const {
83         auto* n_a = a->node();
84         auto* n_b = b->node();
85         if (n_a == n_b) {
86           return false;
87         }
88         auto a_b = n_a->isBefore(n_b);
89         auto b_a = n_b->isBefore(n_a);
90         if (a_b == b_a) {
91           return a->unique() < b->unique();
92         }
93         return a_b;
94       }
95     };
96     // Map from root value to aliases sorted by their order in graph.
97     std::unordered_map<Value*, std::set<Value*, aliasComp>>
98         value_to_sorted_aliases_;
99   };
100 
101   std::shared_ptr<Graph> graph_;
102   MutationRemover* mr_;
103   Module* module_;
104   ValueTracker vt_;
105 };
106 
isAncestor(const Block * a,const Block * b)107 bool isAncestor(const Block* a, const Block* b) {
108   while (b && b->owningNode()) {
109     if (a == b) {
110       return true;
111     }
112     b = b->owningNode()->owningBlock();
113   }
114   return a == b;
115 }
116 
addDummyClone(Graph * graph,Value * orig_data,bool insertBefore,Node * referenceNode)117 Node* addDummyClone(
118     Graph* graph,
119     Value* orig_data,
120     bool insertBefore,
121     Node* referenceNode) {
122   Node* newNode = nullptr;
123   if (orig_data->type()->kind() == TypeKind::ListType) {
124     newNode = graph->create(aten::list, /*num_outputs =*/1);
125     newNode->addInput(orig_data);
126     newNode->output()->setType(orig_data->type());
127     if (insertBefore)
128       newNode->insertBefore(referenceNode);
129     else
130       referenceNode->owningBlock()->prependNode(newNode);
131   } else if (
132       orig_data->type()->kind() == TypeKind::TensorType ||
133       orig_data->type()->kind() == TypeKind::IntType ||
134       orig_data->type()->kind() == TypeKind::FloatType ||
135       orig_data->type()->kind() == TypeKind::BoolType) {
136     auto* noneNode = graph->create(prim::Constant);
137     noneNode->output()->setType(NoneType::get());
138     // For scripting mode, aten::clone requires input to be a TensorType
139     // Hence if we encounter an IntType, FloatType, or BoolType,
140     // we set the input to the appropriate TensorType
141     if (orig_data->type()->kind() == TypeKind::IntType &&
142         insertBefore == false) {
143       orig_data->setType(TensorType::fromNumberType(*IntType::get()));
144     } else if (
145         orig_data->type()->kind() == TypeKind::FloatType &&
146         insertBefore == false) {
147       orig_data->setType(TensorType::fromNumberType(*FloatType::get()));
148     } else if (
149         orig_data->type()->kind() == TypeKind::BoolType &&
150         insertBefore == false) {
151       orig_data->setType(TensorType::fromBoolType());
152     }
153     newNode = graph->create(aten::clone, /*num_outputs =*/1);
154     newNode->addInput(orig_data);
155     newNode->addInput(noneNode->output());
156     newNode->output()->setType(orig_data->type());
157     if (insertBefore)
158       newNode->insertBefore(referenceNode);
159     else
160       referenceNode->owningBlock()->prependNode(newNode);
161     noneNode->insertBefore(newNode);
162   }
163   return newNode;
164 }
165 
PrepareIndexPutForONNX(Node * node)166 std::pair<Value*, Value*> PrepareIndexPutForONNX(Node* node) {
167   TORCH_INTERNAL_ASSERT(
168       node->kind() == aten::index_put || node->kind() == aten::index_put_);
169   auto placeholder_node = EncapsulatePatternIntoSubblock(node).value();
170   node->destroy();
171   return std::make_pair(placeholder_node->input(0), placeholder_node->output());
172 }
173 
PrepareCopyForONNX(Node * node)174 std::pair<Value*, Value*> PrepareCopyForONNX(Node* node) {
175   TORCH_INTERNAL_ASSERT(node->kind() == aten::copy_);
176   // aten::copy_ can be viewed as a special case of index_put, where the
177   // tensor indices input is empty.
178   // Remove aten::copy_, and replace it with index_put.
179   // 1. create an empty listConstruct node as indices input for index_put.
180   // 2. create index_put node.
181 
182   // Tracing aten::copy_ broadcasts the rhs values.
183   // 3. Apply broadcasting for scripting.
184   WithInsertPoint guard(node);
185   auto graph = node->owningGraph();
186   auto dummy_list =
187       graph->insertNode(graph->createList(OptionalType::ofTensor(), {}))
188           ->output();
189 
190   auto expanded_value =
191       graph->insert(aten::expand_as, {node->input(1), node->input(0)});
192   expanded_value->node()->setSourceRange(node->sourceRange());
193   expanded_value->copyMetadata(node->input(1));
194   expanded_value->node()->copyMetadata(node);
195 
196   auto index_put = graph->insert(
197       aten::index_put_,
198       {node->input(0), dummy_list, expanded_value, node->input(2)});
199   index_put->node()->copyMetadata(node);
200   index_put->copyMetadata(node->output());
201   node->output()->replaceAllUsesWith(index_put);
202 
203   node->destroy();
204 
205   return PrepareIndexPutForONNX(index_put->node());
206 }
207 
PrepareSetForONNX(Node * n)208 auto PrepareSetForONNX(Node* n) {
209   TORCH_INTERNAL_ASSERT(n->kind() == aten::set_);
210   auto clone_n = addDummyClone(n->owningGraph(), n->input(1), true, n);
211   TORCH_INTERNAL_ASSERT(nullptr != clone_n);
212   clone_n->copyMetadata(n);
213 
214   auto orig_input = n->input(0);
215   n->output()->replaceAllUsesWith(clone_n->output());
216   n->destroy();
217   return std::make_pair(orig_input, clone_n->output());
218 }
219 
PrepareInplaceOpsInBlocksForONNX(Node * node)220 std::pair<Value*, Value*> PrepareInplaceOpsInBlocksForONNX(Node* node) {
221   if (!node->kind().is_aten())
222     return {};
223 
224   auto name = node->schema().name();
225   bool inplace_op = name.at(name.size() - 1) == '_';
226   if (!inplace_op)
227     return {};
228 
229   auto new_schema = name.substr(0, name.size() - 1);
230 
231   Node* input_node = node->inputs().at(0)->node();
232 
233   auto graph = node->owningGraph();
234   auto new_node = graph->create(Symbol::fromQualString(new_schema), 1);
235   for (Value* input : node->inputs()) {
236     new_node->addInput(input);
237   }
238   new_node->output()->setType(node->output()->type());
239   new_node->insertBefore(node);
240   new_node->copyMetadata(node);
241   node->replaceAllUsesWith(new_node);
242   node->destroy();
243 
244   if (input_node->kind() == aten::select || input_node->kind() == aten::slice) {
245     // Cases from a[i] = x. Convert to copy_ and eventually index_put_.
246     WithInsertPoint guard(new_node);
247     auto false_val_ = graph->insertConstant(false);
248 
249     auto new_copy = graph->create(aten::copy_, 1);
250     new_copy->addInput(new_node->inputs().at(0));
251     new_copy->addInput(new_node->output());
252     new_copy->addInput(false_val_);
253     new_copy->insertAfter(new_node);
254     new_copy->copyMetadata(new_node);
255 
256     return PrepareCopyForONNX(new_copy);
257   } else {
258     // Direct aliasing, the node is a standalone inplace op.
259     return std::make_pair(new_node->input(0), new_node->output());
260   }
261 }
262 
263 // aten::pop is inplace. The tensor list input is updated.
264 // This pass creates an aten::__getitem__ op to return the original output from
265 // aten::pop. Then it makes the original aten::pop operator return the updated
266 // tensor list, and replaces all later uses of that tensor list with this new
267 // output.
PrepareListPopForONNX(Node * n)268 static std::pair<Value*, Value*> PrepareListPopForONNX(Node* n) {
269   TORCH_INTERNAL_ASSERT(n->kind() == aten::pop);
270   //   %ten : Tensor = aten::pop(%seq, %pos)
271   // Convert to
272   //   %ten : Tensor = aten::__getitem__(%seq, %pos)
273   //   %new_seq : Tensor[] = aten::pop(%seq, %pos)
274   // And replace all uses of %seq afterwards with %new_seq
275   Node* getitem_node =
276       n->owningGraph()->create(aten::__getitem__, {n->inputs()});
277   getitem_node->output()->setType(n->output()->type());
278   getitem_node->insertBefore(n);
279   getitem_node->copyMetadata(n);
280   n->output()->replaceAllUsesWith(getitem_node->output());
281   n->output()->setType(n->inputs().at(0)->type());
282 
283   return std::make_pair(n->input(0), n->output());
284 }
285 
PrepareListDeleteForONNX(Node * n)286 static std::pair<Value*, Value*> PrepareListDeleteForONNX(Node* n) {
287   TORCH_INTERNAL_ASSERT(n->kind() == aten::Delete);
288   n->addOutput();
289   n->output()->setType(n->inputs().at(0)->type());
290 
291   return std::make_pair(n->input(0), n->output());
292 }
293 
PrepareListAppendAndInsertForONNX(Node * n)294 static std::pair<Value*, Value*> PrepareListAppendAndInsertForONNX(Node* n) {
295   TORCH_INTERNAL_ASSERT(n->kind() == aten::insert || n->kind() == aten::append);
296   if (n->outputs().empty()) {
297     n->addOutput();
298     n->output()->setType(n->inputs().at(0)->type());
299   }
300   return std::make_pair(n->input(0), n->output());
301 }
302 
PrepareSetItemForONNX(Node * n)303 static std::pair<Value*, Value*> PrepareSetItemForONNX(Node* n) {
304   TORCH_INTERNAL_ASSERT(n->kind() == aten::_set_item);
305   // It seems the JIT does not always produce an output for _set_item.
306   // In particular it seems to for list but not for dict.
307   // So we add one if needed.
308   if (n->outputs().empty()) {
309     n->addOutput();
310     n->output()->setType(n->inputs().at(0)->type());
311   }
312   return std::make_pair(n->input(0), n->output());
313 }
314 
315 // Remove Mutation pass does not handle mutation on block inputs.
316 // To fix this, insert a clone node following the graph input:
317 // Example for graph input node %0:
318 // Before:
319 // graph(%0 : Tensor):
320 //   %5 : Tensor = aten::zero_(%0)
321 //   ...
322 // After:
323 // graph(%0 : Tensor):
324 //   %2 : None = prim::Constant()
325 //   %3 : Tensor = aten::clone(%0, %2)
326 //   %5 : Tensor = aten::zero_(%3)
327 //   ...
328 
PrepareForRemoveMutations(MutationRemover & mr,Block * b)329 static void PrepareForRemoveMutations(MutationRemover& mr, Block* b) {
330   for (auto it = b->nodes().begin(), end = b->nodes().end(); it != end; ++it) {
331     for (auto* child_block : it->blocks()) {
332       PrepareForRemoveMutations(mr, child_block);
333     }
334   }
335 
336   for (auto input : b->inputs()) {
337     bool needsRestart = false;
338     do {
339       needsRestart = false;
340       for (auto use : input->uses()) {
341         Node* node = use.user;
342         if (!mr.inplaceOpVariant(node)) {
343           continue;
344         }
345         auto it =
346             std::find(node->inputs().begin(), node->inputs().end(), input);
347         if (it != node->inputs().end()) {
348           int index = std::distance(node->inputs().begin(), it);
349           TORCH_WARN(
350               "ONNX Preprocess - Removing mutation from node ",
351               node->kind().toQualString(),
352               " on block input: '",
353               (*it)->debugName(),
354               "'. This changes graph semantics.");
355 
356           Node* newNode =
357               addDummyClone(b->owningGraph(), input, false, b->return_node());
358           TORCH_INTERNAL_ASSERT(nullptr != newNode);
359           newNode->copyMetadata(node);
360           node->replaceInput(index, newNode->output());
361           input->replaceAllUsesAfterNodeWith(node, newNode->output());
362           needsRestart = true;
363           break;
364         }
365       }
366     } while (needsRestart);
367   }
368 }
369 
PrepareForRemoveMutations(const std::shared_ptr<Graph> & graph)370 static void PrepareForRemoveMutations(const std::shared_ptr<Graph>& graph) {
371   MutationRemover mr(graph);
372   PrepareForRemoveMutations(mr, graph->block());
373   GRAPH_DUMP("After PrepareForRemoveMutations: ", graph);
374 }
375 
376 // findSubModuleAttr function chases getAttr chains backwards to locate the
377 // submodules. For example: module M {
378 //   attributes {
379 //     A = <SubModule at ...>
380 //   }
381 //   ...
382 //   %A = prim::GetAttr[name="A"](%self)
383 //   ...
384 //   %B = prim::GetAttr[name="B"](%A)
385 //   ...
386 //   %weight = prim::GetAttr[name="scale"](%B)
387 //   ...
findSubModuleAttr(Value * input,std::string & name,Module & attrModule,const std::shared_ptr<Graph> & graph)388 std::deque<std::string> findSubModuleAttr(
389     Value* input,
390     std::string& name,
391     Module& attrModule,
392     const std::shared_ptr<Graph>& graph) {
393   Node* node = input->node();
394   std::deque<std::string> moduleNames;
395 
396   // Loop starts from inner submodule and follows the chain until reaches the
397   // top module.
398 
399   auto selfNode = graph->nodes().begin();
400   auto n = *selfNode;
401   while (node->outputs().at(0)->type() != n->output()->type()) {
402     if (node->kind() == prim::GetAttr) {
403       moduleNames.push_front(node->s(attr::name));
404       node = node->inputs()[0]->node();
405     } else {
406       break;
407     }
408   }
409   // Assign the inner module to attrModule.
410   for (auto& moduleName : moduleNames) {
411     attrModule = attrModule.attr(moduleName).toModule();
412   }
413   return moduleNames;
414 }
415 
findArgumentAsInputParam(const std::shared_ptr<Graph> & graph,std::string & name,IValue & attr)416 Value* findArgumentAsInputParam(
417     const std::shared_ptr<Graph>& graph,
418     std::string& name,
419     IValue& attr) {
420   for (auto input : graph->inputs()) {
421     if (input->debugName() == name)
422       return input;
423   }
424   throw std::runtime_error(
425       "Attribute is not part of model parameters. Cannot handle SetAttr and GetAttr nodes for : " +
426       name);
427 }
428 
init(const std::shared_ptr<Graph> & graph)429 void InplaceConverter::ValueTracker::init(const std::shared_ptr<Graph>& graph) {
430   alias_to_value_ = {};
431   value_to_sorted_aliases_ = {};
432   graph_ = graph;
433 }
434 
toString() const435 std::string InplaceConverter::ValueTracker::toString() const {
436   std::stringstream ss;
437 
438   // ss << "Current graph: " << graph_->toString() << std::endl;
439   ss << "Tracking " << value_to_sorted_aliases_.size() << " individual values."
440      << '\n';
441   ss << "value_to_sorted_aliases_: " << '\n';
442   size_t idx = 0;
443   for (const auto& it : value_to_sorted_aliases_) {
444     ss << "Value[" << idx << "]: " << it.first->debugName() << '\n';
445     ss << "  Mapping to ";
446     for (auto v : it.second) {
447       ss << v->debugName() << " ";
448     }
449     ss << '\n';
450     idx++;
451   }
452 
453   ss << "alias_to_value_: " << '\n';
454   for (auto it : alias_to_value_) {
455     ss << "  Alias " << it.first->debugName();
456     ss << " map to " << it.second->debugName() << '\n';
457   }
458 
459   return ss.str();
460 }
461 
recordSetValue(Value * old_v,Value * new_v)462 void InplaceConverter::ValueTracker::recordSetValue(
463     Value* old_v,
464     Value* new_v) {
465   GRAPH_UPDATE(
466       "Calling recordSetValue with old_v: ",
467       old_v->debugName(),
468       " new_v: ",
469       new_v->debugName());
470   GRAPH_UPDATE(this->toString());
471   auto* n = new_v->node();
472   auto* owning_block = n->owningBlock();
473 
474   if (alias_to_value_.find(old_v) == alias_to_value_.end()) {
475     alias_to_value_[old_v] = old_v;
476     value_to_sorted_aliases_[old_v] = {old_v};
477   }
478 
479   auto root_v = alias_to_value_[old_v];
480   alias_to_value_[new_v] = root_v;
481   auto& sorted_alias = value_to_sorted_aliases_[root_v];
482   sorted_alias.insert(new_v);
483 
484   // check if new_v is created inside if or loop subblock.
485   auto* owning_blocknode = owning_block->owningNode();
486   if (nullptr == owning_blocknode) {
487     return;
488   }
489   auto owning_block_nkind = owning_blocknode->kind();
490   if (owning_block_nkind != prim::Loop && owning_block_nkind != prim::If) {
491     return;
492   }
493 
494   bool registered = std::any_of(
495       owning_block->outputs().begin(),
496       owning_block->outputs().end(),
497       [&sorted_alias](Value* out) {
498         return std::any_of(
499             sorted_alias.begin(), sorted_alias.end(), [&out](Value* alias) {
500               return alias == out;
501             });
502       });
503 
504   bool from_outer_alias = std::any_of(
505       sorted_alias.begin(),
506       sorted_alias.end(),
507       [&owning_blocknode](Value* alias) {
508         return isAncestor(
509             alias->node()->owningBlock(), owning_blocknode->owningBlock());
510       });
511 
512   // The data of this value has been changed.
513   // If this value has alias from outer block,
514   // then the update must be reflected back to outside.
515   // Thus it needs to be registered as a subblock output.
516   // This step can be skipped if other alias of this value has already been
517   // registered as subblock output.
518   if (!registered && from_outer_alias) {
519     if (owning_block_nkind == prim::Loop) {
520       owning_block->registerOutput(new_v);
521       auto new_block_in = owning_block->addInput();
522       new_block_in->setType(new_v->type());
523       sorted_alias.insert(new_block_in);
524       alias_to_value_[new_block_in] = root_v;
525       owning_blocknode->addInput(root_v);
526     } else if (owning_block_nkind == prim::If) {
527       for (auto* if_sub_block : owning_blocknode->blocks()) {
528         if (owning_block == if_sub_block) {
529           if_sub_block->registerOutput(new_v);
530         } else {
531           if_sub_block->registerOutput(root_v);
532         }
533       }
534     }
535     auto* new_blocknode_out = owning_blocknode->addOutput();
536     new_blocknode_out->setType(new_v->type());
537     recordSetValue(root_v, new_blocknode_out);
538   }
539 
540   GRAPH_UPDATE(
541       "After recordSetValue for in: ",
542       old_v->debugName(),
543       ", out: ",
544       new_v->debugName(),
545       ". tracker status:");
546   GRAPH_UPDATE(this->toString());
547 }
548 
549 // Based on current value aliases record, pass over graph and correct alias
550 // reference for all the nodes.
correctAliasReferences()551 void InplaceConverter::correctAliasReferences() {
552   correctAliasReferences(graph_->block());
553 }
554 
correctAliasReferences(Block * block)555 void InplaceConverter::correctAliasReferences(Block* block) {
556   for (auto it = block->nodes().begin(); it != block->nodes().end();) {
557     Node* n = *it;
558     it++; // node n can be destroyed
559 
560     correctAliasReferences(n);
561 
562     auto nkind = n->kind();
563     if (nkind == prim::If || nkind == prim::Loop) {
564       for (auto* sub_block : n->blocks()) {
565         correctAliasReferences(sub_block);
566       }
567     }
568   }
569   correctAliasReferences(block->return_node());
570 }
571 
572 // For every input of Node n, find the correct alias representing that input.
correctAliasReferences(Node * n)573 void InplaceConverter::correctAliasReferences(Node* n) {
574   for (size_t i = 0; i < n->inputs().size(); ++i) {
575     auto* in = n->input(i);
576     auto* alias = vt_.findAliasForValueAtNode(in, n);
577 
578     if (alias != in) {
579       n->replaceInput(i, alias);
580       GRAPH_UPDATE(
581           "Replacing ",
582           in->debugName(),
583           " with ",
584           alias->debugName(),
585           " for ",
586           *n);
587     }
588   }
589 }
590 
591 // Find the correct alias representing Value v at Node n.
findAliasForValueAtNode(Value * v,const Node * n) const592 Value* InplaceConverter::ValueTracker::findAliasForValueAtNode(
593     Value* v,
594     const Node* n) const {
595   GRAPH_UPDATE("Finding alias for value:", v->debugName(), " at node ", *n);
596   if (alias_to_value_.find(v) == alias_to_value_.end()) {
597     // This value was not affected by any inplace operator.
598     return v;
599   }
600 
601   auto* root_v = alias_to_value_.find(v)->second;
602   TORCH_INTERNAL_ASSERT(
603       value_to_sorted_aliases_.find(root_v) != value_to_sorted_aliases_.end());
604   const auto& aliases = value_to_sorted_aliases_.find(root_v)->second;
605 
606   // alias is accessible only if
607   // 1. alias owning block is ancestor of n.
608   // 2. alias owning node is before n.
609   // return the last alias that satisfies this condition.
610   Value* found_alias = nullptr;
611   for (auto* alias : aliases) {
612     auto* alias_n = alias->node();
613     if (alias_n->isBefore(n) &&
614         isAncestor(alias_n->owningBlock(), n->owningBlock())) {
615       found_alias = alias;
616     }
617   }
618 
619   TORCH_INTERNAL_ASSERT(
620       nullptr != found_alias,
621       "More details: \n",
622       n->sourceRange().str(),
623       "Input ",
624       v->debugName(),
625       " of node ",
626       *n,
627       " was modified by in-place operation, but we cannot find its updated value. ",
628       "Please report a bug to PyTorch, and/or try to avoid using in-place operators on this value.");
629 
630   return found_alias;
631 }
632 
633 // Pass over block, and gather the initial value for any attribute.
634 // Also cache the full name of the attribute for every GetAttr/SetAttr node.
gatherAttrNameInitialValueMap(Block * block,std::unordered_map<std::string,Value * > & attr_name_value_map,std::unordered_map<Node *,std::string> & attr_node_fullname_map)635 void InplaceConverter::gatherAttrNameInitialValueMap(
636     Block* block,
637     std::unordered_map<std::string, Value*>& attr_name_value_map,
638     std::unordered_map<Node*, std::string>& attr_node_fullname_map) {
639   for (auto it = block->nodes().begin(); it != block->nodes().end();) {
640     Node* n = *it;
641     it++; // node n can be destroyed
642 
643     for (auto* sub_block : n->blocks()) {
644       gatherAttrNameInitialValueMap(
645           sub_block, attr_name_value_map, attr_node_fullname_map);
646     }
647 
648     if (n->kind() != prim::GetAttr && n->kind() != prim::SetAttr)
649       continue;
650 
651     auto name = n->s(attr::name);
652     auto attrModule = *module_;
653     Value* paramConst = nullptr;
654 
655     auto moduleNames =
656         findSubModuleAttr(n->inputs().at(0), name, attrModule, graph_);
657 
658     std::string fullName("");
659     for (auto& name : moduleNames) {
660       fullName += name + '.';
661     }
662     fullName += name;
663 
664     attr_node_fullname_map.insert({n, fullName});
665 
666     if (attr_name_value_map.find(fullName) == attr_name_value_map.end() &&
667         attrModule.hasattr(name)) {
668       auto attr = attrModule.attr(name);
669       auto type = attrModule.type();
670       auto slot = *type->findAttributeSlot(name);
671 
672       // Add model_parameters and model_buffers as model inputs. Order is
673       // preserved based on the appearance in the graph.
674       WithInsertPoint guard(graph_->nodes().front());
675       if (type->is_parameter(slot) || type->is_buffer(slot) ||
676           (attr.isObject() && !attr.toObjectRef().type()->is_module())) {
677         paramConst = findArgumentAsInputParam(graph_, fullName, attr);
678         attr_name_value_map.insert({fullName, paramConst});
679       } else if (auto attrVal = tryInsertConstant(*graph_, attr)) {
680         // TODO: Extend support for attribute of type List[Tensor] etc.
681         for (size_t i = 0; i < type->getAttributes().size(); i++) {
682           if (type->getAttributeName(i) == name) {
683             paramConst = *attrVal;
684             attr_name_value_map.insert({fullName, paramConst});
685           }
686         }
687       } else {
688         // If attribute is a custom class object, instead of primitive types,
689         // Tensor, or List/Tuple/Dict of Tensors.
690         GRAPH_DEBUG(
691             attr.type()->cast<ClassType>() ? "" : "attribute: ",
692             name,
693             " is not materializable.");
694       }
695     }
696 
697     // Create dummy initial value, if initial value does not exist for this
698     // attribute.
699     if (attr_name_value_map.find(fullName) == attr_name_value_map.end()) {
700       auto* noneNode = graph_->create(prim::Constant);
701       noneNode->output()->setType(NoneType::get());
702       noneNode->insertBefore(graph_->nodes().front());
703       attr_name_value_map.insert({fullName, noneNode->output()});
704     }
705   }
706 }
707 
708 // Replace prim::GetAttr and prim::SetAttr with ATen inplace operators.
709 // Example graph:
710 // clang-format off
711 //  Before graph(%x.1 : Float(12, strides=[1], requires_grad=0, device=cpu)):
712 //    %1 : __torch__.___torch_mangle_1.M = prim::CreateObject()
713 //    ...
714 //    %10 : Tensor = aten::arange(%6, %7, %7, %7, %7)
715 //     = prim::SetAttr[name="_bias"](%1, %10)
716 //     = prim::Loop(%5, %8)
717 //      block0(%i.1 : int):
718 //        %12 : bool = aten::eq(%i.1, %4)
719 //         = prim::If(%12)
720 //          block0():
721 //             = prim::Loop(%3, %8)
722 //              block0(%j : int):
723 //                %14 : Tensor = prim::GetAttr[name="_bias"](%1)
724 //                %15 : Tensor = aten::add_(%14, %2, %9)
725 //                 = prim::SetAttr[name="_bias"](%1, %15)
726 //                -> (%8)
727 //            -> ()
728 //          block1():
729 //            %16 : Tensor = aten::arange(%6, %7, %7, %7, %7)
730 //             = prim::SetAttr[name="_bias"](%1, %16)
731 //            -> ()
732 //        -> (%8)
733 //    %17 : Tensor = prim::GetAttr[name="_bias"](%1)
734 //    %18 : Tensor = aten::add(%17, %x.1, %9)
735 //    return (%18)
736 //
737 //  After graph(%x.1 : Float(12, strides=[1], requires_grad=0, device=cpu)):
738 //    %19 : Float(2, strides=[1], requires_grad=0, device=cpu) = prim::Constant[value= 1  1 [ CPUFloatType{2} ]]()
739 //    %1 : __torch__.___torch_mangle_1.M = prim::CreateObject()
740 //    ...
741 //    %10 : Tensor = aten::arange(%6, %7, %7, %7, %7)
742 //    %28 : Tensor = aten::set_(%19, %10)
743 //     = prim::Loop(%5, %8)
744 //      block0(%i.1 : int):
745 //        %12 : bool = aten::eq(%i.1, %4)
746 //         = prim::If(%12)
747 //          block0():
748 //             = prim::Loop(%3, %8)
749 //              block0(%j : int):
750 //                %15 : Tensor = aten::add_(%19, %2, %9)
751 //                %25 : Tensor = aten::set_(%19, %15)
752 //                -> (%8)
753 //            -> ()
754 //          block1():
755 //            %16 : Tensor = aten::arange(%6, %7, %7, %7, %7)
756 //            %22 : Tensor = aten::set_(%19, %16)
757 //            -> ()
758 //        -> (%8)
759 //    %18 : Tensor = aten::add(%19, %x.1, %9)
760 //    return (%18)
761 // clang-format on
replaceAttrWithInplaceOps(Block * block,const std::unordered_map<std::string,Value * > & attr_name_value_map,const std::unordered_map<Node *,std::string> & attr_node_fullname_map)762 void InplaceConverter::replaceAttrWithInplaceOps(
763     Block* block,
764     const std::unordered_map<std::string, Value*>& attr_name_value_map,
765     const std::unordered_map<Node*, std::string>& attr_node_fullname_map) {
766   for (const auto& pair : attr_node_fullname_map) {
767     auto* n = pair.first;
768     auto fullName = pair.second;
769     auto find_init_val = attr_name_value_map.find(fullName);
770     TORCH_INTERNAL_ASSERT(find_init_val != attr_name_value_map.end());
771 
772     TORCH_INTERNAL_ASSERT(
773         n->kind() == prim::GetAttr || n->kind() == prim::SetAttr);
774     if (n->kind() == prim::SetAttr) {
775       // Convert SetAttr to inplace op aten::set_.
776       WithInsertPoint guard(n);
777       auto* set_node = graph_->create(aten::set_, 1);
778       set_node->addInput(find_init_val->second);
779       set_node->addInput(n->input(1));
780       set_node->copyMetadata(n);
781       set_node->insertBefore(n);
782     } else if (n->kind() == prim::GetAttr) {
783       // Replace use of GetAttr with first seen alias (usually initial value) of
784       // that particular value. Correct alias at point of this node will be
785       // discovered and assigned in later pass.
786       n->output()->replaceAllUsesWith(find_init_val->second);
787     }
788 
789     n->destroy();
790   }
791 }
792 
convertGetSetAttrToInplaceOps(Block * block)793 void InplaceConverter::convertGetSetAttrToInplaceOps(Block* block) {
794   std::unordered_map<std::string, Value*> attr_name_value_map = {};
795   std::unordered_map<Node*, std::string> attr_node_fullname_map = {};
796   // First pass over graph, to gather all attribute names, and their initial
797   // values. Create dummy initial values for attributes if necessary. By the end
798   // of this pass, these dummy initial values should have zero uses, and can be
799   // safely removed. Otherwise it will imply an error in the model for using
800   // uninitialized values.
801   gatherAttrNameInitialValueMap(
802       block, attr_name_value_map, attr_node_fullname_map);
803   GRAPH_UPDATE("Graph after gatherAttrNameInitialValueMap", graph_->toString());
804 
805   // Second pass over graph,
806   // replace GetAttr with first seen alias (usually initial value),
807   // and replace SetAttr with inplace op, updating new value onto first seen
808   // alias.
809   replaceAttrWithInplaceOps(block, attr_name_value_map, attr_node_fullname_map);
810 }
811 
812 // Convert inplace ops to outplace version, and record the associated new alias
813 // in ValueTracker.
convertInplaceOpsAndTrackAlias(Block * block)814 void InplaceConverter::convertInplaceOpsAndTrackAlias(Block* block) {
815   for (auto it = block->nodes().begin(); it != block->nodes().end();) {
816     Node* n = *it;
817     it++; // node n can be destroyed
818 
819     auto nkind = n->kind();
820     if (nkind == prim::If || nkind == prim::Loop) {
821       for (Block* sub_block : n->blocks()) {
822         convertInplaceOpsAndTrackAlias(sub_block);
823       }
824     } else {
825       Value *orig_data = nullptr, *new_out = nullptr;
826       if (nkind == aten::copy_) {
827         std::tie(orig_data, new_out) = PrepareCopyForONNX(n);
828       } else if (nkind == aten::index_put || nkind == aten::index_put_) {
829         std::tie(orig_data, new_out) = PrepareIndexPutForONNX(n);
830         if (nkind == aten::index_put) {
831           // special case, index_put is not inplace.
832           continue;
833         }
834       } else if (nkind == aten::insert || nkind == aten::append) {
835         std::tie(orig_data, new_out) = PrepareListAppendAndInsertForONNX(n);
836       } else if (nkind == aten::set_) {
837         std::tie(orig_data, new_out) = PrepareSetForONNX(n);
838       } else if (mr_->inplaceOpVariant(n)) {
839         std::tie(orig_data, new_out) = PrepareInplaceOpsInBlocksForONNX(n);
840       } else if (nkind == aten::pop) {
841         std::tie(orig_data, new_out) = PrepareListPopForONNX(n);
842       } else if (nkind == aten::Delete) {
843         std::tie(orig_data, new_out) = PrepareListDeleteForONNX(n);
844       } else if (nkind == aten::_set_item) {
845         std::tie(orig_data, new_out) = PrepareSetItemForONNX(n);
846       } else {
847         // Not inplace op.
848         continue;
849       }
850 
851       if (nullptr != orig_data && nullptr != new_out) {
852         vt_.recordSetValue(orig_data, new_out);
853       }
854     }
855   }
856 }
857 
convertInplaceOpsAndTrackAlias()858 void InplaceConverter::convertInplaceOpsAndTrackAlias() {
859   convertInplaceOpsAndTrackAlias(graph_->block());
860   GRAPH_UPDATE(
861       "Graph after convertInplaceOpsAndTrackAlias: ", graph_->toString());
862   GRAPH_UPDATE(vt_.toString());
863 }
864 
convertMutationForONNX()865 void InplaceConverter::convertMutationForONNX() {
866   // First pass to convert all prim::GetAttr and prim::SetAttr to ATen inplace
867   // operators.
868   convertGetSetAttrToInplaceOps(graph_->block());
869   GRAPH_UPDATE("Graph after convertGetSetAttrToInplaceOps", graph_->toString());
870   vt_.init(graph_);
871   // Second pass to convert all inplace operators to outplace version, and
872   // record the associated new alias in ValueTracker.
873   convertInplaceOpsAndTrackAlias();
874   // Third pass to check and correct alias reference for all the nodes.
875   correctAliasReferences();
876 }
877 
878 } // namespace
879 
RemoveInplaceOpsForONNX(const std::shared_ptr<Graph> & graph,Module * model=nullptr)880 void RemoveInplaceOpsForONNX(
881     const std::shared_ptr<Graph>& graph,
882     Module* model = nullptr) {
883   ImplicitCastForBinaryInplaceOps(graph->block());
884   PrepareForRemoveMutations(graph);
885   MutationRemover mr(graph);
886   mr.removeTensorMutation();
887   mr.removeListMutation();
888   InplaceConverter ic(graph, &mr, model);
889   ic.convertMutationForONNX();
890 }
891 
892 } // namespace torch::jit
893