xref: /aosp_15_r20/external/pytorch/benchmarks/dynamo/microbenchmarks/matmul_relu.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from benchmark_helper import time_with_torch_timer
2
3import torch
4import torch._dynamo
5import torch._inductor.config as inductor_config
6
7
8inductor_config.triton.mm = "triton"
9
10
11@torch._dynamo.optimize("inductor", nopython=True)
12def inductor_mm(a, b):
13    return torch.mm(a, b)
14
15
16def torch_mm_relu(a, b):
17    return torch.nn.functional.relu(torch.mm(a, b))
18
19
20def torch_mm(a, b):
21    return torch.mm(a, b)
22
23
24if __name__ == "__main__":
25    # Real shapes from torchbench
26    a_shapes = [
27        [2048, 768],
28        [64, 1280],
29        [2048, 768],
30        [32, 2048],
31        [1, 39200],
32        [128, 3072],
33        [16, 1280],
34    ]
35    b_shapes = [
36        [768, 3072],
37        [1280, 1000],
38        [768, 768],
39        [2048, 1000],
40        [39200, 50],
41        [3072, 1000],
42        [1280, 1000],
43    ]
44
45    # Artificial larger shapes
46    a_shapes += [[10240, 512], [10240, 1024]]
47    b_shapes += [[512, 10240], [1024, 10240]]
48
49    for i in range(len(a_shapes)):
50        a_shape = a_shapes[i]
51        b_shape = b_shapes[i]
52        print("Shape:", a_shape, "x", b_shape)
53        a = torch.randn(a_shape, device="cuda", dtype=torch.float16)
54        b = torch.randn(b_shape, device="cuda", dtype=a.dtype)
55
56        time_with_torch_timer(torch_mm, (a, b), string_id="torch mm")
57        time_with_torch_timer(torch_mm_relu, (a, b), string_id="torch mm + relu")
58        time_with_torch_timer(inductor_mm, (a, b), string_id="inductor mm")
59
60
61# Results obtained on the AWS AI cluster
62# CPU: Intel(R) Xeon(R) Platinum 8275CL CPU @ 3.00GHz
63# GPU: NVIDIA A100-SXM 40GB memory
64"""
65Shape: [2048, 768] x [768, 3072]
66torch mm         mean: 0.0592 ms
67torch mm + relu  mean: 0.0759 ms
68inductor mm      mean: 0.0653 ms
69Shape: [64, 1280] x [1280, 1000]
70torch mm         mean: 0.0231 ms
71torch mm + relu  mean: 0.0316 ms
72inductor mm      mean: 0.0252 ms
73Shape: [2048, 768] x [768, 768]
74torch mm         mean: 0.0190 ms
75torch mm + relu  mean: 0.0277 ms
76inductor mm      mean: 0.0274 ms
77Shape: [32, 2048] x [2048, 1000]
78torch mm         mean: 0.0188 ms
79torch mm + relu  mean: 0.0290 ms
80inductor mm      mean: 0.0244 ms
81Shape: [1, 39200] x [39200, 50]
82torch mm         mean: 0.0134 ms
83torch mm + relu  mean: 0.0234 ms
84inductor mm      mean: 0.0290 ms
85Shape: [128, 3072] x [3072, 1000]
86torch mm         mean: 0.0181 ms
87torch mm + relu  mean: 0.0322 ms
88inductor mm      mean: 0.0319 ms
89Shape: [16, 1280] x [1280, 1000]
90torch mm         mean: 0.0188 ms
91torch mm + relu  mean: 0.0289 ms
92inductor mm      mean: 0.0255 ms
93Shape: [10240, 512] x [512, 10240]
94torch mm         mean: 0.4589 ms
95torch mm + relu  mean: 0.7896 ms
96inductor mm      mean: 0.5090 ms
97Shape: [10240, 1024] x [1024, 10240]
98torch mm         mean: 0.9152 ms
99torch mm + relu  mean: 1.2124 ms
100inductor mm      mean: 0.9462 ms
101"""
102