1# mypy: allow-untyped-defs 2import copy 3import warnings 4from dataclasses import asdict 5from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union 6 7import torch 8from torch._subclasses import FakeTensor 9from torch.ao.quantization import ( 10 _DerivedObserverOrFakeQuantize, 11 FixedQParamsFakeQuantize, 12 FixedQParamsObserver, 13 ObserverBase, 14 ObserverOrFakeQuantize, 15 PlaceholderObserver, 16) 17from torch.ao.quantization.backend_config import ( 18 BackendConfig, 19 DTypeConfig, 20 get_native_backend_config, 21) 22from torch.ao.quantization.backend_config.utils import ( 23 get_fusion_pattern_to_root_node_getter, 24 get_module_to_qat_module, 25 get_pattern_to_dtype_configs, 26) 27from torch.ao.quantization.observer import _is_activation_post_process, _PartialWrapper 28from torch.ao.quantization.qconfig import _is_reuse_input_qconfig, QConfigAny 29from torch.ao.quantization.qconfig_mapping import QConfigMapping 30from torch.ao.quantization.quantize import convert, propagate_qconfig_ 31from torch.ao.quantization.quantizer import ( 32 DerivedQuantizationSpec, 33 EdgeOrNode, 34 FixedQParamsQuantizationSpec, 35 QuantizationSpec, 36 QuantizationSpecBase, 37 SharedQuantizationSpec, 38) 39from torch.ao.quantization.utils import ( 40 _parent_name, 41 get_qconfig_dtypes, 42 get_swapped_custom_module_class, 43 NodePattern, 44 Pattern, 45) 46from torch.fx import GraphModule 47from torch.fx.graph import Graph, Node 48from torch.fx.node import Argument 49 50from ._equalize import is_equalization_observer, node_supports_equalization 51from .custom_config import PrepareCustomConfig, StandaloneModuleConfigEntry 52from .match_utils import _find_matches, _MatchResultWithQConfig 53from .pattern_utils import _sorted_patterns_dict 54from .qconfig_mapping_utils import ( 55 _generate_node_name_to_qconfig, 56 _get_flattened_qconfig_dict, 57 _update_qconfig_for_fusion, 58 _update_qconfig_for_qat, 59) 60from .quantize_handler import ( 61 _default_root_node_getter, 62 _get_pattern_to_quantize_handlers, 63 QuantizeHandler, 64) 65from .utils import ( 66 _insert_dequant_stubs_for_custom_module_lstm_output, 67 _is_custom_module_lstm, 68 _maybe_get_custom_module_lstm_from_node_arg, 69 _qconfig_satisfies_dtype_config_constraints, 70 all_node_args_have_no_tensors, 71 assert_and_get_unique_device, 72 get_custom_module_class_keys, 73 get_new_attr_name_with_prefix, 74 get_non_observable_arg_indexes_and_types, 75 node_arg_is_bias, 76 node_arg_is_weight, 77 NON_QUANTIZABLE_WEIGHT_OPS, 78 ObservedGraphModuleAttrs, 79) 80 81 82__all__ = [ 83 "insert_observers_for_model", 84 "prepare", 85 "propagate_dtypes_for_known_nodes", 86] 87 88 89# list of dtypes to not add observers to 90_DO_NOT_OBS_DTYPE_LIST = [int, float, torch.bool, None] 91_OBS_DTYPE_LIST = [ 92 torch.quint8, 93 torch.qint8, 94 torch.qint32, 95 torch.float16, 96 torch.uint8, 97 torch.int8, 98 torch.int16, 99 torch.int32, 100 torch.float8_e5m2, 101 torch.float8_e4m3fn, 102] 103 104_DEFAULT_FP32_OBS_OR_FQ_CTR = PlaceholderObserver.with_args(dtype=torch.float) 105 106# note: the following default target dtype info dicts are temporary, 107# should be moved to the new programmable API class soon 108_DEFAULT_FP32_QCONFIG_FOR_TARGET_DTYPE_INFO = { 109 "input_act_obs_or_fq_ctr": torch.ao.quantization.qconfig._default_fp32_placeholder_qconfig.activation, 110 "output_act_obs_or_fq_ctr": torch.ao.quantization.qconfig._default_fp32_placeholder_qconfig.activation, 111} 112 113_DEFAULT_QUINT8_QCONFIG_FOR_TARGET_DTYPE_INFO = { 114 "input_act_obs_or_fq_ctr": torch.ao.quantization.qconfig._default_quint8_placeholder_qconfig.activation, 115 "output_act_obs_or_fq_ctr": torch.ao.quantization.qconfig._default_quint8_placeholder_qconfig.activation, 116} 117 118 119def _get_observer_kwargs( 120 quant_spec: Union[QuantizationSpec, FixedQParamsQuantizationSpec] 121): 122 kwargs_dict = asdict(quant_spec) 123 return copy.deepcopy(kwargs_dict) 124 125 126def _get_qspec_for_arg( 127 arg: Node, 128 input_qspec_map: Dict[Node, QuantizationSpecBase], 129 named_modules: Dict[str, torch.nn.Module], 130) -> Optional[QuantizationSpecBase]: 131 while _is_activation_post_process_node(arg, named_modules): 132 arg = arg.args[0] # type: ignore[assignment] 133 return input_qspec_map.get(arg, None) 134 135 136def _create_obs_or_fq_from_qspec( 137 quantization_spec: Optional[QuantizationSpecBase], 138 obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize], 139 is_qat: bool, 140): 141 """Create observer or fake quantize objects based on quantization spec 142 143 Args: 144 quantization_spec: used to store parameters to create the observer or fake quantizer 145 obs_or_fq_map: this is a map from edge/output to the corresponding observer/fake_quant 146 instance, it may be reused for different edge/output depending on configuration 147 """ 148 if quantization_spec is None: 149 return None 150 if isinstance(quantization_spec, SharedQuantizationSpec): 151 edge_or_node = quantization_spec.edge_or_node 152 assert edge_or_node in obs_or_fq_map, ( 153 "please make sure only refer to edge or node that has " 154 f"observer/fake_quant inserted: '{edge_or_node}' not in\n{obs_or_fq_map.keys()}" 155 ) 156 return obs_or_fq_map[edge_or_node] 157 elif isinstance(quantization_spec, DerivedQuantizationSpec): 158 # can't use asdict, so not calling get_observer_kwargs here 159 kwargs = { 160 "dtype": quantization_spec.dtype, 161 "derive_qparams_fn": quantization_spec.derive_qparams_fn, 162 "quant_min": quantization_spec.quant_min, 163 "quant_max": quantization_spec.quant_max, 164 "qscheme": quantization_spec.qscheme, 165 "ch_axis": quantization_spec.ch_axis, 166 } 167 edge_or_nodes = quantization_spec.derived_from 168 obs_or_fqs = [obs_or_fq_map[k] for k in edge_or_nodes] 169 kwargs["obs_or_fqs"] = obs_or_fqs 170 return _DerivedObserverOrFakeQuantize.with_args(**kwargs)() 171 elif isinstance(quantization_spec, FixedQParamsQuantizationSpec): 172 kwargs = _get_observer_kwargs(quantization_spec) 173 observer_ctr = FixedQParamsObserver.with_args(**kwargs) 174 if is_qat: 175 return FixedQParamsFakeQuantize.with_args(observer=observer_ctr)() 176 else: 177 return observer_ctr() 178 179 assert isinstance(quantization_spec, QuantizationSpec) 180 observer_or_fake_quant_ctr = quantization_spec.observer_or_fake_quant_ctr 181 kwargs = _get_observer_kwargs(quantization_spec) 182 kwargs.pop("observer_or_fake_quant_ctr") 183 # we will remove is_dynamic from QuantizationSpec because 184 # it seems that dynamic range quantization 185 obs_or_fq_class = observer_or_fake_quant_ctr 186 if isinstance(observer_or_fake_quant_ctr, _PartialWrapper): 187 obs_or_fq_class = observer_or_fake_quant_ctr.p.func # type: ignore[union-attr, assignment] 188 if "PerChannel" not in obs_or_fq_class.__name__: # type: ignore[operator, union-attr] 189 kwargs.pop("ch_axis") 190 return observer_or_fake_quant_ctr.with_args(**kwargs)() 191 192 193def _needs_obs_or_fq( 194 prev_output_dtype: Any, 195 prev_output_is_dynamic: bool, 196 cur_target_dtype: Any, 197 cur_target_is_dynamic: bool, 198 reuse_input_obs_or_fq: bool, 199 is_zeroth_arg: bool = False, 200) -> bool: 201 """ 202 note: we will treat "not specified" as torch.float for now 203 utility function that checks if we should insert an observer or fake quant node 204 base on the requested dtype for the nodes from user 205 206 is_zeroth_arg: we only dynamically quantize the first arg of the node right now 207 this should be removed when we enable configuring dynamic quantization 208 for a specific argument, this can be removed if we deprecate fx graph mode 209 quantization 210 211 """ 212 213 # need to insert placeholder observer for dynamic quantization so that it can 214 # be converted to choose_qparams -> q -> dq in convert step 215 if cur_target_is_dynamic: 216 assert ( 217 cur_target_dtype in _OBS_DTYPE_LIST 218 ), f"Expected cur_target_dtype to be torch.float, but got: {cur_target_dtype}" 219 assert prev_output_dtype not in _DO_NOT_OBS_DTYPE_LIST 220 return is_zeroth_arg 221 if reuse_input_obs_or_fq: 222 return False 223 # non dynamic quantization 224 if cur_target_dtype in _OBS_DTYPE_LIST: 225 return ( 226 prev_output_dtype in _OBS_DTYPE_LIST + [torch.float] 227 and cur_target_dtype != prev_output_dtype 228 ) 229 230 # lots of error checking are skipped here for now 231 return False 232 233 234def _is_activation_post_process_node( 235 node: Node, named_modules: Dict[str, torch.nn.Module] 236) -> bool: 237 return ( 238 isinstance(node, torch.fx.Node) 239 and node.op == "call_module" 240 and _is_activation_post_process(named_modules[str(node.target)]) 241 ) 242 243 244def _get_dtype_and_is_dynamic( 245 obs_or_fq: Optional[ObserverOrFakeQuantize], 246) -> Tuple[Optional[torch.dtype], bool]: 247 """Given a constructor for observer or fake quant module, returns 248 a Tuple of dtype and is_dynamic 249 """ 250 # TODO: instead of instantiating the instance, we can use inspect to get the default args 251 if obs_or_fq is None: 252 return None, False 253 else: 254 return obs_or_fq.dtype, getattr(obs_or_fq, "is_dynamic", False) # type: ignore[return-value] 255 256 257def _is_input_arg_dtype_supported_by_backend( 258 arg: Argument, 259 node: Node, 260 qconfig: QConfigAny, 261 dtype_config: DTypeConfig, 262 backend_config: BackendConfig, 263) -> bool: 264 """Check if the configured qconfig for the argument 265 is supported by the backend or not 266 """ 267 if isinstance(arg, (list, tuple)): 268 return all( 269 _is_input_arg_dtype_supported_by_backend( 270 a, node, qconfig, dtype_config, backend_config 271 ) 272 for a in arg 273 ) 274 if not isinstance(arg, Node): 275 return True 276 # TODO: support check for standalone module 277 is_weight = node_arg_is_weight(node, arg) 278 is_bias = node_arg_is_bias(node, arg) 279 is_activation = not is_weight and not is_bias 280 if is_activation: 281 input_act_obs_or_fq_ctr = node.meta["target_dtype_info"].get( 282 "input_act_obs_or_fq_ctr" 283 ) 284 input_act_obs_or_fq = ( 285 input_act_obs_or_fq_ctr() if input_act_obs_or_fq_ctr else None 286 ) 287 qconfig_dtype, qconfig_is_dynamic = _get_dtype_and_is_dynamic( 288 input_act_obs_or_fq 289 ) 290 # TODO(future PR): remove the cast to bool below after figuring 291 # out why backend_config has is_dynamic set to None in some cases. 292 return (dtype_config.input_dtype is None) or ( 293 dtype_config.input_dtype == qconfig_dtype 294 and bool(dtype_config.is_dynamic) == bool(qconfig_is_dynamic) 295 and _qconfig_satisfies_dtype_config_constraints( 296 qconfig, dtype_config.input_dtype_with_constraints 297 ) 298 ) 299 elif is_weight: 300 # TODO: move dtype check into `_qconfig_satisfies_dtype_config_constraints` as well 301 weight_obs_or_fq_ctr = node.meta["target_dtype_info"].get( 302 "weight_obs_or_fq_ctr", None 303 ) 304 weight_obs_or_fq = weight_obs_or_fq_ctr() if weight_obs_or_fq_ctr else None 305 qconfig_weight_dtype, _ = _get_dtype_and_is_dynamic(weight_obs_or_fq) 306 backend_config_weight_dtype = dtype_config.weight_dtype 307 dtype_matches = qconfig_weight_dtype == backend_config_weight_dtype 308 qconfig_satisfies_constraints = _qconfig_satisfies_dtype_config_constraints( 309 qconfig, dtype_config.weight_dtype_with_constraints, is_activation=False 310 ) 311 return backend_config_weight_dtype is None or ( 312 dtype_matches and qconfig_satisfies_constraints 313 ) 314 else: # bias 315 # TODO: move dtype check into `_qconfig_satisfies_dtype_config_constraints` as well 316 bias_obs_or_fq_ctr = node.meta["target_dtype_info"].get( 317 "bias_obs_or_fq_ctr", None 318 ) 319 bias_obs_or_fq = bias_obs_or_fq_ctr() if bias_obs_or_fq_ctr else None 320 qconfig_bias_dtype, _ = _get_dtype_and_is_dynamic(bias_obs_or_fq) 321 backend_config_bias_dtype = dtype_config.bias_dtype 322 return ( 323 backend_config_bias_dtype is None 324 or qconfig_bias_dtype == backend_config_bias_dtype 325 ) 326 327 328def _is_output_dtype_supported_by_backend( 329 node: Node, 330 qconfig: QConfigAny, 331 dtype_config: DTypeConfig, 332) -> bool: 333 """Check if the configured qconfig for the output 334 is supported by the backend or not 335 """ 336 # TODO: move dtype check into `_qconfig_satisfies_dtype_config_constraints` as well 337 backend_config_output_dtype = dtype_config.output_dtype 338 # TODO: we should check is_dynamic here as well, the code from _is_input_arg_dtype_supported_by_backend 339 # from input activation check can be reused here 340 qconfig_output_dtype = None 341 output_act_obs_or_fq_ctr = node.meta["target_dtype_info"].get( 342 "output_act_obs_or_fq_ctr", _DEFAULT_FP32_OBS_OR_FQ_CTR 343 ) 344 output_act_obs_or_fq = ( 345 output_act_obs_or_fq_ctr() if output_act_obs_or_fq_ctr else None 346 ) 347 qconfig_output_dtype, qconfig_output_is_dynamic = _get_dtype_and_is_dynamic( 348 output_act_obs_or_fq 349 ) 350 # TODO: this is a hack because we can only specify one activation_obs_or_fq for 351 # qconfig (qconfig.activation), and we are only supporting dynamically quantized 352 # linear op which has fp32 output dtype, this should be removed if we generalize 353 # the structure of qconfig in the future 354 if qconfig_output_is_dynamic: 355 qconfig_output_dtype = torch.float32 356 dtype_matches = qconfig_output_dtype == backend_config_output_dtype 357 qconfig_satisfies_constraints = _qconfig_satisfies_dtype_config_constraints( 358 qconfig, dtype_config.output_dtype_with_constraints 359 ) 360 return backend_config_output_dtype is None or ( 361 dtype_matches and qconfig_satisfies_constraints 362 ) 363 364 365def _is_observer_in_same_graph( 366 node: Node, 367 named_modules: Dict[str, torch.nn.Module], 368 obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize], 369 is_qat, 370): 371 """Check if observer in same graph 372 when the node output is not fp32 and input is 'placeholder' 373 the input is assumed to be quantized, so it is observed 374 in a different place rather than not observed. 375 """ 376 node_output_dtype = _get_arg_target_dtype_as_output( 377 node, named_modules, obs_or_fq_map, is_qat 378 ) 379 if len(node.args) > 0 and isinstance(node.args[0], Node): 380 if ( 381 node_output_dtype in [torch.quint8, torch.uint8] 382 and node.args[0].op == "placeholder" 383 ): 384 return False 385 return True 386 387 388def _is_pattern_dtype_config_and_qconfig_supported_by_backend( 389 pattern: Optional[Pattern], 390 matched_node_pattern: Optional[List[Node]], 391 qconfig: QConfigAny, 392 backend_config: BackendConfig, 393) -> bool: 394 """Check if the dtype configuration of a pattern is supported by 395 the backend or not, and whether the qconfig satisfies constraints 396 specified in the corresponding dtype config. 397 """ 398 if backend_config is None or pattern is None: 399 return True 400 assert matched_node_pattern is not None and len(matched_node_pattern) >= 1 401 pattern_to_dtype_configs = get_pattern_to_dtype_configs(backend_config) 402 dtype_configs: List[DTypeConfig] = pattern_to_dtype_configs.get(pattern, []) 403 pattern_to_root_node_getter = get_fusion_pattern_to_root_node_getter(backend_config) 404 405 root_node_getter = pattern_to_root_node_getter.get( 406 pattern, _default_root_node_getter 407 ) 408 root_node = root_node_getter(matched_node_pattern) 409 input_node = root_node 410 output_node = matched_node_pattern[0] 411 for dtype_config in dtype_configs: 412 # check if arg dtype are supported 413 supported = True 414 for arg in list(input_node.args) + list(input_node.kwargs.values()): 415 supported = supported and _is_input_arg_dtype_supported_by_backend( 416 arg, input_node, qconfig, dtype_config, backend_config 417 ) 418 # check if output dtype is supported 419 supported = supported and _is_output_dtype_supported_by_backend( 420 output_node, qconfig, dtype_config 421 ) 422 if supported: 423 return True 424 return False 425 426 427def _get_standalone_module_configs( 428 node: Node, 429 named_modules: Dict[str, torch.nn.Module], 430 prepare_custom_config: PrepareCustomConfig, 431 parent_qconfig: QConfigAny, 432 parent_backend_config: Optional[BackendConfig], 433) -> Tuple[ 434 QConfigMapping, Tuple[Any, ...], PrepareCustomConfig, Optional[BackendConfig] 435]: 436 """ 437 Returns the standalone module QConfigMapping and PrepareCustomConfig 438 for `node`, assuming that the module pointed to by `node` is 439 a standalone modules. 440 """ 441 module_name = str(node.target) 442 module_type = type(named_modules[module_name]) # type: ignore[index] 443 # name config has precedence over type config 444 config_entry = StandaloneModuleConfigEntry(None, (), None, None) 445 config_entry = prepare_custom_config.standalone_module_classes.get( 446 module_type, config_entry 447 ) 448 config_entry = prepare_custom_config.standalone_module_names.get( 449 module_name, config_entry 450 ) 451 # fallback to use parent module's qconfig if user didn't specify qconfig dict 452 qconfig_mapping = config_entry.qconfig_mapping or QConfigMapping().set_global( 453 parent_qconfig 454 ) 455 example_inputs = config_entry.example_inputs 456 prepare_custom_config = config_entry.prepare_custom_config or PrepareCustomConfig() 457 backend_config = config_entry.backend_config or parent_backend_config 458 return (qconfig_mapping, example_inputs, prepare_custom_config, backend_config) 459 460 461def _qat_swap_modules( 462 root: torch.nn.Module, module_to_qat_module: Dict[Pattern, Type[torch.nn.Module]] 463) -> None: 464 convert(root, mapping=module_to_qat_module, inplace=True, remove_qconfig=False) 465 466 467def _add_matched_node_name_to_set(matched_node_pattern: NodePattern, s: Set[str]): 468 if isinstance(matched_node_pattern, Node): 469 s.add(matched_node_pattern.name) 470 elif isinstance(matched_node_pattern, (list, tuple)): 471 for maybe_node in matched_node_pattern: 472 _add_matched_node_name_to_set(maybe_node, s) 473 474 475def _insert_obs_or_fq( 476 node: Node, 477 obs_or_fq: ObserverOrFakeQuantize, 478 model: torch.nn.Module, 479 named_modules: Dict[str, torch.nn.Module], 480 graph: Graph, 481) -> Node: 482 """ 483 Attaches `obs_or_fq` to `model`, and creates a node which calls 484 `obs_or_fq` on the output of `node`. 485 486 obs_or_fq: an instance of Observer or FakeQuantize module 487 """ 488 model_device = assert_and_get_unique_device(model) 489 if model_device: 490 obs_or_fq.to(model_device) 491 # add obs_or_fq module as attribute 492 if is_equalization_observer(obs_or_fq): 493 prefix = node.name + "_equalization_process_" 494 else: 495 prefix = "activation_post_process_" 496 get_new_obs_or_fq_name = get_new_attr_name_with_prefix(prefix) 497 obs_or_fq_name = get_new_obs_or_fq_name(model) 498 setattr(model, obs_or_fq_name, obs_or_fq) 499 named_modules[obs_or_fq_name] = obs_or_fq 500 with graph.inserting_after(node): 501 new_obs = graph.create_node("call_module", obs_or_fq_name, (node,), {}) 502 return new_obs 503 504 505def _set_target_dtype_info_for_matched_node_pattern( 506 matched_node_pattern: NodePattern, 507 last_node: Node, 508 qconfig: QConfigAny, 509 qhandler: Optional[QuantizeHandler], 510 backend_config: BackendConfig, 511 named_modules: Dict[str, torch.nn.Module], 512 cache_for_no_tensor_check: Dict[Node, bool], 513 processed_nodes: Set[Node], 514) -> None: 515 """Sets the target_dtype_info for each node in matched_node_pattern 516 Note: processed_nodes is used to ensure we only process each node once 517 """ 518 if isinstance(matched_node_pattern, (list, tuple)): 519 for node_pattern in matched_node_pattern: 520 _set_target_dtype_info_for_matched_node_pattern( 521 node_pattern, 522 last_node, 523 qconfig, 524 qhandler, 525 backend_config, 526 named_modules, 527 cache_for_no_tensor_check, 528 processed_nodes, 529 ) 530 531 # set target_dtype_info if matched_node_pattern is a Node 532 # other types of matched object, e.g. int, float literals, are ignored 533 elif isinstance(matched_node_pattern, Node): 534 # for pyre 535 assert isinstance(matched_node_pattern, Node) 536 node = matched_node_pattern 537 if node in processed_nodes: 538 return 539 processed_nodes.add(node) 540 541 if qconfig is None: 542 return 543 # TODO: refactor the following code in terms of apply a qconfig to a pattern 544 # e.g. for a pattern with op1 -> op2 -> op3, and qconfig = QConfig(input_act=obs0, output_act=obs1) 545 # we set the input_obs_or_fq_ctr for the arguments of op1 to based on qconfig.input_act, 546 # and set output_obs_or_fq_ctr based on qconfig.output_act 547 # this also requires we extend the structure of QConfig to support more fine 548 # grained configurations 549 target_dtype_info: Dict[str, Any] = _get_target_activation_dtype_for_node( 550 node, 551 qconfig, 552 qhandler, 553 named_modules, 554 backend_config, 555 cache_for_no_tensor_check, 556 ) 557 node.meta["target_dtype_info"] = target_dtype_info 558 559 560def _get_target_activation_dtype_for_node( 561 node: Node, 562 qconfig: QConfigAny, 563 qhandler: Optional[QuantizeHandler], 564 named_modules: Dict[str, torch.nn.Module], 565 backend_config: BackendConfig, 566 cache_for_no_tensor_check: Dict[Node, bool], 567) -> Dict[str, Any]: 568 """ 569 For each op attribute in the op's input activation, output activation, 570 weight, bias - returns the settings of dtype and is_dynamic we expect 571 for the `quantize` call in the reference model representation, or None 572 if there is no `quantize` call needed. 573 574 For example, if we have a node corresponding to `op0` in 575 576 x0 -> op0 -> x1 577 578 And we want a reference quantized representation to be 579 580 x0 -> quant_static -> dequant -> op0 -> quant_dynamic -> dequant -> x1 581 582 Then this function will return 583 584 { 585 "input_act_obs_or_fq_ctr": MinMaxObserver.with_args(dtype=torch.quint8, is_dynamic=False), 586 "output_act_obs_or_fq_ctr": MinMaxObserver.with_args(dtype=torch.quint8, is_dynamic=False), 587 } 588 589 TODO(future PR, if needed): explicitly spell out the non-Tensor 590 dtypes. 591 """ 592 args_have_no_tensors = all_node_args_have_no_tensors( 593 node, named_modules, cache_for_no_tensor_check 594 ) 595 if args_have_no_tensors: 596 return { 597 "input_act_obs_or_fq_ctr": None, 598 "output_act_obs_or_fq_ctr": None, 599 } 600 # get qconfig to determine the eventual dtype of this node 601 if qconfig is not None: 602 act_dtype, weight_dtype, input_act_is_dynamic = get_qconfig_dtypes(qconfig) 603 604 # Currently `QConfig` only has one `activation` field. 605 # For static quantization, it is reused for both input 606 # and output activation. For dynamic quantization, this 607 # field is currently only used for the input activation, 608 # with the output activation being in fp32. 609 # In the future this may change as we add more fields 610 # to the `QConfig` object. 611 output_act_dtype = act_dtype if (not input_act_is_dynamic) else torch.float 612 613 bias_dtype = ( 614 torch.float16 615 if ( 616 act_dtype == torch.float16 617 and weight_dtype == torch.float16 618 and (not input_act_is_dynamic) 619 ) 620 else torch.float 621 ) 622 623 is_general_tensor_value_op = ( 624 qhandler is not None and qhandler.is_general_tensor_value_op() 625 ) 626 627 _is_standalone_module = qhandler is not None and qhandler.is_standalone_module() 628 629 weight_index = None 630 if ( 631 isinstance(node, Node) 632 and node.op == "call_function" 633 and node.target in backend_config._pattern_complex_format_to_config 634 ): 635 weight_index = backend_config._pattern_complex_format_to_config[ 636 node.target 637 ]._input_type_to_index.get("weight") 638 639 bias_index = None 640 if ( 641 isinstance(node, Node) 642 and node.op == "call_function" 643 and node.target in backend_config._pattern_complex_format_to_config 644 ): 645 bias_index = backend_config._pattern_complex_format_to_config[ 646 node.target 647 ]._input_type_to_index.get("bias") 648 649 return { 650 "input_act_obs_or_fq_ctr": qconfig.activation, 651 "weight_obs_or_fq_ctr": qconfig.weight, 652 "bias_obs_or_fq_ctr": PlaceholderObserver.with_args(dtype=bias_dtype), 653 "weight_index": weight_index, 654 "bias_index": bias_index, 655 "output_act_obs_or_fq_ctr": qconfig.activation, 656 "reuse_input_obs_or_fq": _is_reuse_input_qconfig(qconfig), 657 "input_output_share_observers": is_general_tensor_value_op, 658 "_is_standalone_module": _is_standalone_module, 659 } 660 return copy.copy(_DEFAULT_FP32_QCONFIG_FOR_TARGET_DTYPE_INFO) 661 662 663def _get_output_act_obs_or_fq( 664 arg: Node, 665 named_modules: Dict[str, torch.nn.Module], 666 obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize], 667 is_qat: bool, 668) -> ObserverOrFakeQuantize: 669 """Get the constructor for observer or fake quant object for 670 the argument in the original graph as the output of previous node, 671 skipping inserted observers 672 673 We are assuming that the observers are inserted correctly, and the dtype for 674 argument in quantized graph will match what is specified by the qconfig 675 """ 676 assert isinstance(arg, Node) 677 if "quantization_annotation" in arg.meta: 678 return _create_obs_or_fq_from_qspec( 679 arg.meta["quantization_annotation"].output_qspec, obs_or_fq_map, is_qat 680 ) 681 682 # Custom module LSTM output is a tuple that we broke down into the internal nodes in order 683 # to insert DeQuantStubs (see `_insert_dequant_stubs_for_custom_module_lstm_output`). 684 # Since we modified the graph in this case, we must trace back from the args through 685 # the specific nodes we added in order to reach the original LSTM node. Otherwise, we would 686 # not be able to accurately detect whether this node is a consumer of custom module LSTM. 687 custom_module_lstm_node = _maybe_get_custom_module_lstm_from_node_arg( 688 arg, named_modules 689 ) 690 output_act_obs_or_fq_ctr = None 691 if custom_module_lstm_node is not None: 692 output_act_obs_or_fq_ctr = custom_module_lstm_node.meta["target_dtype_info"][ 693 "output_act_obs_or_fq_ctr" 694 ] 695 output_act_obs_or_fq = ( 696 output_act_obs_or_fq_ctr() if output_act_obs_or_fq_ctr else None 697 ) 698 elif _is_activation_post_process_node(arg, named_modules): 699 observed_arg = arg.args[0] 700 assert isinstance( 701 observed_arg, Node 702 ), "Currently we only support observing Node" 703 if "quantization_annotation" in observed_arg.meta: 704 output_act_obs_or_fq = _create_obs_or_fq_from_qspec( 705 observed_arg.meta["quantization_annotation"].output_qspec, 706 obs_or_fq_map, 707 is_qat, 708 ) 709 else: 710 assert "target_dtype_info" in observed_arg.meta 711 output_act_obs_or_fq_ctr = observed_arg.meta["target_dtype_info"][ 712 "output_act_obs_or_fq_ctr" 713 ] 714 output_act_obs_or_fq = ( 715 output_act_obs_or_fq_ctr() if output_act_obs_or_fq_ctr else None 716 ) 717 else: 718 if "target_dtype_info" in arg.meta: 719 output_act_obs_or_fq_ctr = arg.meta["target_dtype_info"].get( 720 "output_act_obs_or_fq_ctr", _DEFAULT_FP32_OBS_OR_FQ_CTR 721 ) 722 else: 723 output_act_obs_or_fq_ctr = _DEFAULT_FP32_OBS_OR_FQ_CTR 724 output_act_obs_or_fq = ( 725 output_act_obs_or_fq_ctr() if output_act_obs_or_fq_ctr else None 726 ) 727 728 return output_act_obs_or_fq 729 730 731def _get_arg_target_dtype_as_output( 732 arg: Node, 733 named_modules: Dict[str, torch.nn.Module], 734 obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize], 735 is_qat: bool, 736) -> Optional[torch.dtype]: 737 arg_as_output_act_obs_or_fq = _get_output_act_obs_or_fq( 738 arg, named_modules, obs_or_fq_map, is_qat 739 ) 740 arg_as_output_target_dtype, _ = _get_dtype_and_is_dynamic( 741 arg_as_output_act_obs_or_fq 742 ) 743 return arg_as_output_target_dtype 744 745 746def _get_arg_as_input_act_obs_or_fq( 747 arg: Node, 748 node: Node, 749 named_modules: Dict[str, torch.nn.Module], 750 obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize], 751 is_qat: bool, 752) -> Optional[ObserverOrFakeQuantize]: 753 """Get the observer or fake quant constructor for the Argument `arg`, as input 754 to Node `node` 755 """ 756 assert isinstance(arg, Node) 757 # "input_qspec_map" is the more general design we'll use for pt2e path 758 # it is a map from input argument node to observer or fake quant constructor, for example 759 # for the following graph: 760 # x -> conv -> output 761 # 762 # we may annotate conv node like the following: 763 # conv.meta[...] = QuantizationAnnotation("input_qspec_map": {x: MinMaxObserver.with_args(dtype=torch.qint8)}, ...) 764 # 765 if "quantization_annotation" in node.meta: 766 input_qspec_map = node.meta["quantization_annotation"].input_qspec_map 767 input_arg_qspec = _get_qspec_for_arg(arg, input_qspec_map, named_modules) 768 if input_arg_qspec is None: 769 input_arg_obs_or_fq = _DEFAULT_FP32_OBS_OR_FQ_CTR() 770 else: 771 input_arg_obs_or_fq = _create_obs_or_fq_from_qspec( 772 input_arg_qspec, obs_or_fq_map, is_qat 773 ) 774 return input_arg_obs_or_fq 775 776 # we can remove the following path in the future if fx graph mode quantization is 777 # no longer used 778 is_weight = node_arg_is_weight(node, arg) 779 is_bias = node_arg_is_bias(node, arg) 780 is_activation = not is_weight and not is_bias 781 obs_or_fq_ctr = None 782 if is_activation: 783 obs_or_fq_ctr = node.meta["target_dtype_info"].get( 784 "input_act_obs_or_fq_ctr", _DEFAULT_FP32_OBS_OR_FQ_CTR 785 ) 786 elif is_weight: 787 if node.target not in NON_QUANTIZABLE_WEIGHT_OPS: 788 obs_or_fq_ctr = node.meta["target_dtype_info"].get( 789 "weight_obs_or_fq_ctr", _DEFAULT_FP32_OBS_OR_FQ_CTR 790 ) 791 else: 792 obs_or_fq_ctr = node.meta["target_dtype_info"].get( 793 "bias_obs_or_fq_ctr", _DEFAULT_FP32_OBS_OR_FQ_CTR 794 ) 795 return obs_or_fq_ctr() if obs_or_fq_ctr else None 796 797 798def _maybe_insert_input_observer_for_arg_or_kwarg( 799 node: Union[Node, Any], 800 arg: Argument, 801 qconfig: QConfigAny, 802 model: torch.nn.Module, 803 named_modules: Dict[str, torch.nn.Module], 804 graph: Graph, 805 qhandler: Optional[QuantizeHandler], 806 prepare_custom_config: PrepareCustomConfig, 807 obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize], 808 is_qat: bool, 809 backend_config: Optional[BackendConfig] = None, 810) -> Argument: 811 """ 812 Given a `node` and an `arg`, inserts an input observer between 813 `node` and `arg` if necessary. 814 """ 815 # for ops such as torch.cat([x0, x1]), 816 # traverse through the list 817 if isinstance(arg, (list, tuple)): 818 new_arg_to_return = [] 819 for inner_arg in arg: 820 new_inner_arg = _maybe_insert_input_observer_for_arg_or_kwarg( 821 node, 822 inner_arg, 823 qconfig, 824 model, 825 named_modules, 826 graph, 827 qhandler, 828 prepare_custom_config, 829 obs_or_fq_map, 830 is_qat, 831 backend_config, 832 ) 833 new_arg_to_return.append(new_inner_arg) 834 return type(arg)(new_arg_to_return) 835 836 if not isinstance(arg, Node): 837 return arg 838 assert isinstance(arg, Node) 839 # default (no observer) 840 new_arg = arg 841 842 is_standalone_module = qhandler is not None and qhandler.is_standalone_module() 843 # TODO: move this to a separate function 844 if not is_standalone_module: 845 # Note: qconfig can be None in this branch this we are getting act/fq from 846 # node.meta now 847 # regular flow for most nodes, except standalone modules 848 849 if "quantization_annotation" in node.meta: 850 reuse_input_obs_or_fq = node.meta[ 851 "quantization_annotation" 852 ]._reuse_input_obs_or_fq 853 else: 854 assert "target_dtype_info" in node.meta 855 # TODO: we are assuming "target_dtype_info" exists here, maybe 856 # a default value also need to be provided here 857 target_dtype_info = node.meta["target_dtype_info"] 858 # for nodes that doesn't have `reuse_input_obs_or_fq` configured, 859 # we'll default to False, this makes configuring this field optional for users 860 reuse_input_obs_or_fq = target_dtype_info.get( 861 "reuse_input_obs_or_fq", False 862 ) 863 arg_as_input_act_obs_or_fq = _get_arg_as_input_act_obs_or_fq( 864 arg, node, named_modules, obs_or_fq_map, is_qat 865 ) 866 ( 867 arg_as_input_target_dtype, 868 arg_as_input_target_is_dynamic, 869 ) = _get_dtype_and_is_dynamic(arg_as_input_act_obs_or_fq) 870 871 arg_as_output_act_obs_or_fq = _get_output_act_obs_or_fq( 872 arg, named_modules, obs_or_fq_map, is_qat 873 ) 874 ( 875 arg_as_output_target_dtype, 876 arg_as_output_target_is_dynamic, 877 ) = _get_dtype_and_is_dynamic(arg_as_output_act_obs_or_fq) 878 879 needs_obs_or_fq = _needs_obs_or_fq( 880 arg_as_output_target_dtype, 881 arg_as_output_target_is_dynamic, 882 arg_as_input_target_dtype, 883 arg_as_input_target_is_dynamic, 884 reuse_input_obs_or_fq, 885 is_zeroth_arg=len(node.args) > 0 and arg is node.args[0], 886 ) 887 888 else: 889 assert qconfig is not None 890 # custom flow for standalone modules 891 _, _, sm_prepare_custom_config, _ = _get_standalone_module_configs( 892 node, named_modules, prepare_custom_config, qconfig, backend_config 893 ) 894 sm_input_quantized_idxs = sm_prepare_custom_config.input_quantized_indexes 895 896 # for args, this is set to the index of the current arg 897 # for kwargs, this is left at None 898 cur_input_idx = None 899 for arg_idx, arg_to_check in enumerate(node.args): 900 if arg_to_check is arg: 901 cur_input_idx = arg_idx 902 break 903 904 if cur_input_idx is None: 905 needs_obs_or_fq = False 906 else: 907 arg_as_output_target_dtype = _get_arg_target_dtype_as_output( 908 arg, named_modules, obs_or_fq_map, is_qat 909 ) 910 arg_as_input_target_dtype = ( 911 torch.quint8 912 if cur_input_idx in sm_input_quantized_idxs 913 else torch.float 914 ) 915 needs_obs_or_fq = ( 916 arg_as_output_target_dtype != arg_as_input_target_dtype 917 ) and (arg_as_input_target_dtype != torch.float) 918 919 act_post_process_ctr = qconfig.activation 920 arg_as_input_act_obs_or_fq = ( 921 act_post_process_ctr() if act_post_process_ctr else None 922 ) 923 924 if needs_obs_or_fq: 925 existing_obs_node = None 926 927 # Before using the new observer, check if an observer 928 # of the correct type already exists. If it does, use it. 929 # This prevents duplicate observer insertions if a node is 930 # used by multiple nodes. 931 # TODO: this is looking into how the value is used in the future 932 # we should remove this 933 # removing this means we insert one observer for each use, even if they 934 # have the same dtype, we can have an extra pass that removes the extra observers 935 for maybe_obs_node in arg.users.keys(): 936 if maybe_obs_node.op == "call_module": 937 maybe_obs_mod = named_modules[maybe_obs_node.target] # type: ignore[index] 938 if ( 939 type(maybe_obs_mod) == type(arg_as_input_act_obs_or_fq) 940 and maybe_obs_mod.dtype 941 == arg_as_input_target_dtype # type: ignore[possibly-undefined] 942 ): 943 arg_as_input_act_obs_or_fq = maybe_obs_mod # type: ignore[assignment] 944 existing_obs_node = maybe_obs_node 945 break 946 947 assert arg_as_input_act_obs_or_fq is not None 948 obs_or_fq_map[(arg, node)] = arg_as_input_act_obs_or_fq 949 if existing_obs_node is None: 950 new_obs_node = _insert_obs_or_fq( 951 arg, arg_as_input_act_obs_or_fq, model, named_modules, graph 952 ) 953 # override this arg to be the observed arg 954 new_arg = new_obs_node 955 else: 956 new_arg = existing_obs_node 957 958 return new_arg 959 960 961def _maybe_insert_input_observers_for_node( 962 node: Node, 963 qconfig: QConfigAny, 964 model: torch.nn.Module, 965 named_modules: Dict[str, torch.nn.Module], 966 graph: Graph, 967 qhandler: Optional[QuantizeHandler], 968 prepare_custom_config: PrepareCustomConfig, 969 obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize], 970 is_qat: bool, 971 backend_config: Optional[BackendConfig] = None, 972) -> None: 973 """ 974 If needed, inserts observers to the input args and kwargs of `node`. 975 Note: modifies `node` inplace. 976 977 For example, if cur_node needs an observer after prev_node, we change from 978 979 prev_node -> cur_node 980 981 To 982 983 prev_node -> obs -> cur_node 984 985 Note: backend_config only needed for standalone_module node 986 """ 987 # Look through every input arg. If that arg's target dtype does not 988 # match the current node's target dtype, insert an observer. 989 new_args = [] 990 for arg in node.args: 991 new_arg = _maybe_insert_input_observer_for_arg_or_kwarg( 992 node, 993 arg, 994 qconfig, 995 model, 996 named_modules, 997 graph, 998 qhandler, 999 prepare_custom_config, 1000 obs_or_fq_map, 1001 is_qat, 1002 backend_config, 1003 ) 1004 new_args.append(new_arg) 1005 1006 new_kwargs = {} 1007 for k, kwarg in node.kwargs.items(): 1008 new_kwarg = _maybe_insert_input_observer_for_arg_or_kwarg( 1009 node, 1010 kwarg, 1011 qconfig, 1012 model, 1013 named_modules, 1014 graph, 1015 qhandler, 1016 prepare_custom_config, 1017 obs_or_fq_map, 1018 is_qat, 1019 backend_config, 1020 ) 1021 new_kwargs[k] = new_kwarg 1022 1023 # assign the new args and kwargs to the node, inplace 1024 node.args = tuple(new_args) 1025 node.kwargs = new_kwargs 1026 1027 1028def _maybe_insert_input_equalization_observers_for_node( 1029 node: Node, 1030 equalization_qconfig: Any, 1031 model: torch.nn.Module, 1032 named_modules: Dict[str, torch.nn.Module], 1033 graph: Graph, 1034 is_branch: bool, 1035) -> None: 1036 """ 1037 If `node` needs to be equalized, find the input/weight observers it needs in 1038 `equalization_qconfig`, creates them, and inserts it into `graph`. 1039 1040 If `node` does not need an equalization observer, returns None. 1041 """ 1042 if equalization_qconfig is None or not node_supports_equalization( 1043 node, named_modules 1044 ): 1045 return 1046 1047 if is_branch: 1048 warnings.warn(f"Cannot equalize {node} because it is part of a branch.") 1049 return 1050 1051 new_args = [] 1052 for arg in node.args: 1053 if not isinstance(arg, Node) or node_arg_is_bias(node, arg): 1054 new_args.append(arg) 1055 continue 1056 1057 is_weight = node_arg_is_weight(node, arg) 1058 1059 act_eq_process_ctr = ( 1060 equalization_qconfig.weight 1061 if is_weight 1062 else equalization_qconfig.input_activation 1063 ) 1064 1065 new_eq_obs_mod = act_eq_process_ctr() 1066 new_eq_obs_node = _insert_obs_or_fq( 1067 arg, new_eq_obs_mod, model, named_modules, graph 1068 ) 1069 1070 new_args.append(new_eq_obs_node) 1071 1072 # assign the new args and kwargs to the node, inplace 1073 node.args = tuple(new_args) 1074 1075 1076def _maybe_insert_output_observer_for_node( 1077 node: Node, 1078 model: torch.nn.Module, 1079 named_modules: Dict[str, torch.nn.Module], 1080 graph: Graph, 1081 obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize], 1082 is_qat: bool, 1083) -> Optional[Node]: 1084 """ 1085 If `node` needs an output observer, creates it, inserts it into `graph` 1086 and returns it. 1087 1088 If `node` does not need an output observer, returns None. 1089 1090 Note: inserting dynamic quantization ops for output is not supported in fx graph mode 1091 quantization code path right now 1092 """ 1093 assert node.op != "output", "observer insertion for outputs is handled elsewhere" 1094 1095 is_standalone_module = False 1096 if "quantization_annotation" in node.meta: 1097 output_act_obs_or_fq = _create_obs_or_fq_from_qspec( 1098 node.meta["quantization_annotation"].output_qspec, obs_or_fq_map, is_qat 1099 ) 1100 else: 1101 assert "target_dtype_info" in node.meta 1102 is_standalone_module = node.meta["target_dtype_info"].get( 1103 "_is_standalone_module", False 1104 ) 1105 output_act_obs_or_fq_ctr = node.meta["target_dtype_info"].get( 1106 "output_act_obs_or_fq_ctr" 1107 ) 1108 output_act_obs_or_fq = ( 1109 output_act_obs_or_fq_ctr() if output_act_obs_or_fq_ctr else None 1110 ) 1111 target_dtype, target_is_dynamic = _get_dtype_and_is_dynamic(output_act_obs_or_fq) 1112 # uncomment after we support reuse_input_obs_or_fq properly by having separate 1113 # implemntations for this key instead of reusing the input_output_share_observers 1114 # code 1115 # reuse_input_obs_or_fq = node.meta["target_dtype_info"].get("reuse_input_obs_or_fq", False) 1116 # for now we set this to False since reuse_input_obs_or_fq for 1117 # the output of a node is implementation in the same code path as observer sharing, 1118 # we should refactor this part to make it clearer in the future 1119 # and we would be able to read this from config directly 1120 reuse_input_obs_or_fq = False 1121 1122 # Note: prev_output_dtype = torch.float and prev_output_is_dynamic=False 1123 # because the prev_output is the output of an fp32 op, althought technically 1124 # we should get the dtype of the output from node.meta["val"] in the future 1125 # if we deprecate fx graph mode quantization 1126 needs_obs_or_fq = _needs_obs_or_fq( 1127 torch.float, False, target_dtype, target_is_dynamic, reuse_input_obs_or_fq 1128 ) 1129 # currently the activation in QConfig(activation=...,) is for both input 1130 # and output, and when the activation is configured to be dynamic quantization 1131 # e.g. PlaceholderObserver(dtype=torch.quint8, is_dynamic=True, ...), it means 1132 # the input should by dynamically quantized, but output should not be quantized 1133 # 1134 # there is no way we can specify different observer/fq for input and output 1135 # activation through QConfig today, this limitation is lifted in the 1136 # quantizer/annotation API in pytorch 2.0 export quantization code path, 1137 # but since this code is reused, annotating output to be dynamically quantized 1138 # would not work either for that. 1139 # we can change QConfig to support input/output activation if we want 1140 # to remove the following check, or if we can deprecate fx graph mode quantization 1141 if target_is_dynamic: 1142 needs_obs_or_fq = False 1143 1144 # we never insert observers to output of standalone module, we assume 1145 # if needed, they are inserted inside the standalone module 1146 needs_obs_or_fq = needs_obs_or_fq and (not is_standalone_module) 1147 1148 if needs_obs_or_fq: 1149 obs_or_fq_map[node] = output_act_obs_or_fq 1150 return _insert_obs_or_fq( 1151 node, output_act_obs_or_fq, model, named_modules, graph 1152 ) 1153 else: 1154 return None 1155 1156 1157def _maybe_insert_observers_before_graph_output( 1158 graph_output_node: Node, 1159 model: torch.nn.Module, 1160 named_modules: Dict[str, torch.nn.Module], 1161 graph: Graph, 1162 obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize], 1163 is_qat: bool, 1164) -> None: 1165 """ 1166 If the output needs to be quantized and there are any nodes 1167 in the output which are not already observed, inserts observers 1168 for those nodes. 1169 """ 1170 1171 def _recursive_maybe_replace_node_with_obs( 1172 maybe_node: Argument, 1173 model: torch.nn.Module, 1174 named_modules: Dict[str, torch.nn.Module], 1175 graph: Graph, 1176 ) -> Argument: 1177 """ 1178 Navigate an arbitrary data structure of lists, tuples, dicts. 1179 For each container type, recurse on all inputs. Once any Node 1180 is found, insert an observer if needed and do not recurse further. 1181 1182 For example, given a structure of 1183 1184 {'foo1': [[bar1]], 'foo2': {'foo3': [[[bar3]]]}} 1185 1186 we recurse down to bar1 and bar3, observe them if necessary, 1187 and if we inserted an observer then replace the original node 1188 with its observer. 1189 1190 Returns the data structure with all nodes needing observation being 1191 replaced by their observers. 1192 """ 1193 if isinstance(maybe_node, Node): 1194 # check dtype of this node 1195 arg_as_output_target_dtype = _get_arg_target_dtype_as_output( 1196 maybe_node, named_modules, obs_or_fq_map, is_qat 1197 ) 1198 observer_mod = None 1199 arg_as_input_target_dtype = torch.float 1200 if "target_dtype_info" in maybe_node.meta: 1201 observer_cls = maybe_node.meta["target_dtype_info"].get( 1202 "input_act_obs_or_fq_ctr", None 1203 ) 1204 if observer_cls is not None: 1205 observer_mod = observer_cls() 1206 arg_as_input_target_dtype = observer_mod.dtype 1207 # TODO: this does not handle dynamic quantization yet 1208 need_obs = ( 1209 arg_as_output_target_dtype != arg_as_input_target_dtype 1210 and arg_as_input_target_dtype != torch.float 1211 ) 1212 if need_obs: 1213 assert observer_mod is not None 1214 # insert observer 1215 observer_node = _insert_obs_or_fq( 1216 maybe_node, observer_mod, model, named_modules, graph 1217 ) 1218 return observer_node 1219 else: 1220 return maybe_node 1221 elif isinstance(maybe_node, (list, tuple)): 1222 results = [] 1223 for inner_node in maybe_node: 1224 results.append( 1225 _recursive_maybe_replace_node_with_obs( 1226 inner_node, model, named_modules, graph 1227 ) 1228 ) 1229 if isinstance(maybe_node, list): 1230 return results 1231 else: 1232 return tuple(results) 1233 elif isinstance(maybe_node, dict): 1234 results_dict = {} 1235 for k, inner_v in maybe_node.items(): 1236 results_dict[k] = _recursive_maybe_replace_node_with_obs( 1237 inner_v, model, named_modules, graph 1238 ) 1239 return results_dict 1240 elif maybe_node is None: 1241 return None 1242 else: 1243 raise Exception( # noqa: TRY002 1244 "Unhandled type for returned node:", maybe_node 1245 ) 1246 1247 new_args = [] 1248 for old_arg in graph_output_node.args: 1249 new_args.append( 1250 _recursive_maybe_replace_node_with_obs(old_arg, model, named_modules, graph) 1251 ) 1252 1253 graph_output_node.args = tuple(new_args) # type: ignore[assignment] 1254 1255 1256def _maybe_propagate_dtype_for_node( 1257 node: Node, 1258 target_dtype: Union[torch.dtype, type], 1259 node_name_to_match_result_with_qconfig: Dict[str, _MatchResultWithQConfig], 1260) -> None: 1261 """ 1262 Assigns `target_dtype` to `node`, setting `is_dynamic` to False. If `node` 1263 is a general tensor shape op, also call this function recursively on 1264 the first argument, to propagate the dtype to the caller. 1265 """ 1266 node.meta["target_dtype_info"]["input_act_obs_or_fq_ctr"] = None 1267 node.meta["target_dtype_info"]["output_act_obs_or_fq_ctr"] = None 1268 # if this is a copy node, propagate to first arg 1269 ( 1270 root_node, 1271 _, 1272 pattern, 1273 qhandler, 1274 qconfig, 1275 ) = node_name_to_match_result_with_qconfig.get( 1276 node.name, (None, None, None, None, None) 1277 ) 1278 # TODO: probably need to remove `is_general_tensor_value_op` 1279 if qhandler is not None and qhandler.is_general_tensor_value_op(): 1280 prev_node = node.args[0] 1281 if isinstance(prev_node, Node): 1282 _maybe_propagate_dtype_for_node( 1283 prev_node, target_dtype, node_name_to_match_result_with_qconfig 1284 ) 1285 1286 1287def propagate_dtypes_for_known_nodes( 1288 graph: Graph, 1289 node_name_to_match_result_with_qconfig: Dict[str, _MatchResultWithQConfig], 1290) -> None: 1291 """ 1292 Currently we assume that inputs to the graph are either `torch.float` or 1293 `torch.quint8`, which is not always correct. For ops such as 1294 `x.masked_fill(mask, value)`, we know that the dtype of `mask` is a 1295 `BoolTensor`. Propagate this information throughout the graph. 1296 1297 Note: not all dtypes in the graph will be correct after this pass, but a 1298 higher percentage of them will be correct. Hopefully in the future we can 1299 replace this with a better way to reason about dtypes of tensors. 1300 """ 1301 for node in graph.nodes: 1302 non_observable_arg_dict = get_non_observable_arg_indexes_and_types(node) 1303 1304 for arg_type in non_observable_arg_dict: 1305 non_observable_indices = non_observable_arg_dict[arg_type](node) 1306 1307 for index in non_observable_indices: 1308 arg = node.args[index] 1309 1310 # when an argument is a tuple, it does not show up as another node so we need to go through 1311 # all elements of the tuple manually 1312 if isinstance(arg, (tuple, list)): 1313 arg_list = list(arg) 1314 else: 1315 arg_list = [arg] 1316 1317 for cur_arg in arg_list: 1318 # hard coded arguments show up but aren't `Node` typed and do not need dtype propagated 1319 if isinstance(cur_arg, torch.fx.node.Node): 1320 _maybe_propagate_dtype_for_node( 1321 cur_arg, arg_type, node_name_to_match_result_with_qconfig 1322 ) 1323 1324 1325def _maybe_make_input_output_share_observers( 1326 node: Node, 1327 model: torch.nn.Module, 1328 named_modules: Dict[str, torch.nn.Module], 1329) -> bool: 1330 """ 1331 Ensures that we share an observer 1332 for all input arguments as well as the output argument. In detail, given 1333 a graph of 1334 1335 x0 -> obs0 -> op -> x2 1336 / 1337 x1 -> obs1 / 1338 1339 where node obs0 points to observer instance observer0, 1340 obs1 points to observer1 and obs2 points to observer2, we make nodes obs1 1341 and ob2 point to observer0. 1342 Returns: whether the operation succeeded or not 1343 """ 1344 first_arg = None 1345 # find the first non-Tensor arg 1346 for i in range(len(node.args)): 1347 if isinstance(node.args[i], (Node, list, tuple)): 1348 first_arg = node.args[i] 1349 break 1350 1351 # if there is no non-Tensor arg, return directly 1352 if first_arg is None: 1353 return False 1354 1355 if isinstance(first_arg, (list, tuple)): 1356 first_arg_arg = first_arg[0] 1357 elif isinstance(first_arg, Node): 1358 first_arg_arg = first_arg 1359 else: 1360 return False 1361 1362 # if we have a graph such as 1363 # observed_node -> non_observed_node -> cat 1364 # we need to navigate up to the first observer 1365 iteration_guard = 0 1366 while not _is_activation_post_process_node(first_arg_arg, named_modules): 1367 if not isinstance(first_arg_arg, Node): 1368 return False 1369 # did not find an activation_post_process for the op 1370 if first_arg_arg.op == "placeholder": 1371 return False 1372 # trace back the args until we found the first Tensor/Node 1373 trace_back_node = None 1374 for i in range(len(first_arg_arg.args)): 1375 trace_back_node = first_arg_arg.args[i] 1376 if isinstance(trace_back_node, Node): 1377 break 1378 if trace_back_node is None: 1379 return False 1380 first_arg_arg = trace_back_node 1381 1382 iteration_guard += 1 1383 if iteration_guard > 10000: 1384 raise AssertionError("Unable to find observer of previous node") 1385 1386 assert isinstance(first_arg_arg, Node) 1387 target_to_use = first_arg_arg.target 1388 assert isinstance(target_to_use, str) 1389 obs_mod_to_use = named_modules[target_to_use] 1390 1391 if isinstance(first_arg, (list, tuple)): 1392 # set all other input observer nodes to use that module 1393 for input_idx, input_arg in enumerate(first_arg): 1394 if input_idx == 0: 1395 continue 1396 iteration_guard = 0 1397 while not _is_activation_post_process_node(input_arg, named_modules): 1398 # failed to trace back since no input arg for the current node 1399 if len(input_arg.args) < 1: 1400 return False 1401 input_arg = input_arg.args[0] 1402 iteration_guard += 1 1403 if iteration_guard > 10000: 1404 raise AssertionError("Unable to find observer of previous node") 1405 1406 parent_name, name = _parent_name(input_arg.target) 1407 setattr(named_modules[parent_name], name, obs_mod_to_use) 1408 1409 # set the output observer node to use that module 1410 for output_obs_node in node.users.keys(): 1411 assert _is_activation_post_process_node(output_obs_node, named_modules) 1412 parent_name, name = _parent_name(output_obs_node.target) 1413 setattr(named_modules[parent_name], name, obs_mod_to_use) 1414 1415 # TODO(future PR): delete the orphaned observer modules 1416 return True 1417 1418 1419def _remove_output_observer( 1420 node: Node, model: torch.nn.Module, named_modules: Dict[str, torch.nn.Module] 1421): 1422 items = list(node.users.items()) 1423 for output_obs_node, _ in items: 1424 assert _is_activation_post_process_node(output_obs_node, named_modules) 1425 output_obs_node.replace_all_uses_with(node) 1426 model.graph.erase_node(output_obs_node) # type: ignore[union-attr, operator] 1427 1428 1429def _swap_custom_module_to_observed( 1430 node: Node, 1431 qconfig: QConfigAny, 1432 named_modules: Dict[str, torch.nn.Module], 1433 prepare_custom_config: PrepareCustomConfig, 1434): 1435 custom_module = named_modules[node.target] # type: ignore[index] 1436 custom_module_class_mapping = prepare_custom_config.float_to_observed_mapping 1437 observed_custom_module_class = get_swapped_custom_module_class( 1438 custom_module, custom_module_class_mapping, qconfig 1439 ) 1440 observed_custom_module = observed_custom_module_class.from_float(custom_module) 1441 parent_name, name = _parent_name(node.target) 1442 setattr(named_modules[parent_name], name, observed_custom_module) 1443 1444 1445def insert_observers_for_model( 1446 model: GraphModule, 1447 node_name_to_match_result_with_qconfig: Dict[str, _MatchResultWithQConfig], 1448 node_name_to_qconfig: Dict[str, QConfigAny], 1449 prepare_custom_config: PrepareCustomConfig, 1450 equalization_config_map: Dict[str, Any], 1451 backend_config: BackendConfig, 1452 observed_node_names: Set[str], 1453 is_qat: bool, 1454) -> Optional[Node]: 1455 """ 1456 Inserts observers, using the following high level algorithm: 1457 1458 For each node in the graph: 1459 1. determine the target dtype of this node in the quantized graph, and save 1460 it for future steps 1461 2. determine the target dtype or all args and kwargs of this node 1462 3. if any arg or kwarg's target dtype does not match the current node's 1463 dtype, insert an observer 1464 4. if the current node needs an output observer, insert it 1465 1466 For example: 1467 1468 - starting graph: 1469 x0 -> linear -> x1 1470 1471 - observed graph after processing x0: 1472 x0(fp32) 1473 1474 - observed graph after processing linear: 1475 x0(fp32) -> x0_obs0(int8) -> linear(int8) -> linear_obs0(int8) 1476 1477 - observed graph after processing x1: 1478 x0(fp32) -> x0_obs0(int8) -> linear(int8) -> linear_obs0(int8) -> x1 1479 1480 After a node is processed, the naive observer placement is guaranteed to be 1481 complete for that node and all of its predecessors. There can be future 1482 passes which optimize the graph by deduplicating observers, etc. 1483 """ 1484 1485 # node.meta["target_dtype_info"] stores the target dtype information 1486 # that's derived from qconfig for the Node, for example, if we have 1487 # a conv2d node that has a qconfig 1488 # qconfig = QConfig(activation=..., weight=...) 1489 # # information for input and bias node omitted 1490 # # for getattr node 1491 # # weight = getattr(self, 'weight') 1492 # weight.meta["target_dtype_info"] = { 1493 # 'output_act_obs_or_fq_ctr': qconfig.weight, 1494 # } 1495 # # for conv2d node 1496 # # conv2d = call_function[target=torch.nn.functional.conv2d]( 1497 # # args=(input, weight, bias)) 1498 # conv2d.meta["target_dtype_info"] = { 1499 # 'input_act_obs_or_fq_ctr': qconfig.activation 1500 # 'weight_obs_or_fq_ctr': qconfig.weight, 1501 # 'bias_obs_or_fq_ctr': PlaceholderObserver.with_args(dtype=torch.float32), 1502 # 'output_act_obs_or_fq_ctr': qconfig.activation, 1503 # } 1504 # 1505 cache_for_no_tensor_check: Dict[Node, bool] = {} 1506 1507 # first, populate the dtype map based only on qconfig and qhandler 1508 # this assumes: 1509 # graph inputs are fp32 by default, and int8 where overriden 1510 # other nodes output dtype is specified by the qconfig 1511 named_modules = dict(model.named_modules(remove_duplicate=False)) 1512 1513 input_quantized_idxs: List[int] = prepare_custom_config.input_quantized_indexes 1514 output_quantized_idxs: List[int] = prepare_custom_config.output_quantized_indexes 1515 processed_nodes: Set[Node] = set() 1516 # initialize target_dtype_info 1517 for node in model.graph.nodes: 1518 node.meta["target_dtype_info"] = copy.copy( 1519 _DEFAULT_FP32_QCONFIG_FOR_TARGET_DTYPE_INFO 1520 ) 1521 1522 inputs_seen_counter = 0 1523 outputs_seen_counter = 0 1524 placeholder_node_to_input_index: Dict[Node, int] = {} 1525 # TODO: we probably don't need this counter since each graph will only have 1526 # one output node? 1527 output_node_to_output_index: Dict[Node, int] = {} 1528 for node in model.graph.nodes: 1529 if node.op == "placeholder": 1530 placeholder_node_to_input_index[node] = inputs_seen_counter 1531 inputs_seen_counter += 1 1532 if node.op == "output": 1533 output_node_to_output_index[node] = outputs_seen_counter 1534 outputs_seen_counter += 1 1535 1536 # Step 1, set the observer or fake quantize module constructor for each node in the 1537 # matched_node_pattern 1538 1539 for match_res_with_qconfig in node_name_to_match_result_with_qconfig.values(): 1540 ( 1541 last_node, 1542 matched_node_pattern, 1543 pattern, 1544 qhandler, 1545 qconfig, 1546 ) = match_res_with_qconfig 1547 assert qhandler is not None 1548 _set_target_dtype_info_for_matched_node_pattern( 1549 matched_node_pattern, 1550 last_node, 1551 qconfig, 1552 qhandler, 1553 backend_config, 1554 named_modules, 1555 cache_for_no_tensor_check, 1556 processed_nodes, 1557 ) 1558 1559 # Step 2. Special cases for some operators, we might be able to remove them 1560 # in the future if we know dtype information of each node better 1561 1562 # Step 2.1. some settings are not based on patterns, we need to process each node 1563 # instead 1564 for node in model.graph.nodes: 1565 if ( 1566 node.op == "placeholder" 1567 and placeholder_node_to_input_index[node] in input_quantized_idxs 1568 ): 1569 # users are not supposed to call calculate_qparams on PlaceholderObserver, and 1570 # this is OK because we are using this as a way to encode the dtypes of input 1571 # tensor, we won't actually insert these observers in the graph and won't 1572 # actually call calculate_qparams 1573 node.meta["target_dtype_info"] = copy.copy( 1574 _DEFAULT_QUINT8_QCONFIG_FOR_TARGET_DTYPE_INFO 1575 ) 1576 elif node.op in ("call_module", "call_method", "call_function"): 1577 args_have_no_tensors = all_node_args_have_no_tensors( 1578 node, named_modules, cache_for_no_tensor_check 1579 ) 1580 if args_have_no_tensors: 1581 node.meta["target_dtype_info"] = { 1582 "input_act_obs_or_fq_ctr": None, 1583 "output_act_obs_or_fq_ctr": None, 1584 } 1585 elif ( 1586 node.op == "output" 1587 and output_node_to_output_index[node] in output_quantized_idxs 1588 ): 1589 # TODO(future PR): update the output_quantized_idxs API to match 1590 # arbitrary data structures. There is always a single output, and 1591 # that output can have arbitrary nesting of values. List[int] is 1592 # not the right data type for this. 1593 1594 # TODO(future PR): support more dtypes in model outputs, if necessary 1595 node.meta["target_dtype_info"] = copy.copy( 1596 _DEFAULT_QUINT8_QCONFIG_FOR_TARGET_DTYPE_INFO 1597 ) 1598 1599 # Step 2.2, for nodes with known input dtypes, propagate them throughout the 1600 # graph. For example, if there is a call such as 1601 # x1 = x0.masked_fill(mask, 1) 1602 # we propagate the type of mask to be torch.bool 1603 propagate_dtypes_for_known_nodes( 1604 model.graph, node_name_to_match_result_with_qconfig 1605 ) 1606 1607 # Step 3, check if the requested target_dtype_info is supported by backend or not 1608 # if not, we'll reset the target_dtye_info to use the default (float Tensor) 1609 1610 # reset the counters and set of processed_nodes 1611 processed_nodes: Set[Node] = set() 1612 for match_res_with_qconfig in node_name_to_match_result_with_qconfig.values(): 1613 ( 1614 last_node, 1615 matched_node_pattern, 1616 pattern, 1617 qhandler, 1618 qconfig, 1619 ) = match_res_with_qconfig 1620 is_supported_by_backend = ( 1621 _is_pattern_dtype_config_and_qconfig_supported_by_backend( 1622 pattern, matched_node_pattern, qconfig, backend_config 1623 ) 1624 ) 1625 assert qhandler is not None 1626 1627 # get output_act_dtype so that we don't also reset the special typed nodes 1628 # TODO: we might want to handle these more uniformly with the default path 1629 # this can be improved if we can use node.meta["val"] 1630 output_act_or_fq_ctr = node.meta["target_dtype_info"][ 1631 "output_act_obs_or_fq_ctr" 1632 ] 1633 output_act_or_fq = output_act_or_fq_ctr() if output_act_or_fq_ctr else None 1634 output_act_dtype, _ = _get_dtype_and_is_dynamic(output_act_or_fq) 1635 if not is_supported_by_backend and output_act_dtype not in [ 1636 None, 1637 int, 1638 float, 1639 torch.bool, 1640 ]: 1641 # restore target_dtype_info to default if it is not supported by backend 1642 _set_target_dtype_info_for_matched_node_pattern( 1643 matched_node_pattern, 1644 last_node, 1645 torch.ao.quantization.qconfig._default_fp32_placeholder_qconfig, 1646 None, 1647 backend_config, 1648 named_modules, 1649 cache_for_no_tensor_check, 1650 processed_nodes, 1651 ) 1652 1653 # After this point, the current node and all of its arguments 1654 # have a target_dtype_info assigned. Now, we insert observers for inputs 1655 # of this node (if needed for this node), and the output of this node 1656 # (if needed for this node). 1657 1658 # Since we are mutating the graph as we go, we iterate over the original 1659 # nodes before observer insertion, instead of model.graph.nodes. 1660 nodes_before_observation = list(model.graph.nodes) 1661 1662 # Avoid duplicates custom module swaps for multiple nodes with same target. 1663 custom_module_names_already_swapped: Set[str] = set() 1664 1665 # TODO: reuse placeholder_node_to_input_index and output_node_to_output_index 1666 # reset inputs/outputs counters 1667 inputs_seen_counter = 0 1668 outputs_seen_counter = 0 1669 results_node = None 1670 obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize] = {} 1671 1672 # TODO: change this to insert obs/fq by pattern instead of by node 1673 for node in nodes_before_observation: 1674 if node.op == "placeholder": 1675 # if a graph input is in fp32, it does not need observation 1676 # if a graph input is in int8, we assume the observation happens 1677 # outside of the graph, and no additional observation is needed 1678 pass 1679 1680 elif node.op in ("call_module", "call_method", "call_function", "output"): 1681 # check for matches 1682 ( 1683 last_node, 1684 matched_node_pattern, 1685 pattern, 1686 qhandler, 1687 qconfig, 1688 ) = node_name_to_match_result_with_qconfig.get( # type: ignore[assignment] 1689 node.name, (None, None, None, None, None) 1690 ) 1691 equalization_qconfig = equalization_config_map.get(node.name, None) 1692 1693 this_node_dtype_info = node.meta["target_dtype_info"] 1694 if "val" in node.meta: 1695 output_is_a_tensor = this_node_dtype_info is not None and isinstance( 1696 node.meta["val"], FakeTensor 1697 ) 1698 else: 1699 output_is_a_tensor = this_node_dtype_info is not None 1700 1701 skip_inserting_observers = ( 1702 (qconfig is None) or not output_is_a_tensor 1703 ) and (not node.op == "output") 1704 1705 # TODO: take a closer look to see if we can remove this check 1706 # right now it is here because of `observed_node_names`, we are using 1707 # it as an indicator for swapping the modules to reference modules in 1708 # convert 1709 is_supported_by_backend = ( 1710 _is_pattern_dtype_config_and_qconfig_supported_by_backend( 1711 pattern, matched_node_pattern, qconfig, backend_config 1712 ) 1713 ) 1714 1715 if not skip_inserting_observers and is_supported_by_backend: 1716 named_modules = dict(model.named_modules(remove_duplicate=False)) 1717 if node.op != "output": 1718 assert matched_node_pattern is not None 1719 # add matched nodes to the observed node name set 1720 _add_matched_node_name_to_set( 1721 matched_node_pattern, observed_node_names 1722 ) 1723 1724 # This is currently only used for equalization. 1725 # Checks if the current node is in a branch in which the two 1726 # first layers are both being quantized. 1727 # 1728 # ex. conv2 1729 # / 1730 # x -> conv1 1731 # 1732 # If this is the case, we will not apply equalization to the 1733 # initial two layers. 1734 is_quantized_branch = False 1735 if ( 1736 len(node.args) > 0 1737 and isinstance(node.args[0], Node) 1738 and len(node.args[0].users) > 1 1739 ): 1740 for user in node.args[0].users: 1741 # Checks if there exists another user being quantized 1742 is_user_quantized = node_name_to_qconfig.get( 1743 user.name, None 1744 ) is not None or ( 1745 user.op == "call_module" 1746 and isinstance( 1747 named_modules[str(user.target)], ObserverBase 1748 ) 1749 ) 1750 if user != node and is_user_quantized: 1751 is_quantized_branch = True 1752 1753 pattern_to_root_node_getter = ( 1754 get_fusion_pattern_to_root_node_getter(backend_config) 1755 ) 1756 root_node_getter = pattern_to_root_node_getter.get( 1757 pattern, _default_root_node_getter 1758 ) 1759 root_node = root_node_getter(matched_node_pattern) 1760 is_input_node_of_the_pattern = node is root_node 1761 if is_input_node_of_the_pattern: 1762 # this modifies node inplace 1763 _maybe_insert_input_observers_for_node( 1764 node, 1765 qconfig, 1766 model, 1767 named_modules, 1768 model.graph, 1769 qhandler, 1770 prepare_custom_config, 1771 obs_or_fq_map, 1772 is_qat, 1773 backend_config, 1774 ) 1775 1776 # insert equalization input observers if needed 1777 _maybe_insert_input_equalization_observers_for_node( 1778 node, 1779 equalization_qconfig, 1780 model, 1781 named_modules, 1782 model.graph, 1783 is_quantized_branch, 1784 ) 1785 1786 is_last_node_of_pattern = node is last_node 1787 input_output_share_observers = node.meta["target_dtype_info"].get( 1788 "input_output_share_observers", False 1789 ) 1790 reuse_input_obs_or_fq = node.meta["target_dtype_info"].get( 1791 "reuse_input_obs_or_fq", False 1792 ) 1793 1794 if is_last_node_of_pattern: 1795 if _is_custom_module_lstm( 1796 node, named_modules, qconfig, qhandler 1797 ): 1798 # Currently custom module outputs are assumed to be already quantized, 1799 # so we need to insert a DeQuantStub after the output. For custom module 1800 # LSTM specifically, the outputs are also a nested tuple, so we must first 1801 # break down the tuple to insert DeQuantStubs after the internal nodes. 1802 1803 # TODO: This currently diverges from how custom modules are handled today, 1804 # where we insert observers after the output instead of DeQuantStubs, and 1805 # replace these observers with "dequantize" nodes during convert. Conceptually, 1806 # these output observers are the same as DeQuantStubs. In the future, we 1807 # should resolve this inconsistency by inserting DeQuantStubs for all custom 1808 # modules, not just for LSTM. 1809 _insert_dequant_stubs_for_custom_module_lstm_output( 1810 node, model, named_modules, model.graph 1811 ) 1812 if node.target not in custom_module_names_already_swapped: 1813 custom_module_names_already_swapped.add(node.target) 1814 _swap_custom_module_to_observed( 1815 node, qconfig, named_modules, prepare_custom_config 1816 ) 1817 else: 1818 # this returns the new observer node if it was needed 1819 maybe_output_obs_node = ( 1820 _maybe_insert_output_observer_for_node( 1821 node, 1822 model, 1823 named_modules, 1824 model.graph, 1825 obs_or_fq_map, 1826 is_qat, 1827 ) 1828 ) 1829 1830 if maybe_output_obs_node is not None: 1831 # Update users of original node to use the output observer 1832 # instead. For example, change 1833 # 1834 # next_node 1835 # / 1836 # cur_node -> obs 1837 # 1838 # to 1839 # 1840 # next_node 1841 # / 1842 # cur_node -> obs 1843 # 1844 # We need to save orig users before updating uses because 1845 # the list of users will change as we update uses 1846 orig_users = list(node.users.keys()) 1847 for user_node in orig_users: 1848 if user_node is maybe_output_obs_node: 1849 continue 1850 user_node.replace_input_with( 1851 node, maybe_output_obs_node 1852 ) 1853 1854 _is_observer_in_same_graph_ = ( 1855 _is_observer_in_same_graph( 1856 node, named_modules, obs_or_fq_map, is_qat 1857 ) 1858 ) 1859 1860 # for ops whose inputs and outputs share observer/fqs, we modify the graph 1861 # to make all inputs and outputs use the first input's 1862 # observer/fq 1863 if ( 1864 input_output_share_observers 1865 and _is_observer_in_same_graph_ 1866 ) or reuse_input_obs_or_fq: 1867 if not _maybe_make_input_output_share_observers( 1868 node, model, named_modules 1869 ): 1870 _remove_output_observer( 1871 node, model, named_modules 1872 ) 1873 1874 if qhandler is not None and qhandler.is_custom_module(): 1875 if ( 1876 node.target 1877 not in custom_module_names_already_swapped 1878 ): 1879 custom_module_names_already_swapped.add( 1880 node.target 1881 ) 1882 _swap_custom_module_to_observed( 1883 node, 1884 qconfig, 1885 named_modules, 1886 prepare_custom_config, 1887 ) 1888 1889 else: # output 1890 _maybe_insert_observers_before_graph_output( 1891 node, model, named_modules, model.graph, obs_or_fq_map, is_qat 1892 ) 1893 1894 # 1895 # After this point, the current node has input and output observers 1896 # that it needs for itself inserted. 1897 # 1898 1899 # increment the counters, so future inputs and outputs are assigned 1900 # correct dtypes 1901 if node.op == "placeholder": 1902 inputs_seen_counter += 1 1903 elif node.op == "output": 1904 outputs_seen_counter += 1 1905 results_node = node 1906 1907 return results_node 1908 1909 1910def _run_prepare_fx_on_standalone_modules( 1911 model: torch.nn.Module, 1912 is_qat: bool, 1913 named_modules: Dict[str, torch.nn.Module], 1914 node_name_to_match_result_with_qconfig: Any, 1915 prepare_custom_config: PrepareCustomConfig, 1916 backend_config: BackendConfig, 1917) -> None: 1918 """ 1919 Runs prepare_fx on each standalone module. Note: this does 1920 not modify the graph, it just replaces the unobserved modules with 1921 their observed versions. 1922 """ 1923 for ( 1924 root_node, 1925 _, 1926 pattern, 1927 qhandler, 1928 qconfig, 1929 ) in node_name_to_match_result_with_qconfig.values(): 1930 if qhandler is None: 1931 continue 1932 elif not qhandler.is_standalone_module(): 1933 continue 1934 1935 ( 1936 sm_qconfig_mapping, 1937 sm_example_inputs, 1938 sm_prepare_custom_config, 1939 sm_backend_config, 1940 ) = _get_standalone_module_configs( 1941 root_node, named_modules, prepare_custom_config, qconfig, backend_config 1942 ) 1943 1944 standalone_module = named_modules[root_node.target] 1945 prepare = ( 1946 torch.ao.quantization.quantize_fx._prepare_standalone_module_fx 1947 ) # type: ignore[attr-defined] 1948 observed_standalone_module = prepare( 1949 standalone_module, 1950 sm_qconfig_mapping, 1951 is_qat, 1952 example_inputs=sm_example_inputs, 1953 prepare_custom_config=sm_prepare_custom_config, 1954 backend_config=sm_backend_config, 1955 ) 1956 parent_name, name = _parent_name(root_node.target) 1957 setattr(named_modules[parent_name], name, observed_standalone_module) 1958 named_modules[root_node.target] = observed_standalone_module 1959 1960 1961def _save_state( 1962 observed: GraphModule, 1963 node_name_to_qconfig: Dict[str, QConfigAny], 1964 node_name_to_scope: Dict[str, Tuple[str, type]], 1965 prepare_custom_config: PrepareCustomConfig, 1966 equalization_node_name_to_qconfig: Dict[str, Any], 1967 qconfig_mapping: QConfigMapping, 1968 is_qat: bool, 1969 observed_node_names: Set[str], 1970) -> None: 1971 observed.meta["_observed_graph_module_attrs"] = ObservedGraphModuleAttrs( 1972 node_name_to_qconfig=node_name_to_qconfig, 1973 node_name_to_scope=node_name_to_scope, 1974 prepare_custom_config=prepare_custom_config, 1975 equalization_node_name_to_qconfig=equalization_node_name_to_qconfig, 1976 qconfig_mapping=qconfig_mapping, 1977 is_qat=is_qat, 1978 observed_node_names=observed_node_names, 1979 ) 1980 1981 1982def prepare( 1983 model: GraphModule, 1984 qconfig_mapping: Union[QConfigMapping, Dict[str, Any]], 1985 is_qat: bool, 1986 node_name_to_scope: Dict[str, Tuple[str, type]], 1987 example_inputs: Tuple[Any, ...], 1988 prepare_custom_config: Union[PrepareCustomConfig, Dict[str, Any], None] = None, 1989 _equalization_config: Union[QConfigMapping, Dict[str, Any], None] = None, 1990 backend_config: Union[BackendConfig, Dict[str, Any], None] = None, 1991 is_standalone_module: bool = False, 1992) -> GraphModule: 1993 """standalone_module means it a submodule that is not inlined in 1994 parent module, and will be quantized separately as one unit. 1995 1996 How the standalone module is observed is specified by `input_quantized_idxs` and 1997 `output_quantized_idxs` in the prepare_custom_config for the standalone module 1998 Args: 1999 node_name_to_scope: mapping from node name to the scope of the module which contains the node. 2000 The scope is a tuple of fully qualified path of the module and the type of the module 2001 Returns: 2002 model(GraphModule): prepared standalone module 2003 attributes related to standalone module 2004 in model.meta["_observed_graph_module_attrs"]: 2005 is_observed_standalone_module (bool): boolean value that shows whether the 2006 current model is a observed standalone module or not 2007 standalone_module_input_quantized_idxs(List[Int]): a list of 2008 indexes for the graph input that is expected to be quantized, 2009 same as input_quantized_idxs configuration provided 2010 for the standalone module 2011 standalone_module_output_quantized_idxs(List[Int]): a list of 2012 indexs for the graph output that is quantized 2013 same as input_quantized_idxs configuration provided 2014 for the standalone module 2015 """ 2016 if prepare_custom_config is None: 2017 prepare_custom_config = PrepareCustomConfig() 2018 if _equalization_config is None: 2019 _equalization_config = QConfigMapping() 2020 2021 if isinstance(qconfig_mapping, dict): 2022 warnings.warn( 2023 "Passing a QConfig dictionary to prepare is deprecated and will not be supported " 2024 "in a future version. Please pass in a QConfigMapping instead.", 2025 FutureWarning, 2026 stacklevel=2, 2027 ) 2028 qconfig_mapping = QConfigMapping.from_dict(qconfig_mapping) 2029 2030 if isinstance(_equalization_config, dict): 2031 warnings.warn( 2032 "Passing a QConfig dictionary to prepare for equalization is deprecated and will not " 2033 "be supported in a future version. Please pass in a QConfigMapping instead.", 2034 FutureWarning, 2035 stacklevel=2, 2036 ) 2037 _equalization_config = QConfigMapping.from_dict(_equalization_config) 2038 2039 if isinstance(prepare_custom_config, dict): 2040 warnings.warn( 2041 "Passing a prepare_custom_config_dict to prepare is deprecated and will not be supported " 2042 "in a future version. Please pass in a PrepareCustomConfig instead.", 2043 FutureWarning, 2044 stacklevel=2, 2045 ) 2046 prepare_custom_config = PrepareCustomConfig.from_dict(prepare_custom_config) 2047 2048 if isinstance(backend_config, dict): 2049 warnings.warn( 2050 "Passing a backend_config_dict to prepare is deprecated and will not be supported " 2051 "in a future version. Please pass in a BackendConfig instead.", 2052 FutureWarning, 2053 stacklevel=2, 2054 ) 2055 backend_config = BackendConfig.from_dict(backend_config) 2056 2057 assert isinstance(qconfig_mapping, QConfigMapping) 2058 assert isinstance(_equalization_config, QConfigMapping) 2059 qconfig_mapping = copy.deepcopy(qconfig_mapping) 2060 _equalization_config = copy.deepcopy(_equalization_config) 2061 2062 # mapping from a tuple of nodes in reverse order to uninitialized 2063 # QuantizeHandler subclass. For example, 2064 # { 2065 # # match a single node 2066 # (<class 'torch.nn.modules.conv.Conv3d'>: 2067 # <class 'torch.ao.quantization.fx.quantize.ConvRelu'>), 2068 # # match multiple nodes in reverse order 2069 # ((<function relu at 0x7f766a7360d0>, <built-in function add>): 2070 # <class 'torch.ao.quantization.fx.quantize.Add'>), 2071 # } 2072 2073 pattern_to_quantize_handler: Dict[Pattern, QuantizeHandler] = {} 2074 if backend_config is None: 2075 backend_config = get_native_backend_config() 2076 pattern_to_quantize_handler = _get_pattern_to_quantize_handlers(backend_config) 2077 pattern_to_quantize_handler = _sorted_patterns_dict(pattern_to_quantize_handler) 2078 2079 root_node_getter_mapping = get_fusion_pattern_to_root_node_getter(backend_config) 2080 2081 _update_qconfig_for_fusion(model, qconfig_mapping) 2082 _update_qconfig_for_fusion(model, _equalization_config) 2083 flattened_qconfig_dict = _get_flattened_qconfig_dict(qconfig_mapping) 2084 # TODO: support regex as well 2085 propagate_qconfig_(model, flattened_qconfig_dict, prepare_custom_config.to_dict()) 2086 2087 if is_qat: 2088 module_to_qat_module = get_module_to_qat_module(backend_config) 2089 _qat_swap_modules(model, module_to_qat_module) 2090 _update_qconfig_for_qat(qconfig_mapping, backend_config) 2091 2092 # mapping from fully qualified module name to module instance 2093 # for example, 2094 # { 2095 # '': Model(...), 2096 # 'linear': Linear(...), 2097 # 'linear.weight_fake_quant': PerChannelMinMaxObserver(...), 2098 # } 2099 named_modules = dict(model.named_modules(remove_duplicate=False)) 2100 2101 # fill node_name_to_qconfig, a map from node name to qconfig, used in _find_matches 2102 equalization_node_name_to_qconfig = _generate_node_name_to_qconfig( 2103 model, named_modules, model.graph, _equalization_config, node_name_to_scope 2104 ) 2105 node_name_to_qconfig = _generate_node_name_to_qconfig( 2106 model, named_modules, model.graph, qconfig_mapping, node_name_to_scope 2107 ) 2108 2109 # match the patterns that will get quantized 2110 standalone_module_names = list(prepare_custom_config.standalone_module_names.keys()) 2111 standalone_module_classes = list( 2112 prepare_custom_config.standalone_module_classes.keys() 2113 ) 2114 2115 custom_module_classes = get_custom_module_class_keys( 2116 prepare_custom_config.float_to_observed_mapping 2117 ) 2118 matches_without_qconfig = _find_matches( 2119 model.graph, 2120 named_modules, 2121 pattern_to_quantize_handler, 2122 root_node_getter_mapping, 2123 standalone_module_names, 2124 standalone_module_classes, 2125 custom_module_classes, 2126 ) 2127 2128 # map qconfig instances to matches 2129 node_name_to_match_result_with_qconfig = {} 2130 for node_name, match_without_qconfig in matches_without_qconfig.items(): 2131 match_with_qconfig = (*match_without_qconfig, node_name_to_qconfig[node_name]) 2132 node_name_to_match_result_with_qconfig[node_name] = match_with_qconfig 2133 2134 _run_prepare_fx_on_standalone_modules( 2135 model, 2136 is_qat, 2137 named_modules, 2138 node_name_to_match_result_with_qconfig, 2139 prepare_custom_config, 2140 backend_config, 2141 ) 2142 2143 # record names for the set of observed node, so that in convert step 2144 # we know whether we need to convert a floating point module to reference 2145 # quantized module or not 2146 observed_node_names: Set[str] = set() 2147 2148 result_node = insert_observers_for_model( 2149 model, 2150 node_name_to_match_result_with_qconfig, 2151 node_name_to_qconfig, 2152 prepare_custom_config, 2153 equalization_node_name_to_qconfig, 2154 backend_config, 2155 observed_node_names, 2156 is_qat, 2157 ) 2158 model = GraphModule(model, model.graph) 2159 2160 _save_state( 2161 model, 2162 node_name_to_qconfig, 2163 node_name_to_scope, 2164 prepare_custom_config, 2165 equalization_node_name_to_qconfig, 2166 qconfig_mapping, 2167 is_qat, 2168 observed_node_names, 2169 ) 2170 2171 if is_standalone_module: 2172 assert result_node is not None 2173 assert isinstance(result_node.args[0], Node), ( 2174 "standalone module only supports returning simple value currently" 2175 "(not tuple, dict etc.)" 2176 ) 2177 # these inputs are observed in parent 2178 # converting List[int] to Tensor since module attribute is 2179 # Union[Tensor, Module] 2180 input_quantized_idxs: List[int] = prepare_custom_config.input_quantized_indexes 2181 output_quantized_idxs: List[ 2182 int 2183 ] = prepare_custom_config.output_quantized_indexes 2184 observed_graph_module_attrs = model.meta["_observed_graph_module_attrs"] 2185 # inplace modification 2186 observed_graph_module_attrs.is_observed_standalone_module = True 2187 observed_graph_module_attrs.standalone_module_input_quantized_idxs = ( 2188 input_quantized_idxs 2189 ) 2190 observed_graph_module_attrs.standalone_module_output_quantized_idxs = ( 2191 output_quantized_idxs 2192 ) 2193 return model 2194