xref: /aosp_15_r20/external/executorch/backends/arm/test/ops/test_squeeze.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright 2024 Arm Limited and/or its affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7#
8# Tests the squeeze op which squeezes a given dimension with size 1 into a lower ranked tensor.
9#
10
11import unittest
12from typing import Optional, Tuple
13
14import torch
15
16from executorch.backends.arm.test import common
17from executorch.backends.arm.test.tester.arm_tester import ArmTester
18
19from executorch.exir.backend.compile_spec_schema import CompileSpec
20from parameterized import parameterized
21
22
23class TestSqueeze(unittest.TestCase):
24    class SqueezeDim(torch.nn.Module):
25        test_parameters: list[tuple[torch.Tensor, int]] = [
26            (torch.randn(1, 1, 5), -2),
27            (torch.randn(1, 2, 3, 1), 3),
28            (torch.randn(1, 5, 1, 5), -2),
29        ]
30
31        def forward(self, x: torch.Tensor, dim: int):
32            return x.squeeze(dim)
33
34    class SqueezeDims(torch.nn.Module):
35        test_parameters: list[tuple[torch.Tensor, tuple[int]]] = [
36            (torch.randn(1, 1, 5), (0, 1)),
37            (torch.randn(1, 5, 5, 1), (0, -1)),
38            (torch.randn(1, 5, 1, 5), (0, -2)),
39        ]
40
41        def forward(self, x: torch.Tensor, dims: tuple[int]):
42            return x.squeeze(dims)
43
44    class Squeeze(torch.nn.Module):
45        test_parameters: list[tuple[torch.Tensor]] = [
46            (torch.randn(1, 1, 5),),
47            (torch.randn(1, 5, 5, 1),),
48            (torch.randn(1, 5, 1, 5),),
49        ]
50
51        def forward(self, x: torch.Tensor):
52            return x.squeeze()
53
54    def _test_squeeze_tosa_MI_pipeline(
55        self,
56        module: torch.nn.Module,
57        test_data: Tuple[torch.Tensor, Optional[tuple[int]]],
58        export_target: str,
59    ):
60        (
61            ArmTester(
62                module,
63                example_inputs=test_data,
64                compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"),
65            )
66            .export()
67            .check_count({export_target: 1})
68            .to_edge()
69            .partition()
70            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
71            .to_executorch()
72            .run_method_and_compare_outputs(inputs=test_data)
73        )
74
75    def _test_squeeze_tosa_BI_pipeline(
76        self,
77        module: torch.nn.Module,
78        test_data: Tuple[torch.Tensor, Optional[tuple[int]]],
79        export_target: str,
80    ):
81        (
82            ArmTester(
83                module,
84                example_inputs=test_data,
85                compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"),
86            )
87            .quantize()
88            .export()
89            .check_count({export_target: 1})
90            .to_edge()
91            .partition()
92            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
93            .to_executorch()
94            .run_method_and_compare_outputs(inputs=test_data, qtol=1)
95        )
96
97    def _test_squeeze_ethosu_BI_pipeline(
98        self,
99        compile_spec: CompileSpec,
100        module: torch.nn.Module,
101        test_data: Tuple[torch.Tensor, Optional[tuple[int]]],
102        export_target: str,
103    ):
104        (
105            ArmTester(module, example_inputs=test_data, compile_spec=compile_spec)
106            .quantize()
107            .export()
108            .check_count({export_target: 1})
109            .to_edge()
110            .partition()
111            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
112            .to_executorch()
113        )
114
115    @parameterized.expand(Squeeze.test_parameters)
116    def test_squeeze_tosa_MI(
117        self,
118        test_tensor: torch.Tensor,
119    ):
120        self._test_squeeze_tosa_MI_pipeline(
121            self.Squeeze(), (test_tensor,), "torch.ops.aten.squeeze.default"
122        )
123
124    @parameterized.expand(Squeeze.test_parameters)
125    def test_squeeze_tosa_BI(
126        self,
127        test_tensor: torch.Tensor,
128    ):
129        self._test_squeeze_tosa_BI_pipeline(
130            self.Squeeze(), (test_tensor,), "torch.ops.aten.squeeze.default"
131        )
132
133    @parameterized.expand(Squeeze.test_parameters)
134    def test_squeeze_u55_BI(
135        self,
136        test_tensor: torch.Tensor,
137    ):
138        self._test_squeeze_ethosu_BI_pipeline(
139            common.get_u55_compile_spec(permute_memory_to_nhwc=False),
140            self.Squeeze(),
141            (test_tensor,),
142            "torch.ops.aten.squeeze.default",
143        )
144
145    @parameterized.expand(Squeeze.test_parameters)
146    def test_squeeze_u85_BI(
147        self,
148        test_tensor: torch.Tensor,
149    ):
150        self._test_squeeze_ethosu_BI_pipeline(
151            common.get_u85_compile_spec(permute_memory_to_nhwc=True),
152            self.Squeeze(),
153            (test_tensor,),
154            "torch.ops.aten.squeeze.default",
155        )
156
157    @parameterized.expand(SqueezeDim.test_parameters)
158    def test_squeeze_dim_tosa_MI(self, test_tensor: torch.Tensor, dim: int):
159        self._test_squeeze_tosa_MI_pipeline(
160            self.SqueezeDim(), (test_tensor, dim), "torch.ops.aten.squeeze.dim"
161        )
162
163    @parameterized.expand(SqueezeDim.test_parameters)
164    def test_squeeze_dim_tosa_BI(self, test_tensor: torch.Tensor, dim: int):
165        self._test_squeeze_tosa_BI_pipeline(
166            self.SqueezeDim(), (test_tensor, dim), "torch.ops.aten.squeeze.dim"
167        )
168
169    @parameterized.expand(SqueezeDim.test_parameters)
170    def test_squeeze_dim_u55_BI(self, test_tensor: torch.Tensor, dim: int):
171        self._test_squeeze_ethosu_BI_pipeline(
172            common.get_u55_compile_spec(permute_memory_to_nhwc=False),
173            self.SqueezeDim(),
174            (test_tensor, dim),
175            "torch.ops.aten.squeeze.dim",
176        )
177
178    @parameterized.expand(SqueezeDim.test_parameters)
179    def test_squeeze_dim_u85_BI(self, test_tensor: torch.Tensor, dim: int):
180        self._test_squeeze_ethosu_BI_pipeline(
181            common.get_u85_compile_spec(permute_memory_to_nhwc=True),
182            self.SqueezeDim(),
183            (test_tensor, dim),
184            "torch.ops.aten.squeeze.dim",
185        )
186
187    @parameterized.expand(SqueezeDims.test_parameters)
188    def test_squeeze_dims_tosa_MI(self, test_tensor: torch.Tensor, dims: tuple[int]):
189        self._test_squeeze_tosa_MI_pipeline(
190            self.SqueezeDims(), (test_tensor, dims), "torch.ops.aten.squeeze.dims"
191        )
192
193    @parameterized.expand(SqueezeDims.test_parameters)
194    def test_squeeze_dims_tosa_BI(self, test_tensor: torch.Tensor, dims: tuple[int]):
195        self._test_squeeze_tosa_BI_pipeline(
196            self.SqueezeDims(), (test_tensor, dims), "torch.ops.aten.squeeze.dims"
197        )
198
199    @parameterized.expand(SqueezeDims.test_parameters)
200    def test_squeeze_dims_u55_BI(self, test_tensor: torch.Tensor, dims: tuple[int]):
201        self._test_squeeze_ethosu_BI_pipeline(
202            common.get_u55_compile_spec(permute_memory_to_nhwc=False),
203            self.SqueezeDims(),
204            (test_tensor, dims),
205            "torch.ops.aten.squeeze.dims",
206        )
207
208    @parameterized.expand(SqueezeDims.test_parameters)
209    def test_squeeze_dims_u85_BI(self, test_tensor: torch.Tensor, dims: tuple[int]):
210        self._test_squeeze_ethosu_BI_pipeline(
211            common.get_u85_compile_spec(),
212            self.SqueezeDims(),
213            (test_tensor, dims),
214            "torch.ops.aten.squeeze.dims",
215        )
216