1# mypy: allow-untyped-defs 2"""Defines bias subclasses that work with scaled_dot_product_attention""" 3from enum import auto, IntEnum 4from typing import Optional 5from warnings import warn 6 7import torch 8import torch.nn.functional as F 9from torch.backends.cuda import ( 10 can_use_efficient_attention, 11 can_use_flash_attention, 12 is_flash_attention_available, 13 SDPAParams, 14) 15from torch.nn.attention import _raise_kernel_warnings 16from torch.nn.attention._utils import ( 17 _calculate_scale, 18 _input_requires_grad, 19 _postprocess_flash_output, 20 _validate_sdpa_input, 21) 22 23 24__all__ = ["causal_upper_left", "causal_lower_right", "CausalVariant", "CausalBias"] 25 26 27torch._dynamo.allow_in_graph(is_flash_attention_available) 28torch._dynamo.allow_in_graph(can_use_flash_attention) 29torch._dynamo.allow_in_graph(can_use_efficient_attention) 30torch._dynamo.allow_in_graph(SDPAParams) 31 32 33class CausalVariant(IntEnum): 34 r""" 35 Enum for causal variants used in attention mechanisms. 36 37 Defines two types of causal biases: 38 39 `UPPER_LEFT`: Represents upper-left triangular bias for standard causal attention. 40 The equivalent pytorch code for constructing this bias is: 41 42 .. code-block:: python 43 44 torch.tril(torch.ones(size, dtype=torch.bool)) 45 46 For instance, with `shape=(3,4)`, the materialized bias tensor will be: 47 48 .. code-block:: text 49 50 [[1, 0, 0, 0], 51 [1, 1, 0, 0], 52 [1, 1, 1, 0]] 53 54 55 `LOWER_RIGHT`: Represents lower-right triangular bias, the include values are aligned to the lower 56 right corner of the matrix. 57 58 The equivalent pytorch code for constructing this bias is: 59 60 .. code-block:: python 61 62 diagonal_offset = size[1] - size[0] 63 torch.tril( 64 torch.ones(size, dtype=torch.bool), 65 diagonal=diagonal_offset, 66 ) 67 68 For instance, with `shape=(3,4)`, the materialized bias tensor will be: 69 70 .. code-block:: text 71 72 [[1, 1, 0, 0], 73 [1, 1, 1, 0], 74 [1, 1, 1, 1]] 75 76 Note that these variants are equivalent to each other when the sequence lengths of the query and key/value 77 tensors are equal since the triangular matrix is square. 78 79 .. warning:: This enum is a prototype and subject to change. 80 """ 81 82 UPPER_LEFT = auto() 83 LOWER_RIGHT = auto() 84 85 86class CausalBias(torch.Tensor): 87 """ 88 A bias representing causal attention patterns. For an overview of the bias structure, see the :class:`CausalVariant` enum. 89 90 This class is used for defining causal (triangular) attention biases. For construing the bias, there exist 91 two factory functions: :func:`causal_upper_left` and :func:`causal_lower_right`. 92 93 Example: 94 95 .. code-block:: python 96 97 from torch.nn.attention.bias import causal_lower_right 98 99 bsz, num_heads, seqlen_q, seqlen_kv, head_dim = 32, 8, 4, 12, 8 100 101 # Create a lower-right causal bias 102 attn_bias = causal_lower_right(seqlen_q, seqlen_kv) 103 104 q = torch.randn(bsz, num_heads, seqlen_q, head_dim, device="cuda", dtype=torch.float16) 105 k = torch.randn(bsz, num_heads, seqlen_kv, head_dim, device="cuda", dtype=torch.float16) 106 v = torch.randn(bsz, num_heads, seqlen_kv, head_dim, device="cuda", dtype=torch.float16) 107 108 out = F.scaled_dot_product_attention(q, k, v, attn_bias) 109 110 .. warning:: This class is a prototype and subject to change. 111 """ 112 113 def __init__(self, variant: CausalVariant, seq_len_q: int, seq_len_kv: int): 114 """ 115 Initializes the CausalBias instance with a specified variant and sequence lengths. 116 117 Args: 118 variant (CausalVariant): The type of causal bias to use (either UPPER_LEFT or LOWER_RIGHT). 119 seq_len_q (int): The sequence length of the query tensor. 120 seq_len_kv (int): The sequence length of the key/value tensor. 121 122 Raises a warning if the LOWER_RIGHT variant is used with seq_len_q > seq_len_kv, as it may produce NaNs. 123 """ 124 assert isinstance(variant, CausalVariant) 125 self.variant = variant 126 self.seq_len_q = seq_len_q 127 self.seq_len_kv = seq_len_kv 128 if seq_len_q > seq_len_kv and variant == CausalVariant.LOWER_RIGHT: 129 warn( 130 "Lower right causal bias will produce NaNs in the output when seq_len_q > seq_len_kv!" 131 ) 132 133 def _upper_left(self, device: torch.device) -> torch.Tensor: 134 """Upper left causal bias""" 135 return torch.tril( 136 torch.ones(self.seq_len_q, self.seq_len_kv, device=device, dtype=torch.bool) 137 ) 138 139 def _lower_right(self, device: torch.device) -> torch.Tensor: 140 """Lower right causal bias""" 141 diagonal_offset = self.seq_len_kv - self.seq_len_q 142 return torch.tril( 143 torch.ones( 144 self.seq_len_q, self.seq_len_kv, device=device, dtype=torch.bool 145 ), 146 diagonal=diagonal_offset, 147 ) 148 149 def _materialize(self, device: Optional[torch.device] = None) -> torch.Tensor: 150 """ 151 Materializes the causal bias into a tensor form. 152 153 Depending on the variant, this method generates either an upper-left or lower-right 154 triangular matrix to represent the causal bias. 155 156 Args: 157 device (Optional[torch.device]): The device on which to create the tensor. Defaults to CPU. 158 159 Returns: 160 torch.Tensor: The materialized bias tensor. 161 """ 162 if device is None: 163 device = torch.device("cpu") 164 if self.variant == CausalVariant.UPPER_LEFT: 165 return self._upper_left(device) 166 elif self.variant == CausalVariant.LOWER_RIGHT: 167 return self._lower_right(device) 168 169 @staticmethod 170 def _dispatch( 171 query: torch.Tensor, 172 key: torch.Tensor, 173 value: torch.Tensor, 174 attn_mask: "CausalBias", 175 dropout_p: float = 0.0, 176 is_causal: bool = False, 177 scale: Optional[float] = None, 178 enable_gqa: bool = False, 179 ) -> torch.Tensor: 180 r""" 181 Handles the logic for computing attention with the specified causal bias. 182 183 Args: 184 query (Tensor): Query tensor; shape :math:`(N, ..., L, E)`. 185 key (Tensor): Key tensor; shape :math:`(N, ..., S, E)`. 186 value (Tensor): Value tensor; shape :math:`(N, ..., S, Ev)`. 187 attn_mask (CausalBias): The type of causal attention to apply. 188 A boolean mask where a value of True indicates that the element *should* take part in attention. 189 A float mask of the same type as query, key, value that is added to the attention score. 190 dropout_p (float): Dropout probability; if greater than 0.0, dropout is applied 191 is_causal (bool): If true, assumes upper left causal attention masking and errors if both attn_mask and is_causal 192 are set. 193 scale (optional float): Scaling factor applied prior to softmax. If None, the default value is set 194 to :math:`\frac{1}{\sqrt{E}}`. 195 enable_gqa (optional bool): If set to True, Grouped Query Attention (GQA) is enabled, by default it is set to False. 196 197 Returns: 198 output (Tensor): Attention output; shape :math:`(N, ..., L, Ev)`. 199 200 Raises: 201 ValueError: If the causal bias variant is not a CausalVariant type. 202 203 """ 204 if is_causal: 205 raise ValueError("CausalBias should not be used with causal=True") 206 207 if ( 208 attn_mask.seq_len_q == attn_mask.seq_len_kv 209 or attn_mask.variant == CausalVariant.UPPER_LEFT 210 ): 211 return F.scaled_dot_product_attention( 212 query, 213 key, 214 value, 215 attn_mask=None, 216 dropout_p=dropout_p, 217 is_causal=True, 218 scale=scale, 219 enable_gqa=enable_gqa, 220 ) 221 elif attn_mask.variant == CausalVariant.LOWER_RIGHT: 222 _validate_sdpa_input(query, key, value, None, dropout_p, is_causal, scale) 223 sdpa_params = SDPAParams( 224 query, key, value, None, dropout_p, is_causal, enable_gqa 225 ) 226 if can_use_flash_attention(sdpa_params): 227 needs_padding = query.size(-1) % 8 != 0 228 og_head_size = query.size(-1) 229 og_scale = _calculate_scale(og_head_size, scale) 230 if needs_padding: 231 query = torch.nn.functional.pad(query, (0, 8 - query.size(-1) % 8)) 232 key = torch.nn.functional.pad(key, (0, 8 - key.size(-1) % 8)) 233 value = torch.nn.functional.pad(value, (0, 8 - value.size(-1) % 8)) 234 out = torch.ops.aten._scaled_dot_product_flash_attention( 235 query, 236 key, 237 value, 238 dropout_p, 239 is_causal=True, # TODO: Flash accepts causal = True and for this particular op it means lower right 240 return_debug_mask=False, 241 scale=og_scale, 242 )[0] 243 return _postprocess_flash_output(out, og_head_size) 244 if can_use_efficient_attention(sdpa_params): 245 compute_log_sumexp = False 246 if _input_requires_grad(query, key, value): 247 compute_log_sumexp = True 248 return torch.ops.aten._efficient_attention_forward( 249 query.transpose(1, 2), 250 key.transpose(1, 2), 251 value.transpose(1, 2), 252 bias=None, 253 cu_seqlens_q=None, 254 cu_seqlens_k=None, 255 max_seqlen_q=None, 256 max_seqlen_k=None, 257 dropout_p=dropout_p, 258 custom_mask_type=int(attn_mask.variant), 259 compute_log_sumexp=compute_log_sumexp, 260 scale=scale, 261 seqlen_k=None, 262 )[0].transpose(1, 2) 263 else: 264 _raise_kernel_warnings(sdpa_params) 265 # We cant use efficient attention the only support for lower right is via materialization 266 return F.scaled_dot_product_attention( 267 query, 268 key, 269 value, 270 attn_mask=attn_mask._materialize(query.device), 271 dropout_p=dropout_p, 272 is_causal=False, 273 scale=scale, 274 enable_gqa=enable_gqa, 275 ) 276 else: 277 raise ValueError( 278 f"CausalBias.variant must be a CausalVariant type, but found: {attn_mask.variant}" 279 ) 280 281 @classmethod 282 def __torch_function__(cls, func, types, args=(), kwargs=None): 283 """Defines the behavior of torch.nn.functional.scaled_dot_product_attention when the attn_bias is an AttnBias""" 284 if kwargs is None: 285 kwargs = {} 286 if func != torch.nn.functional.scaled_dot_product_attention: 287 raise NotImplementedError( 288 "CausalBias only supports scaled_dot_product_attention" 289 ) 290 return cls._dispatch(*args, **kwargs) 291 292 def __repr__(self): 293 return self._materialize().__repr__() 294 295 296def causal_upper_left(*size) -> CausalBias: 297 """ 298 Creates an upper-left triangular causal bias. 299 300 This function generates a upper-left triangular matrix to represent causal attention bias with a 301 diagonal offset set so that the inclusive values are aligned to the upper left corner of the matrix. 302 This equivalent to the `is_causal=True` argument in `scaled_dot_product_attention`. 303 304 The equivalent pytorch code for constructing this bias is: 305 306 .. code-block:: python 307 308 torch.tril(torch.ones(size, dtype=torch.bool)) 309 310 For instance, with `shape=(3,4)`, the materialized bias tensor will be: 311 312 .. code-block:: text 313 314 [[1, 0, 0, 0], 315 [1, 1, 0, 0], 316 [1, 1, 1, 0]] 317 318 Args: 319 size: The size of the bias matrix. 320 321 Returns: 322 CausalBias: The UPPER_LEFT triangular causal bias variant. 323 """ 324 assert len(size) == 2, "causal_upper_left only supports 2D tensors" 325 seq_len_q, seq_len_kv = size 326 return CausalBias(CausalVariant.UPPER_LEFT, seq_len_q, seq_len_kv) 327 328 329def causal_lower_right(*size) -> CausalBias: 330 """ 331 Creates a lower-right triangular causal bias. 332 333 This function generates a lower-right triangular matrix to represent causal attention bias with a 334 diagonal offset set so that the inclusive values are aligned to the lower right corner of the matrix. 335 336 The equivalent pytorch code for constructing this bias is: 337 338 .. code-block:: python 339 340 diagonal_offset = size[1] - size[0] 341 torch.tril( 342 torch.ones(size, dtype=torch.bool), 343 diagonal=diagonal_offset, 344 ) 345 346 For instance, with `shape=(3,4)`, the materialized bias tensor will be: 347 348 .. code-block:: text 349 350 [[1, 1, 0, 0], 351 [1, 1, 1, 0], 352 [1, 1, 1, 1]] 353 354 Args: 355 size: The size of the bias matrix. 356 357 Returns: 358 CausalBias: The LOWER_RIGHT triangular causal bias variant. 359 """ 360 assert len(size) == 2, "causal_lower_right only supports 2D tensors" 361 seq_len_q, seq_len_kv = size 362 return CausalBias(CausalVariant.LOWER_RIGHT, seq_len_q, seq_len_kv) 363