xref: /aosp_15_r20/external/pytorch/torch/ao/quantization/quantize.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import copy
3import inspect
4import itertools
5import warnings
6
7import torch
8import torch.ao.nn.quantized as nnq
9import torch.nn as nn
10from torch.ao.nn.intrinsic import _FusedModule
11from torch.ao.quantization.observer import _is_activation_post_process
12from torch.ao.quantization.qconfig import (
13    _activation_is_memoryless,
14    _add_module_to_qconfig_obs_ctr,
15    default_dynamic_qconfig,
16    float16_dynamic_qconfig,
17    float_qparams_weight_only_qconfig,
18    float_qparams_weight_only_qconfig_4bit,
19)
20from torch.ao.quantization.quantization_mappings import (
21    _get_special_act_post_process,
22    _has_special_act_post_process,
23    get_default_dynamic_quant_module_mappings,
24    get_default_qat_module_mappings,
25    get_default_qconfig_propagation_list,
26    get_default_static_quant_module_mappings,
27    get_default_static_quant_reference_module_mappings,
28    no_observer_set,
29)
30from torch.ao.quantization.stubs import DeQuantStub, QuantWrapper
31from torch.nn.utils.parametrize import type_before_parametrizations
32
33from .utils import get_qparam_dict, has_no_children_ignoring_parametrizations
34
35
36__all__ = [
37    "get_default_custom_config_dict",
38    "propagate_qconfig_",
39    "add_quant_dequant",
40    "prepare",
41    "quantize",
42    "quantize_dynamic",
43    "prepare_qat",
44    "quantize_qat",
45    "convert",
46    "swap_module",
47]
48
49
50# TODO remove this once BC is no longer required to avoid a SEV
51is_activation_post_process = _is_activation_post_process
52
53
54_DEFAULT_CUSTOM_CONFIG_DICT = {
55    "float_to_observed_custom_module_class": {
56        nn.LSTM: nn.quantizable.LSTM,
57        nn.MultiheadAttention: nn.quantizable.MultiheadAttention,
58    },
59    "observed_to_quantized_custom_module_class": {
60        nn.quantizable.LSTM: nn.quantized.LSTM,
61        nn.quantizable.MultiheadAttention: nn.quantized.MultiheadAttention,
62    },
63}
64
65
66def get_default_custom_config_dict():
67    r"""Defines the default custom config dict."""
68    return _DEFAULT_CUSTOM_CONFIG_DICT
69
70
71def _propagate_qconfig_helper(
72    module,
73    qconfig_dict,
74    qconfig_parent=None,
75    prefix="",
76    prepare_custom_config_dict=None,
77):
78    r"""This is a helper function for `propagate_qconfig_`
79
80    Args:
81        module: input module
82        qconfig_dict: dictionary that maps from name of submodule to quantization
83                     configuration
84        qconfig_parent: quantization config of parent module, we will fallback to
85                       this config when there is no specified config for current
86                       module
87        prefix: corresponding prefix of the current module, used as key in
88                qconfig_dict
89        prepare_custom_config_dict: dictionary for custom handling of modules
90                                    see docs for :func:`~torch.ao.quantization.prepare_fx`
91
92    Return:
93        None, module is modified inplace with qconfig attached
94    """
95
96    module_qconfig = qconfig_dict.get(
97        type_before_parametrizations(module), qconfig_parent
98    )
99    module_qconfig = qconfig_dict.get(prefix, module_qconfig)
100    module_qconfig = getattr(module, "qconfig", module_qconfig)
101
102    torch.ao.quantization.qconfig._assert_valid_qconfig(module_qconfig, module)
103
104    qconfig_with_device_check = _add_module_to_qconfig_obs_ctr(module_qconfig, module)
105    module.qconfig = qconfig_with_device_check
106
107    for name, child in module.named_children():
108        module_prefix = prefix + "." + name if prefix else name
109        #  do no not propagate qconfig to child if child is non traceable
110        if prepare_custom_config_dict is None or not (
111            name in prepare_custom_config_dict.get("non_traceable_module_name", [])
112            or type(child)
113            in prepare_custom_config_dict.get("non_traceable_module_class", [])
114        ):
115            _propagate_qconfig_helper(
116                child, qconfig_dict, qconfig_with_device_check, module_prefix
117            )
118
119
120def propagate_qconfig_(module, qconfig_dict=None, prepare_custom_config_dict=None):
121    r"""Propagate qconfig through the module hierarchy and assign `qconfig`
122    attribute on each leaf module
123
124    Args:
125        module: input module
126        qconfig_dict: dictionary that maps from name or type of submodule to
127            quantization configuration, qconfig applies to all submodules of a
128            given module unless qconfig for the submodules are specified (when
129            the submodule already has qconfig attribute)
130        prepare_custom_config_dict: dictionary for custom handling of modules
131            see docs for :func:`~torch.ao.quantization.prepare_fx`
132
133    Return:
134        None, module is modified inplace with qconfig attached
135    """
136    if qconfig_dict is None:
137        qconfig_dict = {}
138    if prepare_custom_config_dict is None:
139        prepare_custom_config_dict = {}
140    _propagate_qconfig_helper(
141        module, qconfig_dict, prepare_custom_config_dict=prepare_custom_config_dict
142    )
143
144
145def _observer_forward_hook(self, input, output):
146    r"""Forward hook that calls observer on the output"""
147    return self.activation_post_process(output)
148
149
150def _observer_forward_pre_hook(self, input):
151    r"""Forward pre hook that calls observer on the output"""
152    return self.activation_post_process(input[0])
153
154
155def _register_activation_post_process_hook(module, pre_hook=False):
156    assert hasattr(
157        module, "activation_post_process"
158    ), "Expect activation_post_process attribute already attached to the module"
159    if pre_hook:
160        handle = module.register_forward_pre_hook(
161            _observer_forward_pre_hook, prepend=True
162        )
163    else:
164        handle = module.register_forward_hook(_observer_forward_hook, prepend=True)
165
166
167def _add_observer_(
168    module,
169    qconfig_propagation_list=None,
170    non_leaf_module_list=None,
171    device=None,
172    custom_module_class_mapping=None,
173):
174    r"""Add observer for the leaf child of the module.
175
176    This function insert observer module to all leaf child module that
177    has a valid qconfig attribute.
178
179    Args:
180        module: input module with qconfig attributes for all the leaf modules that we want to quantize
181        qconfig_propagation_list: a list of quantizable modules that will have observers added to them
182            if they are leaf nodes
183        device: parent device, if any
184        non_leaf_module_list: list of non-leaf modules we want to add observer
185
186    Return:
187        None, module is modified inplace with added observer modules and forward_hooks
188    """
189    if qconfig_propagation_list is None:
190        qconfig_propagation_list = get_default_qconfig_propagation_list()
191
192    if custom_module_class_mapping is None:
193        custom_module_class_mapping = {}
194
195    # respect device affinity when adding observers
196    if device is None:
197        devices = _get_unique_devices_(module)
198        assert (
199            len(devices) <= 1
200        ), f"_add_observer_ only works with cpu or single-device CUDA modules, but got devices {devices}"
201        device = next(iter(devices)) if len(devices) > 0 else None
202
203    def get_activation_post_process(qconfig, device, special_act_post_process=None):
204        activation = (
205            qconfig.activation()
206            if special_act_post_process is None
207            else special_act_post_process()
208        )
209        if device is not None:
210            activation.to(device)
211        return activation
212
213    def needs_observation(m):
214        return hasattr(m, "qconfig") and m.qconfig is not None
215
216    def insert_activation_post_process(m, special_act_post_process=None):
217        """Adds an activation post process module and register
218        a pre or post hook that calls the module
219        """
220        # We don't insert observer/fake_quantize for DeQuantStub
221        if needs_observation(m) and not isinstance(m, DeQuantStub):
222            # observer and hook will be gone after we swap the module
223            m.add_module(
224                "activation_post_process",
225                get_activation_post_process(
226                    m.qconfig, device, special_act_post_process
227                ),
228            )
229            # Register observer as the first entry in the hook list
230            # All post forward hooks are preserved and will be executed after the observer before convert
231            _register_activation_post_process_hook(
232                m, pre_hook=_activation_is_memoryless(m.qconfig)
233            )
234
235    for name, child in module.named_children():
236        # TODO remove Dropout special after codebase stable
237        if type_before_parametrizations(child) in [nn.Dropout]:
238            continue
239        elif issubclass(
240            type_before_parametrizations(child), (nnq.FloatFunctional, nnq.QFunctional)
241        ):
242            if needs_observation(child):
243                assert hasattr(
244                    child, "activation_post_process"
245                ), f"functional class {type_before_parametrizations(child)} has no pre-defined `activation_post_process`"
246                child.activation_post_process = get_activation_post_process(
247                    child.qconfig, device
248                )
249        elif isinstance(child, _FusedModule):
250            # activation_post_process are now added directly to nn.Sequential/_FusedModule
251            if needs_observation(child):
252                insert_activation_post_process(child)
253        elif (
254            non_leaf_module_list is not None
255            and type_before_parametrizations(child) in non_leaf_module_list
256        ):
257            if needs_observation(child):
258                insert_activation_post_process(child)
259        elif _has_special_act_post_process(child):
260            special_act_post_process = _get_special_act_post_process(child)
261            insert_activation_post_process(child, special_act_post_process)
262        elif (
263            needs_observation(child)
264            and type_before_parametrizations(child) in custom_module_class_mapping
265        ):
266            observed_child = custom_module_class_mapping[
267                type_before_parametrizations(child)
268            ].from_float(child)
269            setattr(module, name, observed_child)
270            # TODO: These are the modules that cannot be observed
271            #       Once there are more, we should move them to a separate list
272            if (
273                custom_module_class_mapping[type_before_parametrizations(child)]
274                not in no_observer_set()
275            ):
276                insert_activation_post_process(observed_child)
277        else:
278            _add_observer_(
279                child,
280                qconfig_propagation_list,
281                non_leaf_module_list,
282                device,
283                custom_module_class_mapping,
284            )
285
286    # Insert observers only for leaf nodes, note that this observer is for
287    # the output of the module, for input QuantStub will observe them
288    if (
289        has_no_children_ignoring_parametrizations(module)
290        and not isinstance(module, torch.nn.Sequential)
291        and type_before_parametrizations(module) in qconfig_propagation_list
292    ):
293        insert_activation_post_process(module)
294    # This is a special case for AdaRound eager mode
295    # AdaRound contains weight_fake_quant to be propagated from API to convert
296    # leaf node check with a number of children looks naive assumption that blocks
297    # Adding an exception case for AdaRound
298    if (
299        hasattr(module, "weight_fake_quant")
300        and not isinstance(module, torch.nn.Sequential)
301        and type_before_parametrizations(module) in qconfig_propagation_list
302    ):
303        insert_activation_post_process(module)
304
305
306def _get_unique_devices_(module):
307    return {p.device for p in module.parameters()} | {
308        p.device for p in module.buffers()
309    }
310
311
312def add_quant_dequant(module):
313    r"""Wrap the leaf child module in QuantWrapper if it has a valid qconfig
314    Note that this function will modify the children of module inplace and it
315    can return a new module which wraps the input module as well.
316
317    Args:
318        module: input module with qconfig attributes for all the leaf modules
319        that we want to quantize
320
321    Return:
322        Either the inplace modified module with submodules wrapped in
323        `QuantWrapper` based on qconfig or a new `QuantWrapper` module which
324        wraps the input module, the latter case only happens when the input
325        module is a leaf module and we want to quantize it.
326    """
327    if (
328        has_no_children_ignoring_parametrizations(module)
329        and hasattr(module, "qconfig")
330        and module.qconfig
331    ):
332        return QuantWrapper(module)
333
334    for name, child in module.named_children():
335        module._modules[name] = add_quant_dequant(child)
336    return module
337
338
339def prepare(
340    model,
341    inplace=False,
342    allow_list=None,
343    observer_non_leaf_module_list=None,
344    prepare_custom_config_dict=None,
345):
346    r"""Prepares a copy of the model for quantization calibration or quantization-aware training.
347
348    Quantization configuration should be assigned preemptively
349    to individual submodules in `.qconfig` attribute.
350
351    The model will be attached with observer or fake quant modules, and qconfig
352    will be propagated.
353
354    Args:
355        `model`: input model to be modified in-place
356        `inplace`: carry out model transformations in-place, the original module is mutated
357        `allow_list`: list of quantizable modules
358        `observer_non_leaf_module_list`: list of non-leaf modules we want to add observer
359        `prepare_custom_config_dict`: customization configuration dictionary for prepare function
360
361    .. code-block:: python
362
363       # Example of prepare_custom_config_dict:
364       prepare_custom_config_dict = {
365           # user will manually define the corresponding observed
366           # module class which has a from_float class method that converts
367           # float custom module to observed custom module
368           "float_to_observed_custom_module_class": {
369               CustomModule: ObservedCustomModule
370           }
371        }
372
373    """
374    torch._C._log_api_usage_once("quantization_api.quantize.prepare")
375    if prepare_custom_config_dict is None:
376        prepare_custom_config_dict = get_default_custom_config_dict()
377    custom_module_class_mapping = prepare_custom_config_dict.get(
378        "float_to_observed_custom_module_class", {}
379    )
380
381    if not inplace:
382        model = copy.deepcopy(model)
383
384    # TODO: remove allow_list
385    qconfig_propagation_list = allow_list
386    if allow_list is None:
387        qconfig_propagation_list = get_default_qconfig_propagation_list()
388    propagate_qconfig_(model, qconfig_dict=None)
389
390    # sanity check common API misusage
391    if not any(hasattr(m, "qconfig") and m.qconfig for m in model.modules()):
392        warnings.warn(
393            "None of the submodule got qconfig applied. Make sure you "
394            "passed correct configuration through `qconfig_dict` or "
395            "by assigning the `.qconfig` attribute directly on submodules"
396        )
397
398    _add_observer_(
399        model,
400        qconfig_propagation_list,
401        observer_non_leaf_module_list,
402        custom_module_class_mapping=custom_module_class_mapping,
403    )
404    return model
405
406
407def _remove_activation_post_process(module):
408    # TODO: maybe we should change activation_post_process to _activation_post_process
409    # to prevent it from being used by user
410    if hasattr(module, "activation_post_process") and _is_activation_post_process(
411        module.activation_post_process
412    ):
413        delattr(module, "activation_post_process")
414
415    # remove activation_post_process pre and post hooks
416    def remove_hooks(pre_hook=False):
417        hook_map = module._forward_pre_hooks if pre_hook else module._forward_hooks
418        observer_hook = (
419            _observer_forward_pre_hook if pre_hook else _observer_forward_hook
420        )
421        handle_ids_to_remove = set()
422        for handle_id, hook_fn in hook_map.items():
423            if hook_fn is observer_hook:
424                handle_ids_to_remove.add(handle_id)
425        for handle_id in handle_ids_to_remove:
426            hook_map.pop(handle_id)
427
428    remove_hooks(pre_hook=True)
429    remove_hooks(pre_hook=False)
430
431
432# TODO: rename to something more general
433def _remove_qconfig(module):
434    r"""Clean up the qconfig left in the module so that new qconfig can be
435    propagated.
436
437    Args:
438        module: module to be cleaned up
439    """
440    for child in module.children():
441        _remove_qconfig(child)
442
443    if hasattr(module, "qconfig"):
444        del module.qconfig
445
446    _remove_activation_post_process(module)
447
448
449def quantize(model, run_fn, run_args, mapping=None, inplace=False):
450    r"""Quantize the input float model with post training static quantization.
451
452    First it will prepare the model for calibration, then it calls
453    `run_fn` which will run the calibration step, after that we will
454    convert the model to a quantized model.
455
456    Args:
457        model: input float model
458        run_fn: a calibration function for calibrating the prepared model
459        run_args: positional arguments for `run_fn`
460        inplace: carry out model transformations in-place, the original module is mutated
461        mapping: correspondence between original module types and quantized counterparts
462
463    Return:
464        Quantized model.
465    """
466    torch._C._log_api_usage_once("quantization_api.quantize.quantize")
467    if mapping is None:
468        mapping = get_default_static_quant_module_mappings()
469    if not inplace:
470        model = copy.deepcopy(model)
471    model.eval()
472    prepare(model, inplace=True)
473    run_fn(model, *run_args)
474    convert(model, mapping, inplace=True)
475    return model
476
477
478def quantize_dynamic(
479    model, qconfig_spec=None, dtype=torch.qint8, mapping=None, inplace=False
480):
481    r"""Converts a float model to dynamic (i.e. weights-only) quantized model.
482
483    Replaces specified modules with dynamic weight-only quantized versions and output the quantized model.
484
485    For simplest usage provide `dtype` argument that can be float16 or qint8. Weight-only quantization
486    by default is performed for layers with large weights size - i.e. Linear and RNN variants.
487
488    Fine grained control is possible with `qconfig` and `mapping` that act similarly to `quantize()`.
489    If `qconfig` is provided, the `dtype` argument is ignored.
490
491    Args:
492        model: input model
493        qconfig_spec: Either:
494
495            - A dictionary that maps from name or type of submodule to quantization
496              configuration, qconfig applies to all submodules of a given
497              module unless qconfig for the submodules are specified (when the
498              submodule already has qconfig attribute). Entries in the dictionary
499              need to be QConfig instances.
500
501            - A set of types and/or submodule names to apply dynamic quantization to,
502              in which case the `dtype` argument is used to specify the bit-width
503
504        inplace: carry out model transformations in-place, the original module is mutated
505        mapping: maps type of a submodule to a type of corresponding dynamically quantized version
506            with which the submodule needs to be replaced
507
508    """
509    torch._C._log_api_usage_once("quantization_api.quantize.quantize_dynamic")
510    if qconfig_spec is None:
511        if dtype == torch.qint8:
512            qconfig_spec = {
513                nn.Linear: default_dynamic_qconfig,
514                nn.LSTM: default_dynamic_qconfig,
515                nn.GRU: default_dynamic_qconfig,
516                nn.LSTMCell: default_dynamic_qconfig,
517                nn.RNNCell: default_dynamic_qconfig,
518                nn.GRUCell: default_dynamic_qconfig,
519            }
520        elif dtype == torch.float16:
521            qconfig_spec = {
522                nn.Linear: float16_dynamic_qconfig,
523                nn.LSTM: float16_dynamic_qconfig,
524                nn.GRU: float16_dynamic_qconfig,
525                nn.LSTMCell: float16_dynamic_qconfig,
526                nn.RNNCell: float16_dynamic_qconfig,
527                nn.GRUCell: float16_dynamic_qconfig,
528            }
529        elif dtype == torch.quint8:
530            qconfig_spec = {
531                nn.EmbeddingBag: float_qparams_weight_only_qconfig,
532                nn.Embedding: float_qparams_weight_only_qconfig,
533            }
534        elif dtype == torch.quint4x2:
535            qconfig_spec = {
536                nn.EmbeddingBag: float_qparams_weight_only_qconfig_4bit,
537            }
538        else:
539            raise ValueError(
540                f"Don't know how to quantize with default settings for {dtype}. Provide full qconfig please"
541            )
542    elif isinstance(qconfig_spec, set):
543        if dtype is torch.qint8:
544            default_qconfig = default_dynamic_qconfig
545        elif dtype is torch.float16:
546            default_qconfig = float16_dynamic_qconfig
547        elif dtype is torch.quint8:
548            default_qconfig = float_qparams_weight_only_qconfig
549        elif dtype is torch.quint4x2:
550            default_qconfig = float_qparams_weight_only_qconfig_4bit
551        else:
552            raise RuntimeError(
553                "Unknown dtype specified for quantize_dynamic: ", str(dtype)
554            )
555        qconfig_spec = dict(zip(qconfig_spec, itertools.repeat(default_qconfig)))
556
557    if mapping is None:
558        mapping = get_default_dynamic_quant_module_mappings()
559
560    if not inplace:
561        model = copy.deepcopy(model)
562    model.eval()
563    propagate_qconfig_(model, qconfig_spec)
564    convert(model, mapping, inplace=True)
565    return model
566
567
568def prepare_qat(model, mapping=None, inplace=False):
569    r"""
570    Prepares a copy of the model for quantization calibration or
571    quantization-aware training and converts it to quantized version.
572
573    Quantization configuration should be assigned preemptively
574    to individual submodules in `.qconfig` attribute.
575
576    Args:
577        model: input model to be modified in-place
578        mapping: dictionary that maps float modules to quantized modules to be
579                 replaced.
580        inplace: carry out model transformations in-place, the original module
581                 is mutated
582    """
583    torch._C._log_api_usage_once("quantization_api.quantize.prepare_qat")
584    assert model.training, "prepare_qat only works on models in training mode"
585    if mapping is None:
586        mapping = get_default_qat_module_mappings()
587
588    if not inplace:
589        model = copy.deepcopy(model)
590
591    propagate_qconfig_(model, qconfig_dict=None)
592    convert(model, mapping=mapping, inplace=True, remove_qconfig=False)
593    prepare(model, observer_non_leaf_module_list=set(mapping.values()), inplace=True)
594    return model
595
596
597def quantize_qat(model, run_fn, run_args, inplace=False):
598    r"""Do quantization aware training and output a quantized model
599
600    Args:
601        model: input model
602        run_fn: a function for evaluating the prepared model, can be a
603                function that simply runs the prepared model or a training
604                loop
605        run_args: positional arguments for `run_fn`
606
607    Return:
608        Quantized model.
609    """
610    torch._C._log_api_usage_once("quantization_api.quantize.quantize_qat")
611    if not inplace:
612        model = copy.deepcopy(model)
613    model.train()
614    prepare_qat(model, inplace=True)
615    run_fn(model, *run_args)
616    convert(model, inplace=True)
617    return model
618
619
620def convert(
621    module,
622    mapping=None,
623    inplace=False,
624    remove_qconfig=True,
625    is_reference=False,
626    convert_custom_config_dict=None,
627    use_precomputed_fake_quant=False,
628):
629    r"""Converts submodules in input module to a different module according to `mapping`
630    by calling `from_float` method on the target module class. And remove qconfig at the
631    end if remove_qconfig is set to True.
632
633    Args:
634        `module`: prepared and calibrated module
635        `mapping`: a dictionary that maps from source module type to target
636                   module type, can be overwritten to allow swapping user defined
637                   Modules
638        `inplace`: carry out model transformations in-place, the original module
639                   is mutated
640        `convert_custom_config_dict`: custom configuration dictionary for convert function
641        `use_precomputed_fake_quant`: a flag to enable use of precomputed fake quant
642
643    .. code-block:: python
644
645       # Example of convert_custom_config_dict:
646       convert_custom_config_dict = {
647           # user will manually define the corresponding quantized
648           # module class which has a from_observed class method that converts
649           # observed custom module to quantized custom module
650           "observed_to_quantized_custom_module_class": {
651               ObservedCustomModule: QuantizedCustomModule
652           }
653       }
654
655    """
656    torch._C._log_api_usage_once("quantization_api.quantize.convert")
657    if not inplace:
658        module = copy.deepcopy(module)
659    _convert(
660        module,
661        mapping,
662        inplace=True,
663        is_reference=is_reference,
664        convert_custom_config_dict=convert_custom_config_dict,
665        use_precomputed_fake_quant=use_precomputed_fake_quant,
666    )
667    if remove_qconfig:
668        _remove_qconfig(module)
669    return module
670
671
672def _convert(
673    module,
674    mapping=None,
675    inplace=False,
676    is_reference=False,
677    convert_custom_config_dict=None,
678    use_precomputed_fake_quant=False,
679):
680    r"""Converts submodules in input module to a different module according to `mapping`
681    by calling `from_float` method on the target module class
682
683    Args:
684        module: input module
685        mapping: a dictionary that maps from source module type to target
686                 module type, can be overwritten to allow swapping user defined
687                 Modules
688        inplace: carry out model transformations in-place, the original module
689                 is mutated
690        is_reference: a flag to enable quantized reference module
691        use_precomputed_fake_quant: a flag to enable use of precomputed fake quant
692
693    """
694    if mapping is None:
695        mapping = (
696            get_default_static_quant_reference_module_mappings()
697            if is_reference
698            else get_default_static_quant_module_mappings()
699        )
700    if convert_custom_config_dict is None:
701        convert_custom_config_dict = get_default_custom_config_dict()
702    custom_module_class_mapping = convert_custom_config_dict.get(
703        "observed_to_quantized_custom_module_class", {}
704    )
705
706    if not inplace:
707        module = copy.deepcopy(module)
708    reassign = {}
709    for name, mod in module.named_children():
710        # both fused modules and observed custom modules are
711        # swapped as one unit
712        if (
713            not isinstance(mod, _FusedModule)
714            and type_before_parametrizations(mod) not in custom_module_class_mapping
715        ):
716            _convert(
717                mod,
718                mapping,
719                True,  # inplace
720                is_reference,
721                convert_custom_config_dict,
722                use_precomputed_fake_quant=use_precomputed_fake_quant,
723            )
724        reassign[name] = swap_module(
725            mod, mapping, custom_module_class_mapping, use_precomputed_fake_quant
726        )
727
728    for key, value in reassign.items():
729        module._modules[key] = value
730
731    return module
732
733
734def swap_module(
735    mod, mapping, custom_module_class_mapping, use_precomputed_fake_quant=False
736):
737    r"""Swaps the module if it has a quantized counterpart and it has an
738    `observer` attached.
739
740    Args:
741        mod: input module
742        mapping: a dictionary that maps from nn module to nnq module
743
744    Return:
745        The corresponding quantized module of `mod`
746    """
747    new_mod = mod
748    if hasattr(mod, "qconfig") and mod.qconfig is not None:
749        swapped = False
750        if type_before_parametrizations(mod) in custom_module_class_mapping:
751            new_mod = custom_module_class_mapping[
752                type_before_parametrizations(mod)
753            ].from_observed(mod)
754            swapped = True
755        elif type_before_parametrizations(mod) in mapping:
756            qmod = mapping[type_before_parametrizations(mod)]
757            if hasattr(qmod, "_IS_REFERENCE") and qmod._IS_REFERENCE:
758                assert mod.qconfig is not None
759                weight_post_process = mod.qconfig.weight()
760                weight_post_process(mod.weight)
761                weight_qparams = get_qparam_dict(weight_post_process)
762                new_mod = qmod.from_float(mod, weight_qparams)
763            else:
764                sig = inspect.signature(qmod.from_float)
765                if "use_precomputed_fake_quant" in sig.parameters:
766                    new_mod = qmod.from_float(
767                        mod, use_precomputed_fake_quant=use_precomputed_fake_quant
768                    )
769                else:
770                    new_mod = qmod.from_float(mod)
771            swapped = True
772
773        if swapped:
774            # Preserve module's pre forward hooks. They'll be called on quantized input
775            for pre_hook_fn in mod._forward_pre_hooks.values():
776                new_mod.register_forward_pre_hook(pre_hook_fn)
777            # Preserve module's post forward hooks except _observer_forward_hook
778            # After convert they'll work with quantized output
779            for hook_fn in mod._forward_hooks.values():
780                if hook_fn is not _observer_forward_hook:
781                    new_mod.register_forward_hook(hook_fn)
782
783            # respect device affinity when swapping modules
784            devices = _get_unique_devices_(mod)
785            assert (
786                len(devices) <= 1
787            ), f"swap_module only works with cpu or single-device CUDA modules, but got devices {devices}"
788            device = next(iter(devices)) if len(devices) > 0 else None
789            if device:
790                new_mod.to(device)
791    return new_mod
792
793
794def _get_observer_dict(mod, target_dict, prefix=""):
795    r"""Traverse the modules and save all observers into dict.
796    This is mainly used for quantization accuracy debug
797    Args:
798        mod: the top module we want to save all observers
799        prefix: the prefix for the current module
800        target_dict: the dictionary used to save all the observers
801    """
802
803    def get_prefix(prefix):
804        return prefix if prefix == "" else prefix + "."
805
806    if hasattr(mod, "activation_post_process"):
807        target_dict[
808            get_prefix(prefix) + "activation_post_process"
809        ] = mod.activation_post_process
810    for name, child in mod.named_children():
811        module_prefix = get_prefix(prefix) + name if prefix else name
812        _get_observer_dict(child, target_dict, module_prefix)
813