1*da0073e9SAndroid Build Coastguard Worker# AOT Autograd - Introduction to an experimental compilation feature in Functorch 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard WorkerThe primary compilation API we provide is something called AOTAutograd. AOT 4*da0073e9SAndroid Build Coastguard WorkerAutograd is an experimental feature that allows ahead of time capture of forward 5*da0073e9SAndroid Build Coastguard Workerand backward graphs, and allows easy integration with compilers. This creates an 6*da0073e9SAndroid Build Coastguard Workereasy to hack Python-based development environment to speedup training of PyTorch 7*da0073e9SAndroid Build Coastguard Workermodels. AOT Autograd currently lives inside functorch.compile namespace. 8*da0073e9SAndroid Build Coastguard Worker 9*da0073e9SAndroid Build Coastguard WorkerAOT Autograd is experimental and the APIs are likely to change. We are looking 10*da0073e9SAndroid Build Coastguard Workerfor feedback. If you are interested in using AOT Autograd and need help or have 11*da0073e9SAndroid Build Coastguard Workersuggestions, please feel free to open an issue. We will be happy to help. 12*da0073e9SAndroid Build Coastguard Worker 13*da0073e9SAndroid Build Coastguard WorkerFor example, here are some examples of how to use it. 14*da0073e9SAndroid Build Coastguard Worker```python 15*da0073e9SAndroid Build Coastguard Workerfrom functorch.compile import aot_function, aot_module, draw_graph 16*da0073e9SAndroid Build Coastguard Workerimport torch.fx as fx 17*da0073e9SAndroid Build Coastguard Workerimport torch 18*da0073e9SAndroid Build Coastguard Worker 19*da0073e9SAndroid Build Coastguard Worker# This simply prints out the FX graph of the forwards and the backwards 20*da0073e9SAndroid Build Coastguard Workerdef print_graph(name): 21*da0073e9SAndroid Build Coastguard Worker def f(fx_g: fx.GraphModule, inps): 22*da0073e9SAndroid Build Coastguard Worker print(name) 23*da0073e9SAndroid Build Coastguard Worker print(fx_g.code) 24*da0073e9SAndroid Build Coastguard Worker return fx_g 25*da0073e9SAndroid Build Coastguard Worker return f 26*da0073e9SAndroid Build Coastguard Worker 27*da0073e9SAndroid Build Coastguard Workerdef f(x): 28*da0073e9SAndroid Build Coastguard Worker return x.cos().cos() 29*da0073e9SAndroid Build Coastguard Worker 30*da0073e9SAndroid Build Coastguard Workernf = aot_function(f, fw_compiler=print_graph("forward"), bw_compiler=print_graph("backward")) 31*da0073e9SAndroid Build Coastguard Workernf(torch.randn(3, requires_grad=True)) 32*da0073e9SAndroid Build Coastguard Worker 33*da0073e9SAndroid Build Coastguard Worker# You can do whatever you want before and after, and you can still backprop through the function. 34*da0073e9SAndroid Build Coastguard Workerinp = torch.randn(3, requires_grad=True) 35*da0073e9SAndroid Build Coastguard Workerinp = inp.cos() 36*da0073e9SAndroid Build Coastguard Workerout = nf(inp) 37*da0073e9SAndroid Build Coastguard Workerout = out.sin().sum().backward() 38*da0073e9SAndroid Build Coastguard Worker 39*da0073e9SAndroid Build Coastguard Workerdef f(x): 40*da0073e9SAndroid Build Coastguard Worker return x.cos().cos() 41*da0073e9SAndroid Build Coastguard Worker 42*da0073e9SAndroid Build Coastguard Worker# This draws out the forwards and the backwards graphs as svg files 43*da0073e9SAndroid Build Coastguard Workerdef graph_drawer(name): 44*da0073e9SAndroid Build Coastguard Worker def f(fx_g: fx.GraphModule, inps): 45*da0073e9SAndroid Build Coastguard Worker draw_graph(fx_g, name) 46*da0073e9SAndroid Build Coastguard Worker return fx_g 47*da0073e9SAndroid Build Coastguard Worker return f 48*da0073e9SAndroid Build Coastguard Worker 49*da0073e9SAndroid Build Coastguard Workeraot_function(f, fw_compiler=graph_drawer("forward"), bw_compiler=graph_drawer("backward"))(torch.randn(3, requires_grad=True)) 50*da0073e9SAndroid Build Coastguard Worker 51*da0073e9SAndroid Build Coastguard Worker# We also have a convenience API for applying AOTAutograd to modules 52*da0073e9SAndroid Build Coastguard Workerfrom torchvision.models import resnet18 53*da0073e9SAndroid Build Coastguard Workeraot_module(resnet18(), print_graph("forward"), print_graph("backward"))(torch.randn(1,3,200,200)) 54*da0073e9SAndroid Build Coastguard Worker# output elided since it's very long 55*da0073e9SAndroid Build Coastguard Worker 56*da0073e9SAndroid Build Coastguard Worker# In practice, you might want to speed it up by sending it to Torchscript. You might also lower it to Torchscript before passing it to another compiler 57*da0073e9SAndroid Build Coastguard Worker 58*da0073e9SAndroid Build Coastguard Workerdef f(x): 59*da0073e9SAndroid Build Coastguard Worker return x.cos().cos() 60*da0073e9SAndroid Build Coastguard Worker 61*da0073e9SAndroid Build Coastguard Workerdef ts_compiler(fx_g: fx.GraphModule, inps): 62*da0073e9SAndroid Build Coastguard Worker f = torch.jit.script(fx_g) 63*da0073e9SAndroid Build Coastguard Worker print(f.graph) 64*da0073e9SAndroid Build Coastguard Worker f = torch.jit.freeze(f.eval()) # Note: This eval() works fine *even* though we're using this for training 65*da0073e9SAndroid Build Coastguard Worker return f 66*da0073e9SAndroid Build Coastguard Worker 67*da0073e9SAndroid Build Coastguard Workeraot_function(f, ts_compiler, ts_compiler)(torch.randn(3, requires_grad=True)) 68*da0073e9SAndroid Build Coastguard Worker``` 69*da0073e9SAndroid Build Coastguard Worker 70*da0073e9SAndroid Build Coastguard Worker## Documentation 71*da0073e9SAndroid Build Coastguard Worker* AOT Autograd [documentation](https://pytorch.org/functorch/nightly/) 72*da0073e9SAndroid Build Coastguard Worker* Min-cut [recomputation](https://dev-discuss.pytorch.org/t/min-cut-optimal-recomputation-i-e-activation-checkpointing-with-aotautograd/467) with AOT Autograd. 73*da0073e9SAndroid Build Coastguard Worker 74*da0073e9SAndroid Build Coastguard Worker## Tutorials 75*da0073e9SAndroid Build Coastguard WorkerYou can use this [tutorial](https://pytorch.org/functorch/nightly/notebooks/aot_autograd_optimizations.html) to play with AOT Autograd. 76