1import timeit 2 3import torch 4import torch.nn as nn 5from functorch.compile import compiled_module, tvm_compile 6 7 8def nop(f, _): 9 return f 10 11 12fw_compiler = tvm_compile(target="llvm", tuning_logfile="fw_keops") 13bw_compiler = tvm_compile(target="llvm", tuning_logfile="bw_keops") 14fw_compiler = nop 15bw_compiler = nop 16 17 18def run(mod, input): 19 out = mod(input) 20 out.sum().backward() 21 grads = [p.grad for p in mod.parameters()] 22 return (out, *grads) 23 24 25class Foo(nn.Module): 26 def __init__(self) -> None: 27 super().__init__() 28 self.param = nn.Parameter(torch.randn(1)) 29 self.register_buffer("buf", torch.randn(1)) 30 31 def forward(self, x): 32 return (self.param * x + self.buf).sum(dim=0) 33 34 35input = torch.randn(1) 36mod = Foo() 37compiled_mod = compiled_module(mod, fw_compiler, bw_compiler) 38 39for a, b in zip(run(mod, input), run(compiled_mod, input)): 40 torch.testing.assert_close(a, b) 41 42out = mod(input) 43out.sum().backward() 44mod.param.data -= mod.param.grad 45compiled_mod.orig_module.param.data -= compiled_mod.orig_module.param.grad 46compiled_mod.orig_module.param.grad = None 47 48for a, b in zip(run(mod, input), run(compiled_mod, input)): 49 torch.testing.assert_close(a, b) 50 51for _ in range(5): 52 i = 10000 53 t = timeit.Timer("mod(input)", globals=globals()).timeit(10000) 54 print(f"eager {t/i*1e6}") 55 t = timeit.Timer("compiled_mod(input)", globals=globals()).timeit(10000) 56 print(f"compiled {t/i*1e6}") 57