xref: /aosp_15_r20/external/executorch/backends/qualcomm/quantizer/quantizer.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Qualcomm Innovation Center, Inc.
2# All rights reserved
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6from enum import IntEnum, unique
7from typing import Callable, Optional, Sequence, Set
8
9import torch
10from executorch.backends.qualcomm._passes.decompose_einsum import DecomposeEinsum
11from executorch.backends.qualcomm._passes.decompose_silu import DecomposeSilu
12from executorch.backends.qualcomm._passes.recompose_pixel_unshuffle import (
13    RecomposePixelUnshuffle,
14)
15from executorch.backends.qualcomm._passes.reduce_dynamic_range import ReduceDynamicRange
16from executorch.backends.qualcomm._passes.replace_inf_buffer import ReplaceInfBuffer
17from executorch.backends.transforms.decompose_sdpa import (
18    DecomposeScaledDotProductAttention,
19)
20
21from torch._ops import OpOverload
22from torch.ao.quantization.quantizer import Quantizer
23from torch.fx import GraphModule
24
25from .annotators import OP_ANNOTATOR
26
27from .qconfig import (
28    get_16a16w_qnn_ptq_config,
29    get_16a4w_qnn_ptq_config,
30    get_16a4w_qnn_qat_config,
31    get_16a8w_qnn_ptq_config,
32    get_8a8w_qnn_ptq_config,
33    get_8a8w_qnn_qat_config,
34    get_ptq_per_channel_quant_config,
35    get_qat_per_channel_quant_config,
36    QuantizationConfig,
37)
38
39# To bypass the meta internal test error
40get_default_16bit_qnn_ptq_config = get_16a16w_qnn_ptq_config
41
42__all__ = [
43    "QnnQuantizer",
44    "QuantDtype",
45    "get_16a4w_qnn_ptq_config",
46    "get_16a8w_qnn_ptq_config",
47    "get_16a16w_qnn_ptq_config",
48    "get_8a8w_qnn_ptq_config",
49    "get_8a8w_qnn_qat_config",
50    "get_16a4w_qnn_qat_config",
51]
52
53
54@unique
55class QuantDtype(IntEnum):
56    """
57    bits of activation and bits of weight
58    """
59
60    use_16a16w = 0
61    use_16a8w = 1
62    use_16a4w = 2
63    use_8a8w = 3
64
65
66quant_config_dict = {
67    # PTQ
68    (QuantDtype.use_16a16w, False): (
69        get_16a16w_qnn_ptq_config,
70        get_ptq_per_channel_quant_config(torch.uint16, torch.int16),
71    ),
72    (QuantDtype.use_16a8w, False): (
73        get_16a8w_qnn_ptq_config,
74        get_ptq_per_channel_quant_config(torch.uint16, torch.int8),
75    ),
76    (QuantDtype.use_16a4w, False): (
77        get_16a4w_qnn_ptq_config,
78        get_ptq_per_channel_quant_config(torch.uint16, "int4"),
79    ),
80    (QuantDtype.use_8a8w, False): (
81        get_8a8w_qnn_ptq_config,
82        get_ptq_per_channel_quant_config(),
83    ),
84    # QAT,
85    (QuantDtype.use_16a4w, True): (
86        get_16a4w_qnn_qat_config,
87        get_qat_per_channel_quant_config(torch.uint16, "int4"),
88    ),
89    (QuantDtype.use_8a8w, True): (
90        get_8a8w_qnn_qat_config,
91        get_qat_per_channel_quant_config(),
92    ),
93}
94
95
96class QnnQuantizer(Quantizer):
97    SUPPORTED_OPS: Set = set(OP_ANNOTATOR.keys())
98
99    def __init__(self):
100        super().__init__()
101        self.quant_ops: Set[OpOverload] = self.SUPPORTED_OPS.copy()
102
103        self.is_qat = False
104        self.quant_dtype = QuantDtype.use_8a8w
105        self.quant_config: QuantizationConfig = get_8a8w_qnn_ptq_config()
106        self.per_channel_quant_config = get_ptq_per_channel_quant_config()
107        self.use_per_channel_weight_quant_ops: Set[OpOverload] = set()
108
109        self.custom_quant_annotations: Sequence[Callable] = []
110        self.discard_nodes: Set[str] = set()
111
112    def _annotate(self, gm: GraphModule) -> None:
113        for node in gm.graph.nodes:
114            if node.name in self.discard_nodes:
115                continue
116
117            quant_config = self._get_quant_config(node.target)
118            if quant_config:
119                OP_ANNOTATOR[node.target](node, quant_config)
120
121    def _annotate_custom_annotation(self, gm: GraphModule) -> None:
122        for annotation_func in self.custom_quant_annotations:
123            annotation_func(gm)
124
125    def _get_quant_config(self, op: str | OpOverload) -> Optional[QuantizationConfig]:
126        """
127        Priority:
128            1. is one of use_per_channel_weight_quant_ops
129            2. quant config
130        """
131        if isinstance(op, str):
132            return
133
134        if op in self.use_per_channel_weight_quant_ops:
135            return self.per_channel_quant_config
136
137        if op in self.quant_ops:
138            return self.quant_config
139
140        print(f"No quant config is implemented for op, {op}")
141
142    def _update_per_channel_weight_quant_ops(self, ops: Set[OpOverload], enable: bool):
143        if enable:
144            self.use_per_channel_weight_quant_ops.update(ops)
145        else:
146            self.use_per_channel_weight_quant_ops.difference_update(ops)
147
148    def add_custom_quant_annotations(
149        self, custom_quant_annotations: Sequence[Callable]
150    ) -> None:
151        self.custom_quant_annotations = custom_quant_annotations
152
153    def add_discard_nodes(self, nodes: Sequence[str]) -> None:
154        self.discard_nodes = set(nodes)
155
156    def add_discard_ops(self, ops: Sequence[OpOverload]) -> None:
157        for op in ops:
158            self.quant_ops.remove(op)
159
160    def annotate(self, model: GraphModule) -> GraphModule:
161        self._annotate(model)
162        self._annotate_custom_annotation(model)
163
164        return model
165
166    def get_supported_ops(self) -> Set[OpOverload]:
167        return self.SUPPORTED_OPS
168
169    def set_quant_config(
170        self, quant_dtype: QuantDtype, is_qat=False, act_observer=None
171    ) -> None:
172        self.quant_dtype = quant_dtype
173        self.is_qat = is_qat
174        if (quant_dtype, is_qat) not in quant_config_dict:
175            raise RuntimeError(
176                f"the quant config, (quant_dtype: {quant_dtype}, is_qat: {is_qat}) is not support"
177            )
178
179        quant_config_fuc, self.per_channel_quant_config = quant_config_dict[
180            (quant_dtype, is_qat)
181        ]
182        self.quant_config = (
183            quant_config_fuc(act_observer) if act_observer else quant_config_fuc()
184        )
185
186    def set_per_channel_conv_quant(self, enable: bool) -> None:
187        conv_ops = {torch.ops.aten.conv1d.default, torch.ops.aten.conv2d.default}
188        self._update_per_channel_weight_quant_ops(conv_ops, enable)
189
190    def set_per_channel_linear_quant(self, enable: bool) -> None:
191        linear_ops = {
192            torch.ops.aten.linear.default,
193        }
194        self._update_per_channel_weight_quant_ops(linear_ops, enable)
195
196    def transform_for_annotation(self, model: GraphModule) -> GraphModule:
197        model = ReduceDynamicRange()(model).graph_module
198        model = RecomposePixelUnshuffle(quantization_capture=True)(model).graph_module
199        model = DecomposeScaledDotProductAttention()(model).graph_module
200        model = DecomposeSilu()(model).graph_module
201        model = DecomposeEinsum()(model).graph_module
202        model = ReplaceInfBuffer()(model).graph_module
203        return model
204
205    def validate(self, model: GraphModule) -> None:
206        pass
207