1# mypy: allow-untyped-decorators 2# mypy: allow-untyped-defs 3import itertools 4from dataclasses import dataclass 5from typing import Callable, Dict, List, NamedTuple, Optional 6 7import torch 8import torch.nn.functional as F 9from torch._subclasses import FakeTensor 10from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix 11from torch.ao.quantization.pt2e.export_utils import _WrapperModule 12from torch.ao.quantization.pt2e.utils import ( 13 _conv1d_bn_example_inputs, 14 _conv2d_bn_example_inputs, 15 _get_aten_graph_module_for_pattern, 16 _is_conv_node, 17 _is_conv_transpose_node, 18) 19from torch.ao.quantization.quantizer import ( 20 QuantizationAnnotation, 21 QuantizationSpec, 22 QuantizationSpecBase, 23 SharedQuantizationSpec, 24) 25from torch.ao.quantization.quantizer.utils import ( 26 _annotate_input_qspec_map, 27 _annotate_output_qspec, 28) 29from torch.fx import Node 30from torch.fx.passes.utils.matcher_with_name_node_map_utils import ( 31 SubgraphMatcherWithNameNodeMap, 32) 33from torch.fx.passes.utils.source_matcher_utils import get_source_partitions 34 35 36__all__ = [ 37 "OperatorConfig", 38 "OperatorPatternType", 39 "QuantizationConfig", 40 "get_input_act_qspec", 41 "get_output_act_qspec", 42 "get_weight_qspec", 43 "get_bias_qspec", 44 "OP_TO_ANNOTATOR", 45 "propagate_annotation", 46] 47 48 49# In the absence of better name, just winging it with QuantizationConfig 50@dataclass(eq=True, frozen=True) 51class QuantizationConfig: 52 input_activation: Optional[QuantizationSpec] 53 output_activation: Optional[QuantizationSpec] 54 weight: Optional[QuantizationSpec] 55 bias: Optional[QuantizationSpec] 56 # TODO: remove, since we can use observer_or_fake_quant_ctr to express this 57 is_qat: bool = False 58 59 60OperatorPatternType = List[Callable] 61OperatorPatternType.__module__ = ( 62 "torch.ao.quantization.quantizer.xnnpack_quantizer_utils" 63) 64 65AnnotatorType = Callable[ 66 [ 67 torch.fx.GraphModule, 68 Optional[QuantizationConfig], 69 Optional[Callable[[Node], bool]], 70 ], 71 Optional[List[List[Node]]], 72] 73OP_TO_ANNOTATOR: Dict[str, AnnotatorType] = {} 74 75 76def register_annotator(op: str): 77 def decorator(annotator: AnnotatorType): 78 OP_TO_ANNOTATOR[op] = annotator 79 80 return decorator 81 82 83class OperatorConfig(NamedTuple): 84 # fix List[str] with List[List[Union[nn.Module, FunctionType, BuiltinFunctionType]]] 85 # Basically we are mapping a quantization config to some list of patterns. 86 # a pattern is defined as a list of nn module, function or builtin function names 87 # e.g. [nn.Conv2d, torch.relu, torch.add] 88 # We have not resolved whether fusion can be considered internal details of the 89 # quantizer hence it does not need communication to user. 90 # Note this pattern is not really informative since it does not really 91 # tell us the graph structure resulting from the list of ops. 92 config: QuantizationConfig 93 operators: List[OperatorPatternType] 94 95 96def _is_annotated(nodes: List[Node]): 97 """ 98 Given a list of nodes (that represents an operator pattern), 99 check if any of the node is annotated, return True if any of the node 100 is annotated, otherwise return False 101 """ 102 annotated = False 103 for node in nodes: 104 annotated = annotated or ( 105 "quantization_annotation" in node.meta 106 and node.meta["quantization_annotation"]._annotated 107 ) 108 return annotated 109 110 111def _mark_nodes_as_annotated(nodes: List[Node]): 112 for node in nodes: 113 if node is not None: 114 if "quantization_annotation" not in node.meta: 115 node.meta["quantization_annotation"] = QuantizationAnnotation() 116 node.meta["quantization_annotation"]._annotated = True 117 118 119def get_input_act_qspec(quantization_config: Optional[QuantizationConfig]): 120 if quantization_config is None: 121 return None 122 if quantization_config.input_activation is None: 123 return None 124 quantization_spec: QuantizationSpec = quantization_config.input_activation 125 assert quantization_spec.qscheme in [ 126 torch.per_tensor_affine, 127 torch.per_tensor_symmetric, 128 ] 129 return quantization_spec 130 131 132def get_output_act_qspec(quantization_config: Optional[QuantizationConfig]): 133 if quantization_config is None: 134 return None 135 if quantization_config.output_activation is None: 136 return None 137 quantization_spec: QuantizationSpec = quantization_config.output_activation 138 assert quantization_spec.qscheme in [ 139 torch.per_tensor_affine, 140 torch.per_tensor_symmetric, 141 ] 142 return quantization_spec 143 144 145def get_weight_qspec(quantization_config: Optional[QuantizationConfig]): 146 if quantization_config is None: 147 return None 148 assert quantization_config is not None 149 if quantization_config.weight is None: 150 return None 151 quantization_spec: QuantizationSpec = quantization_config.weight 152 if quantization_spec.qscheme not in [ 153 torch.per_tensor_symmetric, 154 torch.per_channel_symmetric, 155 ]: 156 raise ValueError( 157 f"Unsupported quantization_spec {quantization_spec} for weight" 158 ) 159 return quantization_spec 160 161 162def get_bias_qspec(quantization_config: Optional[QuantizationConfig]): 163 if quantization_config is None: 164 return None 165 assert quantization_config is not None 166 if quantization_config.bias is None: 167 return None 168 quantization_spec: QuantizationSpec = quantization_config.bias 169 assert ( 170 quantization_spec.dtype == torch.float 171 ), "Only float dtype for bias is supported for bias right now" 172 return quantization_spec 173 174 175@register_annotator("linear") 176def _annotate_linear( 177 gm: torch.fx.GraphModule, 178 quantization_config: Optional[QuantizationConfig], 179 filter_fn: Optional[Callable[[Node], bool]] = None, 180) -> Optional[List[List[Node]]]: 181 annotated_partitions = [] 182 input_act_qspec = get_input_act_qspec(quantization_config) 183 output_act_qspec = get_output_act_qspec(quantization_config) 184 weight_qspec = get_weight_qspec(quantization_config) 185 bias_qspec = get_bias_qspec(quantization_config) 186 for node in gm.graph.nodes: 187 if node.op != "call_function" or node.target != torch.ops.aten.linear.default: 188 continue 189 if filter_fn and not filter_fn(node): 190 continue 191 act_node = node.args[0] 192 weight_node = node.args[1] 193 bias_node = None 194 if len(node.args) > 2: 195 bias_node = node.args[2] 196 197 if _is_annotated([node]) is False: # type: ignore[list-item] 198 _annotate_input_qspec_map( 199 node, 200 act_node, 201 input_act_qspec, 202 ) 203 _annotate_input_qspec_map( 204 node, 205 weight_node, 206 weight_qspec, 207 ) 208 nodes_to_mark_annotated = [node, weight_node] 209 if bias_node: 210 _annotate_input_qspec_map( 211 node, 212 bias_node, 213 bias_qspec, 214 ) 215 nodes_to_mark_annotated.append(bias_node) 216 _annotate_output_qspec(node, output_act_qspec) 217 _mark_nodes_as_annotated(nodes_to_mark_annotated) 218 annotated_partitions.append(nodes_to_mark_annotated) 219 220 return annotated_partitions 221 222 223@register_annotator("linear_relu") 224def _annotate_linear_relu( 225 gm: torch.fx.GraphModule, 226 quantization_config: Optional[QuantizationConfig], 227 filter_fn: Optional[Callable[[Node], bool]] = None, 228) -> Optional[List[List[Node]]]: 229 annotated_partitions = [] 230 input_act_qspec = get_input_act_qspec(quantization_config) 231 output_act_qspec = get_output_act_qspec(quantization_config) 232 weight_qspec = get_weight_qspec(quantization_config) 233 bias_qspec = get_bias_qspec(quantization_config) 234 for node in gm.graph.nodes: 235 if node.op != "call_function" or node.target not in [ 236 torch.ops.aten.relu.default, 237 torch.ops.aten.relu_.default, 238 ]: 239 continue 240 relu_node = node 241 maybe_linear_node = node.args[0] 242 if ( 243 not isinstance(maybe_linear_node, Node) 244 or maybe_linear_node.op != "call_function" 245 or maybe_linear_node.target != torch.ops.aten.linear.default 246 ): 247 continue 248 249 linear_node = maybe_linear_node 250 input_qspec_map = {} 251 input_act = linear_node.args[0] 252 assert isinstance(input_act, Node) 253 input_qspec_map[input_act] = input_act_qspec 254 255 weight = linear_node.args[1] 256 assert isinstance(weight, Node) 257 input_qspec_map[weight] = weight_qspec 258 259 # adding weight node to the partition as well 260 partition = [relu_node, linear_node, weight] 261 bias = linear_node.args[2] if len(linear_node.args) > 2 else None 262 if isinstance(bias, Node): 263 input_qspec_map[bias] = bias_qspec 264 partition.append(bias) 265 266 if _is_annotated(partition): 267 continue 268 269 if filter_fn and any(not filter_fn(n) for n in partition): 270 continue 271 272 linear_node.meta["quantization_annotation"] = QuantizationAnnotation( 273 input_qspec_map=input_qspec_map, 274 _annotated=True, 275 ) 276 relu_node.meta["quantization_annotation"] = QuantizationAnnotation( 277 output_qspec=output_act_qspec, 278 _annotated=True, 279 ) 280 _mark_nodes_as_annotated(partition) 281 annotated_partitions.append(partition) 282 return annotated_partitions 283 284 285@register_annotator("conv") 286def _annotate_conv( 287 gm: torch.fx.GraphModule, 288 quantization_config: Optional[QuantizationConfig], 289 filter_fn: Optional[Callable[[Node], bool]] = None, 290) -> Optional[List[List[Node]]]: 291 annotated_partitions = [] 292 for n in gm.graph.nodes: 293 if n.op != "call_function" or n.target not in [ 294 torch.ops.aten.conv1d.default, 295 torch.ops.aten.conv2d.default, 296 ]: 297 continue 298 conv_node = n 299 300 input_qspec_map = {} 301 input_act = conv_node.args[0] 302 assert isinstance(input_act, Node) 303 input_qspec_map[input_act] = get_input_act_qspec(quantization_config) 304 305 weight = conv_node.args[1] 306 assert isinstance(weight, Node) 307 input_qspec_map[weight] = get_weight_qspec(quantization_config) 308 309 # adding weight node to the partition as well 310 partition = [conv_node, conv_node.args[1]] 311 312 bias = conv_node.args[2] if len(conv_node.args) > 2 else None 313 if isinstance(bias, Node): 314 input_qspec_map[bias] = get_bias_qspec(quantization_config) 315 partition.append(bias) 316 317 if _is_annotated(partition): 318 continue 319 320 if filter_fn and any(not filter_fn(n) for n in partition): 321 continue 322 323 conv_node.meta["quantization_annotation"] = QuantizationAnnotation( 324 input_qspec_map=input_qspec_map, 325 output_qspec=get_output_act_qspec(quantization_config), 326 _annotated=True, 327 ) 328 _mark_nodes_as_annotated(partition) 329 annotated_partitions.append(partition) 330 return annotated_partitions 331 332 333def _do_annotate_conv_relu( 334 gm: torch.fx.GraphModule, 335 quantization_config: Optional[QuantizationConfig], 336 filter_fn: Optional[Callable[[Node], bool]] = None, 337 is_conv_transpose: bool = False, 338): 339 annotated_partitions = [] 340 for n in gm.graph.nodes: 341 if n.op != "call_function" or n.target not in [ 342 torch.ops.aten.relu.default, 343 torch.ops.aten.relu_.default, 344 ]: 345 continue 346 relu_node = n 347 maybe_conv_node = n.args[0] 348 349 is_conv_node = _is_conv_transpose_node if is_conv_transpose else _is_conv_node 350 if not isinstance(maybe_conv_node, Node) or not is_conv_node(maybe_conv_node): 351 continue 352 conv_node = maybe_conv_node 353 354 input_qspec_map = {} 355 input_act = conv_node.args[0] 356 assert isinstance(input_act, Node) 357 input_qspec_map[input_act] = get_input_act_qspec(quantization_config) 358 359 weight = conv_node.args[1] 360 assert isinstance(weight, Node) 361 input_qspec_map[weight] = get_weight_qspec(quantization_config) 362 363 # adding weight node to the partition as well 364 partition = [relu_node, conv_node, conv_node.args[1]] 365 bias = conv_node.args[2] if len(conv_node.args) > 2 else None 366 if isinstance(bias, Node): 367 input_qspec_map[bias] = get_bias_qspec(quantization_config) 368 partition.append(bias) 369 370 if _is_annotated(partition): 371 continue 372 373 if filter_fn and any(not filter_fn(n) for n in partition): 374 continue 375 376 conv_node.meta["quantization_annotation"] = QuantizationAnnotation( 377 input_qspec_map=input_qspec_map, _annotated=True 378 ) 379 relu_node.meta["quantization_annotation"] = QuantizationAnnotation( 380 output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type] 381 _annotated=True, 382 ) 383 _mark_nodes_as_annotated(partition) 384 annotated_partitions.append(partition) 385 return annotated_partitions 386 387 388@register_annotator("conv_relu") 389def _annotate_conv_relu( 390 gm: torch.fx.GraphModule, 391 quantization_config: Optional[QuantizationConfig], 392 filter_fn: Optional[Callable[[Node], bool]] = None, 393) -> Optional[List[List[Node]]]: 394 return _do_annotate_conv_relu( 395 gm, quantization_config, filter_fn, is_conv_transpose=False 396 ) 397 398 399@register_annotator("conv_transpose_relu") 400def _annotate_conv_transpose_relu( 401 gm: torch.fx.GraphModule, 402 quantization_config: Optional[QuantizationConfig], 403 filter_fn: Optional[Callable[[Node], bool]] = None, 404) -> Optional[List[List[Node]]]: 405 return _do_annotate_conv_relu( 406 gm, quantization_config, filter_fn, is_conv_transpose=True 407 ) 408 409 410@register_annotator("conv_bn") 411def _annotate_conv_bn( 412 gm: torch.fx.GraphModule, 413 quantization_config: Optional[QuantizationConfig], 414 filter_fn: Optional[Callable[[Node], bool]] = None, 415) -> Optional[List[List[Node]]]: 416 """ 417 Find conv + batchnorm parititions 418 Note: This is only used for QAT. In PTQ, batchnorm should already be fused into the conv. 419 """ 420 return _do_annotate_conv_bn(gm, quantization_config, filter_fn, has_relu=False) 421 422 423@register_annotator("conv_bn_relu") 424def _annotate_conv_bn_relu( 425 gm: torch.fx.GraphModule, 426 quantization_config: Optional[QuantizationConfig], 427 filter_fn: Optional[Callable[[Node], bool]] = None, 428) -> Optional[List[List[Node]]]: 429 """ 430 Find conv + batchnorm + relu parititions 431 Note: This is only used for QAT. In PTQ, batchnorm should already be fused into the conv. 432 """ 433 return _do_annotate_conv_bn(gm, quantization_config, filter_fn, has_relu=True) 434 435 436@register_annotator("conv_transpose_bn") 437def _annotate_conv_transpose_bn( 438 gm: torch.fx.GraphModule, 439 quantization_config: Optional[QuantizationConfig], 440 filter_fn: Optional[Callable[[Node], bool]] = None, 441) -> Optional[List[List[Node]]]: 442 """ 443 Find conv_transpose + batchnorm parititions 444 Note: This is only used for QAT. In PTQ, batchnorm should already be fused into the conv. 445 """ 446 return _do_annotate_conv_bn( 447 gm, quantization_config, filter_fn, has_relu=False, is_conv_transpose=True 448 ) 449 450 451@register_annotator("conv_transpose_bn_relu") 452def _annotate_conv_transpose_bn_relu( 453 gm: torch.fx.GraphModule, 454 quantization_config: Optional[QuantizationConfig], 455 filter_fn: Optional[Callable[[Node], bool]] = None, 456) -> Optional[List[List[Node]]]: 457 """ 458 Find conv_transpose + batchnorm + relu parititions 459 Note: This is only used for QAT. In PTQ, batchnorm should already be fused into the conv. 460 """ 461 return _do_annotate_conv_bn( 462 gm, quantization_config, filter_fn, has_relu=True, is_conv_transpose=True 463 ) 464 465 466def _do_annotate_conv_bn( 467 gm: torch.fx.GraphModule, 468 quantization_config: Optional[QuantizationConfig], 469 filter_fn: Optional[Callable[[Node], bool]], 470 has_relu: bool, 471 is_conv_transpose: bool = False, 472) -> List[List[Node]]: 473 """ 474 Given a function that takes in a `conv_fn` and returns a conv-bn[-relu] pattern, 475 return a list of annotated partitions. 476 477 The output of the pattern must include a dictionary from string name to node 478 for the following names: "input", "conv", "weight", "bias", and "output". 479 """ 480 481 def get_pattern(conv_fn: Callable, relu_is_inplace: bool): 482 def _conv_bn(x, conv_weight, conv_bias, bn_weight, bn_bias, bn_rm, bn_rv): 483 conv = conv_fn(x, conv_weight, conv_bias) 484 bn = F.batch_norm(conv, bn_rm, bn_rv, bn_weight, bn_bias, training=True) 485 if has_relu: 486 output = F.relu_(bn) if relu_is_inplace else F.relu(bn) 487 else: 488 output = bn 489 return output, { 490 "input": x, 491 "conv": conv, 492 "weight": conv_weight, 493 "bias": conv_bias, 494 "output": output, 495 } 496 497 return _WrapperModule(_conv_bn) 498 499 # Needed for matching, otherwise the matches gets filtered out due to unused 500 # nodes returned by batch norm 501 gm.graph.eliminate_dead_code() 502 gm.recompile() 503 504 matches = [] 505 if is_conv_transpose: 506 combinations = [ 507 (F.conv_transpose1d, _conv1d_bn_example_inputs), 508 (F.conv_transpose2d, _conv2d_bn_example_inputs), 509 ] 510 else: 511 combinations = [ 512 (F.conv1d, _conv1d_bn_example_inputs), # type: ignore[list-item] 513 (F.conv2d, _conv2d_bn_example_inputs), # type: ignore[list-item] 514 ] 515 516 # Add `is_cuda` and `relu_is_inplace` dimensions 517 combinations = itertools.product( # type: ignore[assignment] 518 combinations, 519 [True, False] if torch.cuda.is_available() else [False], # is_cuda 520 [True, False] if has_relu else [False], # relu_is_inplace 521 ) 522 523 # Match against all conv dimensions and cuda variants 524 for (conv_fn, example_inputs), is_cuda, relu_is_inplace in combinations: # type: ignore[misc] 525 pattern = get_pattern(conv_fn, relu_is_inplace) # type: ignore[has-type] 526 pattern = _get_aten_graph_module_for_pattern(pattern, example_inputs, is_cuda) # type: ignore[has-type] 527 pattern.graph.eliminate_dead_code() 528 pattern.recompile() 529 matcher = SubgraphMatcherWithNameNodeMap(pattern, ignore_literals=True) 530 matches.extend(matcher.match(gm.graph)) 531 532 # Annotate nodes returned in the matches 533 annotated_partitions = [] 534 for match in matches: 535 name_node_map = match.name_node_map 536 input_node = name_node_map["input"] 537 conv_node = name_node_map["conv"] 538 weight_node = name_node_map["weight"] 539 bias_node = name_node_map["bias"] 540 output_node = name_node_map["output"] 541 542 # TODO: annotate the uses of input, weight, and bias separately instead 543 # of assuming they come from a single conv node. This is not possible today 544 # because input may have multiple users, and we can't rely on the conv node 545 # always being the first user. This was the case in models with skip 546 # connections like resnet18 547 548 # Validate conv args 549 if conv_node.args[0] is not input_node: 550 raise ValueError("Conv arg did not contain input node ", input_node) 551 if conv_node.args[1] is not weight_node: 552 raise ValueError("Conv arg did not contain weight node ", weight_node) 553 if len(conv_node.args) > 2 and conv_node.args[2] is not bias_node: 554 raise ValueError("Conv arg did not contain bias node ", bias_node) 555 556 # Skip if the partition is already annotated or is filtered out by the user 557 partition = [conv_node, weight_node] 558 if bias_node is not None: 559 partition.append(bias_node) 560 if _is_annotated(partition): 561 continue 562 if filter_fn and any(not filter_fn(n) for n in partition): 563 continue 564 565 # Annotate conv inputs and pattern output 566 input_qspec_map = {} 567 input_qspec_map[input_node] = get_input_act_qspec(quantization_config) 568 input_qspec_map[weight_node] = get_weight_qspec(quantization_config) 569 if bias_node is not None: 570 input_qspec_map[bias_node] = get_bias_qspec(quantization_config) 571 conv_node.meta["quantization_annotation"] = QuantizationAnnotation( 572 input_qspec_map=input_qspec_map, 573 _annotated=True, 574 ) 575 output_node.meta["quantization_annotation"] = QuantizationAnnotation( 576 output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type] 577 _annotated=True, 578 ) 579 _mark_nodes_as_annotated(partition) 580 annotated_partitions.append(partition) 581 return annotated_partitions 582 583 584@register_annotator("gru_io_only") 585def _annotate_gru_io_only( 586 gm: torch.fx.GraphModule, 587 quantization_config: Optional[QuantizationConfig], 588 filter_fn: Optional[Callable[[Node], bool]] = None, 589) -> Optional[List[List[Node]]]: 590 gru_partitions = get_source_partitions(gm.graph, [torch.nn.GRU], filter_fn) 591 gru_partitions = list(itertools.chain.from_iterable(gru_partitions.values())) 592 annotated_partitions = [] 593 for gru_partition in gru_partitions: 594 annotated_partitions.append(gru_partition.nodes) 595 output_nodes = gru_partition.output_nodes 596 input_nodes = gru_partition.input_nodes 597 # skip annotation if it is already annotated 598 if _is_annotated(input_nodes + output_nodes): 599 continue 600 # inside each GRU partition, we should be able to annotate each linear 601 # subgraph 602 input_qspec_map: Dict[Node, QuantizationSpecBase] = {} 603 input_act = input_nodes[0] 604 input_act_user = next(iter(input_act.users.keys())) 605 assert isinstance(input_act, Node) 606 assert isinstance(input_act_user, Node) 607 input_act_user.meta["quantization_annotation"] = QuantizationAnnotation( 608 input_qspec_map={ 609 input_act: get_input_act_qspec(quantization_config), 610 }, 611 _annotated=True, 612 ) 613 614 hidden_state = input_nodes[1] 615 hidden_state_user = next(iter(hidden_state.users.keys())) 616 assert isinstance(hidden_state, Node) 617 assert isinstance(hidden_state_user, Node) 618 hidden_state_user.meta["quantization_annotation"] = QuantizationAnnotation( 619 input_qspec_map={ 620 hidden_state: get_input_act_qspec(quantization_config), 621 }, 622 _annotated=True, 623 ) 624 625 assert len(output_nodes) == 2, "expecting GRU to have two outputs" 626 for output in output_nodes: 627 output.meta["quantization_annotation"] = QuantizationAnnotation( 628 output_qspec=get_output_act_qspec(quantization_config), 629 _annotated=True, 630 ) 631 nodes_to_mark_annotated = list(gru_partition.nodes) 632 _mark_nodes_as_annotated(nodes_to_mark_annotated) 633 return annotated_partitions 634 635 636@register_annotator("adaptive_avg_pool2d") 637def _annotate_adaptive_avg_pool2d( 638 gm: torch.fx.GraphModule, 639 quantization_config: Optional[QuantizationConfig], 640 filter_fn: Optional[Callable[[Node], bool]] = None, 641) -> Optional[List[List[Node]]]: 642 """Always annotate adaptive_avg_pool2d op""" 643 module_partitions = get_source_partitions( 644 gm.graph, [torch.nn.AdaptiveAvgPool2d, F.adaptive_avg_pool2d], filter_fn 645 ) 646 partitions = list(itertools.chain.from_iterable(module_partitions.values())) 647 annotated_partitions = [] 648 for partition in partitions: 649 pool_node = partition.output_nodes[0] 650 if ( 651 pool_node.op != "call_function" 652 or pool_node.target != torch.ops.aten.adaptive_avg_pool2d.default 653 ): 654 raise ValueError(f"{pool_node} is not an aten adaptive_avg_pool2d operator") 655 656 if _is_annotated([pool_node]): 657 continue 658 659 annotated_partitions.append(partition.nodes) 660 input_act = pool_node.args[0] 661 assert isinstance(input_act, Node) 662 663 # only annotate input output sharing operator 664 # when the output of the input node is annotated 665 if ( 666 "quantization_annotation" not in input_act.meta 667 or not input_act.meta["quantization_annotation"]._annotated 668 or input_act.meta["quantization_annotation"].output_qspec is None 669 ): 670 input_act_qspec = get_input_act_qspec(quantization_config) 671 else: 672 input_act_qspec = SharedQuantizationSpec(input_act) 673 674 # output sharing with input 675 output_act_qspec = SharedQuantizationSpec((input_act, pool_node)) 676 pool_node.meta["quantization_annotation"] = QuantizationAnnotation( 677 input_qspec_map={ 678 input_act: input_act_qspec, 679 }, 680 output_qspec=output_act_qspec, 681 _annotated=True, 682 ) 683 return annotated_partitions 684 685 686def _is_input_large_scalar(node: Node, gm: torch.fx.GraphModule): 687 """Check if input is a large scalar value. So that we can skip quantization for the node 688 since histc op (in HistogramObserver) only works for values up to certain upper bound 689 """ 690 if node.op == "get_attr": 691 qualified_name = str(node.target) 692 module_path, _, name = qualified_name.rpartition(".") 693 submod = gm.get_submodule(module_path) 694 tensor = getattr(submod, name) 695 # torch.histc works until this upper bound 696 HISTC_UPPER_BOUND = 3.4028235e15 697 return tensor.numel() == 1 and abs(tensor.item()) > HISTC_UPPER_BOUND 698 return False 699 700 701def _is_input_non_float_tensor(node: Node): 702 """Check if the input is not a float tensor, so that we can skip quantization for the node 703 since observers only works with float Tensors 704 """ 705 if "val" not in node.meta or not isinstance(node.meta["val"], FakeTensor): 706 return True 707 return node.meta["val"].dtype != torch.float32 708 709 710@register_annotator("add_relu") 711def _annotate_add_relu( 712 gm: torch.fx.GraphModule, 713 quantization_config: Optional[QuantizationConfig], 714 filter_fn: Optional[Callable[[Node], bool]] = None, 715) -> Optional[List[List[Node]]]: 716 annotated_partitions = [] 717 for node in gm.graph.nodes: 718 if node.op != "call_function" or node.target not in [ 719 torch.ops.aten.relu.default, 720 torch.ops.aten.relu_.default, 721 ]: 722 continue 723 relu_node = node 724 maybe_add = node.args[0] 725 if ( 726 not isinstance(maybe_add, Node) 727 or maybe_add.op != "call_function" 728 or maybe_add.target 729 not in [ 730 torch.ops.aten.add.Tensor, 731 torch.ops.aten.add_.Tensor, 732 ] 733 ): 734 continue 735 736 add_node = maybe_add 737 partition = [relu_node, add_node] 738 739 if _is_annotated(partition): 740 continue 741 742 if filter_fn and any(not filter_fn(n) for n in partition): 743 continue 744 745 input_act_qspec = get_input_act_qspec(quantization_config) 746 output_act_qspec = get_output_act_qspec(quantization_config) 747 748 input_qspec_map = {} 749 input_act0 = add_node.args[0] 750 if isinstance(input_act0, Node): 751 if _is_input_large_scalar(input_act0, gm): 752 continue 753 if _is_input_non_float_tensor(input_act0): 754 continue 755 partition.append(input_act0) 756 input_qspec_map[input_act0] = input_act_qspec 757 758 input_act1 = add_node.args[1] 759 if isinstance(input_act1, Node): 760 if _is_input_large_scalar(input_act1, gm): 761 continue 762 if _is_input_non_float_tensor(input_act1): 763 continue 764 partition.append(input_act1) 765 input_qspec_map[input_act1] = input_act_qspec 766 767 add_node.meta["quantization_annotation"] = QuantizationAnnotation( 768 input_qspec_map=input_qspec_map, 769 _annotated=True, 770 ) 771 relu_node.meta["quantization_annotation"] = QuantizationAnnotation( 772 output_qspec=output_act_qspec, 773 _annotated=True, 774 ) 775 annotated_partitions.append(partition) 776 return annotated_partitions 777 778 779@register_annotator("add") 780def _annotate_add( 781 gm: torch.fx.GraphModule, 782 quantization_config: Optional[QuantizationConfig], 783 filter_fn: Optional[Callable[[Node], bool]] = None, 784) -> Optional[List[List[Node]]]: 785 annotated_partitions = [] 786 for node in gm.graph.nodes: 787 if node.op != "call_function" or node.target not in [ 788 torch.ops.aten.add.Tensor, 789 torch.ops.aten.add_.Tensor, 790 ]: 791 continue 792 add_node = node 793 partition = [add_node] 794 795 if _is_annotated(partition): 796 continue 797 798 if filter_fn and any(not filter_fn(n) for n in partition): 799 continue 800 801 input_act_qspec = get_input_act_qspec(quantization_config) 802 output_act_qspec = get_output_act_qspec(quantization_config) 803 804 input_qspec_map = {} 805 input_act0 = add_node.args[0] 806 if isinstance(input_act0, Node): 807 if _is_input_large_scalar(input_act0, gm): 808 continue 809 if _is_input_non_float_tensor(input_act0): 810 continue 811 input_qspec_map[input_act0] = input_act_qspec 812 partition.append(input_act0) 813 814 input_act1 = add_node.args[1] 815 if isinstance(input_act1, Node): 816 if _is_input_large_scalar(input_act1, gm): 817 continue 818 if _is_input_non_float_tensor(input_act1): 819 continue 820 input_qspec_map[input_act1] = input_act_qspec 821 partition.append(input_act1) 822 823 add_node.meta["quantization_annotation"] = QuantizationAnnotation( 824 input_qspec_map=input_qspec_map, 825 output_qspec=output_act_qspec, 826 _annotated=True, 827 ) 828 annotated_partitions.append(partition) 829 return annotated_partitions 830 831 832@register_annotator("mul_relu") 833def _annotate_mul_relu( 834 gm: torch.fx.GraphModule, 835 quantization_config: Optional[QuantizationConfig], 836 filter_fn: Optional[Callable[[Node], bool]] = None, 837) -> Optional[List[List[Node]]]: 838 annotated_partitions = [] 839 for node in gm.graph.nodes: 840 if node.op != "call_function" or node.target not in [ 841 torch.ops.aten.relu.default, 842 torch.ops.aten.relu_.default, 843 ]: 844 continue 845 relu_node = node 846 maybe_mul = node.args[0] 847 if ( 848 not isinstance(maybe_mul, Node) 849 or maybe_mul.op != "call_function" 850 or maybe_mul.target 851 not in [ 852 torch.ops.aten.mul.Tensor, 853 torch.ops.aten.mul_.Tensor, 854 ] 855 ): 856 continue 857 858 mul_node = maybe_mul 859 partition = [relu_node, mul_node] 860 861 if _is_annotated(partition): 862 continue 863 864 if filter_fn and any(not filter_fn(n) for n in partition): 865 continue 866 867 input_act_qspec = get_input_act_qspec(quantization_config) 868 output_act_qspec = get_output_act_qspec(quantization_config) 869 870 input_qspec_map = {} 871 input_act0 = mul_node.args[0] 872 if isinstance(input_act0, Node): 873 if _is_input_large_scalar(input_act0, gm): 874 continue 875 if _is_input_non_float_tensor(input_act0): 876 continue 877 partition.append(input_act0) 878 input_qspec_map[input_act0] = input_act_qspec 879 880 input_act1 = mul_node.args[1] 881 if isinstance(input_act1, Node): 882 if _is_input_large_scalar(input_act1, gm): 883 continue 884 if _is_input_non_float_tensor(input_act1): 885 continue 886 partition.append(input_act1) 887 input_qspec_map[input_act1] = input_act_qspec 888 889 mul_node.meta["quantization_annotation"] = QuantizationAnnotation( 890 input_qspec_map=input_qspec_map, 891 _annotated=True, 892 ) 893 relu_node.meta["quantization_annotation"] = QuantizationAnnotation( 894 output_qspec=output_act_qspec, 895 _annotated=True, 896 ) 897 annotated_partitions.append(partition) 898 return annotated_partitions 899 900 901@register_annotator("mul") 902def _annotate_mul( 903 gm: torch.fx.GraphModule, 904 quantization_config: Optional[QuantizationConfig], 905 filter_fn: Optional[Callable[[Node], bool]] = None, 906) -> Optional[List[List[Node]]]: 907 annotated_partitions = [] 908 for node in gm.graph.nodes: 909 if node.op != "call_function" or node.target not in [ 910 torch.ops.aten.mul.Tensor, 911 torch.ops.aten.mul_.Tensor, 912 ]: 913 continue 914 915 mul_node = node 916 partition = [mul_node] 917 if _is_annotated(partition): 918 continue 919 920 if filter_fn and any(not filter_fn(n) for n in partition): 921 continue 922 923 input_act_qspec = get_input_act_qspec(quantization_config) 924 output_act_qspec = get_output_act_qspec(quantization_config) 925 926 input_qspec_map = {} 927 input_act0 = mul_node.args[0] 928 if isinstance(input_act0, Node): 929 if _is_input_large_scalar(input_act0, gm): 930 continue 931 if _is_input_non_float_tensor(input_act0): 932 continue 933 input_qspec_map[input_act0] = input_act_qspec 934 partition.append(input_act0) 935 936 input_act1 = mul_node.args[1] 937 if isinstance(input_act1, Node): 938 if _is_input_large_scalar(input_act1, gm): 939 continue 940 if _is_input_non_float_tensor(input_act1): 941 continue 942 input_qspec_map[input_act1] = input_act_qspec 943 partition.append(input_act0) 944 945 mul_node.meta["quantization_annotation"] = QuantizationAnnotation( 946 input_qspec_map=input_qspec_map, 947 output_qspec=output_act_qspec, 948 _annotated=True, 949 ) 950 annotated_partitions.append(partition) 951 return annotated_partitions 952 953 954# TODO: remove Optional in return type, fix annotated_partitions logic 955@register_annotator("cat") 956def _annotate_cat( 957 gm: torch.fx.GraphModule, 958 quantization_config: Optional[QuantizationConfig], 959 filter_fn: Optional[Callable[[Node], bool]] = None, 960) -> Optional[List[List[Node]]]: 961 cat_partitions = get_source_partitions(gm.graph, [torch.cat], filter_fn) 962 cat_partitions = list(itertools.chain.from_iterable(cat_partitions.values())) 963 annotated_partitions = [] 964 for cat_partition in cat_partitions: 965 cat_node = cat_partition.output_nodes[0] 966 if _is_annotated([cat_node]): 967 continue 968 969 if cat_node.target != torch.ops.aten.cat.default: 970 # TODO: change this to AnnotationException 971 raise Exception( # noqa: TRY002 972 f"Expected cat node: torch.ops.aten.cat.default, but found {cat_node.target}" 973 " please check if you are calling the correct capture API" 974 ) 975 976 annotated_partitions.append(cat_partition.nodes) 977 978 input_act_qspec = get_input_act_qspec(quantization_config) 979 inputs = cat_node.args[0] 980 981 input_qspec_map = {} 982 input_act0 = inputs[0] # type: ignore[index] 983 if isinstance(input_act0, Node): 984 input_qspec_map[input_act0] = input_act_qspec 985 986 shared_with_input0_qspec = SharedQuantizationSpec((input_act0, cat_node)) # type: ignore[arg-type] 987 for input_act in inputs[1:]: # type: ignore[index] 988 input_qspec_map[input_act] = shared_with_input0_qspec # type: ignore[index] 989 990 output_act_qspec = shared_with_input0_qspec 991 992 cat_node.meta["quantization_annotation"] = QuantizationAnnotation( 993 input_qspec_map=input_qspec_map, 994 output_qspec=output_act_qspec, 995 _annotated=True, 996 ) 997 return annotated_partitions 998 999 1000def _is_share_obs_or_fq_op(op: Callable) -> bool: 1001 return op in [ 1002 torch.ops.aten.hardtanh.default, 1003 torch.ops.aten.hardtanh_.default, 1004 torch.ops.aten.max_pool2d.default, 1005 torch.ops.aten.mean.default, 1006 torch.ops.aten.mean.dim, 1007 torch.ops.aten.permute.default, 1008 torch.ops.aten.permute_copy.default, 1009 torch.ops.aten.squeeze.dim, 1010 torch.ops.aten.squeeze_copy.dim, 1011 # TODO: remove? 1012 torch.ops.aten.adaptive_avg_pool2d.default, 1013 torch.ops.aten.view_copy.default, 1014 torch.ops.aten.view.default, 1015 torch.ops.aten.slice_copy.Tensor, 1016 torch.ops.aten.flatten.using_ints, 1017 ] 1018 1019 1020def propagate_annotation(model: torch.fx.GraphModule) -> None: 1021 for n in model.graph.nodes: 1022 if n.op != "call_function" or not _is_share_obs_or_fq_op(n.target): 1023 continue 1024 1025 prev_node = n.args[0] 1026 if not isinstance(prev_node, Node): 1027 continue 1028 1029 quantization_annotation = prev_node.meta.get("quantization_annotation", None) 1030 if not quantization_annotation: 1031 continue 1032 1033 output_qspec = quantization_annotation.output_qspec 1034 if not output_qspec: 1035 continue 1036 1037 # make sure current node is not annotated 1038 if ( 1039 "quantization_annotation" in n.meta 1040 and n.meta["quantization_annotation"]._annotated 1041 ): 1042 continue 1043 1044 shared_qspec = SharedQuantizationSpec(prev_node) 1045 # propagate the previous output_qspec to the current node 1046 n.meta["quantization_annotation"] = QuantizationAnnotation( 1047 input_qspec_map={ 1048 prev_node: shared_qspec, 1049 }, 1050 output_qspec=shared_qspec, 1051 _annotated=True, 1052 ) 1053 1054 1055# TODO: make the list of ops customizable 1056def _convert_scalars_to_attrs(model: torch.fx.GraphModule) -> torch.fx.GraphModule: 1057 for n in model.graph.nodes: 1058 if n.op != "call_function" or n.target not in [ 1059 torch.ops.aten.add.Tensor, 1060 torch.ops.aten.mul.Tensor, 1061 ]: 1062 continue 1063 args = list(n.args) 1064 new_args = [] 1065 for i in range(len(args)): 1066 if isinstance(args[i], torch.fx.Node): 1067 new_args.append(args[i]) 1068 continue 1069 prefix = "_tensor_constant_" 1070 get_new_attr_name = get_new_attr_name_with_prefix(prefix) 1071 tensor_constant_name = get_new_attr_name(model) 1072 float_tensor = torch.tensor(float(args[i])) 1073 model.register_buffer(tensor_constant_name, float_tensor) 1074 fake_mode = n.meta["val"].fake_mode 1075 with model.graph.inserting_before(n): 1076 get_attr_node = model.graph.create_node( 1077 "get_attr", tensor_constant_name, (), {} 1078 ) 1079 get_attr_node.meta["val"] = fake_mode.from_tensor( 1080 float_tensor, static_shapes=True 1081 ) 1082 new_args.append(get_attr_node) 1083 n.args = tuple(new_args) 1084 model.recompile() 1085 return model 1086