1*523fa7a6SAndroid Build Coastguard Workerfrom dataclasses import dataclass 2*523fa7a6SAndroid Build Coastguard Workerfrom typing import Any, Callable, Dict, List, Optional, Tuple 3*523fa7a6SAndroid Build Coastguard Worker 4*523fa7a6SAndroid Build Coastguard Workerimport torch 5*523fa7a6SAndroid Build Coastguard Workerfrom torch import Tensor 6*523fa7a6SAndroid Build Coastguard Workerfrom torch.ao.quantization.fake_quantize import ( 7*523fa7a6SAndroid Build Coastguard Worker FakeQuantize, 8*523fa7a6SAndroid Build Coastguard Worker FusedMovingAvgObsFakeQuantize, 9*523fa7a6SAndroid Build Coastguard Worker) 10*523fa7a6SAndroid Build Coastguard Workerfrom torch.ao.quantization.observer import ( 11*523fa7a6SAndroid Build Coastguard Worker MinMaxObserver, 12*523fa7a6SAndroid Build Coastguard Worker MovingAverageMinMaxObserver, 13*523fa7a6SAndroid Build Coastguard Worker MovingAveragePerChannelMinMaxObserver, 14*523fa7a6SAndroid Build Coastguard Worker PerChannelMinMaxObserver, 15*523fa7a6SAndroid Build Coastguard Worker) 16*523fa7a6SAndroid Build Coastguard Workerfrom torch.ao.quantization.quantizer import DerivedQuantizationSpec, QuantizationSpec 17*523fa7a6SAndroid Build Coastguard Workerfrom torch.fx import Node 18*523fa7a6SAndroid Build Coastguard Worker 19*523fa7a6SAndroid Build Coastguard Worker 20*523fa7a6SAndroid Build Coastguard Worker@dataclass(eq=True, frozen=True) 21*523fa7a6SAndroid Build Coastguard Workerclass QuantizationConfig: 22*523fa7a6SAndroid Build Coastguard Worker input_activation: Optional[QuantizationSpec] 23*523fa7a6SAndroid Build Coastguard Worker output_activation: Optional[QuantizationSpec] 24*523fa7a6SAndroid Build Coastguard Worker weight: Optional[QuantizationSpec] 25*523fa7a6SAndroid Build Coastguard Worker bias: Optional[QuantizationSpec | Callable] 26*523fa7a6SAndroid Build Coastguard Worker 27*523fa7a6SAndroid Build Coastguard Worker 28*523fa7a6SAndroid Build Coastguard Workerdef _derived_bias_quant_spec(node: Node) -> DerivedQuantizationSpec: 29*523fa7a6SAndroid Build Coastguard Worker def _derive_bias_qparams_fn( 30*523fa7a6SAndroid Build Coastguard Worker obs_or_fqs: List, 31*523fa7a6SAndroid Build Coastguard Worker ) -> Tuple[Tensor, Tensor]: 32*523fa7a6SAndroid Build Coastguard Worker assert ( 33*523fa7a6SAndroid Build Coastguard Worker len(obs_or_fqs) == 2 34*523fa7a6SAndroid Build Coastguard Worker ), f"Expecting two obs/fqs, one for activation and one for weight, got: {len(obs_or_fqs)}" 35*523fa7a6SAndroid Build Coastguard Worker act_obs_or_fq = obs_or_fqs[0] 36*523fa7a6SAndroid Build Coastguard Worker weight_obs_or_fq = obs_or_fqs[1] 37*523fa7a6SAndroid Build Coastguard Worker weight_scale, weight_zp = weight_obs_or_fq.calculate_qparams() 38*523fa7a6SAndroid Build Coastguard Worker act_scale, act_zp = act_obs_or_fq.calculate_qparams() 39*523fa7a6SAndroid Build Coastguard Worker (broadcast_act_scale, broadcast_weight_scale) = torch.broadcast_tensors( 40*523fa7a6SAndroid Build Coastguard Worker act_scale, weight_scale 41*523fa7a6SAndroid Build Coastguard Worker ) 42*523fa7a6SAndroid Build Coastguard Worker derived_scale = (broadcast_act_scale * broadcast_weight_scale).to(torch.float32) 43*523fa7a6SAndroid Build Coastguard Worker derived_zero = torch.zeros(derived_scale.size()).to(torch.int32) 44*523fa7a6SAndroid Build Coastguard Worker return (derived_scale, derived_zero) 45*523fa7a6SAndroid Build Coastguard Worker 46*523fa7a6SAndroid Build Coastguard Worker input_act = node.args[0] 47*523fa7a6SAndroid Build Coastguard Worker assert isinstance(input_act, Node) 48*523fa7a6SAndroid Build Coastguard Worker weight = node.args[1] 49*523fa7a6SAndroid Build Coastguard Worker assert isinstance(weight, Node) 50*523fa7a6SAndroid Build Coastguard Worker 51*523fa7a6SAndroid Build Coastguard Worker return DerivedQuantizationSpec( 52*523fa7a6SAndroid Build Coastguard Worker derived_from=[(input_act, node), (weight, node)], 53*523fa7a6SAndroid Build Coastguard Worker derive_qparams_fn=_derive_bias_qparams_fn, 54*523fa7a6SAndroid Build Coastguard Worker dtype=torch.int32, 55*523fa7a6SAndroid Build Coastguard Worker quant_min=torch.iinfo(torch.int32).min, 56*523fa7a6SAndroid Build Coastguard Worker quant_max=torch.iinfo(torch.int32).max, 57*523fa7a6SAndroid Build Coastguard Worker ch_axis=0, 58*523fa7a6SAndroid Build Coastguard Worker qscheme=torch.per_channel_symmetric, 59*523fa7a6SAndroid Build Coastguard Worker ) 60*523fa7a6SAndroid Build Coastguard Worker 61*523fa7a6SAndroid Build Coastguard Worker 62*523fa7a6SAndroid Build Coastguard Workerdef get_8a8w_qnn_ptq_config( 63*523fa7a6SAndroid Build Coastguard Worker act_symmetric: bool = False, act_observer=MovingAverageMinMaxObserver 64*523fa7a6SAndroid Build Coastguard Worker) -> QuantizationConfig: 65*523fa7a6SAndroid Build Coastguard Worker extra_args: Dict[str, Any] = {"eps": 2**-12} 66*523fa7a6SAndroid Build Coastguard Worker 67*523fa7a6SAndroid Build Coastguard Worker act_quantization_spec = QuantizationSpec( 68*523fa7a6SAndroid Build Coastguard Worker dtype=torch.uint8, 69*523fa7a6SAndroid Build Coastguard Worker qscheme=( 70*523fa7a6SAndroid Build Coastguard Worker torch.per_tensor_symmetric if act_symmetric else torch.per_tensor_affine 71*523fa7a6SAndroid Build Coastguard Worker ), 72*523fa7a6SAndroid Build Coastguard Worker ch_axis=0, 73*523fa7a6SAndroid Build Coastguard Worker observer_or_fake_quant_ctr=act_observer.with_args(**extra_args), 74*523fa7a6SAndroid Build Coastguard Worker ) 75*523fa7a6SAndroid Build Coastguard Worker 76*523fa7a6SAndroid Build Coastguard Worker weight_quantization_spec = QuantizationSpec( 77*523fa7a6SAndroid Build Coastguard Worker dtype=torch.int8, 78*523fa7a6SAndroid Build Coastguard Worker quant_min=torch.iinfo(torch.int8).min + 1, 79*523fa7a6SAndroid Build Coastguard Worker quant_max=torch.iinfo(torch.int8).max, 80*523fa7a6SAndroid Build Coastguard Worker qscheme=torch.per_tensor_symmetric, 81*523fa7a6SAndroid Build Coastguard Worker ch_axis=0, 82*523fa7a6SAndroid Build Coastguard Worker observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), 83*523fa7a6SAndroid Build Coastguard Worker ) 84*523fa7a6SAndroid Build Coastguard Worker 85*523fa7a6SAndroid Build Coastguard Worker bias_quantization_spec = QuantizationSpec( 86*523fa7a6SAndroid Build Coastguard Worker dtype=torch.int32, 87*523fa7a6SAndroid Build Coastguard Worker quant_min=torch.iinfo(torch.int32).min, 88*523fa7a6SAndroid Build Coastguard Worker quant_max=torch.iinfo(torch.int32).max, 89*523fa7a6SAndroid Build Coastguard Worker qscheme=torch.per_tensor_symmetric, 90*523fa7a6SAndroid Build Coastguard Worker observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), 91*523fa7a6SAndroid Build Coastguard Worker ) 92*523fa7a6SAndroid Build Coastguard Worker 93*523fa7a6SAndroid Build Coastguard Worker quantization_config = QuantizationConfig( 94*523fa7a6SAndroid Build Coastguard Worker input_activation=act_quantization_spec, 95*523fa7a6SAndroid Build Coastguard Worker output_activation=act_quantization_spec, 96*523fa7a6SAndroid Build Coastguard Worker weight=weight_quantization_spec, 97*523fa7a6SAndroid Build Coastguard Worker bias=bias_quantization_spec, 98*523fa7a6SAndroid Build Coastguard Worker ) 99*523fa7a6SAndroid Build Coastguard Worker 100*523fa7a6SAndroid Build Coastguard Worker return quantization_config 101*523fa7a6SAndroid Build Coastguard Worker 102*523fa7a6SAndroid Build Coastguard Worker 103*523fa7a6SAndroid Build Coastguard Worker# 4 bits quantization only supports specific ops. 104*523fa7a6SAndroid Build Coastguard Workerdef get_16a4w_qnn_ptq_config( 105*523fa7a6SAndroid Build Coastguard Worker act_observer=MovingAverageMinMaxObserver, 106*523fa7a6SAndroid Build Coastguard Worker) -> QuantizationConfig: 107*523fa7a6SAndroid Build Coastguard Worker extra_args: Dict[str, Any] = {"eps": 2**-20} 108*523fa7a6SAndroid Build Coastguard Worker act_quantization_spec = QuantizationSpec( 109*523fa7a6SAndroid Build Coastguard Worker dtype=torch.int32, 110*523fa7a6SAndroid Build Coastguard Worker quant_min=torch.iinfo(torch.uint16).min, 111*523fa7a6SAndroid Build Coastguard Worker quant_max=torch.iinfo(torch.uint16).max, 112*523fa7a6SAndroid Build Coastguard Worker qscheme=torch.per_tensor_affine, 113*523fa7a6SAndroid Build Coastguard Worker observer_or_fake_quant_ctr=act_observer.with_args(**extra_args), 114*523fa7a6SAndroid Build Coastguard Worker ) 115*523fa7a6SAndroid Build Coastguard Worker 116*523fa7a6SAndroid Build Coastguard Worker weight_quantization_spec = QuantizationSpec( 117*523fa7a6SAndroid Build Coastguard Worker dtype=torch.int8, 118*523fa7a6SAndroid Build Coastguard Worker quant_min=-7, 119*523fa7a6SAndroid Build Coastguard Worker quant_max=7, 120*523fa7a6SAndroid Build Coastguard Worker qscheme=torch.per_tensor_symmetric, 121*523fa7a6SAndroid Build Coastguard Worker ch_axis=0, 122*523fa7a6SAndroid Build Coastguard Worker observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), 123*523fa7a6SAndroid Build Coastguard Worker ) 124*523fa7a6SAndroid Build Coastguard Worker 125*523fa7a6SAndroid Build Coastguard Worker bias_quantization_spec = QuantizationSpec( 126*523fa7a6SAndroid Build Coastguard Worker dtype=torch.int32, 127*523fa7a6SAndroid Build Coastguard Worker quant_min=torch.iinfo(torch.int32).min, 128*523fa7a6SAndroid Build Coastguard Worker quant_max=torch.iinfo(torch.int32).max, 129*523fa7a6SAndroid Build Coastguard Worker qscheme=torch.per_tensor_symmetric, 130*523fa7a6SAndroid Build Coastguard Worker observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), 131*523fa7a6SAndroid Build Coastguard Worker ) 132*523fa7a6SAndroid Build Coastguard Worker 133*523fa7a6SAndroid Build Coastguard Worker quantization_config = QuantizationConfig( 134*523fa7a6SAndroid Build Coastguard Worker input_activation=act_quantization_spec, 135*523fa7a6SAndroid Build Coastguard Worker output_activation=act_quantization_spec, 136*523fa7a6SAndroid Build Coastguard Worker weight=weight_quantization_spec, 137*523fa7a6SAndroid Build Coastguard Worker bias=bias_quantization_spec, 138*523fa7a6SAndroid Build Coastguard Worker ) 139*523fa7a6SAndroid Build Coastguard Worker 140*523fa7a6SAndroid Build Coastguard Worker return quantization_config 141*523fa7a6SAndroid Build Coastguard Worker 142*523fa7a6SAndroid Build Coastguard Worker 143*523fa7a6SAndroid Build Coastguard Workerdef get_16a8w_qnn_ptq_config( 144*523fa7a6SAndroid Build Coastguard Worker act_observer=MovingAverageMinMaxObserver, 145*523fa7a6SAndroid Build Coastguard Worker) -> QuantizationConfig: 146*523fa7a6SAndroid Build Coastguard Worker extra_args: Dict[str, Any] = {"eps": 2**-20} 147*523fa7a6SAndroid Build Coastguard Worker act_quantization_spec = QuantizationSpec( 148*523fa7a6SAndroid Build Coastguard Worker dtype=torch.int32, 149*523fa7a6SAndroid Build Coastguard Worker quant_min=torch.iinfo(torch.uint16).min, 150*523fa7a6SAndroid Build Coastguard Worker quant_max=torch.iinfo(torch.uint16).max, 151*523fa7a6SAndroid Build Coastguard Worker qscheme=torch.per_tensor_affine, 152*523fa7a6SAndroid Build Coastguard Worker observer_or_fake_quant_ctr=act_observer.with_args(**extra_args), 153*523fa7a6SAndroid Build Coastguard Worker ) 154*523fa7a6SAndroid Build Coastguard Worker 155*523fa7a6SAndroid Build Coastguard Worker weight_quantization_spec = QuantizationSpec( 156*523fa7a6SAndroid Build Coastguard Worker dtype=torch.uint8, 157*523fa7a6SAndroid Build Coastguard Worker qscheme=torch.per_tensor_symmetric, 158*523fa7a6SAndroid Build Coastguard Worker ch_axis=0, 159*523fa7a6SAndroid Build Coastguard Worker observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), 160*523fa7a6SAndroid Build Coastguard Worker ) 161*523fa7a6SAndroid Build Coastguard Worker 162*523fa7a6SAndroid Build Coastguard Worker bias_quantization_spec = QuantizationSpec( 163*523fa7a6SAndroid Build Coastguard Worker dtype=torch.int32, 164*523fa7a6SAndroid Build Coastguard Worker quant_min=torch.iinfo(torch.int32).min, 165*523fa7a6SAndroid Build Coastguard Worker quant_max=torch.iinfo(torch.int32).max, 166*523fa7a6SAndroid Build Coastguard Worker qscheme=torch.per_tensor_symmetric, 167*523fa7a6SAndroid Build Coastguard Worker observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), 168*523fa7a6SAndroid Build Coastguard Worker ) 169*523fa7a6SAndroid Build Coastguard Worker 170*523fa7a6SAndroid Build Coastguard Worker quantization_config = QuantizationConfig( 171*523fa7a6SAndroid Build Coastguard Worker input_activation=act_quantization_spec, 172*523fa7a6SAndroid Build Coastguard Worker output_activation=act_quantization_spec, 173*523fa7a6SAndroid Build Coastguard Worker weight=weight_quantization_spec, 174*523fa7a6SAndroid Build Coastguard Worker bias=bias_quantization_spec, 175*523fa7a6SAndroid Build Coastguard Worker ) 176*523fa7a6SAndroid Build Coastguard Worker 177*523fa7a6SAndroid Build Coastguard Worker return quantization_config 178*523fa7a6SAndroid Build Coastguard Worker 179*523fa7a6SAndroid Build Coastguard Worker 180*523fa7a6SAndroid Build Coastguard Workerdef get_16a16w_qnn_ptq_config( 181*523fa7a6SAndroid Build Coastguard Worker act_observer=MovingAverageMinMaxObserver, 182*523fa7a6SAndroid Build Coastguard Worker) -> QuantizationConfig: 183*523fa7a6SAndroid Build Coastguard Worker extra_args: Dict[str, Any] = {"eps": 2**-20} 184*523fa7a6SAndroid Build Coastguard Worker act_quantization_spec = QuantizationSpec( 185*523fa7a6SAndroid Build Coastguard Worker dtype=torch.int32, 186*523fa7a6SAndroid Build Coastguard Worker quant_min=torch.iinfo(torch.uint16).min, 187*523fa7a6SAndroid Build Coastguard Worker quant_max=torch.iinfo(torch.uint16).max, 188*523fa7a6SAndroid Build Coastguard Worker qscheme=torch.per_tensor_affine, 189*523fa7a6SAndroid Build Coastguard Worker observer_or_fake_quant_ctr=act_observer.with_args(**extra_args), 190*523fa7a6SAndroid Build Coastguard Worker ) 191*523fa7a6SAndroid Build Coastguard Worker 192*523fa7a6SAndroid Build Coastguard Worker weight_quantization_spec = QuantizationSpec( 193*523fa7a6SAndroid Build Coastguard Worker dtype=torch.int16, 194*523fa7a6SAndroid Build Coastguard Worker quant_min=torch.iinfo(torch.int16).min + 1, 195*523fa7a6SAndroid Build Coastguard Worker quant_max=torch.iinfo(torch.int16).max, 196*523fa7a6SAndroid Build Coastguard Worker qscheme=torch.per_tensor_symmetric, 197*523fa7a6SAndroid Build Coastguard Worker ch_axis=0, 198*523fa7a6SAndroid Build Coastguard Worker observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), 199*523fa7a6SAndroid Build Coastguard Worker ) 200*523fa7a6SAndroid Build Coastguard Worker 201*523fa7a6SAndroid Build Coastguard Worker # torch does not support uint16 quantization, use int32 to bypass 202*523fa7a6SAndroid Build Coastguard Worker bias_quantization_spec = QuantizationSpec( 203*523fa7a6SAndroid Build Coastguard Worker dtype=torch.int32, 204*523fa7a6SAndroid Build Coastguard Worker quant_min=torch.iinfo(torch.int32).min, 205*523fa7a6SAndroid Build Coastguard Worker quant_max=torch.iinfo(torch.int32).max, 206*523fa7a6SAndroid Build Coastguard Worker qscheme=torch.per_tensor_symmetric, 207*523fa7a6SAndroid Build Coastguard Worker observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), 208*523fa7a6SAndroid Build Coastguard Worker ) 209*523fa7a6SAndroid Build Coastguard Worker 210*523fa7a6SAndroid Build Coastguard Worker quantization_config = QuantizationConfig( 211*523fa7a6SAndroid Build Coastguard Worker input_activation=act_quantization_spec, 212*523fa7a6SAndroid Build Coastguard Worker output_activation=act_quantization_spec, 213*523fa7a6SAndroid Build Coastguard Worker weight=weight_quantization_spec, 214*523fa7a6SAndroid Build Coastguard Worker bias=bias_quantization_spec, 215*523fa7a6SAndroid Build Coastguard Worker ) 216*523fa7a6SAndroid Build Coastguard Worker 217*523fa7a6SAndroid Build Coastguard Worker return quantization_config 218*523fa7a6SAndroid Build Coastguard Worker 219*523fa7a6SAndroid Build Coastguard Worker 220*523fa7a6SAndroid Build Coastguard Workerdef get_ptq_per_channel_quant_config( 221*523fa7a6SAndroid Build Coastguard Worker act_dtype=torch.uint8, 222*523fa7a6SAndroid Build Coastguard Worker weight_dtype=torch.int8, 223*523fa7a6SAndroid Build Coastguard Worker act_observer=MovingAverageMinMaxObserver, 224*523fa7a6SAndroid Build Coastguard Worker) -> QuantizationConfig: 225*523fa7a6SAndroid Build Coastguard Worker extra_args: Dict[str, Any] = {"eps": 2**-12} 226*523fa7a6SAndroid Build Coastguard Worker 227*523fa7a6SAndroid Build Coastguard Worker supported_act_types = { 228*523fa7a6SAndroid Build Coastguard Worker torch.uint8, 229*523fa7a6SAndroid Build Coastguard Worker torch.uint16, 230*523fa7a6SAndroid Build Coastguard Worker torch.int8, 231*523fa7a6SAndroid Build Coastguard Worker torch.int16, 232*523fa7a6SAndroid Build Coastguard Worker } 233*523fa7a6SAndroid Build Coastguard Worker # TODO accept "int4" temporally. Remove "int4" when torch support torch.int4 dtype 234*523fa7a6SAndroid Build Coastguard Worker supported_weight_dtypes = {"int4", torch.int8, torch.int16} 235*523fa7a6SAndroid Build Coastguard Worker assert ( 236*523fa7a6SAndroid Build Coastguard Worker act_dtype in supported_act_types 237*523fa7a6SAndroid Build Coastguard Worker ), f"act_dtype, {act_dtype} is not one of supported types, {supported_act_types}" 238*523fa7a6SAndroid Build Coastguard Worker 239*523fa7a6SAndroid Build Coastguard Worker assert ( 240*523fa7a6SAndroid Build Coastguard Worker weight_dtype in supported_weight_dtypes 241*523fa7a6SAndroid Build Coastguard Worker ), f"weight_dtype, {weight_dtype} is not one of supported types, {supported_weight_dtypes}" 242*523fa7a6SAndroid Build Coastguard Worker 243*523fa7a6SAndroid Build Coastguard Worker # torch do not support uint16 quantization, use int32 to bypass 244*523fa7a6SAndroid Build Coastguard Worker act_quantization_spec = QuantizationSpec( 245*523fa7a6SAndroid Build Coastguard Worker dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype, 246*523fa7a6SAndroid Build Coastguard Worker quant_min=torch.iinfo(act_dtype).min, 247*523fa7a6SAndroid Build Coastguard Worker quant_max=torch.iinfo(act_dtype).max, 248*523fa7a6SAndroid Build Coastguard Worker qscheme=torch.per_tensor_affine, 249*523fa7a6SAndroid Build Coastguard Worker observer_or_fake_quant_ctr=act_observer.with_args(**extra_args), 250*523fa7a6SAndroid Build Coastguard Worker ) 251*523fa7a6SAndroid Build Coastguard Worker 252*523fa7a6SAndroid Build Coastguard Worker weight_quantization_spec = QuantizationSpec( 253*523fa7a6SAndroid Build Coastguard Worker dtype=torch.int8 if weight_dtype == "int4" else weight_dtype, 254*523fa7a6SAndroid Build Coastguard Worker quant_min=-7 if weight_dtype == "int4" else torch.iinfo(weight_dtype).min + 1, 255*523fa7a6SAndroid Build Coastguard Worker quant_max=7 if weight_dtype == "int4" else torch.iinfo(weight_dtype).max, 256*523fa7a6SAndroid Build Coastguard Worker qscheme=torch.per_channel_symmetric, 257*523fa7a6SAndroid Build Coastguard Worker ch_axis=0, 258*523fa7a6SAndroid Build Coastguard Worker observer_or_fake_quant_ctr=PerChannelMinMaxObserver.with_args(**extra_args), 259*523fa7a6SAndroid Build Coastguard Worker ) 260*523fa7a6SAndroid Build Coastguard Worker 261*523fa7a6SAndroid Build Coastguard Worker bias_quantization_spec = _derived_bias_quant_spec 262*523fa7a6SAndroid Build Coastguard Worker 263*523fa7a6SAndroid Build Coastguard Worker quantization_config = QuantizationConfig( 264*523fa7a6SAndroid Build Coastguard Worker input_activation=act_quantization_spec, 265*523fa7a6SAndroid Build Coastguard Worker output_activation=act_quantization_spec, 266*523fa7a6SAndroid Build Coastguard Worker weight=weight_quantization_spec, 267*523fa7a6SAndroid Build Coastguard Worker bias=bias_quantization_spec, 268*523fa7a6SAndroid Build Coastguard Worker ) 269*523fa7a6SAndroid Build Coastguard Worker 270*523fa7a6SAndroid Build Coastguard Worker return quantization_config 271*523fa7a6SAndroid Build Coastguard Worker 272*523fa7a6SAndroid Build Coastguard Worker 273*523fa7a6SAndroid Build Coastguard Worker# TODO merge qat and ptq to a fucntion, and use a bool flag to control it 274*523fa7a6SAndroid Build Coastguard Workerdef get_8a8w_qnn_qat_config( 275*523fa7a6SAndroid Build Coastguard Worker act_symmetric: bool = False, act_observer=MovingAverageMinMaxObserver 276*523fa7a6SAndroid Build Coastguard Worker) -> QuantizationConfig: 277*523fa7a6SAndroid Build Coastguard Worker act_fake_quant_ctr = FakeQuantize.with_args( 278*523fa7a6SAndroid Build Coastguard Worker dtype=torch.uint8, 279*523fa7a6SAndroid Build Coastguard Worker qscheme=( 280*523fa7a6SAndroid Build Coastguard Worker torch.per_tensor_symmetric if act_symmetric else torch.per_tensor_affine 281*523fa7a6SAndroid Build Coastguard Worker ), 282*523fa7a6SAndroid Build Coastguard Worker reduce_range=True, 283*523fa7a6SAndroid Build Coastguard Worker observer=act_observer, 284*523fa7a6SAndroid Build Coastguard Worker ) 285*523fa7a6SAndroid Build Coastguard Worker act_quantization_spec = QuantizationSpec( 286*523fa7a6SAndroid Build Coastguard Worker dtype=torch.uint8, 287*523fa7a6SAndroid Build Coastguard Worker qscheme=( 288*523fa7a6SAndroid Build Coastguard Worker torch.per_tensor_symmetric if act_symmetric else torch.per_tensor_affine 289*523fa7a6SAndroid Build Coastguard Worker ), 290*523fa7a6SAndroid Build Coastguard Worker ch_axis=0, 291*523fa7a6SAndroid Build Coastguard Worker observer_or_fake_quant_ctr=act_fake_quant_ctr, 292*523fa7a6SAndroid Build Coastguard Worker ) 293*523fa7a6SAndroid Build Coastguard Worker 294*523fa7a6SAndroid Build Coastguard Worker weight_fake_quant_ctr = FusedMovingAvgObsFakeQuantize.with_args( 295*523fa7a6SAndroid Build Coastguard Worker dtype=torch.int8, 296*523fa7a6SAndroid Build Coastguard Worker quant_min=torch.iinfo(torch.int8).min + 1, 297*523fa7a6SAndroid Build Coastguard Worker quant_max=torch.iinfo(torch.int8).max, 298*523fa7a6SAndroid Build Coastguard Worker qscheme=torch.per_tensor_symmetric, 299*523fa7a6SAndroid Build Coastguard Worker reduce_range=True, 300*523fa7a6SAndroid Build Coastguard Worker observer=MovingAverageMinMaxObserver, 301*523fa7a6SAndroid Build Coastguard Worker ) 302*523fa7a6SAndroid Build Coastguard Worker weight_quantization_spec = QuantizationSpec( 303*523fa7a6SAndroid Build Coastguard Worker dtype=torch.int8, 304*523fa7a6SAndroid Build Coastguard Worker quant_min=torch.iinfo(torch.int8).min + 1, 305*523fa7a6SAndroid Build Coastguard Worker quant_max=torch.iinfo(torch.int8).max, 306*523fa7a6SAndroid Build Coastguard Worker qscheme=torch.per_tensor_symmetric, 307*523fa7a6SAndroid Build Coastguard Worker ch_axis=0, 308*523fa7a6SAndroid Build Coastguard Worker observer_or_fake_quant_ctr=weight_fake_quant_ctr, 309*523fa7a6SAndroid Build Coastguard Worker ) 310*523fa7a6SAndroid Build Coastguard Worker 311*523fa7a6SAndroid Build Coastguard Worker bias_fake_quant_ctr = FakeQuantize.with_args( 312*523fa7a6SAndroid Build Coastguard Worker dtype=torch.int32, 313*523fa7a6SAndroid Build Coastguard Worker quant_min=torch.iinfo(torch.int32).min, 314*523fa7a6SAndroid Build Coastguard Worker quant_max=torch.iinfo(torch.int32).max, 315*523fa7a6SAndroid Build Coastguard Worker qscheme=torch.per_tensor_symmetric, 316*523fa7a6SAndroid Build Coastguard Worker reduce_range=True, 317*523fa7a6SAndroid Build Coastguard Worker observer=MovingAverageMinMaxObserver, 318*523fa7a6SAndroid Build Coastguard Worker ) 319*523fa7a6SAndroid Build Coastguard Worker bias_quantization_spec = QuantizationSpec( 320*523fa7a6SAndroid Build Coastguard Worker dtype=torch.int32, 321*523fa7a6SAndroid Build Coastguard Worker quant_min=torch.iinfo(torch.int32).min, 322*523fa7a6SAndroid Build Coastguard Worker quant_max=torch.iinfo(torch.int32).max, 323*523fa7a6SAndroid Build Coastguard Worker qscheme=torch.per_tensor_symmetric, 324*523fa7a6SAndroid Build Coastguard Worker observer_or_fake_quant_ctr=bias_fake_quant_ctr, 325*523fa7a6SAndroid Build Coastguard Worker ) 326*523fa7a6SAndroid Build Coastguard Worker 327*523fa7a6SAndroid Build Coastguard Worker quantization_config = QuantizationConfig( 328*523fa7a6SAndroid Build Coastguard Worker input_activation=act_quantization_spec, 329*523fa7a6SAndroid Build Coastguard Worker output_activation=act_quantization_spec, 330*523fa7a6SAndroid Build Coastguard Worker weight=weight_quantization_spec, 331*523fa7a6SAndroid Build Coastguard Worker bias=bias_quantization_spec, 332*523fa7a6SAndroid Build Coastguard Worker ) 333*523fa7a6SAndroid Build Coastguard Worker 334*523fa7a6SAndroid Build Coastguard Worker return quantization_config 335*523fa7a6SAndroid Build Coastguard Worker 336*523fa7a6SAndroid Build Coastguard Worker 337*523fa7a6SAndroid Build Coastguard Workerdef get_16a4w_qnn_qat_config( 338*523fa7a6SAndroid Build Coastguard Worker act_observer=MovingAverageMinMaxObserver, 339*523fa7a6SAndroid Build Coastguard Worker) -> QuantizationConfig: 340*523fa7a6SAndroid Build Coastguard Worker act_fake_quant_ctr = FakeQuantize.with_args( 341*523fa7a6SAndroid Build Coastguard Worker dtype=torch.int32, 342*523fa7a6SAndroid Build Coastguard Worker quant_min=torch.iinfo(torch.uint16).min, 343*523fa7a6SAndroid Build Coastguard Worker quant_max=torch.iinfo(torch.uint16).max, 344*523fa7a6SAndroid Build Coastguard Worker qscheme=torch.per_tensor_affine, 345*523fa7a6SAndroid Build Coastguard Worker reduce_range=True, 346*523fa7a6SAndroid Build Coastguard Worker observer=act_observer, 347*523fa7a6SAndroid Build Coastguard Worker ) 348*523fa7a6SAndroid Build Coastguard Worker act_quantization_spec = QuantizationSpec( 349*523fa7a6SAndroid Build Coastguard Worker dtype=torch.int32, 350*523fa7a6SAndroid Build Coastguard Worker quant_min=torch.iinfo(torch.uint16).min, 351*523fa7a6SAndroid Build Coastguard Worker quant_max=torch.iinfo(torch.uint16).max, 352*523fa7a6SAndroid Build Coastguard Worker qscheme=torch.per_tensor_affine, 353*523fa7a6SAndroid Build Coastguard Worker observer_or_fake_quant_ctr=act_fake_quant_ctr, 354*523fa7a6SAndroid Build Coastguard Worker ) 355*523fa7a6SAndroid Build Coastguard Worker 356*523fa7a6SAndroid Build Coastguard Worker weight_fake_quant_ctr = FusedMovingAvgObsFakeQuantize.with_args( 357*523fa7a6SAndroid Build Coastguard Worker dtype=torch.int8, 358*523fa7a6SAndroid Build Coastguard Worker quant_min=-7, 359*523fa7a6SAndroid Build Coastguard Worker quant_max=7, 360*523fa7a6SAndroid Build Coastguard Worker qscheme=torch.per_tensor_symmetric, 361*523fa7a6SAndroid Build Coastguard Worker ch_axis=0, 362*523fa7a6SAndroid Build Coastguard Worker reduce_range=True, 363*523fa7a6SAndroid Build Coastguard Worker observer=MovingAverageMinMaxObserver, 364*523fa7a6SAndroid Build Coastguard Worker ) 365*523fa7a6SAndroid Build Coastguard Worker weight_quantization_spec = QuantizationSpec( 366*523fa7a6SAndroid Build Coastguard Worker dtype=torch.int8, 367*523fa7a6SAndroid Build Coastguard Worker quant_min=-7, 368*523fa7a6SAndroid Build Coastguard Worker quant_max=7, 369*523fa7a6SAndroid Build Coastguard Worker qscheme=torch.per_tensor_symmetric, 370*523fa7a6SAndroid Build Coastguard Worker ch_axis=0, 371*523fa7a6SAndroid Build Coastguard Worker observer_or_fake_quant_ctr=weight_fake_quant_ctr, 372*523fa7a6SAndroid Build Coastguard Worker ) 373*523fa7a6SAndroid Build Coastguard Worker 374*523fa7a6SAndroid Build Coastguard Worker bias_fake_quant_ctr = FakeQuantize.with_args( 375*523fa7a6SAndroid Build Coastguard Worker dtype=torch.int32, 376*523fa7a6SAndroid Build Coastguard Worker quant_min=torch.iinfo(torch.int32).min, 377*523fa7a6SAndroid Build Coastguard Worker quant_max=torch.iinfo(torch.int32).max, 378*523fa7a6SAndroid Build Coastguard Worker qscheme=torch.per_tensor_symmetric, 379*523fa7a6SAndroid Build Coastguard Worker reduce_range=True, 380*523fa7a6SAndroid Build Coastguard Worker observer=MovingAverageMinMaxObserver, 381*523fa7a6SAndroid Build Coastguard Worker ) 382*523fa7a6SAndroid Build Coastguard Worker bias_quantization_spec = QuantizationSpec( 383*523fa7a6SAndroid Build Coastguard Worker dtype=torch.int32, 384*523fa7a6SAndroid Build Coastguard Worker quant_min=torch.iinfo(torch.int32).min, 385*523fa7a6SAndroid Build Coastguard Worker quant_max=torch.iinfo(torch.int32).max, 386*523fa7a6SAndroid Build Coastguard Worker qscheme=torch.per_tensor_symmetric, 387*523fa7a6SAndroid Build Coastguard Worker observer_or_fake_quant_ctr=bias_fake_quant_ctr, 388*523fa7a6SAndroid Build Coastguard Worker ) 389*523fa7a6SAndroid Build Coastguard Worker 390*523fa7a6SAndroid Build Coastguard Worker quantization_config = QuantizationConfig( 391*523fa7a6SAndroid Build Coastguard Worker input_activation=act_quantization_spec, 392*523fa7a6SAndroid Build Coastguard Worker output_activation=act_quantization_spec, 393*523fa7a6SAndroid Build Coastguard Worker weight=weight_quantization_spec, 394*523fa7a6SAndroid Build Coastguard Worker bias=bias_quantization_spec, 395*523fa7a6SAndroid Build Coastguard Worker ) 396*523fa7a6SAndroid Build Coastguard Worker 397*523fa7a6SAndroid Build Coastguard Worker return quantization_config 398*523fa7a6SAndroid Build Coastguard Worker 399*523fa7a6SAndroid Build Coastguard Worker 400*523fa7a6SAndroid Build Coastguard Workerdef get_qat_per_channel_quant_config( 401*523fa7a6SAndroid Build Coastguard Worker act_dtype=torch.uint8, 402*523fa7a6SAndroid Build Coastguard Worker weight_dtype=torch.int8, 403*523fa7a6SAndroid Build Coastguard Worker act_observer=MovingAverageMinMaxObserver, 404*523fa7a6SAndroid Build Coastguard Worker) -> QuantizationConfig: 405*523fa7a6SAndroid Build Coastguard Worker supported_act_types = { 406*523fa7a6SAndroid Build Coastguard Worker torch.uint8, 407*523fa7a6SAndroid Build Coastguard Worker torch.uint16, 408*523fa7a6SAndroid Build Coastguard Worker torch.int8, 409*523fa7a6SAndroid Build Coastguard Worker torch.int16, 410*523fa7a6SAndroid Build Coastguard Worker } 411*523fa7a6SAndroid Build Coastguard Worker # TODO accept "int4" temporally. Remove "int4" when torch support torch.int4 dtype 412*523fa7a6SAndroid Build Coastguard Worker supported_weight_dtypes = {"int4", torch.int8, torch.int16} 413*523fa7a6SAndroid Build Coastguard Worker assert ( 414*523fa7a6SAndroid Build Coastguard Worker act_dtype in supported_act_types 415*523fa7a6SAndroid Build Coastguard Worker ), f"act_dtype, {act_dtype} is not one of supported types, {supported_act_types}" 416*523fa7a6SAndroid Build Coastguard Worker 417*523fa7a6SAndroid Build Coastguard Worker assert ( 418*523fa7a6SAndroid Build Coastguard Worker weight_dtype in supported_weight_dtypes 419*523fa7a6SAndroid Build Coastguard Worker ), f"weight_dtype, {weight_dtype} is not one of supported types, {supported_weight_dtypes}" 420*523fa7a6SAndroid Build Coastguard Worker 421*523fa7a6SAndroid Build Coastguard Worker # torch do not support uint16 quantization, use int32 to bypass 422*523fa7a6SAndroid Build Coastguard Worker act_fake_quant_ctr = FakeQuantize.with_args( 423*523fa7a6SAndroid Build Coastguard Worker dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype, 424*523fa7a6SAndroid Build Coastguard Worker quant_min=torch.iinfo(act_dtype).min, 425*523fa7a6SAndroid Build Coastguard Worker quant_max=torch.iinfo(act_dtype).max, 426*523fa7a6SAndroid Build Coastguard Worker qscheme=torch.per_tensor_affine, 427*523fa7a6SAndroid Build Coastguard Worker reduce_range=True, 428*523fa7a6SAndroid Build Coastguard Worker observer=act_observer, 429*523fa7a6SAndroid Build Coastguard Worker ) 430*523fa7a6SAndroid Build Coastguard Worker act_quantization_spec = QuantizationSpec( 431*523fa7a6SAndroid Build Coastguard Worker dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype, 432*523fa7a6SAndroid Build Coastguard Worker quant_min=torch.iinfo(act_dtype).min, 433*523fa7a6SAndroid Build Coastguard Worker quant_max=torch.iinfo(act_dtype).max, 434*523fa7a6SAndroid Build Coastguard Worker qscheme=torch.per_tensor_affine, 435*523fa7a6SAndroid Build Coastguard Worker observer_or_fake_quant_ctr=act_fake_quant_ctr, 436*523fa7a6SAndroid Build Coastguard Worker ) 437*523fa7a6SAndroid Build Coastguard Worker 438*523fa7a6SAndroid Build Coastguard Worker weight_fake_quant_ctr = FusedMovingAvgObsFakeQuantize.with_args( 439*523fa7a6SAndroid Build Coastguard Worker dtype=torch.int8 if weight_dtype == "int4" else weight_dtype, 440*523fa7a6SAndroid Build Coastguard Worker quant_min=-7 if weight_dtype == "int4" else torch.iinfo(weight_dtype).min + 1, 441*523fa7a6SAndroid Build Coastguard Worker quant_max=7 if weight_dtype == "int4" else torch.iinfo(weight_dtype).max, 442*523fa7a6SAndroid Build Coastguard Worker qscheme=torch.per_channel_symmetric, 443*523fa7a6SAndroid Build Coastguard Worker ch_axis=0, 444*523fa7a6SAndroid Build Coastguard Worker observer=MovingAveragePerChannelMinMaxObserver, 445*523fa7a6SAndroid Build Coastguard Worker ) 446*523fa7a6SAndroid Build Coastguard Worker weight_quantization_spec = QuantizationSpec( 447*523fa7a6SAndroid Build Coastguard Worker dtype=torch.int8 if weight_dtype == "int4" else weight_dtype, 448*523fa7a6SAndroid Build Coastguard Worker quant_min=-7 if weight_dtype == "int4" else torch.iinfo(weight_dtype).min + 1, 449*523fa7a6SAndroid Build Coastguard Worker quant_max=7 if weight_dtype == "int4" else torch.iinfo(weight_dtype).max, 450*523fa7a6SAndroid Build Coastguard Worker qscheme=torch.per_channel_symmetric, 451*523fa7a6SAndroid Build Coastguard Worker ch_axis=0, 452*523fa7a6SAndroid Build Coastguard Worker observer_or_fake_quant_ctr=weight_fake_quant_ctr, 453*523fa7a6SAndroid Build Coastguard Worker ) 454*523fa7a6SAndroid Build Coastguard Worker 455*523fa7a6SAndroid Build Coastguard Worker bias_quantization_spec = _derived_bias_quant_spec 456*523fa7a6SAndroid Build Coastguard Worker 457*523fa7a6SAndroid Build Coastguard Worker quantization_config = QuantizationConfig( 458*523fa7a6SAndroid Build Coastguard Worker input_activation=act_quantization_spec, 459*523fa7a6SAndroid Build Coastguard Worker output_activation=act_quantization_spec, 460*523fa7a6SAndroid Build Coastguard Worker weight=weight_quantization_spec, 461*523fa7a6SAndroid Build Coastguard Worker bias=bias_quantization_spec, 462*523fa7a6SAndroid Build Coastguard Worker ) 463*523fa7a6SAndroid Build Coastguard Worker 464*523fa7a6SAndroid Build Coastguard Worker return quantization_config 465