1*da0073e9SAndroid Build Coastguard Workerimport itertools 2*da0073e9SAndroid Build Coastguard Workerfrom collections import defaultdict 3*da0073e9SAndroid Build Coastguard Workerfrom contextlib import nullcontext 4*da0073e9SAndroid Build Coastguard Workerfrom dataclasses import asdict, dataclass 5*da0073e9SAndroid Build Coastguard Workerfrom typing import Callable, List, Tuple 6*da0073e9SAndroid Build Coastguard Worker 7*da0073e9SAndroid Build Coastguard Workerfrom tabulate import tabulate 8*da0073e9SAndroid Build Coastguard Workerfrom tqdm import tqdm 9*da0073e9SAndroid Build Coastguard Worker 10*da0073e9SAndroid Build Coastguard Workerimport torch 11*da0073e9SAndroid Build Coastguard Workerimport torch.utils.benchmark as benchmark 12*da0073e9SAndroid Build Coastguard Workerfrom torch.nn.attention import sdpa_kernel, SDPBackend 13*da0073e9SAndroid Build Coastguard Workerfrom torch.nn.functional import scaled_dot_product_attention 14*da0073e9SAndroid Build Coastguard Worker 15*da0073e9SAndroid Build Coastguard Worker 16*da0073e9SAndroid Build Coastguard Workerdef benchmark_torch_function_in_microseconds(func: Callable, *args, **kwargs) -> float: 17*da0073e9SAndroid Build Coastguard Worker # warmup 18*da0073e9SAndroid Build Coastguard Worker for _ in range(5): 19*da0073e9SAndroid Build Coastguard Worker func(*args, **kwargs) 20*da0073e9SAndroid Build Coastguard Worker t0 = benchmark.Timer( 21*da0073e9SAndroid Build Coastguard Worker stmt="func(*args, **kwargs)", 22*da0073e9SAndroid Build Coastguard Worker globals={"args": args, "kwargs": kwargs, "func": func}, 23*da0073e9SAndroid Build Coastguard Worker ) 24*da0073e9SAndroid Build Coastguard Worker return t0.adaptive_autorange(min_run_time=0.1).median * 1e6 25*da0073e9SAndroid Build Coastguard Worker 26*da0073e9SAndroid Build Coastguard Worker 27*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True) 28*da0073e9SAndroid Build Coastguard Workerclass ExperimentConfig: 29*da0073e9SAndroid Build Coastguard Worker batch_size: int 30*da0073e9SAndroid Build Coastguard Worker num_heads: int 31*da0073e9SAndroid Build Coastguard Worker q_seq_len: int 32*da0073e9SAndroid Build Coastguard Worker kv_seq_len: int 33*da0073e9SAndroid Build Coastguard Worker embed_dim: int 34*da0073e9SAndroid Build Coastguard Worker is_causal: bool 35*da0073e9SAndroid Build Coastguard Worker dtype: torch.dtype 36*da0073e9SAndroid Build Coastguard Worker backend: SDPBackend 37*da0073e9SAndroid Build Coastguard Worker device: torch.device = torch.device("cuda") 38*da0073e9SAndroid Build Coastguard Worker 39*da0073e9SAndroid Build Coastguard Worker @property 40*da0073e9SAndroid Build Coastguard Worker def head_dim(self) -> int: 41*da0073e9SAndroid Build Coastguard Worker return self.embed_dim // self.num_heads 42*da0073e9SAndroid Build Coastguard Worker 43*da0073e9SAndroid Build Coastguard Worker def asdict(self): 44*da0073e9SAndroid Build Coastguard Worker dict_obj = asdict(self) 45*da0073e9SAndroid Build Coastguard Worker dict_obj["head_dim"] = self.head_dim 46*da0073e9SAndroid Build Coastguard Worker return dict_obj 47*da0073e9SAndroid Build Coastguard Worker 48*da0073e9SAndroid Build Coastguard Worker 49*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True) 50*da0073e9SAndroid Build Coastguard Workerclass ExperimentResults: 51*da0073e9SAndroid Build Coastguard Worker forward_time: float 52*da0073e9SAndroid Build Coastguard Worker backward_time: float 53*da0073e9SAndroid Build Coastguard Worker 54*da0073e9SAndroid Build Coastguard Worker def asdict(self): 55*da0073e9SAndroid Build Coastguard Worker return asdict(self) 56*da0073e9SAndroid Build Coastguard Worker 57*da0073e9SAndroid Build Coastguard Worker 58*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True) 59*da0073e9SAndroid Build Coastguard Workerclass Experiment: 60*da0073e9SAndroid Build Coastguard Worker config: ExperimentConfig 61*da0073e9SAndroid Build Coastguard Worker results: ExperimentResults 62*da0073e9SAndroid Build Coastguard Worker 63*da0073e9SAndroid Build Coastguard Worker def asdict(self): 64*da0073e9SAndroid Build Coastguard Worker dict1 = asdict(self.config) 65*da0073e9SAndroid Build Coastguard Worker dict2 = asdict(self.results) 66*da0073e9SAndroid Build Coastguard Worker return {**dict1, **dict2} 67*da0073e9SAndroid Build Coastguard Worker 68*da0073e9SAndroid Build Coastguard Worker 69*da0073e9SAndroid Build Coastguard Workerdef get_input( 70*da0073e9SAndroid Build Coastguard Worker config: ExperimentConfig, 71*da0073e9SAndroid Build Coastguard Worker) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 72*da0073e9SAndroid Build Coastguard Worker q = torch.randn( 73*da0073e9SAndroid Build Coastguard Worker (config.batch_size, config.num_heads, config.q_seq_len, config.head_dim), 74*da0073e9SAndroid Build Coastguard Worker dtype=config.dtype, 75*da0073e9SAndroid Build Coastguard Worker device=config.device, 76*da0073e9SAndroid Build Coastguard Worker requires_grad=True, 77*da0073e9SAndroid Build Coastguard Worker ) 78*da0073e9SAndroid Build Coastguard Worker k = torch.randn( 79*da0073e9SAndroid Build Coastguard Worker (config.batch_size, config.num_heads, config.kv_seq_len, config.head_dim), 80*da0073e9SAndroid Build Coastguard Worker dtype=config.dtype, 81*da0073e9SAndroid Build Coastguard Worker device=config.device, 82*da0073e9SAndroid Build Coastguard Worker requires_grad=True, 83*da0073e9SAndroid Build Coastguard Worker ) 84*da0073e9SAndroid Build Coastguard Worker v = torch.randn( 85*da0073e9SAndroid Build Coastguard Worker (config.batch_size, config.num_heads, config.kv_seq_len, config.head_dim), 86*da0073e9SAndroid Build Coastguard Worker dtype=config.dtype, 87*da0073e9SAndroid Build Coastguard Worker device=config.device, 88*da0073e9SAndroid Build Coastguard Worker requires_grad=True, 89*da0073e9SAndroid Build Coastguard Worker ) 90*da0073e9SAndroid Build Coastguard Worker return q, k, v 91*da0073e9SAndroid Build Coastguard Worker 92*da0073e9SAndroid Build Coastguard Worker 93*da0073e9SAndroid Build Coastguard Workerdef run_single_experiment(config: ExperimentConfig) -> ExperimentResults: 94*da0073e9SAndroid Build Coastguard Worker q, k, v = get_input(config) 95*da0073e9SAndroid Build Coastguard Worker is_causal = config.is_causal 96*da0073e9SAndroid Build Coastguard Worker context = ( 97*da0073e9SAndroid Build Coastguard Worker sdpa_kernel(config.backend) if config.backend is not None else nullcontext() 98*da0073e9SAndroid Build Coastguard Worker ) 99*da0073e9SAndroid Build Coastguard Worker with context: 100*da0073e9SAndroid Build Coastguard Worker forward_time = benchmark_torch_function_in_microseconds( 101*da0073e9SAndroid Build Coastguard Worker scaled_dot_product_attention, 102*da0073e9SAndroid Build Coastguard Worker q, 103*da0073e9SAndroid Build Coastguard Worker k, 104*da0073e9SAndroid Build Coastguard Worker v, 105*da0073e9SAndroid Build Coastguard Worker is_causal=is_causal, 106*da0073e9SAndroid Build Coastguard Worker attn_mask=None, 107*da0073e9SAndroid Build Coastguard Worker ) 108*da0073e9SAndroid Build Coastguard Worker out_torch = scaled_dot_product_attention( 109*da0073e9SAndroid Build Coastguard Worker q, k, v, is_causal=is_causal, attn_mask=None 110*da0073e9SAndroid Build Coastguard Worker ) 111*da0073e9SAndroid Build Coastguard Worker dOut = torch.randn_like(out_torch) 112*da0073e9SAndroid Build Coastguard Worker backward_time = benchmark_torch_function_in_microseconds( 113*da0073e9SAndroid Build Coastguard Worker out_torch.backward, dOut, retain_graph=True 114*da0073e9SAndroid Build Coastguard Worker ) 115*da0073e9SAndroid Build Coastguard Worker 116*da0073e9SAndroid Build Coastguard Worker return ExperimentResults( 117*da0073e9SAndroid Build Coastguard Worker forward_time=forward_time, 118*da0073e9SAndroid Build Coastguard Worker backward_time=backward_time, 119*da0073e9SAndroid Build Coastguard Worker ) 120*da0073e9SAndroid Build Coastguard Worker 121*da0073e9SAndroid Build Coastguard Worker 122*da0073e9SAndroid Build Coastguard Workerdef generate_experiment_configs() -> List[ExperimentConfig]: 123*da0073e9SAndroid Build Coastguard Worker batch_sizes = [ 124*da0073e9SAndroid Build Coastguard Worker 1, 125*da0073e9SAndroid Build Coastguard Worker 8, 126*da0073e9SAndroid Build Coastguard Worker ] 127*da0073e9SAndroid Build Coastguard Worker num_heads = [16] 128*da0073e9SAndroid Build Coastguard Worker q_kv_seq_lens = [(128, 128), (256, 256), (512, 512), (1024, 1024)] 129*da0073e9SAndroid Build Coastguard Worker embed_dims = [2048] 130*da0073e9SAndroid Build Coastguard Worker backends = [None] # If set to None, all backends are enabled 131*da0073e9SAndroid Build Coastguard Worker dtypes = [ 132*da0073e9SAndroid Build Coastguard Worker torch.bfloat16, 133*da0073e9SAndroid Build Coastguard Worker ] 134*da0073e9SAndroid Build Coastguard Worker is_causal = [True, False] 135*da0073e9SAndroid Build Coastguard Worker all_configs = [] 136*da0073e9SAndroid Build Coastguard Worker for ( 137*da0073e9SAndroid Build Coastguard Worker bsz, 138*da0073e9SAndroid Build Coastguard Worker heads, 139*da0073e9SAndroid Build Coastguard Worker (q_seq_len, kv_seq_len), 140*da0073e9SAndroid Build Coastguard Worker embed_dim, 141*da0073e9SAndroid Build Coastguard Worker causal, 142*da0073e9SAndroid Build Coastguard Worker dtype, 143*da0073e9SAndroid Build Coastguard Worker backend, 144*da0073e9SAndroid Build Coastguard Worker ) in itertools.product( 145*da0073e9SAndroid Build Coastguard Worker batch_sizes, num_heads, q_kv_seq_lens, embed_dims, is_causal, dtypes, backends 146*da0073e9SAndroid Build Coastguard Worker ): 147*da0073e9SAndroid Build Coastguard Worker all_configs.append( 148*da0073e9SAndroid Build Coastguard Worker ExperimentConfig( 149*da0073e9SAndroid Build Coastguard Worker batch_size=bsz, 150*da0073e9SAndroid Build Coastguard Worker num_heads=heads, 151*da0073e9SAndroid Build Coastguard Worker q_seq_len=q_seq_len, 152*da0073e9SAndroid Build Coastguard Worker kv_seq_len=kv_seq_len, 153*da0073e9SAndroid Build Coastguard Worker embed_dim=embed_dim, 154*da0073e9SAndroid Build Coastguard Worker is_causal=causal, 155*da0073e9SAndroid Build Coastguard Worker dtype=dtype, 156*da0073e9SAndroid Build Coastguard Worker backend=backend, 157*da0073e9SAndroid Build Coastguard Worker ) 158*da0073e9SAndroid Build Coastguard Worker ) 159*da0073e9SAndroid Build Coastguard Worker 160*da0073e9SAndroid Build Coastguard Worker return all_configs 161*da0073e9SAndroid Build Coastguard Worker 162*da0073e9SAndroid Build Coastguard Worker 163*da0073e9SAndroid Build Coastguard Workerdef print_results(experiments: List[Experiment]): 164*da0073e9SAndroid Build Coastguard Worker table_data = defaultdict(list) 165*da0073e9SAndroid Build Coastguard Worker for experiment in experiments: 166*da0073e9SAndroid Build Coastguard Worker for key, value in experiment.asdict().items(): 167*da0073e9SAndroid Build Coastguard Worker table_data[key].append(value) 168*da0073e9SAndroid Build Coastguard Worker del table_data["device"] 169*da0073e9SAndroid Build Coastguard Worker if table_data["backend"][0] is None: 170*da0073e9SAndroid Build Coastguard Worker del table_data["backend"] 171*da0073e9SAndroid Build Coastguard Worker print(tabulate(table_data, headers="keys", tablefmt="pretty", floatfmt=".3f")) 172*da0073e9SAndroid Build Coastguard Worker 173*da0073e9SAndroid Build Coastguard Worker 174*da0073e9SAndroid Build Coastguard Workerdef main(): 175*da0073e9SAndroid Build Coastguard Worker seed = 123 176*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(seed) 177*da0073e9SAndroid Build Coastguard Worker results = [] 178*da0073e9SAndroid Build Coastguard Worker for config in tqdm(generate_experiment_configs()): 179*da0073e9SAndroid Build Coastguard Worker results.append(Experiment(config, run_single_experiment(config))) 180*da0073e9SAndroid Build Coastguard Worker 181*da0073e9SAndroid Build Coastguard Worker print_results(results) 182*da0073e9SAndroid Build Coastguard Worker 183*da0073e9SAndroid Build Coastguard Worker 184*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 185*da0073e9SAndroid Build Coastguard Worker main() 186