1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import unittest 8 9import torch 10from executorch.backends.xnnpack.test.tester import Tester 11 12 13class TestMaxPool2d(unittest.TestCase): 14 class MaxPool2d(torch.nn.Module): 15 def __init__(self, kernel_size=3, stride=1, padding=0, dilation=1): 16 super().__init__() 17 self.max_pool2d_module = torch.nn.MaxPool2d( 18 kernel_size=kernel_size, 19 stride=stride, 20 padding=padding, 21 dilation=dilation, 22 ) 23 24 def forward(self, x): 25 return self.max_pool2d_module(x) 26 27 class MaxPool2dUnsupported(torch.nn.Module): 28 def __init__(self, kernel_size=3, stride=1, padding=0, dilation=1): 29 super().__init__() 30 self.max_pool2d_module = torch.nn.MaxPool2d( 31 kernel_size=kernel_size, 32 stride=stride, 33 padding=padding, 34 dilation=dilation, 35 return_indices=True, 36 ) 37 38 def forward(self, x): 39 return self.max_pool2d_module(x)[1] 40 41 class MaxPool2dUnsupportedCeilMode(torch.nn.Module): 42 def __init__(self): 43 super().__init__() 44 self.max_pool2d_module = torch.nn.MaxPool2d(2, stride=2, ceil_mode=True) 45 46 def forward(self, x): 47 return self.max_pool2d_module(x) 48 49 def _test_maxpool2d(self, inputs): 50 """ 51 Note that the export process generates aten.max_pool2d_with_indices. The remove_getitem_op 52 pass transforms it into aten.max_pool2d (if supported). 53 """ 54 ( 55 Tester(self.MaxPool2d(3, 1, 0, 1), inputs) 56 .export() 57 .check_count({"torch.ops.aten.max_pool2d.default": 1}) 58 .to_edge_transform_and_lower() 59 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 60 .check_not( 61 [ 62 "executorch_exir_dialects_edge__ops_aten_max_pool2d_with_indices_default" 63 ] 64 ) 65 .to_executorch() 66 .serialize() 67 .run_method_and_compare_outputs() 68 ) 69 70 def test_fp16_maxpool2d(self): 71 inputs = (torch.randn(4, 3, 24, 24).to(torch.float16),) 72 self._test_maxpool2d(inputs) 73 74 def test_fp32_maxpool2d(self): 75 inputs = (torch.randn(4, 3, 24, 24),) 76 self._test_maxpool2d(inputs) 77 78 def test_fp32_maxpool2d_unsupported(self): 79 """ 80 MaxPool2d with return_indices is not generally supported (see maxpool2d_with_indices constraint). 81 """ 82 inputs = (torch.randn(4, 3, 24, 24),) 83 ( 84 Tester(self.MaxPool2dUnsupported(), inputs) 85 .export() 86 .check_count({"torch.ops.aten.max_pool2d_with_indices.default": 1}) 87 .to_edge_transform_and_lower() 88 # We expect it not be be delegated. 89 .check_count( 90 { 91 "executorch_exir_dialects_edge__ops_aten_max_pool2d_with_indices_default": 1 92 } 93 ) 94 ) 95 96 def test_fp32_maxpool2d_unsupported_ceilmode(self): 97 """ 98 MaxPool2d with ceil mode is not generally supported (see maxpool2d constraint). 99 """ 100 inputs = (torch.randn(1, 32, 23, 23),) 101 ( 102 Tester(self.MaxPool2dUnsupportedCeilMode(), inputs) 103 .export() 104 .check_count({"torch.ops.aten.max_pool2d.default": 1}) 105 .to_edge_transform_and_lower() 106 # We expect it not be be delegated. 107 .check_count({"torch.ops.higher_order.executorch_call_delegate": 0}) 108 .check_count( 109 { 110 "executorch_exir_dialects_edge__ops_aten_max_pool2d_with_indices_default": 1 111 } 112 ) 113 .to_executorch() 114 .serialize() 115 .run_method_and_compare_outputs() 116 ) 117 118 def test_qs8_maxpool2d(self): 119 class MaxPool(torch.nn.Module): 120 def __init__(self, maxpool_params): 121 super().__init__() 122 self.max = torch.nn.MaxPool2d(*maxpool_params) 123 124 def forward(self, x): 125 z = x + x 126 return self.max(z) 127 128 # Parameter order is kernel_size, stride, padding. 129 for maxpool_params in [(4,), (4, 2), (4, 2, 2)]: 130 inputs = (torch.randn(1, 2, 8, 8),) 131 ( 132 Tester(MaxPool(maxpool_params), inputs) 133 .quantize() 134 .export() 135 .check_count({"torch.ops.aten.max_pool2d.default": 1}) 136 .check(["torch.ops.quantized_decomposed"]) 137 .to_edge_transform_and_lower() 138 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 139 .check_not( 140 [ 141 "executorch_exir_dialects_edge__ops_aten_max_pool2d_with_indices_default", 142 "torch.ops.quantized_decomposed", 143 ] 144 ) 145 .to_executorch() 146 .serialize() 147 .run_method_and_compare_outputs() 148 ) 149