1# mypy: ignore-errors 2 3import unittest 4 5from torch.testing._internal.inductor_utils import HAS_CUDA, HAS_GPU 6from torch.utils._triton import has_triton 7 8 9requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda") 10requires_gpu = unittest.skipUnless(HAS_GPU, "requires gpu") 11 12if has_triton(): 13 import triton 14 from triton import language as tl 15 16 # Define here so that multiple tests can take advantage of it 17 @triton.jit 18 def add_kernel( 19 in_ptr0, 20 in_ptr1, 21 out_ptr, 22 n_elements, 23 BLOCK_SIZE: "tl.constexpr", 24 ): 25 pid = tl.program_id(axis=0) 26 block_start = pid * BLOCK_SIZE 27 offsets = block_start + tl.arange(0, BLOCK_SIZE) 28 mask = offsets < n_elements 29 x = tl.load(in_ptr0 + offsets, mask=mask) 30 y = tl.load(in_ptr1 + offsets, mask=mask) 31 output = x + y 32 tl.store(out_ptr + offsets, output, mask=mask) 33 34 @triton.jit 35 def add_kernel_with_optional_param( 36 in_ptr0, 37 in_ptr1, 38 out_ptr, 39 n_elements, 40 ARGS_PASSED: "tl.constexpr", 41 BLOCK_SIZE: "tl.constexpr", 42 ): 43 pid = tl.program_id(axis=0) 44 block_start = pid * BLOCK_SIZE 45 offsets = block_start + tl.arange(0, BLOCK_SIZE) 46 mask = offsets < n_elements 47 x = tl.load(in_ptr0 + offsets, mask=mask) 48 if ARGS_PASSED == "two": 49 y = tl.load(in_ptr1 + offsets, mask=mask) 50 output = x + y 51 else: 52 output = x 53 tl.store(out_ptr + offsets, output, mask=mask) 54 55 @triton.autotune( 56 configs=[ 57 triton.Config({"BLOCK_SIZE": 128}, num_stages=3, num_warps=8), 58 triton.Config({"BLOCK_SIZE": 128}, num_stages=4, num_warps=4), 59 triton.Config({"BLOCK_SIZE": 64}, num_stages=3, num_warps=8), 60 triton.Config({"BLOCK_SIZE": 64}, num_stages=4, num_warps=4), 61 ], 62 key=[], 63 ) 64 @triton.jit 65 def add_kernel_autotuned( 66 in_ptr0, 67 in_ptr1, 68 out_ptr, 69 n_elements, 70 BLOCK_SIZE: "tl.constexpr", 71 ): 72 pid = tl.program_id(axis=0) 73 block_start = pid * BLOCK_SIZE 74 offsets = block_start + tl.arange(0, BLOCK_SIZE) 75 mask = offsets < n_elements 76 x = tl.load(in_ptr0 + offsets, mask=mask) 77 y = tl.load(in_ptr1 + offsets, mask=mask) 78 output = x + y 79 tl.store(out_ptr + offsets, output, mask=mask) 80 81 @triton.autotune( 82 configs=[ 83 triton.Config({"BLOCK_SIZE": 16}, num_stages=2, num_warps=2), 84 ], 85 key=[], 86 ) 87 @triton.jit 88 def add_kernel_autotuned_weird_param_order( 89 in_ptr0, 90 in_ptr1, 91 n_elements, 92 BLOCK_SIZE: "tl.constexpr", 93 out_ptr, 94 ): 95 # out_ptr is after an autotuned param that's declared as tl.constexpr. 96 # This param ordering can create bugs if not handled correctly. 97 pid = tl.program_id(axis=0) 98 block_start = pid * BLOCK_SIZE 99 offsets = block_start + tl.arange(0, BLOCK_SIZE) 100 mask = offsets < n_elements 101 x = tl.load(in_ptr0 + offsets, mask=mask) 102 y = tl.load(in_ptr1 + offsets, mask=mask) 103 output = x + y 104 tl.store(out_ptr + offsets, output, mask=mask) 105 106 @triton.autotune( 107 configs=[ 108 triton.Config( 109 {"BLOCK_SIZE_X": 128, "BLOCK_SIZE_Y": 128}, num_stages=3, num_warps=8 110 ), 111 triton.Config( 112 {"BLOCK_SIZE_X": 128, "BLOCK_SIZE_Y": 128}, num_stages=4, num_warps=4 113 ), 114 triton.Config( 115 {"BLOCK_SIZE_X": 64, "BLOCK_SIZE_Y": 64}, num_stages=3, num_warps=8 116 ), 117 triton.Config( 118 {"BLOCK_SIZE_X": 64, "BLOCK_SIZE_Y": 64}, num_stages=4, num_warps=4 119 ), 120 ], 121 key=[], 122 ) 123 @triton.jit 124 def add_kernel_2d_autotuned( 125 in_ptr0, 126 in_ptr1, 127 out_ptr, 128 x_elements, 129 y_elements, 130 BLOCK_SIZE_X: "tl.constexpr", 131 BLOCK_SIZE_Y: "tl.constexpr", 132 ): 133 xoffset = tl.program_id(0) * BLOCK_SIZE_X 134 xindex = xoffset + tl.arange(0, BLOCK_SIZE_X)[:, None] 135 xmask = xindex < x_elements 136 yoffset = tl.program_id(1) * BLOCK_SIZE_Y 137 yindex = yoffset + tl.arange(0, BLOCK_SIZE_Y)[None, :] 138 ymask = yindex < y_elements 139 x1 = xindex 140 y0 = yindex 141 tmp0 = tl.load(in_ptr0 + (x1 + (x_elements * y0)), xmask & ymask) 142 tmp1 = tl.load(in_ptr0 + (y0 + (y_elements * x1)), xmask & ymask) 143 tmp2 = tmp0 + tmp1 144 tl.store(out_ptr + (x1 + (x_elements * y0)), tmp2, xmask & ymask) 145 146 def _dummy_early_config_prune(configs, *_, **__): 147 return configs 148 149 @triton.autotune( 150 configs=[ 151 triton.Config({"BLOCK_SIZE": 128}, num_stages=3, num_warps=8), 152 triton.Config({"BLOCK_SIZE": 64}, num_stages=4, num_warps=4), 153 ], 154 key=[], 155 warmup=10, 156 rep=20, 157 prune_configs_by={"early_config_prune": _dummy_early_config_prune}, 158 ) 159 @triton.jit 160 def add_kernel_autotuned_with_unsupported_args( 161 in_ptr0, 162 in_ptr1, 163 out_ptr, 164 n_elements, 165 BLOCK_SIZE: "tl.constexpr", 166 ): 167 pid = tl.program_id(axis=0) 168 block_start = pid * BLOCK_SIZE 169 offsets = block_start + tl.arange(0, BLOCK_SIZE) 170 mask = offsets < n_elements 171 x = tl.load(in_ptr0 + offsets, mask=mask) 172 y = tl.load(in_ptr1 + offsets, mask=mask) 173 output = x + y 174 tl.store(out_ptr + offsets, output, mask=mask) 175 176 @triton.jit 177 def add_kernel_with_scaling( 178 in_ptr0, 179 in_ptr1, 180 out_ptr, 181 n_elements, 182 scaling_factor, 183 BLOCK_SIZE: "tl.constexpr", 184 ): 185 pid = tl.program_id(axis=0) 186 block_start = pid * BLOCK_SIZE 187 offsets = block_start + tl.arange(0, BLOCK_SIZE) 188 mask = offsets < n_elements 189 x = tl.load(in_ptr0 + offsets, mask=mask) 190 y = tl.load(in_ptr1 + offsets, mask=mask) 191 output = (x + y) * scaling_factor 192 tl.store(out_ptr + offsets, output, mask=mask) 193 194 @triton.jit 195 def mul2_kernel( 196 in_ptr0, 197 out_ptr, 198 n_elements, 199 BLOCK_SIZE: "tl.constexpr", 200 ): 201 pid = tl.program_id(axis=0) 202 block_start = pid * BLOCK_SIZE 203 offsets = block_start + tl.arange(0, BLOCK_SIZE) 204 mask = offsets < n_elements 205 x = tl.load(in_ptr0 + offsets, mask=mask) 206 output = 2 * x 207 tl.store(out_ptr + offsets, output, mask=mask) 208 209 @triton.jit 210 def mul2_inplace_kernel( 211 ptr, 212 n_elements, 213 BLOCK_SIZE: "tl.constexpr", 214 ): 215 pid = tl.program_id(axis=0) 216 block_start = pid * BLOCK_SIZE 217 offsets = block_start + tl.arange(0, BLOCK_SIZE) 218 mask = offsets < n_elements 219 x = tl.load(ptr + offsets, mask=mask) 220 output = 2 * x 221 tl.store(ptr + offsets, output, mask=mask) 222 223 @triton.jit 224 def zero_negs(x): 225 return tl.where(x >= 0, x, 0) 226 227 @triton.jit 228 def indirection_kernel( 229 in_ptr0, 230 out_ptr, 231 n_elements, 232 BLOCK_SIZE: "tl.constexpr", 233 ACTIVATION: "tl.constexpr", 234 ): 235 pid = tl.program_id(axis=0) 236 block_start = pid * BLOCK_SIZE 237 offsets = block_start + tl.arange(0, BLOCK_SIZE) 238 mask = offsets < n_elements 239 if ACTIVATION == "mul2_inplace_kernel": 240 mul2_inplace_kernel(in_ptr0, n_elements, BLOCK_SIZE=BLOCK_SIZE) 241 elif ACTIVATION == "add_kernel": 242 add_kernel(in_ptr0, in_ptr0, out_ptr, n_elements, BLOCK_SIZE=BLOCK_SIZE) 243 x = tl.load(in_ptr0 + offsets, mask=mask) 244 tl.store(out_ptr + offsets, x, mask=mask) 245 246 @triton.jit 247 def double_strided_kernel( 248 in_ptr, 249 out_ptr, 250 in_y_stride, 251 out_y_stride, 252 X_BLOCK_SIZE: "tl.constexpr", 253 Y_BLOCK_SIZE: "tl.constexpr", 254 ): 255 xid = tl.program_id(axis=0) 256 yid = tl.program_id(axis=1) 257 x_start = xid * X_BLOCK_SIZE 258 y_start = yid * Y_BLOCK_SIZE 259 x_offsets = x_start + tl.arange(0, X_BLOCK_SIZE) 260 y_offsets = y_start + tl.arange(0, Y_BLOCK_SIZE) 261 src_offsets = y_offsets[:, None] * in_y_stride + x_offsets[None, :] 262 dst_offsets = y_offsets[:, None] * out_y_stride + x_offsets[None, :] 263 src = tl.load(in_ptr + src_offsets) 264 tl.store(out_ptr + dst_offsets, src * 2.0) 265 266 @triton.jit 267 def inline_asm_kernel(X, Y, Z, n: "tl.constexpr", BLOCK: "tl.constexpr"): 268 x = tl.load(X + tl.arange(0, BLOCK)) 269 y = tl.load(Y + tl.arange(0, BLOCK)) 270 s = tl.full([BLOCK], n, tl.int32) 271 z = tl.inline_asm_elementwise( 272 "shf.l.wrap.b32 $0, $1, $2, $3;", 273 "=r,r, r, r", 274 [x, y, s], 275 dtype=tl.int32, 276 is_pure=True, 277 pack=1, 278 ) 279 tl.store(Z + tl.arange(0, BLOCK), z) 280 281 @triton.jit 282 def add_kernel_with_block_ptr( 283 x_ptr, 284 y_ptr, 285 output_ptr, 286 n_elements, 287 BLOCK_SIZE: tl.constexpr, 288 ): 289 pid = tl.program_id(axis=0) 290 block_start = pid * BLOCK_SIZE 291 x = tl.load( 292 tl.make_block_ptr( 293 base=x_ptr, 294 shape=[n_elements], 295 strides=[1], 296 offsets=[block_start], 297 block_shape=[BLOCK_SIZE], 298 order=[0], 299 ), 300 boundary_check=[0], 301 ) 302 y = tl.load( 303 tl.make_block_ptr( 304 base=y_ptr, 305 shape=[n_elements], 306 strides=[1], 307 offsets=[block_start], 308 block_shape=[BLOCK_SIZE], 309 order=[0], 310 ), 311 boundary_check=[0], 312 ) 313 output = x + y 314 tl.store( 315 tl.make_block_ptr( 316 base=output_ptr, 317 shape=[n_elements], 318 strides=[1], 319 offsets=[block_start], 320 block_shape=[BLOCK_SIZE], 321 order=[0], 322 ), 323 output, 324 boundary_check=[0], 325 ) 326 327 @triton.jit 328 def kernel_with_block_ptr_2d( 329 x_ptr, 330 output_ptr, 331 n_elements, 332 BLOCK_SIZE: tl.constexpr, 333 ): 334 pid = tl.program_id(axis=0) 335 block_start = pid * BLOCK_SIZE 336 x = tl.load( 337 tl.make_block_ptr( 338 base=x_ptr, 339 shape=[n_elements, 1], 340 strides=[1, 1], 341 offsets=[block_start, 0], 342 block_shape=[BLOCK_SIZE, 1], 343 order=[1, 0], 344 ), 345 boundary_check=[0], 346 ) 347 output = x 348 tl.store( 349 tl.make_block_ptr( 350 base=output_ptr, 351 shape=[n_elements, 1], 352 strides=[1, 1], 353 offsets=[block_start, 0], 354 block_shape=[BLOCK_SIZE, 1], 355 order=[1, 0], 356 ), 357 output, 358 boundary_check=[0], 359 ) 360 361 from triton.language import load, store 362 363 @triton.jit 364 def add_kernel_with_import( 365 in_ptr0, 366 in_ptr1, 367 out_ptr, 368 n_elements, 369 BLOCK_SIZE: "tl.constexpr", 370 ): 371 pid = tl.program_id(axis=0) 372 block_start = pid * BLOCK_SIZE 373 offsets = block_start + tl.arange(0, BLOCK_SIZE) 374 mask = offsets < n_elements 375 x = load(in_ptr0 + offsets, mask=mask) 376 y = load(in_ptr1 + offsets, mask=mask) 377 output = x + y 378 store(out_ptr + offsets, output, mask=mask) 379 380 @triton.jit 381 def cond_op_kernel( 382 in_ptr0, 383 in_ptr1, 384 out_ptr, 385 n_elements, 386 BLOCK_SIZE: "tl.constexpr", 387 ): 388 pid = tl.program_id(axis=0) 389 block_start = pid * BLOCK_SIZE 390 offsets = block_start + tl.arange(0, BLOCK_SIZE) 391 mask = offsets < n_elements 392 x = tl.load(in_ptr0 + offsets, mask=mask) 393 y = tl.load(in_ptr1 + offsets, mask=mask) 394 if tl.program_id(0) == 0: 395 output = x + y 396 else: 397 output = x * y 398 tl.store(out_ptr + offsets, output, mask=mask) 399 400 @triton.jit 401 def atomic_add_kernel( 402 in_ptr0, 403 in_ptr1, 404 out_ptr, 405 n_elements, 406 BLOCK_SIZE: "tl.constexpr", 407 ): 408 pid = tl.program_id(axis=0) 409 block_start = pid * BLOCK_SIZE 410 offsets = block_start + tl.arange(0, BLOCK_SIZE) 411 mask = offsets < n_elements 412 x = tl.load(in_ptr0 + offsets, mask=mask) 413 y = tl.load(in_ptr1 + offsets, mask=mask) 414 output = x + y 415 tl.atomic_add(out_ptr + offsets, output, mask=mask) 416 417 @triton.jit 418 def add_4_times_kernel( 419 in_ptr0, 420 in_ptr1, 421 out_ptr, 422 n_elements, 423 BLOCK_SIZE: "tl.constexpr", 424 ): 425 pid = tl.program_id(axis=0) 426 block_start = pid * BLOCK_SIZE 427 offsets = block_start + tl.arange(0, BLOCK_SIZE) 428 mask = offsets < n_elements 429 x = tl.load(in_ptr0 + offsets, mask=mask) 430 y = tl.load(in_ptr1 + offsets, mask=mask) 431 for i in range(2): 432 output = x + y 433 tl.store(out_ptr + offsets, output, mask=mask) 434 i = 2 435 while i > 0: 436 i -= 1 437 output = x + y 438 tl.store(out_ptr + offsets, output, mask=mask) 439 440 @triton.jit 441 def add_kernel_out_of_order_fn2( 442 in_ptr0, 443 in_ptr1, 444 n_elements, 445 out_ptr, 446 BLOCK_SIZE: "tl.constexpr", 447 ): 448 pid = tl.program_id(axis=0) 449 block_start = pid * BLOCK_SIZE 450 offsets = block_start + tl.arange(0, BLOCK_SIZE) 451 mask = offsets < n_elements 452 x = tl.load(in_ptr0 + offsets, mask=mask) 453 y = tl.load(in_ptr1 + offsets, mask=mask) 454 output = x + y 455 tl.store(out_ptr + offsets, output, mask=mask) 456