xref: /aosp_15_r20/external/executorch/backends/mediatek/quantizer/qconfig.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) 2024 MediaTek Inc.
2#
3# Licensed under the BSD License (the "License"); you may not use this file
4# except in compliance with the License. See the license file in the root
5# directory of this source tree for more details.
6
7import copy
8
9from enum import IntEnum, unique
10
11import torch
12
13from torch.ao.quantization.fake_quantize import FakeQuantize
14from torch.ao.quantization.observer import MinMaxObserver, PerChannelMinMaxObserver
15from torch.ao.quantization.quantizer import QuantizationSpec
16
17
18@unique
19class Precision(IntEnum):
20    A16W16 = 0
21    A16W8 = 1
22    A16W4 = 2
23    A8W8 = 3
24    A8W4 = 4
25
26
27class QuantizationConfig:
28
29    def __init__(
30        self, activation_spec: QuantizationSpec, weight_spec: QuantizationSpec
31    ):
32        self._activation_spec = activation_spec
33        self._weight_spec = weight_spec
34
35    @property
36    def activation(self):
37        return copy.deepcopy(self._activation_spec)
38
39    @property
40    def weight(self):
41        return copy.deepcopy(self._weight_spec)
42
43
44def get_quant_config(
45    precision: Precision,
46    is_per_channel: bool = False,
47    is_qat: bool = False,
48) -> QuantizationConfig:
49
50    precision_mappings = {
51        Precision.A16W16: get_a16w16_quant_config,
52        Precision.A16W8: get_a16w8_quant_config,
53        Precision.A16W4: get_a16w4_quant_config,
54        Precision.A8W8: get_a8w8_quant_config,
55        Precision.A8W4: get_a8w4_quant_config,
56    }
57    if precision not in precision_mappings:
58        raise RuntimeError("Unrecognized precision setting.")
59
60    qconfig_fn = precision_mappings[precision]
61    return qconfig_fn(is_per_channel, is_qat)
62
63
64def _get_activation_qspec(
65    dtype,
66    is_symmetric,
67    is_qat,
68    observer_cls=MinMaxObserver,
69    quant_min=None,
70    quant_max=None,
71):
72    if quant_max is None:
73        quant_max = torch.iinfo(dtype).max
74    if quant_min is None:
75        # quant_min = torch.iinfo(dtype).min + 1 if is_symmetric else torch.iinfo(dtype).min
76        quant_min = torch.iinfo(dtype).min
77
78    qscheme = torch.per_tensor_symmetric if is_symmetric else torch.per_tensor_affine
79    if is_qat:
80        observer_or_fake_quant = FakeQuantize.with_args(observer=observer_cls, eps=1e-6)
81    else:
82        observer_or_fake_quant = observer_cls.with_args(eps=1e-6)
83
84    return QuantizationSpec(
85        dtype=dtype,
86        quant_min=quant_min,
87        quant_max=quant_max,
88        qscheme=qscheme,
89        observer_or_fake_quant_ctr=observer_or_fake_quant,
90    )
91
92
93def _get_weight_qspec(
94    dtype, is_symmetric, is_per_channel, is_qat, quant_min=None, quant_max=None
95):
96    if not is_per_channel:
97        return _get_activation_qspec(
98            dtype, is_symmetric, is_qat, observer_cls=MinMaxObserver
99        )
100
101    if quant_max is None:
102        quant_max = torch.iinfo(dtype).max
103    if quant_min is None:
104        # quant_min = torch.iinfo(dtype).min + 1 if is_symmetric else torch.iinfo(dtype).min
105        quant_min = torch.iinfo(dtype).min
106
107    qscheme = torch.per_channel_symmetric if is_symmetric else torch.per_channel_affine
108    if is_qat:
109        observer_or_fake_quant = FakeQuantize.with_args(
110            observer=PerChannelMinMaxObserver, eps=1e-6
111        )
112    else:
113        observer_or_fake_quant = PerChannelMinMaxObserver.with_args(eps=1e-6)
114
115    return QuantizationSpec(
116        dtype=dtype,
117        quant_min=quant_min,
118        quant_max=quant_max,
119        qscheme=qscheme,
120        ch_axis=0,
121        observer_or_fake_quant_ctr=observer_or_fake_quant,
122    )
123
124
125def get_a16w16_quant_config(is_per_channel, is_qat) -> QuantizationConfig:
126    act_quantization_spec = _get_activation_qspec(torch.int16, True, is_qat)
127    wgt_quantization_spec = _get_weight_qspec(torch.int16, True, is_per_channel, is_qat)
128    quantization_config = QuantizationConfig(
129        act_quantization_spec, wgt_quantization_spec
130    )
131    return quantization_config
132
133
134def get_a16w8_quant_config(is_per_channel, is_qat) -> QuantizationConfig:
135    act_quantization_spec = _get_activation_qspec(torch.int16, True, is_qat)
136    wgt_quantization_spec = _get_weight_qspec(torch.int8, True, is_per_channel, is_qat)
137    quantization_config = QuantizationConfig(
138        act_quantization_spec, wgt_quantization_spec
139    )
140    return quantization_config
141
142
143def get_a16w4_quant_config(is_per_channel, is_qat) -> QuantizationConfig:
144    act_quantization_spec = _get_activation_qspec(torch.int16, True, is_qat)
145    wgt_quantization_spec = _get_weight_qspec(
146        torch.int8, False, is_per_channel, is_qat, quant_min=-8, quant_max=7
147    )
148    quantization_config = QuantizationConfig(
149        act_quantization_spec, wgt_quantization_spec
150    )
151    return quantization_config
152
153
154def get_a8w8_quant_config(is_per_channel, is_qat) -> QuantizationConfig:
155    act_quantization_spec = _get_activation_qspec(torch.int8, False, is_qat)
156    wgt_quantization_spec = _get_weight_qspec(torch.int8, False, is_per_channel, is_qat)
157    quantization_config = QuantizationConfig(
158        act_quantization_spec, wgt_quantization_spec
159    )
160    return quantization_config
161
162
163def get_a8w4_quant_config(is_per_channel, is_qat) -> QuantizationConfig:
164    act_quantization_spec = _get_activation_qspec(torch.int8, False, is_qat)
165    wgt_quantization_spec = _get_weight_qspec(
166        torch.int8, False, is_per_channel, is_qat, quant_min=-8, quant_max=7
167    )
168    quantization_config = QuantizationConfig(
169        act_quantization_spec, wgt_quantization_spec
170    )
171    return quantization_config
172