xref: /aosp_15_r20/external/executorch/backends/xnnpack/test/ops/add.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import unittest
8
9import torch
10from executorch.backends.xnnpack.test.tester import Tester
11
12
13class TestAdd(unittest.TestCase):
14    class Add(torch.nn.Module):
15        def __init__(self):
16            super().__init__()
17
18        def forward(self, x, y):
19            z = x + y
20            z = z + x
21            z = z + x
22            z = z + z
23            return z
24
25    class Add2(torch.nn.Module):
26        def __init__(self):
27            super().__init__()
28
29        def forward(self, x):
30            z = x + x
31            return z
32
33    class AddConstant(torch.nn.Module):
34        def __init__(self, constant):
35            super().__init__()
36            self._constant1 = constant
37            self.register_buffer("_constant2", constant, persistent=False)
38            self.register_parameter("_constant3", torch.nn.Parameter(constant))
39
40        def forward(self, x):
41            out1 = x + self._constant1 + torch.ones(1, 1, 1)
42            out2 = x + self._constant2 + self._constant3
43            return out1, out2
44
45    def _test_add(self, inputs):
46        (
47            Tester(self.Add(), inputs)
48            .export()
49            .check_count({"torch.ops.aten.add.Tensor": 4})
50            .to_edge_transform_and_lower()
51            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
52            .check_not(["executorch_exir_dialects_edge__ops_aten_add_Tensor"])
53            .to_executorch()
54            .serialize()
55            .run_method_and_compare_outputs()
56        )
57
58    def test_fp16_add(self):
59        inputs = (torch.randn(1).to(torch.float16), torch.randn(1).to(torch.float16))
60        self._test_add(inputs)
61
62    def test_fp32_add(self):
63        inputs = (torch.randn(1), torch.randn(1))
64        self._test_add(inputs)
65
66    def test_fp32_add_constant(self):
67        inputs = (torch.randn(4, 4, 4),)
68        (
69            Tester(self.AddConstant(torch.randn(4, 4, 4)), inputs)
70            .export()
71            .check_count({"torch.ops.aten.add.Tensor": 4})
72            .to_edge_transform_and_lower()
73            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
74            .check_not(["executorch_exir_dialects_edge__ops_aten_add_Tensor"])
75            .to_executorch()
76            .serialize()
77            .run_method_and_compare_outputs()
78        )
79
80    def test_qs8_add_constant(self):
81        inputs = (torch.randn(4, 4, 4),)
82        (
83            Tester(self.AddConstant(torch.randn(4, 4, 4)), inputs)
84            .quantize()
85            .export()
86            .check_count({"torch.ops.aten.add.Tensor": 4})
87            .to_edge_transform_and_lower()
88            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
89            .check_not(["executorch_exir_dialects_edge__ops_aten_add_Tensor"])
90            .to_executorch()
91            .serialize()
92            .run_method_and_compare_outputs()
93        )
94
95    def test_qs8_add(self):
96        inputs = (torch.randn(1, 1, 4, 4), torch.randn(1, 1, 4, 4))
97        (
98            Tester(self.Add(), inputs)
99            .quantize()
100            .export()
101            .check_count({"torch.ops.aten.add.Tensor": 4})
102            .check(["torch.ops.quantized_decomposed"])
103            .to_edge_transform_and_lower()
104            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
105            .check_not(
106                [
107                    "executorch_exir_dialects_edge__ops_aten_add_Tensor",
108                    "torch.ops.quantized_decomposed",
109                ]
110            )
111            .to_executorch()
112            .serialize()
113            .run_method_and_compare_outputs()
114        )
115
116    def test_qs8_add2(self):
117        inputs = (torch.randn(1, 1, 4, 4),)
118        (
119            Tester(self.Add2(), inputs)
120            .quantize()
121            .export()
122            .check_count({"torch.ops.aten.add.Tensor": 1})
123            .check(["torch.ops.quantized_decomposed"])
124            .to_edge_transform_and_lower()
125            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
126            .check_not(
127                [
128                    "executorch_exir_dialects_edge__ops_aten_add_Tensor",
129                    "torch.ops.quantized_decomposed",
130                ]
131            )
132            .to_executorch()
133            .serialize()
134            .run_method_and_compare_outputs()
135        )
136
137    def test_qs8_add3(self):
138        inputs = (torch.randn(1, 1, 4, 4), torch.randn(1, 1, 4, 1))
139        (
140            Tester(self.Add(), inputs)
141            .quantize()
142            .export()
143            .check_count({"torch.ops.aten.add.Tensor": 4})
144            .check(["torch.ops.quantized_decomposed"])
145            .to_edge_transform_and_lower()
146            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
147            .check_not(
148                [
149                    "executorch_exir_dialects_edge__ops_aten_add_Tensor",
150                    "torch.ops.quantized_decomposed",
151                ]
152            )
153            .to_executorch()
154            .serialize()
155            .run_method_and_compare_outputs()
156        )
157
158    class AddRelu(torch.nn.Module):
159        def forward(self, x, y):
160            z = x + y
161            return torch.nn.functional.relu(z)
162
163    def test_fp32_add_relu(self):
164        inputs = (torch.randn(1, 1, 4, 4), torch.randn(1, 1, 4, 4))
165        (
166            Tester(self.AddRelu(), inputs)
167            .export()
168            .check_count({"torch.ops.aten.add.Tensor": 1})
169            .check_count({"torch.ops.aten.relu.default": 1})
170            .to_edge_transform_and_lower()
171            .check_not(["executorch_exir_dialects_edge__ops_aten_add_Tensor"])
172            .check_not(["executorch_exir_dialects_edge__ops_aten_relu_default"])
173            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
174            .to_executorch()
175            .serialize()
176            .run_method_and_compare_outputs()
177        )
178
179    def test_qs8_add_relu(self):
180        inputs = (torch.randn(1, 1, 4, 4), torch.randn(1, 1, 4, 4))
181        (
182            Tester(self.AddRelu(), inputs)
183            .quantize()
184            .export()
185            .check_count({"torch.ops.aten.add.Tensor": 1})
186            .check_count({"torch.ops.aten.relu.default": 1})
187            .check(["torch.ops.quantized_decomposed"])
188            .to_edge_transform_and_lower()
189            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
190            .to_executorch()
191            .serialize()
192            .run_method_and_compare_outputs()
193        )
194
195    def test_qs8_add_relu_seq(self):
196        class AddReLU(torch.nn.Module):
197            def __init__(self):
198                super().__init__()
199                self.relu = torch.nn.ReLU()
200
201            def forward(self, x, z):
202                y = x + z
203                y = self.relu(y)
204                y = y + y
205                y = self.relu(y)
206                return y
207
208        inputs = (
209            torch.randn(
210                1,
211                1,
212                20,
213                20,
214            ),
215            torch.randn(
216                1,
217                1,
218                20,
219                20,
220            ),
221        )
222
223        (
224            Tester(self.AddRelu(), inputs)
225            .quantize()
226            .export()
227            .check_count(
228                {"torch.ops.aten.add.Tensor": 1, "torch.ops.aten.relu.default": 1}
229            )
230            .check(["torch.ops.quantized_decomposed"])
231            .to_edge_transform_and_lower()
232            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
233            .to_executorch()
234            .serialize()
235            .run_method_and_compare_outputs()
236        )
237