1import argparse 2import random 3 4import torch 5 6 7def bench(nt_a, nt_b, niter): 8 # Warmup 9 nt_c = nt_a.bmm(nt_b) 10 11 torch.cuda.synchronize() 12 start_event = torch.cuda.Event(enable_timing=True) 13 end_event = torch.cuda.Event(enable_timing=True) 14 start_event.record() 15 for iter in range(niter): 16 nt_c = nt_a.bmm(nt_b) 17 end_event.record() 18 torch.cuda.synchronize() 19 runtime = (start_event.elapsed_time(end_event)) / niter 20 return runtime 21 22 23def sweep_n(niter, dtype): 24 for ntensor in [4, 8, 16, 32, 64, 128, 256]: 25 tensors = [torch.randn(256, random.randint(100, 200)) for t in range(ntensor)] 26 nt_a = torch.nested.nested_tensor( 27 tensors, 28 dtype=dtype, 29 device="cuda", 30 ) 31 nt_b = torch.nested.nested_tensor( 32 [t.t() for t in tensors], 33 dtype=dtype, 34 device="cuda", 35 ) 36 runtime = bench(nt_a, nt_b, niter) 37 nt_a_size = torch.ops.aten._nested_tensor_size(nt_a) 38 lengths = nt_a_size[:, 1] 39 print( 40 ",".join( 41 map( 42 str, 43 [ 44 ntensor, 45 dtype, 46 lengths.min().item(), 47 lengths.float().mean().item(), 48 lengths.max().item(), 49 runtime, 50 ], 51 ) 52 ) 53 ) 54 55 56if __name__ == "__main__": 57 random.seed(123) 58 parser = argparse.ArgumentParser(description="Nested Tensor BMM Benchmark") 59 parser.add_argument("--niter", default="10", type=int) 60 61 args = parser.parse_args() 62 niter = args.niter 63 64 print("ntensor,dtype,min_length,mean_length,max_length,runtime") 65 sweep_n(niter, torch.float32) 66 sweep_n(niter, torch.float16) 67