xref: /aosp_15_r20/external/pytorch/test/inductor/test_triton_heuristics.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: inductor"]
2
3import sys
4import unittest
5
6import torch
7from torch.testing._internal.common_utils import IS_LINUX, skipIfXpu
8from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU
9
10
11try:
12    import triton  # noqa: F401
13    import triton.language as tl
14except ImportError:
15    if __name__ == "__main__":
16        sys.exit(0)
17    raise unittest.SkipTest("requires triton")  # noqa: B904
18
19from torch._inductor import config
20from torch._inductor.runtime.hints import (
21    DeviceProperties,
22    HeuristicType,
23    TRITON_MAX_BLOCK,
24)
25from torch._inductor.runtime.triton_helpers import math as tl_math
26from torch._inductor.runtime.triton_heuristics import CachingAutotuner, triton_config
27from torch._inductor.test_case import run_tests, TestCase
28
29
30class TestTritonHeuristics(TestCase):
31    device_type = GPU_TYPE
32
33    def test_triton_config(self):
34        """
35        Make sure block size does not exceed the maximum defined in inductor config.
36        """
37        cfg = triton_config([2048, 2], 64, 64)
38        for label in "XYZ":
39            key = f"{label}BLOCK"
40            if key not in cfg.kwargs:
41                continue
42            self.assertTrue(cfg.kwargs[key] <= TRITON_MAX_BLOCK[label])
43
44    def _test_artificial_zgrid(self):
45        def forward(primals_1, primals_2, primals_5):
46            view = torch.ops.aten.reshape.default(primals_5, [-1, 2, 4])
47            primals_5 = None
48            permute = torch.ops.aten.permute.default(view, [0, 2, 1])
49            clone = torch.ops.aten.clone.default(
50                permute, memory_format=torch.contiguous_format
51            )
52            permute = None
53            view_1 = torch.ops.aten.reshape.default(clone, [-1, 4])
54            clone = None
55            permute_1 = torch.ops.aten.permute.default(primals_1, [1, 0])
56            primals_1 = None
57            addmm = torch.ops.aten.addmm.default(primals_2, view_1, permute_1)
58            primals_2 = None
59            return addmm
60
61        s0 = 16777472
62        s1 = 8
63
64        args = [
65            torch.rand([2, 4], device=GPU_TYPE),
66            torch.rand([2], device=GPU_TYPE),
67            torch.rand([s0, s1], device=GPU_TYPE),
68        ]
69        torch._dynamo.mark_dynamic(args[-1], 0)
70        foo_c = torch.compile(forward)
71
72        self.assertEqual(forward(*args), foo_c(*args))
73
74        args = [
75            torch.rand([2, 4], device=GPU_TYPE),
76            torch.rand([2], device=GPU_TYPE),
77            torch.rand([s0, s1], device=GPU_TYPE),
78        ]
79        self.assertEqual(forward(*args), foo_c(*args))
80
81    @skipIfXpu
82    def test_artificial_zgrid(self):
83        self._test_artificial_zgrid()
84
85    @skipIfXpu
86    @config.patch("cpp_wrapper", True)
87    def test_artificial_grid_cpp_wrapper(self):
88        self._test_artificial_zgrid()
89
90    def _get_cos_kernel_caching_autotuner_args(self):
91        from triton.compiler.compiler import AttrsDescriptor
92
93        @triton.jit
94        def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK: tl.constexpr):
95            xnumel = 16
96            xoffset = tl.program_id(0) * XBLOCK
97            xindex = xoffset + tl.arange(0, XBLOCK)[:]
98            xmask = xindex < xnumel
99            x0 = xindex
100            tmp0 = tl.load(in_ptr0 + (x0), xmask)
101            tmp1 = tl_math.cos(tmp0)
102            tl.store(out_ptr0 + (x0), tmp1, xmask)
103
104        triton_meta = {
105            "signature": {0: "*fp32", 1: "*fp32", 2: "i32"},
106            "device": DeviceProperties.create(torch.device("cuda")),
107            "constants": {},
108            "configs": [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())],
109        }
110
111        configs = [
112            triton_config([16], 64),
113            triton_config([256], 64),
114        ]
115
116        inductor_meta = {}
117
118        return {
119            "fn": triton_,
120            "triton_meta": triton_meta,
121            "configs": configs,
122            "save_cache_hook": False,
123            "mutated_arg_names": [],
124            "heuristic_type": HeuristicType.POINTWISE,
125            "inductor_meta": inductor_meta,
126        }
127
128    @skipIfXpu
129    def test_pre_hook_assert(self):
130        # assert if any of the configs passed to the CachingAutotuner have pre-hooks
131        args = self._get_cos_kernel_caching_autotuner_args()
132
133        def pre_hook(kwargs):
134            if "in_ptr0" in kwargs:
135                kwargs["in_ptr0"].zero_()
136
137        for cfg in args["configs"]:
138            cfg.pre_hook = pre_hook
139
140        with self.assertRaisesRegex(AssertionError, "pre_hook"):
141            autotuner = CachingAutotuner(**args)
142
143
144if __name__ == "__main__":
145    if IS_LINUX and HAS_GPU:
146        run_tests()
147