1 #include <ATen/core/jit_type.h>
2 #include <torch/csrc/jit/ir/alias_analysis.h>
3 #include <torch/csrc/jit/ir/ir_views.h>
4 #include <torch/csrc/jit/jit_log.h>
5 #include <torch/csrc/jit/passes/dead_code_elimination.h>
6 #include <torch/csrc/jit/passes/peephole.h>
7 #include <torch/csrc/jit/passes/peephole_alias_sensitive.h>
8 #include <torch/csrc/jit/runtime/graph_executor.h>
9 #include <unordered_set>
10
11 namespace torch::jit {
12
13 // This pass only does optimizations which requires Alias Analysis
14 // It is separated out from Peephole Pass so that Peephole does not have
15 // maintain alias db correctness throughout the pass.
16 struct PeepholeOptimizeAliasSensitiveImpl {
PeepholeOptimizeAliasSensitiveImpltorch::jit::PeepholeOptimizeAliasSensitiveImpl17 PeepholeOptimizeAliasSensitiveImpl(
18 std::shared_ptr<Graph> graph,
19 bool shape_peepholes)
20 : graph_(std::move(graph)),
21 aliasDb_(std::make_unique<AliasDb>(graph_)),
22 shape_peepholes_(shape_peepholes) {}
23
runtorch::jit::PeepholeOptimizeAliasSensitiveImpl24 bool run() {
25 return runBlock(graph_->block());
26 }
27
28 private:
replaceWithIValuetorch::jit::PeepholeOptimizeAliasSensitiveImpl29 void replaceWithIValue(Value* v, const IValue& val) {
30 WithInsertPoint guard(v->node());
31 v->replaceAllUsesWith(v->owningGraph()->insertConstant(val));
32 }
33
isFloatingPointtorch::jit::PeepholeOptimizeAliasSensitiveImpl34 bool isFloatingPoint(TensorType& t) {
35 auto input_dtype = t.scalarType();
36 return (
37 shape_peepholes_ && input_dtype && at::isFloatingType(*input_dtype));
38 }
39
runBlocktorch::jit::PeepholeOptimizeAliasSensitiveImpl40 bool runBlock(Block* block) {
41 bool changed = false;
42 for (Node* node : block->nodes()) {
43 for (Block* b : node->blocks()) {
44 changed |= runBlock(b);
45 }
46
47 // dim(conv(x)) extremely common and prevents Conv->BN fusion
48 if (node->kind() == aten::conv1d || node->kind() == aten::conv2d ||
49 node->kind() == aten::conv3d) {
50 auto dim_uses = c10::filter(node->output()->uses(), [](const Use& use) {
51 return use.user->kind() == aten::dim;
52 });
53 if (dim_uses.empty()) {
54 continue;
55 }
56 auto kind = node->kind();
57 int64_t output_size =
58 kind == aten::conv1d ? 3 : (kind == aten::conv2d ? 4 : 5);
59 // This is to handle potential resize_ calls, however unlikely.
60 // If we add more checks related to resize_ in the graph,
61 // factor this out like collectResizeSet in shape_analysis.
62 if (!aliasDb_->hasWriters(node->output())) {
63 for (const Use& dim_use : dim_uses) {
64 replaceWithIValue(dim_use.user->output(), output_size);
65 }
66 changed = true;
67 } else {
68 for (const Use& dim_use : dim_uses) {
69 if (aliasDb_->moveAfterTopologicallyValid(node, dim_use.user)) {
70 replaceWithIValue(dim_use.user->output(), output_size);
71 changed = true;
72 }
73 }
74 }
75 continue;
76 } else if (
77 node->matches(
78 "aten::add(Tensor self, Scalar other, Scalar alpha) -> Tensor",
79 /*const_inputs=*/{attr::alpha, attr::other}) ||
80 node->matches(
81 "aten::sub(Tensor self, Scalar other, Scalar alpha) -> Tensor",
82 /*const_inputs=*/{attr::alpha, attr::other})) {
83 // x + 0 == x - 0 == x
84 // if either scalar input is a float, than removing this operator could
85 // remove type promotion and affect semantics
86 if (!isFloatingPoint(node->input(0)->type()->expectRef<TensorType>())) {
87 auto inps = node->inputs();
88 if (!inps.at(1)->type()->isSubtypeOf(IntType::get()) ||
89 !inps.at(2)->type()->isSubtypeOf(IntType::get())) {
90 continue;
91 }
92 }
93
94 if (node->get<at::Scalar>(attr::alpha)->toDouble() == 1 &&
95 node->get<at::Scalar>(attr::other)->toDouble() == 0) {
96 if (tryToReplaceOutputWithInput(node->input(0), node->output())) {
97 GRAPH_UPDATE(
98 getHeader(node),
99 " (x + 0 == x - 0 == x) is replaced with ",
100 node->input(0)->debugName());
101 node->output()->replaceAllUsesWith(node->input(0));
102 changed = true;
103 }
104 }
105 } else if (
106 node->matches(
107 "aten::mul(Tensor self, Scalar other) -> Tensor",
108 /*const_inputs=*/attr::other) ||
109 node->matches(
110 "aten::div(Tensor self, Scalar other) -> Tensor",
111 /*const_inputs=*/attr::other)) {
112 // x * 1 == x / 1 == x
113 // is the node is a division or other isn't an integer, than removing
114 // this operator could remove type promotion and affect semantics
115 if (!isFloatingPoint(node->input(0)->type()->expectRef<TensorType>())) {
116 if (node->kind() == aten::div ||
117 !node->input(1)->type()->isSubtypeOf(IntType::get())) {
118 continue;
119 }
120 }
121
122 if (node->get<at::Scalar>(attr::other)->toDouble() == 1) {
123 if (tryToReplaceOutputWithInput(node->input(0), node->output())) {
124 GRAPH_UPDATE(
125 getHeader(node),
126 " (x * 1 == x / 1 == x) is replaced with ",
127 node->input(0)->debugName());
128
129 changed = true;
130 }
131 }
132 }
133 }
134 return changed;
135 }
136
tryToReplaceOutputWithInputtorch::jit::PeepholeOptimizeAliasSensitiveImpl137 bool tryToReplaceOutputWithInput(Value* input, Value* output) {
138 if (!aliasDb_->safeToChangeAliasingRelationship(input, output)) {
139 return false;
140 }
141 // whenever we replace an output with an input, all of the aliasing
142 // properties of the output are now present on the input.
143 // For example, if the output aliases a graph output, the input will now
144 // as well.
145 // in order to avoid re-instantiating an alias db on each change, which
146 // would be O(n^2), or inplace modifying it, which would involve
147 // invalidating all of the memory dag caches, we just keep a set of values
148 // which are "stale" (aliasing properties not up to date), and avoid doing
149 // further optimizations on values which alias them
150 if (aliasDb_->mayAlias({input, output}, stale_alias_values_)) {
151 return false;
152 }
153 output->replaceAllUsesWith(input);
154 stale_alias_values_.insert(input);
155 stale_alias_values_.insert(output);
156 return true;
157 }
158
159 ValueSet stale_alias_values_;
160 std::shared_ptr<Graph> graph_;
161 std::unique_ptr<AliasDb> aliasDb_;
162 bool shape_peepholes_;
163 };
164
PeepholeOptimizeAliasSensitive(const std::shared_ptr<Graph> & graph,bool shape_peepholes)165 bool PeepholeOptimizeAliasSensitive(
166 const std::shared_ptr<Graph>& graph,
167 bool shape_peepholes) {
168 PeepholeOptimizeAliasSensitiveImpl opt(graph, shape_peepholes);
169 return opt.run();
170 }
171
172 } // namespace torch::jit
173