1# mypy: allow-untyped-decorators 2# mypy: allow-untyped-defs 3import enum 4import operator 5from typing import Callable, Dict, List, Optional, Set, Tuple, Union 6 7import torch 8import torch.ao.nn.intrinsic.quantized as nniq 9import torch.ao.nn.quantized as nnq 10import torch.nn as nn 11from torch.ao.quantization import FakeQuantizeBase, ObserverBase 12from torch.ao.quantization.observer import _is_activation_post_process 13from torch.ao.quantization.utils import getattr_from_fqn 14from torch.fx import GraphModule 15from torch.fx.graph import Node 16 17from .ns_types import NSNodeTargetType, NSResultsType 18 19 20toq = torch.ops.quantized 21 22 23# TODO(future PR): consider deleting this enum and using the torch types 24# directly. This might be tricky because it is not a one to one mapping. 25class NodeInputOrOutputType(enum.Enum): 26 FP32 = enum.auto() # torch.float 27 INT8 = enum.auto() # torch.qint8 or torch.quint8 28 FP16 = enum.auto() # torch.float16 29 UNKNOWN = enum.auto() # we cannot determine input/output dtype 30 # TODO(future PR): while these functions can support multiple dtypes, 31 # for the purposes of numerical debugging we want to get the actual 32 # dtype used in the model. We will likely need some kind of dtype 33 # propagation to estimate this. 34 FP32_OR_INT8 = enum.auto() # either torch.float or torch.quint8 or torch.qint8 35 # TODO(future PRs): dynamic quant, fake quant, etc 36 37 38def get_node_first_input_and_output_type( 39 node: Node, 40 gm: GraphModule, 41 logger_cls: Callable, 42 node_type_to_io_type_map: Dict[str, Set[NSNodeTargetType]], 43) -> Tuple[NodeInputOrOutputType, NodeInputOrOutputType]: 44 # TODO(future PR): clean this up 45 FUNS_IO_TYPE_FP32 = node_type_to_io_type_map["funs_io_type_fp32"] 46 FUNS_IO_TYPE_FP16 = node_type_to_io_type_map["funs_io_type_fp16"] 47 FUNS_IO_TYPE_INT8 = node_type_to_io_type_map["funs_io_type_int8"] 48 FUNS_IO_TYPE_FP32_OR_INT8 = node_type_to_io_type_map["funs_io_type_fp32_or_int8"] 49 MODS_IO_TYPE_FP32 = node_type_to_io_type_map["mods_io_type_fp32"] 50 MODS_IO_TYPE_INT8 = node_type_to_io_type_map["mods_io_type_int8"] 51 MODS_IO_TYPE_FP32_OR_INT8 = node_type_to_io_type_map["mods_io_type_fp32_or_int8"] 52 METHS_IO_TYPE_FP32_OR_INT8 = node_type_to_io_type_map["meths_io_type_fp32_or_int8"] 53 54 if node.op == "call_function": 55 if node.target in FUNS_IO_TYPE_FP32: 56 return (NodeInputOrOutputType.FP32, NodeInputOrOutputType.FP32) 57 if node.target in FUNS_IO_TYPE_FP16: 58 return (NodeInputOrOutputType.FP16, NodeInputOrOutputType.FP16) 59 elif node.target in FUNS_IO_TYPE_INT8: 60 return (NodeInputOrOutputType.INT8, NodeInputOrOutputType.INT8) 61 elif node.target in FUNS_IO_TYPE_FP32_OR_INT8: 62 first_arg = get_normalized_nth_input(node, gm, 0) 63 assert isinstance(first_arg, Node) 64 ( 65 _prev_node_input_type, 66 prev_node_output_type, 67 ) = get_node_first_input_and_output_type( 68 first_arg, gm, logger_cls, node_type_to_io_type_map 69 ) 70 return (prev_node_output_type, prev_node_output_type) 71 else: 72 return (NodeInputOrOutputType.UNKNOWN, NodeInputOrOutputType.UNKNOWN) 73 74 elif node.op == "call_module": 75 assert node.op == "call_module" 76 assert isinstance(node.target, str) 77 mod = getattr_from_fqn(gm, node.target) 78 is_known_fp32_or_int8_input_module = any( 79 isinstance(mod, target_type) for target_type in MODS_IO_TYPE_FP32_OR_INT8 # type: ignore[arg-type] 80 ) 81 if ( 82 isinstance(mod, (logger_cls, ObserverBase, FakeQuantizeBase)) # type: ignore[arg-type] 83 or is_known_fp32_or_int8_input_module 84 ): 85 # A logger or observer's input and output type is the output 86 # type of the preceding node. 87 first_arg = get_normalized_nth_input(node, gm, 0) 88 assert isinstance(first_arg, Node) 89 ( 90 _prev_node_input_type, 91 prev_node_output_type, 92 ) = get_node_first_input_and_output_type( 93 first_arg, gm, logger_cls, node_type_to_io_type_map 94 ) 95 return (prev_node_output_type, prev_node_output_type) 96 is_known_fp32_input_module = any( 97 isinstance(mod, target_type) for target_type in MODS_IO_TYPE_FP32 # type: ignore[arg-type] 98 ) 99 is_known_int8_input_module = any( 100 isinstance(mod, target_type) for target_type in MODS_IO_TYPE_INT8 # type: ignore[arg-type] 101 ) 102 if is_known_fp32_input_module: 103 return (NodeInputOrOutputType.FP32, NodeInputOrOutputType.FP32) 104 elif is_known_int8_input_module: 105 return (NodeInputOrOutputType.INT8, NodeInputOrOutputType.INT8) 106 else: 107 return (NodeInputOrOutputType.UNKNOWN, NodeInputOrOutputType.UNKNOWN) 108 109 elif node.op == "call_method": 110 if node.target == "dequantize": 111 # Dequantize is a special node because it allows multiple input types. 112 # So, we look up the output type of the previous node and return that 113 # as the input type of this node instance. 114 prev_node = get_normalized_nth_input(node, gm, 0) 115 assert isinstance(prev_node, Node) 116 ( 117 _prev_node_input_type, 118 prev_node_output_type, 119 ) = get_node_first_input_and_output_type( 120 prev_node, gm, logger_cls, node_type_to_io_type_map 121 ) 122 return (prev_node_output_type, NodeInputOrOutputType.FP32) 123 124 elif node.target == "to": 125 # to is a special node because it allows multiple input types. 126 # So, we look up the output type of the previous node and return that 127 # as the input type of this node instance. We also look up the target 128 # of to and return the correct output type. 129 prev_node = get_normalized_nth_input(node, gm, 0) 130 assert isinstance(prev_node, Node) 131 ( 132 _prev_node_input_type, 133 prev_node_output_type, 134 ) = get_node_first_input_and_output_type( 135 prev_node, gm, logger_cls, node_type_to_io_type_map 136 ) 137 138 cur_node_dtype_target = get_normalized_nth_input(node, gm, 1) 139 assert ( 140 cur_node_dtype_target is torch.float16 141 ), f"{cur_node_dtype_target} handling needs to be added" 142 143 return (prev_node_output_type, NodeInputOrOutputType.FP16) 144 145 elif node.target in METHS_IO_TYPE_FP32_OR_INT8: 146 first_arg = get_normalized_nth_input(node, gm, 0) 147 assert isinstance(first_arg, Node) 148 ( 149 _prev_node_input_type, 150 prev_node_output_type, 151 ) = get_node_first_input_and_output_type( 152 first_arg, gm, logger_cls, node_type_to_io_type_map 153 ) 154 return (prev_node_output_type, prev_node_output_type) 155 156 return (NodeInputOrOutputType.UNKNOWN, NodeInputOrOutputType.UNKNOWN) 157 else: 158 return (NodeInputOrOutputType.UNKNOWN, NodeInputOrOutputType.UNKNOWN) 159 160 161def get_node_input_qparams( 162 node: Node, 163 gm: GraphModule, 164 node_type_to_io_type_map: Dict[str, Set[NSNodeTargetType]], 165) -> Optional[Tuple[Union[torch.Tensor, float], Union[torch.Tensor, int]]]: 166 """ 167 Returns the qparams (scale, zero_point) of the first input to `node`, 168 if they can be inferred from the graph. 169 """ 170 prev_node = get_normalized_nth_input(node, gm, 0) 171 172 if not isinstance(prev_node, Node): 173 return None 174 175 MODS_IO_TYPE_FP32_OR_INT8 = node_type_to_io_type_map["mods_io_type_fp32_or_int8"] 176 177 def _get_scale_zp_from_function_args(node, gm, scale_arg_idx, zp_arg_idx): 178 scale_node = get_normalized_nth_input(node, gm, scale_arg_idx) 179 zp_node = get_normalized_nth_input(node, gm, zp_arg_idx) 180 assert isinstance(scale_node, Node) and isinstance(scale_node.target, str) 181 assert isinstance(zp_node, Node) and isinstance(zp_node.target, str) 182 scale_obj = getattr_from_fqn(gm, scale_node.target) 183 zp_obj = getattr_from_fqn(gm, zp_node.target) 184 return (scale_obj, zp_obj) 185 186 if prev_node.op == "call_function": 187 # quantize - read the args directly 188 if prev_node.target == torch.quantize_per_tensor: 189 return _get_scale_zp_from_function_args(prev_node, gm, 1, 2) 190 elif prev_node.target in (toq.add, toq.add_relu, toq.mul, toq.mul_relu): 191 return _get_scale_zp_from_function_args(prev_node, gm, 2, 3) 192 193 return None 194 # TODO(future PR): handle more functionals 195 # TODO(future PR): handle functional ops which inherit qparams from input 196 197 elif prev_node.op == "call_module": 198 # get type of the module 199 assert isinstance(prev_node.target, str) 200 module_obj = getattr_from_fqn(gm, prev_node.target) 201 if isinstance( 202 module_obj, 203 ( 204 nnq.Linear, 205 nnq.Conv1d, 206 nnq.Conv2d, 207 nniq.ConvReLU2d, 208 nnq.Conv3d, 209 nnq.BatchNorm2d, 210 nnq.BatchNorm3d, 211 nnq.ConvTranspose1d, 212 nnq.ConvTranspose2d, 213 nnq.ELU, 214 nnq.GroupNorm, 215 nnq.InstanceNorm1d, 216 nnq.InstanceNorm2d, 217 nnq.InstanceNorm3d, 218 nnq.LayerNorm, 219 nnq.Hardswish, 220 nnq.LeakyReLU, 221 nnq.ReLU6, 222 nniq.BNReLU2d, 223 nniq.BNReLU3d, 224 nniq.ConvReLU1d, 225 nniq.ConvReLU2d, 226 nniq.ConvReLU3d, 227 nniq.LinearReLU, 228 ), 229 ): 230 return (module_obj.scale, module_obj.zero_point) # type: ignore[return-value] 231 232 is_known_fp32_or_int8_input_module = any( 233 isinstance(module_obj, target_type) for target_type in MODS_IO_TYPE_FP32_OR_INT8 # type: ignore[arg-type] 234 ) 235 if is_known_fp32_or_int8_input_module: 236 return get_node_input_qparams(prev_node, gm, node_type_to_io_type_map) 237 238 return None 239 240 241def return_first_non_observer_node( 242 node: Node, 243 gm: GraphModule, 244) -> Node: 245 """ 246 If node is not an observer, returns it. If node is an observer, 247 navigates up the graph and returns the first parent which is not an 248 observer. For example, 249 250 graph: (node_non_obs), node = node_non_obs : returns node_non_obs 251 graph: (node_non_obs -> obs0), node = obs0 : returns node_non_obs 252 graph: (node_non_obs -> obs0 -> fq0), node = fq0 : returns node_non_obs 253 """ 254 if node.op == "call_module": 255 node_obj = getattr_from_fqn(gm, node.target) # type: ignore[arg-type] 256 if _is_activation_post_process(node_obj): 257 assert len(node.args) == 1 258 assert isinstance(node.args[0], Node) 259 node = node.args[0] 260 # code duplication intended, not worth refactoring 261 assert isinstance(node.target, str) 262 node_obj = getattr_from_fqn(gm, node.target) 263 if _is_activation_post_process(node_obj): 264 assert len(node.args) == 1 265 assert isinstance(node.args[0], Node) 266 node = node.args[0] 267 return node 268 269 270def get_number_of_non_param_args( 271 node: Node, 272 gm: GraphModule, 273) -> int: 274 """ 275 Assumes that all non-param args occur first. Returns the number of 276 non-param args expected for a node. For example, for 277 278 F.linear(x, weight, bias) 279 280 Returns 1, because x is a non-param arg and weight and bias are params. 281 For 282 283 lstm_mod(x, hid) 284 285 Returns 2, because both x and hid are non-param args. 286 """ 287 if node.op == "call_module": 288 node_obj = getattr_from_fqn(gm, node.target) # type: ignore[arg-type] 289 if isinstance(node_obj, nn.LSTM): 290 return 2 291 292 # default is 1 293 return 1 294 295 296def get_arg_indices_of_inputs_to_log(node: Node) -> List[int]: 297 """ 298 Returns the indices of args of the node which we should attach 299 loggers to, if input logging is enabled. 300 301 For example, 302 * for (x + y), returns [0, 1] 303 * for (1 + y), returns [1] 304 * for (x + 1), returns [0] 305 * for (linear(x, w, b)) returns [0] 306 * by default, returns [0] 307 """ 308 if len(node.args) == 0: 309 return [] 310 if node.op == "call_function" and ( 311 # TODO(future PR): use relationship map instead of hardcoding 312 node.target in (torch.add, torch.ops.quantized.add, operator.add) 313 or node.target in (torch.mul, torch.ops.quantized.mul, operator.mul) 314 ): 315 result = [] 316 for i in range(2): 317 if type(node.args[i]) == Node: 318 result.append(i) 319 return result 320 return [0] 321 322 323def get_target_type_str(node: Node, gm: GraphModule) -> str: 324 """ 325 Returns a string representation of the type of the function or module 326 pointed to by this node, or '' for other node types. 327 """ 328 target_type = "" 329 if node.op in ("call_function", "call_method"): 330 target_type = torch.typename(node.target) 331 elif node.op == "call_module": 332 assert isinstance(node.target, str) 333 target_mod = getattr_from_fqn(gm, node.target) 334 target_type = torch.typename(target_mod) 335 return target_type 336 337 338def rekey_logger_info_on_node_name_of_model( 339 results: NSResultsType, 340 model_name: str, 341) -> NSResultsType: 342 """ 343 Rekeys the layer name of a results dictionary to use node names 344 from `model_name`. 345 346 For example, transforms 347 348 {'base_op_1_0': {'node_output': {'model_a': 349 [{'ref_node_name': 'linear1', ...}]}}} 350 351 into 352 353 {'linear1': {'node_output': {'model_a': 354 [{'ref_node_name': 'linear1', ...}]}}} 355 356 Note: we cannot use these node names directly because they are not 357 guaranteed to be consistent across models. This is why we extract 358 the results first and rekey afterwards. 359 """ 360 new_results = {} 361 for old_layer_name, result_type_to_results in results.items(): 362 new_layer_name = None 363 for model_name_to_results in result_type_to_results.values(): 364 for cur_model_name, list_of_results in model_name_to_results.items(): 365 if cur_model_name == model_name: 366 assert len(list_of_results) 367 new_layer_name = list_of_results[0]["ref_node_name"] 368 else: 369 continue 370 if new_layer_name is not None: 371 new_results[new_layer_name] = result_type_to_results 372 else: 373 new_results[old_layer_name] = result_type_to_results 374 return new_results 375 376 377def maybe_add_missing_fqns(results: NSResultsType) -> None: 378 """ 379 If `fqn` entries are filled in for one of the models in `results`, copies 380 them over to any models which do not have them filled out. 381 382 A common use case benefitting from this is comparing a model prepared by 383 quantization to a quantized model. In this case, the model prepared by 384 quantization would have `fqn` entries, and the quantized model would not. 385 """ 386 387 # Check in the first result to find any model with fqn entries defined. 388 model_name_with_fqns = None 389 for result_type_to_results in results.values(): 390 for model_name_to_results in result_type_to_results.values(): 391 for model_name, model_results in model_name_to_results.items(): 392 if len(model_results) > 0: 393 if model_results[0]["fqn"] is not None: 394 model_name_with_fqns = model_name 395 break 396 break 397 break 398 399 if model_name_with_fqns: 400 for result_type_to_results in results.values(): 401 for model_name_to_results in result_type_to_results.values(): 402 ref_model_results = model_name_to_results[model_name_with_fqns] 403 for model_name, model_results in model_name_to_results.items(): 404 if model_name == model_name_with_fqns: 405 continue 406 for i in range(len(model_results)): 407 fqn = ref_model_results[i]["fqn"] 408 model_results[i]["fqn"] = fqn 409 410 411def maybe_dequantize_first_two_tensor_args_and_handle_tuples(f): 412 def inner(*args, **kwargs): 413 a0, a1, *a_other = args 414 415 if (isinstance(a0, tuple) and isinstance(a1, tuple)) or ( 416 isinstance(a0, list) and isinstance(a1, list) 417 ): 418 results = [] 419 for el0, el1 in zip(a0, a1): 420 new_args = (el0, el1, *a_other) 421 results.append(inner(*new_args, **kwargs)) 422 return results 423 424 elif isinstance(a0, torch.Tensor) and isinstance(a1, torch.Tensor): 425 if a0.is_quantized: 426 a0 = a0.dequantize() 427 if a1.is_quantized: 428 a1 = a1.dequantize() 429 430 # for the purposes of this util, only handle floats 431 if a0.dtype != torch.float or a1.dtype != torch.float: 432 return None 433 434 new_args = (a0, a1, *a_other) 435 return f(*new_args, **kwargs) 436 437 return inner 438 439 440@maybe_dequantize_first_two_tensor_args_and_handle_tuples 441def compute_sqnr(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 442 """ 443 Computes the SQNR between `x` and `y`. 444 445 Args: 446 x: Tensor or tuple of tensors 447 y: Tensor or tuple of tensors 448 449 Return: 450 float or tuple of floats 451 """ 452 Ps = torch.norm(x) 453 Pn = torch.norm(x - y) 454 return 20 * torch.log10(Ps / Pn) 455 456 457@maybe_dequantize_first_two_tensor_args_and_handle_tuples 458def compute_normalized_l2_error(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 459 """ 460 Computes the normalized L2 error between `x` and `y`. 461 462 Args: 463 x: Tensor or tuple of tensors 464 y: Tensor or tuple of tensors 465 466 Return: 467 float or tuple of floats 468 """ 469 return torch.sqrt(((x - y) ** 2).sum() / (x**2).sum()) 470 471 472@maybe_dequantize_first_two_tensor_args_and_handle_tuples 473def compute_cosine_similarity(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 474 """ 475 Computes the cosine similarity between `x` and `y`. 476 477 Args: 478 x: Tensor or tuple of tensors 479 y: Tensor or tuple of tensors 480 481 Return: 482 float or tuple of floats 483 """ 484 # For convolutions, the shape of the quantized weight has one additional 485 # dimension compared to the shape of the fp32 weight. Match the shapes 486 # to enable cosine similarity comparison. 487 x = x.reshape(1, -1) 488 y = y.reshape(1, -1) 489 return torch.nn.functional.cosine_similarity(x, y) 490 491 492def op_type_supports_shadowing(node: Node) -> bool: 493 if node.op == "call_function": 494 if node.target in ( 495 torch.add, 496 torch.mul, 497 operator.add, 498 operator.mul, 499 torch.cat, 500 torch.stack, 501 ): 502 # shadowing for ops with multiple tensor inputs is not implemented yet 503 return False 504 return True 505 506 507def get_normalized_nth_input(node: Node, gm: GraphModule, idx: int) -> Node: 508 """ 509 Given a node, gets the n'th input to that node, normalizing 510 args and kwargs to the best of its ability. 511 """ 512 try: 513 norm_args_and_kwargs = node.normalized_arguments( 514 gm, normalize_to_only_use_kwargs=True 515 ) 516 if norm_args_and_kwargs is not None: 517 norm_args, norm_kwargs = norm_args_and_kwargs 518 assert len(norm_args) + len(norm_kwargs) > idx 519 if idx < len(norm_args): 520 return norm_args[idx] 521 else: 522 # note: in Python 3.7+ dicts are ordered 523 return list(norm_kwargs.values())[idx] 524 else: 525 assert len(node.args) + len(node.kwargs) > idx 526 if idx < len(node.args): 527 return node.args[idx] # type: ignore[return-value] 528 else: 529 kwargs_idx = idx + len(node.args) 530 return list(node.kwargs.values())[kwargs_idx] # type: ignore[return-value] 531 except RuntimeError: 532 # this RuntimeError happens when node argument normalization 533 # requires typehints to proceed, such as for torch.add where 534 # either the first, second or both arguments could be tensors 535 assert len(node.args) + len(node.kwargs) > idx 536 if idx < len(node.args): 537 return node.args[idx] # type: ignore[return-value] 538 else: 539 kwargs_idx = idx + len(node.args) 540 return list(node.kwargs.values())[kwargs_idx] # type: ignore[return-value] 541