xref: /aosp_15_r20/external/pytorch/torch/ao/quantization/fx/quantize_handler.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from abc import ABC
3from typing import Callable, Dict, List, Optional, Type
4
5import torch
6from torch.ao.quantization.backend_config import (
7    BackendConfig,
8    DTypeConfig,
9    ObservationType,
10)
11from torch.ao.quantization.utils import NodePattern, Pattern, QuantizerCls
12from torch.fx.graph import Node
13
14from .utils import all_node_args_have_no_tensors
15
16
17__all__ = [
18    "QuantizeHandler",
19    "BinaryOpQuantizeHandler",
20    "CatQuantizeHandler",
21    "ConvReluQuantizeHandler",
22    "LinearReLUQuantizeHandler",
23    "BatchNormQuantizeHandler",
24    "EmbeddingQuantizeHandler",
25    "RNNDynamicQuantizeHandler",
26    "DefaultNodeQuantizeHandler",
27    "FixedQParamsOpQuantizeHandler",
28    "CopyNodeQuantizeHandler",
29    "GeneralTensorShapeOpQuantizeHandler",
30    "CustomModuleQuantizeHandler",
31    "StandaloneModuleQuantizeHandler",
32]
33
34
35def _default_root_node_getter(node_pattern):
36    if node_pattern is None:
37        return node_pattern
38    while not isinstance(node_pattern, Node):
39        node_pattern = node_pattern[-1]
40    return node_pattern
41
42
43# Base Pattern Handler
44class QuantizeHandler(ABC):  # noqa: B024
45    """Base handler class for the quantizer patterns"""
46
47    def __init__(
48        self,
49        node_pattern: NodePattern,
50        modules: Dict[str, torch.nn.Module],
51        root_node_getter: Optional[Callable] = None,
52        is_custom_module=False,
53        is_standalone_module=False,
54    ):
55        """Records pattern information in __init__, which will be used
56        in convert
57        """
58        self.node_pattern = node_pattern
59        self.modules = modules
60        if root_node_getter is None:
61            root_node_getter = _default_root_node_getter
62        self.root_node = root_node_getter(node_pattern)
63        self.is_custom_module_ = is_custom_module
64        self.is_standalone_module_ = is_standalone_module
65        self.num_tensor_args = 0
66        # determine how many of the first two args are Tensors (versus scalars)
67        # this distinguishes things like "x + y" from "x + 2" or "2 + x"
68        if isinstance(self.root_node, Node):
69            cache_for_no_tensor_check: Dict[Node, bool] = {}
70            for arg_idx in range(len(self.root_node.args)):
71                arg = self.root_node.args[arg_idx]
72                if isinstance(arg, Node) and (
73                    not all_node_args_have_no_tensors(
74                        arg, self.modules, cache_for_no_tensor_check
75                    )
76                ):
77                    self.num_tensor_args += 1
78
79    def is_general_tensor_value_op(self) -> bool:
80        """
81        Returns True if the operator works for both floating point and
82        quantized input, and does some computation based on the input Tensor,
83        or the ops that only re-arranges the Tensor values or query some metadata
84        about the Tensor
85        so we need to insert observer/fake_quant for the output of the
86        operator (same observer instance as input)
87        since the distribution of values is different for input and output
88        Tensors (for HistogramObserver) while they share the same quantization
89        parameters
90        Example operator: avgpool2d, reshape, transpose, maxpool2d
91        Example observed operator:
92        observer_0 - avgpool2d - observer_0 (same observer instance as input)
93        """
94        return False
95
96    def is_custom_module(self):
97        return self.is_custom_module_
98
99    def is_standalone_module(self):
100        return self.is_standalone_module_
101
102
103def _get_quantize_handler_cls(
104    observation_type: ObservationType,
105    dtype_configs: List[DTypeConfig],
106    num_tensor_args_to_observation_type: Dict[int, ObservationType],
107) -> Type[QuantizeHandler]:
108    """
109    Return a configurable QuantizeHandler that matches the given specifications from the backend.
110    """
111
112    class ConfigurableQuantizeHandler(QuantizeHandler):
113        def __init__(
114            self,
115            node_pattern: NodePattern,
116            modules: Dict[str, torch.nn.Module],
117            root_node_getter: Optional[Callable] = None,
118        ):
119            super().__init__(node_pattern, modules, root_node_getter)
120            if num_tensor_args_to_observation_type:
121                assert self.num_tensor_args in num_tensor_args_to_observation_type, (
122                    f"Must provide observation_type config for tensor number {self.num_tensor_args}"
123                    f" in num_tensor_args_to_observation_type for {node_pattern}"
124                )
125                self.observation_type = num_tensor_args_to_observation_type[
126                    self.num_tensor_args
127                ]
128            else:
129                self.observation_type = observation_type
130            self.dtype_configs = dtype_configs
131
132        def is_general_tensor_value_op(self) -> bool:
133            return (
134                self.observation_type
135                == ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT
136            )
137
138    return ConfigurableQuantizeHandler
139
140
141def _get_pattern_to_quantize_handlers(
142    backend_config: BackendConfig,
143) -> Dict[Pattern, QuantizerCls]:
144    """
145    Note: Quantize handler is just a holder for some check methods like
146    (should_insert_observer_for_output), maybe this can be a enum as well,
147    we can refactor this after we convert the path for fbgemm/qnnpack fully to the
148    new path, this is not exposed to backend developers
149    """
150    pattern_to_quantize_handlers = {}
151    for pattern, config in backend_config._pattern_complex_format_to_config.items():
152        observation_type = config.observation_type
153        dtype_configs = config.dtype_configs
154        num_tensor_args_to_observation_type = (
155            config._num_tensor_args_to_observation_type
156        )
157        pattern_to_quantize_handlers[pattern] = _get_quantize_handler_cls(
158            observation_type, dtype_configs, num_tensor_args_to_observation_type
159        )
160    return pattern_to_quantize_handlers
161
162
163# TODO: remove this class, this is still exposed in torch.ao.quantization
164# but we should be able to break bc
165class BinaryOpQuantizeHandler(QuantizeHandler):
166    pass
167
168
169class CatQuantizeHandler(QuantizeHandler):
170    pass
171
172
173# TODO: remove this class
174class ConvReluQuantizeHandler(QuantizeHandler):
175    pass
176
177
178# TODO: remove this class
179class LinearReLUQuantizeHandler(QuantizeHandler):
180    pass
181
182
183# TODO: remove this class
184class BatchNormQuantizeHandler(QuantizeHandler):
185    pass
186
187
188# TODO: remove this class
189class EmbeddingQuantizeHandler(QuantizeHandler):
190    pass
191
192
193# TODO: remove this class
194class RNNDynamicQuantizeHandler(QuantizeHandler):
195    pass
196
197
198# TODO: remove this class
199class DefaultNodeQuantizeHandler(QuantizeHandler):
200    """Common quantized op, first input and first output will be quantized"""
201
202
203# TODO: remove this class
204class FixedQParamsOpQuantizeHandler(QuantizeHandler):
205    pass
206
207
208# TODO: remove
209class CopyNodeQuantizeHandler(QuantizeHandler):
210    pass
211
212
213# TODO: remove
214class GeneralTensorShapeOpQuantizeHandler(QuantizeHandler):
215    pass
216
217
218# TODO: not used, can be removed after torch.ao.quantization namespace is deprecated
219class CustomModuleQuantizeHandler(QuantizeHandler):
220    pass
221
222
223# TODO: not used, can be removed after torch.ao.quantization namespace is deprecated
224class StandaloneModuleQuantizeHandler(QuantizeHandler):
225    pass
226