xref: /aosp_15_r20/external/pytorch/test/inductor/test_flex_attention.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: inductor"]
2# flake8: noqa: B950
3
4import functools
5from collections import namedtuple
6from typing import Callable, Optional
7
8from unittest import expectedFailure, skip, skipUnless
9from unittest.mock import patch
10
11import torch
12
13from torch._dynamo.testing import CompileCounterWithBackend, normalize_gm
14from torch._higher_order_ops.flex_attention import flex_attention as flex_attention_hop
15from torch._inductor import metrics
16from torch._inductor.test_case import TestCase as InductorTestCase
17from torch._inductor.utils import run_and_get_code
18from torch.nn.attention._flex_attention import (
19    _causal,
20    _compose,
21    _flex_attention,
22    _generate_alibi_bias,
23    _identity,
24    _rel_bias,
25    _rel_causal,
26)
27from torch.testing import FileCheck
28from torch.testing._internal import common_utils
29from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_BF16
30from torch.utils._triton import has_triton
31
32# Skip tests if Triton is not available
33supported_platform = skipUnless(
34    torch.cuda.is_available()
35    and has_triton()
36    and torch.version.hip is None
37    and torch.cuda.get_device_capability() >= (8, 0),
38    "Requires CUDA and Triton",
39)
40
41Tolerances = namedtuple("Tolerances", ["atol", "rtol"])
42torch.set_float32_matmul_precision("high")
43
44index = torch.ops.aten.index
45
46
47def create_attention(score_mod):
48    return functools.partial(_flex_attention, score_mod=score_mod)
49
50
51test_dtypes = (
52    [torch.float16, torch.bfloat16, torch.float32]
53    if PLATFORM_SUPPORTS_BF16
54    else [torch.float16, torch.float32]
55)
56
57test_dtypes_fast = [torch.float16]
58
59# TODO float16 was causing ERRORs for tests on ROCm
60# See https://github.com/pytorch/pytorch/issues/123531
61if common_utils.TEST_WITH_ROCM:
62    test_dtypes = [torch.float32]
63
64
65# --------- Useful score mod functions for testing ---------
66def _inverse_causal(score, b, h, m, n):
67    return torch.where(m <= n, score, float("-inf"))
68
69
70def _times_two(score, b, h, m, n):
71    """Joint graph needed for correctness"""
72    return score * 2
73
74
75def _squared(score, b, h, m, n):
76    """Joint graph needed for correctness"""
77    return score * score
78
79
80def _head_offset(dtype: torch.dtype):
81    """Captured Buffer"""
82    head_offset = torch.rand(H, device="cuda", dtype=dtype)
83
84    def score_mod(score, b, h, m, n):
85        return score * head_offset[h]
86
87    return score_mod
88
89
90def _trig(score, b, h, m, n):
91    """Joint graph needed for correctness"""
92    return torch.sin(torch.cos(score)) + torch.tan(b)
93
94
95def _trig2(score, b, h, m, n):
96    """Branching joint graph"""
97    cos_score = torch.cos(score)
98    sin_score = torch.sin(score)
99    z = cos_score * sin_score + torch.tan(b)
100    return z
101
102
103test_score_mods = [
104    _identity,
105    _times_two,
106    _squared,
107    _causal,
108    _inverse_causal,
109    _rel_bias,
110    _rel_causal,
111    _generate_alibi_bias(8),
112]
113
114captured_buffers_map = {
115    "_head_offset": _head_offset,
116}
117
118B = 4
119H = 8
120S = 2048
121D = 64
122
123
124def query_key_value_clones(
125    query: torch.Tensor,
126    key: torch.Tensor,
127    value: torch.Tensor,
128    dtype: torch.dtype = None,
129):
130    """Clones the query, key, and value tensors and moves them to the specified dtype."""
131    if dtype is None:
132        dtype = query.dtype
133    query_ref = query.clone().detach().to(dtype).requires_grad_(query.requires_grad)
134    key_ref = key.clone().detach().to(dtype).requires_grad_(key.requires_grad)
135    value_ref = value.clone().detach().to(dtype).requires_grad_(value.requires_grad)
136    return query_ref, key_ref, value_ref
137
138
139class TestFlexAttention(InductorTestCase):
140    def _check_equal(
141        self,
142        golden_out: torch.Tensor,
143        ref_out: torch.Tensor,
144        compiled_out: torch.Tensor,
145        fudge_factor: float,
146        tensor_name: Optional[str] = None,
147    ):
148        compiled_error = (golden_out - compiled_out).abs().mean()
149        ref_error = (golden_out - ref_out).abs().mean()
150        if torch.isnan(compiled_error).any() and not torch.isnan(ref_error).any():
151            self.assertTrue(False, "Output/Grad with NaN")
152        if compiled_error > ref_error * fudge_factor:
153            name = tensor_name if tensor_name is not None else ""
154            msg = f"{name} Compiled error {compiled_error} is greater than ref error {ref_error} by more than {fudge_factor}X."
155            self.assertTrue(False, msg)
156
157    def _check_out_and_grad(
158        self,
159        golden_out: torch.Tensor,
160        ref_out: torch.Tensor,
161        compiled_out: torch.Tensor,
162        q_gold: torch.Tensor,
163        q_ref: torch.Tensor,
164        q: torch.Tensor,
165        k_gold: torch.Tensor,
166        k_ref: torch.Tensor,
167        k: torch.Tensor,
168        v_gold: torch.Tensor,
169        v_ref: torch.Tensor,
170        v: torch.Tensor,
171    ):
172        dtype = ref_out.dtype
173        with torch.no_grad():
174            # Note, it seems like we really are less accurate than the float32
175            # computation, likely due to the online softmax
176            if dtype == torch.float32:
177                fudge_factor = 10.0
178            else:
179                fudge_factor = 1.1
180
181            # Checkout output
182            self._check_equal(golden_out, ref_out, compiled_out, fudge_factor, "Out")
183
184            # Check gradients
185            q_fudge_factor = 2.5 * fudge_factor
186            self._check_equal(
187                q_gold.grad, q_ref.grad, q.grad, q_fudge_factor, "Grad_Query"
188            )
189            k_fudge_factor = 4 * fudge_factor
190            self._check_equal(
191                k_gold.grad, k_ref.grad, k.grad, k_fudge_factor, "Grad_Key"
192            )
193            v_fudge_factor = 4 * fudge_factor
194            self._check_equal(
195                v_gold.grad, v_ref.grad, v.grad, v_fudge_factor, "Grad_Value"
196            )
197
198    def run_test(
199        self,
200        score_mod: Callable,
201        dtype: torch.dtype = torch.float16,
202        Q_B: int = B,
203        Q_H: int = H,
204        Q_S: int = S,
205        Q_D: int = D,
206        KV_B: int = B,
207        KV_H: int = H,
208        KV_S: int = S,
209        KV_D: int = D,
210    ):
211        sdpa_partial = create_attention(score_mod)
212        compiled_sdpa = torch.compile(sdpa_partial)
213        q = torch.randn(
214            (Q_B, Q_H, Q_S, Q_D), dtype=dtype, device="cuda", requires_grad=True
215        )
216        k = torch.randn(
217            (KV_B, KV_H, KV_S, KV_D), dtype=dtype, device="cuda", requires_grad=True
218        )
219        v = torch.randn(
220            (KV_B, KV_H, KV_S, KV_D), dtype=dtype, device="cuda", requires_grad=True
221        )
222        q_ref, k_ref, v_ref = query_key_value_clones(q, k, v)
223        q_gold, k_gold, v_gold = query_key_value_clones(q, k, v, torch.float64)
224        golden_out = sdpa_partial(q_gold, k_gold, v_gold)
225        ref_out = sdpa_partial(q_ref, k_ref, v_ref)
226        compiled_out = compiled_sdpa(q, k, v)
227
228        backward_grad = torch.randn((Q_B, Q_H, Q_S, Q_D), dtype=dtype, device="cuda")
229
230        golden_out.backward(backward_grad.to(torch.float64))
231        ref_out.backward(backward_grad)
232        compiled_out.backward(backward_grad)
233
234        self._check_out_and_grad(
235            golden_out,
236            ref_out,
237            compiled_out,
238            q_gold,
239            q_ref,
240            q,
241            k_gold,
242            k_ref,
243            k,
244            v_gold,
245            v_ref,
246            v,
247        )
248
249    def run_dynamic_test(
250        self,
251        score_mod: Callable,
252        dtype: torch.dtype = torch.float16,
253        B: int = B,
254        H: int = H,
255        S: int = S,
256        D: int = D,
257    ):
258        sdpa_partial = create_attention(score_mod)
259        # The first eager batch, shape (B, H, S, D)
260        q1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True)
261        k1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True)
262        v1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True)
263        q1_ref, k1_ref, v1_ref = query_key_value_clones(q1, k1, v1)
264        q1_gold, k1_gold, v1_gold = query_key_value_clones(q1, k1, v1, torch.float64)
265        ref_out1 = sdpa_partial(q1_ref, k1_ref, v1_ref)
266        golden_out1 = sdpa_partial(q1_gold, k1_gold, v1_gold)
267
268        backward_grad1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
269
270        golden_out1.backward(backward_grad1.to(torch.float64))
271        ref_out1.backward(backward_grad1)
272
273        # The second eager batch, shape (B * 2, H, S / 2, D)
274        B = int(B * 2)
275        S = int(S / 2)
276        q2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True)
277        k2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True)
278        v2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True)
279        q2_ref, k2_ref, v2_ref = query_key_value_clones(q2, k2, v2)
280        q2_gold, k2_gold, v2_gold = query_key_value_clones(q2, k2, v2, torch.float64)
281        ref_out2 = sdpa_partial(q2_ref, k2_ref, v2_ref)
282        golden_out2 = sdpa_partial(q2_gold, k2_gold, v2_gold)
283
284        backward_grad2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
285
286        golden_out2.backward(backward_grad2.to(torch.float64))
287        ref_out2.backward(backward_grad2)
288
289        # Need to clear dynamo counters, since flex attention eager mode also uses dynamo tracing.
290        # We check dynamo counters["frames"]["ok"] to ensure there is no re-compilation.
291        torch._dynamo.reset()
292        # Compiling with dynamic shape in the first batch.
293        compiled_sdpa = torch.compile(sdpa_partial, dynamic=True)
294        compiled_out1 = compiled_sdpa(q1, k1, v1)
295        compiled_out1.backward(backward_grad1)
296
297        self._check_out_and_grad(
298            golden_out1,
299            ref_out1,
300            compiled_out1,
301            q1_gold,
302            q1_ref,
303            q1,
304            k1_gold,
305            k1_ref,
306            k1,
307            v1_gold,
308            v1_ref,
309            v1,
310        )
311        self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 1)
312
313        # No re-compilation, use the compiled dynamic shape version.
314        compiled_out2 = compiled_sdpa(q2, k2, v2)
315        compiled_out2.backward(backward_grad2)
316        self._check_out_and_grad(
317            golden_out2,
318            ref_out2,
319            compiled_out2,
320            q2_gold,
321            q2_ref,
322            q2,
323            k2_gold,
324            k2_ref,
325            k2,
326            v2_gold,
327            v2_ref,
328            v2,
329        )
330        self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 1)
331
332    def run_automatic_dynamic_test(
333        self,
334        score_mod: Callable,
335        dtype: torch.dtype = torch.float16,
336        B: int = B,
337        H: int = H,
338        S: int = S,
339        D: int = D,
340    ):
341        sdpa_partial = create_attention(score_mod)
342        # The first eager batch, shape (B, H, S, D)
343        q1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
344        k1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
345        v1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
346        golden_out1 = sdpa_partial(
347            q1.to(torch.float64), k1.to(torch.float64), v1.to(torch.float64)
348        )
349        ref_out1 = sdpa_partial(q1, k1, v1)
350
351        # The second eager batch, shape (B * 2, H, S / 2, D)
352        B = int(B * 2)
353        S = int(S / 2)
354        q2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
355        k2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
356        v2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
357        golden_out2 = sdpa_partial(
358            q2.to(torch.float64), k2.to(torch.float64), v2.to(torch.float64)
359        )
360        ref_out2 = sdpa_partial(q2, k2, v2)
361
362        # The third eager batch, shape (B * 4, H, S / 4, D)
363        B = int(B * 2)
364        S = int(S / 2)
365        q3 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
366        k3 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
367        v3 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
368        golden_out3 = sdpa_partial(
369            q3.to(torch.float64), k3.to(torch.float64), v3.to(torch.float64)
370        )
371        ref_out3 = sdpa_partial(q3, k3, v3)
372
373        # Need to clear dynamo counters, since flex attention eager mode also uses dynamo tracing.
374        # We check dynamo counters["frames"]["ok"] to ensure:
375        # 1, the first batch is compiled with static shape
376        # 2, the second batch is compiled with dynamic shape
377        # 3, no re-compilation in the third batch
378        torch._dynamo.reset()
379
380        # Note, it seems like we really are less accurate than the float32
381        # computation, likely due to the online softmax
382        if dtype == torch.float32:
383            fudge_factor = 10.0
384        else:
385            fudge_factor = 1.1
386
387        # The first batch.
388        compiled_sdpa = torch.compile(sdpa_partial)
389        compiled_out1 = compiled_sdpa(q1, k1, v1)
390        self._check_equal(golden_out1, ref_out1, compiled_out1, fudge_factor)
391        self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 1)
392
393        # The second batch (automatic dynamic).
394        compiled_out2 = compiled_sdpa(q2, k2, v2)
395        self._check_equal(golden_out2, ref_out2, compiled_out2, fudge_factor)
396        self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 2)
397
398        # The third batch (no re-compilation).
399        compiled_out3 = compiled_sdpa(q3, k3, v3)
400        self._check_equal(golden_out3, ref_out3, compiled_out3, fudge_factor)
401        self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 2)
402
403    @supported_platform
404    @common_utils.parametrize("dtype", test_dtypes)
405    @common_utils.parametrize("score_mod", test_score_mods)
406    def test_builtin_score_mods(self, dtype: torch.dtype, score_mod: Callable):
407        self.run_test(score_mod, dtype)
408
409    @supported_platform
410    @common_utils.parametrize("dtype", test_dtypes)
411    @common_utils.parametrize("score_mod", test_score_mods)
412    def test_builtin_score_mods_dynamic(self, dtype: torch.dtype, score_mod: Callable):
413        self.run_dynamic_test(score_mod, dtype)
414
415    @supported_platform
416    @common_utils.parametrize("dtype", test_dtypes)
417    @common_utils.parametrize("score_mod", test_score_mods)
418    def test_builtin_score_mods_automatic_dynamic(
419        self, dtype: torch.dtype, score_mod: Callable
420    ):
421        self.run_automatic_dynamic_test(score_mod, dtype)
422
423    @supported_platform
424    @common_utils.parametrize("dtype", test_dtypes_fast)
425    @common_utils.parametrize("score_mod", test_score_mods)
426    def test_builtin_score_mods_different_seqlen(
427        self, dtype: torch.dtype, score_mod: Callable
428    ):
429        self.run_test(
430            score_mod,
431            dtype,
432            B,
433            H,
434            S // 2,  # Seqlen of Q is different from seqlen of K/V
435            D,
436            B,
437            H,
438            S,
439            D,
440        )
441
442    @supported_platform
443    @common_utils.parametrize("dtype", test_dtypes)
444    def test_skip_odd_keys(self, dtype: torch.dtype):
445        def score_mod(score, b, h, q, kv):
446            return torch.where(kv % 2 == 0, score, float("-inf"))
447
448        self.run_test(score_mod, dtype)
449
450    @supported_platform
451    @common_utils.parametrize("dtype", test_dtypes)
452    def test_function_composition(self, dtype: torch.dtype):
453        def score_mod_1(score, b, h, m, n):
454            return score + (m - n)
455
456        def score_mod_2(score, b, h, m, n):
457            return torch.where(m <= n, score, float("-inf"))
458
459        composed_score_mod = _compose(score_mod_1, score_mod_2)
460
461        self.run_test(composed_score_mod, dtype)
462
463    @supported_platform
464    @common_utils.parametrize("dtype", test_dtypes)
465    def test_captured_buffers(self, dtype: torch.dtype):
466        head_offset = torch.rand(H, device="cuda", dtype=dtype)
467
468        def score_mod(score, b, h, m, n):
469            return score + head_offset[h]
470
471        self.run_test(score_mod, dtype)
472
473    @supported_platform
474    @common_utils.parametrize("dtype", test_dtypes)
475    def test_captured_buffers_all_dims(self, dtype: torch.dtype):
476        head_scale = torch.randn(H, device="cuda")
477        batch_scale = torch.randn(B, device="cuda")
478        tok_scale = torch.randn(S, device="cuda")
479
480        def all_bias(score, batch, head, token_q, token_kv):
481            score = score + tok_scale[token_q]
482            score = score + batch_scale[batch]
483            score = score + head_scale[head]
484            return score
485
486        self.run_test(all_bias, dtype)
487
488    @supported_platform
489    @common_utils.parametrize("dtype", test_dtypes_fast)
490    def test_seq_masking(self, dtype):
491        seq_idx = torch.zeros(S, device="cuda", dtype=torch.bool)
492        seq_idx[S // 2 :] = 1
493
494        def seq_mask_mod(score, b, h, q, kv):
495            return torch.where(seq_idx[q] == seq_idx[kv], score, float("-inf"))
496
497        self.run_test(seq_mask_mod, dtype)
498
499    @supported_platform
500    @common_utils.parametrize("dtype", test_dtypes_fast)
501    def test_load_from_bias_seq_only(self, dtype):
502        bias = torch.randn(S, S, device="cuda", dtype=dtype)
503
504        def bias_mod(score, b, h, q, kv):
505            return score + bias[q, kv]
506
507        self.run_test(bias_mod, dtype)
508
509    @supported_platform
510    @common_utils.parametrize("dtype", test_dtypes_fast)
511    def test_load_from_bias_seq_batch(self, dtype):
512        bias = torch.randn(B, S, S, device="cuda", dtype=dtype)
513
514        def bias_mod(score, b, h, q, kv):
515            return score + bias[b, q, kv]
516
517        self.run_test(bias_mod, dtype)
518
519    @supported_platform
520    @common_utils.parametrize("dtype", test_dtypes_fast)
521    def test_load_from_bias_head_seq_batch(self, dtype):
522        bias = torch.randn(B, H, S, S, device="cuda", dtype=dtype)
523
524        def bias_mod(score, b, h, q, kv):
525            return score + bias[b, h, q, kv]
526
527        self.run_test(bias_mod, dtype)
528
529    @supported_platform
530    @common_utils.parametrize("dtype", test_dtypes_fast)
531    def test_load_rel_bias(self, dtype):
532        rel_bias = torch.randn(2 * S, device="cuda", dtype=dtype)
533
534        def bias_mod(score, b, h, q, kv):
535            return score + rel_bias[(q - kv) + S]
536
537        self.run_test(bias_mod, dtype)
538
539    @supported_platform
540    @common_utils.parametrize("dtype", test_dtypes_fast)
541    def test_dependent_causal_bidirectional(self, dtype):
542        num_bidirectional = torch.randint(0, S, (B,), device="cuda", dtype=torch.int32)
543
544        def bias_mod(score, b, h, q, kv):
545            causal_attention = q >= kv
546            cur_num_bidirectional = num_bidirectional[b]
547            bidirectional_attention_on_video = (q <= cur_num_bidirectional) & (
548                kv <= cur_num_bidirectional
549            )
550            return torch.where(
551                bidirectional_attention_on_video | causal_attention,
552                score,
553                -float("inf"),
554            )
555
556        self.run_test(bias_mod, dtype)
557
558    @supported_platform
559    @common_utils.parametrize("dtype", test_dtypes_fast)
560    def test_natten_2d(self, dtype):
561        H = 32
562        W = S // H
563        WINDOW = 3
564        assert W * H == S
565
566        def get_x_y(idx):
567            # This should be a floor divide, but we don't support that properly
568            return idx / W, idx % W
569
570        def natten_mask(score, b, h, q, kv):
571            q_x, q_y = get_x_y(q)
572            kv_x, kv_y = get_x_y(kv)
573            return torch.where(
574                ((q_x - kv_x).abs() <= WINDOW) | ((q_y - kv_y).abs() <= WINDOW),
575                score,
576                float("-inf"),
577            )
578
579        self.run_test(natten_mask, dtype)
580
581    @supported_platform
582    @common_utils.parametrize("dtype", test_dtypes_fast)
583    def test_subgraph_respect_decompostion(self, dtype):
584        from torch._decomp import core_aten_decompositions
585        from torch.fx.experimental.proxy_tensor import make_fx
586
587        def score_mod_func(score, b, h, q, kv):
588            return score - q // (1 + kv)
589
590        make_tensor = functools.partial(
591            torch.randn,
592            (2, 2, 128, 4),
593            device="cuda",
594            dtype=torch.float64,
595            requires_grad=True,
596        )
597        query, key, value = make_tensor(), make_tensor(), make_tensor()
598        # floor_div is not decomposed in decompostion_table is empty
599        flex_attention = functools.partial(_flex_attention, score_mod=score_mod_func)
600        gm = make_fx(flex_attention, decomposition_table={})(query, key, value)
601        self.assertExpectedInline(
602            gm.sdpa_score0.code.strip(),
603            """\
604def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
605    add = torch.ops.aten.add.Tensor(arg4_1, 1);  arg4_1 = None
606    floor_divide = torch.ops.aten.floor_divide.default(arg3_1, add);  arg3_1 = add = None
607    sub = torch.ops.aten.sub.Tensor(arg0_1, floor_divide);  arg0_1 = floor_divide = None
608    return sub""",
609        )
610
611        # floor_div is decomposed for core_aten_decompositions
612        gm = make_fx(flex_attention, decomposition_table=core_aten_decompositions())(
613            query, key, value
614        )
615        self.assertExpectedInline(
616            gm.sdpa_score0.code.strip(),
617            """\
618def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
619    add = torch.ops.aten.add.Tensor(arg4_1, 1);  arg4_1 = None
620    div = torch.ops.aten.div.Tensor_mode(arg3_1, add, rounding_mode = 'floor');  arg3_1 = add = None
621    sub = torch.ops.aten.sub.Tensor(arg0_1, div);  arg0_1 = div = None
622    return sub""",
623        )
624
625    @supported_platform
626    @common_utils.parametrize("dtype", test_dtypes_fast)
627    def test_silu_on_score(self, dtype):
628        def silu_score(score, b, h, q, kv):
629            return torch.nn.functional.silu(score)
630
631        self.run_test(silu_score, dtype)
632
633    @supported_platform
634    @common_utils.parametrize("dtype", test_dtypes_fast)
635    def test_padded_dense_causal(self, dtype):
636        seq_len = torch.arange(B, device="cuda", dtype=torch.int32) + 1
637
638        def create_padded_dense_wrapper(orig_score_mod):
639            def njt_score_mod(qk, b, h, q, kv):
640                return torch.where(
641                    qk <= seq_len[b], orig_score_mod(qk, b, h, q, kv), -float("inf")
642                )
643
644            return njt_score_mod
645
646        causal_njt = create_padded_dense_wrapper(_causal)
647
648        self.run_test(causal_njt, dtype)
649
650    @supported_platform
651    @common_utils.parametrize("dtype", test_dtypes_fast)
652    def test_captured_scale(self, dtype):
653        scale = torch.ones((), device="cuda", dtype=torch.int32)
654
655        def score_mod_scale(qk, b, h, q, kv):
656            return qk + scale
657
658        self.run_test(score_mod_scale, dtype)
659
660    @supported_platform
661    @common_utils.parametrize("dtype", test_dtypes_fast)
662    def test_recompile_changed_score_mod(self, dtype):
663        scale = torch.ones((), device="cuda", dtype=torch.int32)
664        ADD = True
665
666        def score_mod_scale(qk, b, h, q, kv):
667            if ADD:
668                return qk + scale
669            else:
670                return qk * scale
671
672        self.run_test(score_mod_scale, dtype)
673        ADD = False
674        self.run_test(score_mod_scale, dtype)
675
676    @supported_platform
677    @expectedFailure  # If we capture a tensor then we can perform a reduction on it, and that shouldn't be allowed
678    @common_utils.parametrize("dtype", test_dtypes_fast)
679    def test_captured_reduction(self, dtype):
680        scale = torch.randn((B, 8), device="cuda")
681
682        def score_mod_scale(qk, b, h, q, kv):
683            return qk + scale[b].sum(dim=-1)
684
685        self.run_test(score_mod_scale, dtype)
686
687    @supported_platform
688    def test_multiple_score_mod_calls(self):
689        query = torch.randn((1, 8, 1024, 64), dtype=torch.float32, device="cuda")
690        keys = [
691            torch.randn((1, 8, 1024, 64), dtype=torch.float32, device="cuda")
692            for _ in range(2)
693        ]
694        values = [
695            torch.randn((1, 8, 1024, 64), dtype=torch.float32, device="cuda")
696            for _ in range(2)
697        ]
698
699        def scoremod_1(qk, b, h, q, kv):
700            return qk + (q - kv)
701
702        def scoremod_2(qk, b, h, q, kv):
703            return torch.where(q >= kv, qk, -float("inf"))
704
705        def f(q, k1, k2, v1, v2):
706            q2 = _flex_attention(q, k1, v1, score_mod=scoremod_1)
707            return _flex_attention(q2, k2, v2, score_mod=scoremod_2)
708
709        out = f(query, *keys, *values)
710        out2 = torch.compile(f)(query, *keys, *values)
711        tolerance = Tolerances(atol=2e-1, rtol=2e-1)
712        torch.testing.assert_close(out, out2, atol=tolerance.atol, rtol=tolerance.rtol)
713
714    @supported_platform
715    def test_multiple_score_mod_calls2(self):
716        query = torch.randn((1, 8, 1024, 64), dtype=torch.float32, device="cuda")
717        keys = [
718            torch.randn((1, 8, 1024, 64), dtype=torch.float32, device="cuda")
719            for _ in range(3)
720        ]
721        values = [
722            torch.randn((1, 8, 1024, 64), dtype=torch.float32, device="cuda")
723            for _ in range(3)
724        ]
725
726        def scoremod_1(qk, b, h, q, kv):
727            return qk + (q - kv)
728
729        def scoremod_2(qk, b, h, q, kv):
730            return torch.where(q >= kv, qk, -float("inf"))
731
732        attention1 = functools.partial(_flex_attention, score_mod=scoremod_1)
733
734        def f(q, k1, k2, k3, v1, v2, v3):
735            q2 = attention1(q, k1, v1)
736            q3 = _flex_attention(q2, k2, v2, score_mod=scoremod_2)
737            return _flex_attention(q3, k3, v3, score_mod=scoremod_1)
738
739        out = f(query, *keys, *values)
740        out2 = torch.compile(f)(query, *keys, *values)
741        self.assertTrue((out - out2).abs().mean() < 1e-2)
742
743    @supported_platform
744    def test_inputs_are_realized(self):
745        def f(q, k, v):
746            x = torch.randn(1024, device="cuda")
747            x = x * 2
748
749            def func(qk, b, h, q, kv):
750                return qk + x[q]
751
752            return _flex_attention(q.sin(), k, v, score_mod=func).cos()
753
754        q, k, v = (
755            torch.randn(1, 8, 1024, 64, device="cuda", requires_grad=True)
756            for _ in range(3)
757        )
758        ref = f(q, k, v)
759        out = torch.compile(f)(q, k, v)
760        self.assertTrue((ref - out).abs().mean() < 1e-2)
761        gradOut = torch.randn_like(q)
762
763        ref_grads = torch.autograd.grad(ref, (q, k, v), gradOut)
764        out_grads = torch.autograd.grad(out, (q, k, v), gradOut)
765        for ref, out in zip(ref_grads, out_grads):
766            self.assertTrue((ref - out).abs().mean() < 1e-2)
767
768    @supported_platform
769    def test_epilogue_fused(self):
770        @torch.compile
771        def f(q, k, v):
772            out = _flex_attention(q, k, v)
773            return out.cos()
774
775        q, k, v = (torch.randn(1, 8, 1024, 64, device="cuda") for _ in range(3))
776        metrics.reset()
777        f(q, k, v)
778        accessed_bytes = 1 * 8 * 1024 * 64 * torch.float32.itemsize
779        num_accesses = 4  # q, k, v reads, one output.
780        # TODO: Get rid of this fudge factor
781        # We need this fudge factor for now, since
782        # 1. For some reason we materialize the output of the attention unnecessarily (it's related to the mutation somehow)
783        # 2. We also write the extraneous logsumexp
784        num_accesses += 2
785        self.assertLess(metrics.num_bytes_accessed, accessed_bytes * num_accesses)
786
787    @supported_platform
788    @skip("Triton bug ")  # https://github.com/pytorch/pytorch/issues/124571
789    @common_utils.parametrize("dtype", test_dtypes)
790    def test_njt_causal(self, dtype):
791        offsets = torch.tensor(
792            [0, 1024, 1024 + 512, S], device="cuda", dtype=torch.int32
793        )
794        seq_idx = torch.zeros(S, device="cuda", dtype=torch.int32)
795        for idx in range(len(offsets) - 1):
796            seq_idx[offsets[idx] : offsets[idx + 1]] = idx
797
798        def create_njt_wrapper(orig_score_mod, offsets, seq_idx):
799            def njt_score_mod(qk, b, h, q, kv):
800                q_nested = q - offsets[seq_idx[q]]
801                kv_nested = kv - offsets[seq_idx[kv]]
802                return orig_score_mod(qk, b, h, q_nested, kv_nested)
803
804            return njt_score_mod
805
806        causal_njt = create_njt_wrapper(_causal, offsets, seq_idx)
807
808        self.run_test(causal_njt, dtype)
809
810    @supported_platform
811    def test_mixed_dtypes_fails(self):
812        query = torch.randn((1, 1, 1024, 64), dtype=torch.float32, device="cuda")
813        key = torch.randn((1, 1, 1024, 64), dtype=torch.float16, device="cuda")
814        value = torch.randn((1, 1, 1024, 64), dtype=torch.float16, device="cuda")
815        with self.assertRaisesRegex(
816            ValueError, "Expected query, key, and value to have the same dtype"
817        ):
818            _flex_attention(query, key, value, _identity)
819
820    @supported_platform
821    @patch.object(torch._inductor.config, "max_autotune", True)
822    def test_max_autotune(self):
823        def score_mod(score, b, h, m, n):
824            return score * 2
825
826        self.run_test(score_mod)
827
828    @supported_platform
829    @skip("TODO: Figure out why this is erroring")
830    @patch.object(torch._inductor.config, "max_autotune", True)
831    def test_max_autotune_with_captured(self):
832        head_scale = torch.randn(H, device="cuda")
833        batch_scale = torch.randn(B, device="cuda")
834        tok_scale = torch.randn(S, device="cuda")
835
836        def bias_mod(score, batch, head, token_q, token_kv):
837            score = score + tok_scale[token_q]
838            score = score + batch_scale[batch]
839            score = score + head_scale[head]
840            return score
841
842        self.run_test(bias_mod)
843
844    @supported_platform
845    @common_utils.parametrize("dtype", test_dtypes)
846    @common_utils.parametrize("score_mod", [_identity, _causal])
847    def test_logsumexp_correctness(self, dtype, score_mod):
848        @torch.compile
849        def sdpa_hop(q, k, v, score_mod):
850            return flex_attention_hop(q, k, v, score_mod)
851
852        @torch.compile(backend="aot_eager")
853        def eager_sdpa_hop(q, k, v, score_mod):
854            """The main entrypoint for FlexAttention doesnt return LSE.
855            Besides dropping LSE it also ensures that the hop is compiled with aot-eager
856            backend. We need to replicate this.
857            """
858            return flex_attention_hop(q, k, v, score_mod)
859
860        make_tensor = functools.partial(
861            torch.randn,
862            (B, H, S, D),
863            dtype=dtype,
864            device="cuda",
865            requires_grad=True,
866        )
867        q, k, v = make_tensor(), make_tensor(), make_tensor()
868
869        ref_out, ref_lse = eager_sdpa_hop(
870            q.to(torch.float64), k.to(torch.float64), v.to(torch.float64), score_mod
871        )
872        compiled_out, compiled_lse = sdpa_hop(q, k, v, score_mod)
873
874        # Comparing LSE for the ref and the compiled version
875        # The compiled uses a change of base trick to more efficiently compute the LSE
876        # this means that the base for the LSE computed by ref is e while for the compiled
877        # version it is 2. To compare we use the change of base formula
878        # log_2(x_compiled) = log_e(x_ref) * log_2(e) where
879        # x_ref      = sum(_i e^(scores[i]))
880        # x_compiled = sum(_i 2^(log2(e) * scores[i]))
881
882        self.assertTrue(ref_lse.dtype == torch.float64)
883        self.assertTrue(compiled_lse.dtype == torch.float32)
884        ref_lse = ref_lse * torch.log2(torch.tensor(torch.e))
885
886        tolerance = Tolerances(atol=2e-2, rtol=2e-2)
887        torch.testing.assert_close(
888            ref_out.to(dtype=torch.float32),
889            compiled_out.to(dtype=torch.float32),
890            atol=tolerance.atol,
891            rtol=tolerance.rtol,
892        )
893        torch.testing.assert_close(
894            ref_lse.to(dtype=torch.float32),
895            compiled_lse.to(dtype=torch.float32),
896            atol=tolerance.atol,
897            rtol=tolerance.rtol,
898        )
899
900    @supported_platform
901    def test_logsumexp_only_return(self):
902        make_tensor = functools.partial(
903            torch.randn,
904            (B, H, S, D),
905            dtype=torch.float32,
906            device="cuda",
907            requires_grad=True,
908        )
909        q, k, v = make_tensor(), make_tensor(), make_tensor()
910
911        @torch.compile
912        def func(q, k, v, score_mod):
913            _, lse = flex_attention_hop(q, k, v, score_mod)
914            lse_2 = lse * 2
915            return lse_2
916
917        _, code = run_and_get_code(func, q, k, v, _identity)
918        # Ensure that two kernels are generated
919        FileCheck().check_count(".run(", 2, True).run(code[0])
920
921    @supported_platform
922    def test_logsumexp_is_not_fused(self):
923        make_tensor = functools.partial(
924            torch.randn,
925            (B, H, S, D),
926            dtype=torch.float32,
927            device="cuda",
928            requires_grad=True,
929        )
930        q, k, v = make_tensor(), make_tensor(), make_tensor()
931
932        @torch.compile
933        def func(q, k, v, score_mod):
934            out, lse = flex_attention_hop(q, k, v, score_mod)
935            lse_2 = lse * 2
936            return out, lse_2
937
938        _, code = run_and_get_code(func, q, k, v, _identity)
939        # Ensure that two kernels are generated
940        FileCheck().check_count(".run(", 2, True).run(code[0])
941
942    @supported_platform
943    @common_utils.parametrize(
944        "score_mod", [_identity, _causal, _times_two, _squared, _trig, _trig2]
945    )
946    def test_aot_eager_gradcheck(self, score_mod):
947        make_tensor = functools.partial(
948            torch.randn,
949            (2, 2, 8, 4),
950            device="cuda",
951            dtype=torch.float64,
952            requires_grad=True,
953        )
954        query, key, value = make_tensor(), make_tensor(), make_tensor()
955
956        func = torch.compile(_flex_attention, backend="aot_eager", fullgraph=True)
957
958        self.assertTrue(
959            torch.autograd.gradcheck(
960                func, (query, key, value, score_mod), raise_exception=True
961            )
962        )
963
964    @supported_platform
965    @common_utils.parametrize("score_mod_name", ["_head_offset"])
966    @common_utils.parametrize("mode", ["eager", "aot_eager"])
967    def test_captured_score_mod_aot_eager_gradcheck(
968        self, score_mod_name: str, mode: str
969    ):
970        make_tensor = functools.partial(
971            torch.randn,
972            (2, 2, 8, 4),
973            device="cuda",
974            dtype=torch.float64,
975            requires_grad=True,
976        )
977        query, key, value = make_tensor(), make_tensor(), make_tensor()
978
979        func = torch.compile(_flex_attention, backend=mode, fullgraph=True)
980        score_mod = captured_buffers_map[score_mod_name](torch.float64)
981
982        self.assertTrue(
983            torch.autograd.gradcheck(
984                func, (query, key, value, score_mod), raise_exception=True
985            )
986        )
987
988    @supported_platform
989    def test_fw_bw_graph_correctness(self):
990        cnt = CompileCounterWithBackend("aot_eager")
991        make_tensor = functools.partial(
992            torch.randn,
993            (2, 2, 8, 4),
994            device="cuda",
995            dtype=torch.float64,
996            requires_grad=True,
997        )
998        query, key, value = make_tensor(), make_tensor(), make_tensor()
999
1000        func = torch.compile(_flex_attention, backend=cnt, fullgraph=True)
1001        out = func(query, key, value, _squared)
1002        out.sum().backward()
1003        self.assertEqual(cnt.frame_count, 1)
1004        self.assertEqual(len(cnt.graphs), 1)
1005        graph = cnt.graphs[0]
1006        norm_graph = normalize_gm(graph.print_readable(print_output=False))
1007        self.assertExpectedInline(
1008            norm_graph,
1009            """\
1010class GraphModule(torch.nn.Module):
1011    def forward(self, L_args_0_: "f64[2, 2, 8, 4]", L_args_1_: "f64[2, 2, 8, 4]", L_args_2_: "f64[2, 2, 8, 4]"):
1012        l_args_0_ = L_args_0_
1013        l_args_1_ = L_args_1_
1014        l_args_2_ = L_args_2_
1015
1016        new_empty: "f64[]" = l_args_0_.new_empty([], requires_grad = True)
1017        new_empty_1: "i32[]" = l_args_0_.new_empty([], dtype = torch.int32)
1018        new_empty_2: "i32[]" = l_args_0_.new_empty([], dtype = torch.int32)
1019        new_empty_3: "i32[]" = l_args_0_.new_empty([], dtype = torch.int32)
1020        new_empty_4: "i32[]" = l_args_0_.new_empty([], dtype = torch.int32)
1021        flex_attention_0 = self.flex_attention_0
1022        flex_attention = torch.ops.higher_order.flex_attention(l_args_0_, l_args_1_, l_args_2_, flex_attention_0);  l_args_0_ = l_args_1_ = l_args_2_ = flex_attention_0 = None
1023        out: "f64[2, 2, 8, 4]" = flex_attention[0];  flex_attention = None
1024        return (out,)
1025
1026    class GraphModule(torch.nn.Module):
1027        def forward(self, new_empty: "f64[]", new_empty_1: "i32[]", new_empty_2: "i32[]", new_empty_3: "i32[]", new_empty_4: "i32[]"):
1028            mul: "f64[]" = new_empty * new_empty;  new_empty = None
1029            return mul
1030""",  # noqa: B950
1031        )
1032        # Save the AOT graphs
1033        aot_graphs = []
1034        from torch._inductor import compile_fx
1035
1036        def debug_compile_fx_inner(graph, example_inputs, *args, **kwargs):
1037            aot_graphs.append(graph)
1038            return graph
1039
1040        backend = functools.partial(
1041            compile_fx.compile_fx, inner_compile=debug_compile_fx_inner
1042        )
1043        func = torch.compile(func, backend=backend, fullgraph=True)
1044        out = func(query, key, value, _squared)
1045        out.sum().backward()
1046
1047        joint_graph = normalize_gm(aot_graphs[1].print_readable(print_output=False))
1048
1049        self.assertExpectedInline(
1050            joint_graph,
1051            """\
1052class GraphModule(torch.nn.Module):
1053    def forward(self, primals_1: "f64[2, 2, 8, 4]", primals_2: "f64[2, 2, 8, 4]", primals_3: "f64[2, 2, 8, 4]", getitem: "f64[2, 2, 8, 4]", getitem_1: "f32[2, 2, 8]", tangents_1: "f64[2, 2, 8, 4]"):
1054        fw_graph = self.fw_graph
1055        joint_graph = self.joint_graph
1056        flex_attention_backward = torch.ops.higher_order.flex_attention_backward(primals_1, primals_2, primals_3, getitem, getitem_1, tangents_1, fw_graph, joint_graph);  primals_1 = primals_2 = primals_3 = getitem = getitem_1 = tangents_1 = fw_graph = joint_graph = None
1057        getitem_2: "f64[2, 2, 8, 4]" = flex_attention_backward[0]
1058        getitem_3: "f64[2, 2, 8, 4]" = flex_attention_backward[1]
1059        getitem_4: "f64[2, 2, 8, 4]" = flex_attention_backward[2];  flex_attention_backward = None
1060        return [getitem_2, getitem_3, getitem_4]
1061
1062    class <lambda>(torch.nn.Module):
1063        def forward(self, arg0_1: "f64[]", arg1_1: "i32[]", arg2_1: "i32[]", arg3_1: "i32[]", arg4_1: "i32[]"):
1064            mul: "f64[]" = torch.ops.aten.mul.Tensor(arg0_1, arg0_1);  arg0_1 = None
1065            return mul
1066
1067    class <lambda>(torch.nn.Module):
1068        def forward(self, arg0_1: "f64[]", arg1_1: "i32[]", arg2_1: "i32[]", arg3_1: "i32[]", arg4_1: "i32[]", arg5_1: "f64[]"):
1069            mul: "f64[]" = torch.ops.aten.mul.Tensor(arg0_1, arg0_1)
1070            mul_1: "f64[]" = torch.ops.aten.mul.Tensor(arg5_1, arg0_1)
1071            mul_2: "f64[]" = torch.ops.aten.mul.Tensor(arg5_1, arg0_1);  arg5_1 = arg0_1 = None
1072            add: "f64[]" = torch.ops.aten.add.Tensor(mul_2, mul_1);  mul_2 = mul_1 = None
1073            return [add, None, None, None, None]
1074""",  # noqa: B950
1075        )
1076
1077
1078common_utils.instantiate_parametrized_tests(TestFlexAttention)
1079
1080if __name__ == "__main__":
1081    from torch._inductor.test_case import run_tests
1082
1083    run_tests()
1084