xref: /aosp_15_r20/external/pytorch/torch/ao/quantization/fuse_modules.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import copy
3from typing import List, Optional
4
5import torch.nn as nn
6
7# for backward compatibility
8from torch.ao.quantization.fuser_method_mappings import (  # noqa: F401  # noqa: F401
9    fuse_conv_bn,
10    fuse_conv_bn_relu,
11    get_fuser_method,
12)
13from torch.nn.utils.parametrize import type_before_parametrizations
14
15
16__all__ = [
17    "fuse_known_modules",
18    "fuse_modules",
19    "fuse_modules_qat",
20]
21
22
23# Generalization of getattr
24def _get_module(model, submodule_key):
25    tokens = submodule_key.split(".")
26    cur_mod = model
27    for s in tokens:
28        cur_mod = getattr(cur_mod, s)
29    return cur_mod
30
31
32# Generalization of setattr
33def _set_module(model, submodule_key, module):
34    tokens = submodule_key.split(".")
35    sub_tokens = tokens[:-1]
36    cur_mod = model
37    for s in sub_tokens:
38        cur_mod = getattr(cur_mod, s)
39
40    setattr(cur_mod, tokens[-1], module)
41
42
43def fuse_known_modules(mod_list, is_qat, additional_fuser_method_mapping=None):
44    r"""Return a list of known fuse modules.
45
46    Returns a list of modules that fuses the operations specified
47     in the input module list.
48
49    Fuses only the following sequence of modules:
50    conv, bn
51    conv, bn, relu
52    conv, relu
53    linear, bn
54    linear, relu
55    For these sequences, the first element in the output module list performs
56    the fused operation. The rest of the elements are set to nn.Identity()
57    """
58    types = tuple(type_before_parametrizations(m) for m in mod_list)
59    fuser_method = get_fuser_method(types, additional_fuser_method_mapping)
60    if fuser_method is None:
61        raise NotImplementedError(f"Cannot fuse modules: {types}")
62    new_mod: List[Optional[nn.Module]] = [None] * len(mod_list)
63    fused = fuser_method(is_qat, *mod_list)
64    # NOTE: forward hooks not processed in the two following for loops will be lost after the fusion
65    # Move pre forward hooks of the base module to resulting fused module
66    for pre_hook_fn in mod_list[0]._forward_pre_hooks.values():
67        fused.register_forward_pre_hook(pre_hook_fn)
68    mod_list[0]._forward_pre_hooks.clear()
69    # Move post forward hooks of the last module to resulting fused module
70    for hook_fn in mod_list[-1]._forward_hooks.values():
71        fused.register_forward_hook(hook_fn)
72    mod_list[-1]._forward_hooks.clear()
73    new_mod[0] = fused
74
75    for i in range(1, len(mod_list)):
76        identity = nn.Identity()
77        identity.training = mod_list[0].training
78        new_mod[i] = identity
79
80    return new_mod
81
82
83def _fuse_modules_helper(
84    model,
85    modules_to_fuse,
86    is_qat,
87    fuser_func=fuse_known_modules,
88    fuse_custom_config_dict=None,
89):
90    if fuse_custom_config_dict is None:
91        fuse_custom_config_dict = {}
92    additional_fuser_method_mapping = fuse_custom_config_dict.get(
93        "additional_fuser_method_mapping", {}
94    )
95    mod_list = []
96    for item in modules_to_fuse:
97        mod_list.append(_get_module(model, item))
98
99    # Fuse list of modules
100    new_mod_list = fuser_func(mod_list, is_qat, additional_fuser_method_mapping)
101
102    # Replace original module list with fused module list
103    for i, item in enumerate(modules_to_fuse):
104        _set_module(model, item, new_mod_list[i])
105
106
107def _fuse_modules(
108    model,
109    modules_to_fuse,
110    is_qat,
111    inplace=False,
112    fuser_func=fuse_known_modules,
113    fuse_custom_config_dict=None,
114):
115    if not inplace:
116        model = copy.deepcopy(model)
117
118    if all(isinstance(module_element, str) for module_element in modules_to_fuse):
119        # Handle case of modules_to_fuse being a list
120        _fuse_modules_helper(
121            model, modules_to_fuse, is_qat, fuser_func, fuse_custom_config_dict
122        )
123    else:
124        # Handle case of modules_to_fuse being a list of lists
125        for module_list in modules_to_fuse:
126            _fuse_modules_helper(
127                model, module_list, is_qat, fuser_func, fuse_custom_config_dict
128            )
129    return model
130
131
132def fuse_modules(
133    model,
134    modules_to_fuse,
135    inplace=False,
136    fuser_func=fuse_known_modules,
137    fuse_custom_config_dict=None,
138):
139    r"""Fuse a list of modules into a single module.
140
141    Fuses only the following sequence of modules:
142    conv, bn
143    conv, bn, relu
144    conv, relu
145    linear, relu
146    bn, relu
147    All other sequences are left unchanged.
148    For these sequences, replaces the first item in the list
149    with the fused module, replacing the rest of the modules
150    with identity.
151
152    Args:
153        model: Model containing the modules to be fused
154        modules_to_fuse: list of list of module names to fuse. Can also be a list
155                         of strings if there is only a single list of modules to fuse.
156        inplace: bool specifying if fusion happens in place on the model, by default
157                 a new model is returned
158        fuser_func: Function that takes in a list of modules and outputs a list of fused modules
159                    of the same length. For example,
160                    fuser_func([convModule, BNModule]) returns the list [ConvBNModule, nn.Identity()]
161                    Defaults to torch.ao.quantization.fuse_known_modules
162        `fuse_custom_config_dict`: custom configuration for fusion
163
164    .. code-block:: python
165
166       # Example of fuse_custom_config_dict
167       fuse_custom_config_dict = {
168           # Additional fuser_method mapping
169           "additional_fuser_method_mapping": {
170               (torch.nn.Conv2d, torch.nn.BatchNorm2d): fuse_conv_bn
171           },
172       }
173
174    Returns:
175        model with fused modules. A new copy is created if inplace=True.
176
177    Examples::
178
179            >>> # xdoctest: +SKIP
180            >>> m = M().eval()
181            >>> # m is a module containing the sub-modules below
182            >>> modules_to_fuse = [ ['conv1', 'bn1', 'relu1'], ['submodule.conv', 'submodule.relu']]
183            >>> fused_m = torch.ao.quantization.fuse_modules(m, modules_to_fuse)
184            >>> output = fused_m(input)
185
186            >>> m = M().eval()
187            >>> # Alternately provide a single list of modules to fuse
188            >>> modules_to_fuse = ['conv1', 'bn1', 'relu1']
189            >>> fused_m = torch.ao.quantization.fuse_modules(m, modules_to_fuse)
190            >>> output = fused_m(input)
191
192    """
193    return _fuse_modules(
194        model,
195        modules_to_fuse,
196        is_qat=False,
197        inplace=inplace,
198        fuser_func=fuser_func,
199        fuse_custom_config_dict=fuse_custom_config_dict,
200    )
201
202
203def fuse_modules_qat(
204    model,
205    modules_to_fuse,
206    inplace=False,
207    fuser_func=fuse_known_modules,
208    fuse_custom_config_dict=None,
209):
210    """QAT version for `fuse_modules`."""
211    return _fuse_modules(
212        model,
213        modules_to_fuse,
214        is_qat=True,
215        inplace=inplace,
216        fuser_func=fuser_func,
217        fuse_custom_config_dict=fuse_custom_config_dict,
218    )
219