xref: /aosp_15_r20/external/pytorch/benchmarks/dynamo/microbenchmarks/bench_mm_fusion.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# flake8: noqa
2
3import triton
4from prettytable import PrettyTable
5
6import torch
7import torch._dynamo
8import torch._inductor.config
9from torch._inductor.runtime.benchmarking import benchmarker
10
11
12# torch._inductor.config.debug = True
13torch._inductor.config.triton.dense_indexing = True
14torch.manual_seed(0)
15
16
17# The flag below controls whether to allow TF32 on matmul.
18torch.backends.cuda.matmul.allow_tf32 = True
19
20
21class Func(object):
22    # mm
23    @torch._dynamo.optimize("inductor")
24    def mm(a, b, bias):
25        y = torch.mm(a, b)
26        return y
27
28    # mm+bias
29    @torch._dynamo.optimize("inductor")
30    def mm_add(a, b, bias):
31        y = torch.mm(a, b)
32        return y + bias
33
34    # relu(mm)
35    @torch._dynamo.optimize("inductor")
36    def mm_relu(a, b, bias):
37        y = torch.mm(a, b)
38        return torch.relu(y)
39
40    # relu(mm+bias)
41    @torch._dynamo.optimize("inductor")
42    def mm_add_relu(a, b, bias):
43        y = torch.mm(a, b)
44        y += bias
45        return torch.relu(y)
46
47
48def bench(shape, layer_id, p, fusion_types=[""]):
49    dtype = torch.float16
50    M, K = shape[0]
51    _, N = shape[1]
52    torch.manual_seed(0)
53    # allocate inputs
54    a = torch.randn(shape[0], device="cuda", dtype=dtype)
55    b = torch.randn(shape[1], device="cuda", dtype=dtype)
56
57    def tflops(ms):
58        return M * K * N / ms * 1e-9
59
60    row = [layer_id]
61    for fusion_type in fusion_types:
62        if fusion_type == "":
63            fn_mm = getattr(Func, "mm")
64        else:
65            fn_mm = getattr(Func, f"mm_{fusion_type}")
66
67        if "add" in fusion_type:
68            bias = torch.randn((M, N), dtype=dtype, device="cuda")
69        else:
70            bias = None
71
72        args = (a, b, bias)
73
74        def fn():
75            return fn_mm(*args)
76
77        torch._inductor.config.triton.mm = "aten"
78        torch_mm_ms, _, _ = benchmarker.benchmark_gpu(fn)
79        torch._inductor.config.triton.mm = "triton"
80        # reset to force code gen new python code
81        torch._dynamo.reset()
82        torch._inductor.metrics.reset()
83        triton_mm_ms, _, _ = benchmarker.benchmark_gpu(fn)
84        assert (
85            torch._inductor.metrics.generated_kernel_count == 1
86        ), "codegen #kernel != 1"
87        row.extend([tflops(torch_mm_ms), tflops(triton_mm_ms)])
88
89    p.add_row(row)
90
91
92fusion_types = ["", "add", "relu", "add_relu"]
93shapes = [
94    # alexnet
95    ([128, 9216], [9216, 4096]),
96    ([128, 4096], [4096, 4096]),
97    ([128, 4096], [4096, 1000]),
98    # BERT
99    ([2048, 768], [768, 768]),
100    ([2048, 768], [768, 3072]),
101    ([2048, 3072], [3072, 768]),
102    # hf_GPT2
103    ([1024, 768], [768, 768]),
104    ([1024, 768], [768, 3072]),
105    ([1024, 3072], [3072, 768]),
106    ([1024, 768], [768, 2304]),
107]
108p = PrettyTable()
109field_names = ["layer"]
110for fusion_type in fusion_types:
111    if fusion_type == "":
112        field_names.append("torch mm")
113        field_names.append("triton mm")
114    else:
115        field_names.append(f"torch mm+{fusion_type}")
116        field_names.append(f"triton mm+{fusion_type}")
117
118p.field_names = field_names
119p.float_format = ".3"
120for id, shape in enumerate(shapes):
121    bench(shape, id, p, fusion_types)
122
123print(p)
124