xref: /aosp_15_r20/external/pytorch/functorch/benchmarks/operator_authoring.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import timeit
2from functools import partial
3
4import numpy as np
5import pandas as pd
6
7import torch
8from functorch.compile import pointwise_operator
9
10
11WRITE_CSV = False
12CUDA = False
13SIZES = [1, 512, 8192]
14NUMBER = [100, 10, 1, 1]
15REPEAT = 20
16
17
18@pointwise_operator
19def nnc_add(a, b):
20    return a + b
21
22
23@pointwise_operator
24def nnc_addnorm(a, b, mean, std):
25    return (a + b - mean) / std
26
27
28def eager_addnorm(a, b, mean, std):
29    return (a + b - mean) / std
30
31
32def inplace_addnorm(a, b, mean, std, out):
33    out = torch.add(a, b, out=out)
34    torch.sub(out, mean, out=out)
35    torch.div(out, std, out=out)
36    return out
37
38
39ts_addnorm = torch.jit.script(eager_addnorm)
40ts_ip_addnorm = torch.jit.script(inplace_addnorm)
41
42
43def maybe_synced(fn):
44    if CUDA:
45        synchronize = torch.cuda.synchronize
46        synchronize()  # warmup
47
48        def _fn():
49            result = fn()
50            synchronize()
51            return result
52
53        return _fn
54    return fn
55
56
57def benchmark_loop(setup):
58    result = np.zeros((REPEAT, len(SIZES), 2), dtype=np.float64)
59    for s, n in enumerate(SIZES):
60        nnc, aten = setup(n)
61        nnc = maybe_synced(nnc)
62        aten = maybe_synced(aten)
63
64        for r in range(result.shape[0]):
65            result[r, s, 0] = timeit.timeit(nnc, number=NUMBER[s])
66            result[r, s, 1] = timeit.timeit(aten, number=NUMBER[s])
67
68    result = np.median(result, axis=0)
69    assert result.shape == (len(SIZES), 2)
70    result = result[:, 1] / result[:, 0]
71    print(result)
72    return result
73
74
75def test(make_args, nnc=nnc_add, aten=torch.add):
76    def setup(n):
77        args = make_args(n)
78        result_aten = aten(*args)
79        result_nnc = nnc(*args)
80        assert result_nnc.dtype == result_aten.dtype
81        assert result_nnc.size() == result_aten.size()
82        assert result_nnc.stride() == result_aten.stride()
83        torch.testing.assert_close(result_aten, result_nnc)
84        return (lambda: nnc(*args), lambda: aten(*args))
85
86    return benchmark_loop(setup)
87
88
89def test_inplace(make_args, nnc=nnc_add, aten=torch.add):
90    def inplace_setup(n):
91        a, b = make_args(n)
92        result_aten = torch.clone(a)
93        result_nnc = torch.clone(a)
94        nnc(result_nnc, b, out=result_nnc)
95        aten(result_aten, b, out=result_aten)
96        torch.testing.assert_close(result_aten, result_nnc)
97        return (lambda: nnc(a, b, out=a), lambda: aten(a, b, out=a))
98
99    return benchmark_loop(inplace_setup)
100
101
102def test_out(make_args, out, nnc=nnc_add, aten=torch.add):
103    def out_setup(n):
104        args = make_args(n)
105        result_aten = out(n)
106        result_nnc = out(n)
107        aten(*args, out=result_aten)
108        nnc(*args, out=result_nnc)
109        torch.testing.assert_close(result_aten, result_nnc)
110        result = out(n)
111        return (lambda: nnc(*args, out=result), lambda: aten(*args, out=result))
112
113    return benchmark_loop(out_setup)
114
115
116def test_backwards(make_args, nnc=nnc_add, aten=torch.add):
117    def backwards_setup(n):
118        args = make_args(n)
119        (grad_var,) = (a for a in args if a.requires_grad)
120        aten(*args).sum().backward()
121        correct = grad_var.grad.clone()
122        grad_var.grad.zero_()
123        nnc(*args).sum().backward()
124        torch.testing.assert_close(correct, grad_var.grad)
125        return (
126            lambda: nnc(*args).sum().backward(),
127            lambda: aten(*args).sum().backward(),
128        )
129
130    return benchmark_loop(backwards_setup)
131
132
133def main():
134    torch.set_num_threads(1)  # TODO(jansel): add parallel support
135    torch._C._jit_override_can_fuse_on_cpu(True)
136
137    device = "cuda" if CUDA else "cpu"
138    I = partial(torch.randint, 0, 100, device=device)
139    R = partial(torch.randn, device=device)
140
141    results = [
142        ("add", test(lambda n: (R(n, n), R(n, n)))),
143        ("broadcast1", test(lambda n: (R(n, n), R(1)))),
144        ("broadcast2", test(lambda n: (R(n, n), R(n, 1)))),
145        ("broadcast3", test(lambda n: (R(n, 1), R(1, n)))),
146        ("inplace", test_inplace(lambda n: (R(n, n), R(n, 1)))),
147        ("out=", test_out(lambda n: (R(n, n), R(n, n)), out=lambda n: R(n, n))),
148        ("transposed1", test(lambda n: (R(n, n), R(n, n).transpose(0, 1)))),
149        (
150            "transposed2",
151            test(lambda n: (R(n, n).transpose(0, 1), R(n, n).transpose(0, 1))),
152        ),
153        ("slice1", test(lambda n: (R(n + 1, n + 1, 2)[:n, :n, 0], R(n, n)))),
154        ("slice2", test(lambda n: (R(n, n, 2)[:, :, 0], R(n, n, 2)[:, :, 0]))),
155        (
156            "strided out",
157            test_out(
158                lambda n: (R(n, n), R(n, n)),
159                out=lambda n: R(n + 1, n + 1, 2)[:n, :n, 0],
160            ),
161        ),
162        (
163            "out convert",
164            test_out(
165                lambda n: (R(n, n), R(n, n)), out=lambda n: R(n, n, dtype=torch.float64)
166            ),
167        ),
168        ("issue #57611 (n,32,32,2)", test(lambda n: (R(1, 32, 32, 2), R(n, 1, 1, 2)))),
169        ("float+double", test(lambda n: (R(n, n), R(n, n, dtype=torch.float64)))),
170        (
171            "int+long",
172            test(
173                lambda n: (I([n, n], dtype=torch.int32), I([n, n], dtype=torch.int64))
174            ),
175        ),
176        (
177            "int+short",
178            test(
179                lambda n: (I([n, n], dtype=torch.int32), I([n, n], dtype=torch.int16))
180            ),
181        ),
182        (
183            "float+int",
184            test(
185                lambda n: (R([n, n], dtype=torch.float32), I([n, n], dtype=torch.int32))
186            ),
187        ),
188        (
189            "double+long",
190            test(
191                lambda n: (R([n, n], dtype=torch.float64), I([n, n], dtype=torch.int64))
192            ),
193        ),
194        (
195            "fused addnorm",
196            test(
197                lambda n: (R(n, n), R(n, n), R(n, n), R(n, n)),
198                nnc=nnc_addnorm,
199                aten=eager_addnorm,
200            ),
201        ),
202        (
203            "fused addnorm (vs TS)",
204            test(
205                lambda n: (R(n, n), R(n, n), R(n, n), R(n, n)),
206                nnc=nnc_addnorm,
207                aten=ts_addnorm,
208            ),
209        ),
210        (
211            "fused addnorm out=",
212            test_out(
213                lambda n: (R(n, n), R(n, n), R(n, n), R(n, n)),
214                nnc=nnc_addnorm,
215                aten=inplace_addnorm,
216                out=lambda n: R(n, n),
217            ),
218        ),
219        (
220            "fused addnorm out= (vs TS)",
221            test_out(
222                lambda n: (R(n, n), R(n, n), R(n, n), R(n, n)),
223                nnc=nnc_addnorm,
224                aten=ts_ip_addnorm,
225                out=lambda n: R(n, n),
226            ),
227        ),
228        (
229            "fused addnorm backward",
230            test_backwards(
231                lambda n: (R(n, n), R(n, n, requires_grad=True), R(n, n), R(n, n)),
232                nnc=nnc_addnorm,
233                aten=eager_addnorm,
234            ),
235        ),
236        (
237            "fused addnorm backward (vs TS)",
238            test_backwards(
239                lambda n: (R(n, n), R(n, n, requires_grad=True), R(n, n), R(n, n)),
240                nnc=nnc_addnorm,
241                aten=ts_addnorm,
242            ),
243        ),
244    ]
245
246    df = pd.DataFrame(
247        np.stack([r for n, r in results]),
248        columns=[f"{n}x{n}".rjust(9) for n in SIZES],
249        index=[n for n, r in results],
250    )
251
252    if WRITE_CSV:
253        df.to_csv("../operator_authoring_results.csv")
254        print("wrote ../operator_authoring_results.csv")
255
256    print()
257    print("Speedups over aten")
258    pd.options.display.float_format = "{:.2f}x".format
259    print(df)
260
261
262if __name__ == "__main__":
263    main()
264