1# Owner(s): ["module: inductor"] 2# flake8: noqa: E731 3# Skip do not assign a lambda expression, use a def 4import functools 5from unittest.mock import patch 6 7import torch 8import torch._dynamo.testing 9import torch._inductor.test_case 10from torch._higher_order_ops.triton_kernel_wrap import ( 11 generate_ttir, 12 triton_kernel_wrapper_functional, 13 triton_kernel_wrapper_mutation, 14) 15from torch._inductor import metrics 16from torch._inductor.utils import run_and_get_code 17from torch._library import capture_triton 18from torch.testing._internal import common_utils 19from torch.testing._internal.common_utils import skipIfRocm, skipIfXpu, TEST_WITH_ROCM 20from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CUDA, HAS_GPU, HAS_XPU 21from torch.testing._internal.logging_utils import logs_to_string 22 23# Defines all the kernels for tests 24from torch.testing._internal.triton_utils import * # noqa: F403 25from torch.utils._triton import has_triton_package 26 27 28if HAS_GPU: 29 import triton 30 from triton import language as tl 31 32 if not TEST_WITH_ROCM: 33 if HAS_CUDA: 34 from triton.language.extra.cuda.libdevice import ( 35 fast_dividef, 36 fast_dividef as my_fast_dividef, 37 ) 38 elif HAS_XPU: 39 from triton.language.extra.intel.libdevice import ( 40 fast_dividef, 41 fast_dividef as my_fast_dividef, 42 ) 43 44 # Define shared triton constants here. 45 CONSTANT_C: tl.constexpr = 4 46 STRING_CONSTANT_C: tl.constexpr = "CONSTANT_C" 47 BOOL_CONSTANT_C: tl.constexpr = True 48 49 50class KernelTests(torch._inductor.test_case.TestCase): 51 @requires_gpu 52 def test_triton_kernel_with_kernel_param(self): 53 @triton.jit 54 def pass_kernel(kernel): 55 pass 56 57 @torch.compile(backend="eager") 58 def f(x): 59 grid = (x.numel(),) 60 pass_kernel[grid](kernel=x) 61 62 t1 = torch.rand(5, device=GPU_TYPE) 63 f(t1) 64 # No need to assert anything, the goal is to make sure dynamo does 65 # not crash 66 67 @requires_gpu 68 def test_triton_kernel_higher_order_func(self): 69 from torch._higher_order_ops.triton_kernel_wrap import kernel_side_table 70 71 add_kernel_id = kernel_side_table.add_kernel(add_kernel) 72 73 t1 = torch.rand(5, device=GPU_TYPE) 74 t2 = torch.rand(5, device=GPU_TYPE) 75 76 torch_add = t1 + t2 77 78 # Test higher order function with mutation 79 output = torch.zeros_like(t1) 80 n_elements = output.numel() 81 constant_args_idx = kernel_side_table.add_constant_args( 82 {"n_elements": n_elements, "BLOCK_SIZE": 16} 83 ) 84 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 85 triton_kernel_wrapper_mutation( 86 kernel_idx=add_kernel_id, 87 constant_args_idx=constant_args_idx, 88 grid=[grid], 89 kwargs={ 90 "in_ptr0": t1, 91 "in_ptr1": t2, 92 "out_ptr": output, 93 }, 94 ) 95 self.assertEqual(output, torch_add) 96 # Make sure it is modified 97 self.assertNotEqual(output, torch.zeros_like(t1)) 98 99 # Test higher order function without mutation 100 output = torch.zeros_like(t1) 101 out_dict = triton_kernel_wrapper_functional( 102 kernel_idx=add_kernel_id, 103 constant_args_idx=constant_args_idx, 104 grid=[grid], 105 kwargs={ 106 "in_ptr0": t1, 107 "in_ptr1": t2, 108 "out_ptr": output, 109 }, 110 tensors_to_clone=["in_ptr0", "in_ptr1", "out_ptr"], 111 ) 112 self.assertEqual(out_dict["out_ptr"], torch_add) 113 # Make sure it is NOT modified 114 self.assertEqual(output, torch.zeros_like(t1)) 115 116 @requires_gpu 117 def test_triton_kernel_functionalize(self): 118 from functorch import make_fx 119 from torch._higher_order_ops.triton_kernel_wrap import kernel_side_table 120 from torch._subclasses.functional_tensor import ( 121 CppFunctionalizeAPI, 122 FunctionalTensorMode, 123 PythonFunctionalizeAPI, 124 ) 125 126 kernel_side_table.reset_table() 127 128 def f(x, output): 129 out = triton_kernel_wrapper_functional( 130 kernel_idx=kernel_side_table.add_kernel(mul2_kernel), 131 constant_args_idx=kernel_side_table.add_constant_args( 132 {"n_elements": output.numel(), "BLOCK_SIZE": 16} 133 ), 134 grid=[(x.numel(),)], 135 kwargs={ 136 "in_ptr0": x, 137 "out_ptr": output, 138 }, 139 tensors_to_clone=["in_ptr0", "out_ptr"], 140 ) 141 return out["out_ptr"] 142 143 t1 = torch.rand(5, device=GPU_TYPE) 144 t2 = torch.rand(5, device=GPU_TYPE) 145 with FunctionalTensorMode(): 146 gm = make_fx(PythonFunctionalizeAPI().functionalize(f))(t1, t2) 147 # Make sure t2 was not modified 148 self.assertNotEqual(gm(t1, t2), t2) 149 150 gm = make_fx(CppFunctionalizeAPI().functionalize(f))(t1, t2) 151 # Make sure t2 was not modified 152 self.assertNotEqual(gm(t1, t2), t2) 153 154 gm = make_fx(torch.func.functionalize(f))(t1, t2) 155 # Make sure t2 was not modified 156 self.assertNotEqual(gm(t1, t2), t2) 157 158 gm = make_fx(f, tracing_mode="fake")(t1, t2) 159 self.assertExpectedInline( 160 gm.code.strip(), 161 """\ 162def forward(self, x_1, output_1): 163 triton_kernel_wrapper_functional_proxy = torch.ops.higher_order.triton_kernel_wrapper_functional(kernel_idx = 0, constant_args_idx = 3, grid = [(5,)], kwargs = {'in_ptr0': x_1, 'out_ptr': output_1}, tensors_to_clone = ['in_ptr0', 'out_ptr']); x_1 = output_1 = None 164 getitem = triton_kernel_wrapper_functional_proxy['in_ptr0']; getitem = None 165 getitem_1 = triton_kernel_wrapper_functional_proxy['out_ptr']; triton_kernel_wrapper_functional_proxy = None 166 return getitem_1""", 167 ) 168 169 @requires_gpu 170 def test_triton_kernel_mutation_type(self): 171 from torch._higher_order_ops.triton_kernel_wrap import kernel_side_table 172 from torch._subclasses.fake_tensor import FakeTensorMode 173 from torch._subclasses.functional_tensor import ( 174 FunctionalTensor, 175 FunctionalTensorMode, 176 ) 177 178 def prep(): 179 x = torch.ones(4, device=GPU_TYPE, requires_grad=True) 180 with FunctionalTensorMode(): 181 x_func = FunctionalTensor.to_functional(x) 182 self.assertTrue(torch._is_functional_tensor(x_func.elem)) 183 return x_func 184 185 # normal mutation only 186 with FakeTensorMode(): 187 x_func = prep() 188 189 with FunctionalTensorMode(): 190 x_func.mul_(2) 191 192 self.assertFalse( 193 torch._functionalize_are_all_mutations_hidden_from_autograd(x_func.elem) 194 ) 195 196 # triton kernel mutation only 197 with FakeTensorMode(): 198 x_func = prep() 199 200 with FunctionalTensorMode(): 201 triton_kernel_wrapper_mutation( 202 kernel_idx=kernel_side_table.add_kernel(mul2_inplace_kernel), 203 constant_args_idx=kernel_side_table.add_constant_args( 204 {"n_elements": x_func.numel(), "BLOCK_SIZE": 16} 205 ), 206 grid=[(x_func.numel(),)], 207 kwargs={ 208 "ptr": x_func, 209 }, 210 ) 211 212 self.assertTrue( 213 torch._functionalize_are_all_mutations_hidden_from_autograd(x_func.elem) 214 ) 215 216 # normal mutation + triton kernel mutation 217 with FakeTensorMode(): 218 x_func = prep() 219 220 with FunctionalTensorMode(): 221 x_func.mul_(2) 222 triton_kernel_wrapper_mutation( 223 kernel_idx=kernel_side_table.add_kernel(mul2_inplace_kernel), 224 constant_args_idx=kernel_side_table.add_constant_args( 225 {"n_elements": x_func.numel(), "BLOCK_SIZE": 16} 226 ), 227 grid=[(x_func.numel(),)], 228 kwargs={ 229 "ptr": x_func, 230 }, 231 ) 232 233 self.assertFalse( 234 torch._functionalize_are_all_mutations_hidden_from_autograd(x_func.elem) 235 ) 236 237 @requires_gpu 238 @common_utils.parametrize("dynamic", [False, True]) 239 @common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"]) 240 def test_triton_kernel_with_views(self, dynamic, backend): 241 def call_triton_take_view(x: torch.Tensor): 242 output = torch.zeros_like(x) 243 n_elements = output.numel() 244 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 245 mul2_kernel[grid](x, output, n_elements, BLOCK_SIZE=16) 246 return output 247 248 def call_triton_return_view(x: torch.Tensor): 249 output = torch.zeros_like(x) 250 n_elements = output.numel() 251 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 252 mul2_kernel[grid](x, output, n_elements, BLOCK_SIZE=16) 253 return output.view(4, 4) 254 255 t = torch.rand(4, 4, device=GPU_TYPE) 256 t_view = t.view(16) 257 258 compiled_func = torch.compile( 259 call_triton_take_view, backend=backend, fullgraph=True, dynamic=dynamic 260 ) 261 self.assertEqual(2 * t_view, compiled_func(t_view)) 262 self.assertEqual(2 * t, compiled_func(t_view).view(4, 4)) 263 264 compiled_func = torch.compile( 265 call_triton_return_view, backend=backend, fullgraph=True, dynamic=dynamic 266 ) 267 self.assertEqual(2 * t_view, compiled_func(t).view(16)) 268 self.assertEqual(2 * t, compiled_func(t)) 269 270 @requires_gpu 271 def test_no_nan_kernels(self): 272 @triton.jit 273 def add_one_kernel( 274 in_ptr0, 275 out_ptr, 276 n_elements, 277 BLOCK_SIZE: "tl.constexpr", 278 ): 279 pid = tl.program_id(axis=0) 280 block_start = pid * BLOCK_SIZE 281 offsets = block_start + tl.arange(0, BLOCK_SIZE) 282 mask = offsets < n_elements 283 x = tl.load(in_ptr0 + offsets, mask=mask) 284 output = x + 1 285 tl.store(out_ptr + offsets, output, mask=mask) 286 287 def add_one(x, out): 288 n_elements = x.numel() 289 add_one_kernel[(n_elements,)](x, out, n_elements, BLOCK_SIZE=4) 290 291 class AddOne(torch.autograd.Function): 292 @staticmethod 293 def forward(ctx, x): 294 out = torch.empty_like(x) 295 add_one(x, out) 296 ctx.save_for_backward(out) 297 return out 298 299 @staticmethod 300 def backward(ctx, grad): 301 (saved,) = ctx.saved_tensors 302 out = torch.empty_like(grad) 303 add_one(saved, out) 304 return out 305 306 @torch.compile 307 def f(x): 308 return AddOne.apply(x) 309 310 log_stream, ctx = logs_to_string("torch._inductor.codecache", "output_code") 311 312 x = torch.randn(3, requires_grad=True, device=GPU_TYPE) 313 with ctx(): 314 y = f(x) 315 316 output_code = "\n".join(log_stream.getvalue().strip().split("\n")[3:]).strip() 317 self.assertTrue(len(output_code) > 0, msg="output code is not empty") 318 self.assertEqual(output_code.count('float("nan")'), 0) 319 self.assertEqual(output_code.count("float('nan')"), 0) 320 321 @requires_gpu 322 @common_utils.parametrize("grad_fn", [torch.no_grad, torch.enable_grad]) 323 @common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"]) 324 def test_triton_kernel_with_grad_option(self, grad_fn, backend): 325 def call_triton(x: torch.Tensor): 326 with grad_fn(): 327 output = torch.zeros_like(x) 328 n_elements = output.numel() 329 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 330 mul2_kernel[grid](x, output, n_elements, BLOCK_SIZE=16) 331 return output 332 333 t = torch.rand(5, device=GPU_TYPE) 334 compiled_func = torch.compile(call_triton, backend=backend, fullgraph=True) 335 self.assertEqual(2 * t, compiled_func(t)) 336 337 @requires_gpu 338 @common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"]) 339 def test_triton_kernel_inner_triton_function(self, backend): 340 def f(x: torch.Tensor): 341 @triton.jit 342 def pow2_kernel( 343 in_ptr0, 344 out_ptr, 345 n_elements, 346 BLOCK_SIZE: "tl.constexpr", 347 ): 348 pid = tl.program_id(axis=0) 349 block_start = pid * BLOCK_SIZE 350 offsets = block_start + tl.arange(0, BLOCK_SIZE) 351 mask = offsets < n_elements 352 x = tl.load(in_ptr0 + offsets, mask=mask) 353 output = x * x 354 tl.store(out_ptr + offsets, output, mask=mask) 355 356 output = torch.zeros_like(x) 357 n_elements = output.numel() 358 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 359 pow2_kernel[grid](x, output, n_elements, BLOCK_SIZE=16) 360 return output 361 362 t = torch.rand(5, device=GPU_TYPE) 363 364 compiled_func = torch.compile(f, backend=backend, fullgraph=True) 365 # TODO(oulgen): NYI - Support this 366 # self.assertEqual(t * t, compiled_func(t)) 367 368 @requires_gpu 369 @common_utils.parametrize("grad", [False, True]) 370 @common_utils.parametrize("dynamic", [False, True]) 371 @patch.object(torch._inductor.config, "implicit_fallbacks", False) 372 def test_triton_kernel_no_clones(self, grad, dynamic): 373 from torch._inductor.utils import run_and_get_code 374 375 def call_triton(x: torch.Tensor, y: torch.Tensor, output: torch.Tensor): 376 n_elements = output.numel() 377 378 tmp = torch.add(x, 1) 379 grid = (x.numel(),) 380 add_kernel.run( 381 x, y, output, n_elements, warmup=False, grid=grid, BLOCK_SIZE=16 382 ) 383 384 return output, tmp 385 386 t1 = torch.rand(5, device=GPU_TYPE, requires_grad=grad) 387 t2 = torch.rand(5, device=GPU_TYPE, requires_grad=grad) 388 o1 = torch.zeros_like(t1, requires_grad=grad) 389 390 torch_add = call_triton(t1, t2, o1) 391 metrics.reset() 392 o2 = torch.zeros_like(t1, requires_grad=grad) 393 test, codes = run_and_get_code( 394 torch.compile(call_triton, dynamic=dynamic), t1, t2, o2 395 ) 396 if not grad: 397 self.assertEqual(metrics.generated_kernel_count, 1) 398 self.assertEqual(torch_add, test) 399 # These two asserts are not optimal since it requires original aten 400 # to be in the metadata, so there might be false negatives 401 self.assertTrue("aten.copy" not in codes[0]) 402 self.assertTrue("aten.clone" not in codes[0]) 403 # The following checks that there are only the tensor output is in 404 # the compiled graph 405 if dynamic and grad: 406 self.assertTrue("return (buf0, s0, )" in codes[0]) 407 else: 408 self.assertTrue("return (buf0, )" in codes[0]) 409 410 @requires_gpu 411 def test_triton_kernel_caching(self): 412 from torch._inductor.utils import run_and_get_code 413 414 def add_in_loop( 415 x: torch.Tensor, 416 y: torch.Tensor, 417 ): 418 output = torch.zeros_like(x) 419 n_elements = output.numel() 420 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 421 add_kernel_autotuned[grid](x, y, output, n_elements) 422 return output 423 424 def call_triton_add( 425 x: torch.Tensor, 426 y: torch.Tensor, 427 ): 428 for i in range(4): 429 x = add_in_loop(x, y) 430 return x 431 432 t1 = torch.ones(5, device=GPU_TYPE) 433 t2 = torch.ones(5, device=GPU_TYPE) 434 435 test, (code,) = run_and_get_code(torch.compile(call_triton_add), t1, t2) 436 self.assertEqual(test, 5 * torch.ones(5, device=GPU_TYPE)) 437 self.assertTrue("add_kernel_autotuned_1.run" not in code) 438 439 @requires_gpu 440 def test_triton_kernel_caching_duplicate(self): 441 from torch._inductor.utils import run_and_get_code 442 443 class C: 444 @triton.jit 445 def pass_kernel( 446 in_ptr0, 447 out_ptr, 448 n_elements, 449 BLOCK_SIZE: "tl.constexpr", 450 ): 451 pid = tl.program_id(axis=0) 452 block_start = pid * BLOCK_SIZE 453 offsets = block_start + tl.arange(0, BLOCK_SIZE) 454 mask = offsets < n_elements 455 x = tl.load(in_ptr0 + offsets, mask=mask) 456 tl.store(out_ptr + offsets, x, mask=mask) 457 458 class D: 459 @triton.jit 460 def pass_kernel( 461 in_ptr0, 462 out_ptr, 463 n_elements, 464 BLOCK_SIZE: "tl.constexpr", 465 ): 466 pid = tl.program_id(axis=0) 467 block_start = pid * BLOCK_SIZE 468 offsets = block_start + tl.arange(0, BLOCK_SIZE) 469 mask = offsets < n_elements 470 x = tl.load(in_ptr0 + offsets, mask=mask) 471 tl.store(out_ptr + offsets, x, mask=mask) 472 473 def call_triton(x: torch.Tensor): 474 output1 = torch.zeros_like(x) 475 output2 = torch.zeros_like(x) 476 n_elements = output1.numel() 477 grid = (n_elements,) 478 C.pass_kernel[grid](x, output1, n_elements, BLOCK_SIZE=16) 479 D.pass_kernel[grid](x, output2, n_elements, BLOCK_SIZE=16) 480 return output1 + output2 481 482 t = torch.ones(5, device=GPU_TYPE) 483 test, (code,) = run_and_get_code(torch.compile(call_triton), t) 484 # Make sure we emitted two kernels here 485 self.assertTrue("pass_kernel_0.run" in code) 486 self.assertTrue("pass_kernel_1.run" in code) 487 488 @requires_gpu 489 def test_triton_kernel_various_args(self): 490 @triton.autotune( 491 configs=[triton.Config({"BLOCK_SIZE": 128})], 492 key=[], 493 ) 494 @triton.jit 495 def pass_kernel( 496 out_ptr, 497 n_elements, 498 dummy_None, 499 dummy_empty, 500 dummy_float, 501 BLOCK_SIZE: "tl.constexpr", 502 RANDOM_SIZE: "tl.constexpr", 503 ): 504 pass 505 506 @torch.compile 507 def call_triton(output): 508 n_elements = output.numel() 509 grid = (n_elements,) 510 pass_kernel[grid]( 511 output, 512 n_elements, 513 None, 514 torch.empty_like(output), 515 3.1415926, 516 RANDOM_SIZE=0, 517 ) 518 return output 519 520 output = torch.randn(5, device=GPU_TYPE) 521 # Make sure this does not crash 522 call_triton(output) 523 524 @requires_gpu 525 @skipIfRocm 526 def test_triton_kernel_dependancies(self): 527 def call_triton( 528 x: torch.Tensor, 529 y: torch.Tensor, 530 ): 531 output = torch.zeros_like(x) 532 n_elements = output.numel() 533 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 534 add_kernel_autotuned[grid](x, y, output, n_elements) 535 output2 = torch.zeros_like(output) 536 add_kernel_autotuned[grid](output, y, output2, n_elements) 537 output3 = torch.add(output2, 1) 538 return output3 539 540 t1 = torch.rand(5, device=GPU_TYPE) 541 t2 = torch.rand(5, device=GPU_TYPE) 542 torch_result = call_triton(t1, t2) 543 compiled_result = torch.compile(call_triton)(t1, t2) 544 self.assertEqual(torch_result, compiled_result) 545 546 @requires_gpu 547 def test_triton_kernel_reinplace_inplaceable_pass(self): 548 def call_triton( 549 x: torch.Tensor, 550 y: torch.Tensor, 551 ): 552 output = torch.zeros_like(x) 553 n_elements = output.numel() 554 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 555 add_kernel_autotuned[grid](x, y, output, n_elements) 556 add_kernel_autotuned[grid](output, x, output, n_elements) 557 return output 558 559 t1 = torch.rand(5, device=GPU_TYPE) 560 t2 = torch.rand(5, device=GPU_TYPE) 561 torch_result = call_triton(t1, t2) 562 compiled_result = torch.compile(call_triton)(t1, t2) 563 self.assertEqual(torch_result, compiled_result) 564 565 @requires_gpu 566 @common_utils.parametrize("grad", [False, True]) 567 def test_triton_kernel_multi_kernel(self, grad): 568 @triton.jit 569 def mul2_and_add_and_zero_negatives_kernel( 570 in_ptr0, 571 in_ptr1, 572 out_ptr, 573 n_elements, 574 BLOCK_SIZE: "tl.constexpr", 575 ACTIVATION: "tl.constexpr", 576 ): 577 pid = tl.program_id(axis=0) 578 block_start = pid * BLOCK_SIZE 579 offsets = block_start + tl.arange(0, BLOCK_SIZE) 580 mask = offsets < n_elements 581 indirection_kernel( 582 in_ptr0, 583 in_ptr0, 584 n_elements, 585 BLOCK_SIZE=BLOCK_SIZE, 586 ACTIVATION="mul2_inplace_kernel", 587 ) 588 indirection_kernel( 589 in_ptr1, 590 in_ptr1, 591 n_elements, 592 BLOCK_SIZE=BLOCK_SIZE, 593 ACTIVATION="mul2_inplace_kernel", 594 ) 595 x = tl.load(in_ptr0 + offsets, mask=mask) 596 y = tl.load(in_ptr1 + offsets, mask=mask) 597 output = x + y 598 if ACTIVATION == "zero_negs": 599 output = zero_negs(output) 600 tl.store(out_ptr + offsets, output, mask=mask) 601 602 @torch.compile 603 def call_triton( 604 x: torch.Tensor, 605 y: torch.Tensor, 606 xi: torch.Tensor, 607 yi: torch.Tensor, 608 output: torch.Tensor, 609 outputi: torch.Tensor, 610 ): 611 n_elements = output.numel() 612 613 grid = (x.numel(),) 614 mul2_and_add_and_zero_negatives_kernel[grid]( 615 x, y, output, n_elements, BLOCK_SIZE=16, ACTIVATION="zero_negs" 616 ) 617 mul2_and_add_and_zero_negatives_kernel[grid]( 618 xi, yi, outputi, n_elements, BLOCK_SIZE=16, ACTIVATION=None 619 ) 620 621 return (output, outputi) 622 623 t1 = torch.tensor( 624 [-2.0, -1.0, 0.0, 1.0, 2.0], device=GPU_TYPE, requires_grad=grad 625 ) 626 t2 = torch.tensor( 627 [-2.0, -1.0, 0.0, 1.0, 2.0], device=GPU_TYPE, requires_grad=grad 628 ) 629 float_result = 2 * t1 + 2 * t2 630 float_result = float_result.where(float_result >= 0, 0.0) 631 632 t1i = torch.randint(-2, 2, (5,), device=GPU_TYPE) 633 t2i = torch.randint(-2, 2, (5,), device=GPU_TYPE) 634 o = torch.zeros_like(t1, requires_grad=grad) 635 oi = torch.zeros_like(t1i) 636 int_result = 2 * t1i + 2 * t2i 637 638 (result, resulti) = call_triton(t1, t2, t1i, t2i, o, oi) 639 self.assertEqual(float_result, result) 640 self.assertEqual(int_result, resulti) 641 642 @requires_gpu 643 @skipIfXpu 644 @skipIfRocm 645 def test_triton_kernel_constants(self): 646 @triton.jit 647 def mulC_kernel( 648 in_ptr0, 649 out_ptr, 650 n_elements, 651 BLOCK_SIZE: "tl.constexpr", 652 CONSTANT_NAME: "tl.constexpr", 653 ): 654 pid = tl.program_id(axis=0) 655 block_start = pid * BLOCK_SIZE 656 offsets = block_start + tl.arange(0, BLOCK_SIZE) 657 mask = offsets < n_elements 658 x = tl.load(in_ptr0 + offsets, mask=mask) 659 if CONSTANT_NAME == STRING_CONSTANT_C: 660 output = CONSTANT_C * x 661 if BOOL_CONSTANT_C: 662 output *= CONSTANT_C 663 tl.store(out_ptr + offsets, output, mask=mask) 664 665 def call_triton( 666 x: torch.Tensor, 667 ): 668 output = torch.zeros_like(x) 669 n_elements = output.numel() 670 671 grid = (x.numel(),) 672 mulC_kernel[grid]( 673 x, output, n_elements, BLOCK_SIZE=16, CONSTANT_NAME="CONSTANT_C" 674 ) 675 return output 676 677 # Triton kernels capture global constants by their parse time value 678 # not runtime value 679 global CONSTANT_C 680 prev_c = CONSTANT_C 681 # If the behavior of triton kernels change, this test will fail 682 CONSTANT_C = 10 683 assert CONSTANT_C != prev_c 684 685 t = torch.randn(5, device=GPU_TYPE) 686 torch_result = call_triton(t) 687 compiled_result = torch.compile(call_triton)(t) 688 689 self.assertEqual(torch_result, compiled_result) 690 691 # reset back 692 CONSTANT_C = prev_c 693 694 @requires_gpu 695 @common_utils.parametrize("grad", [False, True]) 696 @common_utils.parametrize("dynamic", [False, True]) 697 @common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"]) 698 @common_utils.parametrize("grid_type", [1, 2, 3]) 699 def test_triton_kernel_autotune(self, grad, dynamic, backend, grid_type): 700 def call_triton(x: torch.Tensor, y: torch.Tensor, output: torch.Tensor): 701 n_elements = output.numel() 702 703 def grid_fn(meta): 704 return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 705 706 if grid_type == 1: 707 grid = (n_elements,) 708 elif grid_type == 2: 709 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 710 elif grid_type == 3: 711 grid = grid_fn 712 713 add_kernel_autotuned[grid](x, y, output, n_elements) 714 return output 715 716 t1 = torch.rand(256, device=GPU_TYPE, requires_grad=grad) 717 t2 = torch.rand(256, device=GPU_TYPE, requires_grad=grad) 718 output = torch.zeros_like(t1, requires_grad=grad) 719 720 torch_add = call_triton(t1, t2, output) 721 compiled_func = torch.compile( 722 call_triton, backend=backend, fullgraph=True, dynamic=dynamic 723 ) 724 725 output2 = torch.zeros_like(t1, requires_grad=grad) 726 self.assertEqual(compiled_func(t1, t2, output2), torch_add) 727 728 @requires_gpu 729 @skipIfRocm # https://github.com/pytorch/pytorch/actions/runs/10051552819/job/27782048305?pr=131431 730 @common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"]) 731 @patch.object( 732 torch._inductor.config, "unsafe_ignore_unsupported_triton_autotune_args", True 733 ) 734 def test_triton_kernel_autotune_with_unsupported_args(self, backend): 735 def call_triton(x: torch.Tensor, y: torch.Tensor): 736 output = torch.zeros_like(x) 737 n_elements = output.numel() 738 add_kernel_autotuned_with_unsupported_args[(n_elements,)]( 739 x, y, output, n_elements 740 ) 741 return output 742 743 t1 = torch.rand(256, device=GPU_TYPE) 744 t2 = torch.rand(256, device=GPU_TYPE) 745 746 torch_add = call_triton(t1, t2) 747 compiled_func = torch.compile(call_triton, backend=backend, fullgraph=True) 748 compiled_add = compiled_func(t1, t2) 749 self.assertEqual(compiled_add, torch_add) 750 751 @requires_gpu 752 @common_utils.parametrize("grad", [False, True]) 753 @common_utils.parametrize("dynamic", [False, True]) 754 @common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"]) 755 @common_utils.parametrize("grid_type", [1, 2, 3]) 756 def test_triton_kernel_2d_autotune(self, grad, dynamic, backend, grid_type): 757 def call_triton(x: torch.Tensor, y: torch.Tensor, output: torch.Tensor): 758 x_elements = output.size()[0] 759 y_elements = output.size()[1] 760 761 def grid_fn(meta): 762 return ( 763 triton.cdiv(x_elements, meta["BLOCK_SIZE_X"]), 764 triton.cdiv(y_elements, meta["BLOCK_SIZE_Y"]), 765 ) 766 767 if grid_type == 1: 768 grid = (x_elements, y_elements) 769 elif grid_type == 2: 770 grid = lambda meta: ( 771 triton.cdiv(x_elements, meta["BLOCK_SIZE_X"]), 772 triton.cdiv(y_elements, meta["BLOCK_SIZE_Y"]), 773 ) 774 elif grid_type == 3: 775 grid = grid_fn 776 777 add_kernel_2d_autotuned[grid](x, y, output, x_elements, y_elements) 778 return output 779 780 t1 = torch.rand((512, 256), device=GPU_TYPE, requires_grad=grad) 781 t2 = torch.rand((512, 256), device=GPU_TYPE, requires_grad=grad) 782 output = torch.zeros_like(t1, requires_grad=grad) 783 784 torch_result = call_triton(t1, t2, output) 785 compiled_func = torch.compile( 786 call_triton, backend=backend, fullgraph=True, dynamic=dynamic 787 ) 788 output2 = torch.zeros_like(t1, requires_grad=grad) 789 self.assertEqual(compiled_func(t1, t2, output2), torch_result) 790 791 @requires_gpu 792 @common_utils.parametrize("dynamic", [False, True]) 793 def test_triton_kernel_tracing(self, dynamic): 794 def call_triton_add( 795 x: torch.Tensor, 796 y: torch.Tensor, 797 grid_type: int, 798 num=1, 799 positional=False, 800 autotuned=False, 801 ): 802 output = torch.empty_like(x) 803 n_elements = output.numel() 804 805 def grid_fn(meta): 806 return (triton.cdiv(num, meta["BLOCK_SIZE"]),) 807 808 if grid_type == 0: 809 grid = (x.numel(),) 810 elif grid_type == 1: 811 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 812 elif grid_type == 2: 813 grid = grid_fn 814 else: 815 grid = [x.numel()] 816 817 if autotuned: 818 capture_triton(add_kernel_autotuned)[grid](x, y, output, n_elements) 819 else: 820 if positional: 821 capture_triton(add_kernel)[grid](x, y, output, n_elements, 16) 822 else: 823 capture_triton(add_kernel)[grid]( 824 x, y, output, n_elements, BLOCK_SIZE=16 825 ) 826 827 return output 828 829 t0 = torch.rand(5, device=GPU_TYPE, requires_grad=True) 830 t1 = torch.rand(5, device=GPU_TYPE, requires_grad=True) 831 t2 = torch.rand(5, device=GPU_TYPE, requires_grad=True) 832 t3 = torch.rand(5, device=GPU_TYPE, requires_grad=True) 833 torch_add = t2 + t3 834 835 tests = [ 836 functools.partial(call_triton_add, grid_type=0), 837 functools.partial(call_triton_add, grid_type=1), 838 functools.partial(call_triton_add, grid_type=1, num=1, positional=True), 839 functools.partial(call_triton_add, grid_type=2, num=200), 840 functools.partial(call_triton_add, grid_type=3), 841 functools.partial(call_triton_add, grid_type=0, autotuned=True), 842 functools.partial(call_triton_add, grid_type=1, num=1, autotuned=True), 843 functools.partial(call_triton_add, grid_type=2, num=200, autotuned=True), 844 functools.partial(call_triton_add, grid_type=3, autotuned=True), 845 ] 846 from functorch import make_fx 847 848 tracing_mode = "symbolic" if dynamic else "fake" 849 850 for test in tests: 851 gm = make_fx(test, tracing_mode=tracing_mode)(t0, t1) 852 result = test(t2, t3) 853 self.assertEqual(result, torch_add) 854 855 @requires_gpu 856 @common_utils.parametrize("grad", [False, True]) 857 @common_utils.parametrize("dynamic", [False, True]) 858 @common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"]) 859 @patch.object(torch._inductor.config, "implicit_fallbacks", False) 860 def test_triton_kernel_native(self, grad, dynamic, backend): 861 def call_triton_add( 862 x: torch.Tensor, 863 y: torch.Tensor, 864 output: torch.Tensor, 865 grid_type: int, 866 num=1, 867 positional=False, 868 ): 869 n_elements = output.numel() 870 871 def grid_fn(meta): 872 return (triton.cdiv(num, meta["BLOCK_SIZE"]),) 873 874 if grid_type == 0: 875 grid = (x.numel(),) 876 elif grid_type == 1: 877 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 878 else: 879 grid = grid_fn 880 881 if positional: 882 add_kernel[grid](x, y, output, n_elements, 16) 883 else: 884 add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=16) 885 886 return output 887 888 t1 = torch.rand(5, device=GPU_TYPE, requires_grad=grad) 889 t2 = torch.rand(5, device=GPU_TYPE, requires_grad=grad) 890 o1 = torch.zeros_like(t1, requires_grad=grad) 891 892 torch_add = t1 + t2 893 894 # No Dynamo -- Make sure triton kernel works 895 self.assertEqual(call_triton_add(t1, t2, o1, 1), torch_add) 896 # No Dynamo -- Make sure triton kernel works (with positional BLOCK_SIZE) 897 o2 = torch.zeros_like(t1, requires_grad=grad) 898 self.assertEqual(call_triton_add(t1, t2, o2, 1, True), torch_add) 899 900 # With Dynamo 901 compiled_func = torch.compile( 902 call_triton_add, backend=backend, fullgraph=True, dynamic=dynamic 903 ) 904 # With simple kernel 905 o3 = torch.zeros_like(t1, requires_grad=grad) 906 self.assertEqual(compiled_func(t1, t2, o3, 0), torch_add) 907 # With lambda kernel 908 o4 = torch.zeros_like(t1, requires_grad=grad) 909 self.assertEqual(compiled_func(t1, t2, o4, 1), torch_add) 910 # With lambda kernel (with positional BLOCK_SIZE) 911 o5 = torch.zeros_like(t1, requires_grad=grad) 912 self.assertEqual(compiled_func(t1, t2, o5, 1, 1, True), torch_add) 913 # With user defined function kernel 914 o6 = torch.zeros_like(t1, requires_grad=grad) 915 self.assertEqual(compiled_func(t1, t2, o6, 2, 200), torch_add) 916 917 @requires_gpu 918 def test_triton_kernel_mutation_not_mark_dirty(self): 919 @torch.compile 920 def f(x): 921 n_elements = x.numel() 922 add_kernel[(n_elements,)](x, x, x, n_elements, 16) 923 return x 924 925 x = torch.randn(5, device=GPU_TYPE, requires_grad=True) 926 x_cloned = x.clone() 927 out = x_cloned.sin() 928 f(x_cloned) 929 out.sum().backward() 930 931 @requires_cuda 932 @patch.object(torch._inductor.config, "allow_buffer_reuse", True) 933 def test_triton_kernel_inputs_buffer_reuse(self): 934 def _mul2(x): 935 y = torch.empty_like(x) 936 mul2_kernel[(10,)]( 937 in_ptr0=x, 938 out_ptr=y, 939 n_elements=x.numel(), 940 BLOCK_SIZE=1, 941 ) 942 return y 943 944 @torch.compile 945 def f(x): 946 for _ in range(4): 947 # The output of one kernel is the input to the next kernel, but 948 # at some point we should re-use buffers not allocate new ones. 949 x = _mul2(x) 950 return x + 1 951 952 x = torch.randn(10, device="cuda", dtype=torch.float32) 953 eager_out = f(x) 954 compiled_out, (code,) = run_and_get_code(torch.compile(f), x) 955 self.assertEqual(compiled_out, eager_out) 956 957 # Check that we're allocating the minimal # of buffers. 958 num_bufs_allocated = code.count( 959 "empty_strided_cuda((10, ), (1, ), torch.float32)" 960 ) 961 self.assertEqual(num_bufs_allocated, 2) 962 963 # Check we're re-using buffers if not allocating. 964 num_bufs_reused = code.count("# reuse") 965 self.assertEqual(num_bufs_reused, 3) 966 967 @requires_gpu 968 def test_triton_kernel_matmul_tracking(self): 969 @triton.jit 970 def ones_kernel(x_ptr, n_elements, BLOCK_SIZE: "tl.constexpr"): 971 pid = tl.program_id(axis=0) 972 block_start = pid * BLOCK_SIZE 973 offsets = block_start + tl.arange(0, BLOCK_SIZE) 974 mask = offsets < n_elements 975 x = 1.0 976 tl.store(x_ptr + offsets, x, mask=mask) 977 978 @torch.compile 979 def f(x): 980 out = torch.zeros_like(x) 981 ones_kernel[(4,)](out, 16, BLOCK_SIZE=16) 982 return torch.mm(out, x) + 10 983 984 x = torch.randn(4, 4, device=GPU_TYPE) 985 torch_out = f(x) 986 python_out = torch.mm(torch.ones(4, 4, device=GPU_TYPE), x) + 10 987 self.assertEqual(torch_out, python_out) 988 989 @requires_gpu 990 def test_triton_kernel_strided_input(self): 991 def f(inp): 992 # left has strides [256, 1] 993 left, right = torch.split(inp, [128, 128], dim=1) 994 out = torch.empty_like(left) 995 X_BLOCK_SIZE, Y_BLOCK_SIZE = 32, 16 996 grid = (left.size(1) // X_BLOCK_SIZE, left.size(0) // Y_BLOCK_SIZE) 997 double_strided_kernel[grid]( 998 in_ptr=left, 999 out_ptr=out, 1000 in_y_stride=left.stride(0), 1001 out_y_stride=out.stride(0), 1002 X_BLOCK_SIZE=X_BLOCK_SIZE, 1003 Y_BLOCK_SIZE=Y_BLOCK_SIZE, 1004 ) 1005 return out 1006 1007 inp = torch.randn(64, 256, device=GPU_TYPE) 1008 1009 eager_out = f(inp) 1010 compiled_out = torch.compile(f)(inp) 1011 self.assertEqual(compiled_out, eager_out) 1012 1013 @requires_gpu 1014 def test_triton_kernel_strided_input_nonzero_offset(self): 1015 def f(inp): 1016 # right has strides [256, 1] and storage offset 128 1017 left, right = torch.split(inp, [128, 128], dim=1) 1018 out = torch.empty_like(right) 1019 X_BLOCK_SIZE, Y_BLOCK_SIZE = 32, 16 1020 grid = (right.size(1) // X_BLOCK_SIZE, right.size(0) // Y_BLOCK_SIZE) 1021 double_strided_kernel[grid]( 1022 in_ptr=right, 1023 out_ptr=out, 1024 in_y_stride=right.stride(0), 1025 out_y_stride=out.stride(0), 1026 X_BLOCK_SIZE=X_BLOCK_SIZE, 1027 Y_BLOCK_SIZE=Y_BLOCK_SIZE, 1028 ) 1029 return out 1030 1031 inp = torch.randn(64, 256, device=GPU_TYPE) 1032 1033 eager_out = f(inp) 1034 compiled_out = torch.compile(f)(inp) 1035 self.assertEqual(compiled_out, eager_out) 1036 1037 @requires_gpu 1038 def test_triton_kernel_slice_and_view_input(self): 1039 def f(inp): 1040 # left has strides [256, 1] 1041 left = inp[:, :128] 1042 left = left.view(64, 4, 32) 1043 out = torch.empty_like(left) 1044 X_BLOCK_SIZE, Y_BLOCK_SIZE = 32, 16 1045 grid = ( 1046 (left.size(1) * left.size(2)) // X_BLOCK_SIZE, 1047 left.size(0) // Y_BLOCK_SIZE, 1048 ) 1049 double_strided_kernel[grid]( 1050 in_ptr=left, 1051 out_ptr=out, 1052 in_y_stride=left.stride(0), 1053 out_y_stride=out.stride(0), 1054 X_BLOCK_SIZE=X_BLOCK_SIZE, 1055 Y_BLOCK_SIZE=Y_BLOCK_SIZE, 1056 ) 1057 return out + left 1058 1059 inp = torch.randn(64, 256, device=GPU_TYPE) 1060 1061 eager_out = f(inp) 1062 compiled_out = torch.compile(f)(inp) 1063 self.assertEqual(compiled_out, eager_out) 1064 1065 @requires_gpu 1066 def test_triton_kernel_fallback(self): 1067 def f(x, y): 1068 out = torch.zeros_like(x) 1069 out2 = torch.zeros_like(x) 1070 # torch.mm is ExternKernelOut 1071 add_kernel[(4,)](x, torch.mm(x, y), out, 4, 16) 1072 # torch.sort creates fallback kernel and hence MultiOutput 1073 add_kernel[(4,)](x, torch.sort(y).values, out, 4, 16) 1074 return out, out2 1075 1076 x = torch.randn(4, 4, device=GPU_TYPE) 1077 y = torch.randn(4, 4, device=GPU_TYPE) 1078 eager_out = f(x, y) 1079 compiled_out = torch.compile(f)(x, y) 1080 self.assertEqual(compiled_out, eager_out) 1081 1082 @requires_gpu 1083 def test_triton_kernel_out_of_order(self): 1084 @triton.jit 1085 def add_kernel( 1086 in_ptr0, 1087 in_ptr1, 1088 BLOCK_SIZE: "tl.constexpr", 1089 out_ptr, 1090 n_elements, 1091 ): 1092 pid = tl.program_id(axis=0) 1093 block_start = pid * BLOCK_SIZE 1094 offsets = block_start + tl.arange(0, BLOCK_SIZE) 1095 mask = offsets < n_elements 1096 x = tl.load(in_ptr0 + offsets, mask=mask) 1097 y = tl.load(in_ptr1 + offsets, mask=mask) 1098 output = x + y 1099 tl.store(out_ptr + offsets, output, mask=mask) 1100 1101 def f(x, y): 1102 out = torch.zeros_like(x) 1103 n_elements = x.numel() 1104 add_kernel[(n_elements,)](x, y, 4, out, n_elements) 1105 return out 1106 1107 x = torch.randn(4, device=GPU_TYPE) 1108 y = torch.randn(4, device=GPU_TYPE) 1109 eager_out = f(x, y) 1110 compiled_out = torch.compile(f)(x, y) 1111 self.assertEqual(compiled_out, eager_out) 1112 1113 @requires_gpu 1114 @torch._dynamo.config.patch(capture_dynamic_output_shape_ops=True) 1115 @torch._dynamo.config.patch(capture_scalar_outputs=True) 1116 @common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"]) 1117 def test_triton_kernel_unbacked_shape_tensor(self, backend): 1118 @triton.jit 1119 def square( 1120 in_ptr, 1121 out_ptr, 1122 n_elements, 1123 BLOCK_SIZE: "tl.constexpr", 1124 ): 1125 pid = tl.program_id(axis=0) 1126 block_start = pid * BLOCK_SIZE 1127 offsets = block_start + tl.arange(0, BLOCK_SIZE) 1128 mask = offsets < n_elements 1129 x = tl.load(in_ptr + offsets, mask=mask) 1130 output = x * x 1131 tl.store(out_ptr + offsets, output, mask=mask) 1132 1133 def f(x): 1134 x = x[x > 2] 1135 n_elements = x.numel() 1136 output = torch.zeros_like(x) 1137 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 1138 square[grid](x, output, n_elements, BLOCK_SIZE=16) 1139 return output 1140 1141 x = torch.randn(4, device=GPU_TYPE) 1142 eager_out = f(x) 1143 compiled_out = torch.compile(f, fullgraph=True, backend=backend)(x) 1144 self.assertEqual(compiled_out, eager_out) 1145 1146 @requires_gpu 1147 @common_utils.parametrize("dynamic", [False, True]) 1148 def test_triton_kernel_equal_to_1_arg(self, dynamic): 1149 @triton.jit 1150 def add_kernel_half_n_elements( 1151 in_ptr0, 1152 in_ptr1, 1153 out_ptr, 1154 half_n_elements, 1155 BLOCK_SIZE: "tl.constexpr", 1156 ): 1157 pid = tl.program_id(axis=0) 1158 block_start = pid * BLOCK_SIZE 1159 offsets = block_start + tl.arange(0, BLOCK_SIZE) 1160 mask = offsets < half_n_elements * 2 1161 x = tl.load(in_ptr0 + offsets, mask=mask) 1162 y = tl.load(in_ptr1 + offsets, mask=mask) 1163 output = x + y 1164 tl.store(out_ptr + offsets, output, mask=mask) 1165 1166 def f(x, y): 1167 out = torch.empty_like(x) 1168 half_n_elements = x.numel() // 2 1169 add_kernel_half_n_elements[(half_n_elements,)]( 1170 x, y, out, half_n_elements, BLOCK_SIZE=16 1171 ) 1172 return out 1173 1174 x = torch.randn(2, device=GPU_TYPE) 1175 y = torch.randn(2, device=GPU_TYPE) 1176 eager_out = f(x, y) 1177 compiled_out, sources = run_and_get_code( 1178 torch.compile(f, dynamic=dynamic), x, y 1179 ) 1180 1181 if dynamic: 1182 # when half_n_elements passed to the Triton kernel is 1183 # dynamic, equal_to_1 specializaiton can't be enforced 1184 self.assertTrue("equal_to_1=()" in sources[0]) 1185 else: 1186 self.assertTrue("equal_to_1=(3,)" in sources[0]) 1187 self.assertEqual(compiled_out, eager_out) 1188 1189 @requires_gpu 1190 @common_utils.parametrize("dynamic", [False, True]) 1191 def test_triton_kernel_equal_to_1_float_arg(self, dynamic): 1192 def f(x, y): 1193 out = torch.empty_like(x) 1194 n_elements = x.numel() 1195 scaling_factor = (n_elements**0) / 1.0 1196 add_kernel_with_scaling[(n_elements,)]( 1197 x, 1198 y, 1199 out, 1200 n_elements, 1201 scaling_factor, 1202 BLOCK_SIZE=16, 1203 ) 1204 return out 1205 1206 x = torch.randn(2, device=GPU_TYPE) 1207 y = torch.randn(2, device=GPU_TYPE) 1208 eager_out = f(x, y) 1209 compiled_out, sources = run_and_get_code( 1210 torch.compile(f, dynamic=dynamic), x, y 1211 ) 1212 1213 # float 1.0 (both literal or symbolic) 1214 # should not be added to equal_to_1 1215 self.assertTrue("equal_to_1=()" in sources[0]) 1216 self.assertEqual(compiled_out, eager_out) 1217 1218 @requires_gpu 1219 @skipIfRocm 1220 def test_triton_kernel_with_imported_symbol(self): 1221 @triton.jit 1222 def add_kernel_with_imported_symbol( 1223 in_ptr, 1224 out_ptr, 1225 n_elements, 1226 BLOCK_SIZE: "tl.constexpr", 1227 ): 1228 pid = tl.program_id(axis=0) 1229 block_start = pid * BLOCK_SIZE 1230 offsets = block_start + tl.arange(0, BLOCK_SIZE) 1231 mask = offsets < n_elements 1232 x = tl.load(in_ptr + offsets, mask=mask) 1233 output = fast_dividef(x, 3.14) 1234 tl.store(out_ptr + offsets, output, mask=mask) 1235 1236 def f(x): 1237 out = torch.empty_like(x) 1238 n_elements = x.numel() 1239 add_kernel_with_imported_symbol[(n_elements,)]( 1240 x, out, n_elements, BLOCK_SIZE=16 1241 ) 1242 return out 1243 1244 x = torch.randn(4, device=GPU_TYPE) 1245 eager_out = f(x) 1246 compiled_out = torch.compile(f)(x) 1247 1248 self.assertEqual(compiled_out, eager_out) 1249 1250 @requires_gpu 1251 @skipIfRocm 1252 def test_triton_kernel_with_imported_symbol_with_custom_name(self): 1253 @triton.jit 1254 def add_kernel_with_imported_symbol( 1255 in_ptr, 1256 out_ptr, 1257 n_elements, 1258 BLOCK_SIZE: "tl.constexpr", 1259 ): 1260 pid = tl.program_id(axis=0) 1261 block_start = pid * BLOCK_SIZE 1262 offsets = block_start + tl.arange(0, BLOCK_SIZE) 1263 mask = offsets < n_elements 1264 x = tl.load(in_ptr + offsets, mask=mask) 1265 output = my_fast_dividef(x, 3.14) 1266 tl.store(out_ptr + offsets, output, mask=mask) 1267 1268 def f(x): 1269 out = torch.empty_like(x) 1270 n_elements = x.numel() 1271 add_kernel_with_imported_symbol[(n_elements,)]( 1272 x, out, n_elements, BLOCK_SIZE=16 1273 ) 1274 return out 1275 1276 x = torch.randn(4, device=GPU_TYPE) 1277 eager_out = f(x) 1278 compiled_out = torch.compile(f)(x) 1279 1280 self.assertEqual(compiled_out, eager_out) 1281 1282 @requires_gpu 1283 @common_utils.parametrize("size", [4, 16]) 1284 @common_utils.parametrize("dynamic", [False, True]) 1285 def test_triton_kernel_different_shapes(self, size, dynamic): 1286 from torch._inductor.utils import run_and_get_code 1287 1288 def f(x, y, xx, yy): 1289 n_elements = x.numel() 1290 output_1 = torch.zeros_like(x) 1291 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 1292 add_kernel[grid](x, y, output_1, n_elements, BLOCK_SIZE=4) 1293 1294 n_elements = xx.numel() 1295 output_2 = torch.zeros_like(xx) 1296 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 1297 add_kernel[grid](xx, yy, output_2, n_elements, BLOCK_SIZE=4) 1298 1299 return output_1, output_2 1300 1301 x = torch.rand(size, device=GPU_TYPE) 1302 y = torch.rand(size, device=GPU_TYPE) 1303 xx = torch.rand(size, size, device=GPU_TYPE) 1304 yy = torch.rand(size, size, device=GPU_TYPE) 1305 args = [x, y, xx, yy] 1306 1307 eager_out = f(*args) 1308 compiled_out, (code,) = run_and_get_code( 1309 torch.compile(f, fullgraph=True, dynamic=dynamic, backend="inductor"), *args 1310 ) 1311 if size == 4 and not dynamic: 1312 # Produce 2 kernels due to divisibility 1313 self.assertTrue("add_kernel_0.run" in code) 1314 self.assertTrue("add_kernel_1.run" in code) 1315 else: 1316 # size == 16 or dynamic 1317 # Only one kernel 1318 self.assertTrue("add_kernel_0.run" in code) 1319 self.assertTrue("add_kernel_1.run" not in code) 1320 1321 self.assertEqual(compiled_out, eager_out) 1322 1323 @requires_gpu 1324 def test_triton_kernel_reset_to_zero(self): 1325 @triton.autotune( 1326 configs=[ 1327 triton.Config({"BLOCK_SIZE": 128}, num_stages=3, num_warps=8), 1328 triton.Config({"BLOCK_SIZE": 64}, num_stages=3, num_warps=8), 1329 ], 1330 key=["n_elements"], 1331 reset_to_zero=["out_ptr"], 1332 ) 1333 @triton.jit 1334 def add_kernel_autotuned_reset( 1335 in_ptr0, 1336 in_ptr1, 1337 out_ptr, 1338 n_elements, 1339 BLOCK_SIZE: "tl.constexpr", 1340 ): 1341 pid = tl.program_id(axis=0) 1342 block_start = pid * BLOCK_SIZE 1343 offsets = block_start + tl.arange(0, BLOCK_SIZE) 1344 mask = offsets < n_elements 1345 x = tl.load(in_ptr0 + offsets, mask=mask) 1346 y = tl.load(in_ptr1 + offsets, mask=mask) 1347 output = x + y 1348 tl.store(out_ptr + offsets, output, mask=mask) 1349 1350 @torch.compile(fullgraph=True) 1351 def f(x, y): 1352 output = torch.zeros_like(x) 1353 n_elements = output.numel() 1354 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 1355 add_kernel_autotuned_reset[grid](x, y, output, n_elements) 1356 return output 1357 1358 x = torch.randn(4, device=GPU_TYPE) 1359 msg = "Only configs and keys are supported for triton.autotune" 1360 with self.assertRaisesRegex(torch._dynamo.exc.Unsupported, msg): 1361 f(x, x) 1362 1363 @requires_gpu 1364 @common_utils.parametrize("dynamic", [False, True]) 1365 @common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"]) 1366 def test_triton_kernel_triton_dtype(self, dynamic, backend): 1367 @triton.jit 1368 def add_kernel_with_dtype( 1369 in_ptr0, 1370 in_ptr1, 1371 out_ptr, 1372 dtype: "tl.constexpr", 1373 n_elements, 1374 BLOCK_SIZE: "tl.constexpr", 1375 ): 1376 pid = tl.program_id(axis=0) 1377 block_start = pid * BLOCK_SIZE 1378 offsets = block_start + tl.arange(0, BLOCK_SIZE) 1379 mask = offsets < n_elements 1380 x = tl.load(in_ptr0 + offsets, mask=mask).to(dtype) 1381 y = tl.load(in_ptr1 + offsets, mask=mask).to(dtype) 1382 output = x + y 1383 tl.store(out_ptr + offsets, output, mask=mask) 1384 1385 def f(x, y, dtype_torch, dtype_triton): 1386 output = torch.zeros_like(x).to(dtype=dtype_torch) 1387 n_elements = output.numel() 1388 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 1389 add_kernel_with_dtype[grid]( 1390 x, y, output, dtype_triton, n_elements, BLOCK_SIZE=4 1391 ) 1392 return output 1393 1394 x = torch.randn(4, device=GPU_TYPE) 1395 y = torch.randn(4, device=GPU_TYPE) 1396 args_list = ( 1397 [x, y, torch.float32, tl.float32], 1398 [x, y, torch.bfloat16, tl.bfloat16], 1399 ) 1400 for args in args_list: 1401 eager_out = f(*args) 1402 compiled_out = torch.compile( 1403 f, fullgraph=True, backend=backend, dynamic=dynamic 1404 )(*args) 1405 self.assertEqual(compiled_out, eager_out) 1406 1407 @requires_gpu 1408 @common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"]) 1409 def test_triton_kernel_special_kwargs_with_autotune(self, backend): 1410 @triton.autotune( 1411 configs=[ 1412 triton.Config({"BLOCK_SIZE": 128}), 1413 triton.Config({"BLOCK_SIZE": 64}), 1414 ], 1415 key=["n_elements"], 1416 ) 1417 @triton.jit 1418 def add_kernel( 1419 in_ptr0, 1420 in_ptr1, 1421 out_ptr, 1422 n_elements, 1423 BLOCK_SIZE: "tl.constexpr", 1424 ): 1425 pid = tl.program_id(axis=0) 1426 block_start = pid * BLOCK_SIZE 1427 offsets = block_start + tl.arange(0, BLOCK_SIZE) 1428 mask = offsets < n_elements 1429 x = tl.load(in_ptr0 + offsets, mask=mask) 1430 y = tl.load(in_ptr1 + offsets, mask=mask) 1431 output = x + y 1432 tl.store(out_ptr + offsets, output, mask=mask) 1433 1434 @torch.compile(fullgraph=True, backend=backend) 1435 def f(x, y): 1436 output = torch.zeros_like(x) 1437 n_elements = output.numel() 1438 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 1439 add_kernel[grid]( 1440 x, 1441 y, 1442 output, 1443 n_elements, 1444 num_warps=8, 1445 num_stages=3, 1446 ) 1447 return output 1448 1449 x = torch.randn(4, device=GPU_TYPE) 1450 f(x, x) 1451 1452 @requires_gpu 1453 @common_utils.parametrize("dynamic", [False, True]) 1454 @common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"]) 1455 def test_triton_kernel_multiple_outputs(self, dynamic, backend): 1456 @triton.jit 1457 def add_kernel( 1458 in_ptr0, 1459 in_ptr1, 1460 out_ptr, 1461 out_ptr2, 1462 n_elements, 1463 BLOCK_SIZE: "tl.constexpr", 1464 ): 1465 pid = tl.program_id(axis=0) 1466 block_start = pid * BLOCK_SIZE 1467 offsets = block_start + tl.arange(0, BLOCK_SIZE) 1468 mask = offsets < n_elements 1469 x = tl.load(in_ptr0 + offsets, mask=mask) 1470 y = tl.load(in_ptr1 + offsets, mask=mask) 1471 output = x + y 1472 tl.store(out_ptr + offsets, output, mask=mask) 1473 tl.store(out_ptr2 + offsets, output + 1, mask=mask) 1474 1475 @torch.compile(fullgraph=True, backend=backend, dynamic=dynamic) 1476 def f(x, y, z): 1477 output = torch.empty_like(x) 1478 output2 = torch.empty_like(x) 1479 n_elements = output.numel() 1480 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 1481 add_kernel[grid](x, y, output, output2, n_elements, BLOCK_SIZE=16) 1482 # The z return is intentional: we're testing training 1483 return output, output2, z**2 1484 1485 x = torch.randn(3, requires_grad=True, device=GPU_TYPE) 1486 y = torch.randn(3, requires_grad=True, device=GPU_TYPE) 1487 z = torch.randn(3, requires_grad=True, device=GPU_TYPE) 1488 out, out2, out3 = f(x, y, z) 1489 self.assertEqual(out, x + y) 1490 self.assertEqual(out2, x + y + 1) 1491 self.assertEqual(out3, z**2) 1492 1493 @requires_gpu 1494 @common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"]) 1495 def test_triton_kernel_num_ctas(self, backend): 1496 @triton.jit 1497 def kernel(X): 1498 return 1499 1500 @torch.compile(backend=backend) 1501 def f(x): 1502 kernel[(1,)](x, num_ctas=1) 1503 kernel.run(x, num_ctas=1, grid=(1,), warmup=False) 1504 return x 1505 1506 x = torch.randn(4, device=GPU_TYPE) 1507 f(x) 1508 1509 @requires_gpu 1510 @common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"]) 1511 def test_triton_kernel_special_kwargs_without_autotune(self, backend): 1512 @triton.jit 1513 def add_kernel( 1514 in_ptr0, 1515 in_ptr1, 1516 out_ptr, 1517 n_elements, 1518 BLOCK_SIZE: "tl.constexpr", 1519 ): 1520 pid = tl.program_id(axis=0) 1521 block_start = pid * BLOCK_SIZE 1522 offsets = block_start + tl.arange(0, BLOCK_SIZE) 1523 mask = offsets < n_elements 1524 x = tl.load(in_ptr0 + offsets, mask=mask) 1525 y = tl.load(in_ptr1 + offsets, mask=mask) 1526 output = x + y 1527 tl.store(out_ptr + offsets, output, mask=mask) 1528 1529 @torch.compile(fullgraph=True, backend=backend) 1530 def f(x, y): 1531 output = torch.zeros_like(x) 1532 n_elements = output.numel() 1533 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 1534 add_kernel[grid]( 1535 x, 1536 y, 1537 output, 1538 n_elements, 1539 BLOCK_SIZE=128, 1540 num_warps=8, 1541 num_stages=3, 1542 ) 1543 return output 1544 1545 x = torch.randn(4, device=GPU_TYPE) 1546 f(x, x) 1547 1548 1549def make_mutation_test(fn): 1550 @requires_gpu 1551 def test_fn(self): 1552 from torch._higher_order_ops.triton_kernel_wrap import identify_mutated_tensors 1553 1554 kernel, inputs, outputs = fn() 1555 self.assertListEqual( 1556 identify_mutated_tensors(kernel, inputs), 1557 outputs, 1558 ) 1559 1560 return test_fn 1561 1562 1563# Triton codegen suffers from scoping issues. 1564# Define helpers here 1565if HAS_GPU: 1566 1567 @triton.jit 1568 def helper_id(p): 1569 return p 1570 1571 @triton.jit 1572 def helper_add_and_out(x, y, out_ptr): 1573 return x + y, out_ptr 1574 1575 1576class MutationTests(torch._inductor.test_case.TestCase): 1577 # Tests injected below 1578 1579 @make_mutation_test 1580 def test_out_of_order_kernel(): 1581 @triton.jit 1582 def add_kernel_out_of_order( 1583 in_ptr0, 1584 n_elements, 1585 in_ptr1, 1586 out_ptr, 1587 BLOCK_SIZE: "tl.constexpr", 1588 ): 1589 pid = tl.program_id(axis=0) 1590 block_start = pid * BLOCK_SIZE 1591 offsets = block_start + tl.arange(0, BLOCK_SIZE) 1592 mask = offsets < n_elements 1593 x = tl.load(in_ptr0 + offsets, mask=mask) 1594 y = tl.load(in_ptr1 + offsets, mask=mask) 1595 output = x + y 1596 tl.store(out_ptr + offsets, output, mask=mask) 1597 1598 t = torch.randn(4) 1599 return ( 1600 add_kernel_out_of_order, 1601 { 1602 "in_ptr0": t, 1603 "n_elements": 4, 1604 "in_ptr1": t, 1605 "out_ptr": t, 1606 "BLOCK_SIZE": 4, 1607 }, 1608 ["out_ptr"], 1609 ) 1610 1611 @make_mutation_test 1612 def test_out_of_order_kernel_call(): 1613 @triton.jit 1614 def add_kernel_out_of_order_fn1( 1615 in_ptr0, 1616 n_elements, 1617 in_ptr1, 1618 out_ptr, 1619 BLOCK_SIZE: "tl.constexpr", 1620 ): 1621 pid = tl.program_id(axis=0) 1622 block_start = pid * BLOCK_SIZE 1623 offsets = block_start + tl.arange(0, BLOCK_SIZE) 1624 mask = offsets < n_elements 1625 add_kernel_out_of_order_fn2( 1626 in_ptr0, in_ptr1, n_elements, out_ptr, BLOCK_SIZE=BLOCK_SIZE 1627 ) 1628 1629 t = torch.randn(4) 1630 return ( 1631 add_kernel_out_of_order_fn1, 1632 { 1633 "in_ptr0": t, 1634 "n_elements": 4, 1635 "in_ptr1": t, 1636 "out_ptr": t, 1637 "BLOCK_SIZE": 4, 1638 }, 1639 ["out_ptr"], 1640 ) 1641 1642 @make_mutation_test 1643 def test_reduce_sum(): 1644 @triton.jit 1645 def reduce_sum_kernel(a_ptr, c_ptr, stride_am, stride_an): 1646 offs_am = tl.arange(0, 4) 1647 offs_an = tl.arange(0, 4) 1648 a_ptrs = a_ptr + ( 1649 offs_am[:, None] * stride_am + offs_an[None, :] * stride_an 1650 ) 1651 a = tl.load(a_ptrs) 1652 m = tl.sum(a, axis=1) 1653 tl.store(c_ptr + tl.arange(0, 4), m) 1654 1655 t = torch.randn(4) 1656 kernel = reduce_sum_kernel 1657 kwargs = { 1658 "a_ptr": t, 1659 "c_ptr": t, 1660 "stride_am": 4, 1661 "stride_an": 4, 1662 } 1663 1664 # TODO(aakhundov): tt.reduce is now supported, but only 1665 # in the new MLIR-based Triton analysis pass (not in the 1666 # old TTIR string parsing-based one). remove this gating 1667 # and use ["c_ptr"] as `expected` after the new Triton 1668 # pin lands both in OSS and internally. 1669 ttir_module, _ = generate_ttir(kernel, kwargs) 1670 if hasattr(ttir_module, "walk"): 1671 # with MLIR-based Triton analysis pass 1672 expected = ["c_ptr"] 1673 else: 1674 # with TTIR string parsing-based Triton analysis pass 1675 expected = ["a_ptr", "c_ptr"] 1676 1677 return ( 1678 kernel, 1679 kwargs, 1680 expected, 1681 ) 1682 1683 @make_mutation_test 1684 def test_argmax(): 1685 @triton.jit 1686 def argmax_kernel(a_ptr, c_ptr, stride_am, stride_an): 1687 offs_am = tl.arange(0, 4) 1688 offs_an = tl.arange(0, 4) 1689 a_ptrs = a_ptr + ( 1690 offs_am[:, None] * stride_am + offs_an[None, :] * stride_an 1691 ) 1692 a = tl.load(a_ptrs) 1693 m = tl.argmax(a, axis=1) 1694 tl.store(c_ptr + tl.arange(0, 4), m) 1695 1696 t = torch.randn(4) 1697 kernel = argmax_kernel 1698 kwargs = { 1699 "a_ptr": t, 1700 "c_ptr": t, 1701 "stride_am": 4, 1702 "stride_an": 4, 1703 } 1704 1705 # TODO(aakhundov): tt.reduce is now supported, but only 1706 # in the new MLIR-based Triton analysis pass (not in the 1707 # old TTIR string parsing-based one). remove this gating 1708 # and use ["c_ptr"] as `expected` after the new Triton 1709 # pin lands both in OSS and internally. 1710 ttir_module, _ = generate_ttir(kernel, kwargs) 1711 if hasattr(ttir_module, "walk"): 1712 # with MLIR-based Triton analysis pass 1713 expected = ["c_ptr"] 1714 else: 1715 # with TTIR string parsing-based Triton analysis pass 1716 expected = ["a_ptr", "c_ptr"] 1717 1718 return ( 1719 kernel, 1720 kwargs, 1721 expected, 1722 ) 1723 1724 @requires_cuda 1725 @skipIfRocm 1726 def test_triton_kernel_inference_mode(self): 1727 def f(x, y, out): 1728 n_elements = x.numel() 1729 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 1730 add_kernel[grid](x, y, out, n_elements, BLOCK_SIZE=4) 1731 1732 with torch.inference_mode(): 1733 x = torch.ones(32, device="cuda") 1734 y = torch.ones(32, device="cuda") 1735 out_ref = torch.zeros_like(x) 1736 out_test = torch.zeros_like(x) 1737 f(x, y, out_ref) 1738 torch.compile(f)(x, y, out_test) 1739 self.assertEqual(out_ref, out_test) 1740 1741 @make_mutation_test 1742 def test_cumsum(): 1743 @triton.jit 1744 def cumsum_kernel(in_ptr, out_ptr, XBLOCK: tl.constexpr, RBLOCK: tl.constexpr): 1745 rindex = tl.arange(0, RBLOCK)[None, :] 1746 xindex = tl.arange(0, XBLOCK)[:, None] 1747 data = tl.load(in_ptr + rindex) 1748 scan = tl.cumsum(data, 1) 1749 expected_max = tl.sum(data, 1) 1750 tl.device_assert(scan <= expected_max) 1751 tl.store(out_ptr + xindex * RBLOCK + rindex, scan) 1752 1753 t = torch.randn(4) 1754 kernel = cumsum_kernel 1755 kwargs = { 1756 "in_ptr": t, 1757 "out_ptr": t, 1758 "XBLOCK": 4, 1759 "RBLOCK": 16, 1760 } 1761 1762 # TODO(aakhundov): tt.scan is now supported, but only 1763 # in the new MLIR-based Triton analysis pass (not in the 1764 # old TTIR string parsing-based one). remove this gating 1765 # and use ["out_ptr"] as `expected` after the new Triton 1766 # pin lands both in OSS and internally. 1767 ttir_module, _ = generate_ttir(kernel, kwargs) 1768 if hasattr(ttir_module, "walk"): 1769 # with MLIR-based Triton analysis pass 1770 expected = ["out_ptr"] 1771 else: 1772 # with TTIR string parsing-based Triton analysis pass 1773 expected = ["in_ptr", "out_ptr"] 1774 1775 return ( 1776 kernel, 1777 kwargs, 1778 expected, 1779 ) 1780 1781 @make_mutation_test 1782 def test_fn_call_one_return(): 1783 @triton.jit 1784 def add_kernel_with_fn_call( 1785 in_ptr0, 1786 in_ptr1, 1787 n_elements, 1788 out_ptr, 1789 BLOCK_SIZE: "tl.constexpr", 1790 ): 1791 pid = tl.program_id(axis=0) 1792 block_start = pid * BLOCK_SIZE 1793 offsets = block_start + tl.arange(0, BLOCK_SIZE) 1794 mask = offsets < n_elements 1795 x = tl.load(in_ptr0 + offsets, mask=mask) 1796 y = tl.load(in_ptr1 + offsets, mask=mask) 1797 output = x + y 1798 out = helper_id(out_ptr) 1799 tl.store(out + offsets, output, mask=mask) 1800 1801 t = torch.randn(4) 1802 return ( 1803 add_kernel_with_fn_call, 1804 { 1805 "in_ptr0": t, 1806 "in_ptr1": t, 1807 "n_elements": 4, 1808 "out_ptr": t, 1809 "BLOCK_SIZE": 4, 1810 }, 1811 ["out_ptr"], 1812 ) 1813 1814 @make_mutation_test 1815 def test_fn_call_multi_return(): 1816 @triton.jit 1817 def add_kernel_with_fn_call( 1818 in_ptr0, 1819 in_ptr1, 1820 n_elements, 1821 out_ptr, 1822 BLOCK_SIZE: "tl.constexpr", 1823 ): 1824 pid = tl.program_id(axis=0) 1825 block_start = pid * BLOCK_SIZE 1826 offsets = block_start + tl.arange(0, BLOCK_SIZE) 1827 mask = offsets < n_elements 1828 x = tl.load(in_ptr0 + offsets, mask=mask) 1829 y = tl.load(in_ptr1 + offsets, mask=mask) 1830 output, out = helper_add_and_out(x, y, out_ptr) 1831 tl.store(out + offsets, output, mask=mask) 1832 1833 t = torch.randn(4) 1834 return ( 1835 add_kernel_with_fn_call, 1836 { 1837 "in_ptr0": t, 1838 "in_ptr1": t, 1839 "n_elements": 4, 1840 "out_ptr": t, 1841 "BLOCK_SIZE": 4, 1842 }, 1843 ["out_ptr"], 1844 ) 1845 1846 @make_mutation_test 1847 def test_nested_cond_op_kernel(): 1848 @triton.jit 1849 def nested_cond_op_kernel( 1850 in_ptr0, 1851 in_ptr1, 1852 out_ptr, 1853 n_elements, 1854 BLOCK_SIZE: "tl.constexpr", 1855 ): 1856 pid = tl.program_id(axis=0) 1857 block_start = pid * BLOCK_SIZE 1858 offsets = block_start + tl.arange(0, BLOCK_SIZE) 1859 mask = offsets < n_elements 1860 x = tl.load(in_ptr0 + offsets, mask=mask) 1861 y = tl.load(in_ptr1 + offsets, mask=mask) 1862 if tl.program_id(0) == 0: 1863 if tl.program_id(1) == 0: 1864 output = x + y 1865 tl.store(out_ptr + offsets, output, mask=mask) 1866 else: 1867 pass 1868 1869 t = torch.randn(4) 1870 return ( 1871 nested_cond_op_kernel, 1872 { 1873 "in_ptr0": t, 1874 "in_ptr1": t, 1875 "out_ptr": t, 1876 "n_elements": 4, 1877 "BLOCK_SIZE": 4, 1878 }, 1879 ["out_ptr"], 1880 ) 1881 1882 @make_mutation_test 1883 def test_add_for_loop(): 1884 @triton.jit 1885 def add_4_times_kernel( 1886 in_ptr0, 1887 in_ptr1, 1888 out_ptr, 1889 n_elements, 1890 BLOCK_SIZE: "tl.constexpr", 1891 ): 1892 pid = tl.program_id(axis=0) 1893 block_start = pid * BLOCK_SIZE 1894 offsets = block_start + tl.arange(0, BLOCK_SIZE) 1895 mask = offsets < n_elements 1896 x = tl.load(in_ptr0 + offsets, mask=mask) 1897 y = tl.load(in_ptr1 + offsets, mask=mask) 1898 output = tl.zeros((n_elements,), dtype=tl.float32) 1899 for i in range(4): 1900 output += x + y 1901 tl.store(out_ptr + offsets, output, mask=mask) 1902 1903 t = torch.randn(4) 1904 return ( 1905 add_4_times_kernel, 1906 { 1907 "in_ptr0": t, 1908 "in_ptr1": t, 1909 "out_ptr": t, 1910 "n_elements": 4, 1911 "BLOCK_SIZE": 4, 1912 }, 1913 ["out_ptr"], 1914 ) 1915 1916 @make_mutation_test 1917 def test_add_for_loop2(): 1918 @triton.jit 1919 def add_1_time_kernel( 1920 in_ptr0, 1921 in_ptr1, 1922 out_ptr, 1923 n_elements, 1924 BLOCK_SIZE: "tl.constexpr", 1925 ): 1926 pid = tl.program_id(axis=0) 1927 block_start = pid * BLOCK_SIZE 1928 offsets = block_start + tl.arange(0, BLOCK_SIZE) 1929 mask = offsets < n_elements 1930 x = tl.load(in_ptr0 + offsets, mask=mask) 1931 y = tl.load(in_ptr1 + offsets, mask=mask) 1932 for i in range(0, BLOCK_SIZE): 1933 i = tl.multiple_of(i, 1) 1934 output = x + y 1935 tl.store(out_ptr + offsets, output, mask=mask) 1936 1937 t = torch.randn(4) 1938 return ( 1939 add_1_time_kernel, 1940 { 1941 "in_ptr0": t, 1942 "in_ptr1": t, 1943 "out_ptr": t, 1944 "n_elements": 4, 1945 "BLOCK_SIZE": 4, 1946 }, 1947 ["out_ptr"], 1948 ) 1949 1950 @make_mutation_test 1951 def test_add_nested_for_loop(): 1952 @triton.jit 1953 def add_4_times_kernel( 1954 in_ptr0, 1955 in_ptr1, 1956 out_ptr, 1957 n_elements, 1958 BLOCK_SIZE: "tl.constexpr", 1959 ): 1960 pid = tl.program_id(axis=0) 1961 block_start = pid * BLOCK_SIZE 1962 offsets = block_start + tl.arange(0, BLOCK_SIZE) 1963 mask = offsets < n_elements 1964 x = tl.load(in_ptr0 + offsets, mask=mask) 1965 y = tl.load(in_ptr1 + offsets, mask=mask) 1966 output = tl.zeros((n_elements,), dtype=tl.float32) 1967 for i in range(2): 1968 for j in range(2): 1969 output += x + y 1970 tl.store(out_ptr + offsets, output, mask=mask) 1971 1972 t = torch.randn(4) 1973 return ( 1974 add_4_times_kernel, 1975 { 1976 "in_ptr0": t, 1977 "in_ptr1": t, 1978 "out_ptr": t, 1979 "n_elements": 4, 1980 "BLOCK_SIZE": 4, 1981 }, 1982 ["out_ptr"], 1983 ) 1984 1985 @make_mutation_test 1986 def test_add_nested_for_loop_multi_return(): 1987 @triton.jit 1988 def add_4_times_kernel( 1989 in_ptr0, 1990 in_ptr1, 1991 out_ptr, 1992 n_elements, 1993 BLOCK_SIZE: "tl.constexpr", 1994 ): 1995 pid = tl.program_id(axis=0) 1996 block_start = pid * BLOCK_SIZE 1997 offsets = block_start + tl.arange(0, BLOCK_SIZE) 1998 mask = offsets < n_elements 1999 x = tl.load(in_ptr0 + offsets, mask=mask) 2000 y = tl.load(in_ptr1 + offsets, mask=mask) 2001 output1 = tl.zeros((n_elements,), dtype=tl.float32) 2002 output2 = tl.zeros((n_elements,), dtype=tl.float32) 2003 for i in range(2): 2004 for j in range(2): 2005 output1 += y 2006 output2 += x 2007 output = output1 + output2 2008 tl.store(out_ptr + offsets, output, mask=mask) 2009 2010 t = torch.randn(4) 2011 return ( 2012 add_4_times_kernel, 2013 { 2014 "in_ptr0": t, 2015 "in_ptr1": t, 2016 "out_ptr": t, 2017 "n_elements": 4, 2018 "BLOCK_SIZE": 4, 2019 }, 2020 ["out_ptr"], 2021 ) 2022 2023 @make_mutation_test 2024 def test_labels(): 2025 @triton.jit 2026 def kernel_with_label( 2027 in_ptr0, 2028 in_ptr1, 2029 out_ptr, 2030 n_elements, 2031 BLOCK_SIZE: "tl.constexpr", 2032 ): 2033 pid = tl.program_id(axis=0) 2034 if pid > 1: 2035 return 2036 block_start = pid * BLOCK_SIZE 2037 offsets = block_start + tl.arange(0, BLOCK_SIZE) 2038 mask = offsets < n_elements 2039 x = tl.load(in_ptr0 + offsets, mask=mask) 2040 y = tl.load(in_ptr1 + offsets, mask=mask) 2041 output = x + y 2042 tl.store(out_ptr + offsets, output, mask=mask) 2043 2044 t = torch.randn(4) 2045 return ( 2046 kernel_with_label, 2047 { 2048 "in_ptr0": t, 2049 "in_ptr1": t, 2050 "out_ptr": t, 2051 "n_elements": 4, 2052 "BLOCK_SIZE": 4, 2053 }, 2054 ["out_ptr"], 2055 ) 2056 2057 @make_mutation_test 2058 def test_for_loop_arg(): 2059 @triton.jit 2060 def fwd_kernel( 2061 X_ptr, 2062 W1_ptr, 2063 b1_ptr, 2064 O_ptr, 2065 M: tl.constexpr, 2066 C1: tl.constexpr, 2067 C2: tl.constexpr, 2068 BLOCK_SIZE_M: tl.constexpr, 2069 BLOCK_SIZE_C2: tl.constexpr, 2070 ): 2071 # Get program ids 2072 pid_m = tl.program_id(0) 2073 2074 # Compute offsets 2075 offs_c1 = tl.arange(0, C1) 2076 offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) 2077 2078 # Load input data 2079 x_block_ptr = X_ptr + offs_m[:, None] * C1 + offs_c1[None, :] 2080 x = tl.load(x_block_ptr) 2081 2082 # Compute gating 2083 for c2 in range(0, tl.cdiv(C2, BLOCK_SIZE_C2)): 2084 # Compute block pointers 2085 offs_c2 = c2 * BLOCK_SIZE_C2 + tl.arange(0, BLOCK_SIZE_C2) 2086 o_block_ptr = O_ptr + offs_m[:, None] * C2 + offs_c2[None, :] 2087 w1_block_ptr = W1_ptr + offs_c1[:, None] * C2 + offs_c2[None, :] 2088 b1_block_ptr = b1_ptr + offs_c2 2089 2090 # Compute output 2091 w = tl.load(w1_block_ptr) 2092 b = tl.load(b1_block_ptr) 2093 o = tl.dot(x, w, allow_tf32=False) 2094 o += b[None, :] 2095 2096 # Store output 2097 tl.store(o_block_ptr, o) 2098 2099 t = torch.randn(64) 2100 return ( 2101 fwd_kernel, 2102 { 2103 "X_ptr": t, 2104 "W1_ptr": t, 2105 "b1_ptr": t, 2106 "O_ptr": t, 2107 "M": 64, 2108 "C1": 64, 2109 "C2": 64, 2110 "BLOCK_SIZE_M": 64, 2111 "BLOCK_SIZE_C2": 64, 2112 }, 2113 ["O_ptr"], 2114 ) 2115 2116 @make_mutation_test 2117 def test_for_loop_arg_2(): 2118 @triton.jit 2119 def fwd_kernel( 2120 x_ptr, 2121 o_ptr, 2122 M, 2123 N, 2124 stride_m, 2125 stride_n, 2126 BLOCK_B: tl.constexpr, 2127 BLOCK_M: tl.constexpr, 2128 BLOCK_N: tl.constexpr, 2129 ): 2130 # Get program ids 2131 pid_m = tl.program_id(0) 2132 X_block_ptr = tl.make_block_ptr( 2133 base=x_ptr, 2134 shape=(M, N), 2135 strides=(stride_m, stride_n), 2136 offsets=(0, 0), 2137 block_shape=(BLOCK_M, BLOCK_N), 2138 order=(1, 0), 2139 ) 2140 O_block_ptr = tl.make_block_ptr( 2141 base=o_ptr, 2142 shape=(M, N), 2143 strides=(stride_m, stride_n), 2144 offsets=(0, 0), 2145 block_shape=(BLOCK_M, BLOCK_N), 2146 order=(1, 0), 2147 ) 2148 2149 for _ in range(BLOCK_B): 2150 x = tl.load(X_block_ptr) 2151 tl.store(O_block_ptr, x) 2152 2153 X_block_ptr = tl.advance(X_block_ptr, (BLOCK_M, 0)) 2154 O_block_ptr = tl.advance(O_block_ptr, (BLOCK_M, 0)) 2155 2156 t = torch.randn((32, 64, 128)) 2157 o = torch.empty_like(t) 2158 B, M, N = t.shape 2159 return ( 2160 fwd_kernel, 2161 { 2162 "x_ptr": t, 2163 "o_ptr": o, 2164 "M": M, 2165 "N": N, 2166 "stride_m": N, 2167 "stride_n": 1, 2168 "BLOCK_B": B, 2169 "BLOCK_M": M, 2170 "BLOCK_N": N, 2171 }, 2172 ["o_ptr"], 2173 ) 2174 2175 @make_mutation_test 2176 def test_while_loop(): 2177 @triton.jit 2178 def fwd_kernel( 2179 x_ptr, 2180 o_ptr, 2181 M, 2182 N, 2183 stride_m, 2184 stride_n, 2185 BLOCK_B: tl.constexpr, 2186 BLOCK_M: tl.constexpr, 2187 BLOCK_N: tl.constexpr, 2188 ): 2189 # Get program ids 2190 pid_m = tl.program_id(0) 2191 X_block_ptr = tl.make_block_ptr( 2192 base=x_ptr, 2193 shape=(M, N), 2194 strides=(stride_m, stride_n), 2195 offsets=(0, 0), 2196 block_shape=(BLOCK_M, BLOCK_N), 2197 order=(1, 0), 2198 ) 2199 O_block_ptr = tl.make_block_ptr( 2200 base=o_ptr, 2201 shape=(M, N), 2202 strides=(stride_m, stride_n), 2203 offsets=(0, 0), 2204 block_shape=(BLOCK_M, BLOCK_N), 2205 order=(1, 0), 2206 ) 2207 2208 i = 0 2209 while i < BLOCK_B: 2210 x = tl.load(X_block_ptr) 2211 tl.store(O_block_ptr, x) 2212 2213 X_block_ptr = tl.advance(X_block_ptr, (BLOCK_M, 0)) 2214 O_block_ptr = tl.advance(O_block_ptr, (BLOCK_M, 0)) 2215 i += 1 2216 2217 t = torch.randn((32, 64, 128)) 2218 o = torch.empty_like(t) 2219 B, M, N = t.shape 2220 return ( 2221 fwd_kernel, 2222 { 2223 "x_ptr": t, 2224 "o_ptr": o, 2225 "M": M, 2226 "N": N, 2227 "stride_m": N, 2228 "stride_n": 1, 2229 "BLOCK_B": B, 2230 "BLOCK_M": M, 2231 "BLOCK_N": N, 2232 }, 2233 ["o_ptr"], 2234 ) 2235 2236 2237if HAS_GPU: 2238 t = torch.randn(4) 2239 tt = torch.randn(4, 1) 2240 tests = [ 2241 [ 2242 add_kernel, 2243 { 2244 "in_ptr0": t, 2245 "in_ptr1": t, 2246 "out_ptr": t, 2247 "n_elements": 4, 2248 "BLOCK_SIZE": 4, 2249 }, 2250 ["out_ptr"], 2251 ], 2252 [ 2253 add_kernel_2d_autotuned, 2254 { 2255 "in_ptr0": t, 2256 "in_ptr1": t, 2257 "out_ptr": t, 2258 "x_elements": 4, 2259 "y_elements": 4, 2260 }, 2261 ["out_ptr"], 2262 ], 2263 [ 2264 indirection_kernel, 2265 { 2266 "in_ptr0": t, 2267 "out_ptr": t, 2268 "n_elements": 4, 2269 "BLOCK_SIZE": 4, 2270 "ACTIVATION": "mul2_inplace_kernel", 2271 }, 2272 ["in_ptr0", "out_ptr"], 2273 ], 2274 [ 2275 indirection_kernel, 2276 { 2277 "in_ptr0": t, 2278 "out_ptr": t, 2279 "n_elements": 4, 2280 "BLOCK_SIZE": 4, 2281 "ACTIVATION": "add_kernel", 2282 }, 2283 ["out_ptr"], 2284 ], 2285 [ 2286 mul2_inplace_kernel, 2287 {"ptr": t, "n_elements": 4, "BLOCK_SIZE": 4}, 2288 ["ptr"], 2289 ], 2290 # Cant optimize since the kernel contains a tl.inline_asm_elementwise 2291 [ 2292 inline_asm_kernel, 2293 {"X": t, "Y": t, "Z": t, "n": 4, "BLOCK": 4}, 2294 ["X", "Y", "Z"], 2295 ], 2296 [ 2297 add_kernel_with_block_ptr, 2298 { 2299 "x_ptr": t, 2300 "y_ptr": t, 2301 "output_ptr": t, 2302 "n_elements": 4, 2303 "BLOCK_SIZE": 4, 2304 }, 2305 ["output_ptr"], 2306 ], 2307 [ 2308 kernel_with_block_ptr_2d, 2309 { 2310 "x_ptr": tt, 2311 "output_ptr": tt, 2312 "n_elements": 4, 2313 "BLOCK_SIZE": 4, 2314 }, 2315 ["output_ptr"], 2316 ], 2317 [ 2318 add_kernel_with_import, 2319 { 2320 "in_ptr0": t, 2321 "in_ptr1": t, 2322 "out_ptr": t, 2323 "n_elements": 4, 2324 "BLOCK_SIZE": 4, 2325 }, 2326 ["out_ptr"], 2327 ], 2328 [ 2329 atomic_add_kernel, 2330 { 2331 "in_ptr0": t, 2332 "in_ptr1": t, 2333 "out_ptr": t, 2334 "n_elements": 4, 2335 "BLOCK_SIZE": 4, 2336 }, 2337 ["out_ptr"], 2338 ], 2339 [ 2340 add_4_times_kernel, 2341 { 2342 "in_ptr0": t, 2343 "in_ptr1": t, 2344 "out_ptr": t, 2345 "n_elements": 4, 2346 "BLOCK_SIZE": 4, 2347 }, 2348 ["out_ptr"], 2349 ], 2350 [ 2351 cond_op_kernel, 2352 { 2353 "in_ptr0": t, 2354 "in_ptr1": t, 2355 "out_ptr": t, 2356 "n_elements": 4, 2357 "BLOCK_SIZE": 4, 2358 }, 2359 ["out_ptr"], 2360 ], 2361 ] 2362 for kernel, inputs, outputs in tests: 2363 fn = make_mutation_test( 2364 # Add default arguments to avoid Python lambda capture pitfall 2365 # This forces the capture at lambda creation 2366 lambda kernel=kernel, inputs=inputs, outputs=outputs: ( 2367 kernel, 2368 inputs, 2369 outputs, 2370 ) 2371 ) 2372 name = f"test_mutations_{kernel.fn.__name__}" 2373 # Poor way to make test names be unique 2374 while name in MutationTests.__dict__: 2375 name += "1" 2376 2377 setattr(MutationTests, name, fn) 2378 2379 2380class CustomOpTests(torch._inductor.test_case.TestCase): 2381 """Tests for custom ops wrapping triton kernels""" 2382 2383 @requires_gpu 2384 @common_utils.parametrize("autotuned", [False, True]) 2385 @common_utils.parametrize("dynamic", [False, True]) 2386 def test_add_kernel(self, autotuned, dynamic): 2387 from torch._inductor.utils import run_and_get_code 2388 2389 libname = "my_cool_namespace" 2390 opname = "my_triton_operator" 2391 2392 @torch._library.triton_op(f"{libname}::{opname}", mutates_args={}) 2393 def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 2394 output = torch.empty_like(x) 2395 n_elements = output.numel() 2396 2397 def grid(meta): 2398 return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 2399 2400 if autotuned: 2401 capture_triton(add_kernel_autotuned)[grid](x, y, output, n_elements) 2402 else: 2403 capture_triton(add_kernel)[grid](x, y, output, n_elements, 16) 2404 return output 2405 2406 def f(x, y): 2407 return add(x, y) 2408 2409 x = torch.randn(3, device=GPU_TYPE) 2410 y = torch.randn(3, device=GPU_TYPE) 2411 2412 out = f(x, y) 2413 expected = x + y 2414 self.assertEqual(out, expected) 2415 out_compiled, codes = run_and_get_code(torch.compile(f, dynamic=dynamic), x, y) 2416 self.assertEqual(out_compiled, expected) 2417 self.assertEqual(len(codes), 1) 2418 2419 # Check that we decomposed the operator away 2420 code = "\n".join(codes[0]) 2421 self.assertNotIn(libname, code) 2422 self.assertNotIn(opname, code) 2423 2424 @unittest.skipIf(not has_triton_package(), "requires triton") 2425 def test_capture_triton_meta(self): 2426 import triton 2427 import triton.language as tl 2428 2429 @triton.jit 2430 def add_kernel( 2431 in_ptr0, 2432 in_ptr1, 2433 out_ptr, 2434 n_elements, 2435 BLOCK_SIZE: "tl.constexpr", 2436 ): 2437 pid = tl.program_id(axis=0) 2438 block_start = pid * BLOCK_SIZE 2439 offsets = block_start + tl.arange(0, BLOCK_SIZE) 2440 mask = offsets < n_elements 2441 x = tl.load(in_ptr0 + offsets, mask=mask) 2442 y = tl.load(in_ptr1 + offsets, mask=mask) 2443 output = x + y 2444 tl.store(out_ptr + offsets, output, mask=mask) 2445 2446 @torch._library.triton_op("mylib::add", mutates_args=()) 2447 def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 2448 output = torch.empty_like(x) 2449 n_elements = output.numel() 2450 2451 def grid(meta): 2452 return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 2453 2454 capture_triton(add_kernel)[grid](x, y, output, n_elements, 16) 2455 return output 2456 2457 def f(x, y): 2458 return add(x, y) 2459 2460 x = torch.randn(3, device="meta") 2461 y = torch.randn(3, device="meta") 2462 2463 out = f(x, y) 2464 expected = torch.empty_like(x) 2465 self.assertEqual(out, expected) 2466 2467 @requires_gpu 2468 def test_capture_triton_disabled_in_triton_op(self): 2469 import triton 2470 import triton.language as tl 2471 2472 @triton.jit 2473 def add_kernel( 2474 in_ptr0, 2475 in_ptr1, 2476 out_ptr, 2477 n_elements, 2478 BLOCK_SIZE: "tl.constexpr", 2479 ): 2480 pid = tl.program_id(axis=0) 2481 block_start = pid * BLOCK_SIZE 2482 offsets = block_start + tl.arange(0, BLOCK_SIZE) 2483 mask = offsets < n_elements 2484 x = tl.load(in_ptr0 + offsets, mask=mask) 2485 y = tl.load(in_ptr1 + offsets, mask=mask) 2486 output = x + y 2487 tl.store(out_ptr + offsets, output, mask=mask) 2488 2489 add_kernel_decorated = torch._library.capture_triton(add_kernel) 2490 2491 status = [] 2492 2493 @torch._library.triton_op("mylib::add", mutates_args=()) 2494 def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 2495 import torch._higher_order_ops.triton_kernel_wrap 2496 2497 status.append(torch._library.triton.is_capture_triton_enabled()) 2498 2499 # capture_triton should return the kernel directly if disabled 2500 result = torch._library.capture_triton(add_kernel) 2501 self.assertIs(result, add_kernel) 2502 2503 # Smoke test: check that with capture_triton disabled this still does something 2504 output = torch.empty_like(x) 2505 output2 = torch.empty_like(x) 2506 2507 n_elements = output.numel() 2508 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 2509 add_kernel_decorated[grid](x, y, output, n_elements, BLOCK_SIZE=16) 2510 2511 add_kernel_decorated.run( 2512 x, y, output2, n_elements, BLOCK_SIZE=16, grid=grid, warmup=False 2513 ) 2514 2515 return output + output2 2516 2517 x = torch.randn(3, device=GPU_TYPE) 2518 y = torch.randn(3, device=GPU_TYPE) 2519 z = add(x, y) 2520 self.assertEqual(status[-1], False) 2521 self.assertEqual(z, (x + y) * 2) 2522 2523 @requires_gpu 2524 @common_utils.parametrize("dynamic", [False, True]) 2525 @common_utils.parametrize("autotune", [False, True]) 2526 def test_capture_triton_special_kwargs(self, dynamic, autotune): 2527 @triton.jit 2528 def add_kernel( 2529 in_ptr0, 2530 in_ptr1, 2531 out_ptr, 2532 n_elements, 2533 BLOCK_SIZE: "tl.constexpr", 2534 ): 2535 pid = tl.program_id(axis=0) 2536 block_start = pid * BLOCK_SIZE 2537 offsets = block_start + tl.arange(0, BLOCK_SIZE) 2538 mask = offsets < n_elements 2539 x = tl.load(in_ptr0 + offsets, mask=mask) 2540 y = tl.load(in_ptr1 + offsets, mask=mask) 2541 output = x + y 2542 tl.store(out_ptr + offsets, output, mask=mask) 2543 2544 if autotune: 2545 add_kernel = triton.autotune( 2546 configs=[ 2547 triton.Config({"BLOCK_SIZE": 128}), 2548 triton.Config({"BLOCK_SIZE": 64}), 2549 ], 2550 key=["n_elements"], 2551 )(add_kernel) 2552 2553 def f(x, y): 2554 output = torch.zeros_like(x) 2555 n_elements = output.numel() 2556 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 2557 if autotune: 2558 kwargs = {} 2559 else: 2560 kwargs = {"BLOCK_SIZE": 128} 2561 capture_triton(add_kernel)[grid]( 2562 x, 2563 y, 2564 output, 2565 n_elements, 2566 num_warps=8, 2567 num_stages=3, 2568 **kwargs, 2569 ) 2570 return output 2571 2572 x = torch.randn(4, device=GPU_TYPE) 2573 tracing_mode = "symbolic" if dynamic else "fake" 2574 2575 result = f(x, x) 2576 self.assertEqual(result, x + x) 2577 2578 from functorch import make_fx 2579 2580 gm = make_fx(f, tracing_mode=tracing_mode)(x, x) 2581 self.assertEqual(gm(x, x), x + x) 2582 2583 2584common_utils.instantiate_parametrized_tests(KernelTests) 2585common_utils.instantiate_parametrized_tests(CustomOpTests) 2586 2587 2588if __name__ == "__main__": 2589 from torch._inductor.test_case import run_tests 2590 2591 run_tests() 2592