xref: /aosp_15_r20/external/executorch/extension/llm/export/quantizer_lib.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# This is for PT2E quantization.
8
9import logging
10from dataclasses import dataclass
11from typing import List, Optional
12
13import torch
14
15from torch.ao.quantization.quantizer import Quantizer
16from torch.ao.quantization.quantizer.embedding_quantizer import EmbeddingQuantizer
17from torch.ao.quantization.quantizer.xnnpack_quantizer import (
18    get_symmetric_quantization_config,
19    XNNPACKQuantizer,
20)
21
22FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
23logging.basicConfig(level=logging.INFO, format=FORMAT)
24
25
26@dataclass
27class EmbeddingQuantOptions:
28    is_per_channel: bool = True
29    group_size: int = -1
30
31    def __post_init__(self):
32        if self.group_size != -1:
33            raise RuntimeError(
34                "PT2E embedding quantizer does not support groupwise at the moment."
35            )
36
37
38@dataclass
39class DynamicQuantLinearOptions:
40    is_per_channel: bool = True
41    is_qc4: bool = False
42
43
44@dataclass
45class PT2EQuantOptions:
46    quantize_embedding: Optional[EmbeddingQuantOptions] = None
47    quantize_linear: Optional[DynamicQuantLinearOptions] = None
48
49
50def get_pt2e_quantization_params(
51    pt2e_quantize: Optional[str] = None,
52    quantization_mode: Optional[str] = None,
53) -> Optional[PT2EQuantOptions]:
54    if pt2e_quantize is None:
55        return None
56    if quantization_mode:
57        raise ValueError("Cannot specify both quantization_mode and pt2e_quantize")
58
59    quantization_options = pt2e_quantize.split(",")
60    quantization_options = [option.strip() for option in quantization_options]
61    # This can really be improved significantly.
62    # Hopefully we dont release this in its current form.
63    # Just using this for quick experiments.
64    quant_options = None
65    if "embedding" in quantization_options:
66        quant_options = quant_options or PT2EQuantOptions()
67        quant_options.quantize_embedding = EmbeddingQuantOptions()
68    if (
69        "xnnpack_dynamic" in quantization_options
70        and "xnnpack_dynamic_qc4" in quantization_options
71    ):
72        raise RuntimeError(
73            "For dynamic linear quantization via xnnpack quantizer you can chose only qc8 or qc4 option, not both."
74        )
75    if (
76        "xnnpack_dynamic" in quantization_options
77        or "xnnpack_dynamic_qc4" in quantization_options
78    ):
79        quant_options = quant_options or PT2EQuantOptions()
80        quant_options.quantize_linear = DynamicQuantLinearOptions()
81        if "xnnpack_dynamic_qc4" in quantization_options:
82            quant_options.quantize_linear.is_qc4 = True
83
84    return quant_options
85
86
87def get_pt2e_quantizers(
88    quant_params: Optional[PT2EQuantOptions],
89    so_library: Optional[str] = None,
90) -> List[Quantizer]:
91    """
92    Get a list of quantizers from quantization params
93    Args:
94        quant_params: PT2E quantization options.
95    Returns:
96        A list of quantizers to pass into LlamaBuilder.
97    """
98
99    def check_embedding_byte_registered():
100        try:
101            _ = torch.ops.quantized_decomposed.embedding_byte.out
102        except AttributeError:
103            if so_library:
104                print(f"Loading library {so_library}")
105                torch.ops.load_library(so_library)
106            else:
107                raise RuntimeError(
108                    "Need to specify shared library path to register quantized ops (and their out variants) into EXIR.\n"
109                    "Follow the following steps to build the needed lib via cmake.\n"
110                    'Use `python -c "import torch as _; print(_.__path__)"` to find where torch package is installed.\n'
111                    "Set that as TORCH_PACKAGE_DIR.\n"
112                    "Then from root executorch dir do the following:\n"
113                    "rm -rf cmake-out && mkdir cmake-out && (cd cmake-out && cmake -DBUCK2=<path-to-buck2> -DCMAKE_PREFIX_PATH=$TORCH_PACKAGE_DIR -DEXECUTORCH_BUILD_KERNELS_QUANTIZED_AOT=ON ..) && cmake --build . -j16\n"
114                    'To find the location of the lib: find cmake-out -name "libquantized_ops_aot_lib*"\n'
115                    "Then specify the said library via -s <path to libquantized_ops_aot_lib.so\n"
116                )
117
118    quantizers = []
119    if quant_params is not None and quant_params.quantize_embedding is not None:
120        logging.info("Apply PT2E embedding quantization.")
121        check_embedding_byte_registered()
122        quantizers.append(EmbeddingQuantizer())
123    if quant_params is not None and quant_params.quantize_linear is not None:
124        logging.info("Apply PT2E dynamic linear quantization.")
125        dynamic_quantizer = XNNPACKQuantizer()
126        assert quant_params.quantize_linear is not None
127        if not quant_params.quantize_linear.is_per_channel:
128            raise ValueError(
129                "At the moment only per channel weight quantization is supported."
130            )
131        if quant_params.quantize_linear.is_qc4:
132            operator_config_dynamic = get_symmetric_quantization_config(
133                is_per_channel=True, is_dynamic=True, weight_qmin=-8, weight_qmax=7
134            )
135        else:
136            operator_config_dynamic = get_symmetric_quantization_config(
137                is_per_channel=True, is_dynamic=True
138            )
139        dynamic_quantizer.set_global(operator_config_dynamic)
140        quantizers.append(dynamic_quantizer)
141    return quantizers
142
143
144def get_qnn_quantizer(
145    pt2e_quantize: str,
146    quantization_mode: Optional[str] = None,
147    is_qat: bool = False,
148):
149    try:
150        from executorch.backends.qualcomm.quantizer.custom_annotation import (  # pyre-fixme[21]
151            custom_annotate_llama_matmul_16a8w,
152        )
153
154        # pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm.quantizer.quantizer`
155        from executorch.backends.qualcomm.quantizer.quantizer import (
156            QnnQuantizer,
157            QuantDtype,
158        )
159        from torch.ao.quantization.observer import MinMaxObserver
160
161    except ImportError:
162        raise ImportError(
163            "Please install the Qualcomm backend follwing https://pytorch.org/executorch/main/build-run-qualcomm.html"
164        )
165
166    backend, quant_config = pt2e_quantize.split("_")
167    assert (
168        backend == "qnn"
169    ), f"The quantization config is for backend {backend} instead of qnn."
170    qnn_quantizer = QnnQuantizer()  # pyre-fixme[16]
171    qnn_quantizer.set_per_channel_conv_quant(enable=True)
172    qnn_quantizer.set_per_channel_linear_quant(enable=True)
173    # more custom quantization are supported including 16a4w etc. default to 8bit quantized
174    custom_annotations = ()
175    if quant_config == "8a8w":
176        quant_dtype = QuantDtype.use_8a8w  # pyre-fixme[16]
177        qnn_quantizer.set_quant_config(quant_dtype, is_qat=is_qat)
178    elif quant_config == "16a16w":
179        quant_dtype = QuantDtype.use_16a16w  # pyre-fixme[16]
180        # Due to the error with 16a16w in Qnn Htp, we need to disable per channel linear quantization when use 16a16w
181        # TODO: enable it after the issue is fixed
182        logging.warning(
183            "Disable per channel quantization for linear and conv due to the error with QNN HTP 16a16w."
184        )
185        qnn_quantizer.set_per_channel_conv_quant(enable=False)
186        qnn_quantizer.set_per_channel_linear_quant(enable=False)
187        # pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`.
188        qnn_quantizer.set_quant_config(
189            quant_dtype, is_qat=is_qat, act_observer=MinMaxObserver
190        )
191    elif quant_config == "16a4w":
192        # pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`.
193        quant_dtype = QuantDtype.use_16a4w
194        # pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`.
195        qnn_quantizer.set_quant_config(
196            quant_dtype, is_qat=is_qat, act_observer=MinMaxObserver
197        )
198        # pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`.
199        custom_annotations = (custom_annotate_llama_matmul_16a8w,)
200    else:
201        raise AssertionError(
202            f"No support for quant type {quant_config}. Support 8a8w, 16a16w and 16a4w."
203        )
204
205    assert (
206        quantization_mode is None
207    ), "Currently qnn backend only supports QnnQuantizer via pt2e flow"
208    qnn_quantizer.add_custom_quant_annotations(custom_annotations)
209
210    return qnn_quantizer, quant_dtype
211
212
213def get_coreml_quantizer(pt2e_quantize: str):
214    try:
215        from coremltools.optimize.torch.quantization.quantization_config import (
216            LinearQuantizerConfig,
217            QuantizationScheme,
218        )
219
220        # pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.apple.coreml.quantizer`.
221        from executorch.backends.apple.coreml.quantizer import CoreMLQuantizer
222    except ImportError:
223        raise ImportError(
224            "Please install the CoreML backend follwing https://pytorch.org/executorch/main/build-run-coreml.html"
225        )
226
227    if pt2e_quantize == "coreml_8a_c8w":
228        config = LinearQuantizerConfig.from_dict(
229            {
230                "global_config": {
231                    "quantization_scheme": QuantizationScheme.affine,
232                    "activation_dtype": torch.quint8,
233                    "weight_dtype": torch.qint8,
234                    "weight_per_channel": True,
235                }
236            }
237        )
238        # pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `apple`.
239        quantizer = CoreMLQuantizer(config)
240
241    elif pt2e_quantize in ("coreml_c4w", "coreml_8a_c4w"):
242        raise NotImplementedError("4-bit Core ML quantizer is still under development")
243
244    elif pt2e_quantize == "coreml_baseline_8a_c8w":
245        config = get_symmetric_quantization_config(
246            is_per_channel=True, is_dynamic=False
247        )
248        quantizer = XNNPACKQuantizer().set_global(config)
249
250    elif pt2e_quantize == "coreml_baseline_8a_c4w":
251        config = get_symmetric_quantization_config(
252            is_per_channel=True, is_dynamic=False, weight_qmin=-8, weight_qmax=7
253        )
254        quantizer = XNNPACKQuantizer().set_global(config)
255
256    else:
257        raise ValueError(f"Unsupported Core ML quantizer specification {pt2e_quantize}")
258
259    return quantizer
260
261
262def get_vulkan_quantizer(pt2e_quantize: str):
263    from executorch.backends.vulkan.quantizer.vulkan_quantizer import (
264        get_weight_quantization_config,
265        VulkanQuantizer,
266    )
267
268    if pt2e_quantize == "vulkan_8w":
269        config = get_weight_quantization_config(
270            is_per_channel=True,
271            weight_qmin=-128,
272            weight_qmax=127,
273        )
274    else:
275        raise ValueError(f"Unsupported Vulkan quantizer specification {pt2e_quantize}")
276
277    quantizer = VulkanQuantizer().set_global(config)
278    return quantizer
279