1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# Copyright 2024 Arm Limited and/or its affiliates. 3# All rights reserved. 4# 5# This source code is licensed under the BSD-style license found in the 6# LICENSE file in the root directory of this source tree. 7 8import unittest 9 10from typing import Tuple 11 12import torch 13from executorch.backends.arm.test import common 14 15from executorch.backends.arm.test.tester.arm_tester import ArmTester 16from executorch.exir.backend.compile_spec_schema import CompileSpec 17from parameterized import parameterized 18 19 20class TestSimpleSub(unittest.TestCase): 21 class Sub(torch.nn.Module): 22 test_parameters = [ 23 (torch.ones(5),), 24 (3 * torch.ones(8),), 25 (10 * torch.randn(8),), 26 ] 27 28 def forward(self, x): 29 return x - x 30 31 class Sub2(torch.nn.Module): 32 test_parameters = [ 33 (torch.randn(1, 1, 4, 4), torch.randn(1, 1, 4, 1)), 34 ] 35 36 def forward(self, x, y): 37 return x - y 38 39 def _test_sub_tosa_MI_pipeline( 40 self, module: torch.nn.Module, test_data: Tuple[torch.Tensor] 41 ): 42 ( 43 ArmTester( 44 module, 45 example_inputs=test_data, 46 compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), 47 ) 48 .export() 49 .check_count({"torch.ops.aten.sub.Tensor": 1}) 50 .check_not(["torch.ops.quantized_decomposed"]) 51 .to_edge() 52 .partition() 53 .check_not(["torch.ops.aten.sub.Tensor"]) 54 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 55 .to_executorch() 56 .run_method_and_compare_outputs(inputs=test_data) 57 ) 58 59 def _test_sub_tosa_BI_pipeline( 60 self, module: torch.nn.Module, test_data: Tuple[torch.Tensor] 61 ): 62 ( 63 ArmTester( 64 module, 65 example_inputs=test_data, 66 compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), 67 ) 68 .quantize() 69 .export() 70 .check_count({"torch.ops.aten.sub.Tensor": 1}) 71 .check(["torch.ops.quantized_decomposed"]) 72 .to_edge() 73 .partition() 74 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 75 .to_executorch() 76 .run_method_and_compare_outputs(inputs=test_data, qtol=1) 77 ) 78 79 def _test_sub_ethosu_BI_pipeline( 80 self, 81 compile_spec: list[CompileSpec], 82 module: torch.nn.Module, 83 test_data: Tuple[torch.Tensor], 84 ): 85 ( 86 ArmTester( 87 module, 88 example_inputs=test_data, 89 compile_spec=compile_spec, 90 ) 91 .quantize() 92 .export() 93 .check_count({"torch.ops.aten.sub.Tensor": 1}) 94 .check(["torch.ops.quantized_decomposed"]) 95 .to_edge() 96 .partition() 97 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 98 .to_executorch() 99 ) 100 101 @parameterized.expand(Sub.test_parameters) 102 def test_sub_tosa_MI(self, test_data: torch.Tensor): 103 test_data = (test_data,) 104 self._test_sub_tosa_MI_pipeline(self.Sub(), test_data) 105 106 @parameterized.expand(Sub.test_parameters) 107 def test_sub_tosa_BI(self, test_data: torch.Tensor): 108 test_data = (test_data,) 109 self._test_sub_tosa_BI_pipeline(self.Sub(), test_data) 110 111 @parameterized.expand(Sub.test_parameters) 112 def test_sub_u55_BI(self, test_data: torch.Tensor): 113 test_data = (test_data,) 114 self._test_sub_ethosu_BI_pipeline( 115 common.get_u55_compile_spec(), self.Sub(), test_data 116 ) 117 118 @parameterized.expand(Sub.test_parameters) 119 def test_sub_u85_BI(self, test_data: torch.Tensor): 120 test_data = (test_data,) 121 self._test_sub_ethosu_BI_pipeline( 122 common.get_u85_compile_spec(), self.Sub(), test_data 123 ) 124 125 @parameterized.expand(Sub2.test_parameters) 126 def test_sub2_tosa_MI(self, operand1: torch.Tensor, operand2: torch.Tensor): 127 test_data = (operand1, operand2) 128 self._test_sub_tosa_MI_pipeline(self.Sub2(), test_data) 129 130 @parameterized.expand(Sub2.test_parameters) 131 def test_sub2_tosa_BI(self, operand1: torch.Tensor, operand2: torch.Tensor): 132 test_data = (operand1, operand2) 133 self._test_sub_tosa_BI_pipeline(self.Sub2(), test_data) 134 135 @parameterized.expand(Sub2.test_parameters) 136 def test_sub2_u55_BI(self, operand1: torch.Tensor, operand2: torch.Tensor): 137 test_data = (operand1, operand2) 138 self._test_sub_ethosu_BI_pipeline( 139 common.get_u55_compile_spec(), self.Sub2(), test_data 140 ) 141 142 @parameterized.expand(Sub2.test_parameters) 143 def test_sub2_u85_BI(self, operand1: torch.Tensor, operand2: torch.Tensor): 144 test_data = (operand1, operand2) 145 self._test_sub_ethosu_BI_pipeline( 146 common.get_u85_compile_spec(), self.Sub2(), test_data 147 ) 148