xref: /aosp_15_r20/external/pytorch/benchmarks/dynamo/microbenchmarks/inductor_mm.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import triton
2from benchmark_helper import time_with_torch_timer
3
4import torch
5import torch._dynamo
6import torch._dynamo.config
7import torch._inductor.config as config
8from torch._inductor.runtime.benchmarking import benchmarker
9
10
11# The flag below controls whether to allow TF32 on matmul. This flag defaults to True.
12torch.backends.cuda.matmul.allow_tf32 = True
13# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
14torch.backends.cudnn.allow_tf32 = True
15
16
17@torch._dynamo.optimize("inductor", nopython=True)
18def inductor_aten_mm(a, b):
19    return torch.mm(a, b)
20
21
22@torch._dynamo.optimize("inductor", nopython=True)
23def inductor_triton_mm(a, b):
24    return torch.mm(a, b)
25
26
27def torch_mm(a, b):
28    return torch.mm(a, b)
29
30
31def triton_mm(a, b):
32    return triton.ops.matmul(a, b)
33
34
35def test_total_time(shapes):
36    print("shape; torch mm; triton mm; inductor aten mm; inductor triton mm")
37    for i in range(len(shapes)):
38        a_shape, b_shape = shapes[i]
39        print(a_shape, "x", b_shape, end="; ")
40        a = torch.randn(a_shape, device="cuda", dtype=torch.float16)
41        b = torch.randn(b_shape, device="cuda", dtype=a.dtype)
42
43        config.triton.mm = "aten"
44        inductor_aten_mm(a, b)
45
46        config.triton.mm = "triton"
47        inductor_triton_mm(a, b)
48
49        torch_ms = time_with_torch_timer(torch_mm, (a, b)).mean * 1000
50
51        triton_ms = time_with_torch_timer(triton_mm, (a, b)).mean * 1000
52
53        config.triton.mm = "aten"
54        ind_aten_ms = time_with_torch_timer(inductor_aten_mm, (a, b)).mean * 1000
55
56        config.triton.mm = "triton"
57        ind_triton_ms = time_with_torch_timer(inductor_triton_mm, (a, b)).mean * 1000
58
59        print(torch_ms, triton_ms, ind_aten_ms, ind_triton_ms, sep="; ")
60
61        torch._dynamo.reset()
62
63
64def test_GPU_time(shapes):
65    print("shape; torch mm; triton mm; inductor aten mm; inductor triton mm")
66    for i in range(len(shapes)):
67        a_shape, b_shape = shapes[i]
68        print(a_shape, "x", b_shape, end="; ")
69        a = torch.randn(a_shape, device="cuda", dtype=torch.float16)
70        b = torch.randn(b_shape, device="cuda", dtype=a.dtype)
71
72        config.triton.mm = "aten"
73        inductor_aten_mm(a, b)
74
75        config.triton.mm = "triton"
76        inductor_triton_mm(a, b)
77
78        torch_ms, _, _ = benchmarker.benchmark_gpu(lambda: torch_mm(a, b))
79        triton_ms, _, _ = benchmarker.benchmark_gpu(lambda: triton_mm(a, b))
80        ind_aten_ms, _, _ = benchmarker.benchmark_gpu(lambda: inductor_aten_mm(a, b))
81        ind_triton_ms, _, _ = benchmarker.benchmark_gpu(
82            lambda: inductor_triton_mm(a, b)
83        )
84        print(torch_ms, triton_ms, ind_aten_ms, ind_triton_ms, sep="; ")
85
86        torch._dynamo.reset()
87
88
89if __name__ == "__main__":
90    shapes = [
91        # alexnet
92        ([128, 9216], [9216, 4096]),
93        ([128, 4096], [4096, 4096]),
94        ([128, 4096], [4096, 1000]),
95        # BERT
96        ([2048, 768], [768, 768]),
97        ([2048, 768], [768, 3072]),
98        ([2048, 3072], [3072, 768]),
99        # hf_GPT2
100        ([1024, 768], [768, 768]),
101        ([1024, 768], [768, 3072]),
102        ([1024, 3072], [3072, 768]),
103        ([1024, 768], [768, 2304]),
104    ]
105    print("test total time")
106    test_total_time(shapes)
107
108    print("test GPU time")
109    test_GPU_time(shapes)
110
111
112# Results Preview on AWS AI cluster
113"""
114test total time
115shape; torch mm; triton mm; inductor aten mm; inductor triton mm
116[128, 9216] x [9216, 4096]; 0.07240759208798409; 0.10885953903198242; 0.20063146017491817; 0.20054904278367758
117[128, 4096] x [4096, 4096]; 0.03640300128608942; 0.10960095096379519; 0.09948539081960917; 0.0996188772842288
118[128, 4096] x [4096, 1000]; 0.02215010579675436; 0.12592008337378502; 0.031120930798351765; 0.0370654184371233
119[2048, 768] x [768, 768]; 0.023501068353652954; 0.10804693214595318; 0.03004650119692087; 0.0276932492852211
120[2048, 768] x [768, 3072]; 0.045639658346772194; 0.10883208829909563; 0.062736920081079; 0.06480381824076176
121[2048, 3072] x [3072, 768]; 0.054093082435429096; 0.10804777964949608; 0.08744294755160809; 0.07766005117446184
122[1024, 768] x [768, 768]; 0.021525858901441097; 0.10909941978752613; 0.02656651195138693; 0.02683836966753006
123[1024, 768] x [768, 3072]; 0.027319076471030712; 0.10825308971107006; 0.040118801407516; 0.039282338693737984
124[1024, 3072] x [3072, 768]; 0.034132059663534164; 0.10594133753329515; 0.05069758277386427; 0.04572632722556591
125[1024, 768] x [768, 2304]; 0.02529360819607973; 0.10486091021448374; 0.03724239766597748; 0.036449190229177475
126test GPU time
127shape; torch mm; triton mm; inductor aten mm; inductor triton mm
128[128, 9216] x [9216, 4096]; 0.09113600105047226; 0.09011200070381165; 0.21606400609016418; 0.21606400609016418
129[128, 4096] x [4096, 4096]; 0.053247999399900436; 0.05222399905323982; 0.1157120019197464; 0.1157120019197464
130[128, 4096] x [4096, 1000]; 0.026623999699950218; 0.02969600073993206; 0.04710400104522705; 0.05222399905323982
131[2048, 768] x [768, 768]; 0.02457600086927414; 0.020479999482631683; 0.04095999896526337; 0.03993599861860275
132[2048, 768] x [768, 3072]; 0.05119999870657921; 0.05222399905323982; 0.07475200295448303; 0.07577600330114365
133[2048, 3072] x [3072, 768]; 0.05939200147986412; 0.05222399905323982; 0.09830400347709656; 0.0870399996638298
134[1024, 768] x [768, 768]; 0.01945599913597107; 0.016383999958634377; 0.03276799991726875; 0.03276799991726875
135[1024, 768] x [768, 3072]; 0.03174399957060814; 0.03276799991726875; 0.053247999399900436; 0.053247999399900436
136[1024, 3072] x [3072, 768]; 0.04403200000524521; 0.03379200026392937; 0.06860800087451935; 0.062463998794555664
137[1024, 768] x [768, 2304]; 0.02969600073993206; 0.02969600073993206; 0.04915200173854828; 0.048128001391887665
138"""
139