1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# This is for PT2E quantization. 8 9import logging 10from dataclasses import dataclass 11from typing import List, Optional 12 13import torch 14 15from torch.ao.quantization.quantizer import Quantizer 16from torch.ao.quantization.quantizer.embedding_quantizer import EmbeddingQuantizer 17from torch.ao.quantization.quantizer.xnnpack_quantizer import ( 18 get_symmetric_quantization_config, 19 XNNPACKQuantizer, 20) 21 22FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" 23logging.basicConfig(level=logging.INFO, format=FORMAT) 24 25 26@dataclass 27class EmbeddingQuantOptions: 28 is_per_channel: bool = True 29 group_size: int = -1 30 31 def __post_init__(self): 32 if self.group_size != -1: 33 raise RuntimeError( 34 "PT2E embedding quantizer does not support groupwise at the moment." 35 ) 36 37 38@dataclass 39class DynamicQuantLinearOptions: 40 is_per_channel: bool = True 41 is_qc4: bool = False 42 43 44@dataclass 45class PT2EQuantOptions: 46 quantize_embedding: Optional[EmbeddingQuantOptions] = None 47 quantize_linear: Optional[DynamicQuantLinearOptions] = None 48 49 50def get_pt2e_quantization_params( 51 pt2e_quantize: Optional[str] = None, 52 quantization_mode: Optional[str] = None, 53) -> Optional[PT2EQuantOptions]: 54 if pt2e_quantize is None: 55 return None 56 if quantization_mode: 57 raise ValueError("Cannot specify both quantization_mode and pt2e_quantize") 58 59 quantization_options = pt2e_quantize.split(",") 60 quantization_options = [option.strip() for option in quantization_options] 61 # This can really be improved significantly. 62 # Hopefully we dont release this in its current form. 63 # Just using this for quick experiments. 64 quant_options = None 65 if "embedding" in quantization_options: 66 quant_options = quant_options or PT2EQuantOptions() 67 quant_options.quantize_embedding = EmbeddingQuantOptions() 68 if ( 69 "xnnpack_dynamic" in quantization_options 70 and "xnnpack_dynamic_qc4" in quantization_options 71 ): 72 raise RuntimeError( 73 "For dynamic linear quantization via xnnpack quantizer you can chose only qc8 or qc4 option, not both." 74 ) 75 if ( 76 "xnnpack_dynamic" in quantization_options 77 or "xnnpack_dynamic_qc4" in quantization_options 78 ): 79 quant_options = quant_options or PT2EQuantOptions() 80 quant_options.quantize_linear = DynamicQuantLinearOptions() 81 if "xnnpack_dynamic_qc4" in quantization_options: 82 quant_options.quantize_linear.is_qc4 = True 83 84 return quant_options 85 86 87def get_pt2e_quantizers( 88 quant_params: Optional[PT2EQuantOptions], 89 so_library: Optional[str] = None, 90) -> List[Quantizer]: 91 """ 92 Get a list of quantizers from quantization params 93 Args: 94 quant_params: PT2E quantization options. 95 Returns: 96 A list of quantizers to pass into LlamaBuilder. 97 """ 98 99 def check_embedding_byte_registered(): 100 try: 101 _ = torch.ops.quantized_decomposed.embedding_byte.out 102 except AttributeError: 103 if so_library: 104 print(f"Loading library {so_library}") 105 torch.ops.load_library(so_library) 106 else: 107 raise RuntimeError( 108 "Need to specify shared library path to register quantized ops (and their out variants) into EXIR.\n" 109 "Follow the following steps to build the needed lib via cmake.\n" 110 'Use `python -c "import torch as _; print(_.__path__)"` to find where torch package is installed.\n' 111 "Set that as TORCH_PACKAGE_DIR.\n" 112 "Then from root executorch dir do the following:\n" 113 "rm -rf cmake-out && mkdir cmake-out && (cd cmake-out && cmake -DBUCK2=<path-to-buck2> -DCMAKE_PREFIX_PATH=$TORCH_PACKAGE_DIR -DEXECUTORCH_BUILD_KERNELS_QUANTIZED_AOT=ON ..) && cmake --build . -j16\n" 114 'To find the location of the lib: find cmake-out -name "libquantized_ops_aot_lib*"\n' 115 "Then specify the said library via -s <path to libquantized_ops_aot_lib.so\n" 116 ) 117 118 quantizers = [] 119 if quant_params is not None and quant_params.quantize_embedding is not None: 120 logging.info("Apply PT2E embedding quantization.") 121 check_embedding_byte_registered() 122 quantizers.append(EmbeddingQuantizer()) 123 if quant_params is not None and quant_params.quantize_linear is not None: 124 logging.info("Apply PT2E dynamic linear quantization.") 125 dynamic_quantizer = XNNPACKQuantizer() 126 assert quant_params.quantize_linear is not None 127 if not quant_params.quantize_linear.is_per_channel: 128 raise ValueError( 129 "At the moment only per channel weight quantization is supported." 130 ) 131 if quant_params.quantize_linear.is_qc4: 132 operator_config_dynamic = get_symmetric_quantization_config( 133 is_per_channel=True, is_dynamic=True, weight_qmin=-8, weight_qmax=7 134 ) 135 else: 136 operator_config_dynamic = get_symmetric_quantization_config( 137 is_per_channel=True, is_dynamic=True 138 ) 139 dynamic_quantizer.set_global(operator_config_dynamic) 140 quantizers.append(dynamic_quantizer) 141 return quantizers 142 143 144def get_qnn_quantizer( 145 pt2e_quantize: str, 146 quantization_mode: Optional[str] = None, 147 is_qat: bool = False, 148): 149 try: 150 from executorch.backends.qualcomm.quantizer.custom_annotation import ( # pyre-fixme[21] 151 custom_annotate_llama_matmul_16a8w, 152 ) 153 154 # pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm.quantizer.quantizer` 155 from executorch.backends.qualcomm.quantizer.quantizer import ( 156 QnnQuantizer, 157 QuantDtype, 158 ) 159 from torch.ao.quantization.observer import MinMaxObserver 160 161 except ImportError: 162 raise ImportError( 163 "Please install the Qualcomm backend follwing https://pytorch.org/executorch/main/build-run-qualcomm.html" 164 ) 165 166 backend, quant_config = pt2e_quantize.split("_") 167 assert ( 168 backend == "qnn" 169 ), f"The quantization config is for backend {backend} instead of qnn." 170 qnn_quantizer = QnnQuantizer() # pyre-fixme[16] 171 qnn_quantizer.set_per_channel_conv_quant(enable=True) 172 qnn_quantizer.set_per_channel_linear_quant(enable=True) 173 # more custom quantization are supported including 16a4w etc. default to 8bit quantized 174 custom_annotations = () 175 if quant_config == "8a8w": 176 quant_dtype = QuantDtype.use_8a8w # pyre-fixme[16] 177 qnn_quantizer.set_quant_config(quant_dtype, is_qat=is_qat) 178 elif quant_config == "16a16w": 179 quant_dtype = QuantDtype.use_16a16w # pyre-fixme[16] 180 # Due to the error with 16a16w in Qnn Htp, we need to disable per channel linear quantization when use 16a16w 181 # TODO: enable it after the issue is fixed 182 logging.warning( 183 "Disable per channel quantization for linear and conv due to the error with QNN HTP 16a16w." 184 ) 185 qnn_quantizer.set_per_channel_conv_quant(enable=False) 186 qnn_quantizer.set_per_channel_linear_quant(enable=False) 187 # pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`. 188 qnn_quantizer.set_quant_config( 189 quant_dtype, is_qat=is_qat, act_observer=MinMaxObserver 190 ) 191 elif quant_config == "16a4w": 192 # pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`. 193 quant_dtype = QuantDtype.use_16a4w 194 # pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`. 195 qnn_quantizer.set_quant_config( 196 quant_dtype, is_qat=is_qat, act_observer=MinMaxObserver 197 ) 198 # pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`. 199 custom_annotations = (custom_annotate_llama_matmul_16a8w,) 200 else: 201 raise AssertionError( 202 f"No support for quant type {quant_config}. Support 8a8w, 16a16w and 16a4w." 203 ) 204 205 assert ( 206 quantization_mode is None 207 ), "Currently qnn backend only supports QnnQuantizer via pt2e flow" 208 qnn_quantizer.add_custom_quant_annotations(custom_annotations) 209 210 return qnn_quantizer, quant_dtype 211 212 213def get_coreml_quantizer(pt2e_quantize: str): 214 try: 215 from coremltools.optimize.torch.quantization.quantization_config import ( 216 LinearQuantizerConfig, 217 QuantizationScheme, 218 ) 219 220 # pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.apple.coreml.quantizer`. 221 from executorch.backends.apple.coreml.quantizer import CoreMLQuantizer 222 except ImportError: 223 raise ImportError( 224 "Please install the CoreML backend follwing https://pytorch.org/executorch/main/build-run-coreml.html" 225 ) 226 227 if pt2e_quantize == "coreml_8a_c8w": 228 config = LinearQuantizerConfig.from_dict( 229 { 230 "global_config": { 231 "quantization_scheme": QuantizationScheme.affine, 232 "activation_dtype": torch.quint8, 233 "weight_dtype": torch.qint8, 234 "weight_per_channel": True, 235 } 236 } 237 ) 238 # pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `apple`. 239 quantizer = CoreMLQuantizer(config) 240 241 elif pt2e_quantize in ("coreml_c4w", "coreml_8a_c4w"): 242 raise NotImplementedError("4-bit Core ML quantizer is still under development") 243 244 elif pt2e_quantize == "coreml_baseline_8a_c8w": 245 config = get_symmetric_quantization_config( 246 is_per_channel=True, is_dynamic=False 247 ) 248 quantizer = XNNPACKQuantizer().set_global(config) 249 250 elif pt2e_quantize == "coreml_baseline_8a_c4w": 251 config = get_symmetric_quantization_config( 252 is_per_channel=True, is_dynamic=False, weight_qmin=-8, weight_qmax=7 253 ) 254 quantizer = XNNPACKQuantizer().set_global(config) 255 256 else: 257 raise ValueError(f"Unsupported Core ML quantizer specification {pt2e_quantize}") 258 259 return quantizer 260 261 262def get_vulkan_quantizer(pt2e_quantize: str): 263 from executorch.backends.vulkan.quantizer.vulkan_quantizer import ( 264 get_weight_quantization_config, 265 VulkanQuantizer, 266 ) 267 268 if pt2e_quantize == "vulkan_8w": 269 config = get_weight_quantization_config( 270 is_per_channel=True, 271 weight_qmin=-128, 272 weight_qmax=127, 273 ) 274 else: 275 raise ValueError(f"Unsupported Vulkan quantizer specification {pt2e_quantize}") 276 277 quantizer = VulkanQuantizer().set_global(config) 278 return quantizer 279