xref: /aosp_15_r20/external/executorch/backends/arm/operators/op_slice.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright 2024 Arm Limited and/or its affiliates.
2#
3# This source code is licensed under the BSD-style license found in the
4# LICENSE file in the root directory of this source tree.
5
6# pyre-unsafe
7
8from typing import List
9
10import serializer.tosa_serializer as ts
11from executorch.backends.arm.operators.node_visitor import (
12    NodeVisitor,
13    register_node_visitor,
14)
15from executorch.backends.arm.tosa_mapping import TosaArg
16from serializer.tosa_serializer import TosaOp
17from torch.fx import Node
18
19
20@register_node_visitor
21class SliceVisitor(NodeVisitor):
22    target = "aten.slice_copy.Tensor"
23
24    def __init__(self, *args):
25        super().__init__(*args)
26
27    def define_node(
28        self,
29        node: Node,
30        tosa_graph: ts.TosaSerializer,
31        inputs: List[TosaArg],
32        output: TosaArg,
33        is_quant_node: bool,
34    ) -> None:
35
36        # aten.slice_copy supports slicing in 1d at a time.
37        # The arguments are dimension of slicing, start index and end index.
38        assert len(inputs) == 4
39        input_node, dim, start, end = inputs
40
41        # Translate and check parameters in Pytorch dim order.
42        shape = input_node.shape
43        dim = dim.number
44        if end.number < 0:
45            end = end.number % shape[dim]
46        else:
47            end = min(end.number, shape[dim])
48        size = end - start.number
49        assert size > 0
50        assert size <= shape[dim]
51
52        # Convert aten args to Tosa's start and size attributes and in TOSA dim order.
53        attr = ts.TosaSerializerAttribute()
54        start_attr = [start.number if i == dim else 0 for i in input_node.dim_order]
55        size_attr = [size if i == dim else shape[i] for i in input_node.dim_order]
56        attr.SliceAttribute(start_attr, size_attr)
57
58        tosa_graph.addOperator(
59            TosaOp.Op().SLICE, [input_node.name], [output.name], attr
60        )
61