1# Copyright 2024 Arm Limited and/or its affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# 8# Tests the unsqueeze op which copies the data of the input tensor (possibly with new data format) 9# 10 11import unittest 12from typing import Sequence, Tuple 13 14import torch 15 16from executorch.backends.arm.test import common 17from executorch.backends.arm.test.tester.arm_tester import ArmTester 18 19from executorch.exir.backend.compile_spec_schema import CompileSpec 20from parameterized import parameterized 21 22 23class TestSimpleUnsqueeze(unittest.TestCase): 24 class Unsqueeze(torch.nn.Module): 25 shapes: list[int | Sequence[int]] = [5, (5, 5), (5, 4), (5, 4, 3)] 26 test_parameters: list[tuple[torch.Tensor]] = [(torch.randn(n),) for n in shapes] 27 28 def forward(self, x: torch.Tensor, dim): 29 return x.unsqueeze(dim) 30 31 def _test_unsqueeze_tosa_MI_pipeline( 32 self, module: torch.nn.Module, test_data: Tuple[torch.Tensor, int] 33 ): 34 ( 35 ArmTester( 36 module, 37 example_inputs=test_data, 38 compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), 39 ) 40 .export() 41 .check_count({"torch.ops.aten.unsqueeze.default": 1}) 42 .to_edge() 43 .partition() 44 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 45 .to_executorch() 46 .run_method_and_compare_outputs(inputs=test_data) 47 ) 48 49 def _test_unsqueeze_tosa_BI_pipeline( 50 self, module: torch.nn.Module, test_data: Tuple[torch.Tensor, int] 51 ): 52 ( 53 ArmTester( 54 module, 55 example_inputs=test_data, 56 compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), 57 ) 58 .quantize() 59 .export() 60 .check_count({"torch.ops.aten.unsqueeze.default": 1}) 61 .to_edge() 62 .partition() 63 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 64 .to_executorch() 65 .run_method_and_compare_outputs(inputs=test_data, qtol=1) 66 ) 67 68 def _test_unsqueeze_ethosu_BI_pipeline( 69 self, 70 compile_spec: CompileSpec, 71 module: torch.nn.Module, 72 test_data: Tuple[torch.Tensor, int], 73 ): 74 ( 75 ArmTester( 76 module, 77 example_inputs=test_data, 78 compile_spec=compile_spec, 79 ) 80 .quantize() 81 .export() 82 .check_count({"torch.ops.aten.unsqueeze.default": 1}) 83 .to_edge() 84 .partition() 85 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 86 .to_executorch() 87 ) 88 89 @parameterized.expand(Unsqueeze.test_parameters) 90 def test_unsqueeze_tosa_MI(self, test_tensor: torch.Tensor): 91 for i in range(-test_tensor.dim() - 1, test_tensor.dim() + 1): 92 self._test_unsqueeze_tosa_MI_pipeline(self.Unsqueeze(), (test_tensor, i)) 93 94 @parameterized.expand(Unsqueeze.test_parameters) 95 def test_unsqueeze_tosa_BI(self, test_tensor: torch.Tensor): 96 self._test_unsqueeze_tosa_BI_pipeline(self.Unsqueeze(), (test_tensor, 0)) 97 98 @parameterized.expand(Unsqueeze.test_parameters[:-1]) 99 def test_unsqueeze_u55_BI(self, test_tensor: torch.Tensor): 100 self._test_unsqueeze_ethosu_BI_pipeline( 101 common.get_u55_compile_spec(), 102 self.Unsqueeze(), 103 (test_tensor, 0), 104 ) 105 106 @parameterized.expand(Unsqueeze.test_parameters) 107 def test_unsqueeze_u85_BI(self, test_tensor: torch.Tensor): 108 self._test_unsqueeze_ethosu_BI_pipeline( 109 common.get_u85_compile_spec(), 110 self.Unsqueeze(), 111 (test_tensor, 0), 112 ) 113