xref: /aosp_15_r20/external/executorch/backends/arm/tosa_quant_utils.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright 2023-2024 Arm Limited and/or its affiliates.
2#
3# This source code is licensed under the BSD-style license found in the
4# LICENSE file in the root directory of this source tree.
5
6# pyre-unsafe
7
8# Utiliy functions for TOSA quantized lowerings
9
10import math
11from typing import Callable, cast, NamedTuple, Sequence
12
13import numpy as np
14
15import serializer.tosa_serializer as ts
16import torch.fx
17import tosa.Op as TosaOp
18from executorch.backends.arm.tosa_mapping import TosaArg
19from executorch.exir.dialects._ops import ops as exir_ops
20from serializer.tosa_serializer import TosaSerializerTensor
21from torch.fx import Node
22
23
24q_op = exir_ops.edge.quantized_decomposed.quantize_per_tensor.default
25dq_op = exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default
26dq_q_ops = (q_op, dq_op)
27passable_ops = [
28    exir_ops.edge.aten.view_copy.default,
29    exir_ops.edge.aten.permute_copy.default,
30    exir_ops.edge.aten.squeeze_copy.dims,
31    exir_ops.edge.aten.unsqueeze_copy.default,
32    exir_ops.edge.aten.split_with_sizes_copy.default,
33    exir_ops.edge.aten.repeat.default,
34    exir_ops.edge.aten.clone.default,
35    exir_ops.edge.aten.slice_copy.Tensor,
36    exir_ops.edge.aten.cat.default,
37]
38
39
40def register_passable_op(op):
41    """We need to be able to add custom ops such as tosa_transpose to the passable_op list after they have been created"""
42    passable_ops.append(op)
43
44
45class QuantArgs(NamedTuple):
46    scale: float
47    zp: int
48    qmin: int
49    qmax: int
50    dtype: torch.dtype
51
52    def quantize_value(self, x):
53        if not isinstance(x, torch.Tensor):
54            x = torch.Tensor([x])
55        return torch.clip(
56            torch.round(x / self.scale) + self.zp,
57            self.qmin,
58            self.qmax,
59        ).to(self.dtype)
60
61    def dequantize_value(self, qx: int) -> float:
62        return (qx - self.zp) * self.scale
63
64
65def quantize_value(x, qargs: QuantArgs, dtype=np.int8):
66    return np.clip(
67        np.round(x / qargs.scale) + qargs.zp,
68        qargs.qmin,
69        qargs.qmax,
70    ).astype(dtype)
71
72
73def dequantize_value(qx, qargs: QuantArgs):
74    return (qx - qargs.zp) * qargs.scale
75
76
77def qargs_from_qnode(node: torch.fx.Node):
78    assert node.target in dq_q_ops, f"Op {node} is not a quant node."
79
80    return QuantArgs(
81        scale=cast(float, node.args[1]),
82        zp=cast(int, node.args[2]),
83        qmin=cast(int, node.args[3]),
84        qmax=cast(int, node.args[4]),
85        dtype=cast(torch.dtype, node.args[5]),
86    )
87
88
89def get_neighbour_quant_args(
90    node: torch.fx.Node,
91) -> tuple[list[QuantArgs], list[QuantArgs]]:
92    user_q_args = []
93
94    for user in node.users:
95        q_args = search_quant_arg_downstream(user)
96        if q_args:
97            user_q_args.append(q_args)
98
99    input_q_nodes = []
100    for input_node in node.all_input_nodes:
101        q_args = search_quant_arg_upstream(input_node)
102        if q_args:
103            input_q_nodes.append(q_args)
104    return user_q_args, input_q_nodes
105
106
107def all_q_args_equal(q_arg_list: list[QuantArgs]) -> bool:
108    first_q_arg = q_arg_list[0]
109    for q_arg in q_arg_list:
110        if q_arg != first_q_arg:
111            return False
112    return True
113
114
115def is_node_quantized(node: torch.fx.Node) -> bool:
116    if node.target in dq_q_ops:
117        return True
118
119    user_q_args, input_q_args = get_neighbour_quant_args(node)
120
121    # If we did not find any neighbouring quant nodes, we are not quantized.
122    if len(input_q_args) == 0 and len(user_q_args) == 0:
123        return False
124
125    if node.target in passable_ops:
126        assert all_q_args_equal(
127            user_q_args + input_q_args
128        ), f"Node {node} needs same quantization parameters on all inputs and outputs."
129
130    return True
131
132
133def search_quant_arg_downstream(node: torch.fx.Node) -> QuantArgs | None:
134    """
135    Iterates downward in the graph passing through 'passable_ops' to find and return a quantization node,
136    starting with 'node'.
137    If a  passable node with multiple consumers is encountered,
138    find QuantArgs for all consumers and assert that they are equal.
139    If a node not in passable_ops is encountered, return None.
140    If a node without consumers is encountered, return None.
141    """
142    if node.target in dq_q_ops:
143        return qargs_from_qnode(node)
144    if node.target not in passable_ops:
145        return None
146    consumer_nodes = list(node.users)
147    if len(consumer_nodes) == 0:
148        return None
149    elif len(consumer_nodes) == 1:
150        return search_quant_arg_downstream(consumer_nodes[0])
151    else:
152        consumer_qargs: list[QuantArgs] = []
153        for input in consumer_nodes:
154            quant_args = search_quant_arg_downstream(input)
155            if quant_args:
156                consumer_qargs.append(quant_args)
157        if len(consumer_qargs) == 0:
158            return None
159        assert all_q_args_equal(
160            consumer_qargs
161        ), f"Encountered a op, {node}, in passable_ops with different QuantArgs for different consumers."
162        return consumer_qargs[0]
163
164
165def get_quant_arg_downstream(node: torch.fx.Node) -> QuantArgs:
166    """Calls search_quant_arg_downstream and asserts that QuantArgs are found,
167    meaning return value can't be None.
168    """
169    qargs = search_quant_arg_downstream(node)
170    assert qargs, f"Did not find QuantArgs downstream for node {node}"
171    return qargs
172
173
174def search_quant_arg_upstream(node: torch.fx.Node) -> QuantArgs | None:
175    """
176    Iterates upward in the graph passing through 'passable_ops' to find and return a quantization node,
177    starting with 'node'.
178    If a  passable node with multiple inputs is encountered,
179    find QuantArgs for all inputs and assert that they are equal.
180    If a node not in passable_ops is encountered, return None.
181    If a node without inputs is encountered, return None.
182    """
183
184    if node.target in dq_q_ops:
185        return qargs_from_qnode(node)
186    if node.target not in passable_ops:
187        return None
188    input_nodes = list(node.all_input_nodes)
189    if len(input_nodes) == 0:
190        return None
191    elif len(input_nodes) == 1:
192        return search_quant_arg_upstream(input_nodes[0])
193    else:
194        input_qargs: list[QuantArgs] = []
195        for input in input_nodes:
196            quant_args = search_quant_arg_upstream(input)
197            if quant_args:
198                input_qargs.append(quant_args)
199        if len(input_qargs) == 0:
200            return None
201        assert all_q_args_equal(
202            input_qargs
203        ), f"Encountered a op, {node}, in passable_ops with different QuantArgs for different inputs."
204        return input_qargs[0]
205
206
207def get_quant_arg_upstream(node: torch.fx.Node) -> QuantArgs:
208    """Calls search_quant_arg_upstream and asserts that QuantArgs are found,
209    meaning return value can't be None.
210    """
211    qargs = search_quant_arg_upstream(node)
212    assert qargs, f"Did not find QuantArgs upstream for node {node}"
213    return qargs
214
215
216def get_quantized_node_output_dtype(node: torch.fx.Node) -> torch.dtype:
217    if isinstance(node.target, Callable) and "tosa" in node.target.__name__:
218        return node.meta["val"].dtype
219    if node.target in dq_q_ops:
220        return cast(torch.dtype, node.args[5])
221
222    # if not a tosa node, nor a q/dq op, walk the graph until we find a q op
223    user_q_args, input_q_args = get_neighbour_quant_args(node)
224    if len(user_q_args) > 0:
225        return user_q_args[0].dtype
226    elif node.target in passable_ops and len(input_q_args) > 0:
227        return input_q_args[0].dtype
228    else:
229        raise RuntimeError("No quantized node found in graph")
230
231
232# Check if scale32 mode is used for given output element type
233def is_scale32(type):
234    return type == ts.DType.INT8
235
236
237# TOSA uses the RESCALE operation to scale between values with differing precision.
238# The RESCALE operator is defined using an integer multiply, add, and shift.
239# This utility function is for calculating the multier and shift given a scale.
240# Ref: https://www.mlplatform.org/tosa/tosa_spec.html#_precision_scaling
241def compute_multiplier_and_shift(scale, scaleWidth=32):
242    if scaleWidth == 16:
243        offset = 15
244    elif scaleWidth == 32:
245        offset = 31
246    else:
247        raise AssertionError("unsupported scale width")
248
249    assert isinstance(scale, float)
250
251    mantissa, exponent = math.frexp(scale)
252    shift = exponent
253
254    const_2_power_15_or_31 = 1 << offset
255    shifted_mantissa = round(mantissa * const_2_power_15_or_31)
256
257    assert shifted_mantissa <= const_2_power_15_or_31
258
259    if shifted_mantissa == const_2_power_15_or_31:
260        shifted_mantissa = shifted_mantissa / 2
261        shift += 1
262
263    # TOSA expects right shift to be positive, and embed (1 << offset) into right shift bits.
264    shift = offset - shift
265
266    # INT32_MAX, 2^31 - 1
267    assert shifted_mantissa <= (const_2_power_15_or_31 - 1)
268
269    multiplier = shifted_mantissa
270
271    if shift > 62:
272        multiplier = multiplier >> min(31, shift - 62)
273        shift = 62
274    return multiplier, shift
275
276
277def build_rescale(
278    tosa_fb,
279    scale,
280    input_node,
281    output_name,
282    output_type,
283    output_shape,
284    input_zp,
285    output_zp,
286    is_double_round=False,
287):
288    scale_width = 32 if is_scale32(output_type) else 16
289    multiplier, shift = compute_multiplier_and_shift(scale, scale_width)
290
291    attr_rescale = ts.TosaSerializerAttribute()
292    attr_rescale.RescaleAttribute(
293        input_zp=input_zp,
294        output_zp=output_zp,
295        multiplier=[multiplier],
296        shift=[shift],
297        scale32=is_scale32(output_type),
298        double_round=is_double_round,
299        per_channel=False,
300        input_unsigned=False,
301        output_unsigned=False,
302    )
303
304    tosa_fb.addOperator(
305        TosaOp.Op().RESCALE, [input_node.name], [output_name], attr_rescale
306    )
307
308    return
309
310
311def build_rescale_to_int32(
312    tosa_fb, input, input_zp, rescale_scale, is_scale32=True, is_double_round=False
313) -> TosaSerializerTensor:
314    multiplier, shift = compute_multiplier_and_shift(rescale_scale)
315    attr_rescale = ts.TosaSerializerAttribute()
316    attr_rescale.RescaleAttribute(
317        input_zp=input_zp,
318        output_zp=0,
319        multiplier=[multiplier],
320        shift=[shift],
321        scale32=is_scale32,
322        double_round=is_double_round,
323        per_channel=False,
324        input_unsigned=False,
325        output_unsigned=False,
326    )
327    input_A_rescaled_to_int32 = tosa_fb.addIntermediate(input.shape, ts.DType.INT32)
328    tosa_fb.addOperator(
329        TosaOp.Op().RESCALE,
330        [input.name],
331        [input_A_rescaled_to_int32.name],
332        attr_rescale,
333    )
334
335    return input_A_rescaled_to_int32
336
337
338def build_rescale_from_int32(
339    tosa_fb,
340    input_name,
341    output_name,
342    output_zp,
343    rescale_scale,
344    is_scale32=True,
345    is_double_round=False,
346) -> None:
347    multiplier, shift = compute_multiplier_and_shift(rescale_scale)
348    attr_rescale_output = ts.TosaSerializerAttribute()
349    attr_rescale_output.RescaleAttribute(
350        input_zp=0,
351        output_zp=output_zp,
352        multiplier=[multiplier],
353        shift=[shift],
354        scale32=is_scale32,
355        double_round=is_double_round,
356        per_channel=False,
357        input_unsigned=False,
358        output_unsigned=False,
359    )
360
361    tosa_fb.addOperator(
362        TosaOp.Op().RESCALE, [input_name], [output_name], attr_rescale_output
363    )
364
365    return
366
367
368def rescale_nodes_to_int32(
369    nodes: Sequence[Node], tosa_graph: ts.TosaSerializer
370) -> tuple[list[TosaSerializerTensor], float]:
371    """Rescales all 'nodes' to int32, adding suitable RESCALE ops to 'tosa_graph'.
372    The scales are adjusted using the smallest scale of all 'nodes'.
373
374    Returns a list of the rescaled nodes and the scale factor used,
375    needed by rescale_node_back_to_int8.
376    """
377
378    tensors = [TosaArg(node) for node in nodes]
379
380    # Reshape tensor according to tosa dim order
381    for tensor in tensors:
382        dim_order = tensor.dim_order
383        tensor.shape = [tensor.shape[i] for i in dim_order]
384
385    qargs = [get_quant_arg_upstream(node) for node in nodes]
386
387    # Scale the int8 quantized input to a common scale in the integer
388    # domain
389    min_scale = min([qarg.scale for qarg in qargs])
390    scales = [qarg.scale / min_scale for qarg in qargs]
391
392    rescaled_nodes: list[TosaSerializerTensor] = []
393    for tensor, qarg, scale in zip(tensors, qargs, scales):
394        rescaled_nodes.append(
395            build_rescale_to_int32(
396                tosa_graph,
397                tensor,
398                qarg.zp,
399                scale,
400            )
401        )
402    return rescaled_nodes, min_scale
403
404
405def rescale_node_back_to_int8(
406    node: Node,
407    last_tensor: TosaSerializerTensor,
408    scale: float,
409    tosa_graph: ts.TosaSerializer,
410):
411    """Rescales the node back to int8, adding a suitable RESCALE op to 'tosa_graph'.
412    Parameters:
413        node: The original node that is being handled by the rescales.
414        last_tensor:the tosa tensor to rescale back.
415        scale: the scaling factor used to rescale to int32, from the function 'rescale_nodes_to_int32'
416        tosa_graph: the tosa_graph to manipulate.
417    """
418    qargs_out = get_quant_arg_downstream(list(node.users)[0])
419    output_rescale_scale = scale / qargs_out.scale
420
421    # Rescale Back to INT8
422    build_rescale_from_int32(
423        tosa_graph,
424        last_tensor.name,
425        node.name,
426        qargs_out.zp,
427        output_rescale_scale,
428    )
429
430
431""" Creates a TOSA rescale op based on conv2d parameters. """
432
433
434def build_rescale_conv_output(
435    tosa_fb,
436    op,
437    output_name,
438    output_type,
439    input_scale,
440    weight_scale,
441    output_scale,
442    output_zp,
443):
444    # TODO add check to verify if this is a Per-channel quantization.
445    post_conv2d_scale = (input_scale * weight_scale) / output_scale
446
447    # Since we assume the input tensor that is being rescaled is int32 date type, zero point must be 0.
448    build_rescale(
449        tosa_fb,
450        post_conv2d_scale,
451        op,
452        output_name,
453        output_type,
454        op.shape,
455        0,
456        output_zp,
457    )
458    return
459