1# Copyright 2024 Arm Limited and/or its affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6# 7# Tests the rsqrt op. 8# 9 10import unittest 11 12import torch 13from executorch.backends.arm.test import common 14from executorch.backends.arm.test.tester.arm_tester import ArmTester 15from executorch.exir.backend.compile_spec_schema import CompileSpec 16from parameterized import parameterized 17 18 19class TestRsqrt(unittest.TestCase): 20 class Rsqrt(torch.nn.Module): 21 test_parameters = [ 22 (torch.ones(1, 10, 10, 10),), 23 (torch.rand(1, 10, 10, 10),), 24 (torch.rand(1, 5, 10, 20),), 25 (torch.rand(5, 10, 20),), 26 ] 27 28 def forward(self, x: torch.Tensor): 29 return x.rsqrt() 30 31 def _test_rsqrt_tosa_MI_pipeline( 32 self, module: torch.nn.Module, test_data: tuple[torch.Tensor] 33 ): 34 ( 35 ArmTester( 36 module, 37 example_inputs=test_data, 38 compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), 39 ) 40 .export() 41 .check_count({"torch.ops.aten.rsqrt.default": 1}) 42 .to_edge() 43 .partition() 44 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 45 .to_executorch() 46 .run_method_and_compare_outputs(inputs=test_data) 47 ) 48 49 def _test_rsqrt_tosa_BI_pipeline( 50 self, module: torch.nn.Module, test_data: tuple[torch.Tensor] 51 ): 52 ( 53 ArmTester( 54 module, 55 example_inputs=test_data, 56 compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), 57 ) 58 .quantize() 59 .export() 60 .check_count({"torch.ops.aten.rsqrt.default": 1}) 61 .to_edge() 62 .partition() 63 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 64 .to_executorch() 65 .run_method_and_compare_outputs(inputs=test_data) 66 ) 67 68 def _test_rsqrt_ethosu_BI_pipeline( 69 self, 70 compile_spec: CompileSpec, 71 module: torch.nn.Module, 72 test_data: tuple[torch.Tensor], 73 ): 74 ( 75 ArmTester( 76 module, 77 example_inputs=test_data, 78 compile_spec=compile_spec, 79 ) 80 .quantize() 81 .export() 82 .check_count({"torch.ops.aten.rsqrt.default": 1}) 83 .to_edge() 84 .partition() 85 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 86 .to_executorch() 87 ) 88 89 @parameterized.expand(Rsqrt.test_parameters) 90 def test_rsqrt_tosa_MI(self, test_tensor: torch.Tensor): 91 self._test_rsqrt_tosa_MI_pipeline(self.Rsqrt(), (test_tensor,)) 92 93 @parameterized.expand(Rsqrt.test_parameters) 94 def test_rsqrt_tosa_BI(self, test_tensor: torch.Tensor): 95 self._test_rsqrt_tosa_BI_pipeline(self.Rsqrt(), (test_tensor,)) 96 97 @parameterized.expand(Rsqrt.test_parameters) 98 def test_rsqrt_u55_BI(self, test_tensor: torch.Tensor): 99 self._test_rsqrt_ethosu_BI_pipeline( 100 common.get_u55_compile_spec(), self.Rsqrt(), (test_tensor,) 101 ) 102 103 @parameterized.expand(Rsqrt.test_parameters) 104 def test_rsqrt_u85_BI(self, test_tensor: torch.Tensor): 105 self._test_rsqrt_ethosu_BI_pipeline( 106 common.get_u85_compile_spec(), self.Rsqrt(), (test_tensor,) 107 ) 108