1# mypy: allow-untyped-defs 2import copy 3import inspect 4import itertools 5import warnings 6 7import torch 8import torch.ao.nn.quantized as nnq 9import torch.nn as nn 10from torch.ao.nn.intrinsic import _FusedModule 11from torch.ao.quantization.observer import _is_activation_post_process 12from torch.ao.quantization.qconfig import ( 13 _activation_is_memoryless, 14 _add_module_to_qconfig_obs_ctr, 15 default_dynamic_qconfig, 16 float16_dynamic_qconfig, 17 float_qparams_weight_only_qconfig, 18 float_qparams_weight_only_qconfig_4bit, 19) 20from torch.ao.quantization.quantization_mappings import ( 21 _get_special_act_post_process, 22 _has_special_act_post_process, 23 get_default_dynamic_quant_module_mappings, 24 get_default_qat_module_mappings, 25 get_default_qconfig_propagation_list, 26 get_default_static_quant_module_mappings, 27 get_default_static_quant_reference_module_mappings, 28 no_observer_set, 29) 30from torch.ao.quantization.stubs import DeQuantStub, QuantWrapper 31from torch.nn.utils.parametrize import type_before_parametrizations 32 33from .utils import get_qparam_dict, has_no_children_ignoring_parametrizations 34 35 36__all__ = [ 37 "get_default_custom_config_dict", 38 "propagate_qconfig_", 39 "add_quant_dequant", 40 "prepare", 41 "quantize", 42 "quantize_dynamic", 43 "prepare_qat", 44 "quantize_qat", 45 "convert", 46 "swap_module", 47] 48 49 50# TODO remove this once BC is no longer required to avoid a SEV 51is_activation_post_process = _is_activation_post_process 52 53 54_DEFAULT_CUSTOM_CONFIG_DICT = { 55 "float_to_observed_custom_module_class": { 56 nn.LSTM: nn.quantizable.LSTM, 57 nn.MultiheadAttention: nn.quantizable.MultiheadAttention, 58 }, 59 "observed_to_quantized_custom_module_class": { 60 nn.quantizable.LSTM: nn.quantized.LSTM, 61 nn.quantizable.MultiheadAttention: nn.quantized.MultiheadAttention, 62 }, 63} 64 65 66def get_default_custom_config_dict(): 67 r"""Defines the default custom config dict.""" 68 return _DEFAULT_CUSTOM_CONFIG_DICT 69 70 71def _propagate_qconfig_helper( 72 module, 73 qconfig_dict, 74 qconfig_parent=None, 75 prefix="", 76 prepare_custom_config_dict=None, 77): 78 r"""This is a helper function for `propagate_qconfig_` 79 80 Args: 81 module: input module 82 qconfig_dict: dictionary that maps from name of submodule to quantization 83 configuration 84 qconfig_parent: quantization config of parent module, we will fallback to 85 this config when there is no specified config for current 86 module 87 prefix: corresponding prefix of the current module, used as key in 88 qconfig_dict 89 prepare_custom_config_dict: dictionary for custom handling of modules 90 see docs for :func:`~torch.ao.quantization.prepare_fx` 91 92 Return: 93 None, module is modified inplace with qconfig attached 94 """ 95 96 module_qconfig = qconfig_dict.get( 97 type_before_parametrizations(module), qconfig_parent 98 ) 99 module_qconfig = qconfig_dict.get(prefix, module_qconfig) 100 module_qconfig = getattr(module, "qconfig", module_qconfig) 101 102 torch.ao.quantization.qconfig._assert_valid_qconfig(module_qconfig, module) 103 104 qconfig_with_device_check = _add_module_to_qconfig_obs_ctr(module_qconfig, module) 105 module.qconfig = qconfig_with_device_check 106 107 for name, child in module.named_children(): 108 module_prefix = prefix + "." + name if prefix else name 109 # do no not propagate qconfig to child if child is non traceable 110 if prepare_custom_config_dict is None or not ( 111 name in prepare_custom_config_dict.get("non_traceable_module_name", []) 112 or type(child) 113 in prepare_custom_config_dict.get("non_traceable_module_class", []) 114 ): 115 _propagate_qconfig_helper( 116 child, qconfig_dict, qconfig_with_device_check, module_prefix 117 ) 118 119 120def propagate_qconfig_(module, qconfig_dict=None, prepare_custom_config_dict=None): 121 r"""Propagate qconfig through the module hierarchy and assign `qconfig` 122 attribute on each leaf module 123 124 Args: 125 module: input module 126 qconfig_dict: dictionary that maps from name or type of submodule to 127 quantization configuration, qconfig applies to all submodules of a 128 given module unless qconfig for the submodules are specified (when 129 the submodule already has qconfig attribute) 130 prepare_custom_config_dict: dictionary for custom handling of modules 131 see docs for :func:`~torch.ao.quantization.prepare_fx` 132 133 Return: 134 None, module is modified inplace with qconfig attached 135 """ 136 if qconfig_dict is None: 137 qconfig_dict = {} 138 if prepare_custom_config_dict is None: 139 prepare_custom_config_dict = {} 140 _propagate_qconfig_helper( 141 module, qconfig_dict, prepare_custom_config_dict=prepare_custom_config_dict 142 ) 143 144 145def _observer_forward_hook(self, input, output): 146 r"""Forward hook that calls observer on the output""" 147 return self.activation_post_process(output) 148 149 150def _observer_forward_pre_hook(self, input): 151 r"""Forward pre hook that calls observer on the output""" 152 return self.activation_post_process(input[0]) 153 154 155def _register_activation_post_process_hook(module, pre_hook=False): 156 assert hasattr( 157 module, "activation_post_process" 158 ), "Expect activation_post_process attribute already attached to the module" 159 if pre_hook: 160 handle = module.register_forward_pre_hook( 161 _observer_forward_pre_hook, prepend=True 162 ) 163 else: 164 handle = module.register_forward_hook(_observer_forward_hook, prepend=True) 165 166 167def _add_observer_( 168 module, 169 qconfig_propagation_list=None, 170 non_leaf_module_list=None, 171 device=None, 172 custom_module_class_mapping=None, 173): 174 r"""Add observer for the leaf child of the module. 175 176 This function insert observer module to all leaf child module that 177 has a valid qconfig attribute. 178 179 Args: 180 module: input module with qconfig attributes for all the leaf modules that we want to quantize 181 qconfig_propagation_list: a list of quantizable modules that will have observers added to them 182 if they are leaf nodes 183 device: parent device, if any 184 non_leaf_module_list: list of non-leaf modules we want to add observer 185 186 Return: 187 None, module is modified inplace with added observer modules and forward_hooks 188 """ 189 if qconfig_propagation_list is None: 190 qconfig_propagation_list = get_default_qconfig_propagation_list() 191 192 if custom_module_class_mapping is None: 193 custom_module_class_mapping = {} 194 195 # respect device affinity when adding observers 196 if device is None: 197 devices = _get_unique_devices_(module) 198 assert ( 199 len(devices) <= 1 200 ), f"_add_observer_ only works with cpu or single-device CUDA modules, but got devices {devices}" 201 device = next(iter(devices)) if len(devices) > 0 else None 202 203 def get_activation_post_process(qconfig, device, special_act_post_process=None): 204 activation = ( 205 qconfig.activation() 206 if special_act_post_process is None 207 else special_act_post_process() 208 ) 209 if device is not None: 210 activation.to(device) 211 return activation 212 213 def needs_observation(m): 214 return hasattr(m, "qconfig") and m.qconfig is not None 215 216 def insert_activation_post_process(m, special_act_post_process=None): 217 """Adds an activation post process module and register 218 a pre or post hook that calls the module 219 """ 220 # We don't insert observer/fake_quantize for DeQuantStub 221 if needs_observation(m) and not isinstance(m, DeQuantStub): 222 # observer and hook will be gone after we swap the module 223 m.add_module( 224 "activation_post_process", 225 get_activation_post_process( 226 m.qconfig, device, special_act_post_process 227 ), 228 ) 229 # Register observer as the first entry in the hook list 230 # All post forward hooks are preserved and will be executed after the observer before convert 231 _register_activation_post_process_hook( 232 m, pre_hook=_activation_is_memoryless(m.qconfig) 233 ) 234 235 for name, child in module.named_children(): 236 # TODO remove Dropout special after codebase stable 237 if type_before_parametrizations(child) in [nn.Dropout]: 238 continue 239 elif issubclass( 240 type_before_parametrizations(child), (nnq.FloatFunctional, nnq.QFunctional) 241 ): 242 if needs_observation(child): 243 assert hasattr( 244 child, "activation_post_process" 245 ), f"functional class {type_before_parametrizations(child)} has no pre-defined `activation_post_process`" 246 child.activation_post_process = get_activation_post_process( 247 child.qconfig, device 248 ) 249 elif isinstance(child, _FusedModule): 250 # activation_post_process are now added directly to nn.Sequential/_FusedModule 251 if needs_observation(child): 252 insert_activation_post_process(child) 253 elif ( 254 non_leaf_module_list is not None 255 and type_before_parametrizations(child) in non_leaf_module_list 256 ): 257 if needs_observation(child): 258 insert_activation_post_process(child) 259 elif _has_special_act_post_process(child): 260 special_act_post_process = _get_special_act_post_process(child) 261 insert_activation_post_process(child, special_act_post_process) 262 elif ( 263 needs_observation(child) 264 and type_before_parametrizations(child) in custom_module_class_mapping 265 ): 266 observed_child = custom_module_class_mapping[ 267 type_before_parametrizations(child) 268 ].from_float(child) 269 setattr(module, name, observed_child) 270 # TODO: These are the modules that cannot be observed 271 # Once there are more, we should move them to a separate list 272 if ( 273 custom_module_class_mapping[type_before_parametrizations(child)] 274 not in no_observer_set() 275 ): 276 insert_activation_post_process(observed_child) 277 else: 278 _add_observer_( 279 child, 280 qconfig_propagation_list, 281 non_leaf_module_list, 282 device, 283 custom_module_class_mapping, 284 ) 285 286 # Insert observers only for leaf nodes, note that this observer is for 287 # the output of the module, for input QuantStub will observe them 288 if ( 289 has_no_children_ignoring_parametrizations(module) 290 and not isinstance(module, torch.nn.Sequential) 291 and type_before_parametrizations(module) in qconfig_propagation_list 292 ): 293 insert_activation_post_process(module) 294 # This is a special case for AdaRound eager mode 295 # AdaRound contains weight_fake_quant to be propagated from API to convert 296 # leaf node check with a number of children looks naive assumption that blocks 297 # Adding an exception case for AdaRound 298 if ( 299 hasattr(module, "weight_fake_quant") 300 and not isinstance(module, torch.nn.Sequential) 301 and type_before_parametrizations(module) in qconfig_propagation_list 302 ): 303 insert_activation_post_process(module) 304 305 306def _get_unique_devices_(module): 307 return {p.device for p in module.parameters()} | { 308 p.device for p in module.buffers() 309 } 310 311 312def add_quant_dequant(module): 313 r"""Wrap the leaf child module in QuantWrapper if it has a valid qconfig 314 Note that this function will modify the children of module inplace and it 315 can return a new module which wraps the input module as well. 316 317 Args: 318 module: input module with qconfig attributes for all the leaf modules 319 that we want to quantize 320 321 Return: 322 Either the inplace modified module with submodules wrapped in 323 `QuantWrapper` based on qconfig or a new `QuantWrapper` module which 324 wraps the input module, the latter case only happens when the input 325 module is a leaf module and we want to quantize it. 326 """ 327 if ( 328 has_no_children_ignoring_parametrizations(module) 329 and hasattr(module, "qconfig") 330 and module.qconfig 331 ): 332 return QuantWrapper(module) 333 334 for name, child in module.named_children(): 335 module._modules[name] = add_quant_dequant(child) 336 return module 337 338 339def prepare( 340 model, 341 inplace=False, 342 allow_list=None, 343 observer_non_leaf_module_list=None, 344 prepare_custom_config_dict=None, 345): 346 r"""Prepares a copy of the model for quantization calibration or quantization-aware training. 347 348 Quantization configuration should be assigned preemptively 349 to individual submodules in `.qconfig` attribute. 350 351 The model will be attached with observer or fake quant modules, and qconfig 352 will be propagated. 353 354 Args: 355 `model`: input model to be modified in-place 356 `inplace`: carry out model transformations in-place, the original module is mutated 357 `allow_list`: list of quantizable modules 358 `observer_non_leaf_module_list`: list of non-leaf modules we want to add observer 359 `prepare_custom_config_dict`: customization configuration dictionary for prepare function 360 361 .. code-block:: python 362 363 # Example of prepare_custom_config_dict: 364 prepare_custom_config_dict = { 365 # user will manually define the corresponding observed 366 # module class which has a from_float class method that converts 367 # float custom module to observed custom module 368 "float_to_observed_custom_module_class": { 369 CustomModule: ObservedCustomModule 370 } 371 } 372 373 """ 374 torch._C._log_api_usage_once("quantization_api.quantize.prepare") 375 if prepare_custom_config_dict is None: 376 prepare_custom_config_dict = get_default_custom_config_dict() 377 custom_module_class_mapping = prepare_custom_config_dict.get( 378 "float_to_observed_custom_module_class", {} 379 ) 380 381 if not inplace: 382 model = copy.deepcopy(model) 383 384 # TODO: remove allow_list 385 qconfig_propagation_list = allow_list 386 if allow_list is None: 387 qconfig_propagation_list = get_default_qconfig_propagation_list() 388 propagate_qconfig_(model, qconfig_dict=None) 389 390 # sanity check common API misusage 391 if not any(hasattr(m, "qconfig") and m.qconfig for m in model.modules()): 392 warnings.warn( 393 "None of the submodule got qconfig applied. Make sure you " 394 "passed correct configuration through `qconfig_dict` or " 395 "by assigning the `.qconfig` attribute directly on submodules" 396 ) 397 398 _add_observer_( 399 model, 400 qconfig_propagation_list, 401 observer_non_leaf_module_list, 402 custom_module_class_mapping=custom_module_class_mapping, 403 ) 404 return model 405 406 407def _remove_activation_post_process(module): 408 # TODO: maybe we should change activation_post_process to _activation_post_process 409 # to prevent it from being used by user 410 if hasattr(module, "activation_post_process") and _is_activation_post_process( 411 module.activation_post_process 412 ): 413 delattr(module, "activation_post_process") 414 415 # remove activation_post_process pre and post hooks 416 def remove_hooks(pre_hook=False): 417 hook_map = module._forward_pre_hooks if pre_hook else module._forward_hooks 418 observer_hook = ( 419 _observer_forward_pre_hook if pre_hook else _observer_forward_hook 420 ) 421 handle_ids_to_remove = set() 422 for handle_id, hook_fn in hook_map.items(): 423 if hook_fn is observer_hook: 424 handle_ids_to_remove.add(handle_id) 425 for handle_id in handle_ids_to_remove: 426 hook_map.pop(handle_id) 427 428 remove_hooks(pre_hook=True) 429 remove_hooks(pre_hook=False) 430 431 432# TODO: rename to something more general 433def _remove_qconfig(module): 434 r"""Clean up the qconfig left in the module so that new qconfig can be 435 propagated. 436 437 Args: 438 module: module to be cleaned up 439 """ 440 for child in module.children(): 441 _remove_qconfig(child) 442 443 if hasattr(module, "qconfig"): 444 del module.qconfig 445 446 _remove_activation_post_process(module) 447 448 449def quantize(model, run_fn, run_args, mapping=None, inplace=False): 450 r"""Quantize the input float model with post training static quantization. 451 452 First it will prepare the model for calibration, then it calls 453 `run_fn` which will run the calibration step, after that we will 454 convert the model to a quantized model. 455 456 Args: 457 model: input float model 458 run_fn: a calibration function for calibrating the prepared model 459 run_args: positional arguments for `run_fn` 460 inplace: carry out model transformations in-place, the original module is mutated 461 mapping: correspondence between original module types and quantized counterparts 462 463 Return: 464 Quantized model. 465 """ 466 torch._C._log_api_usage_once("quantization_api.quantize.quantize") 467 if mapping is None: 468 mapping = get_default_static_quant_module_mappings() 469 if not inplace: 470 model = copy.deepcopy(model) 471 model.eval() 472 prepare(model, inplace=True) 473 run_fn(model, *run_args) 474 convert(model, mapping, inplace=True) 475 return model 476 477 478def quantize_dynamic( 479 model, qconfig_spec=None, dtype=torch.qint8, mapping=None, inplace=False 480): 481 r"""Converts a float model to dynamic (i.e. weights-only) quantized model. 482 483 Replaces specified modules with dynamic weight-only quantized versions and output the quantized model. 484 485 For simplest usage provide `dtype` argument that can be float16 or qint8. Weight-only quantization 486 by default is performed for layers with large weights size - i.e. Linear and RNN variants. 487 488 Fine grained control is possible with `qconfig` and `mapping` that act similarly to `quantize()`. 489 If `qconfig` is provided, the `dtype` argument is ignored. 490 491 Args: 492 model: input model 493 qconfig_spec: Either: 494 495 - A dictionary that maps from name or type of submodule to quantization 496 configuration, qconfig applies to all submodules of a given 497 module unless qconfig for the submodules are specified (when the 498 submodule already has qconfig attribute). Entries in the dictionary 499 need to be QConfig instances. 500 501 - A set of types and/or submodule names to apply dynamic quantization to, 502 in which case the `dtype` argument is used to specify the bit-width 503 504 inplace: carry out model transformations in-place, the original module is mutated 505 mapping: maps type of a submodule to a type of corresponding dynamically quantized version 506 with which the submodule needs to be replaced 507 508 """ 509 torch._C._log_api_usage_once("quantization_api.quantize.quantize_dynamic") 510 if qconfig_spec is None: 511 if dtype == torch.qint8: 512 qconfig_spec = { 513 nn.Linear: default_dynamic_qconfig, 514 nn.LSTM: default_dynamic_qconfig, 515 nn.GRU: default_dynamic_qconfig, 516 nn.LSTMCell: default_dynamic_qconfig, 517 nn.RNNCell: default_dynamic_qconfig, 518 nn.GRUCell: default_dynamic_qconfig, 519 } 520 elif dtype == torch.float16: 521 qconfig_spec = { 522 nn.Linear: float16_dynamic_qconfig, 523 nn.LSTM: float16_dynamic_qconfig, 524 nn.GRU: float16_dynamic_qconfig, 525 nn.LSTMCell: float16_dynamic_qconfig, 526 nn.RNNCell: float16_dynamic_qconfig, 527 nn.GRUCell: float16_dynamic_qconfig, 528 } 529 elif dtype == torch.quint8: 530 qconfig_spec = { 531 nn.EmbeddingBag: float_qparams_weight_only_qconfig, 532 nn.Embedding: float_qparams_weight_only_qconfig, 533 } 534 elif dtype == torch.quint4x2: 535 qconfig_spec = { 536 nn.EmbeddingBag: float_qparams_weight_only_qconfig_4bit, 537 } 538 else: 539 raise ValueError( 540 f"Don't know how to quantize with default settings for {dtype}. Provide full qconfig please" 541 ) 542 elif isinstance(qconfig_spec, set): 543 if dtype is torch.qint8: 544 default_qconfig = default_dynamic_qconfig 545 elif dtype is torch.float16: 546 default_qconfig = float16_dynamic_qconfig 547 elif dtype is torch.quint8: 548 default_qconfig = float_qparams_weight_only_qconfig 549 elif dtype is torch.quint4x2: 550 default_qconfig = float_qparams_weight_only_qconfig_4bit 551 else: 552 raise RuntimeError( 553 "Unknown dtype specified for quantize_dynamic: ", str(dtype) 554 ) 555 qconfig_spec = dict(zip(qconfig_spec, itertools.repeat(default_qconfig))) 556 557 if mapping is None: 558 mapping = get_default_dynamic_quant_module_mappings() 559 560 if not inplace: 561 model = copy.deepcopy(model) 562 model.eval() 563 propagate_qconfig_(model, qconfig_spec) 564 convert(model, mapping, inplace=True) 565 return model 566 567 568def prepare_qat(model, mapping=None, inplace=False): 569 r""" 570 Prepares a copy of the model for quantization calibration or 571 quantization-aware training and converts it to quantized version. 572 573 Quantization configuration should be assigned preemptively 574 to individual submodules in `.qconfig` attribute. 575 576 Args: 577 model: input model to be modified in-place 578 mapping: dictionary that maps float modules to quantized modules to be 579 replaced. 580 inplace: carry out model transformations in-place, the original module 581 is mutated 582 """ 583 torch._C._log_api_usage_once("quantization_api.quantize.prepare_qat") 584 assert model.training, "prepare_qat only works on models in training mode" 585 if mapping is None: 586 mapping = get_default_qat_module_mappings() 587 588 if not inplace: 589 model = copy.deepcopy(model) 590 591 propagate_qconfig_(model, qconfig_dict=None) 592 convert(model, mapping=mapping, inplace=True, remove_qconfig=False) 593 prepare(model, observer_non_leaf_module_list=set(mapping.values()), inplace=True) 594 return model 595 596 597def quantize_qat(model, run_fn, run_args, inplace=False): 598 r"""Do quantization aware training and output a quantized model 599 600 Args: 601 model: input model 602 run_fn: a function for evaluating the prepared model, can be a 603 function that simply runs the prepared model or a training 604 loop 605 run_args: positional arguments for `run_fn` 606 607 Return: 608 Quantized model. 609 """ 610 torch._C._log_api_usage_once("quantization_api.quantize.quantize_qat") 611 if not inplace: 612 model = copy.deepcopy(model) 613 model.train() 614 prepare_qat(model, inplace=True) 615 run_fn(model, *run_args) 616 convert(model, inplace=True) 617 return model 618 619 620def convert( 621 module, 622 mapping=None, 623 inplace=False, 624 remove_qconfig=True, 625 is_reference=False, 626 convert_custom_config_dict=None, 627 use_precomputed_fake_quant=False, 628): 629 r"""Converts submodules in input module to a different module according to `mapping` 630 by calling `from_float` method on the target module class. And remove qconfig at the 631 end if remove_qconfig is set to True. 632 633 Args: 634 `module`: prepared and calibrated module 635 `mapping`: a dictionary that maps from source module type to target 636 module type, can be overwritten to allow swapping user defined 637 Modules 638 `inplace`: carry out model transformations in-place, the original module 639 is mutated 640 `convert_custom_config_dict`: custom configuration dictionary for convert function 641 `use_precomputed_fake_quant`: a flag to enable use of precomputed fake quant 642 643 .. code-block:: python 644 645 # Example of convert_custom_config_dict: 646 convert_custom_config_dict = { 647 # user will manually define the corresponding quantized 648 # module class which has a from_observed class method that converts 649 # observed custom module to quantized custom module 650 "observed_to_quantized_custom_module_class": { 651 ObservedCustomModule: QuantizedCustomModule 652 } 653 } 654 655 """ 656 torch._C._log_api_usage_once("quantization_api.quantize.convert") 657 if not inplace: 658 module = copy.deepcopy(module) 659 _convert( 660 module, 661 mapping, 662 inplace=True, 663 is_reference=is_reference, 664 convert_custom_config_dict=convert_custom_config_dict, 665 use_precomputed_fake_quant=use_precomputed_fake_quant, 666 ) 667 if remove_qconfig: 668 _remove_qconfig(module) 669 return module 670 671 672def _convert( 673 module, 674 mapping=None, 675 inplace=False, 676 is_reference=False, 677 convert_custom_config_dict=None, 678 use_precomputed_fake_quant=False, 679): 680 r"""Converts submodules in input module to a different module according to `mapping` 681 by calling `from_float` method on the target module class 682 683 Args: 684 module: input module 685 mapping: a dictionary that maps from source module type to target 686 module type, can be overwritten to allow swapping user defined 687 Modules 688 inplace: carry out model transformations in-place, the original module 689 is mutated 690 is_reference: a flag to enable quantized reference module 691 use_precomputed_fake_quant: a flag to enable use of precomputed fake quant 692 693 """ 694 if mapping is None: 695 mapping = ( 696 get_default_static_quant_reference_module_mappings() 697 if is_reference 698 else get_default_static_quant_module_mappings() 699 ) 700 if convert_custom_config_dict is None: 701 convert_custom_config_dict = get_default_custom_config_dict() 702 custom_module_class_mapping = convert_custom_config_dict.get( 703 "observed_to_quantized_custom_module_class", {} 704 ) 705 706 if not inplace: 707 module = copy.deepcopy(module) 708 reassign = {} 709 for name, mod in module.named_children(): 710 # both fused modules and observed custom modules are 711 # swapped as one unit 712 if ( 713 not isinstance(mod, _FusedModule) 714 and type_before_parametrizations(mod) not in custom_module_class_mapping 715 ): 716 _convert( 717 mod, 718 mapping, 719 True, # inplace 720 is_reference, 721 convert_custom_config_dict, 722 use_precomputed_fake_quant=use_precomputed_fake_quant, 723 ) 724 reassign[name] = swap_module( 725 mod, mapping, custom_module_class_mapping, use_precomputed_fake_quant 726 ) 727 728 for key, value in reassign.items(): 729 module._modules[key] = value 730 731 return module 732 733 734def swap_module( 735 mod, mapping, custom_module_class_mapping, use_precomputed_fake_quant=False 736): 737 r"""Swaps the module if it has a quantized counterpart and it has an 738 `observer` attached. 739 740 Args: 741 mod: input module 742 mapping: a dictionary that maps from nn module to nnq module 743 744 Return: 745 The corresponding quantized module of `mod` 746 """ 747 new_mod = mod 748 if hasattr(mod, "qconfig") and mod.qconfig is not None: 749 swapped = False 750 if type_before_parametrizations(mod) in custom_module_class_mapping: 751 new_mod = custom_module_class_mapping[ 752 type_before_parametrizations(mod) 753 ].from_observed(mod) 754 swapped = True 755 elif type_before_parametrizations(mod) in mapping: 756 qmod = mapping[type_before_parametrizations(mod)] 757 if hasattr(qmod, "_IS_REFERENCE") and qmod._IS_REFERENCE: 758 assert mod.qconfig is not None 759 weight_post_process = mod.qconfig.weight() 760 weight_post_process(mod.weight) 761 weight_qparams = get_qparam_dict(weight_post_process) 762 new_mod = qmod.from_float(mod, weight_qparams) 763 else: 764 sig = inspect.signature(qmod.from_float) 765 if "use_precomputed_fake_quant" in sig.parameters: 766 new_mod = qmod.from_float( 767 mod, use_precomputed_fake_quant=use_precomputed_fake_quant 768 ) 769 else: 770 new_mod = qmod.from_float(mod) 771 swapped = True 772 773 if swapped: 774 # Preserve module's pre forward hooks. They'll be called on quantized input 775 for pre_hook_fn in mod._forward_pre_hooks.values(): 776 new_mod.register_forward_pre_hook(pre_hook_fn) 777 # Preserve module's post forward hooks except _observer_forward_hook 778 # After convert they'll work with quantized output 779 for hook_fn in mod._forward_hooks.values(): 780 if hook_fn is not _observer_forward_hook: 781 new_mod.register_forward_hook(hook_fn) 782 783 # respect device affinity when swapping modules 784 devices = _get_unique_devices_(mod) 785 assert ( 786 len(devices) <= 1 787 ), f"swap_module only works with cpu or single-device CUDA modules, but got devices {devices}" 788 device = next(iter(devices)) if len(devices) > 0 else None 789 if device: 790 new_mod.to(device) 791 return new_mod 792 793 794def _get_observer_dict(mod, target_dict, prefix=""): 795 r"""Traverse the modules and save all observers into dict. 796 This is mainly used for quantization accuracy debug 797 Args: 798 mod: the top module we want to save all observers 799 prefix: the prefix for the current module 800 target_dict: the dictionary used to save all the observers 801 """ 802 803 def get_prefix(prefix): 804 return prefix if prefix == "" else prefix + "." 805 806 if hasattr(mod, "activation_post_process"): 807 target_dict[ 808 get_prefix(prefix) + "activation_post_process" 809 ] = mod.activation_post_process 810 for name, child in mod.named_children(): 811 module_prefix = get_prefix(prefix) + name if prefix else name 812 _get_observer_dict(child, target_dict, module_prefix) 813