1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) 2024 MediaTek Inc. 2*523fa7a6SAndroid Build Coastguard Worker# 3*523fa7a6SAndroid Build Coastguard Worker# Licensed under the BSD License (the "License"); you may not use this file 4*523fa7a6SAndroid Build Coastguard Worker# except in compliance with the License. See the license file in the root 5*523fa7a6SAndroid Build Coastguard Worker# directory of this source tree for more details. 6*523fa7a6SAndroid Build Coastguard Worker 7*523fa7a6SAndroid Build Coastguard Workerimport copy 8*523fa7a6SAndroid Build Coastguard Worker 9*523fa7a6SAndroid Build Coastguard Workerfrom enum import IntEnum, unique 10*523fa7a6SAndroid Build Coastguard Worker 11*523fa7a6SAndroid Build Coastguard Workerimport torch 12*523fa7a6SAndroid Build Coastguard Worker 13*523fa7a6SAndroid Build Coastguard Workerfrom torch.ao.quantization.fake_quantize import FakeQuantize 14*523fa7a6SAndroid Build Coastguard Workerfrom torch.ao.quantization.observer import MinMaxObserver, PerChannelMinMaxObserver 15*523fa7a6SAndroid Build Coastguard Workerfrom torch.ao.quantization.quantizer import QuantizationSpec 16*523fa7a6SAndroid Build Coastguard Worker 17*523fa7a6SAndroid Build Coastguard Worker 18*523fa7a6SAndroid Build Coastguard Worker@unique 19*523fa7a6SAndroid Build Coastguard Workerclass Precision(IntEnum): 20*523fa7a6SAndroid Build Coastguard Worker A16W16 = 0 21*523fa7a6SAndroid Build Coastguard Worker A16W8 = 1 22*523fa7a6SAndroid Build Coastguard Worker A16W4 = 2 23*523fa7a6SAndroid Build Coastguard Worker A8W8 = 3 24*523fa7a6SAndroid Build Coastguard Worker A8W4 = 4 25*523fa7a6SAndroid Build Coastguard Worker 26*523fa7a6SAndroid Build Coastguard Worker 27*523fa7a6SAndroid Build Coastguard Workerclass QuantizationConfig: 28*523fa7a6SAndroid Build Coastguard Worker 29*523fa7a6SAndroid Build Coastguard Worker def __init__( 30*523fa7a6SAndroid Build Coastguard Worker self, activation_spec: QuantizationSpec, weight_spec: QuantizationSpec 31*523fa7a6SAndroid Build Coastguard Worker ): 32*523fa7a6SAndroid Build Coastguard Worker self._activation_spec = activation_spec 33*523fa7a6SAndroid Build Coastguard Worker self._weight_spec = weight_spec 34*523fa7a6SAndroid Build Coastguard Worker 35*523fa7a6SAndroid Build Coastguard Worker @property 36*523fa7a6SAndroid Build Coastguard Worker def activation(self): 37*523fa7a6SAndroid Build Coastguard Worker return copy.deepcopy(self._activation_spec) 38*523fa7a6SAndroid Build Coastguard Worker 39*523fa7a6SAndroid Build Coastguard Worker @property 40*523fa7a6SAndroid Build Coastguard Worker def weight(self): 41*523fa7a6SAndroid Build Coastguard Worker return copy.deepcopy(self._weight_spec) 42*523fa7a6SAndroid Build Coastguard Worker 43*523fa7a6SAndroid Build Coastguard Worker 44*523fa7a6SAndroid Build Coastguard Workerdef get_quant_config( 45*523fa7a6SAndroid Build Coastguard Worker precision: Precision, 46*523fa7a6SAndroid Build Coastguard Worker is_per_channel: bool = False, 47*523fa7a6SAndroid Build Coastguard Worker is_qat: bool = False, 48*523fa7a6SAndroid Build Coastguard Worker) -> QuantizationConfig: 49*523fa7a6SAndroid Build Coastguard Worker 50*523fa7a6SAndroid Build Coastguard Worker precision_mappings = { 51*523fa7a6SAndroid Build Coastguard Worker Precision.A16W16: get_a16w16_quant_config, 52*523fa7a6SAndroid Build Coastguard Worker Precision.A16W8: get_a16w8_quant_config, 53*523fa7a6SAndroid Build Coastguard Worker Precision.A16W4: get_a16w4_quant_config, 54*523fa7a6SAndroid Build Coastguard Worker Precision.A8W8: get_a8w8_quant_config, 55*523fa7a6SAndroid Build Coastguard Worker Precision.A8W4: get_a8w4_quant_config, 56*523fa7a6SAndroid Build Coastguard Worker } 57*523fa7a6SAndroid Build Coastguard Worker if precision not in precision_mappings: 58*523fa7a6SAndroid Build Coastguard Worker raise RuntimeError("Unrecognized precision setting.") 59*523fa7a6SAndroid Build Coastguard Worker 60*523fa7a6SAndroid Build Coastguard Worker qconfig_fn = precision_mappings[precision] 61*523fa7a6SAndroid Build Coastguard Worker return qconfig_fn(is_per_channel, is_qat) 62*523fa7a6SAndroid Build Coastguard Worker 63*523fa7a6SAndroid Build Coastguard Worker 64*523fa7a6SAndroid Build Coastguard Workerdef _get_activation_qspec( 65*523fa7a6SAndroid Build Coastguard Worker dtype, 66*523fa7a6SAndroid Build Coastguard Worker is_symmetric, 67*523fa7a6SAndroid Build Coastguard Worker is_qat, 68*523fa7a6SAndroid Build Coastguard Worker observer_cls=MinMaxObserver, 69*523fa7a6SAndroid Build Coastguard Worker quant_min=None, 70*523fa7a6SAndroid Build Coastguard Worker quant_max=None, 71*523fa7a6SAndroid Build Coastguard Worker): 72*523fa7a6SAndroid Build Coastguard Worker if quant_max is None: 73*523fa7a6SAndroid Build Coastguard Worker quant_max = torch.iinfo(dtype).max 74*523fa7a6SAndroid Build Coastguard Worker if quant_min is None: 75*523fa7a6SAndroid Build Coastguard Worker # quant_min = torch.iinfo(dtype).min + 1 if is_symmetric else torch.iinfo(dtype).min 76*523fa7a6SAndroid Build Coastguard Worker quant_min = torch.iinfo(dtype).min 77*523fa7a6SAndroid Build Coastguard Worker 78*523fa7a6SAndroid Build Coastguard Worker qscheme = torch.per_tensor_symmetric if is_symmetric else torch.per_tensor_affine 79*523fa7a6SAndroid Build Coastguard Worker if is_qat: 80*523fa7a6SAndroid Build Coastguard Worker observer_or_fake_quant = FakeQuantize.with_args(observer=observer_cls, eps=1e-6) 81*523fa7a6SAndroid Build Coastguard Worker else: 82*523fa7a6SAndroid Build Coastguard Worker observer_or_fake_quant = observer_cls.with_args(eps=1e-6) 83*523fa7a6SAndroid Build Coastguard Worker 84*523fa7a6SAndroid Build Coastguard Worker return QuantizationSpec( 85*523fa7a6SAndroid Build Coastguard Worker dtype=dtype, 86*523fa7a6SAndroid Build Coastguard Worker quant_min=quant_min, 87*523fa7a6SAndroid Build Coastguard Worker quant_max=quant_max, 88*523fa7a6SAndroid Build Coastguard Worker qscheme=qscheme, 89*523fa7a6SAndroid Build Coastguard Worker observer_or_fake_quant_ctr=observer_or_fake_quant, 90*523fa7a6SAndroid Build Coastguard Worker ) 91*523fa7a6SAndroid Build Coastguard Worker 92*523fa7a6SAndroid Build Coastguard Worker 93*523fa7a6SAndroid Build Coastguard Workerdef _get_weight_qspec( 94*523fa7a6SAndroid Build Coastguard Worker dtype, is_symmetric, is_per_channel, is_qat, quant_min=None, quant_max=None 95*523fa7a6SAndroid Build Coastguard Worker): 96*523fa7a6SAndroid Build Coastguard Worker if not is_per_channel: 97*523fa7a6SAndroid Build Coastguard Worker return _get_activation_qspec( 98*523fa7a6SAndroid Build Coastguard Worker dtype, is_symmetric, is_qat, observer_cls=MinMaxObserver 99*523fa7a6SAndroid Build Coastguard Worker ) 100*523fa7a6SAndroid Build Coastguard Worker 101*523fa7a6SAndroid Build Coastguard Worker if quant_max is None: 102*523fa7a6SAndroid Build Coastguard Worker quant_max = torch.iinfo(dtype).max 103*523fa7a6SAndroid Build Coastguard Worker if quant_min is None: 104*523fa7a6SAndroid Build Coastguard Worker # quant_min = torch.iinfo(dtype).min + 1 if is_symmetric else torch.iinfo(dtype).min 105*523fa7a6SAndroid Build Coastguard Worker quant_min = torch.iinfo(dtype).min 106*523fa7a6SAndroid Build Coastguard Worker 107*523fa7a6SAndroid Build Coastguard Worker qscheme = torch.per_channel_symmetric if is_symmetric else torch.per_channel_affine 108*523fa7a6SAndroid Build Coastguard Worker if is_qat: 109*523fa7a6SAndroid Build Coastguard Worker observer_or_fake_quant = FakeQuantize.with_args( 110*523fa7a6SAndroid Build Coastguard Worker observer=PerChannelMinMaxObserver, eps=1e-6 111*523fa7a6SAndroid Build Coastguard Worker ) 112*523fa7a6SAndroid Build Coastguard Worker else: 113*523fa7a6SAndroid Build Coastguard Worker observer_or_fake_quant = PerChannelMinMaxObserver.with_args(eps=1e-6) 114*523fa7a6SAndroid Build Coastguard Worker 115*523fa7a6SAndroid Build Coastguard Worker return QuantizationSpec( 116*523fa7a6SAndroid Build Coastguard Worker dtype=dtype, 117*523fa7a6SAndroid Build Coastguard Worker quant_min=quant_min, 118*523fa7a6SAndroid Build Coastguard Worker quant_max=quant_max, 119*523fa7a6SAndroid Build Coastguard Worker qscheme=qscheme, 120*523fa7a6SAndroid Build Coastguard Worker ch_axis=0, 121*523fa7a6SAndroid Build Coastguard Worker observer_or_fake_quant_ctr=observer_or_fake_quant, 122*523fa7a6SAndroid Build Coastguard Worker ) 123*523fa7a6SAndroid Build Coastguard Worker 124*523fa7a6SAndroid Build Coastguard Worker 125*523fa7a6SAndroid Build Coastguard Workerdef get_a16w16_quant_config(is_per_channel, is_qat) -> QuantizationConfig: 126*523fa7a6SAndroid Build Coastguard Worker act_quantization_spec = _get_activation_qspec(torch.int16, True, is_qat) 127*523fa7a6SAndroid Build Coastguard Worker wgt_quantization_spec = _get_weight_qspec(torch.int16, True, is_per_channel, is_qat) 128*523fa7a6SAndroid Build Coastguard Worker quantization_config = QuantizationConfig( 129*523fa7a6SAndroid Build Coastguard Worker act_quantization_spec, wgt_quantization_spec 130*523fa7a6SAndroid Build Coastguard Worker ) 131*523fa7a6SAndroid Build Coastguard Worker return quantization_config 132*523fa7a6SAndroid Build Coastguard Worker 133*523fa7a6SAndroid Build Coastguard Worker 134*523fa7a6SAndroid Build Coastguard Workerdef get_a16w8_quant_config(is_per_channel, is_qat) -> QuantizationConfig: 135*523fa7a6SAndroid Build Coastguard Worker act_quantization_spec = _get_activation_qspec(torch.int16, True, is_qat) 136*523fa7a6SAndroid Build Coastguard Worker wgt_quantization_spec = _get_weight_qspec(torch.int8, True, is_per_channel, is_qat) 137*523fa7a6SAndroid Build Coastguard Worker quantization_config = QuantizationConfig( 138*523fa7a6SAndroid Build Coastguard Worker act_quantization_spec, wgt_quantization_spec 139*523fa7a6SAndroid Build Coastguard Worker ) 140*523fa7a6SAndroid Build Coastguard Worker return quantization_config 141*523fa7a6SAndroid Build Coastguard Worker 142*523fa7a6SAndroid Build Coastguard Worker 143*523fa7a6SAndroid Build Coastguard Workerdef get_a16w4_quant_config(is_per_channel, is_qat) -> QuantizationConfig: 144*523fa7a6SAndroid Build Coastguard Worker act_quantization_spec = _get_activation_qspec(torch.int16, True, is_qat) 145*523fa7a6SAndroid Build Coastguard Worker wgt_quantization_spec = _get_weight_qspec( 146*523fa7a6SAndroid Build Coastguard Worker torch.int8, False, is_per_channel, is_qat, quant_min=-8, quant_max=7 147*523fa7a6SAndroid Build Coastguard Worker ) 148*523fa7a6SAndroid Build Coastguard Worker quantization_config = QuantizationConfig( 149*523fa7a6SAndroid Build Coastguard Worker act_quantization_spec, wgt_quantization_spec 150*523fa7a6SAndroid Build Coastguard Worker ) 151*523fa7a6SAndroid Build Coastguard Worker return quantization_config 152*523fa7a6SAndroid Build Coastguard Worker 153*523fa7a6SAndroid Build Coastguard Worker 154*523fa7a6SAndroid Build Coastguard Workerdef get_a8w8_quant_config(is_per_channel, is_qat) -> QuantizationConfig: 155*523fa7a6SAndroid Build Coastguard Worker act_quantization_spec = _get_activation_qspec(torch.int8, False, is_qat) 156*523fa7a6SAndroid Build Coastguard Worker wgt_quantization_spec = _get_weight_qspec(torch.int8, False, is_per_channel, is_qat) 157*523fa7a6SAndroid Build Coastguard Worker quantization_config = QuantizationConfig( 158*523fa7a6SAndroid Build Coastguard Worker act_quantization_spec, wgt_quantization_spec 159*523fa7a6SAndroid Build Coastguard Worker ) 160*523fa7a6SAndroid Build Coastguard Worker return quantization_config 161*523fa7a6SAndroid Build Coastguard Worker 162*523fa7a6SAndroid Build Coastguard Worker 163*523fa7a6SAndroid Build Coastguard Workerdef get_a8w4_quant_config(is_per_channel, is_qat) -> QuantizationConfig: 164*523fa7a6SAndroid Build Coastguard Worker act_quantization_spec = _get_activation_qspec(torch.int8, False, is_qat) 165*523fa7a6SAndroid Build Coastguard Worker wgt_quantization_spec = _get_weight_qspec( 166*523fa7a6SAndroid Build Coastguard Worker torch.int8, False, is_per_channel, is_qat, quant_min=-8, quant_max=7 167*523fa7a6SAndroid Build Coastguard Worker ) 168*523fa7a6SAndroid Build Coastguard Worker quantization_config = QuantizationConfig( 169*523fa7a6SAndroid Build Coastguard Worker act_quantization_spec, wgt_quantization_spec 170*523fa7a6SAndroid Build Coastguard Worker ) 171*523fa7a6SAndroid Build Coastguard Worker return quantization_config 172