1# Copyright 2024 Arm Limited and/or its affiliates. 2# 3# This source code is licensed under the BSD-style license found in the 4# LICENSE file in the root directory of this source tree. 5 6# pyre-unsafe 7 8from typing import List 9 10import serializer.tosa_serializer as ts 11from executorch.backends.arm.operators.node_visitor import ( 12 NodeVisitor, 13 register_node_visitor, 14) 15from executorch.backends.arm.tosa_mapping import TosaArg 16from serializer.tosa_serializer import TosaOp 17from torch.fx import Node 18 19 20@register_node_visitor 21class SliceVisitor(NodeVisitor): 22 target = "aten.slice_copy.Tensor" 23 24 def __init__(self, *args): 25 super().__init__(*args) 26 27 def define_node( 28 self, 29 node: Node, 30 tosa_graph: ts.TosaSerializer, 31 inputs: List[TosaArg], 32 output: TosaArg, 33 is_quant_node: bool, 34 ) -> None: 35 36 # aten.slice_copy supports slicing in 1d at a time. 37 # The arguments are dimension of slicing, start index and end index. 38 assert len(inputs) == 4 39 input_node, dim, start, end = inputs 40 41 # Translate and check parameters in Pytorch dim order. 42 shape = input_node.shape 43 dim = dim.number 44 if end.number < 0: 45 end = end.number % shape[dim] 46 else: 47 end = min(end.number, shape[dim]) 48 size = end - start.number 49 assert size > 0 50 assert size <= shape[dim] 51 52 # Convert aten args to Tosa's start and size attributes and in TOSA dim order. 53 attr = ts.TosaSerializerAttribute() 54 start_attr = [start.number if i == dim else 0 for i in input_node.dim_order] 55 size_attr = [size if i == dim else shape[i] for i in input_node.dim_order] 56 attr.SliceAttribute(start_attr, size_attr) 57 58 tosa_graph.addOperator( 59 TosaOp.Op().SLICE, [input_node.name], [output.name], attr 60 ) 61