xref: /aosp_15_r20/external/executorch/backends/arm/test/ops/test_mean_dim.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# Copyright 2024 Arm Limited and/or its affiliates.
3# All rights reserved.
4#
5# This source code is licensed under the BSD-style license found in the
6# LICENSE file in the root directory of this source tree.
7
8import unittest
9
10from typing import Tuple
11
12import torch
13from executorch.backends.arm.test import common
14from executorch.backends.arm.test.tester.arm_tester import ArmTester
15from executorch.exir.backend.backend_details import CompileSpec
16from parameterized import parameterized
17
18
19class TestMeanDim(unittest.TestCase):
20    """Tests MeanDim, called AdaptiveAvgPool2d in Pytorch."""
21
22    class AdaptiveAveragePool2d(torch.nn.Module):
23        test_data_suite = [
24            # (test_name, test_data)
25            (
26                "zeros",
27                torch.zeros(1, 1280, 7, 7),
28            ),
29            (
30                "ones",
31                torch.ones(1, 1280, 7, 7),
32            ),
33            (
34                "rand",
35                torch.rand(1, 1280, 7, 7),
36            ),
37            (
38                "randn",
39                torch.randn(1, 1280, 7, 7),
40            ),
41        ]
42
43        def __init__(self):
44            super().__init__()
45            self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d(output_size=(1, 1))
46
47        def forward(self, x):
48            return self.adaptive_avg_pool2d(x)
49
50    class MeanDim(torch.nn.Module):
51        test_data_suite = [
52            # (test_name, test_data)
53            ("zeros", torch.zeros(1, 1280, 7, 7), -1, True),
54            ("ones", torch.ones(1, 1280, 7, 7), (-1, 2), True),
55            (
56                "rand",
57                torch.rand(1, 1280, 7, 7),
58                (-1),
59                True,
60            ),
61            (
62                "randn",
63                torch.randn(1, 1280, 7, 7),
64                (-1, -2, -3),
65                True,
66            ),
67        ]
68
69        def __init__(self, dim: int | list[int] = -1, keepdim: bool = True):
70            super().__init__()
71            self.dim = dim
72            self.keepdim = keepdim
73
74        def forward(self, x: torch.Tensor):
75            return x.mean(dim=self.dim, keepdim=self.keepdim)
76
77    def _test_adaptive_avg_pool2d_tosa_MI_pipeline(
78        self, module: torch.nn.Module, test_data: Tuple[torch.tensor]
79    ):
80        (
81            ArmTester(
82                module,
83                example_inputs=test_data,
84                compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"),
85            )
86            .export()
87            .check(["torch.ops.aten.adaptive_avg_pool2d.default"])
88            .check_not(["torch.ops.quantized_decomposed"])
89            .to_edge()
90            .partition()
91            .check_not(["executorch_exir_dialects_edge__ops_aten_mean_dim"])
92            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
93            .to_executorch()
94            .run_method_and_compare_outputs(inputs=test_data)
95        )
96
97    def _test_adaptive_avg_pool2d_tosa_BI_pipeline(
98        self, module: torch.nn.Module, test_data: Tuple[torch.tensor]
99    ):
100        (
101            ArmTester(
102                module,
103                example_inputs=test_data,
104                compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"),
105            )
106            .quantize()
107            .export()
108            .check_count({"torch.ops.aten.adaptive_avg_pool2d.default": 1})
109            .check(["torch.ops.quantized_decomposed"])
110            .to_edge()
111            .partition()
112            .check_not(["executorch_exir_dialects_edge__ops_aten_mean_dim"])
113            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
114            .to_executorch()
115            .run_method_and_compare_outputs(inputs=test_data)
116        )
117
118    def _test_adaptive_avg_pool2d_tosa_ethosu_BI_pipeline(
119        self,
120        module: torch.nn.Module,
121        compile_spec: CompileSpec,
122        test_data: Tuple[torch.tensor],
123    ):
124        (
125            ArmTester(
126                module,
127                example_inputs=test_data,
128                compile_spec=compile_spec,
129            )
130            .quantize()
131            .export()
132            .check(["torch.ops.aten.adaptive_avg_pool2d.default"])
133            .check(["torch.ops.quantized_decomposed"])
134            .to_edge()
135            .partition()
136            .check_not(
137                [
138                    "executorch_exir_dialects_edge__ops_aten_mean_dim",
139                    "executorch_exir_dialects_edge__ops_aten_avg_pool2d_default",
140                ]
141            )
142            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
143            .to_executorch()
144        )
145
146    def _test_meandim_tosa_MI_pipeline(
147        self, module: torch.nn.Module, test_data: Tuple[torch.tensor]
148    ):
149        (
150            ArmTester(
151                module,
152                example_inputs=test_data,
153                compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"),
154            )
155            .export()
156            .check_not(["torch.ops.quantized_decomposed"])
157            .to_edge()
158            .partition()
159            .check_not(["executorch_exir_dialects_edge__ops_aten_mean_dim"])
160            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
161            .to_executorch()
162            .run_method_and_compare_outputs(inputs=test_data)
163        )
164
165    def _test_meandim_tosa_BI_pipeline(
166        self, module: torch.nn.Module, test_data: Tuple[torch.tensor]
167    ):
168        (
169            ArmTester(
170                module,
171                example_inputs=test_data,
172                compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"),
173            )
174            .quantize()
175            .export()
176            .check(["torch.ops.quantized_decomposed"])
177            .to_edge()
178            .partition()
179            .check_not(["executorch_exir_dialects_edge__ops_aten_mean_dim"])
180            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
181            .to_executorch()
182            .run_method_and_compare_outputs(inputs=test_data, qtol=1.0)
183        )
184
185    def _test_meandim_tosa_ethosu_BI_pipeline(
186        self,
187        module: torch.nn.Module,
188        compile_spec: CompileSpec,
189        test_data: Tuple[torch.tensor],
190    ):
191        (
192            ArmTester(
193                module,
194                example_inputs=test_data,
195                compile_spec=compile_spec,
196            )
197            .quantize()
198            .export()
199            .check(["torch.ops.quantized_decomposed"])
200            .to_edge()
201            .partition()
202            .check_not(
203                [
204                    "executorch_exir_dialects_edge__ops_aten_mean_dim",
205                    "executorch_exir_dialects_edge__ops_aten_avg_pool2d_default",
206                ]
207            )
208            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
209            .to_executorch()
210        )
211
212    @parameterized.expand(AdaptiveAveragePool2d.test_data_suite)
213    def test_adaptive_avg_pool2d_tosa_MI(
214        self,
215        test_name: str,
216        test_data: torch.Tensor,
217    ):
218        self._test_adaptive_avg_pool2d_tosa_MI_pipeline(
219            self.AdaptiveAveragePool2d(), (test_data,)
220        )
221
222    @parameterized.expand(AdaptiveAveragePool2d.test_data_suite)
223    def test_adaptive_avg_pool2d_tosa_BI(
224        self,
225        test_name: str,
226        test_data: torch.Tensor,
227    ):
228        self._test_adaptive_avg_pool2d_tosa_BI_pipeline(
229            self.AdaptiveAveragePool2d(), (test_data,)
230        )
231
232    @parameterized.expand(AdaptiveAveragePool2d.test_data_suite)
233    def test_adaptive_avg_pool2d_tosa_u55_BI(
234        self,
235        test_name: str,
236        test_data: torch.Tensor,
237    ):
238        self._test_adaptive_avg_pool2d_tosa_ethosu_BI_pipeline(
239            self.AdaptiveAveragePool2d(), common.get_u55_compile_spec(), (test_data,)
240        )
241
242    @parameterized.expand(AdaptiveAveragePool2d.test_data_suite)
243    def test_adaptive_avg_pool2d_tosa_u85_BI(
244        self,
245        test_name: str,
246        test_data: torch.Tensor,
247    ):
248        self._test_adaptive_avg_pool2d_tosa_ethosu_BI_pipeline(
249            self.AdaptiveAveragePool2d(), common.get_u85_compile_spec(), (test_data,)
250        )
251
252    @parameterized.expand(MeanDim.test_data_suite)
253    def test_meandim_tosa_MI(
254        self,
255        test_name: str,
256        test_data: torch.Tensor,
257        dim: int | list[int] = -1,
258        keepdim: bool = True,
259    ):
260        self._test_meandim_tosa_MI_pipeline(self.MeanDim(dim, keepdim), (test_data,))
261
262    @parameterized.expand(MeanDim.test_data_suite)
263    def test_meandim_tosa_BI(
264        self,
265        test_name: str,
266        test_data: torch.Tensor,
267        dim: int | list[int] = -1,
268        keepdim: bool = True,
269    ):
270        self._test_meandim_tosa_BI_pipeline(self.MeanDim(dim, keepdim), (test_data,))
271
272    @parameterized.expand(MeanDim.test_data_suite)
273    def test_meandim_tosa_u55_BI(
274        self,
275        test_name: str,
276        test_data: torch.Tensor,
277        dim: int | list[int] = -1,
278        keepdim: bool = True,
279    ):
280        self._test_meandim_tosa_ethosu_BI_pipeline(
281            self.MeanDim(dim, keepdim),
282            common.get_u55_compile_spec(),
283            (test_data,),
284        )
285
286    @parameterized.expand(MeanDim.test_data_suite)
287    def test_meandim_tosa_u85_BI(
288        self,
289        test_name: str,
290        test_data: torch.Tensor,
291        dim: int | list[int] = -1,
292        keepdim: bool = True,
293    ):
294        self._test_meandim_tosa_ethosu_BI_pipeline(
295            self.MeanDim(dim, keepdim),
296            common.get_u85_compile_spec(),
297            (test_data,),
298        )
299