xref: /aosp_15_r20/external/executorch/backends/qualcomm/quantizer/custom_annotation.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Qualcomm Innovation Center, Inc.
2# All rights reserved
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6from typing import Sequence
7
8import torch
9from executorch.backends.qualcomm.quantizer.annotators import QUANT_ANNOTATION_KEY
10from executorch.backends.qualcomm.quantizer.quantizer import (
11    get_16a8w_qnn_ptq_config,
12    get_8a8w_qnn_ptq_config,
13    get_ptq_per_channel_quant_config,
14    QuantizationConfig,
15)
16from executorch.exir.dialects._ops import ops as exir_ops
17from torch.ao.quantization.observer import MinMaxObserver
18from torch.ao.quantization.quantizer import (
19    QuantizationAnnotation,
20    SharedQuantizationSpec,
21)
22from torch.fx import Node
23
24
25def annotate_matmul_16a8w(gm: torch.fx.GraphModule) -> None:
26    """
27    This function is specific for matmul op 16a8w.
28    """
29
30    def annotate_matmul(node: Node, quantization_config: QuantizationConfig):
31        input_qspec_map = {}
32        input_act = node.args[0]
33        input_spec = quantization_config.input_activation
34        input_qspec_map[input_act] = input_spec
35
36        input_act1 = node.args[1]
37        input_spec1 = quantization_config.weight
38        input_qspec_map[input_act1] = input_spec1
39
40        node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
41            input_qspec_map=input_qspec_map,
42            output_qspec=quantization_config.output_activation,
43            _annotated=True,
44        )
45
46    def annotate_cat(node: Node, quantization_config: QuantizationConfig):
47        input_nodes = node.args[0]
48
49        first_input_node = input_nodes[0]
50        input_qspec_map = {}
51        input_qspec_map[first_input_node] = quantization_config.input_activation
52        share_qparams_with_input_act0_qspec = SharedQuantizationSpec(
53            (first_input_node, node)
54        )
55
56        for input_node in input_nodes[1:]:
57            if input_node not in input_qspec_map:
58                input_qspec_map[input_node] = share_qparams_with_input_act0_qspec
59
60        node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
61            input_qspec_map=input_qspec_map,
62            output_qspec=share_qparams_with_input_act0_qspec,
63            _annotated=True,
64        )
65
66    def annotate_single_in_single_out(
67        node: Node, quantization_config: QuantizationConfig
68    ) -> None:
69
70        input_qspec_map = {}
71        input_act = node.args[0]
72        input_qspec_map[input_act] = quantization_config.input_activation
73
74        node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
75            input_qspec_map=input_qspec_map,
76            output_qspec=quantization_config.output_activation,
77            _annotated=True,
78        )
79
80    def annotate_matmul_input1(node: Node):
81        quantization_config_8a8w = get_8a8w_qnn_ptq_config(
82            act_symmetric=True, act_observer=MinMaxObserver
83        )
84        while isinstance(node, Node) and node.op == "call_function":
85            if node.target in [
86                torch.ops.aten.permute.default,
87                torch.ops.aten.transpose.int,
88            ]:
89                annotate_single_in_single_out(node, quantization_config_8a8w)
90                node = node.args[0]
91            elif node.target == torch.ops.aten.cat.default:
92                annotate_cat(node, quantization_config_8a8w)
93                node = node.args[0][0]
94            else:
95                node = node.args[0]
96
97    quantization_config_16a8w = get_16a8w_qnn_ptq_config(act_observer=MinMaxObserver)
98
99    for node in gm.graph.nodes:
100        if node.op == "call_function" and node.target == torch.ops.aten.matmul.default:
101            annotate_matmul(node, quantization_config_16a8w)
102            annotate_matmul_input1(node.args[1])
103
104
105def custom_annotate_llama_matmul_16a8w(gm: torch.fx.GraphModule) -> None:  # noqa: C901
106    """
107    This function is specific for llama matmul op 16a8w.
108    """
109
110    def annotate_matmul(node: Node, quantization_config: QuantizationConfig):
111        input_qspec_map = {}
112        input_act = node.args[0]
113        input_spec = quantization_config.input_activation
114        input_qspec_map[input_act] = input_spec
115        input_act1 = node.args[1]
116        input_spec1 = quantization_config.weight
117        input_qspec_map[input_act1] = input_spec1
118        node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
119            input_qspec_map=input_qspec_map,
120            output_qspec=quantization_config.output_activation,
121            _annotated=True,
122        )
123
124    def annotate_index_put(node: Node, quantization_config: QuantizationConfig) -> None:
125        input = node.args[0]
126        value = node.args[2]
127        input_qspec_map = {}
128        input_qspec_map[input] = quantization_config.input_activation
129        input_qspec_map[value] = SharedQuantizationSpec((input, node))
130        node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
131            input_qspec_map=input_qspec_map,
132            output_qspec=SharedQuantizationSpec((input, node)),
133            _annotated=True,
134        )
135
136    def annotate_single_in_single_out(
137        node: Node, quantization_config: QuantizationConfig
138    ) -> None:
139        input_qspec_map = {}
140        input_act = node.args[0]
141        input_qspec_map[input_act] = quantization_config.input_activation
142        node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
143            input_qspec_map=input_qspec_map,
144            output_qspec=quantization_config.output_activation,
145            _annotated=True,
146        )
147
148    def annotate_cat(node: Node, quantization_config: QuantizationConfig):
149        input_nodes = node.args[0]
150        assert isinstance(input_nodes, Sequence)
151        first_input_node = input_nodes[0]
152        input_qspec_map = {}
153        assert isinstance(first_input_node, Node)
154        assert isinstance(node, Node)
155        input_qspec_map[first_input_node] = quantization_config.input_activation
156        share_qparams_with_input_act0_qspec = SharedQuantizationSpec(
157            (first_input_node, node)
158        )
159        for input_node in input_nodes[1:]:
160            if input_node not in input_qspec_map:
161                assert isinstance(input_node, Node)
162                input_qspec_map[input_node] = share_qparams_with_input_act0_qspec
163        node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
164            input_qspec_map=input_qspec_map,
165            output_qspec=share_qparams_with_input_act0_qspec,
166            _annotated=True,
167        )
168
169    def is_edge_condition(node: Node):
170        if not isinstance(node, Node) or node.op != "call_function":
171            return True
172        return False
173
174    def annotate_matmul_input1(node: Node, quantization_config: QuantizationConfig):
175        if is_edge_condition(node):
176            return
177        if node.target in [
178            torch.ops.aten.index_put.default,
179            torch.ops.aten.index_put_.default,
180        ]:
181            annotate_index_put(node, quantization_config)
182            annotate_matmul_input1(node.args[0], quantization_config)
183        elif node.target == torch.ops.aten.cat.default:
184            annotate_cat(node, quantization_config)
185            # Expect that the inputs of the cat op are select ops
186            for arg in node.args[0]:
187                annotate_matmul_input1(arg, quantization_config)
188        else:
189            annotate_single_in_single_out(node, quantization_config)
190            annotate_matmul_input1(node.args[0], quantization_config)
191
192    # Annotate 16a8w for matmul op to get better performance
193    quantization_config_16a8w = get_16a8w_qnn_ptq_config()
194    # Annotate 8a8w for second input of matmul until past_kv_cache
195    quantization_config_8a8w = get_8a8w_qnn_ptq_config(act_symmetric=True)
196    for node in gm.graph.nodes:
197        if node.op == "call_function" and node.target == torch.ops.aten.matmul.default:
198            if "nn_module_stack" in node.meta:
199                module_values_list = list(node.meta["nn_module_stack"].values())
200                full_qualified_name = module_values_list[-1][0]
201                if "SDPA" in full_qualified_name:
202                    annotate_matmul(node, quantization_config_16a8w)
203                    annotate_matmul_input1(node.args[1], quantization_config_8a8w)
204
205
206def custom_annotate_llama_last_conv_16a8w(gm: torch.fx.GraphModule) -> None:
207    def annotate_conv2d(node: Node, quantization_config: QuantizationConfig) -> None:
208        input_qspec_map = {}
209        input_act = node.args[0]
210        input_spec = quantization_config.input_activation
211        input_qspec_map[input_act] = input_spec
212
213        weight = node.args[1]
214        input_qspec_map[weight] = quantization_config.weight
215
216        node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
217            input_qspec_map=input_qspec_map,
218            output_qspec=quantization_config.output_activation,
219            _annotated=True,
220        )
221
222    quantization_config_16a8w_per_channel = get_ptq_per_channel_quant_config(
223        torch.uint16, weight_dtype=torch.int8
224    )
225    for node in gm.graph.nodes:
226        if node.op == "call_function" and node.target == torch.ops.aten.conv2d.default:
227            if "nn_module_stack" in node.meta:
228                module_values_list = list(node.meta["nn_module_stack"].values())
229                full_qualified_name = module_values_list[0][0]
230                if full_qualified_name == "L['self'].llama.output":
231                    annotate_conv2d(
232                        node, quantization_config=quantization_config_16a8w_per_channel
233                    )
234
235
236def custom_annotate_matmul_16a8w(gm: torch.fx.GraphModule):
237    """
238    Annotate matmul op with 16a8w quantization config
239    """
240
241    def annotate_matmul(node: Node, quantization_config: QuantizationConfig):
242        input_qspec_map = {}
243        input_act = node.args[0]
244        input_spec = quantization_config.input_activation
245        input_qspec_map[input_act] = input_spec
246        input_act1 = node.args[1]
247        input_spec1 = quantization_config.weight
248        input_qspec_map[input_act1] = input_spec1
249        node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
250            input_qspec_map=input_qspec_map,
251            output_qspec=quantization_config.output_activation,
252            _annotated=True,
253        )
254
255    # Annotate 16a8w for matmul op to get better performance
256    quantization_config_16a8w = get_16a8w_qnn_ptq_config()
257    for node in gm.graph.nodes:
258        if node.op == "call_function" and node.target == torch.ops.aten.matmul.default:
259            annotate_matmul(node, quantization_config_16a8w)
260
261
262def get_custom_quant_ios_dtype(
263    cache_shape: torch.Size,
264    node: torch.fx.Node,
265    kv_dtype=torch.uint8,
266    sharding_dtype=torch.uint16,
267):
268    """
269    This function is specific for llama inputs and outputs
270    """
271    if node.op == "placeholder" and "attention_sdpa_kv_cache_past_" in node.name:
272        return kv_dtype
273
274    # Tag index put node before copy node, because copy is a skipped node in qnn
275    if (
276        exir_ops.edge.aten.index_put.default == node.target
277        and node.meta["val"].shape == cache_shape
278    ):
279        return kv_dtype
280
281    # Tag sharding io
282    if exir_ops.edge.llama.fallback.default in [
283        u.target for u in list(node.users.keys())
284    ] + [node.target]:
285        return sharding_dtype
286
287    # Tag index op as quantized tensors. It is caused by sharding
288    if exir_ops.edge.aten.index.Tensor in [
289        u.target for u in list(node.users.keys())
290    ] + [node.target]:
291        return sharding_dtype
292