1 #include <torch/csrc/jit/passes/peephole.h>
2
3 #include <ATen/core/jit_type.h>
4 #include <c10/util/irange.h>
5 #include <torch/csrc/jit/ir/alias_analysis.h>
6 #include <torch/csrc/jit/ir/ir_views.h>
7 #include <torch/csrc/jit/jit_log.h>
8 #include <torch/csrc/jit/passes/concat_opt.h>
9 #include <torch/csrc/jit/passes/dead_code_elimination.h>
10 #include <torch/csrc/jit/passes/peephole_alias_sensitive.h>
11 #include <torch/csrc/jit/passes/peephole_dict_idioms.h>
12 #include <torch/csrc/jit/passes/peephole_list_idioms.h>
13 #include <torch/csrc/jit/passes/peephole_non_tensor.h>
14 #include <torch/csrc/jit/runtime/graph_executor.h>
15
16 namespace torch::jit {
17
18 // Conservatively compare two optionals. If both are undefined, assume
19 // they aren't equal
20 template <typename T>
mustBeEqual(const std::optional<T> & a,const std::optional<T> & b)21 static bool mustBeEqual(const std::optional<T>& a, const std::optional<T>& b) {
22 return a == b && a.has_value();
23 }
24
25 struct PeepholeOptimizeImpl {
PeepholeOptimizeImpltorch::jit::PeepholeOptimizeImpl26 PeepholeOptimizeImpl(
27 std::shared_ptr<Graph> graph,
28 bool disable_shape_peepholes)
29 : graph_(std::move(graph)), shape_peepholes_(!disable_shape_peepholes) {}
30
runtorch::jit::PeepholeOptimizeImpl31 bool run() {
32 bool changed = optimizeBlock(graph_->block());
33 changed |= PeepholeOptimizeListIdioms(graph_);
34 changed |= PeepholeOptimizeDictIdioms(graph_);
35 changed |= PeepholeOptimizeAliasSensitive(graph_, shape_peepholes_);
36 changed |= PeepholeOptimizeNonTensor(graph_);
37 changed |= CombineConcats(graph_);
38 return changed;
39 }
40
41 // The intent for this optimization pass is to catch all of the small, easy to
42 // catch peephole optimizations you might be interested in doing.
43 //
44 // TODO: Decide what kind of fixed point strategy we will have
optimizeBlocktorch::jit::PeepholeOptimizeImpl45 bool optimizeBlock(Block* block) {
46 bool changed = false;
47 for (auto it = block->nodes().begin(); it != block->nodes().end(); ++it) {
48 auto* node = *it;
49
50 for (Block* sub_block : node->blocks()) {
51 changed |= optimizeBlock(sub_block);
52 }
53
54 // XXX: remember that if you want to simplify an expression by combining
55 // multiple nodes into a different one, then you need to check that they
56 // all belong to the given block
57 // TODO: this doesn't work with Scalar-Tensor ops! We should
58 // canonicalize those
59 if (node->matches(
60 "aten::_grad_sum_to_size(Tensor(a) self, int[]? size) -> Tensor(a)")) {
61 // Eliminate no-op _grad_sum_to_size.
62 // TODO: this doesn't work with Scalar-Tensor ops! We should
63 // canonicalize those
64 if (node->input(1)->mustBeNone()) {
65 GRAPH_UPDATE(
66 getHeader(node),
67 " (x._grad_sum_to_size(x, None) == x) is replaced with ",
68 node->input(0)->debugName());
69 node->output()->replaceAllUsesWith(node->input(0));
70 changed = true;
71 } else {
72 auto uses = node->output()->uses();
73 for (Use u : uses) {
74 if (u.user->matches(
75 "aten::_grad_sum_to_size(Tensor(a) self, int[]? size) -> Tensor(a)") &&
76 u.user->input(1)->type()->isSubtypeOf(*ListType::ofInts())) {
77 GRAPH_UPDATE(
78 getHeader(node),
79 " (x._grad_sum_to_size(y)._grad_sum_to_size(z) == x._grad_sum_to_size(z)) is replaced with ",
80 node->inputs().at(0)->debugName());
81 u.user->replaceInput(0, node->inputs().at(0));
82 changed = true;
83 }
84 }
85 }
86 } else if (
87 node->matches(
88 "aten::expand(Tensor self, int[] size, *, bool implicit) -> Tensor",
89 /*const_inputs=*/attr::size)) {
90 // x.expand(x.size()) == x
91 auto input_type =
92 node->namedInput(attr::self)->type()->cast<TensorType>();
93 if (input_type && shape_peepholes_) {
94 auto expanded_sizes = node->get<c10::List<int64_t>>(attr::size);
95 auto input_type_sizes = input_type->sizes().concrete_sizes();
96 if (expanded_sizes.has_value() && input_type_sizes &&
97 expanded_sizes->vec() == *input_type_sizes) {
98 GRAPH_UPDATE(
99 getHeader(node),
100 " (x.expand(x.size()) == x) is replaced with ",
101 node->namedInput(attr::self)->debugName());
102 node->output()->replaceAllUsesWith(node->namedInput(attr::self));
103 changed = true;
104 }
105 }
106 } else if (node->matches("aten::t(Tensor self) -> Tensor")) {
107 // x.t().t() == x
108 Node* input_node = node->input()->node();
109 if (input_node->matches("aten::t(Tensor self) -> Tensor")) {
110 GRAPH_UPDATE(
111 getHeader(node),
112 " (x.t().t() == x) is replaced with ",
113 input_node->input()->debugName());
114 node->output()->replaceAllUsesWith(input_node->input());
115 changed = true;
116 }
117 } else if (
118 node->matches("aten::type_as(Tensor self, Tensor other) -> Tensor") &&
119 shape_peepholes_) {
120 // x.type_as(y) == x iff x.type() == y.type()
121 auto self_type = node->input(0)->type()->expect<TensorType>();
122 auto other_type = node->input(1)->type()->expect<TensorType>();
123 if (mustBeEqual(self_type->scalarType(), other_type->scalarType()) &&
124 mustBeEqual(self_type->device(), other_type->device())) {
125 GRAPH_UPDATE(
126 getHeader(node),
127 " (x.type_as(y) == x) is replaced with ",
128 node->input(0)->debugName());
129 node->output()->replaceAllUsesWith(node->input(0));
130 changed = true;
131 }
132 } else if (
133 node->kind() == aten::Float || node->kind() == aten::Int ||
134 node->kind() == aten::FloatImplicit ||
135 node->kind() == aten::IntImplicit ||
136 node->kind() == aten::ScalarImplicit) {
137 Node* input_node = node->input()->node();
138 if (input_node->kind() == prim::NumToTensor) {
139 GRAPH_UPDATE(
140 getHeader(node),
141 " (x.NumToTensor() == x) is replaced with ",
142 node->input()->debugName());
143 node->output()->replaceAllUsesWith(input_node->input());
144 changed = true;
145 }
146 } else if (
147 node->matches("aten::size(Tensor self) -> int[]") &&
148 shape_peepholes_) {
149 if (auto ptt = node->input()->type()->cast<TensorType>()) {
150 if (auto sizes = ptt->sizes().concrete_sizes()) {
151 GRAPH_UPDATE(
152 getHeader(node),
153 " (x.size()) is replaced with ",
154 node->input()->debugName());
155 WithInsertPoint guard(node);
156 IValue ival(sizes);
157 auto const_sizes_val = node->owningGraph()->insertConstant(ival);
158 node->output()->replaceAllUsesWith(const_sizes_val);
159 changed = true;
160 }
161 }
162 } else if (
163 node->matches("aten::len.t(t[] a) -> int") &&
164 node->input()->node()->matches("aten::size(Tensor self) -> int[]") &&
165 shape_peepholes_) {
166 auto ptt = node->input()->node()->input()->type()->expect<TensorType>();
167 // only handle one use case for now to avoid modifying mutated lists
168 // TODO: canonicalize as aten::dim ?
169 if (ptt->sizes().size() && node->input()->uses().size() == 1) {
170 WithInsertPoint guard(node);
171 auto output = node->owningGraph()->insertConstant(
172 static_cast<int64_t>(*ptt->sizes().size()));
173 GRAPH_UPDATE(
174 "Replacing ",
175 getHeader(node),
176 " with a \"dim\" constant ",
177 output->debugName());
178 node->output()->replaceAllUsesWith(output);
179 changed = true;
180 }
181 } else if (
182 node->matches("aten::size(Tensor self, int dim) -> int") &&
183 shape_peepholes_) {
184 if (auto ptt = node->inputs().at(0)->type()->cast<TensorType>()) {
185 if (auto maybe_ndim = ptt->sizes().size()) {
186 auto ndim = static_cast<int64_t>(*maybe_ndim);
187 auto maybe_index = toIValue(node->inputs().at(1));
188 if (!maybe_index) {
189 continue;
190 }
191 int64_t index = maybe_index->toInt();
192 int64_t norm_index = index < 0 ? ndim + index : index;
193 if (norm_index >= 0 && norm_index < ndim &&
194 ptt->sizes()[norm_index]) {
195 WithInsertPoint guard(node);
196 IValue ival(*ptt->sizes()[norm_index]);
197 auto const_sizes_val = node->owningGraph()->insertConstant(ival);
198 node->output()->replaceAllUsesWith(const_sizes_val);
199 GRAPH_UPDATE(
200 getHeader(node),
201 " (x.size(dim)) is replaced with constant ",
202 const_sizes_val->debugName());
203 changed = true;
204 }
205 }
206 }
207 } else if (
208 node->matches("aten::is_floating_point(Tensor self) -> bool") &&
209 shape_peepholes_) {
210 auto ptt = node->inputs().at(0)->type()->cast<TensorType>();
211 if (auto maybe_dtype = ptt->scalarType()) {
212 c10::ScalarType dtype = *maybe_dtype;
213 WithInsertPoint guard(node);
214 IValue ival(at::isFloatingType(dtype));
215 auto new_constant = node->owningGraph()->insertConstant(ival);
216 node->output()->replaceAllUsesWith(new_constant);
217 GRAPH_UPDATE(
218 getHeader(node),
219 " (x.is_floating_point()) is replaced with ",
220 new_constant->debugName());
221 changed = true;
222 }
223 } else if (
224 node->matches("aten::is_complex(Tensor self) -> bool") &&
225 shape_peepholes_) {
226 auto ptt = node->inputs().at(0)->type()->cast<TensorType>();
227 if (auto maybe_dtype = ptt->scalarType()) {
228 c10::ScalarType dtype = *maybe_dtype;
229 WithInsertPoint guard(node);
230 IValue ival(at::isComplexType(dtype));
231 auto new_constant = node->owningGraph()->insertConstant(ival);
232 node->output()->replaceAllUsesWith(new_constant);
233 GRAPH_UPDATE(
234 getHeader(node),
235 " (x.is_complex()) is replaced with ",
236 new_constant->debugName());
237 changed = true;
238 }
239 } else if (
240 node->matches("prim::dtype(Tensor a) -> int") && shape_peepholes_) {
241 auto ptt = node->input()->type()->expect<TensorType>();
242 if (ptt->scalarType()) {
243 WithInsertPoint guard(node);
244 auto output = node->owningGraph()->insertConstant(
245 static_cast<int64_t>(*ptt->scalarType()));
246 GRAPH_UPDATE(
247 "Replacing ",
248 getHeader(node),
249 " with a type constant ",
250 output->debugName());
251 node->output()->replaceAllUsesWith(output);
252 changed = true;
253 }
254 } else if (
255 node->matches("prim::device(Tensor a) -> Device") &&
256 shape_peepholes_) {
257 auto ptt = node->input()->type()->expect<TensorType>();
258 if (ptt->device()) {
259 WithInsertPoint guard(node);
260 auto output = node->owningGraph()->insertConstant(*ptt->device());
261 GRAPH_UPDATE(
262 "Replacing ",
263 getHeader(node),
264 " with a device constant ",
265 output->debugName());
266 node->output()->replaceAllUsesWith(output);
267 changed = true;
268 }
269 } else if (
270 node->matches("aten::device(str type, int index) -> Device") &&
271 shape_peepholes_) {
272 auto string_type = node->inputs().at(0)->type()->expect<StringType>();
273 if (string_type) {
274 WithInsertPoint guard(node);
275 std::string type_str = node->inputs().at(0)->node()->s(attr::value);
276 auto maybe_index = toIValue(node->inputs().at(1));
277 int64_t index = 0;
278 if (maybe_index) {
279 index = maybe_index->toInt();
280 }
281 auto device = c10::Device(type_str + ":" + std::to_string(index));
282 auto output = node->owningGraph()->insertConstant(device);
283 GRAPH_UPDATE(
284 "Replacing ",
285 getHeader(node),
286 " with a device constant ",
287 output->debugName());
288 node->output()->replaceAllUsesWith(output);
289 changed = true;
290 }
291 } else if (
292 node->matches("aten::dim(Tensor self) -> int") && shape_peepholes_) {
293 auto ptt = node->input()->type()->expect<TensorType>();
294 if (auto dim = ptt->sizes().size()) {
295 WithInsertPoint guard(node);
296 auto output =
297 node->owningGraph()->insertConstant(static_cast<int64_t>(*dim));
298 GRAPH_UPDATE(
299 "Replacing ",
300 getHeader(node),
301 " with a \"dim\" constant ",
302 output->debugName());
303 node->output()->replaceAllUsesWith(output);
304 changed = true;
305 }
306 } else if (
307 node->matches("prim::is_cuda(Tensor a) -> bool") &&
308 shape_peepholes_) {
309 auto ptt = node->input()->type()->expect<TensorType>();
310 if (ptt->device()) {
311 WithInsertPoint guard(node);
312 auto output =
313 node->owningGraph()->insertConstant((*ptt->device()).is_cuda());
314 GRAPH_UPDATE(
315 "Replacing ",
316 getHeader(node),
317 " with a is_cuda constant ",
318 output->debugName());
319 node->output()->replaceAllUsesWith(output);
320 changed = true;
321 }
322 }
323 }
324 return changed;
325 }
326
327 private:
328 std::shared_ptr<Graph> graph_;
329 bool shape_peepholes_;
330 };
331
FuseAddMM(Block * block)332 static bool FuseAddMM(Block* block) {
333 bool changed = false;
334 for (Node* node : block->nodes()) {
335 // XXX: remember that if you want to simplify an expression by combining
336 // multiple nodes into a different one, then you need to check that they
337 // all belong to the given block
338 if (node->matches(
339 "aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor",
340 /*const_inputs=*/attr::alpha)) {
341 // z + x.mm(y) == z.addmm(x, y) == x.mm(y) + z
342 if (node->get<at::Scalar>(attr::alpha).value().toDouble() == 1.) {
343 // Look for mm from both sides of the add
344 for (const auto mm_side : c10::irange(2)) {
345 // Add will accept tensors of mismatched scalar types, as long as
346 // one of them is a scalar, but addmm will throw in that case, so we
347 // can only perform this fusion if we're sure that it is correct,
348 // and for that we need the add_mat_type. An alternative would be to
349 // insert a type_as conditional on the tensor shape being a scalar,
350 // but that might add overhead, and make analysis harder.
351 auto add_mat_type =
352 node->input(1 - mm_side)->type()->expect<TensorType>();
353 // if we don't have the rank, we can't tell if the bias is a scalar
354 if (!add_mat_type->sizes().size()) {
355 continue;
356 }
357
358 if (node->input(mm_side)->node()->matches(
359 "aten::mm(Tensor self, Tensor mat2) -> Tensor")) {
360 WithInsertPoint guard(node);
361
362 auto* graph = node->owningGraph();
363 auto* mm_node = node->input(mm_side)->node();
364 auto* add_mat = node->input(1 - mm_side);
365 auto* mat1 = mm_node->input(0);
366 auto* mat2 = mm_node->input(1);
367
368 // Attempts to find a matrix with a defined scalar type to type as
369 auto* type_as_mat = mat1;
370 if (!type_as_mat->type()->expectRef<TensorType>().scalarType()) {
371 type_as_mat = mat2;
372 }
373 auto mat_scalar_type =
374 type_as_mat->type()->expectRef<TensorType>().scalarType();
375
376 // we can't use type_as if we don't know the target type (mm), the
377 // bias needs to be coerced to
378 if (!mat_scalar_type) {
379 continue;
380 }
381
382 // We insert the type_as if we're sure that the added element is a
383 // scalar, and we either don't know the type of the scalar, or
384 // know that it's mismatched.
385 if (add_mat_type->sizes().size() &&
386 *add_mat_type->sizes().size() == 0 &&
387 !mustBeEqual(add_mat_type->scalarType(), mat_scalar_type)) {
388 auto* type_as_node =
389 graph->insertNode(graph->create(aten::type_as, 1));
390 type_as_node->addInput(add_mat);
391 type_as_node->addInput(type_as_mat);
392 add_mat = type_as_node->output();
393 if (add_mat_type->isComplete()) {
394 auto new_type =
395 add_mat_type->withScalarType(mat_scalar_type)->contiguous();
396 add_mat->setType(new_type);
397 }
398 }
399
400 auto* cOne = graph->insertConstant(1);
401 auto* addmm_node = graph->insertNode(graph->create(aten::addmm, 1));
402 addmm_node->addInput(add_mat);
403 addmm_node->addInput(mat1);
404 addmm_node->addInput(mat2);
405 addmm_node->addInput(cOne);
406 addmm_node->addInput(cOne);
407 auto* addmm_value = addmm_node->output();
408
409 // Copy shape information from output node
410 addmm_value->copyMetadata(node->output());
411 GRAPH_UPDATE(
412 "Fusing ",
413 mm_node->input(0)->debugName(),
414 ", ",
415 mm_node->input(1)->debugName(),
416 " and ",
417 node->input(1 - mm_side)->debugName(),
418 " into ",
419 addmm_value->debugName());
420 node->output()->replaceAllUsesWith(addmm_value);
421 changed = true;
422 continue;
423 }
424 }
425 }
426 }
427 for (Block* b : node->blocks()) {
428 changed |= FuseAddMM(b);
429 }
430 }
431 return changed;
432 }
433
434 // FuseAddMM is a separate pass from peephole optimize because it is currently
435 // used for exporting to ONNX.
436 // Today, fusing add + MM has no benefit within PyTorch running ATen
437 // ops. However, we rely on seeing the fused version of AddMM for ONNX export,
438 // since otherwise after ONNX translation we would see redundant Gemm ops with
439 // sub-optimal inputs.
440 // It won't be helpful for ATen until we're able to represent
441 // torch.addmm(a, b, c, out=a).
442 // That's because addmm dispatches internally to gemm, which computes:
443 // C = beta * C + alpha * A @ B
444 // but aten::addmm(a, b, c, 1, 1) is really:
445 // D = beta * C + alpha * A @ B
446 // and because it works out of place on C, we're only trading off an
447 // explicit add for a copy inside the addmm function. Note that it
448 // doesn't even result in fewer reads, because mm won't even load C
449 // (because beta == 0 for it).
FuseAddMM(const std::shared_ptr<Graph> & graph)450 bool FuseAddMM(const std::shared_ptr<Graph>& graph) {
451 bool changed = FuseAddMM(graph->block());
452 GRAPH_DUMP("After FuseAddMM: ", graph);
453 return changed;
454 }
455
PeepholeOptimize(const std::shared_ptr<Graph> & graph,bool addmm_fusion_enabled)456 bool PeepholeOptimize(
457 const std::shared_ptr<Graph>& graph,
458 bool addmm_fusion_enabled) {
459 PeepholeOptimizeImpl peephole(graph, addmm_fusion_enabled);
460 bool changed = peephole.run();
461 GRAPH_DUMP("After PeepholeOptimize: ", graph);
462 // Eliminate dead code created by any peephole passes we've just done
463 if (changed) {
464 EliminateDeadCode(graph->block());
465 }
466 return changed;
467 }
468
469 } // namespace torch::jit
470