1# mypy: allow-untyped-decorators 2# mypy: allow-untyped-defs 3import math 4from typing import Any, Callable, Dict, Sequence, Tuple, Union 5 6import torch 7import torch.utils._pytree as pytree 8from torch._C import DispatchKey 9from torch._higher_order_ops.utils import ( 10 _has_potential_branch_input_mutation, 11 autograd_not_implemented, 12 reenter_make_fx, 13 UnsupportedAliasMutationException, 14) 15from torch._ops import HigherOrderOperator 16from torch._subclasses import FakeTensorMode 17from torch.fx.experimental.proxy_tensor import ( 18 make_fx, 19 ProxyTorchDispatchMode, 20 track_tensor_tree, 21) 22from torch.fx.graph_module import GraphModule 23from torch.overrides import TorchFunctionMode 24 25 26# Duplicate of _inductor/kernel/flex_attention.py to avoid circular import 27def _construct_strides( 28 sizes: Sequence[int], 29 fill_order: Sequence[int], 30) -> Sequence[int]: 31 """From a list of sizes and a fill order, construct the strides of the permuted tensor.""" 32 # Initialize strides 33 assert len(sizes) == len( 34 fill_order 35 ), "Length of sizes must match the length of the fill order" 36 strides = [0] * len(sizes) 37 38 # Start with stride 1 for the innermost dimension 39 current_stride = 1 40 41 # Iterate through the fill order populating strides 42 for dim in fill_order: 43 strides[dim] = current_stride 44 current_stride *= sizes[dim] 45 46 return strides 47 48 49def _permute_strides(out: torch.Tensor, query_strides: Tuple[int, ...]) -> torch.Tensor: 50 """ 51 Create a new tensor with the same data and shape as the input, 52 but with strides permuted based on the input tensor's stride order. 53 54 Args: 55 out (torch.Tensor): The output tensor of attention. 56 query_strides (List[int]): The stride order of the input query tensor 57 58 Returns: 59 torch.Tensor: A new tensor with same shape and data as the input, 60 but with strides permuted based on the query tensor's stride order. 61 """ 62 from torch._inductor.ir import get_stride_order, stride_order2fill_order 63 64 stride_order = get_stride_order(query_strides) 65 fill_order = stride_order2fill_order(stride_order) 66 assert out.storage_offset() == 0, "Only support storage_offset == 0" 67 out_strides = _construct_strides(out.shape, fill_order) 68 new_out = out.new_empty(out.shape).as_strided(out.shape, out_strides) 69 new_out.copy_(out) 70 return new_out 71 72 73class TransformGetItemToIndex(TorchFunctionMode): 74 # This is needed since we want to support calling 75 # A[q_idx], where q_idx is a scalar tensor in score_mod. 76 # Today, when q_idx is a scalar tensor, we implicitly convert it to a python 77 # scalar and create a view. We do not want that behavior in this case, so we 78 # use this torchfunctionmode to override that behavior for score_mod 79 # wherever we're running it. 80 def __torch_function__(self, func, types, args, kwargs=None): 81 if func == torch.Tensor.__getitem__: 82 index_args = pytree.tree_leaves(args[1]) 83 if all(isinstance(x, torch.Tensor) for x in index_args): 84 return torch.ops.aten.index(args[0], index_args) 85 return func(*args, **(kwargs or {})) 86 87 88class FlexAttentionHOP(HigherOrderOperator): 89 def __init__(self) -> None: 90 super().__init__("flex_attention") 91 92 def __call__( 93 self, 94 query: torch.Tensor, 95 key: torch.Tensor, 96 value: torch.Tensor, 97 score_mod: Callable, 98 block_mask: Tuple, 99 scale: float, 100 kernel_options: Dict[str, Any], 101 score_mod_other_buffers: Tuple = (), 102 mask_mod_other_buffers: Tuple = (), 103 ) -> Tuple[torch.Tensor, torch.Tensor]: 104 if not all( 105 isinstance(buf, torch.Tensor) 106 for buf in score_mod_other_buffers + mask_mod_other_buffers 107 ): 108 raise RuntimeError("Other buffers must be tensors.") 109 return super().__call__( 110 query, 111 key, 112 value, 113 score_mod, 114 block_mask, 115 scale, 116 kernel_options, 117 score_mod_other_buffers, 118 mask_mod_other_buffers, 119 ) 120 121 122flex_attention = FlexAttentionHOP() 123 124 125class FlexAttentionBackwardHOP(HigherOrderOperator): 126 def __init__(self) -> None: 127 super().__init__("flex_attention_backward") 128 129 def __call__( 130 self, 131 query: torch.Tensor, 132 key: torch.Tensor, 133 value: torch.Tensor, 134 out: torch.Tensor, 135 logsumexp: torch.Tensor, 136 grad_out: torch.Tensor, 137 grad_logsumexp: torch.Tensor, 138 fw_graph: Union[Callable, GraphModule], 139 joint_graph: GraphModule, 140 block_mask: Tuple, 141 scale: float, 142 kernel_options: Dict[str, Any], 143 score_mod_other_buffers: Tuple = (), 144 mask_mod_other_buffers: Tuple = (), 145 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 146 if not all( 147 isinstance(buf, torch.Tensor) 148 for buf in score_mod_other_buffers + mask_mod_other_buffers 149 ): 150 raise RuntimeError("Other buffers must be tensors.") 151 return super().__call__( 152 query, 153 key, 154 value, 155 out, 156 logsumexp, 157 grad_out, 158 grad_logsumexp, 159 fw_graph, 160 joint_graph, 161 block_mask, 162 scale, 163 kernel_options, 164 score_mod_other_buffers, 165 mask_mod_other_buffers, 166 ) 167 168 169flex_attention_backward = FlexAttentionBackwardHOP() 170 171 172def _math_attention_inner( 173 query: torch.Tensor, 174 key: torch.Tensor, 175 value: torch.Tensor, 176 score_mod: Callable, 177 block_mask: Tuple, 178 scale: float, 179 kernel_options: Dict[str, Any], 180 score_mod_other_buffers: Tuple = (), 181 mask_mod_other_buffers: Tuple = (), 182) -> Tuple[torch.Tensor, torch.Tensor]: 183 working_precision = torch.float64 if query.dtype == torch.float64 else torch.float32 184 185 scores = (query @ key.transpose(-2, -1)).to(dtype=working_precision) 186 187 b = torch.arange(0, scores.size(0), device=scores.device) 188 h = torch.arange(0, scores.size(1), device=scores.device) 189 m = torch.arange(0, scores.size(2), device=scores.device) 190 n = torch.arange(0, scores.size(3), device=scores.device) 191 192 captured_buffers_in_dim = (None,) * len(score_mod_other_buffers) 193 from torch.nn.attention.flex_attention import _vmap_for_bhqkv 194 195 # first input is score 196 score_mod = _vmap_for_bhqkv(score_mod, prefix=(0,), suffix=captured_buffers_in_dim) 197 198 mask_mod = block_mask[-1] 199 mask_mod_in_dim_buffers = (None,) * len(mask_mod_other_buffers) 200 mask_mod = _vmap_for_bhqkv(mask_mod, prefix=(), suffix=mask_mod_in_dim_buffers) 201 202 with TransformGetItemToIndex(): 203 scores = (scores * scale).to(working_precision) 204 post_mod_scores = torch.where( 205 mask_mod(b, h, m, n, *mask_mod_other_buffers), 206 score_mod(scores, b, h, m, n, *score_mod_other_buffers), 207 torch.tensor(-float("inf"), dtype=working_precision, device=scores.device), 208 ) 209 210 return scores, post_mod_scores 211 212 213def math_attention( 214 query: torch.Tensor, 215 key: torch.Tensor, 216 value: torch.Tensor, 217 score_mod: Callable, 218 block_mask: Tuple, 219 scale: float, 220 kernel_options: Dict[str, Any], 221 score_mod_other_buffers: Tuple = (), 222 mask_mod_other_buffers: Tuple = (), 223) -> Tuple[torch.Tensor, torch.Tensor]: 224 """Eager implementation 225 226 This implementation uses vmap to vectorize the score_mod function over the batch, head, m, and n dimensions. 227 We then apply the vectorized score_mod function to the scores matrix. Each wrap of vmap applies one of the 228 batch, head, m, or n dimensions. We need to apply vmap 4 times to vectorized over all 4 dimensions. 229 230 Args: 231 query: The query tensor 232 key: The key tensor 233 value: The value tensor 234 score_mod: The score_mod function 235 other_buffers: Other buffers that are passed to the score_mod function 236 """ 237 # broadcast query & key along head dim for GQA 238 G = query.size(1) // key.size(1) 239 value = torch.repeat_interleave(value, G, dim=1) 240 key = torch.repeat_interleave(key, G, dim=1) 241 242 _, post_mod_scores = _math_attention_inner( 243 query, 244 key, 245 value, 246 score_mod, 247 block_mask, 248 scale, 249 kernel_options, 250 score_mod_other_buffers, 251 mask_mod_other_buffers, 252 ) 253 254 # Set fully masked rows' sumexp to 0.0 255 logsumexp = post_mod_scores.logsumexp(dim=-1) 256 masked_rows = torch.all(post_mod_scores == -float("inf"), dim=-1) 257 logsumexp = torch.where(masked_rows, -float("inf"), logsumexp) 258 259 post_mod_scores = torch._safe_softmax(post_mod_scores, dim=-1) 260 261 return post_mod_scores.to(query.dtype) @ value, logsumexp / math.log(2) 262 263 264@flex_attention.py_impl(DispatchKey.CompositeExplicitAutograd) 265def sdpa_dense( 266 query: torch.Tensor, 267 key: torch.Tensor, 268 value: torch.Tensor, 269 score_mod: Callable, 270 block_mask: Tuple, 271 scale: float, 272 kernel_options: Dict[str, Any], 273 score_mod_other_buffers: Tuple = (), 274 mask_mod_other_buffers: Tuple = (), 275) -> Tuple[torch.Tensor, torch.Tensor]: 276 out, lse = math_attention( 277 query, 278 key, 279 value, 280 score_mod, 281 block_mask, 282 scale, 283 kernel_options, 284 score_mod_other_buffers, 285 mask_mod_other_buffers, 286 ) 287 out = _permute_strides(out, query.stride()) 288 return out, lse 289 290 291def trace_flex_attention( 292 proxy_mode: ProxyTorchDispatchMode, 293 query: torch.Tensor, 294 key: torch.Tensor, 295 value: torch.Tensor, 296 score_mod: Callable, 297 block_mask: Tuple, 298 scale: float, 299 kernel_options: Dict[str, Any], 300 score_mod_other_buffers: Tuple = (), 301 mask_mod_other_buffers: Tuple = (), 302) -> Tuple[torch.Tensor, torch.Tensor]: 303 """Traces the flex_attention operator with the given score_mod function and other_buffers. 304 305 Trace SDPA will call make_fx with "fake" example vals and then trace the score_mod function 306 This will produce a GraphModule that will be stored on the root tracer as "sdpa_score". We 307 access this graph module in inductor to inline the score_mod function to the triton template. 308 """ 309 example_out = flex_attention( 310 query, 311 key, 312 value, 313 score_mod, 314 block_mask, 315 scale, 316 kernel_options, 317 score_mod_other_buffers, 318 mask_mod_other_buffers, 319 ) 320 example_vals = [ 321 torch.zeros((), dtype=query.dtype, requires_grad=query.requires_grad) 322 ] + [torch.zeros((), dtype=torch.int) for _ in range(4)] 323 mask_example_vals = [torch.zeros((), dtype=torch.int) for _ in range(4)] 324 mask_mod = block_mask[-1] 325 with TransformGetItemToIndex(): 326 score_graph = reenter_make_fx(score_mod)( 327 *example_vals, *score_mod_other_buffers 328 ) 329 mask_graph = reenter_make_fx(mask_mod)( 330 *mask_example_vals, *mask_mod_other_buffers 331 ) 332 assert isinstance(proxy_mode.tracer, torch.fx.Tracer) 333 block_mask = block_mask[:-1] + (mask_graph,) 334 qualname = proxy_mode.tracer.get_fresh_qualname("sdpa_score") 335 proxy_mode.tracer.root.register_module(qualname, score_graph) 336 mask_qualname = proxy_mode.tracer.get_fresh_qualname("sdpa_mask") 337 proxy_mode.tracer.root.register_module(mask_qualname, mask_graph) 338 node_args = ( 339 query, 340 key, 341 value, 342 score_graph, 343 block_mask, 344 scale, 345 kernel_options, 346 score_mod_other_buffers, 347 mask_mod_other_buffers, 348 ) 349 proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, node_args) 350 out_proxy = proxy_mode.tracer.create_proxy( 351 "call_function", flex_attention, proxy_args, {} 352 ) 353 return track_tensor_tree( 354 example_out, out_proxy, constant=None, tracer=proxy_mode.tracer 355 ) 356 357 358@flex_attention.py_impl(ProxyTorchDispatchMode) 359def flex_attention_proxy_torch_dispatch_mode( 360 mode: ProxyTorchDispatchMode, 361 query: torch.Tensor, 362 key: torch.Tensor, 363 value: torch.Tensor, 364 score_mod: Callable, 365 block_mask: Tuple, 366 scale: float, 367 kernel_options: Dict[str, Any], 368 score_mod_other_buffers: Tuple = (), 369 mask_mod_other_buffers: Tuple = (), 370) -> Tuple[torch.Tensor, torch.Tensor]: 371 assert mode is not None, "Mode should always be enabled for python fallback key" 372 return trace_flex_attention( 373 mode, 374 query, 375 key, 376 value, 377 score_mod, 378 block_mask, 379 scale, 380 kernel_options, 381 score_mod_other_buffers, 382 mask_mod_other_buffers, 383 ) 384 385 386@flex_attention.py_functionalize_impl 387def flex_attention_functionalize( 388 ctx: torch._subclasses.functional_tensor.BaseFunctionalizeAPI, 389 query: torch.Tensor, 390 key: torch.Tensor, 391 value: torch.Tensor, 392 score_mod: Callable, 393 block_mask: Tuple, 394 scale: float, 395 kernel_options: Dict[str, Any], 396 score_mod_other_buffers: Tuple = (), 397 mask_mod_other_buffers: Tuple = (), 398) -> Tuple[torch.Tensor, torch.Tensor]: 399 """Defines the functionalization rules for the flex_attention operator. 400 401 Write now we are unwrapping each tensor and then redispatching to the next, however we want to 402 guard against any mutations in the score_mod function, to the other_buffers since those 403 are free variables. 404 """ 405 query_unwrapped = ctx.unwrap_tensors(query) 406 key_unwrapped = ctx.unwrap_tensors(key) 407 value_unwrapped = ctx.unwrap_tensors(value) 408 block_mask_unwrapped = ctx.unwrap_tensors(block_mask) 409 score_mod_other_buffers_unwrapped = ctx.unwrap_tensors(score_mod_other_buffers) 410 mask_mod_other_buffers_unwrapped = ctx.unwrap_tensors(mask_mod_other_buffers) 411 412 # Appease the mypy overlords 413 assert isinstance(query_unwrapped, torch.Tensor) 414 assert isinstance(key_unwrapped, torch.Tensor) 415 assert isinstance(value_unwrapped, torch.Tensor) 416 assert isinstance(block_mask_unwrapped, tuple) 417 assert isinstance(score_mod_other_buffers_unwrapped, tuple) 418 assert isinstance(mask_mod_other_buffers_unwrapped, tuple) 419 assert all( 420 isinstance(item, torch.Tensor) 421 for item in score_mod_other_buffers_unwrapped + mask_mod_other_buffers_unwrapped 422 ) 423 424 example_vals = ( 425 [torch.zeros((), dtype=query.dtype)] 426 + [torch.zeros((), dtype=torch.int) for _ in range(4)] 427 + list(score_mod_other_buffers_unwrapped) 428 ) 429 with ctx.redispatch_to_next() as m: 430 functional_score_mod = ctx.functionalize(score_mod) 431 pre_dispatch = hasattr(ctx, "mode") and ctx.mode.pre_dispatch 432 with TransformGetItemToIndex(): 433 mutates = _has_potential_branch_input_mutation( 434 functional_score_mod, example_vals, pre_dispatch 435 ) 436 # The only care about mutations of existing buffers since we can't replay these. 437 # However, we can just error if anything is detected 438 if mutates: 439 raise UnsupportedAliasMutationException("Mutations detected in score_mod") 440 441 out = flex_attention( 442 query_unwrapped, 443 key_unwrapped, 444 value_unwrapped, 445 functional_score_mod, 446 block_mask_unwrapped, 447 scale, 448 kernel_options, 449 score_mod_other_buffers_unwrapped, 450 mask_mod_other_buffers_unwrapped, 451 ) 452 return ctx.wrap_tensors(out) # type: ignore[return-value, arg-type] 453 454 455@flex_attention.py_impl(FakeTensorMode) 456def flex_attention_fake_tensor_mode( 457 mode: FakeTensorMode, 458 query: torch.Tensor, 459 key: torch.Tensor, 460 value: torch.Tensor, 461 score_mod: Callable, 462 block_mask: Tuple, 463 scale: float, 464 kernel_options: Dict[str, Any], 465 score_mod_other_buffers: Tuple = (), 466 mask_mod_other_buffers: Tuple = (), 467) -> Tuple[torch.Tensor, torch.Tensor]: 468 with mode: 469 v_head_dim = value.size(-1) 470 batch_size, num_heads, seq_len_q, q_head_dim = query.shape 471 logsumexp = query.new_empty( 472 batch_size, num_heads, seq_len_q, dtype=torch.float32 473 ) 474 out_shape = (batch_size, num_heads, seq_len_q, v_head_dim) 475 out = query.new_empty(out_shape) 476 out = _permute_strides(out, query.stride()) 477 return out, logsumexp 478 479 480# ---------------------------- Autograd Implementation ---------------------------- 481def create_fw_bw_graph(score_mod, index_values, other_buffers): 482 # See Note:[HOP create fw_bw graph] 483 484 # All of these imports need to be here in order to avoid circular dependencies 485 from torch._dispatch.python import suspend_functionalization 486 from torch._functorch.aot_autograd import AOTConfig, create_joint 487 from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode 488 from torch._subclasses.functional_tensor import disable_functional_mode 489 from torch.fx.experimental.proxy_tensor import disable_proxy_modes_tracing 490 491 dummy_aot_config = AOTConfig( 492 fw_compiler=None, # type: ignore[arg-type] 493 bw_compiler=None, # type: ignore[arg-type] 494 partition_fn=None, # type: ignore[arg-type] 495 decompositions={}, 496 num_params_buffers=0, 497 aot_id=0, 498 keep_inference_input_mutations=False, 499 ) 500 501 with suspend_functionalization(), disable_functional_mode(): 502 with disable_proxy_modes_tracing(): 503 504 def _from_fun(t): 505 return torch.empty_strided( 506 t.size(), 507 t.stride(), 508 device=t.device, 509 dtype=t.dtype, 510 requires_grad=t.requires_grad, 511 ) 512 513 # If someone runs this hop under the default compiler backend ("eager") 514 # Then this path will be run with the actual user inputs. We convert them 515 # to fake tensors in order to not perform any actual compute. 516 from torch._guards import detect_fake_mode 517 518 fake_mode = detect_fake_mode(index_values) 519 if fake_mode is None: 520 fake_mode = FakeTensorMode(allow_non_fake_inputs=True) 521 522 with fake_mode: 523 unwrapped_score_mod_indexes = pytree.tree_map(_from_fun, index_values) 524 unwrapped_other_buffers = pytree.tree_map(_from_fun, other_buffers) 525 526 assert all(isinstance(t, FakeTensor) for t in unwrapped_score_mod_indexes) 527 assert all(isinstance(t, FakeTensor) for t in unwrapped_other_buffers) 528 529 example_flat_out = pytree.tree_map( 530 _from_fun, 531 score_mod(*unwrapped_score_mod_indexes, *unwrapped_other_buffers), 532 ) 533 if not isinstance(example_flat_out, torch.Tensor): 534 raise RuntimeError( 535 "Expected output of score_mod to be a tensor." 536 f"Got type {type(example_flat_out)}." 537 ) 538 example_grad = _from_fun(example_flat_out) 539 540 def joint_f(score, b, h, m, n, example_grad, *other_buffers): 541 def fw_with_masks(*args): 542 fw_out = score_mod(*args) 543 out_requires_grad = fw_out.requires_grad 544 return ((fw_out,), (out_requires_grad,)) 545 546 joint = create_joint(fw_with_masks, aot_config=dummy_aot_config) 547 args = [score, b, h, m, n] + list(other_buffers) 548 optional_grad = [example_grad] if example_grad.requires_grad else [] 549 _, grads = joint(args, optional_grad) 550 551 return grads 552 553 joint_graph = make_fx(joint_f)( 554 *unwrapped_score_mod_indexes, example_grad, *unwrapped_other_buffers 555 ) 556 return score_mod, joint_graph 557 558 559class FlexAttentionAutogradOp(torch.autograd.Function): 560 @staticmethod 561 def forward( 562 ctx, 563 query, 564 key, 565 value, 566 fw_graph, 567 joint_graph, 568 block_mask, 569 scale, 570 kernel_options, 571 score_mod_other_buffers, 572 mask_mod_other_buffers, 573 ) -> Tuple[torch.Tensor, torch.Tensor]: 574 any_buffer_requires_grad = any( 575 buffer.requires_grad 576 for buffer in score_mod_other_buffers + mask_mod_other_buffers 577 ) 578 assert ( 579 not any_buffer_requires_grad 580 ), "Captured buffers that require grad are not yet supported." 581 ctx._fw_graph = fw_graph 582 ctx._joint_graph = joint_graph 583 ctx._mask_graph = block_mask[-1] 584 # KV_BLOCK_SIZE and Q_BLOCK_SIZE are integers, so can't use ctx.save_for_backward 585 ctx._KV_BLOCK_SIZE = block_mask[8] 586 ctx._Q_BLOCK_SIZE = block_mask[9] 587 ctx.scale = scale 588 ctx.kernel_options = kernel_options 589 ctx._score_mod_other_buffers_len = len(score_mod_other_buffers) 590 with torch._C._AutoDispatchBelowAutograd(): 591 out, logsumexp = flex_attention( 592 query, 593 key, 594 value, 595 fw_graph, 596 block_mask, 597 scale, 598 kernel_options, 599 score_mod_other_buffers, 600 mask_mod_other_buffers, 601 ) 602 603 ctx.save_for_backward( 604 query, 605 key, 606 value, 607 out, 608 logsumexp, 609 *block_mask[:8], 610 *score_mod_other_buffers, 611 *mask_mod_other_buffers, 612 ) 613 return out, logsumexp 614 615 @staticmethod 616 def backward(ctx, grad_out, grad_logsumexp): 617 fw_args = ctx.saved_tensors 618 ( 619 query, 620 key, 621 value, 622 out, 623 logsumexp, 624 kv_num_blocks, 625 kv_indices, 626 full_kv_num_blocks, 627 full_kv_indices, 628 q_num_blocks, 629 q_indices, 630 full_q_num_blocks, 631 full_q_indices, 632 *other_buffers, 633 ) = fw_args 634 fw_graph = ctx._fw_graph 635 joint_graph = ctx._joint_graph 636 mask_graph = ctx._mask_graph 637 KV_BLOCK_SIZE = ctx._KV_BLOCK_SIZE 638 Q_BLOCK_SIZE = ctx._Q_BLOCK_SIZE 639 scale = ctx.scale 640 kernel_options = ctx.kernel_options 641 score_mod_other_buffers = tuple( 642 other_buffers[: ctx._score_mod_other_buffers_len] 643 ) 644 mask_mod_other_buffers = tuple( 645 other_buffers[ctx._score_mod_other_buffers_len :] 646 ) 647 # We have asserted that other_buffers do not require grad in the forward 648 none_grads = [None] * 7 649 grad_query, grad_key, grad_value = flex_attention_backward( 650 query, 651 key, 652 value, 653 out, 654 logsumexp, 655 grad_out, 656 grad_logsumexp, 657 fw_graph, 658 joint_graph, 659 ( 660 kv_num_blocks, 661 kv_indices, 662 full_kv_num_blocks, 663 full_kv_indices, 664 q_num_blocks, 665 q_indices, 666 full_q_num_blocks, 667 full_q_indices, 668 KV_BLOCK_SIZE, 669 Q_BLOCK_SIZE, 670 mask_graph, 671 ), 672 scale, 673 kernel_options, 674 score_mod_other_buffers, 675 mask_mod_other_buffers, 676 ) 677 return grad_query, grad_key, grad_value, *none_grads 678 679 680@flex_attention.py_impl(DispatchKey.Autograd) 681def flex_attention_autograd( 682 query: torch.Tensor, 683 key: torch.Tensor, 684 value: torch.Tensor, 685 score_mod: Callable, 686 block_mask: Tuple, 687 scale: float, 688 kernel_options: Dict[str, Any], 689 score_mod_other_buffers: Tuple = (), 690 mask_mod_other_buffers: Tuple = (), 691) -> Tuple[torch.Tensor, torch.Tensor]: 692 with TransformGetItemToIndex(): 693 input_requires_grad = any(t.requires_grad for t in (query, key, value)) 694 if torch.is_grad_enabled() and input_requires_grad: 695 example_vals = [ 696 torch.zeros((), dtype=query.dtype, requires_grad=input_requires_grad) 697 ] + [torch.zeros((), dtype=torch.int) for _ in range(4)] 698 fw_graph, bw_graph = create_fw_bw_graph( 699 score_mod, example_vals, score_mod_other_buffers 700 ) 701 else: 702 fw_graph, bw_graph = score_mod, None 703 out, logsumexp = FlexAttentionAutogradOp.apply( 704 query, 705 key, 706 value, 707 fw_graph, 708 bw_graph, 709 block_mask, 710 scale, 711 kernel_options, 712 score_mod_other_buffers, 713 mask_mod_other_buffers, 714 ) 715 return out, logsumexp 716 717 718# ---------------------------- Backward HOP Implementation ---------------------------- 719 720 721@flex_attention_backward.py_impl(DispatchKey.CompositeExplicitAutograd) 722def sdpa_dense_backward( 723 query: torch.Tensor, 724 key: torch.Tensor, 725 value: torch.Tensor, 726 out: torch.Tensor, 727 logsumexp: torch.Tensor, 728 grad_out: torch.Tensor, 729 grad_logsumexp: torch.Tensor, 730 fw_graph: Callable, # GraphModule type hint? 731 joint_graph: Callable, 732 block_mask: Tuple, 733 scale: float, 734 kernel_options: Dict[str, Any], 735 score_mod_other_buffers: Tuple, 736 mask_mod_other_buffers: Tuple, 737) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 738 # Get outputs before calling repeat interleave 739 actual_grad_query = torch.empty_like(query) 740 actual_grad_key = torch.empty_like(key) 741 actual_grad_value = torch.empty_like(value) 742 743 G = query.size(1) // key.size(1) 744 key = torch.repeat_interleave(key, G, dim=1) 745 value = torch.repeat_interleave(value, G, dim=1) 746 747 # We're undoing the log -> log2 change of base in the forwards 748 logsumexp = logsumexp * math.log(2) 749 # The backwards formula for the log -> log2 change of base in the forwards 750 grad_logsumexp = grad_logsumexp / math.log(2) 751 scores, post_mod_scores = _math_attention_inner( 752 query, 753 key, 754 value, 755 fw_graph, 756 block_mask, 757 scale, 758 kernel_options, 759 score_mod_other_buffers, 760 mask_mod_other_buffers, 761 ) 762 masked_out_rows = logsumexp == -float("inf") 763 softmax_scores = torch.exp(post_mod_scores - logsumexp.unsqueeze(-1)) 764 softmax_scores = torch.where(masked_out_rows.unsqueeze(-1), 0, softmax_scores) 765 766 grad_value = softmax_scores.to(query.dtype).transpose(-2, -1) @ grad_out 767 768 grad_softmax_scores = grad_out @ value.transpose(-2, -1) 769 770 sum_scores = torch.sum(out * grad_out, -1, keepdim=True) 771 grad_score_mod = softmax_scores * ( 772 grad_softmax_scores - sum_scores + grad_logsumexp.unsqueeze(-1) 773 ) 774 775 b = torch.arange(0, scores.size(0), device=scores.device) 776 h = torch.arange(0, scores.size(1), device=scores.device) 777 m = torch.arange(0, scores.size(2), device=scores.device) 778 n = torch.arange(0, scores.size(3), device=scores.device) 779 780 mask_graph = block_mask[-1] 781 # Gradient of the inline score_mod function, with respect to the scores 782 captured_buffers_in_dim = (None,) * len(score_mod_other_buffers) 783 out_dims = [0, None, None, None, None] + [None] * len(score_mod_other_buffers) 784 from torch.nn.attention.flex_attention import _vmap_for_bhqkv 785 786 # inputs are [score, b, h, q_idx, kv_idx, gradOut, ...] 787 # score and gradOut are "fully" batched 788 joint_score_mod = _vmap_for_bhqkv( 789 joint_graph, 790 prefix=(0,), 791 suffix=(0,) + captured_buffers_in_dim, 792 out_dims=out_dims, 793 ) 794 with TransformGetItemToIndex(): 795 grad_scores, *_ = joint_score_mod( 796 scores, b, h, m, n, grad_score_mod, *score_mod_other_buffers 797 ) 798 grad_scores = grad_scores * scale 799 grad_scores = grad_scores.to(query.dtype) 800 801 mask_mod = _vmap_for_bhqkv( 802 mask_graph, prefix=(), suffix=(None,) * len(mask_mod_other_buffers) 803 ) 804 with TransformGetItemToIndex(): 805 mask_scores = mask_mod(b, h, m, n, *mask_mod_other_buffers) 806 grad_scores = torch.where( 807 mask_scores, grad_scores, torch.tensor(0, dtype=query.dtype) 808 ) 809 810 grad_query = grad_scores @ key 811 grad_key = grad_scores.transpose(-2, -1) @ query 812 813 # Reduce DK, DV along broadcasted heads. 814 grad_key = grad_key.view( 815 grad_key.size(0), -1, G, grad_key.size(-2), grad_key.size(-1) 816 ) 817 grad_value = grad_value.view( 818 grad_value.size(0), -1, G, grad_value.size(-2), grad_value.size(-1) 819 ) 820 821 grad_key = torch.sum(grad_key, 2, keepdim=False) 822 grad_value = torch.sum(grad_value, 2, keepdim=False) 823 824 actual_grad_query.copy_(grad_query) 825 actual_grad_key.copy_(grad_key) 826 actual_grad_value.copy_(grad_value) 827 828 return actual_grad_query, actual_grad_key, actual_grad_value 829 830 831def trace_flex_attention_backward( 832 proxy_mode: ProxyTorchDispatchMode, 833 query: torch.Tensor, 834 key: torch.Tensor, 835 value: torch.Tensor, 836 out: torch.Tensor, 837 logsumexp: torch.Tensor, 838 grad_out: torch.Tensor, 839 grad_logsumexp: torch.Tensor, 840 fw_graph: Union[Callable, GraphModule], 841 joint_graph: GraphModule, 842 block_mask: Tuple, 843 scale: float, 844 kernel_options: Dict[str, Any], 845 score_mod_other_buffers: Tuple = (), 846 mask_mod_other_buffers: Tuple = (), 847) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 848 """We already have the forward graph and joint graph from the forward pass, so we create a proxy attach both graphs""" 849 example_out = flex_attention_backward( 850 query, 851 key, 852 value, 853 out, 854 logsumexp, 855 grad_out, 856 grad_logsumexp, 857 fw_graph, 858 joint_graph, 859 block_mask, 860 scale, 861 kernel_options, 862 score_mod_other_buffers, 863 mask_mod_other_buffers, 864 ) 865 866 fw_example_vals = [ 867 torch.zeros((), dtype=query.dtype, requires_grad=query.requires_grad) 868 ] + [torch.zeros((), dtype=torch.int) for _ in range(4)] 869 bw_example_vals = fw_example_vals + [torch.zeros((), dtype=query.dtype)] 870 mask_example_vals = [torch.zeros((), dtype=torch.int) for _ in range(4)] 871 mask_graph = block_mask[-1] 872 with TransformGetItemToIndex(): 873 fw_graph = reenter_make_fx(fw_graph)(*fw_example_vals, *score_mod_other_buffers) 874 joint_graph = reenter_make_fx(joint_graph)( 875 *bw_example_vals, *score_mod_other_buffers 876 ) 877 mask_graph = reenter_make_fx(mask_graph)( 878 *mask_example_vals, *mask_mod_other_buffers 879 ) 880 assert isinstance(proxy_mode.tracer, torch.fx.Tracer) 881 block_mask = block_mask[:-1] + (mask_graph,) 882 proxy_mode.tracer.root.register_module("fw_graph", fw_graph) # type: ignore[arg-type] 883 proxy_mode.tracer.root.register_module("joint_graph", joint_graph) 884 proxy_mode.tracer.root.register_module("mask_graph", mask_graph) 885 node_args = ( 886 query, 887 key, 888 value, 889 out, 890 logsumexp, 891 grad_out, 892 grad_logsumexp, 893 fw_graph, 894 joint_graph, 895 block_mask, 896 scale, 897 kernel_options, 898 score_mod_other_buffers, 899 mask_mod_other_buffers, 900 ) 901 proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, node_args) 902 out_proxy = proxy_mode.tracer.create_proxy( 903 "call_function", 904 flex_attention_backward, 905 proxy_args, 906 {}, 907 name="flex_attention_backward", 908 ) 909 return track_tensor_tree( 910 example_out, out_proxy, constant=None, tracer=proxy_mode.tracer 911 ) 912 913 914@flex_attention_backward.py_impl(ProxyTorchDispatchMode) 915def flex_attention_backward_proxy_torch_dispatch_mode( 916 mode: ProxyTorchDispatchMode, 917 query: torch.Tensor, 918 key: torch.Tensor, 919 value: torch.Tensor, 920 out: torch.Tensor, 921 logsumexp: torch.Tensor, 922 grad_out: torch.Tensor, 923 grad_logsumexp: torch.Tensor, 924 fw_graph: Union[Callable, GraphModule], 925 joint_graph: GraphModule, 926 block_mask: Tuple, 927 scale: float, 928 kernel_options: Dict[str, Any], 929 score_mod_other_buffers: Tuple = (), 930 mask_mod_other_buffers: Tuple = (), 931) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 932 assert mode is not None, "Mode should always be enabled for python fallback key" 933 return trace_flex_attention_backward( 934 mode, 935 query, 936 key, 937 value, 938 out, 939 logsumexp, 940 grad_out, 941 grad_logsumexp, 942 fw_graph, 943 joint_graph, 944 block_mask, 945 scale, 946 kernel_options, 947 score_mod_other_buffers, 948 mask_mod_other_buffers, 949 ) 950 951 952@flex_attention_backward.py_functionalize_impl 953def flex_attention_backward_functionalize( 954 ctx: torch._subclasses.functional_tensor.BaseFunctionalizeAPI, 955 query: torch.Tensor, 956 key: torch.Tensor, 957 value: torch.Tensor, 958 out: torch.Tensor, 959 logsumexp: torch.Tensor, 960 grad_out: torch.Tensor, 961 grad_logsumexp: torch.Tensor, 962 fw_graph: Union[Callable, GraphModule], 963 joint_graph: GraphModule, 964 block_mask: Tuple, 965 scale: float, 966 kernel_options: Dict[str, Any], 967 score_mod_other_buffers: Tuple = (), 968 mask_mod_other_buffers: Tuple = (), 969) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 970 """Defines the functionalization rules for the flex_attention operator. 971 972 Write now we are unwrapping each tensor and then redispatching to the next, 973 since we know that the forward score mod function is assured to be free of mutations 974 to the other_buffers, we skip that mutate check and go straight to redispatching. 975 """ 976 query_unwrapped = ctx.unwrap_tensors(query) 977 key_unwrapped = ctx.unwrap_tensors(key) 978 value_unwrapped = ctx.unwrap_tensors(value) 979 out_unwrapped = ctx.unwrap_tensors(out) 980 logsumexp_unwrapped = ctx.unwrap_tensors(logsumexp) 981 grad_out_unwrapped = ctx.unwrap_tensors(grad_out) 982 grad_logsumexp_unwrapped = ctx.unwrap_tensors(grad_logsumexp) 983 block_mask_unwrapped = ctx.unwrap_tensors(block_mask) 984 score_mod_other_buffers_unwrapped = ctx.unwrap_tensors(score_mod_other_buffers) 985 mask_mod_other_buffers_unwrapped = ctx.unwrap_tensors(mask_mod_other_buffers) 986 987 # Appease the mypy overlords 988 assert isinstance(query_unwrapped, torch.Tensor) 989 assert isinstance(key_unwrapped, torch.Tensor) 990 assert isinstance(value_unwrapped, torch.Tensor) 991 assert isinstance(out_unwrapped, torch.Tensor) 992 assert isinstance(logsumexp_unwrapped, torch.Tensor) 993 assert isinstance(grad_out_unwrapped, torch.Tensor) 994 assert isinstance(grad_logsumexp_unwrapped, torch.Tensor) 995 assert isinstance(block_mask_unwrapped, tuple) 996 assert isinstance(score_mod_other_buffers_unwrapped, tuple) 997 assert isinstance(mask_mod_other_buffers_unwrapped, tuple) 998 assert all( 999 isinstance(item, torch.Tensor) 1000 for item in score_mod_other_buffers_unwrapped + mask_mod_other_buffers_unwrapped 1001 ) 1002 1003 with ctx.redispatch_to_next() as m: 1004 functional_fw_graph = ctx.functionalize(fw_graph) 1005 functional_joint_graph = ctx.functionalize(joint_graph) 1006 1007 grad_query, grad_key, grad_value = flex_attention_backward( 1008 query_unwrapped, 1009 key_unwrapped, 1010 value_unwrapped, 1011 out_unwrapped, 1012 logsumexp_unwrapped, 1013 grad_out_unwrapped, 1014 grad_logsumexp_unwrapped, 1015 functional_fw_graph, # type: ignore[arg-type] 1016 functional_joint_graph, # type: ignore[arg-type] 1017 block_mask_unwrapped, 1018 scale, 1019 kernel_options, 1020 score_mod_other_buffers_unwrapped, 1021 mask_mod_other_buffers_unwrapped, 1022 ) 1023 1024 return ctx.wrap_tensors((grad_query, grad_key, grad_value)) # type: ignore[return-value,arg-type] 1025 1026 1027@flex_attention_backward.py_impl(FakeTensorMode) 1028def flex_attention_backward_fake_tensor_mode( 1029 mode: FakeTensorMode, 1030 query: torch.Tensor, 1031 key: torch.Tensor, 1032 value: torch.Tensor, 1033 out: torch.Tensor, 1034 logsumexp: torch.Tensor, 1035 grad_out: torch.Tensor, 1036 grad_logsumexp: torch.Tensor, 1037 fw_graph: Union[Callable, GraphModule], 1038 joint_graph: GraphModule, 1039 block_mask: Tuple, 1040 scale: float, 1041 kernel_options: Dict[str, Any], 1042 score_mod_other_buffers: Tuple = (), 1043 mask_mod_other_buffers: Tuple = (), 1044) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 1045 with mode: 1046 grad_query = torch.empty_like(query) 1047 grad_key = torch.empty_like(key) 1048 grad_value = torch.empty_like(value) 1049 return grad_query, grad_key, grad_value 1050 1051 1052flex_attention_backward.py_impl(DispatchKey.Autograd)( 1053 autograd_not_implemented(flex_attention_backward, deferred_error=True) 1054) 1055