xref: /aosp_15_r20/external/executorch/exir/passes/remove_noop_pass.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-strict
8
9from typing import List, Tuple
10
11import torch
12from executorch.exir.dialects._ops import ops as exir_ops
13from executorch.exir.pass_base import ExportPass, PassResult
14from torch.fx import GraphModule
15
16_DEQUANT_OPS: Tuple[torch._ops.OpOverload] = (
17    torch.ops.quantized_decomposed.dequantize_per_tensor.default,
18    torch.ops.quantized_decomposed.dequantize_per_channel.default,
19    exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
20    exir_ops.edge.quantized_decomposed.dequantize_per_channel.default,
21)
22_QUANT_OPS: Tuple[torch._ops.OpOverload] = (
23    torch.ops.quantized_decomposed.quantize_per_tensor.default,
24    torch.ops.quantized_decomposed.quantize_per_channel.default,
25    exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
26    exir_ops.edge.quantized_decomposed.quantize_per_channel.default,
27)
28
29
30def eliminate_dq_q(
31    graph_module: GraphModule,
32    dequant_nodes: List[torch.fx.Node],
33) -> None:
34    for node in dequant_nodes:
35        assert node.target in _DEQUANT_OPS
36        for user in list(node.users):
37            if user.target in _QUANT_OPS:
38                # Drop the input arg and check that the qparams are the same.
39                qparams_dq = list(node.args)[1:]
40                qparams_q = list(user.args)[1:]
41                if qparams_dq != qparams_q:
42                    continue
43                user.replace_all_uses_with(node.args[0])  # pyre-fixme[6]
44
45
46class RemoveNoopPass(ExportPass):
47    """
48    Removes noops that pass through arguments.
49    """
50
51    def call(self, graph_module: GraphModule) -> PassResult:
52
53        # In this list we'll collect all the dequant nodes that are inputs to ops that
54        # are removed in this pass and later check for redundant dq->q patterns and
55        # remove them.
56        dequant_nodes = []
57
58        for node in graph_module.graph.nodes:
59            if node.op != "call_function":
60                continue
61
62            if node.target not in (
63                torch.ops.aten.to.dtype,
64                torch.ops.aten.dropout.default,
65                torch.ops.aten.slice_copy.Tensor,
66            ):
67                continue
68
69            orig_tensor = node.args[0].meta["val"]
70
71            if orig_tensor is node.meta["val"]:
72                # If the graph is quantized, we must remove the entire pattern consisting of dq->op->q.
73                # Otherwise, removing only the op will suffice.
74                if node.args[0].target in _DEQUANT_OPS:
75                    dequant_nodes += [node.args[0]]
76                node.replace_all_uses_with(node.args[0])
77                continue
78
79            if node.target == torch.ops.aten.slice_copy.Tensor:
80                # Only do this check if all the dims are static.
81                if all(isinstance(dim, int) for dim in orig_tensor.size()):
82                    if orig_tensor.shape == node.meta["val"].shape:
83                        # If the graph is quantized, we must remove the entire pattern consisting of dq->op->q.
84                        # Otherwise, removing only the op will suffice.
85                        if node.args[0].target in _DEQUANT_OPS:
86                            dequant_nodes += [node.args[0]]
87                        node.replace_all_uses_with(node.args[0])
88
89        graph_module.graph.eliminate_dead_code()
90        eliminate_dq_q(graph_module, dequant_nodes)
91        graph_module.graph.lint()
92        graph_module.graph.eliminate_dead_code()
93
94        return PassResult(graph_module, True)
95
96
97class RemoveToCopyPass(ExportPass):
98    """
99    Removes _to_copy that pass through arguments.
100    """
101
102    def call(self, graph_module: GraphModule) -> PassResult:
103        for node in graph_module.graph.nodes:
104            if node.op != "call_function":
105                continue
106
107            if node.target not in (torch.ops.aten._to_copy.default,):
108                continue
109
110            orig_tensor = node.args[0].meta["val"]
111
112            if (
113                orig_tensor.dtype == node.meta["val"].dtype
114                and orig_tensor.device == node.meta["val"].device
115                and orig_tensor.shape == node.meta["val"].shape
116                and orig_tensor.stride() == node.meta["val"].stride()
117            ):
118                node.replace_all_uses_with(node.args[0])
119
120        graph_module.graph.eliminate_dead_code()
121        graph_module.graph.lint()
122
123        return PassResult(graph_module, True)
124