1# Owner(s): ["module: inductor"] 2# flake8: noqa: B950 3 4import functools 5from collections import namedtuple 6from typing import Callable, Optional 7 8from unittest import expectedFailure, skip, skipUnless 9from unittest.mock import patch 10 11import torch 12 13from torch._dynamo.testing import CompileCounterWithBackend, normalize_gm 14from torch._higher_order_ops.flex_attention import flex_attention as flex_attention_hop 15from torch._inductor import metrics 16from torch._inductor.test_case import TestCase as InductorTestCase 17from torch._inductor.utils import run_and_get_code 18from torch.nn.attention._flex_attention import ( 19 _causal, 20 _compose, 21 _flex_attention, 22 _generate_alibi_bias, 23 _identity, 24 _rel_bias, 25 _rel_causal, 26) 27from torch.testing import FileCheck 28from torch.testing._internal import common_utils 29from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_BF16 30from torch.utils._triton import has_triton 31 32# Skip tests if Triton is not available 33supported_platform = skipUnless( 34 torch.cuda.is_available() 35 and has_triton() 36 and torch.version.hip is None 37 and torch.cuda.get_device_capability() >= (8, 0), 38 "Requires CUDA and Triton", 39) 40 41Tolerances = namedtuple("Tolerances", ["atol", "rtol"]) 42torch.set_float32_matmul_precision("high") 43 44index = torch.ops.aten.index 45 46 47def create_attention(score_mod): 48 return functools.partial(_flex_attention, score_mod=score_mod) 49 50 51test_dtypes = ( 52 [torch.float16, torch.bfloat16, torch.float32] 53 if PLATFORM_SUPPORTS_BF16 54 else [torch.float16, torch.float32] 55) 56 57test_dtypes_fast = [torch.float16] 58 59# TODO float16 was causing ERRORs for tests on ROCm 60# See https://github.com/pytorch/pytorch/issues/123531 61if common_utils.TEST_WITH_ROCM: 62 test_dtypes = [torch.float32] 63 64 65# --------- Useful score mod functions for testing --------- 66def _inverse_causal(score, b, h, m, n): 67 return torch.where(m <= n, score, float("-inf")) 68 69 70def _times_two(score, b, h, m, n): 71 """Joint graph needed for correctness""" 72 return score * 2 73 74 75def _squared(score, b, h, m, n): 76 """Joint graph needed for correctness""" 77 return score * score 78 79 80def _head_offset(dtype: torch.dtype): 81 """Captured Buffer""" 82 head_offset = torch.rand(H, device="cuda", dtype=dtype) 83 84 def score_mod(score, b, h, m, n): 85 return score * head_offset[h] 86 87 return score_mod 88 89 90def _trig(score, b, h, m, n): 91 """Joint graph needed for correctness""" 92 return torch.sin(torch.cos(score)) + torch.tan(b) 93 94 95def _trig2(score, b, h, m, n): 96 """Branching joint graph""" 97 cos_score = torch.cos(score) 98 sin_score = torch.sin(score) 99 z = cos_score * sin_score + torch.tan(b) 100 return z 101 102 103test_score_mods = [ 104 _identity, 105 _times_two, 106 _squared, 107 _causal, 108 _inverse_causal, 109 _rel_bias, 110 _rel_causal, 111 _generate_alibi_bias(8), 112] 113 114captured_buffers_map = { 115 "_head_offset": _head_offset, 116} 117 118B = 4 119H = 8 120S = 2048 121D = 64 122 123 124def query_key_value_clones( 125 query: torch.Tensor, 126 key: torch.Tensor, 127 value: torch.Tensor, 128 dtype: torch.dtype = None, 129): 130 """Clones the query, key, and value tensors and moves them to the specified dtype.""" 131 if dtype is None: 132 dtype = query.dtype 133 query_ref = query.clone().detach().to(dtype).requires_grad_(query.requires_grad) 134 key_ref = key.clone().detach().to(dtype).requires_grad_(key.requires_grad) 135 value_ref = value.clone().detach().to(dtype).requires_grad_(value.requires_grad) 136 return query_ref, key_ref, value_ref 137 138 139class TestFlexAttention(InductorTestCase): 140 def _check_equal( 141 self, 142 golden_out: torch.Tensor, 143 ref_out: torch.Tensor, 144 compiled_out: torch.Tensor, 145 fudge_factor: float, 146 tensor_name: Optional[str] = None, 147 ): 148 compiled_error = (golden_out - compiled_out).abs().mean() 149 ref_error = (golden_out - ref_out).abs().mean() 150 if torch.isnan(compiled_error).any() and not torch.isnan(ref_error).any(): 151 self.assertTrue(False, "Output/Grad with NaN") 152 if compiled_error > ref_error * fudge_factor: 153 name = tensor_name if tensor_name is not None else "" 154 msg = f"{name} Compiled error {compiled_error} is greater than ref error {ref_error} by more than {fudge_factor}X." 155 self.assertTrue(False, msg) 156 157 def _check_out_and_grad( 158 self, 159 golden_out: torch.Tensor, 160 ref_out: torch.Tensor, 161 compiled_out: torch.Tensor, 162 q_gold: torch.Tensor, 163 q_ref: torch.Tensor, 164 q: torch.Tensor, 165 k_gold: torch.Tensor, 166 k_ref: torch.Tensor, 167 k: torch.Tensor, 168 v_gold: torch.Tensor, 169 v_ref: torch.Tensor, 170 v: torch.Tensor, 171 ): 172 dtype = ref_out.dtype 173 with torch.no_grad(): 174 # Note, it seems like we really are less accurate than the float32 175 # computation, likely due to the online softmax 176 if dtype == torch.float32: 177 fudge_factor = 10.0 178 else: 179 fudge_factor = 1.1 180 181 # Checkout output 182 self._check_equal(golden_out, ref_out, compiled_out, fudge_factor, "Out") 183 184 # Check gradients 185 q_fudge_factor = 2.5 * fudge_factor 186 self._check_equal( 187 q_gold.grad, q_ref.grad, q.grad, q_fudge_factor, "Grad_Query" 188 ) 189 k_fudge_factor = 4 * fudge_factor 190 self._check_equal( 191 k_gold.grad, k_ref.grad, k.grad, k_fudge_factor, "Grad_Key" 192 ) 193 v_fudge_factor = 4 * fudge_factor 194 self._check_equal( 195 v_gold.grad, v_ref.grad, v.grad, v_fudge_factor, "Grad_Value" 196 ) 197 198 def run_test( 199 self, 200 score_mod: Callable, 201 dtype: torch.dtype = torch.float16, 202 Q_B: int = B, 203 Q_H: int = H, 204 Q_S: int = S, 205 Q_D: int = D, 206 KV_B: int = B, 207 KV_H: int = H, 208 KV_S: int = S, 209 KV_D: int = D, 210 ): 211 sdpa_partial = create_attention(score_mod) 212 compiled_sdpa = torch.compile(sdpa_partial) 213 q = torch.randn( 214 (Q_B, Q_H, Q_S, Q_D), dtype=dtype, device="cuda", requires_grad=True 215 ) 216 k = torch.randn( 217 (KV_B, KV_H, KV_S, KV_D), dtype=dtype, device="cuda", requires_grad=True 218 ) 219 v = torch.randn( 220 (KV_B, KV_H, KV_S, KV_D), dtype=dtype, device="cuda", requires_grad=True 221 ) 222 q_ref, k_ref, v_ref = query_key_value_clones(q, k, v) 223 q_gold, k_gold, v_gold = query_key_value_clones(q, k, v, torch.float64) 224 golden_out = sdpa_partial(q_gold, k_gold, v_gold) 225 ref_out = sdpa_partial(q_ref, k_ref, v_ref) 226 compiled_out = compiled_sdpa(q, k, v) 227 228 backward_grad = torch.randn((Q_B, Q_H, Q_S, Q_D), dtype=dtype, device="cuda") 229 230 golden_out.backward(backward_grad.to(torch.float64)) 231 ref_out.backward(backward_grad) 232 compiled_out.backward(backward_grad) 233 234 self._check_out_and_grad( 235 golden_out, 236 ref_out, 237 compiled_out, 238 q_gold, 239 q_ref, 240 q, 241 k_gold, 242 k_ref, 243 k, 244 v_gold, 245 v_ref, 246 v, 247 ) 248 249 def run_dynamic_test( 250 self, 251 score_mod: Callable, 252 dtype: torch.dtype = torch.float16, 253 B: int = B, 254 H: int = H, 255 S: int = S, 256 D: int = D, 257 ): 258 sdpa_partial = create_attention(score_mod) 259 # The first eager batch, shape (B, H, S, D) 260 q1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True) 261 k1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True) 262 v1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True) 263 q1_ref, k1_ref, v1_ref = query_key_value_clones(q1, k1, v1) 264 q1_gold, k1_gold, v1_gold = query_key_value_clones(q1, k1, v1, torch.float64) 265 ref_out1 = sdpa_partial(q1_ref, k1_ref, v1_ref) 266 golden_out1 = sdpa_partial(q1_gold, k1_gold, v1_gold) 267 268 backward_grad1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda") 269 270 golden_out1.backward(backward_grad1.to(torch.float64)) 271 ref_out1.backward(backward_grad1) 272 273 # The second eager batch, shape (B * 2, H, S / 2, D) 274 B = int(B * 2) 275 S = int(S / 2) 276 q2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True) 277 k2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True) 278 v2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True) 279 q2_ref, k2_ref, v2_ref = query_key_value_clones(q2, k2, v2) 280 q2_gold, k2_gold, v2_gold = query_key_value_clones(q2, k2, v2, torch.float64) 281 ref_out2 = sdpa_partial(q2_ref, k2_ref, v2_ref) 282 golden_out2 = sdpa_partial(q2_gold, k2_gold, v2_gold) 283 284 backward_grad2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda") 285 286 golden_out2.backward(backward_grad2.to(torch.float64)) 287 ref_out2.backward(backward_grad2) 288 289 # Need to clear dynamo counters, since flex attention eager mode also uses dynamo tracing. 290 # We check dynamo counters["frames"]["ok"] to ensure there is no re-compilation. 291 torch._dynamo.reset() 292 # Compiling with dynamic shape in the first batch. 293 compiled_sdpa = torch.compile(sdpa_partial, dynamic=True) 294 compiled_out1 = compiled_sdpa(q1, k1, v1) 295 compiled_out1.backward(backward_grad1) 296 297 self._check_out_and_grad( 298 golden_out1, 299 ref_out1, 300 compiled_out1, 301 q1_gold, 302 q1_ref, 303 q1, 304 k1_gold, 305 k1_ref, 306 k1, 307 v1_gold, 308 v1_ref, 309 v1, 310 ) 311 self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 1) 312 313 # No re-compilation, use the compiled dynamic shape version. 314 compiled_out2 = compiled_sdpa(q2, k2, v2) 315 compiled_out2.backward(backward_grad2) 316 self._check_out_and_grad( 317 golden_out2, 318 ref_out2, 319 compiled_out2, 320 q2_gold, 321 q2_ref, 322 q2, 323 k2_gold, 324 k2_ref, 325 k2, 326 v2_gold, 327 v2_ref, 328 v2, 329 ) 330 self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 1) 331 332 def run_automatic_dynamic_test( 333 self, 334 score_mod: Callable, 335 dtype: torch.dtype = torch.float16, 336 B: int = B, 337 H: int = H, 338 S: int = S, 339 D: int = D, 340 ): 341 sdpa_partial = create_attention(score_mod) 342 # The first eager batch, shape (B, H, S, D) 343 q1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda") 344 k1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda") 345 v1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda") 346 golden_out1 = sdpa_partial( 347 q1.to(torch.float64), k1.to(torch.float64), v1.to(torch.float64) 348 ) 349 ref_out1 = sdpa_partial(q1, k1, v1) 350 351 # The second eager batch, shape (B * 2, H, S / 2, D) 352 B = int(B * 2) 353 S = int(S / 2) 354 q2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda") 355 k2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda") 356 v2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda") 357 golden_out2 = sdpa_partial( 358 q2.to(torch.float64), k2.to(torch.float64), v2.to(torch.float64) 359 ) 360 ref_out2 = sdpa_partial(q2, k2, v2) 361 362 # The third eager batch, shape (B * 4, H, S / 4, D) 363 B = int(B * 2) 364 S = int(S / 2) 365 q3 = torch.randn((B, H, S, D), dtype=dtype, device="cuda") 366 k3 = torch.randn((B, H, S, D), dtype=dtype, device="cuda") 367 v3 = torch.randn((B, H, S, D), dtype=dtype, device="cuda") 368 golden_out3 = sdpa_partial( 369 q3.to(torch.float64), k3.to(torch.float64), v3.to(torch.float64) 370 ) 371 ref_out3 = sdpa_partial(q3, k3, v3) 372 373 # Need to clear dynamo counters, since flex attention eager mode also uses dynamo tracing. 374 # We check dynamo counters["frames"]["ok"] to ensure: 375 # 1, the first batch is compiled with static shape 376 # 2, the second batch is compiled with dynamic shape 377 # 3, no re-compilation in the third batch 378 torch._dynamo.reset() 379 380 # Note, it seems like we really are less accurate than the float32 381 # computation, likely due to the online softmax 382 if dtype == torch.float32: 383 fudge_factor = 10.0 384 else: 385 fudge_factor = 1.1 386 387 # The first batch. 388 compiled_sdpa = torch.compile(sdpa_partial) 389 compiled_out1 = compiled_sdpa(q1, k1, v1) 390 self._check_equal(golden_out1, ref_out1, compiled_out1, fudge_factor) 391 self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 1) 392 393 # The second batch (automatic dynamic). 394 compiled_out2 = compiled_sdpa(q2, k2, v2) 395 self._check_equal(golden_out2, ref_out2, compiled_out2, fudge_factor) 396 self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 2) 397 398 # The third batch (no re-compilation). 399 compiled_out3 = compiled_sdpa(q3, k3, v3) 400 self._check_equal(golden_out3, ref_out3, compiled_out3, fudge_factor) 401 self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 2) 402 403 @supported_platform 404 @common_utils.parametrize("dtype", test_dtypes) 405 @common_utils.parametrize("score_mod", test_score_mods) 406 def test_builtin_score_mods(self, dtype: torch.dtype, score_mod: Callable): 407 self.run_test(score_mod, dtype) 408 409 @supported_platform 410 @common_utils.parametrize("dtype", test_dtypes) 411 @common_utils.parametrize("score_mod", test_score_mods) 412 def test_builtin_score_mods_dynamic(self, dtype: torch.dtype, score_mod: Callable): 413 self.run_dynamic_test(score_mod, dtype) 414 415 @supported_platform 416 @common_utils.parametrize("dtype", test_dtypes) 417 @common_utils.parametrize("score_mod", test_score_mods) 418 def test_builtin_score_mods_automatic_dynamic( 419 self, dtype: torch.dtype, score_mod: Callable 420 ): 421 self.run_automatic_dynamic_test(score_mod, dtype) 422 423 @supported_platform 424 @common_utils.parametrize("dtype", test_dtypes_fast) 425 @common_utils.parametrize("score_mod", test_score_mods) 426 def test_builtin_score_mods_different_seqlen( 427 self, dtype: torch.dtype, score_mod: Callable 428 ): 429 self.run_test( 430 score_mod, 431 dtype, 432 B, 433 H, 434 S // 2, # Seqlen of Q is different from seqlen of K/V 435 D, 436 B, 437 H, 438 S, 439 D, 440 ) 441 442 @supported_platform 443 @common_utils.parametrize("dtype", test_dtypes) 444 def test_skip_odd_keys(self, dtype: torch.dtype): 445 def score_mod(score, b, h, q, kv): 446 return torch.where(kv % 2 == 0, score, float("-inf")) 447 448 self.run_test(score_mod, dtype) 449 450 @supported_platform 451 @common_utils.parametrize("dtype", test_dtypes) 452 def test_function_composition(self, dtype: torch.dtype): 453 def score_mod_1(score, b, h, m, n): 454 return score + (m - n) 455 456 def score_mod_2(score, b, h, m, n): 457 return torch.where(m <= n, score, float("-inf")) 458 459 composed_score_mod = _compose(score_mod_1, score_mod_2) 460 461 self.run_test(composed_score_mod, dtype) 462 463 @supported_platform 464 @common_utils.parametrize("dtype", test_dtypes) 465 def test_captured_buffers(self, dtype: torch.dtype): 466 head_offset = torch.rand(H, device="cuda", dtype=dtype) 467 468 def score_mod(score, b, h, m, n): 469 return score + head_offset[h] 470 471 self.run_test(score_mod, dtype) 472 473 @supported_platform 474 @common_utils.parametrize("dtype", test_dtypes) 475 def test_captured_buffers_all_dims(self, dtype: torch.dtype): 476 head_scale = torch.randn(H, device="cuda") 477 batch_scale = torch.randn(B, device="cuda") 478 tok_scale = torch.randn(S, device="cuda") 479 480 def all_bias(score, batch, head, token_q, token_kv): 481 score = score + tok_scale[token_q] 482 score = score + batch_scale[batch] 483 score = score + head_scale[head] 484 return score 485 486 self.run_test(all_bias, dtype) 487 488 @supported_platform 489 @common_utils.parametrize("dtype", test_dtypes_fast) 490 def test_seq_masking(self, dtype): 491 seq_idx = torch.zeros(S, device="cuda", dtype=torch.bool) 492 seq_idx[S // 2 :] = 1 493 494 def seq_mask_mod(score, b, h, q, kv): 495 return torch.where(seq_idx[q] == seq_idx[kv], score, float("-inf")) 496 497 self.run_test(seq_mask_mod, dtype) 498 499 @supported_platform 500 @common_utils.parametrize("dtype", test_dtypes_fast) 501 def test_load_from_bias_seq_only(self, dtype): 502 bias = torch.randn(S, S, device="cuda", dtype=dtype) 503 504 def bias_mod(score, b, h, q, kv): 505 return score + bias[q, kv] 506 507 self.run_test(bias_mod, dtype) 508 509 @supported_platform 510 @common_utils.parametrize("dtype", test_dtypes_fast) 511 def test_load_from_bias_seq_batch(self, dtype): 512 bias = torch.randn(B, S, S, device="cuda", dtype=dtype) 513 514 def bias_mod(score, b, h, q, kv): 515 return score + bias[b, q, kv] 516 517 self.run_test(bias_mod, dtype) 518 519 @supported_platform 520 @common_utils.parametrize("dtype", test_dtypes_fast) 521 def test_load_from_bias_head_seq_batch(self, dtype): 522 bias = torch.randn(B, H, S, S, device="cuda", dtype=dtype) 523 524 def bias_mod(score, b, h, q, kv): 525 return score + bias[b, h, q, kv] 526 527 self.run_test(bias_mod, dtype) 528 529 @supported_platform 530 @common_utils.parametrize("dtype", test_dtypes_fast) 531 def test_load_rel_bias(self, dtype): 532 rel_bias = torch.randn(2 * S, device="cuda", dtype=dtype) 533 534 def bias_mod(score, b, h, q, kv): 535 return score + rel_bias[(q - kv) + S] 536 537 self.run_test(bias_mod, dtype) 538 539 @supported_platform 540 @common_utils.parametrize("dtype", test_dtypes_fast) 541 def test_dependent_causal_bidirectional(self, dtype): 542 num_bidirectional = torch.randint(0, S, (B,), device="cuda", dtype=torch.int32) 543 544 def bias_mod(score, b, h, q, kv): 545 causal_attention = q >= kv 546 cur_num_bidirectional = num_bidirectional[b] 547 bidirectional_attention_on_video = (q <= cur_num_bidirectional) & ( 548 kv <= cur_num_bidirectional 549 ) 550 return torch.where( 551 bidirectional_attention_on_video | causal_attention, 552 score, 553 -float("inf"), 554 ) 555 556 self.run_test(bias_mod, dtype) 557 558 @supported_platform 559 @common_utils.parametrize("dtype", test_dtypes_fast) 560 def test_natten_2d(self, dtype): 561 H = 32 562 W = S // H 563 WINDOW = 3 564 assert W * H == S 565 566 def get_x_y(idx): 567 # This should be a floor divide, but we don't support that properly 568 return idx / W, idx % W 569 570 def natten_mask(score, b, h, q, kv): 571 q_x, q_y = get_x_y(q) 572 kv_x, kv_y = get_x_y(kv) 573 return torch.where( 574 ((q_x - kv_x).abs() <= WINDOW) | ((q_y - kv_y).abs() <= WINDOW), 575 score, 576 float("-inf"), 577 ) 578 579 self.run_test(natten_mask, dtype) 580 581 @supported_platform 582 @common_utils.parametrize("dtype", test_dtypes_fast) 583 def test_subgraph_respect_decompostion(self, dtype): 584 from torch._decomp import core_aten_decompositions 585 from torch.fx.experimental.proxy_tensor import make_fx 586 587 def score_mod_func(score, b, h, q, kv): 588 return score - q // (1 + kv) 589 590 make_tensor = functools.partial( 591 torch.randn, 592 (2, 2, 128, 4), 593 device="cuda", 594 dtype=torch.float64, 595 requires_grad=True, 596 ) 597 query, key, value = make_tensor(), make_tensor(), make_tensor() 598 # floor_div is not decomposed in decompostion_table is empty 599 flex_attention = functools.partial(_flex_attention, score_mod=score_mod_func) 600 gm = make_fx(flex_attention, decomposition_table={})(query, key, value) 601 self.assertExpectedInline( 602 gm.sdpa_score0.code.strip(), 603 """\ 604def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): 605 add = torch.ops.aten.add.Tensor(arg4_1, 1); arg4_1 = None 606 floor_divide = torch.ops.aten.floor_divide.default(arg3_1, add); arg3_1 = add = None 607 sub = torch.ops.aten.sub.Tensor(arg0_1, floor_divide); arg0_1 = floor_divide = None 608 return sub""", 609 ) 610 611 # floor_div is decomposed for core_aten_decompositions 612 gm = make_fx(flex_attention, decomposition_table=core_aten_decompositions())( 613 query, key, value 614 ) 615 self.assertExpectedInline( 616 gm.sdpa_score0.code.strip(), 617 """\ 618def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): 619 add = torch.ops.aten.add.Tensor(arg4_1, 1); arg4_1 = None 620 div = torch.ops.aten.div.Tensor_mode(arg3_1, add, rounding_mode = 'floor'); arg3_1 = add = None 621 sub = torch.ops.aten.sub.Tensor(arg0_1, div); arg0_1 = div = None 622 return sub""", 623 ) 624 625 @supported_platform 626 @common_utils.parametrize("dtype", test_dtypes_fast) 627 def test_silu_on_score(self, dtype): 628 def silu_score(score, b, h, q, kv): 629 return torch.nn.functional.silu(score) 630 631 self.run_test(silu_score, dtype) 632 633 @supported_platform 634 @common_utils.parametrize("dtype", test_dtypes_fast) 635 def test_padded_dense_causal(self, dtype): 636 seq_len = torch.arange(B, device="cuda", dtype=torch.int32) + 1 637 638 def create_padded_dense_wrapper(orig_score_mod): 639 def njt_score_mod(qk, b, h, q, kv): 640 return torch.where( 641 qk <= seq_len[b], orig_score_mod(qk, b, h, q, kv), -float("inf") 642 ) 643 644 return njt_score_mod 645 646 causal_njt = create_padded_dense_wrapper(_causal) 647 648 self.run_test(causal_njt, dtype) 649 650 @supported_platform 651 @common_utils.parametrize("dtype", test_dtypes_fast) 652 def test_captured_scale(self, dtype): 653 scale = torch.ones((), device="cuda", dtype=torch.int32) 654 655 def score_mod_scale(qk, b, h, q, kv): 656 return qk + scale 657 658 self.run_test(score_mod_scale, dtype) 659 660 @supported_platform 661 @common_utils.parametrize("dtype", test_dtypes_fast) 662 def test_recompile_changed_score_mod(self, dtype): 663 scale = torch.ones((), device="cuda", dtype=torch.int32) 664 ADD = True 665 666 def score_mod_scale(qk, b, h, q, kv): 667 if ADD: 668 return qk + scale 669 else: 670 return qk * scale 671 672 self.run_test(score_mod_scale, dtype) 673 ADD = False 674 self.run_test(score_mod_scale, dtype) 675 676 @supported_platform 677 @expectedFailure # If we capture a tensor then we can perform a reduction on it, and that shouldn't be allowed 678 @common_utils.parametrize("dtype", test_dtypes_fast) 679 def test_captured_reduction(self, dtype): 680 scale = torch.randn((B, 8), device="cuda") 681 682 def score_mod_scale(qk, b, h, q, kv): 683 return qk + scale[b].sum(dim=-1) 684 685 self.run_test(score_mod_scale, dtype) 686 687 @supported_platform 688 def test_multiple_score_mod_calls(self): 689 query = torch.randn((1, 8, 1024, 64), dtype=torch.float32, device="cuda") 690 keys = [ 691 torch.randn((1, 8, 1024, 64), dtype=torch.float32, device="cuda") 692 for _ in range(2) 693 ] 694 values = [ 695 torch.randn((1, 8, 1024, 64), dtype=torch.float32, device="cuda") 696 for _ in range(2) 697 ] 698 699 def scoremod_1(qk, b, h, q, kv): 700 return qk + (q - kv) 701 702 def scoremod_2(qk, b, h, q, kv): 703 return torch.where(q >= kv, qk, -float("inf")) 704 705 def f(q, k1, k2, v1, v2): 706 q2 = _flex_attention(q, k1, v1, score_mod=scoremod_1) 707 return _flex_attention(q2, k2, v2, score_mod=scoremod_2) 708 709 out = f(query, *keys, *values) 710 out2 = torch.compile(f)(query, *keys, *values) 711 tolerance = Tolerances(atol=2e-1, rtol=2e-1) 712 torch.testing.assert_close(out, out2, atol=tolerance.atol, rtol=tolerance.rtol) 713 714 @supported_platform 715 def test_multiple_score_mod_calls2(self): 716 query = torch.randn((1, 8, 1024, 64), dtype=torch.float32, device="cuda") 717 keys = [ 718 torch.randn((1, 8, 1024, 64), dtype=torch.float32, device="cuda") 719 for _ in range(3) 720 ] 721 values = [ 722 torch.randn((1, 8, 1024, 64), dtype=torch.float32, device="cuda") 723 for _ in range(3) 724 ] 725 726 def scoremod_1(qk, b, h, q, kv): 727 return qk + (q - kv) 728 729 def scoremod_2(qk, b, h, q, kv): 730 return torch.where(q >= kv, qk, -float("inf")) 731 732 attention1 = functools.partial(_flex_attention, score_mod=scoremod_1) 733 734 def f(q, k1, k2, k3, v1, v2, v3): 735 q2 = attention1(q, k1, v1) 736 q3 = _flex_attention(q2, k2, v2, score_mod=scoremod_2) 737 return _flex_attention(q3, k3, v3, score_mod=scoremod_1) 738 739 out = f(query, *keys, *values) 740 out2 = torch.compile(f)(query, *keys, *values) 741 self.assertTrue((out - out2).abs().mean() < 1e-2) 742 743 @supported_platform 744 def test_inputs_are_realized(self): 745 def f(q, k, v): 746 x = torch.randn(1024, device="cuda") 747 x = x * 2 748 749 def func(qk, b, h, q, kv): 750 return qk + x[q] 751 752 return _flex_attention(q.sin(), k, v, score_mod=func).cos() 753 754 q, k, v = ( 755 torch.randn(1, 8, 1024, 64, device="cuda", requires_grad=True) 756 for _ in range(3) 757 ) 758 ref = f(q, k, v) 759 out = torch.compile(f)(q, k, v) 760 self.assertTrue((ref - out).abs().mean() < 1e-2) 761 gradOut = torch.randn_like(q) 762 763 ref_grads = torch.autograd.grad(ref, (q, k, v), gradOut) 764 out_grads = torch.autograd.grad(out, (q, k, v), gradOut) 765 for ref, out in zip(ref_grads, out_grads): 766 self.assertTrue((ref - out).abs().mean() < 1e-2) 767 768 @supported_platform 769 def test_epilogue_fused(self): 770 @torch.compile 771 def f(q, k, v): 772 out = _flex_attention(q, k, v) 773 return out.cos() 774 775 q, k, v = (torch.randn(1, 8, 1024, 64, device="cuda") for _ in range(3)) 776 metrics.reset() 777 f(q, k, v) 778 accessed_bytes = 1 * 8 * 1024 * 64 * torch.float32.itemsize 779 num_accesses = 4 # q, k, v reads, one output. 780 # TODO: Get rid of this fudge factor 781 # We need this fudge factor for now, since 782 # 1. For some reason we materialize the output of the attention unnecessarily (it's related to the mutation somehow) 783 # 2. We also write the extraneous logsumexp 784 num_accesses += 2 785 self.assertLess(metrics.num_bytes_accessed, accessed_bytes * num_accesses) 786 787 @supported_platform 788 @skip("Triton bug ") # https://github.com/pytorch/pytorch/issues/124571 789 @common_utils.parametrize("dtype", test_dtypes) 790 def test_njt_causal(self, dtype): 791 offsets = torch.tensor( 792 [0, 1024, 1024 + 512, S], device="cuda", dtype=torch.int32 793 ) 794 seq_idx = torch.zeros(S, device="cuda", dtype=torch.int32) 795 for idx in range(len(offsets) - 1): 796 seq_idx[offsets[idx] : offsets[idx + 1]] = idx 797 798 def create_njt_wrapper(orig_score_mod, offsets, seq_idx): 799 def njt_score_mod(qk, b, h, q, kv): 800 q_nested = q - offsets[seq_idx[q]] 801 kv_nested = kv - offsets[seq_idx[kv]] 802 return orig_score_mod(qk, b, h, q_nested, kv_nested) 803 804 return njt_score_mod 805 806 causal_njt = create_njt_wrapper(_causal, offsets, seq_idx) 807 808 self.run_test(causal_njt, dtype) 809 810 @supported_platform 811 def test_mixed_dtypes_fails(self): 812 query = torch.randn((1, 1, 1024, 64), dtype=torch.float32, device="cuda") 813 key = torch.randn((1, 1, 1024, 64), dtype=torch.float16, device="cuda") 814 value = torch.randn((1, 1, 1024, 64), dtype=torch.float16, device="cuda") 815 with self.assertRaisesRegex( 816 ValueError, "Expected query, key, and value to have the same dtype" 817 ): 818 _flex_attention(query, key, value, _identity) 819 820 @supported_platform 821 @patch.object(torch._inductor.config, "max_autotune", True) 822 def test_max_autotune(self): 823 def score_mod(score, b, h, m, n): 824 return score * 2 825 826 self.run_test(score_mod) 827 828 @supported_platform 829 @skip("TODO: Figure out why this is erroring") 830 @patch.object(torch._inductor.config, "max_autotune", True) 831 def test_max_autotune_with_captured(self): 832 head_scale = torch.randn(H, device="cuda") 833 batch_scale = torch.randn(B, device="cuda") 834 tok_scale = torch.randn(S, device="cuda") 835 836 def bias_mod(score, batch, head, token_q, token_kv): 837 score = score + tok_scale[token_q] 838 score = score + batch_scale[batch] 839 score = score + head_scale[head] 840 return score 841 842 self.run_test(bias_mod) 843 844 @supported_platform 845 @common_utils.parametrize("dtype", test_dtypes) 846 @common_utils.parametrize("score_mod", [_identity, _causal]) 847 def test_logsumexp_correctness(self, dtype, score_mod): 848 @torch.compile 849 def sdpa_hop(q, k, v, score_mod): 850 return flex_attention_hop(q, k, v, score_mod) 851 852 @torch.compile(backend="aot_eager") 853 def eager_sdpa_hop(q, k, v, score_mod): 854 """The main entrypoint for FlexAttention doesnt return LSE. 855 Besides dropping LSE it also ensures that the hop is compiled with aot-eager 856 backend. We need to replicate this. 857 """ 858 return flex_attention_hop(q, k, v, score_mod) 859 860 make_tensor = functools.partial( 861 torch.randn, 862 (B, H, S, D), 863 dtype=dtype, 864 device="cuda", 865 requires_grad=True, 866 ) 867 q, k, v = make_tensor(), make_tensor(), make_tensor() 868 869 ref_out, ref_lse = eager_sdpa_hop( 870 q.to(torch.float64), k.to(torch.float64), v.to(torch.float64), score_mod 871 ) 872 compiled_out, compiled_lse = sdpa_hop(q, k, v, score_mod) 873 874 # Comparing LSE for the ref and the compiled version 875 # The compiled uses a change of base trick to more efficiently compute the LSE 876 # this means that the base for the LSE computed by ref is e while for the compiled 877 # version it is 2. To compare we use the change of base formula 878 # log_2(x_compiled) = log_e(x_ref) * log_2(e) where 879 # x_ref = sum(_i e^(scores[i])) 880 # x_compiled = sum(_i 2^(log2(e) * scores[i])) 881 882 self.assertTrue(ref_lse.dtype == torch.float64) 883 self.assertTrue(compiled_lse.dtype == torch.float32) 884 ref_lse = ref_lse * torch.log2(torch.tensor(torch.e)) 885 886 tolerance = Tolerances(atol=2e-2, rtol=2e-2) 887 torch.testing.assert_close( 888 ref_out.to(dtype=torch.float32), 889 compiled_out.to(dtype=torch.float32), 890 atol=tolerance.atol, 891 rtol=tolerance.rtol, 892 ) 893 torch.testing.assert_close( 894 ref_lse.to(dtype=torch.float32), 895 compiled_lse.to(dtype=torch.float32), 896 atol=tolerance.atol, 897 rtol=tolerance.rtol, 898 ) 899 900 @supported_platform 901 def test_logsumexp_only_return(self): 902 make_tensor = functools.partial( 903 torch.randn, 904 (B, H, S, D), 905 dtype=torch.float32, 906 device="cuda", 907 requires_grad=True, 908 ) 909 q, k, v = make_tensor(), make_tensor(), make_tensor() 910 911 @torch.compile 912 def func(q, k, v, score_mod): 913 _, lse = flex_attention_hop(q, k, v, score_mod) 914 lse_2 = lse * 2 915 return lse_2 916 917 _, code = run_and_get_code(func, q, k, v, _identity) 918 # Ensure that two kernels are generated 919 FileCheck().check_count(".run(", 2, True).run(code[0]) 920 921 @supported_platform 922 def test_logsumexp_is_not_fused(self): 923 make_tensor = functools.partial( 924 torch.randn, 925 (B, H, S, D), 926 dtype=torch.float32, 927 device="cuda", 928 requires_grad=True, 929 ) 930 q, k, v = make_tensor(), make_tensor(), make_tensor() 931 932 @torch.compile 933 def func(q, k, v, score_mod): 934 out, lse = flex_attention_hop(q, k, v, score_mod) 935 lse_2 = lse * 2 936 return out, lse_2 937 938 _, code = run_and_get_code(func, q, k, v, _identity) 939 # Ensure that two kernels are generated 940 FileCheck().check_count(".run(", 2, True).run(code[0]) 941 942 @supported_platform 943 @common_utils.parametrize( 944 "score_mod", [_identity, _causal, _times_two, _squared, _trig, _trig2] 945 ) 946 def test_aot_eager_gradcheck(self, score_mod): 947 make_tensor = functools.partial( 948 torch.randn, 949 (2, 2, 8, 4), 950 device="cuda", 951 dtype=torch.float64, 952 requires_grad=True, 953 ) 954 query, key, value = make_tensor(), make_tensor(), make_tensor() 955 956 func = torch.compile(_flex_attention, backend="aot_eager", fullgraph=True) 957 958 self.assertTrue( 959 torch.autograd.gradcheck( 960 func, (query, key, value, score_mod), raise_exception=True 961 ) 962 ) 963 964 @supported_platform 965 @common_utils.parametrize("score_mod_name", ["_head_offset"]) 966 @common_utils.parametrize("mode", ["eager", "aot_eager"]) 967 def test_captured_score_mod_aot_eager_gradcheck( 968 self, score_mod_name: str, mode: str 969 ): 970 make_tensor = functools.partial( 971 torch.randn, 972 (2, 2, 8, 4), 973 device="cuda", 974 dtype=torch.float64, 975 requires_grad=True, 976 ) 977 query, key, value = make_tensor(), make_tensor(), make_tensor() 978 979 func = torch.compile(_flex_attention, backend=mode, fullgraph=True) 980 score_mod = captured_buffers_map[score_mod_name](torch.float64) 981 982 self.assertTrue( 983 torch.autograd.gradcheck( 984 func, (query, key, value, score_mod), raise_exception=True 985 ) 986 ) 987 988 @supported_platform 989 def test_fw_bw_graph_correctness(self): 990 cnt = CompileCounterWithBackend("aot_eager") 991 make_tensor = functools.partial( 992 torch.randn, 993 (2, 2, 8, 4), 994 device="cuda", 995 dtype=torch.float64, 996 requires_grad=True, 997 ) 998 query, key, value = make_tensor(), make_tensor(), make_tensor() 999 1000 func = torch.compile(_flex_attention, backend=cnt, fullgraph=True) 1001 out = func(query, key, value, _squared) 1002 out.sum().backward() 1003 self.assertEqual(cnt.frame_count, 1) 1004 self.assertEqual(len(cnt.graphs), 1) 1005 graph = cnt.graphs[0] 1006 norm_graph = normalize_gm(graph.print_readable(print_output=False)) 1007 self.assertExpectedInline( 1008 norm_graph, 1009 """\ 1010class GraphModule(torch.nn.Module): 1011 def forward(self, L_args_0_: "f64[2, 2, 8, 4]", L_args_1_: "f64[2, 2, 8, 4]", L_args_2_: "f64[2, 2, 8, 4]"): 1012 l_args_0_ = L_args_0_ 1013 l_args_1_ = L_args_1_ 1014 l_args_2_ = L_args_2_ 1015 1016 new_empty: "f64[]" = l_args_0_.new_empty([], requires_grad = True) 1017 new_empty_1: "i32[]" = l_args_0_.new_empty([], dtype = torch.int32) 1018 new_empty_2: "i32[]" = l_args_0_.new_empty([], dtype = torch.int32) 1019 new_empty_3: "i32[]" = l_args_0_.new_empty([], dtype = torch.int32) 1020 new_empty_4: "i32[]" = l_args_0_.new_empty([], dtype = torch.int32) 1021 flex_attention_0 = self.flex_attention_0 1022 flex_attention = torch.ops.higher_order.flex_attention(l_args_0_, l_args_1_, l_args_2_, flex_attention_0); l_args_0_ = l_args_1_ = l_args_2_ = flex_attention_0 = None 1023 out: "f64[2, 2, 8, 4]" = flex_attention[0]; flex_attention = None 1024 return (out,) 1025 1026 class GraphModule(torch.nn.Module): 1027 def forward(self, new_empty: "f64[]", new_empty_1: "i32[]", new_empty_2: "i32[]", new_empty_3: "i32[]", new_empty_4: "i32[]"): 1028 mul: "f64[]" = new_empty * new_empty; new_empty = None 1029 return mul 1030""", # noqa: B950 1031 ) 1032 # Save the AOT graphs 1033 aot_graphs = [] 1034 from torch._inductor import compile_fx 1035 1036 def debug_compile_fx_inner(graph, example_inputs, *args, **kwargs): 1037 aot_graphs.append(graph) 1038 return graph 1039 1040 backend = functools.partial( 1041 compile_fx.compile_fx, inner_compile=debug_compile_fx_inner 1042 ) 1043 func = torch.compile(func, backend=backend, fullgraph=True) 1044 out = func(query, key, value, _squared) 1045 out.sum().backward() 1046 1047 joint_graph = normalize_gm(aot_graphs[1].print_readable(print_output=False)) 1048 1049 self.assertExpectedInline( 1050 joint_graph, 1051 """\ 1052class GraphModule(torch.nn.Module): 1053 def forward(self, primals_1: "f64[2, 2, 8, 4]", primals_2: "f64[2, 2, 8, 4]", primals_3: "f64[2, 2, 8, 4]", getitem: "f64[2, 2, 8, 4]", getitem_1: "f32[2, 2, 8]", tangents_1: "f64[2, 2, 8, 4]"): 1054 fw_graph = self.fw_graph 1055 joint_graph = self.joint_graph 1056 flex_attention_backward = torch.ops.higher_order.flex_attention_backward(primals_1, primals_2, primals_3, getitem, getitem_1, tangents_1, fw_graph, joint_graph); primals_1 = primals_2 = primals_3 = getitem = getitem_1 = tangents_1 = fw_graph = joint_graph = None 1057 getitem_2: "f64[2, 2, 8, 4]" = flex_attention_backward[0] 1058 getitem_3: "f64[2, 2, 8, 4]" = flex_attention_backward[1] 1059 getitem_4: "f64[2, 2, 8, 4]" = flex_attention_backward[2]; flex_attention_backward = None 1060 return [getitem_2, getitem_3, getitem_4] 1061 1062 class <lambda>(torch.nn.Module): 1063 def forward(self, arg0_1: "f64[]", arg1_1: "i32[]", arg2_1: "i32[]", arg3_1: "i32[]", arg4_1: "i32[]"): 1064 mul: "f64[]" = torch.ops.aten.mul.Tensor(arg0_1, arg0_1); arg0_1 = None 1065 return mul 1066 1067 class <lambda>(torch.nn.Module): 1068 def forward(self, arg0_1: "f64[]", arg1_1: "i32[]", arg2_1: "i32[]", arg3_1: "i32[]", arg4_1: "i32[]", arg5_1: "f64[]"): 1069 mul: "f64[]" = torch.ops.aten.mul.Tensor(arg0_1, arg0_1) 1070 mul_1: "f64[]" = torch.ops.aten.mul.Tensor(arg5_1, arg0_1) 1071 mul_2: "f64[]" = torch.ops.aten.mul.Tensor(arg5_1, arg0_1); arg5_1 = arg0_1 = None 1072 add: "f64[]" = torch.ops.aten.add.Tensor(mul_2, mul_1); mul_2 = mul_1 = None 1073 return [add, None, None, None, None] 1074""", # noqa: B950 1075 ) 1076 1077 1078common_utils.instantiate_parametrized_tests(TestFlexAttention) 1079 1080if __name__ == "__main__": 1081 from torch._inductor.test_case import run_tests 1082 1083 run_tests() 1084