1# Owner(s): ["module: inductor"] 2 3import sys 4import unittest 5 6import torch 7from torch.testing._internal.common_utils import IS_LINUX, skipIfXpu 8from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU 9 10 11try: 12 import triton # noqa: F401 13 import triton.language as tl 14except ImportError: 15 if __name__ == "__main__": 16 sys.exit(0) 17 raise unittest.SkipTest("requires triton") # noqa: B904 18 19from torch._inductor import config 20from torch._inductor.runtime.hints import ( 21 DeviceProperties, 22 HeuristicType, 23 TRITON_MAX_BLOCK, 24) 25from torch._inductor.runtime.triton_helpers import math as tl_math 26from torch._inductor.runtime.triton_heuristics import CachingAutotuner, triton_config 27from torch._inductor.test_case import run_tests, TestCase 28 29 30class TestTritonHeuristics(TestCase): 31 device_type = GPU_TYPE 32 33 def test_triton_config(self): 34 """ 35 Make sure block size does not exceed the maximum defined in inductor config. 36 """ 37 cfg = triton_config([2048, 2], 64, 64) 38 for label in "XYZ": 39 key = f"{label}BLOCK" 40 if key not in cfg.kwargs: 41 continue 42 self.assertTrue(cfg.kwargs[key] <= TRITON_MAX_BLOCK[label]) 43 44 def _test_artificial_zgrid(self): 45 def forward(primals_1, primals_2, primals_5): 46 view = torch.ops.aten.reshape.default(primals_5, [-1, 2, 4]) 47 primals_5 = None 48 permute = torch.ops.aten.permute.default(view, [0, 2, 1]) 49 clone = torch.ops.aten.clone.default( 50 permute, memory_format=torch.contiguous_format 51 ) 52 permute = None 53 view_1 = torch.ops.aten.reshape.default(clone, [-1, 4]) 54 clone = None 55 permute_1 = torch.ops.aten.permute.default(primals_1, [1, 0]) 56 primals_1 = None 57 addmm = torch.ops.aten.addmm.default(primals_2, view_1, permute_1) 58 primals_2 = None 59 return addmm 60 61 s0 = 16777472 62 s1 = 8 63 64 args = [ 65 torch.rand([2, 4], device=GPU_TYPE), 66 torch.rand([2], device=GPU_TYPE), 67 torch.rand([s0, s1], device=GPU_TYPE), 68 ] 69 torch._dynamo.mark_dynamic(args[-1], 0) 70 foo_c = torch.compile(forward) 71 72 self.assertEqual(forward(*args), foo_c(*args)) 73 74 args = [ 75 torch.rand([2, 4], device=GPU_TYPE), 76 torch.rand([2], device=GPU_TYPE), 77 torch.rand([s0, s1], device=GPU_TYPE), 78 ] 79 self.assertEqual(forward(*args), foo_c(*args)) 80 81 @skipIfXpu 82 def test_artificial_zgrid(self): 83 self._test_artificial_zgrid() 84 85 @skipIfXpu 86 @config.patch("cpp_wrapper", True) 87 def test_artificial_grid_cpp_wrapper(self): 88 self._test_artificial_zgrid() 89 90 def _get_cos_kernel_caching_autotuner_args(self): 91 from triton.compiler.compiler import AttrsDescriptor 92 93 @triton.jit 94 def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK: tl.constexpr): 95 xnumel = 16 96 xoffset = tl.program_id(0) * XBLOCK 97 xindex = xoffset + tl.arange(0, XBLOCK)[:] 98 xmask = xindex < xnumel 99 x0 = xindex 100 tmp0 = tl.load(in_ptr0 + (x0), xmask) 101 tmp1 = tl_math.cos(tmp0) 102 tl.store(out_ptr0 + (x0), tmp1, xmask) 103 104 triton_meta = { 105 "signature": {0: "*fp32", 1: "*fp32", 2: "i32"}, 106 "device": DeviceProperties.create(torch.device("cuda")), 107 "constants": {}, 108 "configs": [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())], 109 } 110 111 configs = [ 112 triton_config([16], 64), 113 triton_config([256], 64), 114 ] 115 116 inductor_meta = {} 117 118 return { 119 "fn": triton_, 120 "triton_meta": triton_meta, 121 "configs": configs, 122 "save_cache_hook": False, 123 "mutated_arg_names": [], 124 "heuristic_type": HeuristicType.POINTWISE, 125 "inductor_meta": inductor_meta, 126 } 127 128 @skipIfXpu 129 def test_pre_hook_assert(self): 130 # assert if any of the configs passed to the CachingAutotuner have pre-hooks 131 args = self._get_cos_kernel_caching_autotuner_args() 132 133 def pre_hook(kwargs): 134 if "in_ptr0" in kwargs: 135 kwargs["in_ptr0"].zero_() 136 137 for cfg in args["configs"]: 138 cfg.pre_hook = pre_hook 139 140 with self.assertRaisesRegex(AssertionError, "pre_hook"): 141 autotuner = CachingAutotuner(**args) 142 143 144if __name__ == "__main__": 145 if IS_LINUX and HAS_GPU: 146 run_tests() 147