1# Owner(s): ["oncall: cpu inductor"] 2import contextlib 3import copy 4import itertools 5import unittest 6 7import torch 8import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq 9 10from torch._dynamo import config as dynamo_config 11from torch._dynamo.utils import counters 12from torch._export import capture_pre_autograd_graph 13from torch._inductor import config, metrics 14from torch._inductor.test_case import run_tests, TestCase 15from torch._inductor.utils import run_and_get_code 16from torch.ao.quantization.quantize_pt2e import ( 17 convert_pt2e, 18 prepare_pt2e, 19 prepare_qat_pt2e, 20) 21from torch.ao.quantization.quantizer.x86_inductor_quantizer import X86InductorQuantizer 22from torch.nn import functional as F 23from torch.testing._internal.common_quantization import ( 24 skipIfNoDynamoSupport, 25 skipIfNoONEDNN, 26 skipIfNoONEDNNBF16, 27) 28from torch.testing._internal.common_utils import IS_LINUX, skipIfRocm, TEST_MKL 29from torch.testing._internal.inductor_utils import _check_has_dynamic_shape, HAS_CPU 30 31 32# The dict value is match_nodes(computation_op+unary_op) 33 34unary_list = { 35 torch.nn.ReLU(): 2, 36 torch.nn.Sigmoid(): 2, 37 torch.nn.Tanh(): 2, 38 torch.nn.Hardswish(): 6, 39 torch.nn.LeakyReLU(0.1, inplace=False): 4, 40 torch.nn.Hardtanh(min_val=-0.5, max_val=4, inplace=False): 3, 41 torch.nn.Hardtanh(min_val=-0.5, max_val=float("inf"), inplace=False): 3, 42 torch.nn.GELU(approximate="none"): 6, 43 torch.nn.GELU(approximate="tanh"): 10, 44 torch.nn.ReLU6(): 3, 45 torch.nn.SiLU(): 3, 46 torch.nn.Hardsigmoid(): 5, 47} 48 49non_decomposed_unary_list = [ 50 torch.nn.ReLU, 51 torch.nn.Sigmoid, 52 torch.nn.Tanh, 53] 54 55# The dict value is (match_count, match_nodes, inplace) 56binary_list = { 57 lambda x, y: torch.add(x, y): (1, 2, False), # call_function 58 lambda x, y: torch.add(y, x): (1, 2, False), # call_function 59 lambda x, y: x.add(y): (1, 2, False), # call_method 60 lambda x, y: x.add_(y): (1, 2, True), # call_method 61 lambda x, y: torch.sub(x, y): (1, 2, False), # call_function 62 lambda x, y: x.sub(y): (1, 2, False), # call_method 63 lambda x, y: x.sub_(y): (1, 2, True), # call_method 64} 65 66quantization_add_fn_list = [ 67 lambda x, y: torch.add(x, y), 68 lambda x, y: x.add(y), 69] 70 71quantization_inplace_add_fn_list = [ 72 lambda x, y: x.add_(y), 73] 74 75 76def get_default_quantizer(is_qat, is_dynamic): 77 quantizer = X86InductorQuantizer() 78 quantizer.set_global( 79 xiq.get_default_x86_inductor_quantization_config( 80 is_qat=is_qat, is_dynamic=is_dynamic 81 ) 82 ) 83 return quantizer 84 85 86def cal_conv_generated_kernel_number(mod, input, dtype): 87 # this function is to decide how many kernels are generated 88 # while testing conv2d/3d/deconv2d 89 # the assumption is: 90 # (1) There will be a to_dtype kernel for input for lp 91 # (2) inductor always use channe_last format, there will 92 # be a to_channel_last format for input 93 # (3) to_dtype and to_channel_last for input can be fused 94 # (4) inductor always get channel last format from mkldnn_conv_pointwise(binary), 95 # and force the output to have same stride with eager. 96 # So there will be a to_contiguous for output if eager output is contiguouse 97 mod = copy.deepcopy(mod) 98 input = input.clone() 99 if dtype == torch.float32: 100 maybe_autocast = contextlib.nullcontext() 101 else: 102 maybe_autocast = torch.cpu.amp.autocast(dtype=dtype) 103 with torch.no_grad(), maybe_autocast: 104 output = mod(input) 105 input_kernel, output_kernel = 0, 0 106 if ( 107 input.is_contiguous(memory_format=torch.contiguous_format) 108 or dtype != torch.float32 109 ): 110 input_kernel = 1 111 if output.is_contiguous(memory_format=torch.contiguous_format): 112 output_kernel = 1 113 return input_kernel + output_kernel 114 115 116@config.patch({"freezing": True}) 117class TestPatternMatcherBase(TestCase): 118 def _check_unary_is_decomposed(self, unary_fn): 119 return not any( 120 isinstance(unary_fn, fn) 121 for fn in [torch.nn.ReLU, torch.nn.Sigmoid, torch.nn.Tanh] 122 ) 123 124 def _clone_inputs(self, inputs): 125 def clone(x): 126 if not isinstance(x, torch.Tensor): 127 return x 128 return x.clone() 129 130 return tuple(clone(x) for x in inputs) 131 132 def _generate_qdq_quantized_model( 133 self, mod, inputs, is_qat=False, is_dynamic=False, quantizer=None 134 ): 135 maybe_no_grad = contextlib.nullcontext() if is_qat else torch.no_grad() 136 with maybe_no_grad: 137 export_model = capture_pre_autograd_graph( 138 mod, 139 inputs, 140 ) 141 quantizer = ( 142 quantizer if quantizer else get_default_quantizer(is_qat, is_dynamic) 143 ) 144 prepare_model = ( 145 prepare_qat_pt2e(export_model, quantizer) 146 if is_qat 147 else prepare_pt2e(export_model, quantizer) 148 ) 149 prepare_model(*inputs) 150 convert_model = convert_pt2e(prepare_model) 151 torch.ao.quantization.move_exported_model_to_eval(convert_model) 152 return convert_model 153 154 def _test_common( 155 self, 156 mod, 157 inputs, 158 matcher_count=None, 159 matcher_nodes=None, 160 atol=1e-5, 161 rtol=1.3e-6, 162 check_autocast=torch.float32, 163 check_quantization=False, 164 is_qat=False, 165 matcher_check_fn=None, 166 dtype=None, 167 is_dynamic=False, 168 quantizer=None, 169 ): 170 counters.clear() 171 torch._dynamo.reset() 172 assert matcher_check_fn is not None or ( 173 matcher_count is not None and matcher_nodes is not None 174 ) 175 if ( 176 check_autocast == torch.bfloat16 177 and torch.ops.mkldnn._is_mkldnn_bf16_supported() 178 ): 179 maybe_autocast = torch.cpu.amp.autocast(dtype=torch.bfloat16) 180 atol, rtol = 1e-2, 1e-2 181 elif ( 182 check_autocast == torch.float16 183 and torch.ops.mkldnn._is_mkldnn_fp16_supported() 184 ): 185 maybe_autocast = torch.cpu.amp.autocast(dtype=torch.float16) 186 atol, rtol = 1e-2, 1e-2 187 else: 188 assert check_autocast == torch.float32 189 maybe_autocast = contextlib.nullcontext() 190 191 if check_quantization: 192 convert_model = self._generate_qdq_quantized_model( 193 mod, inputs, is_qat, is_dynamic, quantizer 194 ) 195 with torch.no_grad(), maybe_autocast: 196 _ = torch.compile(convert_model)(*inputs) 197 if matcher_count is not None: 198 self.assertEqual( 199 counters["inductor"]["pattern_matcher_count"], matcher_count 200 ) 201 if matcher_nodes is not None: 202 self.assertEqual( 203 counters["inductor"]["pattern_matcher_nodes"], 204 matcher_nodes, 205 ) 206 if matcher_check_fn is not None: 207 matcher_check_fn() 208 else: 209 with torch.no_grad(), maybe_autocast: 210 clone_inputs = self._clone_inputs(inputs) 211 expected = mod(*inputs) 212 actual = torch.compile(mod)(*clone_inputs) 213 torch.testing.assert_close(actual, expected, atol=atol, rtol=rtol) 214 if matcher_count is not None: 215 self.assertEqual( 216 counters["inductor"]["pattern_matcher_count"], matcher_count 217 ) 218 if matcher_nodes is not None: 219 self.assertEqual( 220 counters["inductor"]["pattern_matcher_nodes"], 221 matcher_nodes, 222 ) 223 if matcher_check_fn is not None: 224 matcher_check_fn() 225 226 def _test_code_common( 227 self, 228 mod, 229 inputs, 230 include_ops, 231 exclude_ops, 232 atol=1e-5, 233 rtol=1.3e-6, 234 check_quantization=False, 235 check_dynamic=None, 236 num_include_ops=None, 237 ): 238 with torch.no_grad(): 239 clone_inputs = self._clone_inputs(inputs) 240 if check_quantization: 241 mod = self._generate_qdq_quantized_model(mod, inputs) 242 expected = mod(*inputs) 243 actual, (source_code,) = run_and_get_code( 244 torch.compile(mod, fullgraph=True, dynamic=check_dynamic), 245 *clone_inputs, 246 ) 247 for op in include_ops: 248 self.assertIn(op, source_code) 249 if num_include_ops is not None: 250 assert len(include_ops) == len(num_include_ops) 251 for i in range(len(include_ops)): 252 self.assertEqual( 253 source_code.count(include_ops[i]), num_include_ops[i] 254 ) 255 for op in exclude_ops: 256 self.assertNotIn(op, source_code) 257 if check_dynamic is not None: 258 _check_has_dynamic_shape(self, source_code) 259 if not check_quantization: 260 # Skip due to reduce range setting for Quantization on preCI system. 261 torch.testing.assert_close(actual, expected, atol=atol, rtol=rtol) 262 263 264class TestPatternMatcher(TestPatternMatcherBase): 265 def _test_conv_unary_cpu_base(self, dim=4): 266 assert dim == 4 or dim == 5 267 268 class M(torch.nn.Module): 269 def __init__( 270 self, 271 unary_fn, 272 **kwargs, 273 ): 274 super().__init__() 275 if dim == 4: 276 self.conv = torch.nn.Conv2d(3, 16, kernel_size=3, stride=1) 277 else: 278 self.conv = torch.nn.Conv3d(3, 16, kernel_size=3, stride=1) 279 self.unary_fn = unary_fn 280 281 def forward(self, x): 282 x = self.conv(x) 283 return self.unary_fn(x) 284 285 dtypes = [ 286 torch.float, 287 ] 288 if torch.ops.mkldnn._is_mkldnn_bf16_supported(): 289 dtypes.append(torch.bfloat16) 290 if torch.ops.mkldnn._is_mkldnn_fp16_supported(): 291 dtypes.append(torch.float16) 292 cl_format = torch.channels_last if dim == 4 else torch.channels_last_3d 293 options = itertools.product( 294 unary_list.keys(), 295 [torch.contiguous_format, cl_format], 296 dtypes, 297 ) 298 299 for ( 300 unary_fn, 301 memory_format, 302 dtype, 303 ) in options: 304 metrics.reset() 305 if dim == 4: 306 x_shape = (1, 3, 56, 56) 307 else: 308 x_shape = (1, 3, 20, 56, 56) 309 mod = M(unary_fn).to(memory_format=memory_format).eval() 310 311 v = ( 312 torch.randn(x_shape, dtype=torch.float32) 313 .add(1) 314 .to(memory_format=memory_format) 315 ) 316 # Add 1 for weight packing pass. 317 match_nodes = unary_list[unary_fn] + 1 318 if dtype in ( 319 torch.float16, 320 torch.bfloat16, 321 ) and self._check_unary_is_decomposed(unary_fn): 322 # Has extra dtype conversion nodes for autocast. 323 match_nodes += 2 324 self._test_common(mod, (v,), 2, match_nodes, check_autocast=dtype) 325 generated_kernel_count = cal_conv_generated_kernel_number(mod, v, dtype) 326 self.assertEqual(metrics.generated_kernel_count, generated_kernel_count) 327 328 @skipIfNoDynamoSupport 329 @skipIfNoONEDNN 330 @skipIfRocm 331 def test_conv2d_unary_cpu(self): 332 self._test_conv_unary_cpu_base(dim=4) 333 334 @skipIfNoDynamoSupport 335 @skipIfNoONEDNN 336 @skipIfRocm 337 def test_conv3d_unary_cpu(self): 338 self._test_conv_unary_cpu_base(dim=5) 339 340 def test_linear_unary(self): 341 class M(torch.nn.Module): 342 def __init__( 343 self, 344 unary_fn, 345 in_features, 346 out_features, 347 bias, 348 **kwargs, 349 ): 350 super().__init__() 351 self.linear = torch.nn.Linear( 352 in_features, 353 out_features, 354 bias, 355 **kwargs, 356 ) 357 self.unary_fn = unary_fn 358 359 def forward(self, x): 360 x = self.linear(x) 361 return self.unary_fn(x) 362 363 dtypes = [] 364 if torch.ops.mkldnn._is_mkldnn_bf16_supported(): 365 dtypes.append(torch.bfloat16) 366 if torch.ops.mkldnn._is_mkldnn_fp16_supported(): 367 dtypes.append(torch.float16) 368 options = itertools.product(unary_list, [True, False], dtypes) 369 for unary_fn, bias, dtype in options: 370 metrics.reset() 371 mod = M(unary_fn, 10, 30, bias=bias).eval() 372 # only fuse for linear when the dtype is bf16 373 mod = mod 374 v = torch.randn(2, 10) 375 # packing pass + unary fusion. 376 matcher_count = 2 377 # Add 1 for weight packing pass. 378 matcher_nodes = unary_list[unary_fn] + 1 379 if self._check_unary_is_decomposed(unary_fn): 380 # Has extra dtype conversion nodes for autocast. 381 matcher_nodes += 2 382 self._test_common( 383 mod, (v,), matcher_count, matcher_nodes, check_autocast=dtype 384 ) 385 # only generated 1 kernel for "to" 386 self.assertEqual(metrics.generated_kernel_count, 1) 387 388 @unittest.skipIf(not TEST_MKL, "Test requires MKL") 389 def test_linear_fp32(self): 390 class M(torch.nn.Module): 391 def __init__(self, bias): 392 super().__init__() 393 self.linear = torch.nn.Linear(10, 30, bias) 394 395 def forward(self, x): 396 return self.linear(x) 397 398 for bias in [True, False]: 399 mod = M(bias=bias).eval() 400 v = torch.randn(2, 10) 401 # packing pass. 402 matcher_count = 1 403 matcher_nodes = 1 404 self._test_common(mod, (v,), matcher_count, matcher_nodes) 405 406 def test_linear_add_bias(self): 407 class M(torch.nn.Module): 408 def __init__(self, dtype, unary_fn): 409 super().__init__() 410 self.linear1 = torch.nn.Linear(10, 64, bias=False) 411 self.bias1 = torch.randn(64).to(dtype=dtype) 412 self.linear2 = torch.nn.Linear(10, 64, bias=False) 413 self.bias2 = torch.randn(64).to(dtype=dtype) 414 self.unary_fn = unary_fn 415 416 def forward(self, x): 417 a = self.linear1(x) + self.bias1 418 b = self.linear2(x) + self.bias2 419 return self.unary_fn(a), self.unary_fn(b) 420 421 dtypes = [] 422 if torch.ops.mkldnn._is_mkldnn_bf16_supported(): 423 dtypes.append(torch.bfloat16) 424 if torch.ops.mkldnn._is_mkldnn_fp16_supported(): 425 dtypes.append(torch.float16) 426 options = itertools.product(unary_list, dtypes) 427 for unary_fn, dtype in options: 428 metrics.reset() 429 mod = M(dtype, unary_fn).eval() 430 v = torch.randn(2, 10) 431 matcher_count = 3 432 # Add 1 for weight packing pass, add 2 for bias folding pass per linear. 433 matcher_nodes = unary_list[unary_fn] + 3 434 if self._check_unary_is_decomposed(unary_fn): 435 # Has extra dtype conversion nodes for autocast. 436 matcher_nodes += 2 437 # we have 2 linears, so we double the matcher_count/nodes 438 self._test_common( 439 mod, (v,), matcher_count * 2, matcher_nodes * 2, check_autocast=dtype 440 ) 441 self.assertEqual(metrics.generated_kernel_count, 1) 442 443 @skipIfNoDynamoSupport 444 @skipIfNoONEDNN 445 @skipIfRocm 446 def test_conv_transpose2d_unary(self): 447 class M(torch.nn.Module): 448 def __init__( 449 self, 450 unary_fn, 451 **kwargs, 452 ): 453 super().__init__() 454 self.conv_transpose2d = torch.nn.ConvTranspose2d( 455 3, 16, 3, stride=2, padding=1 456 ) 457 self.unary_fn = unary_fn 458 459 def forward(self, x): 460 x = self.conv_transpose2d(x) 461 return self.unary_fn(x) 462 463 dtypes = [ 464 torch.float, 465 ] 466 if torch.ops.mkldnn._is_mkldnn_bf16_supported(): 467 dtypes.append(torch.bfloat16) 468 if torch.ops.mkldnn._is_mkldnn_fp16_supported(): 469 dtypes.append(torch.float16) 470 471 options = itertools.product( 472 unary_list, 473 [torch.contiguous_format, torch.channels_last], 474 dtypes, 475 ) 476 477 for unary_fn, memory_format, dtype in options: 478 metrics.reset() 479 x_shape = (1, 3, 28, 28) 480 mod = M(unary_fn).eval() 481 482 v = torch.randn(x_shape, dtype=torch.float32).to( 483 memory_format=memory_format 484 ) 485 # Add 1 for weight packing pass. 486 match_nodes = unary_list[unary_fn] + 1 487 if dtype in ( 488 torch.float16, 489 torch.bfloat16, 490 ) and self._check_unary_is_decomposed(unary_fn): 491 # Has extra dtype conversion nodes for autocast. 492 match_nodes += 2 493 self._test_common(mod, (v,), 2, match_nodes, check_autocast=dtype) 494 generated_kernel_count = cal_conv_generated_kernel_number(mod, v, dtype) 495 self.assertEqual(metrics.generated_kernel_count, generated_kernel_count) 496 497 def _test_conv_binary_base(self, dim=4): 498 assert dim == 4 or dim == 5 499 500 class M(torch.nn.Module): 501 def __init__( 502 self, 503 binary_fn, 504 has_relu, 505 **kwargs, 506 ): 507 super().__init__() 508 if dim == 4: 509 self.conv1 = torch.nn.Conv2d(3, 16, kernel_size=3, stride=1) 510 self.conv2 = torch.nn.Conv2d(3, 16, kernel_size=3, stride=1) 511 else: 512 self.conv1 = torch.nn.Conv3d(3, 16, kernel_size=3, stride=1) 513 self.conv2 = torch.nn.Conv3d(3, 16, kernel_size=3, stride=1) 514 self.binary_fn = binary_fn 515 self.has_relu = has_relu 516 517 def forward(self, x): 518 x1 = self.conv1(x) 519 x2 = self.conv2(x) 520 if has_relu: 521 return self.binary_fn(x1, x2).relu() 522 else: 523 return self.binary_fn(x1, x2) 524 525 dtypes = [ 526 torch.float, 527 ] 528 if torch.ops.mkldnn._is_mkldnn_bf16_supported(): 529 dtypes.append(torch.bfloat16) 530 if torch.ops.mkldnn._is_mkldnn_fp16_supported(): 531 dtypes.append(torch.float16) 532 cl_format = torch.channels_last if dim == 4 else torch.channels_last_3d 533 test_memory_format = [torch.contiguous_format, cl_format] 534 options = itertools.product( 535 binary_list, 536 [True, False], 537 test_memory_format, 538 dtypes, 539 ) 540 541 for ( 542 binary_fn, 543 has_relu, 544 memory_format, 545 dtype, 546 ) in options: 547 metrics.reset() 548 if dim == 4: 549 x_shape = (1, 3, 56, 56) 550 else: 551 x_shape = (1, 3, 20, 56, 56) 552 mod = M(binary_fn, has_relu).eval() 553 v = ( 554 torch.randn(x_shape, dtype=torch.float32, requires_grad=True) 555 .add(1) 556 .to(memory_format=memory_format) 557 ) 558 match_count = binary_list[binary_fn][0] + 2 559 match_nodes = binary_list[binary_fn][1] 560 if has_relu: 561 match_nodes += 1 562 self._test_common( 563 mod, (v,), match_count, match_nodes + 2, check_autocast=dtype 564 ) 565 generated_kernel_count = cal_conv_generated_kernel_number(mod, v, dtype) 566 self.assertEqual(metrics.generated_kernel_count, generated_kernel_count) 567 568 @skipIfNoDynamoSupport 569 @skipIfNoONEDNN 570 @skipIfRocm 571 def test_conv2d_binary(self): 572 self._test_conv_binary_base(dim=4) 573 574 @skipIfNoDynamoSupport 575 @skipIfNoONEDNN 576 @skipIfRocm 577 def test_conv3d_binary(self): 578 self._test_conv_binary_base(dim=5) 579 580 def test_linear_binary(self): 581 class M(torch.nn.Module): 582 def __init__(self, binary_fn, in_channels, out_channels, bias, **kwargs): 583 super().__init__() 584 self.linear = torch.nn.Linear( 585 in_channels, out_channels, bias=bias, **kwargs 586 ) 587 self.binary_fn = binary_fn 588 589 def forward(self, x, y): 590 x = self.linear(x) 591 x = self.binary_fn(x, y.clone()) 592 return x 593 594 dtypes = [] 595 if torch.ops.mkldnn._is_mkldnn_bf16_supported(): 596 dtypes.append(torch.bfloat16) 597 if torch.ops.mkldnn._is_mkldnn_fp16_supported(): 598 dtypes.append(torch.float16) 599 options = itertools.product( 600 binary_list, [[2, 3, 10], [2, 10]], [True, False], dtypes 601 ) 602 out_feature = 30 603 for binary_fn, input_shape, bias, dtype in options: 604 metrics.reset() 605 # addmm(mm) + (linear+add) 606 match_count = 2 607 match_nodes = 3 608 if len(input_shape) == 3: 609 is_inplace = binary_list[binary_fn][2] 610 # view + linear + view(joint_graph+freeze pass) 611 match_count = match_count + 5 if is_inplace else match_count + 3 612 match_nodes = match_nodes + 7 if is_inplace else match_nodes + 5 613 mod = M(binary_fn, input_shape[-1], out_feature, bias).eval() 614 v = torch.randn(input_shape) 615 other = torch.randn(input_shape[:-1] + [out_feature]).to(dtype) 616 self._test_common( 617 mod, 618 ( 619 v, 620 other, 621 ), 622 match_count, 623 match_nodes, 624 check_autocast=dtype, 625 ) 626 self.assertEqual(metrics.generated_kernel_count, 1) 627 628 def test_multi_linear_share_same_input(self): 629 # llama pattern. 630 class M(torch.nn.Module): 631 def __init__( 632 self, 633 ): 634 super().__init__() 635 self.w1 = torch.nn.Linear(16, 16, bias=False) 636 self.w2 = torch.nn.Linear(16, 16, bias=False) 637 638 def forward(self, x): 639 return F.silu(self.w1(x)) * F.relu(self.w2(x)) 640 641 dtypes = [] 642 if torch.ops.mkldnn._is_mkldnn_bf16_supported(): 643 dtypes.append(torch.bfloat16) 644 if torch.ops.mkldnn._is_mkldnn_fp16_supported(): 645 dtypes.append(torch.float16) 646 for dtype in dtypes: 647 mod = M().to(dtype).eval() 648 v = torch.randn(2, 4, 16).to(dtype) 649 # 1. view(match_count=4, match_nodes=4). 650 # 2. mm to packed linear(match_count=2, match_nodes=2). 651 # 3. view+linear+view to linear(match_count=2, match_nodes=6). 652 # 4. linear+silu fusion(match_count=1, match_nodes=5) 653 # 5. linear+relu fusion(match_count=1, match_nodes=2) 654 655 match_count = 10 656 match_nodes = 19 657 self._test_common(mod, (v,), match_count, match_nodes, rtol=1e-2, atol=1e-2) 658 659 def _qconv2d_cpu_test_helper(self, int8_mixed_bf16=False): 660 class M(torch.nn.Module): 661 def __init__( 662 self, 663 **kwargs, 664 ): 665 super().__init__() 666 self.conv = torch.nn.Conv2d(3, 128, kernel_size=3, stride=1) 667 self.conv2 = torch.nn.Conv2d(128, 128, kernel_size=3, stride=1) 668 669 def forward(self, x): 670 return self.conv2(self.conv(x)) 671 672 mod = M().eval() 673 v = torch.randn((1, 3, 8, 8), dtype=torch.float32, requires_grad=False).add(1) 674 675 def matcher_check_fn(): 676 # 1. Dequant-Conv2D pattern matched in QConv2D weight prepack * 1 677 # int8_mixed_fp32: [dequant_node, dequantize_per_channel, clone, convolution] 678 # int8_mixed_bf16: [dequant_node, optional(convert_element_type_4), 679 # dequantize_per_channel, optional(convert_element_type_3), clone, convolution] 680 self.assertEqual( 681 counters["inductor"]["qconv2d_weight_prepack_matcher_count"], 2 682 ) 683 self.assertEqual( 684 counters["inductor"]["qconv2d_weight_prepack_matcher_nodes"], 685 12 if int8_mixed_bf16 else 8, 686 ) 687 688 self._test_common( 689 mod, 690 (v,), 691 check_quantization=True, 692 check_autocast=torch.bfloat16 if int8_mixed_bf16 else torch.float, 693 matcher_check_fn=matcher_check_fn, 694 ) 695 696 @skipIfNoDynamoSupport 697 @skipIfNoONEDNN 698 @skipIfRocm 699 def test_qconv2d_cpu(self): 700 r""" 701 This testcase will quantize a single Conv2d module. 702 """ 703 self._qconv2d_cpu_test_helper() 704 705 @skipIfNoDynamoSupport 706 @skipIfNoONEDNNBF16 707 @skipIfNoONEDNN 708 @skipIfRocm 709 def test_qconv2d_int8_mixed_bf16(self): 710 r""" 711 This testcase will quantize a single Conv2d module with int8_mixed_bf16 quantization. 712 """ 713 self._qconv2d_cpu_test_helper(int8_mixed_bf16=True) 714 715 def _qconv2d_unary_cpu_test_helper( 716 self, 717 int8_mixed_bf16=False, 718 unary_op=torch.nn.ReLU(), 719 qconv2d_unary_matcher_nodes=None, 720 ): 721 class M(torch.nn.Module): 722 def __init__( 723 self, 724 **kwargs, 725 ): 726 super().__init__() 727 self.conv = torch.nn.Conv2d(3, 128, kernel_size=3, stride=1) 728 self.unary_fn = copy.deepcopy(unary_op) 729 self.conv2 = torch.nn.Conv2d(128, 128, kernel_size=3, stride=1) 730 self.unary_fn2 = copy.deepcopy(unary_op) 731 732 def forward(self, x): 733 tmp = self.unary_fn(self.conv(x)) 734 return self.unary_fn2(self.conv2(tmp)) 735 736 mod = M().eval() 737 v = torch.randn((1, 3, 8, 8), dtype=torch.float32, requires_grad=False).add(1) 738 739 def matcher_check_fn(): 740 # 1. Dequant-Conv2D pattern matched in quantization weight prepack * 2 741 self.assertEqual( 742 counters["inductor"]["qconv2d_weight_prepack_matcher_count"], 2 743 ) 744 # 2. QConv2D Unary fusion in post-grad fusion pass * 2 745 self.assertEqual(counters["inductor"]["qconv2d_unary_matcher_count"], 2) 746 if qconv2d_unary_matcher_nodes: 747 self.assertEqual( 748 counters["inductor"]["qconv2d_unary_matcher_nodes"], 749 qconv2d_unary_matcher_nodes, 750 ) 751 752 self._test_common( 753 mod, 754 (v,), 755 check_quantization=True, 756 check_autocast=torch.bfloat16 if int8_mixed_bf16 else torch.float, 757 matcher_check_fn=matcher_check_fn, 758 ) 759 760 @skipIfNoDynamoSupport 761 @skipIfNoONEDNN 762 @skipIfRocm 763 def test_qconv2d_relu_cpu(self): 764 r""" 765 This testcase will quantize Conv2d->ReLU pattern. 766 """ 767 self._qconv2d_unary_cpu_test_helper() 768 769 @skipIfNoDynamoSupport 770 @skipIfNoONEDNNBF16 771 @skipIfNoONEDNN 772 @skipIfRocm 773 def test_qconv2d_relu_int8_mixed_bf16(self): 774 r""" 775 This testcase will quantize Conv2d->ReLU pattern with int8_mixed_bf16 quantization. 776 """ 777 self._qconv2d_unary_cpu_test_helper(int8_mixed_bf16=True) 778 779 @skipIfNoDynamoSupport 780 @skipIfNoONEDNN 781 @skipIfRocm 782 def test_qconv2d_relu6_cpu(self): 783 r""" 784 This testcase will quantize Conv2d->ReLU6 pattern. 785 """ 786 self._qconv2d_unary_cpu_test_helper(unary_op=torch.nn.ReLU6()) 787 788 @skipIfNoDynamoSupport 789 @skipIfNoONEDNN 790 @skipIfRocm 791 def test_qconv2d_hardtanh_cpu(self): 792 r""" 793 This testcase will quantize Conv2d->Hardtanh pattern. 794 """ 795 self._qconv2d_unary_cpu_test_helper(unary_op=torch.nn.Hardtanh()) 796 797 @skipIfNoDynamoSupport 798 @skipIfNoONEDNNBF16 799 @skipIfNoONEDNN 800 @skipIfRocm 801 def test_qconv2d_hardtanh_int8_mixed_bf16_cpu(self): 802 r""" 803 This testcase will quantize Conv2d->Hardtanh pattern. 804 Match.nodes: 805 [qconv2d_pointwise_default, convert_element_type, clamp_min, clamp_max, convert_element_type, quantize_per_tensor] 806 [qconv2d_pointwise_default, convert_element_type, clamp_min, clamp_max, convert_element_type] 807 """ 808 self._qconv2d_unary_cpu_test_helper( 809 unary_op=torch.nn.Hardtanh(), 810 int8_mixed_bf16=True, 811 qconv2d_unary_matcher_nodes=11, 812 ) 813 814 @skipIfNoDynamoSupport 815 @skipIfNoONEDNN 816 @skipIfRocm 817 def test_qconv2d_hardswish_cpu(self): 818 r""" 819 This testcase will quantize Conv2d->Hardswish pattern. 820 """ 821 self._qconv2d_unary_cpu_test_helper(unary_op=torch.nn.Hardswish()) 822 823 @skipIfNoDynamoSupport 824 @skipIfNoONEDNNBF16 825 @skipIfNoONEDNN 826 @skipIfRocm 827 def test_qconv2d_hardswish_int8_mixed_bf16_cpu(self): 828 r""" 829 This testcase will quantize Conv2d->Hardswish pattern. 830 Match.nodes: 831 [qconv2d_pointwise_default, convert_element_type, add, clamp_min, 832 clamp_max, mul, div, convert_element_type, quantize_per_tensor] 833 [qconv2d_pointwise_default, convert_element_type, add, clamp_min, clamp_max, mul, div, convert_element_type] 834 """ 835 self._qconv2d_unary_cpu_test_helper( 836 unary_op=torch.nn.Hardswish(), 837 int8_mixed_bf16=True, 838 qconv2d_unary_matcher_nodes=17, 839 ) 840 841 @skipIfNoDynamoSupport 842 @skipIfNoONEDNN 843 @skipIfRocm 844 def test_qconv2d_silu_cpu(self): 845 r""" 846 This testcase will quantize Conv2d->SiLU pattern. 847 """ 848 self._qconv2d_unary_cpu_test_helper(unary_op=torch.nn.SiLU()) 849 850 @skipIfNoDynamoSupport 851 @skipIfNoONEDNNBF16 852 @skipIfNoONEDNN 853 @skipIfRocm 854 def test_qconv2d_silu_int8_mixed_bf16_cpu(self): 855 r""" 856 This testcase will quantize Conv2d->SiLU pattern. 857 Match.nodes: 858 [qconv2d_pointwise_default, convert_element_type, sigmoid, mul, 859 convert_element_type, quantize_per_tensor] 860 [qconv2d_pointwise_default, convert_element_type, sigmoid, mul, convert_element_type] 861 """ 862 self._qconv2d_unary_cpu_test_helper( 863 unary_op=torch.nn.SiLU(), 864 int8_mixed_bf16=True, 865 qconv2d_unary_matcher_nodes=11, 866 ) 867 868 def _qconv2d_add_cpu_test_helper(self, use_relu=False, int8_mixed_bf16=False): 869 r""" 870 This testcase will quantize a Conv2d->Add pattern as: 871 X 872 / \ 873 Conv1(X) Conv2(X) 874 \ / 875 Add 876 | 877 Optional(relu) 878 | 879 Y 880 """ 881 882 class M(torch.nn.Module): 883 def __init__( 884 self, 885 add_fn, 886 use_relu, 887 **kwargs, 888 ): 889 super().__init__() 890 self.conv1 = torch.nn.Conv2d(3, 6, kernel_size=3, stride=1) 891 self.conv2 = torch.nn.Conv2d(3, 6, kernel_size=3, stride=1) 892 self.add_fn = add_fn 893 self.relu = torch.nn.ReLU() 894 self.conv3 = torch.nn.Conv2d(6, 6, kernel_size=3, stride=1) 895 self.conv4 = torch.nn.Conv2d(6, 6, kernel_size=3, stride=1) 896 self.add_fn2 = add_fn 897 self.relu2 = torch.nn.ReLU() 898 self.use_relu = use_relu 899 900 def forward(self, x): 901 x1 = self.conv1(x) 902 x2 = self.conv2(x) 903 tmp = self.add_fn(x1, x2) 904 if self.use_relu: 905 tmp = self.relu(tmp) 906 tmp1 = self.conv3(tmp) 907 tmp2 = self.conv4(tmp) 908 res = self.add_fn2(tmp1, tmp2) 909 if self.use_relu: 910 res = self.relu2(res) 911 return res 912 913 for add_fn in quantization_add_fn_list + quantization_inplace_add_fn_list: 914 mod = M(add_fn, use_relu).eval() 915 v = torch.randn((1, 3, 8, 8), dtype=torch.float32, requires_grad=False).add( 916 1 917 ) 918 919 def matcher_check_fn(): 920 # 1. Dequant-Conv2D pattern matched in quantization weight prepack * 4 921 self.assertEqual( 922 counters["inductor"]["qconv2d_weight_prepack_matcher_count"], 4 923 ) 924 # 2. Qconv2d Binary Unary fusion in post-grad fusion pass * 2 925 self.assertEqual( 926 counters["inductor"]["qconv2d_binary_matcher_count"], 2 927 ) 928 929 self._test_common( 930 mod, 931 (v,), 932 check_quantization=True, 933 check_autocast=torch.bfloat16 if int8_mixed_bf16 else torch.float, 934 matcher_check_fn=matcher_check_fn, 935 ) 936 937 @skipIfNoDynamoSupport 938 @skipIfNoONEDNN 939 @skipIfRocm 940 def test_qconv2d_add_cpu(self): 941 self._qconv2d_add_cpu_test_helper() 942 943 @skipIfNoDynamoSupport 944 @skipIfNoONEDNNBF16 945 @skipIfNoONEDNN 946 @skipIfRocm 947 def test_qconv2d_add_int8_mixed_bf16(self): 948 self._qconv2d_add_cpu_test_helper(int8_mixed_bf16=True) 949 950 @skipIfNoDynamoSupport 951 @skipIfNoONEDNN 952 @skipIfRocm 953 def test_qconv2d_add_relu_cpu(self): 954 self._qconv2d_add_cpu_test_helper(use_relu=True) 955 956 @skipIfNoDynamoSupport 957 @skipIfNoONEDNNBF16 958 @skipIfNoONEDNN 959 @skipIfRocm 960 def test_qconv2d_add_relu_int8_mixed_bf16(self): 961 self._qconv2d_add_cpu_test_helper(use_relu=True, int8_mixed_bf16=True) 962 963 @skipIfNoDynamoSupport 964 @skipIfNoONEDNN 965 @skipIfRocm 966 def test_qconv2d_add_broadcast_shapes_cpu(self): 967 r""" 968 This testcase will quantize Conv2d->add pattern using broadcast shape inputs. 969 Conv2d->Add fusion will fail for the broadcast shape inputs case. 970 """ 971 972 class M(torch.nn.Module): 973 def __init__(self, use_bias): 974 super().__init__() 975 self.conv = torch.nn.Conv2d(32, 32, kernel_size=3, stride=1) 976 977 def forward(self, x1, x2): 978 return torch.add(self.conv(x1), x2) 979 980 bias_list = [True, False] 981 for bias in bias_list: 982 mod = M(bias).eval() 983 x1 = torch.randn((2, 32, 9, 9)) 984 x2 = torch.randn((2, 32, 1, 1)) 985 986 def matcher_check_fn(): 987 # 1. Dequant-Conv2D pattern matched in quantization weight prepack * 1 988 self.assertEqual( 989 counters["inductor"]["qconv2d_weight_prepack_matcher_count"], 1 990 ) 991 # 2. Qconv2d Binary Unary fusion in post-grad fusion pass * 0 992 self.assertEqual( 993 counters["inductor"]["qconv2d_binary_matcher_count"], 0 994 ) 995 996 self._test_common( 997 mod, 998 (x1, x2), 999 check_quantization=True, 1000 matcher_check_fn=matcher_check_fn, 1001 ) 1002 1003 @skipIfNoDynamoSupport 1004 @skipIfNoONEDNN 1005 @skipIfRocm 1006 def test_qconv2d_add_2(self): 1007 r""" 1008 This testcase prevents this pattern be matched as a conv_binary fusion by mistake. 1009 Conv(X) 3 1010 \ / 1011 Add 1012 We see this pattern in Mobilenet v3 large which add is decomposed from torch.nn.Hardswish or torch.nn.Hardsigmoid. 1013 """ 1014 1015 class M(torch.nn.Module): 1016 def __init__( 1017 self, 1018 post_op, 1019 ): 1020 super().__init__() 1021 self.conv = torch.nn.Conv2d(3, 6, kernel_size=3, stride=1) 1022 self.post_op = post_op 1023 1024 def forward(self, x): 1025 return self.post_op(self.conv(x)) 1026 1027 for post_op in [ 1028 torch.nn.Hardswish(inplace=True), 1029 torch.nn.Hardsigmoid(inplace=True), 1030 ]: 1031 mod = M(post_op).eval() 1032 v = torch.randn((1, 3, 8, 8), dtype=torch.float32, requires_grad=False).add( 1033 1 1034 ) 1035 1036 def matcher_check_fn(): 1037 # Shouldn't hit conv binary fusion 1038 self.assertEqual( 1039 counters["inductor"]["qconv2d_binary_matcher_count"], 0 1040 ) 1041 1042 self._test_common( 1043 mod, 1044 (v,), 1045 check_quantization=True, 1046 matcher_check_fn=matcher_check_fn, 1047 ) 1048 1049 @skipIfNoDynamoSupport 1050 @skipIfNoONEDNN 1051 @skipIfRocm 1052 def test_qconv2d_add_3(self): 1053 r""" 1054 This testcase will test below model: 1055 x 1056 / \ 1057 conv1 maxpool 1058 \ / \ 1059 add conv2 1060 \ / 1061 cat 1062 Based on default recipe of x86InductorQuantizer, we will see this pattern after convert: 1063 qconv1 maxpool 1064 \ | 1065 \ q1 1066 \ / \ 1067 \ dq1 qconv2 1068 \ / 1069 add 1070 | 1071 q2 1072 Since q1 has 2 users and qconv2 is not ancestor node of qconv1, we shouldn't fuse: 1073 int8 1074 / 1075 qconv1 dq1 1076 \ / 1077 add 1078 | 1079 q2 1080 | 1081 int8 1082 Instead we can match and fuse this pattern into qconv_binary: 1083 qconv1 fp32 1084 \ / 1085 add 1086 | 1087 fp32 1088 """ 1089 1090 class M(torch.nn.Module): 1091 def __init__( 1092 self, 1093 ): 1094 super().__init__() 1095 self.conv1 = torch.nn.Conv2d(3, 3, kernel_size=3, stride=1) 1096 self.conv2 = torch.nn.Conv2d(3, 3, kernel_size=1, stride=1) 1097 self.maxpool = torch.nn.MaxPool2d( 1098 kernel_size=3, stride=1, padding=0, dilation=1 1099 ) 1100 1101 def forward(self, x): 1102 tmp1 = self.conv1(x) 1103 tmp2 = self.maxpool(x) 1104 add = torch.add(tmp1, tmp2) 1105 tmp3 = self.conv2(tmp2) 1106 return torch.cat((add, tmp3), dim=1) 1107 1108 mod = M().eval() 1109 v = torch.randn((1, 3, 8, 8), dtype=torch.float32, requires_grad=False).add(1) 1110 1111 def matcher_check_fn(): 1112 self.assertEqual(counters["inductor"]["qconv2d_binary_matcher_count"], 1) 1113 # The matched qconv binary pattern should have 2 nodes [qconv, add] 1114 # instead of 11 which has dequant in binary input and output quant 1115 self.assertEqual(counters["inductor"]["qconv2d_binary_matcher_nodes"], 2) 1116 1117 self._test_common( 1118 mod, 1119 (v,), 1120 check_quantization=True, 1121 matcher_check_fn=matcher_check_fn, 1122 ) 1123 1124 @skipIfNoDynamoSupport 1125 @skipIfNoONEDNN 1126 @skipIfRocm 1127 def test_qat_qconv2d(self): 1128 r""" 1129 This testcase will quantize a single Conv2d module with qat flow. 1130 """ 1131 1132 class M(torch.nn.Module): 1133 def __init__( 1134 self, 1135 **kwargs, 1136 ): 1137 super().__init__() 1138 self.conv = torch.nn.Conv2d(3, 128, kernel_size=3, stride=1) 1139 self.bn = torch.nn.BatchNorm2d(128) 1140 1141 def forward(self, x): 1142 return self.bn(self.conv(x)) 1143 1144 mod = M().train() 1145 v = torch.randn((1, 3, 8, 8), dtype=torch.float32, requires_grad=True).add(1) 1146 1147 def matcher_check_fn(): 1148 # 1. Dequant-conv pattern matched in quantization weight prepack * 1 1149 # [dequantize_per_tensor, dequantize_per_channel, clone, convolution] 1150 self.assertEqual( 1151 counters["inductor"]["qconv2d_weight_prepack_matcher_count"], 1 1152 ) 1153 self.assertEqual( 1154 counters["inductor"]["qconv2d_weight_prepack_matcher_nodes"], 4 1155 ) 1156 # 2. QConv2D Unary fusion in post-grad fusion pass * 1 1157 # [qconv2d_pointwise_default, quantize_per_tensor] 1158 self.assertEqual(counters["inductor"]["qconv2d_unary_matcher_count"], 1) 1159 self.assertEqual(counters["inductor"]["qconv2d_unary_matcher_nodes"], 2) 1160 1161 self._test_common( 1162 mod, 1163 (v,), 1164 check_quantization=True, 1165 is_qat=True, 1166 matcher_check_fn=matcher_check_fn, 1167 ) 1168 1169 def _qat_qconv2d_unary_cpu_test_helper( 1170 self, 1171 unary_op=torch.nn.ReLU(), 1172 ): 1173 class M(torch.nn.Module): 1174 def __init__( 1175 self, 1176 **kwargs, 1177 ): 1178 super().__init__() 1179 self.conv = torch.nn.Conv2d(3, 3, kernel_size=3, stride=1) 1180 self.unary_fn = copy.deepcopy(unary_op) 1181 self.bn = torch.nn.BatchNorm2d(3) 1182 self.conv2 = torch.nn.Conv2d(3, 3, kernel_size=3, stride=1) 1183 self.unary_fn2 = copy.deepcopy(unary_op) 1184 self.bn2 = torch.nn.BatchNorm2d(3) 1185 1186 def forward(self, x): 1187 tmp = self.unary_fn(self.bn(self.conv(x))) 1188 return self.unary_fn2(self.bn2(self.conv2(tmp))) 1189 1190 mod = M() 1191 v = torch.randn((1, 3, 8, 8), dtype=torch.float32, requires_grad=True).add(1) 1192 1193 def matcher_check_fn(): 1194 # 1. Dequant-conv pattern matched in quantization weight prepack * 1 1195 # [convert_element_type_1, sub, mul_1, dequantize_per_channel, clone, convolution] 1196 self.assertEqual( 1197 counters["inductor"]["qconv2d_weight_prepack_matcher_count"], 2 1198 ) 1199 # 2. QConv2D Unary fusion in post-grad fusion pass * 1 1200 # [qconv2d_pointwise_default, relu, div_1, round_2, add_1, clamp_min_1, clamp_max_1, convert_element_type_2] 1201 self.assertEqual(counters["inductor"]["qconv2d_unary_matcher_count"], 2) 1202 1203 self._test_common( 1204 mod, 1205 (v,), 1206 check_quantization=True, 1207 is_qat=True, 1208 matcher_check_fn=matcher_check_fn, 1209 ) 1210 1211 @skipIfNoDynamoSupport 1212 @skipIfNoONEDNN 1213 @skipIfRocm 1214 def test_qat_qconv2d_relu(self): 1215 r""" 1216 This testcase will quantize Conv2d->ReLU pattern with qat flow. 1217 """ 1218 1219 self._qat_qconv2d_unary_cpu_test_helper() 1220 1221 @skipIfNoDynamoSupport 1222 @skipIfNoONEDNN 1223 @skipIfRocm 1224 def test_qat_qconv2d_relu6(self): 1225 r""" 1226 This testcase will quantize Conv2d->ReLU6 pattern with qat flow. 1227 """ 1228 self._qat_qconv2d_unary_cpu_test_helper(unary_op=torch.nn.ReLU6()) 1229 1230 @skipIfNoDynamoSupport 1231 @skipIfNoONEDNN 1232 @skipIfRocm 1233 def test_qat_qconv2d_hardtanh(self): 1234 r""" 1235 This testcase will quantize Conv2d->Hardtanh pattern with qat flow. 1236 """ 1237 self._qat_qconv2d_unary_cpu_test_helper(unary_op=torch.nn.Hardtanh()) 1238 1239 @skipIfNoDynamoSupport 1240 @skipIfNoONEDNN 1241 @skipIfRocm 1242 def test_qat_qconv2d_silu(self): 1243 r""" 1244 This testcase will quantize Conv2d->SiLU pattern with qat flow. 1245 """ 1246 self._qat_qconv2d_unary_cpu_test_helper(unary_op=torch.nn.SiLU()) 1247 1248 @skipIfNoDynamoSupport 1249 @skipIfNoONEDNN 1250 @skipIfRocm 1251 def test_qat_qconv2d_hardswish(self): 1252 r""" 1253 This testcase will quantize Conv2d->Hardswish pattern with qat flow. 1254 """ 1255 self._qat_qconv2d_unary_cpu_test_helper(unary_op=torch.nn.Hardswish()) 1256 1257 @skipIfNoDynamoSupport 1258 @skipIfNoONEDNN 1259 @skipIfRocm 1260 def test_qat_qconv2d_add(self): 1261 r""" 1262 This testcase will quantize a Conv2d->Add pattern as: 1263 X 1264 / \ 1265 Conv1(X) Conv2(X) 1266 \ / 1267 Add 1268 | 1269 Y 1270 """ 1271 1272 class M(torch.nn.Module): 1273 def __init__( 1274 self, 1275 **kwargs, 1276 ): 1277 super().__init__() 1278 self.conv1 = torch.nn.Conv2d(3, 6, kernel_size=3, stride=1) 1279 self.bn1 = torch.nn.BatchNorm2d(6) 1280 self.conv2 = torch.nn.Conv2d(3, 6, kernel_size=3, stride=1) 1281 self.bn2 = torch.nn.BatchNorm2d(6) 1282 1283 def forward(self, x): 1284 x1 = self.bn1(self.conv1(x)) 1285 x2 = self.bn2(self.conv2(x)) 1286 return x1 + x2 1287 1288 mod = M().train() 1289 v = torch.randn((1, 3, 8, 8), dtype=torch.float32, requires_grad=True).add(1) 1290 1291 def matcher_check_fn(): 1292 # 1. Dequant-conv pattern matched in quantization weight prepack * 2 1293 # [dequantize_per_tensor, dequantize_per_channel, clone, convolution] 1294 self.assertEqual( 1295 counters["inductor"]["qconv2d_weight_prepack_matcher_count"], 2 1296 ) 1297 self.assertEqual( 1298 counters["inductor"]["qconv2d_weight_prepack_matcher_nodes"], 8 1299 ) 1300 # 2. Qconv2d Binary fusion in post-grad fusion pass * 1 1301 # [qconv2d_pointwise_default_1, dequantize_per_tensor, add_3, quantize_per_tensor] 1302 self.assertEqual(counters["inductor"]["qconv2d_binary_matcher_count"], 1) 1303 self.assertEqual(counters["inductor"]["qconv2d_binary_matcher_nodes"], 4) 1304 1305 self._test_common( 1306 mod, 1307 (v,), 1308 check_quantization=True, 1309 is_qat=True, 1310 matcher_check_fn=matcher_check_fn, 1311 ) 1312 1313 @skipIfNoDynamoSupport 1314 @skipIfNoONEDNN 1315 @skipIfRocm 1316 def test_qat_qconv2d_add_relu(self): 1317 r""" 1318 This testcase will quantize a Conv2d->Add->ReLU pattern as: 1319 X 1320 / \ 1321 Conv1(X) Conv2(X) 1322 \ / 1323 Add 1324 | 1325 ReLU 1326 | 1327 Y 1328 """ 1329 1330 class M(torch.nn.Module): 1331 def __init__( 1332 self, 1333 **kwargs, 1334 ): 1335 super().__init__() 1336 self.conv1 = torch.nn.Conv2d(3, 6, kernel_size=3, stride=1) 1337 self.bn1 = torch.nn.BatchNorm2d(6) 1338 self.conv2 = torch.nn.Conv2d(3, 6, kernel_size=3, stride=1) 1339 self.bn2 = torch.nn.BatchNorm2d(6) 1340 self.relu = torch.nn.ReLU() 1341 1342 def forward(self, x): 1343 x1 = self.bn1(self.conv1(x)) 1344 x2 = self.bn2(self.conv2(x)) 1345 return self.relu(x1 + x2) 1346 1347 mod = M().train() 1348 v = torch.randn((1, 3, 8, 8), dtype=torch.float32, requires_grad=True).add(1) 1349 1350 def matcher_check_fn(): 1351 # 1. Dequant-conv pattern matched in quantization weight prepack * 2 1352 # [dequantize_per_tensor, dequantize_per_channel, clone, convolution] 1353 self.assertEqual( 1354 counters["inductor"]["qconv2d_weight_prepack_matcher_count"], 2 1355 ) 1356 self.assertEqual( 1357 counters["inductor"]["qconv2d_weight_prepack_matcher_nodes"], 8 1358 ) 1359 # 2. Qconv2d Binary fusion in post-grad fusion pass * 1 1360 # [qconv2d_pointwise_default_1, dequantize_per_tensor, add_3, relu, quantize_per_tensor] 1361 self.assertEqual(counters["inductor"]["qconv2d_binary_matcher_count"], 1) 1362 self.assertEqual(counters["inductor"]["qconv2d_binary_matcher_nodes"], 5) 1363 1364 self._test_common( 1365 mod, 1366 (v,), 1367 check_quantization=True, 1368 is_qat=True, 1369 matcher_check_fn=matcher_check_fn, 1370 ) 1371 1372 @skipIfNoDynamoSupport 1373 @skipIfNoONEDNN 1374 @skipIfRocm 1375 def test_qconv2d_dequant_promotion_cpu(self): 1376 r""" 1377 This testcase tests if dequant node before conv2d is promoted correctly: 1378 X 1379 | 1380 Conv1(X) 1381 / \ 1382 Conv2(X) Conv3(X) 1383 \ / 1384 Add 1385 | 1386 Y 1387 """ 1388 1389 class M(torch.nn.Module): 1390 def __init__( 1391 self, 1392 **kwargs, 1393 ): 1394 super().__init__() 1395 self.conv1 = torch.nn.Conv2d(3, 6, kernel_size=3, stride=1) 1396 self.conv2 = torch.nn.Conv2d(6, 6, kernel_size=3, stride=1) 1397 self.conv3 = torch.nn.Conv2d(6, 6, kernel_size=3, stride=1) 1398 1399 def forward(self, x): 1400 temp = self.conv1(x) 1401 temp = self.conv2(temp) + self.conv3(temp) 1402 return temp 1403 1404 mod = M().eval() 1405 v = torch.randn((1, 3, 8, 8), dtype=torch.float32, requires_grad=False).add(1) 1406 1407 def matcher_check_fn(): 1408 # 1. Dequant pattern matcher for dequant promotion * 1 1409 # [dequantize_per_tensor] 1410 self.assertEqual(counters["inductor"]["dequant_promotion_matcher_count"], 1) 1411 self.assertEqual(counters["inductor"]["dequant_promotion_matcher_nodes"], 1) 1412 # 2. Dequant-conv pattern matched in quantization weight prepack * 3 1413 # [dequantize_per_tensor, dequantize_per_channel, clone, convolution] 1414 self.assertEqual( 1415 counters["inductor"]["qconv2d_weight_prepack_matcher_count"], 3 1416 ) 1417 self.assertEqual( 1418 counters["inductor"]["qconv2d_weight_prepack_matcher_nodes"], 12 1419 ) 1420 # 3. Qconv2d Binary fusion in post-grad fusion pass * 1 1421 # [qconv2d_pointwise_default_1, add_3] 1422 self.assertEqual(counters["inductor"]["qconv2d_binary_matcher_count"], 1) 1423 self.assertEqual(counters["inductor"]["qconv2d_binary_matcher_nodes"], 2) 1424 1425 self._test_common( 1426 mod, 1427 (v,), 1428 check_quantization=True, 1429 matcher_check_fn=matcher_check_fn, 1430 ) 1431 1432 def _qlinear_cpu_test_helper( 1433 self, 1434 inputs, 1435 int8_mixed_bf16=False, 1436 do_permute=False, 1437 matcher_check_fn=None, 1438 bias=True, 1439 is_dynamic=False, 1440 is_qat=False, 1441 ): 1442 class M(torch.nn.Module): 1443 def __init__(self, use_bias, do_permute=False): 1444 super().__init__() 1445 self.linear = torch.nn.Linear(4, 3, use_bias) 1446 self.linear2 = torch.nn.Linear(3, 4, use_bias) 1447 self.do_permute = do_permute 1448 1449 def forward(self, x): 1450 if self.do_permute: 1451 x = torch.reshape(torch.permute(x, (0, 2, 3, 1)), (2, 12, 4)) 1452 return self.linear2(self.linear(x)) 1453 1454 mod = M(bias, do_permute=do_permute).eval() 1455 1456 def _default_matcher_check_fn(): 1457 self.assertEqual( 1458 counters["inductor"]["qlinear_weight_prepack_matcher_count"], 2 1459 ) 1460 1461 self._test_common( 1462 mod, 1463 inputs, 1464 check_autocast=torch.bfloat16 if int8_mixed_bf16 else torch.float, 1465 check_quantization=True, 1466 matcher_check_fn=matcher_check_fn 1467 if matcher_check_fn is not None 1468 else _default_matcher_check_fn, 1469 is_qat=is_qat, 1470 is_dynamic=is_dynamic, 1471 ) 1472 1473 @skipIfNoDynamoSupport 1474 @skipIfNoONEDNN 1475 @skipIfRocm 1476 def test_qlinear_cpu(self): 1477 r""" 1478 This testcase will quantize a single Linear Moduel. 1479 """ 1480 for bias in [True, False]: 1481 self._qlinear_cpu_test_helper((torch.randn((2, 4)),), bias=bias) 1482 1483 @skipIfNoDynamoSupport 1484 @skipIfNoONEDNN 1485 @skipIfRocm 1486 def test_dynamic_qlinear_cpu(self): 1487 r""" 1488 This testcase will quantize a single Linear Moduel. 1489 """ 1490 for bias in [True, False]: 1491 self._qlinear_cpu_test_helper( 1492 (torch.randn((2, 4)),), bias=bias, is_dynamic=True 1493 ) 1494 1495 @skipIfNoDynamoSupport 1496 @skipIfNoONEDNN 1497 @skipIfRocm 1498 def test_dynamic_qlinear_qat_cpu(self): 1499 r""" 1500 This testcase will quantize a single Linear Moduel. 1501 """ 1502 for bias in [True, False]: 1503 self._qlinear_cpu_test_helper( 1504 (torch.randn((2, 4)),), bias=bias, is_dynamic=True, is_qat=True 1505 ) 1506 1507 @skipIfNoDynamoSupport 1508 @skipIfNoONEDNN 1509 @skipIfRocm 1510 def test_dynamic_qlinear_input_dim_exceeds_2(self): 1511 r""" 1512 This testcase will quantize a single Linear Moduel. 1513 """ 1514 for bias in [True, False]: 1515 self._qlinear_cpu_test_helper( 1516 (torch.randn((2, 3, 4)),), bias=bias, is_dynamic=True 1517 ) 1518 1519 @skipIfNoDynamoSupport 1520 @skipIfNoONEDNNBF16 1521 @skipIfNoONEDNN 1522 @skipIfRocm 1523 def test_qlinear_int8_mixed_bf16(self): 1524 r""" 1525 This testcase will quantize a single Linear Moduel with int8_mixed_bf16 quantization. 1526 """ 1527 for bias in [True, False]: 1528 self._qlinear_cpu_test_helper( 1529 (torch.randn((2, 4)),), int8_mixed_bf16=True, bias=bias 1530 ) 1531 1532 @skipIfNoDynamoSupport 1533 @skipIfNoONEDNN 1534 @skipIfRocm 1535 def test_qlinear_input_dim_exceeds_2(self): 1536 r""" 1537 This testcase will quantize a single Linear Moduel. 1538 """ 1539 for bias in [True, False]: 1540 self._qlinear_cpu_test_helper((torch.randn((2, 3, 4)),), bias=bias) 1541 1542 @skipIfNoDynamoSupport 1543 @skipIfNoONEDNNBF16 1544 @skipIfNoONEDNN 1545 @skipIfRocm 1546 def test_qlinear_int8_mixed_bf16_input_dim_exceeds_2(self): 1547 r""" 1548 This testcase will quantize a single Linear Moduel with int8_mixed_bf16 quantization. 1549 """ 1550 for bias in [True, False]: 1551 self._qlinear_cpu_test_helper( 1552 (torch.randn((2, 3, 4)),), int8_mixed_bf16=True, bias=bias 1553 ) 1554 1555 @skipIfNoDynamoSupport 1556 @skipIfNoONEDNN 1557 @skipIfRocm 1558 def test_qlinear_input_dim_exceeds_2_and_not_contiguous(self): 1559 r""" 1560 This testcase will quantize a single Linear Module. 1561 * Input dim exceeds 2 1562 * Input not contiguous 1563 """ 1564 for bias in [True, False]: 1565 1566 def matcher_check_fn(): 1567 self.assertEqual( 1568 counters["inductor"]["qlinear_weight_prepack_matcher_count"], 2 1569 ) 1570 self.assertEqual( 1571 counters["inductor"]["qlinear_weight_prepack_matcher_nodes"], 1572 13 if bias else 12, 1573 ) 1574 1575 self._qlinear_cpu_test_helper( 1576 (torch.randn((2, 4, 3, 4)),), 1577 do_permute=True, 1578 matcher_check_fn=matcher_check_fn, 1579 bias=bias, 1580 ) 1581 1582 @skipIfNoDynamoSupport 1583 @skipIfNoONEDNNBF16 1584 @skipIfNoONEDNN 1585 @skipIfRocm 1586 def test_qlinear_int8_mixed_bf16_input_dim_exceeds_2_and_not_contiguous(self): 1587 r""" 1588 This testcase will quantize a single Linear Module for int8_bf16. 1589 * Input dim exceeds 2 1590 * Input not contiguous 1591 """ 1592 for bias in [True, False]: 1593 1594 def matcher_check_fn(): 1595 self.assertEqual( 1596 counters["inductor"]["qlinear_weight_prepack_matcher_count"], 2 1597 ) 1598 self.assertEqual( 1599 counters["inductor"]["qlinear_weight_prepack_matcher_nodes"], 1600 17 if bias else 16, 1601 ) 1602 1603 self._qlinear_cpu_test_helper( 1604 (torch.randn((2, 4, 3, 4)),), 1605 int8_mixed_bf16=True, 1606 do_permute=True, 1607 matcher_check_fn=matcher_check_fn, 1608 bias=bias, 1609 ) 1610 1611 def _qlinear_unary_cpu_test_helper( 1612 self, inputs, unary_op=torch.nn.ReLU(), int8_mixed_bf16=False 1613 ): 1614 class M(torch.nn.Module): 1615 def __init__(self, use_bias): 1616 super().__init__() 1617 self.linear = torch.nn.Linear(4, 4, use_bias) 1618 self.unary_fn = copy.deepcopy(unary_op) 1619 self.linear2 = torch.nn.Linear(4, 4, use_bias) 1620 self.unary_fn2 = copy.deepcopy(unary_op) 1621 1622 def forward(self, x): 1623 tmp = self.unary_fn(self.linear(x)) 1624 return self.unary_fn2(self.linear2(tmp)) 1625 1626 bias_list = [True, False] 1627 for bias in bias_list: 1628 mod = M(bias).eval() 1629 1630 def matcher_check_fn(): 1631 # 1. dequant-linear pattern matched in quantization weight prepack 1632 self.assertEqual( 1633 counters["inductor"]["qlinear_weight_prepack_matcher_count"], 2 1634 ) 1635 # 2. QLinear Unary fusion in post-grad fusion pass 1636 self.assertEqual(counters["inductor"]["qlinear_unary_matcher_count"], 2) 1637 1638 self._test_common( 1639 mod, 1640 inputs, 1641 check_autocast=torch.bfloat16 if int8_mixed_bf16 else torch.float, 1642 check_quantization=True, 1643 matcher_check_fn=matcher_check_fn, 1644 ) 1645 1646 @skipIfNoDynamoSupport 1647 @skipIfNoONEDNN 1648 @skipIfRocm 1649 def test_qlinear_relu_cpu(self): 1650 r""" 1651 This testcase will quantize a Linear->ReLU pattern. 1652 """ 1653 self._qlinear_unary_cpu_test_helper((torch.randn((2, 4)),)) 1654 1655 @skipIfNoDynamoSupport 1656 @skipIfNoONEDNNBF16 1657 @skipIfNoONEDNN 1658 @skipIfRocm 1659 def test_qlinear_relu_int8_mixed_bf16(self): 1660 r""" 1661 This testcase will quantize a Linear->ReLU pattern with int8_mixed_bf16 quantization. 1662 """ 1663 self._qlinear_unary_cpu_test_helper( 1664 (torch.randn((2, 4)),), int8_mixed_bf16=True 1665 ) 1666 1667 @skipIfNoDynamoSupport 1668 @skipIfNoONEDNN 1669 @skipIfRocm 1670 def test_qlinear_relu_input_dim_exceeds_2(self): 1671 r""" 1672 This testcase will quantize a Linear->ReLU pattern. 1673 """ 1674 self._qlinear_unary_cpu_test_helper((torch.randn((2, 3, 4)),)) 1675 1676 @skipIfNoDynamoSupport 1677 @skipIfNoONEDNNBF16 1678 @skipIfNoONEDNN 1679 @skipIfRocm 1680 def test_qlinear_relu_int8_mixed_bf16_input_dim_exceeds_2(self): 1681 r""" 1682 This testcase will quantize a Linear->ReLU pattern with int8_mixed_bf16 quantization. 1683 """ 1684 self._qlinear_unary_cpu_test_helper( 1685 (torch.randn((2, 3, 4)),), int8_mixed_bf16=True 1686 ) 1687 1688 @skipIfNoDynamoSupport 1689 @skipIfNoONEDNN 1690 @skipIfRocm 1691 def test_qlinear_gelu_cpu(self): 1692 r""" 1693 This testcase will quantize a Linear->GELU pattern. 1694 """ 1695 for gelu in [torch.nn.GELU("none"), torch.nn.GELU("tanh")]: 1696 self._qlinear_unary_cpu_test_helper((torch.randn((2, 4)),), gelu) 1697 1698 @skipIfNoDynamoSupport 1699 @skipIfNoONEDNNBF16 1700 @skipIfNoONEDNN 1701 @skipIfRocm 1702 def test_qlinear_gelu_int8_mixed_bf16(self): 1703 r""" 1704 This testcase will quantize a Linear->GELU pattern with int8_mixed_bf16 quantization. 1705 """ 1706 for gelu in [torch.nn.GELU("none"), torch.nn.GELU("tanh")]: 1707 self._qlinear_unary_cpu_test_helper( 1708 (torch.randn((2, 4)),), gelu, int8_mixed_bf16=True 1709 ) 1710 1711 def _qlinear_add_cpu_test_helper(self, use_relu=False, int8_mixed_bf16=False): 1712 r""" 1713 This testcase will quantize two consecutive Linear->Add(->relu) patterns as: 1714 X 1715 / \ 1716 linear(X) linear(X) 1717 \ / 1718 Add 1719 | 1720 Optional(relu) 1721 / \ 1722 linear(X) linear(X) 1723 \ / 1724 Add 1725 | 1726 Optional(relu) 1727 | 1728 Y 1729 """ 1730 1731 def fake_quant(x): 1732 # to produce a float32 result as extra input 1733 qlib = torch.ops.quantized_decomposed 1734 x = qlib.quantize_per_tensor.default(x, 0.0166785, 42, 0, 255, torch.uint8) 1735 x = qlib.dequantize_per_tensor.default( 1736 x, 0.0166785, 42, 0, 255, torch.uint8 1737 ) 1738 return x 1739 1740 class M(torch.nn.Module): 1741 def __init__( 1742 self, 1743 add_fn, 1744 use_relu, 1745 fake_quant_before_extra_input, 1746 ): 1747 super().__init__() 1748 self.linear1 = torch.nn.Linear(4, 4) 1749 self.linear2 = torch.nn.Linear(4, 4) 1750 self.add_fn = add_fn 1751 self.relu = torch.nn.ReLU() 1752 self.linear3 = torch.nn.Linear(4, 4) 1753 self.linear4 = torch.nn.Linear(4, 4) 1754 self.add_fn2 = add_fn 1755 self.relu2 = torch.nn.ReLU() 1756 self.use_relu = use_relu 1757 self.fake_quant_before_extra_input = fake_quant_before_extra_input 1758 1759 def forward(self, x): 1760 x1 = self.linear1(x) 1761 x2 = self.linear2(x) 1762 if self.fake_quant_before_extra_input: 1763 x2 = fake_quant(x2) 1764 tmp = self.add_fn(x1, x2) 1765 if self.use_relu: 1766 tmp = self.relu(tmp) 1767 tmp1 = self.linear3(tmp) 1768 tmp2 = self.linear4(tmp) 1769 if self.fake_quant_before_extra_input: 1770 tmp2 = fake_quant(tmp2) 1771 res = self.add_fn2(tmp1, tmp2) 1772 if self.use_relu: 1773 res = self.relu2(res) 1774 return res 1775 1776 add_fn_list = [ 1777 lambda x, y: x + y, 1778 lambda x, y: y + x, 1779 lambda x, y: x.add_(y), 1780 lambda x, y: y.add_(x), 1781 ] 1782 fake_quant_x2_list = [False, True] if int8_mixed_bf16 else [False] 1783 cases = itertools.product(add_fn_list, fake_quant_x2_list) 1784 for add_fn, fq_x2 in cases: 1785 mod = M(add_fn, use_relu, fq_x2).eval() 1786 v = torch.randn((4, 4), dtype=torch.float32, requires_grad=False).add(1) 1787 1788 def matcher_check_fn(): 1789 # 1. Dequant-linear pattern matched in quantization weight prepack * 4 1790 self.assertEqual( 1791 counters["inductor"]["qlinear_weight_prepack_matcher_count"], 4 1792 ) 1793 # pattern = [dequant_per_tensor, (convert_dtype), dequant_per_channel, (convert_dtype), permute, addmm] 1794 nodes_per_match = 6 if int8_mixed_bf16 else 4 1795 self.assertEqual( 1796 counters["inductor"]["qlinear_weight_prepack_matcher_nodes"], 1797 4 * nodes_per_match, 1798 ) 1799 # 2. Qlinear Binary Unary fusion in post-grad fusion pass * 2 1800 self.assertEqual( 1801 counters["inductor"]["qlinear_binary_matcher_count"], 2 1802 ) 1803 # Two linear-binary patterns are matched 1804 # matched patter1 = [qlinear, add, (convert dtype), (relu), quantize_per_tensor] 1805 # matched patter2 = [qlinear, add, (convert dtype), (relu)] 1806 # If add_fn is x.add_(y), x is bf16 and y is fp32, there is a to_bf16 node after binary 1807 to_bf16_after_binary = 2 * (add_fn == add_fn_list[2] and fq_x2) 1808 self.assertEqual( 1809 counters["inductor"]["qlinear_binary_matcher_nodes"], 1810 5 + 2 * use_relu + to_bf16_after_binary, 1811 ) 1812 1813 for is_qat in [False, True]: 1814 self._test_common( 1815 mod, 1816 (v,), 1817 check_quantization=True, 1818 check_autocast=torch.bfloat16 if int8_mixed_bf16 else torch.float, 1819 matcher_check_fn=matcher_check_fn, 1820 is_qat=is_qat, 1821 ) 1822 if torch._inductor.config.cpp_wrapper: 1823 # For CPP wrapper 1824 self._test_code_common( 1825 mod, 1826 (v,), 1827 [ 1828 "op_qlinear_pointwise.call", 1829 "op_qlinear_pointwise_binary.call", 1830 ], 1831 [], 1832 check_quantization=True, 1833 num_include_ops=[2, 2], 1834 ) 1835 else: 1836 # For python wrapper 1837 self._test_code_common( 1838 mod, 1839 (v,), 1840 [ 1841 "torch.ops.onednn.qlinear_pointwise.default", 1842 "torch.ops.onednn.qlinear_pointwise.binary", 1843 ], 1844 [], 1845 check_quantization=True, 1846 num_include_ops=[2, 2], 1847 ) 1848 1849 @skipIfNoDynamoSupport 1850 @skipIfNoONEDNN 1851 @skipIfRocm 1852 def test_qlinear_add_cpu(self): 1853 self._qlinear_add_cpu_test_helper() 1854 1855 @skipIfNoDynamoSupport 1856 @skipIfNoONEDNNBF16 1857 @skipIfNoONEDNN 1858 @skipIfRocm 1859 def test_qlinear_add_int8_mixed_bf16(self): 1860 self._qlinear_add_cpu_test_helper(int8_mixed_bf16=True) 1861 1862 @skipIfNoDynamoSupport 1863 @skipIfNoONEDNN 1864 @skipIfRocm 1865 def test_qlinear_add_relu_cpu(self): 1866 self._qlinear_add_cpu_test_helper(use_relu=True) 1867 1868 @skipIfNoDynamoSupport 1869 @skipIfNoONEDNNBF16 1870 @skipIfNoONEDNN 1871 @skipIfRocm 1872 def test_qlinear_add_relu_int8_mixed_bf16(self): 1873 self._qlinear_add_cpu_test_helper(use_relu=True, int8_mixed_bf16=True) 1874 1875 def _qlinear_dequant_promotion_cpu_test_helper( 1876 self, 1877 inputs, 1878 int8_mixed_bf16=False, 1879 is_dynamic=False, 1880 matcher_check_fn=None, 1881 ): 1882 class M(torch.nn.Module): 1883 def __init__( 1884 self, 1885 **kwargs, 1886 ): 1887 super().__init__() 1888 self.linear1 = torch.nn.Linear(4, 4) 1889 self.linear2 = torch.nn.Linear(4, 4) 1890 self.linear3 = torch.nn.Linear(4, 4) 1891 1892 def forward(self, x): 1893 temp = self.linear1(x) 1894 temp = self.linear2(temp) + self.linear3(temp) 1895 return temp 1896 1897 mod = M().eval() 1898 1899 def default_matcher_check_fn(): 1900 # 1. Dequant pattern matcher for dequant promotion * 1 1901 self.assertEqual(counters["inductor"]["dequant_promotion_matcher_count"], 1) 1902 # 2. dequant-linear pattern matched in quantization weight prepack * 3 1903 self.assertEqual( 1904 counters["inductor"]["qlinear_weight_prepack_matcher_count"], 3 1905 ) 1906 # 3. QLinear Unary fusion in post-grad fusion pass * 1 1907 self.assertEqual(counters["inductor"]["qlinear_unary_matcher_count"], 1) 1908 1909 self._test_common( 1910 mod, 1911 inputs, 1912 check_autocast=torch.bfloat16 if int8_mixed_bf16 else torch.float, 1913 check_quantization=True, 1914 matcher_check_fn=matcher_check_fn 1915 if matcher_check_fn is not None 1916 else default_matcher_check_fn, 1917 is_dynamic=is_dynamic, 1918 ) 1919 1920 @skipIfNoDynamoSupport 1921 @skipIfNoONEDNN 1922 @skipIfRocm 1923 def test_qlinear_dequant_promotion_cpu(self): 1924 r""" 1925 This testcase test if dequant node before linear is promoted correctly: 1926 X 1927 | 1928 Linear1(X) 1929 / \ 1930 Linear2(X) Linear3(X) 1931 \ / 1932 Add 1933 | 1934 Y 1935 """ 1936 self._qlinear_dequant_promotion_cpu_test_helper((torch.randn((2, 4)),)) 1937 1938 @skipIfNoDynamoSupport 1939 @skipIfNoONEDNNBF16 1940 @skipIfNoONEDNN 1941 @skipIfRocm 1942 def test_qlinear_dequant_promotion_int8_mixed_bf16(self): 1943 r""" 1944 Test with int8_mixed_bf16 quantization. 1945 This testcase test if dequant node before linear is promoted correctly: 1946 X 1947 | 1948 Linear1(X) 1949 / \ 1950 Linear2(X) Linear3(X) 1951 \ / 1952 Add 1953 | 1954 Y 1955 """ 1956 self._qlinear_dequant_promotion_cpu_test_helper( 1957 (torch.randn((2, 4)),), int8_mixed_bf16=True 1958 ) 1959 1960 @skipIfNoDynamoSupport 1961 @skipIfNoONEDNN 1962 @skipIfRocm 1963 def test_qlinear_dequant_promotion_cpu_input_dim_exceeds_2(self): 1964 r""" 1965 This testcase test if dequant node before linear is promoted correctly: 1966 X 1967 | 1968 Linear1(X) 1969 / \ 1970 Linear2(X) Linear3(X) 1971 \ / 1972 Add 1973 | 1974 Y 1975 """ 1976 self._qlinear_dequant_promotion_cpu_test_helper((torch.randn((2, 3, 4)),)) 1977 1978 @skipIfNoDynamoSupport 1979 @skipIfNoONEDNNBF16 1980 @skipIfNoONEDNN 1981 @skipIfRocm 1982 def test_qlinear_dequant_promotion_int8_mixed_bf16_input_dim_exceeds_2(self): 1983 r""" 1984 Test with int8_mixed_bf16 quantization. 1985 This testcase test if dequant node before linear is promoted correctly: 1986 X 1987 | 1988 Linear1(X) 1989 / \ 1990 Linear2(X) Linear3(X) 1991 \ / 1992 Add 1993 | 1994 Y 1995 """ 1996 self._qlinear_dequant_promotion_cpu_test_helper( 1997 (torch.randn((2, 3, 4)),), int8_mixed_bf16=True 1998 ) 1999 2000 @skipIfNoDynamoSupport 2001 @skipIfNoONEDNN 2002 @skipIfRocm 2003 def test_qlinear_dequant_promotion_dynamic_cpu(self): 2004 r""" 2005 This testcase test if dequant node before linear is promoted correctly: 2006 X 2007 | 2008 Linear1(X) 2009 / \ 2010 Linear2(X) Linear3(X) 2011 \ / 2012 Add 2013 | 2014 Y 2015 """ 2016 2017 def matcher_check_fn(): 2018 # 1. Dequant pattern matcher for dequant promotion * 1 2019 self.assertEqual(counters["inductor"]["dequant_promotion_matcher_count"], 1) 2020 # 2. dequant-linear pattern matched in quantization weight prepack * 3 2021 self.assertEqual( 2022 counters["inductor"]["qlinear_weight_prepack_matcher_count"], 3 2023 ) 2024 2025 self._qlinear_dequant_promotion_cpu_test_helper( 2026 (torch.randn((2, 4)),), 2027 matcher_check_fn=matcher_check_fn, 2028 is_dynamic=True, 2029 ) 2030 2031 @skipIfNoDynamoSupport 2032 @skipIfNoONEDNN 2033 @skipIfRocm 2034 def test_qlinear_mul_cpu(self): 2035 r""" 2036 This testcase will quantize a Linear->Mul pattern. 2037 """ 2038 2039 class M(torch.nn.Module): 2040 def __init__(self, use_bias): 2041 super().__init__() 2042 self.linear = torch.nn.Linear(4, 5, use_bias) 2043 2044 def forward(self, x1, x2): 2045 return torch.mul(self.linear(x1), x2) 2046 2047 bias_list = [True, False] 2048 for bias in bias_list: 2049 mod = M(bias).eval() 2050 x1 = torch.randn((2, 4)) 2051 x2 = torch.randn((2, 5)) 2052 2053 def matcher_check_fn(): 2054 self.assertEqual( 2055 counters["inductor"]["qlinear_weight_prepack_matcher_count"], 1 2056 ) 2057 2058 self._test_common( 2059 mod, 2060 (x1, x2), 2061 check_quantization=True, 2062 matcher_check_fn=matcher_check_fn, 2063 ) 2064 2065 @skipIfNoDynamoSupport 2066 @skipIfRocm 2067 def test_qmaxpool2d(self): 2068 r""" 2069 This testcase will quantize Conv2d->ReLU->MaxPool2d pattern. 2070 """ 2071 2072 class M(torch.nn.Module): 2073 def __init__( 2074 self, 2075 kwargs, 2076 ): 2077 super().__init__() 2078 self.conv = torch.nn.Conv2d( 2079 3, 64, 7, bias=True, stride=2, padding=3, dilation=1 2080 ) 2081 self.relu = torch.nn.ReLU() 2082 self.maxpool = torch.nn.MaxPool2d(3, **kwargs) 2083 2084 def forward(self, x): 2085 return self.maxpool(self.relu(self.conv(x))) 2086 2087 kwargs_list = [ 2088 {"stride": 2}, 2089 {"stride": 2, "padding": 1}, 2090 {"stride": 2, "padding": 1, "dilation": 1}, 2091 {"stride": 2, "padding": 1, "dilation": 1, "ceil_mode": False}, 2092 ] 2093 for kwargs in kwargs_list: 2094 mod = M(kwargs).eval() 2095 v = torch.randn((1, 3, 8, 8), dtype=torch.float32, requires_grad=False).add( 2096 1 2097 ) 2098 2099 def matcher_check_fn(): 2100 self.assertEqual(counters["inductor"]["qmaxpool2d_matcher_count"], 1) 2101 self.assertEqual( 2102 counters["inductor"]["qconv2d_weight_prepack_matcher_count"], 1 2103 ) 2104 self.assertEqual(counters["inductor"]["qconv2d_unary_matcher_count"], 1) 2105 2106 self._test_common( 2107 mod, 2108 (v,), 2109 check_quantization=True, 2110 matcher_check_fn=matcher_check_fn, 2111 ) 2112 2113 @skipIfNoDynamoSupport 2114 @skipIfRocm 2115 def test_qflatten(self): 2116 r""" 2117 This testcase will quantize Conv2d->AdaptiveAvgPool2d->flatten pattern. 2118 """ 2119 2120 class M(torch.nn.Module): 2121 def __init__( 2122 self, 2123 ): 2124 super().__init__() 2125 self.conv = torch.nn.Conv2d( 2126 3, 64, 7, bias=True, stride=2, padding=3, dilation=1 2127 ) 2128 self.relu = torch.nn.ReLU() 2129 self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1)) 2130 2131 def forward(self, x): 2132 return torch.flatten( 2133 self.adaptive_avg_pool2d(self.relu(self.conv(x))), 1 2134 ) 2135 2136 mod = M().eval() 2137 v = torch.randn((1, 3, 8, 8), dtype=torch.float32, requires_grad=False).add(1) 2138 2139 def matcher_check_fn(): 2140 self.assertEqual(counters["inductor"]["qreshape_matcher_count"], 1) 2141 2142 self._test_common( 2143 mod, 2144 (v,), 2145 check_quantization=True, 2146 matcher_check_fn=matcher_check_fn, 2147 ) 2148 2149 @skipIfNoDynamoSupport 2150 @skipIfRocm 2151 def test_qcat(self): 2152 r""" 2153 This testcase will quantize cat based pattern: 2154 X 2155 / \ 2156 Conv1(X) Pow(x) 2157 \ \ 2158 \ Conv2(X) 2159 \ / 2160 Cat 2161 | 2162 Y 2163 """ 2164 2165 class M(torch.nn.Module): 2166 def __init__( 2167 self, 2168 ): 2169 super().__init__() 2170 self.conv = torch.nn.Conv2d( 2171 3, 64, 7, bias=True, stride=2, padding=3, dilation=1 2172 ) 2173 self.conv2 = torch.nn.Conv2d( 2174 3, 64, 7, bias=True, stride=2, padding=3, dilation=1 2175 ) 2176 2177 def forward(self, x): 2178 temp1 = self.conv(x) 2179 temp2 = self.conv2(torch.pow(x, 2)) 2180 return torch.cat((temp1, temp2), 1) 2181 2182 mod = M().eval() 2183 v = torch.randn((1, 3, 8, 8), dtype=torch.float32, requires_grad=False).add(1) 2184 2185 def matcher_check_fn(): 2186 self.assertEqual(counters["inductor"]["qcat_matcher_count"], 1) 2187 self.assertEqual( 2188 counters["inductor"]["qconv2d_weight_prepack_matcher_count"], 2 2189 ) 2190 self.assertEqual(counters["inductor"]["qconv2d_unary_matcher_count"], 2) 2191 2192 self._test_common( 2193 mod, 2194 (v,), 2195 check_quantization=True, 2196 matcher_check_fn=matcher_check_fn, 2197 ) 2198 2199 # https://github.com/pytorch/pytorch/issues/99841. 2200 def test_hardtanh_pattern_fallback(self): 2201 class Model(torch.nn.Module): 2202 def __init__(self): 2203 super().__init__() 2204 self.conv_transpose = torch.nn.ConvTranspose2d( 2205 in_channels=3, out_channels=32, kernel_size=3, stride=1, padding=1 2206 ) 2207 2208 def forward(self, x, min_value, max_value): 2209 conv_transpose_output = self.conv_transpose(x) 2210 clamp_min_output = torch.clamp_min(conv_transpose_output, min_value) 2211 clamp_max_output = torch.clamp_max(clamp_min_output, max_value) 2212 return clamp_max_output 2213 2214 # check works for min_value > max_value. 2215 min_values = [3, torch.randn(1, 32, 28, 28)] 2216 max_values = [0, torch.randn(1, 32, 28, 28)] 2217 v = torch.randn(1, 3, 28, 28) 2218 for min_value, max_value in zip(min_values, max_values): 2219 mod = Model().eval() 2220 self._test_common(mod, (v, min_value, max_value), 2, 4) 2221 2222 def test_leaky_relu_pattern_fallback(self): 2223 class Model(torch.nn.Module): 2224 def __init__(self): 2225 super().__init__() 2226 self.conv = torch.nn.Conv2d( 2227 in_channels=3, out_channels=32, kernel_size=3, stride=1, padding=1 2228 ) 2229 2230 def forward(self, x, negative_slope): 2231 conv_out = self.conv(x) 2232 return torch.where(conv_out > 0, conv_out, conv_out * negative_slope) 2233 2234 negative_slopes = [0.1, torch.randn(1, 32, 28, 28)] 2235 with torch.no_grad(): 2236 v = torch.randn(1, 3, 28, 28) 2237 for negative_slope in negative_slopes: 2238 mod = Model().eval() 2239 self._test_common(mod, (v, negative_slope), 2, 5) 2240 2241 # https://github.com/pytorch/pytorch/issues/99838. 2242 def test_conv2d_add_scalar(self): 2243 class Model(torch.nn.Module): 2244 def __init__(self): 2245 super().__init__() 2246 self.conv = torch.nn.Conv2d( 2247 in_channels=3, out_channels=32, kernel_size=3, stride=1, padding=1 2248 ) 2249 2250 def forward(self, x): 2251 out_conv = self.conv(x) 2252 out = torch.add(out_conv, 1.0) 2253 return out 2254 2255 with torch.no_grad(): 2256 mod = Model().eval() 2257 v = torch.randn(1, 3, 28, 28) 2258 self._test_common(mod, (v,), 1, 1) 2259 2260 def test_conv2d_binary_inplace_fusion_pass_cpu( 2261 self, include_ops=None, exclude_ops=None 2262 ): 2263 class Model_v1(torch.nn.Module): 2264 def __init__(self): 2265 super().__init__() 2266 self.conv = torch.nn.Conv2d( 2267 in_channels=3, out_channels=32, kernel_size=3, stride=1, padding=1 2268 ) 2269 2270 def forward(self, x, other): 2271 conv_out = self.conv(x) 2272 return torch.add(conv_out, other.relu()) 2273 2274 class Model_v2(torch.nn.Module): 2275 def __init__(self): 2276 super().__init__() 2277 self.conv = torch.nn.Conv2d( 2278 in_channels=3, out_channels=32, kernel_size=3, stride=1, padding=1 2279 ) 2280 self.conv2 = torch.nn.Conv2d( 2281 in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1 2282 ) 2283 self.conv3 = torch.nn.Conv2d( 2284 in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1 2285 ) 2286 2287 def forward(self, x, _): 2288 conv_out1 = self.conv(x) 2289 pow_out = torch.pow(conv_out1, 2) 2290 conv_out2 = self.conv2(pow_out) 2291 conv_out3 = self.conv3(conv_out2) 2292 res = torch.add(conv_out3, pow_out) 2293 return res 2294 2295 input = torch.randn(1, 3, 28, 28).to(memory_format=torch.channels_last) 2296 others = [ 2297 torch.randn(1, 32, 28, 28).to(memory_format=torch.channels_last), 2298 torch.randn(1, 32, 28, 28).to(memory_format=torch.channels_last), 2299 ] 2300 mod_v1 = Model_v1().to(memory_format=torch.channels_last).eval() 2301 mod_v2 = Model_v2().to(memory_format=torch.channels_last).eval() 2302 2303 if include_ops is None: 2304 include_ops = ["mkldnn._convolution_pointwise_.binary"] 2305 if exclude_ops is None: 2306 exclude_ops = ["mkldnn._convolution_pointwise.binary"] 2307 2308 for other, mod in zip(others, [mod_v1, mod_v2]): 2309 self._test_code_common(mod, (input, other), include_ops, exclude_ops) 2310 2311 def test_conv2d_binary_inplace_fusion_failed_cpu( 2312 self, include_ops=None, exclude_ops=None 2313 ): 2314 # Written buffer is graph input, we can't fuse inplace. 2315 class Model_v1(torch.nn.Module): 2316 def __init__(self): 2317 super().__init__() 2318 self.conv = torch.nn.Conv2d( 2319 in_channels=3, out_channels=32, kernel_size=3, stride=1, padding=1 2320 ) 2321 2322 def forward(self, x, other): 2323 conv_out = self.conv(x) 2324 return torch.add(conv_out, other) 2325 2326 # Written buffer is an alias tensor, we can't fuse inplace. 2327 class Model_v2(torch.nn.Module): 2328 def __init__(self): 2329 super().__init__() 2330 self.conv = torch.nn.Conv2d( 2331 in_channels=3, out_channels=32, kernel_size=3, stride=1, padding=1 2332 ) 2333 2334 def forward(self, x, other): 2335 conv_out = self.conv(x) 2336 return torch.add(conv_out, other[1:2, :, :, :]), other 2337 2338 class Model_v3(torch.nn.Module): 2339 def __init__(self): 2340 super().__init__() 2341 self.conv = torch.nn.Conv2d( 2342 in_channels=3, out_channels=32, kernel_size=3, stride=1, padding=1 2343 ) 2344 self.conv2 = torch.nn.Conv2d( 2345 in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1 2346 ) 2347 2348 def forward(self, x, _): 2349 pow_out = torch.pow(self.conv(x), 2) 2350 other2 = F.relu(pow_out) 2351 conv_out2 = self.conv2(pow_out) 2352 res = torch.add(conv_out2, pow_out) 2353 res = res + other2 2354 return res 2355 2356 # Written buffer is an ReinterpretView, we can't fuse inplace. 2357 class Model_v4(torch.nn.Module): 2358 def __init__(self): 2359 super().__init__() 2360 self.conv = torch.nn.Conv2d(3, 32, 3, padding=1, bias=True) 2361 self.linear = torch.nn.Linear(32 * 28, 32 * 28) 2362 self.relu = torch.nn.ReLU() 2363 2364 def forward(self, x, y): 2365 x = self.conv(self.relu(x)) 2366 y = self.linear(y) 2367 y = torch.cat((y, y), 1) 2368 y = torch.ops.aten.permute.default(y, [0, 2, 1]).reshape(1, 32, 28, 28) 2369 return x + y 2370 2371 class Model_v5(torch.nn.Module): 2372 def __init__(self): 2373 super().__init__() 2374 self.conv = torch.nn.Conv2d(32, 32, 3, padding=1, bias=True) 2375 self.relu = torch.nn.ReLU() 2376 2377 def forward(self, _, x): 2378 x1 = self.relu(x) 2379 return self.conv(x1) + x1 2380 2381 input = torch.randn(1, 3, 28, 28).to(memory_format=torch.channels_last) 2382 others = [ 2383 torch.randn(1, 32, 28, 28).to(memory_format=torch.channels_last), 2384 torch.randn(2, 32, 28, 28).to(memory_format=torch.channels_last), 2385 torch.randn(1, 32, 28, 28).to(memory_format=torch.channels_last), 2386 torch.randn(1, 14, 32 * 28), 2387 torch.randn(1, 32, 28, 28).to(memory_format=torch.channels_last), 2388 ] 2389 mod_v1 = Model_v1().to(memory_format=torch.channels_last).eval() 2390 mod_v2 = Model_v2().to(memory_format=torch.channels_last).eval() 2391 mod_v3 = Model_v3().to(memory_format=torch.channels_last).eval() 2392 mod_v4 = Model_v4().to(memory_format=torch.channels_last).eval() 2393 mod_v5 = Model_v5().to(memory_format=torch.channels_last).eval() 2394 2395 if include_ops is None: 2396 include_ops = ["mkldnn._convolution_pointwise.binary"] 2397 if exclude_ops is None: 2398 exclude_ops = ["mkldnn._convolution_pointwise_.binary"] 2399 2400 for other, mod in zip(others, [mod_v1, mod_v2, mod_v3, mod_v4, mod_v5]): 2401 self._test_code_common(mod, (input, other), include_ops, exclude_ops) 2402 2403 def test_conv2d_binary_fusion_failed(self): 2404 # we don't support alpha !=1 case or other has different size with conv's output. 2405 class Model(torch.nn.Module): 2406 def __init__(self): 2407 super().__init__() 2408 self.conv = torch.nn.Conv2d( 2409 in_channels=3, out_channels=32, kernel_size=3, stride=1, padding=1 2410 ) 2411 2412 def forward(self, x, other, alpha): 2413 conv_out = self.conv(x) 2414 return torch.add(conv_out, other, alpha=alpha) 2415 2416 # https://github.com/pytorch/pytorch/issues/100802. 2417 # we can't do the fusion when add's inputs are same tensor. 2418 class Model2(torch.nn.Module): 2419 def __init__(self): 2420 super().__init__() 2421 self.conv = torch.nn.Conv2d( 2422 in_channels=3, out_channels=16, kernel_size=3, stride=1, padding=1 2423 ) 2424 2425 def forward(self, x): 2426 out = self.conv(x) 2427 out = torch.add(out, out) 2428 return out 2429 2430 # https://github.com/pytorch/pytorch/issues/101374. 2431 # we can't do the fusion when add's inputs are mixed dtype. 2432 class Model3(torch.nn.Module): 2433 def __init__(self): 2434 super().__init__() 2435 self.conv = torch.nn.Conv2d( 2436 in_channels=3, out_channels=16, kernel_size=3, stride=1, padding=1 2437 ) 2438 2439 def forward(self, x): 2440 temp = self.conv(x) 2441 other = torch.ones(temp.shape, dtype=torch.double) 2442 out = torch.add(temp, other) 2443 return out 2444 2445 input = torch.randn(1, 3, 28, 28).to(memory_format=torch.channels_last) 2446 others = [ 2447 torch.randn(1, 32, 28, 28).to(memory_format=torch.channels_last), 2448 torch.randn(32, 28, 28), 2449 ] 2450 include_ops = ["mkldnn._convolution_pointwise"] 2451 exclude_ops = [ 2452 "mkldnn._convolution_pointwise.binary", 2453 "mkldnn._convolution_pointwise_.binary", 2454 ] 2455 2456 # case1 2457 for other, alpha in zip(others, [0.1, 1.0]): 2458 mod = Model().to(memory_format=torch.channels_last).eval() 2459 self._test_code_common(mod, (input, other, alpha), include_ops, exclude_ops) 2460 # case2: 2461 mod = Model2().to(memory_format=torch.channels_last).eval() 2462 self._test_code_common(mod, (input,), include_ops, exclude_ops) 2463 # case3: 2464 mod = Model3().to(memory_format=torch.channels_last).eval() 2465 self._test_code_common(mod, (input,), include_ops, exclude_ops) 2466 2467 def test_reproduce_99842_issue(self): 2468 class Model(torch.nn.Module): 2469 def __init__(self): 2470 super().__init__() 2471 self.conv = torch.nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1) 2472 2473 def forward(self, input_tensor): 2474 x = self.conv(input_tensor) 2475 x = F.relu(x + torch.ones(x.size())) 2476 return x 2477 2478 input = torch.randn(1, 3, 14, 14) 2479 mod = Model().eval() 2480 include_ops = ["mkldnn._convolution_pointwise_.binary"] 2481 self._test_code_common(mod, (input,), include_ops, []) 2482 2483 def test_reproduce_113440_issue_1(self): 2484 class Mod(torch.nn.Module): 2485 def __init__( 2486 self, 2487 add_fn, 2488 **kwargs, 2489 ): 2490 super().__init__() 2491 self.conv1 = torch.nn.Conv2d(3, 6, kernel_size=3, stride=1) 2492 self.conv2 = torch.nn.Conv2d(3, 6, kernel_size=3, stride=1) 2493 self.add_fn = add_fn 2494 self.relu = torch.nn.ReLU(inplace=True) 2495 self.conv3 = torch.nn.Conv2d(6, 6, kernel_size=3, stride=1) 2496 self.conv4 = torch.nn.Conv2d(6, 6, kernel_size=3, stride=1) 2497 self.add_fn2 = add_fn 2498 self.relu2 = torch.nn.ReLU(inplace=True) 2499 self.use_relu = True 2500 2501 def forward(self, x): 2502 x1 = self.conv1(x) 2503 x2 = self.conv2(x) 2504 tmp = self.add_fn(x1, x2) 2505 if self.use_relu: 2506 tmp = self.relu(tmp) 2507 tmp1 = self.conv3(tmp) 2508 tmp2 = self.conv4(tmp) 2509 res = self.add_fn2(tmp1, tmp2) 2510 if self.use_relu: 2511 res = self.relu2(res) 2512 return res 2513 2514 with torch.no_grad(): 2515 example_inputs = ( 2516 torch.randn((1, 3, 8, 8), dtype=torch.float32, requires_grad=False).add( 2517 1 2518 ), 2519 ) 2520 example_inputs[0].get_device() 2521 m = Mod( 2522 lambda x, y: x.add_(y), 2523 ).eval() 2524 om = torch.compile(m) 2525 om(*example_inputs) 2526 om(*example_inputs) 2527 2528 def test_reproduce_113440_issue_2(self): 2529 class Mod(torch.nn.Module): 2530 def __init__( 2531 self, 2532 add_fn, 2533 **kwargs, 2534 ): 2535 super().__init__() 2536 self.conv1 = torch.nn.Conv2d(3, 6, kernel_size=3, stride=1) 2537 self.conv2 = torch.nn.Conv2d(3, 6, kernel_size=3, stride=1) 2538 self.add_fn = add_fn 2539 self.relu = torch.nn.ReLU(inplace=True) 2540 self.conv3 = torch.nn.Conv2d(6, 6, kernel_size=3, stride=1) 2541 self.conv4 = torch.nn.Conv2d(6, 6, kernel_size=3, stride=1) 2542 self.add_fn2 = add_fn 2543 self.relu2 = torch.nn.ReLU(inplace=True) 2544 2545 self.conv5 = torch.nn.Conv2d(6, 6, kernel_size=3, stride=1) 2546 self.conv6 = torch.nn.Conv2d(6, 6, kernel_size=3, stride=1) 2547 self.conv7 = torch.nn.Conv2d(6, 6, kernel_size=1, stride=1) 2548 self.add_fn3 = add_fn 2549 self.relu3 = torch.nn.ReLU(inplace=True) 2550 2551 self.use_relu = True 2552 2553 def forward(self, x): 2554 x1 = self.conv1(x) 2555 x2 = self.conv2(x) 2556 tmp = self.add_fn(x1, x2) 2557 if self.use_relu: 2558 tmp = self.relu(tmp) 2559 2560 tmp1 = self.conv3(tmp) 2561 res = self.relu2(tmp1) 2562 2563 return res 2564 2565 with torch.no_grad(): 2566 example_inputs = ( 2567 torch.randn((1, 3, 8, 8), dtype=torch.float32, requires_grad=False).add( 2568 1 2569 ), 2570 ) 2571 m = Mod( 2572 lambda x, y: x.add_(y), 2573 ).eval() 2574 om = torch.compile(m) 2575 om(*example_inputs) 2576 om(*example_inputs) 2577 2578 def test_reproduce_121253_issue(self): 2579 class Mod(torch.nn.Module): 2580 def __init__(self, weight, bias, beta, alpha): 2581 super().__init__() 2582 self.weight = weight 2583 self.bias = bias 2584 self.beta = beta 2585 self.alpha = alpha 2586 2587 def forward(self, x): 2588 return torch.addmm( 2589 self.bias, x, self.weight, beta=self.beta, alpha=self.alpha 2590 ) 2591 2592 dtypes = [torch.float32] 2593 if torch.ops.mkldnn._is_mkldnn_bf16_supported(): 2594 dtypes.append(torch.bfloat16) 2595 for dtype in dtypes: 2596 linear_op = ( 2597 "mkl._mkl_linear" 2598 if dtype == torch.float32 2599 else "mkldnn._linear_pointwise" 2600 ) 2601 for beta, alpha in zip([1.0, 0.1, 0.0], [1.0, 0.1, 1.0]): 2602 weight = torch.randn(64, 64, dtype=dtype) 2603 bias = torch.randn(64, dtype=dtype) 2604 mod = Mod(weight, bias, beta, alpha).to(dtype).eval() 2605 with torch.no_grad(): 2606 x = torch.randn(1, 64, dtype=dtype) 2607 include_ops = [] 2608 exclude_ops = [] 2609 if (beta != 1.0 and beta != 0.0) or alpha != 1.0: 2610 exclude_ops = [linear_op] 2611 else: 2612 include_ops = [linear_op] 2613 self._test_code_common(mod, (x,), include_ops, exclude_ops) 2614 2615 @skipIfNoDynamoSupport 2616 @skipIfRocm 2617 def test_woq_int8(self): 2618 class M(torch.nn.Module): 2619 def forward(self, x, weight, scales): 2620 return torch.nn.functional.linear(x, weight.to(dtype=x.dtype)) * scales 2621 2622 mod = M().eval() 2623 x_shape = (1, 1, 256) 2624 w_shape = (12, 256) 2625 s_shape = 12 2626 x_strides = [ 2627 (256, 256, 1), # linear dispatching to mm 2628 (256, 32, 1), # linear dispatching to bmm 2629 ] 2630 for x_stride in x_strides: 2631 x = torch.randn(x_shape, dtype=torch.bfloat16).as_strided(x_shape, x_stride) 2632 w = torch.randint(-128, 127, w_shape, dtype=torch.int8) 2633 s = torch.randn(s_shape, dtype=torch.bfloat16) 2634 2635 def matcher_check_fn(): 2636 self.assertEqual(counters["inductor"]["woq_matcher_count"], 1) 2637 2638 self._test_common( 2639 mod, 2640 (x, w, s), 2641 matcher_check_fn=matcher_check_fn, 2642 check_quantization=False, 2643 atol=0.001, 2644 rtol=0.07, 2645 ) 2646 2647 2648@dynamo_config.patch({"dynamic_shapes": True, "assume_static_by_default": False}) 2649class TestDynamicPatternMatcher(TestPatternMatcherBase): 2650 _test_conv_unary_cpu_base = TestPatternMatcher._test_conv_unary_cpu_base 2651 test_conv2d_unary_dynamic_shapes = TestPatternMatcher.test_conv2d_unary_cpu 2652 test_conv3d_unary_dynamic_shapes = TestPatternMatcher.test_conv3d_unary_cpu 2653 _test_conv_binary_base = TestPatternMatcher._test_conv_binary_base 2654 test_conv2d_binary_dynamic_shapes = TestPatternMatcher.test_conv2d_binary 2655 test_conv3d_binary_dynamic_shapes = TestPatternMatcher.test_conv3d_binary 2656 test_linear_unary_dynamic_shapes = TestPatternMatcher.test_linear_unary 2657 2658 def test_conv_transpose2d_dynamic_shapes(self): 2659 # We don't support conv_transpose2d for now. 2660 class M(torch.nn.Module): 2661 def __init__(self): 2662 super().__init__() 2663 self.conv_transpose2d = torch.nn.ConvTranspose2d( 2664 3, 16, 3, stride=2, padding=1 2665 ) 2666 2667 def forward(self, x): 2668 return self.conv_transpose2d(x) 2669 2670 x_shape = (1, 3, 28, 28) 2671 mod = M().eval() 2672 v = torch.randn(x_shape, dtype=torch.float32) 2673 self._test_common(mod, (v,), 0, 0) 2674 2675 def test_multi_linear_share_same_input_dynamic(self): 2676 # llama pattern. 2677 class M(torch.nn.Module): 2678 def __init__( 2679 self, 2680 ): 2681 super().__init__() 2682 self.w1 = torch.nn.Linear(16, 16, bias=False) 2683 self.w2 = torch.nn.Linear(16, 16, bias=False) 2684 2685 def forward(self, x): 2686 return F.silu(self.w1(x)) * F.relu(self.w2(x)) 2687 2688 dtypes = [] 2689 if torch.ops.mkldnn._is_mkldnn_bf16_supported(): 2690 dtypes.append(torch.bfloat16) 2691 if torch.ops.mkldnn._is_mkldnn_fp16_supported(): 2692 dtypes.append(torch.float16) 2693 for dtype in dtypes: 2694 mod = M().to(dtype).eval() 2695 v = torch.randn(2, 4, 16).to(dtype) 2696 # 1. view(match_count=4, match_nodes=4). 2697 # 2. mm to packed linear(match_count=2, match_nodes=2). 2698 # 3. view+linear+view to linear(match_count=2, match_nodes=6). 2699 # 4. linear to linear+swish(match_count=1, match_nodes=2). 2700 # 5. linear to linear+relu(match_count=1, match_nodes=5). 2701 2702 match_count = 10 2703 match_nodes = 19 2704 self._test_common(mod, (v,), match_count, match_nodes, rtol=1e-2, atol=1e-2) 2705 2706 def test_qconv2d_maxpool2d_linear_dynamic_cpu(self, include_ops=None): 2707 r""" 2708 This testcase will quantize a single Conv2d->Maxpool2d->Linear module 2709 with dynamic batch size input. 2710 """ 2711 2712 class M(torch.nn.Module): 2713 def __init__( 2714 self, 2715 **kwargs, 2716 ): 2717 super().__init__() 2718 self.conv = torch.nn.Conv2d( 2719 3, 16, (2, 2), stride=(1, 1), padding=(1, 1) 2720 ) 2721 self.relu = torch.nn.ReLU() 2722 self.maxpool2d = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 2723 self.avgpool = torch.nn.AdaptiveAvgPool2d((1, 1)) 2724 self.linear = torch.nn.Linear(16, 16) 2725 2726 def forward(self, x): 2727 temp = self.relu(self.conv(x)) 2728 temp = self.maxpool2d(temp) 2729 temp = self.avgpool(temp) 2730 temp = torch.flatten(temp, 1) 2731 return self.linear(temp) 2732 2733 mod = M().eval() 2734 v = torch.randn((2, 3, 8, 8), dtype=torch.float32, requires_grad=False).add(1) 2735 if include_ops is None: 2736 include_ops = [ 2737 "torch.ops.onednn.qconv2d_pointwise", 2738 "torch.ops.quantized.max_pool2d", 2739 "torch.ops.onednn.qlinear_pointwise", 2740 ] 2741 exclude_ops = [] 2742 self._test_code_common( 2743 mod, 2744 (v,), 2745 include_ops, 2746 exclude_ops, 2747 check_quantization=True, 2748 check_dynamic=True, 2749 ) 2750 2751 @skipIfNoDynamoSupport 2752 @skipIfNoONEDNN 2753 @skipIfRocm 2754 def test_qat_bn_conv2d(self): 2755 r""" 2756 This testcase will quantize a single BN Conv2d module with qat flow. 2757 """ 2758 2759 class M(torch.nn.Module): 2760 def __init__( 2761 self, 2762 ): 2763 super().__init__() 2764 self.conv = torch.nn.Conv2d(3, 3, 3) 2765 self.bn1 = torch.nn.BatchNorm2d(3) 2766 self.bn2 = torch.nn.BatchNorm2d(3) 2767 2768 def forward(self, x): 2769 x = self.conv(self.bn1(x)) 2770 return self.bn2(x) 2771 2772 mod = M().train() 2773 v = torch.randn((1, 3, 8, 8), dtype=torch.float32, requires_grad=True).add(1) 2774 2775 def matcher_check_fn(): 2776 self.assertEqual( 2777 counters["inductor"]["qconv2d_weight_prepack_matcher_count"], 1 2778 ) 2779 2780 self._test_common( 2781 mod, 2782 (v,), 2783 check_quantization=True, 2784 is_qat=True, 2785 matcher_check_fn=matcher_check_fn, 2786 ) 2787 2788 @skipIfNoDynamoSupport 2789 @skipIfNoONEDNN 2790 @skipIfRocm 2791 def test_q_attention_block(self): 2792 class SelfAttnLikeModule(torch.nn.Module): 2793 def __init__( 2794 self, 2795 input_dim, 2796 transpose_for_score=False, 2797 num_attention_heads=None, 2798 attention_head_size=None, 2799 ) -> None: 2800 super().__init__() 2801 self.input_dim = input_dim 2802 self.q_proj = torch.nn.Linear(input_dim, input_dim, bias=False) 2803 self.k_proj = torch.nn.Linear(input_dim, input_dim, bias=False) 2804 self.v_proj = torch.nn.Linear(input_dim, input_dim, bias=False) 2805 self.softmax = torch.nn.Softmax(dim=-1) 2806 self.transpose_for_score = transpose_for_score 2807 if self.transpose_for_score: 2808 assert num_attention_heads is not None 2809 assert attention_head_size is not None 2810 self.num_attention_heads = num_attention_heads 2811 self.attention_head_size = attention_head_size 2812 2813 def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: 2814 new_x_shape = x.size()[:-1] + ( 2815 self.num_attention_heads, 2816 self.attention_head_size, 2817 ) 2818 x = x.view(new_x_shape) 2819 return x.permute(0, 2, 1, 3) 2820 2821 def forward(self, x): 2822 q = self.q_proj(x) 2823 k = self.k_proj(x) 2824 v = self.v_proj(x) 2825 if self.transpose_for_score: 2826 q = self.transpose_for_scores(q) 2827 k = self.transpose_for_scores(k) 2828 v = self.transpose_for_scores(v) 2829 scores = torch.matmul(q, k.transpose(-1, -2)) / (self.input_dim**0.5) 2830 attention = self.softmax(scores) 2831 weighted = torch.matmul(attention, v) 2832 return weighted 2833 2834 for annotate_matmul in [False, True]: 2835 mod = SelfAttnLikeModule( 2836 input_dim=64 * 16, 2837 transpose_for_score=True, 2838 num_attention_heads=16, 2839 attention_head_size=64, 2840 ).eval() 2841 v = torch.randn(2, 384, 1024) 2842 2843 def matcher_check_fn(): 2844 self.assertEqual( 2845 counters["inductor"]["qlinear_weight_prepack_matcher_count"], 3 2846 ) 2847 self.assertEqual( 2848 counters["inductor"]["qlinear_unary_matcher_count"], 2849 3 if annotate_matmul else 0, 2850 ) 2851 2852 quantizer = X86InductorQuantizer() 2853 quantizer.set_global(xiq.get_default_x86_inductor_quantization_config()) 2854 if annotate_matmul: 2855 quantizer.set_function_type_qconfig( 2856 torch.matmul, quantizer.get_global_quantization_config() 2857 ) 2858 2859 self._test_common( 2860 mod, 2861 (v,), 2862 check_quantization=True, 2863 matcher_check_fn=matcher_check_fn, 2864 quantizer=quantizer, 2865 ) 2866 2867 2868if __name__ == "__main__": 2869 if IS_LINUX and HAS_CPU and torch.backends.mkldnn.is_available(): 2870 run_tests() 2871