xref: /aosp_15_r20/external/executorch/backends/xnnpack/utils/quant_utils.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import operator
8from itertools import accumulate
9from typing import cast
10
11import torch
12from executorch.exir.backend.canonical_partitioners.config_partitioner import (
13    format_target_name,
14)
15
16_Q_OPS = {
17    "quantize_per_tensor.tensor",
18    "quantize_per_tensor.default",
19    "quantize_per_channel.default",
20    "quantize_per_channel_group.default",
21    "quantize_per_token.default",
22    "quantize_affine.default",
23}
24
25_DQ_OPS = {
26    "dequantize_per_tensor.tensor",
27    "dequantize_per_tensor.default",
28    "dequantize_per_channel.default",
29    "dequantize_per_channel_group.default",
30    "dequantize_per_token.default",
31    "dequantize_affine.default",
32}
33
34
35_QPARAM_OPS = {
36    "choose_qparams.tensor",
37    "choose_qparams_per_token_asymmetric.default",
38    "choose_qparams_affine.default",
39}
40
41_DYNAMIC_OPS = {
42    "quantize_per_tensor.tensor",
43    "quantize_per_token.default",
44    "dequantize_per_tensor.tensor",
45    "dequantize_per_token.default",
46}
47
48
49def is_dynamic_qdq(node: torch.fx.Node) -> bool:
50    if node.op != "call_function":
51        return False
52    node_name = format_target_name(node.target.__name__)  # pyre-ignore
53    is_dynamic_affine = is_per_token(node) and not is_per_channel_group(node)
54
55    return node_name in _DYNAMIC_OPS or is_dynamic_affine
56
57
58def is_qparam(node: torch.fx.Node) -> bool:
59    if node.op != "call_function":
60        return False
61    node_name = format_target_name(node.target.__name__)  # pyre-ignore
62
63    return node_name in _QPARAM_OPS
64
65
66def is_quant(node: torch.fx.Node) -> bool:
67    if node.op != "call_function":
68        return False
69    node_name = format_target_name(node.target.__name__)  # pyre-ignore
70
71    return node_name in _Q_OPS
72
73
74def is_dequant(node: torch.fx.Node) -> bool:
75    if node.op != "call_function":
76        return False
77    node_name = format_target_name(node.target.__name__)  # pyre-ignore
78
79    return node_name in _DQ_OPS
80
81
82def is_per_channel(node: torch.fx.Node) -> bool:
83    if not (is_quant(node) or is_dequant(node)):
84        return False
85
86    is_affine_per_channel_group = is_per_channel_group(node)
87    is_per_channel = "per_channel" in node.target.__name__  # pyre-ignore
88
89    return is_per_channel or is_affine_per_channel_group
90
91
92def is_affine_qdq(node: torch.fx.Node) -> bool:
93    if not (is_quant(node) or is_dequant(node)):
94        return False
95
96    return "quantize_affine" in node.target.__name__  # pyre-ignore
97
98
99def _get_block_size_input_scale(node: torch.fx.Node):
100    assert is_affine_qdq(node)
101    block_size = node.args[1]
102    input_val = node.all_input_nodes[0].meta["val"]
103    scale_val = node.all_input_nodes[1].meta["val"]
104    return block_size, input_val, scale_val
105
106
107def is_per_token(node: torch.fx.Node):
108    if not (is_quant(node) or is_dequant(node)):
109        return False
110
111    if "per_token" in node.target.__name__:  # pyre-ignore
112        return True
113    elif is_affine_qdq(node):
114        block_size, input_val, scale_val = _get_block_size_input_scale(node)
115        flag = True
116        scale_numel_expected = 1
117        for i in range(len(block_size) - 1):
118            flag &= block_size[i] == 1
119            scale_numel_expected *= input_val.shape[i]
120
121        flag &= block_size[-1] == input_val.shape[-1]
122        flag &= scale_val.numel() == scale_numel_expected
123        return flag
124
125    return False
126
127
128def is_per_channel_group(node: torch.fx.Node):
129    if not (is_quant(node) or is_dequant(node)):
130        return False
131
132    if "per_channel_group" in node.target.__name__:  # pyre-ignore
133        return True
134    elif is_affine_qdq(node):
135        block_size, input_val, scale_val = _get_block_size_input_scale(node)
136        flag = True
137        flag &= len(block_size) == 2
138        flag &= block_size[0] == 1
139        group_size = block_size[1]
140        scale_numel = list(accumulate(scale_val.shape, operator.mul))[-1]
141        input_numel = list(accumulate(input_val.shape, operator.mul))[-1]
142        flag &= input_numel == group_size * scale_numel
143        return flag
144
145    return False
146
147
148def extract_qdq_affine_op_args_for_decomposed_ops(node: torch.fx.Node):
149    if not is_affine_qdq(node):
150        return None, None
151    # make sure input_dtype and zero_point_domain have expected values
152    input_node = node.args[0]
153    scale_node = node.args[2]
154    zero_point_node = node.args[3]
155    args = [input_node, scale_node, zero_point_node]
156    assert (
157        len(node.args) > 4
158    ), f"expecting at least 6 args, got node: {node.format_node()}"
159
160    if node.args[4] != torch.int8:
161        return None, None
162    target_dtype = cast(torch.dtype, node.args[4])
163
164    if len(node.args) > 6:
165        # quant_min
166        args.append(node.args[5])
167        # quant_max
168        args.append(node.args[6])
169    else:
170        dtype_info = torch.iinfo(target_dtype)
171        quant_min = dtype_info.min
172        quant_max = dtype_info.max
173        args.append(quant_min)
174        args.append(quant_max)
175
176    # add target_dtype_node after quant_min/quant_max
177    args.append(target_dtype)
178    # zero_point_domain
179    if len(node.args) > 7 and node.args[7] != "INT":
180        return None, None
181
182    if is_per_channel_group(node):
183        block_sizes = cast(list[int], node.args[1])
184        args.append(block_sizes[-1])
185
186    args.append(node.args[-1])
187
188    return args
189