xref: /aosp_15_r20/external/executorch/backends/arm/test/ops/test_expand.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright 2024 Arm Limited and/or its affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7#
8# Tests the expand op which copies the data of the input tensor (possibly with new data format)
9#
10
11import unittest
12from typing import Sequence, Tuple
13
14import torch
15
16from executorch.backends.arm.quantizer.arm_quantizer import (
17    ArmQuantizer,
18    get_symmetric_quantization_config,
19)
20from executorch.backends.arm.test import common
21from executorch.backends.arm.test.tester.arm_tester import ArmTester
22
23from executorch.backends.xnnpack.test.tester.tester import Quantize
24from executorch.exir.backend.backend_details import CompileSpec
25from parameterized import parameterized
26
27
28class TestSimpleExpand(unittest.TestCase):
29    """Tests the Tensor.expand which should be converted to a repeat op by a pass."""
30
31    class Expand(torch.nn.Module):
32        # (input tensor, multiples)
33        test_parameters = [
34            (torch.ones(1), (2,)),
35            (torch.ones(1, 4), (1, -1)),
36            (torch.ones(1, 1, 2, 2), (4, 3, -1, 2)),
37            (torch.ones(1), (2, 2, 4)),
38            (torch.ones(3, 2, 4, 1), (-1, -1, -1, 3)),
39        ]
40
41        def forward(self, x: torch.Tensor, multiples: Sequence):
42            return x.expand(multiples)
43
44    def _test_expand_tosa_MI_pipeline(self, module: torch.nn.Module, test_data: Tuple):
45        (
46            ArmTester(
47                module,
48                example_inputs=test_data,
49                compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"),
50            )
51            .export()
52            .check_count({"torch.ops.aten.expand.default": 1})
53            .to_edge()
54            .partition()
55            .check_not(["torch.ops.aten.expand.default"])
56            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
57            .to_executorch()
58            .run_method_and_compare_outputs(inputs=test_data)
59        )
60
61    def _test_expand_tosa_BI_pipeline(self, module: torch.nn.Module, test_data: Tuple):
62        quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config())
63        (
64            ArmTester(
65                module,
66                example_inputs=test_data,
67                compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"),
68            )
69            .quantize(Quantize(quantizer, get_symmetric_quantization_config()))
70            .export()
71            .check_count({"torch.ops.aten.expand.default": 1})
72            .to_edge()
73            .partition()
74            .check_not(["torch.ops.aten.expand.default"])
75            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
76            .to_executorch()
77            .run_method_and_compare_outputs(inputs=test_data, qtol=1)
78        )
79
80    def _test_expand_ethosu_BI_pipeline(
81        self, compile_spec: CompileSpec, module: torch.nn.Module, test_data: Tuple
82    ):
83        quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config())
84        (
85            ArmTester(
86                module,
87                example_inputs=test_data,
88                compile_spec=compile_spec,
89            )
90            .quantize(Quantize(quantizer, get_symmetric_quantization_config()))
91            .export()
92            .check_count({"torch.ops.aten.expand.default": 1})
93            .to_edge()
94            .partition()
95            .check_not(["torch.ops.aten.expand.default"])
96            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
97            .to_executorch()
98        )
99
100    @parameterized.expand(Expand.test_parameters)
101    def test_expand_tosa_MI(self, test_input, multiples):
102        self._test_expand_tosa_MI_pipeline(self.Expand(), (test_input, multiples))
103
104    @parameterized.expand(Expand.test_parameters)
105    def test_expand_tosa_BI(self, test_input, multiples):
106        self._test_expand_tosa_BI_pipeline(self.Expand(), (test_input, multiples))
107
108    @parameterized.expand(Expand.test_parameters)
109    def test_expand_u55_BI(self, test_input, multiples):
110        self._test_expand_ethosu_BI_pipeline(
111            common.get_u55_compile_spec(), self.Expand(), (test_input, multiples)
112        )
113
114    @parameterized.expand(Expand.test_parameters)
115    def test_expand_u85_BI(self, test_input, multiples):
116        self._test_expand_ethosu_BI_pipeline(
117            common.get_u85_compile_spec(), self.Expand(), (test_input, multiples)
118        )
119