1# mypy: allow-untyped-decorators 2# mypy: allow-untyped-defs 3# flake8: noqa C101 4"""This module implements the user facing API for flex_attention in PyTorch.""" 5import functools 6import inspect 7import itertools 8import math 9import operator 10from contextlib import nullcontext 11from enum import Enum 12from typing import Any, Callable, Dict, List, Optional, Tuple, Union 13 14import torch 15from torch import Tensor 16from torch._higher_order_ops.flex_attention import ( 17 flex_attention as flex_attention_hop, 18 TransformGetItemToIndex, 19) 20from torch._higher_order_ops.utils import _set_compilation_env 21from torch.fx.experimental.proxy_tensor import ( 22 _temp_remove_pre_dispatch_torch_function_mode, 23) 24from torch.nn.attention._utils import _supported_head_dim, _validate_sdpa_input 25from torch.utils._pytree import tree_map_only 26 27 28__all__ = [ 29 "BlockMask", 30 "flex_attention", 31 "create_block_mask", 32 "create_mask", 33 "or_masks", 34 "and_masks", 35 "noop_mask", 36] 37 38_score_mod_signature = Callable[[Tensor, Tensor, Tensor, Tensor, Tensor], Tensor] 39_mask_mod_signature = Callable[[Tensor, Tensor, Tensor, Tensor], Tensor] 40 41 42class _ModificationType(Enum): 43 """Enum for the type of modification function. 44 - SCORE_MOD: score_mod function which accepts a score as the first argument 45 - mask_mod: mask function which does not accept a score and is only used for generating 46 block mask 47 """ 48 49 SCORE_MOD = 1 50 MASK_MOD = 2 51 UNKNOWN = 3 52 53 54def _get_mod_type(fn: Callable) -> _ModificationType: 55 """Get the type of modification function. 56 This function inspects the number of positional arguments of the function to determine 57 the type of modification function. If the function has 5 positional arguments, it is 58 considered as a score_mod function. If the function has 4 positional arguments, it is 59 considered as a mask function. 60 """ 61 num_positional_args = sum( 62 1 63 for param in inspect.signature(fn).parameters.values() 64 if param.default == inspect.Parameter.empty 65 ) 66 assert num_positional_args == 5 or num_positional_args == 4 67 if num_positional_args == 5: 68 return _ModificationType.SCORE_MOD 69 elif num_positional_args == 4: 70 return _ModificationType.MASK_MOD 71 else: 72 return _ModificationType.UNKNOWN 73 74 75# Need to define it here so that Dynamo doesn't skip it 76def _vmap_for_bhqkv( 77 fn: Callable, 78 prefix: Tuple[Optional[int], ...], 79 suffix: Tuple[Optional[int], ...] = (), 80 out_dims: Union[int, List[Optional[int]]] = 0, 81 group_dim: bool = False, 82): 83 """Used to vmap both score_mods and mask_mods over 4-dimensional/5-dimension inputs. 84 Mapping over the [b, hq, q_idx, kv_idx] or [b, hkv, g, q_idx, kv_idx] dimensions. 85 86 Args: 87 fn (callable): The function to vmap. 88 prefix (tuple): The prefix of the vmap. For score mod functions, 89 this should be set to (0,). For mask_mods = () 90 suffix (tuple): We need to add (0,) if gradOut is being mapped over, 91 and (None,) * len(other_buffers). 92 out_dims (tuple): For forward cases, keep this as the default 0 since 93 we are only returning 1 output. For backwards, the joint 94 graph returns grads for B, H, Q_idx, KV_idx and other_buffers, 95 so we set this to (0, None, None, None, None) + (None,) * len(other_buffers). 96 97 Returns: 98 callable: The vmapped function. 99 """ 100 # We vamp a function 4 times, broadcasting the [b, h, q_idx, kv_idx] dimensions 101 dimensions: List[Tuple[None | int, None | int, None | int, None | int]] = [] 102 dimensions = [ 103 (None, None, None, 0), 104 (None, None, 0, None), 105 (None, 0, None, None), 106 ] 107 108 if group_dim: 109 dimensions += [ 110 (None, 0, None, None), 111 ] 112 113 dimensions += [ 114 (0, None, None, None), 115 ] 116 117 for dims in dimensions: 118 fn = torch.vmap(fn, in_dims=prefix + dims + suffix, out_dims=out_dims) 119 return fn 120 121 122def _identity( 123 score: Tensor, 124 batch: Tensor, 125 head: Tensor, 126 token_q: Tensor, 127 token_kv: Tensor, 128) -> Tensor: 129 return score 130 131 132def noop_mask( 133 batch: Tensor, 134 head: Tensor, 135 token_q: Tensor, 136 token_kv: Tensor, 137) -> Tensor: 138 """Returns a noop mask_mod""" 139 return batch.new_ones(size=(), dtype=torch.bool, device=batch.device) 140 141 142_DEFAULT_SPARSE_BLOCK_SIZE = 128 143_LARGE_SPARSE_BLOCK_SIZE = 1 << 30 144 145 146def _ordered_to_dense(num_blocks_in_row: Tensor, col_indices: Tensor): 147 num_rows = col_indices.shape[-2] 148 num_cols = col_indices.shape[-1] 149 batch_dims = num_blocks_in_row.shape[:-1] 150 device = num_blocks_in_row.device 151 152 def create_dense_one(kv_num_blocks, kv_indices): 153 dense_mask = kv_indices.new_zeros(num_rows, num_cols + 1, dtype=torch.int32) 154 155 row_indices = torch.arange(num_rows, dtype=torch.int, device=device).unsqueeze( 156 -1 157 ) 158 col_range = torch.arange(num_cols, dtype=torch.int, device=device) 159 index_mask = col_range < kv_num_blocks.unsqueeze(-1) 160 161 # We write to one spot "out of bounds" 162 valid_indices = torch.where(index_mask, kv_indices, num_cols) 163 164 # set the values in 'a' to 1 where the indices are valid 165 dense_mask[row_indices, valid_indices] = 1 166 return dense_mask[:, :num_cols].contiguous() 167 168 create_dense_batched = create_dense_one 169 for _ in range(len(batch_dims)): 170 create_dense_batched = torch.vmap(create_dense_batched, in_dims=(0, 0)) 171 172 out = create_dense_batched(num_blocks_in_row, col_indices) 173 return out 174 175 176def _dense_to_ordered(dense_mask) -> Tuple: 177 dense_mask = dense_mask.to(dtype=torch.int32) 178 num_blocks_in_row = dense_mask.sum(dim=-1) 179 col_indices = torch.argsort(dense_mask, dim=-1, descending=True, stable=True) 180 return ( 181 num_blocks_in_row.to(torch.int32).contiguous(), 182 col_indices.to(torch.int32).contiguous(), 183 ) 184 185 186def _transpose_ordered(num_blocks_in_row: Tensor, col_indices: Tensor): 187 dense = _ordered_to_dense(num_blocks_in_row, col_indices) 188 return _dense_to_ordered(dense.transpose(-2, -1)) 189 190 191class BlockMask: 192 r""" 193 BlockMask is our format for representing a block-sparse attention mask. 194 It is somewhat of a cross in-between BCSR and a non-sparse format. 195 196 Basics 197 ------ 198 A block-sparse mask means that instead of representing the sparsity of 199 individual elements in the mask, a KV_BLOCK_SIZE x Q_BLOCK_SIZE block is 200 considered sparse only if every element within that block is sparse. 201 This aligns well with hardware, which generally expects to perform 202 contiguous loads and computation. 203 204 This format is primarily optimized for 1. simplicity, and 2. kernel 205 efficiency. Notably, it is *not* optimized for size, as this mask is always 206 reduced by a factor of KV_BLOCK_SIZE * Q_BLOCK_SIZE. If the size is a 207 concern, the tensors can be reduced in size by increasing the block size. 208 209 The essentials of our format are: 210 211 num_blocks_in_row: Tensor[ROWS]: 212 Describes the number of blocks present in each row. 213 214 col_indices: Tensor[ROWS, MAX_BLOCKS_IN_COL]: 215 `col_indices[i]` is the sequence of block positions for row i. The values of 216 this row after `col_indices[i][num_blocks_in_row[i]]` are undefined. 217 218 For example, to reconstruct the original tensor from this format: 219 220 .. code-block:: python 221 222 dense_mask = torch.zeros(ROWS, COLS) 223 for row in range(ROWS): 224 for block_idx in range(num_blocks_in_row[row]): 225 dense_mask[row, col_indices[row, block_idx]] = 1 226 227 Notably, this format makes it easier to implement a reduction along the 228 *rows* of the mask. 229 230 Details 231 ------- 232 The basics of our format require only kv_num_blocks and kv_indices. But, we 233 have up to 8 tensors on this object. This represents 4 pairs: 234 235 1. (kv_num_blocks, kv_indices): Used for the forwards pass of attention, as 236 we reduce along the KV dimension. 237 238 2. [OPTIONAL] (full_kv_num_blocks, full_kv_indices): This is optional and 239 purely an optimization. As it turns out, applying masking to every block 240 is quite expensive! If we specifically know which blocks are "full" and 241 don't require masking at all, then we can skip applying mask_mod to these 242 blocks. This requires the user to split out a separate mask_mod from the 243 score_mod. For causal masks, this is about a 15% speedup. 244 245 3. [GENERATED] (q_num_blocks, q_indices): Required for the backwards pass, 246 as computing dKV requires iterating along the mask along the Q dimension. These are autogenerated from 1. 247 248 4. [GENERATED] (full_q_num_blocks, full_q_indices): Same as above, but for 249 the backwards pass. These are autogenerated from 2. 250 """ 251 kv_num_blocks: Tensor 252 kv_indices: Tensor 253 full_kv_num_blocks: Optional[Tensor] 254 full_kv_indices: Optional[Tensor] 255 q_num_blocks: Optional[Tensor] 256 q_indices: Optional[Tensor] 257 full_q_num_blocks: Optional[Tensor] 258 full_q_indices: Optional[Tensor] 259 BLOCK_SIZE: Tuple[int, int] 260 mask_mod: _mask_mod_signature 261 262 def __init__( 263 self, 264 kv_num_blocks: Tensor, 265 kv_indices: Tensor, 266 full_kv_num_blocks: Optional[Tensor], 267 full_kv_indices: Optional[Tensor], 268 q_num_blocks: Optional[Tensor], 269 q_indices: Optional[Tensor], 270 full_q_num_blocks: Optional[Tensor], 271 full_q_indices: Optional[Tensor], 272 BLOCK_SIZE: Tuple[int, int], 273 mask_mod: _mask_mod_signature, 274 ): 275 if kv_indices.dim() < 2: 276 raise RuntimeError("BlockMask must have at least 2 dimensions") 277 assert kv_num_blocks is not None, "kv_num_blocks must be provided" 278 assert kv_indices is not None, "kv_indices must be provided" 279 assert q_num_blocks is not None, "q_num_blocks must be provided" 280 assert q_indices is not None, "q_indices must be provided" 281 assert (full_kv_num_blocks is None) == ( 282 full_kv_indices is None 283 ), "full_kv_num_blocks and full_kv_indices must be both provided or omitted" 284 assert (full_q_num_blocks is None) == ( 285 full_q_indices is None 286 ), "full_q_num_blocks and full_q_indices must be both provided or omitted" 287 288 self.kv_num_blocks = kv_num_blocks 289 self.kv_indices = kv_indices 290 self.full_kv_num_blocks = full_kv_num_blocks 291 self.full_kv_indices = full_kv_indices 292 self.q_num_blocks = q_num_blocks 293 self.q_indices = q_indices 294 self.full_q_num_blocks = full_q_num_blocks 295 self.full_q_indices = full_q_indices 296 self.BLOCK_SIZE = BLOCK_SIZE 297 self.mask_mod = mask_mod 298 299 @classmethod 300 def from_kv_blocks( 301 cls, 302 kv_num_blocks: Tensor, 303 kv_indices: Tensor, 304 full_kv_num_blocks: Optional[Tensor] = None, 305 full_kv_indices: Optional[Tensor] = None, 306 BLOCK_SIZE: Union[int, Tuple[int, int]] = _DEFAULT_SPARSE_BLOCK_SIZE, 307 mask_mod: Optional[_mask_mod_signature] = None, 308 ): 309 """ 310 Creates a BlockMask instance from key-value block information. 311 312 Args: 313 kv_num_blocks (Tensor): Number of kv_blocks in each Q_BLOCK_SIZE row tile. 314 kv_indices (Tensor): Indices of key-value blocks in each Q_BLOCK_SIZE row tile. 315 full_kv_num_blocks (Optional[Tensor]): Number of full kv_blocks in each Q_BLOCK_SIZE row tile. 316 full_kv_indices (Optional[Tensor]): Indices of full key-value blocks in each Q_BLOCK_SIZE row tile. 317 BLOCK_SIZE (Union[int, Tuple[int, int]]): Size of KV_BLOCK_SIZE x Q_BLOCK_SIZE tiles. 318 mask_mod (Optional[Callable]): Function to modify the mask. 319 320 Returns: 321 BlockMask: Instance with full Q information generated via _transposed_ordered 322 323 Raises: 324 RuntimeError: If kv_indices has < 2 dimensions. 325 AssertionError: If only one of full_kv_* args is provided. 326 """ 327 if kv_indices.dim() < 2: 328 raise RuntimeError("BlockMask must have at least 2 dimensions") 329 330 assert (full_kv_num_blocks is None) == ( 331 full_kv_indices is None 332 ), "full_kv_num_blocks and full_kv_indices must be both provided or omitted" 333 334 # Generate q_num_blocks and q_indices 335 q_num_blocks, q_indices = _transpose_ordered(kv_num_blocks, kv_indices) 336 if full_kv_num_blocks is not None: 337 assert full_kv_indices is not None 338 full_q_num_blocks, full_q_indices = _transpose_ordered( 339 full_kv_num_blocks, full_kv_indices 340 ) 341 else: 342 full_q_num_blocks, full_q_indices = None, None 343 344 if isinstance(BLOCK_SIZE, int): 345 BLOCK_SIZE = (BLOCK_SIZE, BLOCK_SIZE) 346 347 mask_mod = mask_mod if mask_mod is not None else noop_mask 348 349 return cls( 350 kv_num_blocks=kv_num_blocks, 351 kv_indices=kv_indices, 352 full_kv_num_blocks=full_kv_num_blocks, 353 full_kv_indices=full_kv_indices, 354 q_num_blocks=q_num_blocks, 355 q_indices=q_indices, 356 full_q_num_blocks=full_q_num_blocks, 357 full_q_indices=full_q_indices, 358 BLOCK_SIZE=BLOCK_SIZE, 359 mask_mod=mask_mod, 360 ) 361 362 def as_tuple(self, flatten: bool = True): 363 """ 364 Returns a tuple of the attributes of the BlockMask. 365 366 Args: 367 flatten (bool): If True, it will flatten the tuple of (KV_BLOCK_SIZE, Q_BLOCK_SIZE) 368 """ 369 block_size = ( 370 (self.BLOCK_SIZE[0], self.BLOCK_SIZE[1]) if flatten else (self.BLOCK_SIZE,) 371 ) 372 373 return ( 374 self.kv_num_blocks, 375 self.kv_indices, 376 self.full_kv_num_blocks, 377 self.full_kv_indices, 378 self.q_num_blocks, 379 self.q_indices, 380 self.full_q_num_blocks, 381 self.full_q_indices, 382 *block_size, 383 self.mask_mod, 384 ) 385 386 def __str__(self): 387 s = f"BlockMask(shape={self.shape}, sparsity={self.sparsity():.2f}%, \n" 388 mask_str = self.to_string().strip() 389 s += mask_str 390 s += "\n)" 391 return s 392 393 def __getitem__(self, index) -> "BlockMask": 394 """ 395 Returns a new BlockMask instance by getting the mask for the given index position. 396 397 Args: 398 index: Index to apply to all attributes. 399 400 Example Usage: 401 .. code-block:: python 402 403 def causal_mask(b, h, q_idx, kv_idx): 404 return q_idx >= kv_idx 405 406 block_mask = create_block_mask(causal_mask, 4, 2, 512, 512, device="cuda") 407 assert block_mask.kv_num_blocks.shape == (4,2,4) 408 assert block_mask.kv_indices.shape == (4,2,4,4) 409 410 # Index on batch dimension 411 new_block_mask = block_mask[0] 412 assert new_block_mask.kv_num_blocks.shape == (2,4) 413 assert new_block_mask.kv_indices.shape == (2,4,4) 414 415 # Index on batch and head dimension 416 new_block_mask = block_mask[0, 1] 417 assert new_block_mask.kv_num_blocks.shape == (4,) 418 assert new_block_mask.kv_indices.shape == (4,4) 419 420 # slicing on batch and head dimension 421 new_block_mask = block_mask[0:2, 1:2] 422 assert new_block_mask.kv_num_blocks.shape == (2,1,4) 423 assert new_block_mask.kv_indices.shape == (2,1,4,4) 424 425 # slicing on batch, head, and query dimension 426 new_block_mask = block_mask[0:2, 1:2, torch.tensor([1], dtype=torch.int32)] 427 assert new_block_mask.kv_num_blocks.shape == (2,1,1) 428 assert new_block_mask.kv_indices.shape == (2,1,1,4) 429 """ 430 new_kv_num_blocks = self.kv_num_blocks[index] 431 new_kv_indices = self.kv_indices[index] 432 if self.full_kv_num_blocks is not None: 433 assert self.full_kv_indices is not None 434 new_full_kv_num_blocks = self.full_kv_num_blocks[index] 435 new_full_kv_indices = self.full_kv_indices[index] 436 else: 437 new_full_kv_num_blocks = None 438 new_full_kv_indices = None 439 return BlockMask.from_kv_blocks( 440 new_kv_num_blocks, 441 new_kv_indices, 442 new_full_kv_num_blocks, 443 new_full_kv_indices, 444 BLOCK_SIZE=self.BLOCK_SIZE, 445 mask_mod=None, 446 ) 447 448 def __repr__(self): 449 def shape_or_none(x: Optional[torch.Tensor]): 450 return x.shape if x is not None else None 451 452 return ( 453 f"BlockMask(\n" 454 f" kv_num_blocks={self.kv_num_blocks.shape},\n" 455 f" kv_indices={self.kv_indices.shape},\n" 456 f" full_kv_num_blocks={shape_or_none(self.full_kv_num_blocks )},\n" 457 f" full_kv_indices={shape_or_none(self.full_kv_indices)},\n" 458 f" q_num_blocks={shape_or_none(self.q_num_blocks)},\n" 459 f" q_indices={shape_or_none(self.q_indices)},\n" 460 f" full_q_num_blocks={shape_or_none(self.full_q_num_blocks)},\n" 461 f" full_q_indices={shape_or_none(self.full_q_indices)},\n" 462 f" BLOCK_SIZE={self.BLOCK_SIZE},\n" 463 f" shape={self.shape},\n" 464 f" sparsity={self.sparsity():.2f}%,\n" 465 f" mask_mod={self.mask_mod.__name__ if hasattr(self.mask_mod, '__name__') else self.mask_mod}\n" 466 f")" 467 ) 468 469 @property 470 def shape(self): 471 """Returns the shape of the mask.""" 472 *batch_dims, q_length, _ = self.kv_indices.shape 473 q_length = self.kv_indices.shape[-2] * self.BLOCK_SIZE[0] 474 kv_length = self.kv_indices.shape[-1] * self.BLOCK_SIZE[1] 475 return tuple(batch_dims + [q_length, kv_length]) 476 477 def numel(self): 478 """Returns the number of elements (not accounting for sparsity) in the mask.""" 479 shape = self.shape 480 481 def _prod(xs): 482 return functools.reduce(operator.mul, xs, 1) 483 484 return _prod(shape) 485 486 def sparsity(self) -> float: 487 """Computes the percentage of blocks that are sparse (i.e. not computed)""" 488 total_size = self.numel() 489 computed_blocks = self.kv_num_blocks.sum() 490 if self.full_kv_num_blocks is not None: 491 computed_blocks += self.full_kv_num_blocks.sum() 492 493 computed_size = computed_blocks.item() * self.BLOCK_SIZE[0] * self.BLOCK_SIZE[1] 494 dense_ratio = computed_size / total_size 495 return 100 * (1 - dense_ratio) 496 497 def to_dense(self) -> Tensor: 498 """Returns a dense block that is equivalent to the block mask.""" 499 partial_dense = _ordered_to_dense(self.kv_num_blocks, self.kv_indices) 500 if self.full_kv_num_blocks is not None: 501 assert self.full_kv_indices is not None 502 return partial_dense | _ordered_to_dense( 503 self.full_kv_num_blocks, self.full_kv_indices 504 ) 505 return partial_dense 506 507 def to_string(self, grid_size=(20, 20), limit=4): 508 """Returns a string representation of the block mask. Quite nifty. 509 510 If grid_size is None, prints out an uncompressed version. Warning, it can be quite big! 511 """ 512 dense_mask = self.to_dense() 513 *batch_dims, num_rows, num_cols = dense_mask.shape 514 if isinstance(grid_size, int): 515 max_rows = grid_size 516 max_cols = grid_size 517 elif grid_size == -1: 518 max_rows = num_rows 519 max_cols = num_cols 520 else: 521 max_rows, max_cols = grid_size 522 523 def create_block_vis(*batch_idx): 524 descriptors = [] 525 526 descriptors.append(f"{batch_idx}") 527 528 vis = ", ".join(reversed(descriptors)) + "\n" 529 530 def summarize_section(section): 531 percentage = section.float().mean().item() 532 if percentage == 1: 533 return "█" 534 elif percentage == 0: 535 return " " 536 else: 537 return "░" 538 539 def cdiv(a, b): 540 return (a + (b - 1)) // b 541 542 row_step = max(1, cdiv(num_rows, max_rows)) 543 col_step = max(1, cdiv(num_cols, max_cols)) 544 545 for r in range(0, num_rows, row_step): 546 for c in range(0, num_cols, col_step): 547 cur_mask = dense_mask 548 for idx in batch_idx: 549 cur_mask = cur_mask[idx] 550 char = summarize_section( 551 cur_mask[r : r + row_step, c : c + col_step] 552 ) 553 vis += char * 2 554 vis += "\n" 555 return vis 556 557 total_vis = [] 558 for idx, batch_idx in enumerate( 559 itertools.product(*[range(i) for i in batch_dims]) 560 ): 561 if idx == limit: 562 total_vis.append("...") 563 total_vis.append("To print out more, set BlockMask.to_string(limit=N)") 564 total_vis.append( 565 "You can also index (BlockMask[batch, head]) to choose a specific batch or head" 566 ) 567 break 568 block_vis = create_block_vis(*batch_idx) 569 total_vis.append(block_vis) 570 571 return "\n".join(total_vis) 572 573 def to(self, device: Union[torch.device, str]) -> "BlockMask": 574 """Moves the BlockMask to the specified device. 575 576 Args: 577 device (torch.device or str): The target device to move the BlockMask to. 578 Can be a torch.device object or a string (e.g., 'cpu', 'cuda:0'). 579 580 Returns: 581 BlockMask: A new BlockMask instance with all tensor components moved 582 to the specified device. 583 584 Note: 585 This method does not modify the original BlockMask in-place. 586 Instead, it returns a new BlockMask instance where invidual tensor attributes 587 may or may not be moved to the specified device, depending on their 588 current device placement. 589 """ 590 mapped_attributes = tree_map_only( 591 torch.Tensor, 592 lambda x: x.to(device), 593 self.as_tuple(flatten=False), 594 ) 595 return BlockMask(*mapped_attributes) 596 597 598def _broadcast_to_dim(x, dim): 599 while x.dim() < dim: 600 x = x.unsqueeze(0) 601 return x 602 603 604def _round_up_to_multiple(x, multiple): 605 return (x + multiple - 1) // multiple * multiple 606 607 608def _convert_mask_to_block_mask( 609 mask: Tensor, 610 KV_BLOCK_SIZE=_DEFAULT_SPARSE_BLOCK_SIZE, 611 Q_BLOCK_SIZE=_DEFAULT_SPARSE_BLOCK_SIZE, 612 separate_full_blocks: bool = False, 613) -> Tuple[Tensor, Optional[Tensor]]: 614 assert mask.dtype == torch.bool 615 mask = _broadcast_to_dim(mask, 4) 616 B, H, Q, KV = mask.shape 617 assert Q % Q_BLOCK_SIZE == 0 618 assert KV % KV_BLOCK_SIZE == 0 619 mask = mask.view( 620 B, H, Q // Q_BLOCK_SIZE, Q_BLOCK_SIZE, KV // KV_BLOCK_SIZE, KV_BLOCK_SIZE 621 ) # [B, H, Q//Q_BLOCK_SIZE, Q_BLOCK_SIZE, KV//KV_BLOCK_SIZE, KV_BLOCK_SIZE] 622 mask = mask.permute( 623 0, 1, 2, 4, 3, 5 624 ) # [B, H, Q//Q_BLOCK_SIZE, KV//KV_BLOCK_SIZE, Q_BLOCK_SIZE, KV_BLOCK_SIZE] 625 mask_block_sum = mask.sum( 626 dim=[-2, -1] 627 ) # [B, H, Q//Q_BLOCK_SIZE, KV//KV_BLOCK_SIZE] 628 if separate_full_blocks: 629 full_block_sum = Q_BLOCK_SIZE * KV_BLOCK_SIZE 630 full_blocks = mask_block_sum == full_block_sum 631 partial_blocks = (mask_block_sum > 0) & (mask_block_sum < full_block_sum) 632 partial_blocks = partial_blocks.to(dtype=torch.int8) 633 full_blocks = full_blocks.to(dtype=torch.int8) 634 return partial_blocks, full_blocks 635 else: 636 partial_blocks = mask_block_sum > 0 637 partial_blocks = partial_blocks.to(dtype=torch.int8) 638 return partial_blocks, None 639 640 641def or_masks(*mask_mods: _mask_mod_signature) -> _mask_mod_signature: 642 """Returns a mask_mod that's the union of provided mask_mods""" 643 if not all(callable(arg) for arg in mask_mods): 644 raise RuntimeError(f"All inputs should be callable mask_mods: {mask_mods}") 645 646 def or_mask(b, h, q_idx, kv_idx): 647 result = b.new_zeros((), dtype=torch.bool) 648 for mask in mask_mods: 649 result = result | mask(b, h, q_idx, kv_idx) 650 return result 651 652 return or_mask 653 654 655def and_masks(*mask_mods: _mask_mod_signature) -> _mask_mod_signature: 656 """Returns a mask_mod that's the intersection of provided mask_mods""" 657 if not all(callable(arg) for arg in mask_mods): 658 raise RuntimeError(f"All inputs should be callable mask_mods: {mask_mods}") 659 660 def and_mask(b, h, q_idx, kv_idx): 661 result = b.new_ones((), dtype=torch.bool) 662 for mask in mask_mods: 663 result = result & mask(b, h, q_idx, kv_idx) 664 return result 665 666 return and_mask 667 668 669def _convert_block_mask_to_mask( 670 block_mask, 671 KV_BLOCK_SIZE=_DEFAULT_SPARSE_BLOCK_SIZE, 672 Q_BLOCK_SIZE=_DEFAULT_SPARSE_BLOCK_SIZE, 673) -> Tensor: 674 assert block_mask.dim() == 4 675 B, H, Q, KV = block_mask.shape 676 block_mask = block_mask.expand(Q_BLOCK_SIZE, KV_BLOCK_SIZE, *block_mask.shape) 677 block_mask = block_mask.permute(2, 3, 4, 0, 5, 1).reshape( 678 B, H, Q * Q_BLOCK_SIZE, KV * KV_BLOCK_SIZE 679 ) 680 return block_mask 681 682 683def _create_sparse_block_from_block_mask( 684 block_mask: Tuple[Tensor, Optional[Tensor]], 685 mask_mod: Optional[Callable], 686 KV_BLOCK_SIZE: int = _DEFAULT_SPARSE_BLOCK_SIZE, 687 Q_BLOCK_SIZE: int = _DEFAULT_SPARSE_BLOCK_SIZE, 688) -> BlockMask: 689 partial_blocks, full_blocks = block_mask 690 691 partial_bm = _dense_to_ordered(partial_blocks) 692 if full_blocks is not None: 693 full_bm = _dense_to_ordered(full_blocks) 694 else: 695 full_bm = (None, None) 696 697 return BlockMask.from_kv_blocks( 698 partial_bm[0], 699 partial_bm[1], 700 full_bm[0], 701 full_bm[1], 702 BLOCK_SIZE=(KV_BLOCK_SIZE, Q_BLOCK_SIZE), 703 mask_mod=mask_mod, 704 ) 705 706 707def create_mask( 708 mod_fn: Union[_score_mod_signature, _mask_mod_signature], 709 B: Optional[int], 710 H: Optional[int], 711 Q_LEN: int, 712 KV_LEN: int, 713 device: str = "cuda", 714 _compile: bool = False, 715) -> Tensor: 716 r"""This function creates a mask tensor from a mod_fn function. 717 718 Args: 719 mod_fn (Union[_score_mod_signature, _mask_mod_signature]): Function to modify attention scores. 720 B (int): Batch size. 721 H (int): Number of query heads. 722 Q_LEN (int): Sequence length of query. 723 KV_LEN (int): Sequence length of key/value. 724 device (str): Device to run the mask creation on. 725 726 Returns: 727 mask (Tensor): A mask tensor with shape (B, H, M, N). 728 """ 729 if B is None: 730 B = 1 731 if H is None: 732 H = 1 733 b = torch.arange(0, B, device=device) 734 h = torch.arange(0, H, device=device) 735 m = torch.arange(0, Q_LEN, device=device) 736 n = torch.arange(0, KV_LEN, device=device) 737 # TODO: fix this 738 # Lack instantiation support for __torch_function__ mode support under compile 739 if _compile: 740 ctx = nullcontext() 741 else: 742 ctx = TransformGetItemToIndex() # type: ignore[assignment] 743 mod_type = _get_mod_type(mod_fn) 744 745 with ctx: 746 if mod_type == _ModificationType.SCORE_MOD: 747 score_mod = mod_fn 748 score_mod = _vmap_for_bhqkv(score_mod, prefix=(0,)) # first input is score 749 out = score_mod(torch.zeros(B, H, Q_LEN, KV_LEN, device=device), b, h, m, n) 750 mask = torch.where(torch.isneginf(out), False, True) 751 return mask 752 elif mod_type == _ModificationType.MASK_MOD: 753 mask_mod = mod_fn 754 mask_mod = _vmap_for_bhqkv(mask_mod, prefix=()) 755 mask = mask_mod(b, h, m, n) 756 return mask 757 else: 758 raise AssertionError 759 760 761def _create_block_mask_inner( 762 mask_mod: Callable, 763 B: int, 764 H: int, 765 Q_LEN: int, 766 KV_LEN: int, 767 device: str, 768 KV_BLOCK_SIZE: int, 769 Q_BLOCK_SIZE: int, 770): 771 r"""Work around for being unable to instantiate __torch_function__ mode under compile. 772 `create_block_mask` will compile this inner function and wrap the call to this 773 with the __torch_function__ mode. 774 """ 775 mask_tensor = create_mask(mask_mod, B, H, Q_LEN, KV_LEN, device, _compile=True) 776 partial_block_mask, full_block_mask = _convert_mask_to_block_mask( 777 mask_tensor, 778 KV_BLOCK_SIZE=KV_BLOCK_SIZE, 779 Q_BLOCK_SIZE=Q_BLOCK_SIZE, 780 separate_full_blocks=True, 781 ) 782 return partial_block_mask, full_block_mask 783 784 785def create_block_mask( 786 mask_mod: _mask_mod_signature, 787 B: Optional[int], 788 H: Optional[int], 789 Q_LEN: int, 790 KV_LEN: int, 791 device: str = "cuda", 792 BLOCK_SIZE: Union[int, Tuple[int, int]] = _DEFAULT_SPARSE_BLOCK_SIZE, 793 _compile=False, 794) -> BlockMask: 795 r"""This function creates a block mask tuple from a mask_mod function. 796 797 Args: 798 mask_mod (Callable): mask_mod function. This is a callable that defines the 799 masking pattern for the attention mechanism. It takes four arguments: 800 b (batch size), h (number of heads), q_idx (query index), and kv_idx (key/value index). 801 It should return a boolean tensor indicating which attention connections are allowed (True) 802 or masked out (False). 803 B (int): Batch size. 804 H (int): Number of query heads. 805 Q_LEN (int): Sequence length of query. 806 KV_LEN (int): Sequence length of key/value. 807 device (str): Device to run the mask creation on. 808 KV_BLOCK_SIZE (int): Block size of block mask for each query. 809 Q_BLOCK_SIZE (int): Block size of block mask for each key/value. 810 _compile (bool): Whether to compile the mask creation. 811 812 Returns: 813 BlockMask: A BlockMask object that contains the block mask information. 814 815 Example Usage: 816 .. code-block:: python 817 818 def causal_mask(b, h, q_idx, kv_idx): 819 return q_idx >= kv_idx 820 821 block_mask = create_block_mask(causal_mask, 1, 1, 8192, 8192, device="cuda") 822 query = torch.randn(1, 1, 8192, 64, device="cuda", dtype=torch.float16) 823 key = torch.randn(1, 1, 8192, 64, device="cuda", dtype=torch.float16) 824 value = torch.randn(1, 1, 8192, 64, device="cuda", dtype=torch.float16) 825 output = flex_attention(query, key, value, block_mask=block_mask) 826 """ 827 mod_type = _get_mod_type(mask_mod) 828 assert ( 829 mod_type == _ModificationType.MASK_MOD 830 ), f"create-block_mask requires a mask_mod function! Got {mask_mod}" 831 inner_func = _create_block_mask_inner 832 if B is None: 833 B = 1 834 if H is None: 835 H = 1 836 if isinstance(BLOCK_SIZE, int): 837 Q_BLOCK_SIZE = BLOCK_SIZE 838 KV_BLOCK_SIZE = BLOCK_SIZE 839 else: 840 Q_BLOCK_SIZE, KV_BLOCK_SIZE = BLOCK_SIZE 841 842 if Q_LEN < 128: 843 Q_BLOCK_SIZE = Q_LEN 844 else: 845 Q_LEN = _round_up_to_multiple(Q_LEN, Q_BLOCK_SIZE) 846 KV_LEN = _round_up_to_multiple(KV_LEN, KV_BLOCK_SIZE) 847 if _compile: 848 inner_func = torch.compile(inner_func, fullgraph=True, dynamic=False) 849 with TransformGetItemToIndex(): 850 partial_block_mask, full_block_mask = inner_func( 851 mask_mod, B, H, Q_LEN, KV_LEN, device, KV_BLOCK_SIZE, Q_BLOCK_SIZE 852 ) 853 block_mask = _create_sparse_block_from_block_mask( 854 (partial_block_mask, full_block_mask), mask_mod 855 ) 856 return block_mask 857 858 859def _create_empty_block_mask(query: Tensor, key: Tensor) -> BlockMask: 860 r"""Default block mask for flex attention. 861 If users don't specify any block sparse mask info, we create this 862 empty block sparse mask. Which creates a BlockMask with 1 block that is the full length 863 of the query and key tensors. 864 """ 865 device = query.device 866 return BlockMask.from_kv_blocks( 867 kv_num_blocks=torch.ones([1, 1, 1], dtype=torch.int32, device=device), 868 kv_indices=torch.zeros([1, 1, 1, 1], dtype=torch.int32, device=device), 869 BLOCK_SIZE=_LARGE_SPARSE_BLOCK_SIZE, 870 ) 871 872 873def _apply_kernel_options( 874 query: Tensor, key: Tensor, value: Tensor, return_lse: bool, kernel_options 875): 876 kernel_options = {} if kernel_options is None else dict(kernel_options) 877 878 kernel_options.setdefault("ROWS_GUARANTEED_SAFE", False) 879 kernel_options.setdefault("PRESCALE_QK", False) 880 881 # If foward kernel needs to return logsumexp is decided by this rule internally. 882 assert "OUTPUT_LOGSUMEXP" not in kernel_options 883 kernel_options["OUTPUT_LOGSUMEXP"] = True 884 if not return_lse: 885 any_inputs_require_grad = ( 886 query.requires_grad or key.requires_grad or value.requires_grad 887 ) 888 output_logsumexp = any_inputs_require_grad and torch.is_grad_enabled() 889 kernel_options["OUTPUT_LOGSUMEXP"] = output_logsumexp 890 891 return kernel_options 892 893 894def _validate_embed_dim(query: Tensor, key: Tensor, value: Tensor): 895 if query.size(-1) != key.size(-1): 896 raise ValueError( 897 f"Expect query and key/value to have the same embedding dimension " 898 f"but got E={query.size(-1)} and E={key.size(-1)}." 899 ) 900 # TODO this config segfaults with Triton without: 901 # https://github.com/triton-lang/triton/pull/4540 902 if not ( 903 _supported_head_dim(query.size(-1)) and _supported_head_dim(value.size(-1)) 904 ): 905 raise ValueError( 906 f"NYI: Currently non power of 2 embedding dimension are not supported. " 907 f"Got E={query.size(-1)} and Ev={value.size(-1)}." 908 ) 909 if value.size(-1) > query.size(-1): 910 raise ValueError( 911 f"NYI: Currently value embedding dimension must be less than or equal to query embedding dimension. " 912 f"Got Ev={value.size(-1)} and E={query.size(-1)}." 913 ) 914 915 916def flex_attention( 917 query: Tensor, 918 key: Tensor, 919 value: Tensor, 920 score_mod: Optional[_score_mod_signature] = None, 921 block_mask: Optional[BlockMask] = None, 922 scale: Optional[float] = None, 923 enable_gqa: bool = False, 924 return_lse: bool = False, 925 kernel_options: Optional[Dict[str, Any]] = None, 926) -> Union[Tensor, Tuple[Tensor, Tensor]]: 927 r"""This function implements scaled dot product attention with an arbitrary attention score modification function. 928 929 This function computes the scaled dot product attention between query, key, and value tensors with a user-defined 930 attention score modification function. The attention score modification function will be applied after the attention 931 scores have been calculated between the query and key tensors. The attention scores are calculated as follows: 932 933 The ``score_mod`` function should have the following signature: 934 935 .. code-block:: python 936 937 def score_mod( 938 score: Tensor, 939 batch: Tensor, 940 head: Tensor, 941 q_idx: Tensor, 942 k_idx: Tensor 943 ) -> Tensor: 944 945 Where: 946 - ``score``: A scalar tensor representing the attention score, 947 with the same data type and device as the query, key, and value tensors. 948 - ``batch``, ``head``, ``q_idx``, ``k_idx``: Scalar tensors indicating 949 the batch index, query head index, query index, and key/value index, respectively. 950 These should have the ``torch.int`` data type and be located on the same device as the score tensor. 951 952 Args: 953 query (Tensor): Query tensor; shape :math:`(B, Hq, L, E)`. 954 key (Tensor): Key tensor; shape :math:`(B, Hkv, S, E)`. 955 value (Tensor): Value tensor; shape :math:`(B, Hkv, S, Ev)`. 956 score_mod (Optional[Callable]): Function to modify attention scores. By default no score_mod is applied. 957 block_mask (Optional[BlockMask]): BlockMask object that controls the blocksparsity pattern of the attention. 958 scale (Optional[float]): Scaling factor applied prior to softmax. If none, the default value is set to :math:`\frac{1}{\sqrt{E}}`. 959 enable_gqa (bool): If set to True, enables Grouped Query Attention (GQA) and broadcasts key/value heads to query heads. 960 return_lse (bool): Whether to return the logsumexp of the attention scores. Default is False. 961 kernel_options (Optional[Dict[str, Any]]): Options to pass into the Triton kernels. 962 963 Returns: 964 output (Tensor): Attention output; shape :math:`(B, Hq, L, Ev)`. 965 966 Shape legend: 967 - :math:`N: \text{Batch size} ... : \text{Any number of other batch dimensions (optional)}` 968 - :math:`S: \text{Source sequence length}` 969 - :math:`L: \text{Target sequence length}` 970 - :math:`E: \text{Embedding dimension of the query and key}` 971 - :math:`Ev: \text{Embedding dimension of the value}` 972 973 .. warning:: 974 `torch.nn.attention.flex_attention` is a prototype feature in PyTorch. 975 Please look forward to a more stable implementation in a future version of PyTorch. 976 Read more about feature classification at: https://pytorch.org/blog/pytorch-feature-classification-changes/#prototype 977 978 """ 979 # Some basic input validation 980 _validate_sdpa_input(query, key, value) 981 _validate_embed_dim(query, key, value) 982 if query.dim() != 4 or key.dim() != 4 or value.dim() != 4: 983 raise NotImplementedError("NYI: query, key, and value must be 4D tensors") 984 if (not enable_gqa) and query.size(-3) != key.size(-3): 985 raise ValueError( 986 f"Expect query and key/value to have the same number of heads " 987 f"but got Hq={query.size(-3)} and Hkv={key.size(-3)}. " 988 f"Try setting enable_gqa=True for GQA." 989 ) 990 if enable_gqa: 991 Hq = query.size(1) 992 Hkv = key.size(1) 993 if Hq % Hkv != 0: 994 raise ValueError( 995 f"Expect number of query heads to be a multiple of kv heads for GQA " 996 f"but got Hq={Hq} and Hkv={Hkv}." 997 ) 998 999 if score_mod is None: 1000 score_mod = _identity 1001 if block_mask is None: 1002 block_mask = _create_empty_block_mask(query, key) 1003 if scale is None: 1004 scale = 1.0 / math.sqrt(query.size(-1)) 1005 1006 kernel_options = _apply_kernel_options( 1007 query, 1008 key, 1009 value, 1010 return_lse, 1011 kernel_options, 1012 ) 1013 1014 if torch.compiler.is_dynamo_compiling(): 1015 # mark head_dim and number of heads to be static 1016 for x in [query, key, value]: 1017 torch._dynamo.mark_static(x, -3) 1018 torch._dynamo.mark_static(x, -1) 1019 out, lse = flex_attention_hop( 1020 query, key, value, score_mod, block_mask.as_tuple(), scale, kernel_options 1021 ) 1022 if return_lse: 1023 return out, lse * math.log(2) 1024 else: 1025 return out 1026 1027 if not torch._dynamo.is_dynamo_supported(): 1028 raise RuntimeError("flex_attention requires dynamo support") 1029 1030 # Dynamo is expecting a callable with "__code__" attribute. 1031 # We cannot directly pass hop to it. So we wrap it in a dummy function. 1032 def _flex_attention_hop_wrapper(*args, **kwargs): 1033 return flex_attention_hop(*args, **kwargs) 1034 1035 with _set_compilation_env(): 1036 with torch._dynamo.utils.disable_cache_limit(): 1037 with _temp_remove_pre_dispatch_torch_function_mode(): 1038 out, lse = torch.compile( 1039 _flex_attention_hop_wrapper, backend="eager", fullgraph=True 1040 )( 1041 query, 1042 key, 1043 value, 1044 score_mod, 1045 block_mask.as_tuple(), 1046 scale, 1047 kernel_options, 1048 ) 1049 if return_lse: 1050 return out, lse * math.log(2) 1051 else: 1052 return out 1053