xref: /aosp_15_r20/external/pytorch/benchmarks/transformer/sdpa.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Workerimport itertools
2*da0073e9SAndroid Build Coastguard Workerfrom collections import defaultdict
3*da0073e9SAndroid Build Coastguard Workerfrom contextlib import nullcontext
4*da0073e9SAndroid Build Coastguard Workerfrom dataclasses import asdict, dataclass
5*da0073e9SAndroid Build Coastguard Workerfrom typing import Callable, List, Tuple
6*da0073e9SAndroid Build Coastguard Worker
7*da0073e9SAndroid Build Coastguard Workerfrom tabulate import tabulate
8*da0073e9SAndroid Build Coastguard Workerfrom tqdm import tqdm
9*da0073e9SAndroid Build Coastguard Worker
10*da0073e9SAndroid Build Coastguard Workerimport torch
11*da0073e9SAndroid Build Coastguard Workerimport torch.utils.benchmark as benchmark
12*da0073e9SAndroid Build Coastguard Workerfrom torch.nn.attention import sdpa_kernel, SDPBackend
13*da0073e9SAndroid Build Coastguard Workerfrom torch.nn.functional import scaled_dot_product_attention
14*da0073e9SAndroid Build Coastguard Worker
15*da0073e9SAndroid Build Coastguard Worker
16*da0073e9SAndroid Build Coastguard Workerdef benchmark_torch_function_in_microseconds(func: Callable, *args, **kwargs) -> float:
17*da0073e9SAndroid Build Coastguard Worker    # warmup
18*da0073e9SAndroid Build Coastguard Worker    for _ in range(5):
19*da0073e9SAndroid Build Coastguard Worker        func(*args, **kwargs)
20*da0073e9SAndroid Build Coastguard Worker    t0 = benchmark.Timer(
21*da0073e9SAndroid Build Coastguard Worker        stmt="func(*args, **kwargs)",
22*da0073e9SAndroid Build Coastguard Worker        globals={"args": args, "kwargs": kwargs, "func": func},
23*da0073e9SAndroid Build Coastguard Worker    )
24*da0073e9SAndroid Build Coastguard Worker    return t0.adaptive_autorange(min_run_time=0.1).median * 1e6
25*da0073e9SAndroid Build Coastguard Worker
26*da0073e9SAndroid Build Coastguard Worker
27*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True)
28*da0073e9SAndroid Build Coastguard Workerclass ExperimentConfig:
29*da0073e9SAndroid Build Coastguard Worker    batch_size: int
30*da0073e9SAndroid Build Coastguard Worker    num_heads: int
31*da0073e9SAndroid Build Coastguard Worker    q_seq_len: int
32*da0073e9SAndroid Build Coastguard Worker    kv_seq_len: int
33*da0073e9SAndroid Build Coastguard Worker    embed_dim: int
34*da0073e9SAndroid Build Coastguard Worker    is_causal: bool
35*da0073e9SAndroid Build Coastguard Worker    dtype: torch.dtype
36*da0073e9SAndroid Build Coastguard Worker    backend: SDPBackend
37*da0073e9SAndroid Build Coastguard Worker    device: torch.device = torch.device("cuda")
38*da0073e9SAndroid Build Coastguard Worker
39*da0073e9SAndroid Build Coastguard Worker    @property
40*da0073e9SAndroid Build Coastguard Worker    def head_dim(self) -> int:
41*da0073e9SAndroid Build Coastguard Worker        return self.embed_dim // self.num_heads
42*da0073e9SAndroid Build Coastguard Worker
43*da0073e9SAndroid Build Coastguard Worker    def asdict(self):
44*da0073e9SAndroid Build Coastguard Worker        dict_obj = asdict(self)
45*da0073e9SAndroid Build Coastguard Worker        dict_obj["head_dim"] = self.head_dim
46*da0073e9SAndroid Build Coastguard Worker        return dict_obj
47*da0073e9SAndroid Build Coastguard Worker
48*da0073e9SAndroid Build Coastguard Worker
49*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True)
50*da0073e9SAndroid Build Coastguard Workerclass ExperimentResults:
51*da0073e9SAndroid Build Coastguard Worker    forward_time: float
52*da0073e9SAndroid Build Coastguard Worker    backward_time: float
53*da0073e9SAndroid Build Coastguard Worker
54*da0073e9SAndroid Build Coastguard Worker    def asdict(self):
55*da0073e9SAndroid Build Coastguard Worker        return asdict(self)
56*da0073e9SAndroid Build Coastguard Worker
57*da0073e9SAndroid Build Coastguard Worker
58*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True)
59*da0073e9SAndroid Build Coastguard Workerclass Experiment:
60*da0073e9SAndroid Build Coastguard Worker    config: ExperimentConfig
61*da0073e9SAndroid Build Coastguard Worker    results: ExperimentResults
62*da0073e9SAndroid Build Coastguard Worker
63*da0073e9SAndroid Build Coastguard Worker    def asdict(self):
64*da0073e9SAndroid Build Coastguard Worker        dict1 = asdict(self.config)
65*da0073e9SAndroid Build Coastguard Worker        dict2 = asdict(self.results)
66*da0073e9SAndroid Build Coastguard Worker        return {**dict1, **dict2}
67*da0073e9SAndroid Build Coastguard Worker
68*da0073e9SAndroid Build Coastguard Worker
69*da0073e9SAndroid Build Coastguard Workerdef get_input(
70*da0073e9SAndroid Build Coastguard Worker    config: ExperimentConfig,
71*da0073e9SAndroid Build Coastguard Worker) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
72*da0073e9SAndroid Build Coastguard Worker    q = torch.randn(
73*da0073e9SAndroid Build Coastguard Worker        (config.batch_size, config.num_heads, config.q_seq_len, config.head_dim),
74*da0073e9SAndroid Build Coastguard Worker        dtype=config.dtype,
75*da0073e9SAndroid Build Coastguard Worker        device=config.device,
76*da0073e9SAndroid Build Coastguard Worker        requires_grad=True,
77*da0073e9SAndroid Build Coastguard Worker    )
78*da0073e9SAndroid Build Coastguard Worker    k = torch.randn(
79*da0073e9SAndroid Build Coastguard Worker        (config.batch_size, config.num_heads, config.kv_seq_len, config.head_dim),
80*da0073e9SAndroid Build Coastguard Worker        dtype=config.dtype,
81*da0073e9SAndroid Build Coastguard Worker        device=config.device,
82*da0073e9SAndroid Build Coastguard Worker        requires_grad=True,
83*da0073e9SAndroid Build Coastguard Worker    )
84*da0073e9SAndroid Build Coastguard Worker    v = torch.randn(
85*da0073e9SAndroid Build Coastguard Worker        (config.batch_size, config.num_heads, config.kv_seq_len, config.head_dim),
86*da0073e9SAndroid Build Coastguard Worker        dtype=config.dtype,
87*da0073e9SAndroid Build Coastguard Worker        device=config.device,
88*da0073e9SAndroid Build Coastguard Worker        requires_grad=True,
89*da0073e9SAndroid Build Coastguard Worker    )
90*da0073e9SAndroid Build Coastguard Worker    return q, k, v
91*da0073e9SAndroid Build Coastguard Worker
92*da0073e9SAndroid Build Coastguard Worker
93*da0073e9SAndroid Build Coastguard Workerdef run_single_experiment(config: ExperimentConfig) -> ExperimentResults:
94*da0073e9SAndroid Build Coastguard Worker    q, k, v = get_input(config)
95*da0073e9SAndroid Build Coastguard Worker    is_causal = config.is_causal
96*da0073e9SAndroid Build Coastguard Worker    context = (
97*da0073e9SAndroid Build Coastguard Worker        sdpa_kernel(config.backend) if config.backend is not None else nullcontext()
98*da0073e9SAndroid Build Coastguard Worker    )
99*da0073e9SAndroid Build Coastguard Worker    with context:
100*da0073e9SAndroid Build Coastguard Worker        forward_time = benchmark_torch_function_in_microseconds(
101*da0073e9SAndroid Build Coastguard Worker            scaled_dot_product_attention,
102*da0073e9SAndroid Build Coastguard Worker            q,
103*da0073e9SAndroid Build Coastguard Worker            k,
104*da0073e9SAndroid Build Coastguard Worker            v,
105*da0073e9SAndroid Build Coastguard Worker            is_causal=is_causal,
106*da0073e9SAndroid Build Coastguard Worker            attn_mask=None,
107*da0073e9SAndroid Build Coastguard Worker        )
108*da0073e9SAndroid Build Coastguard Worker        out_torch = scaled_dot_product_attention(
109*da0073e9SAndroid Build Coastguard Worker            q, k, v, is_causal=is_causal, attn_mask=None
110*da0073e9SAndroid Build Coastguard Worker        )
111*da0073e9SAndroid Build Coastguard Worker        dOut = torch.randn_like(out_torch)
112*da0073e9SAndroid Build Coastguard Worker        backward_time = benchmark_torch_function_in_microseconds(
113*da0073e9SAndroid Build Coastguard Worker            out_torch.backward, dOut, retain_graph=True
114*da0073e9SAndroid Build Coastguard Worker        )
115*da0073e9SAndroid Build Coastguard Worker
116*da0073e9SAndroid Build Coastguard Worker    return ExperimentResults(
117*da0073e9SAndroid Build Coastguard Worker        forward_time=forward_time,
118*da0073e9SAndroid Build Coastguard Worker        backward_time=backward_time,
119*da0073e9SAndroid Build Coastguard Worker    )
120*da0073e9SAndroid Build Coastguard Worker
121*da0073e9SAndroid Build Coastguard Worker
122*da0073e9SAndroid Build Coastguard Workerdef generate_experiment_configs() -> List[ExperimentConfig]:
123*da0073e9SAndroid Build Coastguard Worker    batch_sizes = [
124*da0073e9SAndroid Build Coastguard Worker        1,
125*da0073e9SAndroid Build Coastguard Worker        8,
126*da0073e9SAndroid Build Coastguard Worker    ]
127*da0073e9SAndroid Build Coastguard Worker    num_heads = [16]
128*da0073e9SAndroid Build Coastguard Worker    q_kv_seq_lens = [(128, 128), (256, 256), (512, 512), (1024, 1024)]
129*da0073e9SAndroid Build Coastguard Worker    embed_dims = [2048]
130*da0073e9SAndroid Build Coastguard Worker    backends = [None]  # If set to None, all backends are enabled
131*da0073e9SAndroid Build Coastguard Worker    dtypes = [
132*da0073e9SAndroid Build Coastguard Worker        torch.bfloat16,
133*da0073e9SAndroid Build Coastguard Worker    ]
134*da0073e9SAndroid Build Coastguard Worker    is_causal = [True, False]
135*da0073e9SAndroid Build Coastguard Worker    all_configs = []
136*da0073e9SAndroid Build Coastguard Worker    for (
137*da0073e9SAndroid Build Coastguard Worker        bsz,
138*da0073e9SAndroid Build Coastguard Worker        heads,
139*da0073e9SAndroid Build Coastguard Worker        (q_seq_len, kv_seq_len),
140*da0073e9SAndroid Build Coastguard Worker        embed_dim,
141*da0073e9SAndroid Build Coastguard Worker        causal,
142*da0073e9SAndroid Build Coastguard Worker        dtype,
143*da0073e9SAndroid Build Coastguard Worker        backend,
144*da0073e9SAndroid Build Coastguard Worker    ) in itertools.product(
145*da0073e9SAndroid Build Coastguard Worker        batch_sizes, num_heads, q_kv_seq_lens, embed_dims, is_causal, dtypes, backends
146*da0073e9SAndroid Build Coastguard Worker    ):
147*da0073e9SAndroid Build Coastguard Worker        all_configs.append(
148*da0073e9SAndroid Build Coastguard Worker            ExperimentConfig(
149*da0073e9SAndroid Build Coastguard Worker                batch_size=bsz,
150*da0073e9SAndroid Build Coastguard Worker                num_heads=heads,
151*da0073e9SAndroid Build Coastguard Worker                q_seq_len=q_seq_len,
152*da0073e9SAndroid Build Coastguard Worker                kv_seq_len=kv_seq_len,
153*da0073e9SAndroid Build Coastguard Worker                embed_dim=embed_dim,
154*da0073e9SAndroid Build Coastguard Worker                is_causal=causal,
155*da0073e9SAndroid Build Coastguard Worker                dtype=dtype,
156*da0073e9SAndroid Build Coastguard Worker                backend=backend,
157*da0073e9SAndroid Build Coastguard Worker            )
158*da0073e9SAndroid Build Coastguard Worker        )
159*da0073e9SAndroid Build Coastguard Worker
160*da0073e9SAndroid Build Coastguard Worker    return all_configs
161*da0073e9SAndroid Build Coastguard Worker
162*da0073e9SAndroid Build Coastguard Worker
163*da0073e9SAndroid Build Coastguard Workerdef print_results(experiments: List[Experiment]):
164*da0073e9SAndroid Build Coastguard Worker    table_data = defaultdict(list)
165*da0073e9SAndroid Build Coastguard Worker    for experiment in experiments:
166*da0073e9SAndroid Build Coastguard Worker        for key, value in experiment.asdict().items():
167*da0073e9SAndroid Build Coastguard Worker            table_data[key].append(value)
168*da0073e9SAndroid Build Coastguard Worker    del table_data["device"]
169*da0073e9SAndroid Build Coastguard Worker    if table_data["backend"][0] is None:
170*da0073e9SAndroid Build Coastguard Worker        del table_data["backend"]
171*da0073e9SAndroid Build Coastguard Worker    print(tabulate(table_data, headers="keys", tablefmt="pretty", floatfmt=".3f"))
172*da0073e9SAndroid Build Coastguard Worker
173*da0073e9SAndroid Build Coastguard Worker
174*da0073e9SAndroid Build Coastguard Workerdef main():
175*da0073e9SAndroid Build Coastguard Worker    seed = 123
176*da0073e9SAndroid Build Coastguard Worker    torch.manual_seed(seed)
177*da0073e9SAndroid Build Coastguard Worker    results = []
178*da0073e9SAndroid Build Coastguard Worker    for config in tqdm(generate_experiment_configs()):
179*da0073e9SAndroid Build Coastguard Worker        results.append(Experiment(config, run_single_experiment(config)))
180*da0073e9SAndroid Build Coastguard Worker
181*da0073e9SAndroid Build Coastguard Worker    print_results(results)
182*da0073e9SAndroid Build Coastguard Worker
183*da0073e9SAndroid Build Coastguard Worker
184*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
185*da0073e9SAndroid Build Coastguard Worker    main()
186