xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/remove_inplace_ops.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/passes/remove_inplace_ops.h>
2 #include <iostream>
3 
4 namespace torch::jit {
5 namespace {
6 static const std::unordered_map<NodeKind, NodeKind> inPlaceToOutOfPlace = {
7     {aten::add_, aten::add},
8     {aten::sub_, aten::sub},
9     {aten::div_, aten::div},
10     {aten::mul_, aten::mul},
11     {aten::masked_fill_, aten::masked_fill},
12     {aten::zero_, aten::zeros_like},
13     {aten::fill_, aten::full_like}};
14 
15 // This is a horrible no good awful hack to "fill in" the TensorOptions
16 // arguments of zeros_like and full_like so that the defaults are filled
17 // in.  Ugh.  Would be better to just run the frontend to get the correct
18 // arity here.
19 static const std::unordered_map<NodeKind, int> expectedInputCount = {
20     {aten::zero_, 6},
21     {aten::fill_, 7}};
22 
isInplaceOp(const Node * node)23 bool isInplaceOp(const Node* node) {
24   return inPlaceToOutOfPlace.count(node->kind()) != 0;
25 }
26 
27 // Remove all in-place ops and replace them with out-of-place equivalents.
28 // e.g.
29 //   %foo = aten::add_(%foo, %n)
30 // becomes
31 //   %foo.2 = aten::add(%foo, %n)
32 //
33 // NOTE: this is NOT SAFE, since it assumes that the LHS is not aliased by
34 // another value. This is only to avoid breaking ONNX export; when alias
35 // analysis is done we can emit a warning if someone tries to export.
RemoveInplaceOps(Block * block)36 void RemoveInplaceOps(Block* block) {
37   auto graph = block->owningGraph();
38   auto it = block->nodes().begin();
39   while (it != block->nodes().end()) {
40     auto node = *it;
41     ++it;
42     for (auto block : node->blocks()) {
43       RemoveInplaceOps(block);
44     }
45 
46     if (isInplaceOp(node)) {
47       // create a replacement out of place op
48       auto newNode = graph->create(inPlaceToOutOfPlace.at(node->kind()));
49       newNode->insertBefore(node);
50       newNode->copyMetadata(node);
51       // copy inputs
52       for (auto input : node->inputs()) {
53         newNode->addInput(input);
54       }
55 
56       int additionalInputCount = 0;
57       if (expectedInputCount.find(node->kind()) != expectedInputCount.end()) {
58         additionalInputCount = expectedInputCount.at(node->kind()) -
59             static_cast<int>(newNode->inputs().size());
60       }
61 
62       for (int i = 0; i < additionalInputCount; ++i) {
63         auto noneNode = graph->createNone();
64         noneNode->insertBefore(newNode);
65         newNode->addInput(noneNode->output());
66       }
67 
68       // Create a new output node and replace all uses of self with it
69       newNode->output()->copyMetadata(node->output());
70       node->replaceAllUsesWith(newNode);
71       node->inputs()[0]->replaceAllUsesAfterNodeWith(
72           newNode, newNode->output());
73       node->destroy();
74     }
75   }
76 }
77 } // namespace
78 
79 // Handles special case of binary inplace ops, where the first input node
80 // has a lower type precedence than the second input node. When the
81 // inplace node is converted to a regular op, this information is lost and
82 // the resulting type is based on type precedence, just like regular ops.
83 // To avoid this loss of information, we add a cast node before the input
84 // node with the higher data type precedence, so that both the input types
85 // are the same.
86 // An example scenario would be:
87 // Before:
88 // graph(%0 : Float),
89 //        %1 : Half):
90 //   # Should result in a Half, but after translation to out-of-place,
91 //   # would become a Float b/c Half+Float -> Float.
92 //   %4 : Float = onnx::Cast[to=1](%1)
93 //   %5 : Float = onnx::Add(%4, %0)
94 //   ...
95 // After:
96 // graph(%0 : Float),
97 //        %1 : Half):
98 //   %4 : Half = onnx::Cast[to=10](%0)
99 //   %5 : Half = onnx::Add(%1, %4)
100 //   ...
101 
ImplicitCastForBinaryInplaceOps(Block * b)102 void ImplicitCastForBinaryInplaceOps(Block* b) {
103   for (auto it = b->nodes().begin(), end = b->nodes().end(); it != end; ++it) {
104     for (auto* child_block : it->blocks()) {
105       ImplicitCastForBinaryInplaceOps(child_block);
106     }
107 
108     // Check type if inplace operation is a binary node
109     if ((it->kind() == aten::add_) || (it->kind() == aten::sub_) ||
110         (it->kind() == aten::mul_) || (it->kind() == aten::div_)) {
111       auto originalInputs = it->inputs();
112       if (originalInputs.at(0) == originalInputs.at(1)) {
113         continue;
114       }
115 
116       auto shape_node = originalInputs.at(0)->node();
117       if ((shape_node->kind() == prim::NumToTensor) &&
118           (shape_node->inputs().at(0)->node()->kind() == aten::size)) {
119         std::cerr
120             << "In-place op on output of tensor.shape. See https://pytorch.org/docs/main/onnx.html#"
121             << "avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode"
122             << '\n';
123       }
124 
125       TensorTypePtr firstInp_tensor =
126           originalInputs.at(0)->type()->cast<TensorType>();
127       TensorTypePtr secondInp_tensor =
128           originalInputs.at(1)->type()->cast<TensorType>();
129       if (!(firstInp_tensor) || !(secondInp_tensor) ||
130           !(firstInp_tensor->scalarType().has_value())) {
131         continue;
132       }
133       auto newInputNode = it->owningGraph()->create(aten::type_as, 1);
134       newInputNode->insertBefore(*it);
135       newInputNode->addInput(originalInputs.at(1));
136       newInputNode->addInput(originalInputs.at(0));
137       it->replaceInput(1, newInputNode->outputs().at(0));
138     }
139   }
140 }
141 
RemoveInplaceOps(const std::shared_ptr<Graph> & graph)142 void RemoveInplaceOps(const std::shared_ptr<Graph>& graph) {
143   ImplicitCastForBinaryInplaceOps(graph->block());
144   RemoveInplaceOps(graph->block());
145 }
146 } // namespace torch::jit
147