xref: /aosp_15_r20/external/executorch/backends/arm/_passes/convert_expand_copy_to_repeat.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright 2024 Arm Limited and/or its affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-unsafe
8
9from typing import cast
10
11from executorch.backends.arm.tosa_mapping import extract_tensor_meta
12from executorch.exir.dialects._ops import ops as exir_ops
13from executorch.exir.pass_base import ExportPass
14
15
16class ConvertExpandCopyToRepeatPass(ExportPass):
17    """
18    Replace expand copy with repeat since it is a repeat that can only repeat singleton dimensions.
19    """
20
21    expand_copy = exir_ops.edge.aten.expand_copy.default
22    repeat = exir_ops.edge.aten.repeat.default
23
24    def call_operator(self, op, args, kwargs, meta):
25        if op != self.expand_copy:
26            return super().call_operator(op, args, kwargs, meta)
27
28        _, shape, _ = extract_tensor_meta(meta.data)
29        multiples = cast(list[int], args[1])
30        expanded_rank = len(multiples)
31
32        # Expanded shape is 'shape' front-padded with ones.
33        padding = expanded_rank - len(shape)
34        extended_shape = [
35            shape[i] if i >= 0 else 1 for i in range(-padding, len(shape))
36        ]
37
38        # To convert expand arg to repeat arg, non-repeated dims should have
39        # multiples[dim] = 1.
40        multiples = [
41            multiples[i] if extended_shape[i] == 1 else 1 for i in range(expanded_rank)
42        ]
43        return super().call_operator(
44            op=self.repeat, args=(args[0], multiples), kwargs=kwargs, meta=meta
45        )
46