1# Copyright 2024 Arm Limited and/or its affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# 8# Tests the expand op which copies the data of the input tensor (possibly with new data format) 9# 10 11import unittest 12from typing import Sequence, Tuple 13 14import torch 15 16from executorch.backends.arm.quantizer.arm_quantizer import ( 17 ArmQuantizer, 18 get_symmetric_quantization_config, 19) 20from executorch.backends.arm.test import common 21from executorch.backends.arm.test.tester.arm_tester import ArmTester 22 23from executorch.backends.xnnpack.test.tester.tester import Quantize 24from executorch.exir.backend.backend_details import CompileSpec 25from parameterized import parameterized 26 27 28class TestSimpleExpand(unittest.TestCase): 29 """Tests the Tensor.expand which should be converted to a repeat op by a pass.""" 30 31 class Expand(torch.nn.Module): 32 # (input tensor, multiples) 33 test_parameters = [ 34 (torch.ones(1), (2,)), 35 (torch.ones(1, 4), (1, -1)), 36 (torch.ones(1, 1, 2, 2), (4, 3, -1, 2)), 37 (torch.ones(1), (2, 2, 4)), 38 (torch.ones(3, 2, 4, 1), (-1, -1, -1, 3)), 39 ] 40 41 def forward(self, x: torch.Tensor, multiples: Sequence): 42 return x.expand(multiples) 43 44 def _test_expand_tosa_MI_pipeline(self, module: torch.nn.Module, test_data: Tuple): 45 ( 46 ArmTester( 47 module, 48 example_inputs=test_data, 49 compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), 50 ) 51 .export() 52 .check_count({"torch.ops.aten.expand.default": 1}) 53 .to_edge() 54 .partition() 55 .check_not(["torch.ops.aten.expand.default"]) 56 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 57 .to_executorch() 58 .run_method_and_compare_outputs(inputs=test_data) 59 ) 60 61 def _test_expand_tosa_BI_pipeline(self, module: torch.nn.Module, test_data: Tuple): 62 quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config()) 63 ( 64 ArmTester( 65 module, 66 example_inputs=test_data, 67 compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), 68 ) 69 .quantize(Quantize(quantizer, get_symmetric_quantization_config())) 70 .export() 71 .check_count({"torch.ops.aten.expand.default": 1}) 72 .to_edge() 73 .partition() 74 .check_not(["torch.ops.aten.expand.default"]) 75 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 76 .to_executorch() 77 .run_method_and_compare_outputs(inputs=test_data, qtol=1) 78 ) 79 80 def _test_expand_ethosu_BI_pipeline( 81 self, compile_spec: CompileSpec, module: torch.nn.Module, test_data: Tuple 82 ): 83 quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config()) 84 ( 85 ArmTester( 86 module, 87 example_inputs=test_data, 88 compile_spec=compile_spec, 89 ) 90 .quantize(Quantize(quantizer, get_symmetric_quantization_config())) 91 .export() 92 .check_count({"torch.ops.aten.expand.default": 1}) 93 .to_edge() 94 .partition() 95 .check_not(["torch.ops.aten.expand.default"]) 96 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 97 .to_executorch() 98 ) 99 100 @parameterized.expand(Expand.test_parameters) 101 def test_expand_tosa_MI(self, test_input, multiples): 102 self._test_expand_tosa_MI_pipeline(self.Expand(), (test_input, multiples)) 103 104 @parameterized.expand(Expand.test_parameters) 105 def test_expand_tosa_BI(self, test_input, multiples): 106 self._test_expand_tosa_BI_pipeline(self.Expand(), (test_input, multiples)) 107 108 @parameterized.expand(Expand.test_parameters) 109 def test_expand_u55_BI(self, test_input, multiples): 110 self._test_expand_ethosu_BI_pipeline( 111 common.get_u55_compile_spec(), self.Expand(), (test_input, multiples) 112 ) 113 114 @parameterized.expand(Expand.test_parameters) 115 def test_expand_u85_BI(self, test_input, multiples): 116 self._test_expand_ethosu_BI_pipeline( 117 common.get_u85_compile_spec(), self.Expand(), (test_input, multiples) 118 ) 119