1 #include <torch/csrc/jit/passes/remove_inplace_ops.h>
2 #include <iostream>
3
4 namespace torch::jit {
5 namespace {
6 static const std::unordered_map<NodeKind, NodeKind> inPlaceToOutOfPlace = {
7 {aten::add_, aten::add},
8 {aten::sub_, aten::sub},
9 {aten::div_, aten::div},
10 {aten::mul_, aten::mul},
11 {aten::masked_fill_, aten::masked_fill},
12 {aten::zero_, aten::zeros_like},
13 {aten::fill_, aten::full_like}};
14
15 // This is a horrible no good awful hack to "fill in" the TensorOptions
16 // arguments of zeros_like and full_like so that the defaults are filled
17 // in. Ugh. Would be better to just run the frontend to get the correct
18 // arity here.
19 static const std::unordered_map<NodeKind, int> expectedInputCount = {
20 {aten::zero_, 6},
21 {aten::fill_, 7}};
22
isInplaceOp(const Node * node)23 bool isInplaceOp(const Node* node) {
24 return inPlaceToOutOfPlace.count(node->kind()) != 0;
25 }
26
27 // Remove all in-place ops and replace them with out-of-place equivalents.
28 // e.g.
29 // %foo = aten::add_(%foo, %n)
30 // becomes
31 // %foo.2 = aten::add(%foo, %n)
32 //
33 // NOTE: this is NOT SAFE, since it assumes that the LHS is not aliased by
34 // another value. This is only to avoid breaking ONNX export; when alias
35 // analysis is done we can emit a warning if someone tries to export.
RemoveInplaceOps(Block * block)36 void RemoveInplaceOps(Block* block) {
37 auto graph = block->owningGraph();
38 auto it = block->nodes().begin();
39 while (it != block->nodes().end()) {
40 auto node = *it;
41 ++it;
42 for (auto block : node->blocks()) {
43 RemoveInplaceOps(block);
44 }
45
46 if (isInplaceOp(node)) {
47 // create a replacement out of place op
48 auto newNode = graph->create(inPlaceToOutOfPlace.at(node->kind()));
49 newNode->insertBefore(node);
50 newNode->copyMetadata(node);
51 // copy inputs
52 for (auto input : node->inputs()) {
53 newNode->addInput(input);
54 }
55
56 int additionalInputCount = 0;
57 if (expectedInputCount.find(node->kind()) != expectedInputCount.end()) {
58 additionalInputCount = expectedInputCount.at(node->kind()) -
59 static_cast<int>(newNode->inputs().size());
60 }
61
62 for (int i = 0; i < additionalInputCount; ++i) {
63 auto noneNode = graph->createNone();
64 noneNode->insertBefore(newNode);
65 newNode->addInput(noneNode->output());
66 }
67
68 // Create a new output node and replace all uses of self with it
69 newNode->output()->copyMetadata(node->output());
70 node->replaceAllUsesWith(newNode);
71 node->inputs()[0]->replaceAllUsesAfterNodeWith(
72 newNode, newNode->output());
73 node->destroy();
74 }
75 }
76 }
77 } // namespace
78
79 // Handles special case of binary inplace ops, where the first input node
80 // has a lower type precedence than the second input node. When the
81 // inplace node is converted to a regular op, this information is lost and
82 // the resulting type is based on type precedence, just like regular ops.
83 // To avoid this loss of information, we add a cast node before the input
84 // node with the higher data type precedence, so that both the input types
85 // are the same.
86 // An example scenario would be:
87 // Before:
88 // graph(%0 : Float),
89 // %1 : Half):
90 // # Should result in a Half, but after translation to out-of-place,
91 // # would become a Float b/c Half+Float -> Float.
92 // %4 : Float = onnx::Cast[to=1](%1)
93 // %5 : Float = onnx::Add(%4, %0)
94 // ...
95 // After:
96 // graph(%0 : Float),
97 // %1 : Half):
98 // %4 : Half = onnx::Cast[to=10](%0)
99 // %5 : Half = onnx::Add(%1, %4)
100 // ...
101
ImplicitCastForBinaryInplaceOps(Block * b)102 void ImplicitCastForBinaryInplaceOps(Block* b) {
103 for (auto it = b->nodes().begin(), end = b->nodes().end(); it != end; ++it) {
104 for (auto* child_block : it->blocks()) {
105 ImplicitCastForBinaryInplaceOps(child_block);
106 }
107
108 // Check type if inplace operation is a binary node
109 if ((it->kind() == aten::add_) || (it->kind() == aten::sub_) ||
110 (it->kind() == aten::mul_) || (it->kind() == aten::div_)) {
111 auto originalInputs = it->inputs();
112 if (originalInputs.at(0) == originalInputs.at(1)) {
113 continue;
114 }
115
116 auto shape_node = originalInputs.at(0)->node();
117 if ((shape_node->kind() == prim::NumToTensor) &&
118 (shape_node->inputs().at(0)->node()->kind() == aten::size)) {
119 std::cerr
120 << "In-place op on output of tensor.shape. See https://pytorch.org/docs/main/onnx.html#"
121 << "avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode"
122 << '\n';
123 }
124
125 TensorTypePtr firstInp_tensor =
126 originalInputs.at(0)->type()->cast<TensorType>();
127 TensorTypePtr secondInp_tensor =
128 originalInputs.at(1)->type()->cast<TensorType>();
129 if (!(firstInp_tensor) || !(secondInp_tensor) ||
130 !(firstInp_tensor->scalarType().has_value())) {
131 continue;
132 }
133 auto newInputNode = it->owningGraph()->create(aten::type_as, 1);
134 newInputNode->insertBefore(*it);
135 newInputNode->addInput(originalInputs.at(1));
136 newInputNode->addInput(originalInputs.at(0));
137 it->replaceInput(1, newInputNode->outputs().at(0));
138 }
139 }
140 }
141
RemoveInplaceOps(const std::shared_ptr<Graph> & graph)142 void RemoveInplaceOps(const std::shared_ptr<Graph>& graph) {
143 ImplicitCastForBinaryInplaceOps(graph->block());
144 RemoveInplaceOps(graph->block());
145 }
146 } // namespace torch::jit
147