1# Copyright (c) 2024 MediaTek Inc. 2# 3# Licensed under the BSD License (the "License"); you may not use this file 4# except in compliance with the License. See the license file in the root 5# directory of this source tree for more details. 6 7import copy 8 9from enum import IntEnum, unique 10 11import torch 12 13from torch.ao.quantization.fake_quantize import FakeQuantize 14from torch.ao.quantization.observer import MinMaxObserver, PerChannelMinMaxObserver 15from torch.ao.quantization.quantizer import QuantizationSpec 16 17 18@unique 19class Precision(IntEnum): 20 A16W16 = 0 21 A16W8 = 1 22 A16W4 = 2 23 A8W8 = 3 24 A8W4 = 4 25 26 27class QuantizationConfig: 28 29 def __init__( 30 self, activation_spec: QuantizationSpec, weight_spec: QuantizationSpec 31 ): 32 self._activation_spec = activation_spec 33 self._weight_spec = weight_spec 34 35 @property 36 def activation(self): 37 return copy.deepcopy(self._activation_spec) 38 39 @property 40 def weight(self): 41 return copy.deepcopy(self._weight_spec) 42 43 44def get_quant_config( 45 precision: Precision, 46 is_per_channel: bool = False, 47 is_qat: bool = False, 48) -> QuantizationConfig: 49 50 precision_mappings = { 51 Precision.A16W16: get_a16w16_quant_config, 52 Precision.A16W8: get_a16w8_quant_config, 53 Precision.A16W4: get_a16w4_quant_config, 54 Precision.A8W8: get_a8w8_quant_config, 55 Precision.A8W4: get_a8w4_quant_config, 56 } 57 if precision not in precision_mappings: 58 raise RuntimeError("Unrecognized precision setting.") 59 60 qconfig_fn = precision_mappings[precision] 61 return qconfig_fn(is_per_channel, is_qat) 62 63 64def _get_activation_qspec( 65 dtype, 66 is_symmetric, 67 is_qat, 68 observer_cls=MinMaxObserver, 69 quant_min=None, 70 quant_max=None, 71): 72 if quant_max is None: 73 quant_max = torch.iinfo(dtype).max 74 if quant_min is None: 75 # quant_min = torch.iinfo(dtype).min + 1 if is_symmetric else torch.iinfo(dtype).min 76 quant_min = torch.iinfo(dtype).min 77 78 qscheme = torch.per_tensor_symmetric if is_symmetric else torch.per_tensor_affine 79 if is_qat: 80 observer_or_fake_quant = FakeQuantize.with_args(observer=observer_cls, eps=1e-6) 81 else: 82 observer_or_fake_quant = observer_cls.with_args(eps=1e-6) 83 84 return QuantizationSpec( 85 dtype=dtype, 86 quant_min=quant_min, 87 quant_max=quant_max, 88 qscheme=qscheme, 89 observer_or_fake_quant_ctr=observer_or_fake_quant, 90 ) 91 92 93def _get_weight_qspec( 94 dtype, is_symmetric, is_per_channel, is_qat, quant_min=None, quant_max=None 95): 96 if not is_per_channel: 97 return _get_activation_qspec( 98 dtype, is_symmetric, is_qat, observer_cls=MinMaxObserver 99 ) 100 101 if quant_max is None: 102 quant_max = torch.iinfo(dtype).max 103 if quant_min is None: 104 # quant_min = torch.iinfo(dtype).min + 1 if is_symmetric else torch.iinfo(dtype).min 105 quant_min = torch.iinfo(dtype).min 106 107 qscheme = torch.per_channel_symmetric if is_symmetric else torch.per_channel_affine 108 if is_qat: 109 observer_or_fake_quant = FakeQuantize.with_args( 110 observer=PerChannelMinMaxObserver, eps=1e-6 111 ) 112 else: 113 observer_or_fake_quant = PerChannelMinMaxObserver.with_args(eps=1e-6) 114 115 return QuantizationSpec( 116 dtype=dtype, 117 quant_min=quant_min, 118 quant_max=quant_max, 119 qscheme=qscheme, 120 ch_axis=0, 121 observer_or_fake_quant_ctr=observer_or_fake_quant, 122 ) 123 124 125def get_a16w16_quant_config(is_per_channel, is_qat) -> QuantizationConfig: 126 act_quantization_spec = _get_activation_qspec(torch.int16, True, is_qat) 127 wgt_quantization_spec = _get_weight_qspec(torch.int16, True, is_per_channel, is_qat) 128 quantization_config = QuantizationConfig( 129 act_quantization_spec, wgt_quantization_spec 130 ) 131 return quantization_config 132 133 134def get_a16w8_quant_config(is_per_channel, is_qat) -> QuantizationConfig: 135 act_quantization_spec = _get_activation_qspec(torch.int16, True, is_qat) 136 wgt_quantization_spec = _get_weight_qspec(torch.int8, True, is_per_channel, is_qat) 137 quantization_config = QuantizationConfig( 138 act_quantization_spec, wgt_quantization_spec 139 ) 140 return quantization_config 141 142 143def get_a16w4_quant_config(is_per_channel, is_qat) -> QuantizationConfig: 144 act_quantization_spec = _get_activation_qspec(torch.int16, True, is_qat) 145 wgt_quantization_spec = _get_weight_qspec( 146 torch.int8, False, is_per_channel, is_qat, quant_min=-8, quant_max=7 147 ) 148 quantization_config = QuantizationConfig( 149 act_quantization_spec, wgt_quantization_spec 150 ) 151 return quantization_config 152 153 154def get_a8w8_quant_config(is_per_channel, is_qat) -> QuantizationConfig: 155 act_quantization_spec = _get_activation_qspec(torch.int8, False, is_qat) 156 wgt_quantization_spec = _get_weight_qspec(torch.int8, False, is_per_channel, is_qat) 157 quantization_config = QuantizationConfig( 158 act_quantization_spec, wgt_quantization_spec 159 ) 160 return quantization_config 161 162 163def get_a8w4_quant_config(is_per_channel, is_qat) -> QuantizationConfig: 164 act_quantization_spec = _get_activation_qspec(torch.int8, False, is_qat) 165 wgt_quantization_spec = _get_weight_qspec( 166 torch.int8, False, is_per_channel, is_qat, quant_min=-8, quant_max=7 167 ) 168 quantization_config = QuantizationConfig( 169 act_quantization_spec, wgt_quantization_spec 170 ) 171 return quantization_config 172