xref: /aosp_15_r20/external/pytorch/functorch/examples/compilation/fuse_module.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import timeit
2
3import torch
4import torch.nn as nn
5from functorch.compile import compiled_module, tvm_compile
6
7
8def nop(f, _):
9    return f
10
11
12fw_compiler = tvm_compile(target="llvm", tuning_logfile="fw_keops")
13bw_compiler = tvm_compile(target="llvm", tuning_logfile="bw_keops")
14fw_compiler = nop
15bw_compiler = nop
16
17
18def run(mod, input):
19    out = mod(input)
20    out.sum().backward()
21    grads = [p.grad for p in mod.parameters()]
22    return (out, *grads)
23
24
25class Foo(nn.Module):
26    def __init__(self) -> None:
27        super().__init__()
28        self.param = nn.Parameter(torch.randn(1))
29        self.register_buffer("buf", torch.randn(1))
30
31    def forward(self, x):
32        return (self.param * x + self.buf).sum(dim=0)
33
34
35input = torch.randn(1)
36mod = Foo()
37compiled_mod = compiled_module(mod, fw_compiler, bw_compiler)
38
39for a, b in zip(run(mod, input), run(compiled_mod, input)):
40    torch.testing.assert_close(a, b)
41
42out = mod(input)
43out.sum().backward()
44mod.param.data -= mod.param.grad
45compiled_mod.orig_module.param.data -= compiled_mod.orig_module.param.grad
46compiled_mod.orig_module.param.grad = None
47
48for a, b in zip(run(mod, input), run(compiled_mod, input)):
49    torch.testing.assert_close(a, b)
50
51for _ in range(5):
52    i = 10000
53    t = timeit.Timer("mod(input)", globals=globals()).timeit(10000)
54    print(f"eager {t/i*1e6}")
55    t = timeit.Timer("compiled_mod(input)", globals=globals()).timeit(10000)
56    print(f"compiled {t/i*1e6}")
57