1# mypy: allow-untyped-defs 2from typing import Any, Callable, Dict, List, Tuple, Type, Union 3 4import torch 5import torch.nn as nn 6import torch.nn.functional as F 7from torch.ao.quantization.fuser_method_mappings import _reverse2, _reverse3 8from torch.ao.quantization.utils import Pattern 9 10from .backend_config import BackendConfig, BackendPatternConfig, DTypeConfig 11 12 13__all__ = [ 14 "get_pattern_to_dtype_configs", 15 "get_qat_module_classes", 16 "get_fused_module_classes", 17 "get_pattern_to_input_type_to_index", 18 "get_root_module_to_quantized_reference_module", 19 "get_fuser_method_mapping", 20 "get_module_to_qat_module", 21 "get_fusion_pattern_to_root_node_getter", 22 "get_fusion_pattern_to_extra_inputs_getter", 23 "remove_boolean_dispatch_from_name", 24 "pattern_to_human_readable", 25 "entry_to_pretty_str", 26] 27 28 29def get_pattern_to_dtype_configs( 30 backend_config: BackendConfig, 31) -> Dict[Pattern, List[DTypeConfig]]: 32 pattern_to_dtype_configs: Dict[Pattern, List[DTypeConfig]] = {} 33 for pattern, config in backend_config._pattern_complex_format_to_config.items(): 34 pattern_to_dtype_configs[pattern] = config.dtype_configs 35 return pattern_to_dtype_configs 36 37 38def get_qat_module_classes(backend_config: BackendConfig) -> Tuple[type, ...]: 39 qat_module_classes = [] 40 for config in backend_config.configs: 41 if config.qat_module is not None: 42 qat_module_classes.append(config.qat_module) 43 return tuple(set(qat_module_classes)) 44 45 46def get_fused_module_classes(backend_config: BackendConfig) -> Tuple[type, ...]: 47 fused_module_classes = [] 48 for config in backend_config.configs: 49 if config.fused_module is not None: 50 fused_module_classes.append(config.fused_module) 51 return tuple(set(fused_module_classes)) 52 53 54def get_pattern_to_input_type_to_index( 55 backend_config: BackendConfig, 56) -> Dict[Pattern, Dict[str, int]]: 57 pattern_to_input_type_to_index: Dict[Pattern, Dict[str, int]] = {} 58 for pattern, config in backend_config._pattern_complex_format_to_config.items(): 59 pattern_to_input_type_to_index[pattern] = config._input_type_to_index 60 return pattern_to_input_type_to_index 61 62 63def get_root_module_to_quantized_reference_module( 64 backend_config: BackendConfig, 65) -> Dict[Type[torch.nn.Module], Type[torch.nn.Module]]: 66 mapping: Dict[Type[torch.nn.Module], Type[torch.nn.Module]] = {} 67 for config in backend_config.configs: 68 if ( 69 config.root_module is not None 70 and config.reference_quantized_module is not None 71 ): 72 mapping[config.root_module] = config.reference_quantized_module 73 return mapping 74 75 76def get_fuser_method_mapping( 77 backend_config: BackendConfig, 78) -> Dict[Pattern, Union[nn.Sequential, Callable]]: 79 fuser_method_mapping: Dict[Pattern, Union[nn.Sequential, Callable]] = {} 80 for pattern, config in backend_config._pattern_complex_format_to_config.items(): 81 if config.fuser_method is not None: 82 # Note: both the fuser method and the pattern are specified in forward order in the 83 # BackendConfig, but the internal pattern matching code uses the reversed nested tuple 84 # format, so we need to convert both to the internal format 85 fuser_method = _get_fuser_method_in_reversed_nested_tuple_format(config) 86 fuser_method_mapping[pattern] = fuser_method 87 return fuser_method_mapping 88 89 90def get_module_to_qat_module( 91 backend_config: BackendConfig, 92) -> Dict[Pattern, Type[torch.nn.Module]]: 93 module_to_qat_module: Dict[Pattern, Type[torch.nn.Module]] = {} 94 for pattern, config in backend_config._pattern_complex_format_to_config.items(): 95 if config.qat_module is not None: 96 module_to_qat_module[pattern] = config.qat_module 97 return module_to_qat_module 98 99 100def get_fusion_pattern_to_root_node_getter( 101 backend_config: BackendConfig, 102) -> Dict[Pattern, Callable]: 103 """Get a map from fusion pattern to a function that returns the root node 104 from the fusion pattern, e.g. the most common one is: 105 def get_root_node(node_pattern): 106 while not isinstance(node_pattern[-1], Node): 107 node_pattern = node_pattern[-1] 108 return node_pattern[-1] 109 This can work for all patterns whose root node is the "last node" in the pattern, 110 e.g. (torch.add, MatchAllNode, (torch.ReLU, torch.Conv2d)) 111 """ 112 root_node_getter_mapping: Dict[Pattern, Callable] = {} 113 for pattern, config in backend_config._pattern_complex_format_to_config.items(): 114 if config._root_node_getter is not None: 115 root_node_getter_mapping[pattern] = config._root_node_getter 116 return root_node_getter_mapping 117 118 119def get_fusion_pattern_to_extra_inputs_getter( 120 backend_config: BackendConfig, 121) -> Dict[Pattern, Callable]: 122 """Get a map from fusion pattern to a function that returns extra input nodes 123 from the fusion pattern, in the order required by the root node. This is optional, 124 if not specified, we will not copy over any extra inputs for the root node. 125 Example: 126 # Let's say we have the pattern (torch.add, MatchAllNode, (torch.nn.BatchNorm2d, torch.nn.Conv2d)) 127 # and root node is torch.nn.Conv2d, and the node in MatchAllNode would be an extra 128 # argument to the fused module, we can unpack the pattern and return the node at 129 # MatchAllNode here 130 # we can implement extra_inputs_getter as follows: 131 def extra_inputs_getter(pattern) -> List[Any]: 132 add, extra_input, conv_pattern = pattern 133 return [extra_input] 134 """ 135 extra_inputs_getter_mapping: Dict[Pattern, Callable] = {} 136 for pattern, config in backend_config._pattern_complex_format_to_config.items(): 137 if config._extra_inputs_getter is not None: 138 extra_inputs_getter_mapping[pattern] = config._extra_inputs_getter 139 return extra_inputs_getter_mapping 140 141 142def remove_boolean_dispatch_from_name(p) -> Any: 143 """ 144 Some ops have a default string representation such as 145 '<function boolean_dispatch.<locals>.fn at 0x7ff1106bf280>', 146 this function replaces them with the hardcoded function names. 147 """ 148 if p is F.fractional_max_pool2d: 149 return "torch.nn.functional.fractional_max_pool2d" 150 elif p is F.fractional_max_pool3d: 151 return "torch.nn.functional.fractional_max_pool3d" 152 elif p is F.max_pool1d: 153 return "torch.nn.functional.max_pool1d" 154 elif p is F.max_pool2d: 155 return "torch.nn.functional.max_pool2d" 156 elif p is F.max_pool3d: 157 return "torch.nn.functional.max_pool3d" 158 elif p is F.adaptive_max_pool1d: 159 return "torch.nn.functional.adaptive_max_pool1d" 160 elif p is F.adaptive_max_pool2d: 161 return "torch.nn.functional.adaptive_max_pool2d" 162 elif p is F.adaptive_max_pool3d: 163 return "torch.nn.functional.adaptive_max_pool3d" 164 assert "boolean_dispatch" not in str(p), ( 165 f"{p} does not have a human readable representation in " 166 + "quantization documentation" 167 ) 168 return p 169 170 171def pattern_to_human_readable(p) -> Any: 172 if isinstance(p, tuple): 173 # nested patterns, recurse 174 return tuple(pattern_to_human_readable(inner_p) for inner_p in p) 175 elif isinstance(p, str): 176 # method names are already human readable 177 return p 178 else: 179 p = remove_boolean_dispatch_from_name(p) 180 return p 181 182 183# TODO(future PR): move backend_config_dict to use dataclass and move this logic to 184# the corresponding __str__ function 185def entry_to_pretty_str(entry) -> str: 186 """ 187 Given a backend_config_dict entry, returns a string with the human readable 188 representation of it. 189 """ 190 s = "{\n" 191 192 # always output the pattern first 193 if "pattern" in entry: 194 pattern_str = pattern_to_human_readable(entry["pattern"]) 195 196 s += f" 'pattern': {pattern_str},\n" 197 198 # custom output for dtype_configs to make it look nice 199 if "dtype_configs" in entry: 200 s += " 'dtype_configs': [\n" 201 for dtype_config in entry["dtype_configs"]: 202 s += " {\n" 203 for k, v in dtype_config.items(): 204 s += f" '{k}': {v},\n" 205 s += " },\n" 206 s += " ],\n" 207 208 # custom output for num_tensor_args_to_observation_type to make it look nice 209 if "num_tensor_args_to_observation_type" in entry: 210 s += " 'num_tensor_args_to_observation_type': {\n" 211 for k, v in entry["num_tensor_args_to_observation_type"].items(): 212 s += f" {k}: {v},\n" 213 s += " },\n" 214 215 # output all the other fields 216 custom_handled_fields = [ 217 "pattern", 218 "dtype_configs", 219 "num_tensor_args_to_observation_type", 220 ] 221 for field_name in entry: 222 if field_name in custom_handled_fields: 223 continue 224 s += f" '{field_name}': {entry[field_name]},\n" 225 226 s += "}" 227 return s 228 229 230def _get_pattern_in_reversed_nested_tuple_format( 231 config: BackendPatternConfig, 232) -> Pattern: 233 """ 234 Return the pattern specified in the given config in the reversed nested tuple format 235 used internally in the quantization pattern matching code. 236 237 If the pattern is not a tuple, or the pattern is already specified in the reversed 238 nested tuple format, return the pattern as is. Otherwise: 239 240 For 2-tuples (a, b), return (b, a). 241 For 3-tuples (a, b, c), return (c, (b, a)). 242 243 For example: 244 * Given nn.Linear, return nn.Linear 245 * Given (nn.Linear, nn.ReLU), return (nn.ReLU, nn.Linear) 246 * Given (nn.Conv2d, nn.BatchNorm2d, nn.ReLU), return 247 (nn.ReLU, (nn.BatchNorm2d, nn.Conv2d)) 248 249 For context, the reason why this is needed is the user-facing BackendConfig 250 API accepts the flat 2-or-3-tuple format in forward order. While this simple 251 format handles the vast majority of use cases, it does not handle the more 252 complex ones, and so the internal pattern matching code for quantization uses 253 the following, more general reversed nested tuple format instead: 254 255 operator = module_type | functional | torch op | native op | MatchAllNode 256 Pattern = (operator, Pattern, Pattern, ...) | operator 257 258 In the future, we expect to replace the above complex format with the one used 259 by the subgraph rewriter in torch.fx, so we don't have to maintain our own 260 complex pattern matching code. Then we won't need this helper function anymore. 261 """ 262 if config._pattern_complex_format is not None: 263 return config._pattern_complex_format 264 if config.pattern is None: 265 raise ValueError( 266 "Either 'pattern' or 'pattern_complex_format' must be specified" 267 ) 268 if not isinstance(config.pattern, tuple): 269 return config.pattern 270 271 # Pattern is specified in the simple tuple format, need to convert 272 if len(config.pattern) == 2: 273 (a, b) = config.pattern 274 return (b, a) 275 elif len(config.pattern) == 3: 276 (a, b, c) = config.pattern 277 return (c, (b, a)) 278 else: 279 raise ValueError("Expected a tuple with 2 or 3 elements, got: ", config.pattern) 280 281 282def _get_fuser_method_in_reversed_nested_tuple_format( 283 config: BackendPatternConfig, 284) -> Callable: 285 """ 286 Return the fuser method specified in the given config in the reversed nested 287 tuple format used internally in the quantization pattern matching code. 288 289 If pattern is specified in the reversed nested tuple format, we assume the 290 fuser method is also specified in this format and simply return it as is. 291 Otherwise, we convert the fuser method as follows: 292 293 * Given f(is_qat, conv, relu), return f'(is_qat, relu, conv) 294 * Given f(is_qat, conv, bn, relu), return f'(is_qat, relu, bn_conv), 295 where bn_conv is a 2-tuple (bn, conv) 296 297 The first argument of a fuser method is always `is_qat` and is not affected 298 in the conversion. We currently only support functions with 3 or 4 arguments. 299 """ 300 assert config.fuser_method is not None 301 if config._pattern_complex_format is not None: 302 return config.fuser_method 303 if not isinstance(config.pattern, tuple): 304 raise ValueError("Expected pattern to be a tuple, got: ", config.pattern) 305 306 # Pattern is specified in the simple tuple format, need to convert 307 if len(config.pattern) == 2: 308 return _reverse2(config.fuser_method) 309 elif len(config.pattern) == 3: 310 return _reverse3(config.fuser_method) 311 else: 312 raise ValueError("Expected a tuple with 2 or 3 elements, got: ", config.pattern) 313