1# Copyright 2024 Arm Limited and/or its affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# 8# Tests the mean op which changes the size of a Tensor without changing the underlying data. 9# 10 11import unittest 12 13import torch 14from executorch.backends.arm.quantizer.arm_quantizer import ( 15 ArmQuantizer, 16 get_symmetric_quantization_config, 17) 18 19from executorch.backends.arm.test import common 20from executorch.backends.arm.test.tester.arm_tester import ArmTester 21from executorch.backends.xnnpack.test.tester.tester import Quantize 22from executorch.exir.backend.backend_details import CompileSpec 23 24from parameterized import parameterized 25 26 27class TestVar(unittest.TestCase): 28 29 class Var(torch.nn.Module): 30 test_parameters = [ 31 (torch.randn(1, 50, 10, 20), True, 0), 32 (torch.rand(1, 50, 10), True, 0), 33 (torch.randn(1, 30, 15, 20), True, 1), 34 (torch.rand(1, 50, 10, 20), True, 0.5), 35 ] 36 37 def forward( 38 self, 39 x: torch.Tensor, 40 keepdim: bool = True, 41 correction: int = 0, 42 ): 43 return x.var(keepdim=keepdim, correction=correction) 44 45 class VarDim(torch.nn.Module): 46 test_parameters = [ 47 (torch.randn(1, 50, 10, 20), 1, True, False), 48 (torch.rand(1, 50, 10), -2, True, False), 49 (torch.randn(1, 30, 15, 20), -3, True, True), 50 (torch.rand(1, 50, 10, 20), -1, True, True), 51 ] 52 53 def forward( 54 self, 55 x: torch.Tensor, 56 dim: int = -1, 57 keepdim: bool = True, 58 unbiased: bool = False, 59 ): 60 return x.var(dim=dim, keepdim=keepdim, unbiased=unbiased) 61 62 class VarCorrection(torch.nn.Module): 63 test_parameters = [ 64 (torch.randn(1, 50, 10, 20), (-1, -2), True, 0), 65 (torch.rand(1, 50, 10), (-2), True, 0), 66 (torch.randn(1, 30, 15, 20), (-1, -2, -3), True, 1), 67 (torch.rand(1, 50, 10, 20), (-1, -2), True, 0.5), 68 ] 69 70 def forward( 71 self, 72 x: torch.Tensor, 73 dim: int | tuple[int] = -1, 74 keepdim: bool = True, 75 correction: int = 0, 76 ): 77 return x.var(dim=dim, keepdim=keepdim, correction=correction) 78 79 def _test_var_tosa_MI_pipeline( 80 self, 81 module: torch.nn.Module, 82 test_data: torch.Tensor, 83 target_str: str = None, 84 ): 85 ( 86 ArmTester( 87 module, 88 example_inputs=test_data, 89 compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), 90 ) 91 .export() 92 .to_edge() 93 .partition() 94 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 95 .to_executorch() 96 .run_method_and_compare_outputs(inputs=test_data) 97 ) 98 99 def _test_var_tosa_BI_pipeline( 100 self, 101 module: torch.nn.Module, 102 test_data: torch.Tensor, 103 target_str: str = None, 104 ): 105 quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config()) 106 ( 107 ArmTester( 108 module, 109 example_inputs=test_data, 110 compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), 111 ) 112 .quantize(Quantize(quantizer, get_symmetric_quantization_config())) 113 .export() 114 .to_edge() 115 .partition() 116 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 117 .to_executorch() 118 .run_method_and_compare_outputs(inputs=test_data, qtol=1) 119 ) 120 121 def _test_var_ethosu_BI_pipeline( 122 self, 123 module: torch.nn.Module, 124 compile_spec: CompileSpec, 125 test_data: torch.Tensor, 126 target_str: str = None, 127 ): 128 quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config()) 129 ( 130 ArmTester( 131 module, 132 example_inputs=test_data, 133 compile_spec=compile_spec, 134 ) 135 .quantize(Quantize(quantizer, get_symmetric_quantization_config())) 136 .export() 137 .to_edge() 138 .partition() 139 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 140 .to_executorch() 141 ) 142 143 @parameterized.expand(Var.test_parameters) 144 def test_var_tosa_MI(self, test_tensor: torch.Tensor, keepdim, correction): 145 self._test_var_tosa_MI_pipeline(self.Var(), (test_tensor, keepdim, correction)) 146 147 @parameterized.expand(Var.test_parameters) 148 def test_var_tosa_BI(self, test_tensor: torch.Tensor, keepdim, correction): 149 self._test_var_tosa_BI_pipeline(self.Var(), (test_tensor, keepdim, correction)) 150 151 @parameterized.expand(Var.test_parameters) 152 def test_var_u55_BI(self, test_tensor: torch.Tensor, keepdim, correction): 153 self._test_var_ethosu_BI_pipeline( 154 self.Var(), 155 common.get_u55_compile_spec(), 156 (test_tensor, keepdim, correction), 157 ) 158 159 @parameterized.expand(Var.test_parameters) 160 def test_var_u85_BI(self, test_tensor: torch.Tensor, keepdim, correction): 161 self._test_var_ethosu_BI_pipeline( 162 self.Var(), 163 common.get_u85_compile_spec(), 164 (test_tensor, keepdim, correction), 165 ) 166 167 @parameterized.expand(VarDim.test_parameters) 168 def test_var_dim_tosa_MI(self, test_tensor: torch.Tensor, dim, keepdim, correction): 169 self._test_var_tosa_MI_pipeline( 170 self.VarDim(), (test_tensor, dim, keepdim, correction) 171 ) 172 173 @parameterized.expand(VarDim.test_parameters) 174 def test_var_dim_tosa_BI(self, test_tensor: torch.Tensor, dim, keepdim, correction): 175 self._test_var_tosa_BI_pipeline( 176 self.VarDim(), (test_tensor, dim, keepdim, correction) 177 ) 178 179 @parameterized.expand(VarDim.test_parameters) 180 def test_var_dim_u55_BI(self, test_tensor: torch.Tensor, dim, keepdim, correction): 181 self._test_var_ethosu_BI_pipeline( 182 self.VarDim(), 183 common.get_u55_compile_spec(), 184 (test_tensor, dim, keepdim, correction), 185 ) 186 187 @parameterized.expand(VarDim.test_parameters) 188 def test_var_dim_u85_BI(self, test_tensor: torch.Tensor, dim, keepdim, correction): 189 self._test_var_ethosu_BI_pipeline( 190 self.VarDim(), 191 common.get_u85_compile_spec(), 192 (test_tensor, dim, keepdim, correction), 193 ) 194 195 @parameterized.expand(VarCorrection.test_parameters) 196 def test_var_correction_tosa_MI( 197 self, test_tensor: torch.Tensor, dim, keepdim, correction 198 ): 199 self._test_var_tosa_MI_pipeline( 200 self.VarCorrection(), (test_tensor, dim, keepdim, correction) 201 ) 202 203 @parameterized.expand(VarCorrection.test_parameters) 204 def test_var_correction_tosa_BI( 205 self, test_tensor: torch.Tensor, dim, keepdim, correction 206 ): 207 self._test_var_tosa_BI_pipeline( 208 self.VarCorrection(), (test_tensor, dim, keepdim, correction) 209 ) 210 211 @parameterized.expand(VarCorrection.test_parameters) 212 def test_var_correction_u55_BI( 213 self, test_tensor: torch.Tensor, dim, keepdim, correction 214 ): 215 self._test_var_ethosu_BI_pipeline( 216 self.VarCorrection(), 217 common.get_u55_compile_spec(), 218 (test_tensor, dim, keepdim, correction), 219 ) 220 221 @parameterized.expand(VarCorrection.test_parameters) 222 def test_var_correction_u85_BI( 223 self, test_tensor: torch.Tensor, dim, keepdim, correction 224 ): 225 self._test_var_ethosu_BI_pipeline( 226 self.VarCorrection(), 227 common.get_u85_compile_spec(), 228 (test_tensor, dim, keepdim, correction), 229 ) 230