1# Owner(s): ["oncall: cpu inductor"] 2import contextlib 3import copy 4import functools 5import itertools 6import math 7import platform 8import sys 9import unittest 10from typing import Callable 11from unittest.mock import patch 12 13import numpy as np 14import sympy 15 16import torch 17from torch import nn 18from torch._C import FileCheck 19from torch._dynamo.testing import rand_strided 20from torch._dynamo.utils import same 21from torch._inductor import codecache, config, metrics 22from torch._inductor.codegen.common import OptimizationContext 23from torch._inductor.codegen.cpp import ( 24 CppOverrides, 25 CppVecKernelChecker, 26 CppVecOverrides, 27) 28from torch._inductor.compile_fx import ( 29 compile_fx, 30 compile_fx_inner, 31 complex_memory_overlap, 32) 33from torch._inductor.graph import GraphLowering 34from torch._inductor.ir import InterpreterShim 35from torch._inductor.utils import timed 36from torch._inductor.virtualized import V 37from torch.fx.experimental.proxy_tensor import make_fx 38from torch.nn import functional as F 39from torch.testing._internal.common_utils import ( 40 instantiate_parametrized_tests, 41 IS_MACOS, 42 parametrize, 43 slowTest, 44) 45from torch.utils._python_dispatch import TorchDispatchMode 46 47try: 48 try: 49 from . import test_torchinductor 50 except ImportError: 51 import test_torchinductor 52except unittest.SkipTest: 53 if __name__ == "__main__": 54 sys.exit(0) 55 raise 56 57 58vec_dtypes = test_torchinductor.vec_dtypes 59_lowp_fp_dtypes = ( 60 torch.bfloat16, 61 torch.float16, 62) 63run_and_get_cpp_code = test_torchinductor.run_and_get_cpp_code 64TestCase = test_torchinductor.TestCase 65aten = torch.ops.aten 66check_model = test_torchinductor.check_model 67 68requires_vectorization = unittest.skipUnless( 69 codecache.valid_vec_isa_list(), "Does not support vectorization" 70) 71 72 73def check_metrics_vec_kernel_count(num_expected_vec_kernels): 74 if codecache.valid_vec_isa_list(): 75 assert metrics.generated_cpp_vec_kernel_count == num_expected_vec_kernels 76 77 78@contextlib.contextmanager 79def set_num_threads(num_threads): 80 orig_num_threads = torch.get_num_threads() 81 torch.set_num_threads(num_threads) 82 yield 83 torch.set_num_threads(orig_num_threads) 84 85 86class LstmModule(torch.nn.Module): 87 def __init__( 88 self, 89 input_size, 90 hidden_size, 91 num_layers, 92 bias=True, 93 bidirectional=False, 94 batch_first=False, 95 ): 96 super().__init__() 97 self.lstm = torch.nn.LSTM( 98 input_size=input_size, 99 hidden_size=hidden_size, 100 num_layers=num_layers, 101 bias=bias, 102 bidirectional=bidirectional, 103 batch_first=batch_first, 104 ) 105 106 def forward(self, x, h=None): 107 x, h = self.lstm(x, h) 108 return x, h 109 110 111@instantiate_parametrized_tests 112class CPUReproTests(TestCase): 113 common = check_model 114 115 def test_conv_stride_constraints(self): 116 for fmt in [torch.contiguous_format, torch.channels_last]: 117 # TorchDispatch doesn't work in our cuda invocation for some reason 118 m = torch.nn.Conv2d(5, 6, [3, 3]) 119 120 def fn(inp, weight): 121 return ( 122 F.conv2d( 123 inp, weight, None, m.stride, m.padding, m.dilation, m.groups 124 ), 125 ) 126 127 inp = torch.randn([2, 5, 16, 16]) 128 inps = [inp, m.weight.to(memory_format=fmt)] 129 fn_fx = make_fx(fn)(*inps) 130 fn_compiled = compile_fx_inner(fn_fx, inps) 131 test_self = self 132 conv_seen = False 133 134 class RecordFunctions(TorchDispatchMode): 135 def __torch_dispatch__(self, func, types, args=(), kwargs=None): 136 kwargs = kwargs if kwargs else {} 137 if func == torch.ops.aten.convolution.default: 138 # For CPU and mkldnn enable, we always using channles last 139 nonlocal fmt 140 if ( 141 torch.backends.mkldnn.enabled 142 and torch.backends.mkldnn.is_available() 143 ): 144 fmt = torch.channels_last 145 test_self.assertTrue(args[0].is_contiguous(memory_format=fmt)) 146 test_self.assertTrue(args[1].is_contiguous(memory_format=fmt)) 147 nonlocal conv_seen 148 conv_seen = True 149 150 return func(*args, **kwargs) 151 152 with RecordFunctions(): 153 out = fn_compiled(inps) 154 155 self.assertTrue(conv_seen) 156 157 @patch("torch.cuda.is_available", lambda: False) 158 def test_conv2d_bn_mixed_dtype(self): 159 class Model(torch.nn.Module): 160 def __init__(self): 161 super().__init__() 162 self.conv = torch.nn.Conv2d( 163 3, 164 16, 165 kernel_size=3, 166 stride=1, 167 padding=1, 168 bias=False, 169 dtype=torch.bfloat16, 170 ) 171 self.bn = torch.nn.BatchNorm2d( 172 16, eps=0.001, momentum=0.1, affine=True, track_running_stats=True 173 ) 174 175 def forward(self, x): 176 x = self.conv(x) 177 x = self.bn(x) 178 return x 179 180 v = torch.randn(1, 3, 64, 64, dtype=torch.bfloat16) 181 mod = Model().eval() 182 with torch.no_grad(): 183 self.common( 184 mod, 185 (v,), 186 ) 187 188 @unittest.skipIf(not torch.backends.mkldnn.is_available(), "MKLDNN is not enabled") 189 @patch("torch.cuda.is_available", lambda: False) 190 def test_conv2d_packed(self): 191 options = itertools.product([[3, 56, 56]], [True, False], [0, (0,)]) 192 for x_shape, mode_train, padding in options: 193 mod = torch.nn.Sequential( 194 torch.nn.Conv2d(3, 64, 3, 3, padding=padding) 195 ).train(mode=mode_train) 196 v = torch.randn(x_shape, dtype=torch.float32) 197 198 with torch.no_grad(): 199 self.common( 200 mod, 201 (v,), 202 ) 203 204 @patch("torch.cuda.is_available", lambda: False) 205 def test_conv2d_autocast(self): 206 v = torch.randn(1, 3, 28, 18, dtype=torch.float32) 207 mod = torch.nn.Sequential(torch.nn.Conv2d(3, 64, 3, 3)).eval() 208 with torch.no_grad(), torch.cpu.amp.autocast(): 209 self.common( 210 mod, 211 (v,), 212 ) 213 214 @unittest.skipIf(not torch.backends.mkldnn.is_available(), "MKLDNN is not enabled") 215 @patch("torch.cuda.is_available", lambda: False) 216 def test_unsupported_conv_transpose(self): 217 class Model(torch.nn.Module): 218 def __init__(self): 219 super().__init__() 220 self.conv_transpose = torch.nn.ConvTranspose2d( 221 3, 6, 3, stride=1, padding=1, output_padding=1 222 ) 223 224 def forward(self, input_tensor): 225 x = self.conv_transpose(input_tensor) 226 output = torch.tanh(x) 227 return output 228 229 input = torch.randn(1, 3, 28, 28) 230 m = Model().eval() 231 232 with torch.no_grad(): 233 compiled_m = torch.compile(m) 234 with self.assertRaisesRegex( 235 RuntimeError, 236 "output padding must be smaller than either stride or dilation", 237 ): 238 compiled_m(input) 239 240 @unittest.skipIf(not torch.backends.mkldnn.is_available(), "MKLDNN is not enabled") 241 @patch("torch.cuda.is_available", lambda: False) 242 def test_conv_used_from_multiple_places(self): 243 class M(torch.nn.Module): 244 def __init__(self, conv_in_channel, conv_out_channel) -> None: 245 super().__init__() 246 self.conv = torch.nn.Conv2d(conv_in_channel, conv_out_channel, (3, 3)) 247 248 def forward(self, x): 249 res = self.conv(x) 250 res = F.relu(res) 251 res = self.conv(res) 252 return res 253 254 with torch.no_grad(): 255 mod = M(3, 3).eval() 256 x = torch.randn(1, 3, 224, 224) 257 self.common( 258 mod, 259 (x,), 260 ) 261 262 @unittest.skipIf(not torch.backends.mkldnn.is_available(), "MKLDNN is not enabled") 263 @patch("torch.cuda.is_available", lambda: False) 264 def test_linear_used_from_multiple_places(self): 265 class M(torch.nn.Module): 266 def __init__(self, in_channel, out_channel) -> None: 267 super().__init__() 268 self.linear = torch.nn.Linear(in_channel, out_channel) 269 270 def forward(self, x): 271 res = self.linear(x) 272 res = F.relu(res) 273 res = self.linear(res) 274 return res 275 276 dtypes = [] 277 if torch.ops.mkldnn._is_mkldnn_bf16_supported(): 278 dtypes.append(torch.bfloat16) 279 if torch.ops.mkldnn._is_mkldnn_fp16_supported(): 280 dtypes.append(torch.float16) 281 for dtype in dtypes: 282 with torch.no_grad(): 283 m = M(224, 224).to(dtype).eval() 284 m_opt = torch.compile(m) 285 x = torch.randn(224, 224, dtype=dtype) 286 m_opt(x) 287 self.assertEqual(m(x), m_opt(x)) 288 289 @config.patch(implicit_fallbacks=True) 290 def test_multihead_attention_cpu(self): 291 def fn( 292 q, 293 k, 294 v, 295 embed_dim, 296 num_heads, 297 qkv_weight, 298 qkv_bias, 299 proj_weight, 300 proj_bias, 301 mask, 302 need_weights, 303 ): 304 return torch._native_multi_head_attention( 305 q, 306 k, 307 v, 308 embed_dim, 309 num_heads, 310 qkv_weight, 311 qkv_bias, 312 proj_weight, 313 proj_bias, 314 mask, 315 need_weights, 316 ) 317 318 B = 1 319 T = 3 320 embed_dim = 6 321 num_heads = 2 322 q = torch.randn([B, T, embed_dim]) 323 k = torch.randn([B, T, embed_dim]) 324 v = torch.randn([B, T, embed_dim]) 325 qkv_weight = torch.randn([3 * embed_dim, embed_dim]) 326 qkv_bias = torch.randn([3 * embed_dim]) 327 proj_weight = torch.randn([3 * embed_dim, embed_dim]) 328 proj_bias = torch.randn([3 * embed_dim]) 329 mask = None 330 need_weights = False 331 332 inps = [ 333 q, 334 k, 335 v, 336 embed_dim, 337 num_heads, 338 qkv_weight, 339 qkv_bias, 340 proj_weight, 341 proj_bias, 342 mask, 343 need_weights, 344 ] 345 self.common(fn, inps) 346 347 @config.patch(freezing=True) 348 def test_module_buffer_mutation(self): 349 class Model(torch.nn.Module): 350 def __init__(self): 351 super().__init__() 352 self.register_buffer("foo", torch.rand((3, 10))) 353 354 def forward(self, x): 355 lx = [x, x.clone(), x.clone()] 356 y = [] 357 for i in range(3): 358 y.append(lx[i] + self.foo[i]) 359 return torch.cat(y, 1) 360 361 with torch.no_grad(): 362 example_inputs = (torch.rand(1, 10),) 363 self.common(Model(), example_inputs) 364 365 @unittest.skipIf(not torch.backends.mkldnn.is_available(), "MKLDNN is not enabled") 366 @patch("torch.cuda.is_available", lambda: False) 367 def test_linear_packed(self): 368 dtypes = [] 369 if torch.ops.mkldnn._is_mkldnn_bf16_supported(): 370 dtypes.append(torch.bfloat16) 371 if torch.ops.mkldnn._is_mkldnn_fp16_supported(): 372 dtypes.append(torch.float16) 373 options = itertools.product( 374 [[2, 3, 10], [2, 10], [10], [2, 0]], [3, 0], [True, False], dtypes 375 ) 376 for input_shape, out_dim, bias, dtype in options: 377 mod = torch.nn.Sequential( 378 torch.nn.Linear(input_shape[-1], out_dim, bias=bias) 379 ).eval() 380 381 v = torch.randn(input_shape) 382 with torch.no_grad(): 383 self.common( 384 mod.to(dtype), 385 (v.to(dtype),), 386 ) 387 388 @unittest.skipIf(not torch.backends.mkldnn.is_available(), "MKLDNN is not enabled") 389 @patch("torch.cuda.is_available", lambda: False) 390 def test_conv_transpose2d_packed_cpu(self): 391 options = itertools.product([[1, 3, 28, 28], [3, 28, 28]], [0, (0,)]) 392 for x_shape, padding in options: 393 mod = torch.nn.Sequential( 394 torch.nn.ConvTranspose2d(3, 64, 3, 3, padding=padding) 395 ).eval() 396 v = torch.randn(x_shape, dtype=torch.float32) 397 with torch.no_grad(): 398 self.common( 399 mod, 400 (v,), 401 ) 402 403 @config.patch(freezing=True) 404 @unittest.skipIf(not torch._C._has_mkldnn, "MKLDNN is not enabled") 405 @torch._dynamo.config.patch(dynamic_shapes=True) 406 @torch._dynamo.config.patch(assume_static_by_default=False) 407 def test_conv_in_channel_1_dynamic_shapes(self): 408 class M(torch.nn.Module): 409 def __init__(self, in_channel, out_channel) -> None: 410 super().__init__() 411 self.conv = torch.nn.Conv2d(in_channel, out_channel, 3) 412 413 def forward(self, x): 414 res = self.conv(x) 415 res = F.relu(res) 416 return res 417 418 # test the case where the channels dim of the input is 1 419 # Reproducer from the maml_omniglot model in Torchbench 420 in_channel = 1 421 out_channel = 3 422 amp_enabled_configs = [False] 423 if torch.ops.mkldnn._is_mkldnn_bf16_supported(): 424 # When amp is enabled here, the input to Conv is a FlexibleLayout. 425 # While it's disabled, the input is a FixedLayout. 426 amp_enabled_configs.append(True) 427 for amp_enabled in amp_enabled_configs: 428 mod = M(in_channel, out_channel).eval() 429 v = torch.randn(5, in_channel, 15, 15) 430 with torch.no_grad(), torch.cpu.amp.autocast(enabled=amp_enabled): 431 self.common( 432 mod, 433 (v,), 434 ) 435 436 @unittest.skipIf(not torch._C._has_mkldnn, "MKLDNN is not enabled") 437 @patch("torch.cuda.is_available", lambda: False) 438 @torch._dynamo.config.patch(dynamic_shapes=True) 439 @torch._dynamo.config.patch(assume_static_by_default=False) 440 @torch._dynamo.config.patch(allow_rnn=True) 441 @config.patch(freezing=True) 442 def _test_lstm_packed(self, params_dict, change_input_sizes=False): 443 from torch._dynamo.utils import counters 444 445 for ( 446 unbatched, 447 input_size, 448 hidden_size, 449 num_layers, 450 bidirectional, 451 bias, 452 empty_state, 453 batch_first, 454 batch_size, 455 seq_len, 456 ) in itertools.product(*list(params_dict.values())): 457 dtypes = [torch.float] 458 if torch.ops.mkldnn._is_mkldnn_bf16_supported(): 459 dtypes.append(torch.bfloat16) 460 if torch.ops.mkldnn._is_mkldnn_fp16_supported(): 461 dtypes.append(torch.float16) 462 for dtype in dtypes: 463 counters.clear() 464 num_directions = 2 if bidirectional else 1 465 466 seq_len_var = seq_len + 3 467 if unbatched: 468 v = torch.randn(seq_len, input_size) 469 v_var = torch.randn(seq_len_var, input_size) 470 h = torch.randn(num_layers * num_directions, hidden_size) 471 c = torch.randn(num_layers * num_directions, hidden_size) 472 else: 473 if batch_first: 474 v = torch.randn(batch_size, seq_len, input_size) 475 v_var = torch.randn(batch_size, seq_len_var, input_size) 476 else: 477 v = torch.randn(seq_len, batch_size, input_size) 478 v_var = torch.randn(seq_len_var, batch_size, input_size) 479 h = torch.randn( 480 num_layers * num_directions, batch_size, hidden_size 481 ) 482 c = torch.randn( 483 num_layers * num_directions, batch_size, hidden_size 484 ) 485 486 mod = LstmModule( 487 input_size, 488 hidden_size, 489 num_layers, 490 bias, 491 bidirectional, 492 batch_first, 493 ).eval() 494 maybe_autocast = ( 495 torch.cpu.amp.autocast() 496 if dtype == torch.bfloat16 497 else contextlib.nullcontext() 498 ) 499 500 with torch.no_grad(), maybe_autocast: 501 inps = [v] 502 if not empty_state: 503 inps.append((h, c)) 504 505 fn_opt = torch._dynamo.optimize("inductor")(mod) 506 _, code = run_and_get_cpp_code(fn_opt, *inps) 507 508 # Check that _flat_weights are not functional_tensor, otherwise 509 # deepcopy will fail during recompilation. 510 fn_opt_copy = copy.deepcopy(fn_opt) 511 _flat_weights = fn_opt_copy.lstm._flat_weights 512 for _flat_weight in _flat_weights: 513 self.assertFalse(torch._is_functional_tensor(_flat_weight)) 514 515 self.assertTrue("aten.mkldnn_rnn_layer" in code) 516 self.assertEqual(fn_opt(*inps), mod(*inps)) 517 self.assertEqual( 518 counters["inductor"]["pattern_matcher_count"], 519 num_layers * num_directions 520 + 2, # num of mkldnn_rnn_layer call + 2 view call on the concatenated hy, cy. 521 ) 522 523 # Change input sizes 524 if change_input_sizes: 525 inps_var = [v_var] 526 self.assertEqual(fn_opt(*inps_var), mod(*inps_var)) 527 528 @slowTest 529 def test_lstm_packed(self): 530 params_dict = { 531 "unbatched": [True, False], 532 "input_size": [1, 2], 533 "hidden_size": [2], 534 "num_layers": [1, 2], 535 "bidirectional": [False, True], 536 "bias": [False, True], 537 "empty_state": [False, True], 538 "batch_first": [True, False], 539 "batch_size": [1, 2], 540 "seq_len": [1, 2], 541 } 542 self._test_lstm_packed(params_dict) 543 544 def test_lstm_packed_change_input_sizes_cpu(self): 545 params_dict = { 546 "unbatched": [False], 547 "input_size": [2], 548 "hidden_size": [5], 549 "num_layers": [3], 550 "bidirectional": [True], 551 "bias": [True], 552 "empty_state": [False], 553 "batch_first": [False], 554 "batch_size": [2], 555 "seq_len": [3], 556 } 557 self._test_lstm_packed(params_dict, change_input_sizes=True) 558 559 @torch._dynamo.config.patch(dynamic_shapes=True) 560 @torch._dynamo.config.patch(assume_static_by_default=False) 561 @torch._dynamo.config.patch(allow_rnn=True) 562 def test_pack_padded_sequence_lstm(self): 563 embedding_dim = 12 564 hidden_dim = 10 565 batch_size = 24 566 num_layers = 1 567 bidirectional = True 568 num_direc = 2 569 max_lens = 96 570 571 sent = torch.randn(batch_size, max_lens, embedding_dim) 572 hid_0 = torch.rand(num_layers * num_direc, batch_size, hidden_dim) 573 hid_1 = torch.randn(num_layers * num_direc, batch_size, hidden_dim) 574 575 sent_lens = torch.Tensor( 576 [1, 2, 3, 4, 5, 1, 3, 2, 96, 5, 3, 1, 1, 2, 1, 2, 3, 6, 1, 2, 4, 6, 2, 1] 577 ) 578 579 assert sent_lens.shape[0] == batch_size 580 assert sent_lens.max().item() == max_lens 581 582 hidden_0 = hid_0.clone().requires_grad_(False) 583 hidden_1 = hid_1.clone().requires_grad_(False) 584 embeds = torch.nn.utils.rnn.pack_padded_sequence( 585 sent, sent_lens, batch_first=True, enforce_sorted=False 586 ) 587 588 mod = LstmModule( 589 embedding_dim, 590 hidden_dim, 591 num_layers=num_layers, 592 bias=True, 593 bidirectional=bidirectional, 594 batch_first=True, 595 ).eval() 596 597 with torch.no_grad(): 598 inps = [embeds, (hidden_0, hidden_1)] 599 fn_opt = torch._dynamo.optimize("inductor")(mod) 600 _, code = run_and_get_cpp_code(fn_opt, *inps) 601 # This case is unsupported 602 self.assertFalse("torch.ops.mkldnn._lstm" in code) 603 self.assertEqual(fn_opt(*inps), mod(*inps)) 604 605 @patch("torch.cuda.is_available", lambda: False) 606 def test_conv_transpose2d_has_output_size_input(self): 607 # https://github.com/pytorch/pytorch/issues/100344. 608 class M(torch.nn.Module): 609 def __init__(self) -> None: 610 super().__init__() 611 self.conv_transpose = torch.nn.ConvTranspose2d( 612 in_channels=3, out_channels=1, kernel_size=3, stride=1, padding=1 613 ) 614 615 def forward(self, x): 616 return self.conv_transpose(x, output_size=(10, 10)) 617 618 mod = M().eval() 619 v = torch.randn(1, 3, 10, 10, dtype=torch.float32) 620 with torch.no_grad(): 621 self.common( 622 mod, 623 (v,), 624 ) 625 626 def test_pad_with_nan_value(self): 627 # https://github.com/pytorch/pytorch/issues/100988. 628 class Model(torch.nn.Module): 629 def forward(self, x): 630 x = F.pad(x, (1, 1, 1, 1), value=float("nan")) 631 return x 632 633 mod = Model().eval() 634 v = torch.randn(1, 3, 10, 10, dtype=torch.float32) 635 with torch.no_grad(): 636 self.common( 637 mod, 638 (v,), 639 ) 640 641 def test_masked_fill_with_inf_or_nan_value(self): 642 def fn(value, mask): 643 y1 = torch.masked_fill(value, mask, float("inf")) 644 y2 = torch.masked_fill(value, mask, float("-inf")) 645 y3 = torch.masked_fill(value, mask, float("nan")) 646 return y1, y2, y3 647 648 value = torch.randn((2, 17)) 649 mask = torch.randint(0, 1, size=(2, 17), dtype=torch.uint8).to(torch.bool) 650 with torch.no_grad(): 651 self.common( 652 fn, 653 (value, mask), 654 ) 655 656 def test_relu_with_inf_value(self): 657 # https://github.com/pytorch/pytorch/issues/117544. 658 659 def fn(out): 660 out = torch.sinh(input=out) 661 out = torch.relu(input=out) 662 return out 663 664 x = torch.Tensor([-572373.5000, 755109.1250, 330995.5625]) 665 with torch.no_grad(): 666 self.common( 667 fn, 668 (x,), 669 ) 670 671 def test_acosh_with_negative_large_input(self): 672 # https://github.com/pytorch/pytorch/issues/118267. 673 674 def fn(input): 675 out = torch.acosh(input) 676 return out 677 678 x = torch.Tensor( 679 [ 680 [ 681 -8493.9854, 682 431654.1250, 683 71741.5859, 684 608234.5000, 685 -103814.7500, 686 -699397.0000, 687 -910685.8125, 688 -832737.1875, 689 875343.5000, 690 ] 691 ] 692 ).repeat(3, 9) 693 694 for dtype in [torch.float32, torch.bfloat16, torch.double]: 695 with torch.no_grad(): 696 torch._dynamo.reset() 697 metrics.reset() 698 _x = x.to(dtype) 699 self.common( 700 fn, 701 (_x,), 702 ) 703 704 @config.patch(implicit_fallbacks=True) 705 def test_repeat_interleave(self): 706 def fn(y): 707 return torch.repeat_interleave(y, 2, output_size=8) 708 709 a = torch.tensor([[1, 2], [3, 4]]) 710 self.common( 711 fn, 712 (a,), 713 ) 714 715 def test_inplace_squeeze_needed(self): 716 mod = torch.nn.Sequential( 717 torch.nn.Linear(10, 10), 718 torch.nn.LayerNorm(10), 719 torch.nn.ReLU(), 720 ).eval() 721 722 def fn(x): 723 return mod(x) 724 725 v = torch.randn(10) 726 # TODO: OMP parallel reduction order is not deterministic. 727 # Hence, the accurarcy might vary up and down. For short term, 728 # we increase the tolerance and will fix it later by using 729 # aten parallel. 730 self.common(fn, (v,), atol=5e-1, rtol=5e-1) 731 732 def test_cat_mul(self): 733 # https://github.com/pytorch/pytorch/issues/93365 734 def fn(p0, p1): 735 y1 = torch.cat([p0, p1], dim=0) 736 y2 = torch.mul(y1, y1) 737 return y1, y2 738 739 p0 = torch.randn(3, 4) 740 p1 = torch.randn(3, 4) 741 self.common(fn, (p0, p1)) 742 743 def test_pow_cos(self): 744 # https://github.com/pytorch/pytorch/issues/98149 745 def fn(x): 746 t = x.pow(5) 747 return torch.cos(t) 748 749 x = torch.tensor([4], dtype=torch.uint8) 750 self.common(fn, (x,)) 751 752 def test_reduce_with_masked(self): 753 # https://github.com/pytorch/pytorch/issues/96484 754 def fn(a, b): 755 a = torch.nn.functional.pad(a, (0, -1)) 756 c = a + b 757 return c.min(0).values 758 759 a = torch.randn([2]) 760 b = torch.randn([2]) 761 self.common(fn, (a, b)) 762 763 def test_scalar_sign_with_min(self): 764 # https://github.com/pytorch/pytorch/issues/101340 765 def fn(a): 766 t1 = torch.tanh(a) 767 t2 = torch.sign(t1) 768 return torch.min(t1, t2) 769 770 a = torch.randn(1, 3) 771 self.common(fn, (a,)) 772 773 def test_index_propagation_issue_102065(self): 774 def fn(x): 775 x = torch.arange(x.numel()) 776 return (x.unsqueeze(0) - x.unsqueeze(1)) ** 2 777 778 self.common( 779 fn, 780 (torch.randn(8),), 781 ) 782 783 def test_ModularIndexing_range_issue_103133(self): 784 def fn(q, k): 785 einsum = torch.einsum("bcxd,bcyd->bcxy", (q, k)) 786 constant_pad_nd = torch.ops.aten.constant_pad_nd.default( 787 einsum, [0, 0, 0, 1], 0.0 788 ) 789 view = torch.ops.aten.view.default(constant_pad_nd, [12, 1, 512, 513]) 790 y = view.new_zeros((12, 2, 256, 513)) 791 y[:, :-1, :, 256:] = view[:, :, :256, :257] 792 return y 793 794 self.common( 795 fn, 796 ( 797 torch.empty_strided((12, 1, 512, 64), (64, 196608, 768, 1)), 798 torch.empty_strided((12, 1, 512, 64), (64, 196608, 768, 1)), 799 ), 800 ) 801 802 @patch("torch.cuda.is_available", lambda: False) 803 def test_max_reduction_lowp_fp(self): 804 def fn(x): 805 return torch.ops.aten.max(x, 1, keepdim=True)[0].float() 806 807 for dtype in _lowp_fp_dtypes: 808 self.common( 809 fn, 810 (torch.randn(1, 32, 4, 4).to(dtype),), 811 ) 812 813 @patch("torch.cuda.is_available", lambda: False) 814 def test_vec_transpose_lowp_fp(self): 815 for dtype in _lowp_fp_dtypes: 816 817 def fn(x): 818 return x.to(memory_format=torch.channels_last).to(dtype) 819 820 self.common( 821 fn, 822 (torch.randn(2, 3, 4, 4),), 823 ) 824 825 def test_load_inf_bf16(self): 826 def fn1(x): 827 return torch.where(x > 0, x, math.inf) 828 829 def fn2(x): 830 return torch.where(x > 0, x, -math.inf) 831 832 for fn in [fn1, fn2]: 833 self.common( 834 fn, 835 (torch.randn(1, 3, 16, 16),), 836 ) 837 838 @patch("torch.cuda.is_available", lambda: False) 839 def test_fp32_load_with_to_lowp_fp(self): 840 # From llama model. 841 class Model(torch.nn.Module): 842 def __init__(self): 843 super().__init__() 844 self.cache_k = torch.zeros(8, 4, 2, 2) 845 846 def forward(self, x, xk): 847 bsz, seqlen, _ = x.shape 848 self.cache_k = self.cache_k.to(x) 849 self.cache_k[:bsz, 1 : 1 + seqlen] = xk 850 return self.cache_k 851 852 for dtype in _lowp_fp_dtypes: 853 ref_model = Model().eval() 854 opt_model = torch.compile()(Model().eval()) 855 x = torch.randn(4, 2, 2).to(dtype) 856 xk = torch.randn(4, 2, 2, 2).to(dtype) 857 self.assertEqual(opt_model(x, xk), ref_model(x, xk)) 858 859 @requires_vectorization 860 @patch("torch.cuda.is_available", lambda: False) 861 def test_sigmoid_with_reduction(self): 862 def fn(x): 863 x = torch.ops.aten.sigmoid.default(x) 864 return torch.ops.aten.mean.dim(x, [-1, -2], True) 865 866 x = torch.randn((1, 8, 8, 8)) 867 with config.patch({"cpp.simdlen": None}): 868 torch._dynamo.reset() 869 metrics.reset() 870 self.common(fn, (x,)) 871 872 def test_slice_scatter_default_end_value(self): 873 # From HF AllenaiLongformerBase. 874 def fn(query, key, window_overlap): 875 batch_size, seq_len, num_heads, head_dim = query.size() 876 assert ( 877 seq_len % (window_overlap * 2) == 0 878 ), f"Sequence length should be multiple of {window_overlap * 2}. Given {seq_len}" 879 880 chunks_count = torch.div(seq_len, window_overlap, rounding_mode="trunc") - 1 881 diagonal_chunked_attention_scores = key 882 diagonal_attention_scores = diagonal_chunked_attention_scores.new_zeros( 883 ( 884 batch_size * num_heads, 885 chunks_count + 1, 886 window_overlap, 887 window_overlap * 2 + 1, 888 ) 889 ) 890 diagonal_attention_scores[ 891 :, :3, :, window_overlap: 892 ] = diagonal_chunked_attention_scores[ 893 :, :, :window_overlap, : window_overlap + 1 894 ] 895 return diagonal_attention_scores 896 897 self.common( 898 fn, 899 ( 900 torch.randn(1, 1024, 12, 64), 901 torch.randn(12, 3, 512, 513), 902 256, 903 ), 904 ) 905 906 @requires_vectorization 907 @patch("torch.cuda.is_available", lambda: False) 908 def test_to_uint8_rounding_method(self): 909 def fn(x): 910 return x.to(torch.uint8) 911 912 numerical_testsuit = [4.4, 4.5, 4.6, 5.5] 913 for numerical_number in numerical_testsuit: 914 x = torch.ones(17) * numerical_number 915 with config.patch({"cpp.simdlen": None}): 916 torch._dynamo.reset() 917 metrics.reset() 918 self.common(fn, (x,)) 919 check_metrics_vec_kernel_count(1) 920 921 @requires_vectorization 922 def _test_decomposed_dequant_relu_quant_helper(self, dtype): 923 def fn( 924 x, scale, zero_point, use_dequant, use_quant, quant_min, quant_max, dtype 925 ): 926 # For quantized_decomposed.dequantize_per_tensor 927 # Refer to torch/ao/quantization/fx/_decomposed.py 928 if use_dequant: 929 x = (x.to(torch.float32) - zero_point) * scale 930 931 x = torch.relu(x) 932 933 # For quantized_decomposed.quantize_per_tensor 934 # Refer to torch/ao/quantization/fx/_decomposed.py 935 if use_quant: 936 inv_scale = 1.0 / scale 937 x = torch.clamp( 938 torch.round(x * inv_scale) + zero_point, quant_min, quant_max 939 ).to(dtype) 940 return x 941 942 assert dtype in [torch.uint8, torch.int8] 943 quant_min = 0 if dtype == torch.uint8 else -128 944 quant_max = 255 if dtype == torch.uint8 else 127 945 946 use_dequant_list = [False, True] 947 use_quant_list = [False, True] 948 for use_dequant, use_quant in itertools.product( 949 use_dequant_list, use_quant_list 950 ): 951 x = torch.clamp( 952 torch.randn((1, 7, 7, 9), dtype=torch.float32) * 100, 953 quant_min, 954 quant_max, 955 ) 956 if use_dequant: 957 x = x.to(dtype) 958 zero_point = 100 959 scale = 0.01 960 with config.patch({"cpp.simdlen": None}): 961 torch._dynamo.reset() 962 metrics.reset() 963 self.common( 964 fn, 965 ( 966 x, 967 scale, 968 zero_point, 969 use_dequant, 970 use_quant, 971 quant_min, 972 quant_max, 973 dtype, 974 ), 975 ) 976 check_metrics_vec_kernel_count(1) 977 978 @requires_vectorization 979 def test_decomposed_dequant_relu_quant_uint8(self): 980 self._test_decomposed_dequant_relu_quant_helper(torch.uint8) 981 982 @requires_vectorization 983 def test_decomposed_dequant_relu_quant_int8(self): 984 self._test_decomposed_dequant_relu_quant_helper(torch.int8) 985 986 def _test_dequant_quant_lowering_helper(self, dtype): 987 def fn( 988 x, scale, zero_point, use_dequant, use_quant, quant_min, quant_max, dtype 989 ): 990 if use_dequant: 991 x = torch.ops.quantized_decomposed.dequantize_per_tensor( 992 x, scale, zero_point, quant_min, quant_max, dtype 993 ) 994 995 x = torch.relu(x) 996 997 if use_quant: 998 x = torch.ops.quantized_decomposed.quantize_per_tensor( 999 x, scale, zero_point, quant_min, quant_max, dtype 1000 ) 1001 return x 1002 1003 use_dequant_list = [False, True] 1004 use_quant_list = [False, True] 1005 use_tensor_overload_list = [False, True] 1006 1007 assert dtype in [torch.uint8, torch.int8] 1008 quant_min = 0 if dtype == torch.uint8 else -128 1009 quant_max = 255 if dtype == torch.uint8 else 127 1010 1011 for use_dequant, use_quant, use_tensor_overload in itertools.product( 1012 use_dequant_list, use_quant_list, use_tensor_overload_list 1013 ): 1014 x = torch.clamp( 1015 torch.randn((1, 7, 7, 9), dtype=torch.float32) * 100, 1016 quant_min, 1017 quant_max, 1018 ) 1019 if use_dequant: 1020 x = x.to(dtype) 1021 zero_point = 100 1022 scale = 0.01 1023 if use_tensor_overload: 1024 zero_point = torch.tensor(zero_point, dtype=torch.int64) 1025 scale = torch.tensor(scale) 1026 with config.patch({"cpp.simdlen": None}): 1027 torch._dynamo.reset() 1028 metrics.reset() 1029 self.common( 1030 fn, 1031 ( 1032 x, 1033 scale, 1034 zero_point, 1035 use_dequant, 1036 use_quant, 1037 quant_min, 1038 quant_max, 1039 dtype, 1040 ), 1041 ) 1042 check_metrics_vec_kernel_count(1) 1043 1044 @requires_vectorization 1045 def test_dequant_quant_lowering_uint8(self): 1046 self._test_dequant_quant_lowering_helper(torch.uint8) 1047 1048 @requires_vectorization 1049 def test_dequant_quant_lowering_int8(self): 1050 self._test_dequant_quant_lowering_helper(torch.int8) 1051 1052 def _test_dequant_maxpool2d_lowering_helper(self, dtype): 1053 def fn(x, scale, zero_point, quant_min, quant_max, dtype): 1054 x = torch.ops.quantized_decomposed.dequantize_per_tensor( 1055 x, scale, zero_point, quant_min, quant_max, dtype 1056 ) 1057 max_pool2d_with_indices_default = ( 1058 torch.ops.aten.max_pool2d_with_indices.default( 1059 x, [2, 2], [2, 2], [1, 1] 1060 )[0] 1061 ) 1062 return max_pool2d_with_indices_default 1063 1064 assert dtype in [torch.uint8, torch.int8] 1065 quant_min = 0 if dtype == torch.uint8 else -128 1066 quant_max = 255 if dtype == torch.uint8 else 127 1067 1068 use_tensor_overload_list = [False, True] 1069 for use_tensor_overload in use_tensor_overload_list: 1070 x = ( 1071 torch.clamp( 1072 torch.randn((3, 16, 8, 8), dtype=torch.float32) * 100, 1073 quant_min, 1074 quant_max, 1075 ) 1076 .to(dtype) 1077 .contiguous(memory_format=torch.channels_last) 1078 ) 1079 zero_point = 100 1080 scale = 0.01 1081 if use_tensor_overload: 1082 zero_point = torch.tensor(zero_point, dtype=torch.int64) 1083 scale = torch.tensor(scale) 1084 with config.patch({"cpp.simdlen": None}): 1085 torch._dynamo.reset() 1086 metrics.reset() 1087 self.common(fn, (x, scale, zero_point, quant_min, quant_max, dtype)) 1088 check_metrics_vec_kernel_count(1) 1089 1090 @requires_vectorization 1091 def test_dequant_maxpool2d_lowering_uint8(self): 1092 self._test_dequant_maxpool2d_lowering_helper(torch.uint8) 1093 1094 @requires_vectorization 1095 def test_dequant_maxpool2d_lowering_int8(self): 1096 self._test_dequant_maxpool2d_lowering_helper(torch.int8) 1097 1098 def _test_tile2d_load_decomposed_dequant_add_relu_quant_helper(self, dtype): 1099 def fn( 1100 x, 1101 scale, 1102 zero_point, 1103 x2, 1104 scale2, 1105 zero_point2, 1106 output_scale, 1107 output_zero_point, 1108 use_dequant, 1109 use_dequant2, 1110 use_quant, 1111 quant_min, 1112 quant_max, 1113 dtype, 1114 ): 1115 if use_dequant: 1116 x = torch.ops.quantized_decomposed.dequantize_per_tensor( 1117 x, scale, zero_point, quant_min, quant_max, dtype 1118 ) 1119 if use_dequant2: 1120 x2 = torch.ops.quantized_decomposed.dequantize_per_tensor( 1121 x2, scale2, zero_point2, quant_min, quant_max, dtype 1122 ) 1123 temp = x + x2 1124 y = torch.relu(temp) 1125 1126 if use_quant: 1127 y = torch.ops.quantized_decomposed.quantize_per_tensor( 1128 y, output_scale, output_zero_point, quant_min, quant_max, dtype 1129 ) 1130 return y.contiguous() 1131 1132 assert dtype in [torch.uint8, torch.int8] 1133 quant_min = 0 if dtype == torch.uint8 else -128 1134 quant_max = 255 if dtype == torch.uint8 else 127 1135 1136 use_dequant_list = [False, True] 1137 use_dequant_list2 = [False, True] 1138 use_quant_list = [False, True] 1139 1140 for use_dequant, use_dequant2, use_quant in itertools.product( 1141 use_dequant_list, use_dequant_list2, use_quant_list 1142 ): 1143 x = torch.clamp( 1144 torch.randn((1, 1024, 14, 14), dtype=torch.float32) * 100, 1145 quant_min, 1146 quant_max, 1147 ).contiguous(memory_format=torch.channels_last) 1148 x2 = torch.clamp( 1149 torch.randn((1, 1024, 14, 14), dtype=torch.float32) * 100, 1150 quant_min, 1151 quant_max, 1152 ).contiguous(memory_format=torch.channels_last) 1153 if use_dequant: 1154 x = x.to(dtype).contiguous(memory_format=torch.channels_last) 1155 if use_dequant2: 1156 x2 = x2.to(dtype).contiguous(memory_format=torch.channels_last) 1157 zero_point = 1 1158 scale = 0.01 1159 zero_point2 = 2 1160 scale2 = 0.02 1161 output_zero_point = 3 1162 output_scale = 0.03 1163 with config.patch({"cpp.simdlen": None}): 1164 torch._dynamo.reset() 1165 metrics.reset() 1166 self.common( 1167 fn, 1168 ( 1169 x, 1170 scale, 1171 zero_point, 1172 x2, 1173 scale2, 1174 zero_point2, 1175 output_scale, 1176 output_zero_point, 1177 use_dequant, 1178 use_dequant2, 1179 use_quant, 1180 quant_min, 1181 quant_max, 1182 dtype, 1183 ), 1184 ) 1185 check_metrics_vec_kernel_count(2) 1186 1187 @requires_vectorization 1188 def test_tile2d_load_decomposed_dequant_add_relu_quant_uint8(self): 1189 self._test_tile2d_load_decomposed_dequant_add_relu_quant_helper(torch.uint8) 1190 1191 @requires_vectorization 1192 def test_tile2d_load_decomposed_dequant_add_relu_quant_int8(self): 1193 self._test_tile2d_load_decomposed_dequant_add_relu_quant_helper(torch.int8) 1194 1195 @requires_vectorization 1196 def _test_per_tensor_fake_quant_helper(self, dtype): 1197 def fn(input, scales, zero_points, quant_min, quant_max, dtype): 1198 input = torch.ops.quantized_decomposed.quantize_per_tensor( 1199 input, scales, zero_points, quant_min, quant_max, dtype 1200 ) 1201 input = torch.ops.quantized_decomposed.dequantize_per_tensor( 1202 input, scales, zero_points, quant_min, quant_max, dtype 1203 ) 1204 return input 1205 1206 use_tensor_overload_list = [False, True] 1207 for use_tensor_overload in use_tensor_overload_list: 1208 assert dtype in [torch.uint8, torch.int8] 1209 quant_min = 0 if dtype == torch.uint8 else -128 1210 quant_max = 255 if dtype == torch.uint8 else 127 1211 x = torch.clamp( 1212 torch.randn((1, 7, 7, 9), dtype=torch.float32) * 100, 1213 quant_min, 1214 quant_max, 1215 ) 1216 zero_point = 100 1217 scale = 0.01 1218 if use_tensor_overload: 1219 zero_point = torch.tensor(zero_point, dtype=torch.int64) 1220 scale = torch.tensor(scale) 1221 with config.patch({"cpp.simdlen": None}): 1222 torch._dynamo.reset() 1223 metrics.reset() 1224 self.common(fn, (x, scale, zero_point, quant_min, quant_max, dtype)) 1225 assert metrics.generated_cpp_vec_kernel_count == 1 1226 1227 @requires_vectorization 1228 def test_per_tensor_fake_quant_uint8(self): 1229 self._test_per_tensor_fake_quant_helper(torch.uint8) 1230 1231 @requires_vectorization 1232 def test_per_tensor_fake_quant_int8(self): 1233 self._test_per_tensor_fake_quant_helper(torch.int8) 1234 1235 def _test_per_channel_fake_quant_helper(self, dtype, input_dtype=torch.float32): 1236 def fn(input, scales, zero_points, axis, quant_min, quant_max, dtype): 1237 input = torch.ops.quantized_decomposed.quantize_per_channel( 1238 input, scales, zero_points, axis, quant_min, quant_max, dtype 1239 ) 1240 input = torch.ops.quantized_decomposed.dequantize_per_channel( 1241 input, scales, zero_points, axis, quant_min, quant_max, dtype 1242 ) 1243 return input 1244 1245 assert dtype in [torch.uint8, torch.int8] 1246 quant_min = 0 if dtype == torch.uint8 else -128 1247 quant_max = 255 if dtype == torch.uint8 else 127 1248 x = torch.clamp( 1249 torch.randn((1, 3, 224, 224), dtype=torch.float32) * 100, 1250 quant_min, 1251 quant_max, 1252 ) 1253 if input_dtype != torch.float32: 1254 x = x.to(dtype=input_dtype) 1255 scales = torch.ones((3,)) 1256 zero_points = torch.zeros((3,)) 1257 axis = 1 1258 with config.patch({"cpp.simdlen": None}): 1259 torch._dynamo.reset() 1260 metrics.reset() 1261 self.common(fn, (x, scales, zero_points, axis, quant_min, quant_max, dtype)) 1262 check_metrics_vec_kernel_count(1) 1263 1264 @requires_vectorization 1265 def test_per_channel_fake_quant_uint8(self): 1266 self._test_per_channel_fake_quant_helper(torch.uint8) 1267 1268 @requires_vectorization 1269 def test_per_channel_fake_quant_module_uint8(self): 1270 class Mod(torch.nn.Module): 1271 def __init__(self): 1272 super().__init__() 1273 self.scales = torch.ones((3,)).to(torch.float64) 1274 self.zero_points = torch.zeros((3,)).to(torch.int64) 1275 self.axis = 1 1276 self.quant_min = 0 1277 self.quant_max = 255 1278 self.dtype = torch.uint8 1279 1280 def forward(self, input): 1281 input = torch.ops.quantized_decomposed.quantize_per_channel( 1282 input, 1283 self.scales, 1284 self.zero_points, 1285 self.axis, 1286 self.quant_min, 1287 self.quant_max, 1288 self.dtype, 1289 ) 1290 input = torch.ops.quantized_decomposed.dequantize_per_channel( 1291 input, 1292 self.scales, 1293 self.zero_points, 1294 self.axis, 1295 self.quant_min, 1296 self.quant_max, 1297 self.dtype, 1298 ) 1299 return input 1300 1301 m = Mod().eval() 1302 x = torch.clamp( 1303 torch.randn((1, 3, 224, 224), dtype=torch.float32) * 100, 1304 0, 1305 255, 1306 ) 1307 with config.patch({"cpp.simdlen": None}): 1308 torch._dynamo.reset() 1309 metrics.reset() 1310 self.common(m, (x,)) 1311 assert metrics.generated_cpp_vec_kernel_count == 1 1312 1313 @requires_vectorization 1314 def test_per_channel_fake_quant_int8(self): 1315 self._test_per_channel_fake_quant_helper(torch.int8) 1316 1317 @requires_vectorization 1318 def test_per_channel_fake_quant_uint8_bf16_input(self): 1319 self._test_per_channel_fake_quant_helper( 1320 torch.uint8, input_dtype=torch.bfloat16 1321 ) 1322 1323 @requires_vectorization 1324 def test_per_channel_fake_quant_int8_bf16_input(self): 1325 self._test_per_channel_fake_quant_helper(torch.int8, input_dtype=torch.bfloat16) 1326 1327 def _test_non_contiguous_load_buf_quant_helper(self, dtype): 1328 def fn( 1329 x1, 1330 x2, 1331 groups, 1332 quant_min, 1333 quant_max, 1334 dtype, 1335 ): 1336 x = torch.cat((x1, x2), dim=1) 1337 batchsize, num_channels, height, width = x.size() 1338 channels_per_group = num_channels // groups 1339 x = torch.ops.quantized_decomposed.dequantize_per_tensor( 1340 x, 1.0, 0, quant_min, quant_max, dtype 1341 ) 1342 x = x.view(batchsize, groups, channels_per_group, height, width) 1343 x = torch.ops.quantized_decomposed.quantize_per_tensor( 1344 x, 1.0, 0, quant_min, quant_max, dtype 1345 ) 1346 x = torch.ops.quantized_decomposed.dequantize_per_tensor( 1347 x, 1.0, 0, quant_min, quant_max, dtype 1348 ) 1349 x = torch.transpose(x, 1, 2).contiguous() 1350 x = x.view(batchsize, num_channels, height, width) 1351 return x 1352 1353 assert dtype in [torch.uint8, torch.int8] 1354 quant_min = 0 if dtype == torch.uint8 else -128 1355 quant_max = 255 if dtype == torch.uint8 else 127 1356 1357 x = torch.randint(0, 8, (1, 116, 28, 28), dtype=dtype).contiguous( 1358 memory_format=torch.channels_last 1359 ) 1360 x2 = torch.randint(0, 8, (1, 116, 28, 28), dtype=dtype).contiguous( 1361 memory_format=torch.channels_last 1362 ) 1363 1364 with config.patch({"cpp.simdlen": None}): 1365 torch._dynamo.reset() 1366 metrics.reset() 1367 self.common( 1368 fn, 1369 ( 1370 x, 1371 x2, 1372 2, 1373 quant_min, 1374 quant_max, 1375 dtype, 1376 ), 1377 ) 1378 check_metrics_vec_kernel_count(2) 1379 1380 @requires_vectorization 1381 def test_non_contiguous_load_buf_quant_uint8(self): 1382 self._test_non_contiguous_load_buf_quant_helper(torch.uint8) 1383 1384 @requires_vectorization 1385 def test_non_contiguous_load_buf_quant_int8(self): 1386 self._test_non_contiguous_load_buf_quant_helper(torch.int8) 1387 1388 def _test_tile2d_store_channel_shuffle_cl_quant_output_helper(self, dtype): 1389 def channel_shuffle( 1390 x, groups, output_scale, output_zero_point, quant_min, quant_max, dtype 1391 ): 1392 batchsize, num_channels, height, width = x.size() 1393 channels_per_group = num_channels // groups 1394 x = x.view(batchsize, groups, channels_per_group, height, width) 1395 x = torch.transpose(x, 1, 2).contiguous() 1396 x = x.view(batchsize, -1, height, width) 1397 x = torch.ops.quantized_decomposed.quantize_per_tensor( 1398 x, output_scale, output_zero_point, quant_min, quant_max, dtype 1399 ) 1400 return x.contiguous(memory_format=torch.channels_last) 1401 1402 assert dtype in [torch.uint8, torch.int8] 1403 quant_min = 0 if dtype == torch.uint8 else -128 1404 quant_max = 255 if dtype == torch.uint8 else 127 1405 1406 with config.patch({"cpp.simdlen": None}): 1407 torch._dynamo.reset() 1408 metrics.reset() 1409 x = torch.randn(64, 58, 28, 28) 1410 output_zero_point = 3 1411 output_scale = 0.03 1412 self.common( 1413 channel_shuffle, 1414 (x, 2, output_scale, output_zero_point, quant_min, quant_max, dtype), 1415 ) 1416 check_metrics_vec_kernel_count(2) 1417 1418 @requires_vectorization 1419 def test_tile2d_store_channel_shuffle_cl_quant_output_uint8(self): 1420 self._test_tile2d_store_channel_shuffle_cl_quant_output_helper(torch.uint8) 1421 1422 @requires_vectorization 1423 def test_tile2d_store_channel_shuffle_cl_quant_output_int8(self): 1424 self._test_tile2d_store_channel_shuffle_cl_quant_output_helper(torch.int8) 1425 1426 def _test_dequant_relu_quant_dequant_relu_quant_lowering_helper(self, dtype): 1427 def fn( 1428 x, 1429 scale, 1430 zero_point, 1431 scale2, 1432 zero_point2, 1433 scale3, 1434 zero_point3, 1435 quant_min, 1436 quant_max, 1437 dtype, 1438 ): 1439 x = torch.ops.quantized_decomposed.dequantize_per_tensor( 1440 x, scale, zero_point, quant_min, quant_max, dtype 1441 ) 1442 x = torch.relu(x) 1443 x = torch.ops.quantized_decomposed.quantize_per_tensor( 1444 x, scale2, zero_point2, quant_min, quant_max, dtype 1445 ) 1446 x = torch.ops.quantized_decomposed.dequantize_per_tensor( 1447 x, scale2, zero_point2, quant_min, quant_max, dtype 1448 ) 1449 x = torch.relu(x) 1450 x = torch.ops.quantized_decomposed.quantize_per_tensor( 1451 x, scale3, zero_point3, quant_min, quant_max, dtype 1452 ) 1453 return x 1454 1455 assert dtype in [torch.uint8, torch.int8] 1456 quant_min = 0 if dtype == torch.uint8 else -128 1457 quant_max = 255 if dtype == torch.uint8 else 127 1458 1459 for use_tensor_overload in [True, False]: 1460 x = torch.clamp( 1461 torch.randn((1, 7, 7, 9), dtype=torch.float32) * 100, 1462 quant_min, 1463 quant_max, 1464 ).to(dtype) 1465 zero_point_list = [100, 101, 102] 1466 scale_list = [0.01, 0.02, 0.03] 1467 if use_tensor_overload: 1468 for i in range(len(zero_point_list)): 1469 zero_point_list[i] = torch.tensor( 1470 zero_point_list[i], dtype=torch.int64 1471 ) 1472 scale_list[i] = torch.tensor(scale_list[i]) 1473 zero_point, zero_point2, zero_point3 = zero_point_list 1474 scale, scale2, scale3 = scale_list 1475 with config.patch({"cpp.simdlen": None}): 1476 torch._dynamo.reset() 1477 metrics.reset() 1478 self.common( 1479 fn, 1480 ( 1481 x, 1482 scale, 1483 zero_point, 1484 scale2, 1485 zero_point2, 1486 scale3, 1487 zero_point3, 1488 quant_min, 1489 quant_max, 1490 dtype, 1491 ), 1492 rtol=1e-2, 1493 atol=1e-2, 1494 ) 1495 check_metrics_vec_kernel_count(1) 1496 1497 @requires_vectorization 1498 def test_dequant_relu_quant_dequant_relu_quant_lowering_uint8(self): 1499 self._test_dequant_relu_quant_dequant_relu_quant_lowering_helper(torch.uint8) 1500 1501 @requires_vectorization 1502 def test_dequant_relu_quant_dequant_relu_quant_lowering_int8(self): 1503 self._test_dequant_relu_quant_dequant_relu_quant_lowering_helper(torch.int8) 1504 1505 def test_inplace_add_alpha(self): 1506 def fn(x, y): 1507 aten.add_.Tensor(x, y, alpha=0.55) 1508 return (x,) 1509 1510 x1 = torch.zeros(10) 1511 x2 = torch.zeros(10) 1512 x3 = torch.zeros(10) 1513 y = torch.randn(10) 1514 fn_fx = make_fx(fn)(x1, y) 1515 fn_compiled = compile_fx_inner(fn_fx, [x1, y]) 1516 fn(x2, y) 1517 fn_compiled([x3, y]) 1518 assert same(x2, x3) 1519 1520 def test_int_div(self): 1521 def fn(x, y): 1522 s3 = x.size(1) 1523 a = torch.zeros((1 + s3) // 2) 1524 a += y 1525 return a, s3 1526 1527 p0 = torch.randint(5, (1, 8)) 1528 p1 = torch.randn(1) 1529 self.common(fn, (p0, p1)) 1530 1531 def test_no_op_squeeze(self): 1532 @torch._dynamo.optimize("inductor") 1533 def forward(arg0_1): 1534 return torch.ops.aten.squeeze.dim(arg0_1, 1) 1535 1536 x = torch.randn((10, 20)) 1537 self.common(forward, (x,)) 1538 1539 def test_parallel_num_threads(self): 1540 @torch._dynamo.optimize("inductor") 1541 def fn(x1, x2): 1542 return x1 + x2 1543 1544 x1 = torch.randn((10, 20)) 1545 x2 = torch.randn((10, 20)) 1546 with set_num_threads(1): 1547 assert same(x1 + x2, fn(x1, x2)) 1548 with set_num_threads(4): 1549 assert same(x1 + x2, fn(x1, x2)) 1550 1551 @patch("torch.cuda.is_available", lambda: False) 1552 def test_timed_cpu_only(self): 1553 timed(lambda: torch.randn(10), ()) 1554 1555 def test_complex_memory_overlap(self): 1556 dense = torch.zeros(64, 32) 1557 self.assertFalse(complex_memory_overlap(dense)) 1558 self.assertFalse(complex_memory_overlap(dense.t())) 1559 1560 strided = dense.split(4, dim=1) 1561 self.assertFalse(complex_memory_overlap(strided[0])) 1562 self.assertFalse(complex_memory_overlap(strided[0].t())) 1563 1564 unsqueezed = dense.unsqueeze(1) 1565 self.assertFalse(complex_memory_overlap(unsqueezed)) 1566 self.assertFalse(complex_memory_overlap(unsqueezed.permute(1, 2, 0))) 1567 1568 gathered = dense.index_select(0, torch.IntTensor([1, 0, 1])) 1569 self.assertFalse(complex_memory_overlap(gathered)) 1570 self.assertFalse(complex_memory_overlap(gathered.t())) 1571 1572 @requires_vectorization 1573 def test_vec_dynamic_shapes(self): 1574 def fn(x): 1575 return torch.softmax(x, -1) 1576 1577 value = torch.randn((2, 10)) 1578 with config.patch({"cpp.simdlen": None}): 1579 torch._dynamo.reset() 1580 metrics.reset() 1581 self.common(fn, (value,)) 1582 1583 @unittest.skipIf( 1584 platform.machine() != "x86_64" or not codecache.valid_vec_isa_list(), 1585 "Does not support vectorization or not x86_64 machine", 1586 ) 1587 @patch("torch.cuda.is_available", lambda: False) 1588 def test_auto_simd(self): 1589 vec_avx512 = codecache.supported_vec_isa_list[0] 1590 vec_avx2 = codecache.supported_vec_isa_list[1] 1591 self.assertTrue(vec_avx512.bit_width() == 512) 1592 self.assertTrue(vec_avx2.bit_width() == 256) 1593 self.assertTrue(vec_avx512.nelements() == 16) 1594 self.assertTrue(vec_avx2.nelements() == 8) 1595 self.assertTrue(vec_avx512.nelements(torch.bfloat16) == 32) 1596 self.assertTrue(vec_avx2.nelements(torch.bfloat16) == 16) 1597 1598 with config.patch({"cpp.simdlen": None}): 1599 isa = codecache.pick_vec_isa() 1600 if vec_avx512 in codecache.valid_vec_isa_list(): 1601 self.assertTrue(isa == vec_avx512) 1602 else: 1603 self.assertTrue(isa == vec_avx2) 1604 1605 with config.patch({"cpp.simdlen": 0}): 1606 isa = codecache.pick_vec_isa() 1607 self.assertFalse(isa) 1608 1609 with config.patch({"cpp.simdlen": 1}): 1610 isa = codecache.pick_vec_isa() 1611 self.assertFalse(isa) 1612 1613 with config.patch({"cpp.simdlen": 257}): 1614 isa = codecache.pick_vec_isa() 1615 self.assertFalse(isa) 1616 1617 with config.patch({"cpp.simdlen": 513}): 1618 isa_list = codecache.valid_vec_isa_list() 1619 if vec_avx512 in isa_list: 1620 self.assertFalse(isa) 1621 1622 with config.patch({"cpp.simdlen": 512}): 1623 isa_list = codecache.valid_vec_isa_list() 1624 if vec_avx512 in isa_list: 1625 isa = codecache.pick_vec_isa() 1626 self.assertTrue(isa == vec_avx512) 1627 1628 with config.patch({"cpp.simdlen": 256}): 1629 isa_list = codecache.valid_vec_isa_list() 1630 if vec_avx2 in isa_list: 1631 isa = codecache.pick_vec_isa() 1632 self.assertTrue(isa == vec_avx2) 1633 1634 @requires_vectorization 1635 @patch("torch.cuda.is_available", lambda: False) 1636 def test_masked_fill_softmax(self): 1637 def fn(value, mask): 1638 mask = mask.to(torch.bool) 1639 x = torch.masked_fill(value, mask, -33.0) 1640 return torch.softmax(x, -1) 1641 1642 for dtype in vec_dtypes: 1643 value = torch.randn((2, 17), dtype=dtype) 1644 mask = torch.randint(0, 1, size=(2, 17), dtype=torch.uint8) 1645 with config.patch({"cpp.simdlen": None}): 1646 for cpp_wrapper_flag in [True, False]: 1647 with config.patch({"cpp_wrapper": cpp_wrapper_flag}): 1648 torch._dynamo.reset() 1649 metrics.reset() 1650 self.common(fn, (value, mask)) 1651 assert metrics.generated_cpp_vec_kernel_count >= 1 1652 1653 def test_channels_last_view_as_complex(self): 1654 # https://github.com/pytorch/pytorch/issues/122448#issuecomment-2046169554 1655 1656 def reduce_example(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 1657 """Applies the rotary embedding to the query and key tensors.""" 1658 x_out = torch.view_as_complex(torch.stack([x.float(), y.float()], dim=-1)) 1659 return x_out 1660 1661 args = [torch.randn(1, 1, 1, 128), torch.randn(1, 1, 1, 128)] 1662 expected = reduce_example(*args) 1663 actual = torch.compile(reduce_example, fullgraph=True)(*args) 1664 self.assertEqual(expected, actual) 1665 1666 def test_load_same_bool_tensor_twice(self): 1667 @torch._dynamo.optimize("inductor") 1668 def fn(a, b): 1669 x = torch.masked_fill(a, b, -33.0) 1670 y = torch.masked_fill(a, b, -33.0) 1671 return x, y 1672 1673 value = torch.randn((2, 17)) 1674 mask = torch.randint(0, 1, size=(2, 17), dtype=torch.uint8).to(torch.bool) 1675 fn(value, mask) 1676 1677 def test_cpu_vec_cosim(self): 1678 cpp_vec_op_list = [] 1679 cpp_op_list = [] 1680 1681 for k, v in CppVecOverrides.__dict__.items(): 1682 if isinstance(v, staticmethod): 1683 cpp_vec_op_list.append(k) 1684 for k, v in CppOverrides.__dict__.items(): 1685 if isinstance(v, staticmethod): 1686 cpp_op_list.append(k) 1687 1688 diff = [ 1689 "airy_ai", 1690 "bessel_j0", 1691 "bessel_j1", 1692 "bessel_y0", 1693 "bessel_y1", 1694 "modified_bessel_i0", 1695 "modified_bessel_i1", 1696 "modified_bessel_k0", 1697 "modified_bessel_k1", 1698 "scaled_modified_bessel_k0", 1699 "scaled_modified_bessel_k1", 1700 "spherical_bessel_j0", 1701 "i1", 1702 "i1e", 1703 "ndtr", 1704 "ndtri", 1705 "log_ndtr", 1706 "erfcx", 1707 "gammainc", 1708 "gammaincc", 1709 "igamma", 1710 "igammac", 1711 "polygamma", 1712 "zeta", 1713 "shifted_chebyshev_polynomial_u", 1714 "chebyshev_polynomial_u", 1715 "chebyshev_polynomial_t", 1716 "shifted_chebyshev_polynomial_w", 1717 "chebyshev_polynomial_w", 1718 "shifted_chebyshev_polynomial_t", 1719 "chebyshev_polynomial_v", 1720 "shifted_chebyshev_polynomial_v", 1721 "hermite_polynomial_he", 1722 "laguerre_polynomial_l", 1723 "hermite_polynomial_h", 1724 "legendre_polynomial_p", 1725 "constant", 1726 "index_expr", 1727 "signbit", 1728 "isinf", 1729 "frexp", 1730 "mod", 1731 "masked", 1732 "randn", 1733 "isnan", 1734 "rand", 1735 "randint64", 1736 "logical_and", 1737 "logical_not", 1738 "logical_or", 1739 "logical_xor", 1740 "bitwise_and", 1741 "bitwise_left_shift", 1742 "bitwise_not", 1743 "bitwise_right_shift", 1744 "bitwise_or", 1745 "bitwise_xor", 1746 "to_dtype_bitcast", 1747 ] 1748 union = {*cpp_vec_op_list, *diff} 1749 self.assertTrue( 1750 set(cpp_op_list).issubset(union), f"unexpected: {set(cpp_op_list) - union}" 1751 ) 1752 1753 def test_atomic_add_lowp_fp(self): 1754 def fn(test_args): 1755 res = torch.gather(**test_args) 1756 return res 1757 1758 for dtype in _lowp_fp_dtypes: 1759 input_tensor_for_ref = torch.tensor( 1760 [[3.0, -5.0]], dtype=dtype, requires_grad=True 1761 ) 1762 input_tensor_for_opt = torch.tensor( 1763 [[3.0, -5.0]], dtype=dtype, requires_grad=True 1764 ) 1765 1766 test_args_for_ref = { 1767 "input": input_tensor_for_ref, 1768 "dim": 1, 1769 "index": torch.tensor([[1]]), 1770 } 1771 test_args_for_opt = { 1772 "input": input_tensor_for_opt, 1773 "dim": 1, 1774 "index": torch.tensor([[1]]), 1775 } 1776 1777 opt_fn = torch.compile(fn) 1778 1779 ref_fwd = fn(test_args_for_ref) 1780 res_fwd = opt_fn(test_args_for_opt) 1781 self.assertEqual(res_fwd, ref_fwd) 1782 1783 torch.manual_seed(1) 1784 bwd_tensor_for_ref = torch.randn(ref_fwd.shape, dtype=dtype) 1785 torch.manual_seed(1) 1786 bwd_tensor_for_opt = torch.randn(res_fwd.shape, dtype=dtype) 1787 self.assertEqual(bwd_tensor_for_ref, bwd_tensor_for_opt) 1788 1789 ref_fwd.backward(bwd_tensor_for_ref) 1790 res_fwd.backward(bwd_tensor_for_opt) 1791 1792 ref_grad = test_args_for_ref["input"].grad 1793 res_grad = test_args_for_opt["input"].grad 1794 self.assertEqual(ref_grad, res_grad) 1795 1796 def test_meta_device(self): 1797 @torch.compile(fullgraph=True) 1798 def fn(): 1799 x = torch.ops.aten.empty.memory_format( 1800 [1024, 128, 128], 1801 dtype=torch.float16, 1802 device="meta", 1803 pin_memory=False, 1804 ) 1805 return x.sin() + 1 1806 1807 self.assertEqual(fn().shape, [1024, 128, 128]) 1808 1809 def test_decomposed_fake_quant_per_channel(self): 1810 def fq(input, scales, zero_points, axis, quant_min, quant_max): 1811 res = torch.fake_quantize_per_channel_affine( 1812 input, scales, zero_points, axis, quant_min, quant_max 1813 ) 1814 return res 1815 1816 def qdq(input, scales, zero_points, axis, quant_min, quant_max): 1817 res = torch.ops.quantized_decomposed.fake_quant_per_channel( 1818 input, scales, zero_points, axis, quant_min, quant_max 1819 ) 1820 return res 1821 1822 def run_eager_aten_fake_quant( 1823 input, scales, zero_points, axis, quant_min, quant_max 1824 ): 1825 input.grad = None 1826 res = fq(input, scales, zero_points, axis, quant_min, quant_max) 1827 res.sum().backward() 1828 return res, input.grad 1829 1830 def run_eager_decomposed_fake_quant( 1831 input, scales, zero_points, axis, quant_min, quant_max 1832 ): 1833 input.grad = None 1834 res = qdq(input, scales, zero_points, axis, quant_min, quant_max) 1835 res.sum().backward() 1836 return res, input.grad 1837 1838 def run_compile_decomposed_fake_quant( 1839 input, scales, zero_points, axis, quant_min, quant_max 1840 ): 1841 input.grad = None 1842 compiled_qdq = torch.compile(qdq) 1843 res = compiled_qdq(input, scales, zero_points, axis, quant_min, quant_max) 1844 res.sum().backward() 1845 return res, input.grad 1846 1847 input = torch.randn(2, 3, 224, 224) 1848 input[1, 2, 3, 4] = 257 1849 input.requires_grad_() 1850 scales = torch.ones((3,)) 1851 zero_points = torch.zeros((3,)) 1852 axis = 1 1853 quant_min = -128 1854 quant_max = 127 1855 1856 aten_input = copy.deepcopy(input) 1857 compiler_input = copy.deepcopy(input) 1858 1859 res_aten_eager, input_grad_aten_eager = run_eager_aten_fake_quant( 1860 aten_input, scales, zero_points, axis, quant_min, quant_max 1861 ) 1862 res_decomp_eager, input_grad_decomp_eager = run_eager_decomposed_fake_quant( 1863 input, scales, zero_points, axis, quant_min, quant_max 1864 ) 1865 res, input_grad = run_compile_decomposed_fake_quant( 1866 compiler_input, scales, zero_points, axis, quant_min, quant_max 1867 ) 1868 1869 self.assertEqual(res_aten_eager, res) 1870 self.assertEqual(res_decomp_eager, res) 1871 self.assertEqual(input_grad_aten_eager, input_grad) 1872 self.assertEqual(input_grad_decomp_eager, input_grad) 1873 self.assertEqual(input_grad[1, 2, 3, 4], torch.tensor(0.0)) 1874 # For forward and backward kernel 1875 check_metrics_vec_kernel_count(2) 1876 1877 @requires_vectorization 1878 def test_ops_masked_with_bool_input(self): 1879 x = torch.zeros(129, dtype=torch.bool) 1880 size = [2, 3] 1881 res_aten_eager = torch.constant_pad_nd(x, size) 1882 cfn = torch.compile(torch.constant_pad_nd) 1883 res = cfn(x, size) 1884 self.assertEqual(res_aten_eager, res) 1885 check_metrics_vec_kernel_count(1) 1886 1887 def test_bitwise_right_shift(self): 1888 x = torch.randint(-1, 0, (1, 1, 1), device="cpu", dtype=torch.int64) 1889 bit_num = 31 1890 res_aten_eager = torch.bitwise_right_shift(x, bit_num) 1891 cfn = torch.compile(torch.bitwise_right_shift) 1892 res = cfn(x, bit_num) 1893 self.assertEqual(res_aten_eager, res) 1894 1895 @patch("torch.cuda.is_available", lambda: False) 1896 def test_scatter_using_atomic_add(self): 1897 def fn(a, dim, index, b): 1898 return aten.scatter(a, dim, index, b, reduce="add") 1899 1900 inps = ( 1901 torch.randn(5, 29, 13), 1902 2, 1903 torch.tensor([[[3, 5, 7, 9]]]), 1904 torch.randn(1, 1, 10), 1905 ) 1906 1907 def _internal_check( 1908 _fn, 1909 _inps, 1910 _target_code_check=None, 1911 _target_code_check_not=None, 1912 ): 1913 torch._dynamo.reset() 1914 metrics.reset() 1915 _fn_opt = torch.compile()(_fn) 1916 _, code = run_and_get_cpp_code(_fn_opt, *inps) 1917 if _target_code_check: 1918 FileCheck().check(_target_code_check).run(code) 1919 if _target_code_check_not: 1920 FileCheck().check_not(_target_code_check_not).run(code) 1921 1922 self.assertEqual( 1923 _fn(*_inps), 1924 _fn_opt(*_inps), 1925 ) 1926 1927 with config.patch({"cpp.fallback_scatter_reduce_sum": False}): 1928 _internal_check(fn, inps, "atomic_add") 1929 1930 with config.patch({"cpp.fallback_scatter_reduce_sum": True}): 1931 _internal_check(fn, inps, "aten.scatter_reduce_") 1932 1933 if "ATen parallel backend: OpenMP" in torch.__config__.parallel_info(): 1934 # Fix https://github.com/pytorch/pytorch/issues/118518 1935 # which fails to change thread number with native thread pool 1936 with set_num_threads(1): 1937 _internal_check(fn, inps, _target_code_check_not="aten.scatter_reduce_") 1938 1939 with config.patch({"cpp.dynamic_threads": True}), set_num_threads(1): 1940 _internal_check(fn, inps, "aten.scatter_reduce_") 1941 1942 @requires_vectorization 1943 @patch("torch.cuda.is_available", lambda: False) 1944 def test_new_vec_op_cpu_only(self): 1945 def fn(x): 1946 return torch.log1p(torch.expm1(torch.erf(x))) 1947 1948 for dtype in vec_dtypes: 1949 torch.manual_seed(0) 1950 x = torch.randn((2, 9), dtype=dtype) 1951 x[0, 0] = torch.nan 1952 x[1, -1] = torch.nan 1953 1954 tol = 1e-2 if dtype == torch.bfloat16 else 1e-4 1955 1956 with config.patch({"cpp.simdlen": None}): 1957 for cpp_wrapper_flag in [True, False]: 1958 with config.patch({"cpp_wrapper": cpp_wrapper_flag}): 1959 torch._dynamo.reset() 1960 metrics.reset() 1961 self.common(fn, (x,)) 1962 check_metrics_vec_kernel_count(1) 1963 1964 @requires_vectorization 1965 @patch("torch.cuda.is_available", lambda: False) 1966 def test_vec_cpu_only_for_all_available_isa(self): 1967 def fn(x): 1968 return torch.sin(torch.cos(torch.erf(x))) 1969 1970 x = torch.randn((2, 9)) 1971 x[0, 0] = torch.nan 1972 x[1, -1] = torch.nan 1973 1974 bit_widths = [isa._bit_width for isa in codecache.valid_vec_isa_list()] + [None] 1975 for item in bit_widths: 1976 with config.patch({"cpp.simdlen": item}): 1977 torch._dynamo.reset() 1978 metrics.reset() 1979 self.common(fn, (x,)) 1980 check_metrics_vec_kernel_count(1) 1981 1982 @slowTest 1983 @requires_vectorization 1984 @patch("torch.cuda.is_available", lambda: False) 1985 def test__adaptive_avg_pool2d(self): 1986 def wrap_fn(oh, ow): 1987 def fn(x): 1988 return torch._adaptive_avg_pool2d(x, (oh, ow)) 1989 1990 return fn 1991 1992 bit_widths = [isa._bit_width for isa in codecache.valid_vec_isa_list()] 1993 ih = [16, 65] 1994 iw = ih 1995 oh = ih 1996 ow = ih 1997 for _ih, _iw, _oh, _ow, _simd_len, dtype in itertools.product( 1998 ih, iw, oh, ow, bit_widths, vec_dtypes 1999 ): 2000 x = torch.randn(2, 3, _ih, _iw, dtype=dtype).to( 2001 memory_format=torch.channels_last 2002 ) 2003 _fn = wrap_fn(_oh, _ow) 2004 with config.patch({"cpp.simdlen": _simd_len}): 2005 torch._dynamo.reset() 2006 metrics.reset() 2007 self.common(_fn, (x,)) 2008 check_metrics_vec_kernel_count(1) 2009 2010 @requires_vectorization 2011 @patch("torch.cuda.is_available", lambda: False) 2012 def test_vec_logical(self): 2013 def wrap_fn1(op: Callable): 2014 def fn(x: torch.Tensor): 2015 return torch.where(op(x), 1.0, 0.0) 2016 2017 return fn 2018 2019 def wrap_fn2(op: Callable): 2020 def fn(x: torch.Tensor, y: torch.Tensor): 2021 return torch.where(op(x, y), 1.0, 0.0) 2022 2023 return fn 2024 2025 for dtype in vec_dtypes: 2026 x = torch.randn(64, dtype=dtype) 2027 y = torch.randn(64, dtype=dtype) 2028 logical_fns = [ 2029 torch.logical_and, 2030 torch.logical_not, 2031 torch.logical_or, 2032 torch.logical_xor, 2033 ] 2034 for logical_fn in logical_fns: 2035 torch._dynamo.reset() 2036 metrics.reset() 2037 if logical_fn == torch.logical_not: 2038 _fn = wrap_fn1(logical_fn) 2039 _args = (x,) 2040 else: 2041 _fn = wrap_fn2(logical_fn) 2042 _args = (x, y) 2043 self.common(_fn, _args) 2044 check_metrics_vec_kernel_count(1) 2045 2046 @requires_vectorization 2047 @patch("torch.cuda.is_available", lambda: False) 2048 def test_vec_compare_op_cpu_only(self): 2049 def fn(x): 2050 y1 = torch.eq(x, 1.0) 2051 x = torch.where(y1, x, -x) 2052 y2 = torch.ne(x, 0.0) 2053 x = torch.where(y2, x, -x) 2054 y3 = torch.lt(x, 5.0) 2055 x = torch.where(y3, x, x - 1.0) 2056 y4 = torch.gt(x, -2.0) 2057 x = torch.where(y4, x, x + 1.0) 2058 y5 = torch.le(x, 8.0) 2059 x = torch.where(y5, x, x - 1.0) 2060 y6 = torch.ge(x, -3.0) 2061 x = torch.where(y6, x, x + 1.0) 2062 y7 = x == 1.0 2063 x = torch.where(y7, x, -x) 2064 y8 = x != 0.0 2065 x = torch.where(y8, x, -x) 2066 y9 = x < 5.0 2067 x = torch.where(y9, x, x - 1.0) 2068 y10 = x > -2.0 2069 x = torch.where(y10, x, x + 1.0) 2070 y11 = x <= 8.0 2071 x = torch.where(y11, x, x - 1.0) 2072 y12 = x >= -3.0 2073 x = torch.where(y12, x, x + 1.0) 2074 return x 2075 2076 for dtype in vec_dtypes: 2077 x = torch.randn((2, 9), dtype=dtype) 2078 2079 with config.patch({"cpp.simdlen": None}): 2080 torch._dynamo.reset() 2081 metrics.reset() 2082 self.common(fn, (x,)) 2083 check_metrics_vec_kernel_count(1) 2084 assert ( 2085 metrics.generated_kernel_count 2086 - metrics.generated_cpp_vec_kernel_count 2087 ) == 0 2088 2089 def test_skip_cpp_codegen(self): 2090 with config.patch({"disable_cpp_codegen": True}): 2091 inps = torch.ones([20]), torch.rand([20]) 2092 2093 def f(x, y): 2094 return x + y + torch.tensor(1) 2095 2096 f_opt = torch.compile()(f) 2097 2098 _, code = run_and_get_cpp_code(f_opt, inps[0], inps[1]) 2099 FileCheck().check_not("void kernel").run(code) 2100 2101 self.assertEqual( 2102 f(*inps), 2103 f_opt(*inps), 2104 ) 2105 2106 # constant needs to be propagated on fallback 2107 def f(x): 2108 return x[torch.tensor(1) :] * 2 2109 2110 f_opt = torch.compile()(f) 2111 _, code = run_and_get_cpp_code(f_opt, inps[0]) 2112 FileCheck().check_not("void kernel").run(code) 2113 self.assertEqual(f_opt(inps[0]), f(inps[0])) 2114 2115 class Model(torch.nn.Module): 2116 def __init__( 2117 self, 2118 ): 2119 super().__init__() 2120 2121 def forward(self, v1: torch.Tensor): 2122 vx = v1.min(dim=1).values 2123 v2 = torch.randn_like(vx) 2124 return v2 2125 2126 model = Model() 2127 x = torch.rand(10, 3, 0) 2128 model_f = torch.compile()(model) 2129 2130 self.assertEqual(model(x), model_f(x)) 2131 2132 def test_redundant_to_node_elimination_lowp_fp(self): 2133 def fn(x, y): 2134 res = x + y 2135 res = torch.mean(res) 2136 return res 2137 2138 for dtype in _lowp_fp_dtypes: 2139 x = torch.randn((2, 9), dtype=dtype) 2140 y = torch.randn((2, 9), dtype=dtype) 2141 2142 for torch_compile_debug in [True, False]: 2143 with config.patch( 2144 {"trace.enabled": torch_compile_debug, "cpp.simdlen": None} 2145 ): 2146 torch._dynamo.reset() 2147 metrics.reset() 2148 self.common(fn, (x, y)) 2149 check_metrics_vec_kernel_count(1) 2150 2151 def test_do_not_insert_to_dtype_for_memory_copy_only_kernel(self): 2152 def fn(x): 2153 res = x.clone() 2154 return res 2155 2156 x = torch.randn((100, 100), dtype=torch.bfloat16) 2157 2158 torch._dynamo.reset() 2159 metrics.reset() 2160 self.common(fn, (x,)) 2161 assert metrics.cpp_to_dtype_count == 0 2162 check_metrics_vec_kernel_count(1) 2163 2164 def test_insert_to_dtype_count(self): 2165 def fn(x): 2166 res = x.relu() 2167 return res 2168 2169 x = torch.randn((100, 100), dtype=torch.bfloat16) 2170 2171 torch._dynamo.reset() 2172 metrics.reset() 2173 self.common(fn, (x,)) 2174 assert metrics.cpp_to_dtype_count == 2 2175 check_metrics_vec_kernel_count(1) 2176 2177 def test_memory_copy_with_fusion(self): 2178 def fn(x): 2179 res = x.relu() 2180 x.copy_(res) 2181 return (res,) 2182 2183 x = torch.randn((100, 100), dtype=torch.bfloat16) 2184 2185 torch._dynamo.reset() 2186 metrics.reset() 2187 self.common(fn, (x,)) 2188 assert metrics.cpp_to_dtype_count == 2 2189 check_metrics_vec_kernel_count(1) 2190 2191 @requires_vectorization 2192 @patch("torch.cuda.is_available", lambda: False) 2193 def test_cpp_vec_constant_checker(self): 2194 _graph: torch.fx.Graph = torch.fx.Graph() 2195 a: torch.fx.Node = _graph.create_node("placeholder", "ops") 2196 iv: torch.fx.Node = _graph.create_node("placeholder", "iv") 2197 fv: torch.fx.Node = _graph.create_node("placeholder", "fv") 2198 b: torch.fx.Node = _graph.create_node( 2199 "call_method", 2200 "constant", 2201 args=( 2202 a, 2203 iv, 2204 torch.int64, 2205 ), 2206 ) 2207 c: torch.fx.Node = _graph.create_node( 2208 "call_method", 2209 "constant", 2210 args=( 2211 a, 2212 fv, 2213 torch.double, 2214 ), 2215 ) 2216 d: torch.fx.Node = _graph.create_node( 2217 "call_method", 2218 "ge", 2219 args=( 2220 a, 2221 b, 2222 b, 2223 ), 2224 ) 2225 _graph.output((d, c)) 2226 2227 def get_index(): 2228 return "" 2229 2230 submodules = {"get_index": get_index} 2231 2232 graph_lowering = GraphLowering( 2233 torch.fx.GraphModule(submodules, _graph), 2234 shape_env=None, 2235 ) 2236 2237 def set_opt_dtype(graph): 2238 for node in graph.nodes: 2239 if node.target == "constant": 2240 if OptimizationContext.key in node.meta: 2241 opt_ctx = node.meta[OptimizationContext.key] 2242 else: 2243 opt_ctx = OptimizationContext() 2244 opt_ctx.dtype = node.args[-1] 2245 node.meta[OptimizationContext.key] = opt_ctx 2246 2247 with patch.object(graph_lowering, "wrapper_code", ""), V.set_graph_handler( 2248 graph_lowering 2249 ): 2250 # The moset inner loop variable is used in the index_expr 2251 tiling_factor = codecache.pick_vec_isa().nelements(dtype=torch.float) 2252 with CppVecKernelChecker( 2253 args=None, num_threads=1, tiling_factor=tiling_factor 2254 ) as vec_checker: 2255 i32_iinfo = np.iinfo(np.int32) 2256 f32_iinfo = np.finfo(np.float32) 2257 set_opt_dtype(_graph) 2258 InterpreterShim(_graph, submodules).run( 2259 V.get_ops_handler(), i32_iinfo.max, f32_iinfo.max 2260 ) 2261 self.assertTrue(vec_checker.simd_vec) 2262 2263 vec_checker.simd_vec = True 2264 set_opt_dtype(_graph) 2265 InterpreterShim(_graph, submodules).run( 2266 V.get_ops_handler(), i32_iinfo.min, f32_iinfo.min 2267 ) 2268 self.assertTrue(vec_checker.simd_vec) 2269 2270 vec_checker.simd_vec = True 2271 set_opt_dtype(_graph) 2272 InterpreterShim(_graph, submodules).run( 2273 V.get_ops_handler(), i32_iinfo.min, np.inf 2274 ) 2275 self.assertTrue(vec_checker.simd_vec) 2276 2277 vec_checker.simd_vec = True 2278 set_opt_dtype(_graph) 2279 InterpreterShim(_graph, submodules).run( 2280 V.get_ops_handler(), i32_iinfo.min, -np.inf 2281 ) 2282 self.assertTrue(vec_checker.simd_vec) 2283 2284 vec_checker.simd_vec = True 2285 set_opt_dtype(_graph) 2286 InterpreterShim(_graph, submodules).run( 2287 V.get_ops_handler(), i32_iinfo.min - 1, f32_iinfo.min 2288 ) 2289 self.assertTrue(vec_checker.simd_vec) 2290 2291 vec_checker.simd_vec = True 2292 set_opt_dtype(_graph) 2293 InterpreterShim(_graph, submodules).run( 2294 V.get_ops_handler(), i32_iinfo.max + 1, f32_iinfo.max 2295 ) 2296 self.assertTrue(vec_checker.simd_vec) 2297 2298 vec_checker.simd_vec = True 2299 set_opt_dtype(_graph) 2300 InterpreterShim(_graph, submodules).run( 2301 V.get_ops_handler(), i32_iinfo.min, f32_iinfo.min * (1 + 1e-5) 2302 ) 2303 self.assertFalse(vec_checker.simd_vec) 2304 2305 vec_checker.simd_vec = True 2306 set_opt_dtype(_graph) 2307 InterpreterShim(_graph, submodules).run( 2308 V.get_ops_handler(), i32_iinfo.max, f32_iinfo.max * (1 + 1e-5) 2309 ) 2310 self.assertFalse(vec_checker.simd_vec) 2311 2312 @requires_vectorization 2313 @patch("torch.cuda.is_available", lambda: False) 2314 def test_cpp_vec_index_expr_checker(self): 2315 _graph: torch.fx.Graph = torch.fx.Graph() 2316 a: torch.fx.Node = _graph.create_node("placeholder", "ops") 2317 b: torch.fx.Node = _graph.create_node("call_module", "get_index", args=()) 2318 c: torch.fx.Node = _graph.create_node( 2319 "call_method", 2320 "index_expr", 2321 args=( 2322 a, 2323 b, 2324 torch.int64, 2325 ), 2326 ) 2327 d: torch.fx.Node = _graph.create_node( 2328 "call_method", 2329 "ge", 2330 args=( 2331 a, 2332 c, 2333 c, 2334 ), 2335 ) 2336 _graph.output(d) 2337 2338 def get_index(): 2339 return "" 2340 2341 submodules = {"get_index": get_index} 2342 graph_lowering = GraphLowering( 2343 torch.fx.GraphModule(submodules, _graph), 2344 shape_env=None, 2345 ) 2346 with patch.object(graph_lowering, "wrapper_code", ""), V.set_graph_handler( 2347 graph_lowering 2348 ): 2349 itervars = [sympy.Symbol("i"), sympy.Symbol("j"), sympy.Symbol("k")] 2350 2351 tiling_factor = codecache.pick_vec_isa().nelements(dtype=torch.float) 2352 # The most inner loop variable is used in the index_expr 2353 with CppVecKernelChecker( 2354 args=None, num_threads=1, tiling_factor=tiling_factor 2355 ) as vec_checker: 2356 2357 def get_index(): 2358 return -itervars[0] ** 2 + 2 * itervars[0] + itervars[1] 2359 2360 ranges = [0, 100, 200] 2361 vec_checker.itervars = itervars[:2] 2362 vec_checker.ranges = ranges[:2] 2363 submodules = {"get_index": get_index} 2364 InterpreterShim(_graph, submodules).run(V.get_ops_handler()) 2365 self.assertTrue(vec_checker.simd_vec) 2366 2367 # Most inner loop variable irrevalant 2368 with CppVecKernelChecker( 2369 args=None, num_threads=1, tiling_factor=tiling_factor 2370 ) as vec_checker: 2371 2372 def get_index(): 2373 return -itervars[0] ** 2 + 2 * itervars[0] + itervars[1] 2374 2375 ranges = [0, 100, 200] 2376 vec_checker.itervars = itervars 2377 vec_checker.ranges = ranges 2378 submodules = {"get_index": get_index} 2379 InterpreterShim(_graph, submodules).run(V.get_ops_handler()) 2380 self.assertTrue(vec_checker.simd_vec) 2381 2382 i32_iinfo = np.iinfo(np.int32) 2383 _max_value = i32_iinfo.max + 1 2384 ranges = [_max_value, _max_value, _max_value] 2385 # Most inner loop variable irrevalant but max value is greater than 2386 # the max value of INT32 2387 with CppVecKernelChecker( 2388 args=None, num_threads=1, tiling_factor=tiling_factor 2389 ) as vec_checker: 2390 2391 def get_index(): 2392 return itervars[0] 2393 2394 submodules = {"get_index": get_index} 2395 vec_checker.itervars = itervars 2396 vec_checker.ranges = ranges 2397 InterpreterShim(_graph, submodules).run(V.get_ops_handler()) 2398 self.assertFalse(vec_checker.simd_vec) 2399 2400 # Most inner loop variable irrevalant but min value is greater than 2401 # the min value of INT32 2402 with CppVecKernelChecker( 2403 args=None, num_threads=1, tiling_factor=tiling_factor 2404 ) as vec_checker: 2405 2406 def get_index(): 2407 return -itervars[0] - 2 2408 2409 submodules = {"get_index": get_index} 2410 vec_checker.itervars = itervars 2411 vec_checker.ranges = ranges 2412 InterpreterShim(_graph, submodules).run(V.get_ops_handler()) 2413 self.assertFalse(vec_checker.simd_vec) 2414 2415 @requires_vectorization 2416 @patch("torch.cuda.is_available", lambda: False) 2417 def test_maxpool2d_cpu_only(self): 2418 for dtype in vec_dtypes: 2419 input = torch.randn(26, 32, 112, 112, dtype=dtype).to( 2420 memory_format=torch.channels_last 2421 ) 2422 maxpool = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 2423 2424 def func(x): 2425 return maxpool(x) 2426 2427 with patch.object(config.cpp, "simdlen", None): 2428 torch._dynamo.reset() 2429 metrics.reset() 2430 self.common(func, (input,)) 2431 check_metrics_vec_kernel_count(1) 2432 2433 @requires_vectorization 2434 @patch("torch.cuda.is_available", lambda: False) 2435 def test_maxpool2d_with_pre_loop_collapse_cpu_only(self): 2436 x1 = torch.randn(2, 3, 20, 20).to(memory_format=torch.channels_last) 2437 x2 = torch.randn(2, 3, 20, 20).to(memory_format=torch.channels_last) 2438 maxpool = torch.nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True) 2439 2440 def func(x1, x2): 2441 y = x1 + x2 2442 return maxpool(y) 2443 2444 with patch.object(config.cpp, "simdlen", None): 2445 torch._dynamo.reset() 2446 metrics.reset() 2447 self.common(func, (x1, x2)) 2448 check_metrics_vec_kernel_count(2) 2449 2450 def test_randint_symint_input(self): 2451 # https://github.com/pytorch/pytorch/issues/122405 2452 @torch.compile(fullgraph=True) 2453 def get_traj_idx(lengths: torch.Tensor, num_slices: int) -> torch.Tensor: 2454 return torch.randint(lengths.shape[0], (num_slices,), device=lengths.device) 2455 2456 lengths = torch.zeros(10, dtype=torch.long) 2457 get_traj_idx(lengths, num_slices=4) 2458 lengths = torch.zeros(11, dtype=torch.long) 2459 get_traj_idx(lengths, num_slices=4) 2460 2461 @requires_vectorization 2462 @patch("torch.cuda.is_available", lambda: False) 2463 def test_sign_cpu_only(self): 2464 def fn(x): 2465 return torch.sign(x) 2466 2467 for dtype in vec_dtypes: 2468 x = torch.randn((2, 9), dtype=dtype) 2469 x[0, 0] = torch.nan 2470 x[1, -1] = torch.nan 2471 2472 with config.patch({"cpp.simdlen": None}): 2473 torch._dynamo.reset() 2474 metrics.reset() 2475 self.common(fn, (x,)) 2476 check_metrics_vec_kernel_count(1) 2477 2478 @requires_vectorization 2479 @patch("torch.cuda.is_available", lambda: False) 2480 def test_reduction_cpu_only(self): 2481 def fn(x): 2482 return torch.argmax(x, -1) 2483 2484 for dtype in vec_dtypes: 2485 x = torch.randn((10, 10), dtype=dtype) 2486 2487 with config.patch({"cpp.simdlen": None}): 2488 torch._dynamo.reset() 2489 metrics.reset() 2490 self.common(fn, (x,)) 2491 assert metrics.generated_cpp_vec_kernel_count == 0 2492 2493 def test_outer_loop_fusion(self): 2494 def fn(x): 2495 max = torch.amax(x, dim=-1, keepdim=True) 2496 return x - max 2497 2498 x = torch.randn(4, 12, 1023, 1022) 2499 2500 with config.patch({"cpp.simdlen": None}): 2501 torch._dynamo.reset() 2502 metrics.reset() 2503 self.common(fn, (x,)) 2504 assert len(metrics.cpp_outer_loop_fused_inner_counts) == 1 2505 assert metrics.cpp_outer_loop_fused_inner_counts[0] == 2 2506 2507 def test_argmin(self): 2508 def fn(x): 2509 return torch.argmin(x, -1) 2510 2511 for dtype in vec_dtypes: 2512 x = torch.randn((10, 10), dtype=dtype) 2513 torch._dynamo.reset() 2514 metrics.reset() 2515 self.common(fn, (x,)) 2516 assert metrics.generated_cpp_vec_kernel_count == 0 2517 2518 def test_argmax_argmin_with_nan_value(self): 2519 def fn(x): 2520 return torch.argmax(x) 2521 2522 def fn2(x): 2523 return torch.argmin(x) 2524 2525 inputs = [ 2526 torch.Tensor([-755832.1250, 100]), 2527 torch.Tensor([-755832.1250, 100, 200]), 2528 torch.Tensor([100, -755832.1250]), 2529 torch.Tensor([100, 200, -755832.1250]), 2530 ] 2531 2532 for x in inputs: 2533 x = x.repeat(16, 16) 2534 x = torch.log1p(x) 2535 2536 # Test argmax 2537 torch._dynamo.reset() 2538 metrics.reset() 2539 self.common(fn, (x,)) 2540 assert metrics.generated_cpp_vec_kernel_count == 0 2541 2542 # Test argmin 2543 torch._dynamo.reset() 2544 metrics.reset() 2545 self.common(fn2, (x,)) 2546 assert metrics.generated_cpp_vec_kernel_count == 0 2547 2548 # Currently, we enabled AVX2 and AVX512 for vectorization. If the platform is not 2549 # supported, the vectorization will not work and skip this test case. For ARM or 2550 # other platforms support, we just need to add the ISA info to the supported_vector_isa 2551 # and include proper aten vectorization head file. 2552 @requires_vectorization 2553 @patch("torch.cuda.is_available", lambda: False) 2554 def test_vec_kernel_cpu_only(self): 2555 def fn(x1, x2): 2556 # Current, there are some limitations as follows. 2557 # rsqrt: 2558 # assert [both a fallback and a decomp for same kernel: aten.rsqrt.default] 2559 # round: 2560 # couldn't find symbolic meta function/decomposition 2561 # fmod/logical_and/logic_or: 2562 # vec kernel has not support to_type 2563 x = torch.abs(x1) 2564 x = torch.sin(x) 2565 x = torch.neg(x) 2566 x = torch.square(x) 2567 x = torch.sigmoid(x) 2568 x = torch.relu(x) 2569 x = torch.cos(x) 2570 x = torch.exp(x) 2571 x = torch.sqrt(x) 2572 x = torch.add(x, x1) 2573 x = torch.sub(x, x2) 2574 x = torch.mul(x, x1) 2575 x = torch.div(x, x1) 2576 x = torch.pow(x, 10) 2577 x = torch.log(x) 2578 x = torch.floor(x) 2579 x = torch.ceil(x) 2580 x = torch.trunc(x) 2581 x = torch.lgamma(x) 2582 x = torch.fmod(x, x2) 2583 x = torch.sign(x) 2584 res = x + x2 2585 return res 2586 2587 for dtype in vec_dtypes: 2588 torch.manual_seed(0) 2589 x1 = torch.randn((5, 20), dtype=dtype) 2590 x2 = torch.randn((5, 20), dtype=dtype) 2591 2592 tol = 1e-2 if dtype == torch.bfloat16 else 1e-4 2593 with config.patch({"cpp.simdlen": 1}): 2594 torch._dynamo.reset() 2595 metrics.reset() 2596 self.common(fn, (x1, x2)) 2597 assert metrics.generated_cpp_vec_kernel_count == 0 2598 2599 with config.patch({"cpp.simdlen": None}): 2600 torch._dynamo.reset() 2601 metrics.reset() 2602 self.common(fn, (x1, x2)) 2603 check_metrics_vec_kernel_count(1) 2604 2605 with config.patch({"cpp.simdlen": None}): 2606 torch._dynamo.reset() 2607 metrics.reset() 2608 x1 = torch.randn(10, 20).permute(1, 0) 2609 x2 = torch.randn((20, 10)) 2610 self.common(fn, (x1, x2)) 2611 check_metrics_vec_kernel_count(2) 2612 2613 torch._dynamo.reset() 2614 metrics.reset() 2615 x1 = torch.randn((10, 7)) 2616 x2 = torch.randn((10, 7)) 2617 self.common(fn, (x1, x2)) 2618 check_metrics_vec_kernel_count(1) 2619 2620 @unittest.skipIf( 2621 sys.platform != "linux", "cpp kernel profile only support linux now" 2622 ) 2623 @patch("torch.cuda.is_available", lambda: False) 2624 @config.patch({"cpp.enable_kernel_profile": True}) 2625 @config.patch({"cpp.descriptive_names": "original_aten"}) 2626 def test_cpp_kernel_profile(self): 2627 from torch.profiler import profile 2628 2629 @torch._dynamo.optimize("inductor", nopython=True) 2630 def fn(a, b): 2631 return a + b 2632 2633 a = torch.rand((100,)) 2634 b = torch.rand((100,)) 2635 with profile() as prof: 2636 fn(a, b) 2637 2638 kernel_profile_events = [] 2639 for e in prof.profiler.function_events: 2640 if "cpp_fused_add_0" in e.name: 2641 kernel_profile_events.append(e.name) 2642 assert len(kernel_profile_events) > 0 2643 2644 @requires_vectorization 2645 def test_channel_shuffle_cl_output(self): 2646 """code and shape extracted from shufflenet_v2_x1_0""" 2647 2648 def channel_shuffle(x, groups): 2649 batchsize, num_channels, height, width = x.size() 2650 channels_per_group = num_channels // groups 2651 x = x.view(batchsize, groups, channels_per_group, height, width) 2652 x = torch.transpose(x, 1, 2).contiguous() 2653 x = x.view(batchsize, -1, height, width) 2654 return x.contiguous(memory_format=torch.channels_last) 2655 2656 for simdlen in (None, 256, 1): 2657 with config.patch({"cpp.simdlen": simdlen}): 2658 torch._dynamo.reset() 2659 metrics.reset() 2660 x = torch.randn(64, 58, 28, 28) 2661 self.common(channel_shuffle, (x, 2)) 2662 if simdlen != 1: 2663 check_metrics_vec_kernel_count(2) 2664 2665 @slowTest 2666 @requires_vectorization 2667 def test_transpose_with_norm(self): 2668 """a sub-module from TIMM gmlp_s16_224""" 2669 2670 class Model(torch.nn.Module): 2671 def __init__(self): 2672 super().__init__() 2673 self.linear = torch.nn.Linear( 2674 in_features=256, out_features=1536, bias=True 2675 ) 2676 self.act = torch.nn.GELU() 2677 self.norm = torch.nn.LayerNorm(768) 2678 self.proj = torch.nn.Linear(196, 196) 2679 self.fc = torch.nn.Linear(in_features=768, out_features=256, bias=True) 2680 2681 def forward(self, x): 2682 x = self.linear(x) 2683 x = self.act(x) 2684 u, v = x.chunk(2, dim=-1) 2685 v = self.norm(v) 2686 v = self.proj(v.transpose(-1, -2)) 2687 y = u * v.transpose(-1, -2) 2688 return self.fc(y) 2689 2690 x = torch.randn(128, 196, 256) 2691 for simdlen in (None, 256, 1): 2692 with config.patch({"cpp.simdlen": simdlen}): 2693 for eval_mode in [True, False]: 2694 torch._dynamo.reset() 2695 metrics.reset() 2696 m = Model().eval() if eval_mode else Model() 2697 self.common(m, (x,)) 2698 if simdlen != 1: 2699 check_metrics_vec_kernel_count(8) 2700 2701 @requires_vectorization 2702 def test_transpose_copy(self): 2703 def fn(a): 2704 return a.t().contiguous() 2705 2706 for simdlen in (None, 256, 1): 2707 with config.patch({"cpp.simdlen": simdlen}): 2708 for dtype in (torch.float, torch.bfloat16): 2709 for shape in ( 2710 (7, 7), 2711 (8, 8), 2712 (9, 9), 2713 (16, 16), 2714 (17, 17), 2715 (32, 32), 2716 (33, 33), 2717 ): 2718 torch._dynamo.reset() 2719 metrics.reset() 2720 x = torch.randn(shape, dtype=dtype) 2721 self.common(fn, (x,)) 2722 if simdlen != 1: 2723 check_metrics_vec_kernel_count(2) 2724 2725 @torch._dynamo.config.patch(specialize_int=False) 2726 def test_slice_scatter_issue122291(self): 2727 @torch.compile(fullgraph=True) 2728 def fn(t, t_src, dim, start, end, step): 2729 return t.slice_scatter(t_src, dim, start, end, step) 2730 2731 shape = ((16, 16), (16, 2), 1, 4, 10, 1) 2732 input_tensor = torch.zeros(shape[0], requires_grad=False, device="cpu") 2733 src_tensor = torch.ones(shape[1], requires_grad=False, device="cpu") 2734 with self.assertRaisesRegex( 2735 torch._dynamo.exc.BackendCompilerFailed, r".*shape error in scatter op" 2736 ): 2737 fn(input_tensor, src_tensor, shape[2], shape[3], shape[4], shape[5]) 2738 2739 def test_horizontal_fusion(self): 2740 def fn(a, b, c, idx): 2741 _a = torch.index_select(a, dim=0, index=idx) 2742 _b = torch.index_select(b, dim=0, index=idx) 2743 _c = torch.index_select(c, dim=0, index=idx) 2744 return _a, _b, _c 2745 2746 with config.patch({"cpp.max_horizontal_fusion_size": 0}): 2747 metrics.reset() 2748 torch._dynamo.reset() 2749 a = torch.randn(size=(4, 16), dtype=torch.bfloat16) 2750 b = torch.randn(size=(4, 16), dtype=torch.bfloat16) 2751 c = torch.randn(size=(4, 16), dtype=torch.bfloat16) 2752 idx = torch.zeros(size=[4], dtype=torch.int64) 2753 opt_fn = torch._dynamo.optimize("inductor")(fn) 2754 opt_fn(a, b, c, idx) 2755 self.assertEqual(metrics.generated_kernel_count, 3) 2756 self.assertTrue(same(fn(a, b, c, idx), opt_fn(a, b, c, idx))) 2757 2758 with config.patch({"cpp.max_horizontal_fusion_size": 1}): 2759 metrics.reset() 2760 torch._dynamo.reset() 2761 a = torch.randn(size=(4, 32), dtype=torch.bfloat16) 2762 b = torch.randn(size=(4, 32), dtype=torch.bfloat16) 2763 c = torch.randn(size=(4, 32), dtype=torch.bfloat16) 2764 idx = torch.zeros(size=[4], dtype=torch.int64) 2765 opt_fn = torch._dynamo.optimize("inductor")(fn) 2766 opt_fn(a, b, c, idx) 2767 self.assertEqual(metrics.generated_kernel_count, 3) 2768 self.assertTrue(same(fn(a, b, c, idx), opt_fn(a, b, c, idx))) 2769 2770 with config.patch({"cpp.max_horizontal_fusion_size": 2}): 2771 metrics.reset() 2772 torch._dynamo.reset() 2773 a = torch.randn(size=(4, 64), dtype=torch.bfloat16) 2774 b = torch.randn(size=(4, 64), dtype=torch.bfloat16) 2775 c = torch.randn(size=(4, 64), dtype=torch.bfloat16) 2776 idx = torch.zeros(size=[4], dtype=torch.int64) 2777 opt_fn = torch._dynamo.optimize("inductor")(fn) 2778 opt_fn(a, b, c, idx) 2779 print(metrics.generated_kernel_count) 2780 self.assertEqual(metrics.generated_kernel_count, 2) 2781 self.assertTrue(same(fn(a, b, c, idx), opt_fn(a, b, c, idx))) 2782 2783 with config.patch({"cpp.max_horizontal_fusion_size": 3}): 2784 metrics.reset() 2785 torch._dynamo.reset() 2786 a = torch.randn(size=(4, 128), dtype=torch.bfloat16) 2787 b = torch.randn(size=(4, 128), dtype=torch.bfloat16) 2788 c = torch.randn(size=(4, 128), dtype=torch.bfloat16) 2789 idx = torch.zeros(size=[4], dtype=torch.int64) 2790 opt_fn = torch._dynamo.optimize("inductor")(fn) 2791 opt_fn(a, b, c, idx) 2792 self.assertEqual(metrics.generated_kernel_count, 1) 2793 self.assertTrue(same(fn(a, b, c, idx), opt_fn(a, b, c, idx))) 2794 2795 def test_lowp_fp_neg_abs(self): 2796 def fn(x): 2797 return x.neg().abs() 2798 2799 for dtype in _lowp_fp_dtypes: 2800 metrics.reset() 2801 x = torch.randn(100, 100).to(dtype) 2802 opt_fn = torch._dynamo.optimize("inductor")(fn) 2803 self.assertTrue(same(fn(x), opt_fn(x))) 2804 assert metrics.cpp_to_dtype_count == 0 2805 check_metrics_vec_kernel_count(1) 2806 2807 def test_transpose_non_contiguous(self): 2808 def fn(a): 2809 # From part of timm HaloAttn: 2810 # (https://github.com/rwightman/pytorch-image-models/blob/main/timm/layers/halo_attn.py#L97). 2811 # Fixed https://github.com/pytorch/pytorch/issues/94269 accuracy issue. 2812 as_strided = torch.ops.aten.as_strided.default( 2813 a, [1, 384, 2, 20, 12], [153600, 1, 61440, 384, 7680] 2814 ) 2815 as_strided_1 = torch.ops.aten.as_strided.default( 2816 as_strided, 2817 [1, 384, 2, 2, 12, 12], 2818 [153600, 1, 61440, 3072, 7680, 384], 2819 ) 2820 clone_1 = torch.ops.aten.clone.default( 2821 as_strided_1, memory_format=torch.contiguous_format 2822 ) 2823 _unsafe_view_1 = torch.ops.aten._unsafe_view.default( 2824 clone_1, [8, 48, 4, 144] 2825 ) 2826 permute_2 = torch.ops.aten.permute.default(_unsafe_view_1, [0, 2, 3, 1]) 2827 split_with_sizes = torch.ops.aten.split_with_sizes.default( 2828 permute_2, [16, 32], -1 2829 ) 2830 getitem = split_with_sizes[0] 2831 getitem_1 = split_with_sizes[1] 2832 permute_3 = torch.ops.aten.permute.default(getitem, [0, 1, 3, 2]) 2833 expand_1 = torch.ops.aten.expand.default(permute_3, [8, 4, 16, 144]) 2834 clone_3 = torch.ops.aten.clone.default( 2835 expand_1, memory_format=torch.contiguous_format 2836 ) 2837 return clone_3 2838 2839 metrics.reset() 2840 x = torch.randn(1, 384, 20, 20).to(memory_format=torch.channels_last) 2841 self.common(fn, (x,)) 2842 check_metrics_vec_kernel_count(1) 2843 2844 def test_non_contiguous_index_with_constant_stride(self): 2845 def fn(x): 2846 x1 = x[:, :, :, ::2] 2847 x2 = x[:, :, :, 1::2] 2848 x = torch.stack((-x2, x1), dim=-1) 2849 return x.flatten(-2) 2850 2851 metrics.reset() 2852 x = torch.randn(1, 32, 16, 68) 2853 opt_fn = torch._dynamo.optimize("inductor")(fn) 2854 _, code = run_and_get_cpp_code(opt_fn, x) 2855 self.assertTrue(same(fn(x), opt_fn(x))) 2856 # def and use 2857 FileCheck().check_count("cpp_fused", 2, exactly=True).run(code) 2858 2859 def test_invalid_index_of_empty_tensor(self): 2860 def fn(a): 2861 b = a[[0]] 2862 return b 2863 2864 a = torch.tensor([]) 2865 with self.assertRaises(RuntimeError): 2866 torch.compile(fn)(a) 2867 2868 @torch.no_grad() 2869 @torch._inductor.config.patch(freezing=True) 2870 def test_issue122380(self): 2871 def func(x): 2872 t1 = torch.unbind(x) 2873 t2 = torch.stack(t1, dim=1) 2874 t3 = torch.tanh(t2) 2875 return t3 2876 2877 x = torch.randn(2, 3, 4) 2878 self.assertEqual(torch.compile(func)(x), func(x)) 2879 2880 def test_ir_node_str(self): 2881 @torch.compile 2882 def fn(x: torch.Tensor) -> torch.Tensor: 2883 return x.sin(), torch.nn.Softmax(dim=1)(x.cos()) 2884 2885 def run_node_alt(*args, **kwargs): 2886 rv = run_node(*args, **kwargs) 2887 strings.append(str(rv)) 2888 return rv 2889 2890 strings = [] 2891 run_node = GraphLowering.run_node 2892 with patch.object(GraphLowering, "run_node", run_node_alt): 2893 fn(torch.randn([8, 128])) 2894 self.assertGreater(len(strings), 3) 2895 2896 def test_vertical_sum_cpu_only(self): 2897 def fn1(a): 2898 return a.sum(dim=0) 2899 2900 def fn2(a): 2901 return a.sum(dim=1) 2902 2903 metrics.reset() 2904 x = torch.randn(100, 100) 2905 self.common(fn1, (x,)) 2906 check_metrics_vec_kernel_count(1) 2907 2908 metrics.reset() 2909 x = torch.randn(100, 100, 100) 2910 self.common(fn2, (x,)) 2911 check_metrics_vec_kernel_count(1) 2912 2913 def test_transpose_vertical_sum_cpu_only(self): 2914 def fn(a, b): 2915 c = a * b 2916 return c.sum(dim=1) 2917 2918 metrics.reset() 2919 x = torch.randn(100, 50, 50) 2920 y = torch.randn(100, 50, 50).transpose(1, 2) 2921 self.common(fn, (x, y)) 2922 check_metrics_vec_kernel_count(2) 2923 2924 def test_transpose_mxn_16_16_bf16_fp16(self): 2925 def fn(a, b): 2926 c = a * b 2927 return c.sum(dim=1) 2928 2929 for dtype in [torch.bfloat16, torch.float16]: 2930 metrics.reset() 2931 x = torch.randn(100, 50, 50).to(dtype) 2932 y = torch.randn(100, 50, 50).to(dtype).transpose(1, 2) 2933 self.common(fn, (x, y)) 2934 check_metrics_vec_kernel_count(2) 2935 2936 def test_transpose_mxn_32_32_bf16_fp16(self): 2937 def fn(a): 2938 return a.permute(0, 2, 1).contiguous() 2939 2940 for dtype in [torch.bfloat16, torch.float16]: 2941 metrics.reset() 2942 x = torch.randn(2, 9216, 9216).to(dtype) 2943 self.common(fn, (x,)) 2944 check_metrics_vec_kernel_count(2) 2945 2946 def test_transpose_sum2d_cpu_only(self): 2947 def fn(a, b): 2948 c = a * b 2949 return c.sum() 2950 2951 metrics.reset() 2952 x = torch.randn(50, 50) 2953 y = torch.randn(50, 50).transpose(0, 1) 2954 self.common(fn, (x, y)) 2955 check_metrics_vec_kernel_count(2) 2956 2957 def test_transpose_sum_outer(self): 2958 # https://github.com/pytorch/pytorch/issues/98573 2959 def fn(a): 2960 return a.transpose(2, 3).sum(dim=1).contiguous() 2961 2962 metrics.reset() 2963 x = torch.randn(10, 50, 50, 50) 2964 self.common(fn, (x,)) 2965 check_metrics_vec_kernel_count(1) 2966 2967 def test_to_dtype_bool_float(self): 2968 # https://github.com/pytorch/pytorch/issues/100800 2969 def f(a): 2970 return torch.where( 2971 torch.ones_like(a).to(torch.bool), 2972 torch.zeros_like(a), 2973 torch.ones_like(a) * 2, 2974 ) 2975 2976 self.common(f, (torch.ones(16),)) 2977 2978 def test_to_dtype_float_bool(self): 2979 # https://github.com/pytorch/pytorch/issues/100466 2980 def f(a): 2981 a = a * torch.tensor(a >= 0, dtype=torch.float32) 2982 return a 2983 2984 x = torch.rand(16) 2985 self.common(f, (x,)) 2986 2987 def test_constant_store(self): 2988 # https://github.com/pytorch/pytorch/issues/104515 2989 def f(a): 2990 a[0, [3, 3]] = -float("inf") 2991 return a 2992 2993 x = torch.rand(4, 5) 2994 self.common(f, (x,)) 2995 2996 def test_to_channels_last_lowp_fp(self): 2997 def f(a): 2998 return a.to(memory_format=torch.channels_last) 2999 3000 for dtype in _lowp_fp_dtypes: 3001 x = torch.rand(2, 3, 14, 14).to(dtype) 3002 self.common(f, (x,)) 3003 3004 def test_broadcast_mul_lowp_fp(self): 3005 def f(a, b): 3006 return a * b 3007 3008 for dtype in _lowp_fp_dtypes: 3009 a = torch.randn(2, 16, 16).to(dtype) 3010 b = torch.randn(2, 1, 1).to(dtype) 3011 self.common(f, (a, b)) 3012 3013 def test_linear_buffer_reuse(self): 3014 class M(torch.nn.Module): 3015 def __init__(self): 3016 super().__init__() 3017 self.linear1 = torch.nn.Linear(16, 16) 3018 self.tanh = torch.nn.Tanh() 3019 self.linear2 = torch.nn.Linear(16, 16) 3020 3021 def forward(self, x): 3022 x = self.linear1(x) 3023 x = self.tanh(x) 3024 x = self.linear2(x) 3025 return x 3026 3027 mod = M().eval() 3028 v = torch.randn(1, 16) 3029 3030 with torch.no_grad(): 3031 3032 def compile_fx_wrapper(model_, example_inputs_): 3033 return compile_fx(model_, example_inputs_) 3034 3035 def run(*ex, **kwargs): 3036 return mod(*ex, **kwargs) 3037 3038 run = torch._dynamo.optimize(compile_fx_wrapper)(run) 3039 _, code = run_and_get_cpp_code(run, v) 3040 self.assertFalse("= as_strided(" in code) 3041 self.assertEqual(run(*v), mod(*v)) 3042 3043 def test_invalid_dropout_args(self): 3044 class MyModel(torch.nn.Module): 3045 def forward(self, x): 3046 x = x * 2 3047 x = torch.nn.functional.dropout(x, p=0.5) 3048 x = torch.relu(x) 3049 return x 3050 3051 example_inputs = torch.tensor([[1, 2, 3], [4, 5, 6]]) 3052 3053 func = MyModel() 3054 jit_func = torch.compile(func) 3055 self.assertRaises(RuntimeError, lambda: func(example_inputs)) 3056 self.assertRaises(RuntimeError, lambda: jit_func(example_inputs)) 3057 3058 def test_nn_param_assign(self): 3059 # https://github.com/pytorch/pytorch/issues/99569 3060 class Model2(nn.Module): 3061 def __init__(self): 3062 super().__init__() 3063 self.conv = nn.Conv2d(in_channels=3, out_channels=5, kernel_size=3) 3064 self.batchnorm = nn.BatchNorm2d(num_features=5) 3065 self.conv_weight = torch.randn(5, 3, 3, 3) 3066 self.conv_bias = torch.randn(5) 3067 3068 def forward(self, x): 3069 self.conv.weight = nn.Parameter(self.conv_weight) 3070 self.conv.bias = nn.Parameter(self.conv_bias, requires_grad=False) 3071 self.conv.eval() 3072 x = self.conv(x) 3073 x = self.batchnorm(x) 3074 x = F.relu(x) 3075 return x 3076 3077 input_tensor = torch.randn(1, 3, 10, 10) 3078 func = Model2().to("cpu") 3079 3080 with torch.no_grad(): 3081 func.train(False) 3082 v1 = func(input_tensor) 3083 jit_func = torch.compile(func, fullgraph=True) 3084 v2 = jit_func(input_tensor) 3085 self.assertEqual(v1, v2) 3086 3087 def test_nn_param_assign_wrapped(self): 3088 class Model2(nn.Module): 3089 def __init__(self): 3090 super().__init__() 3091 self.conv = nn.Conv2d(in_channels=3, out_channels=5, kernel_size=3) 3092 self.batchnorm = nn.BatchNorm2d(num_features=5) 3093 self.conv_weight = torch.randn(5, 3, 3, 3) 3094 self.conv_bias = torch.randn(5) 3095 3096 def forward(self, x): 3097 self.conv.weight = nn.Parameter(self.conv_weight) 3098 self.conv.bias = nn.Parameter(self.conv_bias, requires_grad=False) 3099 self.conv.eval() 3100 x = self.conv(x) 3101 x = self.batchnorm(x) 3102 x = F.relu(x) 3103 return x 3104 3105 input_tensor = torch.randn(1, 3, 10, 10) 3106 func = Model2().to("cpu") 3107 3108 @functools.wraps(func) 3109 def wrapper(*args, **kwargs): 3110 return func(*args, **kwargs) 3111 3112 with torch.no_grad(): 3113 func.train(False) 3114 v1 = func(input_tensor) 3115 jit_func = torch.compile(wrapper, fullgraph=True) 3116 v2 = jit_func(input_tensor) 3117 self.assertEqual(v1, v2) 3118 3119 @config.patch(inplace_buffers=True) 3120 def test_in_out_buffer(self): 3121 def fn(x, y): 3122 z = torch.matmul(x, y.transpose(-1, -2)) / 8.0 3123 return z 3124 3125 inps = [torch.randn(1, 2, 8, 4), torch.randn(1, 2, 8, 4)] 3126 fn_opt = torch._dynamo.optimize("inductor")(fn) 3127 _, code = run_and_get_cpp_code(fn_opt, *inps) 3128 self.assertTrue("in_out_ptr" in code) 3129 self.assertEqual(fn_opt(*inps), fn(*inps)) 3130 3131 def test_eliminate_meaningless_copy(self): 3132 def fn(x1, x2): 3133 permute = torch.ops.aten.permute.default(x2, [0, 2, 1, 3]) 3134 clone = torch.ops.aten.clone.default( 3135 permute, memory_format=torch.contiguous_format 3136 ) 3137 view = torch.ops.aten.view.default(clone, [1024, -1, 32]) 3138 bmm = torch.ops.aten.bmm.default(view, x1) 3139 permute = torch.ops.aten.permute.default(view, [0, 2, 1]) 3140 return (bmm, permute) 3141 3142 metrics.reset() 3143 self.common( 3144 fn, 3145 [ 3146 rand_strided( 3147 (1024, 32, 128), (4096, 1, 32), device="cpu", dtype=torch.float32 3148 ), 3149 rand_strided( 3150 (64, 128, 16, 32), 3151 (65536, 512, 32, 1), 3152 device="cpu", 3153 dtype=torch.float32, 3154 ), 3155 ], 3156 ) 3157 self.assertEqual(metrics.generated_kernel_count, 1) 3158 3159 def test_attention_size_mismatch(self): 3160 class Attention(torch.nn.Module): 3161 def __init__(self, hidden_size, num_heads): 3162 super().__init__() 3163 self.hidden_size = hidden_size 3164 self.num_heads = num_heads 3165 self.head_size = hidden_size // num_heads 3166 self.query = torch.nn.Linear(hidden_size, hidden_size) 3167 self.key = torch.nn.Linear(hidden_size, hidden_size) 3168 self.value = torch.nn.Linear(hidden_size, hidden_size) 3169 self.inv_scale = torch.nn.Parameter( 3170 torch.Tensor([1 / self.head_size**0.5]), requires_grad=False 3171 ) 3172 3173 def forward(self, x): 3174 query = self.query(x) 3175 key = self.key(x) 3176 value = self.value(x) 3177 (batch_size, seq_len, hidden_size) = query.size() 3178 query = query.view( 3179 batch_size, seq_len, self.num_heads, self.head_size 3180 ).permute(0, 2, 1, 3) 3181 key = key.view( 3182 batch_size, seq_len, self.num_heads, self.head_size 3183 ).permute(0, 2, 3, 1) 3184 value = value.view( 3185 batch_size, seq_len, self.num_heads, self.head_size 3186 ).permute(0, 2, 1, 3) 3187 attention_weights = ( 3188 torch.matmul(query, key).mul(self.inv_scale).softmax(dim=-1) 3189 ) 3190 output = torch.matmul(attention_weights, value) 3191 return output 3192 3193 torch.manual_seed(123) 3194 hidden_size = 16 3195 num_heads = 1 3196 seq_len = 4 3197 batch_size = 1 3198 x = torch.randn(batch_size, seq_len, hidden_size) 3199 3200 func = Attention(hidden_size, num_heads).to("cpu") 3201 3202 with torch.no_grad(): 3203 res1 = func(x) 3204 jit_func = torch.compile(func) 3205 res2 = jit_func(x) 3206 self.assertEqual(res1, res2) 3207 3208 def test_scalar_mul_bfloat16(self): 3209 def f(x): 3210 return torch.ops.aten.mul.Tensor(x, 1.7015043497085571) 3211 3212 metrics.reset() 3213 x = torch.randn(4, 5, dtype=torch.bfloat16) 3214 self.common(f, (x,)) 3215 check_metrics_vec_kernel_count(1) 3216 3217 def test_bf16_zeros(self): 3218 def fn(): 3219 x = torch.zeros(1, 1, 32, dtype=torch.bfloat16) 3220 return x 3221 3222 self.common(fn, ()) 3223 3224 def test_select_tiliing_with_index_expr(self): 3225 def fn(x, y): 3226 x = torch.ops.aten.view.default(x, [8, 8, 8, 3136]) 3227 x = torch.ops.aten.permute.default(x, [0, 1, 3, 2]) 3228 y = torch.ops.aten.mul.Tensor(y, x) 3229 return torch.ops.aten.constant_pad_nd.default(y, [0, 0, 1, 0, 0, 0], 0.0) 3230 3231 x = torch.randn(8, 64, 56, 56) 3232 y = torch.randn(8, 8, 3136, 8) 3233 self.common(fn, (x, y)) 3234 3235 @unittest.skipIf(not torch.backends.mkldnn.is_available(), "MKLDNN is not enabled") 3236 @patch("torch.cuda.is_available", lambda: False) 3237 @config.patch(freezing=True) 3238 def test_linear_with_no_default_contiguous_input(self): 3239 dtypes = [ 3240 torch.float32, 3241 ] 3242 if torch.ops.mkldnn._is_mkldnn_bf16_supported(): 3243 dtypes.append(torch.bfloat16) 3244 if torch.ops.mkldnn._is_mkldnn_fp16_supported(): 3245 dtypes.append(torch.float16) 3246 mod = torch.nn.Sequential(torch.nn.Linear(16, 16)).eval() 3247 temp = torch.randn(1, 16, 1, 1) 3248 v = torch.as_strided(temp, [1, 16], [0, 1], 0) 3249 self.assertTrue(v.is_contiguous()) 3250 for dtype in dtypes: 3251 with torch.no_grad(): 3252 self.common( 3253 mod.to(dtype), 3254 (v.to(dtype),), 3255 ) 3256 3257 @patch("torch.cuda.is_available", lambda: False) 3258 @config.patch(freezing=True) 3259 def test_linear_with_reshape(self): 3260 class M(torch.nn.Module): 3261 def __init__(self): 3262 super().__init__() 3263 self.linear = torch.nn.Linear(16, 16, bias=False) 3264 3265 def forward(self, x): 3266 x = self.linear(x) 3267 return x.view(4, 4, 4) 3268 3269 mod = M().eval() 3270 v = torch.randn(4, 16) 3271 with torch.no_grad(): 3272 torch._dynamo.reset() 3273 metrics.reset() 3274 self.common( 3275 mod, 3276 (v,), 3277 ) 3278 assert metrics.generated_kernel_count == 0 3279 3280 @config.patch(implicit_fallbacks=True) 3281 def test_aten_normal_dtype(self): 3282 for dtype in [torch.float64, torch.float16, None]: 3283 3284 def fn(): 3285 return torch.normal(2, 3, (10, 10), dtype=dtype, device="cpu") 3286 3287 self.assertEqual( 3288 torch.compile(fn, backend="aot_eager_decomp_partition")().dtype, 3289 dtype if dtype else torch.float32, 3290 ) 3291 self.assertEqual( 3292 torch.compile(fn, backend="inductor")().dtype, 3293 dtype if dtype else torch.float32, 3294 ) 3295 3296 def test_group_norm_vec(self): 3297 class M(torch.nn.Module): 3298 def __init__(self): 3299 super().__init__() 3300 self.group_norm = torch.nn.GroupNorm(32, 32) 3301 3302 def forward(self, x): 3303 return self.group_norm(x) 3304 3305 metrics.reset() 3306 mod = M().eval() 3307 x = torch.randn(2, 32, 32, 32) 3308 with torch.no_grad(): 3309 self.common(mod, (x,)) 3310 # 2 generated kernels (one for var_mean, the other for result) 3311 check_metrics_vec_kernel_count(2) 3312 3313 def test_int_div_vec(self): 3314 def fn(x, y, mode): 3315 return torch.div(x, y, rounding_mode=mode) 3316 3317 x = torch.randint(1, 100, (32, 32)) 3318 y = torch.randint(1, 100, (32, 32)) 3319 for mode in [None, "trunc", "floor"]: 3320 with torch.no_grad(): 3321 metrics.reset() 3322 self.common(fn, (x, y, mode)) 3323 check_metrics_vec_kernel_count(1) 3324 3325 def test_uint8_add(self): 3326 # https://github.com/pytorch/pytorch/issues/113016 3327 def fn(x, y): 3328 return torch.add(x, y).neg().to(torch.int32) 3329 3330 x = torch.randint(0, 255, (3, 3), dtype=torch.uint8) 3331 y = torch.randint(0, 255, (3, 3), dtype=torch.uint8) 3332 self.common(fn, (x, y)) 3333 3334 def test_uint8_sub(self): 3335 # https://github.com/pytorch/pytorch/issues/113016 3336 def fn(x, y): 3337 return torch.sub(x, y).neg().to(torch.int32) 3338 3339 x = torch.randint(0, 255, (3, 3), dtype=torch.uint8) 3340 y = torch.randint(0, 255, (3, 3), dtype=torch.uint8) 3341 self.common(fn, (x, y)) 3342 3343 def test_non_contiguous_reduction_store(self): 3344 # https://github.com/pytorch/pytorch/issues/113018 3345 class M(torch.nn.Module): 3346 def __init__(self): 3347 super().__init__() 3348 self.conv = torch.nn.Conv2d(39, 1, kernel_size=(1, 17), stride=(2, 2)) 3349 3350 def forward(self, x): 3351 return self.conv(x.max(3).values) 3352 3353 m = M() 3354 x = torch.randn(1, 39, 1, 18, 17) 3355 self.common(m, (x,)) 3356 3357 def test_embedding_vec(self): 3358 class M(torch.nn.Module): 3359 def __init__(self): 3360 super().__init__() 3361 self.emb = torch.nn.Embedding(64, 128) 3362 3363 def forward(self, idx, x): 3364 return self.emb(idx) + x 3365 3366 idx = torch.randint(0, 64, (4, 32)) 3367 x = torch.randn(4, 32, 128) 3368 m = M().eval() 3369 with torch.no_grad(): 3370 metrics.reset() 3371 self.common(m, (idx, x)) 3372 check_metrics_vec_kernel_count(1) 3373 3374 def test_embedding_vec_bf16(self): 3375 class M(torch.nn.Module): 3376 def __init__(self): 3377 super().__init__() 3378 self.emb = torch.nn.Embedding(64, 128) 3379 3380 def forward(self, idx, x): 3381 return self.emb(idx) 3382 3383 idx = torch.randint(0, 64, (4, 32)) 3384 x = torch.randn(4, 32, 128).to(torch.bfloat16) 3385 m = M().eval() 3386 with torch.no_grad(): 3387 metrics.reset() 3388 self.common(m, (idx, x)) 3389 check_metrics_vec_kernel_count(1) 3390 3391 # we are doing direct load/store, make sure we do not generate 3392 # redundant type casts 3393 m_opt = torch.compile(m) 3394 _, code = run_and_get_cpp_code(m_opt, idx, x) 3395 self.assertTrue("Vectorized" in code) 3396 self.assertTrue("cvt_lowp_fp_to_fp32" not in code) 3397 self.assertTrue("cvt_fp32_to_lowp_fp" not in code) 3398 3399 def test_concat_inner_vec(self): 3400 def fn(x, y): 3401 return F.relu(torch.cat([x, y], dim=1)) 3402 3403 x = torch.randn(32, 35) 3404 y = torch.randn(32, 120) 3405 metrics.reset() 3406 self.common(fn, (x, y)) 3407 check_metrics_vec_kernel_count(3) 3408 3409 def test_expr_vec_non_contiguous(self): 3410 def fn(x): 3411 # the pattern from sebotnet33ts_256 3412 y = torch.nn.functional.pad(x, (0, 31)).reshape(-1, 33, 63) 3413 y = y[:, :32, 31:].reshape(4, 32, 1, 32, 32).expand(-1, -1, 32, -1, -1) 3414 y = y.permute(0, 3, 1, 4, 2).clone(memory_format=torch.contiguous_format) 3415 y = y.view(4, 1024, 1024) 3416 return y.softmax(dim=-1) 3417 3418 x = torch.randn(128, 2048) 3419 opt_fn = torch.compile(fn) 3420 metrics.reset() 3421 _, code = run_and_get_cpp_code(opt_fn, x) 3422 self.assertTrue(same(fn(x), opt_fn(x))) 3423 # 4 kernels for max, exp, sum and div 3424 check_metrics_vec_kernel_count(4) 3425 FileCheck().check_count( 3426 "Vectorized<int>::loadu(tmpbuf.data())", 0, exactly=True 3427 ).run(code) 3428 3429 def test_vec_contiguous_ModularIndexing(self): 3430 # https://github.com/pytorch/pytorch/issues/114488 3431 class M(torch.nn.Module): 3432 def __init__(self, dim): 3433 super().__init__() 3434 self.norm = torch.nn.LayerNorm(dim * 4) 3435 3436 def forward(self, x): 3437 # the pattern from swin_base_patch4_window7_224 3438 B, H, W, C = x.shape 3439 x = ( 3440 x.reshape(B, H // 2, 2, W // 2, 2, C) 3441 .permute(0, 1, 3, 4, 2, 5) 3442 .flatten(3) 3443 ) 3444 x = self.norm(x) 3445 return x 3446 3447 x = torch.randn(1, 56, 56, 128) 3448 m = M(128) 3449 opt_m = torch.compile(m) 3450 with torch.no_grad(): 3451 metrics.reset() 3452 _, code = run_and_get_cpp_code(opt_m, x) 3453 self.assertTrue(same(m(x), opt_m(x))) 3454 # Two kernels: one for reduction, one pointwises 3455 check_metrics_vec_kernel_count(2) 3456 FileCheck().check_count( 3457 "Vectorized<float>::loadu(tmpbuf.data())", 0, exactly=True 3458 ).run(code) 3459 3460 @parametrize("dtype", (torch.float16, torch.bfloat16, torch.float)) 3461 @parametrize("shape", ("15,3,13", "4,2048,4096")) 3462 def test_fp8_cast(self, dtype: torch.dtype, shape: str): 3463 def fp8_cast(x): 3464 y0 = x.to(dtype=torch.float8_e4m3fn).to(dtype) 3465 y1 = x.to(dtype=torch.float8_e5m2).to(dtype) 3466 return y0, y1 3467 3468 shape = [int(dim) for dim in shape.split(",")] 3469 x = torch.rand(*shape, device="cpu", dtype=dtype) 3470 self.common(fp8_cast, (x,)) 3471 3472 def test_logical_op_store_to_lowp_data_dtype(self): 3473 # https://github.com/pytorch/pytorch/issues/117624 3474 # https://github.com/pytorch/pytorch/issues/117627 3475 def fn(out1, out2, input, other): 3476 o1 = torch.logical_or(out=out1, input=input, other=other) 3477 o2 = torch.logical_xor(out=out2, input=input, other=other) 3478 return o1, o2 3479 3480 x = torch.rand([3, 3, 2, 8, 9, 2], dtype=torch.float) 3481 y = torch.rand([3, 3, 2, 8, 9, 2], dtype=torch.float) 3482 for dtype in _lowp_fp_dtypes: 3483 o1 = torch.rand([3, 3, 2, 8, 9, 2], dtype=dtype) 3484 o2 = torch.rand([3, 3, 2, 8, 9, 2], dtype=dtype) 3485 with torch.no_grad(): 3486 self.common(fn, (o1, o2, x, y)) 3487 3488 def test_constant_bool_vec(self): 3489 def fn(x): 3490 mask = torch.zeros(1, dtype=torch.bool) 3491 return torch.where(mask, x, -1.0) 3492 3493 x = torch.rand(1000) 3494 metrics.reset() 3495 self.common(fn, (x,)) 3496 check_metrics_vec_kernel_count(1) 3497 3498 @torch._dynamo.config.patch(dynamic_shapes=True) 3499 @torch._dynamo.config.patch(assume_static_by_default=False) 3500 def test_symbolic_shape_scalar_value_reduction(self): 3501 def fn(x, y): 3502 return y + torch.ones(x).sum() 3503 3504 with torch.no_grad(): 3505 metrics.reset() 3506 y = torch.randn(100) 3507 self.common(fn, (100, y)) 3508 check_metrics_vec_kernel_count(2) 3509 3510 def test_int32_pointwise_vec(self): 3511 def fn(x): 3512 return x * x 3513 3514 x = torch.randint(0, 100, (32, 32), dtype=torch.int32) 3515 metrics.reset() 3516 self.common(fn, (x,)) 3517 check_metrics_vec_kernel_count(1) 3518 3519 def test_int32_reduction_vec(self): 3520 def fn(x): 3521 return x.sum(dim=1) 3522 3523 x = torch.randint(0, 100, (32, 32), dtype=torch.int32) 3524 metrics.reset() 3525 self.common(fn, (x,)) 3526 check_metrics_vec_kernel_count(1) 3527 3528 def test_uint32_pointwise_vec(self): 3529 def fn(x): 3530 return x * x 3531 3532 x = torch.randint(0, 100, (32, 32), dtype=torch.uint32) 3533 metrics.reset() 3534 self.common(fn, (x,)) 3535 # TODO(jgong5): change to 1 with vectorized uint32 load 3536 assert metrics.generated_cpp_vec_kernel_count == 0 3537 3538 def test_uint32_reduction_vec(self): 3539 def fn(x): 3540 return x.sum(dim=1) 3541 3542 x = torch.randint(0, 100, (32, 32), dtype=torch.uint32) 3543 metrics.reset() 3544 self.common(fn, (x,)) 3545 # TODO(jgong5): change to 1 with vectorized uint32/uint64 load 3546 assert metrics.generated_cpp_vec_kernel_count == 0 3547 3548 def test_int64_pointwise_vec(self): 3549 def fn(x): 3550 return x * x 3551 3552 x = torch.randint(0, 100, (32, 32), dtype=torch.int64) 3553 metrics.reset() 3554 self.common(fn, (x,)) 3555 check_metrics_vec_kernel_count(1) 3556 3557 def test_int64_reduction_vec(self): 3558 def fn(x): 3559 return x.sum(dim=1) 3560 3561 x = torch.randint(0, 100, (32, 32), dtype=torch.int64) 3562 metrics.reset() 3563 self.common(fn, (x,)) 3564 check_metrics_vec_kernel_count(1) 3565 3566 def test_uint64_pointwise_vec(self): 3567 def fn(x): 3568 return x * x 3569 3570 x = torch.randint(0, 100, (32, 32), dtype=torch.uint64) 3571 metrics.reset() 3572 self.common(fn, (x,)) 3573 # TODO(jgong5): change to 1 with vectorized uint64 load 3574 assert metrics.generated_cpp_vec_kernel_count == 0 3575 3576 def test_uint64_reduction_vec(self): 3577 def fn(x): 3578 return x.sum(dim=1) 3579 3580 x = torch.randint(0, 100, (32, 32), dtype=torch.uint64) 3581 metrics.reset() 3582 self.common(fn, (x,)) 3583 # TODO(jgong5): change to 1 with vectorized uint64 load 3584 assert metrics.generated_cpp_vec_kernel_count == 0 3585 3586 def test_convert_int32_to_int64_vec(self): 3587 def fn(x): 3588 return x.to(torch.int64) 3589 3590 x = torch.randint(0, 100, (32, 32), dtype=torch.int32) 3591 metrics.reset() 3592 self.common(fn, (x,)) 3593 check_metrics_vec_kernel_count(1) 3594 3595 def test_convert_int64_to_int32_vec(self): 3596 def fn(x): 3597 return x.to(torch.int32) 3598 3599 x = torch.randint(0, 100, (32, 32), dtype=torch.int64) 3600 metrics.reset() 3601 self.common(fn, (x,)) 3602 check_metrics_vec_kernel_count(1) 3603 3604 def test_convert_fp32_to_int64_vec(self): 3605 def fn(x): 3606 return x.to(torch.int64) 3607 3608 x = torch.rand(32, 32) 3609 metrics.reset() 3610 self.common(fn, (x,)) 3611 check_metrics_vec_kernel_count(1) 3612 3613 def test_convert_int64_to_fp32_vec(self): 3614 def fn(x): 3615 return x.to(torch.float32) 3616 3617 x = torch.randint(0, 100, (32, 32), dtype=torch.int64) 3618 metrics.reset() 3619 self.common(fn, (x,)) 3620 check_metrics_vec_kernel_count(1) 3621 3622 def test_no_redundant_to_dtypes_between_fused_scheduler_node(self): 3623 # https://github.com/pytorch/pytorch/issues/115260 3624 p0 = torch.tensor([1.0879], dtype=torch.float16) 3625 3626 class Model1(torch.nn.Module): 3627 def __init__(self): 3628 super().__init__() 3629 3630 def forward(self, *args): 3631 cat = torch.cat((args[3], args[2], args[1], args[0]), dim=2) 3632 max_1 = torch.max(args[4], p0) 3633 mul = torch.mul(cat, max_1) 3634 tan = torch.tan(mul) 3635 return (mul, tan) 3636 3637 metrics.reset() 3638 m = Model1() 3639 self.common( 3640 m, 3641 ( 3642 torch.randn((17, 5, 1, 7)).half(), 3643 torch.randn((17, 5, 1, 7)).half(), 3644 torch.randn((17, 5, 11, 7)).half(), 3645 torch.randn((17, 5, 1, 7)).half(), 3646 torch.tensor(4.39, dtype=torch.float16), 3647 ), 3648 ) 3649 3650 def test_masked_load_int64_vec(self): 3651 # https://github.com/pytorch/pytorch/issues/120377 3652 def fn(x): 3653 return torch.nn.functional.pad(x, (0, 13)) 3654 3655 x = torch.randint(0, 100, (819,), dtype=torch.int64) 3656 metrics.reset() 3657 self.common(fn, (x,)) 3658 assert metrics.generated_cpp_vec_kernel_count == 1 3659 3660 def test_reduction_float_to_int64(self): 3661 # https://github.com/pytorch/pytorch/issues/124821 3662 def fn(x): 3663 return x.max(0).values 3664 3665 x = torch.randint(0, 100, (22, 51), dtype=torch.int64) 3666 metrics.reset() 3667 self.common(fn, (x,)) 3668 assert metrics.generated_cpp_vec_kernel_count == 1 3669 3670 @config.patch({"cpp.dynamic_threads": True}) 3671 def test_reduction_with_dynamic_threads(self): 3672 def fn(a, b): 3673 return a.sum(), b.sum() 3674 3675 self.common( 3676 fn, 3677 (torch.randn(1000), torch.rand(1000)), 3678 ) 3679 3680 @patch("torch.cuda.is_available", lambda: False) 3681 @config.patch(freezing=True) 3682 def test_linear_float64(self): 3683 class M(torch.nn.Module): 3684 def __init__(self): 3685 super().__init__() 3686 self.weight1 = torch.nn.Parameter( 3687 torch.randn(10, 10, dtype=torch.float64) 3688 ) 3689 self.weight2 = torch.nn.Parameter( 3690 torch.randn(10, 10, dtype=torch.float64) 3691 ) 3692 self.bias = torch.nn.Parameter(torch.randn(10, dtype=torch.float64)) 3693 3694 def forward(self, x1): 3695 v1 = torch.mm(x1, self.weight1) 3696 v2 = torch.addmm(self.bias, x1, self.weight2) 3697 return (v1, v2) 3698 3699 mod = M().eval() 3700 v = torch.randn(10, 10, dtype=torch.float64) 3701 with torch.no_grad(): 3702 self.common( 3703 mod, 3704 (v,), 3705 ) 3706 3707 def test_fused_attention_conv(self): 3708 # https://github.com/pytorch/pytorch/issues/121174. 3709 class Model(torch.nn.Module): 3710 def __init__(self): 3711 super().__init__() 3712 self.q_conv = torch.nn.Conv2d(4, 4, 1) 3713 self.k_conv = torch.nn.Conv2d(4, 4, 1) 3714 self.v_conv = torch.nn.Conv2d(4, 4, 1) 3715 3716 def forward(self, x): 3717 q = self.q_conv(x) 3718 k = self.k_conv(x) 3719 v = self.v_conv(x) 3720 q = q.permute(0, 2, 1, 3) 3721 k = k.permute(0, 2, 1, 3) 3722 v = v.permute(0, 2, 1, 3) 3723 return torch.nn.functional.scaled_dot_product_attention( 3724 q, k, v, dropout_p=0.0, is_causal=False 3725 ) 3726 3727 fn = Model() 3728 x = torch.randn(1, 4, 2, 2) 3729 self.common(fn, (x,)) 3730 3731 @requires_vectorization 3732 def test_vec_indirect_load_cse_cache(self): 3733 # https://github.com/pytorch/pytorch/issues/123502 3734 from math import inf 3735 3736 def fn(arg0_1): 3737 full_default = torch.ops.aten.full.default([209985], 1) 3738 select = torch.ops.aten.select.int(arg0_1, 0, 0) 3739 select_1 = torch.ops.aten.select.int(arg0_1, 0, 1) 3740 view = torch.ops.aten.reshape.default(select_1, [-1]) 3741 expand = torch.ops.aten.expand.default(view, [209985]) 3742 full_default_1 = torch.ops.aten.full.default([10000], 0) 3743 scatter_add = torch.ops.aten.scatter_add.default( 3744 full_default_1, 0, expand, full_default 3745 ) 3746 pow_1 = torch.ops.aten.pow.Tensor_Scalar(scatter_add, -0.5) 3747 eq = torch.ops.aten.eq.Scalar(pow_1, inf) 3748 full_default_2 = torch.ops.aten.full.default([], 0.0) 3749 where = torch.ops.aten.where.self(eq, full_default_2, pow_1) 3750 index = torch.ops.aten.index.Tensor(where, [select]) 3751 index_1 = torch.ops.aten.index.Tensor(where, [select_1]) 3752 mul_1 = torch.ops.aten.mul.Tensor(index, index_1) 3753 return (mul_1,) 3754 3755 x = torch.zeros(2, 209985).to(torch.int64) 3756 opt_fn = torch._dynamo.optimize("inductor")(fn) 3757 _, code = run_and_get_cpp_code(opt_fn, x) 3758 FileCheck().check_count( 3759 "return at::vec::VectorizedN<int64_t,2>::loadu(tmpbuf.data(),", 3760 4, 3761 exactly=True, 3762 ).run(code) 3763 3764 3765if __name__ == "__main__": 3766 from torch._inductor.test_case import run_tests 3767 from torch.testing._internal.inductor_utils import HAS_CPU 3768 3769 if HAS_CPU and not IS_MACOS: 3770 run_tests(needs="filelock") 3771