xref: /aosp_15_r20/external/pytorch/torch/_higher_order_ops/flex_attention.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-decorators
2# mypy: allow-untyped-defs
3import math
4from typing import Any, Callable, Dict, Sequence, Tuple, Union
5
6import torch
7import torch.utils._pytree as pytree
8from torch._C import DispatchKey
9from torch._higher_order_ops.utils import (
10    _has_potential_branch_input_mutation,
11    autograd_not_implemented,
12    reenter_make_fx,
13    UnsupportedAliasMutationException,
14)
15from torch._ops import HigherOrderOperator
16from torch._subclasses import FakeTensorMode
17from torch.fx.experimental.proxy_tensor import (
18    make_fx,
19    ProxyTorchDispatchMode,
20    track_tensor_tree,
21)
22from torch.fx.graph_module import GraphModule
23from torch.overrides import TorchFunctionMode
24
25
26# Duplicate of _inductor/kernel/flex_attention.py to avoid circular import
27def _construct_strides(
28    sizes: Sequence[int],
29    fill_order: Sequence[int],
30) -> Sequence[int]:
31    """From a list of sizes and a fill order, construct the strides of the permuted tensor."""
32    # Initialize strides
33    assert len(sizes) == len(
34        fill_order
35    ), "Length of sizes must match the length of the fill order"
36    strides = [0] * len(sizes)
37
38    # Start with stride 1 for the innermost dimension
39    current_stride = 1
40
41    # Iterate through the fill order populating strides
42    for dim in fill_order:
43        strides[dim] = current_stride
44        current_stride *= sizes[dim]
45
46    return strides
47
48
49def _permute_strides(out: torch.Tensor, query_strides: Tuple[int, ...]) -> torch.Tensor:
50    """
51    Create a new tensor with the same data and shape as the input,
52    but with strides permuted based on the input tensor's stride order.
53
54    Args:
55        out (torch.Tensor): The output tensor of attention.
56        query_strides (List[int]): The stride order of the input query tensor
57
58    Returns:
59        torch.Tensor: A new tensor with same shape and data as the input,
60        but with strides permuted based on the query tensor's stride order.
61    """
62    from torch._inductor.ir import get_stride_order, stride_order2fill_order
63
64    stride_order = get_stride_order(query_strides)
65    fill_order = stride_order2fill_order(stride_order)
66    assert out.storage_offset() == 0, "Only support storage_offset == 0"
67    out_strides = _construct_strides(out.shape, fill_order)
68    new_out = out.new_empty(out.shape).as_strided(out.shape, out_strides)
69    new_out.copy_(out)
70    return new_out
71
72
73class TransformGetItemToIndex(TorchFunctionMode):
74    # This is needed since we want to support calling
75    # A[q_idx], where q_idx is a scalar tensor in score_mod.
76    # Today, when q_idx is a scalar tensor, we implicitly convert it to a python
77    # scalar and create a view. We do not want that behavior in this case, so we
78    # use this torchfunctionmode to override that behavior for score_mod
79    # wherever we're running it.
80    def __torch_function__(self, func, types, args, kwargs=None):
81        if func == torch.Tensor.__getitem__:
82            index_args = pytree.tree_leaves(args[1])
83            if all(isinstance(x, torch.Tensor) for x in index_args):
84                return torch.ops.aten.index(args[0], index_args)
85        return func(*args, **(kwargs or {}))
86
87
88class FlexAttentionHOP(HigherOrderOperator):
89    def __init__(self) -> None:
90        super().__init__("flex_attention")
91
92    def __call__(
93        self,
94        query: torch.Tensor,
95        key: torch.Tensor,
96        value: torch.Tensor,
97        score_mod: Callable,
98        block_mask: Tuple,
99        scale: float,
100        kernel_options: Dict[str, Any],
101        score_mod_other_buffers: Tuple = (),
102        mask_mod_other_buffers: Tuple = (),
103    ) -> Tuple[torch.Tensor, torch.Tensor]:
104        if not all(
105            isinstance(buf, torch.Tensor)
106            for buf in score_mod_other_buffers + mask_mod_other_buffers
107        ):
108            raise RuntimeError("Other buffers must be tensors.")
109        return super().__call__(
110            query,
111            key,
112            value,
113            score_mod,
114            block_mask,
115            scale,
116            kernel_options,
117            score_mod_other_buffers,
118            mask_mod_other_buffers,
119        )
120
121
122flex_attention = FlexAttentionHOP()
123
124
125class FlexAttentionBackwardHOP(HigherOrderOperator):
126    def __init__(self) -> None:
127        super().__init__("flex_attention_backward")
128
129    def __call__(
130        self,
131        query: torch.Tensor,
132        key: torch.Tensor,
133        value: torch.Tensor,
134        out: torch.Tensor,
135        logsumexp: torch.Tensor,
136        grad_out: torch.Tensor,
137        grad_logsumexp: torch.Tensor,
138        fw_graph: Union[Callable, GraphModule],
139        joint_graph: GraphModule,
140        block_mask: Tuple,
141        scale: float,
142        kernel_options: Dict[str, Any],
143        score_mod_other_buffers: Tuple = (),
144        mask_mod_other_buffers: Tuple = (),
145    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
146        if not all(
147            isinstance(buf, torch.Tensor)
148            for buf in score_mod_other_buffers + mask_mod_other_buffers
149        ):
150            raise RuntimeError("Other buffers must be tensors.")
151        return super().__call__(
152            query,
153            key,
154            value,
155            out,
156            logsumexp,
157            grad_out,
158            grad_logsumexp,
159            fw_graph,
160            joint_graph,
161            block_mask,
162            scale,
163            kernel_options,
164            score_mod_other_buffers,
165            mask_mod_other_buffers,
166        )
167
168
169flex_attention_backward = FlexAttentionBackwardHOP()
170
171
172def _math_attention_inner(
173    query: torch.Tensor,
174    key: torch.Tensor,
175    value: torch.Tensor,
176    score_mod: Callable,
177    block_mask: Tuple,
178    scale: float,
179    kernel_options: Dict[str, Any],
180    score_mod_other_buffers: Tuple = (),
181    mask_mod_other_buffers: Tuple = (),
182) -> Tuple[torch.Tensor, torch.Tensor]:
183    working_precision = torch.float64 if query.dtype == torch.float64 else torch.float32
184
185    scores = (query @ key.transpose(-2, -1)).to(dtype=working_precision)
186
187    b = torch.arange(0, scores.size(0), device=scores.device)
188    h = torch.arange(0, scores.size(1), device=scores.device)
189    m = torch.arange(0, scores.size(2), device=scores.device)
190    n = torch.arange(0, scores.size(3), device=scores.device)
191
192    captured_buffers_in_dim = (None,) * len(score_mod_other_buffers)
193    from torch.nn.attention.flex_attention import _vmap_for_bhqkv
194
195    # first input is score
196    score_mod = _vmap_for_bhqkv(score_mod, prefix=(0,), suffix=captured_buffers_in_dim)
197
198    mask_mod = block_mask[-1]
199    mask_mod_in_dim_buffers = (None,) * len(mask_mod_other_buffers)
200    mask_mod = _vmap_for_bhqkv(mask_mod, prefix=(), suffix=mask_mod_in_dim_buffers)
201
202    with TransformGetItemToIndex():
203        scores = (scores * scale).to(working_precision)
204        post_mod_scores = torch.where(
205            mask_mod(b, h, m, n, *mask_mod_other_buffers),
206            score_mod(scores, b, h, m, n, *score_mod_other_buffers),
207            torch.tensor(-float("inf"), dtype=working_precision, device=scores.device),
208        )
209
210    return scores, post_mod_scores
211
212
213def math_attention(
214    query: torch.Tensor,
215    key: torch.Tensor,
216    value: torch.Tensor,
217    score_mod: Callable,
218    block_mask: Tuple,
219    scale: float,
220    kernel_options: Dict[str, Any],
221    score_mod_other_buffers: Tuple = (),
222    mask_mod_other_buffers: Tuple = (),
223) -> Tuple[torch.Tensor, torch.Tensor]:
224    """Eager implementation
225
226    This implementation uses vmap to vectorize the score_mod function over the batch, head, m, and n dimensions.
227    We then apply the vectorized score_mod function to the scores matrix. Each wrap of vmap applies one of the
228    batch, head, m, or n dimensions. We need to apply vmap 4 times to vectorized over all 4 dimensions.
229
230    Args:
231        query: The query tensor
232        key: The key tensor
233        value: The value tensor
234        score_mod: The score_mod function
235        other_buffers: Other buffers that are passed to the score_mod function
236    """
237    # broadcast query & key along head dim for GQA
238    G = query.size(1) // key.size(1)
239    value = torch.repeat_interleave(value, G, dim=1)
240    key = torch.repeat_interleave(key, G, dim=1)
241
242    _, post_mod_scores = _math_attention_inner(
243        query,
244        key,
245        value,
246        score_mod,
247        block_mask,
248        scale,
249        kernel_options,
250        score_mod_other_buffers,
251        mask_mod_other_buffers,
252    )
253
254    # Set fully masked rows' sumexp to 0.0
255    logsumexp = post_mod_scores.logsumexp(dim=-1)
256    masked_rows = torch.all(post_mod_scores == -float("inf"), dim=-1)
257    logsumexp = torch.where(masked_rows, -float("inf"), logsumexp)
258
259    post_mod_scores = torch._safe_softmax(post_mod_scores, dim=-1)
260
261    return post_mod_scores.to(query.dtype) @ value, logsumexp / math.log(2)
262
263
264@flex_attention.py_impl(DispatchKey.CompositeExplicitAutograd)
265def sdpa_dense(
266    query: torch.Tensor,
267    key: torch.Tensor,
268    value: torch.Tensor,
269    score_mod: Callable,
270    block_mask: Tuple,
271    scale: float,
272    kernel_options: Dict[str, Any],
273    score_mod_other_buffers: Tuple = (),
274    mask_mod_other_buffers: Tuple = (),
275) -> Tuple[torch.Tensor, torch.Tensor]:
276    out, lse = math_attention(
277        query,
278        key,
279        value,
280        score_mod,
281        block_mask,
282        scale,
283        kernel_options,
284        score_mod_other_buffers,
285        mask_mod_other_buffers,
286    )
287    out = _permute_strides(out, query.stride())
288    return out, lse
289
290
291def trace_flex_attention(
292    proxy_mode: ProxyTorchDispatchMode,
293    query: torch.Tensor,
294    key: torch.Tensor,
295    value: torch.Tensor,
296    score_mod: Callable,
297    block_mask: Tuple,
298    scale: float,
299    kernel_options: Dict[str, Any],
300    score_mod_other_buffers: Tuple = (),
301    mask_mod_other_buffers: Tuple = (),
302) -> Tuple[torch.Tensor, torch.Tensor]:
303    """Traces the flex_attention operator with the given score_mod function and other_buffers.
304
305    Trace SDPA will call make_fx with "fake" example vals and then trace the score_mod function
306    This will produce a GraphModule that will be stored on the root tracer as "sdpa_score". We
307    access this graph module in inductor to inline the score_mod function to the triton template.
308    """
309    example_out = flex_attention(
310        query,
311        key,
312        value,
313        score_mod,
314        block_mask,
315        scale,
316        kernel_options,
317        score_mod_other_buffers,
318        mask_mod_other_buffers,
319    )
320    example_vals = [
321        torch.zeros((), dtype=query.dtype, requires_grad=query.requires_grad)
322    ] + [torch.zeros((), dtype=torch.int) for _ in range(4)]
323    mask_example_vals = [torch.zeros((), dtype=torch.int) for _ in range(4)]
324    mask_mod = block_mask[-1]
325    with TransformGetItemToIndex():
326        score_graph = reenter_make_fx(score_mod)(
327            *example_vals, *score_mod_other_buffers
328        )
329        mask_graph = reenter_make_fx(mask_mod)(
330            *mask_example_vals, *mask_mod_other_buffers
331        )
332    assert isinstance(proxy_mode.tracer, torch.fx.Tracer)
333    block_mask = block_mask[:-1] + (mask_graph,)
334    qualname = proxy_mode.tracer.get_fresh_qualname("sdpa_score")
335    proxy_mode.tracer.root.register_module(qualname, score_graph)
336    mask_qualname = proxy_mode.tracer.get_fresh_qualname("sdpa_mask")
337    proxy_mode.tracer.root.register_module(mask_qualname, mask_graph)
338    node_args = (
339        query,
340        key,
341        value,
342        score_graph,
343        block_mask,
344        scale,
345        kernel_options,
346        score_mod_other_buffers,
347        mask_mod_other_buffers,
348    )
349    proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, node_args)
350    out_proxy = proxy_mode.tracer.create_proxy(
351        "call_function", flex_attention, proxy_args, {}
352    )
353    return track_tensor_tree(
354        example_out, out_proxy, constant=None, tracer=proxy_mode.tracer
355    )
356
357
358@flex_attention.py_impl(ProxyTorchDispatchMode)
359def flex_attention_proxy_torch_dispatch_mode(
360    mode: ProxyTorchDispatchMode,
361    query: torch.Tensor,
362    key: torch.Tensor,
363    value: torch.Tensor,
364    score_mod: Callable,
365    block_mask: Tuple,
366    scale: float,
367    kernel_options: Dict[str, Any],
368    score_mod_other_buffers: Tuple = (),
369    mask_mod_other_buffers: Tuple = (),
370) -> Tuple[torch.Tensor, torch.Tensor]:
371    assert mode is not None, "Mode should always be enabled for python fallback key"
372    return trace_flex_attention(
373        mode,
374        query,
375        key,
376        value,
377        score_mod,
378        block_mask,
379        scale,
380        kernel_options,
381        score_mod_other_buffers,
382        mask_mod_other_buffers,
383    )
384
385
386@flex_attention.py_functionalize_impl
387def flex_attention_functionalize(
388    ctx: torch._subclasses.functional_tensor.BaseFunctionalizeAPI,
389    query: torch.Tensor,
390    key: torch.Tensor,
391    value: torch.Tensor,
392    score_mod: Callable,
393    block_mask: Tuple,
394    scale: float,
395    kernel_options: Dict[str, Any],
396    score_mod_other_buffers: Tuple = (),
397    mask_mod_other_buffers: Tuple = (),
398) -> Tuple[torch.Tensor, torch.Tensor]:
399    """Defines the functionalization rules for the flex_attention operator.
400
401    Write now we are unwrapping each tensor and then redispatching to the next, however we want to
402    guard against any mutations in the score_mod function, to the other_buffers since those
403    are free variables.
404    """
405    query_unwrapped = ctx.unwrap_tensors(query)
406    key_unwrapped = ctx.unwrap_tensors(key)
407    value_unwrapped = ctx.unwrap_tensors(value)
408    block_mask_unwrapped = ctx.unwrap_tensors(block_mask)
409    score_mod_other_buffers_unwrapped = ctx.unwrap_tensors(score_mod_other_buffers)
410    mask_mod_other_buffers_unwrapped = ctx.unwrap_tensors(mask_mod_other_buffers)
411
412    # Appease the mypy overlords
413    assert isinstance(query_unwrapped, torch.Tensor)
414    assert isinstance(key_unwrapped, torch.Tensor)
415    assert isinstance(value_unwrapped, torch.Tensor)
416    assert isinstance(block_mask_unwrapped, tuple)
417    assert isinstance(score_mod_other_buffers_unwrapped, tuple)
418    assert isinstance(mask_mod_other_buffers_unwrapped, tuple)
419    assert all(
420        isinstance(item, torch.Tensor)
421        for item in score_mod_other_buffers_unwrapped + mask_mod_other_buffers_unwrapped
422    )
423
424    example_vals = (
425        [torch.zeros((), dtype=query.dtype)]
426        + [torch.zeros((), dtype=torch.int) for _ in range(4)]
427        + list(score_mod_other_buffers_unwrapped)
428    )
429    with ctx.redispatch_to_next() as m:
430        functional_score_mod = ctx.functionalize(score_mod)
431        pre_dispatch = hasattr(ctx, "mode") and ctx.mode.pre_dispatch
432        with TransformGetItemToIndex():
433            mutates = _has_potential_branch_input_mutation(
434                functional_score_mod, example_vals, pre_dispatch
435            )
436        # The only care about mutations of existing buffers since we can't replay these.
437        # However, we can just error if anything is detected
438        if mutates:
439            raise UnsupportedAliasMutationException("Mutations detected in score_mod")
440
441        out = flex_attention(
442            query_unwrapped,
443            key_unwrapped,
444            value_unwrapped,
445            functional_score_mod,
446            block_mask_unwrapped,
447            scale,
448            kernel_options,
449            score_mod_other_buffers_unwrapped,
450            mask_mod_other_buffers_unwrapped,
451        )
452    return ctx.wrap_tensors(out)  # type: ignore[return-value, arg-type]
453
454
455@flex_attention.py_impl(FakeTensorMode)
456def flex_attention_fake_tensor_mode(
457    mode: FakeTensorMode,
458    query: torch.Tensor,
459    key: torch.Tensor,
460    value: torch.Tensor,
461    score_mod: Callable,
462    block_mask: Tuple,
463    scale: float,
464    kernel_options: Dict[str, Any],
465    score_mod_other_buffers: Tuple = (),
466    mask_mod_other_buffers: Tuple = (),
467) -> Tuple[torch.Tensor, torch.Tensor]:
468    with mode:
469        v_head_dim = value.size(-1)
470        batch_size, num_heads, seq_len_q, q_head_dim = query.shape
471        logsumexp = query.new_empty(
472            batch_size, num_heads, seq_len_q, dtype=torch.float32
473        )
474        out_shape = (batch_size, num_heads, seq_len_q, v_head_dim)
475        out = query.new_empty(out_shape)
476        out = _permute_strides(out, query.stride())
477        return out, logsumexp
478
479
480# ---------------------------- Autograd Implementation ----------------------------
481def create_fw_bw_graph(score_mod, index_values, other_buffers):
482    # See Note:[HOP create fw_bw graph]
483
484    # All of these imports need to be here in order to avoid circular dependencies
485    from torch._dispatch.python import suspend_functionalization
486    from torch._functorch.aot_autograd import AOTConfig, create_joint
487    from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
488    from torch._subclasses.functional_tensor import disable_functional_mode
489    from torch.fx.experimental.proxy_tensor import disable_proxy_modes_tracing
490
491    dummy_aot_config = AOTConfig(
492        fw_compiler=None,  # type: ignore[arg-type]
493        bw_compiler=None,  # type: ignore[arg-type]
494        partition_fn=None,  # type: ignore[arg-type]
495        decompositions={},
496        num_params_buffers=0,
497        aot_id=0,
498        keep_inference_input_mutations=False,
499    )
500
501    with suspend_functionalization(), disable_functional_mode():
502        with disable_proxy_modes_tracing():
503
504            def _from_fun(t):
505                return torch.empty_strided(
506                    t.size(),
507                    t.stride(),
508                    device=t.device,
509                    dtype=t.dtype,
510                    requires_grad=t.requires_grad,
511                )
512
513            # If someone runs this hop under the default compiler backend ("eager")
514            # Then this path will be run with the actual user inputs. We convert them
515            # to fake tensors in order to not perform any actual compute.
516            from torch._guards import detect_fake_mode
517
518            fake_mode = detect_fake_mode(index_values)
519            if fake_mode is None:
520                fake_mode = FakeTensorMode(allow_non_fake_inputs=True)
521
522            with fake_mode:
523                unwrapped_score_mod_indexes = pytree.tree_map(_from_fun, index_values)
524                unwrapped_other_buffers = pytree.tree_map(_from_fun, other_buffers)
525
526            assert all(isinstance(t, FakeTensor) for t in unwrapped_score_mod_indexes)
527            assert all(isinstance(t, FakeTensor) for t in unwrapped_other_buffers)
528
529            example_flat_out = pytree.tree_map(
530                _from_fun,
531                score_mod(*unwrapped_score_mod_indexes, *unwrapped_other_buffers),
532            )
533            if not isinstance(example_flat_out, torch.Tensor):
534                raise RuntimeError(
535                    "Expected output of score_mod to be a tensor."
536                    f"Got type {type(example_flat_out)}."
537                )
538            example_grad = _from_fun(example_flat_out)
539
540        def joint_f(score, b, h, m, n, example_grad, *other_buffers):
541            def fw_with_masks(*args):
542                fw_out = score_mod(*args)
543                out_requires_grad = fw_out.requires_grad
544                return ((fw_out,), (out_requires_grad,))
545
546            joint = create_joint(fw_with_masks, aot_config=dummy_aot_config)
547            args = [score, b, h, m, n] + list(other_buffers)
548            optional_grad = [example_grad] if example_grad.requires_grad else []
549            _, grads = joint(args, optional_grad)
550
551            return grads
552
553        joint_graph = make_fx(joint_f)(
554            *unwrapped_score_mod_indexes, example_grad, *unwrapped_other_buffers
555        )
556        return score_mod, joint_graph
557
558
559class FlexAttentionAutogradOp(torch.autograd.Function):
560    @staticmethod
561    def forward(
562        ctx,
563        query,
564        key,
565        value,
566        fw_graph,
567        joint_graph,
568        block_mask,
569        scale,
570        kernel_options,
571        score_mod_other_buffers,
572        mask_mod_other_buffers,
573    ) -> Tuple[torch.Tensor, torch.Tensor]:
574        any_buffer_requires_grad = any(
575            buffer.requires_grad
576            for buffer in score_mod_other_buffers + mask_mod_other_buffers
577        )
578        assert (
579            not any_buffer_requires_grad
580        ), "Captured buffers that require grad are not yet supported."
581        ctx._fw_graph = fw_graph
582        ctx._joint_graph = joint_graph
583        ctx._mask_graph = block_mask[-1]
584        # KV_BLOCK_SIZE and Q_BLOCK_SIZE are integers, so can't use ctx.save_for_backward
585        ctx._KV_BLOCK_SIZE = block_mask[8]
586        ctx._Q_BLOCK_SIZE = block_mask[9]
587        ctx.scale = scale
588        ctx.kernel_options = kernel_options
589        ctx._score_mod_other_buffers_len = len(score_mod_other_buffers)
590        with torch._C._AutoDispatchBelowAutograd():
591            out, logsumexp = flex_attention(
592                query,
593                key,
594                value,
595                fw_graph,
596                block_mask,
597                scale,
598                kernel_options,
599                score_mod_other_buffers,
600                mask_mod_other_buffers,
601            )
602
603        ctx.save_for_backward(
604            query,
605            key,
606            value,
607            out,
608            logsumexp,
609            *block_mask[:8],
610            *score_mod_other_buffers,
611            *mask_mod_other_buffers,
612        )
613        return out, logsumexp
614
615    @staticmethod
616    def backward(ctx, grad_out, grad_logsumexp):
617        fw_args = ctx.saved_tensors
618        (
619            query,
620            key,
621            value,
622            out,
623            logsumexp,
624            kv_num_blocks,
625            kv_indices,
626            full_kv_num_blocks,
627            full_kv_indices,
628            q_num_blocks,
629            q_indices,
630            full_q_num_blocks,
631            full_q_indices,
632            *other_buffers,
633        ) = fw_args
634        fw_graph = ctx._fw_graph
635        joint_graph = ctx._joint_graph
636        mask_graph = ctx._mask_graph
637        KV_BLOCK_SIZE = ctx._KV_BLOCK_SIZE
638        Q_BLOCK_SIZE = ctx._Q_BLOCK_SIZE
639        scale = ctx.scale
640        kernel_options = ctx.kernel_options
641        score_mod_other_buffers = tuple(
642            other_buffers[: ctx._score_mod_other_buffers_len]
643        )
644        mask_mod_other_buffers = tuple(
645            other_buffers[ctx._score_mod_other_buffers_len :]
646        )
647        # We have asserted that other_buffers do not require grad in the forward
648        none_grads = [None] * 7
649        grad_query, grad_key, grad_value = flex_attention_backward(
650            query,
651            key,
652            value,
653            out,
654            logsumexp,
655            grad_out,
656            grad_logsumexp,
657            fw_graph,
658            joint_graph,
659            (
660                kv_num_blocks,
661                kv_indices,
662                full_kv_num_blocks,
663                full_kv_indices,
664                q_num_blocks,
665                q_indices,
666                full_q_num_blocks,
667                full_q_indices,
668                KV_BLOCK_SIZE,
669                Q_BLOCK_SIZE,
670                mask_graph,
671            ),
672            scale,
673            kernel_options,
674            score_mod_other_buffers,
675            mask_mod_other_buffers,
676        )
677        return grad_query, grad_key, grad_value, *none_grads
678
679
680@flex_attention.py_impl(DispatchKey.Autograd)
681def flex_attention_autograd(
682    query: torch.Tensor,
683    key: torch.Tensor,
684    value: torch.Tensor,
685    score_mod: Callable,
686    block_mask: Tuple,
687    scale: float,
688    kernel_options: Dict[str, Any],
689    score_mod_other_buffers: Tuple = (),
690    mask_mod_other_buffers: Tuple = (),
691) -> Tuple[torch.Tensor, torch.Tensor]:
692    with TransformGetItemToIndex():
693        input_requires_grad = any(t.requires_grad for t in (query, key, value))
694        if torch.is_grad_enabled() and input_requires_grad:
695            example_vals = [
696                torch.zeros((), dtype=query.dtype, requires_grad=input_requires_grad)
697            ] + [torch.zeros((), dtype=torch.int) for _ in range(4)]
698            fw_graph, bw_graph = create_fw_bw_graph(
699                score_mod, example_vals, score_mod_other_buffers
700            )
701        else:
702            fw_graph, bw_graph = score_mod, None
703        out, logsumexp = FlexAttentionAutogradOp.apply(
704            query,
705            key,
706            value,
707            fw_graph,
708            bw_graph,
709            block_mask,
710            scale,
711            kernel_options,
712            score_mod_other_buffers,
713            mask_mod_other_buffers,
714        )
715    return out, logsumexp
716
717
718# ---------------------------- Backward HOP Implementation ----------------------------
719
720
721@flex_attention_backward.py_impl(DispatchKey.CompositeExplicitAutograd)
722def sdpa_dense_backward(
723    query: torch.Tensor,
724    key: torch.Tensor,
725    value: torch.Tensor,
726    out: torch.Tensor,
727    logsumexp: torch.Tensor,
728    grad_out: torch.Tensor,
729    grad_logsumexp: torch.Tensor,
730    fw_graph: Callable,  # GraphModule type hint?
731    joint_graph: Callable,
732    block_mask: Tuple,
733    scale: float,
734    kernel_options: Dict[str, Any],
735    score_mod_other_buffers: Tuple,
736    mask_mod_other_buffers: Tuple,
737) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
738    # Get outputs before calling repeat interleave
739    actual_grad_query = torch.empty_like(query)
740    actual_grad_key = torch.empty_like(key)
741    actual_grad_value = torch.empty_like(value)
742
743    G = query.size(1) // key.size(1)
744    key = torch.repeat_interleave(key, G, dim=1)
745    value = torch.repeat_interleave(value, G, dim=1)
746
747    # We're undoing the log -> log2 change of base in the forwards
748    logsumexp = logsumexp * math.log(2)
749    # The backwards formula for the log -> log2 change of base in the forwards
750    grad_logsumexp = grad_logsumexp / math.log(2)
751    scores, post_mod_scores = _math_attention_inner(
752        query,
753        key,
754        value,
755        fw_graph,
756        block_mask,
757        scale,
758        kernel_options,
759        score_mod_other_buffers,
760        mask_mod_other_buffers,
761    )
762    masked_out_rows = logsumexp == -float("inf")
763    softmax_scores = torch.exp(post_mod_scores - logsumexp.unsqueeze(-1))
764    softmax_scores = torch.where(masked_out_rows.unsqueeze(-1), 0, softmax_scores)
765
766    grad_value = softmax_scores.to(query.dtype).transpose(-2, -1) @ grad_out
767
768    grad_softmax_scores = grad_out @ value.transpose(-2, -1)
769
770    sum_scores = torch.sum(out * grad_out, -1, keepdim=True)
771    grad_score_mod = softmax_scores * (
772        grad_softmax_scores - sum_scores + grad_logsumexp.unsqueeze(-1)
773    )
774
775    b = torch.arange(0, scores.size(0), device=scores.device)
776    h = torch.arange(0, scores.size(1), device=scores.device)
777    m = torch.arange(0, scores.size(2), device=scores.device)
778    n = torch.arange(0, scores.size(3), device=scores.device)
779
780    mask_graph = block_mask[-1]
781    # Gradient of the inline score_mod function, with respect to the scores
782    captured_buffers_in_dim = (None,) * len(score_mod_other_buffers)
783    out_dims = [0, None, None, None, None] + [None] * len(score_mod_other_buffers)
784    from torch.nn.attention.flex_attention import _vmap_for_bhqkv
785
786    # inputs are [score, b, h, q_idx, kv_idx, gradOut, ...]
787    # score and gradOut are "fully" batched
788    joint_score_mod = _vmap_for_bhqkv(
789        joint_graph,
790        prefix=(0,),
791        suffix=(0,) + captured_buffers_in_dim,
792        out_dims=out_dims,
793    )
794    with TransformGetItemToIndex():
795        grad_scores, *_ = joint_score_mod(
796            scores, b, h, m, n, grad_score_mod, *score_mod_other_buffers
797        )
798    grad_scores = grad_scores * scale
799    grad_scores = grad_scores.to(query.dtype)
800
801    mask_mod = _vmap_for_bhqkv(
802        mask_graph, prefix=(), suffix=(None,) * len(mask_mod_other_buffers)
803    )
804    with TransformGetItemToIndex():
805        mask_scores = mask_mod(b, h, m, n, *mask_mod_other_buffers)
806        grad_scores = torch.where(
807            mask_scores, grad_scores, torch.tensor(0, dtype=query.dtype)
808        )
809
810    grad_query = grad_scores @ key
811    grad_key = grad_scores.transpose(-2, -1) @ query
812
813    # Reduce DK, DV along broadcasted heads.
814    grad_key = grad_key.view(
815        grad_key.size(0), -1, G, grad_key.size(-2), grad_key.size(-1)
816    )
817    grad_value = grad_value.view(
818        grad_value.size(0), -1, G, grad_value.size(-2), grad_value.size(-1)
819    )
820
821    grad_key = torch.sum(grad_key, 2, keepdim=False)
822    grad_value = torch.sum(grad_value, 2, keepdim=False)
823
824    actual_grad_query.copy_(grad_query)
825    actual_grad_key.copy_(grad_key)
826    actual_grad_value.copy_(grad_value)
827
828    return actual_grad_query, actual_grad_key, actual_grad_value
829
830
831def trace_flex_attention_backward(
832    proxy_mode: ProxyTorchDispatchMode,
833    query: torch.Tensor,
834    key: torch.Tensor,
835    value: torch.Tensor,
836    out: torch.Tensor,
837    logsumexp: torch.Tensor,
838    grad_out: torch.Tensor,
839    grad_logsumexp: torch.Tensor,
840    fw_graph: Union[Callable, GraphModule],
841    joint_graph: GraphModule,
842    block_mask: Tuple,
843    scale: float,
844    kernel_options: Dict[str, Any],
845    score_mod_other_buffers: Tuple = (),
846    mask_mod_other_buffers: Tuple = (),
847) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
848    """We already have the forward graph and joint graph from the forward pass, so we create a proxy attach both graphs"""
849    example_out = flex_attention_backward(
850        query,
851        key,
852        value,
853        out,
854        logsumexp,
855        grad_out,
856        grad_logsumexp,
857        fw_graph,
858        joint_graph,
859        block_mask,
860        scale,
861        kernel_options,
862        score_mod_other_buffers,
863        mask_mod_other_buffers,
864    )
865
866    fw_example_vals = [
867        torch.zeros((), dtype=query.dtype, requires_grad=query.requires_grad)
868    ] + [torch.zeros((), dtype=torch.int) for _ in range(4)]
869    bw_example_vals = fw_example_vals + [torch.zeros((), dtype=query.dtype)]
870    mask_example_vals = [torch.zeros((), dtype=torch.int) for _ in range(4)]
871    mask_graph = block_mask[-1]
872    with TransformGetItemToIndex():
873        fw_graph = reenter_make_fx(fw_graph)(*fw_example_vals, *score_mod_other_buffers)
874        joint_graph = reenter_make_fx(joint_graph)(
875            *bw_example_vals, *score_mod_other_buffers
876        )
877        mask_graph = reenter_make_fx(mask_graph)(
878            *mask_example_vals, *mask_mod_other_buffers
879        )
880    assert isinstance(proxy_mode.tracer, torch.fx.Tracer)
881    block_mask = block_mask[:-1] + (mask_graph,)
882    proxy_mode.tracer.root.register_module("fw_graph", fw_graph)  # type: ignore[arg-type]
883    proxy_mode.tracer.root.register_module("joint_graph", joint_graph)
884    proxy_mode.tracer.root.register_module("mask_graph", mask_graph)
885    node_args = (
886        query,
887        key,
888        value,
889        out,
890        logsumexp,
891        grad_out,
892        grad_logsumexp,
893        fw_graph,
894        joint_graph,
895        block_mask,
896        scale,
897        kernel_options,
898        score_mod_other_buffers,
899        mask_mod_other_buffers,
900    )
901    proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, node_args)
902    out_proxy = proxy_mode.tracer.create_proxy(
903        "call_function",
904        flex_attention_backward,
905        proxy_args,
906        {},
907        name="flex_attention_backward",
908    )
909    return track_tensor_tree(
910        example_out, out_proxy, constant=None, tracer=proxy_mode.tracer
911    )
912
913
914@flex_attention_backward.py_impl(ProxyTorchDispatchMode)
915def flex_attention_backward_proxy_torch_dispatch_mode(
916    mode: ProxyTorchDispatchMode,
917    query: torch.Tensor,
918    key: torch.Tensor,
919    value: torch.Tensor,
920    out: torch.Tensor,
921    logsumexp: torch.Tensor,
922    grad_out: torch.Tensor,
923    grad_logsumexp: torch.Tensor,
924    fw_graph: Union[Callable, GraphModule],
925    joint_graph: GraphModule,
926    block_mask: Tuple,
927    scale: float,
928    kernel_options: Dict[str, Any],
929    score_mod_other_buffers: Tuple = (),
930    mask_mod_other_buffers: Tuple = (),
931) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
932    assert mode is not None, "Mode should always be enabled for python fallback key"
933    return trace_flex_attention_backward(
934        mode,
935        query,
936        key,
937        value,
938        out,
939        logsumexp,
940        grad_out,
941        grad_logsumexp,
942        fw_graph,
943        joint_graph,
944        block_mask,
945        scale,
946        kernel_options,
947        score_mod_other_buffers,
948        mask_mod_other_buffers,
949    )
950
951
952@flex_attention_backward.py_functionalize_impl
953def flex_attention_backward_functionalize(
954    ctx: torch._subclasses.functional_tensor.BaseFunctionalizeAPI,
955    query: torch.Tensor,
956    key: torch.Tensor,
957    value: torch.Tensor,
958    out: torch.Tensor,
959    logsumexp: torch.Tensor,
960    grad_out: torch.Tensor,
961    grad_logsumexp: torch.Tensor,
962    fw_graph: Union[Callable, GraphModule],
963    joint_graph: GraphModule,
964    block_mask: Tuple,
965    scale: float,
966    kernel_options: Dict[str, Any],
967    score_mod_other_buffers: Tuple = (),
968    mask_mod_other_buffers: Tuple = (),
969) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
970    """Defines the functionalization rules for the flex_attention operator.
971
972    Write now we are unwrapping each tensor and then redispatching to the next,
973    since we know that the forward score mod function is assured to be free of mutations
974    to the other_buffers, we skip that mutate check and go straight to redispatching.
975    """
976    query_unwrapped = ctx.unwrap_tensors(query)
977    key_unwrapped = ctx.unwrap_tensors(key)
978    value_unwrapped = ctx.unwrap_tensors(value)
979    out_unwrapped = ctx.unwrap_tensors(out)
980    logsumexp_unwrapped = ctx.unwrap_tensors(logsumexp)
981    grad_out_unwrapped = ctx.unwrap_tensors(grad_out)
982    grad_logsumexp_unwrapped = ctx.unwrap_tensors(grad_logsumexp)
983    block_mask_unwrapped = ctx.unwrap_tensors(block_mask)
984    score_mod_other_buffers_unwrapped = ctx.unwrap_tensors(score_mod_other_buffers)
985    mask_mod_other_buffers_unwrapped = ctx.unwrap_tensors(mask_mod_other_buffers)
986
987    # Appease the mypy overlords
988    assert isinstance(query_unwrapped, torch.Tensor)
989    assert isinstance(key_unwrapped, torch.Tensor)
990    assert isinstance(value_unwrapped, torch.Tensor)
991    assert isinstance(out_unwrapped, torch.Tensor)
992    assert isinstance(logsumexp_unwrapped, torch.Tensor)
993    assert isinstance(grad_out_unwrapped, torch.Tensor)
994    assert isinstance(grad_logsumexp_unwrapped, torch.Tensor)
995    assert isinstance(block_mask_unwrapped, tuple)
996    assert isinstance(score_mod_other_buffers_unwrapped, tuple)
997    assert isinstance(mask_mod_other_buffers_unwrapped, tuple)
998    assert all(
999        isinstance(item, torch.Tensor)
1000        for item in score_mod_other_buffers_unwrapped + mask_mod_other_buffers_unwrapped
1001    )
1002
1003    with ctx.redispatch_to_next() as m:
1004        functional_fw_graph = ctx.functionalize(fw_graph)
1005        functional_joint_graph = ctx.functionalize(joint_graph)
1006
1007        grad_query, grad_key, grad_value = flex_attention_backward(
1008            query_unwrapped,
1009            key_unwrapped,
1010            value_unwrapped,
1011            out_unwrapped,
1012            logsumexp_unwrapped,
1013            grad_out_unwrapped,
1014            grad_logsumexp_unwrapped,
1015            functional_fw_graph,  # type: ignore[arg-type]
1016            functional_joint_graph,  # type: ignore[arg-type]
1017            block_mask_unwrapped,
1018            scale,
1019            kernel_options,
1020            score_mod_other_buffers_unwrapped,
1021            mask_mod_other_buffers_unwrapped,
1022        )
1023
1024    return ctx.wrap_tensors((grad_query, grad_key, grad_value))  # type: ignore[return-value,arg-type]
1025
1026
1027@flex_attention_backward.py_impl(FakeTensorMode)
1028def flex_attention_backward_fake_tensor_mode(
1029    mode: FakeTensorMode,
1030    query: torch.Tensor,
1031    key: torch.Tensor,
1032    value: torch.Tensor,
1033    out: torch.Tensor,
1034    logsumexp: torch.Tensor,
1035    grad_out: torch.Tensor,
1036    grad_logsumexp: torch.Tensor,
1037    fw_graph: Union[Callable, GraphModule],
1038    joint_graph: GraphModule,
1039    block_mask: Tuple,
1040    scale: float,
1041    kernel_options: Dict[str, Any],
1042    score_mod_other_buffers: Tuple = (),
1043    mask_mod_other_buffers: Tuple = (),
1044) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
1045    with mode:
1046        grad_query = torch.empty_like(query)
1047        grad_key = torch.empty_like(key)
1048        grad_value = torch.empty_like(value)
1049        return grad_query, grad_key, grad_value
1050
1051
1052flex_attention_backward.py_impl(DispatchKey.Autograd)(
1053    autograd_not_implemented(flex_attention_backward, deferred_error=True)
1054)
1055