xref: /aosp_15_r20/external/pytorch/functorch/notebooks/_src/plot_jacobians_and_hessians.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1"""
2=============================
3Jacobians, hessians, and more
4=============================
5
6Computing jacobians or hessians are useful in a number of non-traditional
7deep learning models. It is difficult (or annoying) to compute these quantities
8efficiently using a standard autodiff system like PyTorch Autograd; functorch
9provides ways of computing various higher-order autodiff quantities efficiently.
10"""
11from functools import partial
12
13import torch
14import torch.nn.functional as F
15
16
17torch.manual_seed(0)
18
19
20######################################################################
21# Setup: Comparing functorch vs the naive approach
22# --------------------------------------------------------------------
23# Let's start with a function that we'd like to compute the jacobian of.
24# This is a simple linear function with non-linear activation.
25def predict(weight, bias, x):
26    return F.linear(x, weight, bias).tanh()
27
28
29# Here's some dummy data: a weight, a bias, and a feature vector.
30D = 16
31weight = torch.randn(D, D)
32bias = torch.randn(D)
33x = torch.randn(D)
34
35# Let's think of ``predict`` as a function that maps the input ``x`` from R^D -> R^D.
36# PyTorch Autograd computes vector-Jacobian products. In order to compute the full
37# Jacobian of this R^D -> R^D function, we would have to compute it row-by-row
38# by using a different unit vector each time.
39xp = x.clone().requires_grad_()
40unit_vectors = torch.eye(D)
41
42
43def compute_jac(xp):
44    jacobian_rows = [
45        torch.autograd.grad(predict(weight, bias, xp), xp, vec)[0]
46        for vec in unit_vectors
47    ]
48    return torch.stack(jacobian_rows)
49
50
51jacobian = compute_jac(xp)
52
53# Instead of computing the jacobian row-by-row, we can use ``vmap`` to get rid
54# of the for-loop and vectorize the computation. We can't directly apply vmap
55# to PyTorch Autograd; instead, functorch provides a ``vjp`` transform:
56from functorch import vjp, vmap
57
58
59_, vjp_fn = vjp(partial(predict, weight, bias), x)
60(ft_jacobian,) = vmap(vjp_fn)(unit_vectors)
61assert torch.allclose(ft_jacobian, jacobian)
62
63# In another tutorial a composition of reverse-mode AD and vmap gave us
64# per-sample-gradients. In this tutorial, composing reverse-mode AD and vmap
65# gives us Jacobian computation! Various compositions of vmap and autodiff
66# transforms can give us different interesting quantities.
67#
68# functorch provides ``jacrev`` as a convenience function that performs
69# the vmap-vjp composition to compute jacobians. ``jacrev`` accepts an argnums
70# argument that says which argument we would like to compute Jacobians with
71# respect to.
72from functorch import jacrev
73
74
75ft_jacobian = jacrev(predict, argnums=2)(weight, bias, x)
76assert torch.allclose(ft_jacobian, jacobian)
77
78# Let's compare the performance of the two ways to compute jacobian.
79# The functorch version is much faster (and becomes even faster the more outputs
80# there are). In general, we expect that vectorization via ``vmap`` can help
81# eliminate overhead and give better utilization of your hardware.
82from torch.utils.benchmark import Timer
83
84
85without_vmap = Timer(stmt="compute_jac(xp)", globals=globals())
86with_vmap = Timer(stmt="jacrev(predict, argnums=2)(weight, bias, x)", globals=globals())
87print(without_vmap.timeit(500))
88print(with_vmap.timeit(500))
89
90# It's pretty easy to flip the problem around and say we want to compute
91# Jacobians of the parameters to our model (weight, bias) instead of the input.
92ft_jac_weight, ft_jac_bias = jacrev(predict, argnums=(0, 1))(weight, bias, x)
93
94######################################################################
95# reverse-mode Jacobian (jacrev) vs forward-mode Jacobian (jacfwd)
96# --------------------------------------------------------------------
97# We offer two APIs to compute jacobians: jacrev and jacfwd:
98# - jacrev uses reverse-mode AD. As you saw above it is a composition of our
99#   vjp and vmap transforms.
100# - jacfwd uses forward-mode AD. It is implemented as a composition of our
101#   jvp and vmap transforms.
102# jacfwd and jacrev can be subsituted for each other and have different
103# performance characteristics.
104#
105# As a general rule of thumb, if you're computing the jacobian of an R^N -> R^M
106# function, if there are many more outputs than inputs (i.e. M > N) then jacfwd is
107# preferred, otherwise use jacrev. There are exceptions to this rule, but a
108# non-rigorous argument for this follows:
109
110# In reverse-mode AD, we are computing the jacobian row-by-row, while in
111# forward-mode AD (which computes Jacobian-vector products), we are computing
112# it column-by-column. The Jacobian matrix has M rows and N columns.
113from functorch import jacfwd, jacrev
114
115
116# Benchmark with more inputs than outputs
117Din = 32
118Dout = 2048
119weight = torch.randn(Dout, Din)
120bias = torch.randn(Dout)
121x = torch.randn(Din)
122
123using_fwd = Timer(stmt="jacfwd(predict, argnums=2)(weight, bias, x)", globals=globals())
124using_bwd = Timer(stmt="jacrev(predict, argnums=2)(weight, bias, x)", globals=globals())
125print(f"jacfwd time: {using_fwd.timeit(500)}")
126print(f"jacrev time: {using_bwd.timeit(500)}")
127
128# Benchmark with more outputs than inputs
129Din = 2048
130Dout = 32
131weight = torch.randn(Dout, Din)
132bias = torch.randn(Dout)
133x = torch.randn(Din)
134
135using_fwd = Timer(stmt="jacfwd(predict, argnums=2)(weight, bias, x)", globals=globals())
136using_bwd = Timer(stmt="jacrev(predict, argnums=2)(weight, bias, x)", globals=globals())
137print(f"jacfwd time: {using_fwd.timeit(500)}")
138print(f"jacrev time: {using_bwd.timeit(500)}")
139
140######################################################################
141# Hessian computation with functorch.hessian
142# --------------------------------------------------------------------
143# We offer a convenience API to compute hessians: functorch.hessian.
144# Hessians are the jacobian of the jacobian, which suggests that one can just
145# compose functorch's jacobian transforms to compute one.
146# Indeed, under the hood, ``hessian(f)`` is simply ``jacfwd(jacrev(f))``
147#
148# Depending on your model, you may want to use ``jacfwd(jacfwd(f))`` or
149# ``jacrev(jacrev(f))`` instead to compute hessians.
150from functorch import hessian
151
152
153# # TODO: make sure PyTorch has tanh_backward implemented for jvp!!
154# hess0 = hessian(predict, argnums=2)(weight, bias, x)
155# hess1 = jacfwd(jacfwd(predict, argnums=2), argnums=2)(weight, bias, x)
156hess2 = jacrev(jacrev(predict, argnums=2), argnums=2)(weight, bias, x)
157
158######################################################################
159# Batch Jacobian (and Batch Hessian)
160# --------------------------------------------------------------------
161# In the above examples we've been operating with a single feature vector.
162# In some cases you might want to take the Jacobian of a batch of outputs
163# with respect to a batch of inputs where each input produces an independent
164# output. That is, given a batch of inputs of shape (B, N) and a function
165# that goes from (B, N) -> (B, M), we would like a Jacobian of shape (B, M, N).
166# The easiest way to do this is to sum over the batch dimension and then
167# compute the Jacobian of that function:
168
169
170def predict_with_output_summed(weight, bias, x):
171    return predict(weight, bias, x).sum(0)
172
173
174batch_size = 64
175Din = 31
176Dout = 33
177weight = torch.randn(Dout, Din)
178bias = torch.randn(Dout)
179x = torch.randn(batch_size, Din)
180
181batch_jacobian0 = jacrev(predict_with_output_summed, argnums=2)(weight, bias, x)
182
183# If you instead have a function that goes from R^N -> R^M but inputs that are
184# batched, you compose vmap with jacrev to compute batched jacobians:
185
186compute_batch_jacobian = vmap(jacrev(predict, argnums=2), in_dims=(None, None, 0))
187batch_jacobian1 = compute_batch_jacobian(weight, bias, x)
188assert torch.allclose(batch_jacobian0, batch_jacobian1)
189
190# Finally, batch hessians can be computed similarly. It's easiest to think about
191# them by using vmap to batch over hessian computation, but in some cases the sum
192# trick also works.
193compute_batch_hessian = vmap(hessian(predict, argnums=2), in_dims=(None, None, 0))
194batch_hess = compute_batch_hessian(weight, bias, x)
195