1# mypy: allow-untyped-defs 2from __future__ import annotations 3 4from dataclasses import dataclass 5from typing import Any, Dict, List, Optional, Tuple, Type 6 7from torch.ao.quantization import QConfigMapping 8from torch.ao.quantization.backend_config import BackendConfig 9from torch.ao.quantization.quant_type import ( 10 _get_quant_type_to_str, 11 _quant_type_from_str, 12 QuantType, 13) 14 15 16__all__ = [ 17 "ConvertCustomConfig", 18 "FuseCustomConfig", 19 "PrepareCustomConfig", 20 "StandaloneModuleConfigEntry", 21] 22 23 24# TODO: replace all usages with these constants 25STANDALONE_MODULE_NAME_DICT_KEY = "standalone_module_name" 26STANDALONE_MODULE_CLASS_DICT_KEY = "standalone_module_class" 27FLOAT_TO_OBSERVED_DICT_KEY = "float_to_observed_custom_module_class" 28OBSERVED_TO_QUANTIZED_DICT_KEY = "observed_to_quantized_custom_module_class" 29NON_TRACEABLE_MODULE_NAME_DICT_KEY = "non_traceable_module_name" 30NON_TRACEABLE_MODULE_CLASS_DICT_KEY = "non_traceable_module_class" 31INPUT_QUANTIZED_INDEXES_DICT_KEY = "input_quantized_idxs" 32OUTPUT_QUANTIZED_INDEXES_DICT_KEY = "output_quantized_idxs" 33PRESERVED_ATTRIBUTES_DICT_KEY = "preserved_attributes" 34 35 36@dataclass 37class StandaloneModuleConfigEntry: 38 # qconfig_mapping for the prepare function called in the submodule, 39 # None means use qconfig from parent qconfig_mapping 40 qconfig_mapping: Optional[QConfigMapping] 41 example_inputs: Tuple[Any, ...] 42 prepare_custom_config: Optional[PrepareCustomConfig] 43 backend_config: Optional[BackendConfig] 44 45 46class PrepareCustomConfig: 47 """ 48 Custom configuration for :func:`~torch.ao.quantization.quantize_fx.prepare_fx` and 49 :func:`~torch.ao.quantization.quantize_fx.prepare_qat_fx`. 50 51 Example usage:: 52 53 prepare_custom_config = PrepareCustomConfig() \ 54 .set_standalone_module_name("module1", qconfig_mapping, example_inputs, \ 55 child_prepare_custom_config, backend_config) \ 56 .set_standalone_module_class(MyStandaloneModule, qconfig_mapping, example_inputs, \ 57 child_prepare_custom_config, backend_config) \ 58 .set_float_to_observed_mapping(FloatCustomModule, ObservedCustomModule) \ 59 .set_non_traceable_module_names(["module2", "module3"]) \ 60 .set_non_traceable_module_classes([NonTraceableModule1, NonTraceableModule2]) \ 61 .set_input_quantized_indexes([0]) \ 62 .set_output_quantized_indexes([0]) \ 63 .set_preserved_attributes(["attr1", "attr2"]) 64 """ 65 66 def __init__(self) -> None: 67 self.standalone_module_names: Dict[str, StandaloneModuleConfigEntry] = {} 68 self.standalone_module_classes: Dict[Type, StandaloneModuleConfigEntry] = {} 69 self.float_to_observed_mapping: Dict[QuantType, Dict[Type, Type]] = {} 70 self.non_traceable_module_names: List[str] = [] 71 self.non_traceable_module_classes: List[Type] = [] 72 self.input_quantized_indexes: List[int] = [] 73 self.output_quantized_indexes: List[int] = [] 74 self.preserved_attributes: List[str] = [] 75 76 def __repr__(self): 77 dict_nonempty = {k: v for k, v in self.__dict__.items() if len(v) > 0} 78 return f"PrepareCustomConfig({dict_nonempty})" 79 80 def set_standalone_module_name( 81 self, 82 module_name: str, 83 qconfig_mapping: Optional[QConfigMapping], 84 example_inputs: Tuple[Any, ...], 85 prepare_custom_config: Optional[PrepareCustomConfig], 86 backend_config: Optional[BackendConfig], 87 ) -> PrepareCustomConfig: 88 """ 89 Set the configuration for running a standalone module identified by ``module_name``. 90 91 If ``qconfig_mapping`` is None, the parent ``qconfig_mapping`` will be used instead. 92 If ``prepare_custom_config`` is None, an empty ``PrepareCustomConfig`` will be used. 93 If ``backend_config`` is None, the parent ``backend_config`` will be used instead. 94 """ 95 self.standalone_module_names[module_name] = StandaloneModuleConfigEntry( 96 qconfig_mapping, example_inputs, prepare_custom_config, backend_config 97 ) 98 return self 99 100 def set_standalone_module_class( 101 self, 102 module_class: Type, 103 qconfig_mapping: Optional[QConfigMapping], 104 example_inputs: Tuple[Any, ...], 105 prepare_custom_config: Optional[PrepareCustomConfig], 106 backend_config: Optional[BackendConfig], 107 ) -> PrepareCustomConfig: 108 """ 109 Set the configuration for running a standalone module identified by ``module_class``. 110 111 If ``qconfig_mapping`` is None, the parent ``qconfig_mapping`` will be used instead. 112 If ``prepare_custom_config`` is None, an empty ``PrepareCustomConfig`` will be used. 113 If ``backend_config`` is None, the parent ``backend_config`` will be used instead. 114 """ 115 self.standalone_module_classes[module_class] = StandaloneModuleConfigEntry( 116 qconfig_mapping, example_inputs, prepare_custom_config, backend_config 117 ) 118 return self 119 120 def set_float_to_observed_mapping( 121 self, 122 float_class: Type, 123 observed_class: Type, 124 quant_type: QuantType = QuantType.STATIC, 125 ) -> PrepareCustomConfig: 126 """ 127 Set the mapping from a custom float module class to a custom observed module class. 128 129 The observed module class must have a ``from_float`` class method that converts the float module class 130 to the observed module class. This is currently only supported for static quantization. 131 """ 132 if quant_type != QuantType.STATIC: 133 raise ValueError( 134 "set_float_to_observed_mapping is currently only supported for static quantization" 135 ) 136 if quant_type not in self.float_to_observed_mapping: 137 self.float_to_observed_mapping[quant_type] = {} 138 self.float_to_observed_mapping[quant_type][float_class] = observed_class 139 return self 140 141 def set_non_traceable_module_names( 142 self, module_names: List[str] 143 ) -> PrepareCustomConfig: 144 """ 145 Set the modules that are not symbolically traceable, identified by name. 146 """ 147 self.non_traceable_module_names = module_names 148 return self 149 150 def set_non_traceable_module_classes( 151 self, module_classes: List[Type] 152 ) -> PrepareCustomConfig: 153 """ 154 Set the modules that are not symbolically traceable, identified by class. 155 """ 156 self.non_traceable_module_classes = module_classes 157 return self 158 159 def set_input_quantized_indexes(self, indexes: List[int]) -> PrepareCustomConfig: 160 """ 161 Set the indexes of the inputs of the graph that should be quantized. 162 Inputs are otherwise assumed to be in fp32 by default instead. 163 """ 164 self.input_quantized_indexes = indexes 165 return self 166 167 def set_output_quantized_indexes(self, indexes: List[int]) -> PrepareCustomConfig: 168 """ 169 Set the indexes of the outputs of the graph that should be quantized. 170 Outputs are otherwise assumed to be in fp32 by default instead. 171 """ 172 self.output_quantized_indexes = indexes 173 return self 174 175 def set_preserved_attributes(self, attributes: List[str]) -> PrepareCustomConfig: 176 """ 177 Set the names of the attributes that will persist in the graph module even if they are not used in 178 the model's ``forward`` method. 179 """ 180 self.preserved_attributes = attributes 181 return self 182 183 # TODO: remove this 184 @classmethod 185 def from_dict( 186 cls, prepare_custom_config_dict: Dict[str, Any] 187 ) -> PrepareCustomConfig: 188 """ 189 Create a ``PrepareCustomConfig`` from a dictionary with the following items: 190 191 "standalone_module_name": a list of (module_name, qconfig_mapping, example_inputs, 192 child_prepare_custom_config, backend_config) tuples 193 194 "standalone_module_class" a list of (module_class, qconfig_mapping, example_inputs, 195 child_prepare_custom_config, backend_config) tuples 196 197 "float_to_observed_custom_module_class": a nested dictionary mapping from quantization 198 mode to an inner mapping from float module classes to observed module classes, e.g. 199 {"static": {FloatCustomModule: ObservedCustomModule}} 200 201 "non_traceable_module_name": a list of modules names that are not symbolically traceable 202 "non_traceable_module_class": a list of module classes that are not symbolically traceable 203 "input_quantized_idxs": a list of indexes of graph inputs that should be quantized 204 "output_quantized_idxs": a list of indexes of graph outputs that should be quantized 205 "preserved_attributes": a list of attributes that persist even if they are not used in ``forward`` 206 207 This function is primarily for backward compatibility and may be removed in the future. 208 """ 209 210 def _get_qconfig_mapping(obj: Any, dict_key: str) -> Optional[QConfigMapping]: 211 """ 212 Convert the given object into a QConfigMapping if possible, else throw an exception. 213 """ 214 if isinstance(obj, QConfigMapping) or obj is None: 215 return obj 216 if isinstance(obj, Dict): 217 return QConfigMapping.from_dict(obj) 218 raise ValueError( 219 f"Expected QConfigMapping in prepare_custom_config_dict[\"{dict_key}\"], got '{type(obj)}'" 220 ) 221 222 def _get_prepare_custom_config( 223 obj: Any, dict_key: str 224 ) -> Optional[PrepareCustomConfig]: 225 """ 226 Convert the given object into a PrepareCustomConfig if possible, else throw an exception. 227 """ 228 if isinstance(obj, PrepareCustomConfig) or obj is None: 229 return obj 230 if isinstance(obj, Dict): 231 return PrepareCustomConfig.from_dict(obj) 232 raise ValueError( 233 f"Expected PrepareCustomConfig in prepare_custom_config_dict[\"{dict_key}\"], got '{type(obj)}'" 234 ) 235 236 def _get_backend_config(obj: Any, dict_key: str) -> Optional[BackendConfig]: 237 """ 238 Convert the given object into a BackendConfig if possible, else throw an exception. 239 """ 240 if isinstance(obj, BackendConfig) or obj is None: 241 return obj 242 if isinstance(obj, Dict): 243 return BackendConfig.from_dict(obj) 244 raise ValueError( 245 f"Expected BackendConfig in prepare_custom_config_dict[\"{dict_key}\"], got '{type(obj)}'" 246 ) 247 248 conf = cls() 249 for ( 250 module_name, 251 qconfig_dict, 252 example_inputs, 253 _prepare_custom_config_dict, 254 backend_config_dict, 255 ) in prepare_custom_config_dict.get(STANDALONE_MODULE_NAME_DICT_KEY, []): 256 qconfig_mapping = _get_qconfig_mapping( 257 qconfig_dict, STANDALONE_MODULE_NAME_DICT_KEY 258 ) 259 prepare_custom_config = _get_prepare_custom_config( 260 _prepare_custom_config_dict, STANDALONE_MODULE_NAME_DICT_KEY 261 ) 262 backend_config = _get_backend_config( 263 backend_config_dict, STANDALONE_MODULE_NAME_DICT_KEY 264 ) 265 conf.set_standalone_module_name( 266 module_name, 267 qconfig_mapping, 268 example_inputs, 269 prepare_custom_config, 270 backend_config, 271 ) 272 for ( 273 module_class, 274 qconfig_dict, 275 example_inputs, 276 _prepare_custom_config_dict, 277 backend_config_dict, 278 ) in prepare_custom_config_dict.get(STANDALONE_MODULE_CLASS_DICT_KEY, []): 279 qconfig_mapping = _get_qconfig_mapping( 280 qconfig_dict, STANDALONE_MODULE_CLASS_DICT_KEY 281 ) 282 prepare_custom_config = _get_prepare_custom_config( 283 _prepare_custom_config_dict, STANDALONE_MODULE_CLASS_DICT_KEY 284 ) 285 backend_config = _get_backend_config( 286 backend_config_dict, STANDALONE_MODULE_CLASS_DICT_KEY 287 ) 288 conf.set_standalone_module_class( 289 module_class, 290 qconfig_mapping, 291 example_inputs, 292 prepare_custom_config, 293 backend_config, 294 ) 295 for quant_type_name, custom_module_mapping in prepare_custom_config_dict.get( 296 FLOAT_TO_OBSERVED_DICT_KEY, {} 297 ).items(): 298 quant_type = _quant_type_from_str(quant_type_name) 299 for float_class, observed_class in custom_module_mapping.items(): 300 conf.set_float_to_observed_mapping( 301 float_class, observed_class, quant_type 302 ) 303 conf.set_non_traceable_module_names( 304 prepare_custom_config_dict.get(NON_TRACEABLE_MODULE_NAME_DICT_KEY, []) 305 ) 306 conf.set_non_traceable_module_classes( 307 prepare_custom_config_dict.get(NON_TRACEABLE_MODULE_CLASS_DICT_KEY, []) 308 ) 309 conf.set_input_quantized_indexes( 310 prepare_custom_config_dict.get(INPUT_QUANTIZED_INDEXES_DICT_KEY, []) 311 ) 312 conf.set_output_quantized_indexes( 313 prepare_custom_config_dict.get(OUTPUT_QUANTIZED_INDEXES_DICT_KEY, []) 314 ) 315 conf.set_preserved_attributes( 316 prepare_custom_config_dict.get(PRESERVED_ATTRIBUTES_DICT_KEY, []) 317 ) 318 return conf 319 320 def to_dict(self) -> Dict[str, Any]: 321 """ 322 Convert this ``PrepareCustomConfig`` to a dictionary with the items described in 323 :func:`~torch.ao.quantization.fx.custom_config.PrepareCustomConfig.from_dict`. 324 """ 325 326 def _make_tuple(key: Any, e: StandaloneModuleConfigEntry): 327 qconfig_dict = e.qconfig_mapping.to_dict() if e.qconfig_mapping else None 328 prepare_custom_config_dict = ( 329 e.prepare_custom_config.to_dict() if e.prepare_custom_config else None 330 ) 331 return ( 332 key, 333 qconfig_dict, 334 e.example_inputs, 335 prepare_custom_config_dict, 336 e.backend_config, 337 ) 338 339 d: Dict[str, Any] = {} 340 for module_name, sm_config_entry in self.standalone_module_names.items(): 341 if STANDALONE_MODULE_NAME_DICT_KEY not in d: 342 d[STANDALONE_MODULE_NAME_DICT_KEY] = [] 343 d[STANDALONE_MODULE_NAME_DICT_KEY].append( 344 _make_tuple(module_name, sm_config_entry) 345 ) 346 for module_class, sm_config_entry in self.standalone_module_classes.items(): 347 if STANDALONE_MODULE_CLASS_DICT_KEY not in d: 348 d[STANDALONE_MODULE_CLASS_DICT_KEY] = [] 349 d[STANDALONE_MODULE_CLASS_DICT_KEY].append( 350 _make_tuple(module_class, sm_config_entry) 351 ) 352 for ( 353 quant_type, 354 float_to_observed_mapping, 355 ) in self.float_to_observed_mapping.items(): 356 if FLOAT_TO_OBSERVED_DICT_KEY not in d: 357 d[FLOAT_TO_OBSERVED_DICT_KEY] = {} 358 d[FLOAT_TO_OBSERVED_DICT_KEY][ 359 _get_quant_type_to_str(quant_type) 360 ] = float_to_observed_mapping 361 if len(self.non_traceable_module_names) > 0: 362 d[NON_TRACEABLE_MODULE_NAME_DICT_KEY] = self.non_traceable_module_names 363 if len(self.non_traceable_module_classes) > 0: 364 d[NON_TRACEABLE_MODULE_CLASS_DICT_KEY] = self.non_traceable_module_classes 365 if len(self.input_quantized_indexes) > 0: 366 d[INPUT_QUANTIZED_INDEXES_DICT_KEY] = self.input_quantized_indexes 367 if len(self.output_quantized_indexes) > 0: 368 d[OUTPUT_QUANTIZED_INDEXES_DICT_KEY] = self.output_quantized_indexes 369 if len(self.preserved_attributes) > 0: 370 d[PRESERVED_ATTRIBUTES_DICT_KEY] = self.preserved_attributes 371 return d 372 373 374class ConvertCustomConfig: 375 """ 376 Custom configuration for :func:`~torch.ao.quantization.quantize_fx.convert_fx`. 377 378 Example usage:: 379 380 convert_custom_config = ConvertCustomConfig() \ 381 .set_observed_to_quantized_mapping(ObservedCustomModule, QuantizedCustomModule) \ 382 .set_preserved_attributes(["attr1", "attr2"]) 383 """ 384 385 def __init__(self) -> None: 386 self.observed_to_quantized_mapping: Dict[QuantType, Dict[Type, Type]] = {} 387 self.preserved_attributes: List[str] = [] 388 389 def __repr__(self): 390 dict_nonempty = {k: v for k, v in self.__dict__.items() if len(v) > 0} 391 return f"ConvertCustomConfig({dict_nonempty})" 392 393 def set_observed_to_quantized_mapping( 394 self, 395 observed_class: Type, 396 quantized_class: Type, 397 quant_type: QuantType = QuantType.STATIC, 398 ) -> ConvertCustomConfig: 399 """ 400 Set the mapping from a custom observed module class to a custom quantized module class. 401 402 The quantized module class must have a ``from_observed`` class method that converts the observed module class 403 to the quantized module class. 404 """ 405 if quant_type not in self.observed_to_quantized_mapping: 406 self.observed_to_quantized_mapping[quant_type] = {} 407 self.observed_to_quantized_mapping[quant_type][observed_class] = quantized_class 408 return self 409 410 def set_preserved_attributes(self, attributes: List[str]) -> ConvertCustomConfig: 411 """ 412 Set the names of the attributes that will persist in the graph module even if they are not used in 413 the model's ``forward`` method. 414 """ 415 self.preserved_attributes = attributes 416 return self 417 418 # TODO: remove this 419 @classmethod 420 def from_dict( 421 cls, convert_custom_config_dict: Dict[str, Any] 422 ) -> ConvertCustomConfig: 423 """ 424 Create a ``ConvertCustomConfig`` from a dictionary with the following items: 425 426 "observed_to_quantized_custom_module_class": a nested dictionary mapping from quantization 427 mode to an inner mapping from observed module classes to quantized module classes, e.g.:: 428 { 429 "static": {FloatCustomModule: ObservedCustomModule}, 430 "dynamic": {FloatCustomModule: ObservedCustomModule}, 431 "weight_only": {FloatCustomModule: ObservedCustomModule} 432 } 433 "preserved_attributes": a list of attributes that persist even if they are not used in ``forward`` 434 435 This function is primarily for backward compatibility and may be removed in the future. 436 """ 437 conf = cls() 438 for quant_type_name, custom_module_mapping in convert_custom_config_dict.get( 439 OBSERVED_TO_QUANTIZED_DICT_KEY, {} 440 ).items(): 441 quant_type = _quant_type_from_str(quant_type_name) 442 for observed_class, quantized_class in custom_module_mapping.items(): 443 conf.set_observed_to_quantized_mapping( 444 observed_class, quantized_class, quant_type 445 ) 446 conf.set_preserved_attributes( 447 convert_custom_config_dict.get(PRESERVED_ATTRIBUTES_DICT_KEY, []) 448 ) 449 return conf 450 451 def to_dict(self) -> Dict[str, Any]: 452 """ 453 Convert this ``ConvertCustomConfig`` to a dictionary with the items described in 454 :func:`~torch.ao.quantization.fx.custom_config.ConvertCustomConfig.from_dict`. 455 """ 456 d: Dict[str, Any] = {} 457 for ( 458 quant_type, 459 observed_to_quantized_mapping, 460 ) in self.observed_to_quantized_mapping.items(): 461 if OBSERVED_TO_QUANTIZED_DICT_KEY not in d: 462 d[OBSERVED_TO_QUANTIZED_DICT_KEY] = {} 463 d[OBSERVED_TO_QUANTIZED_DICT_KEY][ 464 _get_quant_type_to_str(quant_type) 465 ] = observed_to_quantized_mapping 466 if len(self.preserved_attributes) > 0: 467 d[PRESERVED_ATTRIBUTES_DICT_KEY] = self.preserved_attributes 468 return d 469 470 471class FuseCustomConfig: 472 """ 473 Custom configuration for :func:`~torch.ao.quantization.quantize_fx.fuse_fx`. 474 475 Example usage:: 476 477 fuse_custom_config = FuseCustomConfig().set_preserved_attributes(["attr1", "attr2"]) 478 """ 479 480 def __init__(self) -> None: 481 self.preserved_attributes: List[str] = [] 482 483 def __repr__(self): 484 dict_nonempty = {k: v for k, v in self.__dict__.items() if len(v) > 0} 485 return f"FuseCustomConfig({dict_nonempty})" 486 487 def set_preserved_attributes(self, attributes: List[str]) -> FuseCustomConfig: 488 """ 489 Set the names of the attributes that will persist in the graph module even if they are not used in 490 the model's ``forward`` method. 491 """ 492 self.preserved_attributes = attributes 493 return self 494 495 # TODO: remove this 496 @classmethod 497 def from_dict(cls, fuse_custom_config_dict: Dict[str, Any]) -> FuseCustomConfig: 498 """ 499 Create a ``ConvertCustomConfig`` from a dictionary with the following items: 500 501 "preserved_attributes": a list of attributes that persist even if they are not used in ``forward`` 502 503 This function is primarily for backward compatibility and may be removed in the future. 504 """ 505 conf = cls() 506 conf.set_preserved_attributes( 507 fuse_custom_config_dict.get(PRESERVED_ATTRIBUTES_DICT_KEY, []) 508 ) 509 return conf 510 511 def to_dict(self) -> Dict[str, Any]: 512 """ 513 Convert this ``FuseCustomConfig`` to a dictionary with the items described in 514 :func:`~torch.ao.quantization.fx.custom_config.ConvertCustomConfig.from_dict`. 515 """ 516 d: Dict[str, Any] = {} 517 if len(self.preserved_attributes) > 0: 518 d[PRESERVED_ATTRIBUTES_DICT_KEY] = self.preserved_attributes 519 return d 520