xref: /aosp_15_r20/external/executorch/backends/arm/operators/op_repeat.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright 2024 Arm Limited and/or its affiliates.
2#
3# This source code is licensed under the BSD-style license found in the
4# LICENSE file in the root directory of this source tree.
5
6# pyre-unsafe
7
8import serializer.tosa_serializer as ts
9import torch
10from executorch.backends.arm.operators.node_visitor import (
11    NodeVisitor,
12    register_node_visitor,
13)
14from executorch.backends.arm.tosa_mapping import TosaArg
15from executorch.backends.arm.tosa_utils import tosa_shape
16from serializer.tosa_serializer import TosaOp
17
18
19@register_node_visitor
20class RepeatVisitor(NodeVisitor):
21    target = "aten.repeat.default"
22
23    def __init__(self, *args):
24        super().__init__(*args)
25
26    def define_node(
27        self,
28        node: torch.fx.Node,
29        tosa_graph: ts.TosaSerializer,
30        inputs: list[TosaArg],
31        output: TosaArg,
32        is_quant_node: bool,
33    ) -> None:
34
35        item_name = inputs[0].name
36        shape = inputs[0].shape
37        rank = len(shape)
38        multiples = inputs[1].special
39        new_rank = len(multiples)
40
41        assert new_rank >= rank
42
43        # TILE only supports rank(in) == rank(out). To add more dims, we need a reshape first.
44        if new_rank > rank:
45            # Add length 1 dimensions to shape to match multiples
46            num_new_dims = new_rank - rank
47            expanded_shape = tuple(
48                1 if i < num_new_dims else shape[i - num_new_dims]
49                for i in range(new_rank)
50            )
51            expanded_shape = tosa_shape(expanded_shape, output.dim_order)
52            dtype = (
53                ts.dtype_str_to_val("INT8")
54                if is_quant_node
55                else ts.dtype_str_to_val("FP32")
56            )
57
58            rescale_out = tosa_graph.addIntermediate(expanded_shape, dtype)
59            rescale_attr = ts.TosaSerializerAttribute()
60            rescale_attr.ReshapeAttribute(expanded_shape)
61            tosa_graph.addOperator(
62                TosaOp.Op().RESHAPE, [item_name], [rescale_out.name], rescale_attr
63            )
64            item_name = rescale_out.name
65
66        attr = ts.TosaSerializerAttribute()
67        attr.TileAttribute(tosa_shape(multiples, output.dim_order))
68        tosa_graph.addOperator(TosaOp.Op().TILE, [item_name], [output.name], attr)
69