xref: /aosp_15_r20/external/executorch/backends/arm/quantizer/arm_quantizer.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# Copyright 2024 Arm Limited and/or its affiliates.
3# All rights reserved.
4#
5# This source code is licensed under the BSD-style license found in the
6# LICENSE file in the root directory of this source tree.
7
8# pyre-unsafe
9
10#
11# Quantizer for Arm backend
12#
13
14from __future__ import annotations
15
16import copy
17import functools
18from typing import Any, Callable, Dict, List, Optional, Set
19
20import torch
21import torch.nn.functional as F
22from executorch.backends.arm._passes.arm_pass_manager import ArmPassManager
23
24from executorch.backends.arm.quantizer import arm_quantizer_utils
25from executorch.backends.arm.quantizer.arm_quantizer_utils import (
26    mark_nodes_as_annotated,
27    propagate_annotation,
28)
29from executorch.backends.arm.quantizer.quantization_annotation import (
30    OP_TO_ANNOTATOR,
31    OperatorConfig,
32    OperatorPatternType,
33)
34from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig
35from torch.ao.quantization.fake_quantize import (
36    FakeQuantize,
37    FusedMovingAvgObsFakeQuantize,
38)
39from torch.ao.quantization.observer import (
40    HistogramObserver,
41    MinMaxObserver,
42    MovingAverageMinMaxObserver,
43    MovingAveragePerChannelMinMaxObserver,
44    PerChannelMinMaxObserver,
45    PlaceholderObserver,
46)
47from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor
48from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer
49from torch.ao.quantization.quantizer.utils import (
50    _annotate_input_qspec_map,
51    _annotate_output_qspec,
52)
53from torch.fx import GraphModule, Node
54
55__all__ = [
56    "ArmQuantizer",
57    "get_symmetric_quantization_config",
58]
59
60
61def _supported_symmetric_quantized_operators() -> Dict[str, List[OperatorPatternType]]:
62    supported_operators: Dict[str, List[OperatorPatternType]] = {
63        # Both conv and linear should be able to handle relu + hardtanh fusion since
64        # those are clamp ops
65        "conv2d": [
66            [torch.nn.Conv2d, torch.nn.ReLU],
67            [torch.nn.Conv2d, F.relu],
68            [F.conv2d, torch.nn.ReLU],
69            [F.conv2d, F.relu],
70        ],
71        "linear": [[torch.nn.Linear], [F.linear]],
72        "add": [[torch.add]],
73        "max_pool2d": [[torch.nn.MaxPool2d], [F.max_pool2d]],
74        "adaptive_avg_pool2d": [
75            [torch.nn.AdaptiveAvgPool2d],
76            [F.adaptive_avg_pool2d],
77        ],
78        "mul": [[torch.mul]],
79        "sub": [[torch.sub]],
80    }
81    return copy.deepcopy(supported_operators)
82
83
84def _get_supported_symmetric_config_and_operators() -> List[OperatorConfig]:
85    supported_config_and_operators: List[OperatorConfig] = []
86    for quantization_config in [
87        get_symmetric_quantization_config(),
88        get_symmetric_quantization_config(is_per_channel=True),
89    ]:
90        ops = _supported_symmetric_quantized_operators()
91        for pattern_list in ops.values():
92            supported_config_and_operators.append(
93                OperatorConfig(quantization_config, pattern_list)
94            )
95    return copy.deepcopy(supported_config_and_operators)
96
97
98@functools.lru_cache
99def get_symmetric_quantization_config(
100    is_per_channel: bool = False,
101    is_qat: bool = False,
102    is_dynamic: bool = False,
103    act_qmin: int = -128,
104    act_qmax: int = 127,
105    weight_qmin: int = -127,
106    weight_qmax: int = 127,
107):
108    extra_args: Dict[str, Any] = {"eps": 2**-12}
109    if is_qat:
110        if is_dynamic:
111            act_observer_or_fake_quant_ctr = FakeQuantize
112            dynamic_quant_observer = MovingAverageMinMaxObserver.with_args(
113                averaging_constant=1
114            )
115            extra_args["observer"] = dynamic_quant_observer
116        else:
117            act_observer_or_fake_quant_ctr = FusedMovingAvgObsFakeQuantize  # type: ignore[assignment]
118    else:
119        if is_dynamic:
120            act_observer_or_fake_quant_ctr = PlaceholderObserver  # type: ignore[assignment]
121        else:
122            act_observer_or_fake_quant_ctr = HistogramObserver  # type: ignore[assignment]
123
124    act_quantization_spec = QuantizationSpec(
125        dtype=torch.int8,
126        quant_min=act_qmin,
127        quant_max=act_qmax,
128        qscheme=torch.per_tensor_affine,
129        is_dynamic=is_dynamic,
130        observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr.with_args(
131            **extra_args,
132        ),
133    )
134    weight_qscheme = (
135        torch.per_channel_symmetric if is_per_channel else torch.per_tensor_symmetric
136    )
137    weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = (
138        MinMaxObserver
139    )
140    if is_qat:
141        # TODO: qat + per channel?
142        weight_observer_or_fake_quant_ctr = FusedMovingAvgObsFakeQuantize
143    elif is_per_channel:
144        weight_observer_or_fake_quant_ctr = PerChannelMinMaxObserver
145
146    extra_args: Dict[str, Any] = {"eps": 2**-12}
147    if is_qat:
148        if weight_qscheme == torch.per_tensor_symmetric:
149            extra_args["observer"] = MovingAverageMinMaxObserver
150        else:
151            extra_args["observer"] = MovingAveragePerChannelMinMaxObserver  # type: ignore[dict-item]
152    weight_quantization_spec = QuantizationSpec(
153        dtype=torch.int8,
154        quant_min=weight_qmin,
155        quant_max=weight_qmax,
156        qscheme=weight_qscheme,
157        ch_axis=0,
158        is_dynamic=False,
159        observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr.with_args(
160            **extra_args
161        ),
162    )
163
164    bias_quantization_spec = None
165    if is_dynamic:
166        quantization_config = QuantizationConfig(
167            act_quantization_spec,
168            None,
169            weight_quantization_spec,
170            bias_quantization_spec,
171        )
172    else:
173        quantization_config = QuantizationConfig(
174            act_quantization_spec,
175            act_quantization_spec,
176            weight_quantization_spec,
177            bias_quantization_spec,
178        )
179    return quantization_config
180
181
182def _get_supported_config_and_operators() -> List[OperatorConfig]:
183    return _get_supported_symmetric_config_and_operators()
184
185
186NodeFilterType = Callable[[Node], bool]
187"""Type for a Node Filter used by annotators. A Node filter is a function that takes
188    a Node and returns whether the node should be annotated or not.
189"""
190
191
192def _get_module_name_filter(module_name: str) -> NodeFilterType:
193    """Get the module_name_filter function for a given module name, the filter accepts
194    a node and checks if the node comes from a module that has certain module name
195
196    For example:
197        node: linear_op = call_function[...](...)  # comes from a module with name blocks.sub.linear1
198
199    >> module_name_filter = _get_module_name_filter("blocks.sub")
200    >> print(module_name_filter(node))
201    True  # the node is from "blocks.sub" based on the fully qualified name "blocks.sub.linear1"
202    """
203
204    name_start = len("L['self'].")
205
206    def module_name_filter(n: Node) -> bool:
207        # node_stack example: {
208        #    'L__self___sub': ("L['self'].sub", <class '....Sub'>),
209        #    'L__self___sub_linear': ("L['self'].sub.linear", <class 'torch.nn.modules.linear.Linear'>)
210        # }
211        # get_attr nodes doesn't have nn_module_stack?
212        nn_module_stack = n.meta.get("nn_module_stack", {})
213        names = [name[name_start:] for name, _ in nn_module_stack.values()]
214        return module_name in names
215
216    return module_name_filter
217
218
219def _get_module_type_filter(tp: Callable) -> NodeFilterType:
220    """Get the module_type_filter function for a given module type, the filter accepts
221    a node and checks if the node comes from a module that has certain module type
222
223    For example:
224        node: linear_op = call_function[...](...)  # comes from a module with type Block -> Sub -> Linear
225
226
227    >> module_type_filter = _get_module_type_filter(Sub)  # submodule with type `Sub`, under the `Block` submodule
228    >> print(module_type_filter(node))
229    True  # the node is from the submodule `Sub` (same for `Block` and `Linear` as well)
230    """
231
232    def module_type_filter(n: Node) -> bool:
233        # node_stack example: {
234        #     'L__self___sub': ("L['self'].sub", <class '....Sub'>),
235        #     'L__self___sub_linear': ("L['self'].sub.linear", <class 'torch.nn.modules.linear.Linear'>)
236        # }
237        nn_module_stack = n.meta.get("nn_module_stack", {})
238        types = [t for _, t in nn_module_stack.values()]
239        return tp in types
240
241    return module_type_filter
242
243
244def _get_not_module_type_or_name_filter(
245    tp_list: List[Callable], module_name_list: List[str]
246) -> NodeFilterType:
247    module_type_filters = [_get_module_type_filter(tp) for tp in tp_list]
248    module_name_list_filters = [_get_module_name_filter(m) for m in module_name_list]
249
250    def not_module_type_or_name_filter(n: Node) -> bool:
251        return not any(f(n) for f in module_type_filters + module_name_list_filters)
252
253    return not_module_type_or_name_filter
254
255
256class ArmQuantizer(Quantizer):
257    supported_config_and_operators = _get_supported_config_and_operators()
258
259    # A list of supported static quantization annotators, in order of application.
260    # For example, fusions come before singular ops.
261    # The name must match the name used when registering the annotator.
262    STATIC_ANNOTATION_ORDER = [
263        "linear",
264        "conv",
265        "adaptive_avg_pool2d",
266        "max_pool2d",
267        "add",
268        "sub",
269        "mul",
270        "mm",
271        "one_to_one",
272        "generic",
273        "upsample_nearest2d",
274    ]
275
276    def __init__(self) -> None:
277        super().__init__()
278        self.global_config: Optional[QuantizationConfig] = None
279        self.io_config: Optional[QuantizationConfig] = None
280        self.module_type_config: Dict[Callable, Optional[QuantizationConfig]] = {}
281        self.module_name_config: Dict[str, Optional[QuantizationConfig]] = {}
282
283    def set_global(self, quantization_config: QuantizationConfig) -> ArmQuantizer:
284        """Set quantization_config for submodules that are not already annotated by name or type filters."""
285        self.global_config = quantization_config
286        return self
287
288    def set_module_type(
289        self, module_type: Callable, quantization_config: QuantizationConfig
290    ) -> ArmQuantizer:
291        """Set quantization_config for a submodule with type: `module_type`, for example:
292        quantizer.set_module_name(Sub) or quantizer.set_module_name(nn.Linear), it will quantize all supported operator/operator
293        patterns in the submodule with this module type with the given `quantization_config`
294        """
295        self.module_type_config[module_type] = quantization_config
296        return self
297
298    def set_module_name(
299        self, module_name: str, quantization_config: Optional[QuantizationConfig]
300    ) -> ArmQuantizer:
301        """Set quantization_config for a submodule with name: `module_name`, for example:
302        quantizer.set_module_name("blocks.sub"), it will quantize all supported operator/operator
303        patterns in the submodule with this module name with the given `quantization_config`
304        """
305        assert (
306            quantization_config is not None
307        ), " quantization_config == None is not supported yet"
308        self.module_name_config[module_name] = quantization_config
309        return self
310
311    def set_io(self, quantization_config):
312        """Set quantization_config for input and output nodes."""
313        self.io_config = quantization_config
314        return self
315
316    def transform_for_annotation(self, model: GraphModule) -> GraphModule:
317        """An initial pass for transforming the graph to prepare it for annotation.
318        Currently transforms scalar values to tensor attributes.
319        """
320
321        return ArmPassManager().transform_for_annotation_pipeline(graph_module=model)
322
323    def annotate(self, model: GraphModule) -> GraphModule:
324        """Performs the quantization annotation on the graph.
325            Currently only does static quantization annotation.
326        Args:
327            model: The model to annotate statically.
328        Returns:
329            The annotated model.
330        """
331        model = self._annotate_for_static_quantization_config(model)
332        propagate_annotation(model)
333        return model
334
335    def _annotate_all_static_patterns(
336        self,
337        model: GraphModule,
338        quantization_config: Optional[QuantizationConfig],
339        filter_fn: Optional[Callable[[Node], bool]] = None,
340    ) -> GraphModule:
341        """Loops over all STATIC_OPS and runs the corresponding registred annotator.
342        Args:
343            model: The model to annotate statically.
344            quantization_config: Specifices the QuantizationSpecs for the model's
345                input activations, output activations, weights and biases.
346            filter_fn: An optional filter function that takes a node and returns whether the node should be annotated.
347        Returns:
348            The annotated model.
349        """
350        # TODO: implement the support for None to be canceling out previous annotations
351        if quantization_config is None:
352            return model
353
354        for op in self.STATIC_ANNOTATION_ORDER:
355            OP_TO_ANNOTATOR[op](model, quantization_config, filter_fn)
356        return model
357
358    def _annotate_for_static_quantization_config(
359        self, model: GraphModule
360    ) -> GraphModule:
361        """Matches the correct QuantizationConfig with the correct module using a filter
362        when running _annotate_all_static_patterns.
363        """
364        module_name_list = list(self.module_name_config.keys())
365        for module_name, config in self.module_name_config.items():
366            self._annotate_all_static_patterns(
367                model, config, _get_module_name_filter(module_name)
368            )
369
370        tp_list = list(self.module_type_config.keys())
371        for module_type, config in self.module_type_config.items():
372            self._annotate_all_static_patterns(
373                model, config, _get_module_type_filter(module_type)
374            )
375
376        self._annotate_all_static_patterns(
377            model,
378            self.global_config,
379            _get_not_module_type_or_name_filter(tp_list, module_name_list),
380        )
381
382        if self.io_config:
383            self._annotate_io(model, self.io_config)
384
385        return model
386
387    def _annotate_io(
388        self,
389        model: GraphModule,
390        quantization_config: QuantizationConfig,
391    ):
392        for node in model.graph.nodes:
393            if arm_quantizer_utils.is_annotated(node):
394                continue
395            if node.op == "placeholder" and len(node.users) > 0:
396                _annotate_output_qspec(
397                    node,
398                    quantization_config.get_output_act_qspec(),
399                )
400                mark_nodes_as_annotated([node])
401            if node.op == "output":
402                parent = node.all_input_nodes[0]
403                _annotate_input_qspec_map(
404                    node, parent, quantization_config.get_input_act_qspec()
405                )
406                mark_nodes_as_annotated([node])
407
408    def validate(self, model: GraphModule) -> None:
409        pass
410
411    @classmethod
412    def get_supported_operators(cls) -> List[OperatorConfig]:
413        return cls.supported_config_and_operators
414
415    @classmethod
416    def get_supported_quantization_configs(cls) -> List[QuantizationConfig]:
417        op_configs: Set[QuantizationConfig] = set({})
418        for spec, _ in cls.supported_config_and_operators:
419            op_configs.add(spec)
420        return list(op_configs)
421
422    @classmethod
423    def get_supported_operator_for_quantization_config(
424        cls, quantization_config: Optional[QuantizationConfig]
425    ) -> List[OperatorPatternType]:
426        if quantization_config is None:
427            all_ops = []
428            for _, ops in cls.supported_config_and_operators:
429                all_ops.extend(ops)
430            return all_ops
431
432        for config, ops in cls.supported_config_and_operators:
433            # note: this assumes each entry in cls.supported_spec_and_operators
434            # corresponds to one spec, e.g. we don't have
435            # [(spec1, op_list1), (spec1, op_list2), (spec2, op_list3)]
436            # where the first and second entry have the same spec but did not
437            # merge the op list
438            if config == quantization_config:
439                return ops
440        return []
441