xref: /aosp_15_r20/external/pytorch/torch/nn/attention/flex_attention.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-decorators
2# mypy: allow-untyped-defs
3# flake8: noqa C101
4"""This module implements the user facing API for flex_attention in PyTorch."""
5import functools
6import inspect
7import itertools
8import math
9import operator
10from contextlib import nullcontext
11from enum import Enum
12from typing import Any, Callable, Dict, List, Optional, Tuple, Union
13
14import torch
15from torch import Tensor
16from torch._higher_order_ops.flex_attention import (
17    flex_attention as flex_attention_hop,
18    TransformGetItemToIndex,
19)
20from torch._higher_order_ops.utils import _set_compilation_env
21from torch.fx.experimental.proxy_tensor import (
22    _temp_remove_pre_dispatch_torch_function_mode,
23)
24from torch.nn.attention._utils import _supported_head_dim, _validate_sdpa_input
25from torch.utils._pytree import tree_map_only
26
27
28__all__ = [
29    "BlockMask",
30    "flex_attention",
31    "create_block_mask",
32    "create_mask",
33    "or_masks",
34    "and_masks",
35    "noop_mask",
36]
37
38_score_mod_signature = Callable[[Tensor, Tensor, Tensor, Tensor, Tensor], Tensor]
39_mask_mod_signature = Callable[[Tensor, Tensor, Tensor, Tensor], Tensor]
40
41
42class _ModificationType(Enum):
43    """Enum for the type of modification function.
44    - SCORE_MOD: score_mod function which accepts a score as the first argument
45    - mask_mod: mask function which does not accept a score and is only used for generating
46    block mask
47    """
48
49    SCORE_MOD = 1
50    MASK_MOD = 2
51    UNKNOWN = 3
52
53
54def _get_mod_type(fn: Callable) -> _ModificationType:
55    """Get the type of modification function.
56    This function inspects the number of positional arguments of the function to determine
57    the type of modification function. If the function has 5 positional arguments, it is
58    considered as a score_mod function. If the function has 4 positional arguments, it is
59    considered as a mask function.
60    """
61    num_positional_args = sum(
62        1
63        for param in inspect.signature(fn).parameters.values()
64        if param.default == inspect.Parameter.empty
65    )
66    assert num_positional_args == 5 or num_positional_args == 4
67    if num_positional_args == 5:
68        return _ModificationType.SCORE_MOD
69    elif num_positional_args == 4:
70        return _ModificationType.MASK_MOD
71    else:
72        return _ModificationType.UNKNOWN
73
74
75# Need to define it here so that Dynamo doesn't skip it
76def _vmap_for_bhqkv(
77    fn: Callable,
78    prefix: Tuple[Optional[int], ...],
79    suffix: Tuple[Optional[int], ...] = (),
80    out_dims: Union[int, List[Optional[int]]] = 0,
81    group_dim: bool = False,
82):
83    """Used to vmap both score_mods and mask_mods over 4-dimensional/5-dimension inputs.
84    Mapping over the [b, hq, q_idx, kv_idx] or [b, hkv, g, q_idx, kv_idx] dimensions.
85
86    Args:
87        fn (callable): The function to vmap.
88        prefix (tuple): The prefix of the vmap. For score mod functions,
89                        this should be set to (0,). For mask_mods = ()
90        suffix (tuple): We need to add (0,) if gradOut is being mapped over,
91                        and (None,) * len(other_buffers).
92        out_dims (tuple): For forward cases, keep this as the default 0 since
93                          we are only returning 1 output. For backwards, the joint
94                          graph returns grads for B, H, Q_idx, KV_idx and other_buffers,
95                          so we set this to (0, None, None, None, None) + (None,) * len(other_buffers).
96
97    Returns:
98        callable: The vmapped function.
99    """
100    # We vamp a function 4 times, broadcasting the [b, h, q_idx, kv_idx] dimensions
101    dimensions: List[Tuple[None | int, None | int, None | int, None | int]] = []
102    dimensions = [
103        (None, None, None, 0),
104        (None, None, 0, None),
105        (None, 0, None, None),
106    ]
107
108    if group_dim:
109        dimensions += [
110            (None, 0, None, None),
111        ]
112
113    dimensions += [
114        (0, None, None, None),
115    ]
116
117    for dims in dimensions:
118        fn = torch.vmap(fn, in_dims=prefix + dims + suffix, out_dims=out_dims)
119    return fn
120
121
122def _identity(
123    score: Tensor,
124    batch: Tensor,
125    head: Tensor,
126    token_q: Tensor,
127    token_kv: Tensor,
128) -> Tensor:
129    return score
130
131
132def noop_mask(
133    batch: Tensor,
134    head: Tensor,
135    token_q: Tensor,
136    token_kv: Tensor,
137) -> Tensor:
138    """Returns a noop mask_mod"""
139    return batch.new_ones(size=(), dtype=torch.bool, device=batch.device)
140
141
142_DEFAULT_SPARSE_BLOCK_SIZE = 128
143_LARGE_SPARSE_BLOCK_SIZE = 1 << 30
144
145
146def _ordered_to_dense(num_blocks_in_row: Tensor, col_indices: Tensor):
147    num_rows = col_indices.shape[-2]
148    num_cols = col_indices.shape[-1]
149    batch_dims = num_blocks_in_row.shape[:-1]
150    device = num_blocks_in_row.device
151
152    def create_dense_one(kv_num_blocks, kv_indices):
153        dense_mask = kv_indices.new_zeros(num_rows, num_cols + 1, dtype=torch.int32)
154
155        row_indices = torch.arange(num_rows, dtype=torch.int, device=device).unsqueeze(
156            -1
157        )
158        col_range = torch.arange(num_cols, dtype=torch.int, device=device)
159        index_mask = col_range < kv_num_blocks.unsqueeze(-1)
160
161        # We write to one spot "out of bounds"
162        valid_indices = torch.where(index_mask, kv_indices, num_cols)
163
164        # set the values in 'a' to 1 where the indices are valid
165        dense_mask[row_indices, valid_indices] = 1
166        return dense_mask[:, :num_cols].contiguous()
167
168    create_dense_batched = create_dense_one
169    for _ in range(len(batch_dims)):
170        create_dense_batched = torch.vmap(create_dense_batched, in_dims=(0, 0))
171
172    out = create_dense_batched(num_blocks_in_row, col_indices)
173    return out
174
175
176def _dense_to_ordered(dense_mask) -> Tuple:
177    dense_mask = dense_mask.to(dtype=torch.int32)
178    num_blocks_in_row = dense_mask.sum(dim=-1)
179    col_indices = torch.argsort(dense_mask, dim=-1, descending=True, stable=True)
180    return (
181        num_blocks_in_row.to(torch.int32).contiguous(),
182        col_indices.to(torch.int32).contiguous(),
183    )
184
185
186def _transpose_ordered(num_blocks_in_row: Tensor, col_indices: Tensor):
187    dense = _ordered_to_dense(num_blocks_in_row, col_indices)
188    return _dense_to_ordered(dense.transpose(-2, -1))
189
190
191class BlockMask:
192    r"""
193    BlockMask is our format for representing a block-sparse attention mask.
194    It is somewhat of a cross in-between BCSR and a non-sparse format.
195
196    Basics
197    ------
198    A block-sparse mask means that instead of representing the sparsity of
199    individual elements in the mask, a KV_BLOCK_SIZE x Q_BLOCK_SIZE block is
200    considered sparse only if every element within that block is sparse.
201    This aligns well with hardware, which generally expects to perform
202    contiguous loads and computation.
203
204    This format is primarily optimized for 1. simplicity, and 2. kernel
205    efficiency. Notably, it is *not* optimized for size, as this mask is always
206    reduced by a factor of KV_BLOCK_SIZE * Q_BLOCK_SIZE. If the size is a
207    concern, the tensors can be reduced in size by increasing the block size.
208
209    The essentials of our format are:
210
211    num_blocks_in_row: Tensor[ROWS]:
212    Describes the number of blocks present in each row.
213
214    col_indices: Tensor[ROWS, MAX_BLOCKS_IN_COL]:
215    `col_indices[i]` is the sequence of block positions for row i. The values of
216    this row after `col_indices[i][num_blocks_in_row[i]]` are undefined.
217
218    For example, to reconstruct the original tensor from this format:
219
220    .. code-block:: python
221
222        dense_mask = torch.zeros(ROWS, COLS)
223        for row in range(ROWS):
224            for block_idx in range(num_blocks_in_row[row]):
225                dense_mask[row, col_indices[row, block_idx]] = 1
226
227    Notably, this format makes it easier to implement a reduction along the
228    *rows* of the mask.
229
230    Details
231    -------
232    The basics of our format require only kv_num_blocks and kv_indices. But, we
233    have up to 8 tensors on this object. This represents 4 pairs:
234
235    1. (kv_num_blocks, kv_indices): Used for the forwards pass of attention, as
236    we reduce along the KV dimension.
237
238    2. [OPTIONAL] (full_kv_num_blocks, full_kv_indices): This is optional and
239    purely an optimization. As it turns out, applying masking to every block
240    is quite expensive! If we specifically know which blocks are "full" and
241    don't require masking at all, then we can skip applying mask_mod to these
242    blocks. This requires the user to split out a separate mask_mod from the
243    score_mod. For causal masks, this is about a 15% speedup.
244
245    3. [GENERATED] (q_num_blocks, q_indices): Required for the backwards pass,
246    as computing dKV requires iterating along the mask along the Q dimension. These are autogenerated from 1.
247
248    4. [GENERATED] (full_q_num_blocks, full_q_indices): Same as above, but for
249    the backwards pass. These are autogenerated from 2.
250    """
251    kv_num_blocks: Tensor
252    kv_indices: Tensor
253    full_kv_num_blocks: Optional[Tensor]
254    full_kv_indices: Optional[Tensor]
255    q_num_blocks: Optional[Tensor]
256    q_indices: Optional[Tensor]
257    full_q_num_blocks: Optional[Tensor]
258    full_q_indices: Optional[Tensor]
259    BLOCK_SIZE: Tuple[int, int]
260    mask_mod: _mask_mod_signature
261
262    def __init__(
263        self,
264        kv_num_blocks: Tensor,
265        kv_indices: Tensor,
266        full_kv_num_blocks: Optional[Tensor],
267        full_kv_indices: Optional[Tensor],
268        q_num_blocks: Optional[Tensor],
269        q_indices: Optional[Tensor],
270        full_q_num_blocks: Optional[Tensor],
271        full_q_indices: Optional[Tensor],
272        BLOCK_SIZE: Tuple[int, int],
273        mask_mod: _mask_mod_signature,
274    ):
275        if kv_indices.dim() < 2:
276            raise RuntimeError("BlockMask must have at least 2 dimensions")
277        assert kv_num_blocks is not None, "kv_num_blocks must be provided"
278        assert kv_indices is not None, "kv_indices must be provided"
279        assert q_num_blocks is not None, "q_num_blocks must be provided"
280        assert q_indices is not None, "q_indices must be provided"
281        assert (full_kv_num_blocks is None) == (
282            full_kv_indices is None
283        ), "full_kv_num_blocks and full_kv_indices must be both provided or omitted"
284        assert (full_q_num_blocks is None) == (
285            full_q_indices is None
286        ), "full_q_num_blocks and full_q_indices must be both provided or omitted"
287
288        self.kv_num_blocks = kv_num_blocks
289        self.kv_indices = kv_indices
290        self.full_kv_num_blocks = full_kv_num_blocks
291        self.full_kv_indices = full_kv_indices
292        self.q_num_blocks = q_num_blocks
293        self.q_indices = q_indices
294        self.full_q_num_blocks = full_q_num_blocks
295        self.full_q_indices = full_q_indices
296        self.BLOCK_SIZE = BLOCK_SIZE
297        self.mask_mod = mask_mod
298
299    @classmethod
300    def from_kv_blocks(
301        cls,
302        kv_num_blocks: Tensor,
303        kv_indices: Tensor,
304        full_kv_num_blocks: Optional[Tensor] = None,
305        full_kv_indices: Optional[Tensor] = None,
306        BLOCK_SIZE: Union[int, Tuple[int, int]] = _DEFAULT_SPARSE_BLOCK_SIZE,
307        mask_mod: Optional[_mask_mod_signature] = None,
308    ):
309        """
310        Creates a BlockMask instance from key-value block information.
311
312        Args:
313            kv_num_blocks (Tensor): Number of kv_blocks in each Q_BLOCK_SIZE row tile.
314            kv_indices (Tensor): Indices of key-value blocks in each Q_BLOCK_SIZE row tile.
315            full_kv_num_blocks (Optional[Tensor]): Number of full kv_blocks in each Q_BLOCK_SIZE row tile.
316            full_kv_indices (Optional[Tensor]): Indices of full key-value blocks in each Q_BLOCK_SIZE row tile.
317            BLOCK_SIZE (Union[int, Tuple[int, int]]): Size of KV_BLOCK_SIZE x Q_BLOCK_SIZE tiles.
318            mask_mod (Optional[Callable]): Function to modify the mask.
319
320        Returns:
321            BlockMask: Instance with full Q information generated via _transposed_ordered
322
323        Raises:
324            RuntimeError: If kv_indices has < 2 dimensions.
325            AssertionError: If only one of full_kv_* args is provided.
326        """
327        if kv_indices.dim() < 2:
328            raise RuntimeError("BlockMask must have at least 2 dimensions")
329
330        assert (full_kv_num_blocks is None) == (
331            full_kv_indices is None
332        ), "full_kv_num_blocks and full_kv_indices must be both provided or omitted"
333
334        # Generate q_num_blocks and q_indices
335        q_num_blocks, q_indices = _transpose_ordered(kv_num_blocks, kv_indices)
336        if full_kv_num_blocks is not None:
337            assert full_kv_indices is not None
338            full_q_num_blocks, full_q_indices = _transpose_ordered(
339                full_kv_num_blocks, full_kv_indices
340            )
341        else:
342            full_q_num_blocks, full_q_indices = None, None
343
344        if isinstance(BLOCK_SIZE, int):
345            BLOCK_SIZE = (BLOCK_SIZE, BLOCK_SIZE)
346
347        mask_mod = mask_mod if mask_mod is not None else noop_mask
348
349        return cls(
350            kv_num_blocks=kv_num_blocks,
351            kv_indices=kv_indices,
352            full_kv_num_blocks=full_kv_num_blocks,
353            full_kv_indices=full_kv_indices,
354            q_num_blocks=q_num_blocks,
355            q_indices=q_indices,
356            full_q_num_blocks=full_q_num_blocks,
357            full_q_indices=full_q_indices,
358            BLOCK_SIZE=BLOCK_SIZE,
359            mask_mod=mask_mod,
360        )
361
362    def as_tuple(self, flatten: bool = True):
363        """
364        Returns a tuple of the attributes of the BlockMask.
365
366        Args:
367            flatten (bool): If True, it will flatten the tuple of (KV_BLOCK_SIZE, Q_BLOCK_SIZE)
368        """
369        block_size = (
370            (self.BLOCK_SIZE[0], self.BLOCK_SIZE[1]) if flatten else (self.BLOCK_SIZE,)
371        )
372
373        return (
374            self.kv_num_blocks,
375            self.kv_indices,
376            self.full_kv_num_blocks,
377            self.full_kv_indices,
378            self.q_num_blocks,
379            self.q_indices,
380            self.full_q_num_blocks,
381            self.full_q_indices,
382            *block_size,
383            self.mask_mod,
384        )
385
386    def __str__(self):
387        s = f"BlockMask(shape={self.shape}, sparsity={self.sparsity():.2f}%, \n"
388        mask_str = self.to_string().strip()
389        s += mask_str
390        s += "\n)"
391        return s
392
393    def __getitem__(self, index) -> "BlockMask":
394        """
395        Returns a new BlockMask instance by getting the mask for the given index position.
396
397        Args:
398            index: Index to apply to all attributes.
399
400        Example Usage:
401            .. code-block:: python
402
403                def causal_mask(b, h, q_idx, kv_idx):
404                    return q_idx >= kv_idx
405
406                block_mask = create_block_mask(causal_mask, 4, 2, 512, 512, device="cuda")
407                assert block_mask.kv_num_blocks.shape == (4,2,4)
408                assert block_mask.kv_indices.shape == (4,2,4,4)
409
410                # Index on batch dimension
411                new_block_mask = block_mask[0]
412                assert new_block_mask.kv_num_blocks.shape == (2,4)
413                assert new_block_mask.kv_indices.shape == (2,4,4)
414
415                # Index on batch and head dimension
416                new_block_mask = block_mask[0, 1]
417                assert new_block_mask.kv_num_blocks.shape == (4,)
418                assert new_block_mask.kv_indices.shape == (4,4)
419
420                # slicing on batch and head dimension
421                new_block_mask = block_mask[0:2, 1:2]
422                assert new_block_mask.kv_num_blocks.shape == (2,1,4)
423                assert new_block_mask.kv_indices.shape == (2,1,4,4)
424
425                # slicing on batch, head, and query dimension
426                new_block_mask = block_mask[0:2, 1:2, torch.tensor([1], dtype=torch.int32)]
427                assert new_block_mask.kv_num_blocks.shape == (2,1,1)
428                assert new_block_mask.kv_indices.shape == (2,1,1,4)
429        """
430        new_kv_num_blocks = self.kv_num_blocks[index]
431        new_kv_indices = self.kv_indices[index]
432        if self.full_kv_num_blocks is not None:
433            assert self.full_kv_indices is not None
434            new_full_kv_num_blocks = self.full_kv_num_blocks[index]
435            new_full_kv_indices = self.full_kv_indices[index]
436        else:
437            new_full_kv_num_blocks = None
438            new_full_kv_indices = None
439        return BlockMask.from_kv_blocks(
440            new_kv_num_blocks,
441            new_kv_indices,
442            new_full_kv_num_blocks,
443            new_full_kv_indices,
444            BLOCK_SIZE=self.BLOCK_SIZE,
445            mask_mod=None,
446        )
447
448    def __repr__(self):
449        def shape_or_none(x: Optional[torch.Tensor]):
450            return x.shape if x is not None else None
451
452        return (
453            f"BlockMask(\n"
454            f"    kv_num_blocks={self.kv_num_blocks.shape},\n"
455            f"    kv_indices={self.kv_indices.shape},\n"
456            f"    full_kv_num_blocks={shape_or_none(self.full_kv_num_blocks )},\n"
457            f"    full_kv_indices={shape_or_none(self.full_kv_indices)},\n"
458            f"    q_num_blocks={shape_or_none(self.q_num_blocks)},\n"
459            f"    q_indices={shape_or_none(self.q_indices)},\n"
460            f"    full_q_num_blocks={shape_or_none(self.full_q_num_blocks)},\n"
461            f"    full_q_indices={shape_or_none(self.full_q_indices)},\n"
462            f"    BLOCK_SIZE={self.BLOCK_SIZE},\n"
463            f"    shape={self.shape},\n"
464            f"    sparsity={self.sparsity():.2f}%,\n"
465            f"    mask_mod={self.mask_mod.__name__ if hasattr(self.mask_mod, '__name__') else self.mask_mod}\n"
466            f")"
467        )
468
469    @property
470    def shape(self):
471        """Returns the shape of the mask."""
472        *batch_dims, q_length, _ = self.kv_indices.shape
473        q_length = self.kv_indices.shape[-2] * self.BLOCK_SIZE[0]
474        kv_length = self.kv_indices.shape[-1] * self.BLOCK_SIZE[1]
475        return tuple(batch_dims + [q_length, kv_length])
476
477    def numel(self):
478        """Returns the number of elements (not accounting for sparsity) in the mask."""
479        shape = self.shape
480
481        def _prod(xs):
482            return functools.reduce(operator.mul, xs, 1)
483
484        return _prod(shape)
485
486    def sparsity(self) -> float:
487        """Computes the percentage of blocks that are sparse (i.e. not computed)"""
488        total_size = self.numel()
489        computed_blocks = self.kv_num_blocks.sum()
490        if self.full_kv_num_blocks is not None:
491            computed_blocks += self.full_kv_num_blocks.sum()
492
493        computed_size = computed_blocks.item() * self.BLOCK_SIZE[0] * self.BLOCK_SIZE[1]
494        dense_ratio = computed_size / total_size
495        return 100 * (1 - dense_ratio)
496
497    def to_dense(self) -> Tensor:
498        """Returns a dense block that is equivalent to the block mask."""
499        partial_dense = _ordered_to_dense(self.kv_num_blocks, self.kv_indices)
500        if self.full_kv_num_blocks is not None:
501            assert self.full_kv_indices is not None
502            return partial_dense | _ordered_to_dense(
503                self.full_kv_num_blocks, self.full_kv_indices
504            )
505        return partial_dense
506
507    def to_string(self, grid_size=(20, 20), limit=4):
508        """Returns a string representation of the block mask. Quite nifty.
509
510        If grid_size is None, prints out an uncompressed version. Warning, it can be quite big!
511        """
512        dense_mask = self.to_dense()
513        *batch_dims, num_rows, num_cols = dense_mask.shape
514        if isinstance(grid_size, int):
515            max_rows = grid_size
516            max_cols = grid_size
517        elif grid_size == -1:
518            max_rows = num_rows
519            max_cols = num_cols
520        else:
521            max_rows, max_cols = grid_size
522
523        def create_block_vis(*batch_idx):
524            descriptors = []
525
526            descriptors.append(f"{batch_idx}")
527
528            vis = ", ".join(reversed(descriptors)) + "\n"
529
530            def summarize_section(section):
531                percentage = section.float().mean().item()
532                if percentage == 1:
533                    return "█"
534                elif percentage == 0:
535                    return " "
536                else:
537                    return "░"
538
539            def cdiv(a, b):
540                return (a + (b - 1)) // b
541
542            row_step = max(1, cdiv(num_rows, max_rows))
543            col_step = max(1, cdiv(num_cols, max_cols))
544
545            for r in range(0, num_rows, row_step):
546                for c in range(0, num_cols, col_step):
547                    cur_mask = dense_mask
548                    for idx in batch_idx:
549                        cur_mask = cur_mask[idx]
550                    char = summarize_section(
551                        cur_mask[r : r + row_step, c : c + col_step]
552                    )
553                    vis += char * 2
554                vis += "\n"
555            return vis
556
557        total_vis = []
558        for idx, batch_idx in enumerate(
559            itertools.product(*[range(i) for i in batch_dims])
560        ):
561            if idx == limit:
562                total_vis.append("...")
563                total_vis.append("To print out more, set BlockMask.to_string(limit=N)")
564                total_vis.append(
565                    "You can also index (BlockMask[batch, head]) to choose a specific batch or head"
566                )
567                break
568            block_vis = create_block_vis(*batch_idx)
569            total_vis.append(block_vis)
570
571        return "\n".join(total_vis)
572
573    def to(self, device: Union[torch.device, str]) -> "BlockMask":
574        """Moves the BlockMask to the specified device.
575
576        Args:
577            device (torch.device or str): The target device to move the BlockMask to.
578                Can be a torch.device object or a string (e.g., 'cpu', 'cuda:0').
579
580        Returns:
581            BlockMask: A new BlockMask instance with all tensor components moved
582            to the specified device.
583
584        Note:
585            This method does not modify the original BlockMask in-place.
586            Instead, it returns a new BlockMask instance where invidual tensor attributes
587            may or may not be moved to the specified device, depending on their
588            current device placement.
589        """
590        mapped_attributes = tree_map_only(
591            torch.Tensor,
592            lambda x: x.to(device),
593            self.as_tuple(flatten=False),
594        )
595        return BlockMask(*mapped_attributes)
596
597
598def _broadcast_to_dim(x, dim):
599    while x.dim() < dim:
600        x = x.unsqueeze(0)
601    return x
602
603
604def _round_up_to_multiple(x, multiple):
605    return (x + multiple - 1) // multiple * multiple
606
607
608def _convert_mask_to_block_mask(
609    mask: Tensor,
610    KV_BLOCK_SIZE=_DEFAULT_SPARSE_BLOCK_SIZE,
611    Q_BLOCK_SIZE=_DEFAULT_SPARSE_BLOCK_SIZE,
612    separate_full_blocks: bool = False,
613) -> Tuple[Tensor, Optional[Tensor]]:
614    assert mask.dtype == torch.bool
615    mask = _broadcast_to_dim(mask, 4)
616    B, H, Q, KV = mask.shape
617    assert Q % Q_BLOCK_SIZE == 0
618    assert KV % KV_BLOCK_SIZE == 0
619    mask = mask.view(
620        B, H, Q // Q_BLOCK_SIZE, Q_BLOCK_SIZE, KV // KV_BLOCK_SIZE, KV_BLOCK_SIZE
621    )  # [B, H, Q//Q_BLOCK_SIZE, Q_BLOCK_SIZE, KV//KV_BLOCK_SIZE, KV_BLOCK_SIZE]
622    mask = mask.permute(
623        0, 1, 2, 4, 3, 5
624    )  # [B, H, Q//Q_BLOCK_SIZE, KV//KV_BLOCK_SIZE, Q_BLOCK_SIZE, KV_BLOCK_SIZE]
625    mask_block_sum = mask.sum(
626        dim=[-2, -1]
627    )  # [B, H, Q//Q_BLOCK_SIZE, KV//KV_BLOCK_SIZE]
628    if separate_full_blocks:
629        full_block_sum = Q_BLOCK_SIZE * KV_BLOCK_SIZE
630        full_blocks = mask_block_sum == full_block_sum
631        partial_blocks = (mask_block_sum > 0) & (mask_block_sum < full_block_sum)
632        partial_blocks = partial_blocks.to(dtype=torch.int8)
633        full_blocks = full_blocks.to(dtype=torch.int8)
634        return partial_blocks, full_blocks
635    else:
636        partial_blocks = mask_block_sum > 0
637        partial_blocks = partial_blocks.to(dtype=torch.int8)
638        return partial_blocks, None
639
640
641def or_masks(*mask_mods: _mask_mod_signature) -> _mask_mod_signature:
642    """Returns a mask_mod that's the union of provided mask_mods"""
643    if not all(callable(arg) for arg in mask_mods):
644        raise RuntimeError(f"All inputs should be callable mask_mods: {mask_mods}")
645
646    def or_mask(b, h, q_idx, kv_idx):
647        result = b.new_zeros((), dtype=torch.bool)
648        for mask in mask_mods:
649            result = result | mask(b, h, q_idx, kv_idx)
650        return result
651
652    return or_mask
653
654
655def and_masks(*mask_mods: _mask_mod_signature) -> _mask_mod_signature:
656    """Returns a mask_mod that's the intersection of provided mask_mods"""
657    if not all(callable(arg) for arg in mask_mods):
658        raise RuntimeError(f"All inputs should be callable mask_mods: {mask_mods}")
659
660    def and_mask(b, h, q_idx, kv_idx):
661        result = b.new_ones((), dtype=torch.bool)
662        for mask in mask_mods:
663            result = result & mask(b, h, q_idx, kv_idx)
664        return result
665
666    return and_mask
667
668
669def _convert_block_mask_to_mask(
670    block_mask,
671    KV_BLOCK_SIZE=_DEFAULT_SPARSE_BLOCK_SIZE,
672    Q_BLOCK_SIZE=_DEFAULT_SPARSE_BLOCK_SIZE,
673) -> Tensor:
674    assert block_mask.dim() == 4
675    B, H, Q, KV = block_mask.shape
676    block_mask = block_mask.expand(Q_BLOCK_SIZE, KV_BLOCK_SIZE, *block_mask.shape)
677    block_mask = block_mask.permute(2, 3, 4, 0, 5, 1).reshape(
678        B, H, Q * Q_BLOCK_SIZE, KV * KV_BLOCK_SIZE
679    )
680    return block_mask
681
682
683def _create_sparse_block_from_block_mask(
684    block_mask: Tuple[Tensor, Optional[Tensor]],
685    mask_mod: Optional[Callable],
686    KV_BLOCK_SIZE: int = _DEFAULT_SPARSE_BLOCK_SIZE,
687    Q_BLOCK_SIZE: int = _DEFAULT_SPARSE_BLOCK_SIZE,
688) -> BlockMask:
689    partial_blocks, full_blocks = block_mask
690
691    partial_bm = _dense_to_ordered(partial_blocks)
692    if full_blocks is not None:
693        full_bm = _dense_to_ordered(full_blocks)
694    else:
695        full_bm = (None, None)
696
697    return BlockMask.from_kv_blocks(
698        partial_bm[0],
699        partial_bm[1],
700        full_bm[0],
701        full_bm[1],
702        BLOCK_SIZE=(KV_BLOCK_SIZE, Q_BLOCK_SIZE),
703        mask_mod=mask_mod,
704    )
705
706
707def create_mask(
708    mod_fn: Union[_score_mod_signature, _mask_mod_signature],
709    B: Optional[int],
710    H: Optional[int],
711    Q_LEN: int,
712    KV_LEN: int,
713    device: str = "cuda",
714    _compile: bool = False,
715) -> Tensor:
716    r"""This function creates a mask tensor from a mod_fn function.
717
718    Args:
719        mod_fn (Union[_score_mod_signature, _mask_mod_signature]): Function to modify attention scores.
720        B (int): Batch size.
721        H (int): Number of query heads.
722        Q_LEN (int): Sequence length of query.
723        KV_LEN (int): Sequence length of key/value.
724        device (str): Device to run the mask creation on.
725
726    Returns:
727        mask (Tensor): A mask tensor with shape (B, H, M, N).
728    """
729    if B is None:
730        B = 1
731    if H is None:
732        H = 1
733    b = torch.arange(0, B, device=device)
734    h = torch.arange(0, H, device=device)
735    m = torch.arange(0, Q_LEN, device=device)
736    n = torch.arange(0, KV_LEN, device=device)
737    # TODO: fix this
738    # Lack instantiation support for __torch_function__ mode support under compile
739    if _compile:
740        ctx = nullcontext()
741    else:
742        ctx = TransformGetItemToIndex()  # type: ignore[assignment]
743    mod_type = _get_mod_type(mod_fn)
744
745    with ctx:
746        if mod_type == _ModificationType.SCORE_MOD:
747            score_mod = mod_fn
748            score_mod = _vmap_for_bhqkv(score_mod, prefix=(0,))  # first input is score
749            out = score_mod(torch.zeros(B, H, Q_LEN, KV_LEN, device=device), b, h, m, n)
750            mask = torch.where(torch.isneginf(out), False, True)
751            return mask
752        elif mod_type == _ModificationType.MASK_MOD:
753            mask_mod = mod_fn
754            mask_mod = _vmap_for_bhqkv(mask_mod, prefix=())
755            mask = mask_mod(b, h, m, n)
756            return mask
757        else:
758            raise AssertionError
759
760
761def _create_block_mask_inner(
762    mask_mod: Callable,
763    B: int,
764    H: int,
765    Q_LEN: int,
766    KV_LEN: int,
767    device: str,
768    KV_BLOCK_SIZE: int,
769    Q_BLOCK_SIZE: int,
770):
771    r"""Work around for being unable to instantiate __torch_function__ mode under compile.
772    `create_block_mask` will compile this inner function and wrap the call to this
773    with the __torch_function__ mode.
774    """
775    mask_tensor = create_mask(mask_mod, B, H, Q_LEN, KV_LEN, device, _compile=True)
776    partial_block_mask, full_block_mask = _convert_mask_to_block_mask(
777        mask_tensor,
778        KV_BLOCK_SIZE=KV_BLOCK_SIZE,
779        Q_BLOCK_SIZE=Q_BLOCK_SIZE,
780        separate_full_blocks=True,
781    )
782    return partial_block_mask, full_block_mask
783
784
785def create_block_mask(
786    mask_mod: _mask_mod_signature,
787    B: Optional[int],
788    H: Optional[int],
789    Q_LEN: int,
790    KV_LEN: int,
791    device: str = "cuda",
792    BLOCK_SIZE: Union[int, Tuple[int, int]] = _DEFAULT_SPARSE_BLOCK_SIZE,
793    _compile=False,
794) -> BlockMask:
795    r"""This function creates a block mask tuple from a mask_mod function.
796
797    Args:
798        mask_mod (Callable): mask_mod function. This is a callable that defines the
799            masking pattern for the attention mechanism. It takes four arguments:
800            b (batch size), h (number of heads), q_idx (query index), and kv_idx (key/value index).
801            It should return a boolean tensor indicating which attention connections are allowed (True)
802            or masked out (False).
803        B (int): Batch size.
804        H (int): Number of query heads.
805        Q_LEN (int): Sequence length of query.
806        KV_LEN (int): Sequence length of key/value.
807        device (str): Device to run the mask creation on.
808        KV_BLOCK_SIZE (int): Block size of block mask for each query.
809        Q_BLOCK_SIZE (int): Block size of block mask for each key/value.
810        _compile (bool): Whether to compile the mask creation.
811
812    Returns:
813        BlockMask:  A BlockMask object that contains the block mask information.
814
815    Example Usage:
816        .. code-block:: python
817
818            def causal_mask(b, h, q_idx, kv_idx):
819                return q_idx >= kv_idx
820
821            block_mask = create_block_mask(causal_mask, 1, 1, 8192, 8192, device="cuda")
822            query = torch.randn(1, 1, 8192, 64, device="cuda", dtype=torch.float16)
823            key = torch.randn(1, 1, 8192, 64, device="cuda", dtype=torch.float16)
824            value = torch.randn(1, 1, 8192, 64, device="cuda", dtype=torch.float16)
825            output = flex_attention(query, key, value, block_mask=block_mask)
826    """
827    mod_type = _get_mod_type(mask_mod)
828    assert (
829        mod_type == _ModificationType.MASK_MOD
830    ), f"create-block_mask requires a mask_mod function! Got {mask_mod}"
831    inner_func = _create_block_mask_inner
832    if B is None:
833        B = 1
834    if H is None:
835        H = 1
836    if isinstance(BLOCK_SIZE, int):
837        Q_BLOCK_SIZE = BLOCK_SIZE
838        KV_BLOCK_SIZE = BLOCK_SIZE
839    else:
840        Q_BLOCK_SIZE, KV_BLOCK_SIZE = BLOCK_SIZE
841
842    if Q_LEN < 128:
843        Q_BLOCK_SIZE = Q_LEN
844    else:
845        Q_LEN = _round_up_to_multiple(Q_LEN, Q_BLOCK_SIZE)
846    KV_LEN = _round_up_to_multiple(KV_LEN, KV_BLOCK_SIZE)
847    if _compile:
848        inner_func = torch.compile(inner_func, fullgraph=True, dynamic=False)
849    with TransformGetItemToIndex():
850        partial_block_mask, full_block_mask = inner_func(
851            mask_mod, B, H, Q_LEN, KV_LEN, device, KV_BLOCK_SIZE, Q_BLOCK_SIZE
852        )
853        block_mask = _create_sparse_block_from_block_mask(
854            (partial_block_mask, full_block_mask), mask_mod
855        )
856    return block_mask
857
858
859def _create_empty_block_mask(query: Tensor, key: Tensor) -> BlockMask:
860    r"""Default block mask for flex attention.
861    If users don't specify any block sparse mask info, we create this
862    empty block sparse mask. Which creates a BlockMask with 1 block that is the full length
863    of the query and key tensors.
864    """
865    device = query.device
866    return BlockMask.from_kv_blocks(
867        kv_num_blocks=torch.ones([1, 1, 1], dtype=torch.int32, device=device),
868        kv_indices=torch.zeros([1, 1, 1, 1], dtype=torch.int32, device=device),
869        BLOCK_SIZE=_LARGE_SPARSE_BLOCK_SIZE,
870    )
871
872
873def _apply_kernel_options(
874    query: Tensor, key: Tensor, value: Tensor, return_lse: bool, kernel_options
875):
876    kernel_options = {} if kernel_options is None else dict(kernel_options)
877
878    kernel_options.setdefault("ROWS_GUARANTEED_SAFE", False)
879    kernel_options.setdefault("PRESCALE_QK", False)
880
881    # If foward kernel needs to return logsumexp is decided by this rule internally.
882    assert "OUTPUT_LOGSUMEXP" not in kernel_options
883    kernel_options["OUTPUT_LOGSUMEXP"] = True
884    if not return_lse:
885        any_inputs_require_grad = (
886            query.requires_grad or key.requires_grad or value.requires_grad
887        )
888        output_logsumexp = any_inputs_require_grad and torch.is_grad_enabled()
889        kernel_options["OUTPUT_LOGSUMEXP"] = output_logsumexp
890
891    return kernel_options
892
893
894def _validate_embed_dim(query: Tensor, key: Tensor, value: Tensor):
895    if query.size(-1) != key.size(-1):
896        raise ValueError(
897            f"Expect query and key/value to have the same embedding dimension "
898            f"but got E={query.size(-1)} and E={key.size(-1)}."
899        )
900    # TODO this config segfaults with Triton without:
901    # https://github.com/triton-lang/triton/pull/4540
902    if not (
903        _supported_head_dim(query.size(-1)) and _supported_head_dim(value.size(-1))
904    ):
905        raise ValueError(
906            f"NYI: Currently non power of 2 embedding dimension are not supported. "
907            f"Got E={query.size(-1)} and Ev={value.size(-1)}."
908        )
909    if value.size(-1) > query.size(-1):
910        raise ValueError(
911            f"NYI: Currently value embedding dimension must be less than or equal to query embedding dimension. "
912            f"Got Ev={value.size(-1)} and E={query.size(-1)}."
913        )
914
915
916def flex_attention(
917    query: Tensor,
918    key: Tensor,
919    value: Tensor,
920    score_mod: Optional[_score_mod_signature] = None,
921    block_mask: Optional[BlockMask] = None,
922    scale: Optional[float] = None,
923    enable_gqa: bool = False,
924    return_lse: bool = False,
925    kernel_options: Optional[Dict[str, Any]] = None,
926) -> Union[Tensor, Tuple[Tensor, Tensor]]:
927    r"""This function implements scaled dot product attention with an arbitrary attention score modification function.
928
929    This function computes the scaled dot product attention between query, key, and value tensors with a user-defined
930    attention score modification function. The attention score modification function will be applied after the attention
931    scores have been calculated between the query and key tensors. The attention scores are calculated as follows:
932
933    The ``score_mod`` function should have the following signature:
934
935    .. code-block:: python
936
937        def score_mod(
938            score: Tensor,
939            batch: Tensor,
940            head: Tensor,
941            q_idx: Tensor,
942            k_idx: Tensor
943        ) -> Tensor:
944
945    Where:
946        - ``score``: A scalar tensor representing the attention score,
947          with the same data type and device as the query, key, and value tensors.
948        - ``batch``, ``head``, ``q_idx``, ``k_idx``: Scalar tensors indicating
949          the batch index, query head index, query index, and key/value index, respectively.
950          These should have the ``torch.int`` data type and be located on the same device as the score tensor.
951
952    Args:
953        query (Tensor): Query tensor; shape :math:`(B, Hq, L, E)`.
954        key (Tensor): Key tensor; shape :math:`(B, Hkv, S, E)`.
955        value (Tensor): Value tensor; shape :math:`(B, Hkv, S, Ev)`.
956        score_mod (Optional[Callable]): Function to modify attention scores. By default no score_mod is applied.
957        block_mask (Optional[BlockMask]): BlockMask object that controls the blocksparsity pattern of the attention.
958        scale (Optional[float]): Scaling factor applied prior to softmax. If none, the default value is set to :math:`\frac{1}{\sqrt{E}}`.
959        enable_gqa (bool): If set to True, enables Grouped Query Attention (GQA) and broadcasts key/value heads to query heads.
960        return_lse (bool): Whether to return the logsumexp of the attention scores. Default is False.
961        kernel_options (Optional[Dict[str, Any]]): Options to pass into the Triton kernels.
962
963    Returns:
964        output (Tensor): Attention output; shape :math:`(B, Hq, L, Ev)`.
965
966    Shape legend:
967        - :math:`N: \text{Batch size} ... : \text{Any number of other batch dimensions (optional)}`
968        - :math:`S: \text{Source sequence length}`
969        - :math:`L: \text{Target sequence length}`
970        - :math:`E: \text{Embedding dimension of the query and key}`
971        - :math:`Ev: \text{Embedding dimension of the value}`
972
973    .. warning::
974        `torch.nn.attention.flex_attention` is a prototype feature in PyTorch.
975        Please look forward to a more stable implementation in a future version of PyTorch.
976        Read more about feature classification at: https://pytorch.org/blog/pytorch-feature-classification-changes/#prototype
977
978    """
979    # Some basic input validation
980    _validate_sdpa_input(query, key, value)
981    _validate_embed_dim(query, key, value)
982    if query.dim() != 4 or key.dim() != 4 or value.dim() != 4:
983        raise NotImplementedError("NYI: query, key, and value must be 4D tensors")
984    if (not enable_gqa) and query.size(-3) != key.size(-3):
985        raise ValueError(
986            f"Expect query and key/value to have the same number of heads "
987            f"but got Hq={query.size(-3)} and Hkv={key.size(-3)}. "
988            f"Try setting enable_gqa=True for GQA."
989        )
990    if enable_gqa:
991        Hq = query.size(1)
992        Hkv = key.size(1)
993        if Hq % Hkv != 0:
994            raise ValueError(
995                f"Expect number of query heads to be a multiple of kv heads for GQA "
996                f"but got Hq={Hq} and Hkv={Hkv}."
997            )
998
999    if score_mod is None:
1000        score_mod = _identity
1001    if block_mask is None:
1002        block_mask = _create_empty_block_mask(query, key)
1003    if scale is None:
1004        scale = 1.0 / math.sqrt(query.size(-1))
1005
1006    kernel_options = _apply_kernel_options(
1007        query,
1008        key,
1009        value,
1010        return_lse,
1011        kernel_options,
1012    )
1013
1014    if torch.compiler.is_dynamo_compiling():
1015        # mark head_dim and number of heads to be static
1016        for x in [query, key, value]:
1017            torch._dynamo.mark_static(x, -3)
1018            torch._dynamo.mark_static(x, -1)
1019        out, lse = flex_attention_hop(
1020            query, key, value, score_mod, block_mask.as_tuple(), scale, kernel_options
1021        )
1022        if return_lse:
1023            return out, lse * math.log(2)
1024        else:
1025            return out
1026
1027    if not torch._dynamo.is_dynamo_supported():
1028        raise RuntimeError("flex_attention requires dynamo support")
1029
1030    # Dynamo is expecting a callable with "__code__" attribute.
1031    # We cannot directly pass hop to it. So we wrap it in a dummy function.
1032    def _flex_attention_hop_wrapper(*args, **kwargs):
1033        return flex_attention_hop(*args, **kwargs)
1034
1035    with _set_compilation_env():
1036        with torch._dynamo.utils.disable_cache_limit():
1037            with _temp_remove_pre_dispatch_torch_function_mode():
1038                out, lse = torch.compile(
1039                    _flex_attention_hop_wrapper, backend="eager", fullgraph=True
1040                )(
1041                    query,
1042                    key,
1043                    value,
1044                    score_mod,
1045                    block_mask.as_tuple(),
1046                    scale,
1047                    kernel_options,
1048                )
1049                if return_lse:
1050                    return out, lse * math.log(2)
1051                else:
1052                    return out
1053