xref: /aosp_15_r20/external/pytorch/torch/nn/attention/bias.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2"""Defines bias subclasses that work with scaled_dot_product_attention"""
3from enum import auto, IntEnum
4from typing import Optional
5from warnings import warn
6
7import torch
8import torch.nn.functional as F
9from torch.backends.cuda import (
10    can_use_efficient_attention,
11    can_use_flash_attention,
12    is_flash_attention_available,
13    SDPAParams,
14)
15from torch.nn.attention import _raise_kernel_warnings
16from torch.nn.attention._utils import (
17    _calculate_scale,
18    _input_requires_grad,
19    _postprocess_flash_output,
20    _validate_sdpa_input,
21)
22
23
24__all__ = ["causal_upper_left", "causal_lower_right", "CausalVariant", "CausalBias"]
25
26
27torch._dynamo.allow_in_graph(is_flash_attention_available)
28torch._dynamo.allow_in_graph(can_use_flash_attention)
29torch._dynamo.allow_in_graph(can_use_efficient_attention)
30torch._dynamo.allow_in_graph(SDPAParams)
31
32
33class CausalVariant(IntEnum):
34    r"""
35    Enum for causal variants used in attention mechanisms.
36
37    Defines two types of causal biases:
38
39    `UPPER_LEFT`: Represents upper-left triangular bias for standard causal attention.
40    The equivalent pytorch code for constructing this bias is:
41
42    .. code-block:: python
43
44        torch.tril(torch.ones(size, dtype=torch.bool))
45
46    For instance, with `shape=(3,4)`, the materialized bias tensor will be:
47
48    .. code-block:: text
49
50        [[1, 0, 0, 0],
51         [1, 1, 0, 0],
52         [1, 1, 1, 0]]
53
54
55    `LOWER_RIGHT`: Represents lower-right triangular bias, the include values are aligned to the lower
56    right corner of the matrix.
57
58    The equivalent pytorch code for constructing this bias is:
59
60    .. code-block:: python
61
62        diagonal_offset = size[1] - size[0]
63        torch.tril(
64            torch.ones(size, dtype=torch.bool),
65            diagonal=diagonal_offset,
66        )
67
68    For instance, with `shape=(3,4)`, the materialized bias tensor will be:
69
70    .. code-block:: text
71
72        [[1, 1, 0, 0],
73         [1, 1, 1, 0],
74         [1, 1, 1, 1]]
75
76    Note that these variants are equivalent to each other when the sequence lengths of the query and key/value
77    tensors are equal since the triangular matrix is square.
78
79    .. warning:: This enum is a prototype and subject to change.
80    """
81
82    UPPER_LEFT = auto()
83    LOWER_RIGHT = auto()
84
85
86class CausalBias(torch.Tensor):
87    """
88    A bias representing causal attention patterns. For an overview of the bias structure, see the :class:`CausalVariant` enum.
89
90    This class is used for defining causal (triangular) attention biases. For construing the bias, there exist
91    two factory functions: :func:`causal_upper_left` and :func:`causal_lower_right`.
92
93    Example:
94
95    .. code-block:: python
96
97        from torch.nn.attention.bias import causal_lower_right
98
99        bsz, num_heads, seqlen_q, seqlen_kv, head_dim = 32, 8, 4, 12, 8
100
101        # Create a lower-right causal bias
102        attn_bias = causal_lower_right(seqlen_q, seqlen_kv)
103
104        q = torch.randn(bsz, num_heads, seqlen_q, head_dim, device="cuda", dtype=torch.float16)
105        k = torch.randn(bsz, num_heads, seqlen_kv, head_dim, device="cuda", dtype=torch.float16)
106        v = torch.randn(bsz, num_heads, seqlen_kv, head_dim, device="cuda", dtype=torch.float16)
107
108        out = F.scaled_dot_product_attention(q, k, v, attn_bias)
109
110    .. warning:: This class is a prototype and subject to change.
111    """
112
113    def __init__(self, variant: CausalVariant, seq_len_q: int, seq_len_kv: int):
114        """
115        Initializes the CausalBias instance with a specified variant and sequence lengths.
116
117        Args:
118            variant (CausalVariant): The type of causal bias to use (either UPPER_LEFT or LOWER_RIGHT).
119            seq_len_q (int): The sequence length of the query tensor.
120            seq_len_kv (int): The sequence length of the key/value tensor.
121
122        Raises a warning if the LOWER_RIGHT variant is used with seq_len_q > seq_len_kv, as it may produce NaNs.
123        """
124        assert isinstance(variant, CausalVariant)
125        self.variant = variant
126        self.seq_len_q = seq_len_q
127        self.seq_len_kv = seq_len_kv
128        if seq_len_q > seq_len_kv and variant == CausalVariant.LOWER_RIGHT:
129            warn(
130                "Lower right causal bias will produce NaNs in the output when seq_len_q > seq_len_kv!"
131            )
132
133    def _upper_left(self, device: torch.device) -> torch.Tensor:
134        """Upper left causal bias"""
135        return torch.tril(
136            torch.ones(self.seq_len_q, self.seq_len_kv, device=device, dtype=torch.bool)
137        )
138
139    def _lower_right(self, device: torch.device) -> torch.Tensor:
140        """Lower right causal bias"""
141        diagonal_offset = self.seq_len_kv - self.seq_len_q
142        return torch.tril(
143            torch.ones(
144                self.seq_len_q, self.seq_len_kv, device=device, dtype=torch.bool
145            ),
146            diagonal=diagonal_offset,
147        )
148
149    def _materialize(self, device: Optional[torch.device] = None) -> torch.Tensor:
150        """
151        Materializes the causal bias into a tensor form.
152
153        Depending on the variant, this method generates either an upper-left or lower-right
154        triangular matrix to represent the causal bias.
155
156        Args:
157            device (Optional[torch.device]): The device on which to create the tensor. Defaults to CPU.
158
159        Returns:
160            torch.Tensor: The materialized bias tensor.
161        """
162        if device is None:
163            device = torch.device("cpu")
164        if self.variant == CausalVariant.UPPER_LEFT:
165            return self._upper_left(device)
166        elif self.variant == CausalVariant.LOWER_RIGHT:
167            return self._lower_right(device)
168
169    @staticmethod
170    def _dispatch(
171        query: torch.Tensor,
172        key: torch.Tensor,
173        value: torch.Tensor,
174        attn_mask: "CausalBias",
175        dropout_p: float = 0.0,
176        is_causal: bool = False,
177        scale: Optional[float] = None,
178        enable_gqa: bool = False,
179    ) -> torch.Tensor:
180        r"""
181        Handles the logic for computing attention with the specified causal bias.
182
183        Args:
184            query (Tensor): Query tensor; shape :math:`(N, ..., L, E)`.
185            key (Tensor): Key tensor; shape :math:`(N, ..., S, E)`.
186            value (Tensor): Value tensor; shape :math:`(N, ..., S, Ev)`.
187            attn_mask (CausalBias): The type of causal attention to apply.
188                A boolean mask where a value of True indicates that the element *should* take part in attention.
189                A float mask of the same type as query, key, value that is added to the attention score.
190            dropout_p (float): Dropout probability; if greater than 0.0, dropout is applied
191            is_causal (bool): If true, assumes upper left causal attention masking and errors if both attn_mask and is_causal
192                are set.
193            scale (optional float): Scaling factor applied prior to softmax. If None, the default value is set
194                to :math:`\frac{1}{\sqrt{E}}`.
195            enable_gqa (optional bool): If set to True, Grouped Query Attention (GQA) is enabled, by default it is set to False.
196
197        Returns:
198            output (Tensor): Attention output; shape :math:`(N, ..., L, Ev)`.
199
200        Raises:
201            ValueError: If the causal bias variant is not a CausalVariant type.
202
203        """
204        if is_causal:
205            raise ValueError("CausalBias should not be used with causal=True")
206
207        if (
208            attn_mask.seq_len_q == attn_mask.seq_len_kv
209            or attn_mask.variant == CausalVariant.UPPER_LEFT
210        ):
211            return F.scaled_dot_product_attention(
212                query,
213                key,
214                value,
215                attn_mask=None,
216                dropout_p=dropout_p,
217                is_causal=True,
218                scale=scale,
219                enable_gqa=enable_gqa,
220            )
221        elif attn_mask.variant == CausalVariant.LOWER_RIGHT:
222            _validate_sdpa_input(query, key, value, None, dropout_p, is_causal, scale)
223            sdpa_params = SDPAParams(
224                query, key, value, None, dropout_p, is_causal, enable_gqa
225            )
226            if can_use_flash_attention(sdpa_params):
227                needs_padding = query.size(-1) % 8 != 0
228                og_head_size = query.size(-1)
229                og_scale = _calculate_scale(og_head_size, scale)
230                if needs_padding:
231                    query = torch.nn.functional.pad(query, (0, 8 - query.size(-1) % 8))
232                    key = torch.nn.functional.pad(key, (0, 8 - key.size(-1) % 8))
233                    value = torch.nn.functional.pad(value, (0, 8 - value.size(-1) % 8))
234                out = torch.ops.aten._scaled_dot_product_flash_attention(
235                    query,
236                    key,
237                    value,
238                    dropout_p,
239                    is_causal=True,  # TODO: Flash accepts causal = True and for this particular op it means lower right
240                    return_debug_mask=False,
241                    scale=og_scale,
242                )[0]
243                return _postprocess_flash_output(out, og_head_size)
244            if can_use_efficient_attention(sdpa_params):
245                compute_log_sumexp = False
246                if _input_requires_grad(query, key, value):
247                    compute_log_sumexp = True
248                return torch.ops.aten._efficient_attention_forward(
249                    query.transpose(1, 2),
250                    key.transpose(1, 2),
251                    value.transpose(1, 2),
252                    bias=None,
253                    cu_seqlens_q=None,
254                    cu_seqlens_k=None,
255                    max_seqlen_q=None,
256                    max_seqlen_k=None,
257                    dropout_p=dropout_p,
258                    custom_mask_type=int(attn_mask.variant),
259                    compute_log_sumexp=compute_log_sumexp,
260                    scale=scale,
261                    seqlen_k=None,
262                )[0].transpose(1, 2)
263            else:
264                _raise_kernel_warnings(sdpa_params)
265                # We cant use efficient attention the only support for lower right is via materialization
266                return F.scaled_dot_product_attention(
267                    query,
268                    key,
269                    value,
270                    attn_mask=attn_mask._materialize(query.device),
271                    dropout_p=dropout_p,
272                    is_causal=False,
273                    scale=scale,
274                    enable_gqa=enable_gqa,
275                )
276        else:
277            raise ValueError(
278                f"CausalBias.variant must be a CausalVariant type, but found: {attn_mask.variant}"
279            )
280
281    @classmethod
282    def __torch_function__(cls, func, types, args=(), kwargs=None):
283        """Defines the behavior of torch.nn.functional.scaled_dot_product_attention when the attn_bias is an AttnBias"""
284        if kwargs is None:
285            kwargs = {}
286        if func != torch.nn.functional.scaled_dot_product_attention:
287            raise NotImplementedError(
288                "CausalBias only supports scaled_dot_product_attention"
289            )
290        return cls._dispatch(*args, **kwargs)
291
292    def __repr__(self):
293        return self._materialize().__repr__()
294
295
296def causal_upper_left(*size) -> CausalBias:
297    """
298    Creates an upper-left triangular causal bias.
299
300    This function generates a upper-left triangular matrix to represent causal attention bias with a
301    diagonal offset set so that the inclusive values are aligned to the upper left corner of the matrix.
302    This equivalent to the `is_causal=True` argument in `scaled_dot_product_attention`.
303
304    The equivalent pytorch code for constructing this bias is:
305
306    .. code-block:: python
307
308        torch.tril(torch.ones(size, dtype=torch.bool))
309
310    For instance, with `shape=(3,4)`, the materialized bias tensor will be:
311
312    .. code-block:: text
313
314        [[1, 0, 0, 0],
315         [1, 1, 0, 0],
316         [1, 1, 1, 0]]
317
318    Args:
319        size: The size of the bias matrix.
320
321    Returns:
322        CausalBias: The UPPER_LEFT triangular causal bias variant.
323    """
324    assert len(size) == 2, "causal_upper_left only supports 2D tensors"
325    seq_len_q, seq_len_kv = size
326    return CausalBias(CausalVariant.UPPER_LEFT, seq_len_q, seq_len_kv)
327
328
329def causal_lower_right(*size) -> CausalBias:
330    """
331    Creates a lower-right triangular causal bias.
332
333    This function generates a lower-right triangular matrix to represent causal attention bias with a
334    diagonal offset set so that the inclusive values are aligned to the lower right corner of the matrix.
335
336    The equivalent pytorch code for constructing this bias is:
337
338    .. code-block:: python
339
340        diagonal_offset = size[1] - size[0]
341        torch.tril(
342            torch.ones(size, dtype=torch.bool),
343            diagonal=diagonal_offset,
344        )
345
346    For instance, with `shape=(3,4)`, the materialized bias tensor will be:
347
348    .. code-block:: text
349
350        [[1, 1, 0, 0],
351         [1, 1, 1, 0],
352         [1, 1, 1, 1]]
353
354    Args:
355        size: The size of the bias matrix.
356
357    Returns:
358        CausalBias: The LOWER_RIGHT triangular causal bias variant.
359    """
360    assert len(size) == 2, "causal_lower_right only supports 2D tensors"
361    seq_len_q, seq_len_kv = size
362    return CausalBias(CausalVariant.LOWER_RIGHT, seq_len_q, seq_len_kv)
363