xref: /aosp_15_r20/external/executorch/exir/passes/quant_fusion_pass.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import torch
8from executorch.exir.dialects._ops import ops as exir_ops
9from executorch.exir.pass_base import ExportPass
10from torch.fx import GraphModule, subgraph_rewriter
11from torch.fx.passes.infra.pass_base import PassResult
12from torch.utils import _pytree as pytree
13
14from ._quant_patterns_and_replacements import get_quant_patterns_and_replacements
15
16
17def _fuse_quantized_cat(model: GraphModule) -> None:
18    """fuse "dequantize -> cat -> quantize" pattern to cat operator, only happens if the quantization
19    parameters for dequantize for all the inputs matches, and it also matches the quantization
20    parameters for the quantize node after cat
21    """
22
23    # get quantization parameters for the node, either for quantize or dequantize node
24    def _get_qparams(node):
25        assert node.target in (
26            exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
27            exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
28        )
29        args = list(node.args)
30        # skip input
31        qparams = args[1:]
32        return qparams
33
34    for n in model.graph.nodes:
35        if (
36            n.op != "call_function"
37            or n.target
38            != exir_ops.edge.quantized_decomposed.quantize_per_tensor.default
39        ):
40
41            continue
42        qnode = n
43        maybe_cat = qnode.args[0]
44        if (
45            maybe_cat.op != "call_function"
46            or maybe_cat.target != exir_ops.edge.aten.cat.default
47        ):
48
49            continue
50        tensor_args = maybe_cat.args[0]
51        if not isinstance(tensor_args, (tuple, list)):
52            continue
53
54        matched_quantized_cat = True
55        output_qparams = _get_qparams(qnode)
56        for tensor_arg in tensor_args:
57            if (
58                tensor_arg.op != "call_function"
59                or tensor_arg.target
60                != exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default
61            ):
62
63                matched_quantized_cat = False
64                break
65
66            # make sure the input qparams for each input tensor in the concat list
67            # matches the output qparams
68            current_input_qparams = _get_qparams(tensor_arg)
69            if not current_input_qparams == output_qparams:
70                matched_quantized_cat = False
71                break
72
73        if not matched_quantized_cat:
74            continue
75
76        # now we matched a pattern for quantized cat, e.g.
77        # input1 (int8) -> dq1 -> cat -> q -> following_op
78        # input2 (int8) -> dq2 -/
79
80        # remove dq for inputs and q for output and run cat on the int8 input directly
81        # input1 (int8) -> cat -> following_op
82        # input2 (int8) -/
83
84        # reroute the input of dq to the cat node
85        for tensor_arg in tensor_args:
86            maybe_cat.replace_input_with(tensor_arg, tensor_arg.args[0])
87
88        # remove q for output
89        qnode.replace_all_uses_with(maybe_cat)
90        model.graph.erase_node(qnode)
91
92
93class QuantFusionPass(ExportPass):
94    def __init__(self, _fix_node_meta_val=False):
95        super().__init__()
96        # TODO This pass violate IR spec because it produces a graph missing node.meta['val']
97        self._fix_node_meta_val = _fix_node_meta_val
98
99    def call(self, graph_module: GraphModule) -> PassResult:
100        """Lower a quantized reference model (with reference quantized operator patterns)
101        to executorch backend, that has a canonical set of quantized operators. This pass
102        is a backend pass and should be applied on top of Edge dialect, ideally in
103        `ExecutorchBackendConfig.passes`. See `test_quant_fusion_pass.py` for an example.
104        """
105        # linear, conv2d
106        # dynamic_linear
107        # add
108        # batchnorm2d, relu, adaptive_avg_pool2d, reshape, squeeze, permute
109        for (
110            pattern,
111            replacement,
112            match_filters,
113        ) in get_quant_patterns_and_replacements():
114            subgraph_rewriter.replace_pattern_with_filters(
115                graph_module, pattern, replacement, match_filters
116            )
117
118        _fuse_quantized_cat(graph_module)
119        if self._fix_node_meta_val:
120            for n in graph_module.graph.nodes:
121                if n.op == "call_function" and "val" not in n.meta:
122                    args, kwargs = pytree.tree_map_only(
123                        torch.fx.Node, lambda x: x.meta["val"], (n.args, n.kwargs)
124                    )
125                    n.meta["val"] = n.target(*args, **kwargs)
126        graph_module.graph.lint()
127        graph_module.graph.eliminate_dead_code()
128        return PassResult(graph_module, True)
129