1import argparse 2import csv 3import itertools 4from collections import defaultdict 5from dataclasses import asdict, dataclass 6from functools import partial 7from typing import Callable, List, Optional, Tuple 8 9import numpy as np 10from tabulate import tabulate 11from tqdm import tqdm 12 13import torch 14import torch.nn.functional as F 15from torch.nn.attention.flex_attention import ( 16 _create_empty_block_mask, 17 create_block_mask, 18 create_mask, 19 flex_attention, 20) 21 22 23torch._dynamo.config.automatic_dynamic_shapes = False 24# Needed since changing args to function causes recompiles 25torch._dynamo.config.cache_size_limit = 1000 26 27 28from torch._inductor.runtime.benchmarking import benchmarker 29 30 31def benchmark_torch_function_in_microseconds(func: Callable, *args, **kwargs) -> float: 32 # warmup 33 for _ in range(5): 34 func(*args, **kwargs) 35 return benchmarker.benchmark_gpu(lambda: func(*args, **kwargs)) * 1e3 36 37 38@dataclass(frozen=True) 39class ExperimentConfig: 40 shape: Tuple[int] 41 score_mod: Callable 42 mask_mod: Callable 43 dtype: torch.dtype 44 calculate_bwd_time: bool 45 cal_bandwidth: bool 46 47 def __post_init__(self): 48 assert ( 49 len(self.shape) == 6 50 ), "Shape must be of length 6" # [B, Hq, M, Hkv, N, D] 51 52 def asdict(self): 53 # Convert the dataclass instance to a dictionary 54 d = asdict(self) 55 # Remove the 'calculate_bwd_time' and `cal_bandwidth` key 56 d.pop("calculate_bwd_time", None) 57 d.pop("cal_bandwidth", None) 58 d["shape(B,Hq,M,Hkv,N,D)"] = d.pop("shape") 59 return d 60 61 62@dataclass(frozen=True) 63class Times: 64 eager_time: float 65 compiled_time: float 66 67 68@dataclass(frozen=True) 69class ExperimentResults: 70 fwd_times: Times 71 bwd_times: Optional[Times] 72 73 74@dataclass(frozen=True) 75class Experiment: 76 config: ExperimentConfig 77 results: ExperimentResults 78 79 def asdict(self): 80 dict1 = self.config.asdict() 81 dict2 = asdict(self.results) 82 return {**dict1, **dict2} 83 84 85def generate_inputs( 86 batch_size: int, 87 q_heads: int, 88 q_sequence_length: int, 89 kv_heads: int, 90 kv_sequence_length: int, 91 head_dim: int, 92 dtype: torch.dtype, 93 device: torch.device, 94 requires_grad: bool, 95): 96 q_shape = (batch_size, q_sequence_length, q_heads * head_dim) 97 kv_shape = (batch_size, kv_sequence_length, kv_heads * head_dim) 98 99 assert q_heads % kv_heads == 0 100 101 num_h_groups = q_heads // kv_heads 102 103 make_q = partial( 104 torch.rand, q_shape, device=device, dtype=dtype, requires_grad=requires_grad 105 ) 106 make_kv = partial( 107 torch.rand, kv_shape, device=device, dtype=dtype, requires_grad=requires_grad 108 ) 109 query = ( 110 make_q().view(batch_size, q_sequence_length, q_heads, head_dim).transpose(1, 2) 111 ) 112 key = ( 113 make_kv() 114 .view(batch_size, kv_sequence_length, kv_heads, head_dim) 115 .transpose(1, 2) 116 ) 117 value = ( 118 make_kv() 119 .view(batch_size, kv_sequence_length, kv_heads, head_dim) 120 .transpose(1, 2) 121 ) 122 return query, key, value 123 124 125def run_single_experiment( 126 config: ExperimentConfig, 127 dynamic=False, 128 max_autotune=False, 129) -> ExperimentResults: 130 device = torch.device("cuda") 131 batch_size, q_heads, q_seq_len, kv_heads, kv_seq_len, head_dim = config.shape 132 query, key, value = generate_inputs( 133 batch_size, 134 q_heads, 135 q_seq_len, 136 kv_heads, 137 kv_seq_len, 138 head_dim, 139 config.dtype, 140 device, 141 requires_grad=config.calculate_bwd_time, 142 ) 143 144 kwargs = {} 145 if get_func_name(config.mask_mod) == "causal": 146 kwargs["is_causal"] = True 147 148 def eager_sdpa(query, key, value, attn_mask): 149 out = F.scaled_dot_product_attention(query, key, value, attn_mask, **kwargs) 150 return out.reshape(batch_size, q_heads, q_seq_len, head_dim) 151 152 if max_autotune: 153 compiled_sdpa = torch.compile( 154 flex_attention, dynamic=dynamic, mode="max-autotune-no-cudagraphs" 155 ) 156 else: 157 compiled_sdpa = torch.compile(flex_attention, dynamic=dynamic) 158 159 score_mod = config.score_mod 160 mask_mod = config.mask_mod 161 162 if mask_mod: 163 block_mask = create_block_mask( 164 mask_mod, 1, 1, q_seq_len, kv_seq_len, query.device 165 ) 166 else: 167 block_mask = _create_empty_block_mask(query, key) 168 169 if mask_mod and get_func_name(mask_mod) != "causal": 170 attn_mask = create_mask(mask_mod, 1, 1, query.shape[-2], key.shape[-2]) 171 else: 172 attn_mask = None 173 174 # Broadcast query/key for eager. 175 b_key = torch.repeat_interleave(key, q_heads // kv_heads, dim=1) 176 b_value = torch.repeat_interleave(value, q_heads // kv_heads, dim=1) 177 178 forward_eager_time = benchmark_torch_function_in_microseconds( 179 eager_sdpa, query, b_key, b_value, attn_mask 180 ) 181 forward_compiled_time = benchmark_torch_function_in_microseconds( 182 compiled_sdpa, 183 query, 184 key, 185 value, 186 score_mod=score_mod, 187 block_mask=block_mask, 188 enable_gqa=True, 189 ) 190 191 out_eager = eager_sdpa(query, b_key, b_value, attn_mask) 192 out_compile = compiled_sdpa( 193 query, 194 b_key, 195 b_value, 196 score_mod=score_mod, 197 block_mask=block_mask, 198 enable_gqa=True, 199 ) 200 201 if score_mod is None: 202 torch.testing.assert_close(out_eager, out_compile, atol=1e-2, rtol=1e-2) 203 204 if config.calculate_bwd_time: 205 out_eager = eager_sdpa(query, b_key, b_value, attn_mask) 206 dOut = torch.randn_like(out_eager) 207 backward_eager_time = benchmark_torch_function_in_microseconds( 208 out_eager.backward, dOut, retain_graph=True 209 ) 210 211 out_compile = compiled_sdpa( 212 query, 213 key, 214 value, 215 score_mod=score_mod, 216 block_mask=block_mask, 217 enable_gqa=True, 218 ) 219 dOut = torch.randn_like(out_compile) 220 backward_compile_time = benchmark_torch_function_in_microseconds( 221 out_compile.backward, dOut, retain_graph=True 222 ) 223 224 return ExperimentResults( 225 fwd_times=Times(forward_eager_time, forward_compiled_time), 226 bwd_times=Times(backward_eager_time, backward_compile_time), 227 ) 228 else: 229 return ExperimentResults( 230 fwd_times=Times(forward_eager_time, forward_compiled_time), 231 bwd_times=None, 232 ) 233 234 235def calculate_speedup(results: ExperimentResults, type: str) -> float: 236 if type == "fwd": 237 return results.fwd_times.eager_time / results.fwd_times.compiled_time 238 elif type == "bwd": 239 assert results.bwd_times is not None 240 return results.bwd_times.eager_time / results.bwd_times.compiled_time 241 else: 242 raise ValueError(f"Invalid type {type}") 243 244 245def calculate_bandwidth( 246 config: ExperimentConfig, results: ExperimentResults, type: str 247) -> float: 248 if type == "fwd": 249 batch_size, q_heads, q_seq_len, kv_heads, kv_seq_len, head_dim = config.shape 250 query_size = ( 251 batch_size 252 * q_heads 253 * q_seq_len 254 * head_dim 255 * torch.finfo(config.dtype).bits 256 / 8 257 ) 258 kv_size = ( 259 batch_size 260 * kv_heads 261 * kv_seq_len 262 * head_dim 263 * torch.finfo(config.dtype).bits 264 / 8 265 * 2 266 ) 267 output_size = query_size 268 total_size = (query_size + kv_size + output_size) / 1e9 # In GB 269 time_in_seconds = results.fwd_times.compiled_time / 1e6 270 return total_size / time_in_seconds / 1e3 271 else: 272 raise ValueError(f"Invalid type {type}") 273 274 275def calculate_tflops(config: ExperimentConfig, results: ExperimentResults) -> float: 276 (B, Hq, M, Hkv, N, D) = config.shape 277 qk_flops = M * N * D * 2 278 softmax_flops = M * N * 2 # Not counting online softmax overhead 279 o_flops = M * D * N * 2 280 # Not counting split k overhead 281 total_flops = B * Hq * (qk_flops + softmax_flops + o_flops) 282 return total_flops / results.fwd_times.compiled_time / 1e6 # in TFLOPs/ 283 284 285def get_func_name(func): 286 if func is None: 287 return "None" 288 func_str = str(func) 289 if "<locals>" in func_str: 290 # For locally defined functions 291 return func_str.split("<locals>.")[-1].split(" at ")[0] 292 else: 293 # For regular functions 294 return func.__name__ 295 296 297def set_func_name(func, name): 298 func.__name__ = name 299 300 301def get_average_speedups(results: List[Experiment], type: str): 302 # Calculate speedups 303 speedups = [calculate_speedup(r.results, type) for r in results] 304 305 # Find indices of max and min speedups 306 max_speedup_index = np.argmax(speedups) 307 min_speedup_index = np.argmin(speedups) 308 309 # Get the config dictionaries 310 max_config_dict = results[max_speedup_index].config.asdict() 311 min_config_dict = results[min_speedup_index].config.asdict() 312 313 # Extract function names from score_mod strings 314 max_config_dict["score_mod"] = get_func_name(max_config_dict["score_mod"]) 315 max_config_dict["mask_mod"] = get_func_name(max_config_dict["mask_mod"]) 316 min_config_dict["score_mod"] = get_func_name(min_config_dict["score_mod"]) 317 min_config_dict["mask_mod"] = get_func_name(min_config_dict["mask_mod"]) 318 319 # Create table data 320 table_data = [ 321 { 322 "Type": "Average", 323 "Speedup": np.mean(speedups), 324 **dict.fromkeys(max_config_dict), 325 }, 326 {"Type": "Max", "Speedup": speedups[max_speedup_index], **max_config_dict}, 327 {"Type": "Min", "Speedup": speedups[min_speedup_index], **min_config_dict}, 328 ] 329 330 return table_data 331 332 333def print_results(results: List[Experiment], save_path: Optional[str] = None): 334 table_data = defaultdict(list) 335 for experiment in results: 336 for key, value in experiment.asdict().items(): 337 if key == "fwd_times": 338 for name, time in value.items(): 339 table_data[f"fwd_{name}"].append(float(time)) 340 elif key == "bwd_times": 341 if experiment.config.calculate_bwd_time: 342 for name, time in value.items(): 343 table_data[f"bwd_{name}"].append(float(time)) 344 else: 345 table_data[key].append(value) 346 347 # Calculate speedups 348 fwd_speedups = [calculate_speedup(r.results, type="fwd") for r in results] 349 table_data["fwd_speedup"] = fwd_speedups 350 351 # Calculate mem + computational throughput 352 if results[0].config.cal_bandwidth: 353 fwd_bandwidth = [ 354 calculate_bandwidth(r.config, r.results, type="fwd") for r in results 355 ] 356 table_data["fwd_mem_bw (TB/s)"] = fwd_bandwidth 357 fwd_tflops = [calculate_tflops(r.config, r.results) for r in results] 358 table_data["TFlops/s"] = fwd_tflops 359 360 if results[0].config.calculate_bwd_time: 361 bwd_speedups = [calculate_speedup(r.results, type="bwd") for r in results] 362 table_data["bwd_speedup"] = bwd_speedups 363 364 table_data["score_mod"] = [get_func_name(func) for func in table_data["score_mod"]] 365 table_data["mask_mod"] = [get_func_name(func) for func in table_data["mask_mod"]] 366 367 print(tabulate(table_data, headers="keys", tablefmt="github", floatfmt=".3f")) 368 print("\n") 369 print("FWD Speedups".center(125, "=")) 370 print("\n") 371 average_data = get_average_speedups(results, type="fwd") 372 print(tabulate(average_data, headers="keys", tablefmt="github", floatfmt=".3f")) 373 374 if results[0].config.calculate_bwd_time: 375 print("\n") 376 print("BWD Speedups".center(125, "=")) 377 print("\n") 378 average_data = get_average_speedups(results, type="bwd") 379 print(tabulate(average_data, headers="keys", tablefmt="github", floatfmt=".3f")) 380 381 if save_path is not None: 382 with open(save_path, "w", newline="") as csvfile: 383 writer = csv.DictWriter(csvfile, fieldnames=table_data.keys()) 384 writer.writeheader() 385 for i in range(len(next(iter(table_data.values())))): 386 row = {k: v[i] for k, v in table_data.items()} 387 writer.writerow(row) 388 print(f"\nResults saved to {save_path}") 389 390 391def generate_score_mods(score_mods: List[str]) -> List[Callable | None]: 392 def noop(score, b, h, m, n): 393 return score 394 395 def causal_mask(score, b, h, token_q, token_kv): 396 return torch.where(token_q >= token_kv, score, float("-inf")) 397 398 def relative_bias(score, b, h, m, n): 399 return score + (m - n) 400 401 def head_bias(score, b, h, m, n): 402 return score + 2 * h 403 404 function_dict = { 405 "noop": None, 406 "causal": None, 407 "offset": None, 408 "rel": relative_bias, 409 "head_bias": head_bias, 410 } 411 return [function_dict[name] for name in score_mods] 412 413 414def generate_mask_mods(score_mods: List[str]) -> List[Callable | None]: 415 def noop(b, h, m, n): 416 return True 417 418 def causal(b, h, m, n): 419 return m >= n 420 421 def gen_offset(off): 422 def offset(b, h, m, n): 423 return m + off >= n 424 425 return offset 426 427 mask_mod_dict = { 428 "noop": None, 429 "causal": causal, 430 "offset": gen_offset, 431 "rel": None, 432 "head_bias": None, 433 } 434 return [mask_mod_dict[name] for name in score_mods] 435 436 437def generate_flash_configs( 438 calculate_bwd: bool, 439 dtype: torch.dtype, 440 batch_sizes: List[int], 441 num_heads: List[Tuple[int, int]], 442 seq_lens: List[int], 443 head_dims: List[int], 444 score_mods_str: List[str], 445 decoding: bool, 446 kv_cache_size: List[int], 447 cal_bandwidth: bool, 448) -> List[ExperimentConfig]: 449 assert not (calculate_bwd and decoding), "Decoding does not support backward" 450 451 bs_seqlen_vals = [ 452 (32, 512), 453 (16, 1024), 454 (8, 2048), 455 (4, 4096), 456 (2, 8192), 457 (1, 16384), 458 ] 459 causal_vals = [False, True] 460 headdim_vals = [64, 128] 461 dim = 2048 462 463 score_mods = generate_score_mods(score_mods_str) 464 mask_mods = generate_mask_mods(score_mods_str) 465 all_configs = [] 466 467 for ( 468 (batch_size, seq_len), 469 causal, 470 head_dim, 471 score_mod, 472 mask_mod, 473 ) in itertools.product( 474 bs_seqlen_vals, 475 causal_vals, 476 headdim_vals, 477 score_mods, 478 mask_mods, 479 ): 480 num_heads = dim // head_dim 481 482 if decoding: 483 q_seq_len, kv_seq_len = 1, seq_len 484 else: 485 q_seq_len = kv_seq_len = seq_len 486 487 all_configs.append( 488 ExperimentConfig( 489 shape=( 490 batch_size, 491 num_heads, 492 q_seq_len, 493 num_heads, 494 kv_seq_len, 495 head_dim, 496 ), 497 score_mod=score_mod, 498 mask_mod=mask_mod, 499 dtype=dtype, 500 calculate_bwd_time=calculate_bwd, 501 cal_bandwidth=cal_bandwidth, 502 ) 503 ) 504 505 return all_configs 506 507 508def generate_experiment_configs( 509 calculate_bwd: bool, 510 dtype: torch.dtype, 511 batch_sizes: List[int], 512 num_heads: List[Tuple[int, int]], 513 seq_lens: List[int], 514 head_dims: List[int], 515 score_mods_str: List[str], 516 decoding: bool, 517 kv_cache_size: List[int], 518 cal_bandwidth: bool, 519) -> List[ExperimentConfig]: 520 assert not (calculate_bwd and decoding), "Decoding does not support backward" 521 522 if decoding: 523 q_kv_seq_lens = [(1, i) for i in seq_lens] # only testing query length == 1 524 else: 525 q_kv_seq_lens = [(i, i) for i in seq_lens] # only testing q_len == kv_len 526 dtypes = [dtype] 527 score_mods = generate_score_mods(score_mods_str) 528 mask_mods = generate_mask_mods(score_mods_str) 529 all_configs = [] 530 for ( 531 bsz, 532 (q_heads, kv_heads), 533 (q_seq_len, kv_seq_len), 534 head_dim, 535 (score_mod, mask_mod), 536 dtype, 537 ) in itertools.product( 538 kv_cache_size if kv_cache_size else batch_sizes, 539 num_heads, 540 q_kv_seq_lens, 541 head_dims, 542 zip(score_mods, mask_mods), 543 dtypes, 544 ): 545 if kv_cache_size: 546 head_size_bytes = torch.finfo(dtype).bits / 8 * head_dim 547 bsz = int( 548 (bsz * 1024 * 1024) // (kv_heads * kv_seq_len * head_size_bytes * 2) 549 ) 550 if bsz <= 0: 551 continue 552 553 assert q_heads % kv_heads == 0 554 555 if mask_mod and get_func_name(mask_mod) == "gen_offset": 556 mask_mod = mask_mod(kv_seq_len // 2) 557 558 all_configs.append( 559 ExperimentConfig( 560 shape=(bsz, q_heads, q_seq_len, kv_heads, kv_seq_len, head_dim), 561 score_mod=score_mod, 562 mask_mod=mask_mod, 563 dtype=dtype, 564 calculate_bwd_time=calculate_bwd, 565 cal_bandwidth=cal_bandwidth, 566 ) 567 ) 568 569 return all_configs 570 571 572def main(args): 573 seed = 123 574 np.random.seed(seed) 575 torch.manual_seed(seed) 576 results = [] 577 for config in tqdm( 578 generate_experiment_configs( 579 args.calculate_bwd, 580 args.dtype, 581 args.b, 582 args.nh, 583 args.s, 584 args.d, 585 args.mods, 586 args.decoding, 587 args.kv_cache_size, 588 args.throughput, 589 ) 590 ): 591 results.append( 592 Experiment( 593 config, 594 run_single_experiment( 595 config, 596 dynamic=args.dynamic, 597 max_autotune=args.max_autotune, 598 ), 599 ) 600 ) 601 602 print_results(results, args.save_path) 603 604 605def heads_input_type(s): 606 try: 607 hq, hkv = map(int, s.split(",")) 608 return hq, hkv 609 except Exception as e: 610 raise argparse.ArgumentTypeError("Heads must be Hq,Hkv") from e 611 612 613if __name__ == "__main__": 614 # Set up the argument parser 615 parser = argparse.ArgumentParser( 616 description="Run sweep over sizes and score mods for flex attention" 617 ) 618 parser.add_argument( 619 "--dynamic", 620 action="store_true", 621 help="Runs a dynamic shapes version of compiled flex attention.", 622 ) 623 parser.add_argument( 624 "--calculate-bwd", action="store_true", help="Calculate backward pass times" 625 ) 626 627 parser.add_argument("-dtype", type=str, help="dtype", default="bfloat16") 628 629 parser.add_argument( 630 "-b", type=int, nargs="+", help="batch sizes", default=[2, 8, 16] 631 ) 632 parser.add_argument( 633 "-nh", 634 type=heads_input_type, 635 nargs="+", 636 help="# of q-heads,kv-heads", 637 default=[(16, 16), (16, 2)], 638 ) 639 parser.add_argument( 640 "-s", type=int, nargs="+", help="sequence lengths", default=[512, 1024, 4096] 641 ) 642 parser.add_argument("-d", type=int, nargs="+", help="head dims", default=[64, 128]) 643 parser.add_argument( 644 "-mods", 645 type=str, 646 nargs="+", 647 help="score mods", 648 default=["noop", "causal", "rel", "head_bias"], 649 ) 650 parser.add_argument( 651 "--max-autotune", action="store_true", help="Turn on max-autotune" 652 ) 653 parser.add_argument( 654 "--decoding", 655 action="store_true", 656 help="Benchmark Decoding (query sequence length = 1)", 657 ) 658 parser.add_argument( 659 "--kv-cache-size", 660 type=int, 661 nargs="+", 662 required=False, 663 help=""" 664key/value cache size in MiB. 665Ignores -b batch size and calculate batch size from kv_cache size instead when specified. 666""", 667 ) 668 parser.add_argument( 669 "--throughput", 670 action="store_true", 671 help="Calculate kernel memory bandwidth & computational throughput. ", 672 ) 673 parser.add_argument( 674 "--save-path", 675 type=str, 676 help="Path to save the results JSON file (optional)", 677 default=None, 678 ) 679 # Parse arguments 680 args = parser.parse_args() 681 args.dtype = getattr(torch, args.dtype) 682 683 main(args) 684