1# mypy: allow-untyped-defs 2"""Defines utilities for interacting with scaled_dot_product_attention""" 3import math 4from typing import List, Optional, Union 5 6import torch 7 8 9__all__: List[str] = [] 10 11 12def _input_requires_grad(*tensors: torch.Tensor) -> bool: 13 """Returns True if any of the tensors requires grad""" 14 return any(t.requires_grad for t in tensors) 15 16 17def _postprocess_flash_output(inpt_tensor: torch.Tensor, og_size: int) -> torch.Tensor: 18 """Handles the unpad of the last dimension""" 19 if inpt_tensor.size(-1) != og_size: 20 return inpt_tensor[..., :og_size] 21 return inpt_tensor 22 23 24def _calculate_scale(head_dim_size: int, scale: Optional[float]) -> float: 25 """ 26 For FlashAttention we pad the head dimension to be a multiple of 8 so we need to scale the output 27 by the original head size and not the padded. 28 """ 29 if scale is not None: 30 return scale 31 return 1.0 / math.sqrt(head_dim_size) 32 33 34_SUPPORTED_HEAD_DIMS = [2, 4, 8, 16, 32, 64, 128, 256, 512, 1024] 35 36 37def _supported_head_dim(n: Union[int, torch.SymInt]) -> bool: 38 """Returns true if the head dim is supported by FlexAttention""" 39 return n in _SUPPORTED_HEAD_DIMS 40 41 42def _validate_sdpa_input( 43 query: torch.Tensor, 44 key: torch.Tensor, 45 value: torch.Tensor, 46 attn_mask: Optional[torch.Tensor] = None, 47 dropout_p=0.0, 48 is_causal=False, 49 scale=None, 50): 51 if query.dtype != key.dtype or query.dtype != value.dtype: 52 raise ValueError( 53 f"Expected query, key, and value to have the same dtype, " 54 f"but got query.dtype: {query.dtype}, key.dtype: {key.dtype}, " 55 f"and value.dtype: {value.dtype} instead." 56 ) 57 if query.device != key.device or query.device != value.device: 58 raise ValueError( 59 f"Expected query, key, and value to have the same device type, " 60 f"but got query.device: {query.device}, key.device: {key.device}, " 61 f"and value.device: {value.device} instead." 62 ) 63 if query.dim() < 2 or key.dim() < 2 or value.dim() < 2: 64 raise ValueError( 65 f"Expected query, key, and value to all be at least 2 dimensional, but got query.dim: " 66 f"{query.dim()}, key.dim: {key.dim()} and value.dim: {value.dim()} instead." 67 ) 68