xref: /aosp_15_r20/external/executorch/backends/xnnpack/test/ops/maxpool2d.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import unittest
8
9import torch
10from executorch.backends.xnnpack.test.tester import Tester
11
12
13class TestMaxPool2d(unittest.TestCase):
14    class MaxPool2d(torch.nn.Module):
15        def __init__(self, kernel_size=3, stride=1, padding=0, dilation=1):
16            super().__init__()
17            self.max_pool2d_module = torch.nn.MaxPool2d(
18                kernel_size=kernel_size,
19                stride=stride,
20                padding=padding,
21                dilation=dilation,
22            )
23
24        def forward(self, x):
25            return self.max_pool2d_module(x)
26
27    class MaxPool2dUnsupported(torch.nn.Module):
28        def __init__(self, kernel_size=3, stride=1, padding=0, dilation=1):
29            super().__init__()
30            self.max_pool2d_module = torch.nn.MaxPool2d(
31                kernel_size=kernel_size,
32                stride=stride,
33                padding=padding,
34                dilation=dilation,
35                return_indices=True,
36            )
37
38        def forward(self, x):
39            return self.max_pool2d_module(x)[1]
40
41    class MaxPool2dUnsupportedCeilMode(torch.nn.Module):
42        def __init__(self):
43            super().__init__()
44            self.max_pool2d_module = torch.nn.MaxPool2d(2, stride=2, ceil_mode=True)
45
46        def forward(self, x):
47            return self.max_pool2d_module(x)
48
49    def _test_maxpool2d(self, inputs):
50        """
51        Note that the export process generates aten.max_pool2d_with_indices. The remove_getitem_op
52        pass transforms it into aten.max_pool2d (if supported).
53        """
54        (
55            Tester(self.MaxPool2d(3, 1, 0, 1), inputs)
56            .export()
57            .check_count({"torch.ops.aten.max_pool2d.default": 1})
58            .to_edge_transform_and_lower()
59            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
60            .check_not(
61                [
62                    "executorch_exir_dialects_edge__ops_aten_max_pool2d_with_indices_default"
63                ]
64            )
65            .to_executorch()
66            .serialize()
67            .run_method_and_compare_outputs()
68        )
69
70    def test_fp16_maxpool2d(self):
71        inputs = (torch.randn(4, 3, 24, 24).to(torch.float16),)
72        self._test_maxpool2d(inputs)
73
74    def test_fp32_maxpool2d(self):
75        inputs = (torch.randn(4, 3, 24, 24),)
76        self._test_maxpool2d(inputs)
77
78    def test_fp32_maxpool2d_unsupported(self):
79        """
80        MaxPool2d with return_indices is not generally supported (see maxpool2d_with_indices constraint).
81        """
82        inputs = (torch.randn(4, 3, 24, 24),)
83        (
84            Tester(self.MaxPool2dUnsupported(), inputs)
85            .export()
86            .check_count({"torch.ops.aten.max_pool2d_with_indices.default": 1})
87            .to_edge_transform_and_lower()
88            # We expect it not be be delegated.
89            .check_count(
90                {
91                    "executorch_exir_dialects_edge__ops_aten_max_pool2d_with_indices_default": 1
92                }
93            )
94        )
95
96    def test_fp32_maxpool2d_unsupported_ceilmode(self):
97        """
98        MaxPool2d with ceil mode is not generally supported (see maxpool2d constraint).
99        """
100        inputs = (torch.randn(1, 32, 23, 23),)
101        (
102            Tester(self.MaxPool2dUnsupportedCeilMode(), inputs)
103            .export()
104            .check_count({"torch.ops.aten.max_pool2d.default": 1})
105            .to_edge_transform_and_lower()
106            # We expect it not be be delegated.
107            .check_count({"torch.ops.higher_order.executorch_call_delegate": 0})
108            .check_count(
109                {
110                    "executorch_exir_dialects_edge__ops_aten_max_pool2d_with_indices_default": 1
111                }
112            )
113            .to_executorch()
114            .serialize()
115            .run_method_and_compare_outputs()
116        )
117
118    def test_qs8_maxpool2d(self):
119        class MaxPool(torch.nn.Module):
120            def __init__(self, maxpool_params):
121                super().__init__()
122                self.max = torch.nn.MaxPool2d(*maxpool_params)
123
124            def forward(self, x):
125                z = x + x
126                return self.max(z)
127
128        # Parameter order is kernel_size, stride, padding.
129        for maxpool_params in [(4,), (4, 2), (4, 2, 2)]:
130            inputs = (torch.randn(1, 2, 8, 8),)
131            (
132                Tester(MaxPool(maxpool_params), inputs)
133                .quantize()
134                .export()
135                .check_count({"torch.ops.aten.max_pool2d.default": 1})
136                .check(["torch.ops.quantized_decomposed"])
137                .to_edge_transform_and_lower()
138                .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
139                .check_not(
140                    [
141                        "executorch_exir_dialects_edge__ops_aten_max_pool2d_with_indices_default",
142                        "torch.ops.quantized_decomposed",
143                    ]
144                )
145                .to_executorch()
146                .serialize()
147                .run_method_and_compare_outputs()
148            )
149