xref: /aosp_15_r20/external/pytorch/benchmarks/transformer/score_mod.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import argparse
2import csv
3import itertools
4from collections import defaultdict
5from dataclasses import asdict, dataclass
6from functools import partial
7from typing import Callable, List, Optional, Tuple
8
9import numpy as np
10from tabulate import tabulate
11from tqdm import tqdm
12
13import torch
14import torch.nn.functional as F
15from torch.nn.attention.flex_attention import (
16    _create_empty_block_mask,
17    create_block_mask,
18    create_mask,
19    flex_attention,
20)
21
22
23torch._dynamo.config.automatic_dynamic_shapes = False
24# Needed since changing args to function causes recompiles
25torch._dynamo.config.cache_size_limit = 1000
26
27
28from torch._inductor.runtime.benchmarking import benchmarker
29
30
31def benchmark_torch_function_in_microseconds(func: Callable, *args, **kwargs) -> float:
32    # warmup
33    for _ in range(5):
34        func(*args, **kwargs)
35    return benchmarker.benchmark_gpu(lambda: func(*args, **kwargs)) * 1e3
36
37
38@dataclass(frozen=True)
39class ExperimentConfig:
40    shape: Tuple[int]
41    score_mod: Callable
42    mask_mod: Callable
43    dtype: torch.dtype
44    calculate_bwd_time: bool
45    cal_bandwidth: bool
46
47    def __post_init__(self):
48        assert (
49            len(self.shape) == 6
50        ), "Shape must be of length 6"  # [B, Hq, M, Hkv, N, D]
51
52    def asdict(self):
53        # Convert the dataclass instance to a dictionary
54        d = asdict(self)
55        # Remove the 'calculate_bwd_time' and `cal_bandwidth` key
56        d.pop("calculate_bwd_time", None)
57        d.pop("cal_bandwidth", None)
58        d["shape(B,Hq,M,Hkv,N,D)"] = d.pop("shape")
59        return d
60
61
62@dataclass(frozen=True)
63class Times:
64    eager_time: float
65    compiled_time: float
66
67
68@dataclass(frozen=True)
69class ExperimentResults:
70    fwd_times: Times
71    bwd_times: Optional[Times]
72
73
74@dataclass(frozen=True)
75class Experiment:
76    config: ExperimentConfig
77    results: ExperimentResults
78
79    def asdict(self):
80        dict1 = self.config.asdict()
81        dict2 = asdict(self.results)
82        return {**dict1, **dict2}
83
84
85def generate_inputs(
86    batch_size: int,
87    q_heads: int,
88    q_sequence_length: int,
89    kv_heads: int,
90    kv_sequence_length: int,
91    head_dim: int,
92    dtype: torch.dtype,
93    device: torch.device,
94    requires_grad: bool,
95):
96    q_shape = (batch_size, q_sequence_length, q_heads * head_dim)
97    kv_shape = (batch_size, kv_sequence_length, kv_heads * head_dim)
98
99    assert q_heads % kv_heads == 0
100
101    num_h_groups = q_heads // kv_heads
102
103    make_q = partial(
104        torch.rand, q_shape, device=device, dtype=dtype, requires_grad=requires_grad
105    )
106    make_kv = partial(
107        torch.rand, kv_shape, device=device, dtype=dtype, requires_grad=requires_grad
108    )
109    query = (
110        make_q().view(batch_size, q_sequence_length, q_heads, head_dim).transpose(1, 2)
111    )
112    key = (
113        make_kv()
114        .view(batch_size, kv_sequence_length, kv_heads, head_dim)
115        .transpose(1, 2)
116    )
117    value = (
118        make_kv()
119        .view(batch_size, kv_sequence_length, kv_heads, head_dim)
120        .transpose(1, 2)
121    )
122    return query, key, value
123
124
125def run_single_experiment(
126    config: ExperimentConfig,
127    dynamic=False,
128    max_autotune=False,
129) -> ExperimentResults:
130    device = torch.device("cuda")
131    batch_size, q_heads, q_seq_len, kv_heads, kv_seq_len, head_dim = config.shape
132    query, key, value = generate_inputs(
133        batch_size,
134        q_heads,
135        q_seq_len,
136        kv_heads,
137        kv_seq_len,
138        head_dim,
139        config.dtype,
140        device,
141        requires_grad=config.calculate_bwd_time,
142    )
143
144    kwargs = {}
145    if get_func_name(config.mask_mod) == "causal":
146        kwargs["is_causal"] = True
147
148    def eager_sdpa(query, key, value, attn_mask):
149        out = F.scaled_dot_product_attention(query, key, value, attn_mask, **kwargs)
150        return out.reshape(batch_size, q_heads, q_seq_len, head_dim)
151
152    if max_autotune:
153        compiled_sdpa = torch.compile(
154            flex_attention, dynamic=dynamic, mode="max-autotune-no-cudagraphs"
155        )
156    else:
157        compiled_sdpa = torch.compile(flex_attention, dynamic=dynamic)
158
159    score_mod = config.score_mod
160    mask_mod = config.mask_mod
161
162    if mask_mod:
163        block_mask = create_block_mask(
164            mask_mod, 1, 1, q_seq_len, kv_seq_len, query.device
165        )
166    else:
167        block_mask = _create_empty_block_mask(query, key)
168
169    if mask_mod and get_func_name(mask_mod) != "causal":
170        attn_mask = create_mask(mask_mod, 1, 1, query.shape[-2], key.shape[-2])
171    else:
172        attn_mask = None
173
174    # Broadcast query/key for eager.
175    b_key = torch.repeat_interleave(key, q_heads // kv_heads, dim=1)
176    b_value = torch.repeat_interleave(value, q_heads // kv_heads, dim=1)
177
178    forward_eager_time = benchmark_torch_function_in_microseconds(
179        eager_sdpa, query, b_key, b_value, attn_mask
180    )
181    forward_compiled_time = benchmark_torch_function_in_microseconds(
182        compiled_sdpa,
183        query,
184        key,
185        value,
186        score_mod=score_mod,
187        block_mask=block_mask,
188        enable_gqa=True,
189    )
190
191    out_eager = eager_sdpa(query, b_key, b_value, attn_mask)
192    out_compile = compiled_sdpa(
193        query,
194        b_key,
195        b_value,
196        score_mod=score_mod,
197        block_mask=block_mask,
198        enable_gqa=True,
199    )
200
201    if score_mod is None:
202        torch.testing.assert_close(out_eager, out_compile, atol=1e-2, rtol=1e-2)
203
204    if config.calculate_bwd_time:
205        out_eager = eager_sdpa(query, b_key, b_value, attn_mask)
206        dOut = torch.randn_like(out_eager)
207        backward_eager_time = benchmark_torch_function_in_microseconds(
208            out_eager.backward, dOut, retain_graph=True
209        )
210
211        out_compile = compiled_sdpa(
212            query,
213            key,
214            value,
215            score_mod=score_mod,
216            block_mask=block_mask,
217            enable_gqa=True,
218        )
219        dOut = torch.randn_like(out_compile)
220        backward_compile_time = benchmark_torch_function_in_microseconds(
221            out_compile.backward, dOut, retain_graph=True
222        )
223
224        return ExperimentResults(
225            fwd_times=Times(forward_eager_time, forward_compiled_time),
226            bwd_times=Times(backward_eager_time, backward_compile_time),
227        )
228    else:
229        return ExperimentResults(
230            fwd_times=Times(forward_eager_time, forward_compiled_time),
231            bwd_times=None,
232        )
233
234
235def calculate_speedup(results: ExperimentResults, type: str) -> float:
236    if type == "fwd":
237        return results.fwd_times.eager_time / results.fwd_times.compiled_time
238    elif type == "bwd":
239        assert results.bwd_times is not None
240        return results.bwd_times.eager_time / results.bwd_times.compiled_time
241    else:
242        raise ValueError(f"Invalid type {type}")
243
244
245def calculate_bandwidth(
246    config: ExperimentConfig, results: ExperimentResults, type: str
247) -> float:
248    if type == "fwd":
249        batch_size, q_heads, q_seq_len, kv_heads, kv_seq_len, head_dim = config.shape
250        query_size = (
251            batch_size
252            * q_heads
253            * q_seq_len
254            * head_dim
255            * torch.finfo(config.dtype).bits
256            / 8
257        )
258        kv_size = (
259            batch_size
260            * kv_heads
261            * kv_seq_len
262            * head_dim
263            * torch.finfo(config.dtype).bits
264            / 8
265            * 2
266        )
267        output_size = query_size
268        total_size = (query_size + kv_size + output_size) / 1e9  # In GB
269        time_in_seconds = results.fwd_times.compiled_time / 1e6
270        return total_size / time_in_seconds / 1e3
271    else:
272        raise ValueError(f"Invalid type {type}")
273
274
275def calculate_tflops(config: ExperimentConfig, results: ExperimentResults) -> float:
276    (B, Hq, M, Hkv, N, D) = config.shape
277    qk_flops = M * N * D * 2
278    softmax_flops = M * N * 2  # Not counting online softmax overhead
279    o_flops = M * D * N * 2
280    # Not counting split k overhead
281    total_flops = B * Hq * (qk_flops + softmax_flops + o_flops)
282    return total_flops / results.fwd_times.compiled_time / 1e6  # in TFLOPs/
283
284
285def get_func_name(func):
286    if func is None:
287        return "None"
288    func_str = str(func)
289    if "<locals>" in func_str:
290        # For locally defined functions
291        return func_str.split("<locals>.")[-1].split(" at ")[0]
292    else:
293        # For regular functions
294        return func.__name__
295
296
297def set_func_name(func, name):
298    func.__name__ = name
299
300
301def get_average_speedups(results: List[Experiment], type: str):
302    # Calculate speedups
303    speedups = [calculate_speedup(r.results, type) for r in results]
304
305    # Find indices of max and min speedups
306    max_speedup_index = np.argmax(speedups)
307    min_speedup_index = np.argmin(speedups)
308
309    # Get the config dictionaries
310    max_config_dict = results[max_speedup_index].config.asdict()
311    min_config_dict = results[min_speedup_index].config.asdict()
312
313    # Extract function names from score_mod strings
314    max_config_dict["score_mod"] = get_func_name(max_config_dict["score_mod"])
315    max_config_dict["mask_mod"] = get_func_name(max_config_dict["mask_mod"])
316    min_config_dict["score_mod"] = get_func_name(min_config_dict["score_mod"])
317    min_config_dict["mask_mod"] = get_func_name(min_config_dict["mask_mod"])
318
319    # Create table data
320    table_data = [
321        {
322            "Type": "Average",
323            "Speedup": np.mean(speedups),
324            **dict.fromkeys(max_config_dict),
325        },
326        {"Type": "Max", "Speedup": speedups[max_speedup_index], **max_config_dict},
327        {"Type": "Min", "Speedup": speedups[min_speedup_index], **min_config_dict},
328    ]
329
330    return table_data
331
332
333def print_results(results: List[Experiment], save_path: Optional[str] = None):
334    table_data = defaultdict(list)
335    for experiment in results:
336        for key, value in experiment.asdict().items():
337            if key == "fwd_times":
338                for name, time in value.items():
339                    table_data[f"fwd_{name}"].append(float(time))
340            elif key == "bwd_times":
341                if experiment.config.calculate_bwd_time:
342                    for name, time in value.items():
343                        table_data[f"bwd_{name}"].append(float(time))
344            else:
345                table_data[key].append(value)
346
347    # Calculate speedups
348    fwd_speedups = [calculate_speedup(r.results, type="fwd") for r in results]
349    table_data["fwd_speedup"] = fwd_speedups
350
351    # Calculate mem + computational throughput
352    if results[0].config.cal_bandwidth:
353        fwd_bandwidth = [
354            calculate_bandwidth(r.config, r.results, type="fwd") for r in results
355        ]
356        table_data["fwd_mem_bw (TB/s)"] = fwd_bandwidth
357        fwd_tflops = [calculate_tflops(r.config, r.results) for r in results]
358        table_data["TFlops/s"] = fwd_tflops
359
360    if results[0].config.calculate_bwd_time:
361        bwd_speedups = [calculate_speedup(r.results, type="bwd") for r in results]
362        table_data["bwd_speedup"] = bwd_speedups
363
364    table_data["score_mod"] = [get_func_name(func) for func in table_data["score_mod"]]
365    table_data["mask_mod"] = [get_func_name(func) for func in table_data["mask_mod"]]
366
367    print(tabulate(table_data, headers="keys", tablefmt="github", floatfmt=".3f"))
368    print("\n")
369    print("FWD Speedups".center(125, "="))
370    print("\n")
371    average_data = get_average_speedups(results, type="fwd")
372    print(tabulate(average_data, headers="keys", tablefmt="github", floatfmt=".3f"))
373
374    if results[0].config.calculate_bwd_time:
375        print("\n")
376        print("BWD Speedups".center(125, "="))
377        print("\n")
378        average_data = get_average_speedups(results, type="bwd")
379        print(tabulate(average_data, headers="keys", tablefmt="github", floatfmt=".3f"))
380
381    if save_path is not None:
382        with open(save_path, "w", newline="") as csvfile:
383            writer = csv.DictWriter(csvfile, fieldnames=table_data.keys())
384            writer.writeheader()
385            for i in range(len(next(iter(table_data.values())))):
386                row = {k: v[i] for k, v in table_data.items()}
387                writer.writerow(row)
388        print(f"\nResults saved to {save_path}")
389
390
391def generate_score_mods(score_mods: List[str]) -> List[Callable | None]:
392    def noop(score, b, h, m, n):
393        return score
394
395    def causal_mask(score, b, h, token_q, token_kv):
396        return torch.where(token_q >= token_kv, score, float("-inf"))
397
398    def relative_bias(score, b, h, m, n):
399        return score + (m - n)
400
401    def head_bias(score, b, h, m, n):
402        return score + 2 * h
403
404    function_dict = {
405        "noop": None,
406        "causal": None,
407        "offset": None,
408        "rel": relative_bias,
409        "head_bias": head_bias,
410    }
411    return [function_dict[name] for name in score_mods]
412
413
414def generate_mask_mods(score_mods: List[str]) -> List[Callable | None]:
415    def noop(b, h, m, n):
416        return True
417
418    def causal(b, h, m, n):
419        return m >= n
420
421    def gen_offset(off):
422        def offset(b, h, m, n):
423            return m + off >= n
424
425        return offset
426
427    mask_mod_dict = {
428        "noop": None,
429        "causal": causal,
430        "offset": gen_offset,
431        "rel": None,
432        "head_bias": None,
433    }
434    return [mask_mod_dict[name] for name in score_mods]
435
436
437def generate_flash_configs(
438    calculate_bwd: bool,
439    dtype: torch.dtype,
440    batch_sizes: List[int],
441    num_heads: List[Tuple[int, int]],
442    seq_lens: List[int],
443    head_dims: List[int],
444    score_mods_str: List[str],
445    decoding: bool,
446    kv_cache_size: List[int],
447    cal_bandwidth: bool,
448) -> List[ExperimentConfig]:
449    assert not (calculate_bwd and decoding), "Decoding does not support backward"
450
451    bs_seqlen_vals = [
452        (32, 512),
453        (16, 1024),
454        (8, 2048),
455        (4, 4096),
456        (2, 8192),
457        (1, 16384),
458    ]
459    causal_vals = [False, True]
460    headdim_vals = [64, 128]
461    dim = 2048
462
463    score_mods = generate_score_mods(score_mods_str)
464    mask_mods = generate_mask_mods(score_mods_str)
465    all_configs = []
466
467    for (
468        (batch_size, seq_len),
469        causal,
470        head_dim,
471        score_mod,
472        mask_mod,
473    ) in itertools.product(
474        bs_seqlen_vals,
475        causal_vals,
476        headdim_vals,
477        score_mods,
478        mask_mods,
479    ):
480        num_heads = dim // head_dim
481
482        if decoding:
483            q_seq_len, kv_seq_len = 1, seq_len
484        else:
485            q_seq_len = kv_seq_len = seq_len
486
487        all_configs.append(
488            ExperimentConfig(
489                shape=(
490                    batch_size,
491                    num_heads,
492                    q_seq_len,
493                    num_heads,
494                    kv_seq_len,
495                    head_dim,
496                ),
497                score_mod=score_mod,
498                mask_mod=mask_mod,
499                dtype=dtype,
500                calculate_bwd_time=calculate_bwd,
501                cal_bandwidth=cal_bandwidth,
502            )
503        )
504
505    return all_configs
506
507
508def generate_experiment_configs(
509    calculate_bwd: bool,
510    dtype: torch.dtype,
511    batch_sizes: List[int],
512    num_heads: List[Tuple[int, int]],
513    seq_lens: List[int],
514    head_dims: List[int],
515    score_mods_str: List[str],
516    decoding: bool,
517    kv_cache_size: List[int],
518    cal_bandwidth: bool,
519) -> List[ExperimentConfig]:
520    assert not (calculate_bwd and decoding), "Decoding does not support backward"
521
522    if decoding:
523        q_kv_seq_lens = [(1, i) for i in seq_lens]  # only testing query length == 1
524    else:
525        q_kv_seq_lens = [(i, i) for i in seq_lens]  # only testing q_len == kv_len
526    dtypes = [dtype]
527    score_mods = generate_score_mods(score_mods_str)
528    mask_mods = generate_mask_mods(score_mods_str)
529    all_configs = []
530    for (
531        bsz,
532        (q_heads, kv_heads),
533        (q_seq_len, kv_seq_len),
534        head_dim,
535        (score_mod, mask_mod),
536        dtype,
537    ) in itertools.product(
538        kv_cache_size if kv_cache_size else batch_sizes,
539        num_heads,
540        q_kv_seq_lens,
541        head_dims,
542        zip(score_mods, mask_mods),
543        dtypes,
544    ):
545        if kv_cache_size:
546            head_size_bytes = torch.finfo(dtype).bits / 8 * head_dim
547            bsz = int(
548                (bsz * 1024 * 1024) // (kv_heads * kv_seq_len * head_size_bytes * 2)
549            )
550            if bsz <= 0:
551                continue
552
553        assert q_heads % kv_heads == 0
554
555        if mask_mod and get_func_name(mask_mod) == "gen_offset":
556            mask_mod = mask_mod(kv_seq_len // 2)
557
558        all_configs.append(
559            ExperimentConfig(
560                shape=(bsz, q_heads, q_seq_len, kv_heads, kv_seq_len, head_dim),
561                score_mod=score_mod,
562                mask_mod=mask_mod,
563                dtype=dtype,
564                calculate_bwd_time=calculate_bwd,
565                cal_bandwidth=cal_bandwidth,
566            )
567        )
568
569    return all_configs
570
571
572def main(args):
573    seed = 123
574    np.random.seed(seed)
575    torch.manual_seed(seed)
576    results = []
577    for config in tqdm(
578        generate_experiment_configs(
579            args.calculate_bwd,
580            args.dtype,
581            args.b,
582            args.nh,
583            args.s,
584            args.d,
585            args.mods,
586            args.decoding,
587            args.kv_cache_size,
588            args.throughput,
589        )
590    ):
591        results.append(
592            Experiment(
593                config,
594                run_single_experiment(
595                    config,
596                    dynamic=args.dynamic,
597                    max_autotune=args.max_autotune,
598                ),
599            )
600        )
601
602    print_results(results, args.save_path)
603
604
605def heads_input_type(s):
606    try:
607        hq, hkv = map(int, s.split(","))
608        return hq, hkv
609    except Exception as e:
610        raise argparse.ArgumentTypeError("Heads must be Hq,Hkv") from e
611
612
613if __name__ == "__main__":
614    # Set up the argument parser
615    parser = argparse.ArgumentParser(
616        description="Run sweep over sizes and score mods for flex attention"
617    )
618    parser.add_argument(
619        "--dynamic",
620        action="store_true",
621        help="Runs a dynamic shapes version of compiled flex attention.",
622    )
623    parser.add_argument(
624        "--calculate-bwd", action="store_true", help="Calculate backward pass times"
625    )
626
627    parser.add_argument("-dtype", type=str, help="dtype", default="bfloat16")
628
629    parser.add_argument(
630        "-b", type=int, nargs="+", help="batch sizes", default=[2, 8, 16]
631    )
632    parser.add_argument(
633        "-nh",
634        type=heads_input_type,
635        nargs="+",
636        help="# of q-heads,kv-heads",
637        default=[(16, 16), (16, 2)],
638    )
639    parser.add_argument(
640        "-s", type=int, nargs="+", help="sequence lengths", default=[512, 1024, 4096]
641    )
642    parser.add_argument("-d", type=int, nargs="+", help="head dims", default=[64, 128])
643    parser.add_argument(
644        "-mods",
645        type=str,
646        nargs="+",
647        help="score mods",
648        default=["noop", "causal", "rel", "head_bias"],
649    )
650    parser.add_argument(
651        "--max-autotune", action="store_true", help="Turn on max-autotune"
652    )
653    parser.add_argument(
654        "--decoding",
655        action="store_true",
656        help="Benchmark Decoding (query sequence length = 1)",
657    )
658    parser.add_argument(
659        "--kv-cache-size",
660        type=int,
661        nargs="+",
662        required=False,
663        help="""
664key/value cache size in MiB.
665Ignores -b batch size and calculate batch size from kv_cache size instead when specified.
666""",
667    )
668    parser.add_argument(
669        "--throughput",
670        action="store_true",
671        help="Calculate kernel memory bandwidth & computational throughput. ",
672    )
673    parser.add_argument(
674        "--save-path",
675        type=str,
676        help="Path to save the results JSON file (optional)",
677        default=None,
678    )
679    # Parse arguments
680    args = parser.parse_args()
681    args.dtype = getattr(torch, args.dtype)
682
683    main(args)
684