xref: /aosp_15_r20/external/executorch/backends/arm/test/ops/test_var.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright 2024 Arm Limited and/or its affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7#
8# Tests the mean op which changes the size of a Tensor without changing the underlying data.
9#
10
11import unittest
12
13import torch
14from executorch.backends.arm.quantizer.arm_quantizer import (
15    ArmQuantizer,
16    get_symmetric_quantization_config,
17)
18
19from executorch.backends.arm.test import common
20from executorch.backends.arm.test.tester.arm_tester import ArmTester
21from executorch.backends.xnnpack.test.tester.tester import Quantize
22from executorch.exir.backend.backend_details import CompileSpec
23
24from parameterized import parameterized
25
26
27class TestVar(unittest.TestCase):
28
29    class Var(torch.nn.Module):
30        test_parameters = [
31            (torch.randn(1, 50, 10, 20), True, 0),
32            (torch.rand(1, 50, 10), True, 0),
33            (torch.randn(1, 30, 15, 20), True, 1),
34            (torch.rand(1, 50, 10, 20), True, 0.5),
35        ]
36
37        def forward(
38            self,
39            x: torch.Tensor,
40            keepdim: bool = True,
41            correction: int = 0,
42        ):
43            return x.var(keepdim=keepdim, correction=correction)
44
45    class VarDim(torch.nn.Module):
46        test_parameters = [
47            (torch.randn(1, 50, 10, 20), 1, True, False),
48            (torch.rand(1, 50, 10), -2, True, False),
49            (torch.randn(1, 30, 15, 20), -3, True, True),
50            (torch.rand(1, 50, 10, 20), -1, True, True),
51        ]
52
53        def forward(
54            self,
55            x: torch.Tensor,
56            dim: int = -1,
57            keepdim: bool = True,
58            unbiased: bool = False,
59        ):
60            return x.var(dim=dim, keepdim=keepdim, unbiased=unbiased)
61
62    class VarCorrection(torch.nn.Module):
63        test_parameters = [
64            (torch.randn(1, 50, 10, 20), (-1, -2), True, 0),
65            (torch.rand(1, 50, 10), (-2), True, 0),
66            (torch.randn(1, 30, 15, 20), (-1, -2, -3), True, 1),
67            (torch.rand(1, 50, 10, 20), (-1, -2), True, 0.5),
68        ]
69
70        def forward(
71            self,
72            x: torch.Tensor,
73            dim: int | tuple[int] = -1,
74            keepdim: bool = True,
75            correction: int = 0,
76        ):
77            return x.var(dim=dim, keepdim=keepdim, correction=correction)
78
79    def _test_var_tosa_MI_pipeline(
80        self,
81        module: torch.nn.Module,
82        test_data: torch.Tensor,
83        target_str: str = None,
84    ):
85        (
86            ArmTester(
87                module,
88                example_inputs=test_data,
89                compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"),
90            )
91            .export()
92            .to_edge()
93            .partition()
94            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
95            .to_executorch()
96            .run_method_and_compare_outputs(inputs=test_data)
97        )
98
99    def _test_var_tosa_BI_pipeline(
100        self,
101        module: torch.nn.Module,
102        test_data: torch.Tensor,
103        target_str: str = None,
104    ):
105        quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config())
106        (
107            ArmTester(
108                module,
109                example_inputs=test_data,
110                compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"),
111            )
112            .quantize(Quantize(quantizer, get_symmetric_quantization_config()))
113            .export()
114            .to_edge()
115            .partition()
116            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
117            .to_executorch()
118            .run_method_and_compare_outputs(inputs=test_data, qtol=1)
119        )
120
121    def _test_var_ethosu_BI_pipeline(
122        self,
123        module: torch.nn.Module,
124        compile_spec: CompileSpec,
125        test_data: torch.Tensor,
126        target_str: str = None,
127    ):
128        quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config())
129        (
130            ArmTester(
131                module,
132                example_inputs=test_data,
133                compile_spec=compile_spec,
134            )
135            .quantize(Quantize(quantizer, get_symmetric_quantization_config()))
136            .export()
137            .to_edge()
138            .partition()
139            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
140            .to_executorch()
141        )
142
143    @parameterized.expand(Var.test_parameters)
144    def test_var_tosa_MI(self, test_tensor: torch.Tensor, keepdim, correction):
145        self._test_var_tosa_MI_pipeline(self.Var(), (test_tensor, keepdim, correction))
146
147    @parameterized.expand(Var.test_parameters)
148    def test_var_tosa_BI(self, test_tensor: torch.Tensor, keepdim, correction):
149        self._test_var_tosa_BI_pipeline(self.Var(), (test_tensor, keepdim, correction))
150
151    @parameterized.expand(Var.test_parameters)
152    def test_var_u55_BI(self, test_tensor: torch.Tensor, keepdim, correction):
153        self._test_var_ethosu_BI_pipeline(
154            self.Var(),
155            common.get_u55_compile_spec(),
156            (test_tensor, keepdim, correction),
157        )
158
159    @parameterized.expand(Var.test_parameters)
160    def test_var_u85_BI(self, test_tensor: torch.Tensor, keepdim, correction):
161        self._test_var_ethosu_BI_pipeline(
162            self.Var(),
163            common.get_u85_compile_spec(),
164            (test_tensor, keepdim, correction),
165        )
166
167    @parameterized.expand(VarDim.test_parameters)
168    def test_var_dim_tosa_MI(self, test_tensor: torch.Tensor, dim, keepdim, correction):
169        self._test_var_tosa_MI_pipeline(
170            self.VarDim(), (test_tensor, dim, keepdim, correction)
171        )
172
173    @parameterized.expand(VarDim.test_parameters)
174    def test_var_dim_tosa_BI(self, test_tensor: torch.Tensor, dim, keepdim, correction):
175        self._test_var_tosa_BI_pipeline(
176            self.VarDim(), (test_tensor, dim, keepdim, correction)
177        )
178
179    @parameterized.expand(VarDim.test_parameters)
180    def test_var_dim_u55_BI(self, test_tensor: torch.Tensor, dim, keepdim, correction):
181        self._test_var_ethosu_BI_pipeline(
182            self.VarDim(),
183            common.get_u55_compile_spec(),
184            (test_tensor, dim, keepdim, correction),
185        )
186
187    @parameterized.expand(VarDim.test_parameters)
188    def test_var_dim_u85_BI(self, test_tensor: torch.Tensor, dim, keepdim, correction):
189        self._test_var_ethosu_BI_pipeline(
190            self.VarDim(),
191            common.get_u85_compile_spec(),
192            (test_tensor, dim, keepdim, correction),
193        )
194
195    @parameterized.expand(VarCorrection.test_parameters)
196    def test_var_correction_tosa_MI(
197        self, test_tensor: torch.Tensor, dim, keepdim, correction
198    ):
199        self._test_var_tosa_MI_pipeline(
200            self.VarCorrection(), (test_tensor, dim, keepdim, correction)
201        )
202
203    @parameterized.expand(VarCorrection.test_parameters)
204    def test_var_correction_tosa_BI(
205        self, test_tensor: torch.Tensor, dim, keepdim, correction
206    ):
207        self._test_var_tosa_BI_pipeline(
208            self.VarCorrection(), (test_tensor, dim, keepdim, correction)
209        )
210
211    @parameterized.expand(VarCorrection.test_parameters)
212    def test_var_correction_u55_BI(
213        self, test_tensor: torch.Tensor, dim, keepdim, correction
214    ):
215        self._test_var_ethosu_BI_pipeline(
216            self.VarCorrection(),
217            common.get_u55_compile_spec(),
218            (test_tensor, dim, keepdim, correction),
219        )
220
221    @parameterized.expand(VarCorrection.test_parameters)
222    def test_var_correction_u85_BI(
223        self, test_tensor: torch.Tensor, dim, keepdim, correction
224    ):
225        self._test_var_ethosu_BI_pipeline(
226            self.VarCorrection(),
227            common.get_u85_compile_spec(),
228            (test_tensor, dim, keepdim, correction),
229        )
230