Name Date Size #Lines LOC

..--

_src/H25-Apr-2025-4023

benchmarks/H25-Apr-2025-895650

compile/H25-Apr-2025-3130

csrc/H25-Apr-2025-4,3663,751

dim/H25-Apr-2025-2,2601,717

docs/H25-Apr-2025-1,069653

einops/H25-Apr-2025-518418

examples/H25-Apr-2025-3,0512,166

experimental/H25-Apr-2025-1612

notebooks/H25-Apr-2025-3,9443,692

op_analysis/H25-Apr-2025-1,3791,346

.gitignoreH A D25-Apr-2025233 2220

CMakeLists.txtH A D25-Apr-20252.2 KiB4640

COMPILE_README.mdH A D25-Apr-20252.9 KiB7658

README.mdH A D25-Apr-202510.6 KiB363273

__init__.pyH A D25-Apr-20251 KiB4021

writing_batching_rules.mdH A D25-Apr-20258.1 KiB9772

README.md

1# functorch
2
3[**Why functorch?**](#why-composable-function-transforms)
4| [**Install guide**](#install)
5| [**Transformations**](#what-are-the-transforms)
6| [**Documentation**](#documentation)
7| [**Future Plans**](#future-plans)
8
9**This library is currently under heavy development - if you have suggestions
10on the API or use-cases you'd like to be covered, please open an github issue
11or reach out. We'd love to hear about how you're using the library.**
12
13`functorch` is [JAX-like](https://github.com/google/jax) composable function
14transforms for PyTorch.
15
16It aims to provide composable `vmap` and `grad` transforms that work with
17PyTorch modules and PyTorch autograd with good eager-mode performance.
18
19In addition, there is experimental functionality to trace through these
20transformations using FX in order to capture the results of these transforms
21ahead of time. This would allow us to compile the results of vmap or grad
22to improve performance.
23
24## Why composable function transforms?
25
26There are a number of use cases that are tricky to do in
27PyTorch today:
28- computing per-sample-gradients (or other per-sample quantities)
29- running ensembles of models on a single machine
30- efficiently batching together tasks in the inner-loop of MAML
31- efficiently computing Jacobians and Hessians
32- efficiently computing batched Jacobians and Hessians
33
34Composing `vmap`, `grad`, `vjp`, and `jvp` transforms allows us to express the above
35without designing a separate subsystem for each. This idea of composable function
36transforms comes from the [JAX framework](https://github.com/google/jax).
37
38## Install
39
40There are two ways to install functorch:
411. functorch from source
422. functorch beta (compatible with recent PyTorch releases)
43
44We recommend trying out the functorch beta first.
45
46### Installing functorch from source
47
48<details><summary>Click to expand</summary>
49<p>
50
51#### Using Colab
52
53Follow the instructions [in this Colab notebook](https://colab.research.google.com/drive/1CrLkqIrydBYP_svnF89UUO-aQEqNPE8x?usp=sharing)
54
55#### Locally
56
57As of 9/21/2022, `functorch` comes installed alongside a nightly PyTorch binary.
58Please install a Preview (nightly) PyTorch binary; see  https://pytorch.org/
59for instructions.
60
61Once you've done that, run a quick sanity check in Python:
62```py
63import torch
64from functorch import vmap
65x = torch.randn(3)
66y = vmap(torch.sin)(x)
67assert torch.allclose(y, x.sin())
68```
69
70#### functorch development setup
71
72As of 9/21/2022, `functorch` comes installed alongside PyTorch and is in the
73PyTorch source tree. Please install
74[PyTorch from source](https://github.com/pytorch/pytorch#from-source), then,
75you will be able to `import functorch`.
76
77Try to run some tests to make sure all is OK:
78```bash
79pytest test/test_vmap.py -v
80pytest test/test_eager_transforms.py -v
81```
82
83AOTAutograd has some additional optional requirements. You can install them via:
84```bash
85pip install networkx
86```
87
88To run functorch tests, please install our test dependencies (`expecttest`, `pyyaml`).
89
90
91</p>
92</details>
93
94### Installing functorch beta (compatible with recent PyTorch releases)
95
96<details><summary>Click to expand</summary>
97<p>
98
99#### Using Colab
100
101Follow the instructions [here](https://colab.research.google.com/drive/1GNfb01W_xf8JRu78ZKoNnLqiwcrJrbYG#scrollTo=HJ1srOGeNCGA)
102
103#### pip
104
105Prerequisite: [Install PyTorch](https://pytorch.org/get-started/locally/)
106
107
108```bash
109pip install functorch
110```
111
112Finally, run a quick sanity check in python:
113```py
114import torch
115from functorch import vmap
116x = torch.randn(3)
117y = vmap(torch.sin)(x)
118assert torch.allclose(y, x.sin())
119```
120
121</p>
122</details>
123
124## What are the transforms?
125
126Right now, we support the following transforms:
127- `grad`, `vjp`, `jvp`,
128- `jacrev`, `jacfwd`, `hessian`
129- `vmap`
130
131Furthermore, we have some utilities for working with PyTorch modules.
132- `make_functional(model)`
133- `make_functional_with_buffers(model)`
134
135### vmap
136
137Note: `vmap` imposes restrictions on the code that it can be used on.
138For more details, please read its docstring.
139
140`vmap(func)(*inputs)` is a transform that adds a dimension to all Tensor
141operations in `func`. `vmap(func)` returns a new function that maps `func` over
142some dimension (default: 0) of each Tensor in `inputs`.
143
144`vmap` is useful for hiding batch dimensions: one can write a function `func`
145that runs on examples and then lift it to a function that can take batches of
146examples with `vmap(func)`, leading to a simpler modeling experience:
147
148```py
149from functorch import vmap
150batch_size, feature_size = 3, 5
151weights = torch.randn(feature_size, requires_grad=True)
152
153def model(feature_vec):
154    # Very simple linear model with activation
155    assert feature_vec.dim() == 1
156    return feature_vec.dot(weights).relu()
157
158examples = torch.randn(batch_size, feature_size)
159result = vmap(model)(examples)
160```
161
162### grad
163
164`grad(func)(*inputs)` assumes `func` returns a single-element Tensor. It compute
165the gradients of the output of func w.r.t. to `inputs[0]`.
166
167```py
168from functorch import grad
169x = torch.randn([])
170cos_x = grad(lambda x: torch.sin(x))(x)
171assert torch.allclose(cos_x, x.cos())
172
173# Second-order gradients
174neg_sin_x = grad(grad(lambda x: torch.sin(x)))(x)
175assert torch.allclose(neg_sin_x, -x.sin())
176```
177
178When composed with `vmap`, `grad` can be used to compute per-sample-gradients:
179```py
180from functorch import vmap
181batch_size, feature_size = 3, 5
182
183def model(weights,feature_vec):
184    # Very simple linear model with activation
185    assert feature_vec.dim() == 1
186    return feature_vec.dot(weights).relu()
187
188def compute_loss(weights, example, target):
189    y = model(weights, example)
190    return ((y - target) ** 2).mean()  # MSELoss
191
192weights = torch.randn(feature_size, requires_grad=True)
193examples = torch.randn(batch_size, feature_size)
194targets = torch.randn(batch_size)
195inputs = (weights,examples, targets)
196grad_weight_per_example = vmap(grad(compute_loss), in_dims=(None, 0, 0))(*inputs)
197```
198
199### vjp
200
201The `vjp` transform applies `func` to `inputs` and returns a new function that
202computes vjps given some `cotangents` Tensors.
203```py
204from functorch import vjp
205outputs, vjp_fn = vjp(func, inputs); vjps = vjp_fn(*cotangents)
206```
207
208### jvp
209
210The `jvp` transforms computes Jacobian-vector-products and is also known as
211"forward-mode AD". It is not a higher-order function unlike most other transforms,
212but it returns the outputs of `func(inputs)` as well as the `jvp`s.
213```py
214from functorch import jvp
215x = torch.randn(5)
216y = torch.randn(5)
217f = lambda x, y: (x * y)
218_, output = jvp(f, (x, y), (torch.ones(5), torch.ones(5)))
219assert torch.allclose(output, x + y)
220```
221
222### jacrev, jacfwd, and hessian
223
224The `jacrev` transform returns a new function that takes in `x` and returns the
225Jacobian of `torch.sin` with respect to `x` using reverse-mode AD.
226```py
227from functorch import jacrev
228x = torch.randn(5)
229jacobian = jacrev(torch.sin)(x)
230expected = torch.diag(torch.cos(x))
231assert torch.allclose(jacobian, expected)
232```
233Use `jacrev` to compute the jacobian. This can be composed with vmap to produce
234batched jacobians:
235
236```py
237x = torch.randn(64, 5)
238jacobian = vmap(jacrev(torch.sin))(x)
239assert jacobian.shape == (64, 5, 5)
240```
241
242`jacfwd` is a drop-in replacement for `jacrev` that computes Jacobians using
243forward-mode AD:
244```py
245from functorch import jacfwd
246x = torch.randn(5)
247jacobian = jacfwd(torch.sin)(x)
248expected = torch.diag(torch.cos(x))
249assert torch.allclose(jacobian, expected)
250```
251
252Composing `jacrev` with itself or `jacfwd` can produce hessians:
253```py
254def f(x):
255  return x.sin().sum()
256
257x = torch.randn(5)
258hessian0 = jacrev(jacrev(f))(x)
259hessian1 = jacfwd(jacrev(f))(x)
260```
261
262The `hessian` is a convenience function that combines `jacfwd` and `jacrev`:
263```py
264from functorch import hessian
265
266def f(x):
267  return x.sin().sum()
268
269x = torch.randn(5)
270hess = hessian(f)(x)
271```
272
273### Tracing through the transformations
274We can also trace through these transformations in order to capture the results as new code using `make_fx`. There is also experimental integration with the NNC compiler (only works on CPU for now!).
275
276```py
277from functorch import make_fx, grad
278def f(x):
279    return torch.sin(x).sum()
280x = torch.randn(100)
281grad_f = make_fx(grad(f))(x)
282print(grad_f.code)
283
284def forward(self, x_1):
285    sin = torch.ops.aten.sin(x_1)
286    sum_1 = torch.ops.aten.sum(sin, None);  sin = None
287    cos = torch.ops.aten.cos(x_1);  x_1 = None
288    _tensor_constant0 = self._tensor_constant0
289    mul = torch.ops.aten.mul(_tensor_constant0, cos);  _tensor_constant0 = cos = None
290    return mul
291```
292
293### Working with NN modules: make_functional and friends
294
295Sometimes you may want to perform a transform with respect to the parameters
296and/or buffers of an nn.Module. This can happen for example in:
297- model ensembling, where all of your weights and buffers have an additional
298dimension
299- per-sample-gradient computation where you want to compute per-sample-grads
300of the loss with respect to the model parameters
301
302Our solution to this right now is an API that, given an nn.Module, creates a
303stateless version of it that can be called like a function.
304
305- `make_functional(model)` returns a functional version of `model` and the
306`model.parameters()`
307- `make_functional_with_buffers(model)` returns a functional version of
308`model` and the `model.parameters()` and `model.buffers()`.
309
310Here's an example where we compute per-sample-gradients using an nn.Linear
311layer:
312
313```py
314import torch
315from functorch import make_functional, vmap, grad
316
317model = torch.nn.Linear(3, 3)
318data = torch.randn(64, 3)
319targets = torch.randn(64, 3)
320
321func_model, params = make_functional(model)
322
323def compute_loss(params, data, targets):
324    preds = func_model(params, data)
325    return torch.mean((preds - targets) ** 2)
326
327per_sample_grads = vmap(grad(compute_loss), (None, 0, 0))(params, data, targets)
328```
329
330If you're making an ensemble of models, you may find
331`combine_state_for_ensemble` useful.
332
333## Documentation
334
335For more documentation, see [our docs website](https://pytorch.org/functorch).
336
337## Debugging
338`torch._C._functorch.dump_tensor`: Dumps dispatch keys on stack
339`torch._C._functorch._set_vmap_fallback_warning_enabled(False)` if the vmap warning spam bothers you.
340
341## Future Plans
342
343In the end state, we'd like to upstream this into PyTorch once we iron out the
344design details. To figure out the details, we need your help -- please send us
345your use cases by starting a conversation in the issue tracker or trying our
346project out.
347
348## License
349Functorch has a BSD-style license, as found in the [LICENSE](LICENSE) file.
350
351## Citing functorch
352
353If you use functorch in your publication, please cite it by using the following BibTeX entry.
354
355```bibtex
356@Misc{functorch2021,
357  author =       {Horace He, Richard Zou},
358  title =        {functorch: JAX-like composable function transforms for PyTorch},
359  howpublished = {\url{https://github.com/pytorch/functorch}},
360  year =         {2021}
361}
362```
363