xref: /aosp_15_r20/external/executorch/backends/mediatek/quantizer/qconfig.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) 2024 MediaTek Inc.
2*523fa7a6SAndroid Build Coastguard Worker#
3*523fa7a6SAndroid Build Coastguard Worker# Licensed under the BSD License (the "License"); you may not use this file
4*523fa7a6SAndroid Build Coastguard Worker# except in compliance with the License. See the license file in the root
5*523fa7a6SAndroid Build Coastguard Worker# directory of this source tree for more details.
6*523fa7a6SAndroid Build Coastguard Worker
7*523fa7a6SAndroid Build Coastguard Workerimport copy
8*523fa7a6SAndroid Build Coastguard Worker
9*523fa7a6SAndroid Build Coastguard Workerfrom enum import IntEnum, unique
10*523fa7a6SAndroid Build Coastguard Worker
11*523fa7a6SAndroid Build Coastguard Workerimport torch
12*523fa7a6SAndroid Build Coastguard Worker
13*523fa7a6SAndroid Build Coastguard Workerfrom torch.ao.quantization.fake_quantize import FakeQuantize
14*523fa7a6SAndroid Build Coastguard Workerfrom torch.ao.quantization.observer import MinMaxObserver, PerChannelMinMaxObserver
15*523fa7a6SAndroid Build Coastguard Workerfrom torch.ao.quantization.quantizer import QuantizationSpec
16*523fa7a6SAndroid Build Coastguard Worker
17*523fa7a6SAndroid Build Coastguard Worker
18*523fa7a6SAndroid Build Coastguard Worker@unique
19*523fa7a6SAndroid Build Coastguard Workerclass Precision(IntEnum):
20*523fa7a6SAndroid Build Coastguard Worker    A16W16 = 0
21*523fa7a6SAndroid Build Coastguard Worker    A16W8 = 1
22*523fa7a6SAndroid Build Coastguard Worker    A16W4 = 2
23*523fa7a6SAndroid Build Coastguard Worker    A8W8 = 3
24*523fa7a6SAndroid Build Coastguard Worker    A8W4 = 4
25*523fa7a6SAndroid Build Coastguard Worker
26*523fa7a6SAndroid Build Coastguard Worker
27*523fa7a6SAndroid Build Coastguard Workerclass QuantizationConfig:
28*523fa7a6SAndroid Build Coastguard Worker
29*523fa7a6SAndroid Build Coastguard Worker    def __init__(
30*523fa7a6SAndroid Build Coastguard Worker        self, activation_spec: QuantizationSpec, weight_spec: QuantizationSpec
31*523fa7a6SAndroid Build Coastguard Worker    ):
32*523fa7a6SAndroid Build Coastguard Worker        self._activation_spec = activation_spec
33*523fa7a6SAndroid Build Coastguard Worker        self._weight_spec = weight_spec
34*523fa7a6SAndroid Build Coastguard Worker
35*523fa7a6SAndroid Build Coastguard Worker    @property
36*523fa7a6SAndroid Build Coastguard Worker    def activation(self):
37*523fa7a6SAndroid Build Coastguard Worker        return copy.deepcopy(self._activation_spec)
38*523fa7a6SAndroid Build Coastguard Worker
39*523fa7a6SAndroid Build Coastguard Worker    @property
40*523fa7a6SAndroid Build Coastguard Worker    def weight(self):
41*523fa7a6SAndroid Build Coastguard Worker        return copy.deepcopy(self._weight_spec)
42*523fa7a6SAndroid Build Coastguard Worker
43*523fa7a6SAndroid Build Coastguard Worker
44*523fa7a6SAndroid Build Coastguard Workerdef get_quant_config(
45*523fa7a6SAndroid Build Coastguard Worker    precision: Precision,
46*523fa7a6SAndroid Build Coastguard Worker    is_per_channel: bool = False,
47*523fa7a6SAndroid Build Coastguard Worker    is_qat: bool = False,
48*523fa7a6SAndroid Build Coastguard Worker) -> QuantizationConfig:
49*523fa7a6SAndroid Build Coastguard Worker
50*523fa7a6SAndroid Build Coastguard Worker    precision_mappings = {
51*523fa7a6SAndroid Build Coastguard Worker        Precision.A16W16: get_a16w16_quant_config,
52*523fa7a6SAndroid Build Coastguard Worker        Precision.A16W8: get_a16w8_quant_config,
53*523fa7a6SAndroid Build Coastguard Worker        Precision.A16W4: get_a16w4_quant_config,
54*523fa7a6SAndroid Build Coastguard Worker        Precision.A8W8: get_a8w8_quant_config,
55*523fa7a6SAndroid Build Coastguard Worker        Precision.A8W4: get_a8w4_quant_config,
56*523fa7a6SAndroid Build Coastguard Worker    }
57*523fa7a6SAndroid Build Coastguard Worker    if precision not in precision_mappings:
58*523fa7a6SAndroid Build Coastguard Worker        raise RuntimeError("Unrecognized precision setting.")
59*523fa7a6SAndroid Build Coastguard Worker
60*523fa7a6SAndroid Build Coastguard Worker    qconfig_fn = precision_mappings[precision]
61*523fa7a6SAndroid Build Coastguard Worker    return qconfig_fn(is_per_channel, is_qat)
62*523fa7a6SAndroid Build Coastguard Worker
63*523fa7a6SAndroid Build Coastguard Worker
64*523fa7a6SAndroid Build Coastguard Workerdef _get_activation_qspec(
65*523fa7a6SAndroid Build Coastguard Worker    dtype,
66*523fa7a6SAndroid Build Coastguard Worker    is_symmetric,
67*523fa7a6SAndroid Build Coastguard Worker    is_qat,
68*523fa7a6SAndroid Build Coastguard Worker    observer_cls=MinMaxObserver,
69*523fa7a6SAndroid Build Coastguard Worker    quant_min=None,
70*523fa7a6SAndroid Build Coastguard Worker    quant_max=None,
71*523fa7a6SAndroid Build Coastguard Worker):
72*523fa7a6SAndroid Build Coastguard Worker    if quant_max is None:
73*523fa7a6SAndroid Build Coastguard Worker        quant_max = torch.iinfo(dtype).max
74*523fa7a6SAndroid Build Coastguard Worker    if quant_min is None:
75*523fa7a6SAndroid Build Coastguard Worker        # quant_min = torch.iinfo(dtype).min + 1 if is_symmetric else torch.iinfo(dtype).min
76*523fa7a6SAndroid Build Coastguard Worker        quant_min = torch.iinfo(dtype).min
77*523fa7a6SAndroid Build Coastguard Worker
78*523fa7a6SAndroid Build Coastguard Worker    qscheme = torch.per_tensor_symmetric if is_symmetric else torch.per_tensor_affine
79*523fa7a6SAndroid Build Coastguard Worker    if is_qat:
80*523fa7a6SAndroid Build Coastguard Worker        observer_or_fake_quant = FakeQuantize.with_args(observer=observer_cls, eps=1e-6)
81*523fa7a6SAndroid Build Coastguard Worker    else:
82*523fa7a6SAndroid Build Coastguard Worker        observer_or_fake_quant = observer_cls.with_args(eps=1e-6)
83*523fa7a6SAndroid Build Coastguard Worker
84*523fa7a6SAndroid Build Coastguard Worker    return QuantizationSpec(
85*523fa7a6SAndroid Build Coastguard Worker        dtype=dtype,
86*523fa7a6SAndroid Build Coastguard Worker        quant_min=quant_min,
87*523fa7a6SAndroid Build Coastguard Worker        quant_max=quant_max,
88*523fa7a6SAndroid Build Coastguard Worker        qscheme=qscheme,
89*523fa7a6SAndroid Build Coastguard Worker        observer_or_fake_quant_ctr=observer_or_fake_quant,
90*523fa7a6SAndroid Build Coastguard Worker    )
91*523fa7a6SAndroid Build Coastguard Worker
92*523fa7a6SAndroid Build Coastguard Worker
93*523fa7a6SAndroid Build Coastguard Workerdef _get_weight_qspec(
94*523fa7a6SAndroid Build Coastguard Worker    dtype, is_symmetric, is_per_channel, is_qat, quant_min=None, quant_max=None
95*523fa7a6SAndroid Build Coastguard Worker):
96*523fa7a6SAndroid Build Coastguard Worker    if not is_per_channel:
97*523fa7a6SAndroid Build Coastguard Worker        return _get_activation_qspec(
98*523fa7a6SAndroid Build Coastguard Worker            dtype, is_symmetric, is_qat, observer_cls=MinMaxObserver
99*523fa7a6SAndroid Build Coastguard Worker        )
100*523fa7a6SAndroid Build Coastguard Worker
101*523fa7a6SAndroid Build Coastguard Worker    if quant_max is None:
102*523fa7a6SAndroid Build Coastguard Worker        quant_max = torch.iinfo(dtype).max
103*523fa7a6SAndroid Build Coastguard Worker    if quant_min is None:
104*523fa7a6SAndroid Build Coastguard Worker        # quant_min = torch.iinfo(dtype).min + 1 if is_symmetric else torch.iinfo(dtype).min
105*523fa7a6SAndroid Build Coastguard Worker        quant_min = torch.iinfo(dtype).min
106*523fa7a6SAndroid Build Coastguard Worker
107*523fa7a6SAndroid Build Coastguard Worker    qscheme = torch.per_channel_symmetric if is_symmetric else torch.per_channel_affine
108*523fa7a6SAndroid Build Coastguard Worker    if is_qat:
109*523fa7a6SAndroid Build Coastguard Worker        observer_or_fake_quant = FakeQuantize.with_args(
110*523fa7a6SAndroid Build Coastguard Worker            observer=PerChannelMinMaxObserver, eps=1e-6
111*523fa7a6SAndroid Build Coastguard Worker        )
112*523fa7a6SAndroid Build Coastguard Worker    else:
113*523fa7a6SAndroid Build Coastguard Worker        observer_or_fake_quant = PerChannelMinMaxObserver.with_args(eps=1e-6)
114*523fa7a6SAndroid Build Coastguard Worker
115*523fa7a6SAndroid Build Coastguard Worker    return QuantizationSpec(
116*523fa7a6SAndroid Build Coastguard Worker        dtype=dtype,
117*523fa7a6SAndroid Build Coastguard Worker        quant_min=quant_min,
118*523fa7a6SAndroid Build Coastguard Worker        quant_max=quant_max,
119*523fa7a6SAndroid Build Coastguard Worker        qscheme=qscheme,
120*523fa7a6SAndroid Build Coastguard Worker        ch_axis=0,
121*523fa7a6SAndroid Build Coastguard Worker        observer_or_fake_quant_ctr=observer_or_fake_quant,
122*523fa7a6SAndroid Build Coastguard Worker    )
123*523fa7a6SAndroid Build Coastguard Worker
124*523fa7a6SAndroid Build Coastguard Worker
125*523fa7a6SAndroid Build Coastguard Workerdef get_a16w16_quant_config(is_per_channel, is_qat) -> QuantizationConfig:
126*523fa7a6SAndroid Build Coastguard Worker    act_quantization_spec = _get_activation_qspec(torch.int16, True, is_qat)
127*523fa7a6SAndroid Build Coastguard Worker    wgt_quantization_spec = _get_weight_qspec(torch.int16, True, is_per_channel, is_qat)
128*523fa7a6SAndroid Build Coastguard Worker    quantization_config = QuantizationConfig(
129*523fa7a6SAndroid Build Coastguard Worker        act_quantization_spec, wgt_quantization_spec
130*523fa7a6SAndroid Build Coastguard Worker    )
131*523fa7a6SAndroid Build Coastguard Worker    return quantization_config
132*523fa7a6SAndroid Build Coastguard Worker
133*523fa7a6SAndroid Build Coastguard Worker
134*523fa7a6SAndroid Build Coastguard Workerdef get_a16w8_quant_config(is_per_channel, is_qat) -> QuantizationConfig:
135*523fa7a6SAndroid Build Coastguard Worker    act_quantization_spec = _get_activation_qspec(torch.int16, True, is_qat)
136*523fa7a6SAndroid Build Coastguard Worker    wgt_quantization_spec = _get_weight_qspec(torch.int8, True, is_per_channel, is_qat)
137*523fa7a6SAndroid Build Coastguard Worker    quantization_config = QuantizationConfig(
138*523fa7a6SAndroid Build Coastguard Worker        act_quantization_spec, wgt_quantization_spec
139*523fa7a6SAndroid Build Coastguard Worker    )
140*523fa7a6SAndroid Build Coastguard Worker    return quantization_config
141*523fa7a6SAndroid Build Coastguard Worker
142*523fa7a6SAndroid Build Coastguard Worker
143*523fa7a6SAndroid Build Coastguard Workerdef get_a16w4_quant_config(is_per_channel, is_qat) -> QuantizationConfig:
144*523fa7a6SAndroid Build Coastguard Worker    act_quantization_spec = _get_activation_qspec(torch.int16, True, is_qat)
145*523fa7a6SAndroid Build Coastguard Worker    wgt_quantization_spec = _get_weight_qspec(
146*523fa7a6SAndroid Build Coastguard Worker        torch.int8, False, is_per_channel, is_qat, quant_min=-8, quant_max=7
147*523fa7a6SAndroid Build Coastguard Worker    )
148*523fa7a6SAndroid Build Coastguard Worker    quantization_config = QuantizationConfig(
149*523fa7a6SAndroid Build Coastguard Worker        act_quantization_spec, wgt_quantization_spec
150*523fa7a6SAndroid Build Coastguard Worker    )
151*523fa7a6SAndroid Build Coastguard Worker    return quantization_config
152*523fa7a6SAndroid Build Coastguard Worker
153*523fa7a6SAndroid Build Coastguard Worker
154*523fa7a6SAndroid Build Coastguard Workerdef get_a8w8_quant_config(is_per_channel, is_qat) -> QuantizationConfig:
155*523fa7a6SAndroid Build Coastguard Worker    act_quantization_spec = _get_activation_qspec(torch.int8, False, is_qat)
156*523fa7a6SAndroid Build Coastguard Worker    wgt_quantization_spec = _get_weight_qspec(torch.int8, False, is_per_channel, is_qat)
157*523fa7a6SAndroid Build Coastguard Worker    quantization_config = QuantizationConfig(
158*523fa7a6SAndroid Build Coastguard Worker        act_quantization_spec, wgt_quantization_spec
159*523fa7a6SAndroid Build Coastguard Worker    )
160*523fa7a6SAndroid Build Coastguard Worker    return quantization_config
161*523fa7a6SAndroid Build Coastguard Worker
162*523fa7a6SAndroid Build Coastguard Worker
163*523fa7a6SAndroid Build Coastguard Workerdef get_a8w4_quant_config(is_per_channel, is_qat) -> QuantizationConfig:
164*523fa7a6SAndroid Build Coastguard Worker    act_quantization_spec = _get_activation_qspec(torch.int8, False, is_qat)
165*523fa7a6SAndroid Build Coastguard Worker    wgt_quantization_spec = _get_weight_qspec(
166*523fa7a6SAndroid Build Coastguard Worker        torch.int8, False, is_per_channel, is_qat, quant_min=-8, quant_max=7
167*523fa7a6SAndroid Build Coastguard Worker    )
168*523fa7a6SAndroid Build Coastguard Worker    quantization_config = QuantizationConfig(
169*523fa7a6SAndroid Build Coastguard Worker        act_quantization_spec, wgt_quantization_spec
170*523fa7a6SAndroid Build Coastguard Worker    )
171*523fa7a6SAndroid Build Coastguard Worker    return quantization_config
172