xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/peephole.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/passes/peephole.h>
2 
3 #include <ATen/core/jit_type.h>
4 #include <c10/util/irange.h>
5 #include <torch/csrc/jit/ir/alias_analysis.h>
6 #include <torch/csrc/jit/ir/ir_views.h>
7 #include <torch/csrc/jit/jit_log.h>
8 #include <torch/csrc/jit/passes/concat_opt.h>
9 #include <torch/csrc/jit/passes/dead_code_elimination.h>
10 #include <torch/csrc/jit/passes/peephole_alias_sensitive.h>
11 #include <torch/csrc/jit/passes/peephole_dict_idioms.h>
12 #include <torch/csrc/jit/passes/peephole_list_idioms.h>
13 #include <torch/csrc/jit/passes/peephole_non_tensor.h>
14 #include <torch/csrc/jit/runtime/graph_executor.h>
15 
16 namespace torch::jit {
17 
18 // Conservatively compare two optionals. If both are undefined, assume
19 // they aren't equal
20 template <typename T>
mustBeEqual(const std::optional<T> & a,const std::optional<T> & b)21 static bool mustBeEqual(const std::optional<T>& a, const std::optional<T>& b) {
22   return a == b && a.has_value();
23 }
24 
25 struct PeepholeOptimizeImpl {
PeepholeOptimizeImpltorch::jit::PeepholeOptimizeImpl26   PeepholeOptimizeImpl(
27       std::shared_ptr<Graph> graph,
28       bool disable_shape_peepholes)
29       : graph_(std::move(graph)), shape_peepholes_(!disable_shape_peepholes) {}
30 
runtorch::jit::PeepholeOptimizeImpl31   bool run() {
32     bool changed = optimizeBlock(graph_->block());
33     changed |= PeepholeOptimizeListIdioms(graph_);
34     changed |= PeepholeOptimizeDictIdioms(graph_);
35     changed |= PeepholeOptimizeAliasSensitive(graph_, shape_peepholes_);
36     changed |= PeepholeOptimizeNonTensor(graph_);
37     changed |= CombineConcats(graph_);
38     return changed;
39   }
40 
41   // The intent for this optimization pass is to catch all of the small, easy to
42   // catch peephole optimizations you might be interested in doing.
43   //
44   // TODO: Decide what kind of fixed point strategy we will have
optimizeBlocktorch::jit::PeepholeOptimizeImpl45   bool optimizeBlock(Block* block) {
46     bool changed = false;
47     for (auto it = block->nodes().begin(); it != block->nodes().end(); ++it) {
48       auto* node = *it;
49 
50       for (Block* sub_block : node->blocks()) {
51         changed |= optimizeBlock(sub_block);
52       }
53 
54       // XXX: remember that if you want to simplify an expression by combining
55       // multiple nodes into a different one, then you need to check that they
56       // all belong to the given block
57       // TODO: this doesn't work with Scalar-Tensor ops! We should
58       // canonicalize those
59       if (node->matches(
60               "aten::_grad_sum_to_size(Tensor(a) self, int[]? size) -> Tensor(a)")) {
61         // Eliminate no-op _grad_sum_to_size.
62         // TODO: this doesn't work with Scalar-Tensor ops! We should
63         // canonicalize those
64         if (node->input(1)->mustBeNone()) {
65           GRAPH_UPDATE(
66               getHeader(node),
67               " (x._grad_sum_to_size(x, None) == x) is replaced with ",
68               node->input(0)->debugName());
69           node->output()->replaceAllUsesWith(node->input(0));
70           changed = true;
71         } else {
72           auto uses = node->output()->uses();
73           for (Use u : uses) {
74             if (u.user->matches(
75                     "aten::_grad_sum_to_size(Tensor(a) self, int[]? size) -> Tensor(a)") &&
76                 u.user->input(1)->type()->isSubtypeOf(*ListType::ofInts())) {
77               GRAPH_UPDATE(
78                   getHeader(node),
79                   " (x._grad_sum_to_size(y)._grad_sum_to_size(z) == x._grad_sum_to_size(z)) is replaced with ",
80                   node->inputs().at(0)->debugName());
81               u.user->replaceInput(0, node->inputs().at(0));
82               changed = true;
83             }
84           }
85         }
86       } else if (
87           node->matches(
88               "aten::expand(Tensor self, int[] size, *, bool implicit) -> Tensor",
89               /*const_inputs=*/attr::size)) {
90         // x.expand(x.size()) == x
91         auto input_type =
92             node->namedInput(attr::self)->type()->cast<TensorType>();
93         if (input_type && shape_peepholes_) {
94           auto expanded_sizes = node->get<c10::List<int64_t>>(attr::size);
95           auto input_type_sizes = input_type->sizes().concrete_sizes();
96           if (expanded_sizes.has_value() && input_type_sizes &&
97               expanded_sizes->vec() == *input_type_sizes) {
98             GRAPH_UPDATE(
99                 getHeader(node),
100                 " (x.expand(x.size()) == x) is replaced with ",
101                 node->namedInput(attr::self)->debugName());
102             node->output()->replaceAllUsesWith(node->namedInput(attr::self));
103             changed = true;
104           }
105         }
106       } else if (node->matches("aten::t(Tensor self) -> Tensor")) {
107         // x.t().t() == x
108         Node* input_node = node->input()->node();
109         if (input_node->matches("aten::t(Tensor self) -> Tensor")) {
110           GRAPH_UPDATE(
111               getHeader(node),
112               " (x.t().t() == x) is replaced with ",
113               input_node->input()->debugName());
114           node->output()->replaceAllUsesWith(input_node->input());
115           changed = true;
116         }
117       } else if (
118           node->matches("aten::type_as(Tensor self, Tensor other) -> Tensor") &&
119           shape_peepholes_) {
120         // x.type_as(y) == x iff x.type() == y.type()
121         auto self_type = node->input(0)->type()->expect<TensorType>();
122         auto other_type = node->input(1)->type()->expect<TensorType>();
123         if (mustBeEqual(self_type->scalarType(), other_type->scalarType()) &&
124             mustBeEqual(self_type->device(), other_type->device())) {
125           GRAPH_UPDATE(
126               getHeader(node),
127               " (x.type_as(y) == x) is replaced with ",
128               node->input(0)->debugName());
129           node->output()->replaceAllUsesWith(node->input(0));
130           changed = true;
131         }
132       } else if (
133           node->kind() == aten::Float || node->kind() == aten::Int ||
134           node->kind() == aten::FloatImplicit ||
135           node->kind() == aten::IntImplicit ||
136           node->kind() == aten::ScalarImplicit) {
137         Node* input_node = node->input()->node();
138         if (input_node->kind() == prim::NumToTensor) {
139           GRAPH_UPDATE(
140               getHeader(node),
141               " (x.NumToTensor() == x) is replaced with ",
142               node->input()->debugName());
143           node->output()->replaceAllUsesWith(input_node->input());
144           changed = true;
145         }
146       } else if (
147           node->matches("aten::size(Tensor self) -> int[]") &&
148           shape_peepholes_) {
149         if (auto ptt = node->input()->type()->cast<TensorType>()) {
150           if (auto sizes = ptt->sizes().concrete_sizes()) {
151             GRAPH_UPDATE(
152                 getHeader(node),
153                 " (x.size()) is replaced with ",
154                 node->input()->debugName());
155             WithInsertPoint guard(node);
156             IValue ival(sizes);
157             auto const_sizes_val = node->owningGraph()->insertConstant(ival);
158             node->output()->replaceAllUsesWith(const_sizes_val);
159             changed = true;
160           }
161         }
162       } else if (
163           node->matches("aten::len.t(t[] a) -> int") &&
164           node->input()->node()->matches("aten::size(Tensor self) -> int[]") &&
165           shape_peepholes_) {
166         auto ptt = node->input()->node()->input()->type()->expect<TensorType>();
167         // only handle one use case for now to avoid modifying mutated lists
168         // TODO: canonicalize as aten::dim ?
169         if (ptt->sizes().size() && node->input()->uses().size() == 1) {
170           WithInsertPoint guard(node);
171           auto output = node->owningGraph()->insertConstant(
172               static_cast<int64_t>(*ptt->sizes().size()));
173           GRAPH_UPDATE(
174               "Replacing ",
175               getHeader(node),
176               " with a \"dim\" constant ",
177               output->debugName());
178           node->output()->replaceAllUsesWith(output);
179           changed = true;
180         }
181       } else if (
182           node->matches("aten::size(Tensor self, int dim) -> int") &&
183           shape_peepholes_) {
184         if (auto ptt = node->inputs().at(0)->type()->cast<TensorType>()) {
185           if (auto maybe_ndim = ptt->sizes().size()) {
186             auto ndim = static_cast<int64_t>(*maybe_ndim);
187             auto maybe_index = toIValue(node->inputs().at(1));
188             if (!maybe_index) {
189               continue;
190             }
191             int64_t index = maybe_index->toInt();
192             int64_t norm_index = index < 0 ? ndim + index : index;
193             if (norm_index >= 0 && norm_index < ndim &&
194                 ptt->sizes()[norm_index]) {
195               WithInsertPoint guard(node);
196               IValue ival(*ptt->sizes()[norm_index]);
197               auto const_sizes_val = node->owningGraph()->insertConstant(ival);
198               node->output()->replaceAllUsesWith(const_sizes_val);
199               GRAPH_UPDATE(
200                   getHeader(node),
201                   " (x.size(dim)) is replaced with constant ",
202                   const_sizes_val->debugName());
203               changed = true;
204             }
205           }
206         }
207       } else if (
208           node->matches("aten::is_floating_point(Tensor self) -> bool") &&
209           shape_peepholes_) {
210         auto ptt = node->inputs().at(0)->type()->cast<TensorType>();
211         if (auto maybe_dtype = ptt->scalarType()) {
212           c10::ScalarType dtype = *maybe_dtype;
213           WithInsertPoint guard(node);
214           IValue ival(at::isFloatingType(dtype));
215           auto new_constant = node->owningGraph()->insertConstant(ival);
216           node->output()->replaceAllUsesWith(new_constant);
217           GRAPH_UPDATE(
218               getHeader(node),
219               " (x.is_floating_point()) is replaced with ",
220               new_constant->debugName());
221           changed = true;
222         }
223       } else if (
224           node->matches("aten::is_complex(Tensor self) -> bool") &&
225           shape_peepholes_) {
226         auto ptt = node->inputs().at(0)->type()->cast<TensorType>();
227         if (auto maybe_dtype = ptt->scalarType()) {
228           c10::ScalarType dtype = *maybe_dtype;
229           WithInsertPoint guard(node);
230           IValue ival(at::isComplexType(dtype));
231           auto new_constant = node->owningGraph()->insertConstant(ival);
232           node->output()->replaceAllUsesWith(new_constant);
233           GRAPH_UPDATE(
234               getHeader(node),
235               " (x.is_complex()) is replaced with ",
236               new_constant->debugName());
237           changed = true;
238         }
239       } else if (
240           node->matches("prim::dtype(Tensor a) -> int") && shape_peepholes_) {
241         auto ptt = node->input()->type()->expect<TensorType>();
242         if (ptt->scalarType()) {
243           WithInsertPoint guard(node);
244           auto output = node->owningGraph()->insertConstant(
245               static_cast<int64_t>(*ptt->scalarType()));
246           GRAPH_UPDATE(
247               "Replacing ",
248               getHeader(node),
249               " with a type constant ",
250               output->debugName());
251           node->output()->replaceAllUsesWith(output);
252           changed = true;
253         }
254       } else if (
255           node->matches("prim::device(Tensor a) -> Device") &&
256           shape_peepholes_) {
257         auto ptt = node->input()->type()->expect<TensorType>();
258         if (ptt->device()) {
259           WithInsertPoint guard(node);
260           auto output = node->owningGraph()->insertConstant(*ptt->device());
261           GRAPH_UPDATE(
262               "Replacing ",
263               getHeader(node),
264               " with a device constant ",
265               output->debugName());
266           node->output()->replaceAllUsesWith(output);
267           changed = true;
268         }
269       } else if (
270           node->matches("aten::device(str type, int index) -> Device") &&
271           shape_peepholes_) {
272         auto string_type = node->inputs().at(0)->type()->expect<StringType>();
273         if (string_type) {
274           WithInsertPoint guard(node);
275           std::string type_str = node->inputs().at(0)->node()->s(attr::value);
276           auto maybe_index = toIValue(node->inputs().at(1));
277           int64_t index = 0;
278           if (maybe_index) {
279             index = maybe_index->toInt();
280           }
281           auto device = c10::Device(type_str + ":" + std::to_string(index));
282           auto output = node->owningGraph()->insertConstant(device);
283           GRAPH_UPDATE(
284               "Replacing ",
285               getHeader(node),
286               " with a device constant ",
287               output->debugName());
288           node->output()->replaceAllUsesWith(output);
289           changed = true;
290         }
291       } else if (
292           node->matches("aten::dim(Tensor self) -> int") && shape_peepholes_) {
293         auto ptt = node->input()->type()->expect<TensorType>();
294         if (auto dim = ptt->sizes().size()) {
295           WithInsertPoint guard(node);
296           auto output =
297               node->owningGraph()->insertConstant(static_cast<int64_t>(*dim));
298           GRAPH_UPDATE(
299               "Replacing ",
300               getHeader(node),
301               " with a \"dim\" constant ",
302               output->debugName());
303           node->output()->replaceAllUsesWith(output);
304           changed = true;
305         }
306       } else if (
307           node->matches("prim::is_cuda(Tensor a) -> bool") &&
308           shape_peepholes_) {
309         auto ptt = node->input()->type()->expect<TensorType>();
310         if (ptt->device()) {
311           WithInsertPoint guard(node);
312           auto output =
313               node->owningGraph()->insertConstant((*ptt->device()).is_cuda());
314           GRAPH_UPDATE(
315               "Replacing ",
316               getHeader(node),
317               " with a is_cuda constant ",
318               output->debugName());
319           node->output()->replaceAllUsesWith(output);
320           changed = true;
321         }
322       }
323     }
324     return changed;
325   }
326 
327  private:
328   std::shared_ptr<Graph> graph_;
329   bool shape_peepholes_;
330 };
331 
FuseAddMM(Block * block)332 static bool FuseAddMM(Block* block) {
333   bool changed = false;
334   for (Node* node : block->nodes()) {
335     // XXX: remember that if you want to simplify an expression by combining
336     // multiple nodes into a different one, then you need to check that they
337     // all belong to the given block
338     if (node->matches(
339             "aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor",
340             /*const_inputs=*/attr::alpha)) {
341       // z + x.mm(y) == z.addmm(x, y) == x.mm(y) + z
342       if (node->get<at::Scalar>(attr::alpha).value().toDouble() == 1.) {
343         // Look for mm from both sides of the add
344         for (const auto mm_side : c10::irange(2)) {
345           // Add will accept tensors of mismatched scalar types, as long as
346           // one of them is a scalar, but addmm will throw in that case, so we
347           // can only perform this fusion if we're sure that it is correct,
348           // and for that we need the add_mat_type. An alternative would be to
349           // insert a type_as conditional on the tensor shape being a scalar,
350           // but that might add overhead, and make analysis harder.
351           auto add_mat_type =
352               node->input(1 - mm_side)->type()->expect<TensorType>();
353           // if we don't have the rank, we can't tell if the bias is a scalar
354           if (!add_mat_type->sizes().size()) {
355             continue;
356           }
357 
358           if (node->input(mm_side)->node()->matches(
359                   "aten::mm(Tensor self, Tensor mat2) -> Tensor")) {
360             WithInsertPoint guard(node);
361 
362             auto* graph = node->owningGraph();
363             auto* mm_node = node->input(mm_side)->node();
364             auto* add_mat = node->input(1 - mm_side);
365             auto* mat1 = mm_node->input(0);
366             auto* mat2 = mm_node->input(1);
367 
368             // Attempts to find a matrix with a defined scalar type to type as
369             auto* type_as_mat = mat1;
370             if (!type_as_mat->type()->expectRef<TensorType>().scalarType()) {
371               type_as_mat = mat2;
372             }
373             auto mat_scalar_type =
374                 type_as_mat->type()->expectRef<TensorType>().scalarType();
375 
376             // we can't use type_as if we don't know the target type (mm), the
377             // bias needs to be coerced to
378             if (!mat_scalar_type) {
379               continue;
380             }
381 
382             // We insert the type_as if we're sure that the added element is a
383             // scalar, and we either don't know the type of the scalar, or
384             // know that it's mismatched.
385             if (add_mat_type->sizes().size() &&
386                 *add_mat_type->sizes().size() == 0 &&
387                 !mustBeEqual(add_mat_type->scalarType(), mat_scalar_type)) {
388               auto* type_as_node =
389                   graph->insertNode(graph->create(aten::type_as, 1));
390               type_as_node->addInput(add_mat);
391               type_as_node->addInput(type_as_mat);
392               add_mat = type_as_node->output();
393               if (add_mat_type->isComplete()) {
394                 auto new_type =
395                     add_mat_type->withScalarType(mat_scalar_type)->contiguous();
396                 add_mat->setType(new_type);
397               }
398             }
399 
400             auto* cOne = graph->insertConstant(1);
401             auto* addmm_node = graph->insertNode(graph->create(aten::addmm, 1));
402             addmm_node->addInput(add_mat);
403             addmm_node->addInput(mat1);
404             addmm_node->addInput(mat2);
405             addmm_node->addInput(cOne);
406             addmm_node->addInput(cOne);
407             auto* addmm_value = addmm_node->output();
408 
409             // Copy shape information from output node
410             addmm_value->copyMetadata(node->output());
411             GRAPH_UPDATE(
412                 "Fusing ",
413                 mm_node->input(0)->debugName(),
414                 ", ",
415                 mm_node->input(1)->debugName(),
416                 " and ",
417                 node->input(1 - mm_side)->debugName(),
418                 " into ",
419                 addmm_value->debugName());
420             node->output()->replaceAllUsesWith(addmm_value);
421             changed = true;
422             continue;
423           }
424         }
425       }
426     }
427     for (Block* b : node->blocks()) {
428       changed |= FuseAddMM(b);
429     }
430   }
431   return changed;
432 }
433 
434 // FuseAddMM is a separate pass from peephole optimize because it is currently
435 // used for exporting to ONNX.
436 // Today, fusing add + MM has no benefit within PyTorch running ATen
437 // ops. However, we rely on seeing the fused version of AddMM for ONNX export,
438 // since otherwise after ONNX translation we would see redundant Gemm ops with
439 // sub-optimal inputs.
440 // It won't be helpful for ATen until we're able to represent
441 //   torch.addmm(a, b, c, out=a).
442 // That's because addmm dispatches internally to gemm, which computes:
443 //   C = beta * C + alpha * A @ B
444 // but aten::addmm(a, b, c, 1, 1) is really:
445 //   D = beta * C + alpha * A @ B
446 // and because it works out of place on C, we're only trading off an
447 // explicit add for a copy inside the addmm function. Note that it
448 // doesn't even result in fewer reads, because mm won't even load C
449 // (because beta == 0 for it).
FuseAddMM(const std::shared_ptr<Graph> & graph)450 bool FuseAddMM(const std::shared_ptr<Graph>& graph) {
451   bool changed = FuseAddMM(graph->block());
452   GRAPH_DUMP("After FuseAddMM: ", graph);
453   return changed;
454 }
455 
PeepholeOptimize(const std::shared_ptr<Graph> & graph,bool addmm_fusion_enabled)456 bool PeepholeOptimize(
457     const std::shared_ptr<Graph>& graph,
458     bool addmm_fusion_enabled) {
459   PeepholeOptimizeImpl peephole(graph, addmm_fusion_enabled);
460   bool changed = peephole.run();
461   GRAPH_DUMP("After PeepholeOptimize: ", graph);
462   // Eliminate dead code created by any peephole passes we've just done
463   if (changed) {
464     EliminateDeadCode(graph->block());
465   }
466   return changed;
467 }
468 
469 } // namespace torch::jit
470