xref: /aosp_15_r20/external/executorch/backends/qualcomm/quantizer/qconfig.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Workerfrom dataclasses import dataclass
2*523fa7a6SAndroid Build Coastguard Workerfrom typing import Any, Callable, Dict, List, Optional, Tuple
3*523fa7a6SAndroid Build Coastguard Worker
4*523fa7a6SAndroid Build Coastguard Workerimport torch
5*523fa7a6SAndroid Build Coastguard Workerfrom torch import Tensor
6*523fa7a6SAndroid Build Coastguard Workerfrom torch.ao.quantization.fake_quantize import (
7*523fa7a6SAndroid Build Coastguard Worker    FakeQuantize,
8*523fa7a6SAndroid Build Coastguard Worker    FusedMovingAvgObsFakeQuantize,
9*523fa7a6SAndroid Build Coastguard Worker)
10*523fa7a6SAndroid Build Coastguard Workerfrom torch.ao.quantization.observer import (
11*523fa7a6SAndroid Build Coastguard Worker    MinMaxObserver,
12*523fa7a6SAndroid Build Coastguard Worker    MovingAverageMinMaxObserver,
13*523fa7a6SAndroid Build Coastguard Worker    MovingAveragePerChannelMinMaxObserver,
14*523fa7a6SAndroid Build Coastguard Worker    PerChannelMinMaxObserver,
15*523fa7a6SAndroid Build Coastguard Worker)
16*523fa7a6SAndroid Build Coastguard Workerfrom torch.ao.quantization.quantizer import DerivedQuantizationSpec, QuantizationSpec
17*523fa7a6SAndroid Build Coastguard Workerfrom torch.fx import Node
18*523fa7a6SAndroid Build Coastguard Worker
19*523fa7a6SAndroid Build Coastguard Worker
20*523fa7a6SAndroid Build Coastguard Worker@dataclass(eq=True, frozen=True)
21*523fa7a6SAndroid Build Coastguard Workerclass QuantizationConfig:
22*523fa7a6SAndroid Build Coastguard Worker    input_activation: Optional[QuantizationSpec]
23*523fa7a6SAndroid Build Coastguard Worker    output_activation: Optional[QuantizationSpec]
24*523fa7a6SAndroid Build Coastguard Worker    weight: Optional[QuantizationSpec]
25*523fa7a6SAndroid Build Coastguard Worker    bias: Optional[QuantizationSpec | Callable]
26*523fa7a6SAndroid Build Coastguard Worker
27*523fa7a6SAndroid Build Coastguard Worker
28*523fa7a6SAndroid Build Coastguard Workerdef _derived_bias_quant_spec(node: Node) -> DerivedQuantizationSpec:
29*523fa7a6SAndroid Build Coastguard Worker    def _derive_bias_qparams_fn(
30*523fa7a6SAndroid Build Coastguard Worker        obs_or_fqs: List,
31*523fa7a6SAndroid Build Coastguard Worker    ) -> Tuple[Tensor, Tensor]:
32*523fa7a6SAndroid Build Coastguard Worker        assert (
33*523fa7a6SAndroid Build Coastguard Worker            len(obs_or_fqs) == 2
34*523fa7a6SAndroid Build Coastguard Worker        ), f"Expecting two obs/fqs, one for activation and one for weight, got: {len(obs_or_fqs)}"
35*523fa7a6SAndroid Build Coastguard Worker        act_obs_or_fq = obs_or_fqs[0]
36*523fa7a6SAndroid Build Coastguard Worker        weight_obs_or_fq = obs_or_fqs[1]
37*523fa7a6SAndroid Build Coastguard Worker        weight_scale, weight_zp = weight_obs_or_fq.calculate_qparams()
38*523fa7a6SAndroid Build Coastguard Worker        act_scale, act_zp = act_obs_or_fq.calculate_qparams()
39*523fa7a6SAndroid Build Coastguard Worker        (broadcast_act_scale, broadcast_weight_scale) = torch.broadcast_tensors(
40*523fa7a6SAndroid Build Coastguard Worker            act_scale, weight_scale
41*523fa7a6SAndroid Build Coastguard Worker        )
42*523fa7a6SAndroid Build Coastguard Worker        derived_scale = (broadcast_act_scale * broadcast_weight_scale).to(torch.float32)
43*523fa7a6SAndroid Build Coastguard Worker        derived_zero = torch.zeros(derived_scale.size()).to(torch.int32)
44*523fa7a6SAndroid Build Coastguard Worker        return (derived_scale, derived_zero)
45*523fa7a6SAndroid Build Coastguard Worker
46*523fa7a6SAndroid Build Coastguard Worker    input_act = node.args[0]
47*523fa7a6SAndroid Build Coastguard Worker    assert isinstance(input_act, Node)
48*523fa7a6SAndroid Build Coastguard Worker    weight = node.args[1]
49*523fa7a6SAndroid Build Coastguard Worker    assert isinstance(weight, Node)
50*523fa7a6SAndroid Build Coastguard Worker
51*523fa7a6SAndroid Build Coastguard Worker    return DerivedQuantizationSpec(
52*523fa7a6SAndroid Build Coastguard Worker        derived_from=[(input_act, node), (weight, node)],
53*523fa7a6SAndroid Build Coastguard Worker        derive_qparams_fn=_derive_bias_qparams_fn,
54*523fa7a6SAndroid Build Coastguard Worker        dtype=torch.int32,
55*523fa7a6SAndroid Build Coastguard Worker        quant_min=torch.iinfo(torch.int32).min,
56*523fa7a6SAndroid Build Coastguard Worker        quant_max=torch.iinfo(torch.int32).max,
57*523fa7a6SAndroid Build Coastguard Worker        ch_axis=0,
58*523fa7a6SAndroid Build Coastguard Worker        qscheme=torch.per_channel_symmetric,
59*523fa7a6SAndroid Build Coastguard Worker    )
60*523fa7a6SAndroid Build Coastguard Worker
61*523fa7a6SAndroid Build Coastguard Worker
62*523fa7a6SAndroid Build Coastguard Workerdef get_8a8w_qnn_ptq_config(
63*523fa7a6SAndroid Build Coastguard Worker    act_symmetric: bool = False, act_observer=MovingAverageMinMaxObserver
64*523fa7a6SAndroid Build Coastguard Worker) -> QuantizationConfig:
65*523fa7a6SAndroid Build Coastguard Worker    extra_args: Dict[str, Any] = {"eps": 2**-12}
66*523fa7a6SAndroid Build Coastguard Worker
67*523fa7a6SAndroid Build Coastguard Worker    act_quantization_spec = QuantizationSpec(
68*523fa7a6SAndroid Build Coastguard Worker        dtype=torch.uint8,
69*523fa7a6SAndroid Build Coastguard Worker        qscheme=(
70*523fa7a6SAndroid Build Coastguard Worker            torch.per_tensor_symmetric if act_symmetric else torch.per_tensor_affine
71*523fa7a6SAndroid Build Coastguard Worker        ),
72*523fa7a6SAndroid Build Coastguard Worker        ch_axis=0,
73*523fa7a6SAndroid Build Coastguard Worker        observer_or_fake_quant_ctr=act_observer.with_args(**extra_args),
74*523fa7a6SAndroid Build Coastguard Worker    )
75*523fa7a6SAndroid Build Coastguard Worker
76*523fa7a6SAndroid Build Coastguard Worker    weight_quantization_spec = QuantizationSpec(
77*523fa7a6SAndroid Build Coastguard Worker        dtype=torch.int8,
78*523fa7a6SAndroid Build Coastguard Worker        quant_min=torch.iinfo(torch.int8).min + 1,
79*523fa7a6SAndroid Build Coastguard Worker        quant_max=torch.iinfo(torch.int8).max,
80*523fa7a6SAndroid Build Coastguard Worker        qscheme=torch.per_tensor_symmetric,
81*523fa7a6SAndroid Build Coastguard Worker        ch_axis=0,
82*523fa7a6SAndroid Build Coastguard Worker        observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args),
83*523fa7a6SAndroid Build Coastguard Worker    )
84*523fa7a6SAndroid Build Coastguard Worker
85*523fa7a6SAndroid Build Coastguard Worker    bias_quantization_spec = QuantizationSpec(
86*523fa7a6SAndroid Build Coastguard Worker        dtype=torch.int32,
87*523fa7a6SAndroid Build Coastguard Worker        quant_min=torch.iinfo(torch.int32).min,
88*523fa7a6SAndroid Build Coastguard Worker        quant_max=torch.iinfo(torch.int32).max,
89*523fa7a6SAndroid Build Coastguard Worker        qscheme=torch.per_tensor_symmetric,
90*523fa7a6SAndroid Build Coastguard Worker        observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args),
91*523fa7a6SAndroid Build Coastguard Worker    )
92*523fa7a6SAndroid Build Coastguard Worker
93*523fa7a6SAndroid Build Coastguard Worker    quantization_config = QuantizationConfig(
94*523fa7a6SAndroid Build Coastguard Worker        input_activation=act_quantization_spec,
95*523fa7a6SAndroid Build Coastguard Worker        output_activation=act_quantization_spec,
96*523fa7a6SAndroid Build Coastguard Worker        weight=weight_quantization_spec,
97*523fa7a6SAndroid Build Coastguard Worker        bias=bias_quantization_spec,
98*523fa7a6SAndroid Build Coastguard Worker    )
99*523fa7a6SAndroid Build Coastguard Worker
100*523fa7a6SAndroid Build Coastguard Worker    return quantization_config
101*523fa7a6SAndroid Build Coastguard Worker
102*523fa7a6SAndroid Build Coastguard Worker
103*523fa7a6SAndroid Build Coastguard Worker# 4 bits quantization only supports specific ops.
104*523fa7a6SAndroid Build Coastguard Workerdef get_16a4w_qnn_ptq_config(
105*523fa7a6SAndroid Build Coastguard Worker    act_observer=MovingAverageMinMaxObserver,
106*523fa7a6SAndroid Build Coastguard Worker) -> QuantizationConfig:
107*523fa7a6SAndroid Build Coastguard Worker    extra_args: Dict[str, Any] = {"eps": 2**-20}
108*523fa7a6SAndroid Build Coastguard Worker    act_quantization_spec = QuantizationSpec(
109*523fa7a6SAndroid Build Coastguard Worker        dtype=torch.int32,
110*523fa7a6SAndroid Build Coastguard Worker        quant_min=torch.iinfo(torch.uint16).min,
111*523fa7a6SAndroid Build Coastguard Worker        quant_max=torch.iinfo(torch.uint16).max,
112*523fa7a6SAndroid Build Coastguard Worker        qscheme=torch.per_tensor_affine,
113*523fa7a6SAndroid Build Coastguard Worker        observer_or_fake_quant_ctr=act_observer.with_args(**extra_args),
114*523fa7a6SAndroid Build Coastguard Worker    )
115*523fa7a6SAndroid Build Coastguard Worker
116*523fa7a6SAndroid Build Coastguard Worker    weight_quantization_spec = QuantizationSpec(
117*523fa7a6SAndroid Build Coastguard Worker        dtype=torch.int8,
118*523fa7a6SAndroid Build Coastguard Worker        quant_min=-7,
119*523fa7a6SAndroid Build Coastguard Worker        quant_max=7,
120*523fa7a6SAndroid Build Coastguard Worker        qscheme=torch.per_tensor_symmetric,
121*523fa7a6SAndroid Build Coastguard Worker        ch_axis=0,
122*523fa7a6SAndroid Build Coastguard Worker        observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args),
123*523fa7a6SAndroid Build Coastguard Worker    )
124*523fa7a6SAndroid Build Coastguard Worker
125*523fa7a6SAndroid Build Coastguard Worker    bias_quantization_spec = QuantizationSpec(
126*523fa7a6SAndroid Build Coastguard Worker        dtype=torch.int32,
127*523fa7a6SAndroid Build Coastguard Worker        quant_min=torch.iinfo(torch.int32).min,
128*523fa7a6SAndroid Build Coastguard Worker        quant_max=torch.iinfo(torch.int32).max,
129*523fa7a6SAndroid Build Coastguard Worker        qscheme=torch.per_tensor_symmetric,
130*523fa7a6SAndroid Build Coastguard Worker        observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args),
131*523fa7a6SAndroid Build Coastguard Worker    )
132*523fa7a6SAndroid Build Coastguard Worker
133*523fa7a6SAndroid Build Coastguard Worker    quantization_config = QuantizationConfig(
134*523fa7a6SAndroid Build Coastguard Worker        input_activation=act_quantization_spec,
135*523fa7a6SAndroid Build Coastguard Worker        output_activation=act_quantization_spec,
136*523fa7a6SAndroid Build Coastguard Worker        weight=weight_quantization_spec,
137*523fa7a6SAndroid Build Coastguard Worker        bias=bias_quantization_spec,
138*523fa7a6SAndroid Build Coastguard Worker    )
139*523fa7a6SAndroid Build Coastguard Worker
140*523fa7a6SAndroid Build Coastguard Worker    return quantization_config
141*523fa7a6SAndroid Build Coastguard Worker
142*523fa7a6SAndroid Build Coastguard Worker
143*523fa7a6SAndroid Build Coastguard Workerdef get_16a8w_qnn_ptq_config(
144*523fa7a6SAndroid Build Coastguard Worker    act_observer=MovingAverageMinMaxObserver,
145*523fa7a6SAndroid Build Coastguard Worker) -> QuantizationConfig:
146*523fa7a6SAndroid Build Coastguard Worker    extra_args: Dict[str, Any] = {"eps": 2**-20}
147*523fa7a6SAndroid Build Coastguard Worker    act_quantization_spec = QuantizationSpec(
148*523fa7a6SAndroid Build Coastguard Worker        dtype=torch.int32,
149*523fa7a6SAndroid Build Coastguard Worker        quant_min=torch.iinfo(torch.uint16).min,
150*523fa7a6SAndroid Build Coastguard Worker        quant_max=torch.iinfo(torch.uint16).max,
151*523fa7a6SAndroid Build Coastguard Worker        qscheme=torch.per_tensor_affine,
152*523fa7a6SAndroid Build Coastguard Worker        observer_or_fake_quant_ctr=act_observer.with_args(**extra_args),
153*523fa7a6SAndroid Build Coastguard Worker    )
154*523fa7a6SAndroid Build Coastguard Worker
155*523fa7a6SAndroid Build Coastguard Worker    weight_quantization_spec = QuantizationSpec(
156*523fa7a6SAndroid Build Coastguard Worker        dtype=torch.uint8,
157*523fa7a6SAndroid Build Coastguard Worker        qscheme=torch.per_tensor_symmetric,
158*523fa7a6SAndroid Build Coastguard Worker        ch_axis=0,
159*523fa7a6SAndroid Build Coastguard Worker        observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args),
160*523fa7a6SAndroid Build Coastguard Worker    )
161*523fa7a6SAndroid Build Coastguard Worker
162*523fa7a6SAndroid Build Coastguard Worker    bias_quantization_spec = QuantizationSpec(
163*523fa7a6SAndroid Build Coastguard Worker        dtype=torch.int32,
164*523fa7a6SAndroid Build Coastguard Worker        quant_min=torch.iinfo(torch.int32).min,
165*523fa7a6SAndroid Build Coastguard Worker        quant_max=torch.iinfo(torch.int32).max,
166*523fa7a6SAndroid Build Coastguard Worker        qscheme=torch.per_tensor_symmetric,
167*523fa7a6SAndroid Build Coastguard Worker        observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args),
168*523fa7a6SAndroid Build Coastguard Worker    )
169*523fa7a6SAndroid Build Coastguard Worker
170*523fa7a6SAndroid Build Coastguard Worker    quantization_config = QuantizationConfig(
171*523fa7a6SAndroid Build Coastguard Worker        input_activation=act_quantization_spec,
172*523fa7a6SAndroid Build Coastguard Worker        output_activation=act_quantization_spec,
173*523fa7a6SAndroid Build Coastguard Worker        weight=weight_quantization_spec,
174*523fa7a6SAndroid Build Coastguard Worker        bias=bias_quantization_spec,
175*523fa7a6SAndroid Build Coastguard Worker    )
176*523fa7a6SAndroid Build Coastguard Worker
177*523fa7a6SAndroid Build Coastguard Worker    return quantization_config
178*523fa7a6SAndroid Build Coastguard Worker
179*523fa7a6SAndroid Build Coastguard Worker
180*523fa7a6SAndroid Build Coastguard Workerdef get_16a16w_qnn_ptq_config(
181*523fa7a6SAndroid Build Coastguard Worker    act_observer=MovingAverageMinMaxObserver,
182*523fa7a6SAndroid Build Coastguard Worker) -> QuantizationConfig:
183*523fa7a6SAndroid Build Coastguard Worker    extra_args: Dict[str, Any] = {"eps": 2**-20}
184*523fa7a6SAndroid Build Coastguard Worker    act_quantization_spec = QuantizationSpec(
185*523fa7a6SAndroid Build Coastguard Worker        dtype=torch.int32,
186*523fa7a6SAndroid Build Coastguard Worker        quant_min=torch.iinfo(torch.uint16).min,
187*523fa7a6SAndroid Build Coastguard Worker        quant_max=torch.iinfo(torch.uint16).max,
188*523fa7a6SAndroid Build Coastguard Worker        qscheme=torch.per_tensor_affine,
189*523fa7a6SAndroid Build Coastguard Worker        observer_or_fake_quant_ctr=act_observer.with_args(**extra_args),
190*523fa7a6SAndroid Build Coastguard Worker    )
191*523fa7a6SAndroid Build Coastguard Worker
192*523fa7a6SAndroid Build Coastguard Worker    weight_quantization_spec = QuantizationSpec(
193*523fa7a6SAndroid Build Coastguard Worker        dtype=torch.int16,
194*523fa7a6SAndroid Build Coastguard Worker        quant_min=torch.iinfo(torch.int16).min + 1,
195*523fa7a6SAndroid Build Coastguard Worker        quant_max=torch.iinfo(torch.int16).max,
196*523fa7a6SAndroid Build Coastguard Worker        qscheme=torch.per_tensor_symmetric,
197*523fa7a6SAndroid Build Coastguard Worker        ch_axis=0,
198*523fa7a6SAndroid Build Coastguard Worker        observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args),
199*523fa7a6SAndroid Build Coastguard Worker    )
200*523fa7a6SAndroid Build Coastguard Worker
201*523fa7a6SAndroid Build Coastguard Worker    # torch does not support uint16 quantization, use int32 to bypass
202*523fa7a6SAndroid Build Coastguard Worker    bias_quantization_spec = QuantizationSpec(
203*523fa7a6SAndroid Build Coastguard Worker        dtype=torch.int32,
204*523fa7a6SAndroid Build Coastguard Worker        quant_min=torch.iinfo(torch.int32).min,
205*523fa7a6SAndroid Build Coastguard Worker        quant_max=torch.iinfo(torch.int32).max,
206*523fa7a6SAndroid Build Coastguard Worker        qscheme=torch.per_tensor_symmetric,
207*523fa7a6SAndroid Build Coastguard Worker        observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args),
208*523fa7a6SAndroid Build Coastguard Worker    )
209*523fa7a6SAndroid Build Coastguard Worker
210*523fa7a6SAndroid Build Coastguard Worker    quantization_config = QuantizationConfig(
211*523fa7a6SAndroid Build Coastguard Worker        input_activation=act_quantization_spec,
212*523fa7a6SAndroid Build Coastguard Worker        output_activation=act_quantization_spec,
213*523fa7a6SAndroid Build Coastguard Worker        weight=weight_quantization_spec,
214*523fa7a6SAndroid Build Coastguard Worker        bias=bias_quantization_spec,
215*523fa7a6SAndroid Build Coastguard Worker    )
216*523fa7a6SAndroid Build Coastguard Worker
217*523fa7a6SAndroid Build Coastguard Worker    return quantization_config
218*523fa7a6SAndroid Build Coastguard Worker
219*523fa7a6SAndroid Build Coastguard Worker
220*523fa7a6SAndroid Build Coastguard Workerdef get_ptq_per_channel_quant_config(
221*523fa7a6SAndroid Build Coastguard Worker    act_dtype=torch.uint8,
222*523fa7a6SAndroid Build Coastguard Worker    weight_dtype=torch.int8,
223*523fa7a6SAndroid Build Coastguard Worker    act_observer=MovingAverageMinMaxObserver,
224*523fa7a6SAndroid Build Coastguard Worker) -> QuantizationConfig:
225*523fa7a6SAndroid Build Coastguard Worker    extra_args: Dict[str, Any] = {"eps": 2**-12}
226*523fa7a6SAndroid Build Coastguard Worker
227*523fa7a6SAndroid Build Coastguard Worker    supported_act_types = {
228*523fa7a6SAndroid Build Coastguard Worker        torch.uint8,
229*523fa7a6SAndroid Build Coastguard Worker        torch.uint16,
230*523fa7a6SAndroid Build Coastguard Worker        torch.int8,
231*523fa7a6SAndroid Build Coastguard Worker        torch.int16,
232*523fa7a6SAndroid Build Coastguard Worker    }
233*523fa7a6SAndroid Build Coastguard Worker    # TODO accept "int4" temporally. Remove "int4" when torch support torch.int4 dtype
234*523fa7a6SAndroid Build Coastguard Worker    supported_weight_dtypes = {"int4", torch.int8, torch.int16}
235*523fa7a6SAndroid Build Coastguard Worker    assert (
236*523fa7a6SAndroid Build Coastguard Worker        act_dtype in supported_act_types
237*523fa7a6SAndroid Build Coastguard Worker    ), f"act_dtype, {act_dtype} is not one of supported types, {supported_act_types}"
238*523fa7a6SAndroid Build Coastguard Worker
239*523fa7a6SAndroid Build Coastguard Worker    assert (
240*523fa7a6SAndroid Build Coastguard Worker        weight_dtype in supported_weight_dtypes
241*523fa7a6SAndroid Build Coastguard Worker    ), f"weight_dtype, {weight_dtype} is not one of supported types, {supported_weight_dtypes}"
242*523fa7a6SAndroid Build Coastguard Worker
243*523fa7a6SAndroid Build Coastguard Worker    # torch do not support uint16 quantization, use int32 to bypass
244*523fa7a6SAndroid Build Coastguard Worker    act_quantization_spec = QuantizationSpec(
245*523fa7a6SAndroid Build Coastguard Worker        dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype,
246*523fa7a6SAndroid Build Coastguard Worker        quant_min=torch.iinfo(act_dtype).min,
247*523fa7a6SAndroid Build Coastguard Worker        quant_max=torch.iinfo(act_dtype).max,
248*523fa7a6SAndroid Build Coastguard Worker        qscheme=torch.per_tensor_affine,
249*523fa7a6SAndroid Build Coastguard Worker        observer_or_fake_quant_ctr=act_observer.with_args(**extra_args),
250*523fa7a6SAndroid Build Coastguard Worker    )
251*523fa7a6SAndroid Build Coastguard Worker
252*523fa7a6SAndroid Build Coastguard Worker    weight_quantization_spec = QuantizationSpec(
253*523fa7a6SAndroid Build Coastguard Worker        dtype=torch.int8 if weight_dtype == "int4" else weight_dtype,
254*523fa7a6SAndroid Build Coastguard Worker        quant_min=-7 if weight_dtype == "int4" else torch.iinfo(weight_dtype).min + 1,
255*523fa7a6SAndroid Build Coastguard Worker        quant_max=7 if weight_dtype == "int4" else torch.iinfo(weight_dtype).max,
256*523fa7a6SAndroid Build Coastguard Worker        qscheme=torch.per_channel_symmetric,
257*523fa7a6SAndroid Build Coastguard Worker        ch_axis=0,
258*523fa7a6SAndroid Build Coastguard Worker        observer_or_fake_quant_ctr=PerChannelMinMaxObserver.with_args(**extra_args),
259*523fa7a6SAndroid Build Coastguard Worker    )
260*523fa7a6SAndroid Build Coastguard Worker
261*523fa7a6SAndroid Build Coastguard Worker    bias_quantization_spec = _derived_bias_quant_spec
262*523fa7a6SAndroid Build Coastguard Worker
263*523fa7a6SAndroid Build Coastguard Worker    quantization_config = QuantizationConfig(
264*523fa7a6SAndroid Build Coastguard Worker        input_activation=act_quantization_spec,
265*523fa7a6SAndroid Build Coastguard Worker        output_activation=act_quantization_spec,
266*523fa7a6SAndroid Build Coastguard Worker        weight=weight_quantization_spec,
267*523fa7a6SAndroid Build Coastguard Worker        bias=bias_quantization_spec,
268*523fa7a6SAndroid Build Coastguard Worker    )
269*523fa7a6SAndroid Build Coastguard Worker
270*523fa7a6SAndroid Build Coastguard Worker    return quantization_config
271*523fa7a6SAndroid Build Coastguard Worker
272*523fa7a6SAndroid Build Coastguard Worker
273*523fa7a6SAndroid Build Coastguard Worker# TODO merge qat and ptq to a fucntion, and use a bool flag to control it
274*523fa7a6SAndroid Build Coastguard Workerdef get_8a8w_qnn_qat_config(
275*523fa7a6SAndroid Build Coastguard Worker    act_symmetric: bool = False, act_observer=MovingAverageMinMaxObserver
276*523fa7a6SAndroid Build Coastguard Worker) -> QuantizationConfig:
277*523fa7a6SAndroid Build Coastguard Worker    act_fake_quant_ctr = FakeQuantize.with_args(
278*523fa7a6SAndroid Build Coastguard Worker        dtype=torch.uint8,
279*523fa7a6SAndroid Build Coastguard Worker        qscheme=(
280*523fa7a6SAndroid Build Coastguard Worker            torch.per_tensor_symmetric if act_symmetric else torch.per_tensor_affine
281*523fa7a6SAndroid Build Coastguard Worker        ),
282*523fa7a6SAndroid Build Coastguard Worker        reduce_range=True,
283*523fa7a6SAndroid Build Coastguard Worker        observer=act_observer,
284*523fa7a6SAndroid Build Coastguard Worker    )
285*523fa7a6SAndroid Build Coastguard Worker    act_quantization_spec = QuantizationSpec(
286*523fa7a6SAndroid Build Coastguard Worker        dtype=torch.uint8,
287*523fa7a6SAndroid Build Coastguard Worker        qscheme=(
288*523fa7a6SAndroid Build Coastguard Worker            torch.per_tensor_symmetric if act_symmetric else torch.per_tensor_affine
289*523fa7a6SAndroid Build Coastguard Worker        ),
290*523fa7a6SAndroid Build Coastguard Worker        ch_axis=0,
291*523fa7a6SAndroid Build Coastguard Worker        observer_or_fake_quant_ctr=act_fake_quant_ctr,
292*523fa7a6SAndroid Build Coastguard Worker    )
293*523fa7a6SAndroid Build Coastguard Worker
294*523fa7a6SAndroid Build Coastguard Worker    weight_fake_quant_ctr = FusedMovingAvgObsFakeQuantize.with_args(
295*523fa7a6SAndroid Build Coastguard Worker        dtype=torch.int8,
296*523fa7a6SAndroid Build Coastguard Worker        quant_min=torch.iinfo(torch.int8).min + 1,
297*523fa7a6SAndroid Build Coastguard Worker        quant_max=torch.iinfo(torch.int8).max,
298*523fa7a6SAndroid Build Coastguard Worker        qscheme=torch.per_tensor_symmetric,
299*523fa7a6SAndroid Build Coastguard Worker        reduce_range=True,
300*523fa7a6SAndroid Build Coastguard Worker        observer=MovingAverageMinMaxObserver,
301*523fa7a6SAndroid Build Coastguard Worker    )
302*523fa7a6SAndroid Build Coastguard Worker    weight_quantization_spec = QuantizationSpec(
303*523fa7a6SAndroid Build Coastguard Worker        dtype=torch.int8,
304*523fa7a6SAndroid Build Coastguard Worker        quant_min=torch.iinfo(torch.int8).min + 1,
305*523fa7a6SAndroid Build Coastguard Worker        quant_max=torch.iinfo(torch.int8).max,
306*523fa7a6SAndroid Build Coastguard Worker        qscheme=torch.per_tensor_symmetric,
307*523fa7a6SAndroid Build Coastguard Worker        ch_axis=0,
308*523fa7a6SAndroid Build Coastguard Worker        observer_or_fake_quant_ctr=weight_fake_quant_ctr,
309*523fa7a6SAndroid Build Coastguard Worker    )
310*523fa7a6SAndroid Build Coastguard Worker
311*523fa7a6SAndroid Build Coastguard Worker    bias_fake_quant_ctr = FakeQuantize.with_args(
312*523fa7a6SAndroid Build Coastguard Worker        dtype=torch.int32,
313*523fa7a6SAndroid Build Coastguard Worker        quant_min=torch.iinfo(torch.int32).min,
314*523fa7a6SAndroid Build Coastguard Worker        quant_max=torch.iinfo(torch.int32).max,
315*523fa7a6SAndroid Build Coastguard Worker        qscheme=torch.per_tensor_symmetric,
316*523fa7a6SAndroid Build Coastguard Worker        reduce_range=True,
317*523fa7a6SAndroid Build Coastguard Worker        observer=MovingAverageMinMaxObserver,
318*523fa7a6SAndroid Build Coastguard Worker    )
319*523fa7a6SAndroid Build Coastguard Worker    bias_quantization_spec = QuantizationSpec(
320*523fa7a6SAndroid Build Coastguard Worker        dtype=torch.int32,
321*523fa7a6SAndroid Build Coastguard Worker        quant_min=torch.iinfo(torch.int32).min,
322*523fa7a6SAndroid Build Coastguard Worker        quant_max=torch.iinfo(torch.int32).max,
323*523fa7a6SAndroid Build Coastguard Worker        qscheme=torch.per_tensor_symmetric,
324*523fa7a6SAndroid Build Coastguard Worker        observer_or_fake_quant_ctr=bias_fake_quant_ctr,
325*523fa7a6SAndroid Build Coastguard Worker    )
326*523fa7a6SAndroid Build Coastguard Worker
327*523fa7a6SAndroid Build Coastguard Worker    quantization_config = QuantizationConfig(
328*523fa7a6SAndroid Build Coastguard Worker        input_activation=act_quantization_spec,
329*523fa7a6SAndroid Build Coastguard Worker        output_activation=act_quantization_spec,
330*523fa7a6SAndroid Build Coastguard Worker        weight=weight_quantization_spec,
331*523fa7a6SAndroid Build Coastguard Worker        bias=bias_quantization_spec,
332*523fa7a6SAndroid Build Coastguard Worker    )
333*523fa7a6SAndroid Build Coastguard Worker
334*523fa7a6SAndroid Build Coastguard Worker    return quantization_config
335*523fa7a6SAndroid Build Coastguard Worker
336*523fa7a6SAndroid Build Coastguard Worker
337*523fa7a6SAndroid Build Coastguard Workerdef get_16a4w_qnn_qat_config(
338*523fa7a6SAndroid Build Coastguard Worker    act_observer=MovingAverageMinMaxObserver,
339*523fa7a6SAndroid Build Coastguard Worker) -> QuantizationConfig:
340*523fa7a6SAndroid Build Coastguard Worker    act_fake_quant_ctr = FakeQuantize.with_args(
341*523fa7a6SAndroid Build Coastguard Worker        dtype=torch.int32,
342*523fa7a6SAndroid Build Coastguard Worker        quant_min=torch.iinfo(torch.uint16).min,
343*523fa7a6SAndroid Build Coastguard Worker        quant_max=torch.iinfo(torch.uint16).max,
344*523fa7a6SAndroid Build Coastguard Worker        qscheme=torch.per_tensor_affine,
345*523fa7a6SAndroid Build Coastguard Worker        reduce_range=True,
346*523fa7a6SAndroid Build Coastguard Worker        observer=act_observer,
347*523fa7a6SAndroid Build Coastguard Worker    )
348*523fa7a6SAndroid Build Coastguard Worker    act_quantization_spec = QuantizationSpec(
349*523fa7a6SAndroid Build Coastguard Worker        dtype=torch.int32,
350*523fa7a6SAndroid Build Coastguard Worker        quant_min=torch.iinfo(torch.uint16).min,
351*523fa7a6SAndroid Build Coastguard Worker        quant_max=torch.iinfo(torch.uint16).max,
352*523fa7a6SAndroid Build Coastguard Worker        qscheme=torch.per_tensor_affine,
353*523fa7a6SAndroid Build Coastguard Worker        observer_or_fake_quant_ctr=act_fake_quant_ctr,
354*523fa7a6SAndroid Build Coastguard Worker    )
355*523fa7a6SAndroid Build Coastguard Worker
356*523fa7a6SAndroid Build Coastguard Worker    weight_fake_quant_ctr = FusedMovingAvgObsFakeQuantize.with_args(
357*523fa7a6SAndroid Build Coastguard Worker        dtype=torch.int8,
358*523fa7a6SAndroid Build Coastguard Worker        quant_min=-7,
359*523fa7a6SAndroid Build Coastguard Worker        quant_max=7,
360*523fa7a6SAndroid Build Coastguard Worker        qscheme=torch.per_tensor_symmetric,
361*523fa7a6SAndroid Build Coastguard Worker        ch_axis=0,
362*523fa7a6SAndroid Build Coastguard Worker        reduce_range=True,
363*523fa7a6SAndroid Build Coastguard Worker        observer=MovingAverageMinMaxObserver,
364*523fa7a6SAndroid Build Coastguard Worker    )
365*523fa7a6SAndroid Build Coastguard Worker    weight_quantization_spec = QuantizationSpec(
366*523fa7a6SAndroid Build Coastguard Worker        dtype=torch.int8,
367*523fa7a6SAndroid Build Coastguard Worker        quant_min=-7,
368*523fa7a6SAndroid Build Coastguard Worker        quant_max=7,
369*523fa7a6SAndroid Build Coastguard Worker        qscheme=torch.per_tensor_symmetric,
370*523fa7a6SAndroid Build Coastguard Worker        ch_axis=0,
371*523fa7a6SAndroid Build Coastguard Worker        observer_or_fake_quant_ctr=weight_fake_quant_ctr,
372*523fa7a6SAndroid Build Coastguard Worker    )
373*523fa7a6SAndroid Build Coastguard Worker
374*523fa7a6SAndroid Build Coastguard Worker    bias_fake_quant_ctr = FakeQuantize.with_args(
375*523fa7a6SAndroid Build Coastguard Worker        dtype=torch.int32,
376*523fa7a6SAndroid Build Coastguard Worker        quant_min=torch.iinfo(torch.int32).min,
377*523fa7a6SAndroid Build Coastguard Worker        quant_max=torch.iinfo(torch.int32).max,
378*523fa7a6SAndroid Build Coastguard Worker        qscheme=torch.per_tensor_symmetric,
379*523fa7a6SAndroid Build Coastguard Worker        reduce_range=True,
380*523fa7a6SAndroid Build Coastguard Worker        observer=MovingAverageMinMaxObserver,
381*523fa7a6SAndroid Build Coastguard Worker    )
382*523fa7a6SAndroid Build Coastguard Worker    bias_quantization_spec = QuantizationSpec(
383*523fa7a6SAndroid Build Coastguard Worker        dtype=torch.int32,
384*523fa7a6SAndroid Build Coastguard Worker        quant_min=torch.iinfo(torch.int32).min,
385*523fa7a6SAndroid Build Coastguard Worker        quant_max=torch.iinfo(torch.int32).max,
386*523fa7a6SAndroid Build Coastguard Worker        qscheme=torch.per_tensor_symmetric,
387*523fa7a6SAndroid Build Coastguard Worker        observer_or_fake_quant_ctr=bias_fake_quant_ctr,
388*523fa7a6SAndroid Build Coastguard Worker    )
389*523fa7a6SAndroid Build Coastguard Worker
390*523fa7a6SAndroid Build Coastguard Worker    quantization_config = QuantizationConfig(
391*523fa7a6SAndroid Build Coastguard Worker        input_activation=act_quantization_spec,
392*523fa7a6SAndroid Build Coastguard Worker        output_activation=act_quantization_spec,
393*523fa7a6SAndroid Build Coastguard Worker        weight=weight_quantization_spec,
394*523fa7a6SAndroid Build Coastguard Worker        bias=bias_quantization_spec,
395*523fa7a6SAndroid Build Coastguard Worker    )
396*523fa7a6SAndroid Build Coastguard Worker
397*523fa7a6SAndroid Build Coastguard Worker    return quantization_config
398*523fa7a6SAndroid Build Coastguard Worker
399*523fa7a6SAndroid Build Coastguard Worker
400*523fa7a6SAndroid Build Coastguard Workerdef get_qat_per_channel_quant_config(
401*523fa7a6SAndroid Build Coastguard Worker    act_dtype=torch.uint8,
402*523fa7a6SAndroid Build Coastguard Worker    weight_dtype=torch.int8,
403*523fa7a6SAndroid Build Coastguard Worker    act_observer=MovingAverageMinMaxObserver,
404*523fa7a6SAndroid Build Coastguard Worker) -> QuantizationConfig:
405*523fa7a6SAndroid Build Coastguard Worker    supported_act_types = {
406*523fa7a6SAndroid Build Coastguard Worker        torch.uint8,
407*523fa7a6SAndroid Build Coastguard Worker        torch.uint16,
408*523fa7a6SAndroid Build Coastguard Worker        torch.int8,
409*523fa7a6SAndroid Build Coastguard Worker        torch.int16,
410*523fa7a6SAndroid Build Coastguard Worker    }
411*523fa7a6SAndroid Build Coastguard Worker    # TODO accept "int4" temporally. Remove "int4" when torch support torch.int4 dtype
412*523fa7a6SAndroid Build Coastguard Worker    supported_weight_dtypes = {"int4", torch.int8, torch.int16}
413*523fa7a6SAndroid Build Coastguard Worker    assert (
414*523fa7a6SAndroid Build Coastguard Worker        act_dtype in supported_act_types
415*523fa7a6SAndroid Build Coastguard Worker    ), f"act_dtype, {act_dtype} is not one of supported types, {supported_act_types}"
416*523fa7a6SAndroid Build Coastguard Worker
417*523fa7a6SAndroid Build Coastguard Worker    assert (
418*523fa7a6SAndroid Build Coastguard Worker        weight_dtype in supported_weight_dtypes
419*523fa7a6SAndroid Build Coastguard Worker    ), f"weight_dtype, {weight_dtype} is not one of supported types, {supported_weight_dtypes}"
420*523fa7a6SAndroid Build Coastguard Worker
421*523fa7a6SAndroid Build Coastguard Worker    # torch do not support uint16 quantization, use int32 to bypass
422*523fa7a6SAndroid Build Coastguard Worker    act_fake_quant_ctr = FakeQuantize.with_args(
423*523fa7a6SAndroid Build Coastguard Worker        dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype,
424*523fa7a6SAndroid Build Coastguard Worker        quant_min=torch.iinfo(act_dtype).min,
425*523fa7a6SAndroid Build Coastguard Worker        quant_max=torch.iinfo(act_dtype).max,
426*523fa7a6SAndroid Build Coastguard Worker        qscheme=torch.per_tensor_affine,
427*523fa7a6SAndroid Build Coastguard Worker        reduce_range=True,
428*523fa7a6SAndroid Build Coastguard Worker        observer=act_observer,
429*523fa7a6SAndroid Build Coastguard Worker    )
430*523fa7a6SAndroid Build Coastguard Worker    act_quantization_spec = QuantizationSpec(
431*523fa7a6SAndroid Build Coastguard Worker        dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype,
432*523fa7a6SAndroid Build Coastguard Worker        quant_min=torch.iinfo(act_dtype).min,
433*523fa7a6SAndroid Build Coastguard Worker        quant_max=torch.iinfo(act_dtype).max,
434*523fa7a6SAndroid Build Coastguard Worker        qscheme=torch.per_tensor_affine,
435*523fa7a6SAndroid Build Coastguard Worker        observer_or_fake_quant_ctr=act_fake_quant_ctr,
436*523fa7a6SAndroid Build Coastguard Worker    )
437*523fa7a6SAndroid Build Coastguard Worker
438*523fa7a6SAndroid Build Coastguard Worker    weight_fake_quant_ctr = FusedMovingAvgObsFakeQuantize.with_args(
439*523fa7a6SAndroid Build Coastguard Worker        dtype=torch.int8 if weight_dtype == "int4" else weight_dtype,
440*523fa7a6SAndroid Build Coastguard Worker        quant_min=-7 if weight_dtype == "int4" else torch.iinfo(weight_dtype).min + 1,
441*523fa7a6SAndroid Build Coastguard Worker        quant_max=7 if weight_dtype == "int4" else torch.iinfo(weight_dtype).max,
442*523fa7a6SAndroid Build Coastguard Worker        qscheme=torch.per_channel_symmetric,
443*523fa7a6SAndroid Build Coastguard Worker        ch_axis=0,
444*523fa7a6SAndroid Build Coastguard Worker        observer=MovingAveragePerChannelMinMaxObserver,
445*523fa7a6SAndroid Build Coastguard Worker    )
446*523fa7a6SAndroid Build Coastguard Worker    weight_quantization_spec = QuantizationSpec(
447*523fa7a6SAndroid Build Coastguard Worker        dtype=torch.int8 if weight_dtype == "int4" else weight_dtype,
448*523fa7a6SAndroid Build Coastguard Worker        quant_min=-7 if weight_dtype == "int4" else torch.iinfo(weight_dtype).min + 1,
449*523fa7a6SAndroid Build Coastguard Worker        quant_max=7 if weight_dtype == "int4" else torch.iinfo(weight_dtype).max,
450*523fa7a6SAndroid Build Coastguard Worker        qscheme=torch.per_channel_symmetric,
451*523fa7a6SAndroid Build Coastguard Worker        ch_axis=0,
452*523fa7a6SAndroid Build Coastguard Worker        observer_or_fake_quant_ctr=weight_fake_quant_ctr,
453*523fa7a6SAndroid Build Coastguard Worker    )
454*523fa7a6SAndroid Build Coastguard Worker
455*523fa7a6SAndroid Build Coastguard Worker    bias_quantization_spec = _derived_bias_quant_spec
456*523fa7a6SAndroid Build Coastguard Worker
457*523fa7a6SAndroid Build Coastguard Worker    quantization_config = QuantizationConfig(
458*523fa7a6SAndroid Build Coastguard Worker        input_activation=act_quantization_spec,
459*523fa7a6SAndroid Build Coastguard Worker        output_activation=act_quantization_spec,
460*523fa7a6SAndroid Build Coastguard Worker        weight=weight_quantization_spec,
461*523fa7a6SAndroid Build Coastguard Worker        bias=bias_quantization_spec,
462*523fa7a6SAndroid Build Coastguard Worker    )
463*523fa7a6SAndroid Build Coastguard Worker
464*523fa7a6SAndroid Build Coastguard Worker    return quantization_config
465