xref: /aosp_15_r20/external/pytorch/torch/testing/_internal/triton_utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: ignore-errors
2
3import unittest
4
5from torch.testing._internal.inductor_utils import HAS_CUDA, HAS_GPU
6from torch.utils._triton import has_triton
7
8
9requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda")
10requires_gpu = unittest.skipUnless(HAS_GPU, "requires gpu")
11
12if has_triton():
13    import triton
14    from triton import language as tl
15
16    # Define here so that multiple tests can take advantage of it
17    @triton.jit
18    def add_kernel(
19        in_ptr0,
20        in_ptr1,
21        out_ptr,
22        n_elements,
23        BLOCK_SIZE: "tl.constexpr",
24    ):
25        pid = tl.program_id(axis=0)
26        block_start = pid * BLOCK_SIZE
27        offsets = block_start + tl.arange(0, BLOCK_SIZE)
28        mask = offsets < n_elements
29        x = tl.load(in_ptr0 + offsets, mask=mask)
30        y = tl.load(in_ptr1 + offsets, mask=mask)
31        output = x + y
32        tl.store(out_ptr + offsets, output, mask=mask)
33
34    @triton.jit
35    def add_kernel_with_optional_param(
36        in_ptr0,
37        in_ptr1,
38        out_ptr,
39        n_elements,
40        ARGS_PASSED: "tl.constexpr",
41        BLOCK_SIZE: "tl.constexpr",
42    ):
43        pid = tl.program_id(axis=0)
44        block_start = pid * BLOCK_SIZE
45        offsets = block_start + tl.arange(0, BLOCK_SIZE)
46        mask = offsets < n_elements
47        x = tl.load(in_ptr0 + offsets, mask=mask)
48        if ARGS_PASSED == "two":
49            y = tl.load(in_ptr1 + offsets, mask=mask)
50            output = x + y
51        else:
52            output = x
53        tl.store(out_ptr + offsets, output, mask=mask)
54
55    @triton.autotune(
56        configs=[
57            triton.Config({"BLOCK_SIZE": 128}, num_stages=3, num_warps=8),
58            triton.Config({"BLOCK_SIZE": 128}, num_stages=4, num_warps=4),
59            triton.Config({"BLOCK_SIZE": 64}, num_stages=3, num_warps=8),
60            triton.Config({"BLOCK_SIZE": 64}, num_stages=4, num_warps=4),
61        ],
62        key=[],
63    )
64    @triton.jit
65    def add_kernel_autotuned(
66        in_ptr0,
67        in_ptr1,
68        out_ptr,
69        n_elements,
70        BLOCK_SIZE: "tl.constexpr",
71    ):
72        pid = tl.program_id(axis=0)
73        block_start = pid * BLOCK_SIZE
74        offsets = block_start + tl.arange(0, BLOCK_SIZE)
75        mask = offsets < n_elements
76        x = tl.load(in_ptr0 + offsets, mask=mask)
77        y = tl.load(in_ptr1 + offsets, mask=mask)
78        output = x + y
79        tl.store(out_ptr + offsets, output, mask=mask)
80
81    @triton.autotune(
82        configs=[
83            triton.Config({"BLOCK_SIZE": 16}, num_stages=2, num_warps=2),
84        ],
85        key=[],
86    )
87    @triton.jit
88    def add_kernel_autotuned_weird_param_order(
89        in_ptr0,
90        in_ptr1,
91        n_elements,
92        BLOCK_SIZE: "tl.constexpr",
93        out_ptr,
94    ):
95        # out_ptr is after an autotuned param that's declared as tl.constexpr.
96        # This param ordering can create bugs if not handled correctly.
97        pid = tl.program_id(axis=0)
98        block_start = pid * BLOCK_SIZE
99        offsets = block_start + tl.arange(0, BLOCK_SIZE)
100        mask = offsets < n_elements
101        x = tl.load(in_ptr0 + offsets, mask=mask)
102        y = tl.load(in_ptr1 + offsets, mask=mask)
103        output = x + y
104        tl.store(out_ptr + offsets, output, mask=mask)
105
106    @triton.autotune(
107        configs=[
108            triton.Config(
109                {"BLOCK_SIZE_X": 128, "BLOCK_SIZE_Y": 128}, num_stages=3, num_warps=8
110            ),
111            triton.Config(
112                {"BLOCK_SIZE_X": 128, "BLOCK_SIZE_Y": 128}, num_stages=4, num_warps=4
113            ),
114            triton.Config(
115                {"BLOCK_SIZE_X": 64, "BLOCK_SIZE_Y": 64}, num_stages=3, num_warps=8
116            ),
117            triton.Config(
118                {"BLOCK_SIZE_X": 64, "BLOCK_SIZE_Y": 64}, num_stages=4, num_warps=4
119            ),
120        ],
121        key=[],
122    )
123    @triton.jit
124    def add_kernel_2d_autotuned(
125        in_ptr0,
126        in_ptr1,
127        out_ptr,
128        x_elements,
129        y_elements,
130        BLOCK_SIZE_X: "tl.constexpr",
131        BLOCK_SIZE_Y: "tl.constexpr",
132    ):
133        xoffset = tl.program_id(0) * BLOCK_SIZE_X
134        xindex = xoffset + tl.arange(0, BLOCK_SIZE_X)[:, None]
135        xmask = xindex < x_elements
136        yoffset = tl.program_id(1) * BLOCK_SIZE_Y
137        yindex = yoffset + tl.arange(0, BLOCK_SIZE_Y)[None, :]
138        ymask = yindex < y_elements
139        x1 = xindex
140        y0 = yindex
141        tmp0 = tl.load(in_ptr0 + (x1 + (x_elements * y0)), xmask & ymask)
142        tmp1 = tl.load(in_ptr0 + (y0 + (y_elements * x1)), xmask & ymask)
143        tmp2 = tmp0 + tmp1
144        tl.store(out_ptr + (x1 + (x_elements * y0)), tmp2, xmask & ymask)
145
146    def _dummy_early_config_prune(configs, *_, **__):
147        return configs
148
149    @triton.autotune(
150        configs=[
151            triton.Config({"BLOCK_SIZE": 128}, num_stages=3, num_warps=8),
152            triton.Config({"BLOCK_SIZE": 64}, num_stages=4, num_warps=4),
153        ],
154        key=[],
155        warmup=10,
156        rep=20,
157        prune_configs_by={"early_config_prune": _dummy_early_config_prune},
158    )
159    @triton.jit
160    def add_kernel_autotuned_with_unsupported_args(
161        in_ptr0,
162        in_ptr1,
163        out_ptr,
164        n_elements,
165        BLOCK_SIZE: "tl.constexpr",
166    ):
167        pid = tl.program_id(axis=0)
168        block_start = pid * BLOCK_SIZE
169        offsets = block_start + tl.arange(0, BLOCK_SIZE)
170        mask = offsets < n_elements
171        x = tl.load(in_ptr0 + offsets, mask=mask)
172        y = tl.load(in_ptr1 + offsets, mask=mask)
173        output = x + y
174        tl.store(out_ptr + offsets, output, mask=mask)
175
176    @triton.jit
177    def add_kernel_with_scaling(
178        in_ptr0,
179        in_ptr1,
180        out_ptr,
181        n_elements,
182        scaling_factor,
183        BLOCK_SIZE: "tl.constexpr",
184    ):
185        pid = tl.program_id(axis=0)
186        block_start = pid * BLOCK_SIZE
187        offsets = block_start + tl.arange(0, BLOCK_SIZE)
188        mask = offsets < n_elements
189        x = tl.load(in_ptr0 + offsets, mask=mask)
190        y = tl.load(in_ptr1 + offsets, mask=mask)
191        output = (x + y) * scaling_factor
192        tl.store(out_ptr + offsets, output, mask=mask)
193
194    @triton.jit
195    def mul2_kernel(
196        in_ptr0,
197        out_ptr,
198        n_elements,
199        BLOCK_SIZE: "tl.constexpr",
200    ):
201        pid = tl.program_id(axis=0)
202        block_start = pid * BLOCK_SIZE
203        offsets = block_start + tl.arange(0, BLOCK_SIZE)
204        mask = offsets < n_elements
205        x = tl.load(in_ptr0 + offsets, mask=mask)
206        output = 2 * x
207        tl.store(out_ptr + offsets, output, mask=mask)
208
209    @triton.jit
210    def mul2_inplace_kernel(
211        ptr,
212        n_elements,
213        BLOCK_SIZE: "tl.constexpr",
214    ):
215        pid = tl.program_id(axis=0)
216        block_start = pid * BLOCK_SIZE
217        offsets = block_start + tl.arange(0, BLOCK_SIZE)
218        mask = offsets < n_elements
219        x = tl.load(ptr + offsets, mask=mask)
220        output = 2 * x
221        tl.store(ptr + offsets, output, mask=mask)
222
223    @triton.jit
224    def zero_negs(x):
225        return tl.where(x >= 0, x, 0)
226
227    @triton.jit
228    def indirection_kernel(
229        in_ptr0,
230        out_ptr,
231        n_elements,
232        BLOCK_SIZE: "tl.constexpr",
233        ACTIVATION: "tl.constexpr",
234    ):
235        pid = tl.program_id(axis=0)
236        block_start = pid * BLOCK_SIZE
237        offsets = block_start + tl.arange(0, BLOCK_SIZE)
238        mask = offsets < n_elements
239        if ACTIVATION == "mul2_inplace_kernel":
240            mul2_inplace_kernel(in_ptr0, n_elements, BLOCK_SIZE=BLOCK_SIZE)
241        elif ACTIVATION == "add_kernel":
242            add_kernel(in_ptr0, in_ptr0, out_ptr, n_elements, BLOCK_SIZE=BLOCK_SIZE)
243        x = tl.load(in_ptr0 + offsets, mask=mask)
244        tl.store(out_ptr + offsets, x, mask=mask)
245
246    @triton.jit
247    def double_strided_kernel(
248        in_ptr,
249        out_ptr,
250        in_y_stride,
251        out_y_stride,
252        X_BLOCK_SIZE: "tl.constexpr",
253        Y_BLOCK_SIZE: "tl.constexpr",
254    ):
255        xid = tl.program_id(axis=0)
256        yid = tl.program_id(axis=1)
257        x_start = xid * X_BLOCK_SIZE
258        y_start = yid * Y_BLOCK_SIZE
259        x_offsets = x_start + tl.arange(0, X_BLOCK_SIZE)
260        y_offsets = y_start + tl.arange(0, Y_BLOCK_SIZE)
261        src_offsets = y_offsets[:, None] * in_y_stride + x_offsets[None, :]
262        dst_offsets = y_offsets[:, None] * out_y_stride + x_offsets[None, :]
263        src = tl.load(in_ptr + src_offsets)
264        tl.store(out_ptr + dst_offsets, src * 2.0)
265
266    @triton.jit
267    def inline_asm_kernel(X, Y, Z, n: "tl.constexpr", BLOCK: "tl.constexpr"):
268        x = tl.load(X + tl.arange(0, BLOCK))
269        y = tl.load(Y + tl.arange(0, BLOCK))
270        s = tl.full([BLOCK], n, tl.int32)
271        z = tl.inline_asm_elementwise(
272            "shf.l.wrap.b32 $0, $1, $2, $3;",
273            "=r,r, r, r",
274            [x, y, s],
275            dtype=tl.int32,
276            is_pure=True,
277            pack=1,
278        )
279        tl.store(Z + tl.arange(0, BLOCK), z)
280
281    @triton.jit
282    def add_kernel_with_block_ptr(
283        x_ptr,
284        y_ptr,
285        output_ptr,
286        n_elements,
287        BLOCK_SIZE: tl.constexpr,
288    ):
289        pid = tl.program_id(axis=0)
290        block_start = pid * BLOCK_SIZE
291        x = tl.load(
292            tl.make_block_ptr(
293                base=x_ptr,
294                shape=[n_elements],
295                strides=[1],
296                offsets=[block_start],
297                block_shape=[BLOCK_SIZE],
298                order=[0],
299            ),
300            boundary_check=[0],
301        )
302        y = tl.load(
303            tl.make_block_ptr(
304                base=y_ptr,
305                shape=[n_elements],
306                strides=[1],
307                offsets=[block_start],
308                block_shape=[BLOCK_SIZE],
309                order=[0],
310            ),
311            boundary_check=[0],
312        )
313        output = x + y
314        tl.store(
315            tl.make_block_ptr(
316                base=output_ptr,
317                shape=[n_elements],
318                strides=[1],
319                offsets=[block_start],
320                block_shape=[BLOCK_SIZE],
321                order=[0],
322            ),
323            output,
324            boundary_check=[0],
325        )
326
327    @triton.jit
328    def kernel_with_block_ptr_2d(
329        x_ptr,
330        output_ptr,
331        n_elements,
332        BLOCK_SIZE: tl.constexpr,
333    ):
334        pid = tl.program_id(axis=0)
335        block_start = pid * BLOCK_SIZE
336        x = tl.load(
337            tl.make_block_ptr(
338                base=x_ptr,
339                shape=[n_elements, 1],
340                strides=[1, 1],
341                offsets=[block_start, 0],
342                block_shape=[BLOCK_SIZE, 1],
343                order=[1, 0],
344            ),
345            boundary_check=[0],
346        )
347        output = x
348        tl.store(
349            tl.make_block_ptr(
350                base=output_ptr,
351                shape=[n_elements, 1],
352                strides=[1, 1],
353                offsets=[block_start, 0],
354                block_shape=[BLOCK_SIZE, 1],
355                order=[1, 0],
356            ),
357            output,
358            boundary_check=[0],
359        )
360
361    from triton.language import load, store
362
363    @triton.jit
364    def add_kernel_with_import(
365        in_ptr0,
366        in_ptr1,
367        out_ptr,
368        n_elements,
369        BLOCK_SIZE: "tl.constexpr",
370    ):
371        pid = tl.program_id(axis=0)
372        block_start = pid * BLOCK_SIZE
373        offsets = block_start + tl.arange(0, BLOCK_SIZE)
374        mask = offsets < n_elements
375        x = load(in_ptr0 + offsets, mask=mask)
376        y = load(in_ptr1 + offsets, mask=mask)
377        output = x + y
378        store(out_ptr + offsets, output, mask=mask)
379
380    @triton.jit
381    def cond_op_kernel(
382        in_ptr0,
383        in_ptr1,
384        out_ptr,
385        n_elements,
386        BLOCK_SIZE: "tl.constexpr",
387    ):
388        pid = tl.program_id(axis=0)
389        block_start = pid * BLOCK_SIZE
390        offsets = block_start + tl.arange(0, BLOCK_SIZE)
391        mask = offsets < n_elements
392        x = tl.load(in_ptr0 + offsets, mask=mask)
393        y = tl.load(in_ptr1 + offsets, mask=mask)
394        if tl.program_id(0) == 0:
395            output = x + y
396        else:
397            output = x * y
398        tl.store(out_ptr + offsets, output, mask=mask)
399
400    @triton.jit
401    def atomic_add_kernel(
402        in_ptr0,
403        in_ptr1,
404        out_ptr,
405        n_elements,
406        BLOCK_SIZE: "tl.constexpr",
407    ):
408        pid = tl.program_id(axis=0)
409        block_start = pid * BLOCK_SIZE
410        offsets = block_start + tl.arange(0, BLOCK_SIZE)
411        mask = offsets < n_elements
412        x = tl.load(in_ptr0 + offsets, mask=mask)
413        y = tl.load(in_ptr1 + offsets, mask=mask)
414        output = x + y
415        tl.atomic_add(out_ptr + offsets, output, mask=mask)
416
417    @triton.jit
418    def add_4_times_kernel(
419        in_ptr0,
420        in_ptr1,
421        out_ptr,
422        n_elements,
423        BLOCK_SIZE: "tl.constexpr",
424    ):
425        pid = tl.program_id(axis=0)
426        block_start = pid * BLOCK_SIZE
427        offsets = block_start + tl.arange(0, BLOCK_SIZE)
428        mask = offsets < n_elements
429        x = tl.load(in_ptr0 + offsets, mask=mask)
430        y = tl.load(in_ptr1 + offsets, mask=mask)
431        for i in range(2):
432            output = x + y
433            tl.store(out_ptr + offsets, output, mask=mask)
434        i = 2
435        while i > 0:
436            i -= 1
437            output = x + y
438            tl.store(out_ptr + offsets, output, mask=mask)
439
440    @triton.jit
441    def add_kernel_out_of_order_fn2(
442        in_ptr0,
443        in_ptr1,
444        n_elements,
445        out_ptr,
446        BLOCK_SIZE: "tl.constexpr",
447    ):
448        pid = tl.program_id(axis=0)
449        block_start = pid * BLOCK_SIZE
450        offsets = block_start + tl.arange(0, BLOCK_SIZE)
451        mask = offsets < n_elements
452        x = tl.load(in_ptr0 + offsets, mask=mask)
453        y = tl.load(in_ptr1 + offsets, mask=mask)
454        output = x + y
455        tl.store(out_ptr + offsets, output, mask=mask)
456