xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/peephole_alias_sensitive.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/core/jit_type.h>
2 #include <torch/csrc/jit/ir/alias_analysis.h>
3 #include <torch/csrc/jit/ir/ir_views.h>
4 #include <torch/csrc/jit/jit_log.h>
5 #include <torch/csrc/jit/passes/dead_code_elimination.h>
6 #include <torch/csrc/jit/passes/peephole.h>
7 #include <torch/csrc/jit/passes/peephole_alias_sensitive.h>
8 #include <torch/csrc/jit/runtime/graph_executor.h>
9 #include <unordered_set>
10 
11 namespace torch::jit {
12 
13 // This pass only does optimizations which requires Alias Analysis
14 // It is separated out from Peephole Pass so that Peephole does not have
15 // maintain alias db correctness throughout the pass.
16 struct PeepholeOptimizeAliasSensitiveImpl {
PeepholeOptimizeAliasSensitiveImpltorch::jit::PeepholeOptimizeAliasSensitiveImpl17   PeepholeOptimizeAliasSensitiveImpl(
18       std::shared_ptr<Graph> graph,
19       bool shape_peepholes)
20       : graph_(std::move(graph)),
21         aliasDb_(std::make_unique<AliasDb>(graph_)),
22         shape_peepholes_(shape_peepholes) {}
23 
runtorch::jit::PeepholeOptimizeAliasSensitiveImpl24   bool run() {
25     return runBlock(graph_->block());
26   }
27 
28  private:
replaceWithIValuetorch::jit::PeepholeOptimizeAliasSensitiveImpl29   void replaceWithIValue(Value* v, const IValue& val) {
30     WithInsertPoint guard(v->node());
31     v->replaceAllUsesWith(v->owningGraph()->insertConstant(val));
32   }
33 
isFloatingPointtorch::jit::PeepholeOptimizeAliasSensitiveImpl34   bool isFloatingPoint(TensorType& t) {
35     auto input_dtype = t.scalarType();
36     return (
37         shape_peepholes_ && input_dtype && at::isFloatingType(*input_dtype));
38   }
39 
runBlocktorch::jit::PeepholeOptimizeAliasSensitiveImpl40   bool runBlock(Block* block) {
41     bool changed = false;
42     for (Node* node : block->nodes()) {
43       for (Block* b : node->blocks()) {
44         changed |= runBlock(b);
45       }
46 
47       // dim(conv(x)) extremely common and prevents Conv->BN fusion
48       if (node->kind() == aten::conv1d || node->kind() == aten::conv2d ||
49           node->kind() == aten::conv3d) {
50         auto dim_uses = c10::filter(node->output()->uses(), [](const Use& use) {
51           return use.user->kind() == aten::dim;
52         });
53         if (dim_uses.empty()) {
54           continue;
55         }
56         auto kind = node->kind();
57         int64_t output_size =
58             kind == aten::conv1d ? 3 : (kind == aten::conv2d ? 4 : 5);
59         // This is to handle potential resize_ calls, however unlikely.
60         // If we add more checks related to resize_ in the graph,
61         // factor this out like collectResizeSet in shape_analysis.
62         if (!aliasDb_->hasWriters(node->output())) {
63           for (const Use& dim_use : dim_uses) {
64             replaceWithIValue(dim_use.user->output(), output_size);
65           }
66           changed = true;
67         } else {
68           for (const Use& dim_use : dim_uses) {
69             if (aliasDb_->moveAfterTopologicallyValid(node, dim_use.user)) {
70               replaceWithIValue(dim_use.user->output(), output_size);
71               changed = true;
72             }
73           }
74         }
75         continue;
76       } else if (
77           node->matches(
78               "aten::add(Tensor self, Scalar other, Scalar alpha) -> Tensor",
79               /*const_inputs=*/{attr::alpha, attr::other}) ||
80           node->matches(
81               "aten::sub(Tensor self, Scalar other, Scalar alpha) -> Tensor",
82               /*const_inputs=*/{attr::alpha, attr::other})) {
83         // x + 0 == x - 0 == x
84         // if either scalar input is a float, than removing this operator could
85         // remove type promotion and affect semantics
86         if (!isFloatingPoint(node->input(0)->type()->expectRef<TensorType>())) {
87           auto inps = node->inputs();
88           if (!inps.at(1)->type()->isSubtypeOf(IntType::get()) ||
89               !inps.at(2)->type()->isSubtypeOf(IntType::get())) {
90             continue;
91           }
92         }
93 
94         if (node->get<at::Scalar>(attr::alpha)->toDouble() == 1 &&
95             node->get<at::Scalar>(attr::other)->toDouble() == 0) {
96           if (tryToReplaceOutputWithInput(node->input(0), node->output())) {
97             GRAPH_UPDATE(
98                 getHeader(node),
99                 " (x + 0 == x - 0 == x) is replaced with ",
100                 node->input(0)->debugName());
101             node->output()->replaceAllUsesWith(node->input(0));
102             changed = true;
103           }
104         }
105       } else if (
106           node->matches(
107               "aten::mul(Tensor self, Scalar other) -> Tensor",
108               /*const_inputs=*/attr::other) ||
109           node->matches(
110               "aten::div(Tensor self, Scalar other) -> Tensor",
111               /*const_inputs=*/attr::other)) {
112         // x * 1 == x / 1 == x
113         // is the node is a division or other isn't an integer, than removing
114         // this operator could remove type promotion and affect semantics
115         if (!isFloatingPoint(node->input(0)->type()->expectRef<TensorType>())) {
116           if (node->kind() == aten::div ||
117               !node->input(1)->type()->isSubtypeOf(IntType::get())) {
118             continue;
119           }
120         }
121 
122         if (node->get<at::Scalar>(attr::other)->toDouble() == 1) {
123           if (tryToReplaceOutputWithInput(node->input(0), node->output())) {
124             GRAPH_UPDATE(
125                 getHeader(node),
126                 " (x * 1 == x / 1 == x) is replaced with ",
127                 node->input(0)->debugName());
128 
129             changed = true;
130           }
131         }
132       }
133     }
134     return changed;
135   }
136 
tryToReplaceOutputWithInputtorch::jit::PeepholeOptimizeAliasSensitiveImpl137   bool tryToReplaceOutputWithInput(Value* input, Value* output) {
138     if (!aliasDb_->safeToChangeAliasingRelationship(input, output)) {
139       return false;
140     }
141     // whenever we replace an output with an input, all of the aliasing
142     // properties of the output are now present on the input.
143     // For example, if the output aliases a graph output, the input will now
144     // as well.
145     // in order to avoid re-instantiating an alias db on each change, which
146     // would be O(n^2), or inplace modifying it, which would involve
147     // invalidating all of the memory dag caches, we just keep a set of values
148     // which are "stale" (aliasing properties not up to date), and avoid doing
149     // further optimizations on values which alias them
150     if (aliasDb_->mayAlias({input, output}, stale_alias_values_)) {
151       return false;
152     }
153     output->replaceAllUsesWith(input);
154     stale_alias_values_.insert(input);
155     stale_alias_values_.insert(output);
156     return true;
157   }
158 
159   ValueSet stale_alias_values_;
160   std::shared_ptr<Graph> graph_;
161   std::unique_ptr<AliasDb> aliasDb_;
162   bool shape_peepholes_;
163 };
164 
PeepholeOptimizeAliasSensitive(const std::shared_ptr<Graph> & graph,bool shape_peepholes)165 bool PeepholeOptimizeAliasSensitive(
166     const std::shared_ptr<Graph>& graph,
167     bool shape_peepholes) {
168   PeepholeOptimizeAliasSensitiveImpl opt(graph, shape_peepholes);
169   return opt.run();
170 }
171 
172 } // namespace torch::jit
173