1# Copyright 2024 Arm Limited and/or its affiliates. 2# 3# This source code is licensed under the BSD-style license found in the 4# LICENSE file in the root directory of this source tree. 5 6# pyre-unsafe 7 8import serializer.tosa_serializer as ts 9import torch 10from executorch.backends.arm.operators.node_visitor import ( 11 NodeVisitor, 12 register_node_visitor, 13) 14from executorch.backends.arm.tosa_mapping import TosaArg 15from executorch.backends.arm.tosa_utils import tosa_shape 16from serializer.tosa_serializer import TosaOp 17 18 19@register_node_visitor 20class RepeatVisitor(NodeVisitor): 21 target = "aten.repeat.default" 22 23 def __init__(self, *args): 24 super().__init__(*args) 25 26 def define_node( 27 self, 28 node: torch.fx.Node, 29 tosa_graph: ts.TosaSerializer, 30 inputs: list[TosaArg], 31 output: TosaArg, 32 is_quant_node: bool, 33 ) -> None: 34 35 item_name = inputs[0].name 36 shape = inputs[0].shape 37 rank = len(shape) 38 multiples = inputs[1].special 39 new_rank = len(multiples) 40 41 assert new_rank >= rank 42 43 # TILE only supports rank(in) == rank(out). To add more dims, we need a reshape first. 44 if new_rank > rank: 45 # Add length 1 dimensions to shape to match multiples 46 num_new_dims = new_rank - rank 47 expanded_shape = tuple( 48 1 if i < num_new_dims else shape[i - num_new_dims] 49 for i in range(new_rank) 50 ) 51 expanded_shape = tosa_shape(expanded_shape, output.dim_order) 52 dtype = ( 53 ts.dtype_str_to_val("INT8") 54 if is_quant_node 55 else ts.dtype_str_to_val("FP32") 56 ) 57 58 rescale_out = tosa_graph.addIntermediate(expanded_shape, dtype) 59 rescale_attr = ts.TosaSerializerAttribute() 60 rescale_attr.ReshapeAttribute(expanded_shape) 61 tosa_graph.addOperator( 62 TosaOp.Op().RESHAPE, [item_name], [rescale_out.name], rescale_attr 63 ) 64 item_name = rescale_out.name 65 66 attr = ts.TosaSerializerAttribute() 67 attr.TileAttribute(tosa_shape(multiples, output.dim_order)) 68 tosa_graph.addOperator(TosaOp.Op().TILE, [item_name], [output.name], attr) 69