1# Owner(s): ["module: inductor"] 2# flake8: noqa: B950 3 4import functools 5from collections import namedtuple 6from contextlib import nullcontext 7from typing import Callable, Optional 8from unittest import expectedFailure, skipUnless 9from unittest.mock import patch 10 11import torch 12from torch._inductor.test_case import TestCase as InductorTestCase 13from torch._inductor.utils import run_and_get_code 14from torch.nn.attention.flex_attention import ( 15 _create_empty_block_mask, 16 _identity, 17 BlockMask, 18 create_block_mask, 19 flex_attention, 20) 21from torch.testing import FileCheck 22from torch.testing._internal import common_utils 23from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_BF16 24from torch.testing._internal.common_utils import skipIfRocm, TEST_WITH_ROCM 25from torch.utils._triton import has_triton 26 27 28# Skip tests if Triton is not available 29supported_platform = skipUnless( 30 torch.cuda.is_available() 31 and has_triton() 32 and torch.cuda.get_device_capability() >= (8, 0), 33 "Requires CUDA and Triton", 34) 35 36Tolerances = namedtuple("Tolerances", ["atol", "rtol"]) 37torch.set_float32_matmul_precision("high") 38 39index = torch.ops.aten.index 40Tensor = torch.Tensor 41 42 43def create_attention(score_mod, block_mask, enable_gqa=False): 44 return functools.partial( 45 flex_attention, 46 score_mod=score_mod, 47 block_mask=block_mask, 48 enable_gqa=enable_gqa, 49 ) 50 51 52def create_block_mask_test(score_mod, query, key): 53 block_mask = create_block_mask( 54 score_mod, 1, 1, query.shape[-2], key.shape[-2], query.device 55 ) 56 return block_mask 57 58 59test_dtypes = ( 60 [torch.float16, torch.bfloat16, torch.float32] 61 if PLATFORM_SUPPORTS_BF16 62 else [torch.float16, torch.float32] 63) 64 65test_dtypes_fast = [torch.float16] 66 67 68# --------- Useful score mod functions for testing --------- 69def _causal( 70 score: Tensor, 71 batch: Tensor, 72 head: Tensor, 73 token_q: Tensor, 74 token_kv: Tensor, 75) -> Tensor: 76 return torch.where(token_q >= token_kv, score, float("-inf")) 77 78 79def _generate_windowed(offset): 80 def _windowed(score, b, h, q, kv): 81 return torch.where(q + offset >= kv, score, float("-inf")) 82 83 return _windowed 84 85 86def _get_windowed_sdpa_mask(Mq, Mkv, offset): 87 return torch.tril(torch.ones(Mkv, Mkv, dtype=torch.bool, device="cuda"))[ 88 offset : offset + Mq 89 ] 90 91 92def _rel_bias( 93 score: Tensor, 94 batch: Tensor, 95 head: Tensor, 96 token_q: Tensor, 97 token_kv: Tensor, 98) -> Tensor: 99 return score + (token_q - token_kv) 100 101 102def _rel_causal( 103 score: Tensor, 104 batch: Tensor, 105 head: Tensor, 106 token_q: Tensor, 107 token_kv: Tensor, 108) -> Tensor: 109 return torch.where(token_q >= token_kv, score + (token_q - token_kv), float("-inf")) 110 111 112def _generate_alibi_bias(num_heads: int): 113 def _alibi_bias( 114 score: Tensor, 115 batch: Tensor, 116 head: Tensor, 117 token_q: Tensor, 118 token_kv: Tensor, 119 ) -> Tensor: 120 scale = torch.exp2(-((head + 1) * 8.0 / num_heads)) 121 return score + (token_kv - token_q) * scale 122 123 return _alibi_bias 124 125 126def _inverse_causal(score, b, h, m, n): 127 return torch.where(m <= n, score, float("-inf")) 128 129 130def _times_two(score, b, h, m, n): 131 """Joint graph needed for correctness""" 132 return score * 2 133 134 135def _squared(score, b, h, m, n): 136 """Joint graph needed for correctness""" 137 return score * score 138 139 140def _head_offset(dtype: torch.dtype): 141 """Captured Buffer""" 142 head_offset = torch.rand(Hq, device="cuda", dtype=dtype) 143 144 def score_mod(score, b, h, m, n): 145 return score * head_offset[h] 146 147 return score_mod 148 149 150def _trig(score, b, h, m, n): 151 """Joint graph needed for correctness""" 152 return torch.sin(torch.cos(score)) + torch.tan(b) 153 154 155def _trig2(score, b, h, m, n): 156 """Branching joint graph""" 157 cos_score = torch.cos(score) 158 sin_score = torch.sin(score) 159 z = cos_score * sin_score + torch.tan(b) 160 return z 161 162 163test_score_mods = [ 164 _identity, 165 _times_two, 166 _squared, 167 _causal, 168 _inverse_causal, 169 _rel_bias, 170 _rel_causal, 171 _generate_alibi_bias(8), 172 _generate_windowed(1000), 173] 174 175captured_buffers_map = { 176 "_head_offset": _head_offset, 177} 178 179B = 4 180S = 2048 181D = 64 182 183 184test_Hq_Hkv = [ 185 (16, 1), 186 (8, 2), 187 (16, 16), 188] 189 190(Hq, Hkv) = (16, 8) 191 192 193def query_key_value_clones( 194 query: torch.Tensor, 195 key: torch.Tensor, 196 value: torch.Tensor, 197 dtype: torch.dtype = None, 198): 199 """Clones the query, key, and value tensors and moves them to the specified dtype.""" 200 if dtype is None: 201 dtype = query.dtype 202 query_ref = query.clone().detach().to(dtype).requires_grad_(query.requires_grad) 203 key_ref = key.clone().detach().to(dtype).requires_grad_(key.requires_grad) 204 value_ref = value.clone().detach().to(dtype).requires_grad_(value.requires_grad) 205 return query_ref, key_ref, value_ref 206 207 208class TestFlexDecoding(InductorTestCase): 209 def _check_equal( 210 self, 211 golden_out: torch.Tensor, 212 ref_out: torch.Tensor, 213 compiled_out: torch.Tensor, 214 fudge_factor: float, 215 tensor_name: Optional[str] = None, 216 ): 217 compiled_error = (golden_out - compiled_out).abs().mean() 218 ref_error = (golden_out - ref_out).abs().mean() 219 if torch.isnan(compiled_error).any() and not torch.isnan(ref_error).any(): 220 self.assertTrue(False, "Output/Grad with NaN") 221 if ref_error < (1e-4) * golden_out.abs().mean(): 222 print( 223 "very small ref error of ", 224 (ref_error.to(torch.float64) * (1e5) / golden_out.abs().mean()), 225 ) 226 tolerance = Tolerances(atol=2e-1, rtol=2e-1) 227 torch.testing.assert_close( 228 golden_out.to(dtype=compiled_out.dtype), 229 compiled_out, 230 atol=tolerance.atol, 231 rtol=tolerance.rtol, 232 ) 233 elif compiled_error > ref_error * fudge_factor: 234 name = tensor_name if tensor_name is not None else "" 235 msg = f"{name} Compiled error {compiled_error} is greater than ref error {ref_error} by more than {fudge_factor}X." 236 self.assertTrue(False, msg) 237 238 def _check_out( 239 self, 240 golden_out: torch.Tensor, 241 ref_out: torch.Tensor, 242 compiled_out: torch.Tensor, 243 ): 244 dtype = ref_out.dtype 245 with torch.no_grad(): 246 # Note, it seems like we really are less accurate than the float32 247 # computation, likely due to the online softmax 248 if dtype == torch.float32: 249 fudge_factor = 10.0 250 else: 251 fudge_factor = 1.1 252 253 # Checkout output 254 self._check_equal(golden_out, ref_out, compiled_out, fudge_factor, "Out") 255 256 def run_test( 257 self, 258 score_mod: Optional[Callable], 259 dtype: torch.dtype = torch.float16, 260 Q_B: int = B, 261 Q_H: int = Hq, 262 Q_S: int = 1, 263 Q_D: int = D, 264 KV_B: int = B, 265 KV_H: int = Hkv, 266 KV_S: int = S, 267 V_D: int = D, 268 block_mask: Optional[BlockMask] = None, 269 ): 270 assert ( 271 score_mod is not None or block_mask is not None 272 ), "Must provide score_mod or block_mask" 273 assert Q_H % KV_H == 0 274 if TEST_WITH_ROCM and Q_H != KV_H: 275 self.skipTest("enable_gqa=True is unsupported on ROCM, for now") 276 q = torch.randn( 277 (Q_B, Q_H, Q_S, Q_D), 278 dtype=dtype, 279 device="cuda", 280 requires_grad=False, 281 ) 282 k = torch.randn( 283 (KV_B, KV_H, KV_S, Q_D), dtype=dtype, device="cuda", requires_grad=False 284 ) 285 v = torch.randn( 286 (KV_B, KV_H, KV_S, V_D), dtype=dtype, device="cuda", requires_grad=False 287 ) 288 q_ref, k_ref, v_ref = query_key_value_clones(q, k, v) 289 q_gold, k_gold, v_gold = query_key_value_clones(q, k, v, torch.float64) 290 291 sdpa_partial = create_attention( 292 score_mod, block_mask, enable_gqa=(not Q_H == KV_H) 293 ) 294 compiled_sdpa = torch.compile(sdpa_partial) 295 golden_out, gold_lse = sdpa_partial(q_gold, k_gold, v_gold, return_lse=True) 296 ref_out, ref_lse = sdpa_partial(q_ref, k_ref, v_ref, return_lse=True) 297 compiled_out, compiled_lse = compiled_sdpa(q, k, v, return_lse=True) 298 299 self._check_out( 300 golden_out, 301 ref_out, 302 compiled_out, 303 ) 304 self._check_out( 305 gold_lse, 306 ref_lse, 307 compiled_lse, 308 ) 309 310 def run_test_with_call( 311 self, 312 sdpa_call: Callable, 313 golden_call: Optional[Callable] = None, 314 dtype: torch.dtype = torch.float16, 315 Q_B: int = B, 316 Q_H: int = Hq, 317 Q_S: int = 1, 318 Q_D: int = D, 319 KV_B: int = B, 320 KV_H: int = Hkv, 321 KV_S: int = S, 322 V_D: int = D, 323 ): 324 if not golden_call: 325 golden_call = sdpa_call 326 q = torch.randn( 327 (Q_B, KV_H, Q_S * (Q_H // KV_H), Q_D), 328 dtype=dtype, 329 device="cuda", 330 requires_grad=False, 331 ) 332 k = torch.randn( 333 (KV_B, KV_H, KV_S, Q_D), dtype=dtype, device="cuda", requires_grad=False 334 ) 335 v = torch.randn( 336 (KV_B, KV_H, KV_S, V_D), dtype=dtype, device="cuda", requires_grad=False 337 ) 338 q_ref, k_ref, v_ref = query_key_value_clones(q, k, v) 339 q_gold, k_gold, v_gold = query_key_value_clones(q, k, v, torch.float64) 340 341 compiled_sdpa = torch.compile(sdpa_call) 342 golden_out = golden_call(q_gold, k_gold, v_gold) 343 ref_out = golden_call(q_ref, k_ref, v_ref) 344 compiled_out = compiled_sdpa(q, k, v) 345 346 self._check_out( 347 golden_out, 348 ref_out, 349 compiled_out, 350 ) 351 352 @supported_platform 353 @expectedFailure 354 @common_utils.parametrize("dtype", test_dtypes_fast) 355 def test_bw_decoding_fails(self, dtype): 356 make_kv = functools.partial( 357 torch.randn, 358 (2, 2, 128, 4), 359 dtype=dtype, 360 device="cuda", 361 requires_grad=True, 362 ) 363 make_q = functools.partial( 364 torch.randn, 365 (2, 2, 8, 4), 366 dtype=dtype, 367 device="cuda", 368 requires_grad=True, 369 ) 370 q, k, v, backward_grad = make_q(), make_kv(), make_kv(), make_q() 371 372 block_mask = _create_empty_block_mask(q, k) 373 374 @torch.compile 375 def sdpa_hop(q, k, v, score_mod, block_mask): 376 return flex_attention(q, k, v, score_mod) 377 378 output = sdpa_hop(q, k, v, _identity, block_mask) 379 380 output.backward(backward_grad) 381 382 @supported_platform 383 @common_utils.parametrize("dtype", test_dtypes) 384 @common_utils.parametrize("score_mod", test_score_mods) 385 @common_utils.parametrize("head_dims", test_Hq_Hkv) 386 def test_builtin_score_mods( 387 self, dtype: torch.dtype, score_mod: Callable, head_dims 388 ): 389 Hq, Hkv = head_dims 390 assert Hq % Hkv == 0 391 self.run_test(score_mod, dtype, Q_H=Hq, KV_H=Hkv) 392 393 def input_strides_1(B, H, S, D): 394 return ((H * S * D, S * D, D, 1), 997) # offset 395 396 def input_strides_2(B, H, S, D): 397 return ((H * D, D, B * H * D, 1), 499) # transposed dimensions 398 399 def input_strides_3(B, H, S, D): 400 return ((S * (D + 1), B * S * (D + 1), (D + 1), 1), 293) # additional buffer 401 402 def input_strides_4(B, H, S, D): 403 return ((1, D, (B + 1) * (H + 1) * D, 1), 97) # shared dimension 404 405 test_input_strides = [ 406 input_strides_1, 407 input_strides_2, 408 input_strides_3, 409 input_strides_4, 410 ] 411 412 @supported_platform 413 @common_utils.parametrize("dtype", test_dtypes_fast) 414 @common_utils.parametrize("k_s", test_input_strides) 415 @common_utils.parametrize("v_s", test_input_strides) 416 @common_utils.parametrize("head_dims", test_Hq_Hkv) 417 def test_strided_inputs(self, dtype: torch.dtype, k_s, v_s, head_dims): 418 Hq, Hkv = head_dims 419 assert Hq % Hkv == 0 420 q1 = torch.randn((B * Hq * D), dtype=dtype, device="cuda") 421 k1 = torch.randn((B * Hkv * S * D * 4), dtype=dtype, device="cuda") 422 v1 = torch.randn((B * Hkv * S * D * 4), dtype=dtype, device="cuda") 423 424 k_shape = (B, Hkv, S, D) 425 v_shape = (B, Hkv, S, D) 426 427 q = q1.view(1, Hq, B, D).transpose(0, 2) 428 429 k_strides, k_offset = k_s(B, Hkv, S, D) 430 k_max = [x * (y - 1) for x, y in zip(k_strides, k_shape)] 431 assert sum(k_max) + k_offset < B * Hkv * S * D * 4 432 assert k_strides[-1] == 1 433 k = torch.as_strided(k1, k_shape, k_strides, k_offset) 434 435 v_strides, v_offset = v_s(B, Hkv, S, D) 436 v_max = [x * (y - 1) for x, y in zip(v_strides, v_shape)] 437 assert sum(v_max) + v_offset < B * Hkv * S * D * 4 438 assert v_strides[-1] == 1 439 v = torch.as_strided(v1, v_shape, v_strides, v_offset) 440 441 sdpa_partial = create_attention( 442 score_mod=_generate_alibi_bias(8), 443 block_mask=None, 444 enable_gqa=(not Hq == Hkv), 445 ) 446 compiled_sdpa = torch.compile(sdpa_partial) 447 ref_out = sdpa_partial(q, k, v) 448 compiled_out = compiled_sdpa(q, k, v) 449 450 tolerance = Tolerances(atol=2e-1, rtol=2e-1) 451 torch.testing.assert_close( 452 ref_out, compiled_out, atol=tolerance.atol, rtol=tolerance.rtol 453 ) 454 455 @supported_platform 456 @common_utils.parametrize("dtype", test_dtypes) 457 def test_skip_odd_keys(self, dtype: torch.dtype): 458 def score_mod(score, b, h, q, kv): 459 return torch.where(kv % 2 == 0, score, float("-inf")) 460 461 self.run_test(score_mod, dtype) 462 463 @supported_platform 464 @common_utils.parametrize("dtype", test_dtypes) 465 def test_function_composition(self, dtype: torch.dtype): 466 def score_mod_1(score, b, h, m, n): 467 return score + (m - n) 468 469 def score_mod_2(score, b, h, m, n): 470 return torch.where(m <= n, score, float("-inf")) 471 472 def composed_score_mod(score, b, h, m, n): 473 return score_mod_2(score_mod_1(score, b, h, m, n), b, h, m, n) 474 475 self.run_test(composed_score_mod, dtype) 476 477 @supported_platform 478 @common_utils.parametrize("dtype", test_dtypes) 479 def test_captured_buffers(self, dtype: torch.dtype): 480 head_offset = torch.rand(Hq, device="cuda", dtype=dtype) 481 482 def score_mod(score, b, h, m, n): 483 return score + head_offset[h] 484 485 self.run_test(score_mod, dtype) 486 487 @supported_platform 488 @common_utils.parametrize("dtype", test_dtypes) 489 def test_captured_buffers_all_dims(self, dtype: torch.dtype): 490 head_scale = torch.randn(Hq, device="cuda") 491 batch_scale = torch.randn(B, device="cuda") 492 kv_scale = torch.randn(S, device="cuda") 493 q_scale = torch.randn(1, device="cuda") 494 495 def all_bias(score, batch, head, token_q, token_kv): 496 score = score + kv_scale[token_kv] 497 score = score + q_scale[token_q] 498 score = score + head_scale[head] 499 score = score + batch_scale[batch] 500 return score 501 502 self.run_test(all_bias, dtype) 503 504 @supported_platform 505 @common_utils.parametrize("dtype", test_dtypes_fast) 506 def test_seq_masking(self, dtype): 507 seq_idx = torch.zeros(S, device="cuda", dtype=torch.bool) 508 seq_idx[S // 2 :] = 1 509 510 def seq_mask_mod(score, b, h, q, kv): 511 return torch.where(seq_idx[q] == seq_idx[kv], score, float("-inf")) 512 513 self.run_test(seq_mask_mod, dtype) 514 515 @supported_platform 516 @common_utils.parametrize("dtype", test_dtypes_fast) 517 def test_load_from_bias_seq_only(self, dtype): 518 bias = torch.randn(1, S, device="cuda", dtype=dtype) 519 520 def bias_mod(score, b, h, q, kv): 521 return score + bias[q, kv] 522 523 self.run_test(bias_mod, dtype) 524 525 @supported_platform 526 @common_utils.parametrize("dtype", test_dtypes_fast) 527 def test_load_from_bias_seq_batch(self, dtype): 528 bias = torch.randn(B, 1, S, device="cuda", dtype=dtype) 529 530 def bias_mod(score, b, h, q, kv): 531 return score + bias[b, q, kv] 532 533 self.run_test(bias_mod, dtype) 534 535 @supported_platform 536 @common_utils.parametrize("dtype", test_dtypes_fast) 537 def test_load_from_bias_head_seq_batch(self, dtype): 538 bias = torch.randn( 539 B, 540 Hq, 541 1, 542 S, 543 device="cuda", 544 dtype=dtype, 545 ) 546 547 def bias_mod(score, b, h, q, kv): 548 return score + bias[b, h, q, kv] 549 550 self.run_test(bias_mod, dtype) 551 552 # TODO this config segfaults with Triton without: 553 # https://github.com/triton-lang/triton/pull/4540 554 @supported_platform 555 @common_utils.parametrize("score_mod", test_score_mods) 556 @common_utils.parametrize("dtype", test_dtypes) 557 @common_utils.parametrize("head_dims", [(D, D // 2), (D // 2, D)]) 558 def test_non_equal_head_dims(self, dtype, score_mod, head_dims): 559 qk_d, v_d = head_dims 560 context = nullcontext() if qk_d > v_d else self.assertRaises(ValueError) 561 with context: 562 self.run_test(score_mod, dtype, B, Hq, 1, qk_d, B, Hkv, S, V_D=v_d) 563 564 @supported_platform 565 @common_utils.parametrize("dtype", test_dtypes_fast) 566 def test_subgraph_respect_decompostion(self, dtype): 567 from torch._decomp import core_aten_decompositions 568 from torch.fx.experimental.proxy_tensor import make_fx 569 570 def score_mod_func(score, b, h, q, kv): 571 return score - q // (1 + kv) 572 573 make_kv = functools.partial( 574 torch.randn, 575 (2, 2, 128, 4), 576 dtype=dtype, 577 device="cuda", 578 requires_grad=True, 579 ) 580 make_q = functools.partial( 581 torch.randn, 582 (2, 2, 8, 4), 583 dtype=dtype, 584 device="cuda", 585 requires_grad=True, 586 ) 587 query, key, value = make_q(), make_kv(), make_kv() 588 # floor_div is not decomposed in decompostion_table is empty 589 attention = functools.partial(flex_attention, score_mod=score_mod_func) 590 gm = make_fx(attention, decomposition_table={})(query, key, value) 591 self.assertExpectedInline( 592 gm.sdpa_score0.code.strip(), 593 """\ 594def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): 595 add = torch.ops.aten.add.Tensor(arg4_1, 1); arg4_1 = None 596 floor_divide = torch.ops.aten.floor_divide.default(arg3_1, add); arg3_1 = add = None 597 sub = torch.ops.aten.sub.Tensor(arg0_1, floor_divide); arg0_1 = floor_divide = None 598 return sub""", 599 ) 600 601 # floor_div is decomposed for core_aten_decompositions 602 gm = make_fx(attention, decomposition_table=core_aten_decompositions())( 603 query, key, value 604 ) 605 self.assertExpectedInline( 606 gm.sdpa_score0.code.strip(), 607 """\ 608def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): 609 add = torch.ops.aten.add.Tensor(arg4_1, 1); arg4_1 = None 610 div = torch.ops.aten.div.Tensor_mode(arg3_1, add, rounding_mode = 'floor'); arg3_1 = add = None 611 sub = torch.ops.aten.sub.Tensor(arg0_1, div); arg0_1 = div = None 612 return sub""", 613 ) 614 615 @supported_platform 616 @common_utils.parametrize("dtype", test_dtypes_fast) 617 def test_silu_on_score(self, dtype): 618 def silu_score(score, b, h, q, kv): 619 return torch.nn.functional.silu(score) 620 621 self.run_test(silu_score, dtype) 622 623 @supported_platform 624 @common_utils.parametrize("dtype", test_dtypes_fast) 625 def test_padded_dense_causal(self, dtype): 626 seq_len = torch.arange(B, device="cuda", dtype=torch.int32) + 1 627 628 def create_padded_dense_wrapper(orig_score_mod): 629 def njt_score_mod(qk, b, h, q, kv): 630 return torch.where( 631 qk <= seq_len[b], orig_score_mod(qk, b, h, q, kv), -float("inf") 632 ) 633 634 return njt_score_mod 635 636 causal_njt = create_padded_dense_wrapper(_causal) 637 638 self.run_test(causal_njt, dtype) 639 640 @supported_platform 641 @common_utils.parametrize("dtype", test_dtypes_fast) 642 def test_captured_scale(self, dtype): 643 scale = torch.ones((), device="cuda", dtype=torch.int32) 644 645 def score_mod_scale(qk, b, h, q, kv): 646 return qk + scale 647 648 self.run_test(score_mod_scale, dtype) 649 650 @supported_platform 651 @common_utils.parametrize("dtype", test_dtypes_fast) 652 def test_recompile_changed_score_mod(self, dtype): 653 scale = torch.ones((), device="cuda", dtype=torch.int32) 654 ADD = True 655 656 def score_mod_scale(qk, b, h, q, kv): 657 if ADD: 658 return qk + scale 659 else: 660 return qk * scale 661 662 self.run_test(score_mod_scale, dtype) 663 ADD = False 664 self.run_test(score_mod_scale, dtype) 665 666 @supported_platform 667 @expectedFailure # If we capture a tensor then we can perform a reduction on it, and that shouldn't be allowed 668 @common_utils.parametrize("dtype", test_dtypes_fast) 669 def test_captured_reduction(self, dtype): 670 scale = torch.randn((B, 8), device="cuda") 671 672 def score_mod_scale(qk, b, h, q, kv): 673 return qk + scale[b].sum(dim=-1) 674 675 self.run_test(score_mod_scale, dtype) 676 677 @supported_platform 678 def test_multiple_score_mod_calls(self): 679 query = torch.randn((1, 8, 4, 64), dtype=torch.float32, device="cuda") 680 keys = [ 681 torch.randn((1, 8, 1024, 64), dtype=torch.float32, device="cuda") 682 for _ in range(2) 683 ] 684 values = [ 685 torch.randn((1, 8, 1024, 64), dtype=torch.float32, device="cuda") 686 for _ in range(2) 687 ] 688 689 def scoremod_1(qk, b, h, q, kv): 690 return qk + (q - kv) 691 692 def scoremod_2(qk, b, h, q, kv): 693 return torch.where(q >= kv, qk, -float("inf")) 694 695 def f(q, k1, k2, v1, v2): 696 q2 = flex_attention(q, k1, v1, score_mod=scoremod_1) 697 return flex_attention(q2, k2, v2, score_mod=scoremod_2) 698 699 out = f(query, *keys, *values) 700 out2 = torch.compile(f)(query, *keys, *values) 701 tolerance = Tolerances(atol=2e-1, rtol=2e-1) 702 torch.testing.assert_close(out, out2, atol=tolerance.atol, rtol=tolerance.rtol) 703 704 @supported_platform 705 def test_multiple_score_mod_calls2(self): 706 query = torch.randn((1, 8, 4, 64), dtype=torch.float32, device="cuda") 707 keys = [ 708 torch.randn((1, 8, 1024, 64), dtype=torch.float32, device="cuda") 709 for _ in range(3) 710 ] 711 values = [ 712 torch.randn((1, 8, 1024, 64), dtype=torch.float32, device="cuda") 713 for _ in range(3) 714 ] 715 716 def scoremod_1(qk, b, h, q, kv): 717 return qk + (q - kv) 718 719 def scoremod_2(qk, b, h, q, kv): 720 return torch.where(q >= kv, qk, -float("inf")) 721 722 attention1 = functools.partial(flex_attention, score_mod=scoremod_1) 723 724 def f(q, k1, k2, k3, v1, v2, v3): 725 q2 = attention1(q, k1, v1) 726 q3 = flex_attention(q2, k2, v2, score_mod=scoremod_2) 727 return flex_attention(q3, k3, v3, score_mod=scoremod_1) 728 729 out = f(query, *keys, *values) 730 out2 = torch.compile(f)(query, *keys, *values) 731 self.assertTrue((out - out2).abs().mean() < 1e-2) 732 733 @supported_platform 734 @common_utils.parametrize("dtype", test_dtypes) 735 def test_njt_causal(self, dtype): 736 offsets = torch.tensor( 737 [0, 1024, 1024 + 512, S], device="cuda", dtype=torch.int32 738 ) 739 seq_idx = torch.zeros(S, device="cuda", dtype=torch.int32) 740 for idx in range(len(offsets) - 1): 741 seq_idx[offsets[idx] : offsets[idx + 1]] = idx 742 743 def create_njt_wrapper(orig_score_mod, offsets, seq_idx): 744 def njt_score_mod(qk, b, h, q, kv): 745 q_nested = q - offsets[seq_idx[q]] 746 kv_nested = kv - offsets[seq_idx[kv]] 747 return orig_score_mod(qk, b, h, q_nested, kv_nested) 748 749 return njt_score_mod 750 751 causal_njt = create_njt_wrapper(_causal, offsets, seq_idx) 752 753 self.run_test(causal_njt, dtype) 754 755 @supported_platform 756 def test_mixed_dtypes_fails(self): 757 query = torch.randn((1, 1, 8, 64), dtype=torch.float32, device="cuda") 758 key = torch.randn((1, 1, 1024, 64), dtype=torch.float16, device="cuda") 759 value = torch.randn((1, 1, 1024, 64), dtype=torch.float16, device="cuda") 760 with self.assertRaisesRegex( 761 ValueError, "Expected query, key, and value to have the same dtype" 762 ): 763 flex_attention(query, key, value, _identity) 764 765 @supported_platform 766 @patch.object(torch._inductor.config, "max_autotune", True) 767 def test_max_autotune(self): 768 def score_mod(score, b, h, m, n): 769 return score * 2 770 771 self.run_test(score_mod) 772 773 @supported_platform 774 @patch.object(torch._inductor.config, "max_autotune", True) 775 def test_max_autotune_with_captured(self): 776 head_scale = torch.randn(Hq, device="cuda") 777 batch_scale = torch.randn(B, device="cuda") 778 tok_scale = torch.randn(S, device="cuda") 779 q_scale = torch.randn(1, device="cuda") 780 781 def bias_mod(score, batch, head, token_q, token_kv): 782 score = score + tok_scale[token_kv] 783 score = score + q_scale[token_q] 784 score = score + batch_scale[batch] 785 score = score + head_scale[head] 786 return score 787 788 self.run_test(bias_mod) 789 790 @skipIfRocm 791 @supported_platform 792 def test_fully_masked_out_rows_0_check_gqa(self): 793 # Ensure fully masked out rows won't cause NaNs. 794 query = torch.randn( 795 (B, Hq, S, D), dtype=torch.float32, device="cuda", requires_grad=True 796 ) 797 key = torch.randn( 798 (B, Hkv, S, D), dtype=torch.float32, device="cuda", requires_grad=True 799 ) 800 value = torch.randn( 801 (B, Hkv, S, D), dtype=torch.float32, device="cuda", requires_grad=True 802 ) 803 804 M = S // 2 805 806 def mask_mod(b, h, q, kv): 807 return q < M 808 809 block_mask = create_block_mask(mask_mod, 1, 1, S, S) 810 811 flex = torch.compile(flex_attention, dynamic=False) 812 813 out, lse = flex( 814 query, key, value, block_mask=block_mask, enable_gqa=True, return_lse=True 815 ) 816 self.assertEqual(out[:, :, M:, :].sum(), 0) 817 self.assertTrue((lse[:, :, M:] == -float("inf")).all()) 818 819 loss = out.sum() + lse.sum() 820 loss.backward() 821 self.assertEqual(query.grad[:, :, M:, :].sum(), 0) 822 823 @supported_platform 824 def test_windowed_no_mask_vs_sdpa(self): 825 score_mod = _generate_windowed(1000) 826 attention = functools.partial(flex_attention, score_mod=score_mod) 827 828 sdpa_mask = _get_windowed_sdpa_mask(8, S, 1000) 829 830 sdpa_attention = functools.partial( 831 torch.nn.functional.scaled_dot_product_attention, attn_mask=sdpa_mask 832 ) 833 834 self.run_test_with_call(attention, sdpa_attention, Q_H=16, KV_H=16, Q_S=8) 835 836 @supported_platform 837 def test_windowed_full_mask_vs_sdpa(self): 838 def mask_mod(b, h, q, kv): 839 return q + 1000 >= kv 840 841 score_mod = _generate_windowed(1000) 842 843 block_mask = create_block_mask(mask_mod, 1, 1, 8, S) 844 attention = functools.partial( 845 flex_attention, block_mask=block_mask, score_mod=score_mod 846 ) 847 848 sdpa_mask = _get_windowed_sdpa_mask(8, S, 1000) 849 sdpa_attention = functools.partial( 850 torch.nn.functional.scaled_dot_product_attention, attn_mask=sdpa_mask 851 ) 852 853 self.run_test_with_call(attention, sdpa_attention, Q_H=16, KV_H=16, Q_S=8) 854 855 @supported_platform 856 def test_windowed_partial_block_vs_sdpa(self): 857 def mask_mod(b, h, q, kv): 858 return q + 1000 >= kv 859 860 block_mask = create_block_mask(mask_mod, 1, 1, 8, S) 861 attention = functools.partial(flex_attention, block_mask=block_mask) 862 863 sdpa_mask = _get_windowed_sdpa_mask(8, S, 1000) 864 sdpa_attention = functools.partial( 865 torch.nn.functional.scaled_dot_product_attention, attn_mask=sdpa_mask 866 ) 867 868 self.run_test_with_call(attention, sdpa_attention, Q_H=16, KV_H=16, Q_S=8) 869 870 @supported_platform 871 @common_utils.parametrize("dtype", test_dtypes) 872 @common_utils.parametrize("score_mod", [_identity, _causal]) 873 def test_logsumexp_correctness(self, dtype, score_mod): 874 make_kv = functools.partial( 875 torch.randn, 876 (B, Hkv, S, D), 877 dtype=dtype, 878 device="cuda", 879 requires_grad=True, 880 ) 881 make_q = functools.partial( 882 torch.randn, 883 (B, Hkv, Hq // Hkv, D), 884 dtype=dtype, 885 device="cuda", 886 requires_grad=True, 887 ) 888 q, k, v = make_q(), make_kv(), make_kv() 889 890 @torch.compile 891 def sdpa_hop(q, k, v, score_mod): 892 return flex_attention(q, k, v, score_mod, return_lse=True) 893 894 @torch.compile(backend="aot_eager") 895 def eager_sdpa_hop(q, k, v, score_mod): 896 return flex_attention(q, k, v, score_mod, return_lse=True) 897 898 ref_out, ref_lse = eager_sdpa_hop( 899 q.to(torch.float64), 900 k.to(torch.float64), 901 v.to(torch.float64), 902 score_mod, 903 ) 904 compiled_out, compiled_lse = sdpa_hop(q, k, v, score_mod) 905 906 self.assertTrue(ref_lse.dtype == torch.float64) 907 self.assertTrue(compiled_lse.dtype == torch.float32) 908 909 tolerance = Tolerances(atol=2e-2, rtol=2e-2) 910 torch.testing.assert_close( 911 ref_out.to(dtype=torch.float32), 912 compiled_out.to(dtype=torch.float32), 913 atol=tolerance.atol, 914 rtol=tolerance.rtol, 915 ) 916 torch.testing.assert_close( 917 ref_lse.to(dtype=torch.float32), 918 compiled_lse.to(dtype=torch.float32), 919 atol=tolerance.atol, 920 rtol=tolerance.rtol, 921 ) 922 923 @supported_platform 924 def test_logsumexp_only_return(self): 925 make_q = functools.partial( 926 torch.randn, 927 (B, Hkv, Hq // Hkv, D), 928 dtype=torch.float32, 929 device="cuda", 930 requires_grad=True, 931 ) 932 make_kv = functools.partial( 933 torch.randn, 934 (B, Hkv, S, D), 935 dtype=torch.float32, 936 device="cuda", 937 requires_grad=True, 938 ) 939 940 q, k, v = make_q(), make_kv(), make_kv() 941 942 @torch.compile 943 def func(q, k, v, score_mod): 944 _, lse = flex_attention(q, k, v, score_mod, return_lse=True) 945 lse_2 = lse * 2 946 return lse_2 947 948 _, code = run_and_get_code(func, q, k, v, _identity) 949 # Ensure that we're still generating the flexattention kernel 950 FileCheck().check_count(".run(primals_1, primals_2, primals_3", 1, True).run( 951 code[0] 952 ) 953 954 @supported_platform 955 def test_non_sparse_mulitple_block_size(self): 956 def generate_causal_offset(offset: torch.Tensor): 957 def causal_offset_mask(b, h, q_idx, kv_idx): 958 return (offset + q_idx) >= kv_idx 959 960 return causal_offset_mask 961 962 def noop(score, b, h, q_idx, kv_idx): 963 return score 964 965 mod = generate_causal_offset( 966 torch.tensor(192, device="cuda", dtype=torch.int32) 967 ) 968 block_mask = create_block_mask(mod, 1, 1, 1, 65) 969 970 self.run_test( 971 score_mod=None, 972 dtype=torch.float32, 973 block_mask=block_mask, 974 Q_B=1, 975 Q_H=1, 976 Q_S=1, 977 Q_D=16, 978 KV_B=1, 979 KV_H=1, 980 KV_S=65, 981 V_D=16, 982 ) 983 984 @supported_platform 985 def test_do_not_trigger_dynamic_shapes_on_empty_block_mask(self): 986 torch._dynamo.reset() 987 H = Hq 988 q = torch.randn(B, H, 1, D, device="cuda") 989 for i in range(5): 990 k = torch.randn(B, H, S + i, D, device="cuda") 991 v = torch.randn(B, H, S + i, D, device="cuda") 992 compiled_flex_attention = torch.compile(flex_attention) 993 ref = flex_attention(q, k, v) 994 res = compiled_flex_attention(q, k, v) 995 tolerance = Tolerances(atol=2e-1, rtol=2e-1) 996 torch.testing.assert_close( 997 ref, res, atol=tolerance.atol, rtol=tolerance.rtol 998 ) 999 # Ensure no more re-compilation after the second automatic dynamic shape version. 1000 if i == 0: 1001 self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 1) 1002 else: 1003 self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 2) 1004 1005 1006common_utils.instantiate_parametrized_tests(TestFlexDecoding) 1007 1008if __name__ == "__main__": 1009 from torch._inductor.test_case import run_tests 1010 1011 run_tests() 1012