1""" 2============================= 3Jacobians, hessians, and more 4============================= 5 6Computing jacobians or hessians are useful in a number of non-traditional 7deep learning models. It is difficult (or annoying) to compute these quantities 8efficiently using a standard autodiff system like PyTorch Autograd; functorch 9provides ways of computing various higher-order autodiff quantities efficiently. 10""" 11from functools import partial 12 13import torch 14import torch.nn.functional as F 15 16 17torch.manual_seed(0) 18 19 20###################################################################### 21# Setup: Comparing functorch vs the naive approach 22# -------------------------------------------------------------------- 23# Let's start with a function that we'd like to compute the jacobian of. 24# This is a simple linear function with non-linear activation. 25def predict(weight, bias, x): 26 return F.linear(x, weight, bias).tanh() 27 28 29# Here's some dummy data: a weight, a bias, and a feature vector. 30D = 16 31weight = torch.randn(D, D) 32bias = torch.randn(D) 33x = torch.randn(D) 34 35# Let's think of ``predict`` as a function that maps the input ``x`` from R^D -> R^D. 36# PyTorch Autograd computes vector-Jacobian products. In order to compute the full 37# Jacobian of this R^D -> R^D function, we would have to compute it row-by-row 38# by using a different unit vector each time. 39xp = x.clone().requires_grad_() 40unit_vectors = torch.eye(D) 41 42 43def compute_jac(xp): 44 jacobian_rows = [ 45 torch.autograd.grad(predict(weight, bias, xp), xp, vec)[0] 46 for vec in unit_vectors 47 ] 48 return torch.stack(jacobian_rows) 49 50 51jacobian = compute_jac(xp) 52 53# Instead of computing the jacobian row-by-row, we can use ``vmap`` to get rid 54# of the for-loop and vectorize the computation. We can't directly apply vmap 55# to PyTorch Autograd; instead, functorch provides a ``vjp`` transform: 56from functorch import vjp, vmap 57 58 59_, vjp_fn = vjp(partial(predict, weight, bias), x) 60(ft_jacobian,) = vmap(vjp_fn)(unit_vectors) 61assert torch.allclose(ft_jacobian, jacobian) 62 63# In another tutorial a composition of reverse-mode AD and vmap gave us 64# per-sample-gradients. In this tutorial, composing reverse-mode AD and vmap 65# gives us Jacobian computation! Various compositions of vmap and autodiff 66# transforms can give us different interesting quantities. 67# 68# functorch provides ``jacrev`` as a convenience function that performs 69# the vmap-vjp composition to compute jacobians. ``jacrev`` accepts an argnums 70# argument that says which argument we would like to compute Jacobians with 71# respect to. 72from functorch import jacrev 73 74 75ft_jacobian = jacrev(predict, argnums=2)(weight, bias, x) 76assert torch.allclose(ft_jacobian, jacobian) 77 78# Let's compare the performance of the two ways to compute jacobian. 79# The functorch version is much faster (and becomes even faster the more outputs 80# there are). In general, we expect that vectorization via ``vmap`` can help 81# eliminate overhead and give better utilization of your hardware. 82from torch.utils.benchmark import Timer 83 84 85without_vmap = Timer(stmt="compute_jac(xp)", globals=globals()) 86with_vmap = Timer(stmt="jacrev(predict, argnums=2)(weight, bias, x)", globals=globals()) 87print(without_vmap.timeit(500)) 88print(with_vmap.timeit(500)) 89 90# It's pretty easy to flip the problem around and say we want to compute 91# Jacobians of the parameters to our model (weight, bias) instead of the input. 92ft_jac_weight, ft_jac_bias = jacrev(predict, argnums=(0, 1))(weight, bias, x) 93 94###################################################################### 95# reverse-mode Jacobian (jacrev) vs forward-mode Jacobian (jacfwd) 96# -------------------------------------------------------------------- 97# We offer two APIs to compute jacobians: jacrev and jacfwd: 98# - jacrev uses reverse-mode AD. As you saw above it is a composition of our 99# vjp and vmap transforms. 100# - jacfwd uses forward-mode AD. It is implemented as a composition of our 101# jvp and vmap transforms. 102# jacfwd and jacrev can be subsituted for each other and have different 103# performance characteristics. 104# 105# As a general rule of thumb, if you're computing the jacobian of an R^N -> R^M 106# function, if there are many more outputs than inputs (i.e. M > N) then jacfwd is 107# preferred, otherwise use jacrev. There are exceptions to this rule, but a 108# non-rigorous argument for this follows: 109 110# In reverse-mode AD, we are computing the jacobian row-by-row, while in 111# forward-mode AD (which computes Jacobian-vector products), we are computing 112# it column-by-column. The Jacobian matrix has M rows and N columns. 113from functorch import jacfwd, jacrev 114 115 116# Benchmark with more inputs than outputs 117Din = 32 118Dout = 2048 119weight = torch.randn(Dout, Din) 120bias = torch.randn(Dout) 121x = torch.randn(Din) 122 123using_fwd = Timer(stmt="jacfwd(predict, argnums=2)(weight, bias, x)", globals=globals()) 124using_bwd = Timer(stmt="jacrev(predict, argnums=2)(weight, bias, x)", globals=globals()) 125print(f"jacfwd time: {using_fwd.timeit(500)}") 126print(f"jacrev time: {using_bwd.timeit(500)}") 127 128# Benchmark with more outputs than inputs 129Din = 2048 130Dout = 32 131weight = torch.randn(Dout, Din) 132bias = torch.randn(Dout) 133x = torch.randn(Din) 134 135using_fwd = Timer(stmt="jacfwd(predict, argnums=2)(weight, bias, x)", globals=globals()) 136using_bwd = Timer(stmt="jacrev(predict, argnums=2)(weight, bias, x)", globals=globals()) 137print(f"jacfwd time: {using_fwd.timeit(500)}") 138print(f"jacrev time: {using_bwd.timeit(500)}") 139 140###################################################################### 141# Hessian computation with functorch.hessian 142# -------------------------------------------------------------------- 143# We offer a convenience API to compute hessians: functorch.hessian. 144# Hessians are the jacobian of the jacobian, which suggests that one can just 145# compose functorch's jacobian transforms to compute one. 146# Indeed, under the hood, ``hessian(f)`` is simply ``jacfwd(jacrev(f))`` 147# 148# Depending on your model, you may want to use ``jacfwd(jacfwd(f))`` or 149# ``jacrev(jacrev(f))`` instead to compute hessians. 150from functorch import hessian 151 152 153# # TODO: make sure PyTorch has tanh_backward implemented for jvp!! 154# hess0 = hessian(predict, argnums=2)(weight, bias, x) 155# hess1 = jacfwd(jacfwd(predict, argnums=2), argnums=2)(weight, bias, x) 156hess2 = jacrev(jacrev(predict, argnums=2), argnums=2)(weight, bias, x) 157 158###################################################################### 159# Batch Jacobian (and Batch Hessian) 160# -------------------------------------------------------------------- 161# In the above examples we've been operating with a single feature vector. 162# In some cases you might want to take the Jacobian of a batch of outputs 163# with respect to a batch of inputs where each input produces an independent 164# output. That is, given a batch of inputs of shape (B, N) and a function 165# that goes from (B, N) -> (B, M), we would like a Jacobian of shape (B, M, N). 166# The easiest way to do this is to sum over the batch dimension and then 167# compute the Jacobian of that function: 168 169 170def predict_with_output_summed(weight, bias, x): 171 return predict(weight, bias, x).sum(0) 172 173 174batch_size = 64 175Din = 31 176Dout = 33 177weight = torch.randn(Dout, Din) 178bias = torch.randn(Dout) 179x = torch.randn(batch_size, Din) 180 181batch_jacobian0 = jacrev(predict_with_output_summed, argnums=2)(weight, bias, x) 182 183# If you instead have a function that goes from R^N -> R^M but inputs that are 184# batched, you compose vmap with jacrev to compute batched jacobians: 185 186compute_batch_jacobian = vmap(jacrev(predict, argnums=2), in_dims=(None, None, 0)) 187batch_jacobian1 = compute_batch_jacobian(weight, bias, x) 188assert torch.allclose(batch_jacobian0, batch_jacobian1) 189 190# Finally, batch hessians can be computed similarly. It's easiest to think about 191# them by using vmap to batch over hessian computation, but in some cases the sum 192# trick also works. 193compute_batch_hessian = vmap(hessian(predict, argnums=2), in_dims=(None, None, 0)) 194batch_hess = compute_batch_hessian(weight, bias, x) 195