1from benchmark_helper import time_with_torch_timer 2 3import torch 4import torch._dynamo 5import torch._inductor.config as inductor_config 6 7 8inductor_config.triton.mm = "triton" 9 10 11@torch._dynamo.optimize("inductor", nopython=True) 12def inductor_mm(a, b): 13 return torch.mm(a, b) 14 15 16def torch_mm_relu(a, b): 17 return torch.nn.functional.relu(torch.mm(a, b)) 18 19 20def torch_mm(a, b): 21 return torch.mm(a, b) 22 23 24if __name__ == "__main__": 25 # Real shapes from torchbench 26 a_shapes = [ 27 [2048, 768], 28 [64, 1280], 29 [2048, 768], 30 [32, 2048], 31 [1, 39200], 32 [128, 3072], 33 [16, 1280], 34 ] 35 b_shapes = [ 36 [768, 3072], 37 [1280, 1000], 38 [768, 768], 39 [2048, 1000], 40 [39200, 50], 41 [3072, 1000], 42 [1280, 1000], 43 ] 44 45 # Artificial larger shapes 46 a_shapes += [[10240, 512], [10240, 1024]] 47 b_shapes += [[512, 10240], [1024, 10240]] 48 49 for i in range(len(a_shapes)): 50 a_shape = a_shapes[i] 51 b_shape = b_shapes[i] 52 print("Shape:", a_shape, "x", b_shape) 53 a = torch.randn(a_shape, device="cuda", dtype=torch.float16) 54 b = torch.randn(b_shape, device="cuda", dtype=a.dtype) 55 56 time_with_torch_timer(torch_mm, (a, b), string_id="torch mm") 57 time_with_torch_timer(torch_mm_relu, (a, b), string_id="torch mm + relu") 58 time_with_torch_timer(inductor_mm, (a, b), string_id="inductor mm") 59 60 61# Results obtained on the AWS AI cluster 62# CPU: Intel(R) Xeon(R) Platinum 8275CL CPU @ 3.00GHz 63# GPU: NVIDIA A100-SXM 40GB memory 64""" 65Shape: [2048, 768] x [768, 3072] 66torch mm mean: 0.0592 ms 67torch mm + relu mean: 0.0759 ms 68inductor mm mean: 0.0653 ms 69Shape: [64, 1280] x [1280, 1000] 70torch mm mean: 0.0231 ms 71torch mm + relu mean: 0.0316 ms 72inductor mm mean: 0.0252 ms 73Shape: [2048, 768] x [768, 768] 74torch mm mean: 0.0190 ms 75torch mm + relu mean: 0.0277 ms 76inductor mm mean: 0.0274 ms 77Shape: [32, 2048] x [2048, 1000] 78torch mm mean: 0.0188 ms 79torch mm + relu mean: 0.0290 ms 80inductor mm mean: 0.0244 ms 81Shape: [1, 39200] x [39200, 50] 82torch mm mean: 0.0134 ms 83torch mm + relu mean: 0.0234 ms 84inductor mm mean: 0.0290 ms 85Shape: [128, 3072] x [3072, 1000] 86torch mm mean: 0.0181 ms 87torch mm + relu mean: 0.0322 ms 88inductor mm mean: 0.0319 ms 89Shape: [16, 1280] x [1280, 1000] 90torch mm mean: 0.0188 ms 91torch mm + relu mean: 0.0289 ms 92inductor mm mean: 0.0255 ms 93Shape: [10240, 512] x [512, 10240] 94torch mm mean: 0.4589 ms 95torch mm + relu mean: 0.7896 ms 96inductor mm mean: 0.5090 ms 97Shape: [10240, 1024] x [1024, 10240] 98torch mm mean: 0.9152 ms 99torch mm + relu mean: 1.2124 ms 100inductor mm mean: 0.9462 ms 101""" 102