xref: /aosp_15_r20/external/executorch/backends/cadence/aot/quantizer/quantizer.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-strict
8
9from typing import List, Optional, Tuple, Union
10
11import torch
12from executorch.backends.cadence.aot.quantizer.patterns import (
13    AddmmPattern,
14    BmmPattern,
15    Conv1dPattern,
16    Conv2dPattern,
17    LayerNormPattern,
18    LinearPattern,
19    MatmulPattern,
20    QuantizationPattern,
21    ReluPattern0,
22    ReluPattern1,
23)
24from executorch.backends.cadence.aot.quantizer.utils import (
25    find_sequential_partitions_aten,
26    is_annotated,
27    no_outside_users,
28)
29
30from torch import fx
31
32from torch.ao.quantization.observer import HistogramObserver, MinMaxObserver
33from torch.ao.quantization.quantizer import DerivedQuantizationSpec, Quantizer
34from torch.ao.quantization.quantizer.composable_quantizer import ComposableQuantizer
35from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import (
36    OperatorConfig,
37    QuantizationAnnotation,
38    QuantizationConfig,
39    QuantizationSpec,
40)
41
42
43act_qspec = QuantizationSpec(
44    dtype=torch.uint8,
45    quant_min=0,
46    quant_max=255,
47    qscheme=torch.per_tensor_affine,
48    is_dynamic=False,
49    observer_or_fake_quant_ctr=HistogramObserver.with_args(eps=2**-12),
50)
51
52wgt_qspec = QuantizationSpec(
53    dtype=torch.uint8,
54    quant_min=0,
55    quant_max=255,
56    qscheme=torch.per_tensor_affine,
57    is_dynamic=False,
58    observer_or_fake_quant_ctr=MinMaxObserver,
59)
60
61bias_qspec: Optional[QuantizationSpec] = None
62
63
64class CadenceAtenQuantizer(Quantizer):
65    def __init__(
66        self, pattern: QuantizationPattern, quantization_config: QuantizationConfig
67    ) -> None:
68        super().__init__()
69        self.pattern = pattern
70        self.quantization_config = quantization_config
71
72    def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
73        fused_partitions = find_sequential_partitions_aten(
74            model,
75            self.pattern.partition_types(),
76        )
77
78        input_act_qspec = self.quantization_config.input_activation
79        weight_qspec = self.quantization_config.weight
80        bias_qspec = self.quantization_config.bias
81        output_act_qspec = self.quantization_config.output_activation
82
83        for fused_partition in fused_partitions:
84            if not no_outside_users(fused_partition):
85                continue
86
87            anchors = self.pattern.get_anchors(model, fused_partition)
88            if not anchors:
89                continue
90            if is_annotated(
91                [
92                    x[0]
93                    for x in anchors.inputs
94                    + anchors.weights
95                    + anchors.biases
96                    + anchors.output
97                ]
98            ):
99                continue
100
101            for output, *custom_spec in anchors.output:
102                # pyre-ignore[16]: no attribute
103                output.meta["quantization_annotation"] = QuantizationAnnotation(
104                    # pyre-ignore[6]: incompatible parameter type
105                    output_qspec=(custom_spec[0] if custom_spec else output_act_qspec),
106                    _annotated=True,
107                )
108
109            def annotate_inputs(
110                inputs: Union[
111                    List[Tuple[fx.Node, int]],
112                    List[Tuple[fx.Node, int, DerivedQuantizationSpec],],
113                ],
114                spec: Optional[QuantizationSpec],
115            ) -> None:
116                for node, idx, *custom_spec in inputs:
117                    # pyre-ignore[16]: no attribute
118                    annotation = node.meta.get(
119                        "quantization_annotation",
120                        QuantizationAnnotation(_annotated=True),
121                    )
122                    # pyre-ignore[16]: no attribute
123                    annotation.input_qspec_map[node.args[idx]] = (
124                        custom_spec[0] if custom_spec else spec
125                    )
126                    # pyre-ignore[16]: no attribute
127                    node.meta["quantization_annotation"] = annotation
128
129            annotate_inputs(anchors.inputs, input_act_qspec)
130            annotate_inputs(anchors.weights, weight_qspec)
131            # pyre-ignore[6]: incompatible parameter type
132            annotate_inputs(anchors.biases, bias_qspec)
133        return model
134
135    def validate(self, model: fx.GraphModule) -> None:
136        pass
137
138    @classmethod
139    def get_supported_operators(cls) -> List[OperatorConfig]:
140        return []
141
142
143class CadenceQuantizer(ComposableQuantizer):
144    def __init__(
145        self, quantization_config: Optional[QuantizationConfig] = None
146    ) -> None:
147        static_qconfig = (
148            QuantizationConfig(
149                act_qspec,
150                act_qspec,
151                wgt_qspec,
152                None,
153            )
154            if not quantization_config
155            else quantization_config
156        )
157
158        super().__init__(
159            [
160                CadenceAtenQuantizer(AddmmPattern(), static_qconfig),
161                CadenceAtenQuantizer(BmmPattern(), static_qconfig),
162                CadenceAtenQuantizer(Conv1dPattern(), static_qconfig),
163                CadenceAtenQuantizer(Conv2dPattern(), static_qconfig),
164                CadenceAtenQuantizer(LayerNormPattern(), static_qconfig),
165                CadenceAtenQuantizer(LinearPattern(), static_qconfig),
166                CadenceAtenQuantizer(MatmulPattern(), static_qconfig),
167                CadenceAtenQuantizer(ReluPattern0(), static_qconfig),
168                CadenceAtenQuantizer(ReluPattern1(), static_qconfig),
169            ]
170        )
171