1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import unittest 8 9import torch 10from executorch.backends.xnnpack.test.tester import Tester 11 12 13class TestAdd(unittest.TestCase): 14 class Add(torch.nn.Module): 15 def __init__(self): 16 super().__init__() 17 18 def forward(self, x, y): 19 z = x + y 20 z = z + x 21 z = z + x 22 z = z + z 23 return z 24 25 class Add2(torch.nn.Module): 26 def __init__(self): 27 super().__init__() 28 29 def forward(self, x): 30 z = x + x 31 return z 32 33 class AddConstant(torch.nn.Module): 34 def __init__(self, constant): 35 super().__init__() 36 self._constant1 = constant 37 self.register_buffer("_constant2", constant, persistent=False) 38 self.register_parameter("_constant3", torch.nn.Parameter(constant)) 39 40 def forward(self, x): 41 out1 = x + self._constant1 + torch.ones(1, 1, 1) 42 out2 = x + self._constant2 + self._constant3 43 return out1, out2 44 45 def _test_add(self, inputs): 46 ( 47 Tester(self.Add(), inputs) 48 .export() 49 .check_count({"torch.ops.aten.add.Tensor": 4}) 50 .to_edge_transform_and_lower() 51 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 52 .check_not(["executorch_exir_dialects_edge__ops_aten_add_Tensor"]) 53 .to_executorch() 54 .serialize() 55 .run_method_and_compare_outputs() 56 ) 57 58 def test_fp16_add(self): 59 inputs = (torch.randn(1).to(torch.float16), torch.randn(1).to(torch.float16)) 60 self._test_add(inputs) 61 62 def test_fp32_add(self): 63 inputs = (torch.randn(1), torch.randn(1)) 64 self._test_add(inputs) 65 66 def test_fp32_add_constant(self): 67 inputs = (torch.randn(4, 4, 4),) 68 ( 69 Tester(self.AddConstant(torch.randn(4, 4, 4)), inputs) 70 .export() 71 .check_count({"torch.ops.aten.add.Tensor": 4}) 72 .to_edge_transform_and_lower() 73 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 74 .check_not(["executorch_exir_dialects_edge__ops_aten_add_Tensor"]) 75 .to_executorch() 76 .serialize() 77 .run_method_and_compare_outputs() 78 ) 79 80 def test_qs8_add_constant(self): 81 inputs = (torch.randn(4, 4, 4),) 82 ( 83 Tester(self.AddConstant(torch.randn(4, 4, 4)), inputs) 84 .quantize() 85 .export() 86 .check_count({"torch.ops.aten.add.Tensor": 4}) 87 .to_edge_transform_and_lower() 88 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 89 .check_not(["executorch_exir_dialects_edge__ops_aten_add_Tensor"]) 90 .to_executorch() 91 .serialize() 92 .run_method_and_compare_outputs() 93 ) 94 95 def test_qs8_add(self): 96 inputs = (torch.randn(1, 1, 4, 4), torch.randn(1, 1, 4, 4)) 97 ( 98 Tester(self.Add(), inputs) 99 .quantize() 100 .export() 101 .check_count({"torch.ops.aten.add.Tensor": 4}) 102 .check(["torch.ops.quantized_decomposed"]) 103 .to_edge_transform_and_lower() 104 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 105 .check_not( 106 [ 107 "executorch_exir_dialects_edge__ops_aten_add_Tensor", 108 "torch.ops.quantized_decomposed", 109 ] 110 ) 111 .to_executorch() 112 .serialize() 113 .run_method_and_compare_outputs() 114 ) 115 116 def test_qs8_add2(self): 117 inputs = (torch.randn(1, 1, 4, 4),) 118 ( 119 Tester(self.Add2(), inputs) 120 .quantize() 121 .export() 122 .check_count({"torch.ops.aten.add.Tensor": 1}) 123 .check(["torch.ops.quantized_decomposed"]) 124 .to_edge_transform_and_lower() 125 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 126 .check_not( 127 [ 128 "executorch_exir_dialects_edge__ops_aten_add_Tensor", 129 "torch.ops.quantized_decomposed", 130 ] 131 ) 132 .to_executorch() 133 .serialize() 134 .run_method_and_compare_outputs() 135 ) 136 137 def test_qs8_add3(self): 138 inputs = (torch.randn(1, 1, 4, 4), torch.randn(1, 1, 4, 1)) 139 ( 140 Tester(self.Add(), inputs) 141 .quantize() 142 .export() 143 .check_count({"torch.ops.aten.add.Tensor": 4}) 144 .check(["torch.ops.quantized_decomposed"]) 145 .to_edge_transform_and_lower() 146 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 147 .check_not( 148 [ 149 "executorch_exir_dialects_edge__ops_aten_add_Tensor", 150 "torch.ops.quantized_decomposed", 151 ] 152 ) 153 .to_executorch() 154 .serialize() 155 .run_method_and_compare_outputs() 156 ) 157 158 class AddRelu(torch.nn.Module): 159 def forward(self, x, y): 160 z = x + y 161 return torch.nn.functional.relu(z) 162 163 def test_fp32_add_relu(self): 164 inputs = (torch.randn(1, 1, 4, 4), torch.randn(1, 1, 4, 4)) 165 ( 166 Tester(self.AddRelu(), inputs) 167 .export() 168 .check_count({"torch.ops.aten.add.Tensor": 1}) 169 .check_count({"torch.ops.aten.relu.default": 1}) 170 .to_edge_transform_and_lower() 171 .check_not(["executorch_exir_dialects_edge__ops_aten_add_Tensor"]) 172 .check_not(["executorch_exir_dialects_edge__ops_aten_relu_default"]) 173 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 174 .to_executorch() 175 .serialize() 176 .run_method_and_compare_outputs() 177 ) 178 179 def test_qs8_add_relu(self): 180 inputs = (torch.randn(1, 1, 4, 4), torch.randn(1, 1, 4, 4)) 181 ( 182 Tester(self.AddRelu(), inputs) 183 .quantize() 184 .export() 185 .check_count({"torch.ops.aten.add.Tensor": 1}) 186 .check_count({"torch.ops.aten.relu.default": 1}) 187 .check(["torch.ops.quantized_decomposed"]) 188 .to_edge_transform_and_lower() 189 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 190 .to_executorch() 191 .serialize() 192 .run_method_and_compare_outputs() 193 ) 194 195 def test_qs8_add_relu_seq(self): 196 class AddReLU(torch.nn.Module): 197 def __init__(self): 198 super().__init__() 199 self.relu = torch.nn.ReLU() 200 201 def forward(self, x, z): 202 y = x + z 203 y = self.relu(y) 204 y = y + y 205 y = self.relu(y) 206 return y 207 208 inputs = ( 209 torch.randn( 210 1, 211 1, 212 20, 213 20, 214 ), 215 torch.randn( 216 1, 217 1, 218 20, 219 20, 220 ), 221 ) 222 223 ( 224 Tester(self.AddRelu(), inputs) 225 .quantize() 226 .export() 227 .check_count( 228 {"torch.ops.aten.add.Tensor": 1, "torch.ops.aten.relu.default": 1} 229 ) 230 .check(["torch.ops.quantized_decomposed"]) 231 .to_edge_transform_and_lower() 232 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 233 .to_executorch() 234 .serialize() 235 .run_method_and_compare_outputs() 236 ) 237