1# mypy: allow-untyped-defs 2import re 3from collections import defaultdict, OrderedDict 4from typing import Any, Callable, Dict, List, Set, Tuple, Union 5 6import torch 7from torch.ao.nn.intrinsic import _FusedModule 8from torch.ao.quantization import QConfig 9from torch.ao.quantization.backend_config import BackendConfig, DTypeConfig 10from torch.ao.quantization.backend_config.utils import get_module_to_qat_module 11from torch.ao.quantization.observer import _is_activation_post_process 12from torch.ao.quantization.qconfig import ( 13 _add_module_to_qconfig_obs_ctr, 14 qconfig_equals, 15 QConfigAny, 16) 17from torch.ao.quantization.qconfig_mapping import ( 18 _MODULE_NAME_DICT_KEY, 19 _MODULE_NAME_REGEX_DICT_KEY, 20 _OBJECT_TYPE_DICT_KEY, 21 QConfigMapping, 22) 23from torch.ao.quantization.utils import _parent_name, get_qconfig_dtypes 24from torch.fx import GraphModule 25from torch.fx.graph import Graph 26 27 28__all__: List[str] = [] 29 30 31def _maybe_adjust_qconfig_for_module_name_object_type_order( 32 qconfig_mapping: QConfigMapping, 33 cur_module_path: str, 34 cur_object_type: Callable, 35 cur_object_type_idx: int, 36 fallback_qconfig: QConfigAny, 37) -> QConfigAny: 38 for ( 39 module_name, 40 object_type, 41 index, 42 ), qconfig in qconfig_mapping.module_name_object_type_order_qconfigs.items(): 43 if ( 44 (module_name == cur_module_path) 45 and (object_type == cur_object_type) 46 and (index == cur_object_type_idx) 47 ): 48 return qconfig 49 return fallback_qconfig 50 51 52def _update_qconfig_for_fusion(model: GraphModule, qconfig_mapping: QConfigMapping): 53 """ 54 Update the QConfigMapping to account for fused modules such as LinearReLU. 55 This assumes the QConfigMapping's attributes have already been converted to OrderedDicts. 56 """ 57 object_type_dict = qconfig_mapping.object_type_qconfigs 58 if len(object_type_dict) == 0: 59 return qconfig_mapping 60 61 modules = dict(model.named_modules()) 62 63 for node in model.graph.nodes: 64 if node.op == "call_module" and node.target in modules: 65 maybe_fused_module = modules[str(node.target)] 66 if not isinstance(maybe_fused_module, _FusedModule): 67 continue 68 69 ops = list(maybe_fused_module._modules.values()) 70 fused_qconfig = object_type_dict.get(type(ops[0]), None) 71 72 # Raise an error if the modules in the fused module have 73 # different qconfigs specified in the qconfig_dict 74 # TODO: currently it only works for modules, 75 # need to make this work for torch.nn.functional.relu 76 # TODO: currently it only works for object_type configurations, 77 # ideally it should work for different types of configurations, 78 # maybe we want to redesign this part 79 for op in ops[1:]: 80 if not qconfig_equals( 81 object_type_dict.get(type(op), None), fused_qconfig 82 ): 83 raise LookupError( 84 "During fusion, we need to specify the same " 85 + f"qconfigs for all module types in {type(maybe_fused_module)} " 86 + f"offending type: {type(op)}" 87 ) 88 89 if fused_qconfig is not None: 90 object_type_dict[type(maybe_fused_module)] = fused_qconfig 91 92 93def _generate_node_name_to_qconfig( 94 root: torch.nn.Module, 95 modules: Dict[str, torch.nn.Module], 96 input_graph: Graph, 97 qconfig_mapping: QConfigMapping, 98 node_name_to_scope: Dict[str, Tuple[str, type]], 99) -> Dict[str, QConfigAny]: 100 global_qconfig = qconfig_mapping.global_qconfig 101 node_name_to_qconfig = {} 102 103 # example: 104 # 105 # {'foo.bar': {F.linear: 0, F.conv2d: 1, ...}, ...} 106 # 107 # meaning in submodule 'foo.bar', we have seen 0 F.linear and 108 # 1 F.conv2d invocations so far. 109 submodule_to_object_type_to_cur_idx: Dict[str, Dict[Callable, int]] = defaultdict( 110 lambda: defaultdict(int) 111 ) 112 for node in input_graph.nodes: 113 qconfig = None 114 if node.op == "get_attr": 115 module_name, _ = _parent_name(node.target) 116 qconfig = _maybe_adjust_qconfig_for_module_type_or_name( 117 qconfig_mapping, type(modules[module_name]), module_name, global_qconfig 118 ) 119 qconfig_with_device_check = _add_module_to_qconfig_obs_ctr( 120 qconfig, modules.get(node.target, None) 121 ) 122 elif node.op == "call_function": 123 # precedence: module_name_qconfig 124 # > function_qconfig > global_qconfig 125 # module_name takes precedence over function qconfig 126 function_qconfig = _get_object_type_qconfig( 127 qconfig_mapping, node.target, global_qconfig 128 ) 129 module_path, module_type = node_name_to_scope[node.name] 130 qconfig = _maybe_adjust_qconfig_for_module_type_or_name( 131 qconfig_mapping, module_type, module_path, function_qconfig 132 ) 133 134 cur_object_type_idx = submodule_to_object_type_to_cur_idx[module_path][ 135 node.target 136 ] 137 submodule_to_object_type_to_cur_idx[module_path][node.target] += 1 138 qconfig = _maybe_adjust_qconfig_for_module_name_object_type_order( 139 qconfig_mapping, module_path, node.target, cur_object_type_idx, qconfig 140 ) 141 qconfig_with_device_check = _add_module_to_qconfig_obs_ctr( 142 qconfig, modules.get(node.target, None) 143 ) 144 145 elif node.op == "call_method": 146 module_path, module_type = node_name_to_scope[node.name] 147 # first use node.target (string) to get the qconfig 148 # this is to support configs like 149 # "object_type": [("reshape", qconfig)] 150 qconfig = _maybe_adjust_qconfig_for_module_type_or_name( 151 qconfig_mapping, node.target, module_path, global_qconfig 152 ) 153 # if there is no special config for the method, we'll fall back to the 154 # config for the module that contains the call_method node 155 qconfig = _maybe_adjust_qconfig_for_module_type_or_name( 156 qconfig_mapping, module_type, module_path, qconfig 157 ) 158 # currently call_method does not support modifying qconfig 159 # by order, we can add this later if it is needed. 160 qconfig_with_device_check = _add_module_to_qconfig_obs_ctr( 161 qconfig, modules.get(node.target, None) 162 ) 163 164 elif node.op == "call_module": 165 # if the node is an observer, just continue - don't add it to the qconfig_map 166 if _is_activation_post_process(modules[node.target]): 167 continue 168 qconfig = _maybe_adjust_qconfig_for_module_type_or_name( 169 qconfig_mapping, type(modules[node.target]), node.target, global_qconfig 170 ) 171 172 module_path, module_type = node_name_to_scope[node.name] 173 # Note: for call_module, the module_path is the current module's name. 174 # to meaningfully count invocations, we need to count them in the parent 175 # module. 176 parent_name, _ = _parent_name(module_path) 177 cur_object_type_idx = submodule_to_object_type_to_cur_idx[parent_name][ 178 module_type 179 ] 180 submodule_to_object_type_to_cur_idx[parent_name][module_type] += 1 181 qconfig = _maybe_adjust_qconfig_for_module_name_object_type_order( 182 qconfig_mapping, parent_name, module_type, cur_object_type_idx, qconfig 183 ) 184 qconfig_with_device_check = _add_module_to_qconfig_obs_ctr( 185 qconfig, modules.get(node.target, None) 186 ) 187 188 # regex is not supported eager mode propagate_qconfig_, we'll 189 # need to set the qconfig explicitly here in case regex 190 # is used 191 modules[node.target].qconfig = qconfig_with_device_check 192 else: 193 qconfig_with_device_check = None 194 195 node_name_to_qconfig[node.name] = qconfig_with_device_check 196 return node_name_to_qconfig 197 198 199def _check_is_valid_config_dict( 200 config_dict: Any, allowed_keys: Set[str], dict_name: str 201) -> None: 202 r"""Checks if the given config_dict has the correct keys 203 204 Args: 205 `config_dict`: dictionary whose keys we want to check 206 """ 207 208 for k in config_dict.keys(): 209 if k not in allowed_keys: 210 raise ValueError( 211 "Expected " 212 + dict_name 213 + " to have the following keys: " 214 + str(allowed_keys) 215 + ". But found '" 216 + k 217 + "' instead." 218 ) 219 220 221def _compare_prepare_convert_qconfig_mappings( 222 prepare_qconfig_mapping: QConfigMapping, convert_qconfig_mapping: QConfigMapping 223): 224 r"""Compare the qconfig_mapping passed in convert to the one from prepare and check the values 225 226 Args: 227 `prepare_qconfig_mapping`: configuration for prepare quantization step 228 `convert_qconfig_mapping`: configuration for convert quantization step 229 """ 230 assert qconfig_equals( 231 prepare_qconfig_mapping.global_qconfig, convert_qconfig_mapping.global_qconfig 232 ), "Expected global qconfigs to be the same in the prepare and convert quantization configs" 233 prepare_dicts: List[OrderedDict] = [ 234 prepare_qconfig_mapping.object_type_qconfigs, 235 prepare_qconfig_mapping.module_name_qconfigs, 236 prepare_qconfig_mapping.module_name_regex_qconfigs, 237 ] 238 convert_dicts: List[OrderedDict] = [ 239 convert_qconfig_mapping.object_type_qconfigs, 240 convert_qconfig_mapping.module_name_qconfigs, 241 convert_qconfig_mapping.module_name_regex_qconfigs, 242 ] 243 dict_names = [ 244 _OBJECT_TYPE_DICT_KEY, 245 _MODULE_NAME_DICT_KEY, 246 _MODULE_NAME_REGEX_DICT_KEY, 247 ] 248 for i in range(len(prepare_dicts)): 249 for name in prepare_dicts[i].keys(): 250 assert ( 251 name in convert_dicts[i] 252 ), f"Missing key {dict_names[i]} {name} in convert QConfigMapping \ 253 when it was present in prepare" 254 assert convert_dicts[i][name] is None or qconfig_equals( 255 prepare_dicts[i][name], convert_dicts[i][name] 256 ), f"Expected convert QConfigMapping to have the same qconfig as prepare for key {dict_names[i]} {name}; \ 257 prepare: {prepare_dicts[i][name]}; convert: {convert_dicts[i][name]}" 258 259 260def _is_qconfig_supported_by_dtype_configs( 261 qconfig: QConfig, dtype_configs: List[DTypeConfig] 262): 263 for dtype_config in dtype_configs: 264 is_dynamic = dtype_config.is_dynamic 265 if is_dynamic is None: 266 is_dynamic = False 267 input_dtype = dtype_config.input_dtype or torch.float 268 weight_dtype = dtype_config.weight_dtype or torch.float 269 bias_dtype = dtype_config.bias_dtype or torch.float 270 output_dtype = dtype_config.output_dtype or torch.float 271 ( 272 qconfig_activation_dtype, 273 qconfig_weight_dtype, 274 qconfig_input_act_is_dynamic, 275 ) = get_qconfig_dtypes(qconfig) 276 qconfig_bias_dtype = ( 277 torch.float16 278 if ( 279 qconfig_activation_dtype == torch.float16 280 and qconfig_weight_dtype == torch.float16 281 and not is_dynamic 282 ) 283 else torch.float 284 ) 285 286 if is_dynamic: 287 is_match = ( 288 qconfig_input_act_is_dynamic 289 and input_dtype == qconfig_activation_dtype 290 and output_dtype == torch.float 291 and weight_dtype == qconfig_weight_dtype 292 ) 293 else: 294 is_match = ( 295 input_dtype == qconfig_activation_dtype 296 and output_dtype == qconfig_activation_dtype 297 and weight_dtype == qconfig_weight_dtype 298 and bias_dtype == qconfig_bias_dtype 299 ) 300 if is_match: 301 return True 302 return False 303 304 305def _get_object_type_qconfig( 306 qconfig_mapping: QConfigMapping, 307 object_type: Union[Callable, str], 308 fallback_qconfig: QConfigAny, 309) -> QConfigAny: 310 return qconfig_mapping.object_type_qconfigs.get(object_type, fallback_qconfig) 311 312 313def _get_module_name_regex_qconfig(qconfig_mapping, module_name, fallback_qconfig): 314 for regex_pattern, qconfig in qconfig_mapping.module_name_regex_qconfigs.items(): 315 if re.match(regex_pattern, module_name): 316 # first match wins 317 return qconfig 318 return fallback_qconfig 319 320 321def _get_module_name_qconfig(qconfig_mapping, module_name, fallback_qconfig): 322 if module_name == "": 323 # module name qconfig not found 324 return fallback_qconfig 325 if module_name in qconfig_mapping.module_name_qconfigs: 326 return qconfig_mapping.module_name_qconfigs[module_name] 327 else: 328 parent, _ = _parent_name(module_name) 329 return _get_module_name_qconfig(qconfig_mapping, parent, fallback_qconfig) 330 331 332def _maybe_adjust_qconfig_for_module_type_or_name( 333 qconfig_mapping, module_type, module_name, global_qconfig 334): 335 # get qconfig for module_name, 336 # fallback to module_name_regex_qconfig, module_type_qconfig, 337 # global_qconfig if necessary 338 module_type_qconfig = _get_object_type_qconfig( 339 qconfig_mapping, module_type, global_qconfig 340 ) 341 module_name_regex_qconfig = _get_module_name_regex_qconfig( 342 qconfig_mapping, module_name, module_type_qconfig 343 ) 344 module_name_qconfig = _get_module_name_qconfig( 345 qconfig_mapping, module_name, module_name_regex_qconfig 346 ) 347 return module_name_qconfig 348 349 350def _get_flattened_qconfig_dict( 351 qconfig_mapping: QConfigMapping, 352) -> Dict[Union[Callable, str], QConfigAny]: 353 """flatten the global, object_type and module_name qconfig 354 to the same qconfig_dict so that it can be used by 355 propagate_qconfig_ function. 356 "module_name_regex" is ignored for now since it's not supported 357 in propagate_qconfig_, but it can be fixed later. 358 359 For example: 360 Input: { 361 "": qconfig, 362 "object_type": [ 363 (torch.add, qconfig) 364 ], 365 "module_name": [ 366 ("conv", qconfig) 367 ] 368 } 369 370 Output: { 371 "": qconfig, 372 torch.add: qconfig, 373 "conv": qconfig 374 } 375 """ 376 flattened: Dict[Union[Callable, str], QConfigAny] = { 377 "": qconfig_mapping.global_qconfig 378 } 379 for obj, qconfig in qconfig_mapping.object_type_qconfigs.items(): 380 flattened[obj] = qconfig 381 for obj, qconfig in qconfig_mapping.module_name_qconfigs.items(): 382 flattened[obj] = qconfig 383 return flattened 384 385 386def _update_qconfig_for_qat( 387 qconfig_mapping: QConfigMapping, backend_config: BackendConfig 388): 389 """ 390 Update the qconfig_mapping to account for module swaps during QAT. 391 During QAT we perform a module swap on the nn.Module types to the corresponding nn.qat.modules types. 392 """ 393 module_to_qat_module_class = get_module_to_qat_module(backend_config) 394 object_type_dict = qconfig_mapping.object_type_qconfigs 395 new_object_type_dict = object_type_dict.copy() 396 for k, v in new_object_type_dict.items(): 397 if k in module_to_qat_module_class: 398 object_type_dict[module_to_qat_module_class[k]] = v 399