1 #include <torch/csrc/jit/passes/peephole.h>
2 #include <torch/csrc/jit/passes/peephole_non_tensor.h>
3
4 #include <ATen/core/jit_type.h>
5 #include <c10/util/irange.h>
6 #include <torch/csrc/jit/ir/ir_views.h>
7 #include <torch/csrc/jit/jit_log.h>
8
9 namespace torch::jit {
10
11 namespace {
12
13 /**
14 * Check whether the arithmetic node is binary between integers, and return a
15 * constant int value if there exists one.
16 *
17 * @pre node is integer arithmetic.
18 * @post if there's one constant in two operands, then the second operand is
19 * constant.
20 */
checkArithNode(Node & node)21 std::optional<int64_t> checkArithNode(Node& node) {
22 if (node.inputs().size() != 2 || node.input(0)->type() != IntType::get() ||
23 node.input(1)->type() != IntType::get()) {
24 return {};
25 }
26
27 if (node.kind() == aten::mul || node.kind() == aten::add) {
28 if (auto i = constant_as<int64_t>(node.input(0))) {
29 node.permuteInputs({1, 0});
30 return i;
31 }
32 }
33
34 return constant_as<int64_t>(node.input(1));
35 }
36
37 /**
38 * Remove a mul/floordiv node if it is multiplication or division by 1.
39 *
40 * @pre node is either aten::mul, aten::floordiv or aten::div
41 */
trySimplifyMulOrDiv(Node & node)42 bool trySimplifyMulOrDiv(Node& node) {
43 auto constant = checkArithNode(node);
44 if (!constant || *constant != 1) {
45 return false;
46 }
47
48 node.output()->replaceAllUsesWith(node.inputs()[0]);
49 return true;
50 }
51
52 /**
53 * Simplify an add/sub node with its input node, i.e. merge the constant parts
54 * together.
55 *
56 * @pre node is either aten::add or aten::sub
57 */
trySimplifyAddOrSub(Node & node)58 bool trySimplifyAddOrSub(Node& node) {
59 auto constant = checkArithNode(node);
60 if (!constant) {
61 return false;
62 }
63
64 if (constant == 0) {
65 node.output()->replaceAllUsesWith(node.input(0));
66 return true;
67 }
68
69 auto& dep = *node.inputs()[0]->node();
70 if (dep.kind() != aten::add && dep.kind() != aten::sub) {
71 return false;
72 }
73
74 auto delta = checkArithNode(dep);
75 if (!delta) {
76 return false;
77 }
78 auto merged =
79 dep.kind() == node.kind() ? *constant + *delta : *constant - *delta;
80
81 if (merged == 0) {
82 node.output()->replaceAllUsesWith(dep.inputs()[0]);
83 } else {
84 WithInsertPoint g(&node);
85 node.replaceInput(0, dep.inputs()[0]);
86 node.replaceInput(1, node.owningGraph()->insertConstant(merged));
87 }
88 return true;
89 }
90
91 } // namespace
92
93 struct PeepholeOptimizeNonTensorImpl {
PeepholeOptimizeNonTensorImpltorch::jit::PeepholeOptimizeNonTensorImpl94 PeepholeOptimizeNonTensorImpl(std::shared_ptr<Graph> graph)
95 : graph_(std::move(graph)) {}
96
runtorch::jit::PeepholeOptimizeNonTensorImpl97 bool run() {
98 return optimizeBlock(graph_->block());
99 }
100
optimizeBlocktorch::jit::PeepholeOptimizeNonTensorImpl101 bool optimizeBlock(Block* block) {
102 bool changed = false;
103 for (auto it = block->nodes().begin(); it != block->nodes().end(); ++it) {
104 auto* node = *it;
105
106 for (Block* sub_block : node->blocks()) {
107 changed |= optimizeBlock(sub_block);
108 }
109
110 if (node->kind() != prim::Constant) {
111 WithInsertPoint guard(node);
112 // Any Value whose type is None should be replaced with a Constant
113 // This can occur if a module has an optional attribute, and it is
114 // initialized as None.
115 for (Value* output : node->outputs()) {
116 if (output->type()->cast<NoneType>()) {
117 output->replaceAllUsesWith(graph_->insertConstant(IValue()));
118 changed = true;
119 }
120 }
121 }
122 // XXX: remember that if you want to simplify an expression by combining
123 // multiple nodes into a different one, then you need to check that they
124 // all belong to the given block
125 // TODO: this doesn't work with Scalar-Tensor ops! We should
126 // canonicalize those
127 if (node->kind() == prim::If) {
128 IfView n(node);
129 // this handles redundant short circuits like "x and True" or "x or
130 // False"
131 for (const auto i : c10::irange(n.outputs().size())) {
132 if (n.outputs().at(i)->type() != BoolType::get()) {
133 continue;
134 }
135 bool true_val =
136 constant_as<bool>(n.thenOutputs().at(i)).value_or(false);
137 bool false_val =
138 constant_as<bool>(n.elseOutputs().at(i)).value_or(true);
139 // if an if node's output equals its condition replace output with
140 // condition
141 if (true_val && !false_val) {
142 GRAPH_UPDATE(
143 "Replacing ",
144 n.outputs().at(i)->debugName(),
145 " (True or False) with ",
146 n.cond()->debugName());
147 n.outputs().at(i)->replaceAllUsesWith(n.cond());
148 changed = true;
149 }
150 }
151
152 // check for types that can be refined
153 for (size_t i = 0; i < n.outputs().size(); ++i) {
154 // common case of optional for now
155 bool inputs_non_optional =
156 !n.thenOutputs().at(i)->type()->cast<OptionalType>() &&
157 !n.elseOutputs().at(i)->type()->cast<OptionalType>();
158 auto output_optional =
159 n.outputs().at(i)->type()->cast<OptionalType>();
160 if (inputs_non_optional && output_optional) {
161 if (auto unif = unifyTypes(
162 n.thenOutputs().at(i)->type(),
163 n.elseOutputs().at(i)->type())) {
164 n.outputs().at(i)->setType(*unif);
165 changed = true;
166 }
167 }
168 }
169 } else if (
170 node->kind() == aten::__is__ || node->kind() == aten::__isnot__) {
171 // if we are comparing a None value with a value that can't be None
172 // replace the output with true if node is __isnot__ or false if node is
173 // __is__
174 AT_ASSERT(node->inputs().size() == 2);
175 for (size_t check_none_index : {0, 1}) {
176 bool input_must_be_none =
177 node->inputs().at(check_none_index)->mustBeNone();
178 bool other_must_not_be_none =
179 node->inputs().at(1 - check_none_index)->mustNotBeNone();
180 if (input_must_be_none && other_must_not_be_none) {
181 WithInsertPoint guard(node);
182 auto output = node->owningGraph()->insertConstant(
183 node->kind() == aten::__isnot__);
184 GRAPH_UPDATE(
185 "Folding ", getHeader(node), " to ", output->debugName());
186 node->output()->replaceAllUsesWith(output);
187 changed = true;
188 }
189 }
190 } else if (
191 node->kind() == prim::unchecked_unwrap_optional ||
192 node->kind() == aten::_unwrap_optional) {
193 // we are unwrapping an input that can't be None, remove the unwrap
194 auto input = node->input();
195 if (input->mustNotBeNone()) {
196 GRAPH_UPDATE(
197 "Unwrapping ",
198 getHeader(node),
199 " as ",
200 node->input(),
201 " can't be optional");
202 node->output()->replaceAllUsesWith(node->input());
203 changed = true;
204 }
205 } else if (node->kind() == prim::unchecked_cast) {
206 // unchecked_cast is not generated for tensor properties, so we are not
207 // losing anything by calling unshapedType here
208 auto input_type = unshapedType(node->input()->type());
209 auto output_type = unshapedType(node->output()->type());
210 if (input_type->isSubtypeOf(*output_type)) {
211 GRAPH_UPDATE(
212 "Removing ",
213 getHeader(node),
214 " as input type subtypes output type");
215 node->output()->replaceAllUsesWith(node->input());
216 changed = true;
217 }
218 } else if (
219 (node->kind() == aten::Int || node->kind() == aten::ceil) &&
220 node->inputs().size() == 1 &&
221 node->input()->type()->cast<IntType>()) {
222 GRAPH_UPDATE(
223 "Removing ", getHeader(node), " as input is already an integer");
224 node->output()->replaceAllUsesWith(node->input());
225 changed = true;
226 } else if (node->kind() == aten::ne || node->kind() == aten::eq) {
227 if (node->inputs().size() != 2 ||
228 node->inputs().at(0) != node->inputs().at(1)) {
229 continue;
230 }
231 auto inp_type = node->inputs().at(0)->type();
232 // only handling common immutable types here because other types like
233 // Tensor or list of Tensor might throw on aten::eq
234 auto immut_type = [&](const TypePtr& type) {
235 auto kind = type->kind();
236 static const std::vector<TypeKind> handled_immutable_types = {
237 TypeKind::BoolType,
238 TypeKind::IntType,
239 TypeKind::FloatType,
240 TypeKind::NoneType};
241 return (
242 std::find(
243 handled_immutable_types.begin(),
244 handled_immutable_types.end(),
245 kind) != handled_immutable_types.end());
246 };
247 bool non_throwing_type = false;
248 if (auto li_type = inp_type->cast<ListType>()) {
249 non_throwing_type = immut_type(li_type->getElementType());
250 } else if (auto di_type = inp_type->cast<DictType>()) {
251 non_throwing_type =
252 (immut_type(di_type->getKeyType()) &&
253 immut_type(di_type->getValueType()));
254 } else {
255 non_throwing_type = immut_type(inp_type);
256 }
257 if (non_throwing_type) {
258 WithInsertPoint guard(node);
259 node->output()->replaceAllUsesWith(
260 graph_->insertConstant(node->kind() == aten::eq));
261 changed = true;
262 }
263 } else if (
264 node->kind() == aten::mul || node->kind() == aten::floordiv ||
265 node->kind() == aten::div) {
266 changed |= trySimplifyMulOrDiv(*node);
267 } else if (node->kind() == aten::add || node->kind() == aten::sub) {
268 changed |= trySimplifyAddOrSub(*node);
269 }
270 }
271 return changed;
272 }
273
274 private:
275 std::shared_ptr<Graph> graph_;
276 };
277
PeepholeOptimizeNonTensor(const std::shared_ptr<Graph> & graph)278 bool PeepholeOptimizeNonTensor(const std::shared_ptr<Graph>& graph) {
279 PeepholeOptimizeNonTensorImpl peephole(graph);
280 bool changed = peephole.run();
281 GRAPH_DUMP("After PeepholeOptimize: ", graph);
282 return changed;
283 }
284
285 } // namespace torch::jit
286