import itertools from collections import defaultdict from contextlib import nullcontext from dataclasses import asdict, dataclass from typing import Callable, List, Tuple from tabulate import tabulate from tqdm import tqdm import torch import torch.utils.benchmark as benchmark from torch.nn.attention import sdpa_kernel, SDPBackend from torch.nn.functional import scaled_dot_product_attention def benchmark_torch_function_in_microseconds(func: Callable, *args, **kwargs) -> float: # warmup for _ in range(5): func(*args, **kwargs) t0 = benchmark.Timer( stmt="func(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "func": func}, ) return t0.adaptive_autorange(min_run_time=0.1).median * 1e6 @dataclass(frozen=True) class ExperimentConfig: batch_size: int num_heads: int q_seq_len: int kv_seq_len: int embed_dim: int is_causal: bool dtype: torch.dtype backend: SDPBackend device: torch.device = torch.device("cuda") @property def head_dim(self) -> int: return self.embed_dim // self.num_heads def asdict(self): dict_obj = asdict(self) dict_obj["head_dim"] = self.head_dim return dict_obj @dataclass(frozen=True) class ExperimentResults: forward_time: float backward_time: float def asdict(self): return asdict(self) @dataclass(frozen=True) class Experiment: config: ExperimentConfig results: ExperimentResults def asdict(self): dict1 = asdict(self.config) dict2 = asdict(self.results) return {**dict1, **dict2} def get_input( config: ExperimentConfig, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: q = torch.randn( (config.batch_size, config.num_heads, config.q_seq_len, config.head_dim), dtype=config.dtype, device=config.device, requires_grad=True, ) k = torch.randn( (config.batch_size, config.num_heads, config.kv_seq_len, config.head_dim), dtype=config.dtype, device=config.device, requires_grad=True, ) v = torch.randn( (config.batch_size, config.num_heads, config.kv_seq_len, config.head_dim), dtype=config.dtype, device=config.device, requires_grad=True, ) return q, k, v def run_single_experiment(config: ExperimentConfig) -> ExperimentResults: q, k, v = get_input(config) is_causal = config.is_causal context = ( sdpa_kernel(config.backend) if config.backend is not None else nullcontext() ) with context: forward_time = benchmark_torch_function_in_microseconds( scaled_dot_product_attention, q, k, v, is_causal=is_causal, attn_mask=None, ) out_torch = scaled_dot_product_attention( q, k, v, is_causal=is_causal, attn_mask=None ) dOut = torch.randn_like(out_torch) backward_time = benchmark_torch_function_in_microseconds( out_torch.backward, dOut, retain_graph=True ) return ExperimentResults( forward_time=forward_time, backward_time=backward_time, ) def generate_experiment_configs() -> List[ExperimentConfig]: batch_sizes = [ 1, 8, ] num_heads = [16] q_kv_seq_lens = [(128, 128), (256, 256), (512, 512), (1024, 1024)] embed_dims = [2048] backends = [None] # If set to None, all backends are enabled dtypes = [ torch.bfloat16, ] is_causal = [True, False] all_configs = [] for ( bsz, heads, (q_seq_len, kv_seq_len), embed_dim, causal, dtype, backend, ) in itertools.product( batch_sizes, num_heads, q_kv_seq_lens, embed_dims, is_causal, dtypes, backends ): all_configs.append( ExperimentConfig( batch_size=bsz, num_heads=heads, q_seq_len=q_seq_len, kv_seq_len=kv_seq_len, embed_dim=embed_dim, is_causal=causal, dtype=dtype, backend=backend, ) ) return all_configs def print_results(experiments: List[Experiment]): table_data = defaultdict(list) for experiment in experiments: for key, value in experiment.asdict().items(): table_data[key].append(value) del table_data["device"] if table_data["backend"][0] is None: del table_data["backend"] print(tabulate(table_data, headers="keys", tablefmt="pretty", floatfmt=".3f")) def main(): seed = 123 torch.manual_seed(seed) results = [] for config in tqdm(generate_experiment_configs()): results.append(Experiment(config, run_single_experiment(config))) print_results(results) if __name__ == "__main__": main()