1# Copyright 2024 Arm Limited and/or its affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# 8# Tests the full op which creates a tensor of a given shape filled with a given value. 9# The shape and value are set at compile time, i.e. can't be set by a tensor input. 10# 11 12import unittest 13from typing import Tuple 14 15import torch 16from executorch.backends.arm.test import common 17from executorch.backends.arm.test.tester.arm_tester import ArmTester 18from executorch.exir.backend.compile_spec_schema import CompileSpec 19from parameterized import parameterized 20 21 22class TestFull(unittest.TestCase): 23 """Tests the full op which creates a tensor of a given shape filled with a given value.""" 24 25 class Full(torch.nn.Module): 26 # A single full op 27 def forward(self): 28 return torch.full((3, 3), 4.5) 29 30 class AddConstFull(torch.nn.Module): 31 # Input + a full with constant value. 32 def forward(self, x: torch.Tensor): 33 return torch.full((2, 2, 3, 3), 4.5, dtype=torch.float32) + x 34 35 class AddVariableFull(torch.nn.Module): 36 sizes = [ 37 (5), 38 (5, 5), 39 (5, 5, 5), 40 (1, 5, 5, 5), 41 ] 42 test_parameters = [((torch.randn(n) * 10 - 5, 3.2),) for n in sizes] 43 44 def forward(self, x: torch.Tensor, y): 45 # Input + a full with the shape from the input and a given value 'y'. 46 return x + torch.full(x.shape, y) 47 48 def _test_full_tosa_MI_pipeline( 49 self, 50 module: torch.nn.Module, 51 example_data: Tuple, 52 test_data: Tuple | None = None, 53 ): 54 if test_data is None: 55 test_data = example_data 56 ( 57 ArmTester( 58 module, 59 example_inputs=example_data, 60 compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), 61 ) 62 .export() 63 .check_count({"torch.ops.aten.full.default": 1}) 64 .to_edge() 65 .partition() 66 .check_not(["executorch_exir_dialects_edge__ops_aten_full_default"]) 67 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 68 .to_executorch() 69 .run_method_and_compare_outputs(inputs=test_data) 70 ) 71 72 def _test_full_tosa_BI_pipeline( 73 self, 74 module: torch.nn.Module, 75 test_data: Tuple, 76 permute_memory_to_nhwc: bool, 77 ): 78 ( 79 ArmTester( 80 module, 81 example_inputs=test_data, 82 compile_spec=common.get_tosa_compile_spec( 83 "TOSA-0.80.0+BI", permute_memory_to_nhwc=permute_memory_to_nhwc 84 ), 85 ) 86 .quantize() 87 .export() 88 .check_count({"torch.ops.aten.full.default": 1}) 89 .to_edge() 90 .partition() 91 .check_not(["executorch_exir_dialects_edge__ops_aten_full_default"]) 92 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 93 .to_executorch() 94 .run_method_and_compare_outputs(inputs=test_data) 95 ) 96 97 def _test_full_tosa_ethos_pipeline( 98 self, compile_spec: list[CompileSpec], module: torch.nn.Module, test_data: Tuple 99 ): 100 ( 101 ArmTester(module, example_inputs=test_data, compile_spec=compile_spec) 102 .quantize() 103 .export() 104 .check_count({"torch.ops.aten.full.default": 1}) 105 .to_edge() 106 .partition() 107 .check_not(["executorch_exir_dialects_edge__ops_aten_full_default"]) 108 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 109 .to_executorch() 110 ) 111 112 def _test_full_tosa_u55_pipeline(self, module: torch.nn.Module, test_data: Tuple): 113 self._test_full_tosa_ethos_pipeline( 114 common.get_u55_compile_spec(), module, test_data 115 ) 116 117 def _test_full_tosa_u85_pipeline(self, module: torch.nn.Module, test_data: Tuple): 118 self._test_full_tosa_ethos_pipeline( 119 common.get_u85_compile_spec(), module, test_data 120 ) 121 122 def test_only_full_tosa_MI(self): 123 self._test_full_tosa_MI_pipeline(self.Full(), ()) 124 125 def test_const_full_tosa_MI(self): 126 _input = torch.rand((2, 2, 3, 3)) * 10 127 self._test_full_tosa_MI_pipeline(self.AddConstFull(), (_input,)) 128 129 def test_const_full_nhwc_tosa_BI(self): 130 _input = torch.rand((2, 2, 3, 3)) * 10 131 self._test_full_tosa_BI_pipeline(self.AddConstFull(), (_input,), True) 132 133 @parameterized.expand(AddVariableFull.test_parameters) 134 def test_full_tosa_MI(self, test_tensor: Tuple): 135 self._test_full_tosa_MI_pipeline( 136 self.AddVariableFull(), example_data=test_tensor 137 ) 138 139 @parameterized.expand(AddVariableFull.test_parameters) 140 def test_full_tosa_BI(self, test_tensor: Tuple): 141 self._test_full_tosa_BI_pipeline(self.AddVariableFull(), test_tensor, False) 142 143 @parameterized.expand(AddVariableFull.test_parameters) 144 def test_full_u55_BI(self, test_tensor: Tuple): 145 self._test_full_tosa_u55_pipeline( 146 self.AddVariableFull(), 147 test_tensor, 148 ) 149 150 @parameterized.expand(AddVariableFull.test_parameters) 151 def test_full_u85_BI(self, test_tensor: Tuple): 152 self._test_full_tosa_u85_pipeline( 153 self.AddVariableFull(), 154 test_tensor, 155 ) 156 157 # This fails since full outputs int64 by default if 'fill_value' is integer, which our backend doesn't support. 158 @unittest.expectedFailure 159 def test_integer_value(self): 160 _input = torch.ones((2, 2)) 161 integer_fill_value = 1 162 self._test_full_tosa_MI_pipeline( 163 self.AddVariableFull(), example_data=(_input, integer_fill_value) 164 ) 165 166 # This fails since the fill value in the full tensor is set at compile time by the example data (1.). 167 # Test data tries to set it again at runtime (to 2.) but it doesn't do anything. 168 # In eager mode, the fill value can be set at runtime, causing the outputs to not match. 169 @unittest.expectedFailure 170 def test_set_value_at_runtime(self): 171 _input = torch.ones((2, 2)) 172 example_fill_value = 1.0 173 test_fill_value = 2.0 174 self._test_full_tosa_MI_pipeline( 175 self.AddVariableFull(), 176 example_data=(_input, example_fill_value), 177 test_data=(_input, test_fill_value), 178 ) 179