xref: /aosp_15_r20/external/executorch/backends/arm/operators/op_cat.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright 2024 Arm Limited and/or its affiliates.
2#
3# This source code is licensed under the BSD-style license found in the
4# LICENSE file in the root directory of this source tree.
5
6# pyre-unsafe
7
8from typing import List
9
10import serializer.tosa_serializer as ts
11from executorch.backends.arm.operators.node_visitor import (
12    NodeVisitor,
13    register_node_visitor,
14)
15from executorch.backends.arm.tosa_mapping import TosaArg
16from serializer.tosa_serializer import TosaOp
17from torch.fx import Node
18
19
20@register_node_visitor
21class CatVisitor(NodeVisitor):
22    target = "aten.cat.default"
23
24    def __init__(self, *args):
25        super().__init__(*args)
26
27    def define_node(
28        self,
29        node: Node,
30        tosa_graph: ts.TosaSerializer,
31        inputs: List[TosaArg],
32        output: TosaArg,
33        is_quant_node: bool,
34    ) -> None:
35
36        tensors = inputs[0].special
37        dim = 0 if len(inputs) < 2 else inputs[1].number
38        rank = len(output.shape)
39        dim = (dim + rank) % rank
40        dim = output.dim_order.index(dim)
41
42        attr = ts.TosaSerializerAttribute()
43        attr.AxisAttribute(dim)
44
45        tosa_graph.addOperator(
46            TosaOp.Op().CONCAT, [tensor.name for tensor in tensors], [output.name], attr
47        )
48