xref: /aosp_15_r20/external/pytorch/benchmarks/dynamo/microbenchmarks/microbench.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker#!/usr/bin/env python3
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport argparse
4*da0073e9SAndroid Build Coastguard Workerimport inspect
5*da0073e9SAndroid Build Coastguard Workerimport sys
6*da0073e9SAndroid Build Coastguard Worker
7*da0073e9SAndroid Build Coastguard Workerimport numpy as np
8*da0073e9SAndroid Build Coastguard Workerimport tabulate
9*da0073e9SAndroid Build Coastguard Worker
10*da0073e9SAndroid Build Coastguard Workerimport torch
11*da0073e9SAndroid Build Coastguard Workerimport torch._inductor
12*da0073e9SAndroid Build Coastguard Workerfrom torch._dynamo.backends.cudagraphs import cudagraphs_inner
13*da0073e9SAndroid Build Coastguard Workerfrom torch._dynamo.testing import same
14*da0073e9SAndroid Build Coastguard Workerfrom torch._inductor.compile_fx import compile_fx
15*da0073e9SAndroid Build Coastguard Workerfrom torch._inductor.utils import timed
16*da0073e9SAndroid Build Coastguard Worker
17*da0073e9SAndroid Build Coastguard Worker
18*da0073e9SAndroid Build Coastguard Workeraten = torch.ops.aten
19*da0073e9SAndroid Build Coastguard Worker
20*da0073e9SAndroid Build Coastguard Workertry:
21*da0073e9SAndroid Build Coastguard Worker    import test.test_torchinductor as tti
22*da0073e9SAndroid Build Coastguard Workerexcept ImportError:
23*da0073e9SAndroid Build Coastguard Worker    tti = None
24*da0073e9SAndroid Build Coastguard Worker
25*da0073e9SAndroid Build Coastguard Worker
26*da0073e9SAndroid Build Coastguard Workerdef compute_speedups(args, models, example_inputs):
27*da0073e9SAndroid Build Coastguard Worker    expected = models[0](*example_inputs)
28*da0073e9SAndroid Build Coastguard Worker    for model in models[1:]:
29*da0073e9SAndroid Build Coastguard Worker        actual = model(*example_inputs)
30*da0073e9SAndroid Build Coastguard Worker        assert same(actual, expected), expected[0] - actual[0]
31*da0073e9SAndroid Build Coastguard Worker
32*da0073e9SAndroid Build Coastguard Worker    timings = np.zeros((args.repeat, len(models)), np.float64)
33*da0073e9SAndroid Build Coastguard Worker    for rep in range(args.repeat):
34*da0073e9SAndroid Build Coastguard Worker        # interleave the runs to handle frequency scaling and load changes
35*da0073e9SAndroid Build Coastguard Worker        for m, model in enumerate(models):
36*da0073e9SAndroid Build Coastguard Worker            timings[rep, m] = timed(model, example_inputs)
37*da0073e9SAndroid Build Coastguard Worker    median = np.median(timings, axis=0)
38*da0073e9SAndroid Build Coastguard Worker    return (median[0] / median[1:]).tolist()
39*da0073e9SAndroid Build Coastguard Worker
40*da0073e9SAndroid Build Coastguard Worker
41*da0073e9SAndroid Build Coastguard Workerdef microbenchmark(args, model, example_inputs):
42*da0073e9SAndroid Build Coastguard Worker    compiled_fn = compile_fx(torch.fx.symbolic_trace(model), example_inputs)
43*da0073e9SAndroid Build Coastguard Worker    cudagraphs_eager = cudagraphs_inner(model, example_inputs, copy_outputs=False)
44*da0073e9SAndroid Build Coastguard Worker    cudagraphs_jit = cudagraphs_inner(
45*da0073e9SAndroid Build Coastguard Worker        torch.jit.trace(model, example_inputs), example_inputs, copy_outputs=False
46*da0073e9SAndroid Build Coastguard Worker    )
47*da0073e9SAndroid Build Coastguard Worker    return compute_speedups(
48*da0073e9SAndroid Build Coastguard Worker        args,
49*da0073e9SAndroid Build Coastguard Worker        [cudagraphs_eager, cudagraphs_jit, compiled_fn],
50*da0073e9SAndroid Build Coastguard Worker        example_inputs,
51*da0073e9SAndroid Build Coastguard Worker    )
52*da0073e9SAndroid Build Coastguard Worker
53*da0073e9SAndroid Build Coastguard Worker
54*da0073e9SAndroid Build Coastguard Workerclass MyModel1(torch.nn.Module):
55*da0073e9SAndroid Build Coastguard Worker    def __init__(self):
56*da0073e9SAndroid Build Coastguard Worker        super().__init__()
57*da0073e9SAndroid Build Coastguard Worker        self.model = torch.nn.Sequential(
58*da0073e9SAndroid Build Coastguard Worker            torch.nn.Linear(1024, 1024),
59*da0073e9SAndroid Build Coastguard Worker            torch.nn.ReLU(),
60*da0073e9SAndroid Build Coastguard Worker        )
61*da0073e9SAndroid Build Coastguard Worker
62*da0073e9SAndroid Build Coastguard Worker    def forward(self, input):
63*da0073e9SAndroid Build Coastguard Worker        # return (self.model(input) + 1,)
64*da0073e9SAndroid Build Coastguard Worker        return (self.model(input),)
65*da0073e9SAndroid Build Coastguard Worker
66*da0073e9SAndroid Build Coastguard Worker
67*da0073e9SAndroid Build Coastguard Workerclass MyModel2(torch.nn.Module):
68*da0073e9SAndroid Build Coastguard Worker    def forward(self, x, y):
69*da0073e9SAndroid Build Coastguard Worker        # return x / (torch.abs(x) + 1.0),
70*da0073e9SAndroid Build Coastguard Worker        return (x + y,)
71*da0073e9SAndroid Build Coastguard Worker
72*da0073e9SAndroid Build Coastguard Worker
73*da0073e9SAndroid Build Coastguard Workerclass MicroBenchmarks:
74*da0073e9SAndroid Build Coastguard Worker    @staticmethod
75*da0073e9SAndroid Build Coastguard Worker    def add(a, b):
76*da0073e9SAndroid Build Coastguard Worker        return (a + b,)
77*da0073e9SAndroid Build Coastguard Worker
78*da0073e9SAndroid Build Coastguard Worker    @staticmethod
79*da0073e9SAndroid Build Coastguard Worker    def scale(x, m, d):
80*da0073e9SAndroid Build Coastguard Worker        return ((x - m) / torch.clip(d, 1e-4),)
81*da0073e9SAndroid Build Coastguard Worker
82*da0073e9SAndroid Build Coastguard Worker    @staticmethod
83*da0073e9SAndroid Build Coastguard Worker    def abs_norm(x):
84*da0073e9SAndroid Build Coastguard Worker        return (x / (torch.abs(x) + 1),)
85*da0073e9SAndroid Build Coastguard Worker
86*da0073e9SAndroid Build Coastguard Worker    @staticmethod
87*da0073e9SAndroid Build Coastguard Worker    def add_relu_softmax(x, a):
88*da0073e9SAndroid Build Coastguard Worker        return (torch.softmax(torch.relu(x + a), -1),)
89*da0073e9SAndroid Build Coastguard Worker
90*da0073e9SAndroid Build Coastguard Worker    @staticmethod
91*da0073e9SAndroid Build Coastguard Worker    def sum(a, b):
92*da0073e9SAndroid Build Coastguard Worker        return ((a + b).sum(),)
93*da0073e9SAndroid Build Coastguard Worker
94*da0073e9SAndroid Build Coastguard Worker    @staticmethod
95*da0073e9SAndroid Build Coastguard Worker    def view(x):
96*da0073e9SAndroid Build Coastguard Worker        return (aten.alias(x),)
97*da0073e9SAndroid Build Coastguard Worker
98*da0073e9SAndroid Build Coastguard Worker
99*da0073e9SAndroid Build Coastguard Workerdef main():
100*da0073e9SAndroid Build Coastguard Worker    parser = argparse.ArgumentParser()
101*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
102*da0073e9SAndroid Build Coastguard Worker        "--filter", "-k", action="append", help="filter benchmarks with regexp"
103*da0073e9SAndroid Build Coastguard Worker    )
104*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
105*da0073e9SAndroid Build Coastguard Worker        "--exclude", "-x", action="append", help="filter benchmarks with regexp"
106*da0073e9SAndroid Build Coastguard Worker    )
107*da0073e9SAndroid Build Coastguard Worker    parser.add_argument("--devices", "-d", action="append", help="cpu or cuda")
108*da0073e9SAndroid Build Coastguard Worker    parser.add_argument("--size", "-s", action="append", help="cpu or cuda")
109*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
110*da0073e9SAndroid Build Coastguard Worker        "--repeat", "-n", type=int, default=30, help="number of timing runs"
111*da0073e9SAndroid Build Coastguard Worker    )
112*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
113*da0073e9SAndroid Build Coastguard Worker        "--threads", "-t", type=int, help="number of threads to use for eager"
114*da0073e9SAndroid Build Coastguard Worker    )
115*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
116*da0073e9SAndroid Build Coastguard Worker        "--verbose", "-v", action="store_true", help="enable verbose debug printouts"
117*da0073e9SAndroid Build Coastguard Worker    )
118*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
119*da0073e9SAndroid Build Coastguard Worker        "--nvfuser", action="store_true", help="enable nvfuser globally"
120*da0073e9SAndroid Build Coastguard Worker    )
121*da0073e9SAndroid Build Coastguard Worker    parser.add_argument("--transpose", action="store_true", help="transpose one input")
122*da0073e9SAndroid Build Coastguard Worker    parser.add_argument("--broadcast", action="store_true", help="broadcast one input")
123*da0073e9SAndroid Build Coastguard Worker    args = parser.parse_args()
124*da0073e9SAndroid Build Coastguard Worker
125*da0073e9SAndroid Build Coastguard Worker    # defaults
126*da0073e9SAndroid Build Coastguard Worker    args.devices = args.devices or ["cpu", "cuda"]
127*da0073e9SAndroid Build Coastguard Worker    args.filter = args.filter or [r"."]
128*da0073e9SAndroid Build Coastguard Worker    args.exclude = args.exclude or [r"^$"]
129*da0073e9SAndroid Build Coastguard Worker    args.size = args.size or [64, 256, 1024, 4096, 8192]
130*da0073e9SAndroid Build Coastguard Worker
131*da0073e9SAndroid Build Coastguard Worker    if args.nvfuser:
132*da0073e9SAndroid Build Coastguard Worker        torch._C._jit_override_can_fuse_on_cpu(False)
133*da0073e9SAndroid Build Coastguard Worker        torch._C._jit_override_can_fuse_on_gpu(False)
134*da0073e9SAndroid Build Coastguard Worker        torch._C._jit_set_texpr_fuser_enabled(False)
135*da0073e9SAndroid Build Coastguard Worker        torch._C._jit_set_nvfuser_enabled(True)
136*da0073e9SAndroid Build Coastguard Worker    else:
137*da0073e9SAndroid Build Coastguard Worker        torch._C._jit_override_can_fuse_on_cpu(torch._C._llvm_enabled())
138*da0073e9SAndroid Build Coastguard Worker        torch._C._jit_override_can_fuse_on_gpu(True)
139*da0073e9SAndroid Build Coastguard Worker        torch._C._jit_set_texpr_fuser_enabled(True)
140*da0073e9SAndroid Build Coastguard Worker        if torch.cuda.is_available():
141*da0073e9SAndroid Build Coastguard Worker            torch._C._jit_set_nvfuser_enabled(False)
142*da0073e9SAndroid Build Coastguard Worker
143*da0073e9SAndroid Build Coastguard Worker    if args.threads:
144*da0073e9SAndroid Build Coastguard Worker        torch.set_num_threads(args.threads)
145*da0073e9SAndroid Build Coastguard Worker        torch._inductor.config.cpp.threads = args.threads
146*da0073e9SAndroid Build Coastguard Worker
147*da0073e9SAndroid Build Coastguard Worker    if args.verbose:
148*da0073e9SAndroid Build Coastguard Worker        torch._inductor.config.debug = True
149*da0073e9SAndroid Build Coastguard Worker
150*da0073e9SAndroid Build Coastguard Worker    torch._inductor.config.triton.autotune_pointwise = True
151*da0073e9SAndroid Build Coastguard Worker
152*da0073e9SAndroid Build Coastguard Worker    rows = []
153*da0073e9SAndroid Build Coastguard Worker    for model in (MicroBenchmarks.sum, MicroBenchmarks.view):
154*da0073e9SAndroid Build Coastguard Worker        nargs = len(inspect.signature(model).parameters)
155*da0073e9SAndroid Build Coastguard Worker        for device in args.devices:
156*da0073e9SAndroid Build Coastguard Worker            for n in args.size:
157*da0073e9SAndroid Build Coastguard Worker                n = int(n)
158*da0073e9SAndroid Build Coastguard Worker                sys.stdout.write(f"{model.__name__:10} {device:4} {n:5} ")
159*da0073e9SAndroid Build Coastguard Worker                sys.stdout.flush()
160*da0073e9SAndroid Build Coastguard Worker                inputs = [torch.rand((n, n), device=device) for _ in range(nargs)]
161*da0073e9SAndroid Build Coastguard Worker                if args.broadcast:
162*da0073e9SAndroid Build Coastguard Worker                    inputs[-1] = torch.rand((1, n), device=device)
163*da0073e9SAndroid Build Coastguard Worker                if args.transpose:
164*da0073e9SAndroid Build Coastguard Worker                    inputs[-1] = inputs[-1].transpose(0, 1)
165*da0073e9SAndroid Build Coastguard Worker                result = microbenchmark(args, model, inputs)
166*da0073e9SAndroid Build Coastguard Worker                rows.append([model.__name__, device, str(n)] + result)
167*da0073e9SAndroid Build Coastguard Worker                print(" ".join(f"{v:.2f}x" for v in result))
168*da0073e9SAndroid Build Coastguard Worker
169*da0073e9SAndroid Build Coastguard Worker    print(
170*da0073e9SAndroid Build Coastguard Worker        tabulate.tabulate(
171*da0073e9SAndroid Build Coastguard Worker            rows,
172*da0073e9SAndroid Build Coastguard Worker            headers=[
173*da0073e9SAndroid Build Coastguard Worker                "model",
174*da0073e9SAndroid Build Coastguard Worker                "dev",
175*da0073e9SAndroid Build Coastguard Worker                "n",
176*da0073e9SAndroid Build Coastguard Worker                "ts",
177*da0073e9SAndroid Build Coastguard Worker                "inductor",
178*da0073e9SAndroid Build Coastguard Worker            ],
179*da0073e9SAndroid Build Coastguard Worker        )
180*da0073e9SAndroid Build Coastguard Worker    )
181*da0073e9SAndroid Build Coastguard Worker
182*da0073e9SAndroid Build Coastguard Worker
183*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
184*da0073e9SAndroid Build Coastguard Worker    main()
185