xref: /aosp_15_r20/external/pytorch/functorch/COMPILE_README.md (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# AOT Autograd - Introduction to an experimental compilation feature in Functorch
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard WorkerThe primary compilation API we provide is something called AOTAutograd. AOT
4*da0073e9SAndroid Build Coastguard WorkerAutograd is an experimental feature that allows ahead of time capture of forward
5*da0073e9SAndroid Build Coastguard Workerand backward graphs, and allows easy integration with compilers. This creates an
6*da0073e9SAndroid Build Coastguard Workereasy to hack Python-based development environment to speedup training of PyTorch
7*da0073e9SAndroid Build Coastguard Workermodels. AOT Autograd currently lives inside functorch.compile namespace.
8*da0073e9SAndroid Build Coastguard Worker
9*da0073e9SAndroid Build Coastguard WorkerAOT Autograd is experimental and the APIs are likely to change. We are looking
10*da0073e9SAndroid Build Coastguard Workerfor feedback. If you are interested in using AOT Autograd and need help or have
11*da0073e9SAndroid Build Coastguard Workersuggestions, please feel free to open an issue. We will be happy to help.
12*da0073e9SAndroid Build Coastguard Worker
13*da0073e9SAndroid Build Coastguard WorkerFor example, here are some examples of how to use it.
14*da0073e9SAndroid Build Coastguard Worker```python
15*da0073e9SAndroid Build Coastguard Workerfrom functorch.compile import aot_function, aot_module, draw_graph
16*da0073e9SAndroid Build Coastguard Workerimport torch.fx as fx
17*da0073e9SAndroid Build Coastguard Workerimport torch
18*da0073e9SAndroid Build Coastguard Worker
19*da0073e9SAndroid Build Coastguard Worker# This simply prints out the FX graph of the forwards and the backwards
20*da0073e9SAndroid Build Coastguard Workerdef print_graph(name):
21*da0073e9SAndroid Build Coastguard Worker    def f(fx_g: fx.GraphModule, inps):
22*da0073e9SAndroid Build Coastguard Worker        print(name)
23*da0073e9SAndroid Build Coastguard Worker        print(fx_g.code)
24*da0073e9SAndroid Build Coastguard Worker        return fx_g
25*da0073e9SAndroid Build Coastguard Worker    return f
26*da0073e9SAndroid Build Coastguard Worker
27*da0073e9SAndroid Build Coastguard Workerdef f(x):
28*da0073e9SAndroid Build Coastguard Worker    return x.cos().cos()
29*da0073e9SAndroid Build Coastguard Worker
30*da0073e9SAndroid Build Coastguard Workernf = aot_function(f, fw_compiler=print_graph("forward"), bw_compiler=print_graph("backward"))
31*da0073e9SAndroid Build Coastguard Workernf(torch.randn(3, requires_grad=True))
32*da0073e9SAndroid Build Coastguard Worker
33*da0073e9SAndroid Build Coastguard Worker# You can do whatever you want before and after, and you can still backprop through the function.
34*da0073e9SAndroid Build Coastguard Workerinp = torch.randn(3, requires_grad=True)
35*da0073e9SAndroid Build Coastguard Workerinp = inp.cos()
36*da0073e9SAndroid Build Coastguard Workerout = nf(inp)
37*da0073e9SAndroid Build Coastguard Workerout = out.sin().sum().backward()
38*da0073e9SAndroid Build Coastguard Worker
39*da0073e9SAndroid Build Coastguard Workerdef f(x):
40*da0073e9SAndroid Build Coastguard Worker    return x.cos().cos()
41*da0073e9SAndroid Build Coastguard Worker
42*da0073e9SAndroid Build Coastguard Worker# This draws out the forwards and the backwards graphs as svg files
43*da0073e9SAndroid Build Coastguard Workerdef graph_drawer(name):
44*da0073e9SAndroid Build Coastguard Worker    def f(fx_g: fx.GraphModule, inps):
45*da0073e9SAndroid Build Coastguard Worker        draw_graph(fx_g, name)
46*da0073e9SAndroid Build Coastguard Worker        return fx_g
47*da0073e9SAndroid Build Coastguard Worker    return f
48*da0073e9SAndroid Build Coastguard Worker
49*da0073e9SAndroid Build Coastguard Workeraot_function(f, fw_compiler=graph_drawer("forward"), bw_compiler=graph_drawer("backward"))(torch.randn(3, requires_grad=True))
50*da0073e9SAndroid Build Coastguard Worker
51*da0073e9SAndroid Build Coastguard Worker# We also have a convenience API for applying AOTAutograd to modules
52*da0073e9SAndroid Build Coastguard Workerfrom torchvision.models import resnet18
53*da0073e9SAndroid Build Coastguard Workeraot_module(resnet18(), print_graph("forward"), print_graph("backward"))(torch.randn(1,3,200,200))
54*da0073e9SAndroid Build Coastguard Worker# output elided since it's very long
55*da0073e9SAndroid Build Coastguard Worker
56*da0073e9SAndroid Build Coastguard Worker# In practice, you might want to speed it up by sending it to Torchscript. You might also lower it to Torchscript before passing it to another compiler
57*da0073e9SAndroid Build Coastguard Worker
58*da0073e9SAndroid Build Coastguard Workerdef f(x):
59*da0073e9SAndroid Build Coastguard Worker    return x.cos().cos()
60*da0073e9SAndroid Build Coastguard Worker
61*da0073e9SAndroid Build Coastguard Workerdef ts_compiler(fx_g: fx.GraphModule, inps):
62*da0073e9SAndroid Build Coastguard Worker    f = torch.jit.script(fx_g)
63*da0073e9SAndroid Build Coastguard Worker    print(f.graph)
64*da0073e9SAndroid Build Coastguard Worker    f = torch.jit.freeze(f.eval()) # Note: This eval() works fine *even* though we're using this for training
65*da0073e9SAndroid Build Coastguard Worker    return f
66*da0073e9SAndroid Build Coastguard Worker
67*da0073e9SAndroid Build Coastguard Workeraot_function(f, ts_compiler, ts_compiler)(torch.randn(3, requires_grad=True))
68*da0073e9SAndroid Build Coastguard Worker```
69*da0073e9SAndroid Build Coastguard Worker
70*da0073e9SAndroid Build Coastguard Worker## Documentation
71*da0073e9SAndroid Build Coastguard Worker* AOT Autograd [documentation](https://pytorch.org/functorch/nightly/)
72*da0073e9SAndroid Build Coastguard Worker* Min-cut [recomputation](https://dev-discuss.pytorch.org/t/min-cut-optimal-recomputation-i-e-activation-checkpointing-with-aotautograd/467) with AOT Autograd.
73*da0073e9SAndroid Build Coastguard Worker
74*da0073e9SAndroid Build Coastguard Worker## Tutorials
75*da0073e9SAndroid Build Coastguard WorkerYou can use this [tutorial](https://pytorch.org/functorch/nightly/notebooks/aot_autograd_optimizations.html) to play with AOT Autograd.
76