1import itertools 2from collections import defaultdict 3from contextlib import nullcontext 4from dataclasses import asdict, dataclass 5from typing import Callable, List, Tuple 6 7from tabulate import tabulate 8from tqdm import tqdm 9 10import torch 11import torch.utils.benchmark as benchmark 12from torch.nn.attention import sdpa_kernel, SDPBackend 13from torch.nn.functional import scaled_dot_product_attention 14 15 16def benchmark_torch_function_in_microseconds(func: Callable, *args, **kwargs) -> float: 17 # warmup 18 for _ in range(5): 19 func(*args, **kwargs) 20 t0 = benchmark.Timer( 21 stmt="func(*args, **kwargs)", 22 globals={"args": args, "kwargs": kwargs, "func": func}, 23 ) 24 return t0.adaptive_autorange(min_run_time=0.1).median * 1e6 25 26 27@dataclass(frozen=True) 28class ExperimentConfig: 29 batch_size: int 30 num_heads: int 31 q_seq_len: int 32 kv_seq_len: int 33 embed_dim: int 34 is_causal: bool 35 dtype: torch.dtype 36 backend: SDPBackend 37 device: torch.device = torch.device("cuda") 38 39 @property 40 def head_dim(self) -> int: 41 return self.embed_dim // self.num_heads 42 43 def asdict(self): 44 dict_obj = asdict(self) 45 dict_obj["head_dim"] = self.head_dim 46 return dict_obj 47 48 49@dataclass(frozen=True) 50class ExperimentResults: 51 forward_time: float 52 backward_time: float 53 54 def asdict(self): 55 return asdict(self) 56 57 58@dataclass(frozen=True) 59class Experiment: 60 config: ExperimentConfig 61 results: ExperimentResults 62 63 def asdict(self): 64 dict1 = asdict(self.config) 65 dict2 = asdict(self.results) 66 return {**dict1, **dict2} 67 68 69def get_input( 70 config: ExperimentConfig, 71) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 72 q = torch.randn( 73 (config.batch_size, config.num_heads, config.q_seq_len, config.head_dim), 74 dtype=config.dtype, 75 device=config.device, 76 requires_grad=True, 77 ) 78 k = torch.randn( 79 (config.batch_size, config.num_heads, config.kv_seq_len, config.head_dim), 80 dtype=config.dtype, 81 device=config.device, 82 requires_grad=True, 83 ) 84 v = torch.randn( 85 (config.batch_size, config.num_heads, config.kv_seq_len, config.head_dim), 86 dtype=config.dtype, 87 device=config.device, 88 requires_grad=True, 89 ) 90 return q, k, v 91 92 93def run_single_experiment(config: ExperimentConfig) -> ExperimentResults: 94 q, k, v = get_input(config) 95 is_causal = config.is_causal 96 context = ( 97 sdpa_kernel(config.backend) if config.backend is not None else nullcontext() 98 ) 99 with context: 100 forward_time = benchmark_torch_function_in_microseconds( 101 scaled_dot_product_attention, 102 q, 103 k, 104 v, 105 is_causal=is_causal, 106 attn_mask=None, 107 ) 108 out_torch = scaled_dot_product_attention( 109 q, k, v, is_causal=is_causal, attn_mask=None 110 ) 111 dOut = torch.randn_like(out_torch) 112 backward_time = benchmark_torch_function_in_microseconds( 113 out_torch.backward, dOut, retain_graph=True 114 ) 115 116 return ExperimentResults( 117 forward_time=forward_time, 118 backward_time=backward_time, 119 ) 120 121 122def generate_experiment_configs() -> List[ExperimentConfig]: 123 batch_sizes = [ 124 1, 125 8, 126 ] 127 num_heads = [16] 128 q_kv_seq_lens = [(128, 128), (256, 256), (512, 512), (1024, 1024)] 129 embed_dims = [2048] 130 backends = [None] # If set to None, all backends are enabled 131 dtypes = [ 132 torch.bfloat16, 133 ] 134 is_causal = [True, False] 135 all_configs = [] 136 for ( 137 bsz, 138 heads, 139 (q_seq_len, kv_seq_len), 140 embed_dim, 141 causal, 142 dtype, 143 backend, 144 ) in itertools.product( 145 batch_sizes, num_heads, q_kv_seq_lens, embed_dims, is_causal, dtypes, backends 146 ): 147 all_configs.append( 148 ExperimentConfig( 149 batch_size=bsz, 150 num_heads=heads, 151 q_seq_len=q_seq_len, 152 kv_seq_len=kv_seq_len, 153 embed_dim=embed_dim, 154 is_causal=causal, 155 dtype=dtype, 156 backend=backend, 157 ) 158 ) 159 160 return all_configs 161 162 163def print_results(experiments: List[Experiment]): 164 table_data = defaultdict(list) 165 for experiment in experiments: 166 for key, value in experiment.asdict().items(): 167 table_data[key].append(value) 168 del table_data["device"] 169 if table_data["backend"][0] is None: 170 del table_data["backend"] 171 print(tabulate(table_data, headers="keys", tablefmt="pretty", floatfmt=".3f")) 172 173 174def main(): 175 seed = 123 176 torch.manual_seed(seed) 177 results = [] 178 for config in tqdm(generate_experiment_configs()): 179 results.append(Experiment(config, run_single_experiment(config))) 180 181 print_results(results) 182 183 184if __name__ == "__main__": 185 main() 186