1# Copyright 2024 Arm Limited and/or its affiliates. 2# 3# This source code is licensed under the BSD-style license found in the 4# LICENSE file in the root directory of this source tree. 5 6# pyre-unsafe 7 8from typing import List 9 10import serializer.tosa_serializer as ts 11from executorch.backends.arm.operators.node_visitor import ( 12 NodeVisitor, 13 register_node_visitor, 14) 15from executorch.backends.arm.tosa_mapping import TosaArg 16from serializer.tosa_serializer import TosaOp 17from torch.fx import Node 18 19 20@register_node_visitor 21class CatVisitor(NodeVisitor): 22 target = "aten.cat.default" 23 24 def __init__(self, *args): 25 super().__init__(*args) 26 27 def define_node( 28 self, 29 node: Node, 30 tosa_graph: ts.TosaSerializer, 31 inputs: List[TosaArg], 32 output: TosaArg, 33 is_quant_node: bool, 34 ) -> None: 35 36 tensors = inputs[0].special 37 dim = 0 if len(inputs) < 2 else inputs[1].number 38 rank = len(output.shape) 39 dim = (dim + rank) % rank 40 dim = output.dim_order.index(dim) 41 42 attr = ts.TosaSerializerAttribute() 43 attr.AxisAttribute(dim) 44 45 tosa_graph.addOperator( 46 TosaOp.Op().CONCAT, [tensor.name for tensor in tensors], [output.name], attr 47 ) 48