xref: /aosp_15_r20/external/executorch/backends/arm/tosa_utils.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright 2023-2024 Arm Limited and/or its affiliates.
2#
3# This source code is licensed under the BSD-style license found in the
4# LICENSE file in the root directory of this source tree.
5
6# pyre-unsafe
7
8import logging
9import os
10from typing import Any, cast
11
12import numpy as np
13import serializer.tosa_serializer as ts
14import torch
15from executorch.backends.arm.tosa_mapping import TosaArg
16
17from executorch.backends.arm.tosa_quant_utils import (
18    get_quant_arg_downstream,
19    get_quant_arg_upstream,
20    q_op,
21)
22from executorch.exir.dialects._ops import ops as exir_ops
23from serializer.tosa_serializer import TosaOp
24from torch.fx import Node
25
26logger = logging.getLogger(__name__)
27logger.setLevel(logging.WARNING)
28TOSA_DBG_VERBOSE = os.environ.get("TOSA_DBG_VERBOSE") == "1"
29if TOSA_DBG_VERBOSE:
30    logging.basicConfig(level=logging.INFO)
31    logger.setLevel(logging.INFO)
32
33
34def dbg_node(node):
35    # Debug output of node information
36    logger.info("OP")
37    logger.info(f"  op is {node.op}")
38    logger.info(f"  name is {node.name}")
39    logger.info(f"  node target is {node.target}")
40    logger.info(f"  node args is {node.args}")
41    logger.info(f"  node kwargs is {node.kwargs}")
42    logger.info("  node.meta = ")
43    for k, v in node.meta.items():
44        logger.info(f"    '{k}' = {v}")
45        if isinstance(v, list):
46            for i in v:
47                logger.info(f"      {i} ")
48
49
50# Output TOSA flatbuffer and test harness file
51def dbg_tosa_dump(tosa_graph: ts.TosaSerializer, path: str, suffix: str = ""):
52    filename = f"output{suffix}.tosa"
53
54    logger.info(f"Emitting debug output to: {path=}, {suffix=}")
55
56    os.makedirs(path, exist_ok=True)
57
58    fb = tosa_graph.serialize()
59    js = tosa_graph.writeJson(filename)
60
61    filepath_tosa_fb = os.path.join(path, filename)
62    with open(filepath_tosa_fb, "wb") as f:
63        f.write(fb)
64    assert os.path.exists(filepath_tosa_fb), "Failed to write TOSA flatbuffer"
65
66    filepath_desc_json = os.path.join(path, f"desc{suffix}.json")
67    with open(filepath_desc_json, "w") as f:
68        f.write(js)
69    assert os.path.exists(filepath_desc_json), "Failed to write TOSA JSON"
70
71
72def dbg_fail(node, tosa_graph, path):
73    dbg_tosa_dump(tosa_graph, path)
74    logger.warn("Internal error due to poorly handled node:")
75    dbg_node(node)
76    logger.warn(f"Debug output captured in '{path}'.")
77    raise RuntimeError("TOSA Internal Error on node, enable logging for further info.")
78
79
80# Helper function to match TOSA's broadcasting rank requirement
81# Ref: TOSA 0.80.0 specification - 1.9.3. Data Layouts from
82# https://www.mlplatform.org/tosa/tosa_spec.html
83def promote_shape(tosa_fb, arg, promoted_shape, out_dtype):
84    assert np.prod(arg.shape) == np.prod(promoted_shape), "Incompatible promoted shape"
85    reshape_res = tosa_fb.addIntermediate(promoted_shape, out_dtype)
86    attr = ts.TosaSerializerAttribute()
87    attr.ReshapeAttribute(promoted_shape)
88    tosa_fb.addOperator(TosaOp.Op().RESHAPE, [arg.name], [reshape_res.name], attr)
89    return reshape_res
90
91
92# Helper transpose function to match TOSA's shape requirements
93# E.g., TOSA 0.80.0 specification - 2.3.3 CONV2D shapes:
94# https://www.mlplatform.org/tosa/tosa_spec.html#_conv2d
95def transpose_helper(tosa_fb, input, new_order, out_dtype):
96    # Check new_order's length is equal to input rank
97    assert len(input.shape) == len(new_order), "Wrong shape order length"
98
99    # Check no duplications
100    assert len(set(new_order)) == len(new_order), "Contain duplicated dim numbers"
101
102    # Check all dims are valid
103    for idx in new_order:
104        if idx < 0:
105            assert True, "Negative dim number"
106        elif idx >= len(input.shape):
107            assert True, "Dim is greater than input rank"
108
109    input_shape_transpoed = [input.shape[i] for i in new_order]
110    attr = ts.TosaSerializerAttribute()
111    attr.TransposeAttribute(new_order)
112    input_transposed = tosa_fb.addIntermediate(input_shape_transpoed, out_dtype)
113    tosa_fb.addOperator(
114        TosaOp.Op().TRANSPOSE, [input.name], [input_transposed.name], attr
115    )
116    return input_transposed
117
118
119def getNodeArgs(node: Node) -> list[TosaArg]:
120    return [TosaArg(arg) for arg in node.args]
121
122
123def get_input_tensor(node: Node) -> TosaArg:
124    return TosaArg(node.args[0])
125
126
127def get_output_node(node: Node) -> Node:
128    return list(node.users)[0]
129
130
131""" TOSA reshape returns a tensor with the same type/values as the input.
132    No data conversion happens during a reshape operation. """
133
134
135def build_reshape(tosa_fb, input_name, new_shape, output_name):
136    attr = ts.TosaSerializerAttribute()
137    attr.ReshapeAttribute(new_shape)
138    tosa_fb.addOperator(TosaOp.Op().RESHAPE, [input_name], [output_name], attr)
139
140
141def is_bias_node_for_quantized_conv(node):
142    consumer_node = list(node.users)[0]
143    return (
144        consumer_node.target == exir_ops.edge.aten.convolution.default
145        and list(consumer_node.users)[0].target == q_op
146    )
147
148
149def is_consumer_node_depthwise_conv2d(node):
150    consumer_node = list(node.users)[0]
151    if consumer_node.target == exir_ops.edge.aten.convolution.default:
152        inputs = getNodeArgs(consumer_node)
153        group = inputs[-1]
154        in_channels = inputs[0].shape[1]
155        out_channels = inputs[1].shape[0]
156        if (in_channels == group.number) and (out_channels % in_channels) == 0:
157            return True
158
159    return False
160
161
162def build_avg_pool_2d_common(
163    node: torch.fx.Node,
164    tosa_graph: ts.TosaSerializer,
165    input_tensor: TosaArg,
166    kernel_size: list,
167    stride: list,
168    padding: list,
169    is_quant_node: bool,
170    output: TosaArg,
171):
172    accumulator_type = input_tensor.dtype
173
174    if is_quant_node:
175        # Accumulator type always is int32 when input tensor is an integer type.
176        accumulator_type = ts.DType.INT32
177
178    # Initilize zero point to zero.
179    input_zp = 0
180    output_zp = 0
181
182    if is_quant_node:
183        input_zp = get_quant_arg_upstream(cast(torch.fx.Node, node.args[0])).zp
184        output_zp = get_quant_arg_downstream(list(node.users)[0]).zp
185
186    attr = ts.TosaSerializerAttribute()
187    attr.PoolAttribute(
188        kernel=kernel_size,
189        stride=stride,
190        pad=padding,
191        input_zp=input_zp,
192        output_zp=output_zp,
193        accum_dtype=accumulator_type,
194    )
195
196    tosa_graph.addOperator(
197        TosaOp.Op().AVG_POOL2D,
198        [input_tensor.name],
199        [output.name],
200        attr,
201    )
202
203
204def get_two_inputs(node: Node, check: bool = False) -> tuple[Node, Node]:
205    """Returns two input nodes to 'node' in order. If 'node' only has one input,
206    it is returned twice.
207
208    Fails if there are no input nodes.
209    Fails if there are >2 input nodes and 'check' is True,
210    """
211
212    num_inputs = len(node.all_input_nodes)
213    assert num_inputs > 0, f"Node '{node.name}' requires >0 input, got {num_inputs}."
214
215    input1 = node.all_input_nodes[0]
216    if num_inputs == 1:
217        input2 = node.all_input_nodes[0]
218    else:
219        input2 = node.all_input_nodes[1]
220    if check:
221        assert (
222            num_inputs <= 2
223        ), f"Node '{node.name}' requires <=2 inputs, got {num_inputs}."
224
225    return input1, input2
226
227
228def tosa_shape(shape, dim_order):
229    return tuple([shape[dim] for dim in dim_order])
230
231
232def expand_dims(
233    tosa_graph: ts.TosaSerializer,
234    input_node: TosaArg,
235    dtype: int,
236    dim: int,
237) -> Any:
238    """Inserts TOSA operators into the tosa_graph, that perform the equivalent
239    of the expand_dims (a.k.a unsqueeze) operation. A new axis is created at the
240    dim location.
241
242    Args:
243        tosa_graph (ts.TosaSerializer): The TOSA graph to manipulate.
244        input_node (TosaArg): The parent node of the expand dim operations.
245        dtype (ts.DType): The data type expand dims operations.
246        dim (int): The dimension to expand.
247
248    Returns:
249        Any: The output tensor of the inserted operation in the TOSA graph.
250    """
251    new_shape = list(input_node.shape)
252    new_shape.insert(dim, 1)
253
254    intermediate = tosa_graph.addIntermediate(new_shape, dtype)
255
256    build_reshape(tosa_graph, input_node.name, new_shape, intermediate.name)
257
258    return intermediate
259
260
261def get_resize_parameters(
262    input_size: torch.Tensor,
263    output_size: torch.Tensor,
264    resize_mode: int,
265    align_corners: bool,
266):
267    """Get the tosa.resize parameters based on the input and output size.
268
269    Args:
270        input_size (torch.Tensor): Size of the input
271        output_size (torch.Tensor): Size of the output
272        resize_mode (tosa.ResizeMode): The TOSA resize mode
273        align_corners (bool): Align the corners pixels of the input and output
274
275    Returns:
276        scale_n (torch.Tensor), scale_d (torch.Tensor),
277        offset (torch.Tensor), border (torch.Tensor)
278    """
279    assert torch.all(input_size > 0)
280    assert torch.all(output_size > 0)
281
282    scale_n = torch.tensor(
283        [
284            so - 1 if align_corners and si > 1 and so > 1 else so
285            for si, so in zip(input_size, output_size)
286        ]
287    )
288    scale_d = torch.tensor(
289        [
290            si - 1 if align_corners and si > 1 and so > 1 else si
291            for si, so in zip(input_size, output_size)
292        ]
293    )
294
295    gcd = torch.gcd(scale_n, scale_d)
296    scale_n = scale_n // gcd
297    scale_d = scale_d // gcd
298
299    # No half-pixel centre support in PyTorch, no offset needed
300    offset = torch.zeros_like(input_size)
301    border = scale_d * (output_size - 1) - scale_n * (input_size - 1) + offset
302
303    return scale_n, scale_d, offset, border
304