xref: /aosp_15_r20/external/pytorch/benchmarks/dynamo/microbenchmarks/microbench.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#!/usr/bin/env python3
2
3import argparse
4import inspect
5import sys
6
7import numpy as np
8import tabulate
9
10import torch
11import torch._inductor
12from torch._dynamo.backends.cudagraphs import cudagraphs_inner
13from torch._dynamo.testing import same
14from torch._inductor.compile_fx import compile_fx
15from torch._inductor.utils import timed
16
17
18aten = torch.ops.aten
19
20try:
21    import test.test_torchinductor as tti
22except ImportError:
23    tti = None
24
25
26def compute_speedups(args, models, example_inputs):
27    expected = models[0](*example_inputs)
28    for model in models[1:]:
29        actual = model(*example_inputs)
30        assert same(actual, expected), expected[0] - actual[0]
31
32    timings = np.zeros((args.repeat, len(models)), np.float64)
33    for rep in range(args.repeat):
34        # interleave the runs to handle frequency scaling and load changes
35        for m, model in enumerate(models):
36            timings[rep, m] = timed(model, example_inputs)
37    median = np.median(timings, axis=0)
38    return (median[0] / median[1:]).tolist()
39
40
41def microbenchmark(args, model, example_inputs):
42    compiled_fn = compile_fx(torch.fx.symbolic_trace(model), example_inputs)
43    cudagraphs_eager = cudagraphs_inner(model, example_inputs, copy_outputs=False)
44    cudagraphs_jit = cudagraphs_inner(
45        torch.jit.trace(model, example_inputs), example_inputs, copy_outputs=False
46    )
47    return compute_speedups(
48        args,
49        [cudagraphs_eager, cudagraphs_jit, compiled_fn],
50        example_inputs,
51    )
52
53
54class MyModel1(torch.nn.Module):
55    def __init__(self):
56        super().__init__()
57        self.model = torch.nn.Sequential(
58            torch.nn.Linear(1024, 1024),
59            torch.nn.ReLU(),
60        )
61
62    def forward(self, input):
63        # return (self.model(input) + 1,)
64        return (self.model(input),)
65
66
67class MyModel2(torch.nn.Module):
68    def forward(self, x, y):
69        # return x / (torch.abs(x) + 1.0),
70        return (x + y,)
71
72
73class MicroBenchmarks:
74    @staticmethod
75    def add(a, b):
76        return (a + b,)
77
78    @staticmethod
79    def scale(x, m, d):
80        return ((x - m) / torch.clip(d, 1e-4),)
81
82    @staticmethod
83    def abs_norm(x):
84        return (x / (torch.abs(x) + 1),)
85
86    @staticmethod
87    def add_relu_softmax(x, a):
88        return (torch.softmax(torch.relu(x + a), -1),)
89
90    @staticmethod
91    def sum(a, b):
92        return ((a + b).sum(),)
93
94    @staticmethod
95    def view(x):
96        return (aten.alias(x),)
97
98
99def main():
100    parser = argparse.ArgumentParser()
101    parser.add_argument(
102        "--filter", "-k", action="append", help="filter benchmarks with regexp"
103    )
104    parser.add_argument(
105        "--exclude", "-x", action="append", help="filter benchmarks with regexp"
106    )
107    parser.add_argument("--devices", "-d", action="append", help="cpu or cuda")
108    parser.add_argument("--size", "-s", action="append", help="cpu or cuda")
109    parser.add_argument(
110        "--repeat", "-n", type=int, default=30, help="number of timing runs"
111    )
112    parser.add_argument(
113        "--threads", "-t", type=int, help="number of threads to use for eager"
114    )
115    parser.add_argument(
116        "--verbose", "-v", action="store_true", help="enable verbose debug printouts"
117    )
118    parser.add_argument(
119        "--nvfuser", action="store_true", help="enable nvfuser globally"
120    )
121    parser.add_argument("--transpose", action="store_true", help="transpose one input")
122    parser.add_argument("--broadcast", action="store_true", help="broadcast one input")
123    args = parser.parse_args()
124
125    # defaults
126    args.devices = args.devices or ["cpu", "cuda"]
127    args.filter = args.filter or [r"."]
128    args.exclude = args.exclude or [r"^$"]
129    args.size = args.size or [64, 256, 1024, 4096, 8192]
130
131    if args.nvfuser:
132        torch._C._jit_override_can_fuse_on_cpu(False)
133        torch._C._jit_override_can_fuse_on_gpu(False)
134        torch._C._jit_set_texpr_fuser_enabled(False)
135        torch._C._jit_set_nvfuser_enabled(True)
136    else:
137        torch._C._jit_override_can_fuse_on_cpu(torch._C._llvm_enabled())
138        torch._C._jit_override_can_fuse_on_gpu(True)
139        torch._C._jit_set_texpr_fuser_enabled(True)
140        if torch.cuda.is_available():
141            torch._C._jit_set_nvfuser_enabled(False)
142
143    if args.threads:
144        torch.set_num_threads(args.threads)
145        torch._inductor.config.cpp.threads = args.threads
146
147    if args.verbose:
148        torch._inductor.config.debug = True
149
150    torch._inductor.config.triton.autotune_pointwise = True
151
152    rows = []
153    for model in (MicroBenchmarks.sum, MicroBenchmarks.view):
154        nargs = len(inspect.signature(model).parameters)
155        for device in args.devices:
156            for n in args.size:
157                n = int(n)
158                sys.stdout.write(f"{model.__name__:10} {device:4} {n:5} ")
159                sys.stdout.flush()
160                inputs = [torch.rand((n, n), device=device) for _ in range(nargs)]
161                if args.broadcast:
162                    inputs[-1] = torch.rand((1, n), device=device)
163                if args.transpose:
164                    inputs[-1] = inputs[-1].transpose(0, 1)
165                result = microbenchmark(args, model, inputs)
166                rows.append([model.__name__, device, str(n)] + result)
167                print(" ".join(f"{v:.2f}x" for v in result))
168
169    print(
170        tabulate.tabulate(
171            rows,
172            headers=[
173                "model",
174                "dev",
175                "n",
176                "ts",
177                "inductor",
178            ],
179        )
180    )
181
182
183if __name__ == "__main__":
184    main()
185