1import copy 2from typing import Any, Callable, Dict, Optional, Set, Union 3 4import torch 5import torch.ao.nn as ao_nn 6import torch.ao.nn.intrinsic as nni 7import torch.ao.nn.intrinsic.qat as nniqat 8import torch.ao.nn.intrinsic.quantized as nniq 9import torch.ao.nn.intrinsic.quantized.dynamic as nniqd 10import torch.ao.nn.qat as nnqat 11import torch.ao.nn.qat.dynamic as nnqatd 12import torch.ao.nn.quantized as nnq 13import torch.ao.nn.quantized.dynamic as nnqd 14import torch.ao.nn.quantized.reference as nnqr 15 16# Because `torch.ao.nn` uses lazy imports, we need to make 17# sure we import the contents explicitly here. 18import torch.ao.nn.sparse 19import torch.nn.functional as F 20from torch import nn 21from torch.ao.quantization.fake_quantize import ( 22 default_fixed_qparams_range_0to1_fake_quant, 23 default_fixed_qparams_range_neg1to1_fake_quant, 24) 25from torch.ao.quantization.stubs import DeQuantStub, QuantStub 26from torch.ao.quantization.utils import get_combined_dict 27from torch.nn.utils.parametrize import type_before_parametrizations 28 29 30__all__ = [ 31 "DEFAULT_REFERENCE_STATIC_QUANT_MODULE_MAPPINGS", 32 "DEFAULT_STATIC_QUANT_MODULE_MAPPINGS", 33 "DEFAULT_QAT_MODULE_MAPPINGS", 34 "DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS", 35 "DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS", 36 "DEFAULT_MODULE_TO_ACT_POST_PROCESS", 37 "DEFAULT_STATIC_SPARSE_QUANT_MODULE_MAPPINGS", 38 "DEFAULT_DYNAMIC_SPARSE_QUANT_MODULE_MAPPINGS", 39 "no_observer_set", 40 "get_default_static_quant_module_mappings", 41 "get_default_static_quant_reference_module_mappings", 42 "get_embedding_static_quant_module_mappings", 43 "get_default_static_sparse_quant_module_mappings", 44 "get_static_quant_module_class", 45 "get_dynamic_quant_module_class", 46 "get_default_qat_module_mappings", 47 "get_embedding_qat_module_mappings", 48 "get_default_dynamic_quant_module_mappings", 49 "get_default_dynamic_sparse_quant_module_mappings", 50 "get_default_qconfig_propagation_list", 51 "get_default_compare_output_module_list", 52 "get_default_float_to_quantized_operator_mappings", 53 "get_quantized_operator", 54] 55 56# Default map for swapping float module to reference quantized modules 57DEFAULT_REFERENCE_STATIC_QUANT_MODULE_MAPPINGS: Dict[Callable, Any] = { 58 QuantStub: nnq.Quantize, 59 DeQuantStub: nnq.DeQuantize, 60 nn.Linear: nnqr.Linear, 61 nn.Conv1d: nnqr.Conv1d, 62 nn.Conv2d: nnqr.Conv2d, 63 nn.Conv3d: nnqr.Conv3d, 64 nn.ConvTranspose1d: nnqr.ConvTranspose1d, 65 nn.ConvTranspose2d: nnqr.ConvTranspose2d, 66 nn.ConvTranspose3d: nnqr.ConvTranspose3d, 67 nn.Embedding: nnqr.Embedding, 68 nn.EmbeddingBag: nnqr.EmbeddingBag, 69 nn.GRUCell: nnqr.GRUCell, 70 nn.LSTMCell: nnqr.LSTMCell, 71 nn.RNNCell: nnqr.RNNCell, 72 nn.LSTM: nnqr.LSTM, 73} 74 75# Default map for swapping float module to quantized ones 76DEFAULT_STATIC_QUANT_MODULE_MAPPINGS: Dict[Callable, Any] = { 77 QuantStub: nnq.Quantize, 78 DeQuantStub: nnq.DeQuantize, 79 nn.BatchNorm2d: nnq.BatchNorm2d, 80 nn.BatchNorm3d: nnq.BatchNorm3d, 81 nn.Dropout: nnq.Dropout, 82 nn.Conv1d: nnq.Conv1d, 83 nn.Conv2d: nnq.Conv2d, 84 nn.Conv3d: nnq.Conv3d, 85 nn.ConvTranspose1d: nnq.ConvTranspose1d, 86 nn.ConvTranspose2d: nnq.ConvTranspose2d, 87 nn.ConvTranspose3d: nnq.ConvTranspose3d, 88 nn.ELU: nnq.ELU, 89 nn.Embedding: nnq.Embedding, 90 nn.EmbeddingBag: nnq.EmbeddingBag, 91 nn.GroupNorm: nnq.GroupNorm, 92 nn.Hardswish: nnq.Hardswish, 93 nn.InstanceNorm1d: nnq.InstanceNorm1d, 94 nn.InstanceNorm2d: nnq.InstanceNorm2d, 95 nn.InstanceNorm3d: nnq.InstanceNorm3d, 96 nn.LayerNorm: nnq.LayerNorm, 97 nn.LeakyReLU: nnq.LeakyReLU, 98 nn.modules.linear.NonDynamicallyQuantizableLinear: nnq.Linear, 99 nn.Linear: nnq.Linear, 100 nn.ReLU6: nnq.ReLU6, 101 nn.Dropout: nnq.Dropout, 102 nn.PReLU: nnq.PReLU, 103 # Wrapper Modules: 104 nnq.FloatFunctional: nnq.QFunctional, 105 # Intrinsic modules: 106 nni.BNReLU2d: nniq.BNReLU2d, 107 nni.BNReLU3d: nniq.BNReLU3d, 108 nni.ConvReLU1d: nniq.ConvReLU1d, 109 nni.ConvReLU2d: nniq.ConvReLU2d, 110 nni.ConvReLU3d: nniq.ConvReLU3d, 111 nni.ConvAdd2d: nniq.ConvAdd2d, 112 nni.ConvAddReLU2d: nniq.ConvAddReLU2d, 113 nni.LinearReLU: nniq.LinearReLU, 114 nni.LinearLeakyReLU: nniq.LinearLeakyReLU, 115 nni.LinearTanh: nniq.LinearTanh, 116 nniqat.ConvBn1d: nnq.Conv1d, 117 nniqat.ConvBn2d: nnq.Conv2d, 118 nniqat.ConvBn3d: nnq.Conv3d, 119 nniqat.ConvBnReLU1d: nniq.ConvReLU1d, 120 nniqat.ConvBnReLU2d: nniq.ConvReLU2d, 121 nniqat.ConvBnReLU3d: nniq.ConvReLU3d, 122 nniqat.ConvReLU2d: nniq.ConvReLU2d, 123 nniqat.ConvReLU3d: nniq.ConvReLU3d, 124 nniqat.LinearReLU: nniq.LinearReLU, 125 nniqat.LinearBn1d: nnq.Linear, 126 # QAT modules: 127 nnqat.Linear: nnq.Linear, 128 nnqat.Conv2d: nnq.Conv2d, 129 nnqat.Conv3d: nnq.Conv3d, 130} 131 132# Default map for swapping float module to qat modules 133DEFAULT_QAT_MODULE_MAPPINGS: Dict[Callable, Any] = { 134 nn.Conv2d: nnqat.Conv2d, 135 nn.Conv3d: nnqat.Conv3d, 136 nn.Linear: nnqat.Linear, 137 nn.modules.linear.NonDynamicallyQuantizableLinear: nnqat.Linear, 138 # Intrinsic modules: 139 nni.ConvBn1d: nniqat.ConvBn1d, 140 nni.ConvBn2d: nniqat.ConvBn2d, 141 nni.ConvBn3d: nniqat.ConvBn3d, 142 nni.ConvBnReLU1d: nniqat.ConvBnReLU1d, 143 nni.ConvBnReLU2d: nniqat.ConvBnReLU2d, 144 nni.ConvBnReLU3d: nniqat.ConvBnReLU3d, 145 nni.ConvReLU2d: nniqat.ConvReLU2d, 146 nni.ConvReLU3d: nniqat.ConvReLU3d, 147 nni.LinearReLU: nniqat.LinearReLU, 148 nni.LinearBn1d: nniqat.LinearBn1d, 149} 150 151# Default map for swapping dynamic modules 152DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS: Dict[Callable, Any] = { 153 nn.GRUCell: nnqd.GRUCell, 154 nn.Linear: nnqd.Linear, 155 nnqatd.Linear: nnqd.Linear, 156 nn.modules.linear.NonDynamicallyQuantizableLinear: nnqd.Linear, 157 nn.LSTM: nnqd.LSTM, 158 nn.GRU: nnqd.GRU, 159 nn.LSTMCell: nnqd.LSTMCell, 160 nn.RNNCell: nnqd.RNNCell, 161 nni.LinearReLU: nniqd.LinearReLU, 162 nn.EmbeddingBag: nnq.EmbeddingBag, 163 nn.Embedding: nnq.Embedding, 164 # Don't want to enable these by default because the numerical 165 # accuracy is poor compared to other dynamic ops 166 # nn.Conv1d: nnqd.Conv1d, 167 # nn.Conv2d: nnqd.Conv2d, 168 # nn.Conv3d: nnqd.Conv3d, 169 # nn.ConvTranspose1d: nnqd.ConvTranspose1d, 170 # nn.ConvTranspose2d: nnqd.ConvTranspose2d, 171 # nn.ConvTranspose3d: nnqd.ConvTranspose3d, 172} 173 174# Allowlist for propagating the qconfig 175_INCLUDE_QCONFIG_PROPAGATE_LIST: Set[Callable] = { 176 nn.Sequential, 177} 178 179# Default mapping from floating point function or torch ops to quantized ops 180# TODO: merge with default static mapping 181DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS: Dict[Union[Callable, str], Callable] = { 182 F.elu: torch.ops.quantized.elu, 183 F.hardswish: torch.ops.quantized.hardswish, 184 F.instance_norm: torch.ops.quantized.instance_norm, 185 F.layer_norm: torch.ops.quantized.layer_norm, 186 F.leaky_relu: torch.ops.quantized.leaky_relu, 187 F.dropout: torch.ops.quantized.dropout, 188} 189 190# mapping from module to output activation post process class 191DEFAULT_MODULE_TO_ACT_POST_PROCESS: Dict[Callable, Callable] = { 192 nn.Hardsigmoid: default_fixed_qparams_range_0to1_fake_quant, 193 nn.Sigmoid: default_fixed_qparams_range_0to1_fake_quant, 194 nn.Softmax: default_fixed_qparams_range_0to1_fake_quant, 195 nn.Tanh: default_fixed_qparams_range_neg1to1_fake_quant, 196} 197 198# Default map for swapping float module to static sparse quantized ones 199DEFAULT_STATIC_SPARSE_QUANT_MODULE_MAPPINGS: Dict[Callable, Any] = { 200 nn.Linear: ao_nn.sparse.quantized.Linear 201} 202 203# Default map for swapping float module to dynamic sparse quantized ones 204DEFAULT_DYNAMIC_SPARSE_QUANT_MODULE_MAPPINGS: Dict[Callable, Any] = { 205 nn.Linear: ao_nn.sparse.quantized.dynamic.Linear 206} 207 208 209def no_observer_set() -> Set[Any]: 210 r"""These modules cannot have observers inserted by default.""" 211 no_observers = {nn.quantizable.LSTM, nn.quantizable.MultiheadAttention} 212 return no_observers 213 214 215def get_default_static_quant_module_mappings() -> Dict[Callable, Any]: 216 """Get module mapping for post training static quantization""" 217 return copy.deepcopy(DEFAULT_STATIC_QUANT_MODULE_MAPPINGS) 218 219 220def get_default_static_quant_reference_module_mappings() -> Dict[Callable, Any]: 221 """Get reference module mapping for post training static quantization""" 222 return copy.deepcopy(DEFAULT_REFERENCE_STATIC_QUANT_MODULE_MAPPINGS) 223 224 225def get_embedding_static_quant_module_mappings() -> Dict[Callable, Any]: 226 """Get module mapping, including mapping for embedding QAT""" 227 mapping = copy.deepcopy(DEFAULT_STATIC_QUANT_MODULE_MAPPINGS) 228 mapping[nnqat.EmbeddingBag] = nnq.EmbeddingBag 229 mapping[nnqat.Embedding] = nnq.Embedding 230 return mapping 231 232 233def get_default_static_sparse_quant_module_mappings() -> Dict[Callable, Any]: 234 """Get module mapping for post training static sparse quantization""" 235 return copy.deepcopy(DEFAULT_STATIC_SPARSE_QUANT_MODULE_MAPPINGS) 236 237 238def get_static_quant_module_class( 239 float_module_class: Callable, 240 additional_static_quant_mapping: Optional[Dict[Callable, Any]] = None, 241 is_reference: bool = False, 242) -> Any: 243 r"""n Get the statically quantized module class corresponding to 244 the floating point module class 245 """ 246 if additional_static_quant_mapping is None: 247 additional_static_quant_mapping = {} 248 all_mappings = get_combined_dict( 249 DEFAULT_REFERENCE_STATIC_QUANT_MODULE_MAPPINGS 250 if is_reference 251 else DEFAULT_STATIC_QUANT_MODULE_MAPPINGS, 252 additional_static_quant_mapping, 253 ) 254 static_quant_module_class = all_mappings.get(float_module_class, None) 255 assert static_quant_module_class is not None, ( 256 f"Floating point module class {str(float_module_class)}" 257 + " does not have a corresponding quantized module class" 258 ) 259 return copy.deepcopy(static_quant_module_class) 260 261 262def get_dynamic_quant_module_class( 263 float_module_class: Callable, 264 additional_dynamic_quant_mapping: Optional[Dict[Callable, Any]] = None, 265) -> Any: 266 r"""n Get the dynamically quantized module class corresponding to 267 the floating point module class 268 """ 269 if additional_dynamic_quant_mapping is None: 270 additional_dynamic_quant_mapping = {} 271 all_mappings = get_combined_dict( 272 DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS, additional_dynamic_quant_mapping 273 ) 274 dynamic_quant_module_class = all_mappings.get(float_module_class, None) 275 assert dynamic_quant_module_class is not None, ( 276 f"Floating point module class {str(float_module_class)}" 277 + " does not have a corresponding quantized module class" 278 ) 279 return copy.deepcopy(dynamic_quant_module_class) 280 281 282def get_default_qat_module_mappings() -> Dict[Callable, Any]: 283 """Get default module mapping for quantization aware training""" 284 return copy.deepcopy(DEFAULT_QAT_MODULE_MAPPINGS) 285 286 287def get_embedding_qat_module_mappings() -> Dict[Callable, Any]: 288 """Get module mapping for quantization aware training 289 This is includes default values in addition to 290 enabling qat for embeddings. 291 """ 292 mapping = copy.deepcopy(DEFAULT_QAT_MODULE_MAPPINGS) 293 mapping[nn.EmbeddingBag] = nnqat.EmbeddingBag 294 mapping[nn.Embedding] = nnqat.Embedding 295 return mapping 296 297 298def get_default_dynamic_quant_module_mappings() -> Dict[Callable, Any]: 299 """Get module mapping for post training dynamic quantization""" 300 return DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS 301 302 303def get_default_dynamic_sparse_quant_module_mappings() -> Dict[Callable, Any]: 304 """Get module mapping for post training dynamic sparse quantization""" 305 return DEFAULT_DYNAMIC_SPARSE_QUANT_MODULE_MAPPINGS 306 307 308def get_default_qconfig_propagation_list() -> Set[Callable]: 309 """Get the default list of module types that we'll attach qconfig 310 attribute to in prepare 311 """ 312 QCONFIG_PROPAGATE_MODULE_CLASS_LIST = ( 313 set(DEFAULT_STATIC_QUANT_MODULE_MAPPINGS.keys()) 314 | set(DEFAULT_QAT_MODULE_MAPPINGS.keys()) 315 | set(DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS.keys()) 316 | _INCLUDE_QCONFIG_PROPAGATE_LIST 317 ) 318 return copy.deepcopy(QCONFIG_PROPAGATE_MODULE_CLASS_LIST) 319 320 321def get_default_compare_output_module_list() -> Set[Callable]: 322 """Get list of module class types that we will record output 323 in numeric suite 324 """ 325 NUMERIC_SUITE_COMPARE_MODEL_OUTPUT_MODULE_LIST = ( 326 set(DEFAULT_STATIC_QUANT_MODULE_MAPPINGS.values()) 327 | set(DEFAULT_QAT_MODULE_MAPPINGS.values()) 328 | set(DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS.values()) 329 | set(DEFAULT_STATIC_QUANT_MODULE_MAPPINGS.keys()) 330 | set(DEFAULT_QAT_MODULE_MAPPINGS.keys()) 331 | set(DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS.keys()) 332 | _INCLUDE_QCONFIG_PROPAGATE_LIST 333 ) 334 return copy.deepcopy(NUMERIC_SUITE_COMPARE_MODEL_OUTPUT_MODULE_LIST) 335 336 337def get_default_float_to_quantized_operator_mappings() -> ( 338 Dict[Union[Callable, str], Callable] 339): 340 return copy.deepcopy(DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS) 341 342 343# TODO: merge with get_static_quant_module_class 344def get_quantized_operator(float_op: Union[Callable, str]) -> Callable: 345 """Get the quantized operator corresponding to the float operator""" 346 quantized_op = DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS.get(float_op, None) 347 assert ( 348 quantized_op is not None 349 ), f"Operator {str(float_op)} does not have corresponding quantized op" 350 return quantized_op 351 352 353def _get_special_act_post_process(module: torch.nn.Module) -> Optional[Callable]: 354 r"""Get the special activation post process for `module`, this has 355 higher priority than the activation post process in `qconfig` 356 e.g. 357 input: torch.nn.Sigmoid 358 output: default_affine_fixed_qparam_fake_quant 359 """ 360 return DEFAULT_MODULE_TO_ACT_POST_PROCESS.get( 361 type_before_parametrizations(module), None 362 ) 363 364 365def _has_special_act_post_process(module: torch.nn.Module) -> bool: 366 return module.training and type(module) in DEFAULT_MODULE_TO_ACT_POST_PROCESS 367