xref: /aosp_15_r20/external/pytorch/benchmarks/transformer/sdpa.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import itertools
2from collections import defaultdict
3from contextlib import nullcontext
4from dataclasses import asdict, dataclass
5from typing import Callable, List, Tuple
6
7from tabulate import tabulate
8from tqdm import tqdm
9
10import torch
11import torch.utils.benchmark as benchmark
12from torch.nn.attention import sdpa_kernel, SDPBackend
13from torch.nn.functional import scaled_dot_product_attention
14
15
16def benchmark_torch_function_in_microseconds(func: Callable, *args, **kwargs) -> float:
17    # warmup
18    for _ in range(5):
19        func(*args, **kwargs)
20    t0 = benchmark.Timer(
21        stmt="func(*args, **kwargs)",
22        globals={"args": args, "kwargs": kwargs, "func": func},
23    )
24    return t0.adaptive_autorange(min_run_time=0.1).median * 1e6
25
26
27@dataclass(frozen=True)
28class ExperimentConfig:
29    batch_size: int
30    num_heads: int
31    q_seq_len: int
32    kv_seq_len: int
33    embed_dim: int
34    is_causal: bool
35    dtype: torch.dtype
36    backend: SDPBackend
37    device: torch.device = torch.device("cuda")
38
39    @property
40    def head_dim(self) -> int:
41        return self.embed_dim // self.num_heads
42
43    def asdict(self):
44        dict_obj = asdict(self)
45        dict_obj["head_dim"] = self.head_dim
46        return dict_obj
47
48
49@dataclass(frozen=True)
50class ExperimentResults:
51    forward_time: float
52    backward_time: float
53
54    def asdict(self):
55        return asdict(self)
56
57
58@dataclass(frozen=True)
59class Experiment:
60    config: ExperimentConfig
61    results: ExperimentResults
62
63    def asdict(self):
64        dict1 = asdict(self.config)
65        dict2 = asdict(self.results)
66        return {**dict1, **dict2}
67
68
69def get_input(
70    config: ExperimentConfig,
71) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
72    q = torch.randn(
73        (config.batch_size, config.num_heads, config.q_seq_len, config.head_dim),
74        dtype=config.dtype,
75        device=config.device,
76        requires_grad=True,
77    )
78    k = torch.randn(
79        (config.batch_size, config.num_heads, config.kv_seq_len, config.head_dim),
80        dtype=config.dtype,
81        device=config.device,
82        requires_grad=True,
83    )
84    v = torch.randn(
85        (config.batch_size, config.num_heads, config.kv_seq_len, config.head_dim),
86        dtype=config.dtype,
87        device=config.device,
88        requires_grad=True,
89    )
90    return q, k, v
91
92
93def run_single_experiment(config: ExperimentConfig) -> ExperimentResults:
94    q, k, v = get_input(config)
95    is_causal = config.is_causal
96    context = (
97        sdpa_kernel(config.backend) if config.backend is not None else nullcontext()
98    )
99    with context:
100        forward_time = benchmark_torch_function_in_microseconds(
101            scaled_dot_product_attention,
102            q,
103            k,
104            v,
105            is_causal=is_causal,
106            attn_mask=None,
107        )
108        out_torch = scaled_dot_product_attention(
109            q, k, v, is_causal=is_causal, attn_mask=None
110        )
111        dOut = torch.randn_like(out_torch)
112        backward_time = benchmark_torch_function_in_microseconds(
113            out_torch.backward, dOut, retain_graph=True
114        )
115
116    return ExperimentResults(
117        forward_time=forward_time,
118        backward_time=backward_time,
119    )
120
121
122def generate_experiment_configs() -> List[ExperimentConfig]:
123    batch_sizes = [
124        1,
125        8,
126    ]
127    num_heads = [16]
128    q_kv_seq_lens = [(128, 128), (256, 256), (512, 512), (1024, 1024)]
129    embed_dims = [2048]
130    backends = [None]  # If set to None, all backends are enabled
131    dtypes = [
132        torch.bfloat16,
133    ]
134    is_causal = [True, False]
135    all_configs = []
136    for (
137        bsz,
138        heads,
139        (q_seq_len, kv_seq_len),
140        embed_dim,
141        causal,
142        dtype,
143        backend,
144    ) in itertools.product(
145        batch_sizes, num_heads, q_kv_seq_lens, embed_dims, is_causal, dtypes, backends
146    ):
147        all_configs.append(
148            ExperimentConfig(
149                batch_size=bsz,
150                num_heads=heads,
151                q_seq_len=q_seq_len,
152                kv_seq_len=kv_seq_len,
153                embed_dim=embed_dim,
154                is_causal=causal,
155                dtype=dtype,
156                backend=backend,
157            )
158        )
159
160    return all_configs
161
162
163def print_results(experiments: List[Experiment]):
164    table_data = defaultdict(list)
165    for experiment in experiments:
166        for key, value in experiment.asdict().items():
167            table_data[key].append(value)
168    del table_data["device"]
169    if table_data["backend"][0] is None:
170        del table_data["backend"]
171    print(tabulate(table_data, headers="keys", tablefmt="pretty", floatfmt=".3f"))
172
173
174def main():
175    seed = 123
176    torch.manual_seed(seed)
177    results = []
178    for config in tqdm(generate_experiment_configs()):
179        results.append(Experiment(config, run_single_experiment(config)))
180
181    print_results(results)
182
183
184if __name__ == "__main__":
185    main()
186