1from dataclasses import dataclass 2from typing import Any, Callable, Dict, List, Optional, Tuple 3 4import torch 5from torch import Tensor 6from torch.ao.quantization.fake_quantize import ( 7 FakeQuantize, 8 FusedMovingAvgObsFakeQuantize, 9) 10from torch.ao.quantization.observer import ( 11 MinMaxObserver, 12 MovingAverageMinMaxObserver, 13 MovingAveragePerChannelMinMaxObserver, 14 PerChannelMinMaxObserver, 15) 16from torch.ao.quantization.quantizer import DerivedQuantizationSpec, QuantizationSpec 17from torch.fx import Node 18 19 20@dataclass(eq=True, frozen=True) 21class QuantizationConfig: 22 input_activation: Optional[QuantizationSpec] 23 output_activation: Optional[QuantizationSpec] 24 weight: Optional[QuantizationSpec] 25 bias: Optional[QuantizationSpec | Callable] 26 27 28def _derived_bias_quant_spec(node: Node) -> DerivedQuantizationSpec: 29 def _derive_bias_qparams_fn( 30 obs_or_fqs: List, 31 ) -> Tuple[Tensor, Tensor]: 32 assert ( 33 len(obs_or_fqs) == 2 34 ), f"Expecting two obs/fqs, one for activation and one for weight, got: {len(obs_or_fqs)}" 35 act_obs_or_fq = obs_or_fqs[0] 36 weight_obs_or_fq = obs_or_fqs[1] 37 weight_scale, weight_zp = weight_obs_or_fq.calculate_qparams() 38 act_scale, act_zp = act_obs_or_fq.calculate_qparams() 39 (broadcast_act_scale, broadcast_weight_scale) = torch.broadcast_tensors( 40 act_scale, weight_scale 41 ) 42 derived_scale = (broadcast_act_scale * broadcast_weight_scale).to(torch.float32) 43 derived_zero = torch.zeros(derived_scale.size()).to(torch.int32) 44 return (derived_scale, derived_zero) 45 46 input_act = node.args[0] 47 assert isinstance(input_act, Node) 48 weight = node.args[1] 49 assert isinstance(weight, Node) 50 51 return DerivedQuantizationSpec( 52 derived_from=[(input_act, node), (weight, node)], 53 derive_qparams_fn=_derive_bias_qparams_fn, 54 dtype=torch.int32, 55 quant_min=torch.iinfo(torch.int32).min, 56 quant_max=torch.iinfo(torch.int32).max, 57 ch_axis=0, 58 qscheme=torch.per_channel_symmetric, 59 ) 60 61 62def get_8a8w_qnn_ptq_config( 63 act_symmetric: bool = False, act_observer=MovingAverageMinMaxObserver 64) -> QuantizationConfig: 65 extra_args: Dict[str, Any] = {"eps": 2**-12} 66 67 act_quantization_spec = QuantizationSpec( 68 dtype=torch.uint8, 69 qscheme=( 70 torch.per_tensor_symmetric if act_symmetric else torch.per_tensor_affine 71 ), 72 ch_axis=0, 73 observer_or_fake_quant_ctr=act_observer.with_args(**extra_args), 74 ) 75 76 weight_quantization_spec = QuantizationSpec( 77 dtype=torch.int8, 78 quant_min=torch.iinfo(torch.int8).min + 1, 79 quant_max=torch.iinfo(torch.int8).max, 80 qscheme=torch.per_tensor_symmetric, 81 ch_axis=0, 82 observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), 83 ) 84 85 bias_quantization_spec = QuantizationSpec( 86 dtype=torch.int32, 87 quant_min=torch.iinfo(torch.int32).min, 88 quant_max=torch.iinfo(torch.int32).max, 89 qscheme=torch.per_tensor_symmetric, 90 observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), 91 ) 92 93 quantization_config = QuantizationConfig( 94 input_activation=act_quantization_spec, 95 output_activation=act_quantization_spec, 96 weight=weight_quantization_spec, 97 bias=bias_quantization_spec, 98 ) 99 100 return quantization_config 101 102 103# 4 bits quantization only supports specific ops. 104def get_16a4w_qnn_ptq_config( 105 act_observer=MovingAverageMinMaxObserver, 106) -> QuantizationConfig: 107 extra_args: Dict[str, Any] = {"eps": 2**-20} 108 act_quantization_spec = QuantizationSpec( 109 dtype=torch.int32, 110 quant_min=torch.iinfo(torch.uint16).min, 111 quant_max=torch.iinfo(torch.uint16).max, 112 qscheme=torch.per_tensor_affine, 113 observer_or_fake_quant_ctr=act_observer.with_args(**extra_args), 114 ) 115 116 weight_quantization_spec = QuantizationSpec( 117 dtype=torch.int8, 118 quant_min=-7, 119 quant_max=7, 120 qscheme=torch.per_tensor_symmetric, 121 ch_axis=0, 122 observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), 123 ) 124 125 bias_quantization_spec = QuantizationSpec( 126 dtype=torch.int32, 127 quant_min=torch.iinfo(torch.int32).min, 128 quant_max=torch.iinfo(torch.int32).max, 129 qscheme=torch.per_tensor_symmetric, 130 observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), 131 ) 132 133 quantization_config = QuantizationConfig( 134 input_activation=act_quantization_spec, 135 output_activation=act_quantization_spec, 136 weight=weight_quantization_spec, 137 bias=bias_quantization_spec, 138 ) 139 140 return quantization_config 141 142 143def get_16a8w_qnn_ptq_config( 144 act_observer=MovingAverageMinMaxObserver, 145) -> QuantizationConfig: 146 extra_args: Dict[str, Any] = {"eps": 2**-20} 147 act_quantization_spec = QuantizationSpec( 148 dtype=torch.int32, 149 quant_min=torch.iinfo(torch.uint16).min, 150 quant_max=torch.iinfo(torch.uint16).max, 151 qscheme=torch.per_tensor_affine, 152 observer_or_fake_quant_ctr=act_observer.with_args(**extra_args), 153 ) 154 155 weight_quantization_spec = QuantizationSpec( 156 dtype=torch.uint8, 157 qscheme=torch.per_tensor_symmetric, 158 ch_axis=0, 159 observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), 160 ) 161 162 bias_quantization_spec = QuantizationSpec( 163 dtype=torch.int32, 164 quant_min=torch.iinfo(torch.int32).min, 165 quant_max=torch.iinfo(torch.int32).max, 166 qscheme=torch.per_tensor_symmetric, 167 observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), 168 ) 169 170 quantization_config = QuantizationConfig( 171 input_activation=act_quantization_spec, 172 output_activation=act_quantization_spec, 173 weight=weight_quantization_spec, 174 bias=bias_quantization_spec, 175 ) 176 177 return quantization_config 178 179 180def get_16a16w_qnn_ptq_config( 181 act_observer=MovingAverageMinMaxObserver, 182) -> QuantizationConfig: 183 extra_args: Dict[str, Any] = {"eps": 2**-20} 184 act_quantization_spec = QuantizationSpec( 185 dtype=torch.int32, 186 quant_min=torch.iinfo(torch.uint16).min, 187 quant_max=torch.iinfo(torch.uint16).max, 188 qscheme=torch.per_tensor_affine, 189 observer_or_fake_quant_ctr=act_observer.with_args(**extra_args), 190 ) 191 192 weight_quantization_spec = QuantizationSpec( 193 dtype=torch.int16, 194 quant_min=torch.iinfo(torch.int16).min + 1, 195 quant_max=torch.iinfo(torch.int16).max, 196 qscheme=torch.per_tensor_symmetric, 197 ch_axis=0, 198 observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), 199 ) 200 201 # torch does not support uint16 quantization, use int32 to bypass 202 bias_quantization_spec = QuantizationSpec( 203 dtype=torch.int32, 204 quant_min=torch.iinfo(torch.int32).min, 205 quant_max=torch.iinfo(torch.int32).max, 206 qscheme=torch.per_tensor_symmetric, 207 observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), 208 ) 209 210 quantization_config = QuantizationConfig( 211 input_activation=act_quantization_spec, 212 output_activation=act_quantization_spec, 213 weight=weight_quantization_spec, 214 bias=bias_quantization_spec, 215 ) 216 217 return quantization_config 218 219 220def get_ptq_per_channel_quant_config( 221 act_dtype=torch.uint8, 222 weight_dtype=torch.int8, 223 act_observer=MovingAverageMinMaxObserver, 224) -> QuantizationConfig: 225 extra_args: Dict[str, Any] = {"eps": 2**-12} 226 227 supported_act_types = { 228 torch.uint8, 229 torch.uint16, 230 torch.int8, 231 torch.int16, 232 } 233 # TODO accept "int4" temporally. Remove "int4" when torch support torch.int4 dtype 234 supported_weight_dtypes = {"int4", torch.int8, torch.int16} 235 assert ( 236 act_dtype in supported_act_types 237 ), f"act_dtype, {act_dtype} is not one of supported types, {supported_act_types}" 238 239 assert ( 240 weight_dtype in supported_weight_dtypes 241 ), f"weight_dtype, {weight_dtype} is not one of supported types, {supported_weight_dtypes}" 242 243 # torch do not support uint16 quantization, use int32 to bypass 244 act_quantization_spec = QuantizationSpec( 245 dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype, 246 quant_min=torch.iinfo(act_dtype).min, 247 quant_max=torch.iinfo(act_dtype).max, 248 qscheme=torch.per_tensor_affine, 249 observer_or_fake_quant_ctr=act_observer.with_args(**extra_args), 250 ) 251 252 weight_quantization_spec = QuantizationSpec( 253 dtype=torch.int8 if weight_dtype == "int4" else weight_dtype, 254 quant_min=-7 if weight_dtype == "int4" else torch.iinfo(weight_dtype).min + 1, 255 quant_max=7 if weight_dtype == "int4" else torch.iinfo(weight_dtype).max, 256 qscheme=torch.per_channel_symmetric, 257 ch_axis=0, 258 observer_or_fake_quant_ctr=PerChannelMinMaxObserver.with_args(**extra_args), 259 ) 260 261 bias_quantization_spec = _derived_bias_quant_spec 262 263 quantization_config = QuantizationConfig( 264 input_activation=act_quantization_spec, 265 output_activation=act_quantization_spec, 266 weight=weight_quantization_spec, 267 bias=bias_quantization_spec, 268 ) 269 270 return quantization_config 271 272 273# TODO merge qat and ptq to a fucntion, and use a bool flag to control it 274def get_8a8w_qnn_qat_config( 275 act_symmetric: bool = False, act_observer=MovingAverageMinMaxObserver 276) -> QuantizationConfig: 277 act_fake_quant_ctr = FakeQuantize.with_args( 278 dtype=torch.uint8, 279 qscheme=( 280 torch.per_tensor_symmetric if act_symmetric else torch.per_tensor_affine 281 ), 282 reduce_range=True, 283 observer=act_observer, 284 ) 285 act_quantization_spec = QuantizationSpec( 286 dtype=torch.uint8, 287 qscheme=( 288 torch.per_tensor_symmetric if act_symmetric else torch.per_tensor_affine 289 ), 290 ch_axis=0, 291 observer_or_fake_quant_ctr=act_fake_quant_ctr, 292 ) 293 294 weight_fake_quant_ctr = FusedMovingAvgObsFakeQuantize.with_args( 295 dtype=torch.int8, 296 quant_min=torch.iinfo(torch.int8).min + 1, 297 quant_max=torch.iinfo(torch.int8).max, 298 qscheme=torch.per_tensor_symmetric, 299 reduce_range=True, 300 observer=MovingAverageMinMaxObserver, 301 ) 302 weight_quantization_spec = QuantizationSpec( 303 dtype=torch.int8, 304 quant_min=torch.iinfo(torch.int8).min + 1, 305 quant_max=torch.iinfo(torch.int8).max, 306 qscheme=torch.per_tensor_symmetric, 307 ch_axis=0, 308 observer_or_fake_quant_ctr=weight_fake_quant_ctr, 309 ) 310 311 bias_fake_quant_ctr = FakeQuantize.with_args( 312 dtype=torch.int32, 313 quant_min=torch.iinfo(torch.int32).min, 314 quant_max=torch.iinfo(torch.int32).max, 315 qscheme=torch.per_tensor_symmetric, 316 reduce_range=True, 317 observer=MovingAverageMinMaxObserver, 318 ) 319 bias_quantization_spec = QuantizationSpec( 320 dtype=torch.int32, 321 quant_min=torch.iinfo(torch.int32).min, 322 quant_max=torch.iinfo(torch.int32).max, 323 qscheme=torch.per_tensor_symmetric, 324 observer_or_fake_quant_ctr=bias_fake_quant_ctr, 325 ) 326 327 quantization_config = QuantizationConfig( 328 input_activation=act_quantization_spec, 329 output_activation=act_quantization_spec, 330 weight=weight_quantization_spec, 331 bias=bias_quantization_spec, 332 ) 333 334 return quantization_config 335 336 337def get_16a4w_qnn_qat_config( 338 act_observer=MovingAverageMinMaxObserver, 339) -> QuantizationConfig: 340 act_fake_quant_ctr = FakeQuantize.with_args( 341 dtype=torch.int32, 342 quant_min=torch.iinfo(torch.uint16).min, 343 quant_max=torch.iinfo(torch.uint16).max, 344 qscheme=torch.per_tensor_affine, 345 reduce_range=True, 346 observer=act_observer, 347 ) 348 act_quantization_spec = QuantizationSpec( 349 dtype=torch.int32, 350 quant_min=torch.iinfo(torch.uint16).min, 351 quant_max=torch.iinfo(torch.uint16).max, 352 qscheme=torch.per_tensor_affine, 353 observer_or_fake_quant_ctr=act_fake_quant_ctr, 354 ) 355 356 weight_fake_quant_ctr = FusedMovingAvgObsFakeQuantize.with_args( 357 dtype=torch.int8, 358 quant_min=-7, 359 quant_max=7, 360 qscheme=torch.per_tensor_symmetric, 361 ch_axis=0, 362 reduce_range=True, 363 observer=MovingAverageMinMaxObserver, 364 ) 365 weight_quantization_spec = QuantizationSpec( 366 dtype=torch.int8, 367 quant_min=-7, 368 quant_max=7, 369 qscheme=torch.per_tensor_symmetric, 370 ch_axis=0, 371 observer_or_fake_quant_ctr=weight_fake_quant_ctr, 372 ) 373 374 bias_fake_quant_ctr = FakeQuantize.with_args( 375 dtype=torch.int32, 376 quant_min=torch.iinfo(torch.int32).min, 377 quant_max=torch.iinfo(torch.int32).max, 378 qscheme=torch.per_tensor_symmetric, 379 reduce_range=True, 380 observer=MovingAverageMinMaxObserver, 381 ) 382 bias_quantization_spec = QuantizationSpec( 383 dtype=torch.int32, 384 quant_min=torch.iinfo(torch.int32).min, 385 quant_max=torch.iinfo(torch.int32).max, 386 qscheme=torch.per_tensor_symmetric, 387 observer_or_fake_quant_ctr=bias_fake_quant_ctr, 388 ) 389 390 quantization_config = QuantizationConfig( 391 input_activation=act_quantization_spec, 392 output_activation=act_quantization_spec, 393 weight=weight_quantization_spec, 394 bias=bias_quantization_spec, 395 ) 396 397 return quantization_config 398 399 400def get_qat_per_channel_quant_config( 401 act_dtype=torch.uint8, 402 weight_dtype=torch.int8, 403 act_observer=MovingAverageMinMaxObserver, 404) -> QuantizationConfig: 405 supported_act_types = { 406 torch.uint8, 407 torch.uint16, 408 torch.int8, 409 torch.int16, 410 } 411 # TODO accept "int4" temporally. Remove "int4" when torch support torch.int4 dtype 412 supported_weight_dtypes = {"int4", torch.int8, torch.int16} 413 assert ( 414 act_dtype in supported_act_types 415 ), f"act_dtype, {act_dtype} is not one of supported types, {supported_act_types}" 416 417 assert ( 418 weight_dtype in supported_weight_dtypes 419 ), f"weight_dtype, {weight_dtype} is not one of supported types, {supported_weight_dtypes}" 420 421 # torch do not support uint16 quantization, use int32 to bypass 422 act_fake_quant_ctr = FakeQuantize.with_args( 423 dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype, 424 quant_min=torch.iinfo(act_dtype).min, 425 quant_max=torch.iinfo(act_dtype).max, 426 qscheme=torch.per_tensor_affine, 427 reduce_range=True, 428 observer=act_observer, 429 ) 430 act_quantization_spec = QuantizationSpec( 431 dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype, 432 quant_min=torch.iinfo(act_dtype).min, 433 quant_max=torch.iinfo(act_dtype).max, 434 qscheme=torch.per_tensor_affine, 435 observer_or_fake_quant_ctr=act_fake_quant_ctr, 436 ) 437 438 weight_fake_quant_ctr = FusedMovingAvgObsFakeQuantize.with_args( 439 dtype=torch.int8 if weight_dtype == "int4" else weight_dtype, 440 quant_min=-7 if weight_dtype == "int4" else torch.iinfo(weight_dtype).min + 1, 441 quant_max=7 if weight_dtype == "int4" else torch.iinfo(weight_dtype).max, 442 qscheme=torch.per_channel_symmetric, 443 ch_axis=0, 444 observer=MovingAveragePerChannelMinMaxObserver, 445 ) 446 weight_quantization_spec = QuantizationSpec( 447 dtype=torch.int8 if weight_dtype == "int4" else weight_dtype, 448 quant_min=-7 if weight_dtype == "int4" else torch.iinfo(weight_dtype).min + 1, 449 quant_max=7 if weight_dtype == "int4" else torch.iinfo(weight_dtype).max, 450 qscheme=torch.per_channel_symmetric, 451 ch_axis=0, 452 observer_or_fake_quant_ctr=weight_fake_quant_ctr, 453 ) 454 455 bias_quantization_spec = _derived_bias_quant_spec 456 457 quantization_config = QuantizationConfig( 458 input_activation=act_quantization_spec, 459 output_activation=act_quantization_spec, 460 weight=weight_quantization_spec, 461 bias=bias_quantization_spec, 462 ) 463 464 return quantization_config 465