1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import itertools 8import unittest 9from typing import Optional 10 11import torch 12from executorch.backends.xnnpack.test.test_xnnpack_utils import randomize_bn 13from executorch.backends.xnnpack.test.tester import Quantize, Tester 14from torch.ao.quantization.quantizer.xnnpack_quantizer import ( 15 get_symmetric_quantization_config, 16) 17from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import QuantizationConfig 18 19 20class Conv2d(torch.nn.Module): 21 def __init__( 22 self, 23 in_channels=2, 24 out_channels=1, 25 kernel_size=(3, 3), 26 stride=(2, 2), 27 padding=(1, 1), 28 dilation=(1, 1), 29 groups=1, 30 bias=True, 31 padding_mode="zeros", 32 batches=1, 33 width=8, 34 height=8, 35 dtype=torch.float, 36 ): 37 super().__init__() 38 self.batches = batches 39 self.width = width 40 self.height = height 41 self.in_channels = in_channels 42 self.dtype = dtype 43 44 self.conv = torch.nn.Conv2d( 45 in_channels=in_channels, 46 out_channels=out_channels, 47 kernel_size=kernel_size, 48 stride=stride, 49 padding=padding, 50 dilation=dilation, 51 groups=groups, 52 bias=bias, 53 padding_mode=padding_mode, 54 ).to(dtype) 55 56 def forward(self, x): 57 return self.conv(x) 58 59 def get_inputs(self): 60 return ( 61 torch.randn(self.batches, self.in_channels, self.height, self.width).to( 62 self.dtype 63 ), 64 ) 65 66 67class Conv2dSeq(torch.nn.Module): 68 def __init__(self): 69 super().__init__() 70 self.first = torch.nn.Conv2d( 71 in_channels=1, 72 out_channels=3, 73 kernel_size=(3, 3), 74 padding=1, 75 bias=False, 76 ) 77 self.second = torch.nn.Conv2d( 78 in_channels=3, 79 out_channels=2, 80 kernel_size=(3, 3), 81 padding=1, 82 bias=False, 83 ) 84 85 def forward(self, x): 86 y = self.first(x) 87 return self.second(y) 88 89 def get_inputs(self): 90 return (torch.randn(1, 1, 3, 3),) 91 92 93class Conv2dBatchNorm(torch.nn.Module): 94 def __init__(self): 95 super().__init__() 96 self.conv1 = torch.nn.Conv2d( 97 2, 98 2, 99 (2, 2), 100 bias=False, 101 padding=[1, 1], 102 stride=[4, 4], 103 ) 104 self.bn = randomize_bn(2) 105 self.hardtanh = torch.nn.Hardtanh() 106 self.conv2 = torch.nn.Conv2d( 107 2, 108 2, 109 (2, 2), 110 bias=False, 111 padding=[1, 1], 112 stride=[4, 4], 113 ) 114 115 def forward(self, x): 116 y = self.conv1(x) 117 y = self.bn(y) 118 y = self.hardtanh(y) 119 y = self.conv2(y) 120 y = self.bn(y) 121 y = self.hardtanh(y) 122 return y 123 124 def get_inputs(self): 125 return (torch.randn(2, 2, 4, 4),) 126 127 128class Conv2dPermute(torch.nn.Module): 129 def __init__(self, permute_order): 130 super().__init__() 131 self.conv = torch.nn.Conv2d( 132 2, 133 2, 134 (2, 2), 135 bias=False, 136 padding=[2, 2], 137 stride=[2, 2], 138 ) 139 self.permute_order = permute_order 140 141 def forward(self, x): 142 result = self.conv(x) 143 channels_last = torch.permute(result, self.permute_order) 144 return channels_last 145 146 def get_inputs(self): 147 return (torch.randn(2, 2, 4, 4),) 148 149 150class TestConv2d(unittest.TestCase): 151 def _test( 152 self, 153 m: torch.nn.Module, 154 quant_config: Optional[QuantizationConfig] = None, 155 conv_count=1, 156 dtype: torch.dtype = torch.float, 157 ): 158 # pyre-fixme[29]: `Union[torch._tensor.Tensor, 159 # torch.nn.modules.module.Module]` is not a function. 160 tester = Tester(m.eval(), m.get_inputs()) 161 162 if quant_config is not None: 163 tester = tester.quantize(Quantize(quantization_config=quant_config)) 164 tester.check(["torch.ops.quantized_decomposed"]) 165 166 ( 167 tester.export() 168 .check_count({"torch.ops.aten.conv2d": conv_count}) 169 .to_edge_transform_and_lower() 170 .check_not(["executorch_exir_dialects_edge__ops_aten_convolution_default"]) 171 .check_not( 172 [ 173 "executorch_exir_dialects_edge__ops__native_batch_norm_legit_no_training_default" 174 ] 175 ) 176 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 177 .to_executorch() 178 .serialize() 179 .run_method_and_compare_outputs(qtol=1) 180 ) 181 182 def test_fp16_conv2d(self) -> None: 183 for has_bias in (True, False): 184 self._test(Conv2d(bias=has_bias, dtype=torch.float16)) 185 186 def test_fp32_conv2d(self) -> None: 187 for has_bias in (True, False): 188 self._test(Conv2d(bias=has_bias)) 189 190 def test_fp32_conv2d_permute(self) -> None: 191 for perm_order in list(itertools.permutations([0, 1, 2, 3])): 192 self._test(Conv2dPermute(perm_order)) 193 194 def test_qs8_conv2d_test(self) -> None: 195 for has_bias in (True, False): 196 self._test( 197 Conv2d(bias=has_bias), quant_config=get_symmetric_quantization_config() 198 ) 199 200 def test_qs8_conv2d_per_channel(self) -> None: 201 self._test( 202 Conv2d(), 203 quant_config=get_symmetric_quantization_config(is_per_channel=True), 204 ) 205 206 def test_fp32_conv2d_seq(self) -> None: 207 self._test(Conv2dSeq(), conv_count=2) 208 209 def test_qs8_conv2d_seq(self) -> None: 210 self._test( 211 Conv2dSeq(), conv_count=2, quant_config=get_symmetric_quantization_config() 212 ) 213 214 def test_fp32_conv2d_single_int_params(self): 215 self._test( 216 Conv2d( 217 kernel_size=3, 218 stride=2, 219 padding="valid", 220 dilation=1, 221 ) 222 ) 223 224 def test_fp32_conv2d_depthwise(self): 225 # Depthwise Convolution Requirements: 226 # - Groups must equal In Channels 227 # - Out Channels must be a positive multiple of In Channels 228 self._test(Conv2d(groups=2, in_channels=2, out_channels=6)) 229 230 def test_qs8_conv2d_depthwise(self): 231 self._test( 232 Conv2d(groups=2, in_channels=2, out_channels=6), 233 quant_config=get_symmetric_quantization_config(), 234 ) 235 236 def test_fp32_conv2d_bn(self): 237 class Conv2dBatchNorm(torch.nn.Module): 238 def __init__(self, in_features: int, out_features: int, kernel_size): 239 super().__init__() 240 self.conv2d = torch.nn.Conv2d(in_features, out_features, kernel_size) 241 self.bn = randomize_bn(out_features) 242 self.in_features = in_features 243 self.kernel_size = kernel_size 244 245 def forward(self, x): 246 y = self.conv2d(x) 247 y = self.bn(y) 248 return y 249 250 def get_inputs(self): 251 return ( 252 torch.randn( 253 2, 254 self.in_features, 255 self.kernel_size[0] * 2, 256 self.kernel_size[1] * 2, 257 ), 258 ) 259 260 self._test(Conv2dBatchNorm(in_features=2, out_features=2, kernel_size=(2, 2))) 261 262 def test_fp32_conv2d_bn_hardtanh_mean_sequence(self): 263 """ 264 This test makes sure that we can fuse batchnorm and hardtanh 265 even with inserting copy nodes at some spots in the graph to change 266 memory format 267 """ 268 269 class Conv2dBatchNormHardTanh(torch.nn.Module): 270 def __init__(self, in_channels: int, out_channels: int, kernel_size): 271 super().__init__() 272 self.conv = torch.nn.Conv2d( 273 in_channels=in_channels, 274 out_channels=out_channels, 275 kernel_size=kernel_size, 276 padding=[1, 1], 277 stride=[2, 2], 278 ) 279 self.in_channels = in_channels 280 self.native_batchnorm = torch.nn.BatchNorm2d(out_channels) 281 self.hardtanh = torch.nn.Hardtanh(min_val=0, max_val=6) 282 283 def forward(self, x): 284 x = self.conv(x) 285 x = self.native_batchnorm(x) 286 x = self.hardtanh(x) 287 x = torch.mean(x, (-1, -2), keepdim=True) 288 return x 289 290 def get_inputs(self): 291 return (torch.randn(2, self.in_channels, 8, 8),) 292 293 self._test( 294 Conv2dBatchNormHardTanh(in_channels=2, out_channels=1, kernel_size=(2, 2)) 295 ) 296 297 def test_qs8_conv2d_bn(self): 298 self._test( 299 Conv2dBatchNorm(), 300 quant_config=get_symmetric_quantization_config(), 301 conv_count=2, 302 ) 303 304 def test_qs8_conv2d_relu(self): 305 class ConvReLU(torch.nn.Module): 306 def __init__(self): 307 super().__init__() 308 self.conv1 = torch.nn.Conv2d( 309 2, 310 2, 311 (2, 2), 312 bias=False, 313 padding=[1, 1], 314 stride=[4, 4], 315 ) 316 self.relu = torch.nn.ReLU() 317 318 def forward(self, x): 319 y = self.conv1(x) 320 y = self.relu(y) 321 return y 322 323 def get_inputs(self): 324 return (torch.randn(2, 2, 4, 4),) 325 326 self._test( 327 ConvReLU(), 328 quant_config=get_symmetric_quantization_config(), 329 ) 330 331 def test_qs8_conv2d_dw_relu(self): 332 # Depthwise Convolution Requirements: 333 # - Groups must equal In Channels 334 # - Out Channels must be a positive multiple of In Channels 335 groups = 2 336 stride = [2, 2] 337 padding = [1, 1] 338 dilation = [1, 1] 339 in_channels = groups 340 out_channels = 3 * in_channels 341 width = 8 342 height = 8 343 batches = 1 344 345 class ModelConvReLU(torch.nn.Module): 346 def __init__(self): 347 super().__init__() 348 self.conv1 = torch.nn.Conv2d( 349 in_channels=in_channels, 350 out_channels=out_channels, 351 kernel_size=(3, 3), 352 stride=stride, 353 padding=padding, 354 groups=groups, 355 dilation=dilation, 356 bias=True, 357 ) 358 self.relu = torch.nn.ReLU() 359 360 def forward(self, x): 361 y = self.conv1(x) 362 y = self.relu(y) 363 return y 364 365 def get_inputs(self): 366 return (torch.randn(batches, in_channels, height, width) * 11,) 367 368 for per_channel_quant in (False, True): 369 model = ModelConvReLU() 370 self._test( 371 model, 372 quant_config=get_symmetric_quantization_config( 373 is_per_channel=per_channel_quant 374 ), 375 ) 376 377 def test_qs8_conv2d_relu_seq(self): 378 class ConvReLUSeq(torch.nn.Module): 379 def __init__(self): 380 super().__init__() 381 self.model = torch.nn.Sequential( 382 torch.nn.Conv2d(1, 1, 1), 383 torch.nn.ReLU(), 384 torch.nn.Conv2d(1, 64, 1), 385 torch.nn.ReLU(), 386 ) 387 388 def forward(self, x): 389 return self.model(x) 390 391 def get_inputs(self): 392 return (torch.randn(1, 1, 1, 1),) 393 394 self._test( 395 ConvReLUSeq(), 396 quant_config=get_symmetric_quantization_config(), 397 conv_count=2, 398 ) 399 400 def test_qs8_conv2d_relu_multi_users(self): 401 class Conv2dReluMultiUsers(torch.nn.Module): 402 def __init__(self): 403 super().__init__() 404 self.conv1 = torch.nn.Conv2d(1, 1, 1) 405 self.conv2 = torch.nn.Conv2d(1, 64, 1) 406 self.relu = torch.nn.ReLU() 407 408 def forward(self, x): 409 conv_default = self.conv1(x) 410 y = self.relu(conv_default) 411 conv_default_2 = self.conv2(y) 412 return conv_default + conv_default_2 413 414 def get_inputs(self): 415 return (torch.randn(1, 1, 1, 1),) 416 417 self._test( 418 Conv2dReluMultiUsers(), 419 quant_config=get_symmetric_quantization_config(), 420 conv_count=2, 421 ) 422