1# Owner(s): ["module: inductor"] 2import copy 3import itertools 4import os 5import sys 6import tempfile 7import types 8import unittest 9from typing import Dict, Tuple 10from unittest import skip 11 12import torch 13import torch._export 14import torch._inductor 15import torch._inductor.config 16import torch.nn as nn 17from torch._dynamo.testing import rand_strided, same 18from torch._dynamo.utils import counters 19from torch._inductor import config 20from torch._inductor.exc import CppWrapperCodeGenError 21from torch._inductor.runtime.runtime_utils import cache_dir 22from torch._inductor.test_case import TestCase 23from torch._inductor.utils import run_and_get_cpp_code 24from torch.export import Dim, export 25from torch.testing import FileCheck 26from torch.testing._internal import common_utils 27from torch.testing._internal.common_cuda import SM80OrLater, SM90OrLater 28from torch.testing._internal.common_quantization import ( 29 skip_if_no_torchvision, 30 skipIfNoFBGEMM, 31) 32from torch.testing._internal.common_utils import ( 33 DeterministicGuard, 34 find_library_location, 35 IS_CI, 36 IS_FBCODE, 37 IS_MACOS, 38 IS_SANDCASTLE, 39 IS_WINDOWS, 40 skipIfRocm, 41 TEST_WITH_ROCM, 42) 43from torch.testing._internal.triton_utils import HAS_CUDA, requires_cuda 44from torch.utils import _pytree as pytree 45 46 47if HAS_CUDA: 48 import triton 49 50 from torch.testing._internal.triton_utils import ( 51 add_kernel, 52 add_kernel_2d_autotuned, 53 add_kernel_autotuned, 54 add_kernel_autotuned_weird_param_order, 55 add_kernel_with_optional_param, 56 add_kernel_with_scaling, 57 mul2_inplace_kernel, 58 ) 59 60if IS_WINDOWS and IS_CI: 61 sys.stderr.write( 62 "Windows CI does not have necessary dependencies for test_torchinductor yet\n" 63 ) 64 if __name__ == "__main__": 65 sys.exit(0) 66 raise unittest.SkipTest("requires sympy/functorch/filelock") 67 68try: 69 try: 70 from .test_aot_inductor_utils import AOTIRunnerUtil 71 from .test_control_flow import ( 72 CondModels, 73 prepend_counters, 74 prepend_predicates, 75 WhileLoopModels, 76 ) 77 from .test_torchinductor import copy_tests, requires_multigpu, TestFailure 78 except ImportError: 79 from test_aot_inductor_utils import AOTIRunnerUtil 80 from test_control_flow import ( 81 CondModels, 82 prepend_counters, 83 prepend_predicates, 84 WhileLoopModels, 85 ) 86 from test_torchinductor import copy_tests, requires_multigpu, TestFailure 87except (unittest.SkipTest, ImportError) as e: 88 if __name__ == "__main__": 89 sys.exit(0) 90 raise 91 92 93def check_model( 94 self: TestCase, 95 model, 96 example_inputs, 97 options=None, 98 dynamic_shapes=None, 99 disable_constraint_solver=False, 100 atol=None, 101 rtol=None, 102): 103 with torch.no_grad(), config.patch( 104 { 105 "abi_compatible": self.abi_compatible, 106 "allow_stack_allocation": self.allow_stack_allocation, 107 "use_minimal_arrayref_interface": self.use_minimal_arrayref_interface, 108 } 109 ): 110 torch.manual_seed(0) 111 if not isinstance(model, types.FunctionType): 112 model = model.to(self.device) 113 ref_model = copy.deepcopy(model) 114 ref_inputs = copy.deepcopy(example_inputs) 115 expected = ref_model(*ref_inputs) 116 117 torch.manual_seed(0) 118 actual = AOTIRunnerUtil.run( 119 self.device, 120 model, 121 example_inputs, 122 options, 123 dynamic_shapes, 124 disable_constraint_solver, 125 ) 126 127 self.assertEqual(actual, expected, atol=atol, rtol=rtol) 128 129 130def check_model_with_multiple_inputs( 131 self: TestCase, 132 model, 133 list_example_inputs, 134 options=None, 135 dynamic_shapes=None, 136): 137 with torch.no_grad(), config.patch( 138 { 139 "abi_compatible": self.abi_compatible, 140 "allow_stack_allocation": self.allow_stack_allocation, 141 } 142 ): 143 torch.manual_seed(0) 144 model = model.to(self.device) 145 ref_model = copy.deepcopy(model) 146 ref_inputs = copy.deepcopy(list_example_inputs) 147 list_expected = [ref_model(*inputs) for inputs in ref_inputs] 148 149 torch.manual_seed(0) 150 list_actual = AOTIRunnerUtil.run_multiple( 151 self.device, model, list_example_inputs, options, dynamic_shapes 152 ) 153 154 self.assertTrue(same(list_actual, list_expected)) 155 156 157def code_check_count( 158 self: TestCase, 159 model, 160 example_inputs, 161 target_str: str, 162 target_count: int, 163): 164 so_path = torch._export.aot_compile(model, example_inputs) 165 with open(os.path.splitext(so_path)[0] + ".cpp") as cpp: 166 src_code = cpp.read() 167 FileCheck().check_count( 168 target_str, 169 target_count, 170 exactly=True, 171 ).run(src_code) 172 173 174class AOTInductorTestsTemplate: 175 def test_simple(self): 176 class Model(torch.nn.Module): 177 def __init__(self) -> None: 178 super().__init__() 179 self.linear = torch.nn.Linear(10, 10) 180 181 def forward(self, x, y): 182 return x + self.linear(y) 183 184 example_inputs = ( 185 torch.randn(10, 10, device=self.device), 186 torch.randn(10, 10, device=self.device), 187 ) 188 self.check_model(Model(), example_inputs) 189 190 def test_small_constant(self): 191 class Model(torch.nn.Module): 192 def __init__(self) -> None: 193 super().__init__() 194 self.linear = torch.nn.Linear(4, 4) 195 196 def forward(self, x): 197 return self.linear(x) 198 199 example_inputs = (torch.randn(4, 4, device=self.device),) 200 with config.patch({"always_keep_tensor_constants": True}): 201 self.check_model(Model().to(self.device), example_inputs) 202 203 def test_output_path_1(self): 204 class Model(torch.nn.Module): 205 def __init__(self) -> None: 206 super().__init__() 207 self.linear = torch.nn.Linear(10, 10) 208 209 def forward(self, x, y): 210 return x + self.linear(y) 211 212 example_inputs = ( 213 torch.randn(10, 10, device=self.device), 214 torch.randn(10, 10, device=self.device), 215 ) 216 with config.patch("aot_inductor.output_path", "tmp_output_"): 217 self.check_model(Model(), example_inputs) 218 219 def test_output_path_2(self): 220 class Model(torch.nn.Module): 221 def __init__(self) -> None: 222 super().__init__() 223 self.linear = torch.nn.Linear(10, 10) 224 225 def forward(self, x, y): 226 return x + self.linear(y) 227 228 model = Model().to(device=self.device) 229 example_inputs = ( 230 torch.randn(10, 10, device=self.device), 231 torch.randn(10, 10, device=self.device), 232 ) 233 expected_path = os.path.join(tempfile.mkdtemp(dir=cache_dir()), "model.so") 234 actual_path = AOTIRunnerUtil.compile( 235 model, example_inputs, options={"aot_inductor.output_path": expected_path} 236 ) 237 self.assertTrue(actual_path == expected_path) 238 239 def test_constant_folding(self): 240 class Model(torch.nn.Module): 241 def __init__(self, device): 242 super().__init__() 243 self.w_pre = torch.randn(4, 4, device=device) 244 self.b = torch.randn(4, device=device) 245 246 def forward(self, x): 247 w_transpose = torch.transpose(self.w_pre, 0, 1) 248 w_relu = torch.nn.functional.relu(w_transpose) 249 w = w_relu + self.b 250 return torch.matmul(x, w) 251 252 example_inputs = (torch.randn(4, 4, device=self.device),) 253 with config.patch({"aot_inductor.use_runtime_constant_folding": True}): 254 self.check_model(Model(self.device), example_inputs) 255 256 @requires_cuda 257 def test_duplicate_constant_folding(self): 258 class Model(torch.nn.Module): 259 def __init__(self, device): 260 super().__init__() 261 self.w1 = torch.randn(4, 4, device=device) 262 self.w2 = torch.randn(4, 4, device=device) 263 self.w3 = torch.randn(4, 4, device=device) 264 self.w4 = torch.randn(4, 4, device=device) 265 266 def forward(self, x): 267 w_concat = torch.cat((self.w1, self.w2, self.w3, self.w4)) 268 return torch.cat((x, w_concat)) 269 270 example_inputs = (torch.randn(4, 4, device=self.device),) 271 with config.patch({"aot_inductor.use_runtime_constant_folding": True}): 272 self.check_model(Model(self.device), example_inputs) 273 274 @requires_cuda 275 def test_multi_device(self): 276 class Model(torch.nn.Module): 277 def forward(self, x): 278 x = x + 1 279 x = x.cpu() 280 x = x + 2 281 x = x.cuda() 282 return x 283 284 example_inputs = (torch.randn(32, 64, device=self.device),) 285 self.check_model(Model(), example_inputs) 286 287 def test_large_weight(self): 288 class Model(torch.nn.Module): 289 def __init__(self) -> None: 290 super().__init__() 291 self.linear = torch.nn.Linear(2048, 262144) 292 293 def forward(self, x, y): 294 return x + self.linear(y) 295 296 example_inputs = ( 297 torch.randn(1, 262144, device=self.device), 298 torch.randn(1, 2048, device=self.device), 299 ) 300 301 # We only test compilation since we often get OOM running in CI. 302 model = Model() 303 model = model.to(self.device) 304 AOTIRunnerUtil.compile(model, example_inputs) 305 306 def test_large_mmaped_weights(self): 307 class Model(torch.nn.Module): 308 def __init__(self) -> None: 309 super().__init__() 310 self.linear = torch.nn.Linear(512, 250112) 311 312 def forward(self, x, y): 313 return x + self.linear(y) 314 315 example_inputs = ( 316 torch.randn(1, 250112, device=self.device), 317 torch.randn(1, 512, device=self.device), 318 ) 319 with config.patch({"aot_inductor.force_mmap_weights": True}): 320 self.check_model(Model(), example_inputs) 321 322 def test_with_offset(self): 323 class Model(torch.nn.Module): 324 def __init__(self, device): 325 super().__init__() 326 self.orig_tensor = torch.randn(2, 15, 10, device=device)[0] 327 self.tensor = self.orig_tensor[5:, :] 328 329 def forward(self, x, y): 330 return ( 331 x 332 + torch.nn.functional.linear(y, self.orig_tensor[:10, :]) 333 + self.tensor 334 ) 335 336 example_inputs = ( 337 torch.randn(10, 10, device=self.device), 338 torch.randn(10, 10, device=self.device), 339 ) 340 self.check_model(Model(self.device), example_inputs) 341 342 @unittest.skipIf( 343 IS_FBCODE, 344 "Not yet runnable in fbcode when the model.so is newly generated while older PyTorch is used", 345 ) 346 def test_freezing(self): 347 class Model(torch.nn.Module): 348 def __init__(self, device): 349 super().__init__() 350 self.weight = torch.randn(9, 10, device=device) 351 self.padding = torch.randn(1, 10, device=device) 352 353 def forward(self, x, y): 354 padded_weight = torch.cat((self.weight, self.padding), dim=0) 355 return x + torch.nn.functional.linear(y, padded_weight) 356 357 example_inputs = ( 358 torch.randn(10, 10, device=self.device), 359 torch.randn(10, 10, device=self.device), 360 ) 361 362 with config.patch({"freezing": True}): 363 self.check_model(Model(self.device), example_inputs) 364 365 @unittest.skipIf( 366 IS_FBCODE, 367 "Not yet runnable in fbcode when the model.so is newly generated while older PyTorch is used", 368 ) 369 def test_conv_freezing(self): 370 for dtype, groups in itertools.product([torch.bfloat16, torch.float], [1, 2]): 371 iC = 2 372 oC = 3 373 374 class Model(torch.nn.Module): 375 def __init__(self, device): 376 super().__init__() 377 self.weight = torch.randn(oC * groups, iC, 3, 3, device=device).to( 378 dtype 379 ) 380 381 def forward(self, y): 382 return torch.nn.functional.conv2d(y, self.weight, groups=groups) 383 384 example_inputs = ( 385 torch.randn(2, iC * groups, 10, 10, device=self.device).to(dtype), 386 ) 387 388 with config.patch({"freezing": True}): 389 self.check_model(Model(self.device), example_inputs) 390 391 @unittest.skipIf( 392 IS_FBCODE, 393 "Not yet runnable in fbcode when the model.so is newly generated while older PyTorch is used", 394 ) 395 def test_deconv_freezing(self): 396 dtypes = [torch.float] 397 if torch._C._has_mkldnn and torch.ops.mkldnn._is_mkldnn_bf16_supported(): 398 dtypes.append(torch.bfloat16) 399 for dtype, groups in itertools.product(dtypes, [2, 1]): 400 iC = 4 401 oC = 2 402 403 class Model(torch.nn.Module): 404 def __init__(self, device): 405 super().__init__() 406 self.weight = torch.randn(iC, oC * groups, 2, 2, device=device).to( 407 dtype 408 ) 409 410 def forward(self, y): 411 return torch.nn.functional.conv_transpose2d( 412 y, self.weight, groups=groups 413 ) 414 415 example_inputs = (torch.randn(1, iC, 3, 3, device=self.device).to(dtype),) 416 with config.patch({"freezing": True}): 417 self.check_model(Model(self.device), example_inputs) 418 419 @unittest.skipIf( 420 IS_FBCODE, 421 "Not yet runnable in fbcode when the model.so is newly generated while older PyTorch is used", 422 ) 423 def test_linear_freezing(self): 424 for dtype in [torch.float32, torch.bfloat16]: 425 426 class LinearModel(torch.nn.Module): 427 def __init__(self, device): 428 super().__init__() 429 self.weight = torch.randn(10, 10, device=device).to(dtype) 430 self.bias = torch.randn(10, device=device).to(dtype) 431 432 def forward(self, y): 433 return torch.nn.functional.linear(y, self.weight, self.bias) 434 435 example_inputs = (torch.randn(10, 10, device=self.device).to(dtype),) 436 437 with config.patch({"freezing": True}): 438 self.check_model(LinearModel(self.device), example_inputs) 439 440 @torch._inductor.config.patch( 441 pre_grad_fusion_options={ 442 "normalization_pass": {}, 443 "remove_split_with_size_one_pass": {}, 444 "merge_getitem_cat_pass": {}, 445 "merge_stack_tahn_unbind_pass": {}, 446 "merge_splits_pass": {}, 447 "mutate_cat_pass": {}, 448 "split_cat_pass": {}, 449 "unbind_stack_pass": {}, 450 }, 451 post_grad_fusion_options={}, 452 ) 453 def test_simple_split(self): 454 class Model(torch.nn.Module): 455 def __init__(self) -> None: 456 super().__init__() 457 458 def forward(self, x): 459 return torch.cat(tensors=torch.split(x, 4, dim=1), dim=-2) 460 461 example_inputs = (torch.randn(2, 8, device=self.device),) 462 counters.clear() 463 self.check_model(Model(), example_inputs) 464 self.assertEqual(counters["inductor"]["scmerge_split_removed"], 1) 465 self.assertEqual(counters["inductor"]["scmerge_cat_removed"], 1) 466 self.assertEqual(counters["inductor"]["scmerge_split_sections_removed"], 1) 467 468 def test_amp_fallback_random(self): 469 def fn(x, w): 470 return torch.functional.F.linear(x, w) 471 472 example_inputs = ( 473 torch.randn(10, 10, device=self.device), 474 torch.randn(10, 10, device=self.device), 475 ) 476 if self.device == "cuda": 477 ctx = torch.cuda.amp.autocast 478 elif self.device == "cpu": 479 ctx = torch.cpu.amp.autocast 480 else: 481 raise AssertionError("Unsupported device") 482 483 with config.patch({"fallback_random": True}): 484 with ctx(): 485 self.check_model(fn, example_inputs) 486 487 def test_missing_output(self): 488 class Model(torch.nn.Module): 489 def __init__(self) -> None: 490 super().__init__() 491 492 def forward(self, x, y): 493 a = torch.sin(x) 494 b = torch.mm(a, y) 495 c = torch.cos(b) 496 return c 497 498 example_inputs = ( 499 torch.randn(10, 10, device=self.device), 500 torch.randn(10, 10, device=self.device), 501 ) 502 self.check_model(Model(), example_inputs) 503 504 def test_output_misaligned(self): 505 class Model(torch.nn.Module): 506 def __init__(self) -> None: 507 super().__init__() 508 509 def forward(self, x, y): 510 x_unsqueeze = torch.unsqueeze(x, dim=0) 511 y_unsqueeze = torch.unsqueeze(y, dim=0) 512 cat = torch.cat([x_unsqueeze, y_unsqueeze], dim=0) 513 x_getitem = cat[0] 514 y_getitem = cat[1] 515 x_sigmoid = torch.sigmoid(x_getitem) 516 return x_sigmoid, y_getitem 517 518 example_inputs = ( 519 torch.randn(10, 10, device=self.device), 520 torch.randn(10, 10, device=self.device), 521 ) 522 self.check_model(Model(), example_inputs) 523 524 @skip("Test was marked as expected failure, but does not fail always anymore.") 525 def test_dynamic_smem_above_default_limit(self): 526 class Model(torch.nn.Module): 527 def forward(self, x, y): 528 return x @ y 529 530 model = Model().to(self.device) 531 # on A100, the generated Triton kernel for this MM 532 # requires 55296 bytes of dynamic SMEM which is above 533 # the A100's default dynamic SMEM limit of 49152 bytes. 534 example_inputs = ( 535 torch.randn(10285, 96, device=self.device), 536 torch.randn(96, 1, device=self.device), 537 ) 538 self.check_model( 539 model, 540 example_inputs, 541 options={ 542 "max_autotune": True, 543 "max_autotune_gemm_backends": "TRITON", 544 }, 545 ) 546 547 @unittest.skipIf(IS_FBCODE, "Not yet runnable in fbcode") 548 def test_seq(self): 549 layernorm = torch.nn.LayerNorm(10) 550 net = torch.nn.Sequential( 551 layernorm, 552 torch.nn.ReLU(), 553 layernorm, 554 torch.nn.ReLU(), 555 ) 556 557 example_inputs = (torch.randn(10, device=self.device),) 558 self.check_model(net.eval(), example_inputs) 559 560 def test_addmm(self): 561 class Model(torch.nn.Module): 562 def __init__(self, n, k, device): 563 super().__init__() 564 self.weight = torch.randn(n, k, device=device) 565 self.bias = torch.randn(n, device=device) 566 567 def forward(self, a): 568 return torch.nn.functional.linear(a, self.weight, self.bias) 569 570 M = 8 571 N = 6 572 K = 16 573 model = Model(N, K, self.device) 574 batch = 2 575 a = torch.randn(batch, M, K, device=self.device) 576 example_inputs = (a,) 577 self.check_model(model, example_inputs) 578 579 def test_aliased_buffer_reuse(self): 580 class Model(torch.nn.Module): 581 def __init__(self) -> None: 582 super().__init__() 583 584 def forward(self, x, y): 585 x = 2 * x 586 y = 2 * y 587 c = torch.cat([x, y], dim=-1) 588 d = 1 + c 589 m = torch.mm(d, d) 590 return m[:, :2] + x 591 592 example_inputs = ( 593 torch.randn(4, 2, device=self.device), 594 torch.randn(4, 2, device=self.device), 595 ) 596 self.check_model(Model(), example_inputs) 597 598 def test_buffer_reuse(self): 599 class Model(torch.nn.Module): 600 def __init__(self) -> None: 601 super().__init__() 602 603 def forward(self, x, y): 604 a = torch.sin(x) 605 b = torch.cos(y) 606 c = torch.mm(a, b) 607 d = torch.relu(c) 608 e = torch.sigmoid(d) 609 f = torch.mm(x, y) 610 g = e + f 611 return g 612 613 example_inputs = ( 614 torch.randn(4, 4, device=self.device), 615 torch.randn(4, 4, device=self.device), 616 ) 617 self.check_model(Model(), example_inputs) 618 619 def test_duplicated_params(self): 620 class Model(torch.nn.Module): 621 def __init__(self) -> None: 622 super().__init__() 623 self.p = torch.nn.Parameter(torch.rand(6)) 624 self.q = self.p 625 626 def forward(self, x): 627 return self.p * x + self.q 628 629 example_inputs = (torch.rand(6, device=self.device),) 630 self.check_model(Model(), example_inputs) 631 632 @unittest.skip("Skip this test, only for local test. SIGABRT is produced.") 633 def test_inf(self): 634 class Model(torch.nn.Module): 635 def __init__(self) -> None: 636 super().__init__() 637 self.linear = torch.nn.Linear(10, 10) 638 639 def forward(self, x, y): 640 return x + self.linear(y) 641 642 x = torch.randn(10, 10, device=self.device) 643 x[0][0] = float("Inf") 644 example_inputs = ( 645 x, 646 torch.randn(10, 10, device=self.device), 647 ) 648 self.check_model( 649 Model().to(self.device), 650 example_inputs, 651 options={"debug_check_inf_and_nan": True}, 652 ) 653 654 @unittest.skip("Skip this test, only for local test. SIGABRT is produced.") 655 def test_nan(self): 656 class Model(torch.nn.Module): 657 def __init__(self) -> None: 658 super().__init__() 659 self.linear = torch.nn.Linear(10, 10) 660 661 def forward(self, x, y): 662 return x + self.linear(y) 663 664 x = torch.randn(10, 10, device=self.device) 665 x[0][0] = float("nan") 666 example_inputs = ( 667 x, 668 torch.randn(10, 10, device=self.device), 669 ) 670 self.check_model( 671 Model().to(self.device), 672 example_inputs, 673 options={"debug_check_inf_and_nan": True}, 674 ) 675 676 def test_assert_async(self): 677 if self.device != "cuda": 678 raise unittest.SkipTest("requires CUDA") 679 680 class Model(torch.nn.Module): 681 def __init__(self) -> None: 682 super().__init__() 683 684 def forward(self, x): 685 u0 = x.item() 686 torch._check(u0 > 3) 687 return torch.ones(u0)[0] 688 689 x = torch.tensor(23, device=self.device) 690 example_inputs = (x,) 691 self.check_model(Model(), example_inputs) 692 693 def test_simple_dynamic(self): 694 class Model(torch.nn.Module): 695 def __init__(self) -> None: 696 super().__init__() 697 698 def forward(self, x, y): 699 add_0 = x + y 700 return torch.nn.functional.relu(input=add_0, inplace=False) 701 702 x = torch.randn(128, 2048, device=self.device) 703 y = torch.randn(128, 2048, device=self.device) 704 dim0_x = Dim("dim0_x", min=1, max=2048) 705 dynamic_shapes = {"x": {0: dim0_x}, "y": {0: dim0_x}} 706 example_inputs = (x, y) 707 self.check_model(Model(), example_inputs, dynamic_shapes=dynamic_shapes) 708 709 @unittest.skipIf( 710 not torch.cuda.is_available() or torch.cuda.get_device_capability() < (9, 0), 711 "FP8 is only supported on H100+", 712 ) 713 @skipIfRocm # _scaled_mm_out_cuda is not compiled for ROCm platform 714 def test_fp8(self): 715 class Model(torch.nn.Module): 716 def __init__(self, dtype): 717 super().__init__() 718 self.out_dtype = dtype 719 720 def forward(self, x, weight, bias, scale_a, scale_b): 721 weight = weight.to(torch.float8_e4m3fn) 722 output = torch._scaled_mm( 723 x, 724 weight, 725 bias=input_bias, 726 out_dtype=self.out_dtype, 727 scale_a=scale_a, 728 scale_b=scale_b, 729 ) 730 return output 731 732 dtype = torch.float16 733 734 a_scale = torch.Tensor([1.0]).to(device="cuda") 735 b_scale = torch.Tensor([1.0]).to(device="cuda") 736 input_bias = torch.rand(32, device="cuda", dtype=dtype) 737 weight_shape = (32, 16) 738 weight = torch.rand(*weight_shape, device="cuda", dtype=dtype).T 739 a_inverse_scale = 1 / a_scale 740 b_inverse_scale = 1 / b_scale 741 742 x_shape = (16, 16) 743 x = torch.rand(*x_shape, device="cuda", dtype=dtype).to(torch.float8_e4m3fn) 744 dim0_x = Dim("dim0_x", min=1, max=2048) 745 dynamic_shapes = ({0: dim0_x}, None, None, None, None) 746 self.check_model( 747 Model(dtype), 748 (x, weight, input_bias, a_inverse_scale, b_inverse_scale), 749 dynamic_shapes=dynamic_shapes, 750 ) 751 752 @unittest.skipIf( 753 not torch.cuda.is_available() or torch.cuda.get_device_capability() < (9, 0), 754 "FP8 is only supported on H100+", 755 ) 756 @skipIfRocm # _scaled_mm_out_cuda is not compiled for ROCm platform 757 def test_fp8_view_of_param(self): 758 # cuda only 759 if self.device != "cuda": 760 return 761 762 class Model(torch.nn.Module): 763 def __init__(self, dtype, weight): 764 super().__init__() 765 self.out_dtype = dtype 766 self.weight = weight 767 768 def forward(self, x, bias, scale_a, scale_b): 769 # test: do the view inside of the graph, 770 # AOTI needs to materialize this view before passing 771 # it into the scaled_mm extern kernel 772 weight = self.weight.T 773 output = torch._scaled_mm( 774 x, 775 weight, 776 bias=input_bias, 777 out_dtype=self.out_dtype, 778 scale_a=scale_a, 779 scale_b=scale_b, 780 ) 781 return output 782 783 dtype = torch.float16 784 785 a_scale = torch.Tensor([1.0]).to(device=self.device) 786 b_scale = torch.Tensor([1.0]).to(device=self.device) 787 input_bias = torch.rand(32, device=self.device, dtype=dtype) 788 weight_shape = (32, 16) 789 weight = torch.rand(*weight_shape, device=self.device, dtype=dtype).to( 790 torch.float8_e4m3fn 791 ) 792 a_inverse_scale = 1 / a_scale 793 b_inverse_scale = 1 / b_scale 794 795 x_shape = (16, 16) 796 x = torch.rand(*x_shape, device=self.device, dtype=dtype).to( 797 torch.float8_e4m3fn 798 ) 799 dim0_x = Dim("dim0_x", min=1, max=2048) 800 dynamic_shapes = ({0: dim0_x}, None, None, None) 801 self.check_model( 802 Model(dtype, weight), 803 (x, input_bias, a_inverse_scale, b_inverse_scale), 804 dynamic_shapes=dynamic_shapes, 805 ) 806 807 def test_poi_multiple_dynamic(self): 808 class Model(torch.nn.Module): 809 def __init__(self) -> None: 810 super().__init__() 811 812 def forward(self, x, y): 813 add_0 = x + y 814 return torch.nn.functional.relu(input=add_0, inplace=False) 815 816 x = torch.randn(128, 2048, device=self.device) 817 y = torch.randn(128, 2048, device=self.device) 818 dim0_x = Dim("dim0_x", min=1, max=2048) 819 dynamic_shapes = {"x": {0: dim0_x}, "y": {0: dim0_x}} 820 list_example_inputs = [(x, y)] 821 list_example_inputs.append( 822 ( 823 torch.randn(64, 2048, device=self.device), 824 torch.randn(64, 2048, device=self.device), 825 ), 826 ) 827 list_example_inputs.append( 828 ( 829 torch.randn(211, 2048, device=self.device), 830 torch.randn(211, 2048, device=self.device), 831 ), 832 ) 833 self.check_model_with_multiple_inputs( 834 Model(), list_example_inputs, dynamic_shapes=dynamic_shapes 835 ) 836 837 def test_addmm_multiple_dynamic(self): 838 class Model(torch.nn.Module): 839 def __init__(self, n, k, device): 840 super().__init__() 841 self.weight = torch.randn(n, k, device=device) 842 self.bias = torch.randn(n, device=device) 843 844 def forward(self, a): 845 return torch.nn.functional.linear(a, self.weight, self.bias) 846 847 M = 8 848 N = 6 849 K = 16 850 model = Model(N, K, self.device) 851 batch = 2 852 a = torch.randn(batch, M, K, device=self.device) 853 dim0_a = Dim("dim0_a", min=1, max=2048) 854 dynamic_shapes = {"a": {0: dim0_a}} 855 list_example_inputs = [(a,)] 856 batch = 2048 857 list_example_inputs.append( 858 (torch.randn(batch, M, K, device=self.device),), 859 ) 860 batch = 128 861 list_example_inputs.append( 862 (torch.randn(batch, M, K, device=self.device),), 863 ) 864 self.check_model_with_multiple_inputs( 865 model, 866 list_example_inputs, 867 dynamic_shapes=dynamic_shapes, 868 options={ 869 "max_autotune": True, 870 "max_autotune_gemm_backends": "TRITON", 871 }, 872 ) 873 874 def test_bmm_multiple_dynamic(self): 875 class Model(torch.nn.Module): 876 def __init__(self) -> None: 877 super().__init__() 878 879 def forward(self, a, b): 880 return torch.bmm(a, b) 881 882 M = 8 883 N = 6 884 K = 16 885 model = Model() 886 batch = 1024 887 a = torch.randn(batch, M, K, device=self.device) 888 b = torch.randn(batch, K, N, device=self.device) 889 dim0_a = Dim("dim0_a", min=1, max=2048) 890 dynamic_shapes = {"a": {0: dim0_a}, "b": {0: dim0_a}} 891 list_example_inputs = [(a, b)] 892 batch = 2048 893 list_example_inputs.append( 894 ( 895 torch.randn(batch, M, K, device=self.device), 896 torch.randn(batch, K, N, device=self.device), 897 ), 898 ) 899 batch = 128 900 list_example_inputs.append( 901 ( 902 torch.randn(batch, M, K, device=self.device), 903 torch.randn(batch, K, N, device=self.device), 904 ), 905 ) 906 self.check_model_with_multiple_inputs( 907 model, 908 list_example_inputs, 909 options={ 910 "max_autotune": True, 911 "max_autotune_gemm_backends": "TRITON", 912 }, 913 dynamic_shapes=dynamic_shapes, 914 ) 915 916 def test_foreach_multiple_dynamic(self): 917 class Model(torch.nn.Module): 918 def __init__(self) -> None: 919 super().__init__() 920 921 def forward(self, x, y): 922 x_unsqueeze = torch.unsqueeze(x, dim=0) 923 y_unsqueeze = torch.unsqueeze(y, dim=0) 924 cat = torch.cat([x_unsqueeze, y_unsqueeze], dim=0) 925 return cat 926 927 model = Model() 928 x = torch.randn(128, 2048, device=self.device) 929 y = torch.randn(128, 2048, device=self.device) 930 dim0_x = Dim("dim0_x", min=1, max=2048) 931 dynamic_shapes = {"x": {0: dim0_x}, "y": {0: dim0_x}} 932 list_example_inputs = [(x, y)] 933 list_example_inputs.append( 934 ( 935 torch.randn(64, 2048, device=self.device), 936 torch.randn(64, 2048, device=self.device), 937 ), 938 ) 939 list_example_inputs.append( 940 ( 941 torch.randn(211, 2048, device=self.device), 942 torch.randn(211, 2048, device=self.device), 943 ), 944 ) 945 self.check_model_with_multiple_inputs( 946 model, 947 list_example_inputs, 948 dynamic_shapes=dynamic_shapes, 949 ) 950 951 # scaled_dot_product_flash_attention 952 @unittest.skipIf(IS_FBCODE, "Not yet runnable in fbcode") 953 @unittest.skipIf(not SM80OrLater, "bfloat16 only supported in sm80+") 954 def test_sdpa(self): 955 class Model(torch.nn.Module): 956 def __init__(self) -> None: 957 super().__init__() 958 959 def forward(self, q, k, v): 960 return torch.nn.functional.scaled_dot_product_attention(q, k, v)[0] 961 962 example_inputs = ( 963 torch.randn(1, 48, 64, 64, dtype=torch.bfloat16, device=self.device), 964 torch.randn(1, 48, 64, 64, dtype=torch.bfloat16, device=self.device), 965 torch.randn(1, 48, 64, 64, dtype=torch.bfloat16, device=self.device), 966 ) 967 self.check_model(Model(), example_inputs) 968 969 @unittest.skipIf(IS_FBCODE, "Not yet runnable in fbcode") 970 @unittest.skipIf(not SM80OrLater, "bfloat16 only supported in sm80+") 971 def test_sdpa_2(self): 972 class Model(torch.nn.Module): 973 def __init__(self) -> None: 974 super().__init__() 975 976 def forward(self, q, k, v, x): 977 t = torch.nn.functional.scaled_dot_product_attention( 978 q, k, v, is_causal=True 979 )[0] 980 return x + t 981 982 example_inputs = ( 983 torch.randn(1, 48, 64, 64, dtype=torch.bfloat16, device=self.device), 984 torch.randn(1, 48, 64, 64, dtype=torch.bfloat16, device=self.device), 985 torch.randn(1, 48, 64, 64, dtype=torch.bfloat16, device=self.device), 986 torch.randn(1, 48, 64, 64, dtype=torch.bfloat16, device=self.device), 987 ) 988 self.check_model(Model(), example_inputs) 989 990 @skipIfNoFBGEMM 991 def test_quantized_linear(self): 992 class Model(torch.nn.Module): 993 def __init__(self, device): 994 super().__init__() 995 self.weight = torch.randn(10, 10, device=device) 996 self.bias = torch.randn(10, device=device) 997 998 def forward(self, x): 999 return torch.ops.quantized.linear_dynamic_fp16_unpacked_weight( 1000 x, self.weight, self.bias 1001 ) 1002 1003 example_inputs = (torch.randn(10, 10, device=self.device),) 1004 with config.patch({"aot_inductor.use_runtime_constant_folding": True}): 1005 self.check_model(Model(self.device), example_inputs) 1006 1007 @skipIfNoFBGEMM 1008 def test_quanatized_int8_linear(self): 1009 class Model(torch.nn.Module): 1010 def __init__(self, device): 1011 super().__init__() 1012 self.weight = torch.randn(10, 10, device=device) 1013 self.bias = torch.randn(10, device=device) 1014 self.input_scale = torch.tensor(0.1) 1015 self.input_zero_point = torch.tensor(0) 1016 self.weight_scale = torch.tensor(0.1) 1017 self.weight_zero_point = torch.tensor(0) 1018 self.output_scale = torch.tensor(0.1) 1019 self.output_zero_point = torch.tensor(0) 1020 self.out_channel = 10 1021 1022 def forward(self, x): 1023 return torch.ops._quantized.wrapped_quantized_linear( 1024 x, 1025 self.input_scale, 1026 self.input_zero_point, 1027 self.weight, 1028 self.weight_scale, 1029 self.weight_zero_point, 1030 self.bias, 1031 self.output_scale, 1032 self.output_zero_point, 1033 self.out_channel, 1034 ) 1035 1036 example_inputs = (torch.randn(10, 10, device=self.device),) 1037 with config.patch({"aot_inductor.use_runtime_constant_folding": True}): 1038 self.check_model(Model(self.device), example_inputs) 1039 1040 def test_zero_grid_with_unbacked_symbols(self): 1041 class Repro(torch.nn.Module): 1042 def __init__(self) -> None: 1043 super().__init__() 1044 1045 def forward(self, x, y): 1046 nz = torch.nonzero(x) 1047 b = torch.ones_like(nz, dtype=torch.float16) 1048 c = torch.zeros_like(nz, dtype=torch.float16) 1049 d = (b + c) @ y 1050 return d.sum() 1051 1052 example_inputs = ( 1053 torch.tensor([1, 1, 1], device=self.device), 1054 torch.randn((1, 32), dtype=torch.float16, device=self.device), 1055 ) 1056 self.check_model(Repro(), example_inputs) 1057 1058 def test_large_grid(self): 1059 if self.device != "cuda": 1060 raise unittest.SkipTest("requires CUDA") 1061 1062 class Model(torch.nn.Module): 1063 def __init__(self) -> None: 1064 super().__init__() 1065 1066 def forward(self, primals_5): 1067 view = torch.ops.aten.reshape.default(primals_5, [-1, 2, 4]) 1068 primals_5 = None 1069 permute = torch.ops.aten.permute.default(view, [0, 2, 1]) 1070 clone = torch.ops.aten.clone.default( 1071 permute, memory_format=torch.contiguous_format 1072 ) 1073 return clone 1074 1075 # let y_grid = 65537 1076 s0 = 16777472 1077 s1 = 8 1078 example_inputs = (torch.rand(s0, s1, device=self.device),) 1079 self.check_model(Model(), example_inputs) 1080 1081 def test_cond_simple(self): 1082 inputs = ( 1083 torch.randn((10, 20), device=self.device), 1084 torch.randn((10, 20), device=self.device), 1085 ) 1086 dim0_ab = Dim("s0", min=2, max=1024) 1087 dynamic_shapes = { 1088 "p": {}, 1089 "a": {0: dim0_ab, 1: None}, 1090 "b": {0: dim0_ab, 1: None}, 1091 } 1092 self.check_model_with_multiple_inputs( 1093 CondModels.Simple(), 1094 prepend_predicates(inputs), 1095 dynamic_shapes=dynamic_shapes, 1096 ) 1097 1098 def test_cond_nested(self): 1099 inputs = ( 1100 torch.randn((10, 20), device=self.device), 1101 torch.randn((10, 20), device=self.device), 1102 torch.randn((10, 20), device=self.device), 1103 ) 1104 dim0_abc = Dim("s0", min=2, max=1024) 1105 dynamic_shapes = { 1106 "p0": {}, 1107 "p1": {}, 1108 "p2": {}, 1109 "a": {0: dim0_abc, 1: None}, 1110 "b": {0: dim0_abc, 1: None}, 1111 "c": {0: dim0_abc, 1: None}, 1112 } 1113 self.check_model_with_multiple_inputs( 1114 CondModels.Nested(), 1115 prepend_predicates(inputs, num_predicates=3), 1116 dynamic_shapes=dynamic_shapes, 1117 ) 1118 1119 def test_cond_with_parameters(self): 1120 inputs = (torch.randn((10, 20), device=self.device),) 1121 dim0_abc = Dim("s0", min=2, max=1024) 1122 dynamic_shapes = { 1123 "p": {}, 1124 "a": {0: dim0_abc, 1: None}, 1125 } 1126 self.check_model_with_multiple_inputs( 1127 CondModels.Parameters(self.device), 1128 prepend_predicates(inputs), 1129 dynamic_shapes=dynamic_shapes, 1130 ) 1131 1132 def test_cond_with_reinterpret_view_inputs_outputs(self): 1133 inputs = ( 1134 torch.randn((10, 20), device=self.device), 1135 torch.randn((10, 20), device=self.device), 1136 ) 1137 dim0_ab = Dim("s0", min=3, max=1024) 1138 dynamic_shapes = { 1139 "p": {}, 1140 "a": {0: dim0_ab, 1: None}, 1141 "b": {0: dim0_ab, 1: None}, 1142 } 1143 self.check_model_with_multiple_inputs( 1144 CondModels.ReinterpretView(), 1145 prepend_predicates(inputs), 1146 dynamic_shapes=dynamic_shapes, 1147 ) 1148 1149 def test_cond_with_multiple_outputs(self): 1150 inputs = ( 1151 torch.randn((10, 20), device=self.device), 1152 torch.randn((10, 20), device=self.device), 1153 torch.randn((30, 40), device=self.device), 1154 ) 1155 dim0_ab = Dim("s0", min=2, max=1024) 1156 dim0_c = Dim("s1", min=2, max=1024) 1157 dynamic_shapes = { 1158 "p": {}, 1159 "a": {0: dim0_ab, 1: None}, 1160 "b": {0: dim0_ab, 1: None}, 1161 "c": {0: dim0_c, 1: None}, 1162 } 1163 self.check_model_with_multiple_inputs( 1164 CondModels.MultipleOutputs(), 1165 prepend_predicates(inputs), 1166 dynamic_shapes=dynamic_shapes, 1167 ) 1168 1169 def test_cond_with_outer_code_before_after(self): 1170 inputs = ( 1171 torch.randn((10, 20), device=self.device), 1172 torch.randn((10, 20), device=self.device), 1173 ) 1174 dim0_ab = Dim("s0", min=2, max=1024) 1175 dynamic_shapes = { 1176 "p": {}, 1177 "a": {0: dim0_ab, 1: None}, 1178 "b": {0: dim0_ab, 1: None}, 1179 } 1180 self.check_model_with_multiple_inputs( 1181 CondModels.OuterCode(), 1182 prepend_predicates(inputs), 1183 dynamic_shapes=dynamic_shapes, 1184 ) 1185 1186 def test_cond_use_buffers_from_outer_scope(self): 1187 inputs = ( 1188 torch.randn((10, 20), device=self.device), 1189 torch.randn((10, 20), device=self.device), 1190 torch.randn((10, 20), device=self.device), 1191 ) 1192 dim0_abc = Dim("s0", min=2, max=1024) 1193 dynamic_shapes = { 1194 "p": {}, 1195 "a": {0: dim0_abc, 1: None}, 1196 "b": {0: dim0_abc, 1: None}, 1197 "c": {0: dim0_abc, 1: None}, 1198 } 1199 self.check_model_with_multiple_inputs( 1200 CondModels.OuterBuffers(), 1201 prepend_predicates(inputs), 1202 dynamic_shapes=dynamic_shapes, 1203 ) 1204 1205 @common_utils.parametrize("dynamic", [False, True]) 1206 def test_cond_non_tensor_predicates(self, dynamic): 1207 inputs1 = ( 1208 torch.randn((10, 20), device=self.device), 1209 torch.randn((15, 20), device=self.device), 1210 ) 1211 inputs2 = ( 1212 torch.randn((10, 20), device=self.device), 1213 torch.randn((5, 20), device=self.device), 1214 ) 1215 inputs = (inputs1,) 1216 dynamic_shapes = None 1217 if dynamic: 1218 inputs = (inputs1, inputs2) 1219 dim0_a = Dim("s0", min=2, max=1024) 1220 dim0_b = Dim("s1", min=2, max=1024) 1221 dynamic_shapes = { 1222 "a": {0: dim0_a, 1: None}, 1223 "b": {0: dim0_b, 1: None}, 1224 } 1225 self.check_model_with_multiple_inputs( 1226 CondModels.WithNonTensorPredicate(), 1227 inputs, 1228 dynamic_shapes=dynamic_shapes, 1229 ) 1230 1231 def test_while_loop_simple(self): 1232 inputs = ( 1233 torch.randn((10, 20), device=self.device), 1234 torch.randn((10, 20), device=self.device), 1235 ) 1236 dim0_ab = Dim("s0", min=2, max=1024) 1237 dynamic_shapes = { 1238 "ci": {}, 1239 "a": {0: dim0_ab, 1: None}, 1240 "b": {0: dim0_ab, 1: None}, 1241 } 1242 self.check_model_with_multiple_inputs( 1243 WhileLoopModels.Simple(), 1244 prepend_counters(inputs), 1245 dynamic_shapes=dynamic_shapes, 1246 ) 1247 1248 def test_while_loop_nested(self): 1249 inputs = ( 1250 torch.randn((10, 20), device=self.device), 1251 torch.randn((10, 20), device=self.device), 1252 ) 1253 dim0_ab = Dim("s0", min=2, max=1024) 1254 dynamic_shapes = { 1255 "ci": {}, 1256 "cj": {}, 1257 "a": {0: dim0_ab, 1: None}, 1258 "b": {0: dim0_ab, 1: None}, 1259 } 1260 self.check_model_with_multiple_inputs( 1261 WhileLoopModels.Nested(), 1262 prepend_counters(inputs, num_counters=2), 1263 dynamic_shapes=dynamic_shapes, 1264 ) 1265 1266 def test_while_loop_with_outer_code(self): 1267 inputs = ( 1268 torch.randn((10, 20), device=self.device), 1269 torch.randn((10, 20), device=self.device), 1270 ) 1271 dim0_ab = Dim("s0", min=2, max=1024) 1272 dynamic_shapes = { 1273 "c": {}, 1274 "a": {0: dim0_ab, 1: None}, 1275 "b": {0: dim0_ab, 1: None}, 1276 } 1277 self.check_model_with_multiple_inputs( 1278 WhileLoopModels.OuterCode(), 1279 prepend_counters(inputs), 1280 dynamic_shapes=dynamic_shapes, 1281 ) 1282 1283 def test_while_loop_with_parameters(self): 1284 inputs = (torch.randn((10, 20), device=self.device),) 1285 dim0_a = Dim("s0", min=2, max=1024) 1286 dynamic_shapes = { 1287 "c": {}, 1288 "a": {0: dim0_a, 1: None}, 1289 } 1290 self.check_model_with_multiple_inputs( 1291 WhileLoopModels.Parameters(self.device), 1292 prepend_counters(inputs), 1293 dynamic_shapes=dynamic_shapes, 1294 ) 1295 1296 def test_while_loop_with_outer_buffers(self): 1297 inputs = ( 1298 torch.randn((10, 20), device=self.device), 1299 torch.randn((10, 20), device=self.device), 1300 ) 1301 # dynamic shapes don't work now due to 1302 # https://github.com/pytorch/pytorch/issues/123596 1303 # dim0_ab = Dim("s0", min=2, max=1024) 1304 # dynamic_shapes = { 1305 # "c": {}, 1306 # "a": {0: dim0_ab, 1: None}, 1307 # "b": {0: dim0_ab, 1: None}, 1308 # } 1309 dynamic_shapes = None 1310 self.check_model_with_multiple_inputs( 1311 WhileLoopModels.OuterBuffers(), 1312 prepend_counters(inputs), 1313 dynamic_shapes=dynamic_shapes, 1314 ) 1315 1316 @config.patch({"is_predispatch": True}) 1317 def test_constant(self): 1318 class M(torch.nn.Module): 1319 def __init__(self, device): 1320 super().__init__() 1321 self.device = device 1322 1323 def forward(self, x): 1324 t = torch.tensor(x.size(-1), device=self.device, dtype=torch.float) 1325 t = torch.sqrt(t * 3) 1326 return x * t 1327 1328 self.check_model(M(self.device), (torch.randn(5, 5, device=self.device),)) 1329 1330 def test_zero_grid_with_backed_symbols(self): 1331 class Repro(torch.nn.Module): 1332 def __init__(self) -> None: 1333 super().__init__() 1334 1335 def forward(self, x, b): 1336 return x + b 1337 1338 example_inputs = ( 1339 x := torch.randn((3, 2), device=self.device), 1340 torch.randn((1, 2), device=self.device), 1341 ) 1342 torch._dynamo.mark_dynamic(x, index=0) # Create dynamic symbol 1343 1344 # Compile & run model where dynamic dim size > 0. 1345 so_path: str = AOTIRunnerUtil.compile( 1346 Repro(), 1347 example_inputs, 1348 ) 1349 aot_inductor_module = AOTIRunnerUtil.load("cuda", so_path) 1350 aot_inductor_module(*example_inputs) 1351 1352 # Re-run where dynamic dim size is 0. 1353 example_inputs = ( 1354 torch.randn((0, 2), device=self.device), 1355 torch.randn((1, 2), device=self.device), 1356 ) 1357 actual = aot_inductor_module(*example_inputs) 1358 expected = Repro()(*example_inputs) 1359 torch.testing.assert_close(actual, expected) 1360 1361 def test_repeat_interleave(self): 1362 class Repro(torch.nn.Module): 1363 def __init__(self) -> None: 1364 super().__init__() 1365 1366 def forward(self, x): 1367 return torch.ops.aten.repeat_interleave.Tensor(x, output_size=12) 1368 1369 example_inputs = (torch.ones((1,), dtype=torch.int32, device=self.device) * 12,) 1370 self.check_model(Repro(), example_inputs) 1371 1372 def test_dynamic_cat(self): 1373 class Model(torch.nn.Module): 1374 def __init__(self) -> None: 1375 super().__init__() 1376 1377 def forward(self, a, b): 1378 return torch.cat([a, b], dim=0) 1379 1380 a = torch.randn(2, 4, device=self.device) 1381 b = torch.randn(3, 4, device=self.device) 1382 dim0_a = Dim("dim0_a", min=1, max=10) 1383 dim0_b = Dim("dim0_b", min=1, max=20) 1384 dynamic_shapes = {"a": {0: dim0_a}, "b": {0: dim0_b}} 1385 example_inputs = (a, b) 1386 self.check_model(Model(), example_inputs, dynamic_shapes=dynamic_shapes) 1387 1388 def test_buffer_mutation_1(self): 1389 class Model(torch.nn.Module): 1390 def __init__(self, device): 1391 super().__init__() 1392 self.foo = torch.nn.Buffer(torch.randn(4, 4, device=device)) 1393 1394 def forward(self, x): 1395 self.foo.add_(1) 1396 return self.foo + x 1397 1398 example_inputs = (torch.rand(4, 4, device=self.device),) 1399 self.check_model(Model(self.device), example_inputs) 1400 1401 def test_non_tensor_input(self): 1402 class Model(torch.nn.Module): 1403 def forward(self, a, b, alpha=1.0): 1404 return torch.add(a, b, alpha=alpha) 1405 1406 a = torch.randn(10, device=self.device) 1407 b = torch.randn(10, device=self.device) 1408 1409 for simdlen in [0, None]: 1410 with torch._inductor.config.patch({"cpp.simdlen": simdlen}): 1411 so_path = torch._export.aot_compile( 1412 torch.ops.aten.add, 1413 args=(a, b), 1414 kwargs={"alpha": 2.0}, 1415 ) 1416 kernel_runner = AOTIRunnerUtil.load_runner(self.device, so_path) 1417 res = kernel_runner.run([a, b]) 1418 self.assertTrue(isinstance(res, list)) 1419 self.assertTrue(len(res) == 1) 1420 self.assertEqual(Model()(a, b, alpha=2.0), res[0]) 1421 1422 def test_buffer_mutation_2(self): 1423 class Model(torch.nn.Module): 1424 def __init__(self, device): 1425 super().__init__() 1426 self.foo = torch.nn.Buffer(torch.arange(10, device=device)) 1427 self.bar = torch.nn.Buffer(torch.arange(10, device=device)) 1428 1429 def forward(self, x): 1430 self.bar.mul_(2) 1431 self.foo[5] = self.bar[0] 1432 return x + self.bar, x * self.foo 1433 1434 example_inputs = (torch.randn(10, device=self.device),) 1435 self.check_model(Model(self.device), example_inputs) 1436 1437 def test_buffer_mutation_3(self): 1438 class KVCache(torch.nn.Module): 1439 def __init__( 1440 self, 1441 max_batch_size, 1442 max_seq_length, 1443 n_heads, 1444 head_dim, 1445 dtype=torch.float, 1446 ): 1447 super().__init__() 1448 cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim) 1449 self.k_cache = torch.nn.Buffer(torch.zeros(cache_shape, dtype=dtype)) 1450 self.v_cache = torch.nn.Buffer(torch.zeros(cache_shape, dtype=dtype)) 1451 1452 def update(self, input_pos, k_val, v_val): 1453 # input_pos: [S], k_val: [B, H, S, D] 1454 k_out = self.k_cache 1455 v_out = self.v_cache 1456 k_out[:, :, input_pos] = k_val 1457 v_out[:, :, input_pos] = v_val 1458 1459 return k_out, v_out 1460 1461 class Model(torch.nn.Module): 1462 def __init__(self, device): 1463 super().__init__() 1464 self.kv_cache = KVCache(1, 256, 6, 48) 1465 1466 def forward(self, inp_pos, k, v): 1467 self.kv_cache.update(inp_pos, k, v) 1468 return self.kv_cache.k_cache + 1, self.kv_cache.v_cache / 2 1469 1470 example_inputs = ( 1471 torch.tensor([0], device=self.device), 1472 torch.randn(1, 6, 1, 48, device=self.device), 1473 torch.randn(1, 6, 1, 48, device=self.device), 1474 ) 1475 model = Model(self.device) 1476 self.check_model(model, example_inputs) 1477 self.code_check_count(model, example_inputs, "empty_strided", 2) 1478 1479 def test_buffer_mutation_4(self): 1480 if self.device != "cuda": 1481 raise unittest.SkipTest("requires CUDA") 1482 1483 class Model(torch.nn.Module): 1484 def __init__(self) -> None: 1485 super().__init__() 1486 self.register_buffer( 1487 "_tensor_constant0", 1488 torch.randint(1, size=[38], dtype=torch.int64, device="cpu"), 1489 ) 1490 1491 def forward(self, x): 1492 return x + self._tensor_constant0.to(torch.device(type="cuda", index=0)) 1493 1494 example_inputs = ( 1495 torch.randint(1, size=[38], dtype=torch.int64, device="cuda"), 1496 ) 1497 torch._export.aot_compile(Model(), example_inputs) 1498 1499 @requires_multigpu() 1500 def test_replicate_on_devices(self): 1501 if self.device != "cuda": 1502 raise unittest.SkipTest("requires CUDA") 1503 1504 class Model(torch.nn.Module): 1505 def __init__(self, w1, w2): 1506 super().__init__() 1507 self.w1 = w1 1508 self.w2 = w2 1509 1510 def forward(self, x, y): 1511 a = x * self.w1 1512 b = y * self.w2 1513 return a + b 1514 1515 w1 = torch.randn(10, 10) 1516 w2 = torch.randn(10, 10) 1517 inputs = (torch.randn(10, 10), torch.randn(10, 10)) 1518 result_cpu = Model(w1, w2)(*inputs) 1519 1520 # Compile model with AOTInductor 1521 with torch.cuda.device(0), config.patch("abi_compatible", self.abi_compatible): 1522 so_path = AOTIRunnerUtil.compile( 1523 model=Model(w1.cuda(0), w2.cuda(0)), 1524 example_inputs=tuple(t.cuda(0) for t in inputs), 1525 ) 1526 1527 # Run model on cuda:N 1528 for i in range(torch.cuda.device_count()): 1529 with torch.cuda.device(i): 1530 example_inputs = tuple(t.cuda(i) for t in inputs) 1531 optimized = AOTIRunnerUtil.load("cuda", so_path) 1532 result_cuda = optimized(*example_inputs) 1533 self.assertTrue(same(result_cpu, result_cuda.cpu())) 1534 1535 def test_pytree_inputs(self): 1536 class M(torch.nn.Module): 1537 def __init__(self) -> None: 1538 super().__init__() 1539 1540 def forward(self, x: Dict[str, torch.Tensor]): 1541 device = next(iter(x.values())).device 1542 add_ = torch.zeros(5, device=device) 1543 mul_ = torch.ones(5, device=device) 1544 for v in x.values(): 1545 add_ += v 1546 mul_ *= v 1547 1548 return [add_, mul_] 1549 1550 self.check_model( 1551 M(), 1552 ( 1553 { 1554 "x": torch.ones(5, device=self.device), 1555 "y": torch.ones(5, device=self.device), 1556 }, 1557 ), 1558 ) 1559 1560 @requires_multigpu() 1561 def test_non_default_cuda_device(self): 1562 if self.device != "cuda": 1563 raise unittest.SkipTest("requires CUDA") 1564 1565 class Model(torch.nn.Module): 1566 def __init__(self, weight): 1567 super().__init__() 1568 self.weight = weight 1569 1570 def forward(self, x, y): 1571 return x + torch.nn.functional.linear(y, self.weight) 1572 1573 weight = torch.randn(10, 10) 1574 inputs = (torch.randn(10, 10), torch.randn(10, 10)) 1575 result_cpu = Model(weight)(*inputs) 1576 1577 with torch.cuda.device(0), torch.no_grad(), config.patch( 1578 "abi_compatible", self.abi_compatible 1579 ): 1580 result_cuda_0 = AOTIRunnerUtil.run( 1581 "cuda", Model(weight.cuda(0)), tuple(t.cuda(0) for t in inputs) 1582 ) 1583 1584 with torch.cuda.device(1), torch.no_grad(), config.patch( 1585 "abi_compatible", self.abi_compatible 1586 ): 1587 result_cuda_1 = AOTIRunnerUtil.run( 1588 "cuda", Model(weight.cuda(1)), tuple(t.cuda(1) for t in inputs) 1589 ) 1590 1591 self.assertTrue(same(result_cpu, result_cuda_0.cpu())) 1592 self.assertTrue(same(result_cpu, result_cuda_1.cpu())) 1593 1594 def test_reuse_kernel(self): 1595 class Model(torch.nn.Module): 1596 def __init__(self) -> None: 1597 super().__init__() 1598 1599 def forward(self, x, y): 1600 a = torch.sin(x) 1601 b = torch.mm(a, y) 1602 c = torch.sin(b) 1603 d = torch.mm(b, c) 1604 return d 1605 1606 example_inputs = ( 1607 torch.randn(87, 87, device=self.device), 1608 torch.randn(87, 87, device=self.device), 1609 ) 1610 model = Model() 1611 self.check_model( 1612 model, example_inputs, atol=1e-4, rtol=1e-4 1613 ) # 1e-4 is the tol value used in pytorch/torch/_dynamo/utils.py 1614 1615 if self.device == "cuda": 1616 self.code_check_count( 1617 model, example_inputs, "triton_poi_fused_sin_0 = loadKernel(", 1 1618 ) 1619 1620 def test_reuse_kernel_dynamic(self): 1621 class Model(torch.nn.Module): 1622 def __init__(self, device): 1623 super().__init__() 1624 self.cst = torch.randn(48, device=device, dtype=torch.float) 1625 self.weights = torch.randn(6, 48, 48, device=device, dtype=torch.float) 1626 self.cst_1 = torch.randn(48, device=device, dtype=torch.float) 1627 self.weights_1 = torch.randn( 1628 6, 48, 48, device=device, dtype=torch.float 1629 ) 1630 1631 def forward(self, x, y, z): 1632 dim0 = x.size(1) 1633 add_0 = z + z 1634 expand_2 = add_0.expand(-1, -1, 48) 1635 # [s0, 6, 48] 1636 mul_3 = add_0 * expand_2 1637 # [6, s0, 48] 1638 permute_4 = torch.permute(mul_3, (1, 0, 2)) 1639 # [6, s0, 48] 1640 bmm_5 = torch.bmm(permute_4, self.weights) 1641 add_6 = bmm_5 + self.cst 1642 reshape_7 = torch.reshape(add_6, [6, dim0 * 6, 8]) 1643 # [6*s0, 6, 8] 1644 permute_8 = torch.permute(reshape_7, (1, 0, 2)) 1645 mul_9 = permute_8 * 0.123 1646 reshape_10 = torch.reshape(y, [8, dim0 * 6, 4]) 1647 # [6*s0, 8, 4] 1648 permute_11 = torch.permute(reshape_10, (1, 0, 2)) 1649 bmm_12 = torch.bmm(mul_9, permute_11) 1650 1651 add_0_1 = z + z 1652 expand_2_1 = add_0_1.expand(-1, -1, 48) 1653 # [s0, 6, 48] 1654 mul_3_1 = add_0_1 * expand_2_1 1655 # [6, s0, 48] 1656 permute_4_1 = torch.permute(mul_3_1, (1, 0, 2)) 1657 # [6, s0, 48] 1658 bmm_5_1 = torch.bmm(permute_4_1, self.weights_1) 1659 add_6_1 = bmm_5_1 + self.cst_1 1660 reshape_7_1 = torch.reshape(add_6_1, [6, dim0 * 6, 8]) 1661 # [6*s0, 6, 8] 1662 permute_8_1 = torch.permute(reshape_7_1, (1, 0, 2)) 1663 mul_9_1 = permute_8_1 * 0.123 1664 reshape_10_1 = torch.reshape(y, [8, dim0 * 6, 4]) 1665 # [6*s0, 8, 4] 1666 permute_11_1 = torch.permute(reshape_10_1, (1, 0, 2)) 1667 bmm_12_1 = torch.bmm(mul_9_1, permute_11_1) 1668 return bmm_12 + bmm_12_1 1669 1670 x = torch.randn(6, 2, 48, device=self.device, dtype=torch.float) 1671 y = torch.randn(48, 2, 4, device=self.device, dtype=torch.float) 1672 z = torch.randn(2, 6, 1, device=self.device, dtype=torch.float) 1673 dim0 = Dim("dim0", min=1, max=2048) 1674 dynamic_shapes = { 1675 "x": {1: dim0}, 1676 "y": {1: dim0}, 1677 "z": {0: dim0}, 1678 } 1679 1680 example_inputs = (x, y, z) 1681 m = Model(self.device).to(dtype=torch.float) 1682 self.check_model(m, example_inputs, dynamic_shapes=dynamic_shapes) 1683 1684 def test_fake_tensor_device_validation(self): 1685 if self.device != "cuda": 1686 raise unittest.SkipTest("requires CUDA") 1687 1688 class Model(torch.nn.Module): 1689 def __init__(self) -> None: 1690 super().__init__() 1691 1692 def forward(self, x, y): 1693 return x + y 1694 1695 example_inputs = (torch.randn(10, 10), torch.randn(10, 10)) 1696 1697 # Export on CPU 1698 exported_program = export(Model(), example_inputs) 1699 1700 # Compile exported model on CUDA 1701 gm = exported_program.graph_module.to(self.device) 1702 with self.assertRaisesRegex(ValueError, "Device mismatch between fake input"): 1703 torch._inductor.aot_compile( 1704 gm, tuple(i.to(self.device) for i in example_inputs) 1705 ) 1706 1707 def test_fx_gm_return_tuple_validation(self): 1708 from torch.fx.experimental.proxy_tensor import make_fx 1709 1710 class Model(torch.nn.Module): 1711 def __init__(self) -> None: 1712 super().__init__() 1713 1714 def forward(self, x, y): 1715 return x + y 1716 1717 example_inputs = (torch.randn(10, 10), torch.randn(10, 10)) 1718 1719 gm = make_fx(Model(), tracing_mode="symbolic")(*example_inputs) 1720 with self.assertRaisesRegex( 1721 AssertionError, 1722 r"Graph output must be a tuple\(\). This is so that we can avoid " 1723 "pytree processing of the outputs.", 1724 ): 1725 torch._inductor.aot_compile(gm, example_inputs) 1726 1727 @unittest.mock.patch("torch._inductor.graph.supported_dtype_of_cpp_wrapper") 1728 def test_unsupported_input_dtype(self, supported_dtype_of_cpp_wrapper_mock): 1729 supported_dtype_of_cpp_wrapper_mock.return_value = False 1730 1731 class Model(torch.nn.Module): 1732 def __init__(self) -> None: 1733 super().__init__() 1734 1735 def forward(self, x, y): 1736 return x + y 1737 1738 example_inputs = ( 1739 torch.randn(10, 10).to(self.device), 1740 torch.randn(10, 10).to(self.device), 1741 ) 1742 with self.assertRaisesRegex( 1743 CppWrapperCodeGenError, "Unsupported input dtype torch.float32" 1744 ): 1745 torch._export.aot_compile(Model(), example_inputs) 1746 1747 supported_dtype_of_cpp_wrapper_mock.assert_called_once_with( 1748 torch.float32, self.device == "cuda" 1749 ) 1750 1751 def test_consecutive_compiles(self): 1752 """Test that compilation behaves correctly with cache hits""" 1753 1754 class TestModule(torch.nn.Module): 1755 def __init__(self) -> None: 1756 super().__init__() 1757 1758 def forward(self, x): 1759 return x + 1 1760 1761 mod = TestModule() 1762 inp = torch.rand(1) 1763 mod(inp) 1764 mod2 = torch.fx.symbolic_trace(mod, concrete_args=[inp]) 1765 so = torch._export.aot_compile(mod2, (inp,)) 1766 assert so is not None 1767 # compile the 2nd time with cache hit 1768 so = torch._export.aot_compile(mod2, (inp,)) 1769 assert so is not None 1770 1771 def test_normal_functional(self): 1772 class Model(torch.nn.Module): 1773 def __init__(self) -> None: 1774 super().__init__() 1775 1776 def forward(self, x): 1777 return torch.ops.aten.normal_functional.default(x) 1778 1779 self.check_model(Model(), (torch.empty(4, 1, 4, 4),)) 1780 1781 def test_empty_graph(self): 1782 class Model(torch.nn.Module): 1783 def __init__(self) -> None: 1784 super().__init__() 1785 1786 def forward(self, x): 1787 return x 1788 1789 example_inputs = (torch.randn(8, 4, 4, device=self.device),) 1790 self.check_model(Model(), example_inputs) 1791 1792 @unittest.skipIf(IS_FBCODE, "Not runnable in fbcode") 1793 def test_dup_unbacked_sym_decl(self): 1794 class Model(torch.nn.Module): 1795 def __init__(self) -> None: 1796 super().__init__() 1797 1798 def forward(self, x): 1799 abs_1 = torch.ops.aten.abs.default(x) 1800 lt = torch.ops.aten.lt.Scalar(abs_1, 0.001) 1801 eq = torch.ops.aten.eq.Scalar(lt, 0) 1802 index_1 = torch.ops.aten.index.Tensor(x, [eq]) 1803 sin = torch.ops.aten.sin.default(index_1) 1804 index_2 = torch.ops.aten.index.Tensor(x, [eq]) 1805 div_3 = torch.ops.aten.div.Tensor(sin, index_2) 1806 return div_3 1807 1808 example_inputs = (torch.randn(4, 4, 4, 4).to(self.device),) 1809 self.check_model(Model(), example_inputs) 1810 1811 # This exercises _eliminate_unbacked path in ShapeEnv 1812 @unittest.skipIf(IS_FBCODE, "Not runnable in fbcode") 1813 def test_dup_unbacked_sym_decl_with_refinement(self): 1814 class Model(torch.nn.Module): 1815 def __init__(self) -> None: 1816 super().__init__() 1817 1818 def forward(self, x): 1819 abs_1 = torch.ops.aten.abs.default(x) 1820 lt = torch.ops.aten.lt.Scalar(abs_1, 0.001) 1821 eq = torch.ops.aten.eq.Scalar(lt, 0) 1822 index_1 = torch.ops.aten.index.Tensor(x, [eq]) 1823 torch._check(index_1.size(0) == 4**4) 1824 sin = torch.ops.aten.sin.default(index_1) 1825 index_2 = torch.ops.aten.index.Tensor(x, [eq]) 1826 div_3 = torch.ops.aten.div.Tensor(sin, index_2) 1827 return div_3 1828 1829 example_inputs = (torch.ones(4, 4, 4, 4).to(self.device),) 1830 self.check_model(Model(), example_inputs) 1831 1832 def test_run_with_grad_enabled(self): 1833 class Model(torch.nn.Module): 1834 def forward(self, x, weight, bias): 1835 return torch.ops.aten.addmm(bias, weight, x) 1836 1837 m = Model().to(device=self.device) 1838 x = torch.rand(8, 8, device=self.device, requires_grad=True) 1839 weight = torch.rand(8, 8, device=self.device, requires_grad=True) 1840 bias = torch.rand(8, device=self.device, requires_grad=True) 1841 example_inputs = (x, weight, bias) 1842 1843 expected = m(*example_inputs) 1844 expected = pytree.tree_leaves(expected) 1845 1846 # compiler under no_grad 1847 with torch.no_grad(): 1848 so_path = AOTIRunnerUtil.compile(m, example_inputs) 1849 1850 # run under grad enabled 1851 self.assertTrue(torch.is_grad_enabled()) 1852 1853 optimized = AOTIRunnerUtil.load(self.device, so_path) 1854 actual = optimized(*example_inputs) 1855 actual = pytree.tree_leaves(actual) 1856 1857 self.assertTrue(same(actual, expected)) 1858 1859 def test_return_constant(self): 1860 class Model(torch.nn.Module): 1861 def __init__(self, device): 1862 super().__init__() 1863 self.cst = torch.randn(5, 5, device=device) 1864 1865 def forward(self, x): 1866 a = self.cst.clone() 1867 return (x, a) 1868 1869 x = torch.randn(5, device=self.device) 1870 self.check_model(Model(self.device), (x,)) 1871 1872 def test_return_view_constant(self): 1873 class Model(torch.nn.Module): 1874 def __init__(self, device): 1875 super().__init__() 1876 self.cst = torch.randn(5, 5, device=device) 1877 1878 def forward(self, x): 1879 a = torch.transpose(self.cst, 0, 1) 1880 return (x, a) 1881 1882 x = torch.randn(5, device=self.device) 1883 self.check_model(Model(self.device), (x,)) 1884 1885 def test_with_profiler(self): 1886 class Model(torch.nn.Module): 1887 def __init__(self) -> None: 1888 super().__init__() 1889 self.linear = torch.nn.Linear(10, 10) 1890 1891 def forward(self, x, y): 1892 return x + self.linear(y) 1893 1894 example_inputs = ( 1895 torch.randn(10, 10, device=self.device), 1896 torch.randn(10, 10, device=self.device), 1897 ) 1898 with config.patch({"profile_bandwidth": "1", "profile_bandwidth_regex": ""}): 1899 self.check_model(Model(), example_inputs) 1900 1901 def test_with_no_triton_profiler(self): 1902 class Model(torch.nn.Module): 1903 def __init__(self) -> None: 1904 super().__init__() 1905 1906 def forward(self, x): 1907 return torch.permute(x, (1, 0)) 1908 1909 example_inputs = (torch.randn(10, 10, device=self.device),) 1910 with config.patch({"profile_bandwidth": "1", "profile_bandwidth_regex": ""}): 1911 self.check_model(Model(), example_inputs) 1912 1913 def test_repeat_output(self): 1914 class Model(torch.nn.Module): 1915 def __init__(self) -> None: 1916 super().__init__() 1917 1918 def forward(self, x): 1919 y = torch.sin(x) 1920 return y, y 1921 1922 example_inputs = (torch.randn(3, 10, device=self.device),) 1923 self.check_model(Model(), example_inputs) 1924 1925 def test_view_outputs(self): 1926 class Model(torch.nn.Module): 1927 def forward(self, x): 1928 y = torch.sin(x) 1929 y_same_size = y.view(*y.shape) 1930 y_diff_size = y.view(1, *y.shape) 1931 return y, y_same_size, y_diff_size 1932 1933 example_inputs = (torch.randn(3, 10, device=self.device),) 1934 self.check_model(Model(), example_inputs) 1935 1936 @skip_if_no_torchvision 1937 def test_missing_cubin(self): 1938 from torchvision.models.resnet import Bottleneck, ResNet 1939 1940 class Model(ResNet): 1941 def __init__(self) -> None: 1942 super().__init__( 1943 block=Bottleneck, 1944 layers=[3, 4, 6, 3], 1945 replace_stride_with_dilation=[False, False, True], 1946 norm_layer=None, 1947 ) 1948 1949 def forward(self, x): 1950 x = self.conv1(x) 1951 x = self.bn1(x) 1952 x = self.relu(x) 1953 f1 = x 1954 x = self.maxpool(x) 1955 x = self.layer1(x) 1956 f2 = x 1957 x = self.layer2(x) 1958 f3 = x 1959 x = self.layer3(x) 1960 x = self.layer4(x) 1961 f4 = x 1962 return [f1, f2, f3, f4] 1963 1964 # Call eval() here so that batch_norm won't update the running stats 1965 # Use float64 to avoid numeric difference failure 1966 model = Model().to(device=self.device, dtype=torch.float64).eval() 1967 example_inputs = ( 1968 torch.randn(4, 3, 64, 64, device=self.device, dtype=torch.float64), 1969 ) 1970 self.check_model(model, example_inputs) 1971 1972 @common_utils.parametrize("grid_type", [1, 2, 3]) 1973 @common_utils.parametrize("num_dims", [1, 2]) 1974 @common_utils.parametrize("dynamic", [False, True]) 1975 @common_utils.parametrize("autotune", [False, True]) 1976 def test_triton_kernel(self, grid_type, num_dims, dynamic, autotune): 1977 if self.device != "cuda": 1978 raise unittest.SkipTest("requires CUDA") 1979 1980 class Model(torch.nn.Module): 1981 def __init__(self) -> None: 1982 super().__init__() 1983 1984 def forward(self, x, y): 1985 output = torch.zeros_like(x) 1986 if autotune and num_dims == 2: 1987 x_elements = output.size()[0] 1988 y_elements = output.size()[1] 1989 else: 1990 n_elements = output.numel() 1991 1992 # Select grid 1993 if autotune and num_dims == 2: 1994 if grid_type == 1: 1995 grid = (x_elements, y_elements) 1996 elif grid_type == 2: 1997 grid = lambda meta: ( # noqa: E731 1998 triton.cdiv(x_elements, meta["BLOCK_SIZE_X"]), 1999 triton.cdiv(y_elements, meta["BLOCK_SIZE_Y"]), 2000 ) 2001 else: 2002 2003 def grid_fn(meta): 2004 return ( 2005 triton.cdiv(x_elements, meta["BLOCK_SIZE_X"]), 2006 triton.cdiv(y_elements, meta["BLOCK_SIZE_Y"]), 2007 ) 2008 2009 grid = grid_fn 2010 else: 2011 if grid_type == 1: 2012 grid = (n_elements,) 2013 elif grid_type == 2: 2014 grid = lambda meta: ( # noqa: E731 2015 triton.cdiv(n_elements, meta["BLOCK_SIZE"]), 2016 ) 2017 else: 2018 2019 def grid_fn(meta): 2020 return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 2021 2022 grid = grid_fn 2023 2024 # Select kernel 2025 if autotune: 2026 if num_dims == 1: 2027 add_kernel_autotuned[grid](x, y, output, n_elements) 2028 else: 2029 add_kernel_2d_autotuned[grid]( 2030 x, y, output, x_elements, y_elements 2031 ) 2032 else: 2033 add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=16) 2034 return output 2035 2036 dims = [10] * num_dims 2037 x = torch.randn(*dims, device=self.device) 2038 y = torch.randn(*dims, device=self.device) 2039 dynamic_shapes = [] 2040 if dynamic: 2041 dim0_x = Dim("dim0_x", min=1, max=10) 2042 dim0_y = Dim("dim0_y", min=1, max=10) 2043 dynamic_shapes = {"x": {0: dim0_x}, "y": {0: dim0_y}} 2044 self.check_model(Model(), (x, y), dynamic_shapes=dynamic_shapes) 2045 2046 def test_triton_kernel_dynamic_shape_with_div(self): 2047 if self.device != "cuda": 2048 raise unittest.SkipTest("requires CUDA") 2049 2050 @triton.jit 2051 def pass_kernel(x, num): 2052 pass 2053 2054 class Model(torch.nn.Module): 2055 def __init__(self) -> None: 2056 super().__init__() 2057 2058 def forward(self, x): 2059 num = x.numel() // 4 2060 2061 grid = lambda meta: (triton.cdiv(num, 16),) # noqa: E731 2062 pass_kernel[grid](x, num) 2063 return x 2064 2065 x = torch.randn(10, device=self.device) 2066 dim0_x = Dim("dim0_x", min=1, max=10) 2067 dynamic_shapes = {"x": {0: dim0_x}} 2068 self.check_model(Model(), (x,), dynamic_shapes=dynamic_shapes) 2069 2070 def test_triton_kernel_reinterpret_view(self): 2071 if self.device != "cuda": 2072 raise unittest.SkipTest("requires CUDA") 2073 2074 @triton.jit 2075 def pass_kernel(x, y): 2076 pass 2077 2078 class Model(torch.nn.Module): 2079 def __init__(self) -> None: 2080 super().__init__() 2081 2082 def forward(self, x): 2083 out = torch.zeros_like(x[:, 4:]) 2084 # the slicing below creates two ReinterpretView 2085 # instances: with offset=3 and offset=4 2086 add_kernel[(10,)]( 2087 in_ptr0=x[:, 3:-1], 2088 in_ptr1=x[:, 4:], 2089 out_ptr=out, 2090 n_elements=160, 2091 BLOCK_SIZE=16, 2092 ) 2093 return out 2094 2095 example_inputs = (torch.randn(10, 20, device=self.device),) 2096 self.check_model(Model(), example_inputs) 2097 2098 def test_triton_kernel_sympy_expr_arg(self): 2099 if self.device != "cuda": 2100 raise unittest.SkipTest("requires CUDA") 2101 2102 class Model(torch.nn.Module): 2103 def forward(self, x, e): 2104 sympy_expr = max(1, e.item()) 2105 out = torch.zeros_like(x) 2106 add_kernel[(1,)]( 2107 in_ptr0=x, 2108 in_ptr1=x, 2109 out_ptr=out, 2110 n_elements=sympy_expr, 2111 BLOCK_SIZE=1, 2112 ) 2113 return out 2114 2115 NUMEL = 64 2116 inputs = ( 2117 torch.randn(NUMEL, device=self.device), 2118 torch.tensor(NUMEL, device=self.device), 2119 ) 2120 self.check_model(Model(), inputs) 2121 2122 def test_triton_kernel_sympy_fn_like_arg(self): 2123 # This test should hit sympy.expand("sqrt") which crashes with 2124 # AttributeError: 'function' object has no attribute 'expand'. 2125 if self.device != "cuda": 2126 raise unittest.SkipTest("requires CUDA") 2127 2128 class Model(torch.nn.Module): 2129 def forward(self, x): 2130 out = torch.zeros_like(x) 2131 add_kernel_with_optional_param[1,]( 2132 in_ptr0=x, 2133 in_ptr1=x, 2134 out_ptr=out, 2135 n_elements=x.numel(), 2136 BLOCK_SIZE=1, 2137 ARGS_PASSED="sqrt", # sqrt is a valid sympy fn 2138 ) 2139 return out 2140 2141 inputs = (torch.randn(4, device=self.device),) 2142 self.check_model(Model(), inputs) 2143 2144 def test_triton_kernel_with_none_input(self): 2145 if self.device != "cuda": 2146 raise unittest.SkipTest("requires CUDA") 2147 2148 class Model(torch.nn.Module): 2149 def __init__(self) -> None: 2150 super().__init__() 2151 2152 def forward(self, x, y): 2153 n_elements = x.size()[0] 2154 BLOCK_SIZE = 1024 2155 2156 output_wo_y = torch.empty_like(x) 2157 output_with_y = torch.empty_like(x) 2158 2159 wo_kernel = add_kernel_with_optional_param[(1,)]( 2160 x, 2161 None, 2162 output_wo_y, 2163 n_elements, 2164 ARGS_PASSED="one", 2165 BLOCK_SIZE=BLOCK_SIZE, 2166 ) 2167 with_kernel = add_kernel_with_optional_param[(1,)]( 2168 x, 2169 y, 2170 output_with_y, 2171 n_elements, 2172 ARGS_PASSED="two", 2173 BLOCK_SIZE=BLOCK_SIZE, 2174 ) 2175 2176 return 2.71 * output_wo_y + 3.14 * output_with_y 2177 2178 example_inputs = ( 2179 torch.randn(1023, device=self.device), 2180 torch.randn(1023, device=self.device), 2181 ) 2182 2183 self.check_model(Model(), example_inputs) 2184 2185 def test_triton_kernel_equal_to_1_arg(self): 2186 if self.device != "cuda": 2187 raise unittest.SkipTest("requires CUDA") 2188 2189 class Model(torch.nn.Module): 2190 def forward(self, x, y): 2191 out = torch.empty_like(x) 2192 n_elements = x.numel() 2193 add_kernel[(n_elements,)](x, y, out, n_elements, BLOCK_SIZE=16) 2194 return out 2195 2196 example_inputs = ( 2197 torch.randn(1, device=self.device), 2198 torch.randn(1, device=self.device), 2199 ) 2200 2201 self.check_model(Model(), example_inputs) 2202 2203 @common_utils.parametrize("dynamic", [False, True]) 2204 def test_triton_kernel_equal_to_1_float_arg(self, dynamic): 2205 if self.device != "cuda": 2206 raise unittest.SkipTest("requires CUDA") 2207 2208 class Model(torch.nn.Module): 2209 def forward(self, x, y): 2210 out = torch.empty_like(x) 2211 n_elements = x.numel() 2212 scaling_factor = (n_elements**0) / 1.0 2213 add_kernel_with_scaling[(n_elements,)]( 2214 x, 2215 y, 2216 out, 2217 n_elements, 2218 scaling_factor, 2219 BLOCK_SIZE=16, 2220 ) 2221 return out 2222 2223 dynamic_shapes = None 2224 if dynamic: 2225 dim0_xy = Dim("s0", min=2, max=1024) 2226 dynamic_shapes = { 2227 "x": {0: dim0_xy, 1: None}, 2228 "y": {0: dim0_xy, 1: None}, 2229 } 2230 example_inputs = ( 2231 torch.randn(2, device=self.device), 2232 torch.randn(2, device=self.device), 2233 ) 2234 self.check_model( 2235 Model(), 2236 example_inputs, 2237 dynamic_shapes=dynamic_shapes, 2238 ) 2239 2240 def test_triton_kernel_weird_param_order(self): 2241 if self.device != "cuda": 2242 raise unittest.SkipTest("requires CUDA") 2243 2244 class Model(torch.nn.Module): 2245 def __init__(self) -> None: 2246 super().__init__() 2247 2248 def forward(self, x): 2249 out = torch.empty_like(x) 2250 add_kernel_autotuned_weird_param_order[16,]( 2251 in_ptr0=x, 2252 in_ptr1=x, 2253 n_elements=x.numel(), 2254 out_ptr=out, 2255 ) 2256 return out 2257 2258 x = torch.randn(16, 16, device=self.device) 2259 self.check_model(Model(), (x,)) 2260 2261 def test_shifted_constraint_ranges(self): 2262 class Model(torch.nn.Module): 2263 def __init__(self) -> None: 2264 super().__init__() 2265 2266 def forward( 2267 self, 2268 x: torch.Tensor, 2269 y: torch.Tensor, 2270 ): 2271 torch._check(y.size(0) == x.size(0) + 1) 2272 return x.sum(0) + y.sum(0) 2273 2274 a = torch.randn((4, 5), device=self.device) 2275 b = torch.randn((5, 5), device=self.device) 2276 dim0_x = Dim("dim0_x", min=2, max=1024) 2277 dim0_y = dim0_x + 1 2278 dynamic_shapes = {"x": {0: dim0_x}, "y": {0: dim0_y}} 2279 self.check_model( 2280 Model(), 2281 (a, b), 2282 dynamic_shapes=dynamic_shapes, 2283 ) 2284 2285 def test_scatter_fallback(self): 2286 class Model(torch.nn.Module): 2287 def __init__(self) -> None: 2288 super().__init__() 2289 2290 def forward( 2291 self, 2292 inp: torch.Tensor, 2293 index: torch.Tensor, 2294 src: torch.Tensor, 2295 ): 2296 return torch.scatter(inp, 1, index, src) 2297 2298 inputs = ( 2299 torch.ones((3, 5), device=self.device, dtype=torch.int64), 2300 torch.tensor([[0, 1, 2, 0]], device=self.device, dtype=torch.int64), 2301 torch.zeros((2, 5), device=self.device, dtype=torch.int64), 2302 ) 2303 2304 self.check_model(Model(), inputs) 2305 2306 def test_scatter_reduce_fallback(self): 2307 class Model(torch.nn.Module): 2308 def __init__(self) -> None: 2309 super().__init__() 2310 2311 def forward( 2312 self, 2313 inp: torch.Tensor, 2314 index: torch.Tensor, 2315 src: torch.Tensor, 2316 ): 2317 return torch.scatter_reduce(inp, 0, index, src, reduce="sum") 2318 2319 inputs = ( 2320 torch.tensor([1, 10, 100, 1000], device=self.device, dtype=torch.int64), 2321 torch.tensor([0, 1, 0, 1, 2, 1], device=self.device, dtype=torch.int64), 2322 torch.tensor([1, 2, 3, 4, 5, 6], device=self.device, dtype=torch.int64), 2323 ) 2324 2325 self.check_model(Model(), inputs) 2326 2327 def test_index_put_fallback(self): 2328 # index_put falls back in the deterministic mode 2329 with DeterministicGuard(True): 2330 2331 class Model(torch.nn.Module): 2332 def __init__(self) -> None: 2333 super().__init__() 2334 2335 def forward( 2336 self, 2337 self_tensor: torch.Tensor, 2338 indices: Tuple[torch.Tensor], 2339 values: torch.Tensor, 2340 ): 2341 return torch.index_put( 2342 self_tensor, indices, values, accumulate=True 2343 ) 2344 2345 inputs = ( 2346 torch.ones(4, device=self.device, dtype=torch.int64), 2347 (torch.tensor([1, 1, 2, 2], device=self.device, dtype=torch.bool),), 2348 torch.ones(4, device=self.device, dtype=torch.int64), 2349 ) 2350 2351 self.check_model(Model(), inputs) 2352 2353 def test_repeated_user_defined_triton_kernel(self): 2354 if self.device != "cuda": 2355 raise unittest.SkipTest("requires CUDA") 2356 2357 class Model(torch.nn.Module): 2358 def __init__(self) -> None: 2359 super().__init__() 2360 2361 def forward(self, x): 2362 for _ in range(3): 2363 mul2_inplace_kernel[4,](x, n_elements=4, BLOCK_SIZE=16) 2364 return x 2365 2366 inputs = (torch.randn(4, 4, device=self.device),) 2367 self.check_model(Model(), inputs) 2368 2369 def test_convolution(self): 2370 class Model(torch.nn.Module): 2371 def __init__(self) -> None: 2372 super().__init__() 2373 2374 def forward(self, x, w, b): 2375 return torch.ops.aten.convolution(x, w, b, [4], [0], [1], True, [0], 1) 2376 2377 example_inputs = ( 2378 torch.randn([2, 32, 90], device=self.device), 2379 torch.randn([32, 16, 8], device=self.device), 2380 torch.randn([16], device=self.device), 2381 ) 2382 with config.patch( 2383 { 2384 "max_autotune": True, 2385 "max_autotune_gemm_backends": "Triton", 2386 } 2387 ): 2388 self.check_model(Model(), example_inputs) 2389 2390 def test_zero_size_weight(self): 2391 class Model(torch.nn.Module): 2392 def __init__(self, channel, r=8): 2393 super().__init__() 2394 self.pool = torch.nn.AdaptiveAvgPool2d(1) 2395 self.net = torch.nn.Sequential( 2396 torch.nn.Linear(channel, channel // r, bias=False), 2397 torch.nn.ReLU(inplace=True), 2398 torch.nn.Linear(channel // r, channel, bias=False), 2399 torch.nn.Sigmoid(), 2400 ) 2401 2402 def forward(self, inp): 2403 b, c, _, _ = inp.shape 2404 x = self.pool(inp).view(b, c) 2405 x = self.net(x).view(b, c, 1, 1) 2406 x = inp * x 2407 return x 2408 2409 inputs = (torch.rand(4, 4, 4, 4, device=self.device),) 2410 self.check_model(Model(4), inputs) 2411 2412 def test_no_args(self): 2413 class Model(torch.nn.Module): 2414 def __init__(self, m, n): 2415 super().__init__() 2416 self.weight = torch.nn.Parameter( 2417 torch.randn(m, n), 2418 ) 2419 self.alpha = torch.nn.Parameter(torch.randn(m, n)) 2420 2421 def forward(self): 2422 return self.weight * self.alpha 2423 2424 self.check_model(Model(6, 4), ()) 2425 2426 def test_dynamic_scalar(self): 2427 class Model(torch.nn.Module): 2428 def __init__(self) -> None: 2429 super().__init__() 2430 self.criterion_ce = torch.nn.CrossEntropyLoss(reduction="none") 2431 2432 def forward(self, inputs, targets, split_index=None): 2433 statistics = {} 2434 total_loss = self.criterion_ce(inputs, targets).sum() 2435 statistics["dl"] = total_loss.item() 2436 return total_loss, statistics 2437 2438 inputs = ( 2439 torch.rand(4, 4, 4, 4, device=self.device), 2440 torch.rand(4, 4, 4, 4, device=self.device), 2441 ) 2442 self.check_model(Model(), inputs) 2443 2444 def test_constant_original_fqn_and_dtype(self): 2445 class FooBarModule(torch.nn.Module): 2446 def __init__(self) -> None: 2447 super().__init__() 2448 self.register_parameter("0", torch.nn.Parameter(torch.randn(3, 4))) 2449 self.test_buf = torch.nn.Buffer(torch.randn(3, 4)) 2450 self.register_parameter( 2451 "test_param", torch.nn.Parameter(torch.randn(3, 4)) 2452 ) 2453 2454 def forward(self, x): 2455 return ((x + self.test_buf) * getattr(self, "0")) / self.test_param 2456 2457 class TestModule(torch.nn.Module): 2458 def __init__(self) -> None: 2459 super().__init__() 2460 self.foo_bar = FooBarModule() 2461 self.register_parameter( 2462 "test_param", torch.nn.Parameter(torch.randn(3, 4)) 2463 ) 2464 self.test_buf = torch.nn.Buffer(torch.randn(3, 4)) 2465 2466 def forward(self, x): 2467 return (self.foo_bar(x) + self.test_param) * self.test_buf 2468 2469 with torch.no_grad(): 2470 so_path = AOTIRunnerUtil.compile( 2471 model=TestModule().to(device=self.device), 2472 example_inputs=(torch.rand(3, 4, device=self.device),), 2473 ) 2474 2475 runner = AOTIRunnerUtil.load_runner(self.device, so_path) 2476 2477 expected_original_fqns = { 2478 "L__self___test_param": "test_param", 2479 "L__self___test_buf": "test_buf", 2480 "getattr_L__self___foo_bar___0__": "foo_bar.0", 2481 "L__self___foo_bar_test_param": "foo_bar.test_param", 2482 "L__self___foo_bar_test_buf": "foo_bar.test_buf", 2483 } 2484 self.assertEqual( 2485 expected_original_fqns, runner.get_constant_names_to_original_fqns() 2486 ) 2487 2488 expected_dtypes = { 2489 "L__self___test_param": 6, 2490 "L__self___test_buf": 6, 2491 "getattr_L__self___foo_bar___0__": 6, 2492 "L__self___foo_bar_test_param": 6, 2493 "L__self___foo_bar_test_buf": 6, 2494 } 2495 self.assertEqual(expected_dtypes, runner.get_constant_names_to_dtypes()) 2496 2497 def test_fqn(self): 2498 class NestedChild(torch.nn.Module): 2499 def __init__(self) -> None: 2500 super().__init__() 2501 self.nestedchild3buffer = torch.nn.Buffer(torch.ones(2, 3) * 3) 2502 2503 def forward(self, x): 2504 return x / self.nestedchild3buffer 2505 2506 class Child1(torch.nn.Module): 2507 def __init__(self) -> None: 2508 super().__init__() 2509 self.nested = NestedChild() 2510 self.register_parameter( 2511 "child1param", torch.nn.Parameter(torch.ones(2, 3)) 2512 ) 2513 2514 def forward(self, x): 2515 x = self.nested(x) 2516 return x + self.child1param 2517 2518 class Child2(torch.nn.Module): 2519 def __init__(self) -> None: 2520 super().__init__() 2521 self.child2buffer = torch.nn.Buffer(torch.ones(2, 3) * 2) 2522 2523 def forward(self, x): 2524 return x - self.child2buffer 2525 2526 class MyModule(torch.nn.Module): 2527 def __init__(self) -> None: 2528 super().__init__() 2529 self.foo = Child1() 2530 self.bar = Child2() 2531 self.register_parameter( 2532 "rootparam", torch.nn.Parameter(torch.ones(2, 3) * 4) 2533 ) 2534 2535 def forward(self, x): 2536 x = x * self.rootparam 2537 x = self.foo(x) 2538 x = self.bar(x) 2539 return x 2540 2541 orig_eager = MyModule() 2542 2543 self.check_model(MyModule(), (torch.randn(2, 3, device=self.device),)) 2544 2545 def test_model_modified_weights(self): 2546 class Model(torch.nn.Module): 2547 def __init__(self, n, k, device): 2548 super().__init__() 2549 self.weight = torch.randn(n, k, device=device) 2550 self.bias = torch.randn(n, device=device) 2551 2552 def forward(self, a): 2553 return torch.nn.functional.linear(a, self.weight, self.bias) 2554 2555 M = 16 2556 N = 10 2557 K = 128 2558 batch = 8 2559 example_inputs = (torch.randn(2, M, K, device=self.device),) 2560 model = Model(N, K, self.device) 2561 self.check_model(model, example_inputs) 2562 # Update model weights, after this AOTInductor should re-generate model.so 2563 # if weights are stored in the model.so 2564 model.weight += 1 2565 self.check_model(model, example_inputs) 2566 2567 def test_custom_op_add(self) -> None: 2568 class M(torch.nn.Module): 2569 def forward(self, x, y): 2570 return torch.ops.aoti_custom_ops.custom_add(x, y) 2571 2572 m = M().to(device=self.device) 2573 args = ( 2574 torch.randn(3, 3, device=self.device), 2575 torch.randn(3, 3, device=self.device), 2576 ) 2577 self.check_model(m, args) 2578 2579 def test_custom_op_all_inputs(self) -> None: 2580 class MyModel(torch.nn.Module): 2581 # pyre-fixme[3]: Return type must be annotated. 2582 def __init__(self): 2583 super().__init__() 2584 2585 # pyre-fixme[3]: Return type must be annotated. 2586 # pyre-fixme[2]: Parameter must be annotated. 2587 def forward(self, x, y): 2588 with torch.no_grad(): 2589 x_dim0 = x.shape[0] 2590 x_dim1 = x.shape[1] 2591 y_dim0 = y.shape[0] 2592 y_dim1 = y.shape[1] 2593 symint_0 = x_dim0 + x_dim1 2594 symint_1 = y_dim0 * y_dim1 2595 2596 z = torch.concat((x, x)) 2597 2598 _2547 = torch.ops.aoti_custom_ops.fn_with_all_inputs( 2599 tensor=x, 2600 tensors=[x, y], 2601 optional_tensors=[None, z], 2602 b8=False, 2603 b8s=[True, False], 2604 i64=42, 2605 i64s=[16, 17], 2606 symint=symint_0, 2607 symints=[symint_0, symint_1], 2608 f64=3.14, 2609 f64s=[2.2, 3.3], 2610 scalar=1.23, 2611 scalars=[45, 67], 2612 string="hello", 2613 strings=["ab", "cde"], 2614 # dtype=torch.float16, 2615 # memory_format=torch.contiguous_format, 2616 # layout=torch.strided, 2617 device=torch.device("cpu"), 2618 # optional 2619 o_tensor=None, 2620 o_tensors=[x, y], 2621 o_b8=False, 2622 o_b8s=[True, False], 2623 o_i64=None, 2624 o_i64s=[16, 17], 2625 o_symint=symint_1, 2626 o_symints=[symint_1, symint_0], 2627 o_f64=3.14, 2628 o_f64s=None, 2629 o_scalar=None, 2630 o_scalars=[89, 910], 2631 o_string="hello", 2632 o_strings=["ab", "cde"], 2633 # o_dtype=None, 2634 # o_memory_format=torch.contiguous_format, 2635 # o_layout=torch.strided, 2636 o_device=None, 2637 ) 2638 2639 return _2547 2640 2641 m = MyModel().to(device=self.device) 2642 x = torch.zeros(4, 8, device=self.device) 2643 y = torch.ones(3, 9, device=self.device) 2644 args = (x, y) 2645 m(*args) 2646 2647 self.check_model(m, args) 2648 2649 def test_custom_op_with_multiple_outputs(self) -> None: 2650 class Model(torch.nn.Module): 2651 def forward(self, x, y): 2652 out = x + y 2653 # tuple of Tensor output 2654 out3, out4 = torch.ops.aoti_custom_ops.fn_with_tuple_output(out, 1) 2655 # TensorList output 2656 out5, out6 = torch.ops.aoti_custom_ops.fn_with_list_output( 2657 [out3, out4], 1 2658 ) 2659 # tuple of Tensor and TensorList 2660 out7, [out8, out9] = torch.ops.aoti_custom_ops.fn_with_mix_outputs( 2661 out5, [out6, out4] 2662 ) 2663 return out3, out4, out5, out6, out7, out8, out9 2664 2665 m = Model().to(device=self.device) 2666 args = ( 2667 torch.randn(4, 4, device=self.device), 2668 torch.randn(4, 4, device=self.device), 2669 ) 2670 m(*args) 2671 2672 self.check_model(m, args) 2673 2674 def test_custom_op_with_reinterpret_view_inputs(self) -> None: 2675 class Model(torch.nn.Module): 2676 def forward(self, x): 2677 out = x.permute([1, 0]) 2678 return torch.ops.aoti_custom_ops.fn_with_default_input(out, 1) 2679 2680 m = Model().to(device=self.device) 2681 args = (torch.randn(2, 3, device=self.device),) 2682 2683 self.check_model(m, args) 2684 2685 def test_custom_op_with_concat_inputs(self) -> None: 2686 class Model(torch.nn.Module): 2687 def forward(self, x, y): 2688 out = torch.concat([x, y], dim=0) 2689 return torch.ops.aoti_custom_ops.fn_with_default_input(out, 1) 2690 2691 m = Model().to(device=self.device) 2692 args = ( 2693 torch.randn(2, 3, device=self.device), 2694 torch.randn(2, 3, device=self.device), 2695 ) 2696 2697 self.check_model(m, args) 2698 2699 def test_custom_op_missing_arg_with_default_value(self) -> None: 2700 class Model(torch.nn.Module): 2701 def forward(self, x): 2702 # missing second arg 2703 return torch.ops.aoti_custom_ops.fn_with_default_input(x) 2704 2705 m = Model().to(device=self.device) 2706 args = (torch.randn(2, 3, device=self.device),) 2707 2708 self.check_model(m, args) 2709 2710 def test_triton_kernel_extern_kernel_arg(self): 2711 if self.device != "cuda": 2712 raise unittest.SkipTest("requires CUDA") 2713 2714 class Model(torch.nn.Module): 2715 def forward(self, x, y): 2716 out = torch.zeros_like(x) 2717 # torch.mm is ExternKernelOut 2718 add_kernel[(4,)](x, torch.mm(x, y), out, 4, 16) 2719 return out 2720 2721 example_inputs = ( 2722 torch.randn(4, 4, device="cuda"), 2723 torch.randn(4, 4, device="cuda"), 2724 ) 2725 2726 self.check_model(Model(), example_inputs) 2727 2728 def test_triton_kernel_multi_output_arg(self): 2729 if self.device != "cuda": 2730 raise unittest.SkipTest("requires CUDA") 2731 2732 class Model(torch.nn.Module): 2733 def forward(self, x, y): 2734 out = torch.zeros_like(x) 2735 # torch.sort creates fallback kernel and hence MultiOutput 2736 add_kernel[(4,)](x, torch.sort(y).values, out, 4, 16) 2737 return out 2738 2739 example_inputs = ( 2740 torch.randn(4, 4, device="cuda"), 2741 torch.randn(4, 4, device="cuda"), 2742 ) 2743 2744 self.check_model(Model(), example_inputs) 2745 2746 @config.patch({"abi_compatible": True}) 2747 def test_triton_kernel_reinterpret_view_mem_leak(self): 2748 # Check for memory leak when using user-defined Triton Kernel + AOTI. 2749 if self.device != "cuda": 2750 raise unittest.SkipTest("requires CUDA") 2751 2752 class Model(torch.nn.Module): 2753 def __init__(self) -> None: 2754 super().__init__() 2755 2756 def forward(self, x, y): 2757 out = torch.zeros_like(x) 2758 yy = y * y 2759 # reshape creates a ReinterpretView 2760 add_kernel[(4,)](x, yy.reshape_as(x), out, 4, 16) 2761 return out 2762 2763 example_inputs = ( 2764 torch.randn(4, 4, device="cuda"), 2765 torch.randn(1, 16, device="cuda"), 2766 ) 2767 2768 so_path: str = AOTIRunnerUtil.compile( 2769 Model(), 2770 example_inputs, 2771 ) 2772 aot_inductor_module = AOTIRunnerUtil.load("cuda", so_path) 2773 2774 # Don't assign outputs to a variable b/c it will allocate GPU memory. 2775 device: int = torch.cuda.current_device() 2776 mem_before = torch.cuda.memory_allocated(device) 2777 aot_inductor_module(*example_inputs) 2778 aot_inductor_module(*example_inputs) 2779 mem_after = torch.cuda.memory_allocated(device) 2780 self.assertEqual(mem_before, mem_after) 2781 2782 actual = aot_inductor_module(*example_inputs) 2783 expected = Model()(*example_inputs) 2784 torch.testing.assert_close(actual, expected) 2785 2786 @torch._dynamo.config.patch(capture_scalar_outputs=True) 2787 @common_utils.parametrize("dynamic", [False, True]) 2788 @common_utils.parametrize("autotuning", [False, True]) 2789 def test_triton_kernel_unbacked_symint_in_grid(self, dynamic, autotuning): 2790 if self.device != "cuda": 2791 raise unittest.SkipTest("requires CUDA") 2792 2793 class Model(torch.nn.Module): 2794 def forward(self, x, y, n_elements_tensor): 2795 output = torch.zeros_like(x) 2796 n_elements_symint = n_elements_tensor.item() 2797 n_elements = x.numel() 2798 2799 def grid(meta): 2800 return (triton.cdiv(n_elements_symint, meta["BLOCK_SIZE"]),) 2801 2802 if autotuning: 2803 add_kernel_autotuned[grid]( 2804 x, 2805 y, 2806 output, 2807 n_elements, 2808 ) 2809 else: 2810 add_kernel[grid]( 2811 x, 2812 y, 2813 output, 2814 n_elements, 2815 BLOCK_SIZE=16, 2816 ) 2817 2818 return output 2819 2820 example_inputs = ( 2821 torch.randn(123, device="cuda"), 2822 torch.randn(123, device="cuda"), 2823 torch.tensor(123), 2824 ) 2825 2826 dynamic_shapes = None 2827 if dynamic: 2828 dim0 = Dim("s0", min=2, max=1024) 2829 dynamic_shapes = { 2830 "x": {0: dim0}, 2831 "y": {0: dim0}, 2832 "n_elements_tensor": {}, 2833 } 2834 2835 self.check_model( 2836 Model(), 2837 example_inputs, 2838 dynamic_shapes=dynamic_shapes, 2839 ) 2840 2841 @skipIfRocm # USE_MEM_EFF_ATTENTION was not enabled for build. 2842 def test_scaled_dot_product_efficient_attention(self): 2843 if self.device != "cuda": 2844 raise unittest.SkipTest("requires CUDA") 2845 2846 class Model(torch.nn.Module): 2847 def forward(self, q, k, v, attn_bias): 2848 return torch.ops.aten._scaled_dot_product_efficient_attention( 2849 q, k, v, attn_bias, False 2850 )[0] 2851 2852 example_inputs = ( 2853 torch.randn(4, 4, 36, 36, device="cuda"), 2854 torch.randn(4, 4, 36, 36, device="cuda"), 2855 torch.randn(4, 4, 36, 36, device="cuda"), 2856 torch.randn(4, 4, 36, 36, device="cuda"), 2857 ) 2858 self.check_model(Model(), example_inputs) 2859 2860 def test_index_put_with_none_index(self): 2861 # index_put falls back in the deterministic mode 2862 with DeterministicGuard(True): 2863 2864 class Model(torch.nn.Module): 2865 def forward(self, x, i1, i2, y): 2866 return torch.ops.aten.index_put( 2867 x, 2868 (None, None, i1, i2.transpose(0, 1)), 2869 y, 2870 accumulate=True, 2871 ) 2872 2873 example_inputs = ( 2874 torch.rand(8, 192, 30, 30, device=self.device), 2875 torch.zeros(3, 14, 1, 1, dtype=torch.int64, device=self.device), 2876 torch.ones(14, 3, dtype=torch.int64, device=self.device), 2877 torch.randn(8, 192, 3, 14, 3, 14, device=self.device), 2878 ) 2879 self.check_model(Model(), example_inputs) 2880 2881 def test_runtime_checks(self): 2882 class Model(torch.nn.Module): 2883 def __init__(self) -> None: 2884 super().__init__() 2885 2886 def forward(self, x0, x1, x2, x3, x4, x5, x6, x7, x8, x9): 2887 return (x0, x1, x2, x3, x4, x5, x6, x7, x8, x9) 2888 2889 inputs = [] 2890 for dtype in ( 2891 torch.float16, 2892 torch.float32, 2893 torch.float64, 2894 torch.bfloat16, 2895 torch.bool, 2896 torch.int8, 2897 torch.int16, 2898 torch.int32, 2899 torch.int64, 2900 torch.uint8, 2901 ): 2902 inputs.append(torch.ones(4, 8, 10, dtype=dtype, device=self.device)) 2903 dim0 = Dim("s0", min=2, max=1024) 2904 dim1 = Dim("s1", min=2, max=512) 2905 dim2 = Dim("s2", min=2, max=128) 2906 dynamic_shapes = { 2907 "x0": {0: dim0}, 2908 "x1": {0: dim0}, 2909 "x2": {0: dim0}, 2910 "x3": {1: dim1}, 2911 "x4": {1: dim1}, 2912 "x5": {1: dim1}, 2913 "x6": {}, 2914 "x7": {2: dim2}, 2915 "x8": {2: dim2}, 2916 "x9": {2: dim2}, 2917 } 2918 m = Model() 2919 inputs = tuple(inputs) 2920 with torch.no_grad(), config.patch( 2921 { 2922 "abi_compatible": self.abi_compatible, 2923 "aot_inductor.debug_compile": True, 2924 } 2925 ): 2926 so_path = AOTIRunnerUtil.compile(m, inputs, dynamic_shapes=dynamic_shapes) 2927 with open(os.path.splitext(so_path)[0] + ".cpp") as cpp: 2928 src_code = cpp.read() 2929 FileCheck().check_count( 2930 "unmatched dtype", 2931 10, 2932 exactly=True, 2933 ).run(src_code) 2934 FileCheck().check_count( 2935 "unmatched dim value at", 2936 21, # we have 9 dynamic dims for which we generate different checks 2937 exactly=True, 2938 ).run(src_code) 2939 FileCheck().check_count( 2940 "dim value is too", 2941 18, # we have 9 dynamic dims for which we generate two checks 2942 exactly=True, 2943 ).run(src_code) 2944 FileCheck().check_count( 2945 "unmatched stride value at", 2946 21, # we have 9 symbolic strides for which we don't generate checks 2947 exactly=True, 2948 ).run(src_code) 2949 optimized = AOTIRunnerUtil.load(self.device, so_path) 2950 actual = optimized(*inputs) 2951 expected = m(*inputs) 2952 torch.testing.assert_close(actual, expected) 2953 2954 @unittest.skipIf(TEST_WITH_ROCM, "FP8 is not supported on ROCM") 2955 @unittest.skipIf(not SM90OrLater, "FP8 is only supported on H100+") 2956 def test_runtime_checks_fp8(self): 2957 class Model(torch.nn.Module): 2958 def __init__(self) -> None: 2959 super().__init__() 2960 2961 def forward(self, x0, x1): 2962 t = x0.to(torch.float) + x1.to(torch.float) 2963 return t 2964 2965 inputs = [] 2966 for dtype in ( 2967 torch.float8_e4m3fn, 2968 torch.float8_e5m2, 2969 # FP8 funz are for AMD 2970 # see https://github.com/pytorch/pytorch/issues/126734 2971 # torch.float8_e4m3fnuz, 2972 # torch.float8_e5m2fnuz, 2973 ): 2974 inputs.append(torch.ones(8, 8, 8, dtype=dtype, device=self.device)) 2975 dim0 = Dim("s0", min=2, max=1024) 2976 dynamic_shapes = { 2977 "x0": {0: dim0}, 2978 "x1": {0: dim0}, 2979 } 2980 with torch.no_grad(), config.patch( 2981 { 2982 "abi_compatible": self.abi_compatible, 2983 "aot_inductor.debug_compile": True, 2984 } 2985 ): 2986 self.check_model( 2987 Model(), 2988 tuple(inputs), 2989 dynamic_shapes=dynamic_shapes, 2990 ) 2991 2992 def test_runtime_checks_complex(self): 2993 class Model(torch.nn.Module): 2994 def __init__(self) -> None: 2995 super().__init__() 2996 2997 def forward(self, x0, x1, x2): 2998 return (x0, x1, x2) 2999 3000 inputs = [] 3001 x0 = torch.tensor([1, -1], dtype=torch.complex32, device=self.device) 3002 x1 = torch.tensor( 3003 [1 + 1j, -1 + 1j, -2 + 2j, 3 - 3j, 0, 1j, 1, -1], 3004 dtype=torch.complex64, 3005 device=self.device, 3006 ) 3007 x2 = torch.tensor(128, dtype=torch.complex128, device=self.device) 3008 inputs.append(x0) 3009 inputs.append(x1) 3010 inputs.append(x2) 3011 dim0 = Dim("s0", min=2, max=1024) 3012 dynamic_shapes = { 3013 "x0": {0: dim0}, 3014 "x1": {}, 3015 "x2": {}, 3016 } 3017 with torch.no_grad(), config.patch( 3018 { 3019 "abi_compatible": self.abi_compatible, 3020 "aot_inductor.debug_compile": True, 3021 } 3022 ): 3023 self.check_model( 3024 Model(), 3025 tuple(inputs), 3026 dynamic_shapes=dynamic_shapes, 3027 ) 3028 3029 @unittest.skipIf(IS_FBCODE, "Not yet runnable in fbcode") 3030 def test_runtime_checks_dtype_failed(self): 3031 class Model(torch.nn.Module): 3032 def __init__(self) -> None: 3033 super().__init__() 3034 3035 def forward(self, x): 3036 y = x.type(torch.float) 3037 return y 3038 3039 x = torch.randn(1, 4, dtype=torch.float16, device=self.device) 3040 model = Model() 3041 with torch.no_grad(), config.patch( 3042 { 3043 "abi_compatible": self.abi_compatible, 3044 "aot_inductor.debug_compile": True, 3045 } 3046 ): 3047 so_path: str = AOTIRunnerUtil.compile( 3048 model, 3049 (x,), 3050 ) 3051 aot_inductor_module = AOTIRunnerUtil.load(self.device, so_path) 3052 x_casted = x.float() 3053 with self.assertRaisesRegex(Exception, ""): 3054 aot_inductor_module(x_casted) 3055 3056 def test_non_contiguous_output_alias(self): 3057 # Test return x, x.contiguous() where x is non-contiguous. 3058 class Model(torch.nn.Module): 3059 def forward(self, x): 3060 squared = x * x 3061 transposed = squared.t() # non-contiguous 3062 contig = transposed.contiguous() 3063 return transposed, contig 3064 3065 x = torch.randn(3, 4, dtype=torch.float16, device=self.device) 3066 model = Model() 3067 with torch.no_grad(), config.patch( 3068 { 3069 "abi_compatible": self.abi_compatible, 3070 } 3071 ): 3072 result = AOTIRunnerUtil.run( 3073 self.device, 3074 model, 3075 (x,), 3076 ) 3077 actual = model(x) 3078 self.assertTrue(same(result, actual)) 3079 3080 # contiguous() should create a new tensor 3081 self.assertTrue(result[0].data_ptr() != result[1].data_ptr()) 3082 3083 def test_multiple_output_alias(self): 3084 # Test when mutliple outputs alias the same tensor 3085 class Model(torch.nn.Module): 3086 def forward(self, x): 3087 squared = x * x 3088 contig = squared.contiguous() # alias 3089 reshaped = squared.reshape(squared.shape) # alias 3090 cubed = squared * x 3091 return squared, contig, reshaped, cubed 3092 3093 x = torch.randn(3, 4, dtype=torch.float32, device=self.device) 3094 model = Model() 3095 3096 with torch.no_grad(), config.patch( 3097 { 3098 "abi_compatible": self.abi_compatible, 3099 } 3100 ): 3101 result = AOTIRunnerUtil.run( 3102 self.device, 3103 model, 3104 (x,), 3105 ) 3106 actual = model(x) 3107 self.assertTrue(same(result, actual)) 3108 3109 # squared, contig and reshaped alias the same tensor. 3110 self.assertTrue(result[0].data_ptr() == result[1].data_ptr()) 3111 self.assertTrue(result[0].data_ptr() == result[2].data_ptr()) 3112 # cubed shouldn't be an alias. 3113 self.assertTrue(result[0].data_ptr() != result[3].data_ptr()) 3114 3115 def test_runtime_checks_shape_failed(self): 3116 class Model(torch.nn.Module): 3117 def __init__(self) -> None: 3118 super().__init__() 3119 3120 def forward(self, x): 3121 return x 3122 3123 x = torch.randn(4, 4, 4, dtype=torch.float16, device=self.device) 3124 y0 = torch.randn(8, 4, 4, dtype=torch.float16, device=self.device) 3125 y1 = torch.randn(4, 8, 4, dtype=torch.float16, device=self.device) 3126 y2 = rand_strided( 3127 (4, 4, 4), (16, 1, 4), dtype=torch.float16, device=self.device 3128 ) 3129 # batch size is outside of the range 3130 y3 = torch.randn(2048, 3, 4, dtype=torch.float16, device=self.device) 3131 y4 = torch.randn(2048, 4, 4, dtype=torch.float16, device=self.device) 3132 dim0 = Dim("s0", min=4, max=1024) 3133 dynamic_shapes = { 3134 "x": {0: dim0}, 3135 } 3136 model = Model() 3137 with torch.no_grad(), config.patch( 3138 { 3139 "abi_compatible": self.abi_compatible, 3140 "aot_inductor.debug_compile": True, 3141 } 3142 ): 3143 so_path: str = AOTIRunnerUtil.compile( 3144 model, (x,), dynamic_shapes=dynamic_shapes 3145 ) 3146 aot_inductor_module = AOTIRunnerUtil.load(self.device, so_path) 3147 # dynamic dim works fine 3148 _ = aot_inductor_module(y0) 3149 with self.assertRaisesRegex(Exception, ""): 3150 aot_inductor_module(y1) 3151 with self.assertRaisesRegex(Exception, ""): 3152 aot_inductor_module(y2) 3153 with self.assertRaisesRegex(Exception, ""): 3154 aot_inductor_module(y3) 3155 with self.assertRaisesRegex(Exception, ""): 3156 aot_inductor_module(y4) 3157 3158 def test_add_complex(self): 3159 class Model(torch.nn.Module): 3160 def forward(self, a, b): 3161 return torch.add(a, b) 3162 3163 x = torch.tensor( 3164 [1 + 1j, -1 + 1j, -2 + 2j, 3 - 3j, 0, 1j, 1, -1], device=self.device 3165 ) 3166 y = torch.tensor( 3167 [1 + 1j, -1 + 1j, -2 + 2j, 3 - 3j, 0, 1j, 1, -1], device=self.device 3168 ) 3169 self.check_model(Model(), (x, y)) 3170 3171 def test_embedding_bag(self): 3172 class Model(torch.nn.Module): 3173 def forward(self, w, i, o): 3174 return torch.ops.aten._embedding_bag(w, i, o, False, 0, False, None) 3175 3176 example_inputs = ( 3177 torch.randn([10, 4], device=self.device), 3178 torch.randint(10, [8], device=self.device), 3179 torch.tensor([0, 2, 6], device=self.device), 3180 ) 3181 self.check_model(Model(), example_inputs) 3182 3183 def test_fft_c2c(self): 3184 class Model(torch.nn.Module): 3185 def forward(self, x): 3186 return torch.fft.fftn(x), torch.fft.fftn(x).real 3187 3188 example_inputs = (torch.randn(16, 16, 16, device=self.device),) 3189 self.check_model(Model(), example_inputs) 3190 3191 def test_bool_input(self): 3192 # Specialize on whichever branch the example input for b is 3193 class Model(torch.nn.Module): 3194 def forward(self, x, b): 3195 if b: 3196 return x * x 3197 else: 3198 return x + x 3199 3200 example_inputs = (torch.randn(3, 3, device=self.device), True) 3201 self.check_model(Model(), example_inputs) 3202 3203 def test_int_list_input(self): 3204 class Model(torch.nn.Module): 3205 def forward(self, x, i): 3206 return x * i[0] * i[1] 3207 3208 example_inputs = (torch.randn(3, 3, device=self.device), [3, 4]) 3209 self.check_model(Model(), example_inputs) 3210 3211 def test_nested_tensor_from_jagged(self): 3212 class Model(nn.Module): 3213 def __init__(self) -> None: 3214 super().__init__() 3215 self.mlp = nn.Sequential( 3216 nn.Linear(128, 64), nn.ReLU(), nn.Linear(64, 32), nn.Sigmoid() 3217 ) 3218 3219 def forward(self, values, offsets): 3220 nt = torch.nested.nested_tensor_from_jagged(values, offsets) 3221 res = self.mlp(nt) 3222 return res.values() 3223 3224 model = Model().to(device=self.device) 3225 3226 example_inputs_1 = ( 3227 torch.randn((15, 128), device=self.device), 3228 torch.tensor([0, 3, 4, 10, 15], device=self.device), 3229 ) 3230 3231 # same "NT batch size", different actual amount of data 3232 example_inputs_2 = ( 3233 torch.randn((31, 128), device=self.device), 3234 torch.tensor([0, 1, 20, 25, 31], device=self.device), 3235 ) 3236 3237 # same actual amount of data, different "NT batch size" 3238 example_inputs_3 = ( 3239 torch.randn((15, 128), device=self.device), 3240 torch.tensor([0, 3, 10, 15], device=self.device), 3241 ) 3242 3243 # different "NT batch size" 3244 example_inputs_4 = ( 3245 torch.randn((37, 128), device=self.device), 3246 torch.tensor([0, 5, 16, 25, 29, 37], device=self.device), 3247 ) 3248 3249 dim0_values = Dim("dim0_values", min=1, max=128) 3250 dim0_offsets = Dim("dim0_offsets", min=1, max=9) 3251 dynamic_shapes = {"values": {0: dim0_values}, "offsets": {0: dim0_offsets}} 3252 example_inputs_list = [ 3253 example_inputs_1, 3254 example_inputs_2, 3255 example_inputs_3, 3256 example_inputs_4, 3257 ] 3258 3259 self.check_model_with_multiple_inputs( 3260 model, example_inputs_list, dynamic_shapes=dynamic_shapes 3261 ) 3262 3263 @common_utils.parametrize("max_autotune", [False, True]) 3264 def test_misc_1(self, max_autotune): 3265 if self.device == "cpu" and IS_MACOS and max_autotune: 3266 raise unittest.SkipTest("max_autotune not supported on macos") 3267 3268 class Model(nn.Module): 3269 def __init__(self) -> None: 3270 super().__init__() 3271 self.mlp = nn.Sequential( 3272 nn.Linear(128, 64), nn.ReLU(), nn.Linear(64, 32), nn.Sigmoid() 3273 ) 3274 self.emb = nn.EmbeddingBag(num_embeddings=128, embedding_dim=32) 3275 self.over_arch = nn.Sequential( 3276 nn.Linear(64, 32), nn.ReLU(), nn.Linear(32, 32), nn.Sigmoid() 3277 ) 3278 3279 def forward(self, x, y): 3280 mlp_output = self.mlp(x) 3281 emb_output = self.emb(y) 3282 return self.over_arch(torch.concat([mlp_output, emb_output], dim=1)) 3283 3284 example_inputs = ( 3285 torch.randn(16, 128, device=self.device), 3286 torch.randint(0, 128, (16, 10), device=self.device), 3287 ) 3288 self.check_model( 3289 Model(), example_inputs, options=dict(max_autotune=max_autotune) 3290 ) 3291 3292 def test_aoti_debug_printer_codegen(self): 3293 # basic addmm model to test codegen for aoti intermediate debug printer 3294 class Model(torch.nn.Module): 3295 def __init__(self, n, k, device): 3296 super().__init__() 3297 self.weight = torch.randn(n, k, device=device) 3298 self.bias = torch.randn(n, device=device) 3299 3300 def forward(self, a): 3301 return torch.nn.functional.linear(a, self.weight, self.bias) 3302 3303 M = 8 3304 N = 6 3305 K = 16 3306 model = Model(N, K, self.device) 3307 batch = 2 3308 a = torch.randn(batch, M, K, device=self.device) 3309 example_inputs = (a,) 3310 3311 kernel_calls = ( 3312 [ 3313 ("triton_poi_fused_0", 1), 3314 ("aoti_torch_cuda_addmm_out", 2), 3315 ] 3316 if self.device == "cuda" 3317 else [ 3318 ("aoti_torch_cpu_addmm_out", 2), 3319 ] 3320 ) 3321 3322 # test default debug printing all tensor values codegen 3323 with config.patch({"aot_inductor.debug_intermediate_value_printer": "2"}): 3324 result, code = run_and_get_cpp_code( 3325 AOTIRunnerUtil.compile, model, example_inputs 3326 ) 3327 3328 # check the c shim print_tensor_handle call is triggered by the config and injected the cpp output code as expected 3329 self.assertEqual("aoti_torch_print_tensor_handle" in code, True) 3330 3331 # check the codegen for debug printing around the actual kernel call is expected 3332 3333 for kernel_call, count in kernel_calls: 3334 FileCheck().check_count( 3335 f"before_launch - {kernel_call}", 3336 count, 3337 ).run(code) 3338 FileCheck().check_count( 3339 f"after_launch - {kernel_call}", 3340 count, 3341 ).run(code) 3342 3343 # test printing selected kernel's tensor values codegen 3344 filtered_kernel_name = f"aoti_torch_{self.device}_addmm_out" 3345 with config.patch( 3346 { 3347 "aot_inductor.debug_intermediate_value_printer": "2", 3348 "aot_inductor.filtered_kernel_names": filtered_kernel_name, 3349 } 3350 ): 3351 result, code = run_and_get_cpp_code( 3352 AOTIRunnerUtil.compile, model, example_inputs 3353 ) 3354 filtered_kernel_calls = [ 3355 (filtered_kernel_name, 2), 3356 ] 3357 for kernel_call, count in filtered_kernel_calls: 3358 FileCheck().check_count( 3359 f"before_launch - {kernel_call}", 3360 count, 3361 ).run(code) 3362 FileCheck().check_count( 3363 f"after_launch - {kernel_call}", 3364 count, 3365 ).run(code) 3366 3367 kernel_calls_not_to_print = [ 3368 kernel_call 3369 for kernel_call in kernel_calls 3370 if kernel_call[0] != filtered_kernel_name 3371 ] 3372 for kernel_name, _ in kernel_calls_not_to_print: 3373 FileCheck().check_not(f"before_launch - {kernel_name}").run(code) 3374 FileCheck().check_not(f"after_launch - {kernel_name}").run(code) 3375 3376 def test_aoti_debug_printer_user_defined_triton_kernel(self): 3377 if self.device != "cuda": 3378 raise unittest.SkipTest("requires CUDA") 3379 3380 class Model(torch.nn.Module): 3381 def __init__(self) -> None: 3382 super().__init__() 3383 3384 def forward(self, x, y): 3385 out = torch.zeros_like(x) 3386 add_kernel[(4,)](x, y, out, n_elements=4, BLOCK_SIZE=16) 3387 return out 3388 3389 example_inputs = ( 3390 torch.randn(4, 4, device=self.device), 3391 torch.randn(4, 4, device=self.device), 3392 ) 3393 3394 kernel_calls = [ 3395 ("add_kernel_0", 3), 3396 ] 3397 3398 with config.patch({"aot_inductor.debug_intermediate_value_printer": "2"}): 3399 result, code = run_and_get_cpp_code( 3400 AOTIRunnerUtil.compile, Model(), example_inputs 3401 ) 3402 # check the c shim print_tensor_handle call is triggered by the config and injected the cpp output code as expected 3403 self.assertEqual("aoti_torch_print_tensor_handle" in code, True) 3404 # check the codegen for debug printing around the actual kernel call is expected 3405 for kernel_call, count in kernel_calls: 3406 FileCheck().check_count( 3407 f"before_launch - {kernel_call}", 3408 count, 3409 ).run(code) 3410 FileCheck().check_count( 3411 f"after_launch - {kernel_call}", 3412 count, 3413 ).run(code) 3414 3415 def test_size_from_multi_output(self): 3416 class Model(torch.nn.Module): 3417 def __init__(self): 3418 super().__init__() 3419 self.relu = torch.nn.ReLU() 3420 3421 def forward(self, x): 3422 _x, _i = torch.unique(x, sorted=True, return_inverse=True) 3423 _x = _x.clone().detach() 3424 return self.relu(_x), _i 3425 3426 example_inputs = (torch.randn(8, device=self.device),) 3427 self.check_model(Model(), example_inputs) 3428 3429 3430common_utils.instantiate_parametrized_tests(AOTInductorTestsTemplate) 3431 3432 3433class AOTITestCase(TestCase): 3434 def setUp(self): 3435 if IS_SANDCASTLE or IS_FBCODE: 3436 torch.ops.load_library("//caffe2/test/inductor:custom_ops") 3437 elif IS_MACOS: 3438 raise unittest.SkipTest("non-portable load_library call used in test") 3439 else: 3440 lib_file_path = find_library_location("libaoti_custom_ops.so") 3441 if IS_WINDOWS: 3442 lib_file_path = find_library_location("aoti_custom_ops.dll") 3443 torch.ops.load_library(str(lib_file_path)) 3444 super().setUp() 3445 3446 3447class AOTInductorTestABICompatibleCpu(AOTITestCase): 3448 device = "cpu" 3449 abi_compatible = True 3450 check_model = check_model 3451 check_model_with_multiple_inputs = check_model_with_multiple_inputs 3452 code_check_count = code_check_count 3453 allow_stack_allocation = False 3454 use_minimal_arrayref_interface = False 3455 3456 3457def fail_with_and_without_stack_allocation(is_skip=False): 3458 return TestFailure( 3459 ( 3460 "abi_compatible_cpu", 3461 "abi_compatible_cpu_with_stack_allocation", 3462 "abi_compatible_cpu_with_stack_allocation_and_minimal_arrayref_interface", 3463 ), 3464 is_skip=is_skip, 3465 ) 3466 3467 3468def fail_stack_allocation(is_skip=False): 3469 return TestFailure( 3470 ( 3471 "abi_compatible_cpu_with_stack_allocation", 3472 "abi_compatible_cpu_with_stack_allocation_and_minimal_arrayref_interface", 3473 ), 3474 is_skip=is_skip, 3475 ) 3476 3477 3478def fail_minimal_arrayref_interface(is_skip=False): 3479 return TestFailure( 3480 ("abi_compatible_cpu_with_stack_allocation_and_minimal_arrayref_interface",), 3481 is_skip=is_skip, 3482 ) 3483 3484 3485def fail_cuda(is_skip=False): 3486 return TestFailure( 3487 ("abi_compatible_cuda", "non_abi_compatible_cuda"), 3488 is_skip=is_skip, 3489 ) 3490 3491 3492def fail_abi_compatible_cuda(is_skip=False): 3493 return TestFailure( 3494 ("abi_compatible_cuda",), 3495 is_skip=is_skip, 3496 ) 3497 3498 3499def fail_non_abi_compatible_cuda(is_skip=False): 3500 return TestFailure( 3501 ("non_abi_compatible_cuda",), 3502 is_skip=is_skip, 3503 ) 3504 3505 3506# test_failures, xfail by default, set is_skip=True to skip 3507CPU_TEST_FAILURES = { 3508 # TODO: error: ‘complex64’ was not declared in this scope 3509 "test_add_complex": fail_minimal_arrayref_interface(is_skip=True), 3510 # TODO: test_conv_freezing_abi_compatible_cpu fails, 3511 # AssertionError: None, i.e. optional output is not supported 3512 "test_conv_freezing": fail_with_and_without_stack_allocation(is_skip=True), 3513 # TODO: test_deconv_freezing_abi_compatible_cpu fails, 3514 # AssertionError: None, i.e. optional output is not supported 3515 "test_deconv_freezing": fail_with_and_without_stack_allocation(is_skip=True), 3516 # FIXME: failed with Segfault while exiting the Python runtime 3517 "test_duplicate_constant_folding": fail_with_and_without_stack_allocation( 3518 is_skip=True 3519 ), 3520 # TODO: use of deleted function RAIIAtenTensorHandle 3521 "test_dup_unbacked_sym_decl": fail_minimal_arrayref_interface(is_skip=True), 3522 # TODO: use of deleted function RAIIAtenTensorHandle 3523 "test_dup_unbacked_sym_decl_with_refinement": fail_minimal_arrayref_interface( 3524 is_skip=True 3525 ), 3526 # TODO: error: cannot convert ArrayRefTensor<float> to AtenTensorHandle 3527 "test_dynamic_cat": fail_minimal_arrayref_interface(), 3528 # https://github.com/pytorch/pytorch/issues/129550 3529 # https://github.com/pytorch/pytorch/issues/123691 3530 "test_dynamic_scalar": fail_minimal_arrayref_interface(is_skip=True), 3531 # https://github.com/pytorch/pytorch/issues/122980 3532 "test_fft_c2c": fail_stack_allocation(is_skip=True), 3533 # TODO: test_freezing_abi_compatible_cpu fails, 3534 # AssertionError: None, i.e. optional output is not supported 3535 "test_freezing": fail_with_and_without_stack_allocation(is_skip=True), 3536 # TODO: test_linear_freezing_abi_compatible_cpu fails, 3537 # AssertionError: None, i.e. optional output is not supported 3538 "test_linear_freezing": fail_with_and_without_stack_allocation(is_skip=True), 3539 # FIXME: failed with Segfault while exiting the Python runtime 3540 "test_missing_cubin": fail_with_and_without_stack_allocation(is_skip=True), 3541 # minimal arrayref interface only works with CPU; test crashes. 3542 # https://github.com/pytorch/pytorch/issues/122983 3543 "test_multi_device": fail_minimal_arrayref_interface(is_skip=True), 3544 # TODO: AssertionError: unsupported Optional type in convert_arg_type: Generator 3545 "test_normal_functional": fail_with_and_without_stack_allocation(is_skip=True), 3546 # TODO: The same issue as https://github.com/pytorch/pytorch/issues/122978 3547 # error: cannot convert ArrayRefTensor<float> to AtenTensorHandle 3548 "test_reuse_kernel_dynamic": fail_minimal_arrayref_interface(is_skip=True), 3549 # the test segfaults 3550 "test_repeat_output": fail_stack_allocation(is_skip=True), 3551 # TODO: failed internally 3552 "test_multiple_output_alias": fail_with_and_without_stack_allocation(is_skip=True), 3553 # segfault 3554 "test_buffer_mutation_1": fail_stack_allocation(is_skip=True), 3555 # segfault 3556 "test_buffer_mutation_2": fail_stack_allocation(is_skip=True), 3557 # segfault 3558 "test_bool_input": fail_stack_allocation(is_skip=True), 3559 # segfault 3560 "test_int_list_input": fail_stack_allocation(is_skip=True), 3561 # segfault 3562 # 'AOTInductorTestABICompatibleCpuWithStackAllocation' object has no attribute 'code_check_count' 3563 "test_buffer_mutation_3": fail_stack_allocation(is_skip=True), 3564 # FIXME: failed with Segfault while exiting the Python runtime 3565 "test_scatter_fallback": fail_stack_allocation(is_skip=True), 3566 # Looks like the same issue as https://github.com/pytorch/pytorch/issues/122978 3567 "test_scatter_reduce_fallback": fail_minimal_arrayref_interface(is_skip=True), 3568 # Looks like the same issue as https://github.com/pytorch/pytorch/issues/122978 3569 "test_index_put_fallback": fail_minimal_arrayref_interface(is_skip=True), 3570 # https://github.com/pytorch/pytorch/issues/122984 3571 "test_index_put_with_none_index": fail_minimal_arrayref_interface(is_skip=True), 3572 # FIXME: failed with Segfault while exiting the Python runtime 3573 "test_constant": fail_stack_allocation(is_skip=True), 3574 # Looks like the same issue as https://github.com/pytorch/pytorch/issues/122978 3575 "test_shifted_constraint_ranges": fail_with_and_without_stack_allocation( 3576 is_skip=True 3577 ), 3578 # https://github.com/pytorch/pytorch/issues/123691 3579 "test_amp_fallback_random": fail_minimal_arrayref_interface(is_skip=True), 3580 "test_simple_dynamic": fail_minimal_arrayref_interface(), 3581 # https://github.com/pytorch/pytorch/issues/123691 3582 "test_zero_grid_with_unbacked_symbols": fail_minimal_arrayref_interface( 3583 is_skip=True 3584 ), 3585 # failed on MacOS 3586 "test_zero_grid_with_backed_symbols": fail_with_and_without_stack_allocation( 3587 is_skip=True 3588 ), 3589 # https://github.com/pytorch/pytorch/issues/122990 3590 "test_cond_non_tensor_predicates_dynamic_False": fail_stack_allocation( 3591 is_skip=True 3592 ), 3593 # same issue as https://github.com/pytorch/pytorch/issues/122990 3594 "test_cond_non_tensor_predicates_dynamic_True": fail_stack_allocation(is_skip=True), 3595 # https://github.com/pytorch/pytorch/issues/122991 3596 "test_runtime_checks_complex": fail_with_and_without_stack_allocation(is_skip=True), 3597 "test_runtime_checks_fp8": fail_with_and_without_stack_allocation(is_skip=True), 3598 "test_while_loop_simple": fail_stack_allocation(is_skip=True), 3599 "test_while_loop_nested": fail_stack_allocation(is_skip=True), 3600 "test_while_loop_with_outer_code": fail_stack_allocation(is_skip=True), 3601 # TODO: error: cannot convert ArrayRefTensor<float> to AtenTensorHandle 3602 "test_while_loop_with_outer_buffers": fail_stack_allocation(is_skip=True), 3603 # TODO: use of undeclared identifier 'float8_e4m3fn' and 'half' 3604 "test_fp8": fail_minimal_arrayref_interface(is_skip=True), 3605 "test_custom_op_add": fail_minimal_arrayref_interface(is_skip=True), 3606 "test_custom_op_all_inputs": fail_minimal_arrayref_interface(is_skip=True), 3607 "test_custom_op_with_multiple_outputs": fail_minimal_arrayref_interface( 3608 is_skip=True 3609 ), 3610 "test_custom_op_with_reinterpret_view_inputs": fail_minimal_arrayref_interface( 3611 is_skip=True 3612 ), 3613 "test_custom_op_with_concat_inputs": fail_minimal_arrayref_interface(is_skip=True), 3614 "test_custom_op_missing_arg_with_default_value": fail_minimal_arrayref_interface( 3615 is_skip=True 3616 ), 3617 "test_size_from_multi_output": fail_stack_allocation(is_skip=True), 3618} 3619 3620# test_failures, xfail by default, set is_skip=True to skip 3621CUDA_TEST_FAILURES = { 3622 # TODO: AssertionError: unsupported Optional type in convert_arg_type: Generator 3623 "test_normal_functional": fail_abi_compatible_cuda(is_skip=True), 3624 # no runtime checks for non_abi_compatible mode 3625 "test_runtime_checks": fail_non_abi_compatible_cuda(is_skip=True), 3626 "test_runtime_checks_complex": fail_non_abi_compatible_cuda(is_skip=True), 3627 "test_runtime_checks_fp8": fail_non_abi_compatible_cuda(is_skip=True), 3628 "test_runtime_checks_dtype_failed": fail_non_abi_compatible_cuda(is_skip=True), 3629 "test_runtime_checks_shape_failed": fail_non_abi_compatible_cuda(is_skip=True), 3630 # quantized unsupported for GPU 3631 "test_quantized_linear": fail_cuda(is_skip=True), 3632 "test_quanatized_int8_linear": fail_cuda(is_skip=True), 3633 "test_custom_op_add": fail_non_abi_compatible_cuda(is_skip=True), 3634 # fp8 to be re-enabled for AOTI 3635 "test_fp8": fail_cuda(is_skip=True), 3636 "test_custom_op_all_inputs": fail_non_abi_compatible_cuda(is_skip=True), 3637 "test_custom_op_missing_arg_with_default_value": fail_non_abi_compatible_cuda( 3638 is_skip=True 3639 ), 3640 "test_custom_op_with_concat_inputs": fail_non_abi_compatible_cuda(is_skip=True), 3641 "test_custom_op_with_reinterpret_view_inputs": fail_non_abi_compatible_cuda( 3642 is_skip=True 3643 ), 3644 "test_custom_op_with_multiple_outputs": fail_non_abi_compatible_cuda(is_skip=True), 3645 # non-abi compatible mode aoti debug printer is not supported yet 3646 "test_aoti_debug_printer_codegen": fail_non_abi_compatible_cuda(is_skip=True), 3647 "test_aoti_debug_printer_user_defined_triton_kernel": fail_non_abi_compatible_cuda( 3648 is_skip=True 3649 ), 3650} 3651 3652 3653if not IS_FBCODE: 3654 # The following tests look like they pass in both pytest and unittest (xml 3655 # and terminal output say pass), but the process will segfault. This only 3656 # happens in OSS CI and is fine internally. 3657 CPU_TEST_FAILURES.update( 3658 { 3659 "test_duplicated_params": fail_stack_allocation(is_skip=True), 3660 "test_embedding_bag": fail_stack_allocation(is_skip=True), 3661 "test_fqn": fail_stack_allocation(is_skip=True), 3662 "test_no_args": fail_stack_allocation(is_skip=True), 3663 "test_output_misaligned": fail_stack_allocation(is_skip=True), 3664 "test_pytree_inputs": fail_stack_allocation(is_skip=True), 3665 "test_seq": fail_stack_allocation(is_skip=True), 3666 "test_simple_split": fail_stack_allocation(is_skip=True), 3667 "test_addmm": fail_minimal_arrayref_interface(is_skip=True), 3668 "test_aliased_buffer_reuse": fail_minimal_arrayref_interface(is_skip=True), 3669 "test_buffer_reuse": fail_minimal_arrayref_interface(is_skip=True), 3670 "test_constant_folding": fail_minimal_arrayref_interface(is_skip=True), 3671 "test_convolution": fail_minimal_arrayref_interface(is_skip=True), 3672 "test_empty_graph": fail_minimal_arrayref_interface(is_skip=True), 3673 "test_large_weight": fail_minimal_arrayref_interface(is_skip=True), 3674 "test_large_mmaped_weights": fail_minimal_arrayref_interface(is_skip=True), 3675 "test_normal_functional": fail_minimal_arrayref_interface(is_skip=True), 3676 "test_misc_1": fail_minimal_arrayref_interface(is_skip=True), 3677 "test_missing_output": fail_minimal_arrayref_interface(is_skip=True), 3678 "test_model_modified_weights": fail_minimal_arrayref_interface( 3679 is_skip=True 3680 ), 3681 "test_output_path_1": fail_minimal_arrayref_interface(is_skip=True), 3682 "test_quantized_linear": fail_minimal_arrayref_interface(is_skip=True), 3683 "test_quanatized_int8_linear": fail_minimal_arrayref_interface( 3684 is_skip=True 3685 ), 3686 "test_repeat_interleave": fail_minimal_arrayref_interface(is_skip=True), 3687 "test_return_constant": fail_minimal_arrayref_interface(is_skip=True), 3688 "test_reuse_kernel": fail_minimal_arrayref_interface(is_skip=True), 3689 "test_simple": fail_minimal_arrayref_interface(is_skip=True), 3690 "test_small_constant": fail_minimal_arrayref_interface(is_skip=True), 3691 "test_with_no_triton_profiler": fail_minimal_arrayref_interface( 3692 is_skip=True 3693 ), 3694 "test_with_offset": fail_minimal_arrayref_interface(is_skip=True), 3695 "test_with_profiler": fail_minimal_arrayref_interface(is_skip=True), 3696 "test_zero_size_weight": fail_minimal_arrayref_interface(is_skip=True), 3697 "test_aoti_debug_printer_codegen": fail_with_and_without_stack_allocation( 3698 is_skip=True 3699 ), 3700 } 3701 ), 3702 # The following test passes internally but fails in OSS CI. To be investigated. 3703 CUDA_TEST_FAILURES.update( 3704 { 3705 "test_aoti_debug_printer_codegen": fail_cuda(is_skip=True), 3706 "test_aoti_debug_printer_user_defined_triton_kernel": fail_cuda( 3707 is_skip=True 3708 ), 3709 } 3710 ) 3711 3712copy_tests( 3713 AOTInductorTestsTemplate, 3714 AOTInductorTestABICompatibleCpu, 3715 "abi_compatible_cpu", 3716 CPU_TEST_FAILURES, 3717) 3718 3719 3720class AOTInductorTestABICompatibleCpuWithStackAllocation(AOTITestCase): 3721 device = "cpu" 3722 abi_compatible = True 3723 check_model = check_model 3724 check_model_with_multiple_inputs = check_model_with_multiple_inputs 3725 code_check_count = code_check_count 3726 allow_stack_allocation = True 3727 use_minimal_arrayref_interface = False 3728 3729 3730copy_tests( 3731 AOTInductorTestsTemplate, 3732 AOTInductorTestABICompatibleCpuWithStackAllocation, 3733 "abi_compatible_cpu_with_stack_allocation", 3734 CPU_TEST_FAILURES, 3735) 3736 3737 3738class AOTInductorTestABICompatibleCpuWithStackAllocationAndMinimalArrayRefInterface( 3739 TestCase 3740): 3741 device = "cpu" 3742 abi_compatible = True 3743 check_model = check_model 3744 check_model_with_multiple_inputs = check_model_with_multiple_inputs 3745 allow_stack_allocation = True 3746 use_minimal_arrayref_interface = True 3747 3748 3749copy_tests( 3750 AOTInductorTestsTemplate, 3751 AOTInductorTestABICompatibleCpuWithStackAllocationAndMinimalArrayRefInterface, 3752 "abi_compatible_cpu_with_stack_allocation_and_minimal_arrayref_interface", 3753 CPU_TEST_FAILURES, 3754) 3755 3756 3757@unittest.skipIf(sys.platform == "darwin", "No CUDA on MacOS") 3758class AOTInductorTestABICompatibleCuda(AOTITestCase): 3759 device = "cuda" 3760 abi_compatible = True 3761 check_model = check_model 3762 check_model_with_multiple_inputs = check_model_with_multiple_inputs 3763 code_check_count = code_check_count 3764 allow_stack_allocation = False 3765 use_minimal_arrayref_interface = False 3766 3767 3768copy_tests( 3769 AOTInductorTestsTemplate, 3770 AOTInductorTestABICompatibleCuda, 3771 "abi_compatible_cuda", 3772 CUDA_TEST_FAILURES, 3773) 3774 3775 3776@unittest.skipIf( 3777 IS_FBCODE or sys.platform == "darwin", 3778 "NonABI mode should not be used in fbcode nor on MacOS", 3779) 3780class AOTInductorTestNonABICompatibleCpu(AOTITestCase): 3781 device = "cpu" 3782 abi_compatible = False 3783 check_model = check_model 3784 check_model_with_multiple_inputs = check_model_with_multiple_inputs 3785 code_check_count = code_check_count 3786 allow_stack_allocation = False 3787 use_minimal_arrayref_interface = False 3788 3789 3790copy_tests( 3791 AOTInductorTestsTemplate, 3792 AOTInductorTestNonABICompatibleCpu, 3793 "non_abi_compatible_cpu", 3794 # test_failures, xfail by default, set is_skip=True to skip 3795 { 3796 "test_duplicate_constant_folding": TestFailure( 3797 ("non_abi_compatible_cpu",), is_skip=True 3798 ), 3799 # no runtime checks for non_abi_compatible mode 3800 "test_runtime_checks": TestFailure(("non_abi_compatible_cpu",), is_skip=True), 3801 "test_runtime_checks_dtype_failed": TestFailure( 3802 ("non_abi_compatible_cpu",), is_skip=True 3803 ), 3804 "test_runtime_checks_shape_failed": TestFailure( 3805 ("non_abi_compatible_cpu",), is_skip=True 3806 ), 3807 "test_custom_op_add": TestFailure(("non_abi_compatible_cpu",), is_skip=True), 3808 "test_aoti_debug_printer_codegen": TestFailure( 3809 ("non_abi_compatible_cpu",), is_skip=True 3810 ), 3811 "test_custom_op_all_inputs": TestFailure( 3812 ("non_abi_compatible_cpu",), is_skip=True 3813 ), 3814 "test_custom_op_missing_arg_with_default_value": TestFailure( 3815 ("non_abi_compatible_cpu",), is_skip=True 3816 ), 3817 "test_custom_op_with_concat_inputs": TestFailure( 3818 ("non_abi_compatible_cpu",), is_skip=True 3819 ), 3820 "test_custom_op_with_multiple_outputs": TestFailure( 3821 ("non_abi_compatible_cpu",), is_skip=True 3822 ), 3823 "test_custom_op_with_reinterpret_view_inputs": TestFailure( 3824 ("non_abi_compatible_cpu",), is_skip=True 3825 ), 3826 }, 3827) 3828 3829 3830@unittest.skipIf( 3831 IS_FBCODE or sys.platform == "darwin", 3832 "NonABI mode should not be used in fbcode nor on MacOS", 3833) 3834class AOTInductorTestNonABICompatibleCuda(AOTITestCase): 3835 device = "cuda" 3836 abi_compatible = False 3837 check_model = check_model 3838 check_model_with_multiple_inputs = check_model_with_multiple_inputs 3839 code_check_count = code_check_count 3840 allow_stack_allocation = False 3841 use_minimal_arrayref_interface = False 3842 3843 3844copy_tests( 3845 AOTInductorTestsTemplate, 3846 AOTInductorTestNonABICompatibleCuda, 3847 "non_abi_compatible_cuda", 3848 CUDA_TEST_FAILURES, 3849) 3850 3851 3852if __name__ == "__main__": 3853 from torch._inductor.test_case import run_tests 3854 3855 # cpp_extension N/A in fbcode 3856 if HAS_CUDA or sys.platform == "darwin": 3857 run_tests(needs="filelock") 3858