xref: /aosp_15_r20/external/executorch/backends/arm/test/ops/test_repeat.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright 2024 Arm Limited and/or its affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7#
8# Tests the repeat op which copies the data of the input tensor (possibly with new data format)
9#
10
11import unittest
12from typing import Sequence, Tuple
13
14import torch
15
16from executorch.backends.arm.quantizer.arm_quantizer import (
17    ArmQuantizer,
18    get_symmetric_quantization_config,
19)
20from executorch.backends.arm.test import common
21from executorch.backends.arm.test.tester.arm_tester import ArmTester
22
23from executorch.backends.xnnpack.test.tester.tester import Quantize
24from executorch.exir.backend.backend_details import CompileSpec
25from parameterized import parameterized
26
27
28class TestSimpleRepeat(unittest.TestCase):
29    """Tests Tensor.repeat for different ranks and dimensions."""
30
31    class Repeat(torch.nn.Module):
32        # (input tensor, multiples)
33        test_parameters = [
34            (torch.randn(3), (2,)),
35            (torch.randn(3, 4), (2, 1)),
36            (torch.randn(1, 1, 2, 2), (1, 2, 3, 4)),
37            (torch.randn(3), (2, 2)),
38            (torch.randn(3), (1, 2, 3)),
39            (torch.randn((3, 3)), (2, 2, 2)),
40        ]
41
42        def forward(self, x: torch.Tensor, multiples: Sequence):
43            return x.repeat(multiples)
44
45    def _test_repeat_tosa_MI_pipeline(self, module: torch.nn.Module, test_data: Tuple):
46        (
47            ArmTester(
48                module,
49                example_inputs=test_data,
50                compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"),
51            )
52            .export()
53            .check_count({"torch.ops.aten.repeat.default": 1})
54            .to_edge()
55            .partition()
56            .check_not(["torch.ops.aten.repeat.default"])
57            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
58            .to_executorch()
59            .run_method_and_compare_outputs(inputs=test_data)
60        )
61
62    def _test_repeat_tosa_BI_pipeline(self, module: torch.nn.Module, test_data: Tuple):
63        quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config())
64        (
65            ArmTester(
66                module,
67                example_inputs=test_data,
68                compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"),
69            )
70            .quantize(Quantize(quantizer, get_symmetric_quantization_config()))
71            .export()
72            .check_count({"torch.ops.aten.repeat.default": 1})
73            .to_edge()
74            .partition()
75            .check_not(["torch.ops.aten.repeat.default"])
76            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
77            .to_executorch()
78            .run_method_and_compare_outputs(inputs=test_data, qtol=1)
79        )
80
81    def _test_repeat_ethosu_pipeline(
82        self, compile_spec: CompileSpec, module: torch.nn.Module, test_data: Tuple
83    ):
84        quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config())
85        (
86            ArmTester(
87                module,
88                example_inputs=test_data,
89                compile_spec=compile_spec,
90            )
91            .quantize(Quantize(quantizer, get_symmetric_quantization_config()))
92            .export()
93            .check_count({"torch.ops.aten.repeat.default": 1})
94            .to_edge()
95            .partition()
96            .check_not(["torch.ops.aten.repeat.default"])
97            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
98            .to_executorch()
99        )
100
101    @parameterized.expand(Repeat.test_parameters)
102    def test_repeat_tosa_MI(self, test_input, multiples):
103        self._test_repeat_tosa_MI_pipeline(self.Repeat(), (test_input, multiples))
104
105    @parameterized.expand(Repeat.test_parameters)
106    def test_repeat_tosa_BI(self, test_input, multiples):
107        self._test_repeat_tosa_BI_pipeline(self.Repeat(), (test_input, multiples))
108
109    @parameterized.expand(Repeat.test_parameters)
110    def test_repeat_u55_BI(self, test_input, multiples):
111        self._test_repeat_ethosu_pipeline(
112            common.get_u55_compile_spec(), self.Repeat(), (test_input, multiples)
113        )
114
115    @parameterized.expand(Repeat.test_parameters)
116    def test_repeat_u85_BI(self, test_input, multiples):
117        self._test_repeat_ethosu_pipeline(
118            common.get_u85_compile_spec(), self.Repeat(), (test_input, multiples)
119        )
120