1# mypy: allow-untyped-defs 2import copy 3from typing import List, Optional 4 5import torch.nn as nn 6 7# for backward compatibility 8from torch.ao.quantization.fuser_method_mappings import ( # noqa: F401 # noqa: F401 9 fuse_conv_bn, 10 fuse_conv_bn_relu, 11 get_fuser_method, 12) 13from torch.nn.utils.parametrize import type_before_parametrizations 14 15 16__all__ = [ 17 "fuse_known_modules", 18 "fuse_modules", 19 "fuse_modules_qat", 20] 21 22 23# Generalization of getattr 24def _get_module(model, submodule_key): 25 tokens = submodule_key.split(".") 26 cur_mod = model 27 for s in tokens: 28 cur_mod = getattr(cur_mod, s) 29 return cur_mod 30 31 32# Generalization of setattr 33def _set_module(model, submodule_key, module): 34 tokens = submodule_key.split(".") 35 sub_tokens = tokens[:-1] 36 cur_mod = model 37 for s in sub_tokens: 38 cur_mod = getattr(cur_mod, s) 39 40 setattr(cur_mod, tokens[-1], module) 41 42 43def fuse_known_modules(mod_list, is_qat, additional_fuser_method_mapping=None): 44 r"""Return a list of known fuse modules. 45 46 Returns a list of modules that fuses the operations specified 47 in the input module list. 48 49 Fuses only the following sequence of modules: 50 conv, bn 51 conv, bn, relu 52 conv, relu 53 linear, bn 54 linear, relu 55 For these sequences, the first element in the output module list performs 56 the fused operation. The rest of the elements are set to nn.Identity() 57 """ 58 types = tuple(type_before_parametrizations(m) for m in mod_list) 59 fuser_method = get_fuser_method(types, additional_fuser_method_mapping) 60 if fuser_method is None: 61 raise NotImplementedError(f"Cannot fuse modules: {types}") 62 new_mod: List[Optional[nn.Module]] = [None] * len(mod_list) 63 fused = fuser_method(is_qat, *mod_list) 64 # NOTE: forward hooks not processed in the two following for loops will be lost after the fusion 65 # Move pre forward hooks of the base module to resulting fused module 66 for pre_hook_fn in mod_list[0]._forward_pre_hooks.values(): 67 fused.register_forward_pre_hook(pre_hook_fn) 68 mod_list[0]._forward_pre_hooks.clear() 69 # Move post forward hooks of the last module to resulting fused module 70 for hook_fn in mod_list[-1]._forward_hooks.values(): 71 fused.register_forward_hook(hook_fn) 72 mod_list[-1]._forward_hooks.clear() 73 new_mod[0] = fused 74 75 for i in range(1, len(mod_list)): 76 identity = nn.Identity() 77 identity.training = mod_list[0].training 78 new_mod[i] = identity 79 80 return new_mod 81 82 83def _fuse_modules_helper( 84 model, 85 modules_to_fuse, 86 is_qat, 87 fuser_func=fuse_known_modules, 88 fuse_custom_config_dict=None, 89): 90 if fuse_custom_config_dict is None: 91 fuse_custom_config_dict = {} 92 additional_fuser_method_mapping = fuse_custom_config_dict.get( 93 "additional_fuser_method_mapping", {} 94 ) 95 mod_list = [] 96 for item in modules_to_fuse: 97 mod_list.append(_get_module(model, item)) 98 99 # Fuse list of modules 100 new_mod_list = fuser_func(mod_list, is_qat, additional_fuser_method_mapping) 101 102 # Replace original module list with fused module list 103 for i, item in enumerate(modules_to_fuse): 104 _set_module(model, item, new_mod_list[i]) 105 106 107def _fuse_modules( 108 model, 109 modules_to_fuse, 110 is_qat, 111 inplace=False, 112 fuser_func=fuse_known_modules, 113 fuse_custom_config_dict=None, 114): 115 if not inplace: 116 model = copy.deepcopy(model) 117 118 if all(isinstance(module_element, str) for module_element in modules_to_fuse): 119 # Handle case of modules_to_fuse being a list 120 _fuse_modules_helper( 121 model, modules_to_fuse, is_qat, fuser_func, fuse_custom_config_dict 122 ) 123 else: 124 # Handle case of modules_to_fuse being a list of lists 125 for module_list in modules_to_fuse: 126 _fuse_modules_helper( 127 model, module_list, is_qat, fuser_func, fuse_custom_config_dict 128 ) 129 return model 130 131 132def fuse_modules( 133 model, 134 modules_to_fuse, 135 inplace=False, 136 fuser_func=fuse_known_modules, 137 fuse_custom_config_dict=None, 138): 139 r"""Fuse a list of modules into a single module. 140 141 Fuses only the following sequence of modules: 142 conv, bn 143 conv, bn, relu 144 conv, relu 145 linear, relu 146 bn, relu 147 All other sequences are left unchanged. 148 For these sequences, replaces the first item in the list 149 with the fused module, replacing the rest of the modules 150 with identity. 151 152 Args: 153 model: Model containing the modules to be fused 154 modules_to_fuse: list of list of module names to fuse. Can also be a list 155 of strings if there is only a single list of modules to fuse. 156 inplace: bool specifying if fusion happens in place on the model, by default 157 a new model is returned 158 fuser_func: Function that takes in a list of modules and outputs a list of fused modules 159 of the same length. For example, 160 fuser_func([convModule, BNModule]) returns the list [ConvBNModule, nn.Identity()] 161 Defaults to torch.ao.quantization.fuse_known_modules 162 `fuse_custom_config_dict`: custom configuration for fusion 163 164 .. code-block:: python 165 166 # Example of fuse_custom_config_dict 167 fuse_custom_config_dict = { 168 # Additional fuser_method mapping 169 "additional_fuser_method_mapping": { 170 (torch.nn.Conv2d, torch.nn.BatchNorm2d): fuse_conv_bn 171 }, 172 } 173 174 Returns: 175 model with fused modules. A new copy is created if inplace=True. 176 177 Examples:: 178 179 >>> # xdoctest: +SKIP 180 >>> m = M().eval() 181 >>> # m is a module containing the sub-modules below 182 >>> modules_to_fuse = [ ['conv1', 'bn1', 'relu1'], ['submodule.conv', 'submodule.relu']] 183 >>> fused_m = torch.ao.quantization.fuse_modules(m, modules_to_fuse) 184 >>> output = fused_m(input) 185 186 >>> m = M().eval() 187 >>> # Alternately provide a single list of modules to fuse 188 >>> modules_to_fuse = ['conv1', 'bn1', 'relu1'] 189 >>> fused_m = torch.ao.quantization.fuse_modules(m, modules_to_fuse) 190 >>> output = fused_m(input) 191 192 """ 193 return _fuse_modules( 194 model, 195 modules_to_fuse, 196 is_qat=False, 197 inplace=inplace, 198 fuser_func=fuser_func, 199 fuse_custom_config_dict=fuse_custom_config_dict, 200 ) 201 202 203def fuse_modules_qat( 204 model, 205 modules_to_fuse, 206 inplace=False, 207 fuser_func=fuse_known_modules, 208 fuse_custom_config_dict=None, 209): 210 """QAT version for `fuse_modules`.""" 211 return _fuse_modules( 212 model, 213 modules_to_fuse, 214 is_qat=True, 215 inplace=inplace, 216 fuser_func=fuser_func, 217 fuse_custom_config_dict=fuse_custom_config_dict, 218 ) 219