xref: /aosp_15_r20/external/pytorch/torch/ao/quantization/backend_config/backend_config.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from __future__ import annotations
3
4from dataclasses import dataclass
5from enum import Enum
6from typing import Any, Callable, Dict, List, Optional, Type, TYPE_CHECKING, Union
7
8import torch
9
10
11if TYPE_CHECKING:
12    from torch.ao.quantization.utils import Pattern
13
14
15__all__ = [
16    "BackendConfig",
17    "BackendPatternConfig",
18    "DTypeConfig",
19    "DTypeWithConstraints",
20    "ObservationType",
21]
22
23
24# DTypeConfig dict keys
25INPUT_DTYPE_DICT_KEY = "input_dtype"
26OUTPUT_DTYPE_DICT_KEY = "output_dtype"
27WEIGHT_DTYPE_DICT_KEY = "weight_dtype"
28BIAS_DTYPE_DICT_KEY = "bias_dtype"
29IS_DYNAMIC_DICT_KEY = "is_dynamic"
30
31# BackendConfig dict keys
32NAME_DICT_KEY = "name"
33CONFIGS_DICT_KEY = "configs"
34
35# BackendPatternConfig dict keys
36PATTERN_DICT_KEY = "pattern"
37PATTERN_COMPLEX_FORMAT_DICT_KEY = "pattern_complex_format"
38OBSERVATION_TYPE_DICT_KEY = "observation_type"
39DTYPE_CONFIGS_DICT_KEY = "dtype_configs"
40ROOT_MODULE_DICT_KEY = "root_module"
41QAT_MODULE_DICT_KEY = "qat_module"
42REFERENCE_QUANTIZED_MODULE_DICT_KEY = "reference_quantized_module_for_root"
43FUSED_MODULE_DICT_KEY = "fused_module"
44FUSER_METHOD_DICT_KEY = "fuser_method"
45ROOT_NODE_GETTER_DICT_KEY = "root_node_getter"
46EXTRA_INPUTS_GETTER_DICT_KEY = "extra_inputs_getter"
47NUM_TENSOR_ARGS_TO_OBSERVATION_TYPE_DICT_KEY = "num_tensor_args_to_observation_type"
48INPUT_TYPE_TO_INDEX_DICT_KEY = "input_type_to_index"
49
50
51# TODO: maybe rename this to something that's not related to observer
52# e.g. QParamsType
53class ObservationType(Enum):
54    """An enum that represents different ways of how an operator/operator pattern
55    should be observed
56    """
57
58    OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT = 0
59    """this means input and output are observed with different observers, based
60    on qconfig.activation
61    example: conv, linear, softmax
62    """
63
64    OUTPUT_SHARE_OBSERVER_WITH_INPUT = 1
65    """this means the output will use the same observer instance as input, based
66    on qconfig.activation
67    example: torch.cat, maxpool
68    """
69
70    INPUT_OUTPUT_NOT_OBSERVED = 2
71    """this means the input and output are never observed
72    example: x.shape, x.size
73    """
74
75
76@dataclass
77class DTypeWithConstraints:
78    """
79    Config for specifying additional constraints for a given dtype, such as quantization
80    value ranges, scale value ranges, and fixed quantization params, to be used in
81    :class:`~torch.ao.quantization.backend_config.DTypeConfig`.
82
83    The constraints currently supported are:
84
85    * `quant_min_lower_bound` and `quant_max_upper_bound`: Lower and upper
86      bounds for the minimum and maximum quantized values respectively. If
87      the QConfig's `quant_min` and `quant_max` fall outside this range,
88      then the QConfig will be ignored.
89
90    * `scale_min_lower_bound` and `scale_max_upper_bound`: Lower and upper
91      bounds for the minimum and maximum scale values respectively. If the
92      QConfig's minimum scale value (currently exposed as `eps`) falls below
93      the lower bound, then the QConfig will be ignored. Note that the upper
94      bound is currently not enforced.
95
96    * `scale_exact_match` and `zero_point_exact_match`: Exact match requirements
97      for scale and zero point, to be used for operators with fixed quantization
98      parameters such as sigmoid and tanh. If the observer specified in the QConfig
99      is neither `FixedQParamsObserver` nor `FixedQParamsFakeQuantize`, or if
100      the quantization parameters don't match, then the QConfig will be ignored.
101    """
102
103    dtype: Optional[torch.dtype] = None
104    quant_min_lower_bound: Union[int, float, None] = None
105    quant_max_upper_bound: Union[int, float, None] = None
106    scale_min_lower_bound: Union[int, float, None] = None
107    scale_max_upper_bound: Union[int, float, None] = None
108    scale_exact_match: Optional[float] = None
109    zero_point_exact_match: Optional[int] = None
110
111
112@dataclass
113class DTypeConfig:
114    """
115    Config object that specifies the supported data types passed as arguments to
116    quantize ops in the reference model spec, for input and output activations,
117    weights, and biases.
118
119    For example, consider the following reference model:
120
121      quant1 - [dequant1 - fp32_linear - quant2] - dequant2
122
123    The pattern in the square brackets refers to the reference pattern of
124    statically quantized linear. Setting the input dtype as `torch.quint8`
125    in the DTypeConfig means we pass in `torch.quint8` as the dtype argument
126    to the first quantize op (quant1). Similarly, setting the output dtype as
127    `torch.quint8` means we pass in `torch.quint8` as the dtype argument to
128    the second quantize op (quant2).
129
130    Note that the dtype here does not refer to the interface dtypes of the
131    op. For example, the "input dtype" here is not the dtype of the input
132    tensor passed to the quantized linear op. Though it can still be the
133    same as the interface dtype, this is not always the case, e.g. the
134    interface dtype is fp32 in dynamic quantization but the "input dtype"
135    specified in the DTypeConfig would still be quint8. The semantics of
136    dtypes here are the same as the semantics of the dtypes specified in
137    the observers.
138
139    These dtypes are matched against the ones specified in the user's
140    QConfig. If there is a match, and the QConfig satisfies the constraints
141    specified in the DTypeConfig (if any), then we will quantize the given
142    pattern using this DTypeConfig. Otherwise, the QConfig is ignored and
143    the pattern will not be quantized.
144
145    Example usage::
146
147        >>> # xdoctest: +SKIP(failing)
148        >>> dtype_config1 = DTypeConfig(
149        ...     input_dtype=torch.quint8,
150        ...     output_dtype=torch.quint8,
151        ...     weight_dtype=torch.qint8,
152        ...     bias_dtype=torch.float)
153
154        >>> dtype_config2 = DTypeConfig(
155        ...     input_dtype=DTypeWithConstraints(
156        ...         dtype=torch.quint8,
157        ...         quant_min_lower_bound=0,
158        ...         quant_max_upper_bound=255,
159        ...     ),
160        ...     output_dtype=DTypeWithConstraints(
161        ...         dtype=torch.quint8,
162        ...         quant_min_lower_bound=0,
163        ...         quant_max_upper_bound=255,
164        ...     ),
165        ...     weight_dtype=DTypeWithConstraints(
166        ...         dtype=torch.qint8,
167        ...         quant_min_lower_bound=-128,
168        ...         quant_max_upper_bound=127,
169        ...     ),
170        ...     bias_dtype=torch.float)
171
172        >>> dtype_config1.input_dtype
173        torch.quint8
174
175        >>> dtype_config2.input_dtype
176        torch.quint8
177
178        >>> dtype_config2.input_dtype_with_constraints
179        DTypeWithConstraints(dtype=torch.quint8, quant_min_lower_bound=0, quant_max_upper_bound=255, \
180scale_min_lower_bound=None, scale_max_upper_bound=None)
181    """
182
183    input_dtype_with_constraints: DTypeWithConstraints
184    output_dtype_with_constraints: DTypeWithConstraints
185    weight_dtype_with_constraints: DTypeWithConstraints
186    bias_dtype: Optional[torch.dtype]
187    is_dynamic: Optional[bool]
188
189    def __init__(
190        self,
191        input_dtype: Union[torch.dtype, DTypeWithConstraints, None] = None,
192        output_dtype: Union[torch.dtype, DTypeWithConstraints, None] = None,
193        weight_dtype: Union[torch.dtype, DTypeWithConstraints, None] = None,
194        bias_dtype: Optional[torch.dtype] = None,
195        is_dynamic: Optional[bool] = None,
196    ):
197        if isinstance(input_dtype, DTypeWithConstraints):
198            self.input_dtype_with_constraints = input_dtype
199        else:
200            self.input_dtype_with_constraints = DTypeWithConstraints(dtype=input_dtype)
201
202        if isinstance(output_dtype, DTypeWithConstraints):
203            self.output_dtype_with_constraints = output_dtype
204        else:
205            self.output_dtype_with_constraints = DTypeWithConstraints(
206                dtype=output_dtype
207            )
208
209        if isinstance(weight_dtype, DTypeWithConstraints):
210            self.weight_dtype_with_constraints = weight_dtype
211        else:
212            self.weight_dtype_with_constraints = DTypeWithConstraints(
213                dtype=weight_dtype
214            )
215
216        self.bias_dtype = bias_dtype
217        self.is_dynamic = is_dynamic
218
219    @property
220    def input_dtype(self) -> Optional[torch.dtype]:
221        return self.input_dtype_with_constraints.dtype
222
223    @property
224    def output_dtype(self) -> Optional[torch.dtype]:
225        return self.output_dtype_with_constraints.dtype
226
227    @property
228    def weight_dtype(self) -> Optional[torch.dtype]:
229        return self.weight_dtype_with_constraints.dtype
230
231    @classmethod
232    def from_dict(cls, dtype_config_dict: Dict[str, Any]) -> DTypeConfig:
233        """
234        Create a ``DTypeConfig`` from a dictionary with the following items (all optional):
235            "input_dtype": torch.dtype or ``DTypeWithConstraints``
236            "output_dtype": torch.dtype or ``DTypeWithConstraints``
237            "weight_dtype": torch.dtype or ``DTypeWithConstraints``
238            "bias_type": torch.dtype
239            "is_dynamic": bool
240        """
241        input_dtype = dtype_config_dict.get(INPUT_DTYPE_DICT_KEY, None)
242        if input_dtype is not None and not isinstance(
243            input_dtype, (torch.dtype, DTypeWithConstraints)
244        ):
245            raise ValueError(
246                "Expected input_dtype to be a torch.dtype or DTypeWithConstraints"
247            )
248        output_dtype = dtype_config_dict.get(OUTPUT_DTYPE_DICT_KEY, None)
249        if output_dtype is not None and not isinstance(
250            output_dtype, (torch.dtype, DTypeWithConstraints)
251        ):
252            raise ValueError(
253                "Expected output_dtype to be a torch.dtype or DTypeWithConstraints"
254            )
255        weight_dtype = dtype_config_dict.get(WEIGHT_DTYPE_DICT_KEY, None)
256        if weight_dtype is not None and not isinstance(
257            weight_dtype, (torch.dtype, DTypeWithConstraints)
258        ):
259            raise ValueError(
260                "Expected weight_dtype to be a torch.dtype or DTypeWithConstraints"
261            )
262        bias_dtype = dtype_config_dict.get(BIAS_DTYPE_DICT_KEY, None)
263        is_dynamic = dtype_config_dict.get(IS_DYNAMIC_DICT_KEY, None)
264        return cls(input_dtype, output_dtype, weight_dtype, bias_dtype, is_dynamic)
265
266    def to_dict(self) -> Dict[str, Any]:
267        """
268        Convert this ``DTypeConfig`` to a dictionary with the items described in
269        :func:`~torch.ao.quantization.backend_config.DTypeConfig.from_dict`.
270        """
271        dtype_config_dict: Dict[str, Any] = {}
272        if self.input_dtype is not None:
273            dtype_config_dict[INPUT_DTYPE_DICT_KEY] = self.input_dtype_with_constraints
274        if self.output_dtype is not None:
275            dtype_config_dict[
276                OUTPUT_DTYPE_DICT_KEY
277            ] = self.output_dtype_with_constraints
278        if self.weight_dtype is not None:
279            dtype_config_dict[
280                WEIGHT_DTYPE_DICT_KEY
281            ] = self.weight_dtype_with_constraints
282        if self.bias_dtype is not None:
283            dtype_config_dict[BIAS_DTYPE_DICT_KEY] = self.bias_dtype
284        if self.is_dynamic is not None:
285            dtype_config_dict[IS_DYNAMIC_DICT_KEY] = self.is_dynamic
286        return dtype_config_dict
287
288
289class BackendConfig:
290    # TODO: refer to NativeBackendConfig once that is implemented
291    """Config that defines the set of patterns that can be quantized on a given backend, and how reference
292    quantized models can be produced from these patterns.
293
294    A pattern in this context refers to a module, a functional, an operator, or a directed acyclic graph
295    of the above. Each pattern supported on the target backend can be individually configured through
296    :class:`~torch.ao.quantization.backend_config.BackendPatternConfig` in terms of:
297
298    (1) The supported input/output activation, weight, and bias data types
299
300    (2) How observers and quant/dequant ops are inserted in order to construct the reference pattern, and
301
302    (3) (Optionally) Fusion, QAT, and reference module mappings.
303
304    The format of the patterns is described in:
305    https://github.com/pytorch/pytorch/blob/master/torch/ao/quantization/backend_config/README.md
306
307    Example usage::
308
309        import torch
310        from torch.ao.quantization.backend_config import (
311            BackendConfig,
312            BackendPatternConfig,
313            DTypeConfig,
314            ObservationType,
315        )
316
317        weighted_int8_dtype_config = DTypeConfig(
318            input_dtype=torch.quint8,
319            output_dtype=torch.quint8,
320            weight_dtype=torch.qint8,
321            bias_dtype=torch.float)
322
323        def fuse_conv2d_relu(is_qat, conv, relu):
324            return torch.ao.nn.intrinsic.ConvReLU2d(conv, relu)
325
326        # For quantizing Linear
327        linear_config = BackendPatternConfig(torch.nn.Linear) \
328            .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) \
329            .add_dtype_config(weighted_int8_dtype_config) \
330            .set_root_module(torch.nn.Linear) \
331            .set_qat_module(torch.ao.nn.qat.Linear) \
332            .set_reference_quantized_module(torch.ao.nn.quantized.reference.Linear)
333
334        # For fusing Conv2d + ReLU into ConvReLU2d
335        conv_relu_config = BackendPatternConfig((torch.nn.Conv2d, torch.nn.ReLU)) \
336            .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) \
337            .add_dtype_config(weighted_int8_dtype_config) \
338            .set_fused_module(torch.ao.nn.intrinsic.ConvReLU2d) \
339            .set_fuser_method(fuse_conv2d_relu)
340
341        # For quantizing ConvReLU2d
342        fused_conv_relu_config = BackendPatternConfig(torch.ao.nn.intrinsic.ConvReLU2d) \
343            .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) \
344            .add_dtype_config(weighted_int8_dtype_config) \
345            .set_root_module(torch.nn.Conv2d) \
346            .set_qat_module(torch.ao.nn.intrinsic.qat.ConvReLU2d) \
347            .set_reference_quantized_module(torch.ao.nn.quantized.reference.Conv2d)
348
349        backend_config = BackendConfig("my_backend") \
350            .set_backend_pattern_config(linear_config) \
351            .set_backend_pattern_config(conv_relu_config) \
352            .set_backend_pattern_config(fused_conv_relu_config)
353
354    """
355
356    def __init__(self, name: str = ""):
357        self.name = name
358        # Store all BackendPatternConfigs in a map to handle duplicates
359        # Note: the key in this map uses the complex reversed tuple format.
360        # This is intended only for internal use; users who wish to access
361        # the original patterns should go through `self.configs` instead.
362        self._pattern_complex_format_to_config: Dict[Pattern, BackendPatternConfig] = {}
363
364    def __repr__(self):
365        return f"BackendConfig({self.__dict__})"
366
367    def set_name(self, name: str) -> BackendConfig:
368        """
369        Set the name of the target backend.
370        """
371        self.name = name
372        return self
373
374    def set_backend_pattern_config(self, config: BackendPatternConfig) -> BackendConfig:
375        """
376        Set the config for an pattern that can be run on the target backend.
377        This overrides any existing config for the given pattern.
378        """
379        # Avoid circular dependencies
380        pattern_complex_format = torch.ao.quantization.backend_config.utils._get_pattern_in_reversed_nested_tuple_format(
381            config
382        )  # type: ignore[attr-defined]
383        self._pattern_complex_format_to_config[pattern_complex_format] = config
384        return self
385
386    def set_backend_pattern_configs(
387        self, configs: List[BackendPatternConfig]
388    ) -> BackendConfig:
389        """
390        Set the configs for patterns that can be run on the target backend.
391        This overrides any existing config for a given pattern if it was previously registered already.
392        """
393        for conf in configs:
394            self.set_backend_pattern_config(conf)
395        return self
396
397    @property
398    def configs(self) -> List[BackendPatternConfig]:
399        """
400        Return a copy of the list of configs set in this `BackendConfig`.
401        """
402        return list(self._pattern_complex_format_to_config.values())
403
404    @classmethod
405    def from_dict(cls, backend_config_dict: Dict[str, Any]) -> BackendConfig:
406        """
407        Create a ``BackendConfig`` from a dictionary with the following items:
408
409            "name": the name of the target backend
410
411            "configs": a list of dictionaries that each represents a `BackendPatternConfig`
412
413        """
414        conf = cls(backend_config_dict.get(NAME_DICT_KEY, ""))
415        for d in backend_config_dict.get(CONFIGS_DICT_KEY, []):
416            if isinstance(d, BackendPatternConfig):
417                conf.set_backend_pattern_config(d)
418            elif isinstance(d, Dict):
419                conf.set_backend_pattern_config(BackendPatternConfig.from_dict(d))
420            else:
421                raise ValueError(
422                    f"Expected backend_config_dict['{CONFIGS_DICT_KEY}'] to be a dictionary"
423                )
424        return conf
425
426    def to_dict(self) -> Dict[str, Any]:
427        """
428        Convert this ``BackendConfig`` to a dictionary with the items described in
429        :func:`~torch.ao.quantization.backend_config.BackendConfig.from_dict`.
430        """
431        return {
432            NAME_DICT_KEY: self.name,
433            CONFIGS_DICT_KEY: [c.to_dict() for c in self.configs],
434        }
435
436
437class BackendPatternConfig:
438    """
439    Config object that specifies quantization behavior for a given operator pattern.
440    For a detailed example usage, see :class:`~torch.ao.quantization.backend_config.BackendConfig`.
441    """
442
443    def __init__(self, pattern: Optional[Pattern] = None):
444        self.pattern: Optional[Pattern] = pattern
445        self.observation_type = ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT
446        self.dtype_configs: List[DTypeConfig] = []
447        self.root_module: Optional[Type[torch.nn.Module]] = None
448        self.qat_module: Optional[Type[torch.nn.Module]] = None
449        self.reference_quantized_module: Optional[Type[torch.nn.Module]] = None
450        self.fused_module: Optional[Type[torch.nn.Module]] = None
451        self.fuser_method: Optional[Callable] = None
452
453        # Temporary/internal configs
454        self._root_node_getter: Optional[Callable] = None
455        self._extra_inputs_getter: Optional[Callable] = None
456        self._num_tensor_args_to_observation_type: Dict[int, ObservationType] = {}
457        self._input_type_to_index: Dict[str, int] = {}
458        self._pattern_complex_format: Optional[Pattern] = None
459
460    def __repr__(self):
461        dict_nonempty = {
462            k: v
463            for k, v in self.__dict__.items()
464            if (
465                (not isinstance(v, (list, dict)) and v is not None)
466                or (isinstance(v, (list, dict)) and len(v) > 0)
467            )
468        }
469        return f"BackendPatternConfig({dict_nonempty})"
470
471    def set_pattern(self, pattern: Pattern) -> BackendPatternConfig:
472        """
473        Set the pattern to configure.
474
475        The pattern can be a float module, functional operator, pytorch operator, or a tuple
476        combination of the above. Tuple patterns are treated as sequential patterns, and
477        currently only tuples of 2 or 3 elements are supported.
478        """
479        if self._pattern_complex_format is not None:
480            raise ValueError(
481                "Only one of 'pattern' or 'pattern_complex_format' can be set"
482            )
483        self.pattern = pattern
484        return self
485
486    def set_observation_type(
487        self, observation_type: ObservationType
488    ) -> BackendPatternConfig:
489        """
490        Set how observers should be inserted in the graph for this pattern.
491
492        Observation type here refers to how observers (or quant-dequant ops) will be placed
493        in the graph. This is used to produce the desired reference patterns understood by
494        the backend. Weighted ops such as linear and conv require different observers
495        (or quantization parameters passed to quantize ops in the reference model) for the
496        input and the output.
497
498        There are two observation types:
499
500            `OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT` (default): the output observer instance
501            will be different from the input. This is the most common observation type.
502
503            `OUTPUT_SHARE_OBSERVER_WITH_INPUT`: the output observer instance will be the
504            same as the input. This is useful for operators like `cat`.
505
506        Note: This will be renamed in the near future, since we will soon insert QuantDeQuantStubs
507        with observers (and fake quantizes) attached instead of observers themselves.
508        """
509        self.observation_type = observation_type
510        return self
511
512    def add_dtype_config(self, dtype_config: DTypeConfig) -> BackendPatternConfig:
513        """
514        Add a set of supported data types passed as arguments to quantize ops in the
515        reference model spec.
516        """
517        self.dtype_configs.append(dtype_config)
518        return self
519
520    def set_dtype_configs(
521        self, dtype_configs: List[DTypeConfig]
522    ) -> BackendPatternConfig:
523        """
524        Set the supported data types passed as arguments to quantize ops in the
525        reference model spec, overriding all previously registered data types.
526        """
527        self.dtype_configs = dtype_configs
528        return self
529
530    def set_root_module(
531        self, root_module: Type[torch.nn.Module]
532    ) -> BackendPatternConfig:
533        """
534        Set the module that represents the root for this pattern.
535
536        When we construct the reference quantized model during the convert phase,
537        the root modules (e.g. torch.nn.Linear for torch.ao.nn.intrinsic.LinearReLU)
538        will be swapped to the corresponding reference quantized modules (e.g.
539        torch.ao.nn.reference.quantized.Linear). This allows custom backends to
540        specify custom reference quantized module implementations to match the
541        numerics of their lowered operators. Since this is a one-to-one mapping,
542        both the root module and the reference quantized module must be specified
543        in the same BackendPatternConfig in order for the conversion to take place.
544        """
545        self.root_module = root_module
546        return self
547
548    def set_qat_module(self, qat_module: Type[torch.nn.Module]) -> BackendPatternConfig:
549        """
550        Set the module that represents the QAT implementation for this pattern.
551        """
552        self.qat_module = qat_module
553        return self
554
555    def set_reference_quantized_module(
556        self, reference_quantized_module: Type[torch.nn.Module]
557    ) -> BackendPatternConfig:
558        """
559        Set the module that represents the reference quantized implementation for
560        this pattern's root module.
561
562        For more detail, see :func:`~torch.ao.quantization.backend_config.BackendPatternConfig.set_root_module`.
563        """
564        self.reference_quantized_module = reference_quantized_module
565        return self
566
567    def set_fused_module(
568        self, fused_module: Type[torch.nn.Module]
569    ) -> BackendPatternConfig:
570        """
571        Set the module that represents the fused implementation for this pattern.
572        """
573        self.fused_module = fused_module
574        return self
575
576    def set_fuser_method(self, fuser_method: Callable) -> BackendPatternConfig:
577        """
578        Set the function that specifies how to fuse this BackendPatternConfig's pattern.
579
580        The first argument of this function should be `is_qat`, and the rest of the arguments
581        should be the items in the tuple pattern. The return value of this function should be
582        the resulting fused module.
583
584        For example, the fuser method for the pattern `(torch.nn.Linear, torch.nn.ReLU)` can be:
585
586            def fuse_linear_relu(is_qat, linear, relu):
587                return torch.ao.nn.intrinsic.LinearReLU(linear, relu)
588
589        For a more complicated example, see https://gist.github.com/jerryzh168/8bea7180a8ba3c279f2c9b050f2a69a6.
590        """
591        self.fuser_method = fuser_method
592        return self
593
594    def _set_root_node_getter(self, root_node_getter: Callable) -> BackendPatternConfig:
595        self._root_node_getter = root_node_getter
596        return self
597
598    def _set_extra_inputs_getter(
599        self, extra_inputs_getter: Callable
600    ) -> BackendPatternConfig:
601        self._extra_inputs_getter = extra_inputs_getter
602        return self
603
604    def _set_num_tensor_args_to_observation_type(
605        self, num_tensor_args_to_observation_type: Dict[int, ObservationType]
606    ) -> BackendPatternConfig:
607        self._num_tensor_args_to_observation_type = num_tensor_args_to_observation_type
608        return self
609
610    def _set_input_type_to_index(
611        self, input_type_to_index: Dict[str, int]
612    ) -> BackendPatternConfig:
613        self._input_type_to_index = input_type_to_index
614        return self
615
616    def _set_pattern_complex_format(self, pattern: Pattern) -> BackendPatternConfig:
617        """
618        Set the pattern to configure, using the reversed nested tuple format.
619
620        See the BackendConfig README for more detail:
621        https://github.com/pytorch/pytorch/blob/master/torch/ao/quantization/backend_config/README.md#advanced-pattern-specification
622        """
623        if self.pattern is not None:
624            raise ValueError(
625                "Only one of 'pattern' or 'pattern_complex_format' can be set"
626            )
627        self._pattern_complex_format = pattern
628        return self
629
630    @classmethod
631    def from_dict(
632        cls, backend_pattern_config_dict: Dict[str, Any]
633    ) -> BackendPatternConfig:
634        """
635        Create a ``BackendPatternConfig`` from a dictionary with the following items:
636
637            "pattern": the pattern being configured
638            "observation_type": the :class:`~torch.ao.quantization.backend_config.ObservationType` that specifies how
639            observers should be inserted for this pattern
640            "dtype_configs": a list of dictionaries that represents :class:`~torch.ao.quantization.backend_config.DTypeConfig` s
641            "root_module": a :class:`torch.nn.Module` that represents the root for this pattern
642            "qat_module": a :class:`torch.nn.Module` that represents the QAT implementation for this pattern
643            "reference_quantized_module": a :class:`torch.nn.Module` that represents the reference quantized
644            implementation for this pattern's root module.
645            "fused_module": a :class:`torch.nn.Module` that represents the fused implementation for this pattern
646            "fuser_method": a function that specifies how to fuse the pattern for this pattern
647            "pattern_complex_format": the pattern specified in the reversed nested tuple format (deprecated)
648
649        """
650
651        def _get_dtype_config(obj: Any) -> DTypeConfig:
652            """
653            Convert the given object into a ``DTypeConfig`` if possible, else throw an exception.
654            """
655            if isinstance(obj, DTypeConfig):
656                return obj
657            if isinstance(obj, Dict):
658                return DTypeConfig.from_dict(obj)
659            raise ValueError(
660                f"Expected a list of DTypeConfigs in "
661                f"backend_pattern_config_dict[\"{DTYPE_CONFIGS_DICT_KEY}\"], got '{type(obj)}'"
662            )
663
664        conf = cls()
665        if PATTERN_DICT_KEY in backend_pattern_config_dict:
666            conf.set_pattern(backend_pattern_config_dict[PATTERN_DICT_KEY])
667        if OBSERVATION_TYPE_DICT_KEY in backend_pattern_config_dict:
668            conf.set_observation_type(
669                backend_pattern_config_dict[OBSERVATION_TYPE_DICT_KEY]
670            )
671        for d in backend_pattern_config_dict.get(DTYPE_CONFIGS_DICT_KEY, []):
672            conf.add_dtype_config(_get_dtype_config(d))
673        conf.set_root_module(
674            backend_pattern_config_dict.get(ROOT_MODULE_DICT_KEY, None)
675        )
676        conf.set_qat_module(backend_pattern_config_dict.get(QAT_MODULE_DICT_KEY, None))
677        conf.set_reference_quantized_module(
678            backend_pattern_config_dict.get(REFERENCE_QUANTIZED_MODULE_DICT_KEY, None)
679        )
680        conf.set_fused_module(
681            backend_pattern_config_dict.get(FUSED_MODULE_DICT_KEY, None)
682        )
683        conf.set_fuser_method(
684            backend_pattern_config_dict.get(FUSER_METHOD_DICT_KEY, None)
685        )
686        conf._set_root_node_getter(
687            backend_pattern_config_dict.get(ROOT_NODE_GETTER_DICT_KEY, None)
688        )
689        conf._set_extra_inputs_getter(
690            backend_pattern_config_dict.get(EXTRA_INPUTS_GETTER_DICT_KEY, None)
691        )
692        conf._set_num_tensor_args_to_observation_type(
693            backend_pattern_config_dict.get(
694                NUM_TENSOR_ARGS_TO_OBSERVATION_TYPE_DICT_KEY, {}
695            )
696        )
697        conf._set_input_type_to_index(
698            backend_pattern_config_dict.get(INPUT_TYPE_TO_INDEX_DICT_KEY, {})
699        )
700        if PATTERN_COMPLEX_FORMAT_DICT_KEY in backend_pattern_config_dict:
701            conf._set_pattern_complex_format(
702                backend_pattern_config_dict[PATTERN_COMPLEX_FORMAT_DICT_KEY]
703            )
704        return conf
705
706    def to_dict(self) -> Dict[str, Any]:
707        """
708        Convert this ``BackendPatternConfig`` to a dictionary with the items described in
709        :func:`~torch.ao.quantization.backend_config.BackendPatternConfig.from_dict`.
710        """
711        backend_pattern_config_dict: Dict[str, Any] = {
712            OBSERVATION_TYPE_DICT_KEY: self.observation_type,
713            DTYPE_CONFIGS_DICT_KEY: [c.to_dict() for c in self.dtype_configs],
714        }
715        if self.pattern is not None:
716            backend_pattern_config_dict[PATTERN_DICT_KEY] = self.pattern
717        if self.root_module is not None:
718            backend_pattern_config_dict[ROOT_MODULE_DICT_KEY] = self.root_module
719        if self.qat_module is not None:
720            backend_pattern_config_dict[QAT_MODULE_DICT_KEY] = self.qat_module
721        if self.reference_quantized_module is not None:
722            backend_pattern_config_dict[
723                REFERENCE_QUANTIZED_MODULE_DICT_KEY
724            ] = self.reference_quantized_module
725        if self.fused_module is not None:
726            backend_pattern_config_dict[FUSED_MODULE_DICT_KEY] = self.fused_module
727        if self.fuser_method is not None:
728            backend_pattern_config_dict[FUSER_METHOD_DICT_KEY] = self.fuser_method
729        if self._root_node_getter is not None:
730            backend_pattern_config_dict[
731                ROOT_NODE_GETTER_DICT_KEY
732            ] = self._root_node_getter
733        if self._extra_inputs_getter is not None:
734            backend_pattern_config_dict[
735                EXTRA_INPUTS_GETTER_DICT_KEY
736            ] = self._extra_inputs_getter
737        if len(self._num_tensor_args_to_observation_type) > 0:
738            backend_pattern_config_dict[
739                NUM_TENSOR_ARGS_TO_OBSERVATION_TYPE_DICT_KEY
740            ] = self._num_tensor_args_to_observation_type
741        if len(self._input_type_to_index) > 0:
742            backend_pattern_config_dict[
743                INPUT_TYPE_TO_INDEX_DICT_KEY
744            ] = self._input_type_to_index
745        if self._pattern_complex_format is not None:
746            backend_pattern_config_dict[
747                PATTERN_COMPLEX_FORMAT_DICT_KEY
748            ] = self._pattern_complex_format
749        return backend_pattern_config_dict
750