xref: /aosp_15_r20/external/executorch/backends/qualcomm/quantizer/qconfig.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1from dataclasses import dataclass
2from typing import Any, Callable, Dict, List, Optional, Tuple
3
4import torch
5from torch import Tensor
6from torch.ao.quantization.fake_quantize import (
7    FakeQuantize,
8    FusedMovingAvgObsFakeQuantize,
9)
10from torch.ao.quantization.observer import (
11    MinMaxObserver,
12    MovingAverageMinMaxObserver,
13    MovingAveragePerChannelMinMaxObserver,
14    PerChannelMinMaxObserver,
15)
16from torch.ao.quantization.quantizer import DerivedQuantizationSpec, QuantizationSpec
17from torch.fx import Node
18
19
20@dataclass(eq=True, frozen=True)
21class QuantizationConfig:
22    input_activation: Optional[QuantizationSpec]
23    output_activation: Optional[QuantizationSpec]
24    weight: Optional[QuantizationSpec]
25    bias: Optional[QuantizationSpec | Callable]
26
27
28def _derived_bias_quant_spec(node: Node) -> DerivedQuantizationSpec:
29    def _derive_bias_qparams_fn(
30        obs_or_fqs: List,
31    ) -> Tuple[Tensor, Tensor]:
32        assert (
33            len(obs_or_fqs) == 2
34        ), f"Expecting two obs/fqs, one for activation and one for weight, got: {len(obs_or_fqs)}"
35        act_obs_or_fq = obs_or_fqs[0]
36        weight_obs_or_fq = obs_or_fqs[1]
37        weight_scale, weight_zp = weight_obs_or_fq.calculate_qparams()
38        act_scale, act_zp = act_obs_or_fq.calculate_qparams()
39        (broadcast_act_scale, broadcast_weight_scale) = torch.broadcast_tensors(
40            act_scale, weight_scale
41        )
42        derived_scale = (broadcast_act_scale * broadcast_weight_scale).to(torch.float32)
43        derived_zero = torch.zeros(derived_scale.size()).to(torch.int32)
44        return (derived_scale, derived_zero)
45
46    input_act = node.args[0]
47    assert isinstance(input_act, Node)
48    weight = node.args[1]
49    assert isinstance(weight, Node)
50
51    return DerivedQuantizationSpec(
52        derived_from=[(input_act, node), (weight, node)],
53        derive_qparams_fn=_derive_bias_qparams_fn,
54        dtype=torch.int32,
55        quant_min=torch.iinfo(torch.int32).min,
56        quant_max=torch.iinfo(torch.int32).max,
57        ch_axis=0,
58        qscheme=torch.per_channel_symmetric,
59    )
60
61
62def get_8a8w_qnn_ptq_config(
63    act_symmetric: bool = False, act_observer=MovingAverageMinMaxObserver
64) -> QuantizationConfig:
65    extra_args: Dict[str, Any] = {"eps": 2**-12}
66
67    act_quantization_spec = QuantizationSpec(
68        dtype=torch.uint8,
69        qscheme=(
70            torch.per_tensor_symmetric if act_symmetric else torch.per_tensor_affine
71        ),
72        ch_axis=0,
73        observer_or_fake_quant_ctr=act_observer.with_args(**extra_args),
74    )
75
76    weight_quantization_spec = QuantizationSpec(
77        dtype=torch.int8,
78        quant_min=torch.iinfo(torch.int8).min + 1,
79        quant_max=torch.iinfo(torch.int8).max,
80        qscheme=torch.per_tensor_symmetric,
81        ch_axis=0,
82        observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args),
83    )
84
85    bias_quantization_spec = QuantizationSpec(
86        dtype=torch.int32,
87        quant_min=torch.iinfo(torch.int32).min,
88        quant_max=torch.iinfo(torch.int32).max,
89        qscheme=torch.per_tensor_symmetric,
90        observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args),
91    )
92
93    quantization_config = QuantizationConfig(
94        input_activation=act_quantization_spec,
95        output_activation=act_quantization_spec,
96        weight=weight_quantization_spec,
97        bias=bias_quantization_spec,
98    )
99
100    return quantization_config
101
102
103# 4 bits quantization only supports specific ops.
104def get_16a4w_qnn_ptq_config(
105    act_observer=MovingAverageMinMaxObserver,
106) -> QuantizationConfig:
107    extra_args: Dict[str, Any] = {"eps": 2**-20}
108    act_quantization_spec = QuantizationSpec(
109        dtype=torch.int32,
110        quant_min=torch.iinfo(torch.uint16).min,
111        quant_max=torch.iinfo(torch.uint16).max,
112        qscheme=torch.per_tensor_affine,
113        observer_or_fake_quant_ctr=act_observer.with_args(**extra_args),
114    )
115
116    weight_quantization_spec = QuantizationSpec(
117        dtype=torch.int8,
118        quant_min=-7,
119        quant_max=7,
120        qscheme=torch.per_tensor_symmetric,
121        ch_axis=0,
122        observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args),
123    )
124
125    bias_quantization_spec = QuantizationSpec(
126        dtype=torch.int32,
127        quant_min=torch.iinfo(torch.int32).min,
128        quant_max=torch.iinfo(torch.int32).max,
129        qscheme=torch.per_tensor_symmetric,
130        observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args),
131    )
132
133    quantization_config = QuantizationConfig(
134        input_activation=act_quantization_spec,
135        output_activation=act_quantization_spec,
136        weight=weight_quantization_spec,
137        bias=bias_quantization_spec,
138    )
139
140    return quantization_config
141
142
143def get_16a8w_qnn_ptq_config(
144    act_observer=MovingAverageMinMaxObserver,
145) -> QuantizationConfig:
146    extra_args: Dict[str, Any] = {"eps": 2**-20}
147    act_quantization_spec = QuantizationSpec(
148        dtype=torch.int32,
149        quant_min=torch.iinfo(torch.uint16).min,
150        quant_max=torch.iinfo(torch.uint16).max,
151        qscheme=torch.per_tensor_affine,
152        observer_or_fake_quant_ctr=act_observer.with_args(**extra_args),
153    )
154
155    weight_quantization_spec = QuantizationSpec(
156        dtype=torch.uint8,
157        qscheme=torch.per_tensor_symmetric,
158        ch_axis=0,
159        observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args),
160    )
161
162    bias_quantization_spec = QuantizationSpec(
163        dtype=torch.int32,
164        quant_min=torch.iinfo(torch.int32).min,
165        quant_max=torch.iinfo(torch.int32).max,
166        qscheme=torch.per_tensor_symmetric,
167        observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args),
168    )
169
170    quantization_config = QuantizationConfig(
171        input_activation=act_quantization_spec,
172        output_activation=act_quantization_spec,
173        weight=weight_quantization_spec,
174        bias=bias_quantization_spec,
175    )
176
177    return quantization_config
178
179
180def get_16a16w_qnn_ptq_config(
181    act_observer=MovingAverageMinMaxObserver,
182) -> QuantizationConfig:
183    extra_args: Dict[str, Any] = {"eps": 2**-20}
184    act_quantization_spec = QuantizationSpec(
185        dtype=torch.int32,
186        quant_min=torch.iinfo(torch.uint16).min,
187        quant_max=torch.iinfo(torch.uint16).max,
188        qscheme=torch.per_tensor_affine,
189        observer_or_fake_quant_ctr=act_observer.with_args(**extra_args),
190    )
191
192    weight_quantization_spec = QuantizationSpec(
193        dtype=torch.int16,
194        quant_min=torch.iinfo(torch.int16).min + 1,
195        quant_max=torch.iinfo(torch.int16).max,
196        qscheme=torch.per_tensor_symmetric,
197        ch_axis=0,
198        observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args),
199    )
200
201    # torch does not support uint16 quantization, use int32 to bypass
202    bias_quantization_spec = QuantizationSpec(
203        dtype=torch.int32,
204        quant_min=torch.iinfo(torch.int32).min,
205        quant_max=torch.iinfo(torch.int32).max,
206        qscheme=torch.per_tensor_symmetric,
207        observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args),
208    )
209
210    quantization_config = QuantizationConfig(
211        input_activation=act_quantization_spec,
212        output_activation=act_quantization_spec,
213        weight=weight_quantization_spec,
214        bias=bias_quantization_spec,
215    )
216
217    return quantization_config
218
219
220def get_ptq_per_channel_quant_config(
221    act_dtype=torch.uint8,
222    weight_dtype=torch.int8,
223    act_observer=MovingAverageMinMaxObserver,
224) -> QuantizationConfig:
225    extra_args: Dict[str, Any] = {"eps": 2**-12}
226
227    supported_act_types = {
228        torch.uint8,
229        torch.uint16,
230        torch.int8,
231        torch.int16,
232    }
233    # TODO accept "int4" temporally. Remove "int4" when torch support torch.int4 dtype
234    supported_weight_dtypes = {"int4", torch.int8, torch.int16}
235    assert (
236        act_dtype in supported_act_types
237    ), f"act_dtype, {act_dtype} is not one of supported types, {supported_act_types}"
238
239    assert (
240        weight_dtype in supported_weight_dtypes
241    ), f"weight_dtype, {weight_dtype} is not one of supported types, {supported_weight_dtypes}"
242
243    # torch do not support uint16 quantization, use int32 to bypass
244    act_quantization_spec = QuantizationSpec(
245        dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype,
246        quant_min=torch.iinfo(act_dtype).min,
247        quant_max=torch.iinfo(act_dtype).max,
248        qscheme=torch.per_tensor_affine,
249        observer_or_fake_quant_ctr=act_observer.with_args(**extra_args),
250    )
251
252    weight_quantization_spec = QuantizationSpec(
253        dtype=torch.int8 if weight_dtype == "int4" else weight_dtype,
254        quant_min=-7 if weight_dtype == "int4" else torch.iinfo(weight_dtype).min + 1,
255        quant_max=7 if weight_dtype == "int4" else torch.iinfo(weight_dtype).max,
256        qscheme=torch.per_channel_symmetric,
257        ch_axis=0,
258        observer_or_fake_quant_ctr=PerChannelMinMaxObserver.with_args(**extra_args),
259    )
260
261    bias_quantization_spec = _derived_bias_quant_spec
262
263    quantization_config = QuantizationConfig(
264        input_activation=act_quantization_spec,
265        output_activation=act_quantization_spec,
266        weight=weight_quantization_spec,
267        bias=bias_quantization_spec,
268    )
269
270    return quantization_config
271
272
273# TODO merge qat and ptq to a fucntion, and use a bool flag to control it
274def get_8a8w_qnn_qat_config(
275    act_symmetric: bool = False, act_observer=MovingAverageMinMaxObserver
276) -> QuantizationConfig:
277    act_fake_quant_ctr = FakeQuantize.with_args(
278        dtype=torch.uint8,
279        qscheme=(
280            torch.per_tensor_symmetric if act_symmetric else torch.per_tensor_affine
281        ),
282        reduce_range=True,
283        observer=act_observer,
284    )
285    act_quantization_spec = QuantizationSpec(
286        dtype=torch.uint8,
287        qscheme=(
288            torch.per_tensor_symmetric if act_symmetric else torch.per_tensor_affine
289        ),
290        ch_axis=0,
291        observer_or_fake_quant_ctr=act_fake_quant_ctr,
292    )
293
294    weight_fake_quant_ctr = FusedMovingAvgObsFakeQuantize.with_args(
295        dtype=torch.int8,
296        quant_min=torch.iinfo(torch.int8).min + 1,
297        quant_max=torch.iinfo(torch.int8).max,
298        qscheme=torch.per_tensor_symmetric,
299        reduce_range=True,
300        observer=MovingAverageMinMaxObserver,
301    )
302    weight_quantization_spec = QuantizationSpec(
303        dtype=torch.int8,
304        quant_min=torch.iinfo(torch.int8).min + 1,
305        quant_max=torch.iinfo(torch.int8).max,
306        qscheme=torch.per_tensor_symmetric,
307        ch_axis=0,
308        observer_or_fake_quant_ctr=weight_fake_quant_ctr,
309    )
310
311    bias_fake_quant_ctr = FakeQuantize.with_args(
312        dtype=torch.int32,
313        quant_min=torch.iinfo(torch.int32).min,
314        quant_max=torch.iinfo(torch.int32).max,
315        qscheme=torch.per_tensor_symmetric,
316        reduce_range=True,
317        observer=MovingAverageMinMaxObserver,
318    )
319    bias_quantization_spec = QuantizationSpec(
320        dtype=torch.int32,
321        quant_min=torch.iinfo(torch.int32).min,
322        quant_max=torch.iinfo(torch.int32).max,
323        qscheme=torch.per_tensor_symmetric,
324        observer_or_fake_quant_ctr=bias_fake_quant_ctr,
325    )
326
327    quantization_config = QuantizationConfig(
328        input_activation=act_quantization_spec,
329        output_activation=act_quantization_spec,
330        weight=weight_quantization_spec,
331        bias=bias_quantization_spec,
332    )
333
334    return quantization_config
335
336
337def get_16a4w_qnn_qat_config(
338    act_observer=MovingAverageMinMaxObserver,
339) -> QuantizationConfig:
340    act_fake_quant_ctr = FakeQuantize.with_args(
341        dtype=torch.int32,
342        quant_min=torch.iinfo(torch.uint16).min,
343        quant_max=torch.iinfo(torch.uint16).max,
344        qscheme=torch.per_tensor_affine,
345        reduce_range=True,
346        observer=act_observer,
347    )
348    act_quantization_spec = QuantizationSpec(
349        dtype=torch.int32,
350        quant_min=torch.iinfo(torch.uint16).min,
351        quant_max=torch.iinfo(torch.uint16).max,
352        qscheme=torch.per_tensor_affine,
353        observer_or_fake_quant_ctr=act_fake_quant_ctr,
354    )
355
356    weight_fake_quant_ctr = FusedMovingAvgObsFakeQuantize.with_args(
357        dtype=torch.int8,
358        quant_min=-7,
359        quant_max=7,
360        qscheme=torch.per_tensor_symmetric,
361        ch_axis=0,
362        reduce_range=True,
363        observer=MovingAverageMinMaxObserver,
364    )
365    weight_quantization_spec = QuantizationSpec(
366        dtype=torch.int8,
367        quant_min=-7,
368        quant_max=7,
369        qscheme=torch.per_tensor_symmetric,
370        ch_axis=0,
371        observer_or_fake_quant_ctr=weight_fake_quant_ctr,
372    )
373
374    bias_fake_quant_ctr = FakeQuantize.with_args(
375        dtype=torch.int32,
376        quant_min=torch.iinfo(torch.int32).min,
377        quant_max=torch.iinfo(torch.int32).max,
378        qscheme=torch.per_tensor_symmetric,
379        reduce_range=True,
380        observer=MovingAverageMinMaxObserver,
381    )
382    bias_quantization_spec = QuantizationSpec(
383        dtype=torch.int32,
384        quant_min=torch.iinfo(torch.int32).min,
385        quant_max=torch.iinfo(torch.int32).max,
386        qscheme=torch.per_tensor_symmetric,
387        observer_or_fake_quant_ctr=bias_fake_quant_ctr,
388    )
389
390    quantization_config = QuantizationConfig(
391        input_activation=act_quantization_spec,
392        output_activation=act_quantization_spec,
393        weight=weight_quantization_spec,
394        bias=bias_quantization_spec,
395    )
396
397    return quantization_config
398
399
400def get_qat_per_channel_quant_config(
401    act_dtype=torch.uint8,
402    weight_dtype=torch.int8,
403    act_observer=MovingAverageMinMaxObserver,
404) -> QuantizationConfig:
405    supported_act_types = {
406        torch.uint8,
407        torch.uint16,
408        torch.int8,
409        torch.int16,
410    }
411    # TODO accept "int4" temporally. Remove "int4" when torch support torch.int4 dtype
412    supported_weight_dtypes = {"int4", torch.int8, torch.int16}
413    assert (
414        act_dtype in supported_act_types
415    ), f"act_dtype, {act_dtype} is not one of supported types, {supported_act_types}"
416
417    assert (
418        weight_dtype in supported_weight_dtypes
419    ), f"weight_dtype, {weight_dtype} is not one of supported types, {supported_weight_dtypes}"
420
421    # torch do not support uint16 quantization, use int32 to bypass
422    act_fake_quant_ctr = FakeQuantize.with_args(
423        dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype,
424        quant_min=torch.iinfo(act_dtype).min,
425        quant_max=torch.iinfo(act_dtype).max,
426        qscheme=torch.per_tensor_affine,
427        reduce_range=True,
428        observer=act_observer,
429    )
430    act_quantization_spec = QuantizationSpec(
431        dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype,
432        quant_min=torch.iinfo(act_dtype).min,
433        quant_max=torch.iinfo(act_dtype).max,
434        qscheme=torch.per_tensor_affine,
435        observer_or_fake_quant_ctr=act_fake_quant_ctr,
436    )
437
438    weight_fake_quant_ctr = FusedMovingAvgObsFakeQuantize.with_args(
439        dtype=torch.int8 if weight_dtype == "int4" else weight_dtype,
440        quant_min=-7 if weight_dtype == "int4" else torch.iinfo(weight_dtype).min + 1,
441        quant_max=7 if weight_dtype == "int4" else torch.iinfo(weight_dtype).max,
442        qscheme=torch.per_channel_symmetric,
443        ch_axis=0,
444        observer=MovingAveragePerChannelMinMaxObserver,
445    )
446    weight_quantization_spec = QuantizationSpec(
447        dtype=torch.int8 if weight_dtype == "int4" else weight_dtype,
448        quant_min=-7 if weight_dtype == "int4" else torch.iinfo(weight_dtype).min + 1,
449        quant_max=7 if weight_dtype == "int4" else torch.iinfo(weight_dtype).max,
450        qscheme=torch.per_channel_symmetric,
451        ch_axis=0,
452        observer_or_fake_quant_ctr=weight_fake_quant_ctr,
453    )
454
455    bias_quantization_spec = _derived_bias_quant_spec
456
457    quantization_config = QuantizationConfig(
458        input_activation=act_quantization_spec,
459        output_activation=act_quantization_spec,
460        weight=weight_quantization_spec,
461        bias=bias_quantization_spec,
462    )
463
464    return quantization_config
465