1# Copyright 2024 Arm Limited and/or its affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# 8# Tests the squeeze op which squeezes a given dimension with size 1 into a lower ranked tensor. 9# 10 11import unittest 12from typing import Optional, Tuple 13 14import torch 15 16from executorch.backends.arm.test import common 17from executorch.backends.arm.test.tester.arm_tester import ArmTester 18 19from executorch.exir.backend.compile_spec_schema import CompileSpec 20from parameterized import parameterized 21 22 23class TestSqueeze(unittest.TestCase): 24 class SqueezeDim(torch.nn.Module): 25 test_parameters: list[tuple[torch.Tensor, int]] = [ 26 (torch.randn(1, 1, 5), -2), 27 (torch.randn(1, 2, 3, 1), 3), 28 (torch.randn(1, 5, 1, 5), -2), 29 ] 30 31 def forward(self, x: torch.Tensor, dim: int): 32 return x.squeeze(dim) 33 34 class SqueezeDims(torch.nn.Module): 35 test_parameters: list[tuple[torch.Tensor, tuple[int]]] = [ 36 (torch.randn(1, 1, 5), (0, 1)), 37 (torch.randn(1, 5, 5, 1), (0, -1)), 38 (torch.randn(1, 5, 1, 5), (0, -2)), 39 ] 40 41 def forward(self, x: torch.Tensor, dims: tuple[int]): 42 return x.squeeze(dims) 43 44 class Squeeze(torch.nn.Module): 45 test_parameters: list[tuple[torch.Tensor]] = [ 46 (torch.randn(1, 1, 5),), 47 (torch.randn(1, 5, 5, 1),), 48 (torch.randn(1, 5, 1, 5),), 49 ] 50 51 def forward(self, x: torch.Tensor): 52 return x.squeeze() 53 54 def _test_squeeze_tosa_MI_pipeline( 55 self, 56 module: torch.nn.Module, 57 test_data: Tuple[torch.Tensor, Optional[tuple[int]]], 58 export_target: str, 59 ): 60 ( 61 ArmTester( 62 module, 63 example_inputs=test_data, 64 compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), 65 ) 66 .export() 67 .check_count({export_target: 1}) 68 .to_edge() 69 .partition() 70 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 71 .to_executorch() 72 .run_method_and_compare_outputs(inputs=test_data) 73 ) 74 75 def _test_squeeze_tosa_BI_pipeline( 76 self, 77 module: torch.nn.Module, 78 test_data: Tuple[torch.Tensor, Optional[tuple[int]]], 79 export_target: str, 80 ): 81 ( 82 ArmTester( 83 module, 84 example_inputs=test_data, 85 compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), 86 ) 87 .quantize() 88 .export() 89 .check_count({export_target: 1}) 90 .to_edge() 91 .partition() 92 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 93 .to_executorch() 94 .run_method_and_compare_outputs(inputs=test_data, qtol=1) 95 ) 96 97 def _test_squeeze_ethosu_BI_pipeline( 98 self, 99 compile_spec: CompileSpec, 100 module: torch.nn.Module, 101 test_data: Tuple[torch.Tensor, Optional[tuple[int]]], 102 export_target: str, 103 ): 104 ( 105 ArmTester(module, example_inputs=test_data, compile_spec=compile_spec) 106 .quantize() 107 .export() 108 .check_count({export_target: 1}) 109 .to_edge() 110 .partition() 111 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 112 .to_executorch() 113 ) 114 115 @parameterized.expand(Squeeze.test_parameters) 116 def test_squeeze_tosa_MI( 117 self, 118 test_tensor: torch.Tensor, 119 ): 120 self._test_squeeze_tosa_MI_pipeline( 121 self.Squeeze(), (test_tensor,), "torch.ops.aten.squeeze.default" 122 ) 123 124 @parameterized.expand(Squeeze.test_parameters) 125 def test_squeeze_tosa_BI( 126 self, 127 test_tensor: torch.Tensor, 128 ): 129 self._test_squeeze_tosa_BI_pipeline( 130 self.Squeeze(), (test_tensor,), "torch.ops.aten.squeeze.default" 131 ) 132 133 @parameterized.expand(Squeeze.test_parameters) 134 def test_squeeze_u55_BI( 135 self, 136 test_tensor: torch.Tensor, 137 ): 138 self._test_squeeze_ethosu_BI_pipeline( 139 common.get_u55_compile_spec(permute_memory_to_nhwc=False), 140 self.Squeeze(), 141 (test_tensor,), 142 "torch.ops.aten.squeeze.default", 143 ) 144 145 @parameterized.expand(Squeeze.test_parameters) 146 def test_squeeze_u85_BI( 147 self, 148 test_tensor: torch.Tensor, 149 ): 150 self._test_squeeze_ethosu_BI_pipeline( 151 common.get_u85_compile_spec(permute_memory_to_nhwc=True), 152 self.Squeeze(), 153 (test_tensor,), 154 "torch.ops.aten.squeeze.default", 155 ) 156 157 @parameterized.expand(SqueezeDim.test_parameters) 158 def test_squeeze_dim_tosa_MI(self, test_tensor: torch.Tensor, dim: int): 159 self._test_squeeze_tosa_MI_pipeline( 160 self.SqueezeDim(), (test_tensor, dim), "torch.ops.aten.squeeze.dim" 161 ) 162 163 @parameterized.expand(SqueezeDim.test_parameters) 164 def test_squeeze_dim_tosa_BI(self, test_tensor: torch.Tensor, dim: int): 165 self._test_squeeze_tosa_BI_pipeline( 166 self.SqueezeDim(), (test_tensor, dim), "torch.ops.aten.squeeze.dim" 167 ) 168 169 @parameterized.expand(SqueezeDim.test_parameters) 170 def test_squeeze_dim_u55_BI(self, test_tensor: torch.Tensor, dim: int): 171 self._test_squeeze_ethosu_BI_pipeline( 172 common.get_u55_compile_spec(permute_memory_to_nhwc=False), 173 self.SqueezeDim(), 174 (test_tensor, dim), 175 "torch.ops.aten.squeeze.dim", 176 ) 177 178 @parameterized.expand(SqueezeDim.test_parameters) 179 def test_squeeze_dim_u85_BI(self, test_tensor: torch.Tensor, dim: int): 180 self._test_squeeze_ethosu_BI_pipeline( 181 common.get_u85_compile_spec(permute_memory_to_nhwc=True), 182 self.SqueezeDim(), 183 (test_tensor, dim), 184 "torch.ops.aten.squeeze.dim", 185 ) 186 187 @parameterized.expand(SqueezeDims.test_parameters) 188 def test_squeeze_dims_tosa_MI(self, test_tensor: torch.Tensor, dims: tuple[int]): 189 self._test_squeeze_tosa_MI_pipeline( 190 self.SqueezeDims(), (test_tensor, dims), "torch.ops.aten.squeeze.dims" 191 ) 192 193 @parameterized.expand(SqueezeDims.test_parameters) 194 def test_squeeze_dims_tosa_BI(self, test_tensor: torch.Tensor, dims: tuple[int]): 195 self._test_squeeze_tosa_BI_pipeline( 196 self.SqueezeDims(), (test_tensor, dims), "torch.ops.aten.squeeze.dims" 197 ) 198 199 @parameterized.expand(SqueezeDims.test_parameters) 200 def test_squeeze_dims_u55_BI(self, test_tensor: torch.Tensor, dims: tuple[int]): 201 self._test_squeeze_ethosu_BI_pipeline( 202 common.get_u55_compile_spec(permute_memory_to_nhwc=False), 203 self.SqueezeDims(), 204 (test_tensor, dims), 205 "torch.ops.aten.squeeze.dims", 206 ) 207 208 @parameterized.expand(SqueezeDims.test_parameters) 209 def test_squeeze_dims_u85_BI(self, test_tensor: torch.Tensor, dims: tuple[int]): 210 self._test_squeeze_ethosu_BI_pipeline( 211 common.get_u85_compile_spec(), 212 self.SqueezeDims(), 213 (test_tensor, dims), 214 "torch.ops.aten.squeeze.dims", 215 ) 216