xref: /aosp_15_r20/external/pytorch/test/inductor/test_flex_decoding.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: inductor"]
2# flake8: noqa: B950
3
4import functools
5from collections import namedtuple
6from contextlib import nullcontext
7from typing import Callable, Optional
8from unittest import expectedFailure, skipUnless
9from unittest.mock import patch
10
11import torch
12from torch._inductor.test_case import TestCase as InductorTestCase
13from torch._inductor.utils import run_and_get_code
14from torch.nn.attention.flex_attention import (
15    _create_empty_block_mask,
16    _identity,
17    BlockMask,
18    create_block_mask,
19    flex_attention,
20)
21from torch.testing import FileCheck
22from torch.testing._internal import common_utils
23from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_BF16
24from torch.testing._internal.common_utils import skipIfRocm, TEST_WITH_ROCM
25from torch.utils._triton import has_triton
26
27
28# Skip tests if Triton is not available
29supported_platform = skipUnless(
30    torch.cuda.is_available()
31    and has_triton()
32    and torch.cuda.get_device_capability() >= (8, 0),
33    "Requires CUDA and Triton",
34)
35
36Tolerances = namedtuple("Tolerances", ["atol", "rtol"])
37torch.set_float32_matmul_precision("high")
38
39index = torch.ops.aten.index
40Tensor = torch.Tensor
41
42
43def create_attention(score_mod, block_mask, enable_gqa=False):
44    return functools.partial(
45        flex_attention,
46        score_mod=score_mod,
47        block_mask=block_mask,
48        enable_gqa=enable_gqa,
49    )
50
51
52def create_block_mask_test(score_mod, query, key):
53    block_mask = create_block_mask(
54        score_mod, 1, 1, query.shape[-2], key.shape[-2], query.device
55    )
56    return block_mask
57
58
59test_dtypes = (
60    [torch.float16, torch.bfloat16, torch.float32]
61    if PLATFORM_SUPPORTS_BF16
62    else [torch.float16, torch.float32]
63)
64
65test_dtypes_fast = [torch.float16]
66
67
68# --------- Useful score mod functions for testing ---------
69def _causal(
70    score: Tensor,
71    batch: Tensor,
72    head: Tensor,
73    token_q: Tensor,
74    token_kv: Tensor,
75) -> Tensor:
76    return torch.where(token_q >= token_kv, score, float("-inf"))
77
78
79def _generate_windowed(offset):
80    def _windowed(score, b, h, q, kv):
81        return torch.where(q + offset >= kv, score, float("-inf"))
82
83    return _windowed
84
85
86def _get_windowed_sdpa_mask(Mq, Mkv, offset):
87    return torch.tril(torch.ones(Mkv, Mkv, dtype=torch.bool, device="cuda"))[
88        offset : offset + Mq
89    ]
90
91
92def _rel_bias(
93    score: Tensor,
94    batch: Tensor,
95    head: Tensor,
96    token_q: Tensor,
97    token_kv: Tensor,
98) -> Tensor:
99    return score + (token_q - token_kv)
100
101
102def _rel_causal(
103    score: Tensor,
104    batch: Tensor,
105    head: Tensor,
106    token_q: Tensor,
107    token_kv: Tensor,
108) -> Tensor:
109    return torch.where(token_q >= token_kv, score + (token_q - token_kv), float("-inf"))
110
111
112def _generate_alibi_bias(num_heads: int):
113    def _alibi_bias(
114        score: Tensor,
115        batch: Tensor,
116        head: Tensor,
117        token_q: Tensor,
118        token_kv: Tensor,
119    ) -> Tensor:
120        scale = torch.exp2(-((head + 1) * 8.0 / num_heads))
121        return score + (token_kv - token_q) * scale
122
123    return _alibi_bias
124
125
126def _inverse_causal(score, b, h, m, n):
127    return torch.where(m <= n, score, float("-inf"))
128
129
130def _times_two(score, b, h, m, n):
131    """Joint graph needed for correctness"""
132    return score * 2
133
134
135def _squared(score, b, h, m, n):
136    """Joint graph needed for correctness"""
137    return score * score
138
139
140def _head_offset(dtype: torch.dtype):
141    """Captured Buffer"""
142    head_offset = torch.rand(Hq, device="cuda", dtype=dtype)
143
144    def score_mod(score, b, h, m, n):
145        return score * head_offset[h]
146
147    return score_mod
148
149
150def _trig(score, b, h, m, n):
151    """Joint graph needed for correctness"""
152    return torch.sin(torch.cos(score)) + torch.tan(b)
153
154
155def _trig2(score, b, h, m, n):
156    """Branching joint graph"""
157    cos_score = torch.cos(score)
158    sin_score = torch.sin(score)
159    z = cos_score * sin_score + torch.tan(b)
160    return z
161
162
163test_score_mods = [
164    _identity,
165    _times_two,
166    _squared,
167    _causal,
168    _inverse_causal,
169    _rel_bias,
170    _rel_causal,
171    _generate_alibi_bias(8),
172    _generate_windowed(1000),
173]
174
175captured_buffers_map = {
176    "_head_offset": _head_offset,
177}
178
179B = 4
180S = 2048
181D = 64
182
183
184test_Hq_Hkv = [
185    (16, 1),
186    (8, 2),
187    (16, 16),
188]
189
190(Hq, Hkv) = (16, 8)
191
192
193def query_key_value_clones(
194    query: torch.Tensor,
195    key: torch.Tensor,
196    value: torch.Tensor,
197    dtype: torch.dtype = None,
198):
199    """Clones the query, key, and value tensors and moves them to the specified dtype."""
200    if dtype is None:
201        dtype = query.dtype
202    query_ref = query.clone().detach().to(dtype).requires_grad_(query.requires_grad)
203    key_ref = key.clone().detach().to(dtype).requires_grad_(key.requires_grad)
204    value_ref = value.clone().detach().to(dtype).requires_grad_(value.requires_grad)
205    return query_ref, key_ref, value_ref
206
207
208class TestFlexDecoding(InductorTestCase):
209    def _check_equal(
210        self,
211        golden_out: torch.Tensor,
212        ref_out: torch.Tensor,
213        compiled_out: torch.Tensor,
214        fudge_factor: float,
215        tensor_name: Optional[str] = None,
216    ):
217        compiled_error = (golden_out - compiled_out).abs().mean()
218        ref_error = (golden_out - ref_out).abs().mean()
219        if torch.isnan(compiled_error).any() and not torch.isnan(ref_error).any():
220            self.assertTrue(False, "Output/Grad with NaN")
221        if ref_error < (1e-4) * golden_out.abs().mean():
222            print(
223                "very small ref error of ",
224                (ref_error.to(torch.float64) * (1e5) / golden_out.abs().mean()),
225            )
226            tolerance = Tolerances(atol=2e-1, rtol=2e-1)
227            torch.testing.assert_close(
228                golden_out.to(dtype=compiled_out.dtype),
229                compiled_out,
230                atol=tolerance.atol,
231                rtol=tolerance.rtol,
232            )
233        elif compiled_error > ref_error * fudge_factor:
234            name = tensor_name if tensor_name is not None else ""
235            msg = f"{name} Compiled error {compiled_error} is greater than ref error {ref_error} by more than {fudge_factor}X."
236            self.assertTrue(False, msg)
237
238    def _check_out(
239        self,
240        golden_out: torch.Tensor,
241        ref_out: torch.Tensor,
242        compiled_out: torch.Tensor,
243    ):
244        dtype = ref_out.dtype
245        with torch.no_grad():
246            # Note, it seems like we really are less accurate than the float32
247            # computation, likely due to the online softmax
248            if dtype == torch.float32:
249                fudge_factor = 10.0
250            else:
251                fudge_factor = 1.1
252
253            # Checkout output
254            self._check_equal(golden_out, ref_out, compiled_out, fudge_factor, "Out")
255
256    def run_test(
257        self,
258        score_mod: Optional[Callable],
259        dtype: torch.dtype = torch.float16,
260        Q_B: int = B,
261        Q_H: int = Hq,
262        Q_S: int = 1,
263        Q_D: int = D,
264        KV_B: int = B,
265        KV_H: int = Hkv,
266        KV_S: int = S,
267        V_D: int = D,
268        block_mask: Optional[BlockMask] = None,
269    ):
270        assert (
271            score_mod is not None or block_mask is not None
272        ), "Must provide score_mod or block_mask"
273        assert Q_H % KV_H == 0
274        if TEST_WITH_ROCM and Q_H != KV_H:
275            self.skipTest("enable_gqa=True is unsupported on ROCM, for now")
276        q = torch.randn(
277            (Q_B, Q_H, Q_S, Q_D),
278            dtype=dtype,
279            device="cuda",
280            requires_grad=False,
281        )
282        k = torch.randn(
283            (KV_B, KV_H, KV_S, Q_D), dtype=dtype, device="cuda", requires_grad=False
284        )
285        v = torch.randn(
286            (KV_B, KV_H, KV_S, V_D), dtype=dtype, device="cuda", requires_grad=False
287        )
288        q_ref, k_ref, v_ref = query_key_value_clones(q, k, v)
289        q_gold, k_gold, v_gold = query_key_value_clones(q, k, v, torch.float64)
290
291        sdpa_partial = create_attention(
292            score_mod, block_mask, enable_gqa=(not Q_H == KV_H)
293        )
294        compiled_sdpa = torch.compile(sdpa_partial)
295        golden_out, gold_lse = sdpa_partial(q_gold, k_gold, v_gold, return_lse=True)
296        ref_out, ref_lse = sdpa_partial(q_ref, k_ref, v_ref, return_lse=True)
297        compiled_out, compiled_lse = compiled_sdpa(q, k, v, return_lse=True)
298
299        self._check_out(
300            golden_out,
301            ref_out,
302            compiled_out,
303        )
304        self._check_out(
305            gold_lse,
306            ref_lse,
307            compiled_lse,
308        )
309
310    def run_test_with_call(
311        self,
312        sdpa_call: Callable,
313        golden_call: Optional[Callable] = None,
314        dtype: torch.dtype = torch.float16,
315        Q_B: int = B,
316        Q_H: int = Hq,
317        Q_S: int = 1,
318        Q_D: int = D,
319        KV_B: int = B,
320        KV_H: int = Hkv,
321        KV_S: int = S,
322        V_D: int = D,
323    ):
324        if not golden_call:
325            golden_call = sdpa_call
326        q = torch.randn(
327            (Q_B, KV_H, Q_S * (Q_H // KV_H), Q_D),
328            dtype=dtype,
329            device="cuda",
330            requires_grad=False,
331        )
332        k = torch.randn(
333            (KV_B, KV_H, KV_S, Q_D), dtype=dtype, device="cuda", requires_grad=False
334        )
335        v = torch.randn(
336            (KV_B, KV_H, KV_S, V_D), dtype=dtype, device="cuda", requires_grad=False
337        )
338        q_ref, k_ref, v_ref = query_key_value_clones(q, k, v)
339        q_gold, k_gold, v_gold = query_key_value_clones(q, k, v, torch.float64)
340
341        compiled_sdpa = torch.compile(sdpa_call)
342        golden_out = golden_call(q_gold, k_gold, v_gold)
343        ref_out = golden_call(q_ref, k_ref, v_ref)
344        compiled_out = compiled_sdpa(q, k, v)
345
346        self._check_out(
347            golden_out,
348            ref_out,
349            compiled_out,
350        )
351
352    @supported_platform
353    @expectedFailure
354    @common_utils.parametrize("dtype", test_dtypes_fast)
355    def test_bw_decoding_fails(self, dtype):
356        make_kv = functools.partial(
357            torch.randn,
358            (2, 2, 128, 4),
359            dtype=dtype,
360            device="cuda",
361            requires_grad=True,
362        )
363        make_q = functools.partial(
364            torch.randn,
365            (2, 2, 8, 4),
366            dtype=dtype,
367            device="cuda",
368            requires_grad=True,
369        )
370        q, k, v, backward_grad = make_q(), make_kv(), make_kv(), make_q()
371
372        block_mask = _create_empty_block_mask(q, k)
373
374        @torch.compile
375        def sdpa_hop(q, k, v, score_mod, block_mask):
376            return flex_attention(q, k, v, score_mod)
377
378        output = sdpa_hop(q, k, v, _identity, block_mask)
379
380        output.backward(backward_grad)
381
382    @supported_platform
383    @common_utils.parametrize("dtype", test_dtypes)
384    @common_utils.parametrize("score_mod", test_score_mods)
385    @common_utils.parametrize("head_dims", test_Hq_Hkv)
386    def test_builtin_score_mods(
387        self, dtype: torch.dtype, score_mod: Callable, head_dims
388    ):
389        Hq, Hkv = head_dims
390        assert Hq % Hkv == 0
391        self.run_test(score_mod, dtype, Q_H=Hq, KV_H=Hkv)
392
393    def input_strides_1(B, H, S, D):
394        return ((H * S * D, S * D, D, 1), 997)  # offset
395
396    def input_strides_2(B, H, S, D):
397        return ((H * D, D, B * H * D, 1), 499)  # transposed dimensions
398
399    def input_strides_3(B, H, S, D):
400        return ((S * (D + 1), B * S * (D + 1), (D + 1), 1), 293)  # additional buffer
401
402    def input_strides_4(B, H, S, D):
403        return ((1, D, (B + 1) * (H + 1) * D, 1), 97)  # shared dimension
404
405    test_input_strides = [
406        input_strides_1,
407        input_strides_2,
408        input_strides_3,
409        input_strides_4,
410    ]
411
412    @supported_platform
413    @common_utils.parametrize("dtype", test_dtypes_fast)
414    @common_utils.parametrize("k_s", test_input_strides)
415    @common_utils.parametrize("v_s", test_input_strides)
416    @common_utils.parametrize("head_dims", test_Hq_Hkv)
417    def test_strided_inputs(self, dtype: torch.dtype, k_s, v_s, head_dims):
418        Hq, Hkv = head_dims
419        assert Hq % Hkv == 0
420        q1 = torch.randn((B * Hq * D), dtype=dtype, device="cuda")
421        k1 = torch.randn((B * Hkv * S * D * 4), dtype=dtype, device="cuda")
422        v1 = torch.randn((B * Hkv * S * D * 4), dtype=dtype, device="cuda")
423
424        k_shape = (B, Hkv, S, D)
425        v_shape = (B, Hkv, S, D)
426
427        q = q1.view(1, Hq, B, D).transpose(0, 2)
428
429        k_strides, k_offset = k_s(B, Hkv, S, D)
430        k_max = [x * (y - 1) for x, y in zip(k_strides, k_shape)]
431        assert sum(k_max) + k_offset < B * Hkv * S * D * 4
432        assert k_strides[-1] == 1
433        k = torch.as_strided(k1, k_shape, k_strides, k_offset)
434
435        v_strides, v_offset = v_s(B, Hkv, S, D)
436        v_max = [x * (y - 1) for x, y in zip(v_strides, v_shape)]
437        assert sum(v_max) + v_offset < B * Hkv * S * D * 4
438        assert v_strides[-1] == 1
439        v = torch.as_strided(v1, v_shape, v_strides, v_offset)
440
441        sdpa_partial = create_attention(
442            score_mod=_generate_alibi_bias(8),
443            block_mask=None,
444            enable_gqa=(not Hq == Hkv),
445        )
446        compiled_sdpa = torch.compile(sdpa_partial)
447        ref_out = sdpa_partial(q, k, v)
448        compiled_out = compiled_sdpa(q, k, v)
449
450        tolerance = Tolerances(atol=2e-1, rtol=2e-1)
451        torch.testing.assert_close(
452            ref_out, compiled_out, atol=tolerance.atol, rtol=tolerance.rtol
453        )
454
455    @supported_platform
456    @common_utils.parametrize("dtype", test_dtypes)
457    def test_skip_odd_keys(self, dtype: torch.dtype):
458        def score_mod(score, b, h, q, kv):
459            return torch.where(kv % 2 == 0, score, float("-inf"))
460
461        self.run_test(score_mod, dtype)
462
463    @supported_platform
464    @common_utils.parametrize("dtype", test_dtypes)
465    def test_function_composition(self, dtype: torch.dtype):
466        def score_mod_1(score, b, h, m, n):
467            return score + (m - n)
468
469        def score_mod_2(score, b, h, m, n):
470            return torch.where(m <= n, score, float("-inf"))
471
472        def composed_score_mod(score, b, h, m, n):
473            return score_mod_2(score_mod_1(score, b, h, m, n), b, h, m, n)
474
475        self.run_test(composed_score_mod, dtype)
476
477    @supported_platform
478    @common_utils.parametrize("dtype", test_dtypes)
479    def test_captured_buffers(self, dtype: torch.dtype):
480        head_offset = torch.rand(Hq, device="cuda", dtype=dtype)
481
482        def score_mod(score, b, h, m, n):
483            return score + head_offset[h]
484
485        self.run_test(score_mod, dtype)
486
487    @supported_platform
488    @common_utils.parametrize("dtype", test_dtypes)
489    def test_captured_buffers_all_dims(self, dtype: torch.dtype):
490        head_scale = torch.randn(Hq, device="cuda")
491        batch_scale = torch.randn(B, device="cuda")
492        kv_scale = torch.randn(S, device="cuda")
493        q_scale = torch.randn(1, device="cuda")
494
495        def all_bias(score, batch, head, token_q, token_kv):
496            score = score + kv_scale[token_kv]
497            score = score + q_scale[token_q]
498            score = score + head_scale[head]
499            score = score + batch_scale[batch]
500            return score
501
502        self.run_test(all_bias, dtype)
503
504    @supported_platform
505    @common_utils.parametrize("dtype", test_dtypes_fast)
506    def test_seq_masking(self, dtype):
507        seq_idx = torch.zeros(S, device="cuda", dtype=torch.bool)
508        seq_idx[S // 2 :] = 1
509
510        def seq_mask_mod(score, b, h, q, kv):
511            return torch.where(seq_idx[q] == seq_idx[kv], score, float("-inf"))
512
513        self.run_test(seq_mask_mod, dtype)
514
515    @supported_platform
516    @common_utils.parametrize("dtype", test_dtypes_fast)
517    def test_load_from_bias_seq_only(self, dtype):
518        bias = torch.randn(1, S, device="cuda", dtype=dtype)
519
520        def bias_mod(score, b, h, q, kv):
521            return score + bias[q, kv]
522
523        self.run_test(bias_mod, dtype)
524
525    @supported_platform
526    @common_utils.parametrize("dtype", test_dtypes_fast)
527    def test_load_from_bias_seq_batch(self, dtype):
528        bias = torch.randn(B, 1, S, device="cuda", dtype=dtype)
529
530        def bias_mod(score, b, h, q, kv):
531            return score + bias[b, q, kv]
532
533        self.run_test(bias_mod, dtype)
534
535    @supported_platform
536    @common_utils.parametrize("dtype", test_dtypes_fast)
537    def test_load_from_bias_head_seq_batch(self, dtype):
538        bias = torch.randn(
539            B,
540            Hq,
541            1,
542            S,
543            device="cuda",
544            dtype=dtype,
545        )
546
547        def bias_mod(score, b, h, q, kv):
548            return score + bias[b, h, q, kv]
549
550        self.run_test(bias_mod, dtype)
551
552    # TODO this config segfaults with Triton without:
553    # https://github.com/triton-lang/triton/pull/4540
554    @supported_platform
555    @common_utils.parametrize("score_mod", test_score_mods)
556    @common_utils.parametrize("dtype", test_dtypes)
557    @common_utils.parametrize("head_dims", [(D, D // 2), (D // 2, D)])
558    def test_non_equal_head_dims(self, dtype, score_mod, head_dims):
559        qk_d, v_d = head_dims
560        context = nullcontext() if qk_d > v_d else self.assertRaises(ValueError)
561        with context:
562            self.run_test(score_mod, dtype, B, Hq, 1, qk_d, B, Hkv, S, V_D=v_d)
563
564    @supported_platform
565    @common_utils.parametrize("dtype", test_dtypes_fast)
566    def test_subgraph_respect_decompostion(self, dtype):
567        from torch._decomp import core_aten_decompositions
568        from torch.fx.experimental.proxy_tensor import make_fx
569
570        def score_mod_func(score, b, h, q, kv):
571            return score - q // (1 + kv)
572
573        make_kv = functools.partial(
574            torch.randn,
575            (2, 2, 128, 4),
576            dtype=dtype,
577            device="cuda",
578            requires_grad=True,
579        )
580        make_q = functools.partial(
581            torch.randn,
582            (2, 2, 8, 4),
583            dtype=dtype,
584            device="cuda",
585            requires_grad=True,
586        )
587        query, key, value = make_q(), make_kv(), make_kv()
588        # floor_div is not decomposed in decompostion_table is empty
589        attention = functools.partial(flex_attention, score_mod=score_mod_func)
590        gm = make_fx(attention, decomposition_table={})(query, key, value)
591        self.assertExpectedInline(
592            gm.sdpa_score0.code.strip(),
593            """\
594def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
595    add = torch.ops.aten.add.Tensor(arg4_1, 1);  arg4_1 = None
596    floor_divide = torch.ops.aten.floor_divide.default(arg3_1, add);  arg3_1 = add = None
597    sub = torch.ops.aten.sub.Tensor(arg0_1, floor_divide);  arg0_1 = floor_divide = None
598    return sub""",
599        )
600
601        # floor_div is decomposed for core_aten_decompositions
602        gm = make_fx(attention, decomposition_table=core_aten_decompositions())(
603            query, key, value
604        )
605        self.assertExpectedInline(
606            gm.sdpa_score0.code.strip(),
607            """\
608def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
609    add = torch.ops.aten.add.Tensor(arg4_1, 1);  arg4_1 = None
610    div = torch.ops.aten.div.Tensor_mode(arg3_1, add, rounding_mode = 'floor');  arg3_1 = add = None
611    sub = torch.ops.aten.sub.Tensor(arg0_1, div);  arg0_1 = div = None
612    return sub""",
613        )
614
615    @supported_platform
616    @common_utils.parametrize("dtype", test_dtypes_fast)
617    def test_silu_on_score(self, dtype):
618        def silu_score(score, b, h, q, kv):
619            return torch.nn.functional.silu(score)
620
621        self.run_test(silu_score, dtype)
622
623    @supported_platform
624    @common_utils.parametrize("dtype", test_dtypes_fast)
625    def test_padded_dense_causal(self, dtype):
626        seq_len = torch.arange(B, device="cuda", dtype=torch.int32) + 1
627
628        def create_padded_dense_wrapper(orig_score_mod):
629            def njt_score_mod(qk, b, h, q, kv):
630                return torch.where(
631                    qk <= seq_len[b], orig_score_mod(qk, b, h, q, kv), -float("inf")
632                )
633
634            return njt_score_mod
635
636        causal_njt = create_padded_dense_wrapper(_causal)
637
638        self.run_test(causal_njt, dtype)
639
640    @supported_platform
641    @common_utils.parametrize("dtype", test_dtypes_fast)
642    def test_captured_scale(self, dtype):
643        scale = torch.ones((), device="cuda", dtype=torch.int32)
644
645        def score_mod_scale(qk, b, h, q, kv):
646            return qk + scale
647
648        self.run_test(score_mod_scale, dtype)
649
650    @supported_platform
651    @common_utils.parametrize("dtype", test_dtypes_fast)
652    def test_recompile_changed_score_mod(self, dtype):
653        scale = torch.ones((), device="cuda", dtype=torch.int32)
654        ADD = True
655
656        def score_mod_scale(qk, b, h, q, kv):
657            if ADD:
658                return qk + scale
659            else:
660                return qk * scale
661
662        self.run_test(score_mod_scale, dtype)
663        ADD = False
664        self.run_test(score_mod_scale, dtype)
665
666    @supported_platform
667    @expectedFailure  # If we capture a tensor then we can perform a reduction on it, and that shouldn't be allowed
668    @common_utils.parametrize("dtype", test_dtypes_fast)
669    def test_captured_reduction(self, dtype):
670        scale = torch.randn((B, 8), device="cuda")
671
672        def score_mod_scale(qk, b, h, q, kv):
673            return qk + scale[b].sum(dim=-1)
674
675        self.run_test(score_mod_scale, dtype)
676
677    @supported_platform
678    def test_multiple_score_mod_calls(self):
679        query = torch.randn((1, 8, 4, 64), dtype=torch.float32, device="cuda")
680        keys = [
681            torch.randn((1, 8, 1024, 64), dtype=torch.float32, device="cuda")
682            for _ in range(2)
683        ]
684        values = [
685            torch.randn((1, 8, 1024, 64), dtype=torch.float32, device="cuda")
686            for _ in range(2)
687        ]
688
689        def scoremod_1(qk, b, h, q, kv):
690            return qk + (q - kv)
691
692        def scoremod_2(qk, b, h, q, kv):
693            return torch.where(q >= kv, qk, -float("inf"))
694
695        def f(q, k1, k2, v1, v2):
696            q2 = flex_attention(q, k1, v1, score_mod=scoremod_1)
697            return flex_attention(q2, k2, v2, score_mod=scoremod_2)
698
699        out = f(query, *keys, *values)
700        out2 = torch.compile(f)(query, *keys, *values)
701        tolerance = Tolerances(atol=2e-1, rtol=2e-1)
702        torch.testing.assert_close(out, out2, atol=tolerance.atol, rtol=tolerance.rtol)
703
704    @supported_platform
705    def test_multiple_score_mod_calls2(self):
706        query = torch.randn((1, 8, 4, 64), dtype=torch.float32, device="cuda")
707        keys = [
708            torch.randn((1, 8, 1024, 64), dtype=torch.float32, device="cuda")
709            for _ in range(3)
710        ]
711        values = [
712            torch.randn((1, 8, 1024, 64), dtype=torch.float32, device="cuda")
713            for _ in range(3)
714        ]
715
716        def scoremod_1(qk, b, h, q, kv):
717            return qk + (q - kv)
718
719        def scoremod_2(qk, b, h, q, kv):
720            return torch.where(q >= kv, qk, -float("inf"))
721
722        attention1 = functools.partial(flex_attention, score_mod=scoremod_1)
723
724        def f(q, k1, k2, k3, v1, v2, v3):
725            q2 = attention1(q, k1, v1)
726            q3 = flex_attention(q2, k2, v2, score_mod=scoremod_2)
727            return flex_attention(q3, k3, v3, score_mod=scoremod_1)
728
729        out = f(query, *keys, *values)
730        out2 = torch.compile(f)(query, *keys, *values)
731        self.assertTrue((out - out2).abs().mean() < 1e-2)
732
733    @supported_platform
734    @common_utils.parametrize("dtype", test_dtypes)
735    def test_njt_causal(self, dtype):
736        offsets = torch.tensor(
737            [0, 1024, 1024 + 512, S], device="cuda", dtype=torch.int32
738        )
739        seq_idx = torch.zeros(S, device="cuda", dtype=torch.int32)
740        for idx in range(len(offsets) - 1):
741            seq_idx[offsets[idx] : offsets[idx + 1]] = idx
742
743        def create_njt_wrapper(orig_score_mod, offsets, seq_idx):
744            def njt_score_mod(qk, b, h, q, kv):
745                q_nested = q - offsets[seq_idx[q]]
746                kv_nested = kv - offsets[seq_idx[kv]]
747                return orig_score_mod(qk, b, h, q_nested, kv_nested)
748
749            return njt_score_mod
750
751        causal_njt = create_njt_wrapper(_causal, offsets, seq_idx)
752
753        self.run_test(causal_njt, dtype)
754
755    @supported_platform
756    def test_mixed_dtypes_fails(self):
757        query = torch.randn((1, 1, 8, 64), dtype=torch.float32, device="cuda")
758        key = torch.randn((1, 1, 1024, 64), dtype=torch.float16, device="cuda")
759        value = torch.randn((1, 1, 1024, 64), dtype=torch.float16, device="cuda")
760        with self.assertRaisesRegex(
761            ValueError, "Expected query, key, and value to have the same dtype"
762        ):
763            flex_attention(query, key, value, _identity)
764
765    @supported_platform
766    @patch.object(torch._inductor.config, "max_autotune", True)
767    def test_max_autotune(self):
768        def score_mod(score, b, h, m, n):
769            return score * 2
770
771        self.run_test(score_mod)
772
773    @supported_platform
774    @patch.object(torch._inductor.config, "max_autotune", True)
775    def test_max_autotune_with_captured(self):
776        head_scale = torch.randn(Hq, device="cuda")
777        batch_scale = torch.randn(B, device="cuda")
778        tok_scale = torch.randn(S, device="cuda")
779        q_scale = torch.randn(1, device="cuda")
780
781        def bias_mod(score, batch, head, token_q, token_kv):
782            score = score + tok_scale[token_kv]
783            score = score + q_scale[token_q]
784            score = score + batch_scale[batch]
785            score = score + head_scale[head]
786            return score
787
788        self.run_test(bias_mod)
789
790    @skipIfRocm
791    @supported_platform
792    def test_fully_masked_out_rows_0_check_gqa(self):
793        # Ensure fully masked out rows won't cause NaNs.
794        query = torch.randn(
795            (B, Hq, S, D), dtype=torch.float32, device="cuda", requires_grad=True
796        )
797        key = torch.randn(
798            (B, Hkv, S, D), dtype=torch.float32, device="cuda", requires_grad=True
799        )
800        value = torch.randn(
801            (B, Hkv, S, D), dtype=torch.float32, device="cuda", requires_grad=True
802        )
803
804        M = S // 2
805
806        def mask_mod(b, h, q, kv):
807            return q < M
808
809        block_mask = create_block_mask(mask_mod, 1, 1, S, S)
810
811        flex = torch.compile(flex_attention, dynamic=False)
812
813        out, lse = flex(
814            query, key, value, block_mask=block_mask, enable_gqa=True, return_lse=True
815        )
816        self.assertEqual(out[:, :, M:, :].sum(), 0)
817        self.assertTrue((lse[:, :, M:] == -float("inf")).all())
818
819        loss = out.sum() + lse.sum()
820        loss.backward()
821        self.assertEqual(query.grad[:, :, M:, :].sum(), 0)
822
823    @supported_platform
824    def test_windowed_no_mask_vs_sdpa(self):
825        score_mod = _generate_windowed(1000)
826        attention = functools.partial(flex_attention, score_mod=score_mod)
827
828        sdpa_mask = _get_windowed_sdpa_mask(8, S, 1000)
829
830        sdpa_attention = functools.partial(
831            torch.nn.functional.scaled_dot_product_attention, attn_mask=sdpa_mask
832        )
833
834        self.run_test_with_call(attention, sdpa_attention, Q_H=16, KV_H=16, Q_S=8)
835
836    @supported_platform
837    def test_windowed_full_mask_vs_sdpa(self):
838        def mask_mod(b, h, q, kv):
839            return q + 1000 >= kv
840
841        score_mod = _generate_windowed(1000)
842
843        block_mask = create_block_mask(mask_mod, 1, 1, 8, S)
844        attention = functools.partial(
845            flex_attention, block_mask=block_mask, score_mod=score_mod
846        )
847
848        sdpa_mask = _get_windowed_sdpa_mask(8, S, 1000)
849        sdpa_attention = functools.partial(
850            torch.nn.functional.scaled_dot_product_attention, attn_mask=sdpa_mask
851        )
852
853        self.run_test_with_call(attention, sdpa_attention, Q_H=16, KV_H=16, Q_S=8)
854
855    @supported_platform
856    def test_windowed_partial_block_vs_sdpa(self):
857        def mask_mod(b, h, q, kv):
858            return q + 1000 >= kv
859
860        block_mask = create_block_mask(mask_mod, 1, 1, 8, S)
861        attention = functools.partial(flex_attention, block_mask=block_mask)
862
863        sdpa_mask = _get_windowed_sdpa_mask(8, S, 1000)
864        sdpa_attention = functools.partial(
865            torch.nn.functional.scaled_dot_product_attention, attn_mask=sdpa_mask
866        )
867
868        self.run_test_with_call(attention, sdpa_attention, Q_H=16, KV_H=16, Q_S=8)
869
870    @supported_platform
871    @common_utils.parametrize("dtype", test_dtypes)
872    @common_utils.parametrize("score_mod", [_identity, _causal])
873    def test_logsumexp_correctness(self, dtype, score_mod):
874        make_kv = functools.partial(
875            torch.randn,
876            (B, Hkv, S, D),
877            dtype=dtype,
878            device="cuda",
879            requires_grad=True,
880        )
881        make_q = functools.partial(
882            torch.randn,
883            (B, Hkv, Hq // Hkv, D),
884            dtype=dtype,
885            device="cuda",
886            requires_grad=True,
887        )
888        q, k, v = make_q(), make_kv(), make_kv()
889
890        @torch.compile
891        def sdpa_hop(q, k, v, score_mod):
892            return flex_attention(q, k, v, score_mod, return_lse=True)
893
894        @torch.compile(backend="aot_eager")
895        def eager_sdpa_hop(q, k, v, score_mod):
896            return flex_attention(q, k, v, score_mod, return_lse=True)
897
898        ref_out, ref_lse = eager_sdpa_hop(
899            q.to(torch.float64),
900            k.to(torch.float64),
901            v.to(torch.float64),
902            score_mod,
903        )
904        compiled_out, compiled_lse = sdpa_hop(q, k, v, score_mod)
905
906        self.assertTrue(ref_lse.dtype == torch.float64)
907        self.assertTrue(compiled_lse.dtype == torch.float32)
908
909        tolerance = Tolerances(atol=2e-2, rtol=2e-2)
910        torch.testing.assert_close(
911            ref_out.to(dtype=torch.float32),
912            compiled_out.to(dtype=torch.float32),
913            atol=tolerance.atol,
914            rtol=tolerance.rtol,
915        )
916        torch.testing.assert_close(
917            ref_lse.to(dtype=torch.float32),
918            compiled_lse.to(dtype=torch.float32),
919            atol=tolerance.atol,
920            rtol=tolerance.rtol,
921        )
922
923    @supported_platform
924    def test_logsumexp_only_return(self):
925        make_q = functools.partial(
926            torch.randn,
927            (B, Hkv, Hq // Hkv, D),
928            dtype=torch.float32,
929            device="cuda",
930            requires_grad=True,
931        )
932        make_kv = functools.partial(
933            torch.randn,
934            (B, Hkv, S, D),
935            dtype=torch.float32,
936            device="cuda",
937            requires_grad=True,
938        )
939
940        q, k, v = make_q(), make_kv(), make_kv()
941
942        @torch.compile
943        def func(q, k, v, score_mod):
944            _, lse = flex_attention(q, k, v, score_mod, return_lse=True)
945            lse_2 = lse * 2
946            return lse_2
947
948        _, code = run_and_get_code(func, q, k, v, _identity)
949        # Ensure that we're still generating the flexattention kernel
950        FileCheck().check_count(".run(primals_1, primals_2, primals_3", 1, True).run(
951            code[0]
952        )
953
954    @supported_platform
955    def test_non_sparse_mulitple_block_size(self):
956        def generate_causal_offset(offset: torch.Tensor):
957            def causal_offset_mask(b, h, q_idx, kv_idx):
958                return (offset + q_idx) >= kv_idx
959
960            return causal_offset_mask
961
962        def noop(score, b, h, q_idx, kv_idx):
963            return score
964
965        mod = generate_causal_offset(
966            torch.tensor(192, device="cuda", dtype=torch.int32)
967        )
968        block_mask = create_block_mask(mod, 1, 1, 1, 65)
969
970        self.run_test(
971            score_mod=None,
972            dtype=torch.float32,
973            block_mask=block_mask,
974            Q_B=1,
975            Q_H=1,
976            Q_S=1,
977            Q_D=16,
978            KV_B=1,
979            KV_H=1,
980            KV_S=65,
981            V_D=16,
982        )
983
984    @supported_platform
985    def test_do_not_trigger_dynamic_shapes_on_empty_block_mask(self):
986        torch._dynamo.reset()
987        H = Hq
988        q = torch.randn(B, H, 1, D, device="cuda")
989        for i in range(5):
990            k = torch.randn(B, H, S + i, D, device="cuda")
991            v = torch.randn(B, H, S + i, D, device="cuda")
992            compiled_flex_attention = torch.compile(flex_attention)
993            ref = flex_attention(q, k, v)
994            res = compiled_flex_attention(q, k, v)
995            tolerance = Tolerances(atol=2e-1, rtol=2e-1)
996            torch.testing.assert_close(
997                ref, res, atol=tolerance.atol, rtol=tolerance.rtol
998            )
999            # Ensure no more re-compilation after the second automatic dynamic shape version.
1000            if i == 0:
1001                self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 1)
1002            else:
1003                self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 2)
1004
1005
1006common_utils.instantiate_parametrized_tests(TestFlexDecoding)
1007
1008if __name__ == "__main__":
1009    from torch._inductor.test_case import run_tests
1010
1011    run_tests()
1012