xref: /aosp_15_r20/external/executorch/backends/xnnpack/operators/quant_params.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7from __future__ import annotations
8
9from typing import cast, Optional, Union
10
11import torch
12from executorch.backends.xnnpack._passes.tag_implicit_q_dq_pass import (
13    TagImplicitQDqPass,
14)
15from executorch.backends.xnnpack.utils.quant_utils import (
16    extract_qdq_affine_op_args_for_decomposed_ops,
17    is_affine_qdq,
18    is_dequant,
19    is_dynamic_qdq,
20    is_per_channel,
21    is_per_channel_group,
22    is_quant,
23)
24from executorch.backends.xnnpack.utils.utils import (
25    check_or_raise,
26    get_param_tensor,
27    is_param_node,
28)
29from executorch.exir.dialects._ops import ops as exir_ops
30from torch.export import ExportedProgram
31
32
33class QuantParams:
34    """
35    QuantParams class, to represent the paramaters and meta data needed
36    to quantize a tensor. The metadata can technically all be encapsulated
37    within the quant torch.fx.Node, however, there are some cases in which
38    nodes which are meant to be quantized for XNNPACK are not quantized
39    in PyTorch IR, specifically bias nodes. In this case, we can still build
40    quantizer class to serialize the quantized attributes needed for XNNPACK.
41
42    Attributes:
43        per_channel: Whether this quantization is per channel or per tensor
44        q_input: node that is the input to this quantization
45        scale: tensor or float that is used as the quantization scale
46        zp: tensor or float that is used as the quantization zero point
47        axis: used for per_channel quantizaiton, representing the axis
48        dtype: dtype of the type being quantized to
49        qmin: quantization minimum
50        qmax: quantization maximum
51        is_output: whether this is an output node or not
52        is_input: whether this is an input node or not
53    """
54
55    def __init__(
56        self,
57        per_channel: bool,
58        q_input: torch.fx.Node,
59        scale: Union[torch.Tensor, float],
60        zp: Union[torch.Tensor, float],
61        axis: int,
62        dtype: torch.dtype,
63        qmax: int,
64        qmin: int,
65        is_output: bool,
66        is_input: bool,
67        is_dynamic: bool = False,
68        num_nonbatch_dims: int = 1,
69        group_size: int = 0,
70    ) -> None:
71        self.per_channel = per_channel
72        self.q_input = q_input
73        self.scale = scale
74        self.zp = zp
75        self.axis = axis
76        self.dtype = dtype
77        self.qmax = qmax
78        self.qmin = qmin
79        self.is_output = is_output
80        self.is_input = is_input
81        self.is_dynamic = is_dynamic
82        self.num_nonbatch_dims = num_nonbatch_dims
83        self.is_qc4w = (
84            self.per_channel
85            and not self.is_dynamic
86            and self.qmin == -8
87            and self.qmax == 7
88            and self.dtype == torch.int8
89        )
90
91        # Groupwise quantization for weight
92        self.per_channel_group = False
93        self.group_size = group_size
94        if self.group_size > 0:
95            assert (
96                self.per_channel is True
97            ), "Only per channel quantization supports groupwise quantization"
98            assert (
99                cast(torch.Tensor, scale).ndim == 2
100            ), "Scale must be 2D for per channel groupwise quant"
101            self.per_channel_group = True
102            assert group_size > 0, "Group size must be greater than 0"
103        self.is_per_channel_group = self.per_channel and self.group_size > 0
104
105    def quantize_tensor(self, tensor: torch.Tensor) -> torch.Tensor:
106        # Do nothing if already quantized by the Quantizer
107        if tensor.dtype == self.dtype:
108            return tensor
109
110        if self.per_channel:
111            assert (
112                self.per_channel_group is False
113            ), f"Not expecting per channel group quantization, got q dtype: {self.dtype}, tensor.dtype {tensor.dtype}"
114            assert (
115                tensor.shape[self.axis] == cast(torch.Tensor, self.scale).shape[0]
116            ), f"Invalid size of per channel quantization scales, axis: {self.axis}, scale size: {self.scale.shape}, tensor shape: {tensor.shape}"
117
118            assert (
119                tensor.shape[self.axis] == cast(torch.Tensor, self.zp).shape[0]
120            ), f"Invalid size of per channel quantization zero-points, axis: {self.axis}, zp size: {self.zp.shape}, tensor shape: {tensor.shape}"
121
122            # Assuming folded quant weights
123            # TODO Add support for unfolded weights
124            assert not self.is_qc4w, "Not expecting QC4W per channel tensor"
125
126            return exir_ops.edge.quantized_decomposed.quantize_per_channel.default(
127                tensor, self.scale, self.zp, self.axis, self.qmin, self.qmax, self.dtype
128            )
129        else:
130            return exir_ops.edge.quantized_decomposed.quantize_per_tensor.default(
131                tensor, self.scale, self.zp, self.qmin, self.qmax, self.dtype
132            )
133
134    @classmethod
135    def _from_dynamic_input_node(cls, quant_node: torch.fx.Node) -> QuantParams:
136        q_input = quant_node.args[0]  # fp32 input
137        assert isinstance(q_input, torch.fx.Node)
138        # TODO - materialize this from the quant_node scale count and val shape
139        num_nonbatch_dims = 1
140
141        return cls(
142            per_channel=False,  # True is not valid
143            q_input=q_input,
144            scale=0.0,  # no need
145            zp=0.0,  # no need
146            axis=0,  # no need
147            dtype=torch.float32,  # will be quantized at runtime
148            qmax=0,  # no need
149            qmin=0,  # no need
150            is_output=False,
151            is_input=q_input.op == "placeholder",
152            is_dynamic=True,
153            num_nonbatch_dims=num_nonbatch_dims,
154        )
155
156    @classmethod
157    def from_q_dq_node(
158        cls, quant_node: torch.fx.Node, ep: Optional[ExportedProgram] = None
159    ) -> QuantParams:
160        check_or_raise(
161            is_quant(quant_node) or is_dequant(quant_node),
162            f"building quantizer from q/dq node but was given node:{quant_node}",
163        )
164        q_input = quant_node.all_input_nodes[0]
165
166        # TODO: Use presence of choose_qparam node to determine if this is a dynamic quantization
167        if is_dynamic_qdq(quant_node):
168            return cls._from_dynamic_input_node(quant_node)
169
170        per_channel = is_per_channel(quant_node)
171
172        _groupwise = is_per_channel_group(quant_node)
173        quant_node_args = quant_node.args
174        if _groupwise and is_affine_qdq(quant_node):
175            quant_node_args = extract_qdq_affine_op_args_for_decomposed_ops(quant_node)
176
177        scale = quant_node_args[1]
178        zp = quant_node_args[2]
179        axis = 0
180        if per_channel:
181            assert isinstance(scale, torch.fx.Node) and isinstance(scale.target, str)
182            assert isinstance(zp, torch.fx.Node) and isinstance(zp.target, str)
183            assert (
184                ep is not None
185            ), "ExportedProgram must be provided to extract per channel params"
186
187            def _get_tensor(node):
188                param = get_param_tensor(ep, node)
189                assert param is not None, f"Expected to find param tensor for {node}"
190                return cast(torch.Tensor, param)
191
192            scale = _get_tensor(scale)
193            zp = _get_tensor(zp)
194            axis = cast(int, quant_node_args[3])
195
196            if _groupwise:
197                scale_tensor = cast(torch.Tensor, scale)
198                if scale_tensor.ndim == 1:
199                    scale_tensor = scale_tensor.reshape(-1, 1)
200                    zp = zp.reshape(-1, 1)
201                    scale = scale_tensor
202
203                assert (
204                    scale_tensor.ndim == 2
205                ), "Weight scale must be 2D for per_channel_group [de]quant node, got {scale.ndim}D"
206                axis = 0  # axis is ignored for groupwise quantization
207
208        check_or_raise(
209            bool(
210                quant_node_args[-1] != torch.uint8
211                or quant_node_args[-1] != torch.quint8
212            ),
213            "XNNPACK does not support unsigned quantization",
214        )
215
216        if _groupwise:
217            _ = quant_node_args[-1]  # output dtype - not used
218            group_size = cast(int, quant_node_args[-2])
219            dtype = cast(torch.dtype, quant_node_args[-3])
220            qmax = cast(int, quant_node_args[-4])
221            qmin = cast(int, quant_node_args[-5])
222        else:
223            group_size = 0
224            dtype = cast(torch.dtype, quant_node_args[-1])
225            qmax = cast(int, quant_node_args[-2])
226            qmin = cast(int, quant_node_args[-3])
227
228        is_output = any(
229            user_node.op == "output" for user_node in quant_node.users.keys()
230        )
231        is_input = q_input.op == "placeholder"
232        return cls(
233            per_channel,
234            q_input,
235            scale,
236            zp,
237            axis,
238            dtype,
239            qmax,
240            qmin,
241            is_output,
242            is_input,
243            group_size=group_size,
244        )
245
246    @classmethod
247    def from_weights(
248        cls, tensor_node: torch.fx.Node, ep: Optional[ExportedProgram] = None
249    ) -> Optional[QuantParams]:
250        if not is_dequant(tensor_node):
251            return None
252
253        # source node for quant params
254        src = tensor_node
255
256        # is input of dq is q?
257        dq_input = src.all_input_nodes[0]
258        if is_quant(dq_input):
259            src = dq_input
260
261        # replace this with pointing to the actual weight value.
262        # if no one else uses this weight value then take it out of the toplevel module
263        check_or_raise(
264            src.all_input_nodes[0].op in ["get_attr", "placeholder"],
265            f"q->dq->permute_copy not derived from static weight, input to the q or dq (for folded quant) node: {src.all_input_nodes[0]}",
266        )
267
268        return cls.from_q_dq_node(src, ep)
269
270    @classmethod
271    def from_inputs(
272        cls, tensor_node: torch.fx.Node, ep: ExportedProgram
273    ) -> Optional[QuantParams]:
274        # tensor_node is quantized if it is produced by a dequant node
275        if is_dequant(tensor_node) and TagImplicitQDqPass.is_tagged_as_implicit_q_dq(
276            tensor_node
277        ):
278            dq_input = cast(torch.fx.Node, tensor_node.args[0])
279            if is_quant(dq_input):
280                q_input = cast(torch.fx.Node, dq_input.args[0])
281                if is_param_node(ep, q_input):
282                    return cls.from_q_dq_node(dq_input)
283            return cls.from_q_dq_node(tensor_node)
284
285        return None
286
287    @classmethod
288    def from_outputs(cls, tensor_node: torch.fx.Node) -> Optional[QuantParams]:
289        # tensor_node can also be quantized if it is used as in q -> dq
290        if len(tensor_node.users) == 1:
291            q = list(tensor_node.users.keys())[0]
292            # Check if user is a q node
293            if is_quant(q) and TagImplicitQDqPass.is_tagged_as_implicit_q_dq(q):
294                return cls.from_q_dq_node(q)
295
296        return None
297
298    @classmethod
299    def from_bias(
300        cls,
301        bias: torch.fx.Node,
302        weight_quantizer: Optional[QuantParams],
303        input_quantizer: Optional[QuantParams],
304    ) -> Optional[QuantParams]:
305        if weight_quantizer is None or input_quantizer is None:
306            check_or_raise(
307                weight_quantizer is None and input_quantizer is None,
308                "Weight and Input should both be quantized",
309            )
310            return None
311
312        if input_quantizer.is_dynamic:
313            # No need to quantize bias for dyanamic quantization
314            return None
315
316        check_or_raise(
317            not input_quantizer.per_channel,
318            "Input can not be quantized per channel",
319        )
320
321        # Only per_tensor quantization is supported for input here
322        check_or_raise(
323            isinstance(input_quantizer.scale, float),
324            f"q_input scale should be float, but got {input_quantizer.scale}",
325        )
326        return cls(
327            per_channel=weight_quantizer.per_channel,
328            q_input=bias,
329            scale=weight_quantizer.scale * cast(float, input_quantizer.scale),
330            zp=weight_quantizer.zp * 0,
331            axis=0,  # not using weight_quantizer.axis because bias is always of shape [out_channels] i.e. 1D
332            dtype=torch.int32,
333            qmin=-(2**31),
334            qmax=(2**31) - 1,
335            is_output=False,
336            is_input=False,
337        )
338