1# mypy: allow-untyped-defs 2from __future__ import annotations 3 4from dataclasses import dataclass 5from enum import Enum 6from typing import Any, Callable, Dict, List, Optional, Type, TYPE_CHECKING, Union 7 8import torch 9 10 11if TYPE_CHECKING: 12 from torch.ao.quantization.utils import Pattern 13 14 15__all__ = [ 16 "BackendConfig", 17 "BackendPatternConfig", 18 "DTypeConfig", 19 "DTypeWithConstraints", 20 "ObservationType", 21] 22 23 24# DTypeConfig dict keys 25INPUT_DTYPE_DICT_KEY = "input_dtype" 26OUTPUT_DTYPE_DICT_KEY = "output_dtype" 27WEIGHT_DTYPE_DICT_KEY = "weight_dtype" 28BIAS_DTYPE_DICT_KEY = "bias_dtype" 29IS_DYNAMIC_DICT_KEY = "is_dynamic" 30 31# BackendConfig dict keys 32NAME_DICT_KEY = "name" 33CONFIGS_DICT_KEY = "configs" 34 35# BackendPatternConfig dict keys 36PATTERN_DICT_KEY = "pattern" 37PATTERN_COMPLEX_FORMAT_DICT_KEY = "pattern_complex_format" 38OBSERVATION_TYPE_DICT_KEY = "observation_type" 39DTYPE_CONFIGS_DICT_KEY = "dtype_configs" 40ROOT_MODULE_DICT_KEY = "root_module" 41QAT_MODULE_DICT_KEY = "qat_module" 42REFERENCE_QUANTIZED_MODULE_DICT_KEY = "reference_quantized_module_for_root" 43FUSED_MODULE_DICT_KEY = "fused_module" 44FUSER_METHOD_DICT_KEY = "fuser_method" 45ROOT_NODE_GETTER_DICT_KEY = "root_node_getter" 46EXTRA_INPUTS_GETTER_DICT_KEY = "extra_inputs_getter" 47NUM_TENSOR_ARGS_TO_OBSERVATION_TYPE_DICT_KEY = "num_tensor_args_to_observation_type" 48INPUT_TYPE_TO_INDEX_DICT_KEY = "input_type_to_index" 49 50 51# TODO: maybe rename this to something that's not related to observer 52# e.g. QParamsType 53class ObservationType(Enum): 54 """An enum that represents different ways of how an operator/operator pattern 55 should be observed 56 """ 57 58 OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT = 0 59 """this means input and output are observed with different observers, based 60 on qconfig.activation 61 example: conv, linear, softmax 62 """ 63 64 OUTPUT_SHARE_OBSERVER_WITH_INPUT = 1 65 """this means the output will use the same observer instance as input, based 66 on qconfig.activation 67 example: torch.cat, maxpool 68 """ 69 70 INPUT_OUTPUT_NOT_OBSERVED = 2 71 """this means the input and output are never observed 72 example: x.shape, x.size 73 """ 74 75 76@dataclass 77class DTypeWithConstraints: 78 """ 79 Config for specifying additional constraints for a given dtype, such as quantization 80 value ranges, scale value ranges, and fixed quantization params, to be used in 81 :class:`~torch.ao.quantization.backend_config.DTypeConfig`. 82 83 The constraints currently supported are: 84 85 * `quant_min_lower_bound` and `quant_max_upper_bound`: Lower and upper 86 bounds for the minimum and maximum quantized values respectively. If 87 the QConfig's `quant_min` and `quant_max` fall outside this range, 88 then the QConfig will be ignored. 89 90 * `scale_min_lower_bound` and `scale_max_upper_bound`: Lower and upper 91 bounds for the minimum and maximum scale values respectively. If the 92 QConfig's minimum scale value (currently exposed as `eps`) falls below 93 the lower bound, then the QConfig will be ignored. Note that the upper 94 bound is currently not enforced. 95 96 * `scale_exact_match` and `zero_point_exact_match`: Exact match requirements 97 for scale and zero point, to be used for operators with fixed quantization 98 parameters such as sigmoid and tanh. If the observer specified in the QConfig 99 is neither `FixedQParamsObserver` nor `FixedQParamsFakeQuantize`, or if 100 the quantization parameters don't match, then the QConfig will be ignored. 101 """ 102 103 dtype: Optional[torch.dtype] = None 104 quant_min_lower_bound: Union[int, float, None] = None 105 quant_max_upper_bound: Union[int, float, None] = None 106 scale_min_lower_bound: Union[int, float, None] = None 107 scale_max_upper_bound: Union[int, float, None] = None 108 scale_exact_match: Optional[float] = None 109 zero_point_exact_match: Optional[int] = None 110 111 112@dataclass 113class DTypeConfig: 114 """ 115 Config object that specifies the supported data types passed as arguments to 116 quantize ops in the reference model spec, for input and output activations, 117 weights, and biases. 118 119 For example, consider the following reference model: 120 121 quant1 - [dequant1 - fp32_linear - quant2] - dequant2 122 123 The pattern in the square brackets refers to the reference pattern of 124 statically quantized linear. Setting the input dtype as `torch.quint8` 125 in the DTypeConfig means we pass in `torch.quint8` as the dtype argument 126 to the first quantize op (quant1). Similarly, setting the output dtype as 127 `torch.quint8` means we pass in `torch.quint8` as the dtype argument to 128 the second quantize op (quant2). 129 130 Note that the dtype here does not refer to the interface dtypes of the 131 op. For example, the "input dtype" here is not the dtype of the input 132 tensor passed to the quantized linear op. Though it can still be the 133 same as the interface dtype, this is not always the case, e.g. the 134 interface dtype is fp32 in dynamic quantization but the "input dtype" 135 specified in the DTypeConfig would still be quint8. The semantics of 136 dtypes here are the same as the semantics of the dtypes specified in 137 the observers. 138 139 These dtypes are matched against the ones specified in the user's 140 QConfig. If there is a match, and the QConfig satisfies the constraints 141 specified in the DTypeConfig (if any), then we will quantize the given 142 pattern using this DTypeConfig. Otherwise, the QConfig is ignored and 143 the pattern will not be quantized. 144 145 Example usage:: 146 147 >>> # xdoctest: +SKIP(failing) 148 >>> dtype_config1 = DTypeConfig( 149 ... input_dtype=torch.quint8, 150 ... output_dtype=torch.quint8, 151 ... weight_dtype=torch.qint8, 152 ... bias_dtype=torch.float) 153 154 >>> dtype_config2 = DTypeConfig( 155 ... input_dtype=DTypeWithConstraints( 156 ... dtype=torch.quint8, 157 ... quant_min_lower_bound=0, 158 ... quant_max_upper_bound=255, 159 ... ), 160 ... output_dtype=DTypeWithConstraints( 161 ... dtype=torch.quint8, 162 ... quant_min_lower_bound=0, 163 ... quant_max_upper_bound=255, 164 ... ), 165 ... weight_dtype=DTypeWithConstraints( 166 ... dtype=torch.qint8, 167 ... quant_min_lower_bound=-128, 168 ... quant_max_upper_bound=127, 169 ... ), 170 ... bias_dtype=torch.float) 171 172 >>> dtype_config1.input_dtype 173 torch.quint8 174 175 >>> dtype_config2.input_dtype 176 torch.quint8 177 178 >>> dtype_config2.input_dtype_with_constraints 179 DTypeWithConstraints(dtype=torch.quint8, quant_min_lower_bound=0, quant_max_upper_bound=255, \ 180scale_min_lower_bound=None, scale_max_upper_bound=None) 181 """ 182 183 input_dtype_with_constraints: DTypeWithConstraints 184 output_dtype_with_constraints: DTypeWithConstraints 185 weight_dtype_with_constraints: DTypeWithConstraints 186 bias_dtype: Optional[torch.dtype] 187 is_dynamic: Optional[bool] 188 189 def __init__( 190 self, 191 input_dtype: Union[torch.dtype, DTypeWithConstraints, None] = None, 192 output_dtype: Union[torch.dtype, DTypeWithConstraints, None] = None, 193 weight_dtype: Union[torch.dtype, DTypeWithConstraints, None] = None, 194 bias_dtype: Optional[torch.dtype] = None, 195 is_dynamic: Optional[bool] = None, 196 ): 197 if isinstance(input_dtype, DTypeWithConstraints): 198 self.input_dtype_with_constraints = input_dtype 199 else: 200 self.input_dtype_with_constraints = DTypeWithConstraints(dtype=input_dtype) 201 202 if isinstance(output_dtype, DTypeWithConstraints): 203 self.output_dtype_with_constraints = output_dtype 204 else: 205 self.output_dtype_with_constraints = DTypeWithConstraints( 206 dtype=output_dtype 207 ) 208 209 if isinstance(weight_dtype, DTypeWithConstraints): 210 self.weight_dtype_with_constraints = weight_dtype 211 else: 212 self.weight_dtype_with_constraints = DTypeWithConstraints( 213 dtype=weight_dtype 214 ) 215 216 self.bias_dtype = bias_dtype 217 self.is_dynamic = is_dynamic 218 219 @property 220 def input_dtype(self) -> Optional[torch.dtype]: 221 return self.input_dtype_with_constraints.dtype 222 223 @property 224 def output_dtype(self) -> Optional[torch.dtype]: 225 return self.output_dtype_with_constraints.dtype 226 227 @property 228 def weight_dtype(self) -> Optional[torch.dtype]: 229 return self.weight_dtype_with_constraints.dtype 230 231 @classmethod 232 def from_dict(cls, dtype_config_dict: Dict[str, Any]) -> DTypeConfig: 233 """ 234 Create a ``DTypeConfig`` from a dictionary with the following items (all optional): 235 "input_dtype": torch.dtype or ``DTypeWithConstraints`` 236 "output_dtype": torch.dtype or ``DTypeWithConstraints`` 237 "weight_dtype": torch.dtype or ``DTypeWithConstraints`` 238 "bias_type": torch.dtype 239 "is_dynamic": bool 240 """ 241 input_dtype = dtype_config_dict.get(INPUT_DTYPE_DICT_KEY, None) 242 if input_dtype is not None and not isinstance( 243 input_dtype, (torch.dtype, DTypeWithConstraints) 244 ): 245 raise ValueError( 246 "Expected input_dtype to be a torch.dtype or DTypeWithConstraints" 247 ) 248 output_dtype = dtype_config_dict.get(OUTPUT_DTYPE_DICT_KEY, None) 249 if output_dtype is not None and not isinstance( 250 output_dtype, (torch.dtype, DTypeWithConstraints) 251 ): 252 raise ValueError( 253 "Expected output_dtype to be a torch.dtype or DTypeWithConstraints" 254 ) 255 weight_dtype = dtype_config_dict.get(WEIGHT_DTYPE_DICT_KEY, None) 256 if weight_dtype is not None and not isinstance( 257 weight_dtype, (torch.dtype, DTypeWithConstraints) 258 ): 259 raise ValueError( 260 "Expected weight_dtype to be a torch.dtype or DTypeWithConstraints" 261 ) 262 bias_dtype = dtype_config_dict.get(BIAS_DTYPE_DICT_KEY, None) 263 is_dynamic = dtype_config_dict.get(IS_DYNAMIC_DICT_KEY, None) 264 return cls(input_dtype, output_dtype, weight_dtype, bias_dtype, is_dynamic) 265 266 def to_dict(self) -> Dict[str, Any]: 267 """ 268 Convert this ``DTypeConfig`` to a dictionary with the items described in 269 :func:`~torch.ao.quantization.backend_config.DTypeConfig.from_dict`. 270 """ 271 dtype_config_dict: Dict[str, Any] = {} 272 if self.input_dtype is not None: 273 dtype_config_dict[INPUT_DTYPE_DICT_KEY] = self.input_dtype_with_constraints 274 if self.output_dtype is not None: 275 dtype_config_dict[ 276 OUTPUT_DTYPE_DICT_KEY 277 ] = self.output_dtype_with_constraints 278 if self.weight_dtype is not None: 279 dtype_config_dict[ 280 WEIGHT_DTYPE_DICT_KEY 281 ] = self.weight_dtype_with_constraints 282 if self.bias_dtype is not None: 283 dtype_config_dict[BIAS_DTYPE_DICT_KEY] = self.bias_dtype 284 if self.is_dynamic is not None: 285 dtype_config_dict[IS_DYNAMIC_DICT_KEY] = self.is_dynamic 286 return dtype_config_dict 287 288 289class BackendConfig: 290 # TODO: refer to NativeBackendConfig once that is implemented 291 """Config that defines the set of patterns that can be quantized on a given backend, and how reference 292 quantized models can be produced from these patterns. 293 294 A pattern in this context refers to a module, a functional, an operator, or a directed acyclic graph 295 of the above. Each pattern supported on the target backend can be individually configured through 296 :class:`~torch.ao.quantization.backend_config.BackendPatternConfig` in terms of: 297 298 (1) The supported input/output activation, weight, and bias data types 299 300 (2) How observers and quant/dequant ops are inserted in order to construct the reference pattern, and 301 302 (3) (Optionally) Fusion, QAT, and reference module mappings. 303 304 The format of the patterns is described in: 305 https://github.com/pytorch/pytorch/blob/master/torch/ao/quantization/backend_config/README.md 306 307 Example usage:: 308 309 import torch 310 from torch.ao.quantization.backend_config import ( 311 BackendConfig, 312 BackendPatternConfig, 313 DTypeConfig, 314 ObservationType, 315 ) 316 317 weighted_int8_dtype_config = DTypeConfig( 318 input_dtype=torch.quint8, 319 output_dtype=torch.quint8, 320 weight_dtype=torch.qint8, 321 bias_dtype=torch.float) 322 323 def fuse_conv2d_relu(is_qat, conv, relu): 324 return torch.ao.nn.intrinsic.ConvReLU2d(conv, relu) 325 326 # For quantizing Linear 327 linear_config = BackendPatternConfig(torch.nn.Linear) \ 328 .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) \ 329 .add_dtype_config(weighted_int8_dtype_config) \ 330 .set_root_module(torch.nn.Linear) \ 331 .set_qat_module(torch.ao.nn.qat.Linear) \ 332 .set_reference_quantized_module(torch.ao.nn.quantized.reference.Linear) 333 334 # For fusing Conv2d + ReLU into ConvReLU2d 335 conv_relu_config = BackendPatternConfig((torch.nn.Conv2d, torch.nn.ReLU)) \ 336 .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) \ 337 .add_dtype_config(weighted_int8_dtype_config) \ 338 .set_fused_module(torch.ao.nn.intrinsic.ConvReLU2d) \ 339 .set_fuser_method(fuse_conv2d_relu) 340 341 # For quantizing ConvReLU2d 342 fused_conv_relu_config = BackendPatternConfig(torch.ao.nn.intrinsic.ConvReLU2d) \ 343 .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) \ 344 .add_dtype_config(weighted_int8_dtype_config) \ 345 .set_root_module(torch.nn.Conv2d) \ 346 .set_qat_module(torch.ao.nn.intrinsic.qat.ConvReLU2d) \ 347 .set_reference_quantized_module(torch.ao.nn.quantized.reference.Conv2d) 348 349 backend_config = BackendConfig("my_backend") \ 350 .set_backend_pattern_config(linear_config) \ 351 .set_backend_pattern_config(conv_relu_config) \ 352 .set_backend_pattern_config(fused_conv_relu_config) 353 354 """ 355 356 def __init__(self, name: str = ""): 357 self.name = name 358 # Store all BackendPatternConfigs in a map to handle duplicates 359 # Note: the key in this map uses the complex reversed tuple format. 360 # This is intended only for internal use; users who wish to access 361 # the original patterns should go through `self.configs` instead. 362 self._pattern_complex_format_to_config: Dict[Pattern, BackendPatternConfig] = {} 363 364 def __repr__(self): 365 return f"BackendConfig({self.__dict__})" 366 367 def set_name(self, name: str) -> BackendConfig: 368 """ 369 Set the name of the target backend. 370 """ 371 self.name = name 372 return self 373 374 def set_backend_pattern_config(self, config: BackendPatternConfig) -> BackendConfig: 375 """ 376 Set the config for an pattern that can be run on the target backend. 377 This overrides any existing config for the given pattern. 378 """ 379 # Avoid circular dependencies 380 pattern_complex_format = torch.ao.quantization.backend_config.utils._get_pattern_in_reversed_nested_tuple_format( 381 config 382 ) # type: ignore[attr-defined] 383 self._pattern_complex_format_to_config[pattern_complex_format] = config 384 return self 385 386 def set_backend_pattern_configs( 387 self, configs: List[BackendPatternConfig] 388 ) -> BackendConfig: 389 """ 390 Set the configs for patterns that can be run on the target backend. 391 This overrides any existing config for a given pattern if it was previously registered already. 392 """ 393 for conf in configs: 394 self.set_backend_pattern_config(conf) 395 return self 396 397 @property 398 def configs(self) -> List[BackendPatternConfig]: 399 """ 400 Return a copy of the list of configs set in this `BackendConfig`. 401 """ 402 return list(self._pattern_complex_format_to_config.values()) 403 404 @classmethod 405 def from_dict(cls, backend_config_dict: Dict[str, Any]) -> BackendConfig: 406 """ 407 Create a ``BackendConfig`` from a dictionary with the following items: 408 409 "name": the name of the target backend 410 411 "configs": a list of dictionaries that each represents a `BackendPatternConfig` 412 413 """ 414 conf = cls(backend_config_dict.get(NAME_DICT_KEY, "")) 415 for d in backend_config_dict.get(CONFIGS_DICT_KEY, []): 416 if isinstance(d, BackendPatternConfig): 417 conf.set_backend_pattern_config(d) 418 elif isinstance(d, Dict): 419 conf.set_backend_pattern_config(BackendPatternConfig.from_dict(d)) 420 else: 421 raise ValueError( 422 f"Expected backend_config_dict['{CONFIGS_DICT_KEY}'] to be a dictionary" 423 ) 424 return conf 425 426 def to_dict(self) -> Dict[str, Any]: 427 """ 428 Convert this ``BackendConfig`` to a dictionary with the items described in 429 :func:`~torch.ao.quantization.backend_config.BackendConfig.from_dict`. 430 """ 431 return { 432 NAME_DICT_KEY: self.name, 433 CONFIGS_DICT_KEY: [c.to_dict() for c in self.configs], 434 } 435 436 437class BackendPatternConfig: 438 """ 439 Config object that specifies quantization behavior for a given operator pattern. 440 For a detailed example usage, see :class:`~torch.ao.quantization.backend_config.BackendConfig`. 441 """ 442 443 def __init__(self, pattern: Optional[Pattern] = None): 444 self.pattern: Optional[Pattern] = pattern 445 self.observation_type = ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT 446 self.dtype_configs: List[DTypeConfig] = [] 447 self.root_module: Optional[Type[torch.nn.Module]] = None 448 self.qat_module: Optional[Type[torch.nn.Module]] = None 449 self.reference_quantized_module: Optional[Type[torch.nn.Module]] = None 450 self.fused_module: Optional[Type[torch.nn.Module]] = None 451 self.fuser_method: Optional[Callable] = None 452 453 # Temporary/internal configs 454 self._root_node_getter: Optional[Callable] = None 455 self._extra_inputs_getter: Optional[Callable] = None 456 self._num_tensor_args_to_observation_type: Dict[int, ObservationType] = {} 457 self._input_type_to_index: Dict[str, int] = {} 458 self._pattern_complex_format: Optional[Pattern] = None 459 460 def __repr__(self): 461 dict_nonempty = { 462 k: v 463 for k, v in self.__dict__.items() 464 if ( 465 (not isinstance(v, (list, dict)) and v is not None) 466 or (isinstance(v, (list, dict)) and len(v) > 0) 467 ) 468 } 469 return f"BackendPatternConfig({dict_nonempty})" 470 471 def set_pattern(self, pattern: Pattern) -> BackendPatternConfig: 472 """ 473 Set the pattern to configure. 474 475 The pattern can be a float module, functional operator, pytorch operator, or a tuple 476 combination of the above. Tuple patterns are treated as sequential patterns, and 477 currently only tuples of 2 or 3 elements are supported. 478 """ 479 if self._pattern_complex_format is not None: 480 raise ValueError( 481 "Only one of 'pattern' or 'pattern_complex_format' can be set" 482 ) 483 self.pattern = pattern 484 return self 485 486 def set_observation_type( 487 self, observation_type: ObservationType 488 ) -> BackendPatternConfig: 489 """ 490 Set how observers should be inserted in the graph for this pattern. 491 492 Observation type here refers to how observers (or quant-dequant ops) will be placed 493 in the graph. This is used to produce the desired reference patterns understood by 494 the backend. Weighted ops such as linear and conv require different observers 495 (or quantization parameters passed to quantize ops in the reference model) for the 496 input and the output. 497 498 There are two observation types: 499 500 `OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT` (default): the output observer instance 501 will be different from the input. This is the most common observation type. 502 503 `OUTPUT_SHARE_OBSERVER_WITH_INPUT`: the output observer instance will be the 504 same as the input. This is useful for operators like `cat`. 505 506 Note: This will be renamed in the near future, since we will soon insert QuantDeQuantStubs 507 with observers (and fake quantizes) attached instead of observers themselves. 508 """ 509 self.observation_type = observation_type 510 return self 511 512 def add_dtype_config(self, dtype_config: DTypeConfig) -> BackendPatternConfig: 513 """ 514 Add a set of supported data types passed as arguments to quantize ops in the 515 reference model spec. 516 """ 517 self.dtype_configs.append(dtype_config) 518 return self 519 520 def set_dtype_configs( 521 self, dtype_configs: List[DTypeConfig] 522 ) -> BackendPatternConfig: 523 """ 524 Set the supported data types passed as arguments to quantize ops in the 525 reference model spec, overriding all previously registered data types. 526 """ 527 self.dtype_configs = dtype_configs 528 return self 529 530 def set_root_module( 531 self, root_module: Type[torch.nn.Module] 532 ) -> BackendPatternConfig: 533 """ 534 Set the module that represents the root for this pattern. 535 536 When we construct the reference quantized model during the convert phase, 537 the root modules (e.g. torch.nn.Linear for torch.ao.nn.intrinsic.LinearReLU) 538 will be swapped to the corresponding reference quantized modules (e.g. 539 torch.ao.nn.reference.quantized.Linear). This allows custom backends to 540 specify custom reference quantized module implementations to match the 541 numerics of their lowered operators. Since this is a one-to-one mapping, 542 both the root module and the reference quantized module must be specified 543 in the same BackendPatternConfig in order for the conversion to take place. 544 """ 545 self.root_module = root_module 546 return self 547 548 def set_qat_module(self, qat_module: Type[torch.nn.Module]) -> BackendPatternConfig: 549 """ 550 Set the module that represents the QAT implementation for this pattern. 551 """ 552 self.qat_module = qat_module 553 return self 554 555 def set_reference_quantized_module( 556 self, reference_quantized_module: Type[torch.nn.Module] 557 ) -> BackendPatternConfig: 558 """ 559 Set the module that represents the reference quantized implementation for 560 this pattern's root module. 561 562 For more detail, see :func:`~torch.ao.quantization.backend_config.BackendPatternConfig.set_root_module`. 563 """ 564 self.reference_quantized_module = reference_quantized_module 565 return self 566 567 def set_fused_module( 568 self, fused_module: Type[torch.nn.Module] 569 ) -> BackendPatternConfig: 570 """ 571 Set the module that represents the fused implementation for this pattern. 572 """ 573 self.fused_module = fused_module 574 return self 575 576 def set_fuser_method(self, fuser_method: Callable) -> BackendPatternConfig: 577 """ 578 Set the function that specifies how to fuse this BackendPatternConfig's pattern. 579 580 The first argument of this function should be `is_qat`, and the rest of the arguments 581 should be the items in the tuple pattern. The return value of this function should be 582 the resulting fused module. 583 584 For example, the fuser method for the pattern `(torch.nn.Linear, torch.nn.ReLU)` can be: 585 586 def fuse_linear_relu(is_qat, linear, relu): 587 return torch.ao.nn.intrinsic.LinearReLU(linear, relu) 588 589 For a more complicated example, see https://gist.github.com/jerryzh168/8bea7180a8ba3c279f2c9b050f2a69a6. 590 """ 591 self.fuser_method = fuser_method 592 return self 593 594 def _set_root_node_getter(self, root_node_getter: Callable) -> BackendPatternConfig: 595 self._root_node_getter = root_node_getter 596 return self 597 598 def _set_extra_inputs_getter( 599 self, extra_inputs_getter: Callable 600 ) -> BackendPatternConfig: 601 self._extra_inputs_getter = extra_inputs_getter 602 return self 603 604 def _set_num_tensor_args_to_observation_type( 605 self, num_tensor_args_to_observation_type: Dict[int, ObservationType] 606 ) -> BackendPatternConfig: 607 self._num_tensor_args_to_observation_type = num_tensor_args_to_observation_type 608 return self 609 610 def _set_input_type_to_index( 611 self, input_type_to_index: Dict[str, int] 612 ) -> BackendPatternConfig: 613 self._input_type_to_index = input_type_to_index 614 return self 615 616 def _set_pattern_complex_format(self, pattern: Pattern) -> BackendPatternConfig: 617 """ 618 Set the pattern to configure, using the reversed nested tuple format. 619 620 See the BackendConfig README for more detail: 621 https://github.com/pytorch/pytorch/blob/master/torch/ao/quantization/backend_config/README.md#advanced-pattern-specification 622 """ 623 if self.pattern is not None: 624 raise ValueError( 625 "Only one of 'pattern' or 'pattern_complex_format' can be set" 626 ) 627 self._pattern_complex_format = pattern 628 return self 629 630 @classmethod 631 def from_dict( 632 cls, backend_pattern_config_dict: Dict[str, Any] 633 ) -> BackendPatternConfig: 634 """ 635 Create a ``BackendPatternConfig`` from a dictionary with the following items: 636 637 "pattern": the pattern being configured 638 "observation_type": the :class:`~torch.ao.quantization.backend_config.ObservationType` that specifies how 639 observers should be inserted for this pattern 640 "dtype_configs": a list of dictionaries that represents :class:`~torch.ao.quantization.backend_config.DTypeConfig` s 641 "root_module": a :class:`torch.nn.Module` that represents the root for this pattern 642 "qat_module": a :class:`torch.nn.Module` that represents the QAT implementation for this pattern 643 "reference_quantized_module": a :class:`torch.nn.Module` that represents the reference quantized 644 implementation for this pattern's root module. 645 "fused_module": a :class:`torch.nn.Module` that represents the fused implementation for this pattern 646 "fuser_method": a function that specifies how to fuse the pattern for this pattern 647 "pattern_complex_format": the pattern specified in the reversed nested tuple format (deprecated) 648 649 """ 650 651 def _get_dtype_config(obj: Any) -> DTypeConfig: 652 """ 653 Convert the given object into a ``DTypeConfig`` if possible, else throw an exception. 654 """ 655 if isinstance(obj, DTypeConfig): 656 return obj 657 if isinstance(obj, Dict): 658 return DTypeConfig.from_dict(obj) 659 raise ValueError( 660 f"Expected a list of DTypeConfigs in " 661 f"backend_pattern_config_dict[\"{DTYPE_CONFIGS_DICT_KEY}\"], got '{type(obj)}'" 662 ) 663 664 conf = cls() 665 if PATTERN_DICT_KEY in backend_pattern_config_dict: 666 conf.set_pattern(backend_pattern_config_dict[PATTERN_DICT_KEY]) 667 if OBSERVATION_TYPE_DICT_KEY in backend_pattern_config_dict: 668 conf.set_observation_type( 669 backend_pattern_config_dict[OBSERVATION_TYPE_DICT_KEY] 670 ) 671 for d in backend_pattern_config_dict.get(DTYPE_CONFIGS_DICT_KEY, []): 672 conf.add_dtype_config(_get_dtype_config(d)) 673 conf.set_root_module( 674 backend_pattern_config_dict.get(ROOT_MODULE_DICT_KEY, None) 675 ) 676 conf.set_qat_module(backend_pattern_config_dict.get(QAT_MODULE_DICT_KEY, None)) 677 conf.set_reference_quantized_module( 678 backend_pattern_config_dict.get(REFERENCE_QUANTIZED_MODULE_DICT_KEY, None) 679 ) 680 conf.set_fused_module( 681 backend_pattern_config_dict.get(FUSED_MODULE_DICT_KEY, None) 682 ) 683 conf.set_fuser_method( 684 backend_pattern_config_dict.get(FUSER_METHOD_DICT_KEY, None) 685 ) 686 conf._set_root_node_getter( 687 backend_pattern_config_dict.get(ROOT_NODE_GETTER_DICT_KEY, None) 688 ) 689 conf._set_extra_inputs_getter( 690 backend_pattern_config_dict.get(EXTRA_INPUTS_GETTER_DICT_KEY, None) 691 ) 692 conf._set_num_tensor_args_to_observation_type( 693 backend_pattern_config_dict.get( 694 NUM_TENSOR_ARGS_TO_OBSERVATION_TYPE_DICT_KEY, {} 695 ) 696 ) 697 conf._set_input_type_to_index( 698 backend_pattern_config_dict.get(INPUT_TYPE_TO_INDEX_DICT_KEY, {}) 699 ) 700 if PATTERN_COMPLEX_FORMAT_DICT_KEY in backend_pattern_config_dict: 701 conf._set_pattern_complex_format( 702 backend_pattern_config_dict[PATTERN_COMPLEX_FORMAT_DICT_KEY] 703 ) 704 return conf 705 706 def to_dict(self) -> Dict[str, Any]: 707 """ 708 Convert this ``BackendPatternConfig`` to a dictionary with the items described in 709 :func:`~torch.ao.quantization.backend_config.BackendPatternConfig.from_dict`. 710 """ 711 backend_pattern_config_dict: Dict[str, Any] = { 712 OBSERVATION_TYPE_DICT_KEY: self.observation_type, 713 DTYPE_CONFIGS_DICT_KEY: [c.to_dict() for c in self.dtype_configs], 714 } 715 if self.pattern is not None: 716 backend_pattern_config_dict[PATTERN_DICT_KEY] = self.pattern 717 if self.root_module is not None: 718 backend_pattern_config_dict[ROOT_MODULE_DICT_KEY] = self.root_module 719 if self.qat_module is not None: 720 backend_pattern_config_dict[QAT_MODULE_DICT_KEY] = self.qat_module 721 if self.reference_quantized_module is not None: 722 backend_pattern_config_dict[ 723 REFERENCE_QUANTIZED_MODULE_DICT_KEY 724 ] = self.reference_quantized_module 725 if self.fused_module is not None: 726 backend_pattern_config_dict[FUSED_MODULE_DICT_KEY] = self.fused_module 727 if self.fuser_method is not None: 728 backend_pattern_config_dict[FUSER_METHOD_DICT_KEY] = self.fuser_method 729 if self._root_node_getter is not None: 730 backend_pattern_config_dict[ 731 ROOT_NODE_GETTER_DICT_KEY 732 ] = self._root_node_getter 733 if self._extra_inputs_getter is not None: 734 backend_pattern_config_dict[ 735 EXTRA_INPUTS_GETTER_DICT_KEY 736 ] = self._extra_inputs_getter 737 if len(self._num_tensor_args_to_observation_type) > 0: 738 backend_pattern_config_dict[ 739 NUM_TENSOR_ARGS_TO_OBSERVATION_TYPE_DICT_KEY 740 ] = self._num_tensor_args_to_observation_type 741 if len(self._input_type_to_index) > 0: 742 backend_pattern_config_dict[ 743 INPUT_TYPE_TO_INDEX_DICT_KEY 744 ] = self._input_type_to_index 745 if self._pattern_complex_format is not None: 746 backend_pattern_config_dict[ 747 PATTERN_COMPLEX_FORMAT_DICT_KEY 748 ] = self._pattern_complex_format 749 return backend_pattern_config_dict 750