xref: /aosp_15_r20/external/pytorch/torch/ao/quantization/quantizer/xnnpack_quantizer.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from __future__ import annotations
3
4import copy
5import functools
6from typing import Any, Callable, Dict, List, Optional, Set, TYPE_CHECKING
7
8import torch
9import torch._dynamo as torchdynamo
10import torch.nn.functional as F
11from torch.ao.quantization.fake_quantize import (
12    FakeQuantize,
13    FusedMovingAvgObsFakeQuantize,
14)
15from torch.ao.quantization.observer import (
16    HistogramObserver,
17    MinMaxObserver,
18    MovingAverageMinMaxObserver,
19    MovingAveragePerChannelMinMaxObserver,
20    PerChannelMinMaxObserver,
21    PlaceholderObserver,
22)
23from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer
24from torch.ao.quantization.quantizer.utils import _get_module_name_filter
25from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import (
26    _convert_scalars_to_attrs,
27    OP_TO_ANNOTATOR,
28    OperatorConfig,
29    OperatorPatternType,
30    propagate_annotation,
31    QuantizationConfig,
32)
33
34
35if TYPE_CHECKING:
36    from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor
37    from torch.fx import Node
38
39
40__all__ = [
41    "XNNPACKQuantizer",
42    "get_symmetric_quantization_config",
43]
44
45
46def _get_dynamo_graph(function: Callable, inputs) -> torch.fx.Graph:
47    gm, _ = torchdynamo.export(function, aten_graph=True)(*inputs)
48    gm.graph.eliminate_dead_code()
49    return gm.graph
50
51
52def _get_linear_patterns(input_size: List[int]):
53    in_channels = input_size[-1]
54    out_channels = 8  # hard coding but this should not matter
55    weight = torch.ones((out_channels, in_channels))
56    bias = torch.ones((out_channels,))
57    act = torch.ones(input_size)
58
59    def linear_op(act, weight, bias=None):
60        return F.linear(act, weight, bias)
61
62    pattern_w_bias = _get_dynamo_graph(linear_op, (act, weight, bias))
63    pattern_wo_bias = _get_dynamo_graph(linear_op, (act, weight))
64    return [pattern_w_bias, pattern_wo_bias]
65
66
67def _supported_symmetric_quantized_operators() -> Dict[str, List[OperatorPatternType]]:
68    supported_operators: Dict[str, List[OperatorPatternType]] = {
69        # Both conv and linear should be able to handle relu + hardtanh fusion since
70        # those are clamp ops
71        "conv2d": [
72            [torch.nn.Conv2d, torch.nn.ReLU],
73            [torch.nn.Conv2d, F.relu],
74            [F.conv2d, torch.nn.ReLU],
75            [F.conv2d, F.relu],
76        ],
77        "linear": [[torch.nn.Linear], [F.linear]],
78        "add": [[torch.add]],
79        "adaptive_avg_pool2d": [
80            [torch.nn.AdaptiveAvgPool2d],
81            [F.adaptive_avg_pool2d],
82        ],
83    }
84    return copy.deepcopy(supported_operators)
85
86
87def _get_supported_symmetric_config_and_operators() -> List[OperatorConfig]:
88    supported_config_and_operators: List[OperatorConfig] = []
89    for quantization_config in [
90        get_symmetric_quantization_config(),
91        get_symmetric_quantization_config(is_qat=True),
92        get_symmetric_quantization_config(is_per_channel=True),
93        get_symmetric_quantization_config(is_per_channel=True, is_qat=True),
94    ]:
95        ops = _supported_symmetric_quantized_operators()
96        for pattern_list in ops.values():
97            supported_config_and_operators.append(
98                OperatorConfig(quantization_config, pattern_list)
99            )
100    return copy.deepcopy(supported_config_and_operators)
101
102
103@functools.lru_cache
104def get_symmetric_quantization_config(
105    is_per_channel: bool = False,
106    is_qat: bool = False,
107    is_dynamic: bool = False,
108    act_qmin: int = -128,
109    act_qmax: int = 127,
110    weight_qmin: int = -127,
111    weight_qmax: int = 127,
112):
113    extra_args: Dict[str, Any] = {"eps": 2**-12}
114    if is_qat:
115        if is_dynamic:
116            act_observer_or_fake_quant_ctr = FakeQuantize
117            dynamic_quant_observer = MovingAverageMinMaxObserver.with_args(
118                averaging_constant=1
119            )
120            extra_args["observer"] = dynamic_quant_observer
121        else:
122            act_observer_or_fake_quant_ctr = FusedMovingAvgObsFakeQuantize  # type: ignore[assignment]
123    else:
124        if is_dynamic:
125            act_observer_or_fake_quant_ctr = PlaceholderObserver  # type: ignore[assignment]
126        else:
127            act_observer_or_fake_quant_ctr = HistogramObserver  # type: ignore[assignment]
128
129    act_quantization_spec = QuantizationSpec(
130        dtype=torch.int8,
131        quant_min=act_qmin,
132        quant_max=act_qmax,
133        qscheme=torch.per_tensor_affine,
134        is_dynamic=is_dynamic,
135        observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr.with_args(
136            **extra_args,
137        ),
138    )
139    weight_qscheme = (
140        torch.per_channel_symmetric if is_per_channel else torch.per_tensor_symmetric
141    )
142    weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = (
143        MinMaxObserver
144    )
145    if is_qat:
146        # TODO: qat + per channel?
147        weight_observer_or_fake_quant_ctr = FusedMovingAvgObsFakeQuantize
148    elif is_per_channel:
149        weight_observer_or_fake_quant_ctr = PerChannelMinMaxObserver
150
151    extra_args: Dict[str, Any] = {"eps": 2**-12}
152    if is_qat:
153        if weight_qscheme == torch.per_tensor_symmetric:
154            extra_args["observer"] = MovingAverageMinMaxObserver
155        else:
156            extra_args["observer"] = MovingAveragePerChannelMinMaxObserver  # type: ignore[dict-item]
157    weight_quantization_spec = QuantizationSpec(
158        dtype=torch.int8,
159        quant_min=weight_qmin,
160        quant_max=weight_qmax,
161        qscheme=weight_qscheme,
162        ch_axis=0,
163        is_dynamic=False,
164        observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr.with_args(
165            **extra_args
166        ),
167    )
168
169    bias_quantization_spec = None
170    if is_dynamic:
171        quantization_config = QuantizationConfig(
172            act_quantization_spec,
173            None,
174            weight_quantization_spec,
175            bias_quantization_spec,
176            is_qat,
177        )
178    else:
179        quantization_config = QuantizationConfig(
180            act_quantization_spec,
181            act_quantization_spec,
182            weight_quantization_spec,
183            bias_quantization_spec,
184            is_qat,
185        )
186    return quantization_config
187
188
189def _get_supported_config_and_operators() -> List[OperatorConfig]:
190    return _get_supported_symmetric_config_and_operators()
191
192
193def _get_module_type_filter(tp: Callable):
194    """Get the module_type_filter function for a given module type, the filter accepts
195    a node and checks if the node comes from a module that has certain module type
196
197    For example:
198        node: linear_op = call_function[...](...)  # comes from a module with type Block -> Sub -> Linear
199
200
201    >> module_type_filter = _get_module_type_filter(Sub)  # submodule with type `Sub`, under the `Block` submodule
202    >> print(module_type_filter(node))
203    True  # the node is from the submodule `Sub` (same for `Block` and `Linear` as well)
204    """
205
206    tp_str = tp.__module__ + "." + tp.__qualname__
207
208    def module_type_filter(n: Node) -> bool:
209        # example: {
210        #     'L__self___sub': ("L['self'].sub", <class '....Sub'>),
211        #     'L__self___sub_linear': ("L['self'].sub.linear", <class 'torch.nn.modules.linear.Linear'>)
212        # }
213        nn_module_stack = n.meta.get("nn_module_stack", {})
214        types = []
215        for _, t in nn_module_stack.values():
216            # export() returns str, but older APIs (e.g. capture_pre_autograd_graph)
217            # return type. Handle both cases.
218            if isinstance(t, type):
219                t = t.__module__ + "." + t.__qualname__
220            types.append(t)
221        return tp_str in types
222
223    return module_type_filter
224
225
226def _get_not_module_type_or_name_filter(
227    tp_list: List[Callable], module_name_list: List[str]
228) -> Callable[[Node], bool]:
229    module_type_filters = [_get_module_type_filter(tp) for tp in tp_list]
230    module_name_list_filters = [_get_module_name_filter(m) for m in module_name_list]
231
232    def not_module_type_or_name_filter(n: Node) -> bool:
233        return not any(f(n) for f in module_type_filters + module_name_list_filters)
234
235    return not_module_type_or_name_filter
236
237
238class XNNPACKQuantizer(Quantizer):
239    supported_config_and_operators = _get_supported_config_and_operators()
240    STATIC_QAT_ONLY_OPS = [
241        "conv_bn_relu",
242        "conv_bn",
243        "conv_transpose_bn_relu",
244        "conv_transpose_bn",
245    ]
246
247    # static quantization ops (both PTQ and QAT)
248    # Preserve the order that fusions come before singular ops
249    STATIC_OPS = [
250        "linear_relu",
251        "linear",
252        "conv_relu",
253        "conv",
254        "conv_transpose_relu",
255        "adaptive_avg_pool2d",
256        # TODO: move this to BoltNNQuantizer?
257        "gru_io_only",
258        "add_relu",
259        "add",
260        "mul_relu",
261        "mul",
262        "cat",
263    ]
264
265    DYNAMIC_OPS = [
266        "linear",
267    ]
268
269    def __init__(self) -> None:
270        super().__init__()
271        self.global_config: Optional[QuantizationConfig] = None
272        self.operator_type_config: Dict[
273            torch._ops.OpOverloadPacket, Optional[QuantizationConfig]
274        ] = {}
275        self.module_type_config: Dict[Callable, Optional[QuantizationConfig]] = {}
276        self.module_name_config: Dict[str, Optional[QuantizationConfig]] = {}
277
278    @classmethod
279    def get_supported_quantization_configs(cls) -> List[QuantizationConfig]:
280        op_configs: Set[QuantizationConfig] = {
281            spec for spec, _ in cls.supported_config_and_operators
282        }
283        return list(op_configs)
284
285    @classmethod
286    def get_supported_operator_for_quantization_config(
287        cls, quantization_config: Optional[QuantizationConfig]
288    ) -> List[OperatorPatternType]:
289        if quantization_config is None:
290            all_ops = []
291            for _, ops in cls.supported_config_and_operators:
292                all_ops.extend(ops)
293            return all_ops
294
295        for config, ops in cls.supported_config_and_operators:
296            # note: this assumes each entry in cls.supported_spec_and_operators
297            # corresponds to one spec, e.g. we don't have
298            # [(spec1, op_list1), (spec1, op_list2), (spec2, op_list3)]
299            # where the first and second entry have the same spec but did not
300            # merge the op list
301            if config == quantization_config:
302                return ops
303        return []
304
305    def set_global(self, quantization_config: QuantizationConfig) -> XNNPACKQuantizer:
306        self.global_config = quantization_config
307        return self
308
309    def set_operator_type(
310        self,
311        operator_type: torch._ops.OpOverloadPacket,
312        quantization_config: QuantizationConfig,
313    ) -> XNNPACKQuantizer:
314        self.operator_type_config[operator_type] = quantization_config
315        return self
316
317    def set_module_type(
318        self, module_type: Callable, quantization_config: QuantizationConfig
319    ):
320        """Set quantization_config for a submodule with type: `module_type`, for example:
321        quantizer.set_module_name(Sub) or quantizer.set_module_name(nn.Linear), it will quantize all supported operator/operator
322        patterns in the submodule with this module type with the given `quantization_config`
323        """
324        self.module_type_config[module_type] = quantization_config
325        return self
326
327    def set_module_name(
328        self, module_name: str, quantization_config: Optional[QuantizationConfig]
329    ):
330        """Set quantization_config for a submodule with name: `module_name`, for example:
331        quantizer.set_module_name("blocks.sub"), it will quantize all supported operator/operator
332        patterns in the submodule with this module name with the given `quantization_config`
333        """
334        assert (
335            quantization_config is not None
336        ), " quantization_config == None is not supported yet"
337        self.module_name_config[module_name] = quantization_config
338        return self
339
340    def transform_for_annotation(
341        self, model: torch.fx.GraphModule
342    ) -> torch.fx.GraphModule:
343        """Transforms scalar values to tensor attributes"""
344        return _convert_scalars_to_attrs(model)
345
346    def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
347        """just handling global spec for now"""
348        # hacked for handling dynamic linear quant. will fix later.
349        if self.global_config and self.global_config.input_activation.is_dynamic:  # type: ignore[union-attr]
350            model = self._annotate_for_dynamic_quantization_config(model)
351        else:
352            model = self._annotate_for_static_quantization_config(model)
353        propagate_annotation(model)
354        return model
355
356    def _annotate_all_static_patterns(
357        self,
358        model: torch.fx.GraphModule,
359        quantization_config: Optional[QuantizationConfig],
360        filter_fn: Optional[Callable[[Node], bool]] = None,
361    ) -> torch.fx.GraphModule:
362        # TODO: implement the support for None to be canceling out previous annotations
363        if quantization_config is None:
364            return model
365
366        if quantization_config.is_qat:
367            for op in self.STATIC_QAT_ONLY_OPS:
368                OP_TO_ANNOTATOR[op](model, quantization_config, filter_fn)
369        for op in self.STATIC_OPS:
370            OP_TO_ANNOTATOR[op](model, quantization_config, filter_fn)
371        return model
372
373    def _annotate_all_dynamic_patterns(
374        self,
375        model: torch.fx.GraphModule,
376        quantization_config: Optional[QuantizationConfig],
377        filter_fn: Optional[Callable[[Node], bool]] = None,
378    ) -> torch.fx.GraphModule:
379        # TODO: implement the support for None to be canceling out previous annotations
380        if quantization_config is None:
381            return model
382
383        for op in self.DYNAMIC_OPS:
384            OP_TO_ANNOTATOR[op](model, quantization_config, filter_fn)
385        return model
386
387    def _annotate_for_static_quantization_config(
388        self, model: torch.fx.GraphModule
389    ) -> torch.fx.GraphModule:
390        module_name_list = list(self.module_name_config.keys())
391        for module_name, config in self.module_name_config.items():
392            self._annotate_all_static_patterns(
393                model, config, _get_module_name_filter(module_name)
394            )
395
396        tp_list = list(self.module_type_config.keys())
397        for module_type, config in self.module_type_config.items():
398            self._annotate_all_static_patterns(
399                model, config, _get_module_type_filter(module_type)
400            )
401
402        self._annotate_all_static_patterns(
403            model,
404            self.global_config,
405            _get_not_module_type_or_name_filter(tp_list, module_name_list),
406        )
407        return model
408
409    def _annotate_for_dynamic_quantization_config(
410        self, model: torch.fx.GraphModule
411    ) -> torch.fx.GraphModule:
412        module_name_list = list(self.module_name_config.keys())
413        for module_name, config in self.module_name_config.items():
414            self._annotate_all_dynamic_patterns(
415                model, config, _get_module_name_filter(module_name)
416            )
417
418        tp_list = list(self.module_type_config.keys())
419        for module_type, config in self.module_type_config.items():
420            self._annotate_all_dynamic_patterns(
421                model, config, _get_module_type_filter(module_type)
422            )
423
424        self._annotate_all_dynamic_patterns(
425            model,
426            self.global_config,
427            _get_not_module_type_or_name_filter(tp_list, module_name_list),
428        )
429        return model
430
431    def validate(self, model: torch.fx.GraphModule) -> None:
432        pass
433
434    @classmethod
435    def get_supported_operators(cls) -> List[OperatorConfig]:
436        return cls.supported_config_and_operators
437