xref: /aosp_15_r20/external/executorch/backends/xnnpack/test/ops/conv2d.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import itertools
8import unittest
9from typing import Optional
10
11import torch
12from executorch.backends.xnnpack.test.test_xnnpack_utils import randomize_bn
13from executorch.backends.xnnpack.test.tester import Quantize, Tester
14from torch.ao.quantization.quantizer.xnnpack_quantizer import (
15    get_symmetric_quantization_config,
16)
17from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import QuantizationConfig
18
19
20class Conv2d(torch.nn.Module):
21    def __init__(
22        self,
23        in_channels=2,
24        out_channels=1,
25        kernel_size=(3, 3),
26        stride=(2, 2),
27        padding=(1, 1),
28        dilation=(1, 1),
29        groups=1,
30        bias=True,
31        padding_mode="zeros",
32        batches=1,
33        width=8,
34        height=8,
35        dtype=torch.float,
36    ):
37        super().__init__()
38        self.batches = batches
39        self.width = width
40        self.height = height
41        self.in_channels = in_channels
42        self.dtype = dtype
43
44        self.conv = torch.nn.Conv2d(
45            in_channels=in_channels,
46            out_channels=out_channels,
47            kernel_size=kernel_size,
48            stride=stride,
49            padding=padding,
50            dilation=dilation,
51            groups=groups,
52            bias=bias,
53            padding_mode=padding_mode,
54        ).to(dtype)
55
56    def forward(self, x):
57        return self.conv(x)
58
59    def get_inputs(self):
60        return (
61            torch.randn(self.batches, self.in_channels, self.height, self.width).to(
62                self.dtype
63            ),
64        )
65
66
67class Conv2dSeq(torch.nn.Module):
68    def __init__(self):
69        super().__init__()
70        self.first = torch.nn.Conv2d(
71            in_channels=1,
72            out_channels=3,
73            kernel_size=(3, 3),
74            padding=1,
75            bias=False,
76        )
77        self.second = torch.nn.Conv2d(
78            in_channels=3,
79            out_channels=2,
80            kernel_size=(3, 3),
81            padding=1,
82            bias=False,
83        )
84
85    def forward(self, x):
86        y = self.first(x)
87        return self.second(y)
88
89    def get_inputs(self):
90        return (torch.randn(1, 1, 3, 3),)
91
92
93class Conv2dBatchNorm(torch.nn.Module):
94    def __init__(self):
95        super().__init__()
96        self.conv1 = torch.nn.Conv2d(
97            2,
98            2,
99            (2, 2),
100            bias=False,
101            padding=[1, 1],
102            stride=[4, 4],
103        )
104        self.bn = randomize_bn(2)
105        self.hardtanh = torch.nn.Hardtanh()
106        self.conv2 = torch.nn.Conv2d(
107            2,
108            2,
109            (2, 2),
110            bias=False,
111            padding=[1, 1],
112            stride=[4, 4],
113        )
114
115    def forward(self, x):
116        y = self.conv1(x)
117        y = self.bn(y)
118        y = self.hardtanh(y)
119        y = self.conv2(y)
120        y = self.bn(y)
121        y = self.hardtanh(y)
122        return y
123
124    def get_inputs(self):
125        return (torch.randn(2, 2, 4, 4),)
126
127
128class Conv2dPermute(torch.nn.Module):
129    def __init__(self, permute_order):
130        super().__init__()
131        self.conv = torch.nn.Conv2d(
132            2,
133            2,
134            (2, 2),
135            bias=False,
136            padding=[2, 2],
137            stride=[2, 2],
138        )
139        self.permute_order = permute_order
140
141    def forward(self, x):
142        result = self.conv(x)
143        channels_last = torch.permute(result, self.permute_order)
144        return channels_last
145
146    def get_inputs(self):
147        return (torch.randn(2, 2, 4, 4),)
148
149
150class TestConv2d(unittest.TestCase):
151    def _test(
152        self,
153        m: torch.nn.Module,
154        quant_config: Optional[QuantizationConfig] = None,
155        conv_count=1,
156        dtype: torch.dtype = torch.float,
157    ):
158        # pyre-fixme[29]: `Union[torch._tensor.Tensor,
159        #  torch.nn.modules.module.Module]` is not a function.
160        tester = Tester(m.eval(), m.get_inputs())
161
162        if quant_config is not None:
163            tester = tester.quantize(Quantize(quantization_config=quant_config))
164            tester.check(["torch.ops.quantized_decomposed"])
165
166        (
167            tester.export()
168            .check_count({"torch.ops.aten.conv2d": conv_count})
169            .to_edge_transform_and_lower()
170            .check_not(["executorch_exir_dialects_edge__ops_aten_convolution_default"])
171            .check_not(
172                [
173                    "executorch_exir_dialects_edge__ops__native_batch_norm_legit_no_training_default"
174                ]
175            )
176            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
177            .to_executorch()
178            .serialize()
179            .run_method_and_compare_outputs(qtol=1)
180        )
181
182    def test_fp16_conv2d(self) -> None:
183        for has_bias in (True, False):
184            self._test(Conv2d(bias=has_bias, dtype=torch.float16))
185
186    def test_fp32_conv2d(self) -> None:
187        for has_bias in (True, False):
188            self._test(Conv2d(bias=has_bias))
189
190    def test_fp32_conv2d_permute(self) -> None:
191        for perm_order in list(itertools.permutations([0, 1, 2, 3])):
192            self._test(Conv2dPermute(perm_order))
193
194    def test_qs8_conv2d_test(self) -> None:
195        for has_bias in (True, False):
196            self._test(
197                Conv2d(bias=has_bias), quant_config=get_symmetric_quantization_config()
198            )
199
200    def test_qs8_conv2d_per_channel(self) -> None:
201        self._test(
202            Conv2d(),
203            quant_config=get_symmetric_quantization_config(is_per_channel=True),
204        )
205
206    def test_fp32_conv2d_seq(self) -> None:
207        self._test(Conv2dSeq(), conv_count=2)
208
209    def test_qs8_conv2d_seq(self) -> None:
210        self._test(
211            Conv2dSeq(), conv_count=2, quant_config=get_symmetric_quantization_config()
212        )
213
214    def test_fp32_conv2d_single_int_params(self):
215        self._test(
216            Conv2d(
217                kernel_size=3,
218                stride=2,
219                padding="valid",
220                dilation=1,
221            )
222        )
223
224    def test_fp32_conv2d_depthwise(self):
225        # Depthwise Convolution Requirements:
226        # - Groups must equal In Channels
227        # - Out Channels must be a positive multiple of In Channels
228        self._test(Conv2d(groups=2, in_channels=2, out_channels=6))
229
230    def test_qs8_conv2d_depthwise(self):
231        self._test(
232            Conv2d(groups=2, in_channels=2, out_channels=6),
233            quant_config=get_symmetric_quantization_config(),
234        )
235
236    def test_fp32_conv2d_bn(self):
237        class Conv2dBatchNorm(torch.nn.Module):
238            def __init__(self, in_features: int, out_features: int, kernel_size):
239                super().__init__()
240                self.conv2d = torch.nn.Conv2d(in_features, out_features, kernel_size)
241                self.bn = randomize_bn(out_features)
242                self.in_features = in_features
243                self.kernel_size = kernel_size
244
245            def forward(self, x):
246                y = self.conv2d(x)
247                y = self.bn(y)
248                return y
249
250            def get_inputs(self):
251                return (
252                    torch.randn(
253                        2,
254                        self.in_features,
255                        self.kernel_size[0] * 2,
256                        self.kernel_size[1] * 2,
257                    ),
258                )
259
260        self._test(Conv2dBatchNorm(in_features=2, out_features=2, kernel_size=(2, 2)))
261
262    def test_fp32_conv2d_bn_hardtanh_mean_sequence(self):
263        """
264        This test makes sure that we can fuse batchnorm and hardtanh
265        even with inserting copy nodes at some spots in the graph to change
266        memory format
267        """
268
269        class Conv2dBatchNormHardTanh(torch.nn.Module):
270            def __init__(self, in_channels: int, out_channels: int, kernel_size):
271                super().__init__()
272                self.conv = torch.nn.Conv2d(
273                    in_channels=in_channels,
274                    out_channels=out_channels,
275                    kernel_size=kernel_size,
276                    padding=[1, 1],
277                    stride=[2, 2],
278                )
279                self.in_channels = in_channels
280                self.native_batchnorm = torch.nn.BatchNorm2d(out_channels)
281                self.hardtanh = torch.nn.Hardtanh(min_val=0, max_val=6)
282
283            def forward(self, x):
284                x = self.conv(x)
285                x = self.native_batchnorm(x)
286                x = self.hardtanh(x)
287                x = torch.mean(x, (-1, -2), keepdim=True)
288                return x
289
290            def get_inputs(self):
291                return (torch.randn(2, self.in_channels, 8, 8),)
292
293        self._test(
294            Conv2dBatchNormHardTanh(in_channels=2, out_channels=1, kernel_size=(2, 2))
295        )
296
297    def test_qs8_conv2d_bn(self):
298        self._test(
299            Conv2dBatchNorm(),
300            quant_config=get_symmetric_quantization_config(),
301            conv_count=2,
302        )
303
304    def test_qs8_conv2d_relu(self):
305        class ConvReLU(torch.nn.Module):
306            def __init__(self):
307                super().__init__()
308                self.conv1 = torch.nn.Conv2d(
309                    2,
310                    2,
311                    (2, 2),
312                    bias=False,
313                    padding=[1, 1],
314                    stride=[4, 4],
315                )
316                self.relu = torch.nn.ReLU()
317
318            def forward(self, x):
319                y = self.conv1(x)
320                y = self.relu(y)
321                return y
322
323            def get_inputs(self):
324                return (torch.randn(2, 2, 4, 4),)
325
326        self._test(
327            ConvReLU(),
328            quant_config=get_symmetric_quantization_config(),
329        )
330
331    def test_qs8_conv2d_dw_relu(self):
332        # Depthwise Convolution Requirements:
333        # - Groups must equal In Channels
334        # - Out Channels must be a positive multiple of In Channels
335        groups = 2
336        stride = [2, 2]
337        padding = [1, 1]
338        dilation = [1, 1]
339        in_channels = groups
340        out_channels = 3 * in_channels
341        width = 8
342        height = 8
343        batches = 1
344
345        class ModelConvReLU(torch.nn.Module):
346            def __init__(self):
347                super().__init__()
348                self.conv1 = torch.nn.Conv2d(
349                    in_channels=in_channels,
350                    out_channels=out_channels,
351                    kernel_size=(3, 3),
352                    stride=stride,
353                    padding=padding,
354                    groups=groups,
355                    dilation=dilation,
356                    bias=True,
357                )
358                self.relu = torch.nn.ReLU()
359
360            def forward(self, x):
361                y = self.conv1(x)
362                y = self.relu(y)
363                return y
364
365            def get_inputs(self):
366                return (torch.randn(batches, in_channels, height, width) * 11,)
367
368        for per_channel_quant in (False, True):
369            model = ModelConvReLU()
370            self._test(
371                model,
372                quant_config=get_symmetric_quantization_config(
373                    is_per_channel=per_channel_quant
374                ),
375            )
376
377    def test_qs8_conv2d_relu_seq(self):
378        class ConvReLUSeq(torch.nn.Module):
379            def __init__(self):
380                super().__init__()
381                self.model = torch.nn.Sequential(
382                    torch.nn.Conv2d(1, 1, 1),
383                    torch.nn.ReLU(),
384                    torch.nn.Conv2d(1, 64, 1),
385                    torch.nn.ReLU(),
386                )
387
388            def forward(self, x):
389                return self.model(x)
390
391            def get_inputs(self):
392                return (torch.randn(1, 1, 1, 1),)
393
394        self._test(
395            ConvReLUSeq(),
396            quant_config=get_symmetric_quantization_config(),
397            conv_count=2,
398        )
399
400    def test_qs8_conv2d_relu_multi_users(self):
401        class Conv2dReluMultiUsers(torch.nn.Module):
402            def __init__(self):
403                super().__init__()
404                self.conv1 = torch.nn.Conv2d(1, 1, 1)
405                self.conv2 = torch.nn.Conv2d(1, 64, 1)
406                self.relu = torch.nn.ReLU()
407
408            def forward(self, x):
409                conv_default = self.conv1(x)
410                y = self.relu(conv_default)
411                conv_default_2 = self.conv2(y)
412                return conv_default + conv_default_2
413
414            def get_inputs(self):
415                return (torch.randn(1, 1, 1, 1),)
416
417        self._test(
418            Conv2dReluMultiUsers(),
419            quant_config=get_symmetric_quantization_config(),
420            conv_count=2,
421        )
422