xref: /aosp_15_r20/external/executorch/backends/mediatek/quantizer/annotator.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) 2024 MediaTek Inc.
2#
3# Licensed under the BSD License (the "License"); you may not use this file
4# except in compliance with the License. See the license file in the root
5# directory of this source tree for more details.
6
7from typing import Callable, List
8
9import torch
10from torch._ops import OpOverload
11from torch._subclasses import FakeTensor
12
13from torch.ao.quantization.quantizer import QuantizationAnnotation
14from torch.ao.quantization.quantizer.utils import (
15    _annotate_input_qspec_map,
16    _annotate_output_qspec,
17)
18
19from torch.export import export_for_training
20from torch.fx import Graph, Node
21from torch.fx.passes.utils.matcher_with_name_node_map_utils import (
22    SubgraphMatcherWithNameNodeMap,
23)
24
25from .qconfig import QuantizationConfig
26
27
28OP_TO_ANNOTATOR = {}
29
30
31def annotate(graph: Graph, quant_config: QuantizationConfig) -> None:
32    # Pattern annotation
33    _annotate_rmsnorm_pattern(graph, quant_config)
34    _annotate_fused_activation_pattern(graph, quant_config)
35
36    # Per-op annotation
37    for node in graph.nodes:
38        if node.op == "placeholder":
39            annotate_placeholder(node, quant_config)
40        elif node.op == "call_function":
41            annotate_func = OP_TO_ANNOTATOR.get(node.target, None)
42            if annotate_func is not None:
43                annotate_func(node, quant_config)
44
45
46def register_annotator(ops: List[OpOverload]):
47
48    def decorator(annotator_fn: Callable):
49        for op in ops:
50            OP_TO_ANNOTATOR[op] = annotator_fn
51
52    return decorator
53
54
55def _is_annotated(node: Node):
56    """
57    Given a list of nodes (that represents an operator pattern),
58    return True if any of the node
59    is annotated, otherwise return False
60    """
61    KEY = "quantization_annotation"
62    return KEY in node.meta and node.meta[KEY]._annotated
63
64
65def _mark_as_annotated(nodes: List[Node]):
66    KEY = "quantization_annotation"
67    for node in nodes:
68        if KEY not in node.meta:
69            node.meta[KEY] = QuantizationAnnotation()
70        node.meta[KEY]._annotated = True
71
72
73def _is_float_activation_tensor(node: Node):
74    if not isinstance(node, Node):
75        return False
76    if "val" not in node.meta:
77        return False
78    if not isinstance(node.meta["val"], FakeTensor):
79        return False
80    return node.meta["val"].dtype == torch.float32
81
82
83def _annotate_fused_activation_pattern(
84    graph: Graph, quant_config: QuantizationConfig
85) -> None:
86    for relu_node in graph.nodes:
87        # Check relu/relu6 node
88        if relu_node.op != "call_function":
89            continue
90        if relu_node.target not in [
91            torch.ops.aten.relu.default,
92            torch.ops.aten.relu_.default,
93            torch.ops.aten.relu6.default,
94        ]:
95            continue
96
97        producer_node = relu_node.args[0]
98        if not isinstance(producer_node, Node):
99            continue
100        if producer_node.op != "call_function":
101            continue
102        if len(producer_node.users) != 1:
103            continue
104
105        # Handle affine + relu fusion
106        if producer_node.target in [
107            torch.ops.aten.conv1d.default,
108            torch.ops.aten.conv2d.default,
109            torch.ops.aten.linear.default,
110        ]:
111            weight_node = producer_node.args[1]
112            _annotate_input_qspec_map(
113                producer_node,
114                weight_node,
115                quant_config.weight,
116            )
117            _annotate_output_qspec(relu_node, quant_config.activation)
118            _mark_as_annotated([producer_node, weight_node, relu_node])
119            continue
120
121        # Handle arithmetic + relu fusion
122        if producer_node.target in [
123            torch.ops.aten.add.Scalar,
124            torch.ops.aten.add.Tensor,
125            torch.ops.aten.add_.Scalar,
126            torch.ops.aten.add_.Tensor,
127            torch.ops.aten.div.Scalar,
128            torch.ops.aten.div.Tensor,
129            torch.ops.aten.div_.Scalar,
130            torch.ops.aten.div_.Tensor,
131            torch.ops.aten.divide.Scalar,
132            torch.ops.aten.divide.Tensor,
133            torch.ops.aten.mul.Scalar,
134            torch.ops.aten.mul.Tensor,
135            torch.ops.aten.mul_.Scalar,
136            torch.ops.aten.mul_.Tensor,
137            torch.ops.aten.rsub.Scalar,
138            torch.ops.aten.rsub.Tensor,
139            torch.ops.aten.sub.Scalar,
140            torch.ops.aten.sub.Tensor,
141            torch.ops.aten.sub_.Scalar,
142            torch.ops.aten.sub_.Tensor,
143        ]:
144            _annotate_output_qspec(relu_node, quant_config.activation)
145            _mark_as_annotated([producer_node, relu_node])
146            continue
147
148
149def _annotate_rmsnorm_pattern(graph: Graph, quant_config: QuantizationConfig) -> None:
150
151    class ExecuTorchPattern(torch.nn.Module):
152        def forward(self, x):
153            norm = x * torch.rsqrt((x * x).mean(-1, keepdim=True) + 1e-6)
154            return norm, {}
155
156    class MTKPattern(torch.nn.Module):
157        def forward(self, x):
158            norm = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + 1e-6)
159            return norm, {}
160
161    for pattern_cls in (ExecuTorchPattern, MTKPattern):
162        pattern_gm = export_for_training(pattern_cls(), (torch.randn(3, 3),)).module()
163        matcher = SubgraphMatcherWithNameNodeMap(
164            pattern_gm, ignore_literals=True, remove_overlapping_matches=False
165        )
166        matches = matcher.match(graph)
167        for match in matches:
168            target_nodes = []
169            for node in match.nodes_map.values():
170                if node in match.placeholder_nodes:
171                    continue
172                if node.op == "call_function" and node.target in OP_TO_ANNOTATOR:
173                    target_nodes.append(node)
174
175            if any(_is_annotated(node) for node in target_nodes):
176                continue
177            _mark_as_annotated(target_nodes)
178            for node in match.returning_nodes:
179                _annotate_output_qspec(node, quant_config.activation)
180
181
182def annotate_placeholder(node: Node, quant_config: QuantizationConfig) -> None:
183    if _is_annotated(node):
184        return
185
186    if _is_float_activation_tensor(node):
187        _annotate_output_qspec(node, quant_config.activation)
188
189    _mark_as_annotated([node])
190
191
192@register_annotator(
193    [
194        torch.ops.aten.conv1d.default,
195        torch.ops.aten.conv2d.default,
196        torch.ops.aten.linear.default,
197    ]
198)
199def annotate_affine_ops(node: Node, quant_config: QuantizationConfig) -> None:
200    if _is_annotated(node):
201        return
202
203    weight_node = node.args[1]
204    _annotate_input_qspec_map(
205        node,
206        weight_node,
207        quant_config.weight,
208    )
209    _annotate_output_qspec(node, quant_config.activation)
210
211    # Make weight as annotated because it is a constant node
212    _mark_as_annotated([node, weight_node])
213
214
215@register_annotator(
216    [
217        torch.ops.aten.add.Scalar,
218        torch.ops.aten.add.Tensor,
219        torch.ops.aten.add_.Scalar,
220        torch.ops.aten.add_.Tensor,
221        torch.ops.aten.bmm.default,
222        torch.ops.aten.div.Scalar,
223        torch.ops.aten.div.Tensor,
224        torch.ops.aten.div_.Scalar,
225        torch.ops.aten.div_.Tensor,
226        torch.ops.aten.divide.Scalar,
227        torch.ops.aten.divide.Tensor,
228        torch.ops.aten.gelu.default,
229        torch.ops.aten.group_norm.default,
230        torch.ops.aten.layer_norm.default,
231        torch.ops.aten.leaky_relu.default,
232        torch.ops.aten.matmul.default,
233        torch.ops.aten.mul.Scalar,
234        torch.ops.aten.mul.Tensor,
235        torch.ops.aten.mul_.Scalar,
236        torch.ops.aten.mul_.Tensor,
237        torch.ops.aten.pow.Scalar,
238        torch.ops.aten.pow.Tensor_Scalar,
239        torch.ops.aten.pow.Tensor_Tensor,
240        torch.ops.aten.prelu.default,
241        torch.ops.aten.rsub.Scalar,
242        torch.ops.aten.rsub.Tensor,
243        torch.ops.aten.silu.default,
244        torch.ops.aten.sub.Scalar,
245        torch.ops.aten.sub.Tensor,
246        torch.ops.aten.sub_.Scalar,
247        torch.ops.aten.sub_.Tensor,
248    ]
249)
250def annotate_output_qspec(node: Node, quant_config: QuantizationConfig) -> None:
251    if _is_annotated(node):
252        return
253    _annotate_output_qspec(node, quant_config.activation)
254    _mark_as_annotated([node])
255
256
257@register_annotator([torch.ops.aten.embedding.default])
258def annotate_embedding_op(node: Node, quant_config: QuantizationConfig) -> None:
259    if _is_annotated(node):
260        return
261
262    wgt_node = node.args[0]
263    _annotate_input_qspec_map(node, wgt_node, quant_config.activation)
264    _mark_as_annotated([node])
265