1 #include <torch/csrc/jit/passes/onnx/helper.h>
2 #include <torch/csrc/jit/passes/onnx/remove_inplace_ops_for_onnx.h>
3 #include <torch/csrc/jit/passes/remove_inplace_ops.h>
4 #include <torch/csrc/jit/passes/remove_mutation.h>
5
6 #include <torch/csrc/jit/frontend/error_report.h>
7 #include <torch/csrc/jit/jit_log.h>
8 #include <torch/csrc/jit/passes/dead_code_elimination.h>
9 #include <torch/csrc/jit/passes/onnx/helper.h>
10 #include <torch/csrc/jit/passes/onnx/pattern_conversion/pattern_encapsulation.h>
11
12 #include <c10/util/irange.h>
13
14 #include <limits>
15
16 namespace torch::jit {
17
18 namespace {
19
20 const std::set<c10::Symbol> inplace_ops =
21 {aten::append, aten::index_put_, aten::pop, aten::insert, aten::Delete};
22
23 // InplaceConverter defines a set of functions that together enables the
24 // conversion from prim::GetAttr, prim::SetAttr, and ATen in-place operators to
25 // ONNX out-place operators.
26 struct InplaceConverter {
InplaceConvertertorch::jit::__anone66fbc010111::InplaceConverter27 InplaceConverter(
28 std::shared_ptr<Graph> graph,
29 MutationRemover* mr,
30 Module* model = nullptr)
31 : graph_(std::move(graph)), mr_(mr), module_(model) {}
32
33 void convertMutationForONNX();
34
35 private:
36 void gatherAttrNameInitialValueMap(
37 Block* block,
38 std::unordered_map<std::string, Value*>& attr_name_value_map,
39 std::unordered_map<Node*, std::string>& attr_node_fullname_map);
40 void replaceAttrWithInplaceOps(
41 Block* block,
42 const std::unordered_map<std::string, Value*>& attr_name_value_map,
43 const std::unordered_map<Node*, std::string>& attr_node_fullname_map);
44
45 void convertInplaceOpsAndTrackAlias();
46 void convertInplaceOpsAndTrackAlias(Block* block);
47
48 void correctAliasReferences();
49 void correctAliasReferences(Block* block);
50 void correctAliasReferences(Node* n);
51
52 void convertGetSetAttrToInplaceOps(Block* block);
53
54 // ValueTracker provides apis to record aliases for a single value,
55 // and to retrieve the correct alias of any given value based on the location
56 // in the graph it is used.
57 struct ValueTracker {
ValueTrackertorch::jit::__anone66fbc010111::InplaceConverter::ValueTracker58 ValueTracker() : graph_(nullptr) {}
59
60 void init(const std::shared_ptr<Graph>& graph);
61 void recordSetValue(Value* old_v, Value* new_v);
62 Value* findAliasForValueAtNode(Value* v, const Node* n) const;
63
64 std::string toString() const;
65
66 private:
67 std::shared_ptr<Graph> graph_;
68
69 // Map from aliases to root value.
70 // A single value can have multiple aliases throughout the graph,
71 // created by inplace operators, and preserved through loop carried
72 // input/output. For each such value, its first occurrence will be set as
73 // root value.
74 std::unordered_map<Value*, Value*> alias_to_value_;
75
76 // Sort the alias based on their order in graph.
77 // A tie can happen when two distinct aliases belong to different blocks,
78 // while having the same ancestor node. The unique id is used as tie
79 // breaker, otherwise the two aliases will be considered equal to each
80 // other. aliasComp must satisfy strict weak ordering.
81 struct aliasComp {
operator ()torch::jit::__anone66fbc010111::InplaceConverter::ValueTracker::aliasComp82 bool operator()(const Value* a, const Value* b) const {
83 auto* n_a = a->node();
84 auto* n_b = b->node();
85 if (n_a == n_b) {
86 return false;
87 }
88 auto a_b = n_a->isBefore(n_b);
89 auto b_a = n_b->isBefore(n_a);
90 if (a_b == b_a) {
91 return a->unique() < b->unique();
92 }
93 return a_b;
94 }
95 };
96 // Map from root value to aliases sorted by their order in graph.
97 std::unordered_map<Value*, std::set<Value*, aliasComp>>
98 value_to_sorted_aliases_;
99 };
100
101 std::shared_ptr<Graph> graph_;
102 MutationRemover* mr_;
103 Module* module_;
104 ValueTracker vt_;
105 };
106
isAncestor(const Block * a,const Block * b)107 bool isAncestor(const Block* a, const Block* b) {
108 while (b && b->owningNode()) {
109 if (a == b) {
110 return true;
111 }
112 b = b->owningNode()->owningBlock();
113 }
114 return a == b;
115 }
116
addDummyClone(Graph * graph,Value * orig_data,bool insertBefore,Node * referenceNode)117 Node* addDummyClone(
118 Graph* graph,
119 Value* orig_data,
120 bool insertBefore,
121 Node* referenceNode) {
122 Node* newNode = nullptr;
123 if (orig_data->type()->kind() == TypeKind::ListType) {
124 newNode = graph->create(aten::list, /*num_outputs =*/1);
125 newNode->addInput(orig_data);
126 newNode->output()->setType(orig_data->type());
127 if (insertBefore)
128 newNode->insertBefore(referenceNode);
129 else
130 referenceNode->owningBlock()->prependNode(newNode);
131 } else if (
132 orig_data->type()->kind() == TypeKind::TensorType ||
133 orig_data->type()->kind() == TypeKind::IntType ||
134 orig_data->type()->kind() == TypeKind::FloatType ||
135 orig_data->type()->kind() == TypeKind::BoolType) {
136 auto* noneNode = graph->create(prim::Constant);
137 noneNode->output()->setType(NoneType::get());
138 // For scripting mode, aten::clone requires input to be a TensorType
139 // Hence if we encounter an IntType, FloatType, or BoolType,
140 // we set the input to the appropriate TensorType
141 if (orig_data->type()->kind() == TypeKind::IntType &&
142 insertBefore == false) {
143 orig_data->setType(TensorType::fromNumberType(*IntType::get()));
144 } else if (
145 orig_data->type()->kind() == TypeKind::FloatType &&
146 insertBefore == false) {
147 orig_data->setType(TensorType::fromNumberType(*FloatType::get()));
148 } else if (
149 orig_data->type()->kind() == TypeKind::BoolType &&
150 insertBefore == false) {
151 orig_data->setType(TensorType::fromBoolType());
152 }
153 newNode = graph->create(aten::clone, /*num_outputs =*/1);
154 newNode->addInput(orig_data);
155 newNode->addInput(noneNode->output());
156 newNode->output()->setType(orig_data->type());
157 if (insertBefore)
158 newNode->insertBefore(referenceNode);
159 else
160 referenceNode->owningBlock()->prependNode(newNode);
161 noneNode->insertBefore(newNode);
162 }
163 return newNode;
164 }
165
PrepareIndexPutForONNX(Node * node)166 std::pair<Value*, Value*> PrepareIndexPutForONNX(Node* node) {
167 TORCH_INTERNAL_ASSERT(
168 node->kind() == aten::index_put || node->kind() == aten::index_put_);
169 auto placeholder_node = EncapsulatePatternIntoSubblock(node).value();
170 node->destroy();
171 return std::make_pair(placeholder_node->input(0), placeholder_node->output());
172 }
173
PrepareCopyForONNX(Node * node)174 std::pair<Value*, Value*> PrepareCopyForONNX(Node* node) {
175 TORCH_INTERNAL_ASSERT(node->kind() == aten::copy_);
176 // aten::copy_ can be viewed as a special case of index_put, where the
177 // tensor indices input is empty.
178 // Remove aten::copy_, and replace it with index_put.
179 // 1. create an empty listConstruct node as indices input for index_put.
180 // 2. create index_put node.
181
182 // Tracing aten::copy_ broadcasts the rhs values.
183 // 3. Apply broadcasting for scripting.
184 WithInsertPoint guard(node);
185 auto graph = node->owningGraph();
186 auto dummy_list =
187 graph->insertNode(graph->createList(OptionalType::ofTensor(), {}))
188 ->output();
189
190 auto expanded_value =
191 graph->insert(aten::expand_as, {node->input(1), node->input(0)});
192 expanded_value->node()->setSourceRange(node->sourceRange());
193 expanded_value->copyMetadata(node->input(1));
194 expanded_value->node()->copyMetadata(node);
195
196 auto index_put = graph->insert(
197 aten::index_put_,
198 {node->input(0), dummy_list, expanded_value, node->input(2)});
199 index_put->node()->copyMetadata(node);
200 index_put->copyMetadata(node->output());
201 node->output()->replaceAllUsesWith(index_put);
202
203 node->destroy();
204
205 return PrepareIndexPutForONNX(index_put->node());
206 }
207
PrepareSetForONNX(Node * n)208 auto PrepareSetForONNX(Node* n) {
209 TORCH_INTERNAL_ASSERT(n->kind() == aten::set_);
210 auto clone_n = addDummyClone(n->owningGraph(), n->input(1), true, n);
211 TORCH_INTERNAL_ASSERT(nullptr != clone_n);
212 clone_n->copyMetadata(n);
213
214 auto orig_input = n->input(0);
215 n->output()->replaceAllUsesWith(clone_n->output());
216 n->destroy();
217 return std::make_pair(orig_input, clone_n->output());
218 }
219
PrepareInplaceOpsInBlocksForONNX(Node * node)220 std::pair<Value*, Value*> PrepareInplaceOpsInBlocksForONNX(Node* node) {
221 if (!node->kind().is_aten())
222 return {};
223
224 auto name = node->schema().name();
225 bool inplace_op = name.at(name.size() - 1) == '_';
226 if (!inplace_op)
227 return {};
228
229 auto new_schema = name.substr(0, name.size() - 1);
230
231 Node* input_node = node->inputs().at(0)->node();
232
233 auto graph = node->owningGraph();
234 auto new_node = graph->create(Symbol::fromQualString(new_schema), 1);
235 for (Value* input : node->inputs()) {
236 new_node->addInput(input);
237 }
238 new_node->output()->setType(node->output()->type());
239 new_node->insertBefore(node);
240 new_node->copyMetadata(node);
241 node->replaceAllUsesWith(new_node);
242 node->destroy();
243
244 if (input_node->kind() == aten::select || input_node->kind() == aten::slice) {
245 // Cases from a[i] = x. Convert to copy_ and eventually index_put_.
246 WithInsertPoint guard(new_node);
247 auto false_val_ = graph->insertConstant(false);
248
249 auto new_copy = graph->create(aten::copy_, 1);
250 new_copy->addInput(new_node->inputs().at(0));
251 new_copy->addInput(new_node->output());
252 new_copy->addInput(false_val_);
253 new_copy->insertAfter(new_node);
254 new_copy->copyMetadata(new_node);
255
256 return PrepareCopyForONNX(new_copy);
257 } else {
258 // Direct aliasing, the node is a standalone inplace op.
259 return std::make_pair(new_node->input(0), new_node->output());
260 }
261 }
262
263 // aten::pop is inplace. The tensor list input is updated.
264 // This pass creates an aten::__getitem__ op to return the original output from
265 // aten::pop. Then it makes the original aten::pop operator return the updated
266 // tensor list, and replaces all later uses of that tensor list with this new
267 // output.
PrepareListPopForONNX(Node * n)268 static std::pair<Value*, Value*> PrepareListPopForONNX(Node* n) {
269 TORCH_INTERNAL_ASSERT(n->kind() == aten::pop);
270 // %ten : Tensor = aten::pop(%seq, %pos)
271 // Convert to
272 // %ten : Tensor = aten::__getitem__(%seq, %pos)
273 // %new_seq : Tensor[] = aten::pop(%seq, %pos)
274 // And replace all uses of %seq afterwards with %new_seq
275 Node* getitem_node =
276 n->owningGraph()->create(aten::__getitem__, {n->inputs()});
277 getitem_node->output()->setType(n->output()->type());
278 getitem_node->insertBefore(n);
279 getitem_node->copyMetadata(n);
280 n->output()->replaceAllUsesWith(getitem_node->output());
281 n->output()->setType(n->inputs().at(0)->type());
282
283 return std::make_pair(n->input(0), n->output());
284 }
285
PrepareListDeleteForONNX(Node * n)286 static std::pair<Value*, Value*> PrepareListDeleteForONNX(Node* n) {
287 TORCH_INTERNAL_ASSERT(n->kind() == aten::Delete);
288 n->addOutput();
289 n->output()->setType(n->inputs().at(0)->type());
290
291 return std::make_pair(n->input(0), n->output());
292 }
293
PrepareListAppendAndInsertForONNX(Node * n)294 static std::pair<Value*, Value*> PrepareListAppendAndInsertForONNX(Node* n) {
295 TORCH_INTERNAL_ASSERT(n->kind() == aten::insert || n->kind() == aten::append);
296 if (n->outputs().empty()) {
297 n->addOutput();
298 n->output()->setType(n->inputs().at(0)->type());
299 }
300 return std::make_pair(n->input(0), n->output());
301 }
302
PrepareSetItemForONNX(Node * n)303 static std::pair<Value*, Value*> PrepareSetItemForONNX(Node* n) {
304 TORCH_INTERNAL_ASSERT(n->kind() == aten::_set_item);
305 // It seems the JIT does not always produce an output for _set_item.
306 // In particular it seems to for list but not for dict.
307 // So we add one if needed.
308 if (n->outputs().empty()) {
309 n->addOutput();
310 n->output()->setType(n->inputs().at(0)->type());
311 }
312 return std::make_pair(n->input(0), n->output());
313 }
314
315 // Remove Mutation pass does not handle mutation on block inputs.
316 // To fix this, insert a clone node following the graph input:
317 // Example for graph input node %0:
318 // Before:
319 // graph(%0 : Tensor):
320 // %5 : Tensor = aten::zero_(%0)
321 // ...
322 // After:
323 // graph(%0 : Tensor):
324 // %2 : None = prim::Constant()
325 // %3 : Tensor = aten::clone(%0, %2)
326 // %5 : Tensor = aten::zero_(%3)
327 // ...
328
PrepareForRemoveMutations(MutationRemover & mr,Block * b)329 static void PrepareForRemoveMutations(MutationRemover& mr, Block* b) {
330 for (auto it = b->nodes().begin(), end = b->nodes().end(); it != end; ++it) {
331 for (auto* child_block : it->blocks()) {
332 PrepareForRemoveMutations(mr, child_block);
333 }
334 }
335
336 for (auto input : b->inputs()) {
337 bool needsRestart = false;
338 do {
339 needsRestart = false;
340 for (auto use : input->uses()) {
341 Node* node = use.user;
342 if (!mr.inplaceOpVariant(node)) {
343 continue;
344 }
345 auto it =
346 std::find(node->inputs().begin(), node->inputs().end(), input);
347 if (it != node->inputs().end()) {
348 int index = std::distance(node->inputs().begin(), it);
349 TORCH_WARN(
350 "ONNX Preprocess - Removing mutation from node ",
351 node->kind().toQualString(),
352 " on block input: '",
353 (*it)->debugName(),
354 "'. This changes graph semantics.");
355
356 Node* newNode =
357 addDummyClone(b->owningGraph(), input, false, b->return_node());
358 TORCH_INTERNAL_ASSERT(nullptr != newNode);
359 newNode->copyMetadata(node);
360 node->replaceInput(index, newNode->output());
361 input->replaceAllUsesAfterNodeWith(node, newNode->output());
362 needsRestart = true;
363 break;
364 }
365 }
366 } while (needsRestart);
367 }
368 }
369
PrepareForRemoveMutations(const std::shared_ptr<Graph> & graph)370 static void PrepareForRemoveMutations(const std::shared_ptr<Graph>& graph) {
371 MutationRemover mr(graph);
372 PrepareForRemoveMutations(mr, graph->block());
373 GRAPH_DUMP("After PrepareForRemoveMutations: ", graph);
374 }
375
376 // findSubModuleAttr function chases getAttr chains backwards to locate the
377 // submodules. For example: module M {
378 // attributes {
379 // A = <SubModule at ...>
380 // }
381 // ...
382 // %A = prim::GetAttr[name="A"](%self)
383 // ...
384 // %B = prim::GetAttr[name="B"](%A)
385 // ...
386 // %weight = prim::GetAttr[name="scale"](%B)
387 // ...
findSubModuleAttr(Value * input,std::string & name,Module & attrModule,const std::shared_ptr<Graph> & graph)388 std::deque<std::string> findSubModuleAttr(
389 Value* input,
390 std::string& name,
391 Module& attrModule,
392 const std::shared_ptr<Graph>& graph) {
393 Node* node = input->node();
394 std::deque<std::string> moduleNames;
395
396 // Loop starts from inner submodule and follows the chain until reaches the
397 // top module.
398
399 auto selfNode = graph->nodes().begin();
400 auto n = *selfNode;
401 while (node->outputs().at(0)->type() != n->output()->type()) {
402 if (node->kind() == prim::GetAttr) {
403 moduleNames.push_front(node->s(attr::name));
404 node = node->inputs()[0]->node();
405 } else {
406 break;
407 }
408 }
409 // Assign the inner module to attrModule.
410 for (auto& moduleName : moduleNames) {
411 attrModule = attrModule.attr(moduleName).toModule();
412 }
413 return moduleNames;
414 }
415
findArgumentAsInputParam(const std::shared_ptr<Graph> & graph,std::string & name,IValue & attr)416 Value* findArgumentAsInputParam(
417 const std::shared_ptr<Graph>& graph,
418 std::string& name,
419 IValue& attr) {
420 for (auto input : graph->inputs()) {
421 if (input->debugName() == name)
422 return input;
423 }
424 throw std::runtime_error(
425 "Attribute is not part of model parameters. Cannot handle SetAttr and GetAttr nodes for : " +
426 name);
427 }
428
init(const std::shared_ptr<Graph> & graph)429 void InplaceConverter::ValueTracker::init(const std::shared_ptr<Graph>& graph) {
430 alias_to_value_ = {};
431 value_to_sorted_aliases_ = {};
432 graph_ = graph;
433 }
434
toString() const435 std::string InplaceConverter::ValueTracker::toString() const {
436 std::stringstream ss;
437
438 // ss << "Current graph: " << graph_->toString() << std::endl;
439 ss << "Tracking " << value_to_sorted_aliases_.size() << " individual values."
440 << '\n';
441 ss << "value_to_sorted_aliases_: " << '\n';
442 size_t idx = 0;
443 for (const auto& it : value_to_sorted_aliases_) {
444 ss << "Value[" << idx << "]: " << it.first->debugName() << '\n';
445 ss << " Mapping to ";
446 for (auto v : it.second) {
447 ss << v->debugName() << " ";
448 }
449 ss << '\n';
450 idx++;
451 }
452
453 ss << "alias_to_value_: " << '\n';
454 for (auto it : alias_to_value_) {
455 ss << " Alias " << it.first->debugName();
456 ss << " map to " << it.second->debugName() << '\n';
457 }
458
459 return ss.str();
460 }
461
recordSetValue(Value * old_v,Value * new_v)462 void InplaceConverter::ValueTracker::recordSetValue(
463 Value* old_v,
464 Value* new_v) {
465 GRAPH_UPDATE(
466 "Calling recordSetValue with old_v: ",
467 old_v->debugName(),
468 " new_v: ",
469 new_v->debugName());
470 GRAPH_UPDATE(this->toString());
471 auto* n = new_v->node();
472 auto* owning_block = n->owningBlock();
473
474 if (alias_to_value_.find(old_v) == alias_to_value_.end()) {
475 alias_to_value_[old_v] = old_v;
476 value_to_sorted_aliases_[old_v] = {old_v};
477 }
478
479 auto root_v = alias_to_value_[old_v];
480 alias_to_value_[new_v] = root_v;
481 auto& sorted_alias = value_to_sorted_aliases_[root_v];
482 sorted_alias.insert(new_v);
483
484 // check if new_v is created inside if or loop subblock.
485 auto* owning_blocknode = owning_block->owningNode();
486 if (nullptr == owning_blocknode) {
487 return;
488 }
489 auto owning_block_nkind = owning_blocknode->kind();
490 if (owning_block_nkind != prim::Loop && owning_block_nkind != prim::If) {
491 return;
492 }
493
494 bool registered = std::any_of(
495 owning_block->outputs().begin(),
496 owning_block->outputs().end(),
497 [&sorted_alias](Value* out) {
498 return std::any_of(
499 sorted_alias.begin(), sorted_alias.end(), [&out](Value* alias) {
500 return alias == out;
501 });
502 });
503
504 bool from_outer_alias = std::any_of(
505 sorted_alias.begin(),
506 sorted_alias.end(),
507 [&owning_blocknode](Value* alias) {
508 return isAncestor(
509 alias->node()->owningBlock(), owning_blocknode->owningBlock());
510 });
511
512 // The data of this value has been changed.
513 // If this value has alias from outer block,
514 // then the update must be reflected back to outside.
515 // Thus it needs to be registered as a subblock output.
516 // This step can be skipped if other alias of this value has already been
517 // registered as subblock output.
518 if (!registered && from_outer_alias) {
519 if (owning_block_nkind == prim::Loop) {
520 owning_block->registerOutput(new_v);
521 auto new_block_in = owning_block->addInput();
522 new_block_in->setType(new_v->type());
523 sorted_alias.insert(new_block_in);
524 alias_to_value_[new_block_in] = root_v;
525 owning_blocknode->addInput(root_v);
526 } else if (owning_block_nkind == prim::If) {
527 for (auto* if_sub_block : owning_blocknode->blocks()) {
528 if (owning_block == if_sub_block) {
529 if_sub_block->registerOutput(new_v);
530 } else {
531 if_sub_block->registerOutput(root_v);
532 }
533 }
534 }
535 auto* new_blocknode_out = owning_blocknode->addOutput();
536 new_blocknode_out->setType(new_v->type());
537 recordSetValue(root_v, new_blocknode_out);
538 }
539
540 GRAPH_UPDATE(
541 "After recordSetValue for in: ",
542 old_v->debugName(),
543 ", out: ",
544 new_v->debugName(),
545 ". tracker status:");
546 GRAPH_UPDATE(this->toString());
547 }
548
549 // Based on current value aliases record, pass over graph and correct alias
550 // reference for all the nodes.
correctAliasReferences()551 void InplaceConverter::correctAliasReferences() {
552 correctAliasReferences(graph_->block());
553 }
554
correctAliasReferences(Block * block)555 void InplaceConverter::correctAliasReferences(Block* block) {
556 for (auto it = block->nodes().begin(); it != block->nodes().end();) {
557 Node* n = *it;
558 it++; // node n can be destroyed
559
560 correctAliasReferences(n);
561
562 auto nkind = n->kind();
563 if (nkind == prim::If || nkind == prim::Loop) {
564 for (auto* sub_block : n->blocks()) {
565 correctAliasReferences(sub_block);
566 }
567 }
568 }
569 correctAliasReferences(block->return_node());
570 }
571
572 // For every input of Node n, find the correct alias representing that input.
correctAliasReferences(Node * n)573 void InplaceConverter::correctAliasReferences(Node* n) {
574 for (size_t i = 0; i < n->inputs().size(); ++i) {
575 auto* in = n->input(i);
576 auto* alias = vt_.findAliasForValueAtNode(in, n);
577
578 if (alias != in) {
579 n->replaceInput(i, alias);
580 GRAPH_UPDATE(
581 "Replacing ",
582 in->debugName(),
583 " with ",
584 alias->debugName(),
585 " for ",
586 *n);
587 }
588 }
589 }
590
591 // Find the correct alias representing Value v at Node n.
findAliasForValueAtNode(Value * v,const Node * n) const592 Value* InplaceConverter::ValueTracker::findAliasForValueAtNode(
593 Value* v,
594 const Node* n) const {
595 GRAPH_UPDATE("Finding alias for value:", v->debugName(), " at node ", *n);
596 if (alias_to_value_.find(v) == alias_to_value_.end()) {
597 // This value was not affected by any inplace operator.
598 return v;
599 }
600
601 auto* root_v = alias_to_value_.find(v)->second;
602 TORCH_INTERNAL_ASSERT(
603 value_to_sorted_aliases_.find(root_v) != value_to_sorted_aliases_.end());
604 const auto& aliases = value_to_sorted_aliases_.find(root_v)->second;
605
606 // alias is accessible only if
607 // 1. alias owning block is ancestor of n.
608 // 2. alias owning node is before n.
609 // return the last alias that satisfies this condition.
610 Value* found_alias = nullptr;
611 for (auto* alias : aliases) {
612 auto* alias_n = alias->node();
613 if (alias_n->isBefore(n) &&
614 isAncestor(alias_n->owningBlock(), n->owningBlock())) {
615 found_alias = alias;
616 }
617 }
618
619 TORCH_INTERNAL_ASSERT(
620 nullptr != found_alias,
621 "More details: \n",
622 n->sourceRange().str(),
623 "Input ",
624 v->debugName(),
625 " of node ",
626 *n,
627 " was modified by in-place operation, but we cannot find its updated value. ",
628 "Please report a bug to PyTorch, and/or try to avoid using in-place operators on this value.");
629
630 return found_alias;
631 }
632
633 // Pass over block, and gather the initial value for any attribute.
634 // Also cache the full name of the attribute for every GetAttr/SetAttr node.
gatherAttrNameInitialValueMap(Block * block,std::unordered_map<std::string,Value * > & attr_name_value_map,std::unordered_map<Node *,std::string> & attr_node_fullname_map)635 void InplaceConverter::gatherAttrNameInitialValueMap(
636 Block* block,
637 std::unordered_map<std::string, Value*>& attr_name_value_map,
638 std::unordered_map<Node*, std::string>& attr_node_fullname_map) {
639 for (auto it = block->nodes().begin(); it != block->nodes().end();) {
640 Node* n = *it;
641 it++; // node n can be destroyed
642
643 for (auto* sub_block : n->blocks()) {
644 gatherAttrNameInitialValueMap(
645 sub_block, attr_name_value_map, attr_node_fullname_map);
646 }
647
648 if (n->kind() != prim::GetAttr && n->kind() != prim::SetAttr)
649 continue;
650
651 auto name = n->s(attr::name);
652 auto attrModule = *module_;
653 Value* paramConst = nullptr;
654
655 auto moduleNames =
656 findSubModuleAttr(n->inputs().at(0), name, attrModule, graph_);
657
658 std::string fullName("");
659 for (auto& name : moduleNames) {
660 fullName += name + '.';
661 }
662 fullName += name;
663
664 attr_node_fullname_map.insert({n, fullName});
665
666 if (attr_name_value_map.find(fullName) == attr_name_value_map.end() &&
667 attrModule.hasattr(name)) {
668 auto attr = attrModule.attr(name);
669 auto type = attrModule.type();
670 auto slot = *type->findAttributeSlot(name);
671
672 // Add model_parameters and model_buffers as model inputs. Order is
673 // preserved based on the appearance in the graph.
674 WithInsertPoint guard(graph_->nodes().front());
675 if (type->is_parameter(slot) || type->is_buffer(slot) ||
676 (attr.isObject() && !attr.toObjectRef().type()->is_module())) {
677 paramConst = findArgumentAsInputParam(graph_, fullName, attr);
678 attr_name_value_map.insert({fullName, paramConst});
679 } else if (auto attrVal = tryInsertConstant(*graph_, attr)) {
680 // TODO: Extend support for attribute of type List[Tensor] etc.
681 for (size_t i = 0; i < type->getAttributes().size(); i++) {
682 if (type->getAttributeName(i) == name) {
683 paramConst = *attrVal;
684 attr_name_value_map.insert({fullName, paramConst});
685 }
686 }
687 } else {
688 // If attribute is a custom class object, instead of primitive types,
689 // Tensor, or List/Tuple/Dict of Tensors.
690 GRAPH_DEBUG(
691 attr.type()->cast<ClassType>() ? "" : "attribute: ",
692 name,
693 " is not materializable.");
694 }
695 }
696
697 // Create dummy initial value, if initial value does not exist for this
698 // attribute.
699 if (attr_name_value_map.find(fullName) == attr_name_value_map.end()) {
700 auto* noneNode = graph_->create(prim::Constant);
701 noneNode->output()->setType(NoneType::get());
702 noneNode->insertBefore(graph_->nodes().front());
703 attr_name_value_map.insert({fullName, noneNode->output()});
704 }
705 }
706 }
707
708 // Replace prim::GetAttr and prim::SetAttr with ATen inplace operators.
709 // Example graph:
710 // clang-format off
711 // Before graph(%x.1 : Float(12, strides=[1], requires_grad=0, device=cpu)):
712 // %1 : __torch__.___torch_mangle_1.M = prim::CreateObject()
713 // ...
714 // %10 : Tensor = aten::arange(%6, %7, %7, %7, %7)
715 // = prim::SetAttr[name="_bias"](%1, %10)
716 // = prim::Loop(%5, %8)
717 // block0(%i.1 : int):
718 // %12 : bool = aten::eq(%i.1, %4)
719 // = prim::If(%12)
720 // block0():
721 // = prim::Loop(%3, %8)
722 // block0(%j : int):
723 // %14 : Tensor = prim::GetAttr[name="_bias"](%1)
724 // %15 : Tensor = aten::add_(%14, %2, %9)
725 // = prim::SetAttr[name="_bias"](%1, %15)
726 // -> (%8)
727 // -> ()
728 // block1():
729 // %16 : Tensor = aten::arange(%6, %7, %7, %7, %7)
730 // = prim::SetAttr[name="_bias"](%1, %16)
731 // -> ()
732 // -> (%8)
733 // %17 : Tensor = prim::GetAttr[name="_bias"](%1)
734 // %18 : Tensor = aten::add(%17, %x.1, %9)
735 // return (%18)
736 //
737 // After graph(%x.1 : Float(12, strides=[1], requires_grad=0, device=cpu)):
738 // %19 : Float(2, strides=[1], requires_grad=0, device=cpu) = prim::Constant[value= 1 1 [ CPUFloatType{2} ]]()
739 // %1 : __torch__.___torch_mangle_1.M = prim::CreateObject()
740 // ...
741 // %10 : Tensor = aten::arange(%6, %7, %7, %7, %7)
742 // %28 : Tensor = aten::set_(%19, %10)
743 // = prim::Loop(%5, %8)
744 // block0(%i.1 : int):
745 // %12 : bool = aten::eq(%i.1, %4)
746 // = prim::If(%12)
747 // block0():
748 // = prim::Loop(%3, %8)
749 // block0(%j : int):
750 // %15 : Tensor = aten::add_(%19, %2, %9)
751 // %25 : Tensor = aten::set_(%19, %15)
752 // -> (%8)
753 // -> ()
754 // block1():
755 // %16 : Tensor = aten::arange(%6, %7, %7, %7, %7)
756 // %22 : Tensor = aten::set_(%19, %16)
757 // -> ()
758 // -> (%8)
759 // %18 : Tensor = aten::add(%19, %x.1, %9)
760 // return (%18)
761 // clang-format on
replaceAttrWithInplaceOps(Block * block,const std::unordered_map<std::string,Value * > & attr_name_value_map,const std::unordered_map<Node *,std::string> & attr_node_fullname_map)762 void InplaceConverter::replaceAttrWithInplaceOps(
763 Block* block,
764 const std::unordered_map<std::string, Value*>& attr_name_value_map,
765 const std::unordered_map<Node*, std::string>& attr_node_fullname_map) {
766 for (const auto& pair : attr_node_fullname_map) {
767 auto* n = pair.first;
768 auto fullName = pair.second;
769 auto find_init_val = attr_name_value_map.find(fullName);
770 TORCH_INTERNAL_ASSERT(find_init_val != attr_name_value_map.end());
771
772 TORCH_INTERNAL_ASSERT(
773 n->kind() == prim::GetAttr || n->kind() == prim::SetAttr);
774 if (n->kind() == prim::SetAttr) {
775 // Convert SetAttr to inplace op aten::set_.
776 WithInsertPoint guard(n);
777 auto* set_node = graph_->create(aten::set_, 1);
778 set_node->addInput(find_init_val->second);
779 set_node->addInput(n->input(1));
780 set_node->copyMetadata(n);
781 set_node->insertBefore(n);
782 } else if (n->kind() == prim::GetAttr) {
783 // Replace use of GetAttr with first seen alias (usually initial value) of
784 // that particular value. Correct alias at point of this node will be
785 // discovered and assigned in later pass.
786 n->output()->replaceAllUsesWith(find_init_val->second);
787 }
788
789 n->destroy();
790 }
791 }
792
convertGetSetAttrToInplaceOps(Block * block)793 void InplaceConverter::convertGetSetAttrToInplaceOps(Block* block) {
794 std::unordered_map<std::string, Value*> attr_name_value_map = {};
795 std::unordered_map<Node*, std::string> attr_node_fullname_map = {};
796 // First pass over graph, to gather all attribute names, and their initial
797 // values. Create dummy initial values for attributes if necessary. By the end
798 // of this pass, these dummy initial values should have zero uses, and can be
799 // safely removed. Otherwise it will imply an error in the model for using
800 // uninitialized values.
801 gatherAttrNameInitialValueMap(
802 block, attr_name_value_map, attr_node_fullname_map);
803 GRAPH_UPDATE("Graph after gatherAttrNameInitialValueMap", graph_->toString());
804
805 // Second pass over graph,
806 // replace GetAttr with first seen alias (usually initial value),
807 // and replace SetAttr with inplace op, updating new value onto first seen
808 // alias.
809 replaceAttrWithInplaceOps(block, attr_name_value_map, attr_node_fullname_map);
810 }
811
812 // Convert inplace ops to outplace version, and record the associated new alias
813 // in ValueTracker.
convertInplaceOpsAndTrackAlias(Block * block)814 void InplaceConverter::convertInplaceOpsAndTrackAlias(Block* block) {
815 for (auto it = block->nodes().begin(); it != block->nodes().end();) {
816 Node* n = *it;
817 it++; // node n can be destroyed
818
819 auto nkind = n->kind();
820 if (nkind == prim::If || nkind == prim::Loop) {
821 for (Block* sub_block : n->blocks()) {
822 convertInplaceOpsAndTrackAlias(sub_block);
823 }
824 } else {
825 Value *orig_data = nullptr, *new_out = nullptr;
826 if (nkind == aten::copy_) {
827 std::tie(orig_data, new_out) = PrepareCopyForONNX(n);
828 } else if (nkind == aten::index_put || nkind == aten::index_put_) {
829 std::tie(orig_data, new_out) = PrepareIndexPutForONNX(n);
830 if (nkind == aten::index_put) {
831 // special case, index_put is not inplace.
832 continue;
833 }
834 } else if (nkind == aten::insert || nkind == aten::append) {
835 std::tie(orig_data, new_out) = PrepareListAppendAndInsertForONNX(n);
836 } else if (nkind == aten::set_) {
837 std::tie(orig_data, new_out) = PrepareSetForONNX(n);
838 } else if (mr_->inplaceOpVariant(n)) {
839 std::tie(orig_data, new_out) = PrepareInplaceOpsInBlocksForONNX(n);
840 } else if (nkind == aten::pop) {
841 std::tie(orig_data, new_out) = PrepareListPopForONNX(n);
842 } else if (nkind == aten::Delete) {
843 std::tie(orig_data, new_out) = PrepareListDeleteForONNX(n);
844 } else if (nkind == aten::_set_item) {
845 std::tie(orig_data, new_out) = PrepareSetItemForONNX(n);
846 } else {
847 // Not inplace op.
848 continue;
849 }
850
851 if (nullptr != orig_data && nullptr != new_out) {
852 vt_.recordSetValue(orig_data, new_out);
853 }
854 }
855 }
856 }
857
convertInplaceOpsAndTrackAlias()858 void InplaceConverter::convertInplaceOpsAndTrackAlias() {
859 convertInplaceOpsAndTrackAlias(graph_->block());
860 GRAPH_UPDATE(
861 "Graph after convertInplaceOpsAndTrackAlias: ", graph_->toString());
862 GRAPH_UPDATE(vt_.toString());
863 }
864
convertMutationForONNX()865 void InplaceConverter::convertMutationForONNX() {
866 // First pass to convert all prim::GetAttr and prim::SetAttr to ATen inplace
867 // operators.
868 convertGetSetAttrToInplaceOps(graph_->block());
869 GRAPH_UPDATE("Graph after convertGetSetAttrToInplaceOps", graph_->toString());
870 vt_.init(graph_);
871 // Second pass to convert all inplace operators to outplace version, and
872 // record the associated new alias in ValueTracker.
873 convertInplaceOpsAndTrackAlias();
874 // Third pass to check and correct alias reference for all the nodes.
875 correctAliasReferences();
876 }
877
878 } // namespace
879
RemoveInplaceOpsForONNX(const std::shared_ptr<Graph> & graph,Module * model=nullptr)880 void RemoveInplaceOpsForONNX(
881 const std::shared_ptr<Graph>& graph,
882 Module* model = nullptr) {
883 ImplicitCastForBinaryInplaceOps(graph->block());
884 PrepareForRemoveMutations(graph);
885 MutationRemover mr(graph);
886 mr.removeTensorMutation();
887 mr.removeListMutation();
888 InplaceConverter ic(graph, &mr, model);
889 ic.convertMutationForONNX();
890 }
891
892 } // namespace torch::jit
893