xref: /aosp_15_r20/external/pytorch/test/inductor/test_triton_kernels.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: inductor"]
2# flake8: noqa: E731
3# Skip do not assign a lambda expression, use a def
4import functools
5from unittest.mock import patch
6
7import torch
8import torch._dynamo.testing
9import torch._inductor.test_case
10from torch._higher_order_ops.triton_kernel_wrap import (
11    generate_ttir,
12    triton_kernel_wrapper_functional,
13    triton_kernel_wrapper_mutation,
14)
15from torch._inductor import metrics
16from torch._inductor.utils import run_and_get_code
17from torch._library import capture_triton
18from torch.testing._internal import common_utils
19from torch.testing._internal.common_utils import skipIfRocm, skipIfXpu, TEST_WITH_ROCM
20from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CUDA, HAS_GPU, HAS_XPU
21from torch.testing._internal.logging_utils import logs_to_string
22
23# Defines all the kernels for tests
24from torch.testing._internal.triton_utils import *  # noqa: F403
25from torch.utils._triton import has_triton_package
26
27
28if HAS_GPU:
29    import triton
30    from triton import language as tl
31
32    if not TEST_WITH_ROCM:
33        if HAS_CUDA:
34            from triton.language.extra.cuda.libdevice import (
35                fast_dividef,
36                fast_dividef as my_fast_dividef,
37            )
38        elif HAS_XPU:
39            from triton.language.extra.intel.libdevice import (
40                fast_dividef,
41                fast_dividef as my_fast_dividef,
42            )
43
44    # Define shared triton constants here.
45    CONSTANT_C: tl.constexpr = 4
46    STRING_CONSTANT_C: tl.constexpr = "CONSTANT_C"
47    BOOL_CONSTANT_C: tl.constexpr = True
48
49
50class KernelTests(torch._inductor.test_case.TestCase):
51    @requires_gpu
52    def test_triton_kernel_with_kernel_param(self):
53        @triton.jit
54        def pass_kernel(kernel):
55            pass
56
57        @torch.compile(backend="eager")
58        def f(x):
59            grid = (x.numel(),)
60            pass_kernel[grid](kernel=x)
61
62        t1 = torch.rand(5, device=GPU_TYPE)
63        f(t1)
64        # No need to assert anything, the goal is to make sure dynamo does
65        # not crash
66
67    @requires_gpu
68    def test_triton_kernel_higher_order_func(self):
69        from torch._higher_order_ops.triton_kernel_wrap import kernel_side_table
70
71        add_kernel_id = kernel_side_table.add_kernel(add_kernel)
72
73        t1 = torch.rand(5, device=GPU_TYPE)
74        t2 = torch.rand(5, device=GPU_TYPE)
75
76        torch_add = t1 + t2
77
78        # Test higher order function with mutation
79        output = torch.zeros_like(t1)
80        n_elements = output.numel()
81        constant_args_idx = kernel_side_table.add_constant_args(
82            {"n_elements": n_elements, "BLOCK_SIZE": 16}
83        )
84        grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
85        triton_kernel_wrapper_mutation(
86            kernel_idx=add_kernel_id,
87            constant_args_idx=constant_args_idx,
88            grid=[grid],
89            kwargs={
90                "in_ptr0": t1,
91                "in_ptr1": t2,
92                "out_ptr": output,
93            },
94        )
95        self.assertEqual(output, torch_add)
96        # Make sure it is modified
97        self.assertNotEqual(output, torch.zeros_like(t1))
98
99        # Test higher order function without mutation
100        output = torch.zeros_like(t1)
101        out_dict = triton_kernel_wrapper_functional(
102            kernel_idx=add_kernel_id,
103            constant_args_idx=constant_args_idx,
104            grid=[grid],
105            kwargs={
106                "in_ptr0": t1,
107                "in_ptr1": t2,
108                "out_ptr": output,
109            },
110            tensors_to_clone=["in_ptr0", "in_ptr1", "out_ptr"],
111        )
112        self.assertEqual(out_dict["out_ptr"], torch_add)
113        # Make sure it is NOT modified
114        self.assertEqual(output, torch.zeros_like(t1))
115
116    @requires_gpu
117    def test_triton_kernel_functionalize(self):
118        from functorch import make_fx
119        from torch._higher_order_ops.triton_kernel_wrap import kernel_side_table
120        from torch._subclasses.functional_tensor import (
121            CppFunctionalizeAPI,
122            FunctionalTensorMode,
123            PythonFunctionalizeAPI,
124        )
125
126        kernel_side_table.reset_table()
127
128        def f(x, output):
129            out = triton_kernel_wrapper_functional(
130                kernel_idx=kernel_side_table.add_kernel(mul2_kernel),
131                constant_args_idx=kernel_side_table.add_constant_args(
132                    {"n_elements": output.numel(), "BLOCK_SIZE": 16}
133                ),
134                grid=[(x.numel(),)],
135                kwargs={
136                    "in_ptr0": x,
137                    "out_ptr": output,
138                },
139                tensors_to_clone=["in_ptr0", "out_ptr"],
140            )
141            return out["out_ptr"]
142
143        t1 = torch.rand(5, device=GPU_TYPE)
144        t2 = torch.rand(5, device=GPU_TYPE)
145        with FunctionalTensorMode():
146            gm = make_fx(PythonFunctionalizeAPI().functionalize(f))(t1, t2)
147        # Make sure t2 was not modified
148        self.assertNotEqual(gm(t1, t2), t2)
149
150        gm = make_fx(CppFunctionalizeAPI().functionalize(f))(t1, t2)
151        # Make sure t2 was not modified
152        self.assertNotEqual(gm(t1, t2), t2)
153
154        gm = make_fx(torch.func.functionalize(f))(t1, t2)
155        # Make sure t2 was not modified
156        self.assertNotEqual(gm(t1, t2), t2)
157
158        gm = make_fx(f, tracing_mode="fake")(t1, t2)
159        self.assertExpectedInline(
160            gm.code.strip(),
161            """\
162def forward(self, x_1, output_1):
163    triton_kernel_wrapper_functional_proxy = torch.ops.higher_order.triton_kernel_wrapper_functional(kernel_idx = 0, constant_args_idx = 3, grid = [(5,)], kwargs = {'in_ptr0': x_1, 'out_ptr': output_1}, tensors_to_clone = ['in_ptr0', 'out_ptr']);  x_1 = output_1 = None
164    getitem = triton_kernel_wrapper_functional_proxy['in_ptr0'];  getitem = None
165    getitem_1 = triton_kernel_wrapper_functional_proxy['out_ptr'];  triton_kernel_wrapper_functional_proxy = None
166    return getitem_1""",
167        )
168
169    @requires_gpu
170    def test_triton_kernel_mutation_type(self):
171        from torch._higher_order_ops.triton_kernel_wrap import kernel_side_table
172        from torch._subclasses.fake_tensor import FakeTensorMode
173        from torch._subclasses.functional_tensor import (
174            FunctionalTensor,
175            FunctionalTensorMode,
176        )
177
178        def prep():
179            x = torch.ones(4, device=GPU_TYPE, requires_grad=True)
180            with FunctionalTensorMode():
181                x_func = FunctionalTensor.to_functional(x)
182            self.assertTrue(torch._is_functional_tensor(x_func.elem))
183            return x_func
184
185        # normal mutation only
186        with FakeTensorMode():
187            x_func = prep()
188
189            with FunctionalTensorMode():
190                x_func.mul_(2)
191
192            self.assertFalse(
193                torch._functionalize_are_all_mutations_hidden_from_autograd(x_func.elem)
194            )
195
196        # triton kernel mutation only
197        with FakeTensorMode():
198            x_func = prep()
199
200            with FunctionalTensorMode():
201                triton_kernel_wrapper_mutation(
202                    kernel_idx=kernel_side_table.add_kernel(mul2_inplace_kernel),
203                    constant_args_idx=kernel_side_table.add_constant_args(
204                        {"n_elements": x_func.numel(), "BLOCK_SIZE": 16}
205                    ),
206                    grid=[(x_func.numel(),)],
207                    kwargs={
208                        "ptr": x_func,
209                    },
210                )
211
212            self.assertTrue(
213                torch._functionalize_are_all_mutations_hidden_from_autograd(x_func.elem)
214            )
215
216        # normal mutation + triton kernel mutation
217        with FakeTensorMode():
218            x_func = prep()
219
220            with FunctionalTensorMode():
221                x_func.mul_(2)
222                triton_kernel_wrapper_mutation(
223                    kernel_idx=kernel_side_table.add_kernel(mul2_inplace_kernel),
224                    constant_args_idx=kernel_side_table.add_constant_args(
225                        {"n_elements": x_func.numel(), "BLOCK_SIZE": 16}
226                    ),
227                    grid=[(x_func.numel(),)],
228                    kwargs={
229                        "ptr": x_func,
230                    },
231                )
232
233            self.assertFalse(
234                torch._functionalize_are_all_mutations_hidden_from_autograd(x_func.elem)
235            )
236
237    @requires_gpu
238    @common_utils.parametrize("dynamic", [False, True])
239    @common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"])
240    def test_triton_kernel_with_views(self, dynamic, backend):
241        def call_triton_take_view(x: torch.Tensor):
242            output = torch.zeros_like(x)
243            n_elements = output.numel()
244            grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
245            mul2_kernel[grid](x, output, n_elements, BLOCK_SIZE=16)
246            return output
247
248        def call_triton_return_view(x: torch.Tensor):
249            output = torch.zeros_like(x)
250            n_elements = output.numel()
251            grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
252            mul2_kernel[grid](x, output, n_elements, BLOCK_SIZE=16)
253            return output.view(4, 4)
254
255        t = torch.rand(4, 4, device=GPU_TYPE)
256        t_view = t.view(16)
257
258        compiled_func = torch.compile(
259            call_triton_take_view, backend=backend, fullgraph=True, dynamic=dynamic
260        )
261        self.assertEqual(2 * t_view, compiled_func(t_view))
262        self.assertEqual(2 * t, compiled_func(t_view).view(4, 4))
263
264        compiled_func = torch.compile(
265            call_triton_return_view, backend=backend, fullgraph=True, dynamic=dynamic
266        )
267        self.assertEqual(2 * t_view, compiled_func(t).view(16))
268        self.assertEqual(2 * t, compiled_func(t))
269
270    @requires_gpu
271    def test_no_nan_kernels(self):
272        @triton.jit
273        def add_one_kernel(
274            in_ptr0,
275            out_ptr,
276            n_elements,
277            BLOCK_SIZE: "tl.constexpr",
278        ):
279            pid = tl.program_id(axis=0)
280            block_start = pid * BLOCK_SIZE
281            offsets = block_start + tl.arange(0, BLOCK_SIZE)
282            mask = offsets < n_elements
283            x = tl.load(in_ptr0 + offsets, mask=mask)
284            output = x + 1
285            tl.store(out_ptr + offsets, output, mask=mask)
286
287        def add_one(x, out):
288            n_elements = x.numel()
289            add_one_kernel[(n_elements,)](x, out, n_elements, BLOCK_SIZE=4)
290
291        class AddOne(torch.autograd.Function):
292            @staticmethod
293            def forward(ctx, x):
294                out = torch.empty_like(x)
295                add_one(x, out)
296                ctx.save_for_backward(out)
297                return out
298
299            @staticmethod
300            def backward(ctx, grad):
301                (saved,) = ctx.saved_tensors
302                out = torch.empty_like(grad)
303                add_one(saved, out)
304                return out
305
306        @torch.compile
307        def f(x):
308            return AddOne.apply(x)
309
310        log_stream, ctx = logs_to_string("torch._inductor.codecache", "output_code")
311
312        x = torch.randn(3, requires_grad=True, device=GPU_TYPE)
313        with ctx():
314            y = f(x)
315
316        output_code = "\n".join(log_stream.getvalue().strip().split("\n")[3:]).strip()
317        self.assertTrue(len(output_code) > 0, msg="output code is not empty")
318        self.assertEqual(output_code.count('float("nan")'), 0)
319        self.assertEqual(output_code.count("float('nan')"), 0)
320
321    @requires_gpu
322    @common_utils.parametrize("grad_fn", [torch.no_grad, torch.enable_grad])
323    @common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"])
324    def test_triton_kernel_with_grad_option(self, grad_fn, backend):
325        def call_triton(x: torch.Tensor):
326            with grad_fn():
327                output = torch.zeros_like(x)
328                n_elements = output.numel()
329                grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
330                mul2_kernel[grid](x, output, n_elements, BLOCK_SIZE=16)
331                return output
332
333        t = torch.rand(5, device=GPU_TYPE)
334        compiled_func = torch.compile(call_triton, backend=backend, fullgraph=True)
335        self.assertEqual(2 * t, compiled_func(t))
336
337    @requires_gpu
338    @common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"])
339    def test_triton_kernel_inner_triton_function(self, backend):
340        def f(x: torch.Tensor):
341            @triton.jit
342            def pow2_kernel(
343                in_ptr0,
344                out_ptr,
345                n_elements,
346                BLOCK_SIZE: "tl.constexpr",
347            ):
348                pid = tl.program_id(axis=0)
349                block_start = pid * BLOCK_SIZE
350                offsets = block_start + tl.arange(0, BLOCK_SIZE)
351                mask = offsets < n_elements
352                x = tl.load(in_ptr0 + offsets, mask=mask)
353                output = x * x
354                tl.store(out_ptr + offsets, output, mask=mask)
355
356            output = torch.zeros_like(x)
357            n_elements = output.numel()
358            grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
359            pow2_kernel[grid](x, output, n_elements, BLOCK_SIZE=16)
360            return output
361
362        t = torch.rand(5, device=GPU_TYPE)
363
364        compiled_func = torch.compile(f, backend=backend, fullgraph=True)
365        # TODO(oulgen): NYI - Support this
366        # self.assertEqual(t * t, compiled_func(t))
367
368    @requires_gpu
369    @common_utils.parametrize("grad", [False, True])
370    @common_utils.parametrize("dynamic", [False, True])
371    @patch.object(torch._inductor.config, "implicit_fallbacks", False)
372    def test_triton_kernel_no_clones(self, grad, dynamic):
373        from torch._inductor.utils import run_and_get_code
374
375        def call_triton(x: torch.Tensor, y: torch.Tensor, output: torch.Tensor):
376            n_elements = output.numel()
377
378            tmp = torch.add(x, 1)
379            grid = (x.numel(),)
380            add_kernel.run(
381                x, y, output, n_elements, warmup=False, grid=grid, BLOCK_SIZE=16
382            )
383
384            return output, tmp
385
386        t1 = torch.rand(5, device=GPU_TYPE, requires_grad=grad)
387        t2 = torch.rand(5, device=GPU_TYPE, requires_grad=grad)
388        o1 = torch.zeros_like(t1, requires_grad=grad)
389
390        torch_add = call_triton(t1, t2, o1)
391        metrics.reset()
392        o2 = torch.zeros_like(t1, requires_grad=grad)
393        test, codes = run_and_get_code(
394            torch.compile(call_triton, dynamic=dynamic), t1, t2, o2
395        )
396        if not grad:
397            self.assertEqual(metrics.generated_kernel_count, 1)
398        self.assertEqual(torch_add, test)
399        # These two asserts are not optimal since it requires original aten
400        # to be in the metadata, so there might be false negatives
401        self.assertTrue("aten.copy" not in codes[0])
402        self.assertTrue("aten.clone" not in codes[0])
403        # The following checks that there are only the tensor output is in
404        # the compiled graph
405        if dynamic and grad:
406            self.assertTrue("return (buf0, s0, )" in codes[0])
407        else:
408            self.assertTrue("return (buf0, )" in codes[0])
409
410    @requires_gpu
411    def test_triton_kernel_caching(self):
412        from torch._inductor.utils import run_and_get_code
413
414        def add_in_loop(
415            x: torch.Tensor,
416            y: torch.Tensor,
417        ):
418            output = torch.zeros_like(x)
419            n_elements = output.numel()
420            grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
421            add_kernel_autotuned[grid](x, y, output, n_elements)
422            return output
423
424        def call_triton_add(
425            x: torch.Tensor,
426            y: torch.Tensor,
427        ):
428            for i in range(4):
429                x = add_in_loop(x, y)
430            return x
431
432        t1 = torch.ones(5, device=GPU_TYPE)
433        t2 = torch.ones(5, device=GPU_TYPE)
434
435        test, (code,) = run_and_get_code(torch.compile(call_triton_add), t1, t2)
436        self.assertEqual(test, 5 * torch.ones(5, device=GPU_TYPE))
437        self.assertTrue("add_kernel_autotuned_1.run" not in code)
438
439    @requires_gpu
440    def test_triton_kernel_caching_duplicate(self):
441        from torch._inductor.utils import run_and_get_code
442
443        class C:
444            @triton.jit
445            def pass_kernel(
446                in_ptr0,
447                out_ptr,
448                n_elements,
449                BLOCK_SIZE: "tl.constexpr",
450            ):
451                pid = tl.program_id(axis=0)
452                block_start = pid * BLOCK_SIZE
453                offsets = block_start + tl.arange(0, BLOCK_SIZE)
454                mask = offsets < n_elements
455                x = tl.load(in_ptr0 + offsets, mask=mask)
456                tl.store(out_ptr + offsets, x, mask=mask)
457
458        class D:
459            @triton.jit
460            def pass_kernel(
461                in_ptr0,
462                out_ptr,
463                n_elements,
464                BLOCK_SIZE: "tl.constexpr",
465            ):
466                pid = tl.program_id(axis=0)
467                block_start = pid * BLOCK_SIZE
468                offsets = block_start + tl.arange(0, BLOCK_SIZE)
469                mask = offsets < n_elements
470                x = tl.load(in_ptr0 + offsets, mask=mask)
471                tl.store(out_ptr + offsets, x, mask=mask)
472
473        def call_triton(x: torch.Tensor):
474            output1 = torch.zeros_like(x)
475            output2 = torch.zeros_like(x)
476            n_elements = output1.numel()
477            grid = (n_elements,)
478            C.pass_kernel[grid](x, output1, n_elements, BLOCK_SIZE=16)
479            D.pass_kernel[grid](x, output2, n_elements, BLOCK_SIZE=16)
480            return output1 + output2
481
482        t = torch.ones(5, device=GPU_TYPE)
483        test, (code,) = run_and_get_code(torch.compile(call_triton), t)
484        # Make sure we emitted two kernels here
485        self.assertTrue("pass_kernel_0.run" in code)
486        self.assertTrue("pass_kernel_1.run" in code)
487
488    @requires_gpu
489    def test_triton_kernel_various_args(self):
490        @triton.autotune(
491            configs=[triton.Config({"BLOCK_SIZE": 128})],
492            key=[],
493        )
494        @triton.jit
495        def pass_kernel(
496            out_ptr,
497            n_elements,
498            dummy_None,
499            dummy_empty,
500            dummy_float,
501            BLOCK_SIZE: "tl.constexpr",
502            RANDOM_SIZE: "tl.constexpr",
503        ):
504            pass
505
506        @torch.compile
507        def call_triton(output):
508            n_elements = output.numel()
509            grid = (n_elements,)
510            pass_kernel[grid](
511                output,
512                n_elements,
513                None,
514                torch.empty_like(output),
515                3.1415926,
516                RANDOM_SIZE=0,
517            )
518            return output
519
520        output = torch.randn(5, device=GPU_TYPE)
521        # Make sure this does not crash
522        call_triton(output)
523
524    @requires_gpu
525    @skipIfRocm
526    def test_triton_kernel_dependancies(self):
527        def call_triton(
528            x: torch.Tensor,
529            y: torch.Tensor,
530        ):
531            output = torch.zeros_like(x)
532            n_elements = output.numel()
533            grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
534            add_kernel_autotuned[grid](x, y, output, n_elements)
535            output2 = torch.zeros_like(output)
536            add_kernel_autotuned[grid](output, y, output2, n_elements)
537            output3 = torch.add(output2, 1)
538            return output3
539
540        t1 = torch.rand(5, device=GPU_TYPE)
541        t2 = torch.rand(5, device=GPU_TYPE)
542        torch_result = call_triton(t1, t2)
543        compiled_result = torch.compile(call_triton)(t1, t2)
544        self.assertEqual(torch_result, compiled_result)
545
546    @requires_gpu
547    def test_triton_kernel_reinplace_inplaceable_pass(self):
548        def call_triton(
549            x: torch.Tensor,
550            y: torch.Tensor,
551        ):
552            output = torch.zeros_like(x)
553            n_elements = output.numel()
554            grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
555            add_kernel_autotuned[grid](x, y, output, n_elements)
556            add_kernel_autotuned[grid](output, x, output, n_elements)
557            return output
558
559        t1 = torch.rand(5, device=GPU_TYPE)
560        t2 = torch.rand(5, device=GPU_TYPE)
561        torch_result = call_triton(t1, t2)
562        compiled_result = torch.compile(call_triton)(t1, t2)
563        self.assertEqual(torch_result, compiled_result)
564
565    @requires_gpu
566    @common_utils.parametrize("grad", [False, True])
567    def test_triton_kernel_multi_kernel(self, grad):
568        @triton.jit
569        def mul2_and_add_and_zero_negatives_kernel(
570            in_ptr0,
571            in_ptr1,
572            out_ptr,
573            n_elements,
574            BLOCK_SIZE: "tl.constexpr",
575            ACTIVATION: "tl.constexpr",
576        ):
577            pid = tl.program_id(axis=0)
578            block_start = pid * BLOCK_SIZE
579            offsets = block_start + tl.arange(0, BLOCK_SIZE)
580            mask = offsets < n_elements
581            indirection_kernel(
582                in_ptr0,
583                in_ptr0,
584                n_elements,
585                BLOCK_SIZE=BLOCK_SIZE,
586                ACTIVATION="mul2_inplace_kernel",
587            )
588            indirection_kernel(
589                in_ptr1,
590                in_ptr1,
591                n_elements,
592                BLOCK_SIZE=BLOCK_SIZE,
593                ACTIVATION="mul2_inplace_kernel",
594            )
595            x = tl.load(in_ptr0 + offsets, mask=mask)
596            y = tl.load(in_ptr1 + offsets, mask=mask)
597            output = x + y
598            if ACTIVATION == "zero_negs":
599                output = zero_negs(output)
600            tl.store(out_ptr + offsets, output, mask=mask)
601
602        @torch.compile
603        def call_triton(
604            x: torch.Tensor,
605            y: torch.Tensor,
606            xi: torch.Tensor,
607            yi: torch.Tensor,
608            output: torch.Tensor,
609            outputi: torch.Tensor,
610        ):
611            n_elements = output.numel()
612
613            grid = (x.numel(),)
614            mul2_and_add_and_zero_negatives_kernel[grid](
615                x, y, output, n_elements, BLOCK_SIZE=16, ACTIVATION="zero_negs"
616            )
617            mul2_and_add_and_zero_negatives_kernel[grid](
618                xi, yi, outputi, n_elements, BLOCK_SIZE=16, ACTIVATION=None
619            )
620
621            return (output, outputi)
622
623        t1 = torch.tensor(
624            [-2.0, -1.0, 0.0, 1.0, 2.0], device=GPU_TYPE, requires_grad=grad
625        )
626        t2 = torch.tensor(
627            [-2.0, -1.0, 0.0, 1.0, 2.0], device=GPU_TYPE, requires_grad=grad
628        )
629        float_result = 2 * t1 + 2 * t2
630        float_result = float_result.where(float_result >= 0, 0.0)
631
632        t1i = torch.randint(-2, 2, (5,), device=GPU_TYPE)
633        t2i = torch.randint(-2, 2, (5,), device=GPU_TYPE)
634        o = torch.zeros_like(t1, requires_grad=grad)
635        oi = torch.zeros_like(t1i)
636        int_result = 2 * t1i + 2 * t2i
637
638        (result, resulti) = call_triton(t1, t2, t1i, t2i, o, oi)
639        self.assertEqual(float_result, result)
640        self.assertEqual(int_result, resulti)
641
642    @requires_gpu
643    @skipIfXpu
644    @skipIfRocm
645    def test_triton_kernel_constants(self):
646        @triton.jit
647        def mulC_kernel(
648            in_ptr0,
649            out_ptr,
650            n_elements,
651            BLOCK_SIZE: "tl.constexpr",
652            CONSTANT_NAME: "tl.constexpr",
653        ):
654            pid = tl.program_id(axis=0)
655            block_start = pid * BLOCK_SIZE
656            offsets = block_start + tl.arange(0, BLOCK_SIZE)
657            mask = offsets < n_elements
658            x = tl.load(in_ptr0 + offsets, mask=mask)
659            if CONSTANT_NAME == STRING_CONSTANT_C:
660                output = CONSTANT_C * x
661            if BOOL_CONSTANT_C:
662                output *= CONSTANT_C
663            tl.store(out_ptr + offsets, output, mask=mask)
664
665        def call_triton(
666            x: torch.Tensor,
667        ):
668            output = torch.zeros_like(x)
669            n_elements = output.numel()
670
671            grid = (x.numel(),)
672            mulC_kernel[grid](
673                x, output, n_elements, BLOCK_SIZE=16, CONSTANT_NAME="CONSTANT_C"
674            )
675            return output
676
677        # Triton kernels capture global constants by their parse time value
678        # not runtime value
679        global CONSTANT_C
680        prev_c = CONSTANT_C
681        # If the behavior of triton kernels change, this test will fail
682        CONSTANT_C = 10
683        assert CONSTANT_C != prev_c
684
685        t = torch.randn(5, device=GPU_TYPE)
686        torch_result = call_triton(t)
687        compiled_result = torch.compile(call_triton)(t)
688
689        self.assertEqual(torch_result, compiled_result)
690
691        # reset back
692        CONSTANT_C = prev_c
693
694    @requires_gpu
695    @common_utils.parametrize("grad", [False, True])
696    @common_utils.parametrize("dynamic", [False, True])
697    @common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"])
698    @common_utils.parametrize("grid_type", [1, 2, 3])
699    def test_triton_kernel_autotune(self, grad, dynamic, backend, grid_type):
700        def call_triton(x: torch.Tensor, y: torch.Tensor, output: torch.Tensor):
701            n_elements = output.numel()
702
703            def grid_fn(meta):
704                return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
705
706            if grid_type == 1:
707                grid = (n_elements,)
708            elif grid_type == 2:
709                grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
710            elif grid_type == 3:
711                grid = grid_fn
712
713            add_kernel_autotuned[grid](x, y, output, n_elements)
714            return output
715
716        t1 = torch.rand(256, device=GPU_TYPE, requires_grad=grad)
717        t2 = torch.rand(256, device=GPU_TYPE, requires_grad=grad)
718        output = torch.zeros_like(t1, requires_grad=grad)
719
720        torch_add = call_triton(t1, t2, output)
721        compiled_func = torch.compile(
722            call_triton, backend=backend, fullgraph=True, dynamic=dynamic
723        )
724
725        output2 = torch.zeros_like(t1, requires_grad=grad)
726        self.assertEqual(compiled_func(t1, t2, output2), torch_add)
727
728    @requires_gpu
729    @skipIfRocm  # https://github.com/pytorch/pytorch/actions/runs/10051552819/job/27782048305?pr=131431
730    @common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"])
731    @patch.object(
732        torch._inductor.config, "unsafe_ignore_unsupported_triton_autotune_args", True
733    )
734    def test_triton_kernel_autotune_with_unsupported_args(self, backend):
735        def call_triton(x: torch.Tensor, y: torch.Tensor):
736            output = torch.zeros_like(x)
737            n_elements = output.numel()
738            add_kernel_autotuned_with_unsupported_args[(n_elements,)](
739                x, y, output, n_elements
740            )
741            return output
742
743        t1 = torch.rand(256, device=GPU_TYPE)
744        t2 = torch.rand(256, device=GPU_TYPE)
745
746        torch_add = call_triton(t1, t2)
747        compiled_func = torch.compile(call_triton, backend=backend, fullgraph=True)
748        compiled_add = compiled_func(t1, t2)
749        self.assertEqual(compiled_add, torch_add)
750
751    @requires_gpu
752    @common_utils.parametrize("grad", [False, True])
753    @common_utils.parametrize("dynamic", [False, True])
754    @common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"])
755    @common_utils.parametrize("grid_type", [1, 2, 3])
756    def test_triton_kernel_2d_autotune(self, grad, dynamic, backend, grid_type):
757        def call_triton(x: torch.Tensor, y: torch.Tensor, output: torch.Tensor):
758            x_elements = output.size()[0]
759            y_elements = output.size()[1]
760
761            def grid_fn(meta):
762                return (
763                    triton.cdiv(x_elements, meta["BLOCK_SIZE_X"]),
764                    triton.cdiv(y_elements, meta["BLOCK_SIZE_Y"]),
765                )
766
767            if grid_type == 1:
768                grid = (x_elements, y_elements)
769            elif grid_type == 2:
770                grid = lambda meta: (
771                    triton.cdiv(x_elements, meta["BLOCK_SIZE_X"]),
772                    triton.cdiv(y_elements, meta["BLOCK_SIZE_Y"]),
773                )
774            elif grid_type == 3:
775                grid = grid_fn
776
777            add_kernel_2d_autotuned[grid](x, y, output, x_elements, y_elements)
778            return output
779
780        t1 = torch.rand((512, 256), device=GPU_TYPE, requires_grad=grad)
781        t2 = torch.rand((512, 256), device=GPU_TYPE, requires_grad=grad)
782        output = torch.zeros_like(t1, requires_grad=grad)
783
784        torch_result = call_triton(t1, t2, output)
785        compiled_func = torch.compile(
786            call_triton, backend=backend, fullgraph=True, dynamic=dynamic
787        )
788        output2 = torch.zeros_like(t1, requires_grad=grad)
789        self.assertEqual(compiled_func(t1, t2, output2), torch_result)
790
791    @requires_gpu
792    @common_utils.parametrize("dynamic", [False, True])
793    def test_triton_kernel_tracing(self, dynamic):
794        def call_triton_add(
795            x: torch.Tensor,
796            y: torch.Tensor,
797            grid_type: int,
798            num=1,
799            positional=False,
800            autotuned=False,
801        ):
802            output = torch.empty_like(x)
803            n_elements = output.numel()
804
805            def grid_fn(meta):
806                return (triton.cdiv(num, meta["BLOCK_SIZE"]),)
807
808            if grid_type == 0:
809                grid = (x.numel(),)
810            elif grid_type == 1:
811                grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
812            elif grid_type == 2:
813                grid = grid_fn
814            else:
815                grid = [x.numel()]
816
817            if autotuned:
818                capture_triton(add_kernel_autotuned)[grid](x, y, output, n_elements)
819            else:
820                if positional:
821                    capture_triton(add_kernel)[grid](x, y, output, n_elements, 16)
822                else:
823                    capture_triton(add_kernel)[grid](
824                        x, y, output, n_elements, BLOCK_SIZE=16
825                    )
826
827            return output
828
829        t0 = torch.rand(5, device=GPU_TYPE, requires_grad=True)
830        t1 = torch.rand(5, device=GPU_TYPE, requires_grad=True)
831        t2 = torch.rand(5, device=GPU_TYPE, requires_grad=True)
832        t3 = torch.rand(5, device=GPU_TYPE, requires_grad=True)
833        torch_add = t2 + t3
834
835        tests = [
836            functools.partial(call_triton_add, grid_type=0),
837            functools.partial(call_triton_add, grid_type=1),
838            functools.partial(call_triton_add, grid_type=1, num=1, positional=True),
839            functools.partial(call_triton_add, grid_type=2, num=200),
840            functools.partial(call_triton_add, grid_type=3),
841            functools.partial(call_triton_add, grid_type=0, autotuned=True),
842            functools.partial(call_triton_add, grid_type=1, num=1, autotuned=True),
843            functools.partial(call_triton_add, grid_type=2, num=200, autotuned=True),
844            functools.partial(call_triton_add, grid_type=3, autotuned=True),
845        ]
846        from functorch import make_fx
847
848        tracing_mode = "symbolic" if dynamic else "fake"
849
850        for test in tests:
851            gm = make_fx(test, tracing_mode=tracing_mode)(t0, t1)
852            result = test(t2, t3)
853            self.assertEqual(result, torch_add)
854
855    @requires_gpu
856    @common_utils.parametrize("grad", [False, True])
857    @common_utils.parametrize("dynamic", [False, True])
858    @common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"])
859    @patch.object(torch._inductor.config, "implicit_fallbacks", False)
860    def test_triton_kernel_native(self, grad, dynamic, backend):
861        def call_triton_add(
862            x: torch.Tensor,
863            y: torch.Tensor,
864            output: torch.Tensor,
865            grid_type: int,
866            num=1,
867            positional=False,
868        ):
869            n_elements = output.numel()
870
871            def grid_fn(meta):
872                return (triton.cdiv(num, meta["BLOCK_SIZE"]),)
873
874            if grid_type == 0:
875                grid = (x.numel(),)
876            elif grid_type == 1:
877                grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
878            else:
879                grid = grid_fn
880
881            if positional:
882                add_kernel[grid](x, y, output, n_elements, 16)
883            else:
884                add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=16)
885
886            return output
887
888        t1 = torch.rand(5, device=GPU_TYPE, requires_grad=grad)
889        t2 = torch.rand(5, device=GPU_TYPE, requires_grad=grad)
890        o1 = torch.zeros_like(t1, requires_grad=grad)
891
892        torch_add = t1 + t2
893
894        # No Dynamo -- Make sure triton kernel works
895        self.assertEqual(call_triton_add(t1, t2, o1, 1), torch_add)
896        # No Dynamo -- Make sure triton kernel works (with positional BLOCK_SIZE)
897        o2 = torch.zeros_like(t1, requires_grad=grad)
898        self.assertEqual(call_triton_add(t1, t2, o2, 1, True), torch_add)
899
900        # With Dynamo
901        compiled_func = torch.compile(
902            call_triton_add, backend=backend, fullgraph=True, dynamic=dynamic
903        )
904        # With simple kernel
905        o3 = torch.zeros_like(t1, requires_grad=grad)
906        self.assertEqual(compiled_func(t1, t2, o3, 0), torch_add)
907        # With lambda kernel
908        o4 = torch.zeros_like(t1, requires_grad=grad)
909        self.assertEqual(compiled_func(t1, t2, o4, 1), torch_add)
910        # With lambda kernel (with positional BLOCK_SIZE)
911        o5 = torch.zeros_like(t1, requires_grad=grad)
912        self.assertEqual(compiled_func(t1, t2, o5, 1, 1, True), torch_add)
913        # With user defined function kernel
914        o6 = torch.zeros_like(t1, requires_grad=grad)
915        self.assertEqual(compiled_func(t1, t2, o6, 2, 200), torch_add)
916
917    @requires_gpu
918    def test_triton_kernel_mutation_not_mark_dirty(self):
919        @torch.compile
920        def f(x):
921            n_elements = x.numel()
922            add_kernel[(n_elements,)](x, x, x, n_elements, 16)
923            return x
924
925        x = torch.randn(5, device=GPU_TYPE, requires_grad=True)
926        x_cloned = x.clone()
927        out = x_cloned.sin()
928        f(x_cloned)
929        out.sum().backward()
930
931    @requires_cuda
932    @patch.object(torch._inductor.config, "allow_buffer_reuse", True)
933    def test_triton_kernel_inputs_buffer_reuse(self):
934        def _mul2(x):
935            y = torch.empty_like(x)
936            mul2_kernel[(10,)](
937                in_ptr0=x,
938                out_ptr=y,
939                n_elements=x.numel(),
940                BLOCK_SIZE=1,
941            )
942            return y
943
944        @torch.compile
945        def f(x):
946            for _ in range(4):
947                # The output of one kernel is the input to the next kernel, but
948                # at some point we should re-use buffers not allocate new ones.
949                x = _mul2(x)
950            return x + 1
951
952        x = torch.randn(10, device="cuda", dtype=torch.float32)
953        eager_out = f(x)
954        compiled_out, (code,) = run_and_get_code(torch.compile(f), x)
955        self.assertEqual(compiled_out, eager_out)
956
957        # Check that we're allocating the minimal # of buffers.
958        num_bufs_allocated = code.count(
959            "empty_strided_cuda((10, ), (1, ), torch.float32)"
960        )
961        self.assertEqual(num_bufs_allocated, 2)
962
963        # Check we're re-using buffers if not allocating.
964        num_bufs_reused = code.count("# reuse")
965        self.assertEqual(num_bufs_reused, 3)
966
967    @requires_gpu
968    def test_triton_kernel_matmul_tracking(self):
969        @triton.jit
970        def ones_kernel(x_ptr, n_elements, BLOCK_SIZE: "tl.constexpr"):
971            pid = tl.program_id(axis=0)
972            block_start = pid * BLOCK_SIZE
973            offsets = block_start + tl.arange(0, BLOCK_SIZE)
974            mask = offsets < n_elements
975            x = 1.0
976            tl.store(x_ptr + offsets, x, mask=mask)
977
978        @torch.compile
979        def f(x):
980            out = torch.zeros_like(x)
981            ones_kernel[(4,)](out, 16, BLOCK_SIZE=16)
982            return torch.mm(out, x) + 10
983
984        x = torch.randn(4, 4, device=GPU_TYPE)
985        torch_out = f(x)
986        python_out = torch.mm(torch.ones(4, 4, device=GPU_TYPE), x) + 10
987        self.assertEqual(torch_out, python_out)
988
989    @requires_gpu
990    def test_triton_kernel_strided_input(self):
991        def f(inp):
992            # left has strides [256, 1]
993            left, right = torch.split(inp, [128, 128], dim=1)
994            out = torch.empty_like(left)
995            X_BLOCK_SIZE, Y_BLOCK_SIZE = 32, 16
996            grid = (left.size(1) // X_BLOCK_SIZE, left.size(0) // Y_BLOCK_SIZE)
997            double_strided_kernel[grid](
998                in_ptr=left,
999                out_ptr=out,
1000                in_y_stride=left.stride(0),
1001                out_y_stride=out.stride(0),
1002                X_BLOCK_SIZE=X_BLOCK_SIZE,
1003                Y_BLOCK_SIZE=Y_BLOCK_SIZE,
1004            )
1005            return out
1006
1007        inp = torch.randn(64, 256, device=GPU_TYPE)
1008
1009        eager_out = f(inp)
1010        compiled_out = torch.compile(f)(inp)
1011        self.assertEqual(compiled_out, eager_out)
1012
1013    @requires_gpu
1014    def test_triton_kernel_strided_input_nonzero_offset(self):
1015        def f(inp):
1016            # right has strides [256, 1] and storage offset 128
1017            left, right = torch.split(inp, [128, 128], dim=1)
1018            out = torch.empty_like(right)
1019            X_BLOCK_SIZE, Y_BLOCK_SIZE = 32, 16
1020            grid = (right.size(1) // X_BLOCK_SIZE, right.size(0) // Y_BLOCK_SIZE)
1021            double_strided_kernel[grid](
1022                in_ptr=right,
1023                out_ptr=out,
1024                in_y_stride=right.stride(0),
1025                out_y_stride=out.stride(0),
1026                X_BLOCK_SIZE=X_BLOCK_SIZE,
1027                Y_BLOCK_SIZE=Y_BLOCK_SIZE,
1028            )
1029            return out
1030
1031        inp = torch.randn(64, 256, device=GPU_TYPE)
1032
1033        eager_out = f(inp)
1034        compiled_out = torch.compile(f)(inp)
1035        self.assertEqual(compiled_out, eager_out)
1036
1037    @requires_gpu
1038    def test_triton_kernel_slice_and_view_input(self):
1039        def f(inp):
1040            # left has strides [256, 1]
1041            left = inp[:, :128]
1042            left = left.view(64, 4, 32)
1043            out = torch.empty_like(left)
1044            X_BLOCK_SIZE, Y_BLOCK_SIZE = 32, 16
1045            grid = (
1046                (left.size(1) * left.size(2)) // X_BLOCK_SIZE,
1047                left.size(0) // Y_BLOCK_SIZE,
1048            )
1049            double_strided_kernel[grid](
1050                in_ptr=left,
1051                out_ptr=out,
1052                in_y_stride=left.stride(0),
1053                out_y_stride=out.stride(0),
1054                X_BLOCK_SIZE=X_BLOCK_SIZE,
1055                Y_BLOCK_SIZE=Y_BLOCK_SIZE,
1056            )
1057            return out + left
1058
1059        inp = torch.randn(64, 256, device=GPU_TYPE)
1060
1061        eager_out = f(inp)
1062        compiled_out = torch.compile(f)(inp)
1063        self.assertEqual(compiled_out, eager_out)
1064
1065    @requires_gpu
1066    def test_triton_kernel_fallback(self):
1067        def f(x, y):
1068            out = torch.zeros_like(x)
1069            out2 = torch.zeros_like(x)
1070            # torch.mm is ExternKernelOut
1071            add_kernel[(4,)](x, torch.mm(x, y), out, 4, 16)
1072            # torch.sort creates fallback kernel and hence MultiOutput
1073            add_kernel[(4,)](x, torch.sort(y).values, out, 4, 16)
1074            return out, out2
1075
1076        x = torch.randn(4, 4, device=GPU_TYPE)
1077        y = torch.randn(4, 4, device=GPU_TYPE)
1078        eager_out = f(x, y)
1079        compiled_out = torch.compile(f)(x, y)
1080        self.assertEqual(compiled_out, eager_out)
1081
1082    @requires_gpu
1083    def test_triton_kernel_out_of_order(self):
1084        @triton.jit
1085        def add_kernel(
1086            in_ptr0,
1087            in_ptr1,
1088            BLOCK_SIZE: "tl.constexpr",
1089            out_ptr,
1090            n_elements,
1091        ):
1092            pid = tl.program_id(axis=0)
1093            block_start = pid * BLOCK_SIZE
1094            offsets = block_start + tl.arange(0, BLOCK_SIZE)
1095            mask = offsets < n_elements
1096            x = tl.load(in_ptr0 + offsets, mask=mask)
1097            y = tl.load(in_ptr1 + offsets, mask=mask)
1098            output = x + y
1099            tl.store(out_ptr + offsets, output, mask=mask)
1100
1101        def f(x, y):
1102            out = torch.zeros_like(x)
1103            n_elements = x.numel()
1104            add_kernel[(n_elements,)](x, y, 4, out, n_elements)
1105            return out
1106
1107        x = torch.randn(4, device=GPU_TYPE)
1108        y = torch.randn(4, device=GPU_TYPE)
1109        eager_out = f(x, y)
1110        compiled_out = torch.compile(f)(x, y)
1111        self.assertEqual(compiled_out, eager_out)
1112
1113    @requires_gpu
1114    @torch._dynamo.config.patch(capture_dynamic_output_shape_ops=True)
1115    @torch._dynamo.config.patch(capture_scalar_outputs=True)
1116    @common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"])
1117    def test_triton_kernel_unbacked_shape_tensor(self, backend):
1118        @triton.jit
1119        def square(
1120            in_ptr,
1121            out_ptr,
1122            n_elements,
1123            BLOCK_SIZE: "tl.constexpr",
1124        ):
1125            pid = tl.program_id(axis=0)
1126            block_start = pid * BLOCK_SIZE
1127            offsets = block_start + tl.arange(0, BLOCK_SIZE)
1128            mask = offsets < n_elements
1129            x = tl.load(in_ptr + offsets, mask=mask)
1130            output = x * x
1131            tl.store(out_ptr + offsets, output, mask=mask)
1132
1133        def f(x):
1134            x = x[x > 2]
1135            n_elements = x.numel()
1136            output = torch.zeros_like(x)
1137            grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
1138            square[grid](x, output, n_elements, BLOCK_SIZE=16)
1139            return output
1140
1141        x = torch.randn(4, device=GPU_TYPE)
1142        eager_out = f(x)
1143        compiled_out = torch.compile(f, fullgraph=True, backend=backend)(x)
1144        self.assertEqual(compiled_out, eager_out)
1145
1146    @requires_gpu
1147    @common_utils.parametrize("dynamic", [False, True])
1148    def test_triton_kernel_equal_to_1_arg(self, dynamic):
1149        @triton.jit
1150        def add_kernel_half_n_elements(
1151            in_ptr0,
1152            in_ptr1,
1153            out_ptr,
1154            half_n_elements,
1155            BLOCK_SIZE: "tl.constexpr",
1156        ):
1157            pid = tl.program_id(axis=0)
1158            block_start = pid * BLOCK_SIZE
1159            offsets = block_start + tl.arange(0, BLOCK_SIZE)
1160            mask = offsets < half_n_elements * 2
1161            x = tl.load(in_ptr0 + offsets, mask=mask)
1162            y = tl.load(in_ptr1 + offsets, mask=mask)
1163            output = x + y
1164            tl.store(out_ptr + offsets, output, mask=mask)
1165
1166        def f(x, y):
1167            out = torch.empty_like(x)
1168            half_n_elements = x.numel() // 2
1169            add_kernel_half_n_elements[(half_n_elements,)](
1170                x, y, out, half_n_elements, BLOCK_SIZE=16
1171            )
1172            return out
1173
1174        x = torch.randn(2, device=GPU_TYPE)
1175        y = torch.randn(2, device=GPU_TYPE)
1176        eager_out = f(x, y)
1177        compiled_out, sources = run_and_get_code(
1178            torch.compile(f, dynamic=dynamic), x, y
1179        )
1180
1181        if dynamic:
1182            # when half_n_elements passed to the Triton kernel is
1183            # dynamic, equal_to_1 specializaiton can't be enforced
1184            self.assertTrue("equal_to_1=()" in sources[0])
1185        else:
1186            self.assertTrue("equal_to_1=(3,)" in sources[0])
1187        self.assertEqual(compiled_out, eager_out)
1188
1189    @requires_gpu
1190    @common_utils.parametrize("dynamic", [False, True])
1191    def test_triton_kernel_equal_to_1_float_arg(self, dynamic):
1192        def f(x, y):
1193            out = torch.empty_like(x)
1194            n_elements = x.numel()
1195            scaling_factor = (n_elements**0) / 1.0
1196            add_kernel_with_scaling[(n_elements,)](
1197                x,
1198                y,
1199                out,
1200                n_elements,
1201                scaling_factor,
1202                BLOCK_SIZE=16,
1203            )
1204            return out
1205
1206        x = torch.randn(2, device=GPU_TYPE)
1207        y = torch.randn(2, device=GPU_TYPE)
1208        eager_out = f(x, y)
1209        compiled_out, sources = run_and_get_code(
1210            torch.compile(f, dynamic=dynamic), x, y
1211        )
1212
1213        # float 1.0 (both literal or symbolic)
1214        # should not be added to equal_to_1
1215        self.assertTrue("equal_to_1=()" in sources[0])
1216        self.assertEqual(compiled_out, eager_out)
1217
1218    @requires_gpu
1219    @skipIfRocm
1220    def test_triton_kernel_with_imported_symbol(self):
1221        @triton.jit
1222        def add_kernel_with_imported_symbol(
1223            in_ptr,
1224            out_ptr,
1225            n_elements,
1226            BLOCK_SIZE: "tl.constexpr",
1227        ):
1228            pid = tl.program_id(axis=0)
1229            block_start = pid * BLOCK_SIZE
1230            offsets = block_start + tl.arange(0, BLOCK_SIZE)
1231            mask = offsets < n_elements
1232            x = tl.load(in_ptr + offsets, mask=mask)
1233            output = fast_dividef(x, 3.14)
1234            tl.store(out_ptr + offsets, output, mask=mask)
1235
1236        def f(x):
1237            out = torch.empty_like(x)
1238            n_elements = x.numel()
1239            add_kernel_with_imported_symbol[(n_elements,)](
1240                x, out, n_elements, BLOCK_SIZE=16
1241            )
1242            return out
1243
1244        x = torch.randn(4, device=GPU_TYPE)
1245        eager_out = f(x)
1246        compiled_out = torch.compile(f)(x)
1247
1248        self.assertEqual(compiled_out, eager_out)
1249
1250    @requires_gpu
1251    @skipIfRocm
1252    def test_triton_kernel_with_imported_symbol_with_custom_name(self):
1253        @triton.jit
1254        def add_kernel_with_imported_symbol(
1255            in_ptr,
1256            out_ptr,
1257            n_elements,
1258            BLOCK_SIZE: "tl.constexpr",
1259        ):
1260            pid = tl.program_id(axis=0)
1261            block_start = pid * BLOCK_SIZE
1262            offsets = block_start + tl.arange(0, BLOCK_SIZE)
1263            mask = offsets < n_elements
1264            x = tl.load(in_ptr + offsets, mask=mask)
1265            output = my_fast_dividef(x, 3.14)
1266            tl.store(out_ptr + offsets, output, mask=mask)
1267
1268        def f(x):
1269            out = torch.empty_like(x)
1270            n_elements = x.numel()
1271            add_kernel_with_imported_symbol[(n_elements,)](
1272                x, out, n_elements, BLOCK_SIZE=16
1273            )
1274            return out
1275
1276        x = torch.randn(4, device=GPU_TYPE)
1277        eager_out = f(x)
1278        compiled_out = torch.compile(f)(x)
1279
1280        self.assertEqual(compiled_out, eager_out)
1281
1282    @requires_gpu
1283    @common_utils.parametrize("size", [4, 16])
1284    @common_utils.parametrize("dynamic", [False, True])
1285    def test_triton_kernel_different_shapes(self, size, dynamic):
1286        from torch._inductor.utils import run_and_get_code
1287
1288        def f(x, y, xx, yy):
1289            n_elements = x.numel()
1290            output_1 = torch.zeros_like(x)
1291            grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
1292            add_kernel[grid](x, y, output_1, n_elements, BLOCK_SIZE=4)
1293
1294            n_elements = xx.numel()
1295            output_2 = torch.zeros_like(xx)
1296            grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
1297            add_kernel[grid](xx, yy, output_2, n_elements, BLOCK_SIZE=4)
1298
1299            return output_1, output_2
1300
1301        x = torch.rand(size, device=GPU_TYPE)
1302        y = torch.rand(size, device=GPU_TYPE)
1303        xx = torch.rand(size, size, device=GPU_TYPE)
1304        yy = torch.rand(size, size, device=GPU_TYPE)
1305        args = [x, y, xx, yy]
1306
1307        eager_out = f(*args)
1308        compiled_out, (code,) = run_and_get_code(
1309            torch.compile(f, fullgraph=True, dynamic=dynamic, backend="inductor"), *args
1310        )
1311        if size == 4 and not dynamic:
1312            # Produce 2 kernels due to divisibility
1313            self.assertTrue("add_kernel_0.run" in code)
1314            self.assertTrue("add_kernel_1.run" in code)
1315        else:
1316            # size == 16 or dynamic
1317            # Only one kernel
1318            self.assertTrue("add_kernel_0.run" in code)
1319            self.assertTrue("add_kernel_1.run" not in code)
1320
1321        self.assertEqual(compiled_out, eager_out)
1322
1323    @requires_gpu
1324    def test_triton_kernel_reset_to_zero(self):
1325        @triton.autotune(
1326            configs=[
1327                triton.Config({"BLOCK_SIZE": 128}, num_stages=3, num_warps=8),
1328                triton.Config({"BLOCK_SIZE": 64}, num_stages=3, num_warps=8),
1329            ],
1330            key=["n_elements"],
1331            reset_to_zero=["out_ptr"],
1332        )
1333        @triton.jit
1334        def add_kernel_autotuned_reset(
1335            in_ptr0,
1336            in_ptr1,
1337            out_ptr,
1338            n_elements,
1339            BLOCK_SIZE: "tl.constexpr",
1340        ):
1341            pid = tl.program_id(axis=0)
1342            block_start = pid * BLOCK_SIZE
1343            offsets = block_start + tl.arange(0, BLOCK_SIZE)
1344            mask = offsets < n_elements
1345            x = tl.load(in_ptr0 + offsets, mask=mask)
1346            y = tl.load(in_ptr1 + offsets, mask=mask)
1347            output = x + y
1348            tl.store(out_ptr + offsets, output, mask=mask)
1349
1350        @torch.compile(fullgraph=True)
1351        def f(x, y):
1352            output = torch.zeros_like(x)
1353            n_elements = output.numel()
1354            grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
1355            add_kernel_autotuned_reset[grid](x, y, output, n_elements)
1356            return output
1357
1358        x = torch.randn(4, device=GPU_TYPE)
1359        msg = "Only configs and keys are supported for triton.autotune"
1360        with self.assertRaisesRegex(torch._dynamo.exc.Unsupported, msg):
1361            f(x, x)
1362
1363    @requires_gpu
1364    @common_utils.parametrize("dynamic", [False, True])
1365    @common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"])
1366    def test_triton_kernel_triton_dtype(self, dynamic, backend):
1367        @triton.jit
1368        def add_kernel_with_dtype(
1369            in_ptr0,
1370            in_ptr1,
1371            out_ptr,
1372            dtype: "tl.constexpr",
1373            n_elements,
1374            BLOCK_SIZE: "tl.constexpr",
1375        ):
1376            pid = tl.program_id(axis=0)
1377            block_start = pid * BLOCK_SIZE
1378            offsets = block_start + tl.arange(0, BLOCK_SIZE)
1379            mask = offsets < n_elements
1380            x = tl.load(in_ptr0 + offsets, mask=mask).to(dtype)
1381            y = tl.load(in_ptr1 + offsets, mask=mask).to(dtype)
1382            output = x + y
1383            tl.store(out_ptr + offsets, output, mask=mask)
1384
1385        def f(x, y, dtype_torch, dtype_triton):
1386            output = torch.zeros_like(x).to(dtype=dtype_torch)
1387            n_elements = output.numel()
1388            grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
1389            add_kernel_with_dtype[grid](
1390                x, y, output, dtype_triton, n_elements, BLOCK_SIZE=4
1391            )
1392            return output
1393
1394        x = torch.randn(4, device=GPU_TYPE)
1395        y = torch.randn(4, device=GPU_TYPE)
1396        args_list = (
1397            [x, y, torch.float32, tl.float32],
1398            [x, y, torch.bfloat16, tl.bfloat16],
1399        )
1400        for args in args_list:
1401            eager_out = f(*args)
1402            compiled_out = torch.compile(
1403                f, fullgraph=True, backend=backend, dynamic=dynamic
1404            )(*args)
1405            self.assertEqual(compiled_out, eager_out)
1406
1407    @requires_gpu
1408    @common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"])
1409    def test_triton_kernel_special_kwargs_with_autotune(self, backend):
1410        @triton.autotune(
1411            configs=[
1412                triton.Config({"BLOCK_SIZE": 128}),
1413                triton.Config({"BLOCK_SIZE": 64}),
1414            ],
1415            key=["n_elements"],
1416        )
1417        @triton.jit
1418        def add_kernel(
1419            in_ptr0,
1420            in_ptr1,
1421            out_ptr,
1422            n_elements,
1423            BLOCK_SIZE: "tl.constexpr",
1424        ):
1425            pid = tl.program_id(axis=0)
1426            block_start = pid * BLOCK_SIZE
1427            offsets = block_start + tl.arange(0, BLOCK_SIZE)
1428            mask = offsets < n_elements
1429            x = tl.load(in_ptr0 + offsets, mask=mask)
1430            y = tl.load(in_ptr1 + offsets, mask=mask)
1431            output = x + y
1432            tl.store(out_ptr + offsets, output, mask=mask)
1433
1434        @torch.compile(fullgraph=True, backend=backend)
1435        def f(x, y):
1436            output = torch.zeros_like(x)
1437            n_elements = output.numel()
1438            grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
1439            add_kernel[grid](
1440                x,
1441                y,
1442                output,
1443                n_elements,
1444                num_warps=8,
1445                num_stages=3,
1446            )
1447            return output
1448
1449        x = torch.randn(4, device=GPU_TYPE)
1450        f(x, x)
1451
1452    @requires_gpu
1453    @common_utils.parametrize("dynamic", [False, True])
1454    @common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"])
1455    def test_triton_kernel_multiple_outputs(self, dynamic, backend):
1456        @triton.jit
1457        def add_kernel(
1458            in_ptr0,
1459            in_ptr1,
1460            out_ptr,
1461            out_ptr2,
1462            n_elements,
1463            BLOCK_SIZE: "tl.constexpr",
1464        ):
1465            pid = tl.program_id(axis=0)
1466            block_start = pid * BLOCK_SIZE
1467            offsets = block_start + tl.arange(0, BLOCK_SIZE)
1468            mask = offsets < n_elements
1469            x = tl.load(in_ptr0 + offsets, mask=mask)
1470            y = tl.load(in_ptr1 + offsets, mask=mask)
1471            output = x + y
1472            tl.store(out_ptr + offsets, output, mask=mask)
1473            tl.store(out_ptr2 + offsets, output + 1, mask=mask)
1474
1475        @torch.compile(fullgraph=True, backend=backend, dynamic=dynamic)
1476        def f(x, y, z):
1477            output = torch.empty_like(x)
1478            output2 = torch.empty_like(x)
1479            n_elements = output.numel()
1480            grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
1481            add_kernel[grid](x, y, output, output2, n_elements, BLOCK_SIZE=16)
1482            # The z return is intentional: we're testing training
1483            return output, output2, z**2
1484
1485        x = torch.randn(3, requires_grad=True, device=GPU_TYPE)
1486        y = torch.randn(3, requires_grad=True, device=GPU_TYPE)
1487        z = torch.randn(3, requires_grad=True, device=GPU_TYPE)
1488        out, out2, out3 = f(x, y, z)
1489        self.assertEqual(out, x + y)
1490        self.assertEqual(out2, x + y + 1)
1491        self.assertEqual(out3, z**2)
1492
1493    @requires_gpu
1494    @common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"])
1495    def test_triton_kernel_num_ctas(self, backend):
1496        @triton.jit
1497        def kernel(X):
1498            return
1499
1500        @torch.compile(backend=backend)
1501        def f(x):
1502            kernel[(1,)](x, num_ctas=1)
1503            kernel.run(x, num_ctas=1, grid=(1,), warmup=False)
1504            return x
1505
1506        x = torch.randn(4, device=GPU_TYPE)
1507        f(x)
1508
1509    @requires_gpu
1510    @common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"])
1511    def test_triton_kernel_special_kwargs_without_autotune(self, backend):
1512        @triton.jit
1513        def add_kernel(
1514            in_ptr0,
1515            in_ptr1,
1516            out_ptr,
1517            n_elements,
1518            BLOCK_SIZE: "tl.constexpr",
1519        ):
1520            pid = tl.program_id(axis=0)
1521            block_start = pid * BLOCK_SIZE
1522            offsets = block_start + tl.arange(0, BLOCK_SIZE)
1523            mask = offsets < n_elements
1524            x = tl.load(in_ptr0 + offsets, mask=mask)
1525            y = tl.load(in_ptr1 + offsets, mask=mask)
1526            output = x + y
1527            tl.store(out_ptr + offsets, output, mask=mask)
1528
1529        @torch.compile(fullgraph=True, backend=backend)
1530        def f(x, y):
1531            output = torch.zeros_like(x)
1532            n_elements = output.numel()
1533            grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
1534            add_kernel[grid](
1535                x,
1536                y,
1537                output,
1538                n_elements,
1539                BLOCK_SIZE=128,
1540                num_warps=8,
1541                num_stages=3,
1542            )
1543            return output
1544
1545        x = torch.randn(4, device=GPU_TYPE)
1546        f(x, x)
1547
1548
1549def make_mutation_test(fn):
1550    @requires_gpu
1551    def test_fn(self):
1552        from torch._higher_order_ops.triton_kernel_wrap import identify_mutated_tensors
1553
1554        kernel, inputs, outputs = fn()
1555        self.assertListEqual(
1556            identify_mutated_tensors(kernel, inputs),
1557            outputs,
1558        )
1559
1560    return test_fn
1561
1562
1563# Triton codegen suffers from scoping issues.
1564# Define helpers here
1565if HAS_GPU:
1566
1567    @triton.jit
1568    def helper_id(p):
1569        return p
1570
1571    @triton.jit
1572    def helper_add_and_out(x, y, out_ptr):
1573        return x + y, out_ptr
1574
1575
1576class MutationTests(torch._inductor.test_case.TestCase):
1577    # Tests injected below
1578
1579    @make_mutation_test
1580    def test_out_of_order_kernel():
1581        @triton.jit
1582        def add_kernel_out_of_order(
1583            in_ptr0,
1584            n_elements,
1585            in_ptr1,
1586            out_ptr,
1587            BLOCK_SIZE: "tl.constexpr",
1588        ):
1589            pid = tl.program_id(axis=0)
1590            block_start = pid * BLOCK_SIZE
1591            offsets = block_start + tl.arange(0, BLOCK_SIZE)
1592            mask = offsets < n_elements
1593            x = tl.load(in_ptr0 + offsets, mask=mask)
1594            y = tl.load(in_ptr1 + offsets, mask=mask)
1595            output = x + y
1596            tl.store(out_ptr + offsets, output, mask=mask)
1597
1598        t = torch.randn(4)
1599        return (
1600            add_kernel_out_of_order,
1601            {
1602                "in_ptr0": t,
1603                "n_elements": 4,
1604                "in_ptr1": t,
1605                "out_ptr": t,
1606                "BLOCK_SIZE": 4,
1607            },
1608            ["out_ptr"],
1609        )
1610
1611    @make_mutation_test
1612    def test_out_of_order_kernel_call():
1613        @triton.jit
1614        def add_kernel_out_of_order_fn1(
1615            in_ptr0,
1616            n_elements,
1617            in_ptr1,
1618            out_ptr,
1619            BLOCK_SIZE: "tl.constexpr",
1620        ):
1621            pid = tl.program_id(axis=0)
1622            block_start = pid * BLOCK_SIZE
1623            offsets = block_start + tl.arange(0, BLOCK_SIZE)
1624            mask = offsets < n_elements
1625            add_kernel_out_of_order_fn2(
1626                in_ptr0, in_ptr1, n_elements, out_ptr, BLOCK_SIZE=BLOCK_SIZE
1627            )
1628
1629        t = torch.randn(4)
1630        return (
1631            add_kernel_out_of_order_fn1,
1632            {
1633                "in_ptr0": t,
1634                "n_elements": 4,
1635                "in_ptr1": t,
1636                "out_ptr": t,
1637                "BLOCK_SIZE": 4,
1638            },
1639            ["out_ptr"],
1640        )
1641
1642    @make_mutation_test
1643    def test_reduce_sum():
1644        @triton.jit
1645        def reduce_sum_kernel(a_ptr, c_ptr, stride_am, stride_an):
1646            offs_am = tl.arange(0, 4)
1647            offs_an = tl.arange(0, 4)
1648            a_ptrs = a_ptr + (
1649                offs_am[:, None] * stride_am + offs_an[None, :] * stride_an
1650            )
1651            a = tl.load(a_ptrs)
1652            m = tl.sum(a, axis=1)
1653            tl.store(c_ptr + tl.arange(0, 4), m)
1654
1655        t = torch.randn(4)
1656        kernel = reduce_sum_kernel
1657        kwargs = {
1658            "a_ptr": t,
1659            "c_ptr": t,
1660            "stride_am": 4,
1661            "stride_an": 4,
1662        }
1663
1664        # TODO(aakhundov): tt.reduce is now supported, but only
1665        # in the new MLIR-based Triton analysis pass (not in the
1666        # old TTIR string parsing-based one). remove this gating
1667        # and use ["c_ptr"] as `expected` after the new Triton
1668        # pin lands both in OSS and internally.
1669        ttir_module, _ = generate_ttir(kernel, kwargs)
1670        if hasattr(ttir_module, "walk"):
1671            # with MLIR-based Triton analysis pass
1672            expected = ["c_ptr"]
1673        else:
1674            # with TTIR string parsing-based Triton analysis pass
1675            expected = ["a_ptr", "c_ptr"]
1676
1677        return (
1678            kernel,
1679            kwargs,
1680            expected,
1681        )
1682
1683    @make_mutation_test
1684    def test_argmax():
1685        @triton.jit
1686        def argmax_kernel(a_ptr, c_ptr, stride_am, stride_an):
1687            offs_am = tl.arange(0, 4)
1688            offs_an = tl.arange(0, 4)
1689            a_ptrs = a_ptr + (
1690                offs_am[:, None] * stride_am + offs_an[None, :] * stride_an
1691            )
1692            a = tl.load(a_ptrs)
1693            m = tl.argmax(a, axis=1)
1694            tl.store(c_ptr + tl.arange(0, 4), m)
1695
1696        t = torch.randn(4)
1697        kernel = argmax_kernel
1698        kwargs = {
1699            "a_ptr": t,
1700            "c_ptr": t,
1701            "stride_am": 4,
1702            "stride_an": 4,
1703        }
1704
1705        # TODO(aakhundov): tt.reduce is now supported, but only
1706        # in the new MLIR-based Triton analysis pass (not in the
1707        # old TTIR string parsing-based one). remove this gating
1708        # and use ["c_ptr"] as `expected` after the new Triton
1709        # pin lands both in OSS and internally.
1710        ttir_module, _ = generate_ttir(kernel, kwargs)
1711        if hasattr(ttir_module, "walk"):
1712            # with MLIR-based Triton analysis pass
1713            expected = ["c_ptr"]
1714        else:
1715            # with TTIR string parsing-based Triton analysis pass
1716            expected = ["a_ptr", "c_ptr"]
1717
1718        return (
1719            kernel,
1720            kwargs,
1721            expected,
1722        )
1723
1724    @requires_cuda
1725    @skipIfRocm
1726    def test_triton_kernel_inference_mode(self):
1727        def f(x, y, out):
1728            n_elements = x.numel()
1729            grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
1730            add_kernel[grid](x, y, out, n_elements, BLOCK_SIZE=4)
1731
1732        with torch.inference_mode():
1733            x = torch.ones(32, device="cuda")
1734            y = torch.ones(32, device="cuda")
1735            out_ref = torch.zeros_like(x)
1736            out_test = torch.zeros_like(x)
1737            f(x, y, out_ref)
1738            torch.compile(f)(x, y, out_test)
1739            self.assertEqual(out_ref, out_test)
1740
1741    @make_mutation_test
1742    def test_cumsum():
1743        @triton.jit
1744        def cumsum_kernel(in_ptr, out_ptr, XBLOCK: tl.constexpr, RBLOCK: tl.constexpr):
1745            rindex = tl.arange(0, RBLOCK)[None, :]
1746            xindex = tl.arange(0, XBLOCK)[:, None]
1747            data = tl.load(in_ptr + rindex)
1748            scan = tl.cumsum(data, 1)
1749            expected_max = tl.sum(data, 1)
1750            tl.device_assert(scan <= expected_max)
1751            tl.store(out_ptr + xindex * RBLOCK + rindex, scan)
1752
1753        t = torch.randn(4)
1754        kernel = cumsum_kernel
1755        kwargs = {
1756            "in_ptr": t,
1757            "out_ptr": t,
1758            "XBLOCK": 4,
1759            "RBLOCK": 16,
1760        }
1761
1762        # TODO(aakhundov): tt.scan is now supported, but only
1763        # in the new MLIR-based Triton analysis pass (not in the
1764        # old TTIR string parsing-based one). remove this gating
1765        # and use ["out_ptr"] as `expected` after the new Triton
1766        # pin lands both in OSS and internally.
1767        ttir_module, _ = generate_ttir(kernel, kwargs)
1768        if hasattr(ttir_module, "walk"):
1769            # with MLIR-based Triton analysis pass
1770            expected = ["out_ptr"]
1771        else:
1772            # with TTIR string parsing-based Triton analysis pass
1773            expected = ["in_ptr", "out_ptr"]
1774
1775        return (
1776            kernel,
1777            kwargs,
1778            expected,
1779        )
1780
1781    @make_mutation_test
1782    def test_fn_call_one_return():
1783        @triton.jit
1784        def add_kernel_with_fn_call(
1785            in_ptr0,
1786            in_ptr1,
1787            n_elements,
1788            out_ptr,
1789            BLOCK_SIZE: "tl.constexpr",
1790        ):
1791            pid = tl.program_id(axis=0)
1792            block_start = pid * BLOCK_SIZE
1793            offsets = block_start + tl.arange(0, BLOCK_SIZE)
1794            mask = offsets < n_elements
1795            x = tl.load(in_ptr0 + offsets, mask=mask)
1796            y = tl.load(in_ptr1 + offsets, mask=mask)
1797            output = x + y
1798            out = helper_id(out_ptr)
1799            tl.store(out + offsets, output, mask=mask)
1800
1801        t = torch.randn(4)
1802        return (
1803            add_kernel_with_fn_call,
1804            {
1805                "in_ptr0": t,
1806                "in_ptr1": t,
1807                "n_elements": 4,
1808                "out_ptr": t,
1809                "BLOCK_SIZE": 4,
1810            },
1811            ["out_ptr"],
1812        )
1813
1814    @make_mutation_test
1815    def test_fn_call_multi_return():
1816        @triton.jit
1817        def add_kernel_with_fn_call(
1818            in_ptr0,
1819            in_ptr1,
1820            n_elements,
1821            out_ptr,
1822            BLOCK_SIZE: "tl.constexpr",
1823        ):
1824            pid = tl.program_id(axis=0)
1825            block_start = pid * BLOCK_SIZE
1826            offsets = block_start + tl.arange(0, BLOCK_SIZE)
1827            mask = offsets < n_elements
1828            x = tl.load(in_ptr0 + offsets, mask=mask)
1829            y = tl.load(in_ptr1 + offsets, mask=mask)
1830            output, out = helper_add_and_out(x, y, out_ptr)
1831            tl.store(out + offsets, output, mask=mask)
1832
1833        t = torch.randn(4)
1834        return (
1835            add_kernel_with_fn_call,
1836            {
1837                "in_ptr0": t,
1838                "in_ptr1": t,
1839                "n_elements": 4,
1840                "out_ptr": t,
1841                "BLOCK_SIZE": 4,
1842            },
1843            ["out_ptr"],
1844        )
1845
1846    @make_mutation_test
1847    def test_nested_cond_op_kernel():
1848        @triton.jit
1849        def nested_cond_op_kernel(
1850            in_ptr0,
1851            in_ptr1,
1852            out_ptr,
1853            n_elements,
1854            BLOCK_SIZE: "tl.constexpr",
1855        ):
1856            pid = tl.program_id(axis=0)
1857            block_start = pid * BLOCK_SIZE
1858            offsets = block_start + tl.arange(0, BLOCK_SIZE)
1859            mask = offsets < n_elements
1860            x = tl.load(in_ptr0 + offsets, mask=mask)
1861            y = tl.load(in_ptr1 + offsets, mask=mask)
1862            if tl.program_id(0) == 0:
1863                if tl.program_id(1) == 0:
1864                    output = x + y
1865                    tl.store(out_ptr + offsets, output, mask=mask)
1866            else:
1867                pass
1868
1869        t = torch.randn(4)
1870        return (
1871            nested_cond_op_kernel,
1872            {
1873                "in_ptr0": t,
1874                "in_ptr1": t,
1875                "out_ptr": t,
1876                "n_elements": 4,
1877                "BLOCK_SIZE": 4,
1878            },
1879            ["out_ptr"],
1880        )
1881
1882    @make_mutation_test
1883    def test_add_for_loop():
1884        @triton.jit
1885        def add_4_times_kernel(
1886            in_ptr0,
1887            in_ptr1,
1888            out_ptr,
1889            n_elements,
1890            BLOCK_SIZE: "tl.constexpr",
1891        ):
1892            pid = tl.program_id(axis=0)
1893            block_start = pid * BLOCK_SIZE
1894            offsets = block_start + tl.arange(0, BLOCK_SIZE)
1895            mask = offsets < n_elements
1896            x = tl.load(in_ptr0 + offsets, mask=mask)
1897            y = tl.load(in_ptr1 + offsets, mask=mask)
1898            output = tl.zeros((n_elements,), dtype=tl.float32)
1899            for i in range(4):
1900                output += x + y
1901            tl.store(out_ptr + offsets, output, mask=mask)
1902
1903        t = torch.randn(4)
1904        return (
1905            add_4_times_kernel,
1906            {
1907                "in_ptr0": t,
1908                "in_ptr1": t,
1909                "out_ptr": t,
1910                "n_elements": 4,
1911                "BLOCK_SIZE": 4,
1912            },
1913            ["out_ptr"],
1914        )
1915
1916    @make_mutation_test
1917    def test_add_for_loop2():
1918        @triton.jit
1919        def add_1_time_kernel(
1920            in_ptr0,
1921            in_ptr1,
1922            out_ptr,
1923            n_elements,
1924            BLOCK_SIZE: "tl.constexpr",
1925        ):
1926            pid = tl.program_id(axis=0)
1927            block_start = pid * BLOCK_SIZE
1928            offsets = block_start + tl.arange(0, BLOCK_SIZE)
1929            mask = offsets < n_elements
1930            x = tl.load(in_ptr0 + offsets, mask=mask)
1931            y = tl.load(in_ptr1 + offsets, mask=mask)
1932            for i in range(0, BLOCK_SIZE):
1933                i = tl.multiple_of(i, 1)
1934            output = x + y
1935            tl.store(out_ptr + offsets, output, mask=mask)
1936
1937        t = torch.randn(4)
1938        return (
1939            add_1_time_kernel,
1940            {
1941                "in_ptr0": t,
1942                "in_ptr1": t,
1943                "out_ptr": t,
1944                "n_elements": 4,
1945                "BLOCK_SIZE": 4,
1946            },
1947            ["out_ptr"],
1948        )
1949
1950    @make_mutation_test
1951    def test_add_nested_for_loop():
1952        @triton.jit
1953        def add_4_times_kernel(
1954            in_ptr0,
1955            in_ptr1,
1956            out_ptr,
1957            n_elements,
1958            BLOCK_SIZE: "tl.constexpr",
1959        ):
1960            pid = tl.program_id(axis=0)
1961            block_start = pid * BLOCK_SIZE
1962            offsets = block_start + tl.arange(0, BLOCK_SIZE)
1963            mask = offsets < n_elements
1964            x = tl.load(in_ptr0 + offsets, mask=mask)
1965            y = tl.load(in_ptr1 + offsets, mask=mask)
1966            output = tl.zeros((n_elements,), dtype=tl.float32)
1967            for i in range(2):
1968                for j in range(2):
1969                    output += x + y
1970            tl.store(out_ptr + offsets, output, mask=mask)
1971
1972        t = torch.randn(4)
1973        return (
1974            add_4_times_kernel,
1975            {
1976                "in_ptr0": t,
1977                "in_ptr1": t,
1978                "out_ptr": t,
1979                "n_elements": 4,
1980                "BLOCK_SIZE": 4,
1981            },
1982            ["out_ptr"],
1983        )
1984
1985    @make_mutation_test
1986    def test_add_nested_for_loop_multi_return():
1987        @triton.jit
1988        def add_4_times_kernel(
1989            in_ptr0,
1990            in_ptr1,
1991            out_ptr,
1992            n_elements,
1993            BLOCK_SIZE: "tl.constexpr",
1994        ):
1995            pid = tl.program_id(axis=0)
1996            block_start = pid * BLOCK_SIZE
1997            offsets = block_start + tl.arange(0, BLOCK_SIZE)
1998            mask = offsets < n_elements
1999            x = tl.load(in_ptr0 + offsets, mask=mask)
2000            y = tl.load(in_ptr1 + offsets, mask=mask)
2001            output1 = tl.zeros((n_elements,), dtype=tl.float32)
2002            output2 = tl.zeros((n_elements,), dtype=tl.float32)
2003            for i in range(2):
2004                for j in range(2):
2005                    output1 += y
2006                    output2 += x
2007            output = output1 + output2
2008            tl.store(out_ptr + offsets, output, mask=mask)
2009
2010        t = torch.randn(4)
2011        return (
2012            add_4_times_kernel,
2013            {
2014                "in_ptr0": t,
2015                "in_ptr1": t,
2016                "out_ptr": t,
2017                "n_elements": 4,
2018                "BLOCK_SIZE": 4,
2019            },
2020            ["out_ptr"],
2021        )
2022
2023    @make_mutation_test
2024    def test_labels():
2025        @triton.jit
2026        def kernel_with_label(
2027            in_ptr0,
2028            in_ptr1,
2029            out_ptr,
2030            n_elements,
2031            BLOCK_SIZE: "tl.constexpr",
2032        ):
2033            pid = tl.program_id(axis=0)
2034            if pid > 1:
2035                return
2036            block_start = pid * BLOCK_SIZE
2037            offsets = block_start + tl.arange(0, BLOCK_SIZE)
2038            mask = offsets < n_elements
2039            x = tl.load(in_ptr0 + offsets, mask=mask)
2040            y = tl.load(in_ptr1 + offsets, mask=mask)
2041            output = x + y
2042            tl.store(out_ptr + offsets, output, mask=mask)
2043
2044        t = torch.randn(4)
2045        return (
2046            kernel_with_label,
2047            {
2048                "in_ptr0": t,
2049                "in_ptr1": t,
2050                "out_ptr": t,
2051                "n_elements": 4,
2052                "BLOCK_SIZE": 4,
2053            },
2054            ["out_ptr"],
2055        )
2056
2057    @make_mutation_test
2058    def test_for_loop_arg():
2059        @triton.jit
2060        def fwd_kernel(
2061            X_ptr,
2062            W1_ptr,
2063            b1_ptr,
2064            O_ptr,
2065            M: tl.constexpr,
2066            C1: tl.constexpr,
2067            C2: tl.constexpr,
2068            BLOCK_SIZE_M: tl.constexpr,
2069            BLOCK_SIZE_C2: tl.constexpr,
2070        ):
2071            # Get program ids
2072            pid_m = tl.program_id(0)
2073
2074            # Compute offsets
2075            offs_c1 = tl.arange(0, C1)
2076            offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
2077
2078            # Load input data
2079            x_block_ptr = X_ptr + offs_m[:, None] * C1 + offs_c1[None, :]
2080            x = tl.load(x_block_ptr)
2081
2082            # Compute gating
2083            for c2 in range(0, tl.cdiv(C2, BLOCK_SIZE_C2)):
2084                # Compute block pointers
2085                offs_c2 = c2 * BLOCK_SIZE_C2 + tl.arange(0, BLOCK_SIZE_C2)
2086                o_block_ptr = O_ptr + offs_m[:, None] * C2 + offs_c2[None, :]
2087                w1_block_ptr = W1_ptr + offs_c1[:, None] * C2 + offs_c2[None, :]
2088                b1_block_ptr = b1_ptr + offs_c2
2089
2090                # Compute output
2091                w = tl.load(w1_block_ptr)
2092                b = tl.load(b1_block_ptr)
2093                o = tl.dot(x, w, allow_tf32=False)
2094                o += b[None, :]
2095
2096                # Store output
2097                tl.store(o_block_ptr, o)
2098
2099        t = torch.randn(64)
2100        return (
2101            fwd_kernel,
2102            {
2103                "X_ptr": t,
2104                "W1_ptr": t,
2105                "b1_ptr": t,
2106                "O_ptr": t,
2107                "M": 64,
2108                "C1": 64,
2109                "C2": 64,
2110                "BLOCK_SIZE_M": 64,
2111                "BLOCK_SIZE_C2": 64,
2112            },
2113            ["O_ptr"],
2114        )
2115
2116    @make_mutation_test
2117    def test_for_loop_arg_2():
2118        @triton.jit
2119        def fwd_kernel(
2120            x_ptr,
2121            o_ptr,
2122            M,
2123            N,
2124            stride_m,
2125            stride_n,
2126            BLOCK_B: tl.constexpr,
2127            BLOCK_M: tl.constexpr,
2128            BLOCK_N: tl.constexpr,
2129        ):
2130            # Get program ids
2131            pid_m = tl.program_id(0)
2132            X_block_ptr = tl.make_block_ptr(
2133                base=x_ptr,
2134                shape=(M, N),
2135                strides=(stride_m, stride_n),
2136                offsets=(0, 0),
2137                block_shape=(BLOCK_M, BLOCK_N),
2138                order=(1, 0),
2139            )
2140            O_block_ptr = tl.make_block_ptr(
2141                base=o_ptr,
2142                shape=(M, N),
2143                strides=(stride_m, stride_n),
2144                offsets=(0, 0),
2145                block_shape=(BLOCK_M, BLOCK_N),
2146                order=(1, 0),
2147            )
2148
2149            for _ in range(BLOCK_B):
2150                x = tl.load(X_block_ptr)
2151                tl.store(O_block_ptr, x)
2152
2153                X_block_ptr = tl.advance(X_block_ptr, (BLOCK_M, 0))
2154                O_block_ptr = tl.advance(O_block_ptr, (BLOCK_M, 0))
2155
2156        t = torch.randn((32, 64, 128))
2157        o = torch.empty_like(t)
2158        B, M, N = t.shape
2159        return (
2160            fwd_kernel,
2161            {
2162                "x_ptr": t,
2163                "o_ptr": o,
2164                "M": M,
2165                "N": N,
2166                "stride_m": N,
2167                "stride_n": 1,
2168                "BLOCK_B": B,
2169                "BLOCK_M": M,
2170                "BLOCK_N": N,
2171            },
2172            ["o_ptr"],
2173        )
2174
2175    @make_mutation_test
2176    def test_while_loop():
2177        @triton.jit
2178        def fwd_kernel(
2179            x_ptr,
2180            o_ptr,
2181            M,
2182            N,
2183            stride_m,
2184            stride_n,
2185            BLOCK_B: tl.constexpr,
2186            BLOCK_M: tl.constexpr,
2187            BLOCK_N: tl.constexpr,
2188        ):
2189            # Get program ids
2190            pid_m = tl.program_id(0)
2191            X_block_ptr = tl.make_block_ptr(
2192                base=x_ptr,
2193                shape=(M, N),
2194                strides=(stride_m, stride_n),
2195                offsets=(0, 0),
2196                block_shape=(BLOCK_M, BLOCK_N),
2197                order=(1, 0),
2198            )
2199            O_block_ptr = tl.make_block_ptr(
2200                base=o_ptr,
2201                shape=(M, N),
2202                strides=(stride_m, stride_n),
2203                offsets=(0, 0),
2204                block_shape=(BLOCK_M, BLOCK_N),
2205                order=(1, 0),
2206            )
2207
2208            i = 0
2209            while i < BLOCK_B:
2210                x = tl.load(X_block_ptr)
2211                tl.store(O_block_ptr, x)
2212
2213                X_block_ptr = tl.advance(X_block_ptr, (BLOCK_M, 0))
2214                O_block_ptr = tl.advance(O_block_ptr, (BLOCK_M, 0))
2215                i += 1
2216
2217        t = torch.randn((32, 64, 128))
2218        o = torch.empty_like(t)
2219        B, M, N = t.shape
2220        return (
2221            fwd_kernel,
2222            {
2223                "x_ptr": t,
2224                "o_ptr": o,
2225                "M": M,
2226                "N": N,
2227                "stride_m": N,
2228                "stride_n": 1,
2229                "BLOCK_B": B,
2230                "BLOCK_M": M,
2231                "BLOCK_N": N,
2232            },
2233            ["o_ptr"],
2234        )
2235
2236
2237if HAS_GPU:
2238    t = torch.randn(4)
2239    tt = torch.randn(4, 1)
2240    tests = [
2241        [
2242            add_kernel,
2243            {
2244                "in_ptr0": t,
2245                "in_ptr1": t,
2246                "out_ptr": t,
2247                "n_elements": 4,
2248                "BLOCK_SIZE": 4,
2249            },
2250            ["out_ptr"],
2251        ],
2252        [
2253            add_kernel_2d_autotuned,
2254            {
2255                "in_ptr0": t,
2256                "in_ptr1": t,
2257                "out_ptr": t,
2258                "x_elements": 4,
2259                "y_elements": 4,
2260            },
2261            ["out_ptr"],
2262        ],
2263        [
2264            indirection_kernel,
2265            {
2266                "in_ptr0": t,
2267                "out_ptr": t,
2268                "n_elements": 4,
2269                "BLOCK_SIZE": 4,
2270                "ACTIVATION": "mul2_inplace_kernel",
2271            },
2272            ["in_ptr0", "out_ptr"],
2273        ],
2274        [
2275            indirection_kernel,
2276            {
2277                "in_ptr0": t,
2278                "out_ptr": t,
2279                "n_elements": 4,
2280                "BLOCK_SIZE": 4,
2281                "ACTIVATION": "add_kernel",
2282            },
2283            ["out_ptr"],
2284        ],
2285        [
2286            mul2_inplace_kernel,
2287            {"ptr": t, "n_elements": 4, "BLOCK_SIZE": 4},
2288            ["ptr"],
2289        ],
2290        # Cant optimize since the kernel contains a tl.inline_asm_elementwise
2291        [
2292            inline_asm_kernel,
2293            {"X": t, "Y": t, "Z": t, "n": 4, "BLOCK": 4},
2294            ["X", "Y", "Z"],
2295        ],
2296        [
2297            add_kernel_with_block_ptr,
2298            {
2299                "x_ptr": t,
2300                "y_ptr": t,
2301                "output_ptr": t,
2302                "n_elements": 4,
2303                "BLOCK_SIZE": 4,
2304            },
2305            ["output_ptr"],
2306        ],
2307        [
2308            kernel_with_block_ptr_2d,
2309            {
2310                "x_ptr": tt,
2311                "output_ptr": tt,
2312                "n_elements": 4,
2313                "BLOCK_SIZE": 4,
2314            },
2315            ["output_ptr"],
2316        ],
2317        [
2318            add_kernel_with_import,
2319            {
2320                "in_ptr0": t,
2321                "in_ptr1": t,
2322                "out_ptr": t,
2323                "n_elements": 4,
2324                "BLOCK_SIZE": 4,
2325            },
2326            ["out_ptr"],
2327        ],
2328        [
2329            atomic_add_kernel,
2330            {
2331                "in_ptr0": t,
2332                "in_ptr1": t,
2333                "out_ptr": t,
2334                "n_elements": 4,
2335                "BLOCK_SIZE": 4,
2336            },
2337            ["out_ptr"],
2338        ],
2339        [
2340            add_4_times_kernel,
2341            {
2342                "in_ptr0": t,
2343                "in_ptr1": t,
2344                "out_ptr": t,
2345                "n_elements": 4,
2346                "BLOCK_SIZE": 4,
2347            },
2348            ["out_ptr"],
2349        ],
2350        [
2351            cond_op_kernel,
2352            {
2353                "in_ptr0": t,
2354                "in_ptr1": t,
2355                "out_ptr": t,
2356                "n_elements": 4,
2357                "BLOCK_SIZE": 4,
2358            },
2359            ["out_ptr"],
2360        ],
2361    ]
2362    for kernel, inputs, outputs in tests:
2363        fn = make_mutation_test(
2364            # Add default arguments to avoid Python lambda capture pitfall
2365            # This forces the capture at lambda creation
2366            lambda kernel=kernel, inputs=inputs, outputs=outputs: (
2367                kernel,
2368                inputs,
2369                outputs,
2370            )
2371        )
2372        name = f"test_mutations_{kernel.fn.__name__}"
2373        # Poor way to make test names be unique
2374        while name in MutationTests.__dict__:
2375            name += "1"
2376
2377        setattr(MutationTests, name, fn)
2378
2379
2380class CustomOpTests(torch._inductor.test_case.TestCase):
2381    """Tests for custom ops wrapping triton kernels"""
2382
2383    @requires_gpu
2384    @common_utils.parametrize("autotuned", [False, True])
2385    @common_utils.parametrize("dynamic", [False, True])
2386    def test_add_kernel(self, autotuned, dynamic):
2387        from torch._inductor.utils import run_and_get_code
2388
2389        libname = "my_cool_namespace"
2390        opname = "my_triton_operator"
2391
2392        @torch._library.triton_op(f"{libname}::{opname}", mutates_args={})
2393        def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
2394            output = torch.empty_like(x)
2395            n_elements = output.numel()
2396
2397            def grid(meta):
2398                return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
2399
2400            if autotuned:
2401                capture_triton(add_kernel_autotuned)[grid](x, y, output, n_elements)
2402            else:
2403                capture_triton(add_kernel)[grid](x, y, output, n_elements, 16)
2404            return output
2405
2406        def f(x, y):
2407            return add(x, y)
2408
2409        x = torch.randn(3, device=GPU_TYPE)
2410        y = torch.randn(3, device=GPU_TYPE)
2411
2412        out = f(x, y)
2413        expected = x + y
2414        self.assertEqual(out, expected)
2415        out_compiled, codes = run_and_get_code(torch.compile(f, dynamic=dynamic), x, y)
2416        self.assertEqual(out_compiled, expected)
2417        self.assertEqual(len(codes), 1)
2418
2419        # Check that we decomposed the operator away
2420        code = "\n".join(codes[0])
2421        self.assertNotIn(libname, code)
2422        self.assertNotIn(opname, code)
2423
2424    @unittest.skipIf(not has_triton_package(), "requires triton")
2425    def test_capture_triton_meta(self):
2426        import triton
2427        import triton.language as tl
2428
2429        @triton.jit
2430        def add_kernel(
2431            in_ptr0,
2432            in_ptr1,
2433            out_ptr,
2434            n_elements,
2435            BLOCK_SIZE: "tl.constexpr",
2436        ):
2437            pid = tl.program_id(axis=0)
2438            block_start = pid * BLOCK_SIZE
2439            offsets = block_start + tl.arange(0, BLOCK_SIZE)
2440            mask = offsets < n_elements
2441            x = tl.load(in_ptr0 + offsets, mask=mask)
2442            y = tl.load(in_ptr1 + offsets, mask=mask)
2443            output = x + y
2444            tl.store(out_ptr + offsets, output, mask=mask)
2445
2446        @torch._library.triton_op("mylib::add", mutates_args=())
2447        def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
2448            output = torch.empty_like(x)
2449            n_elements = output.numel()
2450
2451            def grid(meta):
2452                return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
2453
2454            capture_triton(add_kernel)[grid](x, y, output, n_elements, 16)
2455            return output
2456
2457        def f(x, y):
2458            return add(x, y)
2459
2460        x = torch.randn(3, device="meta")
2461        y = torch.randn(3, device="meta")
2462
2463        out = f(x, y)
2464        expected = torch.empty_like(x)
2465        self.assertEqual(out, expected)
2466
2467    @requires_gpu
2468    def test_capture_triton_disabled_in_triton_op(self):
2469        import triton
2470        import triton.language as tl
2471
2472        @triton.jit
2473        def add_kernel(
2474            in_ptr0,
2475            in_ptr1,
2476            out_ptr,
2477            n_elements,
2478            BLOCK_SIZE: "tl.constexpr",
2479        ):
2480            pid = tl.program_id(axis=0)
2481            block_start = pid * BLOCK_SIZE
2482            offsets = block_start + tl.arange(0, BLOCK_SIZE)
2483            mask = offsets < n_elements
2484            x = tl.load(in_ptr0 + offsets, mask=mask)
2485            y = tl.load(in_ptr1 + offsets, mask=mask)
2486            output = x + y
2487            tl.store(out_ptr + offsets, output, mask=mask)
2488
2489        add_kernel_decorated = torch._library.capture_triton(add_kernel)
2490
2491        status = []
2492
2493        @torch._library.triton_op("mylib::add", mutates_args=())
2494        def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
2495            import torch._higher_order_ops.triton_kernel_wrap
2496
2497            status.append(torch._library.triton.is_capture_triton_enabled())
2498
2499            # capture_triton should return the kernel directly if disabled
2500            result = torch._library.capture_triton(add_kernel)
2501            self.assertIs(result, add_kernel)
2502
2503            # Smoke test: check that with capture_triton disabled this still does something
2504            output = torch.empty_like(x)
2505            output2 = torch.empty_like(x)
2506
2507            n_elements = output.numel()
2508            grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
2509            add_kernel_decorated[grid](x, y, output, n_elements, BLOCK_SIZE=16)
2510
2511            add_kernel_decorated.run(
2512                x, y, output2, n_elements, BLOCK_SIZE=16, grid=grid, warmup=False
2513            )
2514
2515            return output + output2
2516
2517        x = torch.randn(3, device=GPU_TYPE)
2518        y = torch.randn(3, device=GPU_TYPE)
2519        z = add(x, y)
2520        self.assertEqual(status[-1], False)
2521        self.assertEqual(z, (x + y) * 2)
2522
2523    @requires_gpu
2524    @common_utils.parametrize("dynamic", [False, True])
2525    @common_utils.parametrize("autotune", [False, True])
2526    def test_capture_triton_special_kwargs(self, dynamic, autotune):
2527        @triton.jit
2528        def add_kernel(
2529            in_ptr0,
2530            in_ptr1,
2531            out_ptr,
2532            n_elements,
2533            BLOCK_SIZE: "tl.constexpr",
2534        ):
2535            pid = tl.program_id(axis=0)
2536            block_start = pid * BLOCK_SIZE
2537            offsets = block_start + tl.arange(0, BLOCK_SIZE)
2538            mask = offsets < n_elements
2539            x = tl.load(in_ptr0 + offsets, mask=mask)
2540            y = tl.load(in_ptr1 + offsets, mask=mask)
2541            output = x + y
2542            tl.store(out_ptr + offsets, output, mask=mask)
2543
2544        if autotune:
2545            add_kernel = triton.autotune(
2546                configs=[
2547                    triton.Config({"BLOCK_SIZE": 128}),
2548                    triton.Config({"BLOCK_SIZE": 64}),
2549                ],
2550                key=["n_elements"],
2551            )(add_kernel)
2552
2553        def f(x, y):
2554            output = torch.zeros_like(x)
2555            n_elements = output.numel()
2556            grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
2557            if autotune:
2558                kwargs = {}
2559            else:
2560                kwargs = {"BLOCK_SIZE": 128}
2561            capture_triton(add_kernel)[grid](
2562                x,
2563                y,
2564                output,
2565                n_elements,
2566                num_warps=8,
2567                num_stages=3,
2568                **kwargs,
2569            )
2570            return output
2571
2572        x = torch.randn(4, device=GPU_TYPE)
2573        tracing_mode = "symbolic" if dynamic else "fake"
2574
2575        result = f(x, x)
2576        self.assertEqual(result, x + x)
2577
2578        from functorch import make_fx
2579
2580        gm = make_fx(f, tracing_mode=tracing_mode)(x, x)
2581        self.assertEqual(gm(x, x), x + x)
2582
2583
2584common_utils.instantiate_parametrized_tests(KernelTests)
2585common_utils.instantiate_parametrized_tests(CustomOpTests)
2586
2587
2588if __name__ == "__main__":
2589    from torch._inductor.test_case import run_tests
2590
2591    run_tests()
2592