xref: /aosp_15_r20/external/pytorch/torch/ao/quantization/quantization_mappings.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import copy
2from typing import Any, Callable, Dict, Optional, Set, Union
3
4import torch
5import torch.ao.nn as ao_nn
6import torch.ao.nn.intrinsic as nni
7import torch.ao.nn.intrinsic.qat as nniqat
8import torch.ao.nn.intrinsic.quantized as nniq
9import torch.ao.nn.intrinsic.quantized.dynamic as nniqd
10import torch.ao.nn.qat as nnqat
11import torch.ao.nn.qat.dynamic as nnqatd
12import torch.ao.nn.quantized as nnq
13import torch.ao.nn.quantized.dynamic as nnqd
14import torch.ao.nn.quantized.reference as nnqr
15
16# Because `torch.ao.nn` uses lazy imports, we need to make
17# sure we import the contents explicitly here.
18import torch.ao.nn.sparse
19import torch.nn.functional as F
20from torch import nn
21from torch.ao.quantization.fake_quantize import (
22    default_fixed_qparams_range_0to1_fake_quant,
23    default_fixed_qparams_range_neg1to1_fake_quant,
24)
25from torch.ao.quantization.stubs import DeQuantStub, QuantStub
26from torch.ao.quantization.utils import get_combined_dict
27from torch.nn.utils.parametrize import type_before_parametrizations
28
29
30__all__ = [
31    "DEFAULT_REFERENCE_STATIC_QUANT_MODULE_MAPPINGS",
32    "DEFAULT_STATIC_QUANT_MODULE_MAPPINGS",
33    "DEFAULT_QAT_MODULE_MAPPINGS",
34    "DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS",
35    "DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS",
36    "DEFAULT_MODULE_TO_ACT_POST_PROCESS",
37    "DEFAULT_STATIC_SPARSE_QUANT_MODULE_MAPPINGS",
38    "DEFAULT_DYNAMIC_SPARSE_QUANT_MODULE_MAPPINGS",
39    "no_observer_set",
40    "get_default_static_quant_module_mappings",
41    "get_default_static_quant_reference_module_mappings",
42    "get_embedding_static_quant_module_mappings",
43    "get_default_static_sparse_quant_module_mappings",
44    "get_static_quant_module_class",
45    "get_dynamic_quant_module_class",
46    "get_default_qat_module_mappings",
47    "get_embedding_qat_module_mappings",
48    "get_default_dynamic_quant_module_mappings",
49    "get_default_dynamic_sparse_quant_module_mappings",
50    "get_default_qconfig_propagation_list",
51    "get_default_compare_output_module_list",
52    "get_default_float_to_quantized_operator_mappings",
53    "get_quantized_operator",
54]
55
56# Default map for swapping float module to reference quantized modules
57DEFAULT_REFERENCE_STATIC_QUANT_MODULE_MAPPINGS: Dict[Callable, Any] = {
58    QuantStub: nnq.Quantize,
59    DeQuantStub: nnq.DeQuantize,
60    nn.Linear: nnqr.Linear,
61    nn.Conv1d: nnqr.Conv1d,
62    nn.Conv2d: nnqr.Conv2d,
63    nn.Conv3d: nnqr.Conv3d,
64    nn.ConvTranspose1d: nnqr.ConvTranspose1d,
65    nn.ConvTranspose2d: nnqr.ConvTranspose2d,
66    nn.ConvTranspose3d: nnqr.ConvTranspose3d,
67    nn.Embedding: nnqr.Embedding,
68    nn.EmbeddingBag: nnqr.EmbeddingBag,
69    nn.GRUCell: nnqr.GRUCell,
70    nn.LSTMCell: nnqr.LSTMCell,
71    nn.RNNCell: nnqr.RNNCell,
72    nn.LSTM: nnqr.LSTM,
73}
74
75# Default map for swapping float module to quantized ones
76DEFAULT_STATIC_QUANT_MODULE_MAPPINGS: Dict[Callable, Any] = {
77    QuantStub: nnq.Quantize,
78    DeQuantStub: nnq.DeQuantize,
79    nn.BatchNorm2d: nnq.BatchNorm2d,
80    nn.BatchNorm3d: nnq.BatchNorm3d,
81    nn.Dropout: nnq.Dropout,
82    nn.Conv1d: nnq.Conv1d,
83    nn.Conv2d: nnq.Conv2d,
84    nn.Conv3d: nnq.Conv3d,
85    nn.ConvTranspose1d: nnq.ConvTranspose1d,
86    nn.ConvTranspose2d: nnq.ConvTranspose2d,
87    nn.ConvTranspose3d: nnq.ConvTranspose3d,
88    nn.ELU: nnq.ELU,
89    nn.Embedding: nnq.Embedding,
90    nn.EmbeddingBag: nnq.EmbeddingBag,
91    nn.GroupNorm: nnq.GroupNorm,
92    nn.Hardswish: nnq.Hardswish,
93    nn.InstanceNorm1d: nnq.InstanceNorm1d,
94    nn.InstanceNorm2d: nnq.InstanceNorm2d,
95    nn.InstanceNorm3d: nnq.InstanceNorm3d,
96    nn.LayerNorm: nnq.LayerNorm,
97    nn.LeakyReLU: nnq.LeakyReLU,
98    nn.modules.linear.NonDynamicallyQuantizableLinear: nnq.Linear,
99    nn.Linear: nnq.Linear,
100    nn.ReLU6: nnq.ReLU6,
101    nn.Dropout: nnq.Dropout,
102    nn.PReLU: nnq.PReLU,
103    # Wrapper Modules:
104    nnq.FloatFunctional: nnq.QFunctional,
105    # Intrinsic modules:
106    nni.BNReLU2d: nniq.BNReLU2d,
107    nni.BNReLU3d: nniq.BNReLU3d,
108    nni.ConvReLU1d: nniq.ConvReLU1d,
109    nni.ConvReLU2d: nniq.ConvReLU2d,
110    nni.ConvReLU3d: nniq.ConvReLU3d,
111    nni.ConvAdd2d: nniq.ConvAdd2d,
112    nni.ConvAddReLU2d: nniq.ConvAddReLU2d,
113    nni.LinearReLU: nniq.LinearReLU,
114    nni.LinearLeakyReLU: nniq.LinearLeakyReLU,
115    nni.LinearTanh: nniq.LinearTanh,
116    nniqat.ConvBn1d: nnq.Conv1d,
117    nniqat.ConvBn2d: nnq.Conv2d,
118    nniqat.ConvBn3d: nnq.Conv3d,
119    nniqat.ConvBnReLU1d: nniq.ConvReLU1d,
120    nniqat.ConvBnReLU2d: nniq.ConvReLU2d,
121    nniqat.ConvBnReLU3d: nniq.ConvReLU3d,
122    nniqat.ConvReLU2d: nniq.ConvReLU2d,
123    nniqat.ConvReLU3d: nniq.ConvReLU3d,
124    nniqat.LinearReLU: nniq.LinearReLU,
125    nniqat.LinearBn1d: nnq.Linear,
126    # QAT modules:
127    nnqat.Linear: nnq.Linear,
128    nnqat.Conv2d: nnq.Conv2d,
129    nnqat.Conv3d: nnq.Conv3d,
130}
131
132# Default map for swapping float module to qat modules
133DEFAULT_QAT_MODULE_MAPPINGS: Dict[Callable, Any] = {
134    nn.Conv2d: nnqat.Conv2d,
135    nn.Conv3d: nnqat.Conv3d,
136    nn.Linear: nnqat.Linear,
137    nn.modules.linear.NonDynamicallyQuantizableLinear: nnqat.Linear,
138    # Intrinsic modules:
139    nni.ConvBn1d: nniqat.ConvBn1d,
140    nni.ConvBn2d: nniqat.ConvBn2d,
141    nni.ConvBn3d: nniqat.ConvBn3d,
142    nni.ConvBnReLU1d: nniqat.ConvBnReLU1d,
143    nni.ConvBnReLU2d: nniqat.ConvBnReLU2d,
144    nni.ConvBnReLU3d: nniqat.ConvBnReLU3d,
145    nni.ConvReLU2d: nniqat.ConvReLU2d,
146    nni.ConvReLU3d: nniqat.ConvReLU3d,
147    nni.LinearReLU: nniqat.LinearReLU,
148    nni.LinearBn1d: nniqat.LinearBn1d,
149}
150
151# Default map for swapping dynamic modules
152DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS: Dict[Callable, Any] = {
153    nn.GRUCell: nnqd.GRUCell,
154    nn.Linear: nnqd.Linear,
155    nnqatd.Linear: nnqd.Linear,
156    nn.modules.linear.NonDynamicallyQuantizableLinear: nnqd.Linear,
157    nn.LSTM: nnqd.LSTM,
158    nn.GRU: nnqd.GRU,
159    nn.LSTMCell: nnqd.LSTMCell,
160    nn.RNNCell: nnqd.RNNCell,
161    nni.LinearReLU: nniqd.LinearReLU,
162    nn.EmbeddingBag: nnq.EmbeddingBag,
163    nn.Embedding: nnq.Embedding,
164    # Don't want to enable these by default because the numerical
165    # accuracy is poor compared to other dynamic ops
166    # nn.Conv1d: nnqd.Conv1d,
167    # nn.Conv2d: nnqd.Conv2d,
168    # nn.Conv3d: nnqd.Conv3d,
169    # nn.ConvTranspose1d: nnqd.ConvTranspose1d,
170    # nn.ConvTranspose2d: nnqd.ConvTranspose2d,
171    # nn.ConvTranspose3d: nnqd.ConvTranspose3d,
172}
173
174# Allowlist for propagating the qconfig
175_INCLUDE_QCONFIG_PROPAGATE_LIST: Set[Callable] = {
176    nn.Sequential,
177}
178
179# Default mapping from floating point function or torch ops to quantized ops
180# TODO: merge with default static mapping
181DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS: Dict[Union[Callable, str], Callable] = {
182    F.elu: torch.ops.quantized.elu,
183    F.hardswish: torch.ops.quantized.hardswish,
184    F.instance_norm: torch.ops.quantized.instance_norm,
185    F.layer_norm: torch.ops.quantized.layer_norm,
186    F.leaky_relu: torch.ops.quantized.leaky_relu,
187    F.dropout: torch.ops.quantized.dropout,
188}
189
190# mapping from module to output activation post process class
191DEFAULT_MODULE_TO_ACT_POST_PROCESS: Dict[Callable, Callable] = {
192    nn.Hardsigmoid: default_fixed_qparams_range_0to1_fake_quant,
193    nn.Sigmoid: default_fixed_qparams_range_0to1_fake_quant,
194    nn.Softmax: default_fixed_qparams_range_0to1_fake_quant,
195    nn.Tanh: default_fixed_qparams_range_neg1to1_fake_quant,
196}
197
198# Default map for swapping float module to static sparse quantized ones
199DEFAULT_STATIC_SPARSE_QUANT_MODULE_MAPPINGS: Dict[Callable, Any] = {
200    nn.Linear: ao_nn.sparse.quantized.Linear
201}
202
203# Default map for swapping float module to dynamic sparse quantized ones
204DEFAULT_DYNAMIC_SPARSE_QUANT_MODULE_MAPPINGS: Dict[Callable, Any] = {
205    nn.Linear: ao_nn.sparse.quantized.dynamic.Linear
206}
207
208
209def no_observer_set() -> Set[Any]:
210    r"""These modules cannot have observers inserted by default."""
211    no_observers = {nn.quantizable.LSTM, nn.quantizable.MultiheadAttention}
212    return no_observers
213
214
215def get_default_static_quant_module_mappings() -> Dict[Callable, Any]:
216    """Get module mapping for post training static quantization"""
217    return copy.deepcopy(DEFAULT_STATIC_QUANT_MODULE_MAPPINGS)
218
219
220def get_default_static_quant_reference_module_mappings() -> Dict[Callable, Any]:
221    """Get reference module mapping for post training static quantization"""
222    return copy.deepcopy(DEFAULT_REFERENCE_STATIC_QUANT_MODULE_MAPPINGS)
223
224
225def get_embedding_static_quant_module_mappings() -> Dict[Callable, Any]:
226    """Get module mapping, including mapping for embedding QAT"""
227    mapping = copy.deepcopy(DEFAULT_STATIC_QUANT_MODULE_MAPPINGS)
228    mapping[nnqat.EmbeddingBag] = nnq.EmbeddingBag
229    mapping[nnqat.Embedding] = nnq.Embedding
230    return mapping
231
232
233def get_default_static_sparse_quant_module_mappings() -> Dict[Callable, Any]:
234    """Get module mapping for post training static sparse quantization"""
235    return copy.deepcopy(DEFAULT_STATIC_SPARSE_QUANT_MODULE_MAPPINGS)
236
237
238def get_static_quant_module_class(
239    float_module_class: Callable,
240    additional_static_quant_mapping: Optional[Dict[Callable, Any]] = None,
241    is_reference: bool = False,
242) -> Any:
243    r"""n Get the statically quantized module class corresponding to
244    the floating point module class
245    """
246    if additional_static_quant_mapping is None:
247        additional_static_quant_mapping = {}
248    all_mappings = get_combined_dict(
249        DEFAULT_REFERENCE_STATIC_QUANT_MODULE_MAPPINGS
250        if is_reference
251        else DEFAULT_STATIC_QUANT_MODULE_MAPPINGS,
252        additional_static_quant_mapping,
253    )
254    static_quant_module_class = all_mappings.get(float_module_class, None)
255    assert static_quant_module_class is not None, (
256        f"Floating point module class {str(float_module_class)}"
257        + " does not have a corresponding quantized module class"
258    )
259    return copy.deepcopy(static_quant_module_class)
260
261
262def get_dynamic_quant_module_class(
263    float_module_class: Callable,
264    additional_dynamic_quant_mapping: Optional[Dict[Callable, Any]] = None,
265) -> Any:
266    r"""n Get the dynamically quantized module class corresponding to
267    the floating point module class
268    """
269    if additional_dynamic_quant_mapping is None:
270        additional_dynamic_quant_mapping = {}
271    all_mappings = get_combined_dict(
272        DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS, additional_dynamic_quant_mapping
273    )
274    dynamic_quant_module_class = all_mappings.get(float_module_class, None)
275    assert dynamic_quant_module_class is not None, (
276        f"Floating point module class {str(float_module_class)}"
277        + " does not have a corresponding quantized module class"
278    )
279    return copy.deepcopy(dynamic_quant_module_class)
280
281
282def get_default_qat_module_mappings() -> Dict[Callable, Any]:
283    """Get default module mapping for quantization aware training"""
284    return copy.deepcopy(DEFAULT_QAT_MODULE_MAPPINGS)
285
286
287def get_embedding_qat_module_mappings() -> Dict[Callable, Any]:
288    """Get module mapping for quantization aware training
289    This is includes default values in addition to
290    enabling qat for embeddings.
291    """
292    mapping = copy.deepcopy(DEFAULT_QAT_MODULE_MAPPINGS)
293    mapping[nn.EmbeddingBag] = nnqat.EmbeddingBag
294    mapping[nn.Embedding] = nnqat.Embedding
295    return mapping
296
297
298def get_default_dynamic_quant_module_mappings() -> Dict[Callable, Any]:
299    """Get module mapping for post training dynamic quantization"""
300    return DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS
301
302
303def get_default_dynamic_sparse_quant_module_mappings() -> Dict[Callable, Any]:
304    """Get module mapping for post training dynamic sparse quantization"""
305    return DEFAULT_DYNAMIC_SPARSE_QUANT_MODULE_MAPPINGS
306
307
308def get_default_qconfig_propagation_list() -> Set[Callable]:
309    """Get the default list of module types that we'll attach qconfig
310    attribute to in prepare
311    """
312    QCONFIG_PROPAGATE_MODULE_CLASS_LIST = (
313        set(DEFAULT_STATIC_QUANT_MODULE_MAPPINGS.keys())
314        | set(DEFAULT_QAT_MODULE_MAPPINGS.keys())
315        | set(DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS.keys())
316        | _INCLUDE_QCONFIG_PROPAGATE_LIST
317    )
318    return copy.deepcopy(QCONFIG_PROPAGATE_MODULE_CLASS_LIST)
319
320
321def get_default_compare_output_module_list() -> Set[Callable]:
322    """Get list of module class types that we will record output
323    in numeric suite
324    """
325    NUMERIC_SUITE_COMPARE_MODEL_OUTPUT_MODULE_LIST = (
326        set(DEFAULT_STATIC_QUANT_MODULE_MAPPINGS.values())
327        | set(DEFAULT_QAT_MODULE_MAPPINGS.values())
328        | set(DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS.values())
329        | set(DEFAULT_STATIC_QUANT_MODULE_MAPPINGS.keys())
330        | set(DEFAULT_QAT_MODULE_MAPPINGS.keys())
331        | set(DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS.keys())
332        | _INCLUDE_QCONFIG_PROPAGATE_LIST
333    )
334    return copy.deepcopy(NUMERIC_SUITE_COMPARE_MODEL_OUTPUT_MODULE_LIST)
335
336
337def get_default_float_to_quantized_operator_mappings() -> (
338    Dict[Union[Callable, str], Callable]
339):
340    return copy.deepcopy(DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS)
341
342
343# TODO: merge with get_static_quant_module_class
344def get_quantized_operator(float_op: Union[Callable, str]) -> Callable:
345    """Get the quantized operator corresponding to the float operator"""
346    quantized_op = DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS.get(float_op, None)
347    assert (
348        quantized_op is not None
349    ), f"Operator {str(float_op)} does not have corresponding quantized op"
350    return quantized_op
351
352
353def _get_special_act_post_process(module: torch.nn.Module) -> Optional[Callable]:
354    r"""Get the special activation post process for `module`, this has
355    higher priority than the activation post process in `qconfig`
356    e.g.
357    input: torch.nn.Sigmoid
358    output: default_affine_fixed_qparam_fake_quant
359    """
360    return DEFAULT_MODULE_TO_ACT_POST_PROCESS.get(
361        type_before_parametrizations(module), None
362    )
363
364
365def _has_special_act_post_process(module: torch.nn.Module) -> bool:
366    return module.training and type(module) in DEFAULT_MODULE_TO_ACT_POST_PROCESS
367