xref: /aosp_15_r20/external/pytorch/torch/nn/attention/_utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2"""Defines utilities for interacting with scaled_dot_product_attention"""
3import math
4from typing import List, Optional, Union
5
6import torch
7
8
9__all__: List[str] = []
10
11
12def _input_requires_grad(*tensors: torch.Tensor) -> bool:
13    """Returns True if any of the tensors requires grad"""
14    return any(t.requires_grad for t in tensors)
15
16
17def _postprocess_flash_output(inpt_tensor: torch.Tensor, og_size: int) -> torch.Tensor:
18    """Handles the unpad of the last dimension"""
19    if inpt_tensor.size(-1) != og_size:
20        return inpt_tensor[..., :og_size]
21    return inpt_tensor
22
23
24def _calculate_scale(head_dim_size: int, scale: Optional[float]) -> float:
25    """
26    For FlashAttention we pad the head dimension to be a multiple of 8 so we need to scale the output
27    by the original head size and not the padded.
28    """
29    if scale is not None:
30        return scale
31    return 1.0 / math.sqrt(head_dim_size)
32
33
34_SUPPORTED_HEAD_DIMS = [2, 4, 8, 16, 32, 64, 128, 256, 512, 1024]
35
36
37def _supported_head_dim(n: Union[int, torch.SymInt]) -> bool:
38    """Returns true if the head dim is supported by FlexAttention"""
39    return n in _SUPPORTED_HEAD_DIMS
40
41
42def _validate_sdpa_input(
43    query: torch.Tensor,
44    key: torch.Tensor,
45    value: torch.Tensor,
46    attn_mask: Optional[torch.Tensor] = None,
47    dropout_p=0.0,
48    is_causal=False,
49    scale=None,
50):
51    if query.dtype != key.dtype or query.dtype != value.dtype:
52        raise ValueError(
53            f"Expected query, key, and value to have the same dtype, "
54            f"but got query.dtype: {query.dtype}, key.dtype: {key.dtype}, "
55            f"and value.dtype: {value.dtype} instead."
56        )
57    if query.device != key.device or query.device != value.device:
58        raise ValueError(
59            f"Expected query, key, and value to have the same device type, "
60            f"but got query.device: {query.device}, key.device: {key.device}, "
61            f"and value.device: {value.device} instead."
62        )
63    if query.dim() < 2 or key.dim() < 2 or value.dim() < 2:
64        raise ValueError(
65            f"Expected query, key, and value to all be  at least 2 dimensional, but got query.dim: "
66            f"{query.dim()}, key.dim: {key.dim()} and value.dim: {value.dim()} instead."
67        )
68