1# Copyright 2024 Arm Limited and/or its affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# 8# Tests the repeat op which copies the data of the input tensor (possibly with new data format) 9# 10 11import unittest 12from typing import Sequence, Tuple 13 14import torch 15 16from executorch.backends.arm.quantizer.arm_quantizer import ( 17 ArmQuantizer, 18 get_symmetric_quantization_config, 19) 20from executorch.backends.arm.test import common 21from executorch.backends.arm.test.tester.arm_tester import ArmTester 22 23from executorch.backends.xnnpack.test.tester.tester import Quantize 24from executorch.exir.backend.backend_details import CompileSpec 25from parameterized import parameterized 26 27 28class TestSimpleRepeat(unittest.TestCase): 29 """Tests Tensor.repeat for different ranks and dimensions.""" 30 31 class Repeat(torch.nn.Module): 32 # (input tensor, multiples) 33 test_parameters = [ 34 (torch.randn(3), (2,)), 35 (torch.randn(3, 4), (2, 1)), 36 (torch.randn(1, 1, 2, 2), (1, 2, 3, 4)), 37 (torch.randn(3), (2, 2)), 38 (torch.randn(3), (1, 2, 3)), 39 (torch.randn((3, 3)), (2, 2, 2)), 40 ] 41 42 def forward(self, x: torch.Tensor, multiples: Sequence): 43 return x.repeat(multiples) 44 45 def _test_repeat_tosa_MI_pipeline(self, module: torch.nn.Module, test_data: Tuple): 46 ( 47 ArmTester( 48 module, 49 example_inputs=test_data, 50 compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), 51 ) 52 .export() 53 .check_count({"torch.ops.aten.repeat.default": 1}) 54 .to_edge() 55 .partition() 56 .check_not(["torch.ops.aten.repeat.default"]) 57 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 58 .to_executorch() 59 .run_method_and_compare_outputs(inputs=test_data) 60 ) 61 62 def _test_repeat_tosa_BI_pipeline(self, module: torch.nn.Module, test_data: Tuple): 63 quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config()) 64 ( 65 ArmTester( 66 module, 67 example_inputs=test_data, 68 compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), 69 ) 70 .quantize(Quantize(quantizer, get_symmetric_quantization_config())) 71 .export() 72 .check_count({"torch.ops.aten.repeat.default": 1}) 73 .to_edge() 74 .partition() 75 .check_not(["torch.ops.aten.repeat.default"]) 76 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 77 .to_executorch() 78 .run_method_and_compare_outputs(inputs=test_data, qtol=1) 79 ) 80 81 def _test_repeat_ethosu_pipeline( 82 self, compile_spec: CompileSpec, module: torch.nn.Module, test_data: Tuple 83 ): 84 quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config()) 85 ( 86 ArmTester( 87 module, 88 example_inputs=test_data, 89 compile_spec=compile_spec, 90 ) 91 .quantize(Quantize(quantizer, get_symmetric_quantization_config())) 92 .export() 93 .check_count({"torch.ops.aten.repeat.default": 1}) 94 .to_edge() 95 .partition() 96 .check_not(["torch.ops.aten.repeat.default"]) 97 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 98 .to_executorch() 99 ) 100 101 @parameterized.expand(Repeat.test_parameters) 102 def test_repeat_tosa_MI(self, test_input, multiples): 103 self._test_repeat_tosa_MI_pipeline(self.Repeat(), (test_input, multiples)) 104 105 @parameterized.expand(Repeat.test_parameters) 106 def test_repeat_tosa_BI(self, test_input, multiples): 107 self._test_repeat_tosa_BI_pipeline(self.Repeat(), (test_input, multiples)) 108 109 @parameterized.expand(Repeat.test_parameters) 110 def test_repeat_u55_BI(self, test_input, multiples): 111 self._test_repeat_ethosu_pipeline( 112 common.get_u55_compile_spec(), self.Repeat(), (test_input, multiples) 113 ) 114 115 @parameterized.expand(Repeat.test_parameters) 116 def test_repeat_u85_BI(self, test_input, multiples): 117 self._test_repeat_ethosu_pipeline( 118 common.get_u85_compile_spec(), self.Repeat(), (test_input, multiples) 119 ) 120