1# mypy: allow-untyped-defs 2import copy 3import functools 4import itertools 5import operator 6import warnings 7from dataclasses import dataclass 8from typing import ( 9 Any, 10 Callable, 11 Dict, 12 List, 13 Optional, 14 Sequence, 15 Set, 16 Tuple, 17 TYPE_CHECKING, 18 Union, 19) 20from typing_extensions import TypeAlias 21 22import torch 23import torch.nn.functional as F 24from torch.ao.quantization.fake_quantize import ( 25 FakeQuantize, 26 FusedMovingAvgObsFakeQuantize, 27) 28from torch.ao.quantization.observer import ( 29 HistogramObserver, 30 MovingAverageMinMaxObserver, 31 MovingAveragePerChannelMinMaxObserver, 32 PerChannelMinMaxObserver, 33 PlaceholderObserver, 34) 35from torch.ao.quantization.pt2e.graph_utils import find_sequential_partitions 36from torch.ao.quantization.quantizer.quantizer import ( 37 QuantizationAnnotation, 38 QuantizationSpec, 39 Quantizer, 40 SharedQuantizationSpec, 41) 42from torch.ao.quantization.quantizer.utils import _get_module_name_filter 43from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import ( 44 get_bias_qspec, 45 get_input_act_qspec, 46 get_output_act_qspec, 47 get_weight_qspec, 48 OperatorConfig, 49 OperatorPatternType, 50 QuantizationConfig, 51) 52from torch.fx import Node 53from torch.fx.passes.utils.source_matcher_utils import ( 54 get_source_partitions, 55 SourcePartition, 56) 57 58 59FilterFn: TypeAlias = Callable[[List[Node]], bool] 60 61 62if TYPE_CHECKING: 63 from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor 64 65__all__ = [ 66 "X86InductorQuantizer", 67 "get_default_x86_inductor_quantization_config", 68] 69 70 71@dataclass 72class _X86InductorQuantizationAnnotation(QuantizationAnnotation): 73 # _is_output_of_quantized_pattern: 74 # * Node as output node of a fusion pattern. 75 # * The fusion pattern supports int8 data type. 76 # * The fusion pattern has inputs annotated to insert observer. 77 # * The quantization_config is not `None`. 78 _is_output_of_quantized_pattern: bool = False 79 80 81# Operators that: 82# 1. Operators are optimized to run with int8 when int8 input provided. 83# 2. Operators do not support int8 input and produce fp32 output. 84int8_in_int8_out_ops: Set = { 85 torch.ops.aten.max_pool2d.default, 86 torch.ops.aten.cat.default, 87 torch.ops.aten.avg_pool2d.default, 88 torch.ops.aten.adaptive_avg_pool2d.default, 89 torch.ops.aten.flatten.using_ints, 90} 91 92# Operators that support the int8 data type for quantization config propagation. 93# A superset of int8_in_int8_out_ops incorporating additional operators. 94propagation_quantizable_ops = int8_in_int8_out_ops 95 96# Operators support the int8 data type 97# and recipe is configured by default in X86InductorQuantizer. 98default_quantizable_ops = propagation_quantizable_ops | { 99 torch.ops.aten.conv2d.default, 100 torch.ops.aten.linear.default, 101} 102 103# A superset of default_quantizable_ops includes operators support the int8 data type 104# but not enabled by default recipe of X86InductorQuantizer. 105quantizable_ops = default_quantizable_ops | { 106 torch.ops.aten.matmul.default, 107} 108 109QUANT_ANNOTATION_KEY = "quantization_annotation" 110 111 112def _skip_annotate(nodes: List[Node], filter_fn: Optional[FilterFn] = None) -> bool: 113 """Determine whether to skip annotation for a list of nodes.""" 114 115 # 1) Skip annotate if any node is already annotated 116 if _is_any_annotated(nodes): 117 return True 118 119 # 2) Proceed annotate if a) a filter function is provided 120 # and b) the given nodes list passes the filter function check. 121 if filter_fn and filter_fn(nodes): 122 return False 123 124 return True 125 126 127def _create_module_name_filter(module_name: str) -> FilterFn: 128 """Create a filter function for a given module name. 129 130 The filter function takes a list of nodes (as determined by the annotate function) 131 and return True if *all* nodes come from the specified module name, False otherwise. 132 133 For example: 134 linear_1: "f32[3, 10]" = torch.ops.aten.linear.default(...) # comes from a module with name `sub.linear1` 135 relu: "f32[3, 10]" = torch.ops.aten.relu.default(linear_1); # comes from a module with name `sub.relu1` 136 137 >> module_name_filter = _create_module_name_filter_inner("sub") 138 >> print(module_name_filter([relu, linear_1])) 139 # True # These two nodes are determined by `_annotate_linear_unary` function and from "sub". 140 """ 141 142 filter_fn = _get_module_name_filter(module_name) 143 144 def check_all_nodes_from_module(nodes: List[Node]) -> bool: 145 all_nodes_from_module_name: bool = all(filter_fn(n) for n in nodes) 146 return all_nodes_from_module_name 147 148 return check_all_nodes_from_module 149 150 151def _create_operator_type_filter( 152 operator_type: Callable, 153) -> FilterFn: 154 """Create a filter function for a given operator type. 155 156 The filter function takes a list of nodes and returns True if it contains 157 exactly one node with the specified operator type, False otherwise. 158 159 For example: 160 linear_1: "f32[3, 10]" = torch.ops.aten.linear.default(...) # comes from a module with name `sub.linear1` 161 relu: "f32[3, 10]" = torch.ops.aten.relu.default(linear_1); # comes from a module with name `sub.relu1` 162 163 >> operator_type_filter = _create_operator_type_filter(torch.ops.aten.linear.default) 164 >> print(operator_type_filter([relu, linear_1])) 165 # True # These two nodes are determined by `_annotate_linear_unary` function and the second node is `linear`. 166 """ 167 168 def operator_type_filter(nodes: List[Node]): 169 num_nodes_with_operator_type = sum( 170 node.target == operator_type for node in nodes 171 ) 172 if num_nodes_with_operator_type > 1: 173 raise NotImplementedError( 174 f"Several nodes within a single pattern are {operator_type}." 175 ) 176 return num_nodes_with_operator_type == 1 177 178 return operator_type_filter 179 180 181def _global_config_filter(nodes: List[Node]) -> bool: 182 """Filter function for global configuration. 183 184 This filter function takes a list of nodes and returns True if there is exactly one node 185 in the list that is a default quantizable operation, False otherwise. 186 """ 187 num_nodes_in_default_quantizable_ops = sum( 188 node.target in default_quantizable_ops for node in nodes 189 ) 190 if num_nodes_in_default_quantizable_ops > 1: 191 raise NotImplementedError( 192 "Several nodes within a single pattern are default quantizable operations." 193 ) 194 return num_nodes_in_default_quantizable_ops == 1 195 196 197def _map_module_function_to_aten_operator_type(): 198 module_function_to_aten_operator: Dict[Callable, torch._ops.OpOverloadPacket] = {} 199 map_list = ( 200 ([torch.nn.Conv2d, F.conv2d], torch.ops.aten.conv2d.default), 201 ([torch.nn.Linear, F.linear], torch.ops.aten.linear.default), 202 ([torch.nn.MaxPool2d, F.max_pool2d], torch.ops.aten.max_pool2d.default), 203 ( 204 [ 205 torch.cat, 206 ], 207 torch.ops.aten.cat.default, 208 ), 209 ([torch.nn.AvgPool2d, F.avg_pool2d], torch.ops.aten.avg_pool2d.default), 210 ( 211 [torch.nn.AdaptiveAvgPool2d, F.adaptive_avg_pool2d], 212 torch.ops.aten.adaptive_avg_pool2d.default, 213 ), 214 ( 215 [ 216 torch.flatten, 217 ], 218 torch.ops.aten.flatten.using_ints, 219 ), 220 ( 221 [ 222 torch.matmul, 223 ], 224 torch.ops.aten.matmul.default, 225 ), 226 ) 227 for map_item in map_list: 228 module_function_to_aten_operator.update(dict.fromkeys(map_item[0], map_item[1])) # type: ignore[call-overload] 229 return module_function_to_aten_operator 230 231 232def _mark_nodes_as_annotated(nodes: List[Node]): 233 for node in nodes: 234 if node is not None: 235 if QUANT_ANNOTATION_KEY not in node.meta: 236 node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation() 237 node.meta[QUANT_ANNOTATION_KEY]._annotated = True 238 239 240def _is_node_annotated(_node): 241 """ 242 return True if the node is annotated, otherwise return False 243 """ 244 return ( 245 QUANT_ANNOTATION_KEY in _node.meta 246 and _node.meta[QUANT_ANNOTATION_KEY]._annotated 247 ) 248 249 250def _is_any_annotated(nodes: List[Node]): 251 """ 252 Given a list of nodes (that represents an operator pattern), 253 check if any of the node is annotated, return True if any of the node 254 is annotated, otherwise return False. 255 """ 256 return any(_is_node_annotated(node) for node in nodes) 257 258 259def _is_all_annotated(nodes: List[Node]): 260 """ 261 Given a list of nodes (that represents an operator pattern), 262 return True if all of the node is annotated, otherwise return False. 263 """ 264 return all(_is_node_annotated(node) for node in nodes) 265 266 267def _is_quantized_op_pt2e(node: torch.fx.Node): 268 """ 269 Used for pt2e flow to check if the node is a quantized node: 270 Case1: the node has been annotated as output node of a fusion pattern. 271 Case2: the node has been annotated as single quantized node. 272 """ 273 if not _is_any_annotated([node]): 274 # The node has not been annotated, directly return False 275 return False 276 quantization_annotation = node.meta.get(QUANT_ANNOTATION_KEY, None) 277 assert isinstance(quantization_annotation, _X86InductorQuantizationAnnotation) 278 return quantization_annotation._is_output_of_quantized_pattern 279 280 281def _supported_quantized_operators() -> Dict[str, List[OperatorPatternType]]: 282 # TODO: Add more supported operators here. 283 supported_operators: Dict[str, List[OperatorPatternType]] = { 284 "conv2d": [ 285 [torch.nn.Conv2d], 286 [F.conv2d], 287 ], 288 } 289 290 # Append Conv Optional(Add) Optioinal(ReLU) 291 conv_add_relu_options = itertools.product( 292 [torch.nn.Conv2d, F.conv2d], 293 [torch.add, operator.add, None], # add 294 [torch.nn.ReLU, F.relu, None], # relu 295 ) 296 for conv_op, add_op, relu_op in conv_add_relu_options: 297 if add_op is None: 298 # Append Conv ReLU 299 supported_operators["conv2d"].append([conv_op, relu_op]) # type: ignore[list-item] 300 elif relu_op is None: 301 # Append Conv Add 302 supported_operators["conv2d"].append([conv_op, add_op]) # type: ignore[list-item] 303 else: 304 # Append Conv Add ReLU 305 supported_operators["conv2d"].append([conv_op, add_op, relu_op]) # type: ignore[list-item] 306 307 return copy.deepcopy(supported_operators) 308 309 310def _get_supported_x86_inductor_config_and_operators() -> List[OperatorConfig]: 311 supported_config_and_operators: List[OperatorConfig] = [] 312 for quantization_config in [ 313 get_default_x86_inductor_quantization_config(), 314 ]: 315 ops = _supported_quantized_operators() 316 for pattern_list in ops.values(): 317 supported_config_and_operators.append( 318 OperatorConfig(quantization_config, pattern_list) 319 ) 320 return copy.deepcopy(supported_config_and_operators) 321 322 323@functools.lru_cache 324def get_default_x86_inductor_quantization_config( 325 is_qat: bool = False, 326 is_dynamic: bool = False, 327): 328 extra_args: Dict[str, Any] = {"eps": 2**-12} 329 if is_qat: 330 if is_dynamic: 331 act_observer_or_fake_quant_ctr = FakeQuantize 332 dynamic_quant_observer = MovingAverageMinMaxObserver.with_args( 333 averaging_constant=1 334 ) 335 extra_args["observer"] = dynamic_quant_observer 336 else: 337 act_observer_or_fake_quant_ctr = FusedMovingAvgObsFakeQuantize # type: ignore[assignment] 338 else: 339 if is_dynamic: 340 act_observer_or_fake_quant_ctr = PlaceholderObserver # type: ignore[assignment] 341 else: 342 act_observer_or_fake_quant_ctr = HistogramObserver # type: ignore[assignment] 343 344 # Copy from x86 default qconfig from torch/ao/quantization/qconfig.py 345 act_quantization_spec = QuantizationSpec( 346 dtype=torch.uint8, 347 quant_min=0, 348 quant_max=255, # reduce_range=False 349 qscheme=torch.per_tensor_affine, 350 is_dynamic=is_dynamic, 351 observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr.with_args( 352 **extra_args 353 ), 354 ) 355 356 weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = ( 357 FusedMovingAvgObsFakeQuantize if is_qat else PerChannelMinMaxObserver 358 ) 359 360 if is_qat: 361 # Only support per channel quant for now 362 extra_args["observer"] = MovingAveragePerChannelMinMaxObserver # type: ignore[dict-item] 363 weight_quantization_spec = QuantizationSpec( 364 dtype=torch.int8, 365 quant_min=-128, 366 quant_max=127, 367 qscheme=torch.per_channel_symmetric, 368 ch_axis=0, # 0 corresponding to weight shape = (oc, ic, kh, kw) of conv 369 is_dynamic=False, 370 observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr.with_args( 371 **extra_args 372 ), 373 ) 374 bias_quantization_spec = None # will use placeholder observer by default 375 quantization_config = QuantizationConfig( 376 act_quantization_spec, 377 act_quantization_spec, 378 weight_quantization_spec, 379 bias_quantization_spec, 380 is_qat, 381 ) 382 return quantization_config 383 384 385def _get_supported_config_and_operators() -> List[OperatorConfig]: 386 return _get_supported_x86_inductor_config_and_operators() 387 388 389def _annotate_nodes_not_quantize(nodes: Union[Node, List[Node]]) -> None: 390 """Annotate nodes to exclude them from quantization (their `quantization_config` is `None`).""" 391 if not isinstance(nodes, list): 392 nodes = [nodes] 393 for node in nodes: 394 node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( 395 _annotated=True 396 ) 397 398 399def _config_checker(method: Callable) -> Callable: 400 @functools.wraps(method) 401 def wrapper( 402 quantizer: "X86InductorQuantizer", 403 name: Any, 404 quantization_config: Optional["QuantizationConfig"], 405 ) -> "X86InductorQuantizer": 406 if quantizer._need_skip_config(quantization_config): 407 warnings.warn( 408 f"Skip the quantization config for {name}.", 409 ) 410 return quantizer 411 return method(quantizer, name, quantization_config) 412 413 return wrapper 414 415 416@dataclass 417class _CurrentQuantizationMode: 418 r"""Configuration defining the current quantization mode for the quantizer. 419 420 All possible current quantization modes are listed below: 421 ---------------------------------------------------------------------------------------------------------- 422 | dynamic_state 423 qat_state |--------------------------------------------------------------------------------------------- 424 | None | True | False 425 ---------------------------------------------------------------------------------------------------------- 426 None | quantizer does not receive a non-None `quantization_config` | \ | \ 427 False | quantizer will not do QAT | dynamic | static 428 True | quantizer will do QAT | QAT + dynamic | QAT + static 429 """ 430 431 qat_state: Optional[bool] 432 dynamic_state: Optional[bool] 433 434 435class X86InductorQuantizer(Quantizer): 436 supported_config_and_operators = _get_supported_config_and_operators() 437 module_function_to_aten_operator_type = _map_module_function_to_aten_operator_type() 438 439 def __init__(self) -> None: 440 super().__init__() 441 self.global_config: Optional[QuantizationConfig] = None 442 self.operator_type_qconfig: Dict[ 443 torch._ops.OpOverloadPacket, Optional[QuantizationConfig] 444 ] = {} 445 self.module_name_qconfig: Dict[str, Optional[QuantizationConfig]] = {} 446 447 @classmethod 448 def get_supported_quantization_configs(cls) -> List[QuantizationConfig]: 449 op_configs: Set[QuantizationConfig] = { 450 spec for spec, _ in cls.supported_config_and_operators 451 } 452 return list(op_configs) 453 454 @classmethod 455 def get_supported_operator_for_quantization_config( 456 cls, quantization_config: Optional[QuantizationConfig] 457 ) -> List[OperatorPatternType]: 458 if quantization_config is None: 459 all_ops = [] 460 for _, ops in cls.supported_config_and_operators: 461 all_ops.extend(ops) 462 return all_ops 463 464 for config, ops in cls.supported_config_and_operators: 465 if config == quantization_config: 466 return ops 467 return [] 468 469 def _get_current_quantization_mode(self) -> _CurrentQuantizationMode: 470 """Retrieves the current quantization mode based on all configurations.""" 471 qat_state = None 472 dynamic_state = None 473 474 # As we use `_need_skip_config` to skip all invalid configurations, 475 # we can safely assume that the all existing non-None configurations 476 # have the same quantization mode. 477 for qconfig in ( 478 list(self.module_name_qconfig.values()) 479 + list(self.operator_type_qconfig.values()) 480 + [self.global_config] 481 ): 482 if qconfig is not None: 483 # Query the `is_qat` state 484 if qat_state is None: 485 qat_state = qconfig.is_qat 486 else: 487 assert qat_state == qconfig.is_qat, ( 488 f"All non-None quantization configs should have the same `is_qat`," 489 f"but got {qat_state} and {qconfig.is_qat}." 490 ) 491 # Query the `is_dynamic` state 492 input_activation_spec = qconfig.input_activation 493 if input_activation_spec is not None: 494 if dynamic_state is None: 495 dynamic_state = input_activation_spec.is_dynamic 496 else: 497 assert dynamic_state == input_activation_spec.is_dynamic, ( 498 f"All non-None `input_activation_spec` should have the same `is_dynamic`," 499 f"but got {dynamic_state} and {input_activation_spec.is_dynamic}." 500 ) 501 return _CurrentQuantizationMode( 502 qat_state=qat_state, dynamic_state=dynamic_state 503 ) 504 505 def _need_skip_config( 506 self, quantization_config: Optional[QuantizationConfig] 507 ) -> bool: 508 """Check if the provided quantization config is valid for X86InductorQuantizer. 509 510 Mixed static/dynamic configurations or mixed QAT/non-QAT configurations are not supported. 511 To avoid such a mix, we compare the incoming configuration with current configuration status. 512 Refer the `_CurrentQuantizationMode` definition for all possible modes. 513 """ 514 if quantization_config is None: 515 return False 516 517 need_skip = False 518 current_mode = self._get_current_quantization_mode() 519 if ( 520 current_mode.qat_state is not None 521 and current_mode.qat_state != quantization_config.is_qat 522 ): 523 warnings.warn("Mixed QAT and Non-QAT quantization config is not supported.") 524 need_skip = True 525 if current_mode.dynamic_state is not None: 526 input_activation_spec = quantization_config.input_activation 527 if ( 528 input_activation_spec is not None 529 and current_mode.dynamic_state != input_activation_spec.is_dynamic 530 ): 531 warnings.warn( 532 "Mixed dynamic and static quantization config is not supported." 533 ) 534 need_skip = True 535 return need_skip 536 537 def set_global(self, quantization_config: QuantizationConfig): 538 if self._need_skip_config(quantization_config): 539 warnings.warn("Skip the global quantization config.") 540 return self 541 self.global_config = quantization_config 542 return self 543 544 def get_global_quantization_config(self): 545 if not isinstance(self.global_config, QuantizationConfig): 546 warnings.warn( 547 "The global_config for X86InductorQuantizer is currently invalid. \ 548 Please ensure that you use set_global to establish the global quantization configuration." 549 ) 550 return self.global_config 551 552 @_config_checker 553 def set_function_type_qconfig( 554 self, 555 function_type: Callable, 556 quantization_config: Optional[QuantizationConfig], 557 ) -> "X86InductorQuantizer": 558 if function_type in X86InductorQuantizer.module_function_to_aten_operator_type: 559 self._set_aten_operator_qconfig( 560 X86InductorQuantizer.module_function_to_aten_operator_type[ 561 function_type 562 ], 563 quantization_config, 564 ) 565 else: 566 warnings.warn( 567 f"function: Unable to customize quantization config for {function_type} by X86InductorQuantizer." 568 ) 569 return self 570 571 @_config_checker 572 def set_module_type_qconfig( 573 self, 574 module_type: torch.nn.Module, 575 quantization_config: Optional[QuantizationConfig], 576 ) -> "X86InductorQuantizer": 577 if module_type in X86InductorQuantizer.module_function_to_aten_operator_type: 578 self._set_aten_operator_qconfig( 579 X86InductorQuantizer.module_function_to_aten_operator_type[module_type], 580 quantization_config, 581 ) 582 else: 583 warnings.warn( 584 f"Module: Unable to customize quantization config for {module_type} by X86InductorQuantizer." 585 ) 586 return self 587 588 @_config_checker 589 def set_module_name_qconfig( 590 self, module_name: str, quantization_config: Optional[QuantizationConfig] 591 ): 592 """Set quantization_config for a submodule with name: `module_name`, for example: 593 quantizer.set_module_name_qconfig("blocks.sub"), it will quantize all supported operator/operator 594 patterns in the submodule with this module name with the given `quantization_config` 595 596 The supported operators include `quantizable_ops` and `propagation_quantizable_ops`. 597 """ 598 self.module_name_qconfig[module_name] = quantization_config 599 return self 600 601 def _set_aten_operator_qconfig( 602 self, 603 operator_type: torch._ops.OpOverloadPacket, 604 quantization_config: Optional[QuantizationConfig], 605 ) -> "X86InductorQuantizer": 606 if operator_type in quantizable_ops: 607 self.operator_type_qconfig[operator_type] = quantization_config 608 else: 609 warnings.warn( 610 f"operator: Unable to quantize {operator} by X86InductorQuantizer." 611 ) 612 return self 613 614 def _annotate_conv_node_helper( 615 self, 616 conv_node: torch.fx.Node, 617 annotate_output: bool, 618 quantization_config: Optional[QuantizationConfig], 619 ) -> None: 620 """Helper function to annotate the conv node""" 621 if quantization_config is None: 622 _annotate_nodes_not_quantize(conv_node) 623 return 624 input_qspec_map = {} 625 input_node = conv_node.args[0] 626 assert isinstance(input_node, Node) 627 input_qspec_map[input_node] = get_input_act_qspec(quantization_config) 628 weight_node = conv_node.args[1] 629 assert isinstance(weight_node, Node) 630 input_qspec_map[weight_node] = get_weight_qspec(quantization_config) 631 bias_node = None if len(conv_node.args) == 2 else conv_node.args[2] 632 if isinstance(bias_node, Node): 633 input_qspec_map[bias_node] = get_bias_qspec(quantization_config) 634 if annotate_output: 635 conv_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( 636 input_qspec_map=input_qspec_map, 637 _annotated=True, 638 _is_output_of_quantized_pattern=True, 639 ) 640 else: 641 conv_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( 642 input_qspec_map=input_qspec_map, 643 _annotated=True, 644 ) 645 646 def _annotate_linear_node_helper( 647 self, 648 linear_node: torch.fx.Node, 649 annotate_output: bool, 650 quantization_config: Optional[QuantizationConfig], 651 ) -> None: 652 """Helper function to annotate the linear node""" 653 if quantization_config is None: 654 _annotate_nodes_not_quantize(linear_node) 655 return 656 input_qspec_map = {} 657 assert linear_node.target in (torch.ops.aten.linear.default,) 658 has_bias = len(linear_node.args) == 3 659 input_index = 0 660 weight_index = 1 661 bias_index = 2 662 663 input_node = linear_node.args[input_index] 664 assert isinstance(input_node, Node) 665 input_qspec_map[input_node] = get_input_act_qspec(quantization_config) 666 667 weight_node = linear_node.args[weight_index] 668 assert isinstance(weight_node, Node) 669 input_qspec_map[weight_node] = get_weight_qspec(quantization_config) 670 671 bias_node = linear_node.args[bias_index] if has_bias else None 672 if isinstance(bias_node, Node): 673 input_qspec_map[bias_node] = get_bias_qspec(quantization_config) 674 675 if annotate_output: 676 linear_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( 677 input_qspec_map=input_qspec_map, 678 _annotated=True, 679 _is_output_of_quantized_pattern=True, 680 ) 681 else: 682 linear_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( 683 input_qspec_map=input_qspec_map, _annotated=True 684 ) 685 686 def _get_output_nodes_of_partitions( 687 self, 688 partition_list: List[SourcePartition], 689 ) -> List[torch.fx.Node]: 690 """Helper function to get the output node list from partition list""" 691 output_node_list = [] 692 for partition in partition_list: 693 if len(partition.output_nodes) > 1: 694 raise ValueError("Input partition has more than one output node") 695 output_node = partition.output_nodes[0] 696 assert isinstance(output_node, Node) 697 output_node_list.append(output_node) 698 if len(output_node_list) != len(partition_list): 699 raise ValueError( 700 "length of output_node_list should equal to length of partition_list" 701 ) 702 return output_node_list 703 704 def _get_input_idx_for_binary_node( 705 self, 706 conv_gemm_node: torch.fx.Node, 707 binary_node: torch.fx.Node, 708 ): 709 """Helper function to check conv_gemm and extra input node index 710 for binary node fused with conv_gemm. 711 """ 712 conv_gemm_node_idx = None 713 extra_input_node_idx = None 714 if (binary_node.args[0].op == "call_function") and ( # type: ignore[union-attr] 715 binary_node.args[0] == conv_gemm_node 716 ): 717 conv_gemm_node_idx = 0 718 extra_input_node_idx = 1 719 elif (binary_node.args[1].op == "call_function") and ( # type: ignore[union-attr] 720 binary_node.args[1] == conv_gemm_node 721 ): 722 conv_gemm_node_idx = 1 723 extra_input_node_idx = 0 724 extra_input_node = binary_node.args[extra_input_node_idx] # type: ignore[index] 725 assert isinstance(extra_input_node, Node) 726 return conv_gemm_node_idx, extra_input_node_idx 727 728 def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: 729 """Annotate the given model with quantization configurations. 730 731 Annotation contracts: 732 1. Annotate each node according to the user's qconfig in the following order: 733 `module_name_qconfig`, `operator_type_qconfig`, and `global_config`. 734 2. Avoid re-annotating nodes already annotated in prior stages. For example, 735 if `linear1` has been annotated by `module_name_qconfig`, it won't be annotated again 736 during the processing of the 'operator_type_qconfig' or 'global_config'. 737 3. For config is `None`, the node will be annotated with `_X86InductorQuantizationAnnotation(_annotated=True)`. 738 739 For each pair of (module_name_or_operator_type_or_global, qconfig), a filter function is created. 740 This filter function checks if the node is marked by current stage and not annotated by the previous stage. 741 """ 742 for module_name, quantization_config in self.module_name_qconfig.items(): 743 self._annotate_with_config( 744 model, quantization_config, _create_module_name_filter(module_name) 745 ) 746 747 for operator_type, quantization_config in self.operator_type_qconfig.items(): 748 self._annotate_with_config( 749 model, quantization_config, _create_operator_type_filter(operator_type) 750 ) 751 752 if self.global_config: 753 self._annotate_with_config( 754 model, 755 self.global_config, 756 _global_config_filter, 757 ) 758 759 # Once we've annotated the model with quantization configurations, we also need to annotate 760 # the output of quantizable operations. For example, if we annotated `maxpool2d` to quantize its inputs, 761 # we will quantize its output accordingly. This enables us to fuse the dq-operator-q into a quantized op. 762 # Refer to https://github.com/intel/intel-extension-for-pytorch/blob/ 763 # 90d19323d96afc53fcc22ba5a7bb3fb07fdd6c1c/intel_extension_for_pytorch/quantization/_recipe.py#L487 764 765 self._annotate_output_for_int8_in_int8_out_pattern_entry(model) 766 767 return model 768 769 def _annotate_with_config( 770 self, 771 model: torch.fx.GraphModule, 772 quantization_config: Optional[QuantizationConfig], 773 filter_fn: FilterFn, 774 ) -> None: 775 """Annotate the model with the given quantization configuration. 776 777 High-level description of quantization recipe for X86 Inductor Backend: 778 Step 1: Apply quantization recipe for fusion patterns of conv/linear to enable int8 data type actively. 779 Step 2: Propagate quantization annotation for patterns besides conv/linear. Go through the pattern in model 780 from start to the end. If a pattern supports computation with int8 data type and inputs connected to 781 quantized patterns, annotate its inputs as quantized pattern. 782 """ 783 784 # Step1: Recipe of fusion patterns like conv/linear. 785 self._annotate_conv2d_fusion_pattern(model, quantization_config, filter_fn) 786 self._annotate_linear_fusion_pattern(model, quantization_config, filter_fn) 787 self._annotate_matmul(model, quantization_config, filter_fn) 788 789 # Step2: Recipe to propagate annotation for patterns beside conv/linear. 790 # Go through all the nodes from start to end. 791 # Recipe refer to https://github.com/intel/intel-extension-for-pytorch/blob/ 792 # 90d19323d96afc53fcc22ba5a7bb3fb07fdd6c1c/intel_extension_for_pytorch/quantization/_recipe.py#L538 793 794 self._annotate_propagation_quantizable_pattern_entry( 795 model, quantization_config, filter_fn 796 ) 797 798 def _annotate_qat_conv2d_fusion_pattern( 799 self, 800 model: torch.fx.GraphModule, 801 quantization_config: Optional[QuantizationConfig], 802 filter_fn: Optional[FilterFn] = None, 803 ): 804 # Annotate QAT Specific patterns 805 self._annotate_qat_conv2d_bn_binary_unary(model, quantization_config, filter_fn) 806 self._annotate_qat_conv2d_bn_binary(model, quantization_config, filter_fn) 807 self._annotate_qat_conv2d_bn_unary(model, quantization_config, filter_fn) 808 self._annotate_qat_conv2d_bn(model, quantization_config, filter_fn) 809 810 def _annotate_qat_conv2d_bn_binary_unary( 811 self, 812 gm: torch.fx.GraphModule, 813 quantization_config: Optional[QuantizationConfig], 814 filter_fn: Optional[FilterFn] = None, 815 ) -> None: 816 fused_partitions = find_sequential_partitions( 817 gm, [torch.nn.Conv2d, torch.nn.BatchNorm2d, operator.add, torch.nn.ReLU] 818 ) 819 for fused_partition in fused_partitions: 820 ( 821 conv_partition, 822 bn_partition, 823 binary_partition, 824 unary_partition, 825 ) = fused_partition 826 827 ( 828 conv_node, 829 bn_output_node, 830 binary_node, 831 unary_node, 832 ) = self._get_output_nodes_of_partitions( 833 [conv_partition, bn_partition, binary_partition, unary_partition] 834 ) 835 if len(bn_output_node.users) != 1: 836 # Conv BN pattern should only has 1 user. 837 continue 838 ( 839 bn_output_node_idx, 840 extra_input_node_idx, 841 ) = self._get_input_idx_for_binary_node(bn_output_node, binary_node) 842 if (bn_output_node_idx is None) or (extra_input_node_idx is None): 843 continue 844 if bn_output_node != binary_node.args[bn_output_node_idx]: 845 raise ValueError(f"{bn_output_node} doesn't match input of binary node") 846 extra_input_node = binary_node.args[extra_input_node_idx] 847 848 if ( 849 conv_node.op != "call_function" 850 or conv_node.target != torch.ops.aten.conv2d.default 851 ): 852 continue 853 854 if _skip_annotate( 855 [unary_node, binary_node, bn_output_node, conv_node], filter_fn 856 ): 857 continue 858 859 self._annotate_conv_node_helper(conv_node, False, quantization_config) 860 861 if quantization_config is not None: 862 binary_node_input_qspec_map = {} 863 binary_node_input_qspec_map[extra_input_node] = get_input_act_qspec( 864 quantization_config 865 ) 866 binary_node.meta[ 867 QUANT_ANNOTATION_KEY 868 ] = _X86InductorQuantizationAnnotation( 869 input_qspec_map=binary_node_input_qspec_map, 870 _annotated=True, 871 ) 872 unary_node.meta[ 873 QUANT_ANNOTATION_KEY 874 ] = _X86InductorQuantizationAnnotation( 875 # TODO<leslie> Remove the annotate of output in QAT when qat util support pattern matcher. 876 output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type] 877 _annotated=True, 878 _is_output_of_quantized_pattern=True, 879 ) 880 else: 881 _annotate_nodes_not_quantize([binary_node, unary_node]) 882 nodes_to_mark_annotated = list(conv_partition.nodes) 883 nodes_to_mark_annotated.extend(list(bn_partition.nodes)) 884 nodes_to_mark_annotated.extend(list(binary_partition.nodes)) 885 nodes_to_mark_annotated.extend(list(unary_partition.nodes)) 886 _mark_nodes_as_annotated(nodes_to_mark_annotated) 887 888 def _annotate_qat_conv2d_bn_binary( 889 self, 890 gm: torch.fx.GraphModule, 891 quantization_config: Optional[QuantizationConfig], 892 filter_fn: Optional[FilterFn] = None, 893 ) -> None: 894 fused_partitions = find_sequential_partitions( 895 gm, [torch.nn.Conv2d, torch.nn.BatchNorm2d, operator.add] 896 ) 897 for fused_partition in fused_partitions: 898 conv_partition, bn_partition, binary_partition = fused_partition 899 ( 900 conv_node, 901 bn_output_node, 902 binary_node, 903 ) = self._get_output_nodes_of_partitions( 904 [conv_partition, bn_partition, binary_partition] 905 ) 906 if len(bn_output_node.users) != 1: 907 # Conv BN pattern should only has 1 user. 908 continue 909 ( 910 bn_output_node_idx, 911 extra_input_node_idx, 912 ) = self._get_input_idx_for_binary_node(bn_output_node, binary_node) 913 if (bn_output_node_idx is None) or (extra_input_node_idx is None): 914 continue 915 if bn_output_node != binary_node.args[bn_output_node_idx]: 916 raise ValueError(f"{bn_output_node} doesn't match input of binary node") 917 918 extra_input_node = binary_node.args[extra_input_node_idx] 919 920 if ( 921 conv_node.op != "call_function" 922 or conv_node.target != torch.ops.aten.conv2d.default 923 ): 924 continue 925 926 if _skip_annotate([binary_node, bn_output_node, conv_node], filter_fn): 927 continue 928 929 self._annotate_conv_node_helper(conv_node, False, quantization_config) 930 931 if quantization_config is not None: 932 binary_node_input_qspec_map = {} 933 binary_node_input_qspec_map[extra_input_node] = get_input_act_qspec( 934 quantization_config 935 ) 936 binary_node.meta[ 937 QUANT_ANNOTATION_KEY 938 ] = _X86InductorQuantizationAnnotation( 939 input_qspec_map=binary_node_input_qspec_map, 940 # TODO<leslie> Remove the annotate of output in QAT when qat util support pattern matcher. 941 output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type] 942 _annotated=True, 943 _is_output_of_quantized_pattern=True, 944 ) 945 else: 946 _annotate_nodes_not_quantize(binary_node) 947 nodes_to_mark_annotated = list(conv_partition.nodes) 948 nodes_to_mark_annotated.extend(list(bn_partition.nodes)) 949 nodes_to_mark_annotated.extend(list(binary_partition.nodes)) 950 _mark_nodes_as_annotated(nodes_to_mark_annotated) 951 952 def _annotate_qat_conv2d_bn_unary( 953 self, 954 gm: torch.fx.GraphModule, 955 quantization_config: Optional[QuantizationConfig], 956 filter_fn: Optional[FilterFn] = None, 957 ) -> None: 958 fused_partitions = [] 959 unary_patterns = [ 960 [torch.nn.Conv2d, torch.nn.BatchNorm2d, torch.nn.ReLU], 961 [torch.nn.Conv2d, torch.nn.BatchNorm2d, torch.nn.Hardtanh], 962 [torch.nn.Conv2d, torch.nn.BatchNorm2d, torch.nn.Hardswish], 963 [torch.nn.Conv2d, torch.nn.BatchNorm2d, torch.nn.ReLU6], 964 [torch.nn.Conv2d, torch.nn.BatchNorm2d, torch.nn.SiLU], 965 ] 966 for unary_pattern in unary_patterns: 967 partitions = find_sequential_partitions(gm, unary_pattern) 968 if partitions: 969 # Extend the fused_partitions if partitions is not empty 970 fused_partitions.extend(partitions) 971 972 for fused_partition in fused_partitions: 973 conv_partition, bn_partition, unary_partition = fused_partition 974 ( 975 conv_node, 976 bn_output_node, 977 unary_node, 978 ) = self._get_output_nodes_of_partitions( 979 [conv_partition, bn_partition, unary_partition] 980 ) 981 982 if ( 983 conv_node.op != "call_function" 984 or conv_node.target != torch.ops.aten.conv2d.default 985 ): 986 continue 987 988 if _skip_annotate([unary_node, bn_output_node, conv_node], filter_fn): 989 continue 990 991 self._annotate_conv_node_helper(conv_node, False, quantization_config) 992 if quantization_config is not None: 993 unary_node.meta[ 994 QUANT_ANNOTATION_KEY 995 ] = _X86InductorQuantizationAnnotation( 996 # TODO<leslie> Remove the annotate of output in QAT when qat util support pattern matcher. 997 output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type] 998 _annotated=True, 999 _is_output_of_quantized_pattern=True, 1000 ) 1001 else: 1002 _annotate_nodes_not_quantize(unary_node) 1003 nodes_to_mark_annotated = list(conv_partition.nodes) 1004 nodes_to_mark_annotated.extend(list(bn_partition.nodes)) 1005 nodes_to_mark_annotated.extend(list(unary_partition.nodes)) 1006 _mark_nodes_as_annotated(nodes_to_mark_annotated) 1007 1008 def _annotate_qat_conv2d_bn( 1009 self, 1010 gm: torch.fx.GraphModule, 1011 quantization_config: Optional[QuantizationConfig], 1012 filter_fn: Optional[FilterFn] = None, 1013 ) -> None: 1014 fused_partitions = find_sequential_partitions( 1015 gm, [torch.nn.Conv2d, torch.nn.BatchNorm2d] 1016 ) 1017 for fused_partition in fused_partitions: 1018 conv_partition, bn_partition = fused_partition 1019 conv_node, bn_output_node = self._get_output_nodes_of_partitions( 1020 [conv_partition, bn_partition] 1021 ) 1022 1023 if ( 1024 conv_node.op != "call_function" 1025 or conv_node.target != torch.ops.aten.conv2d.default 1026 ): 1027 continue 1028 1029 if _skip_annotate([bn_output_node, conv_node], filter_fn): 1030 continue 1031 1032 self._annotate_conv_node_helper(conv_node, False, quantization_config) 1033 if quantization_config is not None: 1034 bn_output_node.meta[ 1035 QUANT_ANNOTATION_KEY 1036 ] = _X86InductorQuantizationAnnotation( 1037 # TODO<leslie> Remove the annotate of output in QAT when qat util support pattern matcher. 1038 output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type] 1039 _annotated=True, 1040 _is_output_of_quantized_pattern=True, 1041 ) 1042 else: 1043 _annotate_nodes_not_quantize(bn_output_node) 1044 nodes_to_mark_annotated = list(conv_partition.nodes) 1045 nodes_to_mark_annotated.extend(list(bn_partition.nodes)) 1046 _mark_nodes_as_annotated(nodes_to_mark_annotated) 1047 1048 def _annotate_conv2d_fusion_pattern( 1049 self, 1050 model: torch.fx.GraphModule, 1051 quantization_config: Optional[QuantizationConfig], 1052 filter_fn: Optional[FilterFn] = None, 1053 ): 1054 if (quantization_config is None) or (quantization_config.is_qat): 1055 # Annotate QAT specific pattern: mainly due to BN not folded in prepare_qat 1056 self._annotate_qat_conv2d_fusion_pattern( 1057 model, quantization_config, filter_fn 1058 ) 1059 self._annotate_conv2d_binary_unary(model, quantization_config, filter_fn) 1060 self._annotate_conv2d_binary(model, quantization_config, filter_fn) 1061 self._annotate_conv2d_unary(model, quantization_config, filter_fn) 1062 self._annotate_conv2d(model, quantization_config, filter_fn) 1063 1064 def _annotate_linear_fusion_pattern( 1065 self, 1066 model: torch.fx.GraphModule, 1067 quantization_config: Optional[QuantizationConfig], 1068 filter_fn: Optional[FilterFn] = None, 1069 ): 1070 self._annotate_linear_binary_unary(model, quantization_config, filter_fn) 1071 self._annotate_linear_unary(model, quantization_config, filter_fn) 1072 self._annotate_linear(model, quantization_config, filter_fn) 1073 1074 def _annotate_matmul( 1075 self, 1076 model: torch.fx.GraphModule, 1077 quantization_config: Optional[QuantizationConfig], 1078 filter_fn: Optional[FilterFn] = None, 1079 ): 1080 for node in model.graph.nodes: 1081 if node.target != torch.ops.aten.matmul.default: 1082 continue 1083 if _skip_annotate([node], filter_fn): 1084 continue 1085 1086 if quantization_config is None: 1087 _annotate_nodes_not_quantize(node) 1088 continue 1089 1090 input_qspec_map = {} 1091 matmul_node = node 1092 for input_node in matmul_node.args: 1093 input_qspec_map[input_node] = get_input_act_qspec(quantization_config) 1094 matmul_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( 1095 input_qspec_map=input_qspec_map, 1096 _annotated=True, 1097 _is_output_of_quantized_pattern=True, 1098 ) 1099 1100 def _annotate_conv2d_binary_unary( 1101 self, 1102 gm: torch.fx.GraphModule, 1103 quantization_config: Optional[QuantizationConfig], 1104 filter_fn: Optional[FilterFn] = None, 1105 ) -> None: 1106 # Conv2d + add + unary op 1107 fused_partitions = find_sequential_partitions( 1108 gm, [torch.nn.Conv2d, operator.add, torch.nn.ReLU] 1109 ) 1110 for fused_partition in fused_partitions: 1111 conv_partition, binary_partition, unary_partition = fused_partition 1112 conv_node, binary_node, unary_node = self._get_output_nodes_of_partitions( 1113 [conv_partition, binary_partition, unary_partition] 1114 ) 1115 if len(conv_node.users) != 1: 1116 # Conv Node should only has 1 user node 1117 continue 1118 conv_node_idx, extra_input_node_idx = self._get_input_idx_for_binary_node( 1119 conv_node, binary_node 1120 ) 1121 if (conv_node_idx is None) or (extra_input_node_idx is None): 1122 continue 1123 if conv_node != binary_node.args[conv_node_idx]: 1124 raise ValueError(f"{conv_node} doesn't match input of binary node") 1125 extra_input_node = binary_node.args[extra_input_node_idx] 1126 if ( 1127 conv_node.op != "call_function" 1128 or conv_node.target != torch.ops.aten.conv2d.default 1129 ): 1130 # No conv node found to be fused with add 1131 continue 1132 if _skip_annotate([unary_node, binary_node, conv_node], filter_fn): 1133 continue 1134 1135 if quantization_config is None: 1136 _annotate_nodes_not_quantize([conv_node, binary_node, unary_node]) 1137 continue 1138 1139 self._annotate_conv_node_helper(conv_node, False, quantization_config) 1140 binary_node_input_qspec_map = {} 1141 binary_node_input_qspec_map[extra_input_node] = get_input_act_qspec( 1142 quantization_config 1143 ) 1144 binary_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( 1145 input_qspec_map=binary_node_input_qspec_map, 1146 _annotated=True, 1147 ) 1148 unary_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( 1149 _annotated=True, 1150 _is_output_of_quantized_pattern=True, 1151 ) 1152 1153 def _annotate_conv2d_binary( 1154 self, 1155 gm: torch.fx.GraphModule, 1156 quantization_config: Optional[QuantizationConfig], 1157 filter_fn: Optional[FilterFn] = None, 1158 ) -> None: 1159 # Conv2d + add 1160 fused_partitions = find_sequential_partitions( 1161 gm, [torch.nn.Conv2d, operator.add] 1162 ) 1163 for fused_partition in fused_partitions: 1164 conv_partition, binary_partition = fused_partition 1165 conv_node, binary_node = self._get_output_nodes_of_partitions( 1166 [conv_partition, binary_partition] 1167 ) 1168 if len(conv_node.users) != 1: 1169 # Conv Node should only has 1 user node 1170 continue 1171 conv_node_idx, extra_input_node_idx = self._get_input_idx_for_binary_node( 1172 conv_node, binary_node 1173 ) 1174 if (conv_node_idx is None) or (extra_input_node_idx is None): 1175 continue 1176 if conv_node != binary_node.args[conv_node_idx]: 1177 raise ValueError(f"{conv_node} doesn't match input of binary node") 1178 extra_input_node = binary_node.args[extra_input_node_idx] 1179 assert isinstance(conv_node, Node) 1180 if ( 1181 conv_node.op != "call_function" 1182 or conv_node.target != torch.ops.aten.conv2d.default 1183 ): 1184 # No conv node found to be fused with add 1185 continue 1186 if _skip_annotate([binary_node, conv_node], filter_fn): 1187 continue 1188 1189 if quantization_config is None: 1190 _annotate_nodes_not_quantize([conv_node, binary_node]) 1191 continue 1192 1193 self._annotate_conv_node_helper(conv_node, False, quantization_config) 1194 binary_node_input_qspec_map = {} 1195 binary_node_input_qspec_map[extra_input_node] = get_input_act_qspec( 1196 quantization_config 1197 ) 1198 binary_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( 1199 input_qspec_map=binary_node_input_qspec_map, 1200 _annotated=True, 1201 _is_output_of_quantized_pattern=True, 1202 ) 1203 1204 def _annotate_conv2d_unary( 1205 self, 1206 gm: torch.fx.GraphModule, 1207 quantization_config: Optional[QuantizationConfig], 1208 filter_fn: Optional[FilterFn] = None, 1209 ) -> None: 1210 fused_partitions = [] 1211 unary_patterns = [ 1212 [torch.nn.Conv2d, torch.nn.ReLU], 1213 [torch.nn.Conv2d, torch.nn.Hardtanh], 1214 [torch.nn.Conv2d, torch.nn.Hardswish], 1215 [torch.nn.Conv2d, torch.nn.ReLU6], 1216 [torch.nn.Conv2d, torch.nn.SiLU], 1217 ] 1218 for unary_pattern in unary_patterns: 1219 partitions = find_sequential_partitions(gm, unary_pattern) 1220 if partitions: 1221 # Extend the fused_partitions if partitions is not empty 1222 fused_partitions.extend(partitions) 1223 1224 for fused_partition in fused_partitions: 1225 conv_partition, unary_partition = fused_partition 1226 conv_node, unary_node = self._get_output_nodes_of_partitions( 1227 [conv_partition, unary_partition] 1228 ) 1229 if ( 1230 conv_node.op != "call_function" 1231 or conv_node.target != torch.ops.aten.conv2d.default 1232 ): 1233 continue 1234 if _skip_annotate([unary_node, conv_node], filter_fn): 1235 continue 1236 1237 if quantization_config is None: 1238 _annotate_nodes_not_quantize([conv_node, unary_node]) 1239 continue 1240 1241 self._annotate_conv_node_helper(conv_node, False, quantization_config) 1242 unary_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( 1243 _annotated=True, 1244 _is_output_of_quantized_pattern=True, 1245 ) 1246 1247 def _annotate_conv2d( 1248 self, 1249 gm: torch.fx.GraphModule, 1250 quantization_config: Optional[QuantizationConfig], 1251 filter_fn: Optional[FilterFn] = None, 1252 ) -> None: 1253 conv_partitions = get_source_partitions( 1254 gm.graph, [torch.nn.Conv2d, torch.nn.functional.conv2d] 1255 ) 1256 conv_partitions = list(itertools.chain.from_iterable(conv_partitions.values())) 1257 for conv_partition in conv_partitions: 1258 if len(conv_partition.output_nodes) > 1: 1259 raise ValueError("conv partition has more than one output node") 1260 conv_node = conv_partition.output_nodes[0] 1261 if ( 1262 conv_node.op != "call_function" 1263 or conv_node.target != torch.ops.aten.conv2d.default 1264 ): 1265 raise ValueError(f"{conv_node} is not an aten conv2d operator") 1266 # skip annotation if it is already annotated 1267 if _skip_annotate([conv_node], filter_fn): 1268 continue 1269 self._annotate_conv_node_helper(conv_node, True, quantization_config) 1270 1271 def _annotate_maxpool2d( 1272 self, 1273 node: Node, 1274 quantization_config: Optional[QuantizationConfig], 1275 ) -> None: 1276 if node.target is not torch.ops.aten.max_pool2d.default: 1277 return 1278 if quantization_config is None: 1279 _annotate_nodes_not_quantize(node) 1280 return 1281 1282 maxpool_node = node 1283 if _is_any_annotated( 1284 [ 1285 maxpool_node, 1286 ] 1287 ): 1288 return 1289 1290 input_node = maxpool_node.args[0] 1291 assert isinstance(input_node, Node) 1292 input_qspec_map = {} 1293 input_qspec_map[input_node] = get_input_act_qspec(quantization_config) 1294 maxpool_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( 1295 input_qspec_map=input_qspec_map, 1296 _annotated=True, 1297 _is_output_of_quantized_pattern=True, 1298 ) 1299 1300 def _annotate_cat( 1301 self, node: Node, quantization_config: QuantizationConfig 1302 ) -> None: 1303 if quantization_config is None: 1304 _annotate_nodes_not_quantize(node) 1305 return 1306 cat_node = node 1307 input_nodes = cat_node.args[0] 1308 assert isinstance(input_nodes, Sequence) 1309 first_input_node = input_nodes[0] 1310 input_qspec_map = {} 1311 assert isinstance(first_input_node, Node) 1312 assert isinstance(cat_node, Node) 1313 input_qspec_map[first_input_node] = get_input_act_qspec(quantization_config) 1314 share_qparams_with_input_act0_qspec = SharedQuantizationSpec( 1315 (first_input_node, cat_node) 1316 ) 1317 1318 for input_node in input_nodes[1:]: 1319 if input_node not in input_qspec_map: 1320 # There has the case of cat same nodes: torch.cat([input0, input0], 1) 1321 assert isinstance(input_node, Node) 1322 input_qspec_map[input_node] = share_qparams_with_input_act0_qspec 1323 1324 cat_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( 1325 input_qspec_map=input_qspec_map, 1326 _annotated=True, 1327 _is_output_of_quantized_pattern=True, 1328 ) 1329 1330 def _annotate_propagation_quantizable_pattern_entry( 1331 self, 1332 gm: torch.fx.GraphModule, 1333 quantization_config: Optional[QuantizationConfig], 1334 filter_fn: Optional[FilterFn] = None, 1335 ): 1336 for node in gm.graph.nodes: 1337 self._annotate_propagation_quantizable_pattern( 1338 node, quantization_config, filter_fn 1339 ) 1340 1341 def _annotate_propagation_quantizable_pattern( 1342 self, node: Node, quantization_config, filter_fn 1343 ) -> None: 1344 # Propagate annotation to quantizable patterns. 1345 if ( 1346 (node.target in propagation_quantizable_ops) 1347 and (not _is_any_annotated([node])) 1348 and (node.op == "call_function") 1349 ): 1350 1351 def is_all_inputs_connected_to_quantized_op(input_nodes): 1352 # Ensure all the inputs connect to fusion pattern or quantized node 1353 for input_node in input_nodes: 1354 if not _is_quantized_op_pt2e(input_node): 1355 return False 1356 return True 1357 1358 if _skip_annotate([node], filter_fn): 1359 return 1360 1361 if quantization_config is None: 1362 _annotate_nodes_not_quantize(node) 1363 return 1364 1365 if node.target is torch.ops.aten.max_pool2d.default: 1366 # Recipe of maxpool2d: check input arg[0] of maxpool2d is quantized or not 1367 input_nodes_to_check = [node.all_input_nodes[0]] 1368 if not is_all_inputs_connected_to_quantized_op(input_nodes_to_check): 1369 if quantization_config is not None: 1370 warnings.warn( 1371 f"The input of maxpool2d is not quantized, skip annotate maxpool2d with config {quantization_config}." 1372 ) 1373 return 1374 1375 self._annotate_maxpool2d(node, quantization_config) 1376 return 1377 elif node.target is torch.ops.aten.cat.default: 1378 input_nodes_to_check = node.all_input_nodes 1379 if not is_all_inputs_connected_to_quantized_op(input_nodes_to_check): 1380 return 1381 self._annotate_cat(node, quantization_config) 1382 else: 1383 input_node = node.all_input_nodes[0] 1384 if not is_all_inputs_connected_to_quantized_op( 1385 [ 1386 input_node, 1387 ] 1388 ): 1389 return 1390 input_qspec_map = {} 1391 input_qspec_map[input_node] = get_input_act_qspec(quantization_config) 1392 node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( 1393 input_qspec_map=input_qspec_map, 1394 _annotated=True, 1395 _is_output_of_quantized_pattern=True, 1396 ) 1397 return 1398 1399 def _annotate_output_share_observer_as_input( 1400 self, input_node: Node, source_node: Node 1401 ): 1402 source_node_quantization_annotation = ( 1403 source_node.meta[QUANT_ANNOTATION_KEY] 1404 if QUANT_ANNOTATION_KEY in source_node.meta 1405 else None 1406 ) 1407 if ( 1408 source_node_quantization_annotation 1409 and source_node_quantization_annotation._is_output_of_quantized_pattern 1410 ): 1411 edge_or_node = (input_node, source_node) 1412 source_node_quantization_annotation.output_qspec = SharedQuantizationSpec( 1413 edge_or_node 1414 ) 1415 return 1416 1417 def _annotate_output_for_int8_in_int8_out_pattern_entry( 1418 self, 1419 model: torch.fx.GraphModule, 1420 ): 1421 for node in model.graph.nodes: 1422 self._annotate_output_for_int8_in_int8_out_pattern(node) 1423 1424 def _annotate_output_for_int8_in_int8_out_pattern( 1425 self, 1426 node: Node, 1427 ) -> None: 1428 r""" 1429 Check and insert observer at output of node in int8_in_int8_out_ops if needed. 1430 Recipe refers to https://github.com/intel/intel-extension-for-pytorch/blob/ 1431 90d19323d96afc53fcc22ba5a7bb3fb07fdd6c1c/intel_extension_for_pytorch/quantization/_utils.py#L495 1432 """ 1433 edge_or_node: Tuple[Node, Node] 1434 if (node.target in int8_in_int8_out_ops) and (_is_any_annotated([node])): 1435 if node.target == torch.ops.aten.max_pool2d.default: 1436 maxpool_node = node 1437 if not _is_all_annotated( 1438 [ 1439 maxpool_node, 1440 ] 1441 ): 1442 return 1443 1444 # Get the quantization_annotation from getitem_node 1445 maxpool_node_quantization_annotation = ( 1446 maxpool_node.meta[QUANT_ANNOTATION_KEY] 1447 if QUANT_ANNOTATION_KEY in maxpool_node.meta 1448 else None 1449 ) 1450 if ( 1451 maxpool_node_quantization_annotation 1452 and maxpool_node_quantization_annotation._is_output_of_quantized_pattern 1453 ): 1454 # Annotate the output_qspec of getitem_node 1455 input_act = maxpool_node.args[0] 1456 assert isinstance(input_act, Node) 1457 assert isinstance(maxpool_node, Node) 1458 edge_or_node = (input_act, maxpool_node) 1459 maxpool_node_quantization_annotation.output_qspec = ( 1460 SharedQuantizationSpec(edge_or_node) 1461 ) 1462 else: 1463 input_node = node.all_input_nodes[0] 1464 self._annotate_output_share_observer_as_input(input_node, node) 1465 return 1466 1467 def _annotate_linear( 1468 self, 1469 gm: torch.fx.GraphModule, 1470 quantization_config: Optional[QuantizationConfig], 1471 filter_fn: Optional[FilterFn] = None, 1472 ) -> None: 1473 linear_partitions = get_source_partitions( 1474 gm.graph, [torch.nn.Linear, torch.nn.functional.linear] 1475 ) 1476 linear_partitions = list( 1477 itertools.chain.from_iterable(linear_partitions.values()) 1478 ) 1479 for partition in linear_partitions: 1480 if len(partition.output_nodes) > 1: 1481 raise ValueError( 1482 "Linear partition cannot have more than one output node" 1483 ) 1484 linear_node = partition.output_nodes[0] 1485 if linear_node.op != "call_function" or linear_node.target not in ( 1486 torch.ops.aten.linear.default, 1487 ): 1488 raise ValueError(f"{linear_node} is not an aten linear operator") 1489 # skip annotation if it is already annotated 1490 if _skip_annotate([linear_node], filter_fn): 1491 continue 1492 self._annotate_linear_node_helper(linear_node, True, quantization_config) 1493 1494 def _annotate_linear_unary( 1495 self, 1496 gm: torch.fx.GraphModule, 1497 quantization_config: Optional[QuantizationConfig], 1498 filter_fn: Optional[FilterFn] = None, 1499 ) -> None: 1500 postop_list = [ 1501 torch.nn.ReLU, 1502 torch.nn.LeakyReLU, 1503 torch.nn.Tanh, 1504 torch.nn.GELU, 1505 ] 1506 fused_partitions: List[tuple] = [] 1507 for postop in postop_list: 1508 fused_partitions = fused_partitions + find_sequential_partitions( 1509 gm, [torch.nn.Linear, postop] 1510 ) 1511 for fused_partition in fused_partitions: 1512 linear_partition, unary_partition = fused_partition 1513 linear_node, unary_node = self._get_output_nodes_of_partitions( 1514 [linear_partition, unary_partition] 1515 ) 1516 if linear_node.op != "call_function" or linear_node.target not in ( 1517 torch.ops.aten.linear.default, 1518 ): 1519 continue 1520 if _skip_annotate([unary_node, linear_node], filter_fn): 1521 continue 1522 1523 if quantization_config is None: 1524 _annotate_nodes_not_quantize([linear_node, unary_node]) 1525 continue 1526 1527 self._annotate_linear_node_helper(linear_node, False, quantization_config) 1528 unary_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( 1529 _annotated=True, 1530 _is_output_of_quantized_pattern=True, 1531 ) 1532 1533 def _annotate_linear_binary_unary( 1534 self, 1535 gm: torch.fx.GraphModule, 1536 quantization_config: Optional[QuantizationConfig], 1537 filter_fn: Optional[FilterFn] = None, 1538 ) -> None: 1539 # linear + binary_op + (optional) unary op 1540 binary_op_list = [operator.add] 1541 unary_op_list = [torch.nn.ReLU, None] 1542 combinations = itertools.product(binary_op_list, unary_op_list) 1543 for binary_op, unary_op in combinations: 1544 has_unary = unary_op is not None 1545 seq_partition = [torch.nn.Linear, binary_op] 1546 if has_unary: 1547 seq_partition.append(unary_op) 1548 fused_partitions = find_sequential_partitions(gm, seq_partition) 1549 for fused_partition in fused_partitions: 1550 unary_partition, unary_node = None, None 1551 if has_unary: 1552 ( 1553 linear_partition, 1554 binary_partition, 1555 unary_partition, 1556 ) = fused_partition 1557 ( 1558 linear_node, 1559 binary_node, 1560 unary_node, 1561 ) = self._get_output_nodes_of_partitions( 1562 [linear_partition, binary_partition, unary_partition] 1563 ) 1564 else: 1565 linear_partition, binary_partition = fused_partition 1566 linear_node, binary_node = self._get_output_nodes_of_partitions( 1567 [linear_partition, binary_partition] 1568 ) 1569 if len(linear_node.users) != 1: 1570 # Linear Node should only has 1 user node 1571 continue 1572 ( 1573 linear_node_idx, 1574 extra_input_node_idx, 1575 ) = self._get_input_idx_for_binary_node(linear_node, binary_node) 1576 if (linear_node_idx is None) or (extra_input_node_idx is None): 1577 continue 1578 if linear_node != binary_node.args[linear_node_idx]: 1579 raise ValueError( 1580 f"{linear_node} doesn't match input of binary node" 1581 ) 1582 assert isinstance(linear_node, Node) 1583 if ( 1584 linear_node.op != "call_function" 1585 or linear_node.target != torch.ops.aten.linear.default 1586 ): 1587 # No linear node found to be fused with add 1588 continue 1589 node_list = ( 1590 [binary_node, linear_node] 1591 if unary_node is None 1592 else [unary_node, binary_node, linear_node] 1593 ) 1594 if _skip_annotate(node_list, filter_fn): 1595 continue 1596 1597 if quantization_config is None: 1598 _annotate_nodes_not_quantize(node_list) 1599 continue 1600 1601 self._annotate_linear_node_helper( 1602 linear_node, False, quantization_config 1603 ) 1604 # We don't insert q-dq before the binary input node due to accuracy issues 1605 binary_node.meta[ 1606 QUANT_ANNOTATION_KEY 1607 ] = _X86InductorQuantizationAnnotation( 1608 input_qspec_map={}, 1609 _annotated=True, 1610 _is_output_of_quantized_pattern=(not has_unary), 1611 ) 1612 if unary_node is not None: 1613 unary_node.meta[ 1614 QUANT_ANNOTATION_KEY 1615 ] = _X86InductorQuantizationAnnotation( 1616 _annotated=True, 1617 _is_output_of_quantized_pattern=True, 1618 ) 1619 1620 def validate(self, model: torch.fx.GraphModule) -> None: 1621 pass 1622 1623 @classmethod 1624 def get_supported_operators(cls) -> List[OperatorConfig]: 1625 return cls.supported_config_and_operators 1626