1# flake8: noqa 2 3import triton 4from prettytable import PrettyTable 5 6import torch 7import torch._dynamo 8import torch._inductor.config 9from torch._inductor.runtime.benchmarking import benchmarker 10 11 12# torch._inductor.config.debug = True 13torch._inductor.config.triton.dense_indexing = True 14torch.manual_seed(0) 15 16 17# The flag below controls whether to allow TF32 on matmul. 18torch.backends.cuda.matmul.allow_tf32 = True 19 20 21class Func(object): 22 # mm 23 @torch._dynamo.optimize("inductor") 24 def mm(a, b, bias): 25 y = torch.mm(a, b) 26 return y 27 28 # mm+bias 29 @torch._dynamo.optimize("inductor") 30 def mm_add(a, b, bias): 31 y = torch.mm(a, b) 32 return y + bias 33 34 # relu(mm) 35 @torch._dynamo.optimize("inductor") 36 def mm_relu(a, b, bias): 37 y = torch.mm(a, b) 38 return torch.relu(y) 39 40 # relu(mm+bias) 41 @torch._dynamo.optimize("inductor") 42 def mm_add_relu(a, b, bias): 43 y = torch.mm(a, b) 44 y += bias 45 return torch.relu(y) 46 47 48def bench(shape, layer_id, p, fusion_types=[""]): 49 dtype = torch.float16 50 M, K = shape[0] 51 _, N = shape[1] 52 torch.manual_seed(0) 53 # allocate inputs 54 a = torch.randn(shape[0], device="cuda", dtype=dtype) 55 b = torch.randn(shape[1], device="cuda", dtype=dtype) 56 57 def tflops(ms): 58 return M * K * N / ms * 1e-9 59 60 row = [layer_id] 61 for fusion_type in fusion_types: 62 if fusion_type == "": 63 fn_mm = getattr(Func, "mm") 64 else: 65 fn_mm = getattr(Func, f"mm_{fusion_type}") 66 67 if "add" in fusion_type: 68 bias = torch.randn((M, N), dtype=dtype, device="cuda") 69 else: 70 bias = None 71 72 args = (a, b, bias) 73 74 def fn(): 75 return fn_mm(*args) 76 77 torch._inductor.config.triton.mm = "aten" 78 torch_mm_ms, _, _ = benchmarker.benchmark_gpu(fn) 79 torch._inductor.config.triton.mm = "triton" 80 # reset to force code gen new python code 81 torch._dynamo.reset() 82 torch._inductor.metrics.reset() 83 triton_mm_ms, _, _ = benchmarker.benchmark_gpu(fn) 84 assert ( 85 torch._inductor.metrics.generated_kernel_count == 1 86 ), "codegen #kernel != 1" 87 row.extend([tflops(torch_mm_ms), tflops(triton_mm_ms)]) 88 89 p.add_row(row) 90 91 92fusion_types = ["", "add", "relu", "add_relu"] 93shapes = [ 94 # alexnet 95 ([128, 9216], [9216, 4096]), 96 ([128, 4096], [4096, 4096]), 97 ([128, 4096], [4096, 1000]), 98 # BERT 99 ([2048, 768], [768, 768]), 100 ([2048, 768], [768, 3072]), 101 ([2048, 3072], [3072, 768]), 102 # hf_GPT2 103 ([1024, 768], [768, 768]), 104 ([1024, 768], [768, 3072]), 105 ([1024, 3072], [3072, 768]), 106 ([1024, 768], [768, 2304]), 107] 108p = PrettyTable() 109field_names = ["layer"] 110for fusion_type in fusion_types: 111 if fusion_type == "": 112 field_names.append("torch mm") 113 field_names.append("triton mm") 114 else: 115 field_names.append(f"torch mm+{fusion_type}") 116 field_names.append(f"triton mm+{fusion_type}") 117 118p.field_names = field_names 119p.float_format = ".3" 120for id, shape in enumerate(shapes): 121 bench(shape, id, p, fusion_types) 122 123print(p) 124