xref: /aosp_15_r20/external/pytorch/test/functorch/test_eager_transforms.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: functorch"]
2
3# Copyright (c) Facebook, Inc. and its affiliates.
4# All rights reserved.
5#
6# This source code is licensed under the BSD-style license found in the
7# LICENSE file in the root directory of this source tree.
8
9import copy
10import math
11import os
12import subprocess
13import sys
14import unittest
15import warnings
16from functools import partial, wraps
17
18# NB: numpy is a testing dependency!
19import numpy as np
20from common_utils import expectedFailureIf
21
22import functorch
23import torch
24import torch.autograd.forward_ad as fwAD
25import torch.nn as nn
26import torch.nn.functional as F
27from functorch import (
28    combine_state_for_ensemble,
29    grad,
30    grad_and_value,
31    hessian,
32    jacfwd,
33    jacrev,
34    jvp,
35    make_functional,
36    make_functional_with_buffers,
37    make_fx,
38    vjp,
39    vmap,
40)
41from functorch.experimental import functionalize, replace_all_batch_norm_modules_
42from torch._C import _ExcludeDispatchKeyGuard, DispatchKey, DispatchKeySet
43from torch._dynamo import allow_in_graph
44from torch._functorch.eager_transforms import _slice_argnums
45from torch._functorch.make_functional import (
46    functional_init,
47    functional_init_with_buffers,
48)
49from torch._functorch.utils import enable_single_level_autograd_function
50from torch._ops import HigherOrderOperator
51from torch._subclasses.fake_tensor import FakeTensorMode
52from torch.func import functional_call, linearize, stack_module_state
53from torch.testing import make_tensor
54from torch.testing._internal.common_cuda import (
55    SM70OrLater,
56    TEST_CUDA,
57    tf32_on_and_off,
58    with_tf32_off,
59)
60from torch.testing._internal.common_device_type import (
61    dtypes,
62    instantiate_device_type_tests,
63    onlyCPU,
64    onlyCUDA,
65)
66from torch.testing._internal.common_dtype import get_all_fp_dtypes
67from torch.testing._internal.common_utils import (
68    freeze_rng_state,
69    instantiate_parametrized_tests,
70    IS_FBCODE,
71    IS_WINDOWS,
72    markDynamoStrictTest,
73    parametrize,
74    run_tests,
75    skipIfRocm,
76    skipIfTorchDynamo,
77    subtest,
78    TEST_WITH_TORCHDYNAMO,
79    TestCase,
80    xfailIfTorchDynamo,
81)
82from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten
83
84
85USE_TORCHVISION = False
86try:
87    import torchvision  # noqa: F401
88
89    USE_TORCHVISION = True
90except ImportError:
91    warnings.warn(
92        "Couldn't import torchvision. Some of our tests use it, try "
93        "to install it with commands from pytorch.org, post-fixed with "
94        "`--no-deps` to avoid overwriting the pytorch installation",
95        UserWarning,
96    )
97
98# TestCase for _slice_argnums, an important helper function
99
100
101class VmapTearDownMixin:
102    def tearDown(self):
103        # Ensure that in the case of a test failure, the next test won't fail
104        # because of a previous call to _vmap_increment_nesting that wasn't undone
105        # i.e. test_vmap_free_tensor fails when PYTORCH_TEST_WITH_DYNAMO=1
106        # and the call to increment nesting is not undone
107        if not TEST_WITH_TORCHDYNAMO:
108            return
109
110        warn = False
111        while ci := torch._C._functorch.peek_interpreter_stack():
112            if ci.key() == torch._C._functorch.TransformType.Vmap:
113                warn = True
114                torch._C._functorch._vmap_decrement_nesting()
115            else:
116                break
117
118        if warn:
119            msg = (
120                "Interpreter stack is not empty. Test should have called "
121                "'torch._C._functorch._vmap_decrement_nesting()'"
122            )
123            warnings.warn(msg)
124
125
126@markDynamoStrictTest
127class TestSliceArgnums(TestCase):
128    def test_invalid_argnum_type(self):
129        x = torch.randn(3)
130        args = (x,)
131        with self.assertRaisesRegex(RuntimeError, "int or Tuple"):
132            _slice_argnums(args, 0.0)
133        with self.assertRaisesRegex(RuntimeError, "int or Tuple"):
134            _slice_argnums(args, [0])
135        with self.assertRaisesRegex(RuntimeError, "must be int"):
136            _slice_argnums(args, (0.0,))
137
138        args = (0.1, 1.1, 2.1, 3.1, 4.1)
139
140        with self.assertRaisesRegex(RuntimeError, "must be int"):
141            _slice_argnums(args, ((0, 1), 2))
142
143    def test_out_of_bounds_argnum_values(self):
144        x = torch.randn(3)
145        args = (x,)
146        with self.assertRaisesRegex(RuntimeError, "positional inputs"):
147            _slice_argnums(args, 1)
148        with self.assertRaisesRegex(RuntimeError, "positional inputs"):
149            _slice_argnums(args, -2)
150        with self.assertRaisesRegex(RuntimeError, "positional inputs"):
151            _slice_argnums(args, (-2,))
152
153    def test_not_enough_argnums(self):
154        x = torch.randn(3)
155        args = (x,)
156        with self.assertRaisesRegex(RuntimeError, "must be non-empty"):
157            _slice_argnums(args, ())
158
159    def test_duplicate_argnums(self):
160        x = torch.randn(3)
161        args = (x, x)
162        with self.assertRaisesRegex(RuntimeError, "must be unique"):
163            _slice_argnums(args, (0, 0))
164        with self.assertRaisesRegex(RuntimeError, "must be unique"):
165            _slice_argnums(args, (0, -2))
166
167    def test_flat_args_with_positive_int_argnum(self):
168        args = (0.1, 1.1, 2.1, 3.1, 4.1)
169
170        res = _slice_argnums(args, 0)
171        self.assertEqual(res, (0.1,))
172
173        res = _slice_argnums(args, 4)
174        self.assertEqual(res, (4.1,))
175
176    def test_flat_args_with_negative_int_argnum(self):
177        args = (0.1, 1.1, 2.1, 3.1, 4.1)
178
179        res = _slice_argnums(args, -1)
180        self.assertEqual(res, (4.1,))
181
182        res = _slice_argnums(args, -5)
183        self.assertEqual(res, (0.1,))
184
185    def test_flat_args_with_tuple_argnum(self):
186        args = (0.1, 1.1, 2.1, 3.1, 4.1)
187
188        res = _slice_argnums(args, (0, 1, 2, 3, 4))
189        self.assertEqual(res, args)
190
191        res = _slice_argnums(args, (0, -3))
192        self.assertEqual(res, (0.1, 2.1))
193
194    def test_pytree_args(self):
195        args = ((0.1, 1.1), 2.0, [3.1])
196
197        res = _slice_argnums(args, 0)
198        self.assertEqual(res, args[0:1])
199
200        res = _slice_argnums(args, (0,))
201        self.assertEqual(res, args[0:1])
202
203        res = _slice_argnums(args, -1)
204        self.assertEqual(res, args[-1:])
205
206        res = _slice_argnums(args, (0, -2))
207        self.assertEqual(res, args[0:2])
208
209    def test_argnums_reorders(self):
210        args = ((0.1, 1.1, 2.1), 3.1, 4.1)
211
212        res = _slice_argnums(args, (1, 0))
213        self.assertEqual(res, (args[1], args[0]))
214
215
216def _get_weights_and_functional_call(net, mechanism):
217    if mechanism == "make_functional":
218        return make_functional(net)
219    else:
220        assert mechanism == "functional_call"
221        # this makes it so the function from make_functional and this call have the same signature
222
223        def net_func(weights, data):
224            return functional_call(net, weights, (data,))
225
226        return net_func, dict(net.named_parameters())
227
228
229def _get_weights_and_functional_call_with_buffers(net, mechanism):
230    if mechanism == "make_functional":
231        return make_functional_with_buffers(net)
232    else:
233        assert mechanism == "functional_call"
234
235        # this makes it so the function from make_functional and this call have the same signature
236        def net_func(weights, buffers, data):
237            return functional_call(net, (weights, buffers), (data,))
238
239        return net_func, dict(net.named_parameters()), dict(net.named_buffers())
240
241
242@markDynamoStrictTest
243class TestGradTransform(TestCase):
244    def test_primitive(self, device):
245        x = torch.randn([], device=device)
246        result = grad(torch.sin)(x)
247        self.assertEqual(result, torch.cos(x))
248
249    def test_composite_simple(self, device):
250        x = torch.randn(2, 3, 4, device=device)
251        result = grad(lambda x: torch.flatten(x).sum())(x)
252        self.assertEqual(result, torch.ones_like(x))
253
254    def test_fn_with_kwargs(self, device):
255        def foo(x, y):
256            return (x * y).sum()
257
258        x = torch.randn(3, device=device)
259        y = torch.randn(3, device=device)
260        expected = grad(foo)(x, y)
261        result = grad(foo)(x, y=y)
262        self.assertEqual(result, expected)
263
264    def test_composite_complicated(self, device):
265        x = torch.randn(3, device=device)
266        y = torch.randn(3, 5, device=device)
267
268        def foo(x, y):
269            result = x @ y
270            return result.sum()
271
272        result = grad(foo)(x, y)
273
274        x.requires_grad_()
275        out = foo(x, y)
276        (expected,) = torch.autograd.grad(out, x)
277
278        self.assertEqual(result, expected)
279
280    def test_composite_two_ops(self, device):
281        N, C = 2, 5
282        y = torch.randn(N, C, device=device)
283        targets = torch.randint(0, C, (N,), device=device)
284
285        def foo(y, targets):
286            return F.cross_entropy(y, targets)
287
288        result = grad(foo)(y, targets)
289
290        y.requires_grad_()
291        (expected,) = torch.autograd.grad(foo(y, targets), y)
292
293        self.assertEqual(result, expected)
294
295    def _test_attributes(self, get_attr_lambda, device):
296        x = torch.randn(2, 3, 5, dtype=torch.double, device=device)
297        expected = get_attr_lambda(x)
298
299        def foo(x):
300            self.assertEqual(get_attr_lambda(x), expected)
301            return x.sum()
302
303        grad(foo)(x)
304
305    def test_shape(self, device):
306        self._test_attributes(lambda x: x.shape, device)
307
308    def test_dtype(self, device):
309        self._test_attributes(lambda x: x.dtype, device)
310
311    def test_is_cuda(self, device):
312        self._test_attributes(lambda x: x.is_cuda, device)
313
314    def test_numel(self, device):
315        self._test_attributes(lambda x: x.numel(), device)
316
317    def test_inplace(self, device):
318        x = torch.randn([], device=device)
319
320        def foo(x):
321            return x.clone().sin_()
322
323        result = grad(foo)(x)
324        self.assertEqual(result, x.cos())
325
326    def test_inplace_on_view(self, device):
327        x = torch.randn(3, device=device)
328
329        def foo(x):
330            y = x.clone()
331            y0 = y[0]
332            y0.sin_()
333            return y.sum()
334
335        result = grad(foo)(x)
336
337        x.requires_grad_()
338        out = foo(x)
339        (expected,) = torch.autograd.grad(out, x)
340
341        self.assertEqual(result, expected)
342
343    def test_inplace_on_view_base(self, device):
344        x = torch.randn(3, device=device)
345
346        def foo(x):
347            y = x.clone()
348            y0 = y[0]
349            y.sin_()
350            return y0
351
352        result = grad(foo)(x)
353
354        x.requires_grad_()
355        out = foo(x)
356        (expected,) = torch.autograd.grad(out, x)
357
358        self.assertEqual(result, expected)
359
360    def test_inplace_on_captures(self, device):
361        x = torch.tensor([1.0, 2.0, 3.0], device=device)
362        captured = torch.randn(3, device=device)
363
364        def foo(x):
365            captured.copy_(x)
366            return (x * captured).sum()
367
368        with self.assertRaisesRegex(RuntimeError, "mutate a captured Tensor"):
369            grad(foo)(x)
370
371    def test_nesting_simple(self, device):
372        x = torch.randn([], device=device)
373        result = grad(grad(torch.sin))(x)
374        self.assertEqual(result, -torch.sin(x))
375
376    @skipIfTorchDynamo("Ref: https://github.com/pytorch/pytorch/issues/103613")
377    def test_escaped_wrappers_are_marked_as_dead(self, device):
378        x = torch.randn([], device=device)
379        escaped = []
380
381        def foo(x):
382            y = x.sin()
383            escaped.append(y)
384            return y
385
386        grad(foo)(x)
387        self.assertEqual(torch._C._functorch.dlevel(escaped[0]), -1)
388
389    @skipIfTorchDynamo("Ref: https://github.com/pytorch/pytorch/issues/103613")
390    def test_escaped_wrappers_are_ignored(self, device):
391        x = torch.randn([], device=device)
392        escaped = []
393
394        def foo(x):
395            y = x.sin()
396            escaped.append(y)
397            return y
398
399        grad(foo)(x)
400
401        something = escaped[0].sum()
402        self.assertEqual(torch._C._functorch.dlevel(something), 0)
403        self.assertEqual(something, x.sin().sum())
404
405    def test_manual_seed_inside_grad(self, device):
406        x = torch.randn([], device=device)
407
408        def f(x):
409            torch.manual_seed(0)
410            return x * torch.randn_like(x)
411
412        with freeze_rng_state():
413            result = grad(f)(x)
414            x.requires_grad_()
415            (expected,) = torch.autograd.grad(f(x), x)
416            self.assertEqual(result, expected)
417
418    def test_vjp(self, device):
419        x = torch.randn([], device=device)
420        out, vjp_fn = vjp(torch.sin, x)
421        self.assertEqual(out, x.sin())
422
423        v = torch.randn([], device=device)
424        (result,) = vjp_fn(v)
425        self.assertEqual(result, v * x.cos())
426
427    def test_vjp_two_outputs(self, device):
428        def f(x):
429            return x, x
430
431        result, vjp_fn = vjp(f, torch.tensor(1.0))
432        vjp_fn(result)
433
434    def test_conj_bit(self):
435        x = torch.tensor(1 + 1j)
436
437        def foo(x):
438            assert not x.is_conj()
439            y = x.conj()
440            assert y.is_conj()
441            return y.abs()
442
443        res = grad(foo)(x)
444        with torch.no_grad():
445            self.assertEqual(res, torch.ones_like(res) * torch.sgn(x))
446
447    def test_composed_with_autograd(self, device):
448        x = torch.randn([], requires_grad=True, device=device)
449
450        y = grad(torch.sin)(x)
451        (result,) = torch.autograd.grad(y, x)
452        self.assertEqual(result, -x.sin())
453
454    def test_grad_of_vjp_composition(self, device):
455        x = torch.randn([], device=device)
456        y = torch.randn([], device=device)
457
458        def foo(x, y):
459            out, vjp_fn = vjp(torch.sin, x)
460            return grad(lambda y: vjp_fn(y)[0])(y)
461
462        result = foo(x, y)
463        expected = x.cos()
464        self.assertEqual(result, expected)
465
466    def test_vjp_of_grad_composition(self, device):
467        x = torch.randn([], device=device)
468        y = torch.randn([], device=device)
469
470        def foo(x, y):
471            out, vjp_fn = vjp(grad(torch.sin), x)
472            return vjp_fn(y)[0]
473
474        result = foo(x, y)
475        expected = -y * x.sin()
476        self.assertEqual(result, expected)
477
478    def test_grad_of_vjp_of_grad_composition(self, device):
479        x = torch.randn([], device=device)
480        y = torch.randn([], device=device)
481
482        def foo(x, y):
483            df, vjp_fn = vjp(grad(lambda x: -torch.cos(x)), x)
484            return grad(lambda y: vjp_fn(y)[0])(y)
485
486        result = foo(x, y)
487        expected = x.cos()
488        self.assertEqual(result, expected)
489
490    def test_views(self, device):
491        x = torch.randn([], requires_grad=True, device=device)
492        y = torch.randn([], requires_grad=True, device=device)
493
494        def silly_sin(x):
495            x = x.view([])
496            x = x.sin()
497            return x
498
499        def foo(x, y):
500            z1 = grad(silly_sin)(x)
501            z2 = torch.cos(y)
502            return z1 + z2
503
504        result = foo(x, y)
505        grads = torch.autograd.grad(result, [x, y])
506        self.assertEqual(grads[0], -x.sin())
507        self.assertEqual(grads[1], -y.sin())
508
509    def test_view_inplace_simple(self, device):
510        def foo(x):
511            x = x.clone()
512            x.view([]).sin_()
513            return x
514
515        x = torch.randn([], requires_grad=True, device=device)
516        result = grad(foo)(x)
517        self.assertEqual(result, x.cos())
518
519    def test_invalid_argnums(self, device):
520        x = torch.randn([])
521        y = torch.randn([])
522        with self.assertRaisesRegex(RuntimeError, "but only"):
523            grad(torch.mul, argnums=-3)(x, y)
524        with self.assertRaisesRegex(RuntimeError, "but only"):
525            grad(torch.mul, argnums=2)(x, y)
526        with self.assertRaisesRegex(RuntimeError, "int or Tuple"):
527            grad(torch.mul, argnums=[0])(x, y)
528        with self.assertRaisesRegex(RuntimeError, "must be int"):
529            grad(torch.mul, argnums=("0",))(x, y)
530        with self.assertRaisesRegex(RuntimeError, "must be unique"):
531            grad(torch.mul, argnums=(0, 0))(x, y)
532        with self.assertRaisesRegex(RuntimeError, "must be unique"):
533            grad(torch.mul, argnums=(0, -2))(x, y)
534
535    def test_argnums(self, device):
536        x = torch.randn([])
537        y = torch.randn([])
538        gx = grad(torch.mul, argnums=0)(x, y)
539        self.assertEqual(gx, y)
540
541        gy = grad(torch.mul, argnums=1)(x, y)
542        self.assertEqual(gy, x)
543
544        (gx,) = grad(torch.mul, argnums=(0,))(x, y)
545        self.assertEqual(gx, y)
546
547        gx, gy = grad(torch.mul, argnums=(0, 1))(x, y)
548        self.assertEqual(gx, y)
549        self.assertEqual(gy, x)
550
551    def test_out_of_order_argnums(self, device):
552        x = torch.randn([])
553        y = torch.randn([])
554        gy, gx = grad(torch.mul, argnums=(1, 0))(x, y)
555        self.assertEqual(gx, y)
556        self.assertEqual(gy, x)
557
558    def test_negative_argnums(self, device):
559        x = torch.randn([])
560        y = torch.randn([])
561        gx = grad(torch.mul, argnums=-2)(x, y)
562        self.assertEqual(gx, y)
563
564        gy = grad(torch.mul, argnums=-1)(x, y)
565        self.assertEqual(gy, x)
566
567        (gx,) = grad(torch.mul, argnums=(-2,))(x, y)
568        self.assertEqual(gx, y)
569
570        gx, gy = grad(torch.mul, argnums=(-2, -1))(x, y)
571        self.assertEqual(gx, y)
572        self.assertEqual(gy, x)
573
574    def test_grad_pytree_inputs(self, device):
575        x = torch.randn([], device=device)
576
577        def f(a, b):
578            x, y = a
579            return 1 * x + 2 * y + 3 * b["foo"]
580
581        args = ((x, x), {"foo": x})
582
583        gx, gy = grad(f)(*args)
584        self.assertEqual(gx, torch.tensor(1.0, device=device))
585        self.assertEqual(gy, torch.tensor(2.0, device=device))
586
587        ((gx, gy),) = grad(f, argnums=(0,))(*args)
588        self.assertEqual(gx, torch.tensor(1.0, device=device))
589        self.assertEqual(gy, torch.tensor(2.0, device=device))
590
591        (gx, gy), gz = grad(f, argnums=(0, 1))(*args)
592        self.assertEqual(gx, torch.tensor(1.0, device=device))
593        self.assertEqual(gy, torch.tensor(2.0, device=device))
594        self.assertEqual(gz["foo"], torch.tensor(3.0, device=device))
595
596    def test_grad_aux_tensor(self, device):
597        x = torch.randn(3, device=device)
598
599        with self.assertRaisesRegex(
600            RuntimeError,
601            r"grad_and_value\(f\)\(\*args\): output of function f should be a tuple",
602        ):
603            grad(lambda t: [t, t], has_aux=True)(x)
604
605        with self.assertRaisesRegex(
606            RuntimeError,
607            r"grad_and_value\(f\)\(\*args\): output of function f should be a tuple",
608        ):
609            grad(lambda t: (t, t + 2, t + 3), has_aux=True)(x)
610
611        def f(t):
612            y = t.sin()
613            return y.sum(), t.cos()
614
615        out, aux = grad(f, has_aux=True)(x)
616        self.assertEqual(aux, x.cos())
617        self.assertEqual(out, x.cos())
618
619    def test_grad_aux_pytree(self, device):
620        def f(x):
621            y = x.sin()
622            return y.sum(), {"a": x.cos(), "b": [x.tan()]}
623
624        x = torch.randn(3, device=device)
625
626        out, aux = grad(f, has_aux=True)(x)
627        _, expected_aux = f(x)
628        self.assertEqual(aux, expected_aux)
629        self.assertEqual(out, x.cos())
630
631        for aux in [1, 1.0, "abc"]:
632            with self.assertRaisesRegex(
633                RuntimeError, r"Expected tensors, got unsupported type"
634            ):
635                _ = grad(lambda x: (x.sum(), aux), has_aux=True)(x)
636            with self.assertRaisesRegex(
637                RuntimeError, r"Expected tensors, got unsupported type"
638            ):
639                _ = grad(lambda x: (x.sum(), [x, aux]), has_aux=True)(x)
640
641    def test_zero_grad(self, device):
642        def f(x):
643            return (x["a"] ** 2.0).sum()
644
645        inps = {
646            "a": torch.randn(10, device=device) + 3,
647            "b": torch.randn(10, device=device),
648        }
649        grads = grad(f)(inps)
650        self.assertNotEqual(grads["a"].sum(), 0.0)
651        self.assertEqual(grads["b"].sum(), 0.0)
652
653    def test_unrelated_grad(self, device):
654        x = torch.tensor(1.0, device=device)
655        y = torch.tensor(2.0, device=device)
656
657        def unrelated(x):
658            return y
659
660        result = grad(unrelated)(x)
661        self.assertEqual(result, torch.zeros_like(x))
662
663    def test_unrelated_vjp(self, device):
664        x = torch.tensor(1.0, device=device)
665        y = torch.tensor(2.0, device=device)
666        v = torch.tensor(1.0, device=device)
667
668        def unrelated(x):
669            return y
670
671        out, vjp_fn = vjp(unrelated, x)
672        result = vjp_fn(v)
673        expected = (torch.zeros_like(x),)
674        self.assertEqual(result, expected)
675
676    def test_unrelated_vjp_multiple_inputs_outputs(self, device):
677        w = torch.tensor(3.0, device=device)
678        x = torch.tensor(4.0, device=device)
679        y = torch.tensor(2.0, device=device)
680        v = torch.tensor(1.0, device=device)
681
682        def unrelated(w, x):
683            return y, y, x
684
685        out, vjp_fn = vjp(unrelated, w, x)
686        result = vjp_fn((v, v, v))
687        expected = (torch.zeros_like(x), torch.ones_like(x))
688        self.assertEqual(result, expected)
689
690    # TODO: https://github.com/zou3519/functorch/issues/12
691    @onlyCPU
692    def test_unrelated_hessian(self, device):
693        N = 5
694        M = 3
695        W = torch.randn(N, M, device=device)
696
697        def f(x):
698            return W @ x
699
700        x = torch.randn(M)
701        result = jacrev(jacrev(f))(x)
702        expected = torch.zeros(N, M, M, device=device)
703        self.assertEqual(result, expected)
704
705    def test_vjp_pytree_input(self, device):
706        def f(x):
707            return x[0] * x[1][0]
708
709        x = torch.randn([], device=device)
710        v = torch.randn([], device=device)
711        out, vjp_fn = vjp(f, (x, (x, x)))
712        self.assertEqual(out, x * x)
713        result = vjp_fn(v)
714        self.assertEqual(result, ((x * v, (x * v, 0.0)),))
715
716    def test_vjp_pytree_output(self, device):
717        def f(x):
718            return x, (x, x)
719
720        x = torch.randn([], device=device)
721        v1 = torch.randn([], device=device)
722        v2 = torch.randn([], device=device)
723        v3 = torch.randn([], device=device)
724        _, vjp_fn = vjp(f, x)
725        (result,) = vjp_fn((v1, (v2, v3)))
726        self.assertEqual(result, v1 + v2 + v3)
727
728    def test_vjp_outputs_can_any_pytree(self, device):
729        x = torch.randn(2, 3, device=device)
730        t = torch.randn(2, 3, device=device)
731
732        for output in [None, ()]:
733            with self.assertRaisesRegex(
734                RuntimeError,
735                r"vjp\(f, \*primals\): Expected f to be a function that has non-empty output",
736            ):
737                _, vjp_fn = vjp(lambda _: output, x)
738                vjp_fn(t)
739
740        for output in [1, True, 12.2, "abc"]:
741            with self.assertRaisesRegex(
742                RuntimeError,
743                r"vjp\(f, \*primals\): expected f\(\*primals\) to return only tensors",
744            ):
745                _, vjp_fn = vjp(lambda _: output, x)
746                vjp_fn(t)
747
748        # Check list output
749        output, vjp_fn = vjp(lambda x: [x, x.sum()], x)
750        (vjp_out,) = vjp_fn([t, t.sum()])
751        assert isinstance(output, list) and len(output) == 2
752        assert isinstance(vjp_out, torch.Tensor)
753
754        # Check dict output
755        output, vjp_fn = vjp(lambda x: {"x": x, "xsum": x.sum()}, x)
756        (vjp_out,) = vjp_fn({"x": t, "xsum": t.sum()})
757        assert isinstance(output, dict) and len(output) == 2 and "xsum" in output
758        assert isinstance(vjp_out, torch.Tensor)
759
760        def composite_output(x):
761            out = x.sum()
762            return [
763                (out, {"a": x, "out": [x, out]}),
764            ]
765
766        output, vjp_fn = vjp(composite_output, x)
767        (vjp_out,) = vjp_fn(
768            [
769                (t.sum(), {"a": t, "out": [t, t.sum()]}),
770            ]
771        )
772        assert isinstance(output, list)
773        assert isinstance(output[0], tuple) and isinstance(output[0][1], dict)
774        assert isinstance(vjp_out, torch.Tensor)
775
776    def test_vjp_pytree_error(self, device):
777        def f(x):
778            return x, (x, x)
779
780        x = torch.randn([], device=device)
781        v1 = torch.randn([], device=device)
782        v2 = torch.randn([], device=device)
783        v3 = torch.randn([], device=device)
784        _, vjp_fn = vjp(f, x)
785        with self.assertRaisesRegex(RuntimeError, "Expected pytree structure"):
786            (result,) = vjp_fn(((v1, (v2, v3)),))
787
788    def test_vjp_aux_tensor(self, device):
789        x = torch.randn(3, device=device)
790
791        with self.assertRaisesRegex(
792            RuntimeError, r"vjp\(f, \*primals\): output of function f should be a tuple"
793        ):
794            vjp(lambda t: [t, t], x, has_aux=True)
795
796        with self.assertRaisesRegex(
797            RuntimeError, r"vjp\(f, \*primals\): output of function f should be a tuple"
798        ):
799            vjp(lambda t: (t, t + 2, t + 3), x, has_aux=True)
800
801        def f(t):
802            y = t.sin()
803            return y, t.cos()
804
805        out, vjp_fn, aux = vjp(f, x, has_aux=True)
806        self.assertEqual(aux, x.cos())
807        self.assertEqual(out, x.sin())
808
809        v = torch.randn(3, device=device)
810        (grad_x,) = vjp_fn(v)
811        self.assertEqual(grad_x, v * x.cos())
812
813    def test_vjp_aux_pytree(self, device):
814        def f(x):
815            y = x.sin()
816            return y, {"a": x.cos(), "b": [x.tan()]}
817
818        x = torch.randn(3, device=device)
819
820        out, vjp_fn, aux = vjp(f, x, has_aux=True)
821        expected_out, expected_aux = f(x)
822        self.assertEqual(out, expected_out)
823        self.assertEqual(aux, expected_aux)
824
825        v = torch.randn(3, device=device)
826        (grad_x,) = vjp_fn(v)
827        self.assertEqual(grad_x, v * x.cos())
828
829        for aux in [1, 1.0, "abc"]:
830            with self.assertRaisesRegex(
831                RuntimeError, r"Expected tensors, got unsupported type"
832            ):
833                _ = vjp(lambda x: (x, aux), x, has_aux=True)
834            with self.assertRaisesRegex(
835                RuntimeError, r"Expected tensors, got unsupported type"
836            ):
837                _ = vjp(lambda x: (x, [x, aux]), x, has_aux=True)
838
839    def test_functional_init(self, device):
840        class MLPClassifier(nn.Module):
841            def __init__(self, hidden_dim=32, n_classes=2):
842                super().__init__()
843                self.hidden_dim = hidden_dim
844                self.n_classes = n_classes
845
846                self.fc1 = nn.Linear(2, self.hidden_dim)
847                self.fc2 = nn.Linear(self.hidden_dim, self.n_classes)
848
849            def forward(self, x):
850                x = self.fc1(x)
851                x = F.relu(x)
852                x = self.fc2(x)
853                x = F.log_softmax(x, -1)
854                return x
855
856        B = 10
857        weights, fn, _ = functional_init(MLPClassifier, (B,), device=device)(32, 2)
858        inputs = torch.randn(B, 7, 2, device=device)
859        vmap(fn)(weights, (inputs,))
860
861    def test_functional_init_with_buffers(self, device):
862        class MLPClassifier(nn.Module):
863            def __init__(self, hidden_dim=32, n_classes=2):
864                super().__init__()
865                self.hidden_dim = hidden_dim
866                self.n_classes = n_classes
867
868                self.fc1 = nn.Linear(2, self.hidden_dim)
869                self.bn = nn.BatchNorm1d(self.hidden_dim, affine=True)
870                self.fc2 = nn.Linear(self.hidden_dim, self.n_classes)
871
872            def forward(self, x):
873                x = self.fc1(x)
874                x = F.relu(x)
875                x = self.bn(x)
876                x = self.fc2(x)
877                x = F.log_softmax(x, -1)
878                return x
879
880        B = 10
881        weights, buffers, fn, _, _ = functional_init_with_buffers(
882            MLPClassifier, [B], device=device
883        )(32, 2)
884        inputs = torch.randn(B, 7, 2, device=device)
885        vmap(fn)(weights, buffers, (inputs,))
886
887    def test_advanced_indexing(self, device):
888        def f(value):
889            log_prob = torch.ones((), device=device)
890            val = torch.zeros(()) > 0
891            log_prob[val] = 0
892            return value
893
894        result = grad(f)(torch.randn((), device=device))
895        self.assertEqual(result, torch.ones_like(result))
896
897        def f2(value):
898            value = value.clone()
899            value[value > 0] = 0
900            return value.sum()
901
902        x = torch.randn(100, device=device)
903        result = grad(f2)(x)
904        self.assertEqual(result, (x <= 0).type_as(x))
905
906    def test_tensor_ctor_inside_grad(self, device):
907        def foo(x):
908            return x * torch.tensor(2.0, device=device)
909
910        x = torch.tensor(3.14, device=device)
911        functorch.grad(foo)(x)
912
913    @parametrize(
914        "op_list_data",
915        [
916            subtest(
917                (
918                    [
919                        vmap,
920                    ],
921                    [(4, 2), (64, 3, 32, 32)],
922                ),
923                name="vmap",
924            ),
925            subtest(([vmap, vmap], [(4, 3, 2), (64, 3, 32, 32)]), name="vmap_vmap"),
926            subtest(
927                (
928                    [
929                        grad,
930                    ],
931                    [(0,), [], (4, 2), (64, 3, 32, 32)],
932                ),
933                name="grad",
934            ),
935            subtest(
936                (
937                    [grad, grad],
938                    [
939                        [],
940                    ],
941                ),
942                name="grad_grad",
943            ),
944            subtest(([vmap, grad], [(4, 2)]), name="vmap_grad"),
945        ],
946    )
947    def test_tensor_print(self, device, op_list_data):
948        op_list, shapes = op_list_data
949
950        for dt in get_all_fp_dtypes():
951            data = [torch.randn(s, dtype=dt, device=device) for s in shapes]
952
953            for x in data:
954                buf = None
955
956                def foo(t):
957                    nonlocal buf
958                    buf = repr(t)
959                    return t.mean()
960
961                fn = foo
962                bdim = 0
963                for op in reversed(op_list):
964                    if op == vmap:
965                        fn = op(fn, in_dims=bdim)
966                        bdim += 1
967                    else:
968                        fn = op(fn)
969
970                expected = f"{repr(x)}"
971                level = 0
972                for op in op_list:
973                    level += 1
974                    if op == grad:
975                        expected = f"GradTrackingTensor(lvl={level}, value={expected})"
976                    elif op == vmap:
977                        bdim -= 1
978                        expected = (
979                            f"BatchedTensor(lvl={level}, bdim={bdim}, value={expected})"
980                        )
981
982                fn(x)
983                buf = buf.replace("\n", "").replace("  ", "")
984                expected = expected.replace("\n", "").replace("  ", "")
985                self.assertEqual(expected, buf)
986
987    def test_print_captured_tensor_inside_transform(self, device):
988        x = torch.tensor([1.0, 2.0, 3.0], device=device)
989        out = None
990
991        def f(y):
992            nonlocal out
993            out = repr(x)
994            return y
995
996        vjp(f, torch.randn(4, device=device))
997        self.assertEqual(out, repr(x))
998
999    def test_no_grad_outside(self, device):
1000        x = torch.randn([], device=device, requires_grad=True)
1001        with torch.no_grad():
1002            y = grad(torch.sin)(x)
1003        self.assertEqual(y, x.cos())
1004        self.assertFalse(y.requires_grad)
1005
1006    def test_no_grad_inside(self, device):
1007        def f(x):
1008            with torch.no_grad():
1009                shift = x**2
1010            return x**2 - shift
1011
1012        x = torch.randn([], device=device)
1013        y = grad(f)(x)
1014        self.assertEqual(y, 2 * x)
1015        y = grad(grad(f))(x)
1016        self.assertEqual(y, 2)
1017
1018        x = torch.randn([], device=device, requires_grad=True)
1019        y = grad(f)(x)
1020        (z,) = torch.autograd.grad(y, x)
1021        self.assertEqual(z, 2)
1022
1023    def test_no_grad_mixed(self, device):
1024        def f(x):
1025            with torch.no_grad():
1026                shift = x**2
1027            return x**2 - shift
1028
1029        x = torch.randn([], device=device, requires_grad=True)
1030        with torch.no_grad():
1031            y = grad(f)(x)
1032
1033        self.assertEqual(y, 2 * x)
1034        self.assertFalse(y.requires_grad)
1035
1036    def test_no_grad_nested_simple(self, device):
1037        def h(x):
1038            with torch.no_grad():
1039                shift = grad(lambda x: 0.25 * x**4)(x)
1040            return x**3 - shift
1041
1042        x = torch.tensor(1.5, device=device, requires_grad=True)
1043        y = grad(h)(x)
1044        self.assertEqual(y, 3 * x**2)
1045
1046        (z,) = torch.autograd.grad(y, x)
1047        self.assertEqual(z, 6 * x)
1048
1049    def test_no_grad_nested_complicated(self, device):
1050        def f(x):
1051            with torch.no_grad():
1052                shift = x**3
1053            return x**3 - shift
1054
1055        def g(x):
1056            r1 = grad(f)(x)
1057            with torch.no_grad():
1058                shift = grad(f)(x)
1059            return r1 - shift
1060
1061        x = torch.randn([], requires_grad=True, device=device)
1062        y = grad(g)(x)
1063        # The only differential part of g is x ** 3
1064        self.assertEqual(y, 6 * x)
1065
1066        (z,) = torch.autograd.grad(y, x)
1067        self.assertEqual(z, 6)
1068
1069    def test_no_grad_value(self, device):
1070        def h(x):
1071            with torch.no_grad():
1072                gvalue, value = grad_and_value(lambda x: x**3)(x)
1073            return x**3 - value
1074
1075        x = torch.tensor(1.6, device=device, requires_grad=True)
1076        y = grad(h)(x)
1077        self.assertEqual(y, 3 * x**2)
1078
1079        (z,) = torch.autograd.grad(y, x)
1080        self.assertEqual(z, 6 * x)
1081
1082    def test_no_grad_outside_vjp(self, device):
1083        def h(x):
1084            return x**2
1085
1086        x = torch.tensor(2.0, requires_grad=True, device=device)
1087        with torch.no_grad():
1088            out, vjp_fn = vjp(h, x)
1089            (y,) = vjp_fn(torch.tensor(1.0, device=device))
1090
1091        self.assertEqual(y, 2 * x)
1092        self.assertFalse(y.requires_grad)
1093        self.assertFalse(out.requires_grad)
1094
1095    def test_no_grad_outside_vjp_fn(self, device):
1096        def h(x):
1097            return x**2
1098
1099        x = torch.tensor(3.14, requires_grad=True, device=device)
1100        out, vjp_fn = vjp(h, x)
1101        with torch.no_grad():
1102            (y,) = vjp_fn(torch.tensor(1.0, device=device))
1103
1104        self.assertEqual(y, 2 * x)
1105        self.assertFalse(y.requires_grad)
1106        self.assertTrue(out.requires_grad)
1107
1108        (z,) = torch.autograd.grad(out, x)
1109        self.assertEqual(z, 2 * x)
1110
1111    def test_no_grad_outside_vjp_only(self, device):
1112        def h(x):
1113            return x**2
1114
1115        x = torch.tensor(3.14, requires_grad=True, device=device)
1116        with torch.no_grad():
1117            out, vjp_fn = vjp(h, x)
1118        (y,) = vjp_fn(torch.tensor(1.0, device=device))
1119
1120        self.assertEqual(y, 2 * x)
1121        self.assertFalse(out.requires_grad)
1122
1123        # This one is a little weird...
1124        self.assertTrue(y.requires_grad)
1125
1126        (z,) = torch.autograd.grad(y, x)
1127        self.assertEqual(z, 2)
1128
1129
1130@markDynamoStrictTest
1131class TestAutogradFunction(TestCase):
1132    def test_set_materialize_grads(self, device):
1133        class A(torch.autograd.Function):
1134            @staticmethod
1135            def forward(x, y):
1136                return x, y
1137
1138            @staticmethod
1139            def setup_context(ctx, inputs, output):
1140                ctx.set_materialize_grads(False)
1141
1142            @staticmethod
1143            def backward(ctx, gx, gy):
1144                self.assertIsNotNone(gx)
1145                self.assertIsNone(gy)
1146                return gx, gy
1147
1148        def f(y, x):
1149            x, y = A.apply(x, y)
1150            return x**2
1151
1152        x = torch.tensor(2.0, device=device)
1153        y = torch.tensor(3.0, device=device)
1154        # grad differentiates w.r.t. arg 0 by default
1155        grad(f)(y, x)
1156        grad(grad(f))(y, x)
1157
1158    @parametrize("inner_requires_grad", [True, False])
1159    @parametrize("save_for", ["jvp", "vjp"])
1160    @parametrize("save_tensors", ["input", "output", "neither"])
1161    @parametrize("mark_dirty", [True, False])
1162    def test_function_returns_input(
1163        self, device, inner_requires_grad, save_for, save_tensors, mark_dirty
1164    ):
1165        class A(torch.autograd.Function):
1166            @staticmethod
1167            def forward(x):
1168                return x
1169
1170            @staticmethod
1171            def setup_context(ctx, inputs, output):
1172                if save_for == "jvp":
1173                    save_fn = ctx.save_for_forward
1174                else:
1175                    save_fn = ctx.save_for_backward
1176
1177                if mark_dirty:
1178                    ctx.mark_dirty(inputs[0])
1179
1180                if save_tensors == "input":
1181                    save_fn(inputs[0])
1182                elif save_tensors == "output":
1183                    save_fn(output)
1184                elif save_tensors == "neither":
1185                    pass
1186
1187            @staticmethod
1188            def backward(ctx, grad_output):
1189                return grad_output
1190
1191            @staticmethod
1192            def jvp(ctx, x_t):
1193                # NB: the logic to check ctx.save_for_forward happens
1194                #     before we reach this!
1195                if mark_dirty:
1196                    ret = x_t.add_(0)
1197                else:
1198                    ret = x_t.view_as(x_t)
1199                return ret
1200
1201        def fn(x):
1202            return A.apply(x.clone())
1203
1204        err_msg = "A input that has been returned as-is"
1205
1206        a = torch.tensor(2.0, device=device, requires_grad=inner_requires_grad)
1207        a_t = torch.tensor(2.0, device=device, requires_grad=inner_requires_grad)
1208        if save_tensors in ("input", "output") and not mark_dirty:
1209            with self.assertRaisesRegex(RuntimeError, err_msg):
1210                grad(fn)(a)
1211            with self.assertRaisesRegex(RuntimeError, err_msg):
1212                jvp(fn, (a,), (a_t,))
1213        else:
1214            grad(fn)(a)
1215            jvp(fn, (a,), (a_t,))
1216
1217        a = torch.tensor(2.0, device=device, requires_grad=inner_requires_grad).clone()
1218        a_t = torch.tensor(
1219            2.0, device=device, requires_grad=inner_requires_grad
1220        ).clone()
1221
1222        if save_tensors in ("input", "output") and not mark_dirty:
1223            with self.assertRaisesRegex(RuntimeError, err_msg):
1224                A.apply(a)
1225            with self.assertRaisesRegex(RuntimeError, err_msg):
1226                with fwAD.dual_level():
1227                    A.apply(fwAD.make_dual(a, a_t))
1228        else:
1229            b = A.apply(a)
1230            if mark_dirty:
1231                self.assertTrue(a is b)
1232            if not (
1233                mark_dirty and save_for == "vjp" and save_tensors in ("input", "output")
1234            ):
1235                # TODO(soulitzer): https://github.com/pytorch/pytorch/issues/97827
1236                with fwAD.dual_level():
1237                    a_dual = fwAD.make_dual(a, a_t)
1238                    b_dual = A.apply(a_dual)
1239                if mark_dirty:
1240                    self.assertTrue(a_dual is b_dual)
1241
1242    def test_needs_input_grads(self, device):
1243        class A(torch.autograd.Function):
1244            @staticmethod
1245            def forward(x, y):
1246                return x * y
1247
1248            @staticmethod
1249            def setup_context(ctx, inputs, output):
1250                return
1251
1252            @staticmethod
1253            def backward(ctx, grad_output):
1254                self.assertTrue(ctx.needs_input_grad[0])
1255                self.assertFalse(ctx.needs_input_grad[1])
1256                return None, None
1257
1258        x = torch.tensor(2.0, device=device)
1259        y = torch.tensor(3.0, device=device)
1260        # grad differentiates w.r.t. arg 0 by default
1261        grad(A.apply)(x, y)
1262        grad(grad(A.apply))(x, y)
1263
1264    def _get_NumpyCubeNotComposable(self):
1265        class NumpyCubeNotComposable(torch.autograd.Function):
1266            @staticmethod
1267            def forward(input):
1268                input_np = input.cpu().numpy()
1269                return torch.tensor(input_np**3, device=input.device), input_np
1270
1271            @staticmethod
1272            def setup_context(ctx, inputs, output):
1273                ctx.input_np = output[1]
1274                ctx.device = inputs[0].device
1275
1276            @staticmethod
1277            @torch.autograd.function.once_differentiable
1278            def backward(ctx, grad_output, grad_saved):
1279                result_np = 3 * (ctx.input_np**2)
1280                return torch.tensor(result_np, device=ctx.device)
1281
1282        return NumpyCubeNotComposable
1283
1284    def test_once_differentiable_autograd_vjp(self, device):
1285        NumpyCubeNotComposable = self._get_NumpyCubeNotComposable()
1286
1287        def f(x):
1288            y, _ = NumpyCubeNotComposable.apply(x)
1289            return y
1290
1291        # regular autograd x vjp
1292        x = torch.randn([], requires_grad=True, device=device)
1293        grad_y = torch.randn_like(x, requires_grad=True)
1294        _, vjp_fn = vjp(f, x)
1295        (gx,) = vjp_fn(grad_y)
1296
1297        with self.assertRaisesRegex(RuntimeError, "marked with @once_differentiable"):
1298            gx.backward()
1299
1300    # TODO: support torch.autograd.function.once_differentiable
1301    # (or, if impossible, figure out how to raise a nice error)
1302    # https://github.com/pytorch/pytorch/issues/90224
1303    @unittest.expectedFailure
1304    def test_once_differentiable_grad_vjp(self, device):
1305        NumpyCubeNotComposable = self._get_NumpyCubeNotComposable()
1306
1307        # grad x vjp
1308        x = torch.randn([], device=device)
1309        grad_y = torch.randn_like(x)
1310
1311        def h(x, grad_y):
1312            _, vjp_fn = vjp(f, x)  # noqa: F821
1313            (gx,) = vjp_fn(grad_y)
1314            return gx
1315
1316        grad(h, argnums=(0, 1))(x, grad_y)
1317
1318    def test_grad_fn_name(self, device):
1319        names = []
1320
1321        class FooBar(torch.autograd.Function):
1322            @staticmethod
1323            def forward(x):
1324                return x.clone()
1325
1326            @staticmethod
1327            def setup_context(ctx, inputs, output):
1328                return
1329
1330            @staticmethod
1331            def backward(ctx, grad_output):
1332                return grad_output
1333
1334        def f(x):
1335            y = FooBar.apply(x)
1336            names.append(type(y.grad_fn).__name__)
1337            return y
1338
1339        x = torch.tensor(1.0)
1340        grad(f)(x)
1341        self.assertEqual(names, ["FooBarGeneratedBackward"])
1342
1343
1344@markDynamoStrictTest
1345class TestAutogradFunctionVmapAPI(TestCase):
1346    def test_no_vmap_staticmethod_and_no_generate_vmap_rule(self, device):
1347        class NumpyCube(torch.autograd.Function):
1348            @staticmethod
1349            def forward(input):
1350                input_np = to_numpy(input)  # noqa: F821
1351                dinput = torch.tensor(3 * input_np**2, device=input.device)
1352                return torch.tensor(input_np**3, device=input.device), dinput
1353
1354            @staticmethod
1355            def setup_context(ctx, inputs, output):
1356                ctx.save_for_backward(inputs, output[1])
1357
1358            @staticmethod
1359            def backward(ctx, grad_output, grad_saved):
1360                raise RuntimeError("foobar")
1361
1362        x = torch.randn(3, device=device)
1363        with self.assertRaisesRegex(RuntimeError, "does not have vmap support"):
1364            vmap(NumpyCube.apply)(x)
1365
1366    def test_has_vmap_staticmethod_and_has_generate_vmap_rule(self, device):
1367        class NumpyCube(torch.autograd.Function):
1368            generate_vmap_rule = True
1369
1370            @staticmethod
1371            def forward(input):
1372                input_np = to_numpy(input)  # noqa: F821
1373                dinput = torch.tensor(3 * input_np**2, device=input.device)
1374                return torch.tensor(input_np**3, device=input.device), dinput
1375
1376            @staticmethod
1377            def setup_context(ctx, outputs, input):
1378                ctx.save_for_backward(input, outputs[1])
1379
1380            @staticmethod
1381            def backward(ctx, grad_output, grad_saved):
1382                raise RuntimeError("foobar")
1383
1384            @staticmethod
1385            def vmap(infos, in_dims, x):
1386                raise RuntimeError("foobar")
1387
1388        x = torch.randn(3, device=device)
1389        with self.assertRaisesRegex(RuntimeError, "generate_vmap_rule=True and"):
1390            vmap(NumpyCube.apply)(x)
1391
1392    def test_info_object(self, device):
1393        batch_size = 10
1394
1395        class Id(torch.autograd.Function):
1396            @staticmethod
1397            def forward(input):
1398                pass
1399
1400            @staticmethod
1401            def setup_context(ctx, inputs, output):
1402                pass
1403
1404            @staticmethod
1405            def backward(ctx, grad_output, grad_saved):
1406                pass
1407
1408            @staticmethod
1409            def vmap(info, in_dims, input):
1410                self.assertEqual(info.batch_size, batch_size)
1411                self.assertEqual(info.randomness, randomness)
1412                return input, in_dims[0]
1413
1414        x = torch.randn(batch_size, 3, device=device)
1415
1416        for randomness in ("error", "different", "same"):
1417            vmap(Id.apply, randomness=randomness)(x)
1418
1419    def test_in_dims_single_input(self, device):
1420        class Id(torch.autograd.Function):
1421            @staticmethod
1422            def forward(input):
1423                pass
1424
1425            @staticmethod
1426            def setup_context(ctx, inputs, output):
1427                pass
1428
1429            @staticmethod
1430            def backward(ctx, grad_output, grad_saved):
1431                pass
1432
1433            @staticmethod
1434            def vmap(info, in_dims, input):
1435                self.assertEqual(in_dims, (1,))
1436                return input, in_dims[0]
1437
1438        B = 10
1439        x = torch.randn(3, B, device=device)
1440        vmap(Id.apply, in_dims=1)(x)
1441        vmap(Id.apply, in_dims=(1,))(x)
1442
1443    def test_in_dims_multiple_inputs(self, device):
1444        class Id(torch.autograd.Function):
1445            @staticmethod
1446            def forward(x, y):
1447                pass
1448
1449            @staticmethod
1450            def setup_context(ctx, inputs, output):
1451                pass
1452
1453            @staticmethod
1454            def backward(ctx, grad_output, grad_saved):
1455                pass
1456
1457            @staticmethod
1458            def vmap(info, in_dims, x, y):
1459                self.assertEqual(in_dims, (0, [0, 0]))
1460                self.assertTrue(isinstance(in_dims, tuple))
1461                self.assertTrue(isinstance(in_dims[1], list))
1462                return (x, y), in_dims
1463
1464        x = torch.randn(2, device=device)
1465        vmap(Id.apply)(x, [x, x])
1466
1467    def test_skips_empty_layer(self, device):
1468        class Id(torch.autograd.Function):
1469            @staticmethod
1470            def forward(input):
1471                return input
1472
1473            @staticmethod
1474            def setup_context(ctx, inputs, output):
1475                pass
1476
1477            @staticmethod
1478            def backward(ctx, grad_output, grad_saved):
1479                pass
1480
1481            @staticmethod
1482            def vmap(info, in_dims, input):
1483                raise RuntimeError("expected to not be called")
1484
1485        def f(x):
1486            y = torch.tensor(1.0)
1487            y = Id.apply(y)
1488            return x * 1
1489
1490        x = torch.randn(2, 3)
1491        vmap(f)(x)
1492
1493    def test_none_returns(self, device):
1494        class Zeros(torch.autograd.Function):
1495            @staticmethod
1496            def forward(input):
1497                return torch.zeros(input.shape, device=input.device)
1498
1499            @staticmethod
1500            def setup_context(ctx, inputs, output):
1501                pass
1502
1503            @staticmethod
1504            def vmap(info, in_dims, input):
1505                assert in_dims == (0,)
1506                return torch.zeros(input.shape[1:], device=input.device), None
1507
1508        B = 2
1509        x = torch.randn(B, 3)
1510        y = vmap(Zeros.apply)(x)
1511        self.assertEqual(y, torch.zeros_like(x))
1512
1513        class TwoZeros(torch.autograd.Function):
1514            @staticmethod
1515            def forward(input):
1516                r = torch.zeros(input.shape, device=input.device)
1517                return r, r
1518
1519            @staticmethod
1520            def setup_context(ctx, inputs, output):
1521                pass
1522
1523            @staticmethod
1524            def vmap(info, in_dims, input):
1525                assert in_dims == (0,)
1526                r = torch.zeros(input.shape[1:], device=input.device)
1527                return (r, r), None
1528
1529        B = 2
1530        x = torch.randn(B, 3)
1531        result = vmap(TwoZeros.apply)(x)
1532
1533        self.assertTrue(isinstance(result, tuple))
1534        y, z = result
1535        self.assertEqual(y, torch.zeros_like(x))
1536        self.assertEqual(z, torch.zeros_like(x))
1537
1538    def test_should_have_two_returns(self, device):
1539        class Zeros(torch.autograd.Function):
1540            @staticmethod
1541            def forward(input):
1542                r = torch.zeros(input.shape, device=input.device)
1543                return r
1544
1545            @staticmethod
1546            def setup_context(ctx, inputs, output):
1547                pass
1548
1549            @staticmethod
1550            def vmap(info, in_dims, input):
1551                r = torch.zeros(input.shape[1:], device=input.device)
1552                return r
1553
1554        B = 2
1555        x = torch.randn(B, 3)
1556        with self.assertRaisesRegex(RuntimeError, "to have two returns"):
1557            result = vmap(Zeros.apply)(x)
1558
1559        class TwoZeros(torch.autograd.Function):
1560            @staticmethod
1561            def forward(input):
1562                r = torch.zeros(input.shape, device=input.device)
1563                return r, r
1564
1565            @staticmethod
1566            def setup_context(ctx, inputs, output):
1567                pass
1568
1569            @staticmethod
1570            def vmap(info, in_dims, input):
1571                r = torch.zeros(input.shape[1:], device=input.device)
1572                return r, r, 0, 0
1573
1574        B = 2
1575        x = torch.randn(B, 3)
1576        with self.assertRaisesRegex(RuntimeError, "to have two returns"):
1577            result = vmap(Zeros.apply)(x)
1578
1579    def test_incompatible_out_dims_error_msg(self, device):
1580        class Zeros(torch.autograd.Function):
1581            @staticmethod
1582            def forward(input):
1583                r = torch.zeros(input.shape, device=input.device)
1584                return r
1585
1586            @staticmethod
1587            def setup_context(ctx, inputs, output):
1588                pass
1589
1590            @staticmethod
1591            def vmap(info, in_dims, input):
1592                r = torch.zeros(input.shape[1:], device=input.device)
1593                return r, (None,)
1594
1595        B = 2
1596        x = torch.randn(B, 3)
1597        with self.assertRaisesRegex(RuntimeError, "returned an incompatible"):
1598            result = vmap(Zeros.apply)(x)
1599
1600        class Zeros(torch.autograd.Function):
1601            @staticmethod
1602            def forward(input):
1603                r = torch.zeros(input.shape, device=input.device)
1604                return [r]
1605
1606            @staticmethod
1607            def setup_context(ctx, inputs, output):
1608                pass
1609
1610            @staticmethod
1611            def vmap(info, in_dims, input):
1612                r = torch.zeros(input.shape[1:], device=input.device)
1613                return [r], (None,)
1614
1615        B = 2
1616        x = torch.randn(B, 3)
1617        with self.assertRaisesRegex(RuntimeError, "returned an incompatible"):
1618            result = vmap(Zeros.apply)(x)
1619
1620    def test_kwarg_only_tensors(self, device):
1621        with self.assertRaisesRegex(NotImplementedError, "kwarg-only Tensor args"):
1622
1623            class MyClass(torch.autograd.Function):
1624                @staticmethod
1625                def forward(x, *, y):
1626                    return x + y
1627
1628                @staticmethod
1629                def setup_context(ctx, inputs, output):
1630                    pass
1631
1632                @staticmethod
1633                def vmap(info, in_dims, x, *, y):
1634                    assert in_dims == (0,)
1635                    return x + y, 0
1636
1637            x = torch.randn(3)
1638            y = torch.randn(3)
1639
1640            vmap(MyClass.apply)(x, y=y)
1641
1642
1643@markDynamoStrictTest
1644class TestVmapOfGrad(TestCase):
1645    def test_per_sample_grads_inplace_view(self, device):
1646        def compute_loss(weight, x, t):
1647            x = x.mm(weight)
1648            y = x.squeeze_(0)
1649            return (y - t).sum()
1650
1651        weight = torch.randn(16, 2, device=device)
1652        x = torch.randn(64, 1, 16, device=device)
1653        t = torch.randn(64, 2, device=device)
1654        result = vmap(partial(grad(compute_loss), weight))(x, t)
1655        expected = [grad(compute_loss)(weight, x[i], t[i]) for i in range(64)]
1656        expected = torch.stack(expected)
1657        # TODO: Check if the rtol is a problem
1658        self.assertEqual(result, expected, atol=0, rtol=5e-4)
1659
1660    def test_new_zeros_materializes_tensor(self, device):
1661        N = 3
1662        C = 5
1663
1664        def foo(y, x):
1665            result = x.new_zeros((C,))
1666            result.copy_(y)
1667            return result.sum()
1668
1669        x = torch.randn(N, device=device)
1670        y = torch.randn(N, C, device=device)
1671        result = vmap(grad(foo))(y, x)
1672        self.assertEqual(result, torch.ones_like(y))
1673
1674    def test_new_empty_materializes_tensor(self, device):
1675        N = 3
1676        C = 5
1677
1678        def foo(y, x):
1679            result = x.new_empty((C,))
1680            result.copy_(y)
1681            return result.sum()
1682
1683        x = torch.randn(N, device=device)
1684        y = torch.randn(N, C, device=device)
1685        result = vmap(grad(foo))(y, x)
1686        self.assertEqual(result, torch.ones_like(y))
1687
1688    def test_per_sample_grads_simple(self, device):
1689        def compute_loss(weight, x, t):
1690            y = x @ weight
1691            return ((y - t) ** 2).sum()
1692
1693        weight = torch.randn(16, 2, device=device)
1694        x = torch.randn(64, 16, device=device)
1695        t = torch.randn(64, 2, device=device)
1696        result = vmap(partial(grad(compute_loss), weight))(x, t)
1697        expected = [grad(compute_loss)(weight, x[i], t[i]) for i in range(64)]
1698        expected = torch.stack(expected)
1699        # TODO: Check if the rtol is a problem
1700        self.assertEqual(result, expected, atol=0, rtol=5e-4)
1701
1702    def _compare_expected_and_result(self, expected, result, mechanism):
1703        if mechanism == "make_functional":
1704            expected = zip(*expected)
1705            expected = tuple(torch.stack(shards) for shards in expected)
1706            for r, e in zip(result, expected):
1707                self.assertEqual(r, e, atol=0, rtol=1.5e-3)
1708        else:
1709            assert mechanism == "functional_call"
1710            expected = {
1711                k: tuple(d[k] for d in expected) for k, v in expected[0].items()
1712            }
1713            expected = {k: torch.stack(shards) for k, shards in expected.items()}
1714            for key in result:
1715                self.assertEqual(result[key], expected[key], atol=0, rtol=1.5e-3)
1716
1717    @tf32_on_and_off(0.005)
1718    @parametrize("mechanism", ["make_functional", "functional_call"])
1719    def test_per_sample_grads_embeddingnet(self, device, mechanism):
1720        class SampleNet(nn.Module):
1721            def __init__(self, vocab_size: int):
1722                super().__init__()
1723                self.emb = nn.Embedding(vocab_size, 16)
1724                self.fc1 = nn.Linear(16, 16)
1725                self.fc2 = nn.Linear(16, 2)
1726
1727            def forward(self, x):
1728                x = self.emb(x)
1729                x = torch.transpose(x, -1, -2)
1730                x = torch.mean(x, -1)
1731                x = self.fc1(x)
1732                x = F.relu(x)
1733                x = self.fc2(x)
1734                return x
1735
1736            def name(self):
1737                return "SampleNet"
1738
1739        # Create our inputs...
1740        vocab_size = 1000
1741        batch_shape = [64]
1742        words_per_sentence = 5
1743        data = torch.randint(
1744            0, vocab_size, (*batch_shape, words_per_sentence), device=device
1745        )
1746        targets = torch.randint(0, 1, (*batch_shape,), device=device)
1747
1748        # Construct our module
1749        net = SampleNet(vocab_size).to(device=device)
1750        criterion = nn.CrossEntropyLoss()
1751
1752        net_func, weights = _get_weights_and_functional_call(net, mechanism)
1753
1754        def compute_loss(weights, data, target):
1755            output = net_func(weights, data)
1756            result = criterion(output, target)
1757            return result
1758
1759        expected = [grad(compute_loss)(weights, data[i], targets[i]) for i in range(64)]
1760        result = vmap(partial(grad(compute_loss), weights))(data, targets)
1761        self._compare_expected_and_result(expected, result, mechanism)
1762
1763    def test_log_softmax(self, device):
1764        x = torch.randn(3, 5, device=device)
1765        v = torch.randn(5, device=device)
1766
1767        def foo(x, v):
1768            _, vjp_fn = vjp(partial(torch.log_softmax, dim=-1), x)
1769            return vjp_fn(v)[0]
1770
1771        result = vmap(foo, (0, None))(x, v)
1772
1773        v = v.expand_as(x)
1774        x.requires_grad_()
1775        output = torch.log_softmax(x, dim=-1)
1776        output.backward(v)
1777        self.assertEqual(result, x.grad)
1778
1779
1780jacrev_and_jacfwd = parametrize(
1781    "jacapi", [subtest(jacrev, name="jacrev"), subtest(jacfwd, name="jacfwd")]
1782)
1783
1784FIXME_jacrev_only = parametrize("jacapi", [subtest(jacrev, name="jacrev")])
1785
1786
1787@markDynamoStrictTest
1788class TestJac(VmapTearDownMixin, TestCase):
1789    @jacrev_and_jacfwd
1790    def test_simple(self, device, jacapi):
1791        x = torch.randn(3, device=device)
1792        y = jacapi(torch.sin)(x)
1793        expected = torch.diagflat(x.cos())
1794        assert torch.allclose(y, expected)
1795
1796    @jacrev_and_jacfwd
1797    def test_simple_not_flat(self, device, jacapi):
1798        x = torch.randn(2, 3, device=device)
1799        y = jacapi(torch.sin)(x)
1800        expected = torch.diagflat(x.view(-1).cos())
1801        expected = expected.view(2, 3, 2, 3)
1802        assert torch.allclose(y, expected)
1803
1804    @jacrev_and_jacfwd
1805    def test_take(self, device, jacapi):
1806        x = torch.rand(5)
1807
1808        def func(x):
1809            y = torch.ones(3, dtype=torch.long)
1810            z = torch.take(x, y)
1811            return z
1812
1813        self.assertEqual(jacrev(func)(x), torch.autograd.functional.jacobian(func, x))
1814
1815    @jacrev_and_jacfwd
1816    def test_diff_numel(self, device, jacapi):
1817        x = torch.randn(2, 4, device=device)
1818
1819        # Tensor[2, 4] -> Tensor[3, 1]
1820        def f(x):
1821            return x[0, 1:].unsqueeze(-1)
1822
1823        y = jacapi(f)(x)
1824        self.assertEqual(y.shape, (3, 1, 2, 4))
1825
1826        expected = x.new_zeros(3, 1, 2, 4)
1827        expected[0, 0, 0, 1] = 1
1828        expected[1, 0, 0, 2] = 1
1829        expected[2, 0, 0, 3] = 1
1830        self.assertEqual(y, expected)
1831
1832    @jacrev_and_jacfwd
1833    def test_vmap_on_jac_simple(self, device, jacapi):
1834        x = torch.randn(2, 3, device=device)
1835        y = vmap(jacapi(torch.sin))(x)
1836        expected = torch.stack([torch.diagflat(x[i].cos()) for i in range(2)])
1837        assert torch.allclose(y, expected)
1838
1839    @jacrev_and_jacfwd
1840    def test_nested_jac_simple(self, device, jacapi):
1841        def foo(x):
1842            return x.sin().sum()
1843
1844        x = torch.randn(3, device=device)
1845        y = jacapi(jacapi(foo))(x)
1846        expected = torch.diagflat(-x.sin())
1847        assert torch.allclose(y, expected)
1848
1849    @jacrev_and_jacfwd
1850    def test_multiple_args(self, device, jacapi):
1851        x = torch.randn(3, device=device)
1852        y = torch.randn(3, device=device)
1853        z = jacapi(torch.multiply, argnums=1)(x, y)
1854        expected = torch.diagflat(x)
1855        assert torch.allclose(z, expected)
1856
1857    @jacrev_and_jacfwd
1858    def test_multiple_outputs_multiple_argnums(self, device, jacapi):
1859        def f(x, y):
1860            return 2 * x + 3 * y, 4 * x + 5 * y
1861
1862        x = torch.randn(3, device=device)
1863        y = torch.randn(3, device=device)
1864        z = jacapi(f, argnums=(0, 1))(x, y)
1865        expected_out0_x = torch.diagflat(torch.full_like(x, 2))
1866        expected_out0_y = torch.diagflat(torch.full_like(y, 3))
1867        expected_out1_x = torch.diagflat(torch.full_like(x, 4))
1868        expected_out1_y = torch.diagflat(torch.full_like(y, 5))
1869
1870        self.assertEqual(len(z), 2)
1871        self.assertTrue(isinstance(z, tuple))
1872        self.assertEqual(len(z[0]), 2)
1873        self.assertTrue(isinstance(z[0], tuple))
1874        self.assertEqual(z[0][0], expected_out0_x)
1875        self.assertEqual(z[0][1], expected_out0_y)
1876        self.assertEqual(z[1][0], expected_out1_x)
1877        self.assertEqual(z[1][1], expected_out1_y)
1878
1879    @jacrev_and_jacfwd
1880    def test_multiple_outputs_single_argnums(self, device, jacapi):
1881        def f(x, y):
1882            return 2 * x + 3 * y, 4 * x + 5 * y
1883
1884        x = torch.randn(3, device=device)
1885        y = torch.randn(3, device=device)
1886        expected_out0_x = torch.diagflat(torch.full_like(x, 2))
1887        expected_out1_x = torch.diagflat(torch.full_like(x, 4))
1888
1889        z = jacapi(f, argnums=0)(x, y)
1890        self.assertEqual(len(z), 2)
1891        self.assertTrue(isinstance(z, tuple))
1892        self.assertEqual(z, (expected_out0_x, expected_out1_x))
1893
1894        z = jacapi(f, argnums=(0,))(x, y)
1895        self.assertEqual(len(z), 2)
1896        self.assertTrue(isinstance(z, tuple))
1897        self.assertTrue(isinstance(z[0], tuple))
1898        self.assertEqual(z, ((expected_out0_x,), (expected_out1_x,)))
1899
1900    @jacrev_and_jacfwd
1901    def test_multiple_outputs_pytree(self, device, jacapi):
1902        def f(x, y):
1903            return {"left": 2 * x + 3 * y, "right": 4 * x + 5 * y}
1904
1905        x = torch.randn(3, device=device)
1906        y = torch.randn(3, device=device)
1907        z = jacapi(f, argnums=(0, 1))(x, y)
1908        expected_left_x = torch.diagflat(torch.full_like(x, 2))
1909        expected_left_y = torch.diagflat(torch.full_like(y, 3))
1910        expected_right_x = torch.diagflat(torch.full_like(x, 4))
1911        expected_right_y = torch.diagflat(torch.full_like(y, 5))
1912        expected = {
1913            "left": (expected_left_x, expected_left_y),
1914            "right": (expected_right_x, expected_right_y),
1915        }
1916        self.assertTrue(isinstance(z, dict))
1917        self.assertTrue(isinstance(z["left"], tuple))
1918        self.assertTrue(isinstance(z["right"], tuple))
1919        self.assertEqual(z, expected)
1920
1921    @jacrev_and_jacfwd
1922    def test_multiple_inputs_pytree(self, device, jacapi):
1923        def f(a, b, c):
1924            a0, a1 = a
1925            return a0 + a1 * 2 + b * 3 + c * 4
1926
1927        x = torch.randn([], device=device)
1928        args = ((x, x), x, x)
1929
1930        result = jacapi(f, argnums=(0, 1, 2))(*args)
1931        expected = (
1932            (torch.tensor(1.0, device=device), torch.tensor(2.0, device=device)),
1933            torch.tensor(3.0, device=device),
1934            torch.tensor(4.0, device=device),
1935        )
1936        self.assertEqual(result, expected)
1937
1938        result = jacapi(f, argnums=(0,))(*args)
1939        expected = (
1940            (torch.tensor(1.0, device=device), torch.tensor(2.0, device=device)),
1941        )
1942        self.assertEqual(result, expected)
1943
1944        result = jacapi(f)(*args)
1945        expected = (torch.tensor(1.0, device=device), torch.tensor(2.0, device=device))
1946        self.assertEqual(result, expected)
1947
1948    @jacrev_and_jacfwd
1949    def test_dimensionality(self, device, jacapi):
1950        def f(x):
1951            return x
1952
1953        x = torch.randn([], device=device)
1954        result = jacapi(f)(x)
1955        self.assertEqual(result.dim(), 0)
1956        self.assertEqual(result, torch.ones_like(x))
1957
1958        x = torch.randn([1], device=device)
1959        result = jacapi(f)(x)
1960        self.assertEqual(result.dim(), 2)
1961        self.assertEqual(result, x.new_ones(1, 1))
1962
1963    @jacrev_and_jacfwd
1964    def test_aux_tensor(self, device, jacapi):
1965        def f(x):
1966            y = x.clone()
1967            return y, y.cos()
1968
1969        x = torch.randn(3, device=device)
1970        result, aux = jacapi(f, has_aux=True)(x)
1971
1972        self.assertEqual(result, torch.eye(3, 3, device=device))
1973        self.assertEqual(aux, x.cos())
1974
1975    @jacrev_and_jacfwd
1976    def test_aux_pytree(self, device, jacapi):
1977        def f(x):
1978            y = x.clone()
1979            return y, {"a": y.cos(), "b": [y.tan()]}
1980
1981        x = torch.randn(3, device=device)
1982
1983        result, aux = jacapi(f, has_aux=True)(x)
1984        self.assertEqual(result, torch.eye(3, 3, device=device))
1985        _, expected_aux = f(x)
1986        self.assertEqual(aux, expected_aux)
1987
1988        for aux in [1, 1.0, "abc"]:
1989            with self.assertRaisesRegex(
1990                RuntimeError, r"Expected tensors, got unsupported type"
1991            ):
1992                _ = jacapi(lambda x: (x, aux), has_aux=True)(x)
1993            with self.assertRaisesRegex(
1994                RuntimeError, r"Expected tensors, got unsupported type"
1995            ):
1996                _ = jacapi(lambda x: (x, [x, aux]), has_aux=True)(x)
1997
1998    @jacrev_and_jacfwd
1999    def test_outputs_can_any_pytree(self, device, jacapi):
2000        x = torch.randn(2, 3, device=device)
2001
2002        for output in [None, ()]:
2003            with self.assertRaisesRegex(
2004                RuntimeError,
2005                r"(vjp|jvp).+: Expected f to be a function that has non-empty output",
2006            ):
2007                jacapi(lambda _: output)(x)
2008
2009        for output in [1, True, 12.2, "abc"]:
2010            with self.assertRaisesRegex(
2011                RuntimeError,
2012                r"(vjp|jvp).+: expected f\(\*primals\) to return only tensors",
2013            ):
2014                jacapi(lambda _: output)(x)
2015
2016        # Check list output
2017        out = jacapi(lambda x: [x, x.sum()])(x)
2018        assert isinstance(out, list) and len(out) == 2
2019
2020        # Check dict output
2021        out = jacapi(lambda x: {"x": x, "xsum": x.sum()})(x)
2022        assert isinstance(out, dict) and len(out) == 2 and "xsum" in out
2023
2024        def composite_output(x):
2025            out = x.sum()
2026            return [
2027                (out, {"a": x, "out": [x, out]}),
2028            ]
2029
2030        out = jacapi(composite_output)(x)
2031        assert isinstance(out, list)
2032        assert isinstance(out[0], tuple) and isinstance(out[0][1], dict)
2033
2034    @jacrev_and_jacfwd
2035    def test_multiple_inputs_outputs_pytree(self, device, jacapi):
2036        def f(a, b, c):
2037            a0, a1 = a
2038            return a0 + a1 * 2, {"foo": b * 3 + c * 4}
2039
2040        x = torch.randn([], device=device)
2041        zero = torch.zeros([], device=device)
2042        args = ((x, x), x, x)
2043
2044        result = jacapi(f)(*args)
2045        expected = (
2046            (torch.tensor(1.0, device=device), torch.tensor(2.0, device=device)),
2047            {"foo": (zero, zero)},
2048        )
2049        self.assertEqual(result, expected)
2050
2051        result = jacapi(f, argnums=(0,))(*args)
2052        expected = (
2053            ((torch.tensor(1.0, device=device), torch.tensor(2.0, device=device)),),
2054            {"foo": ((zero, zero),)},
2055        )
2056        self.assertEqual(result, expected)
2057
2058        result = jacapi(f, argnums=(0, 1))(*args)
2059        expected = (
2060            (
2061                (torch.tensor(1.0, device=device), torch.tensor(2.0, device=device)),
2062                zero,
2063            ),
2064            {"foo": ((zero, zero), torch.tensor(3.0, device=device))},
2065        )
2066        self.assertEqual(result, expected)
2067
2068    @jacrev_and_jacfwd
2069    def test_multiple_inputs_outputs_pytree_multidim(self, device, jacapi):
2070        def f(dct):
2071            a = dct["a"]
2072            b = dct["b"]
2073            return {"c": a.sin(), "d": b.cos()}
2074
2075        x = torch.randn(3, device=device)
2076        args = ({"a": x, "b": x},)
2077
2078        result = jacapi(f)(*args)
2079        expected = {
2080            "c": {"a": x.cos().diagflat(), "b": x.new_zeros(3, 3)},
2081            "d": {"a": x.new_zeros(3, 3), "b": -x.sin().diagflat()},
2082        }
2083        self.assertEqual(result, expected)
2084
2085    @jacrev_and_jacfwd
2086    def test_unrelated_input(self, device, jacapi):
2087        def f(x, y):
2088            return x
2089
2090        x = torch.randn(2, 3, device=device)
2091        y = torch.randn(2, 3, device=device)
2092
2093        result = jacapi(f, argnums=(0, 1))(x, y)
2094        expected0 = torch.eye(6, 6, device=device).view(2, 3, 2, 3)
2095        expected1 = y.new_zeros(2, 3, 2, 3)
2096        expected = (expected0, expected1)
2097        self.assertTrue(isinstance(result, tuple))
2098        self.assertEqual(result, expected)
2099
2100    @jacrev_and_jacfwd
2101    def test_unrelated_output(self, device, jacapi):
2102        y = torch.randn(2, 3, device=device)
2103
2104        def f(x):
2105            return y
2106
2107        x = torch.randn(2, 3, device=device)
2108
2109        result = jacapi(f)(x)
2110        expected = x.new_zeros(2, 3, 2, 3)
2111        self.assertEqual(result, expected)
2112
2113    @jacrev_and_jacfwd
2114    def test_empty_output(self, device, jacapi):
2115        x = torch.randn(3, device=device)
2116        y = torch.randn(3, device=device)
2117
2118        def f(x, y):
2119            return ()
2120
2121        with self.assertRaisesRegex(RuntimeError, "xpected"):
2122            jacapi(f)(x, y)
2123
2124    @jacrev_and_jacfwd
2125    def test_argnums_tuple(self, device, jacapi):
2126        x = torch.randn(3, device=device)
2127        y = torch.randn(3, device=device)
2128        z = jacapi(torch.multiply, argnums=(0, 1))(x, y)
2129        expected0 = torch.diagflat(y)
2130        expected1 = torch.diagflat(x)
2131        assert len(z) == 2
2132        assert torch.allclose(z[0], expected0)
2133        assert torch.allclose(z[1], expected1)
2134
2135    @jacrev_and_jacfwd
2136    def test_argnums_effect_on_return(self, device, jacapi):
2137        x = torch.randn(3, device=device)
2138        y = torch.randn(3, device=device)
2139        z = jacapi(torch.multiply, argnums=(0,))(x, y)
2140        expected0 = torch.diagflat(y)
2141        assert isinstance(z, tuple)
2142        assert len(z) == 1
2143        assert torch.allclose(z[0], expected0)
2144
2145        x = torch.randn(3, device=device)
2146        y = torch.randn(3, device=device)
2147        z = jacapi(torch.multiply, argnums=0)(x, y)
2148        expected0 = torch.diagflat(y)
2149        assert isinstance(z, torch.Tensor)
2150        assert torch.allclose(z, expected0)
2151
2152    @jacrev_and_jacfwd
2153    def test_argnums_defaults_to_zero(self, device, jacapi):
2154        def f(x, y):
2155            return x * 2 + y * 3
2156
2157        x = torch.randn(3, device=device)
2158        y = torch.randn(3, device=device)
2159        z = jacapi(f)(x, y)
2160        expected = torch.diagflat(torch.full_like(x, 2))
2161        self.assertEqual(z, expected)
2162
2163    @jacrev_and_jacfwd
2164    def test_empty_argnums(self, device, jacapi):
2165        x = torch.randn(3, device=device)
2166        with self.assertRaisesRegex(RuntimeError, "must be non-empty"):
2167            jacapi(torch.sin, argnums=())(x)
2168
2169    @jacrev_and_jacfwd
2170    def test_out_of_bounds_argnums(self, device, jacapi):
2171        x = torch.randn(3, device=device)
2172        with self.assertRaisesRegex(RuntimeError, "only 1 positional inputs"):
2173            jacapi(torch.sin, argnums=2)(x)
2174
2175    @jacrev_and_jacfwd
2176    def test_negative_argnums(self, device, jacapi):
2177        x = torch.randn(3, device=device)
2178        with self.assertRaisesRegex(RuntimeError, "only 1 positional inputs"):
2179            jacapi(torch.sin, argnums=-2)(x)
2180
2181    @jacrev_and_jacfwd
2182    def test_repeated_argnums(self, device, jacapi):
2183        x = torch.randn(3, device=device)
2184        with self.assertRaisesRegex(RuntimeError, "must be unique"):
2185            jacapi(torch.sin, argnums=(0, 0))(x)
2186
2187    @jacrev_and_jacfwd
2188    def test_float_argnums(self, device, jacapi):
2189        x = torch.randn(3, device=device)
2190        with self.assertRaisesRegex(RuntimeError, "must be int or Tuple"):
2191            jacapi(torch.sin, argnums=0.0)(x)
2192        with self.assertRaisesRegex(RuntimeError, "must be int"):
2193            jacapi(torch.multiply, argnums=(1, 0.0))(x, x)
2194
2195    def test_hessian_simple(self, device):
2196        def f(x):
2197            return x.sin()
2198
2199        x = torch.randn(3, device=device)
2200        hessian(f)(x)
2201
2202    def _test_against_reference(self, f, inputs, jacapi):
2203        def foo(inputs):
2204            return f(*inputs)
2205
2206        expected = torch.autograd.functional.jacobian(f, inputs)
2207        result = jacapi(foo)(inputs)
2208        self.assertEqual(result, expected)
2209
2210    @jacrev_and_jacfwd
2211    def test_against_reference_simple(self, device, jacapi):
2212        def f(x):
2213            return 3 * x**2
2214
2215        x = torch.randn(2, 3, 5, device=device)
2216        self._test_against_reference(f, (x,), jacapi)
2217
2218    @jacrev_and_jacfwd
2219    def test_against_reference_multi_input(self, device, jacapi):
2220        def f(x, y):
2221            return (x.cos() * x) @ y.sin()
2222
2223        x = torch.randn(2, 3, device=device)
2224        y = torch.randn(3, 5, device=device)
2225        self._test_against_reference(f, (x, y), jacapi)
2226
2227    @jacrev_and_jacfwd
2228    def test_against_reference_multi_input_multi_output(self, device, jacapi):
2229        def f(x, y):
2230            return (x * x) @ y, x @ (x.sum(1) * y), y.sum()
2231
2232        x = torch.randn(5, 3, device=device)
2233        y = torch.randn(3, 5, device=device)
2234        self._test_against_reference(f, (x, y), jacapi)
2235
2236    @jacrev_and_jacfwd
2237    def test_against_reference_unrelated_outputs(self, device, jacapi):
2238        def f(x, y):
2239            return x, y, x, y
2240
2241        x = torch.randn(2, device=device)
2242        y = torch.randn(3, device=device)
2243        self._test_against_reference(f, (x, y), jacapi)
2244
2245    @jacrev_and_jacfwd
2246    def test_against_reference_zero_dim(self, device, jacapi):
2247        # zero-dim output
2248        def f(x, y):
2249            return x.sum(), y.sum(), x * y
2250
2251        x = torch.randn(3, device=device)
2252        y = torch.randn(3, device=device)
2253        self._test_against_reference(f, (x, y), jacapi)
2254
2255        # zero-dim input
2256        def g(x):
2257            return torch.stack([x, x, x])
2258
2259        x = torch.randn([], device=device)
2260        self._test_against_reference(g, (x,), jacapi)
2261
2262        # Mixed zero-dim input / zero-dim output
2263        def h(x, y):
2264            return y.sum(), x * y
2265
2266        x = torch.randn([], device=device)
2267        y = torch.randn(1, device=device)
2268        self._test_against_reference(h, (x, y), jacapi)
2269
2270    @jacrev_and_jacfwd
2271    def test_against_reference_correctness_different_devices(self, device, jacapi):
2272        def f(x, y):
2273            return x * y, (x * y).to(device=device)
2274
2275        x = torch.randn(3)
2276        y = torch.randn(3)
2277        self._test_against_reference(f, (x, y), jacapi)
2278
2279    @jacrev_and_jacfwd
2280    def test_against_reference_default_arg(self, device, jacapi):
2281        def f(x, y, z=3.0):
2282            return x * y * z
2283
2284        x = torch.randn(3, device=device)
2285        y = torch.randn(3, device=device)
2286        self._test_against_reference(f, (x, y), jacapi)
2287
2288    @jacrev_and_jacfwd
2289    def test_inplace(self, device, jacapi):
2290        def f(x, y):
2291            y.copy_(x)
2292            return y
2293
2294        out = jacapi(f, argnums=0)  # x is differentiable
2295        x, y = torch.randn(2, device=device), torch.randn(2, device=device)
2296        self.assertEqual(out(x, y), torch.eye(y.shape[0]))
2297
2298        # testing tuple of argnums with the example that raised this issue originally
2299        def g(x, y, z):
2300            x[:2] = y
2301            return torch.vstack([(x**2).sum(), (z**3).sum()])
2302
2303        out = jacapi(g, argnums=(1, 2))
2304        x, y, z = (
2305            torch.randn(3, device=device),
2306            torch.randn(2, device=device),
2307            torch.randn(2, device=device),
2308        )
2309
2310        expected_out = (
2311            torch.zeros(2, 1, 2, device=device),
2312            torch.zeros(2, 1, 2, device=device),
2313        )
2314        expected_out[0][0][0] = 2 * y  # top left corner
2315        expected_out[1][1][0] = 3 * (z**2)  # bottom right corner
2316
2317        out_val = out(x, y, z)
2318        self.assertEqual(out_val, expected_out)
2319
2320    @parametrize("_preallocate_and_copy", (True, False))
2321    def test_chunk_jacrev(self, device, _preallocate_and_copy):
2322        x = torch.randn(10, 2, device=device)
2323        y = torch.randn(1, 2, device=device)
2324
2325        def f(x, y):
2326            return (x.sin(), x + y), (x + 2, x.sum())
2327
2328        for chunk_size in (1, 2, 3, 4, 7, 10, 1000):
2329            expected = jacrev(f, argnums=(0, 1))(x, y)
2330            actual = jacrev(
2331                f,
2332                argnums=(0, 1),
2333                chunk_size=chunk_size,
2334                _preallocate_and_copy=_preallocate_and_copy,
2335            )(x, y)
2336            self.assertEqual(actual, expected)
2337
2338        err_msg = "jacrev: `chunk_size` should be greater than 0."
2339        with self.assertRaisesRegex(ValueError, err_msg):
2340            jacrev(f, argnums=(0,), chunk_size=0)(x, y)
2341
2342        with self.assertRaisesRegex(ValueError, err_msg):
2343            jacrev(f, argnums=(0,), chunk_size=-2)(x, y)
2344
2345    @parametrize("_preallocate_and_copy", (True, False))
2346    def test_chunk_jacrev_composition(self, device, _preallocate_and_copy):
2347        x = torch.randn(10, 2, device=device)
2348        chunk_size = 3
2349
2350        def f(x):
2351            return (x.sin(), x), (x + 2, x.sum())
2352
2353        expected = vmap(jacrev(jacrev(f)))(x)
2354        actual = vmap(
2355            jacrev(
2356                jacrev(
2357                    f,
2358                    chunk_size=chunk_size,
2359                    _preallocate_and_copy=_preallocate_and_copy,
2360                ),
2361                chunk_size=chunk_size,
2362            )
2363        )(x)
2364        self.assertEqual(actual, expected)
2365
2366    # https://github.com/pytorch/pytorch/issues/127036
2367    @xfailIfTorchDynamo
2368    @parametrize("_preallocate_and_copy", (True, False))
2369    def test_chunk_jacrev_chunksize_one(self, device, _preallocate_and_copy):
2370        # With chunk_size=1, we shouldn't `vmap` and hence not be limited
2371        # by it's constraints.
2372        x = torch.randn(3, 3, device=device)
2373
2374        # Function with Dynamic Op in Backward.
2375        # This should cause jacrev/vmap(vjp) to fail.
2376        class IdentityWithDynamicBackwardOp(torch.autograd.Function):
2377            @staticmethod
2378            def forward(input):
2379                return input
2380
2381            @staticmethod
2382            def setup_context(ctx, inputs, output):
2383                pass
2384
2385            @staticmethod
2386            def backward(ctx, grad_output):
2387                # dynamic op in backward pass.
2388                grad_output.nonzero()
2389                return grad_output
2390
2391        def f(x):
2392            return IdentityWithDynamicBackwardOp.apply(x)
2393
2394        # With `chunk_size=1`, we don't use vmap. So the following should work.
2395        jacfn = jacrev(f, chunk_size=1, _preallocate_and_copy=_preallocate_and_copy)
2396        actual = jacfn(x)
2397        expected = torch.autograd.functional.jacobian(f, x, vectorize=False)
2398        self.assertEqual(actual, expected)
2399
2400        # Should fail with `chunk_size=2`.
2401        msg = (
2402            r"vmap: We do not support batching operators that can output dynamic shape."
2403        )
2404        with self.assertRaisesRegex(RuntimeError, msg):
2405            jacrev(f, chunk_size=2, _preallocate_and_copy=_preallocate_and_copy)(x)
2406
2407    def test_complex_error(self, device):
2408        # Verify complex input raises error
2409        # C -> C
2410        def fn(x):
2411            return x.conj()
2412
2413        x = torch.randn(1, device=device, dtype=torch.cfloat)
2414
2415        with self.assertRaisesRegex(RuntimeError, "jacrev: Expected all inputs"):
2416            jacrev(fn)(x)
2417
2418        with self.assertRaisesRegex(RuntimeError, "jacfwd: Expected all inputs"):
2419            jacfwd(fn)(x)
2420
2421        # Verify complex output raises error
2422        # R -> C
2423        def fn(x):
2424            return torch.conj(x * 0.5j)
2425
2426        x = torch.randn(1, device=device, dtype=torch.float)
2427
2428        with self.assertRaisesRegex(RuntimeError, "jacrev: Expected all outputs"):
2429            jacrev(fn)(x)
2430
2431        with self.assertRaisesRegex(RuntimeError, "jacfwd: Expected all outputs"):
2432            jacfwd(fn)(x)
2433
2434    @jacrev_and_jacfwd
2435    def test_jac_with_non_tensor_args(self, device, jacapi):
2436        def f(t, int_x):
2437            return t + int_x
2438
2439        t = torch.randn(3, 3, device=device)
2440
2441        actual = jacapi(f)(t, 3)
2442        expected = torch.autograd.functional.jacobian(partial(f, int_x=3), t)
2443        self.assertEqual(actual, expected)
2444
2445
2446@markDynamoStrictTest
2447class TestHessian(TestCase):
2448    def _test_against_reference(self, f, inputs):
2449        def foo(inputs):
2450            return f(*inputs)
2451
2452        expected = torch.autograd.functional.hessian(f, inputs)
2453        result = hessian(foo)(inputs)
2454        self.assertEqual(result, expected)
2455
2456    def test_hessian_vectorize_correctness_simple(self, device):
2457        def f(x):
2458            return (3 * x**2).sum()
2459
2460        x = torch.randn(2, 3, 5, device=device)
2461        self._test_against_reference(f, (x,))
2462
2463    def test_hessian_vectorize_correctness_multi_input(self, device):
2464        def f(x, y, z):
2465            return ((x.relu() * x) @ y.sin() @ z).sum()
2466
2467        x = torch.randn(2, 3, device=device)
2468        y = torch.randn(3, 5, device=device)
2469        z = torch.randn(5, 5, device=device)
2470        self._test_against_reference(f, (x, y, z))
2471
2472    def test_hessian_vectorize_correctness_unrelated_outputs(self, device):
2473        # output unrelated to one input
2474        def f(x, y):
2475            return (x**2).sum()
2476
2477        x = torch.randn(2, device=device)
2478        y = torch.randn(3, device=device)
2479        self._test_against_reference(f, (x, y))
2480
2481        # output unrelated to all inputs
2482        def f(x, y):
2483            return torch.ones([])
2484
2485        x = torch.randn(2, device=device)
2486        y = torch.randn(3, device=device)
2487        self._test_against_reference(f, (x, y))
2488
2489    def test_jacfwd_different_levels(self, device):
2490        # Test case from:
2491        # https://github.com/pytorch/functorch/issues/597
2492        b = 8
2493        n = 100
2494        d = 2
2495        x1 = torch.randn(b, n, d, device=device)
2496        x2 = x1
2497        A = 0.1 * torch.randn(b, d, d, device=device)
2498
2499        def loss(A, x1, x2):
2500            x2_hat = (A @ (x1.T)).T
2501            res = x2 - x2_hat
2502            res_sqr = res**2
2503            return res_sqr.sum()
2504
2505        hess1 = vmap(jacrev(jacrev(loss)))(A, x1, x2)
2506        hess2 = vmap(hessian(loss))(A, x1, x2)
2507        self.assertEqual(hess2, hess1)
2508
2509
2510@markDynamoStrictTest
2511class TestJvp(TestCase):
2512    def test_inplace_on_captures(self, device):
2513        x = torch.tensor([1.0, 2.0, 3.0], device=device)
2514        captured = torch.randn(3, device=device)
2515
2516        def foo(x):
2517            captured.copy_(x)
2518            return (x * captured).sum()
2519
2520        with self.assertRaisesRegex(RuntimeError, "mutate a captured Tensor"):
2521            grad(foo)(x)
2522
2523    def test_simple(self, device):
2524        x = torch.randn(2, 3, device=device)
2525        t = torch.randn(2, 3, device=device)
2526        result = jvp(torch.sin, (x,), (t,))
2527        expected = (x.sin(), x.cos() * t)
2528        self.assertTrue(isinstance(result, tuple))
2529        self.assertEqual(result, expected)
2530
2531    def test_multiple_inputs(self, device):
2532        x = torch.randn(2, 3, device=device)
2533        y = torch.randn(2, 3, device=device)
2534        tx = torch.randn(2, 3, device=device)
2535        ty = torch.randn(2, 3, device=device)
2536
2537        def f(x, y):
2538            return x * y
2539
2540        result = jvp(f, (x, y), (tx, ty))
2541        expected = (x * y, y * tx + x * ty)
2542        self.assertTrue(isinstance(result, tuple))
2543        self.assertEqual(result, expected)
2544
2545    def test_pytree_inputs(self, device):
2546        def f(x, y, z):
2547            a, b = x
2548            return a + 2 * b + 3 * y + 4 * z
2549
2550        one = torch.tensor(1.0, device=device)
2551        primal_outs, tangent_outs = jvp(
2552            f, ((one, one), one, one), ((one, one), one, one)
2553        )
2554        self.assertEqual(primal_outs, one * 10)
2555        self.assertEqual(tangent_outs, one * 10)
2556
2557    def test_pytree_inputs_error_cases(self, device):
2558        def f(x):
2559            return x
2560
2561        one = torch.tensor(1.0, device=device)
2562
2563        with self.assertRaisesRegex(RuntimeError, "Expected primals to be a tuple"):
2564            jvp(f, one, one)
2565        with self.assertRaisesRegex(RuntimeError, "same python structure"):
2566            jvp(f, ((one, one), one), (one, one))
2567        with self.assertRaisesRegex(RuntimeError, "only contain Tensors"):
2568            jvp(f, ((one, one), 1), ((one, one), one))
2569        with self.assertRaisesRegex(RuntimeError, "only contain Tensors"):
2570            jvp(f, ((one, one), 1), ((1, one), one))
2571        with self.assertRaisesRegex(RuntimeError, "at least one Tensor"):
2572            jvp(f, ((),), ((),))
2573
2574    def test_unrelated_input(self, device):
2575        def f(x, y):
2576            return x
2577
2578        x = torch.randn(2, 3, device=device)
2579        y = torch.randn(2, 3, device=device)
2580        tx = torch.randn(2, 3, device=device)
2581        ty = torch.randn(2, 3, device=device)
2582
2583        result = jvp(f, (x, y), (tx, ty))
2584        expected = (x, tx)
2585        self.assertTrue(isinstance(result, tuple))
2586        self.assertEqual(result, expected)
2587
2588    def test_unrelated_output(self, device):
2589        y = torch.randn(2, 3, device=device)
2590
2591        def f(x):
2592            return y
2593
2594        x = torch.randn(2, 3, device=device)
2595        tx = torch.randn(2, 3, device=device)
2596
2597        result = jvp(f, (x,), (tx,))
2598        expected = (y, torch.zeros_like(y))
2599        self.assertTrue(isinstance(result, tuple))
2600        self.assertEqual(result, expected)
2601
2602    def test_strict_mode(self, device):
2603        y = torch.randn(2, 3, device=device)
2604
2605        def f(x):
2606            return x, y
2607
2608        x = torch.randn(2, 3, device=device)
2609        tx = torch.randn(2, 3, device=device)
2610
2611        with self.assertRaisesRegex(RuntimeError, "strict"):
2612            jvp(f, (x,), (tx,), strict=True)
2613
2614    def test_multiple_outputs(self, device):
2615        x = torch.randn(2, 3, device=device)
2616        t = torch.randn(2, 3, device=device)
2617
2618        def f(x):
2619            return torch.sin(x), torch.cos(x)
2620
2621        result = jvp(f, (x,), (t,))
2622        expected = (f(x), (x.cos() * t, -x.sin() * t))
2623        self.assertTrue(isinstance(result, tuple))
2624        self.assertEqual(result, expected)
2625
2626    def test_multiple_inputs_outputs(self, device):
2627        x = torch.randn(2, 3, device=device)
2628        y = torch.randn(2, 3, device=device)
2629        tx = torch.randn(2, 3, device=device)
2630        ty = torch.randn(2, 3, device=device)
2631
2632        def f(x, y):
2633            return 2 * x + 3 * y, 4 * x + 5 * y
2634
2635        result = jvp(f, (x, y), (tx, ty))
2636        expected = (f(x, y), f(tx, ty))
2637        self.assertTrue(isinstance(result, tuple))
2638        self.assertEqual(result, expected)
2639
2640    def test_jvp_new_tensor(self):
2641        def f(x):
2642            y = x.new_tensor(0.5)
2643            return x + y
2644
2645        x = torch.rand(10, 10)
2646        tangents = torch.zeros_like(x)
2647        actual = jvp(f, (x,), (tangents,))
2648        expected = (f(x), torch.zeros_like(x))
2649        self.assertEqual(actual, expected)
2650
2651    def test_primals_tangents_length_mismatch(self, device):
2652        x = torch.randn(2, 3, device=device)
2653        t = torch.randn(2, 3, device=device)
2654
2655        msg = "same python structure"
2656        with self.assertRaisesRegex(RuntimeError, msg):
2657            jvp(torch.sin, (x,), (t, t))
2658        with self.assertRaisesRegex(RuntimeError, msg):
2659            jvp(torch.sin, (x, x), (t, t, t))
2660
2661    def test_nonempty_primals_and_tangents(self, device):
2662        with self.assertRaisesRegex(RuntimeError, "at least one Tensor"):
2663            jvp(torch.sin, (), ())
2664
2665    def test_inputs_are_tuples_of_tensors(self, device):
2666        x = torch.randn(2, 3, device=device)
2667        t = torch.randn(2, 3, device=device)
2668
2669        with self.assertRaisesRegex(RuntimeError, "be a tuple"):
2670            jvp(torch.sin, x, (t,))
2671        with self.assertRaisesRegex(RuntimeError, "same python structure"):
2672            jvp(torch.sin, (x,), t)
2673        with self.assertRaisesRegex(RuntimeError, "same python structure"):
2674            jvp(torch.sin, (x,), [t])
2675        with self.assertRaisesRegex(RuntimeError, "only contain Tensors"):
2676            jvp(torch.sin, (1.0,), (t,))
2677        with self.assertRaisesRegex(RuntimeError, "only contain Tensors"):
2678            jvp(torch.sin, (x,), (1.0,))
2679
2680    def test_outputs_can_any_pytree(self, device):
2681        x = torch.randn(2, 3, device=device)
2682        t = torch.randn(2, 3, device=device)
2683
2684        for output in [None, ()]:
2685            with self.assertRaisesRegex(
2686                RuntimeError,
2687                r"jvp\(f, primals, tangents\): Expected f to be a function that has non-empty output",
2688            ):
2689                jvp(lambda _: output, (x,), (t,))
2690
2691        for output in [1, True, 12.2, "abc"]:
2692            with self.assertRaisesRegex(
2693                RuntimeError,
2694                r"jvp\(f, primals, tangents\): expected f\(\*primals\) to return only tensors",
2695            ):
2696                jvp(lambda _: output, (x,), (t,))
2697
2698        # Check list output
2699        out = jvp(lambda x: [x, x.sum()], (x,), (t,))
2700        for i in range(2):
2701            assert isinstance(out[i], list) and len(out[i]) == 2
2702
2703        # Check dict output
2704        out = jvp(lambda x: {"x": x, "xsum": x.sum()}, (x,), (t,))
2705        for i in range(2):
2706            assert isinstance(out[i], dict) and len(out[i]) == 2 and "xsum" in out[i]
2707
2708        def composite_output(x):
2709            out = x.sum()
2710            return [
2711                (out, {"a": x, "out": [x, out]}),
2712            ]
2713
2714        out = jvp(composite_output, (x,), (t,))
2715        for i in range(2):
2716            assert isinstance(out[i], list)
2717            assert isinstance(out[i][0], tuple) and isinstance(out[i][0][1], dict)
2718
2719    def test_aux_tensor(self, device):
2720        x = torch.randn(3, device=device)
2721        t = torch.randn(3, device=device)
2722
2723        with self.assertRaisesRegex(
2724            RuntimeError,
2725            r"jvp\(f, primals, tangents\): output of function f should be a tuple",
2726        ):
2727            jvp(lambda t: [t, t], (x,), (t,), has_aux=True)
2728
2729        with self.assertRaisesRegex(
2730            RuntimeError,
2731            r"jvp\(f, primals, tangents\): output of function f should be a tuple",
2732        ):
2733            jvp(lambda t: (t, t + 2, t + 3), (x,), (t,), has_aux=True)
2734
2735        def f(z):
2736            y = z.sin()
2737            return y, z.cos()
2738
2739        out, jvp_out, aux = jvp(f, (x,), (t,), has_aux=True)
2740        self.assertEqual(aux, x.cos())
2741        self.assertEqual(out, x.sin())
2742        self.assertEqual(jvp_out, t * x.cos())
2743
2744    def test_aux_pytree(self, device):
2745        def f(x):
2746            y = x.sin()
2747            return y, {"a": x.cos(), "b": [x.tan()]}
2748
2749        x = torch.randn(3, device=device)
2750        t = torch.randn(3, device=device)
2751
2752        out, jvp_out, aux = jvp(f, (x,), (t,), has_aux=True)
2753        expected_out, expected_aux = f(x)
2754        self.assertEqual(out, expected_out)
2755        self.assertEqual(aux, expected_aux)
2756        self.assertEqual(jvp_out, t * x.cos())
2757
2758        for aux in [1, 1.0, "abc"]:
2759            with self.assertRaisesRegex(
2760                RuntimeError, r"Expected tensors, got unsupported type"
2761            ):
2762                _ = jvp(lambda x: (x, aux), (x,), (t,), has_aux=True)
2763            with self.assertRaisesRegex(
2764                RuntimeError, r"Expected tensors, got unsupported type"
2765            ):
2766                _ = jvp(lambda x: (x, [x, aux]), (x,), (t,), has_aux=True)
2767
2768    def test_autograd_function_disables_fwd_grad(self, device):
2769        # Sanity check. We don't really assume this anywhere so
2770        # it's fine if this breaks one day.
2771        class MySquare(torch.autograd.Function):
2772            @staticmethod
2773            def forward(ctx, x):
2774                enabled = fwAD._is_fwd_grad_enabled()
2775                self.assertFalse(enabled)
2776                return x * x
2777
2778            @staticmethod
2779            def backward(ctx, gx):
2780                return gx
2781
2782        x = torch.randn(3, requires_grad=True)
2783        MySquare.apply(x)
2784
2785    def test_disable_fwd_grad_outside(self, device):
2786        x = torch.randn([], device=device)
2787        t = torch.ones_like(x)
2788        with fwAD._set_fwd_grad_enabled(False):
2789            _, y = jvp(torch.sin, (x,), (t,))
2790        self.assertEqual(y, x.cos())
2791
2792    def test_disable_fwd_grad_inside(self, device):
2793        def f(x):
2794            with fwAD._set_fwd_grad_enabled(False):
2795                shift = x**2
2796            return x**2 - shift
2797
2798        x = torch.randn([], device=device)
2799        t = torch.ones_like(x)
2800        _, y = jvp(f, (x,), (t,))
2801        self.assertEqual(y, 2 * x)
2802        _, y = jvp(lambda x: jvp(f, (x,), (t,))[1], (x,), (t,))
2803        self.assertEqual(y, 2)
2804
2805    def test_disable_fwd_grad_mixed(self, device):
2806        def f(x):
2807            with fwAD._set_fwd_grad_enabled(False):
2808                shift = x**2
2809            return x**2 - shift
2810
2811        x = torch.randn([], device=device)
2812        t = torch.ones_like(x)
2813        with fwAD._set_fwd_grad_enabled(True):
2814            _, y = jvp(f, (x,), (t,))
2815
2816        self.assertEqual(y, 2 * x)
2817
2818    def test_jvp_inside_autograd_function(self, device):
2819        class MySin(torch.autograd.Function):
2820            @staticmethod
2821            def forward(ctx, x):
2822                t = torch.ones_like(x)
2823                _, neg_sin_x = jvp(torch.cos, (x,), (t,))
2824                ctx.save_for_backward(x)
2825                return -neg_sin_x
2826
2827            @staticmethod
2828            def backward(ctx, gx):
2829                (x,) = ctx.saved_tensors
2830                t = torch.ones_like(x)
2831                _, cos_x = jvp(torch.sin, (x,), (t,))
2832                return gx * cos_x
2833
2834        x = torch.randn([], device=device, requires_grad=True)
2835        y = MySin.apply(x)
2836        self.assertEqual(y, x.sin())
2837
2838        (gx,) = torch.autograd.grad(y, x)
2839        self.assertEqual(gx, x.cos())
2840
2841    def test_zerotensor_vmapjvp_interaction(self, device):
2842        dummy = torch.ones(4, 1)
2843        x = torch.randn(4, 2)
2844        x_tangent = torch.randn(2)
2845
2846        def push_jvp(dummy, x):
2847            result = jvp(torch.cov, (x,), (x_tangent,))
2848            return result
2849
2850        # Should not error
2851        vmap(vmap(push_jvp, (0, None)))(dummy, x)
2852
2853
2854@markDynamoStrictTest
2855class TestLinearize(TestCase):
2856    @dtypes(torch.float)
2857    def test_linearize_basic(self, device, dtype):
2858        x_p = make_tensor((3, 1), device=device, dtype=dtype)
2859        x_t = make_tensor((3, 1), device=device, dtype=dtype)
2860
2861        def fn(x):
2862            return x.cos()
2863
2864        actual_output, jvp_fn = linearize(fn, x_p)
2865        actual_jvp = jvp_fn(x_t)
2866        expected_output, expected_jvp = jvp(fn, (x_p,), (x_t,))
2867        self.assertEqual(actual_output, expected_output)
2868        self.assertEqual(actual_jvp, expected_jvp)
2869
2870    @dtypes(torch.float)
2871    def test_linearize_return(self, device, dtype):
2872        x_p = make_tensor((3, 1), device=device, dtype=dtype)
2873        x_t = make_tensor((3, 1), device=device, dtype=dtype)
2874
2875        def fn(x):
2876            return (x.cos(), x.sum())
2877
2878        actual_output, jvp_fn = linearize(fn, x_p)
2879        actual_jvp = jvp_fn(x_t)
2880        expected_output, expected_jvp = jvp(fn, (x_p,), (x_t,))
2881        self.assertEqual(actual_output, expected_output)
2882        self.assertEqual(actual_jvp, expected_jvp)
2883
2884    @dtypes(torch.float)
2885    def test_linearize_composition_vmap(self, device, dtype):
2886        x_p = make_tensor((3, 1), device=device, dtype=dtype)
2887        x_t = make_tensor((3, 3, 1), device=device, dtype=dtype)
2888
2889        def fn(x):
2890            return (x.cos(), x.sum())
2891
2892        _, jvp_fn = linearize(fn, x_p)
2893        actual_batched_jvp = vmap(jvp_fn)(x_t)
2894
2895        def jvp_fn(x_t):
2896            return jvp(fn, (x_p,), (x_t,))[1]
2897
2898        expected_batched_jvp = vmap(jvp_fn)(x_t)
2899
2900        self.assertEqual(actual_batched_jvp, expected_batched_jvp)
2901
2902    @dtypes(torch.float)
2903    def test_linearize_composition_grad(self, device, dtype):
2904        x_p = make_tensor((3,), device=device, dtype=dtype)
2905        x_t = make_tensor((3,), device=device, dtype=dtype)
2906
2907        def fn(x):
2908            z = torch.ones(3, device=device, dtype=dtype)
2909            return grad(lambda x: z @ x)(x)
2910
2911        _, jvp_fn = linearize(fn, x_p)
2912        actual_batched_jvp = jvp_fn(x_t)
2913
2914        def jvp_fn(x_t):
2915            return jvp(fn, (x_p,), (x_t,))[1]
2916
2917        expected_batched_jvp = jvp_fn(x_t)
2918
2919        self.assertEqual(actual_batched_jvp, expected_batched_jvp)
2920
2921    @dtypes(torch.float)
2922    def test_linearize_nested_input_nested_output(self, device, dtype):
2923        x_p = make_tensor((3, 1), device=device, dtype=dtype)
2924        x_t = make_tensor((3, 1), device=device, dtype=dtype)
2925        y_p = make_tensor((3, 1), device=device, dtype=dtype)
2926        y_t = make_tensor((3, 1), device=device, dtype=dtype)
2927        z_p = make_tensor((3, 1), device=device, dtype=dtype)
2928        z_t = make_tensor((3, 1), device=device, dtype=dtype)
2929
2930        def fn(arg):
2931            x = arg["x"]
2932            y = arg["yz"][0]
2933            z = arg["yz"][1]
2934
2935            return {"a": x.sum(), "b": {"c": y + z, "d": (x * z, y.exp())}}
2936
2937        inp_p = {"x": x_p, "yz": (y_p, z_p)}
2938        inp_t = {"x": x_t, "yz": (y_t, z_t)}
2939        actual_output, jvp_fn = linearize(fn, inp_p)
2940        actual_jvp = jvp_fn(inp_t)
2941
2942        expected_output, expected_jvp = jvp(fn, (inp_p,), (inp_t,))
2943
2944        self.assertEqual(actual_output, expected_output)
2945        self.assertEqual(actual_jvp, expected_jvp)
2946
2947    @onlyCUDA
2948    def test_linearize_errors(self):
2949        dtype = torch.float
2950        device = torch.device("cpu")
2951        x_p = make_tensor((3, 1), device=device, dtype=dtype)
2952        x_t = make_tensor((3, 1), device=device, dtype=dtype)
2953
2954        def fn(x):
2955            return x.sin()
2956
2957        _, jvp_fn = linearize(fn, x_p)
2958
2959        with self.assertRaisesRegex(
2960            RuntimeError, "to have the same argspec as the primals"
2961        ):
2962            jvp_fn((x_t, x_t))
2963
2964        with self.assertRaisesRegex(
2965            RuntimeError, "in flattened pytree doesn't match the shape"
2966        ):
2967            jvp_fn(x_t.unsqueeze(0))
2968
2969        with self.assertRaisesRegex(
2970            RuntimeError, "in flattened pytree doesn't match the dtype"
2971        ):
2972            jvp_fn(x_t.to(torch.double))
2973
2974        with self.assertRaisesRegex(
2975            RuntimeError, "in flattened pytree doesn't match the device"
2976        ):
2977            jvp_fn(x_t.to(torch.device("cuda")))
2978
2979
2980# The tests here follow the cases in [Forward Grad View/inplace]
2981# https://github.com/pytorch/pytorch/blob/master/torch/csrc/autograd/autograd_meta.cpp#L18-L43
2982@markDynamoStrictTest
2983class TestVmapJvpInplaceView(TestCase):
2984    # Case 1 in [Forward Grad View/inplace]
2985    def test_all_dual_no_view(self, device):
2986        B = 2
2987
2988        def push_jvp(f):
2989            def inner(x, xt, y, yt):
2990                return jvp(f, (x, y), (xt, yt))
2991
2992            return inner
2993
2994        def f(x, y):
2995            x.copy_(y)
2996            return x
2997
2998        x = torch.randn(3, B, device=device)
2999        xt = torch.randn(3, B, device=device)
3000        y = torch.randn(3, B, device=device)
3001        yt = torch.randn(3, B, device=device)
3002        out, out_tangent = vmap(push_jvp(f), in_dims=1)(x, xt, y, yt)
3003        self.assertEqual(out, x.movedim(1, 0))
3004        self.assertEqual(out_tangent, yt.movedim(1, 0))
3005
3006        x = torch.randn(3, B, device=device)
3007        xt = torch.randn(3, B, device=device)
3008        y = torch.randn(3, 3, device=device)[:, 1]
3009        yt = torch.randn(6, device=device)[::2]
3010        out, out_tangent = vmap(push_jvp(f), in_dims=(1, 1, None, None))(x, xt, y, yt)
3011        self.assertEqual(out, x.movedim(1, 0))
3012        self.assertEqual(out_tangent, yt.expand(B, 3))
3013
3014    # Case 2 in [Forward Grad View/inplace]
3015    def test_all_dual_base_view_inplace(self, device):
3016        B = 2
3017
3018        def push_jvp(f):
3019            def inner(x, xt, y, yt):
3020                return jvp(f, (x, y), (xt, yt))
3021
3022            return inner
3023
3024        # with view, propagate from view to base
3025        def f(x, y):
3026            view = x[:, ::2]
3027            view.copy_(y)
3028            return view, x
3029
3030        orig_x = torch.randn(2, 6, B, device=device)
3031        orig_xt = torch.randn(2, 6, B, device=device)
3032        x = orig_x.clone()
3033        xt = orig_xt.clone()
3034        y = torch.randn(2, B, 3, device=device)
3035        yt = torch.randn(2, B, 3, device=device)
3036        out, out_tangent = vmap(push_jvp(f), in_dims=(2, 2, 1, 1))(x, xt, y, yt)
3037
3038        expected_out = vmap(f, in_dims=(2, 1))(orig_x.clone(), y)
3039        self.assertEqual(out[0], expected_out[0])
3040        self.assertEqual(out[1], expected_out[1])
3041
3042        self.assertEqual(out_tangent[0], yt.movedim(1, 0))
3043
3044        expected_x_tangent = orig_xt.movedim(-1, 0).clone()
3045        expected_x_tangent[:, :, ::2].copy_(yt.movedim(1, 0))
3046        self.assertEqual(out_tangent[1], expected_x_tangent)
3047
3048        expected = orig_x.movedim(2, 0).clone()
3049        expected[:, :, ::2] = y.movedim(1, 0)
3050        self.assertEqual(x.movedim(2, 0), expected)
3051
3052    # Case 3 in [Forward Grad View/inplace]
3053    def test_all_dual_base_inplace(self, device):
3054        B = 2
3055
3056        def push_jvp(f):
3057            def inner(x, xt, y, yt):
3058                return jvp(f, (x, y), (xt, yt))
3059
3060            return inner
3061
3062        # Case 3: with view, propagate from base to view
3063        def f(x, y):
3064            view = x[0, ::2]
3065            x.copy_(y)
3066            return x, view
3067
3068        x = torch.randn(2, B, 6, device=device)
3069        xt = torch.randn(2, 6, B, device=device)
3070        y = torch.randn(2, B, 6, device=device)
3071        yt = torch.randn(2, B, 6, device=device)
3072        out, out_tangent = vmap(push_jvp(f), in_dims=(1, 2, 1, 1))(x.clone(), xt, y, yt)
3073
3074        expected_out = vmap(f, in_dims=(1, 1))(x.clone(), y)
3075        self.assertEqual(out[0], expected_out[0])
3076        self.assertEqual(out[1], expected_out[1])
3077
3078        self.assertEqual(out_tangent[0], yt.movedim(1, 0))
3079        self.assertEqual(out_tangent[1], yt.movedim(1, 0)[:, 0, ::2])
3080
3081    # Case 4 in [Forward Grad View/inplace]
3082    def test_right_dual_view_prop(self, device):
3083        B = 2
3084
3085        # Changes on the view must propagate to its base. Also:
3086        # - x is a regular Tensor
3087        # - y is a dual tensor
3088        def f(x, y):
3089            x = x.clone()
3090            view = x[0]
3091            view.copy_(y)
3092            return view, x
3093
3094        def push_jvp(x, y, yt):
3095            return jvp(partial(f, x), (y,), (yt,))
3096
3097        x = torch.randn(2, B, 6, device=device)
3098        y = torch.randn(6, B, device=device)
3099        yt = torch.randn(6, B, device=device)
3100        outs, tangents = vmap(push_jvp, in_dims=(1, 1, 1))(x, y, yt)
3101
3102        expected_out = vmap(f, in_dims=(1, 1))(x.clone(), y)
3103        self.assertEqual(outs[0], expected_out[0])
3104        self.assertEqual(outs[1], expected_out[1])
3105
3106        self.assertEqual(tangents[0], yt.movedim(1, 0))
3107
3108        expected_tangent_1 = torch.zeros_like(x).movedim(1, 0)
3109        expected_tangent_1[:, 0].copy_(yt.movedim(1, 0))
3110        self.assertEqual(tangents[1], expected_tangent_1)
3111
3112    # Case 5 in [Forward Grad View/inplace]
3113    def test_right_dual_base_prop(self, device):
3114        B = 2
3115
3116        # Changes on the base must propagate on all its views. Also:
3117        # - x is a regular Tensor
3118        # - y is a dual tensor
3119        def f(x, y):
3120            x = x.clone()
3121            view = x[0]
3122            x.copy_(y)
3123            return view, x
3124
3125        def push_jvp(x, y, yt):
3126            return jvp(partial(f, x), (y,), (yt,))
3127
3128        x = torch.randn(2, B, 6)
3129        y = torch.randn(2, 6, B)
3130        yt = torch.randn(2, 6, B)
3131        outs, tangents = vmap(push_jvp, in_dims=(1, 2, 2))(x, y, yt)
3132
3133        expected_out = vmap(f, in_dims=(1, 2))(x, y)
3134        self.assertEqual(outs[0], expected_out[0])
3135        self.assertEqual(outs[1], expected_out[1])
3136
3137        self.assertEqual(tangents[0], yt.movedim(2, 0)[:, 0])
3138        self.assertEqual(tangents[1], yt.movedim(2, 0))
3139
3140
3141# Use for testing miscellaneous helper functions
3142@markDynamoStrictTest
3143class TestHelpers(TestCase):
3144    def test_CtxWithSavedTensors_error_if_name_collision(self, device):
3145        x = torch.randn([], device=device, requires_grad=True)
3146        y = torch.randn([], device=device, requires_grad=True)
3147
3148        class A(torch.autograd.Function):
3149            @staticmethod
3150            def forward(ctx, x):
3151                ctx._pt_inner_ctx = 1
3152                ctx.save_for_backward(x)
3153                return x
3154
3155            @staticmethod
3156            def backward(ctx, gy):
3157                wrapped = torch._functorch.autograd_function.CtxWithSavedTensors(
3158                    ctx, (y,)
3159                )
3160                return gy
3161
3162        class B(torch.autograd.Function):
3163            @staticmethod
3164            def forward(ctx, x):
3165                ctx._pt_new_saved_tensors = 1
3166                ctx.save_for_backward(x)
3167                return x
3168
3169            @staticmethod
3170            def backward(ctx, gy):
3171                wrapped = torch._functorch.autograd_function.CtxWithSavedTensors(
3172                    ctx, (y,)
3173                )
3174                return gy
3175
3176        out = A.apply(x)
3177        with self.assertRaisesRegex(RuntimeError, "name collision"):
3178            out.backward()
3179        out = B.apply(x)
3180        with self.assertRaisesRegex(RuntimeError, "name collision"):
3181            out.backward()
3182
3183    def test_CtxWithSavedTensors_nesting(self, device):
3184        CtxWithSavedTensors = torch._functorch.autograd_function.CtxWithSavedTensors
3185        x = torch.randn([], device=device, requires_grad=True)
3186        y = torch.randn([], device=device)
3187        z = torch.randn([], device=device)
3188
3189        class A(torch.autograd.Function):
3190            @staticmethod
3191            def forward(ctx, x):
3192                ctx.save_for_backward(x)
3193                return x
3194
3195            @staticmethod
3196            def backward(ctx, gy):
3197                ctx_y = CtxWithSavedTensors(ctx, (y,))
3198                # Can't use self.assertEqual because that relies on TLS
3199                # that is not available in multithread autograd
3200                assert len(ctx_y.saved_tensors) == 1
3201                assert torch.allclose(ctx_y.saved_tensors[0], y)
3202
3203                wrapped = CtxWithSavedTensors(ctx_y, (z,))
3204
3205                assert len(wrapped.saved_tensors) == 1
3206                assert torch.allclose(wrapped.saved_tensors[0], z)
3207
3208                assert len(ctx_y.saved_tensors) == 1
3209                assert torch.allclose(ctx_y.saved_tensors[0], y)
3210
3211                return gy * wrapped.saved_tensors[0]
3212
3213        out = A.apply(x)
3214        out.backward()
3215        self.assertEqual(x.grad, z)
3216
3217    def test_CtxWithSavedTensors_overrides_saved_tensors(self, device):
3218        x = torch.randn([], device=device, requires_grad=True)
3219
3220        class A(torch.autograd.Function):
3221            @staticmethod
3222            def forward(ctx, x):
3223                ctx.save_for_backward(x)
3224                return x
3225
3226            @staticmethod
3227            def backward(ctx, gy):
3228                # The override can be literally anything
3229                override = (1, 2, 3)
3230                wrapped = torch._functorch.autograd_function.CtxWithSavedTensors(
3231                    ctx, override
3232                )
3233                assert wrapped.saved_tensors == override
3234                return gy
3235
3236        out = A.apply(x)
3237        out.backward()
3238
3239    def test_CtxWithSavedTensors_passthrough(self, device):
3240        x = torch.randn([], device=device, requires_grad=True)
3241        y = torch.randn([], device=device)
3242
3243        class A(torch.autograd.Function):
3244            @staticmethod
3245            def forward(ctx, x, y):
3246                ctx.save_for_backward(x, y)
3247                return x * y
3248
3249            @staticmethod
3250            def backward(ctx, gz):
3251                # The override can be literally anything
3252                override = (1, 2, 3)
3253                wrapped = torch._functorch.autograd_function.CtxWithSavedTensors(
3254                    ctx, override
3255                )
3256
3257                assert wrapped.needs_input_grad[0] == ctx.needs_input_grad[0]
3258                assert wrapped.needs_input_grad[1] == ctx.needs_input_grad[1]
3259                wrapped.foo = "bar"
3260                assert wrapped.foo == "bar"
3261                assert ctx.foo == "bar"
3262                return gz, gz
3263
3264        out = A.apply(x, y)
3265        out.backward()
3266
3267    def test_reductify_leaf(self, device):
3268        reductify_leaf = torch._functorch.autograd_function.reductify_leaf
3269        B = 2
3270
3271        # grad_input None case
3272        output = reductify_leaf(None, None, 0, B)
3273        self.assertIsNone(output)
3274        output = reductify_leaf(None, None, None, B)
3275        self.assertIsNone(output)
3276
3277        # grad_input has bdim, input does not have bdim
3278        grad_input = torch.randn([B, 3, 4], device=device)
3279        output = reductify_leaf(grad_input, 0, None, B)
3280        self.assertEqual(output, grad_input.sum(0))
3281
3282        grad_input = torch.randn([3, B, 4], device=device)
3283        output = reductify_leaf(grad_input, 1, None, B, (3,))
3284        self.assertEqual(output, grad_input.sum(1))
3285
3286        # grad_input does not have bdim, input has bdim
3287        # This can happen if the user returns a fresh Tensor from the backward pass
3288        # that is unrelated to the input
3289        grad_input = torch.randn([3, 4], device=device)
3290        output = reductify_leaf(grad_input, None, 1, B)
3291        self.assertEqual(output, grad_input.view(3, 1, 4).expand(3, B, 4))
3292
3293        grad_input = torch.randn([3, 4], device=device)
3294        output = reductify_leaf(grad_input, None, 1, B, (4,))
3295        self.assertEqual(output, grad_input.view(3, 4, 1).expand(3, 4, B).sum(0))
3296
3297        # grad_input has bdim, input has bdim
3298        grad_input = torch.randn([B, 3, 4], device=device)
3299        output = reductify_leaf(grad_input, 0, 1, B)
3300        self.assertEqual(output, grad_input.movedim(0, 1))
3301
3302        grad_input = torch.randn([3, 4, 5, B], device=device)
3303        output = reductify_leaf(grad_input, 3, 0, B, (5,))
3304        self.assertEqual(output, grad_input.movedim(-1, 2).sum(0).sum(0))
3305
3306
3307@markDynamoStrictTest
3308class TestComposability(TestCase):
3309    def test_deprecation_vmap(self, device):
3310        x = torch.randn(3, device=device)
3311
3312        # functorch version of the API is deprecated
3313        with self.assertWarnsRegex(FutureWarning, "Please use `torch.vmap`"):
3314            vmap(torch.sin)
3315
3316        # the non-functorch version is not deprecated
3317        with warnings.catch_warnings():
3318            warnings.simplefilter("error")
3319            torch.vmap(torch.sin)
3320
3321    # Some of these pass, some of these don't
3322    @parametrize(
3323        "transform",
3324        ["grad", "jacrev", "jacfwd", "grad_and_value", "hessian", "functionalize"],
3325    )
3326    def test_deprecation_transforms(self, device, transform):
3327        api = getattr(functorch, transform)
3328        new_api = getattr(torch.func, transform)
3329
3330        # functorch version of the API is deprecated
3331        with self.assertWarnsRegex(
3332            FutureWarning, f"Please use `torch.func.{transform}`"
3333        ):
3334            api(torch.sin)
3335
3336        # the non-functorch version is not deprecated
3337        with warnings.catch_warnings():
3338            warnings.simplefilter("error")
3339            new_api(torch.sin)
3340
3341    def test_grad_grad(self, device):
3342        x = torch.randn([], device=device)
3343        y = grad(grad(torch.sin))(x)
3344        self.assertEqual(y, -x.sin())
3345
3346    def test_grad_vmap(self, device):
3347        def foo(x):
3348            y = vmap(torch.sin)(x)
3349            return y.sum()
3350
3351        x = torch.randn(3, device=device)
3352        y = grad(foo)(x)
3353        self.assertEqual(y, x.cos())
3354
3355    def test_grad_vjp(self, device):
3356        x = torch.randn(3, device=device)
3357
3358        def foo(x):
3359            _, vjp_fn = vjp(torch.sin, x)
3360            return vjp_fn(x)[0].sum()
3361
3362        y = grad(foo)(x)
3363        expected = grad(lambda x: (x * x.cos()).sum())(x)
3364        self.assertEqual(y, expected)
3365
3366    def test_vmap_grad(self, device):
3367        x = torch.randn(3, device=device)
3368        y = vmap(grad(torch.sin))(x)
3369        self.assertEqual(y, x.cos())
3370
3371    def test_vmap_vmap(self, device):
3372        x = torch.randn(2, 3, device=device)
3373        y = vmap(vmap(torch.sin))(x)
3374        self.assertEqual(y, x.sin())
3375
3376    def test_vmap_vjp(self, device):
3377        x = torch.randn(3, device=device)
3378        _, vjp_fn = vjp(torch.sin, x)
3379
3380        def foo(x):
3381            _, vjp_fn = vjp(torch.sin, x)
3382            return vjp_fn(x)
3383
3384        y = vmap(foo)(x)
3385        self.assertEqual(y, vjp_fn(x))
3386
3387        # TODO: there's a very interesting error message when the following
3388        # is on CPU
3389        xs = torch.randn(5, 3, device=device)
3390        expected = torch.stack([vjp_fn(x)[0] for x in xs])
3391        result = vmap(lambda x: vjp_fn(x)[0])(xs)
3392        self.assertEqual(result, expected)
3393
3394    def test_vjp_grad(self, device):
3395        x = torch.randn([], device=device)
3396        y, vjp_fn = vjp(grad(torch.sin), x)
3397        self.assertEqual(y, x.cos())
3398
3399        v = torch.randn([])
3400        self.assertEqual(vjp_fn(v)[0], -x.sin() * v)
3401
3402    def test_vjp_vmap(self, device):
3403        x = torch.randn(3, device=device)
3404        y, vjp_fn = vjp(vmap(torch.sin), x)
3405        self.assertEqual(y, x.sin())
3406
3407        v = torch.randn(3, device=device)
3408        self.assertEqual(vjp_fn(v)[0], x.cos() * v)
3409
3410    def test_vjp_vjp(self, device):
3411        x = torch.randn(3, device=device)
3412        y, vjp_fn = vjp(torch.sin, x)
3413        self.assertEqual(y, x.sin())
3414
3415        y, vjp_fn = vjp(lambda x: vjp_fn(x)[0], x)
3416        self.assertEqual(y, x * x.cos())
3417
3418        y = vjp_fn(x)[0]
3419        # Honestly IDK what the result here is... but at least it runs
3420
3421    def test_make_fx_vmap(self, device):
3422        def f(x):
3423            return torch.sin(x)
3424
3425        inp = torch.randn(5, 3)
3426        f = vmap(f)
3427        fx_f = make_fx(f)(inp)
3428        new_inp = torch.randn(5, 3)
3429        self.assertEqual(fx_f(new_inp), f(new_inp))
3430
3431    def test_make_fx_jacrev(self, device):
3432        def f(x):
3433            return x.sin().sum()
3434
3435        inp = torch.randn(3)
3436        f = jacrev(jacrev(f))
3437        fx_f = make_fx(f)(inp)
3438        new_inp = torch.randn(3)
3439        self.assertEqual(fx_f(new_inp), f(new_inp))
3440
3441    def test_make_fx_vjp(self, device):
3442        def f(x):
3443            return torch.sin(x).sum()
3444
3445        primals = torch.randn(3)
3446        _, vjp_fn = vjp(f, primals)
3447        cotangent = torch.randn(())
3448        fx_f = make_fx(vjp_fn)(cotangent, True, True)
3449        new_cotangent = torch.randn(())
3450        self.assertEqual(fx_f(new_cotangent, True, True), vjp_fn(new_cotangent))
3451
3452    # FIXME: test fails in Windows
3453    @unittest.skipIf(IS_WINDOWS, "fails in Windows; needs investigation")
3454    @unittest.skipIf(IS_FBCODE, "can't subprocess in fbcode")
3455    # it is redundant to run this test twice on a machine that has GPUs
3456    @onlyCPU
3457    def test_no_warning_on_import_functorch(self, device):
3458        out = subprocess.check_output(
3459            [sys.executable, "-W", "always", "-c", "import functorch"],
3460            stderr=subprocess.STDOUT,
3461            cwd=os.path.dirname(os.path.realpath(__file__)),
3462        ).decode("utf-8")
3463        self.assertEqual(out, "")
3464
3465    def test_requires_grad_inside_transform(self, device):
3466        def f(x):
3467            x.requires_grad_()
3468            return x.sin().sum()
3469
3470        x = torch.randn(3)
3471
3472        with self.assertRaisesRegex(RuntimeError, "Tensor.requires_grad_()"):
3473            vmap(f)(x)
3474        with self.assertRaisesRegex(RuntimeError, "Tensor.requires_grad_()"):
3475            grad(f)(x)
3476        with self.assertRaisesRegex(RuntimeError, "Tensor.requires_grad_()"):
3477            vmap(grad(f))(x)
3478
3479        x = torch.randn([])
3480        with self.assertRaisesRegex(RuntimeError, "Tensor.requires_grad_()"):
3481            grad(grad(f))(x)
3482
3483    def test_retain_grad_inside_transform(self, device):
3484        def f(x):
3485            y = x.sin()
3486            y.retain_grad()
3487            return y.sum()
3488
3489        x = torch.randn(3)
3490
3491        with self.assertRaisesRegex(RuntimeError, "Tensor.retain_grad()"):
3492            grad(f)(x)
3493
3494    def test_autograd_functional_jacrev_inside_transform(self, device):
3495        def f(x):
3496            y = torch.autograd.functional.jacobian(lambda x: x.sin().sum(), x)
3497            return y
3498
3499        B = 5
3500        x = torch.randn(B, 3)
3501        with self.assertRaisesRegex(RuntimeError, "torch.autograd.functional"):
3502            vmap(f)(x)
3503
3504        x = torch.randn([])
3505        with self.assertRaisesRegex(RuntimeError, "torch.autograd.functional"):
3506            grad(f)(x)
3507
3508    def test_autograd_functional_vjp_inside_transform(self, device):
3509        def f(x):
3510            y = torch.autograd.functional.vjp(lambda x: x.sin().sum(), x)
3511            return y
3512
3513        B = 5
3514        x = torch.randn(B, 3)
3515        with self.assertRaisesRegex(RuntimeError, "torch.autograd.functional"):
3516            vmap(f)(x)
3517
3518        x = torch.randn([])
3519        with self.assertRaisesRegex(RuntimeError, "torch.autograd.functional"):
3520            grad(f)(x)
3521
3522    def test_autograd_functional_jvp_inside_transform(self, device):
3523        def f(x):
3524            t = torch.ones_like(x)
3525            y = torch.autograd.functional.jvp(lambda x: x.sin().sum(), (x,), (t,))
3526            return y
3527
3528        B = 5
3529        x = torch.randn(B, 3)
3530        with self.assertRaisesRegex(RuntimeError, "torch.autograd.functional"):
3531            vmap(f)(x)
3532
3533        x = torch.randn([])
3534        with self.assertRaisesRegex(RuntimeError, "torch.autograd.functional"):
3535            grad(f)(x)
3536
3537    def test_autograd_functional_jacfwd_inside_transform(self, device):
3538        def f(x):
3539            y = torch.autograd.functional.jacobian(
3540                lambda x: x.sin().sum(), x, strategy="forward-mode", vectorize=True
3541            )
3542            return y
3543
3544        B = 5
3545        x = torch.randn(B, 3)
3546        with self.assertRaisesRegex(
3547            RuntimeError, "Batching rule not implemented for aten::_make_dual"
3548        ):
3549            vmap(f)(x)
3550
3551    @parametrize(
3552        "transform",
3553        [
3554            "vmap",
3555            "grad",
3556            "jacrev",
3557            "jacfwd",
3558            "grad_and_value",
3559            "hessian",
3560            "functionalize",
3561        ],
3562    )
3563    def test_autograd_function_no_setup_context(self, device, transform):
3564        class MySin(torch.autograd.Function):
3565            @staticmethod
3566            def forward(ctx, x):
3567                ctx.save_for_backward(x)
3568                return x.sin()
3569
3570            @staticmethod
3571            def backward(ctx, gy):
3572                (x,) = ctx.saved_tensors
3573                return gy * x.cos()
3574
3575        x = torch.randn(3, device=device)
3576        transform = getattr(functorch, transform)
3577        with self.assertRaisesRegex(RuntimeError, "must override the setup_context"):
3578            transform(MySin.apply)(x)
3579
3580    # Some of these pass, some of these don't
3581    @parametrize(
3582        "transform",
3583        [
3584            "grad",
3585            "jacrev",
3586            "grad_and_value",
3587            "hessian",
3588        ],
3589    )
3590    def test_transforms_dont_support_saved_tensor_hooks(self, device, transform):
3591        def f(x):
3592            return torch.sin(x).sum()
3593
3594        def g(x):
3595            with torch.autograd.graph.save_on_cpu():
3596                return f(x)
3597
3598        x = torch.randn(3, device=device)
3599
3600        if transform == "functionalize":
3601            transform = functorch.experimental.functionalize
3602        else:
3603            transform = getattr(functorch, transform)
3604        with self.assertRaisesRegex(RuntimeError, "saved tensor hooks"):
3605            with torch.autograd.graph.save_on_cpu():
3606                transform(f)(x)
3607
3608        with self.assertRaisesRegex(RuntimeError, "saved tensor hooks"):
3609            transform(g)(x)
3610
3611    def test_vjp_doesnt_support_saved_tensor_hooks(self, device):
3612        def f(x):
3613            return torch.sin(x).sum()
3614
3615        def g(x):
3616            with torch.autograd.graph.save_on_cpu():
3617                return f(x)
3618
3619        x = torch.randn(3, device=device)
3620        with self.assertRaisesRegex(RuntimeError, "saved tensor hooks"):
3621            with torch.autograd.graph.save_on_cpu():
3622                vjp(f, x)
3623
3624        with self.assertRaisesRegex(RuntimeError, "saved tensor hooks"):
3625            vjp(g, x)
3626
3627    def test_jvp_supports_saved_tensor_hooks(self, device):
3628        def f(x):
3629            return torch.sin(x).sum()
3630
3631        def g(x):
3632            with torch.autograd.graph.save_on_cpu():
3633                return f(x)
3634
3635        x = torch.randn(3, device=device)
3636        t = torch.randn(3, device=device)
3637
3638        # smoke tests
3639        with torch.autograd.graph.save_on_cpu():
3640            jvp(f, (x,), (t,))
3641
3642        # smoke tests
3643        jvp(g, (x,), (t,))
3644
3645    def test_can_use_functionalize_when_key_is_excluded(self, device):
3646        def f(x):
3647            y = x.clone()
3648            y.sin_()
3649            return y
3650
3651        x = torch.randn([], device=device)
3652        expected = f(x)
3653
3654        with _ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.Functionalize)):
3655            gm = make_fx(functorch.functionalize(f))(x)
3656            self.assertTrue("sin_" not in gm.code)
3657            self.assertEqual(gm(x), expected)
3658
3659            local_exclude_set = torch._C._dispatch_tls_local_exclude_set()
3660            self.assertTrue(local_exclude_set.has(DispatchKey.Functionalize))
3661
3662    def test_can_use_vmap_when_key_is_excluded(self, device):
3663        def f(x):
3664            return x.sum(0)
3665
3666        x = torch.randn(3, device=device)
3667        expected = vmap(f)(x)
3668
3669        with _ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.FuncTorchBatched)):
3670            result = vmap(f)(x)
3671            self.assertEqual(result, expected)
3672            local_exclude_set = torch._C._dispatch_tls_local_exclude_set()
3673            self.assertTrue(local_exclude_set.has(DispatchKey.FuncTorchBatched))
3674
3675    def test_can_use_grad_when_key_is_excluded(self, device):
3676        def f(x):
3677            return x.sin()
3678
3679        x = torch.randn([], device=device)
3680        expected = grad(f)(x)
3681
3682        with _ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.Autograd)):
3683            result = grad(f)(x)
3684            self.assertEqual(result, expected)
3685            local_exclude_set = torch._C._dispatch_tls_local_exclude_set()
3686            self.assertTrue(local_exclude_set.has(DispatchKey.Autograd))
3687
3688
3689@markDynamoStrictTest
3690class TestMakeFunctional(TestCase):
3691    @parametrize("disable_autograd_tracking", [True, False])
3692    def test_disable_autograd_tracking(self, disable_autograd_tracking):
3693        class Foo(nn.Module):
3694            def __init__(self) -> None:
3695                super().__init__()
3696                self.linear = nn.Linear(3, 3)
3697
3698            def forward(self, x):
3699                x = self.linear(x)
3700                return x
3701
3702        mod = Foo()
3703        _, params = make_functional(
3704            mod, disable_autograd_tracking=disable_autograd_tracking
3705        )
3706        self.assertEqual(len(params), 2)
3707        for param in params:
3708            self.assertEqual(param.requires_grad, not disable_autograd_tracking)
3709
3710    def test_parameter_tying(self):
3711        class Foo(nn.Module):
3712            def __init__(self) -> None:
3713                super().__init__()
3714                self.bias = nn.Parameter(torch.randn(3))
3715                self.linear = nn.Linear(3, 3)
3716                self.linear.bias = self.bias
3717                self.linear_tied = self.linear
3718
3719            def forward(self, x):
3720                x = self.linear(x)
3721                x = self.linear_tied(x)
3722                x = x + self.bias
3723                return x
3724
3725        torch.manual_seed(1)
3726        mod = Foo()
3727        func, _ = make_functional(mod)
3728
3729        torch.manual_seed(0)
3730        mod = Foo()
3731        _, params = make_functional(mod)
3732        self.assertEqual(len(params), 2)
3733
3734        x = torch.randn(2, 3)
3735        result = func(params, x)
3736        expected = mod(x)
3737        self.assertEqual(result, expected)
3738
3739    def test_buffer_tying(self):
3740        class Foo(nn.Module):
3741            def __init__(self) -> None:
3742                super().__init__()
3743                self.bias = nn.Parameter(torch.randn(3))
3744                self.linear = nn.Linear(3, 3)
3745                self.buffer = nn.Buffer(torch.randn(3))
3746                self.buffer_tied = self.buffer
3747
3748            def forward(self, x):
3749                x = self.linear(x)
3750                x = x + self.bias
3751                x = x + self.buffer
3752                x = x + self.buffer_tied
3753                return x
3754
3755        torch.manual_seed(1)
3756        mod = Foo()
3757        func, _, _ = make_functional_with_buffers(mod)
3758
3759        torch.manual_seed(0)
3760        mod = Foo()
3761        _, params, buffers = make_functional_with_buffers(mod)
3762        self.assertEqual(len(params), 3)
3763        self.assertEqual(len(buffers), 1)
3764
3765        x = torch.randn(2, 3)
3766        result = func(params, buffers, x)
3767        expected = mod(x)
3768        self.assertEqual(result, expected)
3769
3770    @parametrize("disable_autograd_tracking", [True, False])
3771    def test_with_buffers_disable_autograd_tracking(self, disable_autograd_tracking):
3772        class Foo(nn.Module):
3773            def __init__(self) -> None:
3774                super().__init__()
3775                self.linear = nn.Linear(3, 3)
3776                self.buffer = nn.Buffer(torch.randn(3))
3777
3778            def forward(self, x):
3779                x = self.linear(x)
3780                x = x + self.buffer
3781                return x
3782
3783        mod = Foo()
3784        _, params, buffers = make_functional_with_buffers(
3785            mod, disable_autograd_tracking=disable_autograd_tracking
3786        )
3787        self.assertEqual(len(params), 2)
3788        self.assertEqual(len(buffers), 1)
3789        for param in params:
3790            self.assertEqual(param.requires_grad, not disable_autograd_tracking)
3791
3792    @parametrize("detach_params", [True, False])
3793    def test_using_detach_functional_call(self, detach_params):
3794        class Foo(nn.Module):
3795            def __init__(self) -> None:
3796                super().__init__()
3797                self.linear = nn.Linear(3, 3)
3798                self.buffer = nn.Buffer(torch.randn(3))
3799
3800            def forward(self, x):
3801                x = self.linear(x)
3802                x = x + self.buffer
3803                return x
3804
3805        def params_dict(mod):
3806            named_params = mod.named_parameters()
3807            return (
3808                {k: v.detach() for k, v in named_params}
3809                if detach_params
3810                else dict(named_params)
3811            )
3812
3813        mod = Foo()
3814        x = torch.randn(3, 3)
3815        d = (params_dict(mod), dict(mod.named_buffers()))
3816        out = functional_call(mod, d, x)
3817        self.assertEqual(out.grad_fn is None, detach_params)
3818
3819    def test_parameter_tying_grad(self):
3820        class Foo(nn.Module):
3821            def __init__(self) -> None:
3822                super().__init__()
3823                self.linear = nn.Linear(3, 3)
3824                self.weight = self.linear.weight
3825                self.bias = self.linear.bias
3826
3827            def forward(self, x):
3828                x = self.linear(x)
3829                x = F.linear(x, self.weight, self.bias)
3830                return x
3831
3832        x = torch.randn(2, 3)
3833        torch.manual_seed(0)
3834        mod = Foo()
3835        loss = mod(x).sum()
3836        expected = torch.autograd.grad(loss, mod.parameters())
3837
3838        mod = Foo()
3839        fmod, _, _ = make_functional_with_buffers(mod)
3840        torch.manual_seed(0)
3841        mod = Foo()
3842        _, params, buffers = make_functional_with_buffers(mod)
3843
3844        def compute_loss(params, buffers, x):
3845            return fmod(params, buffers, x).sum()
3846
3847        result = grad(compute_loss)(params, buffers, x)
3848
3849        self.assertEqual(result, expected)
3850
3851    def test_parameter_tying_ensemble(self):
3852        class Foo(nn.Module):
3853            def __init__(self) -> None:
3854                super().__init__()
3855                self.linear = nn.Linear(3, 3)
3856                self.weight = self.linear.weight
3857                self.bias = self.linear.bias
3858                self.buffer = nn.Buffer(torch.randn(3))
3859                self.buffer_tied = self.buffer
3860
3861            def forward(self, x):
3862                x = self.linear(x)
3863                x = F.linear(x, self.weight, self.bias)
3864                x = x + self.buffer
3865                x = x + self.buffer_tied
3866                return x
3867
3868        num_models = 2
3869        xs = torch.randn(num_models, 64, 3)
3870        models = [Foo() for _ in range(num_models)]
3871        fmodel, _, _ = combine_state_for_ensemble(models)
3872
3873        torch.manual_seed(0)
3874        models = [Foo() for _ in range(num_models)]
3875        _, params, buffers = combine_state_for_ensemble(models)
3876        result = vmap(fmodel)(params, buffers, xs)
3877
3878        torch.manual_seed(0)
3879        models = [Foo() for _ in range(num_models)]
3880        expected = torch.stack([model(x) for model, x in zip(models, xs)])
3881
3882        self.assertEqual(result, expected)
3883
3884    @parametrize("mechanism", ["make_functional", "functional_call"])
3885    def test_correctness_mnist(self, mechanism):
3886        class Net(nn.Module):
3887            def __init__(self) -> None:
3888                super().__init__()
3889                self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
3890                self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
3891                self.conv2_drop = nn.Dropout2d()
3892                self.fc1 = nn.Linear(320, 50)
3893                self.fc2 = nn.Linear(50, 10)
3894
3895            def forward(self, x):
3896                x = F.relu(F.max_pool2d(self.conv1(x), 2))
3897                x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
3898                x = x.view(-1, 320)
3899                x = F.relu(self.fc1(x))
3900                x = F.dropout(x, training=self.training)
3901                x = self.fc2(x)
3902                return F.log_softmax(x)
3903
3904        x = torch.randn(64, 1, 32, 32)
3905        torch.manual_seed(301)
3906        fnet, _ = _get_weights_and_functional_call(Net(), mechanism)
3907
3908        torch.manual_seed(0)
3909        _, params = _get_weights_and_functional_call(Net(), mechanism)
3910        result = fnet(params, x)
3911
3912        torch.manual_seed(0)
3913        net = Net()
3914        expected = net(x)
3915
3916        self.assertEqual(result, expected)
3917
3918    def test_combine_state_for_ensemble_error(self):
3919        in_features = 2
3920        out_features = 2
3921
3922        models = []
3923        with self.assertRaisesRegex(RuntimeError, "Expected at least one model"):
3924            _ = combine_state_for_ensemble(models)
3925
3926        num_models = 3
3927        models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)]
3928        models[1].eval()
3929        with self.assertRaisesRegex(RuntimeError, "same training/eval mode"):
3930            _ = combine_state_for_ensemble(models)
3931
3932        models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)]
3933        models[1] = torch.nn.Conv2d(3, 3, (3, 3))
3934        with self.assertRaisesRegex(RuntimeError, "models to be of the same class"):
3935            _ = combine_state_for_ensemble(models)
3936
3937    def test_combine_state_for_ensemble_smoke(self):
3938        in_features = 2
3939        out_features = 2
3940        num_models = 3
3941        models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)]
3942        _ = combine_state_for_ensemble(models)
3943
3944    def test_stack_module_state_smoke(self):
3945        in_features = 2
3946        out_features = 2
3947        num_models = 3
3948        models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)]
3949        _ = stack_module_state(models)
3950
3951    def test_stack_module_state_leaf(self):
3952        in_features = 2
3953        out_features = 2
3954        num_models = 3
3955        models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)]
3956        params, buffers = stack_module_state(models)
3957        for param in params.values():
3958            self.assertTrue(param.requires_grad)
3959            self.assertTrue(param.is_leaf)
3960
3961    def test_stack_module_state_mismatch_error(self):
3962        in_features = 2
3963        out_features = 2
3964        num_models = 3
3965        models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)]
3966        models[0].weight.requires_grad_(False)
3967        with self.assertRaisesRegex(RuntimeError, "same .requires_grad"):
3968            params, buffers = stack_module_state(models)
3969
3970    def test_stack_module_state_error(self):
3971        in_features = 2
3972        out_features = 2
3973
3974        models = []
3975        with self.assertRaisesRegex(
3976            RuntimeError, "stack_module_state:.* Expected at least one model"
3977        ):
3978            _ = stack_module_state(models)
3979
3980        num_models = 3
3981        models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)]
3982        models[1].eval()
3983        with self.assertRaisesRegex(
3984            RuntimeError, "stack_module_state:.* same training/eval mode."
3985        ):
3986            _ = stack_module_state(models)
3987
3988        models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)]
3989        models[1] = torch.nn.Conv2d(3, 3, (3, 3))
3990        with self.assertRaisesRegex(
3991            RuntimeError, "stack_module_state:.* models to be of the same class"
3992        ):
3993            _ = stack_module_state(models)
3994
3995    @parametrize("mechanism", ["make_functional", "functional_call"])
3996    def test_make_functional_state_correctly_returned_after_forward(self, mechanism):
3997        class Net(nn.Module):
3998            def __init__(self) -> None:
3999                super().__init__()
4000                self.linear = nn.Linear(3, 3)
4001
4002            def forward(self, x):
4003                x = self.linear(x)
4004                return x
4005
4006        def get_module_info(mod):
4007            if mechanism == "make_functional":
4008                return make_functional(mod)
4009            else:
4010                assert mechanism == "functional_call"
4011                return mod, dict(mod.named_parameters())
4012
4013        mod = Net()
4014        func_mod, params = get_module_info(mod)
4015
4016        # state in func.names_map
4017        mod = func_mod.stateless_model if mechanism == "make_functional" else func_mod
4018        old_state_linear_weight = mod.linear.weight
4019        old_state_linear_bias = mod.linear.bias
4020
4021        self.assertIsNotNone(old_state_linear_weight)
4022        self.assertIsNotNone(old_state_linear_bias)
4023
4024        x = torch.randn(4, 3)
4025        if mechanism == "make_functional":
4026            func_mod(params, x)
4027        else:
4028            assert mechanism == "functional_call"
4029            functional_call(func_mod, params, x)
4030
4031        mod = func_mod.stateless_model if mechanism == "make_functional" else func_mod
4032        new_state_linear_weight = mod.linear.weight
4033        new_state_linear_bias = mod.linear.bias
4034
4035        self.assertIsNotNone(new_state_linear_weight)
4036        self.assertIsNotNone(new_state_linear_bias)
4037
4038        self.assertEqual(old_state_linear_weight, new_state_linear_weight)
4039        self.assertEqual(old_state_linear_bias, new_state_linear_bias)
4040
4041
4042@markDynamoStrictTest
4043class TestExamplesCorrectness(TestCase):
4044    def _update_params(self, params, grads, alpha, mechanism):
4045        if mechanism == "make_functional":
4046            return [(params[i] - alpha * grads[i]) for i in range(len(params))]
4047        else:
4048            assert mechanism == "functional_call"
4049            return {k: params[k] - alpha * grads[k] for k in params}
4050
4051    @parametrize("mechanism", ["make_functional", "functional_call"])
4052    def test_maml_regression(self, device, mechanism):
4053        class ThreeLayerNet(nn.Module):
4054            def __init__(self) -> None:
4055                super().__init__()
4056                self.fc1 = nn.Linear(1, 40)
4057                self.relu1 = nn.ReLU()
4058                self.fc2 = nn.Linear(40, 40)
4059                self.relu2 = nn.ReLU()
4060                self.fc3 = nn.Linear(40, 1)
4061
4062            def forward(self, x):
4063                x = self.fc1(x)
4064                x = self.relu1(x)
4065                x = self.fc2(x)
4066                x = self.relu2(x)
4067                x = self.fc3(x)
4068                return x
4069
4070        # TODO: should replace with F.mse_loss
4071        def mse_loss(x, y):
4072            return torch.mean((x - y) ** 2)
4073
4074        net, params = _get_weights_and_functional_call(
4075            ThreeLayerNet().to(device), mechanism
4076        )
4077        K = 20
4078        num_tasks = 4
4079        alpha = 0.1
4080
4081        def sample_tasks(outer_batch_size, inner_batch_size):
4082            # Select amplitude and phase for the task
4083            As = []
4084            phases = []
4085            for _ in range(outer_batch_size):
4086                As.append(np.random.uniform(low=0.1, high=0.5))
4087                phases.append(np.random.uniform(low=0.0, high=np.pi))
4088
4089            def get_batch():
4090                xs, ys = [], []
4091                for A, phase in zip(As, phases):
4092                    x = np.random.uniform(
4093                        low=-5.0, high=5.0, size=(inner_batch_size, 1)
4094                    )
4095                    y = A * np.sin(x + phase)
4096                    xs.append(x)
4097                    ys.append(y)
4098                return torch.tensor(xs, dtype=torch.float, device=device), torch.tensor(
4099                    ys, dtype=torch.float, device=device
4100                )
4101
4102            x1, y1 = get_batch()
4103            x2, y2 = get_batch()
4104            return x1, y1, x2, y2
4105
4106        def get_loss_for_task(use_transform, x1, y1, x2, y2):
4107            def inner_loss(params, x1, y1):
4108                f = net(params, x1)
4109                loss = mse_loss(f, y1)
4110                return loss
4111
4112            if use_transform:
4113                grads = grad(inner_loss)(params, x1, y1)
4114            else:
4115                loss = inner_loss(params, x1, y1)
4116                grad_params, spec = tree_flatten(params)
4117                grads = torch.autograd.grad(loss, grad_params, create_graph=True)
4118                grads = tree_unflatten(grads, spec)
4119
4120            new_params = self._update_params(params, grads, alpha, mechanism)
4121
4122            v_f = net(new_params, x2)
4123            return mse_loss(v_f, y2)
4124
4125        task = sample_tasks(num_tasks, K)
4126        list_params = (
4127            params if mechanism == "make_functional" else list(params.values())
4128        )
4129
4130        # Compute with vmap+grad
4131        inner_losses = vmap(partial(get_loss_for_task, True))(
4132            task[0], task[1], task[2], task[3]
4133        )
4134        loss2 = sum(inner_losses) / len(inner_losses)
4135        result_grads = torch.autograd.grad(loss2, list_params)
4136
4137        # Compute without vmap+grad
4138        inner_losses = [
4139            get_loss_for_task(False, task[0][i], task[1][i], task[2][i], task[3][i])
4140            for i in range(num_tasks)
4141        ]
4142        loss2 = sum(inner_losses) / len(inner_losses)
4143        expected_grads = torch.autograd.grad(loss2, list_params)
4144
4145        self.assertEqual(result_grads, expected_grads)
4146
4147    @parametrize("mechanism", ["make_functional", "functional_call"])
4148    def test_maml_omniglot(self, device, mechanism):
4149        # TODO: there appears to be precision issues for float32
4150        dtype = torch.double
4151
4152        # TODO: We don't support inplace relu?
4153        inplace_relu = False
4154        n_way = 5
4155        n_inner_iter = 2
4156        num_tasks = 2
4157
4158        # real example uses batch norm but it's numerically unstable in the first
4159        # iteration, when near 0, and won't produce same gradients. Uses group norm instead
4160        net = (
4161            nn.Sequential(
4162                nn.Conv2d(1, 64, 3),
4163                nn.GroupNorm(64, 64, affine=True),
4164                nn.ReLU(inplace=inplace_relu),
4165                nn.MaxPool2d(2, 2),
4166                nn.Conv2d(64, 64, 3),
4167                nn.GroupNorm(64, 64, affine=True),
4168                nn.ReLU(inplace=inplace_relu),
4169                nn.MaxPool2d(2, 2),
4170                nn.Conv2d(64, 64, 3),
4171                nn.GroupNorm(64, 64, affine=True),
4172                nn.ReLU(inplace=inplace_relu),
4173                nn.MaxPool2d(2, 2),
4174                nn.Flatten(),
4175                nn.Linear(64, n_way),
4176            )
4177            .to(device)
4178            .to(dtype)
4179        )
4180
4181        fnet, params, buffers = _get_weights_and_functional_call_with_buffers(
4182            net, mechanism
4183        )
4184        net = (params, buffers, fnet)
4185
4186        def loss_for_task(net, n_inner_iter, use_transform, x_spt, y_spt, x_qry, y_qry):
4187            params, buffers, fnet = net
4188            querysz = x_qry.size(0)
4189
4190            def compute_loss(new_params, buffers, x, y):
4191                logits = fnet(new_params, buffers, x)
4192                loss = F.cross_entropy(logits, y)
4193                return loss
4194
4195            new_params = params
4196            for _ in range(n_inner_iter):
4197                if use_transform:
4198                    grads = grad(compute_loss)(new_params, buffers, x_spt, y_spt)
4199                else:
4200                    res = compute_loss(new_params, buffers, x_spt, y_spt)
4201                    grad_params, spec = tree_flatten(new_params)
4202                    grads = torch.autograd.grad(res, grad_params, create_graph=True)
4203                    grads = tree_unflatten(grads, spec)
4204
4205                new_params = self._update_params(new_params, grads, 1e-1, mechanism)
4206
4207            qry_logits = fnet(new_params, buffers, x_qry)
4208            qry_loss = F.cross_entropy(qry_logits, y_qry)
4209            qry_acc = (qry_logits.argmax(dim=1) == y_qry).sum() / querysz
4210
4211            return qry_loss, qry_acc
4212
4213        # Get some sample inputs...
4214        x_spt = torch.randn(num_tasks, 25, 1, 28, 28, dtype=dtype, device=device)
4215        y_spt = torch.randint(0, 5, (num_tasks, 25), device=device)
4216        x_qry = torch.randn(num_tasks, 75, 1, 28, 28, dtype=dtype, device=device)
4217        y_qry = torch.randint(0, 5, (num_tasks, 75), device=device)
4218
4219        # compute with vmap + grad
4220        compute_loss = partial(loss_for_task, net, n_inner_iter, True)
4221        qry_losses, _ = vmap(compute_loss)(x_spt, y_spt, x_qry, y_qry)
4222        list_params = (
4223            params if mechanism == "make_functional" else list(params.values())
4224        )
4225        result_grads = torch.autograd.grad(qry_losses.sum(), list_params)
4226
4227        # compute without vmap + grad
4228        compute_loss = partial(loss_for_task, net, n_inner_iter, False)
4229        losses = [
4230            compute_loss(x_spt[i], y_spt[i], x_qry[i], y_qry[i])[0]
4231            for i in range(num_tasks)
4232        ]
4233        expected_grads = torch.autograd.grad(sum(losses), list_params)
4234
4235        self.assertEqual(result_grads, expected_grads)
4236
4237    @parametrize("mechanism", ["make_functional", "functional_call"])
4238    @parametrize("originally_track_running_stats", [True, False])
4239    def test_update_batch_norm(self, device, originally_track_running_stats, mechanism):
4240        dtype = torch.double
4241        inplace_relu = False
4242        classes = 5
4243        num_batches = 2
4244        net = (
4245            nn.Sequential(
4246                nn.Conv2d(64, 64, 3),
4247                nn.BatchNorm2d(
4248                    64, affine=True, track_running_stats=originally_track_running_stats
4249                ),
4250                nn.ReLU(inplace=inplace_relu),
4251                nn.Flatten(),
4252                nn.Linear(43264, classes),
4253            )
4254            .to(device)
4255            .to(dtype)
4256        )
4257
4258        replace_all_batch_norm_modules_(net)
4259        transformed_net = net
4260        fnet, params, buffers = _get_weights_and_functional_call_with_buffers(
4261            transformed_net, mechanism
4262        )
4263        criterion = nn.CrossEntropyLoss()
4264
4265        def compute_loss(x, y, params, buffers):
4266            return criterion(fnet(params, buffers, x), y)
4267
4268        # Get some sample inputs...
4269        x = torch.randn(num_batches, 1, 64, 28, 28, device=device, dtype=dtype)
4270        y = torch.randint(0, classes, (num_batches, 1), device=device)
4271
4272        # compute some per sample grads with vmap + grad
4273        result_grads = vmap(grad(compute_loss, argnums=2), in_dims=(0, 0, None, None))(
4274            x, y, params, buffers
4275        )
4276
4277        # compute some per sample grads without vmap + grad
4278        fnet, params, buffers = _get_weights_and_functional_call_with_buffers(
4279            transformed_net, mechanism
4280        )
4281        flat_params, spec = tree_flatten(params)
4282        expected_grads = [
4283            torch.autograd.grad(compute_loss(x[i], y[i], params, buffers), flat_params)
4284            for i in range(num_batches)
4285        ]
4286        expected_grads = [torch.stack(shards) for shards in zip(*expected_grads)]
4287        expected_grads = tree_unflatten(expected_grads, spec)
4288
4289        self.assertEqual(result_grads, expected_grads)
4290
4291    @parametrize("jac", ["jacfwd", "jacrev"])
4292    def test_lennard_jones_batched_jac(self, device, jac):
4293        sigma = 0.5
4294        epsilon = 4.0
4295
4296        jac = getattr(functorch, jac)
4297
4298        def lennard_jones(r):
4299            return epsilon * ((sigma / r) ** 12 - (sigma / r) ** 6)
4300
4301        def lennard_jones_force(r):
4302            """Get magnitude of LJ force"""
4303            return -epsilon * (
4304                (-12 * sigma**12 / r**13) + (6 * sigma**6 / r**7)
4305            )
4306
4307        r = torch.linspace(0.5, 2 * sigma, steps=100, requires_grad=True, device=device)
4308        drs = torch.outer(r, torch.tensor([1.0, 0, 0], device=device))
4309        norms = torch.norm(drs, dim=1).reshape(-1, 1)
4310        training_energies = torch.stack(list(map(lennard_jones, norms))).reshape(-1, 1)
4311        training_forces = torch.stack(
4312            [force * dr for force, dr in zip(map(lennard_jones_force, norms), drs)]
4313        )
4314
4315        model = nn.Sequential(
4316            nn.Linear(1, 16),
4317            nn.Tanh(),
4318            nn.Linear(16, 16),
4319            nn.Tanh(),
4320            nn.Linear(16, 16),
4321            nn.Tanh(),
4322            nn.Linear(16, 16),
4323            nn.Tanh(),
4324            nn.Linear(16, 1),
4325        ).to(device)
4326
4327        def make_prediction(model, drs, use_functorch):
4328            norms = torch.norm(drs, dim=1).reshape(-1, 1)
4329            energies = model(norms)
4330
4331            if use_functorch:
4332                network_derivs = vmap(jac(model))(norms).squeeze(-1)
4333                forces = -network_derivs * drs / norms
4334            else:
4335                forces = []
4336                for r, dr in zip(norms, drs):
4337                    network_deriv = torch.autograd.functional.jacobian(
4338                        model, r, create_graph=True
4339                    )
4340                    force = -network_deriv * dr / r
4341                    forces.append(force)
4342                forces = torch.cat(forces)
4343            return energies, forces
4344
4345        def loss_fn(energies, forces, predicted_energies, predicted_forces):
4346            return (
4347                F.mse_loss(energies, predicted_energies)
4348                + 0.01 * F.mse_loss(forces, predicted_forces) / 3
4349            )
4350
4351        energies, forces = make_prediction(model, drs, use_functorch=True)
4352        loss = loss_fn(training_energies, training_forces, energies, forces)
4353        result = torch.autograd.grad(loss, model.parameters())
4354
4355        energies, forces = make_prediction(model, drs, use_functorch=False)
4356        loss = loss_fn(training_energies, training_forces, energies, forces)
4357        expected = torch.autograd.grad(loss, model.parameters())
4358
4359        self.assertEqual(result, expected)
4360
4361    @parametrize("mechanism", ["make_functional", "functional_call"])
4362    def test_ensemble_regression(self, device, mechanism):
4363        def make_spirals(n_samples, noise_std=0.0, rotations=1.0):
4364            ts = torch.linspace(0, 1, n_samples)
4365            rs = ts**0.5
4366            thetas = rs * rotations * 2 * math.pi
4367            signs = torch.randint(0, 2, (n_samples,)) * 2 - 1
4368            labels = (signs > 0).to(torch.long)
4369
4370            xs = rs * signs * torch.cos(thetas) + torch.randn(n_samples) * noise_std
4371            ys = rs * signs * torch.sin(thetas) + torch.randn(n_samples) * noise_std
4372            points = torch.stack([xs, ys], dim=1)
4373            return points.to(device), labels.to(device)
4374
4375        points, labels = make_spirals(100, noise_std=0.05)
4376
4377        class MLPClassifier(nn.Module):
4378            def __init__(self, hidden_dim=32, n_classes=2):
4379                super().__init__()
4380                self.hidden_dim = hidden_dim
4381                self.n_classes = n_classes
4382
4383                self.fc1 = nn.Linear(2, self.hidden_dim)
4384                self.fc2 = nn.Linear(self.hidden_dim, self.n_classes)
4385
4386            def forward(self, x):
4387                x = self.fc1(x)
4388                x = F.relu(x)
4389                x = self.fc2(x)
4390                x = F.log_softmax(x, -1)
4391                return x
4392
4393        loss_fn = nn.NLLLoss()
4394
4395        func_model, weights = _get_weights_and_functional_call(
4396            MLPClassifier().to(device), mechanism
4397        )
4398
4399        def train_step_fn(use_transform, weights, batch, targets, lr=0.2):
4400            def compute_loss(weights, batch, targets):
4401                output = func_model(weights, batch)
4402                loss = loss_fn(output, targets)
4403                return loss
4404
4405            if use_transform:
4406                grad_weights, loss = grad_and_value(compute_loss)(
4407                    weights, batch, targets
4408                )
4409            else:
4410                loss = compute_loss(weights, batch, targets)
4411                flat_weights, spec = tree_flatten(weights)
4412                flat_grad_weights = torch.autograd.grad(loss, flat_weights)
4413                grad_weights = tree_unflatten(flat_grad_weights, spec)
4414
4415            new_weights = self._update_params(weights, grad_weights, lr, mechanism)
4416            return (loss, new_weights)
4417
4418        def unpack(train_result):
4419            return train_result[0], train_result[1]
4420
4421        def init_fn(num_models):
4422            models = tuple(MLPClassifier().to(device) for _ in range(num_models))
4423            if mechanism == "make_functional":
4424                return combine_state_for_ensemble(models)[1]
4425            else:
4426                return stack_module_state(models)[0]
4427
4428        def slice_weights(batched_weights, index):
4429            return tree_map(
4430                lambda weight: weight[index].detach().requires_grad_(), batched_weights
4431            )
4432
4433        batched_weights = init_fn(num_models=2)
4434        parallel_train_step_fn = vmap(
4435            partial(train_step_fn, True), in_dims=(0, None, None)
4436        )
4437
4438        result_loss, result_weights = unpack(
4439            parallel_train_step_fn(batched_weights, points, labels)
4440        )
4441
4442        loss0, weights0 = unpack(
4443            train_step_fn(False, slice_weights(batched_weights, 0), points, labels)
4444        )
4445        loss1, weights1 = unpack(
4446            train_step_fn(False, slice_weights(batched_weights, 1), points, labels)
4447        )
4448        expected_loss = torch.stack([loss0, loss1])
4449
4450        weights0, spec0 = tree_flatten(weights0)
4451        weights1, spec1 = tree_flatten(weights1)
4452        assert spec0 == spec1
4453        expected_weights = tuple(
4454            torch.stack([w0, w1]) for w0, w1 in zip(weights0, weights1)
4455        )
4456        expected_weights = tree_unflatten(expected_weights, spec0)
4457
4458        self.assertEqual(result_loss, expected_loss)
4459        self.assertEqual(result_weights, expected_weights)
4460
4461    @parametrize(
4462        "dropout_layer",
4463        [
4464            subtest(nn.Dropout, "Dropout"),
4465            subtest(nn.AlphaDropout, "AlphaDropout"),
4466            subtest(nn.FeatureAlphaDropout, "FeatureAlphaDropout"),
4467        ],
4468    )
4469    @parametrize("mechanism", ["make_functional", "functional_call"])
4470    def test_find_learning_rate_ensembling(self, device, dropout_layer, mechanism):
4471        # This example mimics what a user might do when trying to find the optimal learning rate. They would
4472        # want to run a bunch of models with the same behavior (including the same dropout!) and have them
4473        # each run with different learning rates. Specifically, this is an example of using same randomness with vmap
4474        points, labels = torch.randn(100, 2, 2, 2, 2, device=device), torch.randint(
4475            0, 2, (100,), device=device
4476        )
4477
4478        class MLPClassifier(nn.Module):
4479            def __init__(self, hidden_dim=32, n_classes=2):
4480                super().__init__()
4481                self.hidden_dim = hidden_dim
4482                self.n_classes = n_classes
4483
4484                self.dropout = dropout_layer()
4485                self.fc1 = nn.Linear(16, self.hidden_dim)
4486                self.fc2 = nn.Linear(self.hidden_dim, self.n_classes)
4487
4488            def forward(self, x):
4489                x = self.dropout(x)
4490                x = torch.flatten(x, start_dim=1)
4491                x = self.fc1(x)
4492                x = F.relu(x)
4493                x = self.fc2(x)
4494                x = F.log_softmax(x, -1)
4495                return x
4496
4497        loss_fn = nn.NLLLoss()
4498
4499        func_model, weights = _get_weights_and_functional_call(
4500            MLPClassifier().to(device), mechanism
4501        )
4502
4503        def train_step_fn(weights, batch, targets, lr):
4504            def compute_loss(weights, batch, targets):
4505                output = func_model(weights, batch)
4506                loss = loss_fn(output, targets)
4507                return loss
4508
4509            grad_weights, loss = grad_and_value(compute_loss)(weights, batch, targets)
4510            new_weights = self._update_params(weights, grad_weights, lr, mechanism)
4511            if mechanism != "make_functional":
4512                new_weights = list(new_weights.values())
4513            # NB: return looks weird because torch.vmap must return Tensors
4514            return (loss, *new_weights)
4515
4516        def unpack(train_result):
4517            return train_result[0], train_result[1:]
4518
4519        def init_fn(num_models):
4520            og_model = MLPClassifier().to(device)
4521            models = tuple(
4522                copy.deepcopy(og_model) for _ in range(num_models)
4523            )  # have same initialization
4524            if mechanism == "make_functional":
4525                return combine_state_for_ensemble(models)[1]
4526            else:
4527                return stack_module_state(models)[0]
4528
4529        batched_weights = init_fn(num_models=2)
4530        parallel_train_step_fn = vmap(
4531            train_step_fn, in_dims=(0, None, None, 0), randomness="same"
4532        )
4533
4534        lrs = torch.tensor([0.2, 0.4], device=device)
4535        result_loss, result_weights = unpack(
4536            parallel_train_step_fn(batched_weights, points, labels, lrs)
4537        )
4538
4539        self.assertEqual(result_loss[0], result_loss[1])
4540        self.assertNotEqual(
4541            tuple(weight[0] for weight in result_weights),
4542            tuple(weight[1] for weight in result_weights),
4543        )
4544
4545    @with_tf32_off  # https://github.com/pytorch/pytorch/issues/86798
4546    @unittest.skipIf(not USE_TORCHVISION, "test requires torchvision")
4547    @parametrize("mechanism", ["make_functional", "functional_call"])
4548    def test_resnet18_per_sample_grads(self, device, mechanism):
4549        import torchvision.models as models
4550
4551        model = models.__dict__["resnet18"](
4552            pretrained=False, norm_layer=(lambda c: nn.GroupNorm(min(32, c), c))
4553        ).to(device)
4554        criterion = nn.CrossEntropyLoss(
4555            reduction="sum"
4556        )  # avoid cross batch reductions for for loop comparison
4557
4558        func_model, weights = _get_weights_and_functional_call(model, mechanism)
4559
4560        def compute_loss(weights, image, target):
4561            image = image.unsqueeze(0)
4562            target = target.unsqueeze(0)
4563            output = func_model(weights, image)
4564            loss = criterion(output, target)
4565            return loss
4566
4567        batch_size = 3
4568        images = torch.randn(batch_size, 3, 32, 32, device=device)
4569        targets = torch.randint(0, 10, (batch_size,), device=device)
4570
4571        result_grads = vmap(grad(compute_loss), in_dims=(None, 0, 0))(
4572            weights, images, targets
4573        )
4574
4575        flat_weights, spec = tree_flatten(weights)
4576        expected_grads = [
4577            torch.autograd.grad(
4578                compute_loss(weights, images[i], targets[i]), flat_weights
4579            )
4580            for i in range(batch_size)
4581        ]
4582        expected_grads = [torch.stack(shards) for shards in zip(*expected_grads)]
4583        expected_grads = tree_unflatten(expected_grads, spec)
4584
4585        self.assertEqual(result_grads, expected_grads, atol=1e-3, rtol=1.0)
4586
4587
4588def normalize_devices(fx_g):
4589    for node in fx_g.graph.nodes:
4590        args = list(node.args)
4591        for idx, arg in enumerate(args):
4592            if isinstance(arg, torch.device):
4593                args[idx] = "cpu"
4594        node.args = tuple(args)
4595        new_kwargs = {}
4596        for k, v in node.kwargs.items():
4597            if isinstance(v, torch.device):
4598                v = "cpu"
4599            new_kwargs[k] = v
4600        node.kwargs = new_kwargs
4601    fx_g.recompile()
4602    return fx_g
4603
4604
4605@markDynamoStrictTest
4606class TestFunctionalize(TestCase):
4607    def _check_functionalize_correctness(self, f, inpt, *, skip_vmap=False):
4608        inpt1 = inpt.clone()
4609        inpt2 = inpt.clone()
4610        inpt3 = inpt.clone()
4611
4612        expected_outputs = f(inpt1)
4613        if skip_vmap:
4614            actual_outputs = functionalize(f)(inpt2)
4615        else:
4616            actual_outputs = vmap(functionalize(f))(inpt2.unsqueeze(0))[0].squeeze()
4617        # Right now the flavor of functionalize that also removes view ops
4618        # isn't being used with vmap
4619        # That's because {view}_copy ops don't have batching rules yet
4620        # (although we should probably fix that)
4621        actual_outputs_view_copy = functionalize(f, remove="mutations_and_views")(inpt3)
4622        # Check that outputs are the same
4623        self.assertEqual(actual_outputs, expected_outputs)
4624        self.assertEqual(actual_outputs_view_copy, expected_outputs)
4625
4626        # Inputs might have been mutated by f: check that they were mutated properly
4627        self.assertEqual(inpt1, inpt2)
4628        self.assertEqual(inpt1, inpt3)
4629
4630    def test_simple_view(self, device):
4631        def f(x: torch.Tensor) -> torch.Tensor:
4632            tmp = torch.ones(2, device=device)
4633            y = x.view(4, 2)
4634            y.add_(tmp)
4635            return x
4636
4637        self._check_functionalize_correctness(f, torch.zeros(4, 2, device=device))
4638
4639    def test_multioutput_view(self, device):
4640        def f(x: torch.Tensor) -> torch.Tensor:
4641            tmp = torch.ones(2, device=device)
4642            y1, y2 = x.split(2)
4643            y1_view = y1.diagonal()
4644            y1_view.add_(tmp)
4645            return x
4646
4647        self._check_functionalize_correctness(f, torch.zeros(4, 2, device=device))
4648
4649    def test_inplace_view(self, device):
4650        def f(x: torch.Tensor) -> torch.Tensor:
4651            tmp = torch.ones(4, device=device)
4652            y = x + x
4653            y2 = y.transpose(1, 0)
4654            z = y2[0]
4655            z.add_(tmp)
4656            return y
4657
4658        self._check_functionalize_correctness(
4659            f, torch.zeros(4, 2, device=device), skip_vmap=True
4660        )
4661
4662    # See https://github.com/pytorch/functorch/issues/780
4663    def test_linear(self, device):
4664        def f(x, y, z) -> torch.Tensor:
4665            return torch._C._nn.linear(x, y, z)
4666
4667        x = torch.randn(14, 1, 384, device=device)
4668        y = torch.randn(96, 384, device=device)
4669        z = torch.randn(96, device=device)
4670
4671        out_expected = f(x, y, z)
4672        out_actual = functionalize(f)(x, y, z)
4673        self.assertEqual(out_expected, out_actual)
4674
4675    def test_multioutput_inplace_slice_view(self, device):
4676        def f(x: torch.Tensor) -> torch.Tensor:
4677            tmp = torch.ones(2, 2, device=device)
4678            y = x.view(8)
4679            z0 = y.reshape(2, 4)
4680            z1 = z0.transpose(1, 0)
4681            z1.unsqueeze_(0)
4682            z1.squeeze_()
4683            z2, z3 = z1.split(2)
4684            z2.add_(tmp)
4685            return x
4686
4687        # See Note [Fix vmap slice_scatter]
4688        self._check_functionalize_correctness(
4689            f, torch.zeros(4, 2, device=device), skip_vmap=True
4690        )
4691
4692    # Ensure functionalize works with List[Optional[Tensor]] arguments.
4693    # See the fix / discussion at https://github.com/pytorch/pytorch/pull/76085
4694    def test_functionalize_opt_tensor_list(self, device):
4695        def f(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
4696            return x[indices]
4697
4698        inpta = torch.ones(4, device=device)
4699        inptb = torch.arange(2, device=device)
4700        out1 = f(inpta, inptb)
4701        out2 = functionalize(f)(inpta, inptb)
4702        self.assertEqual(out1, out2)
4703        out = make_fx(functionalize(f))(inpta, inptb)
4704        self.assertExpectedInline(
4705            (out.code),
4706            """\
4707
4708
4709
4710def forward(self, x_1, indices_1) -> torch.Tensor:
4711    index = torch.ops.aten.index.Tensor(x_1, [indices_1]);  x_1 = indices_1 = None
4712    return index
4713    """,
4714        )
4715
4716    # Ensure grad(functionalize(f)) works
4717    def test_functionalize_grad(self, device):
4718        def f(x: torch.Tensor) -> torch.Tensor:
4719            tmp = torch.ones(2, device=device)
4720            y = x + x
4721            z = y.view(4, 2)
4722            y.add_(tmp)
4723            return z.sum()
4724
4725        inpt1 = torch.ones(4, 2, device=device)
4726        inpt2 = torch.ones(4, 2, device=device)
4727        out1 = grad(f)(inpt1)
4728        out2 = grad(functionalize(f))(inpt2)
4729        self.assertEqual(out1, out2)
4730        self.assertEqual(inpt1, inpt2)
4731
4732    @unittest.skipIf(IS_FBCODE, "fails in fbcode")
4733    def test_vmap_functionalize_jvp(self, device):
4734        def f(x: torch.Tensor) -> torch.Tensor:
4735            y = x + x
4736            z = y.view(-1)
4737            y.add_(1)
4738            return z
4739
4740        def jvp_wrapper(x, t):
4741            return jvp(
4742                f,
4743                (x,),
4744                (t,),
4745            )
4746
4747        x = torch.randn(2, 3, device=device)
4748        t = torch.randn(2, 3, device=device)
4749
4750        out1 = vmap(jvp_wrapper)(x, t)
4751        out2 = vmap(functionalize(jvp_wrapper))(x, t)
4752        self.assertEqual(out1, out2)
4753
4754    # TODO: move this test into test_fake_tensor.py
4755    # once functionalize() can be used in core tests.
4756    def test_functionalize_fake_tensors(self, device):
4757        def f(x: torch.Tensor) -> torch.Tensor:
4758            y = x.detach()
4759            return y + y
4760
4761        with FakeTensorMode() as mode:
4762            x = torch.ones(2, device=device, requires_grad=True)
4763            out = functionalize(f)(x)
4764        self.assertEqual(x.size(), (2,))
4765
4766    def test_functionalize_fx_simple(self, device):
4767        def f(x: torch.Tensor) -> torch.Tensor:
4768            tmp = torch.ones(2, device=device)
4769            y = x.view(4, 2)
4770            y.add_(tmp)
4771            return x
4772
4773        # There's a copy_ in the graph, because the input (x) was mutated.
4774        # To preserve semantics, functionalize() needs to propagate the mutation.
4775        fn = make_fx(functionalize(f, remove="mutations_and_views"))
4776        out = fn(torch.zeros(4, 2, device=device))
4777        out = normalize_devices(out)
4778        self.assertExpectedInline(
4779            (out.code),
4780            """\
4781
4782
4783
4784def forward(self, x_1) -> torch.Tensor:
4785    ones = torch.ops.aten.ones.default([2], device = 'cpu', pin_memory = False)
4786    view_copy = torch.ops.aten.view_copy.default(x_1, [4, 2])
4787    add = torch.ops.aten.add.Tensor(view_copy, ones);  view_copy = ones = None
4788    view_copy_1 = torch.ops.aten.view_copy.default(add, [4, 2]);  add = None
4789    view_copy_2 = torch.ops.aten.view_copy.default(view_copy_1, [4, 2]);  view_copy_2 = None
4790    copy_ = torch.ops.aten.copy_.default(x_1, view_copy_1);  x_1 = copy_ = None
4791    return view_copy_1
4792    """,
4793        )
4794
4795    def test_functionalize_fx_transpose_simple(self, device):
4796        def f(x: torch.Tensor) -> torch.Tensor:
4797            return x.transpose(1, 0)
4798
4799        fn = make_fx(functionalize(f, remove="mutations_and_views"))
4800        out = fn(torch.zeros(4, 2, device=device))
4801        out = normalize_devices(out)
4802        self.assertExpectedInline(
4803            out.code,
4804            """\
4805
4806
4807
4808def forward(self, x_1) -> torch.Tensor:
4809    transpose_copy = torch.ops.aten.transpose_copy.int(x_1, 1, 0);  x_1 = None
4810    return transpose_copy
4811    """,
4812        )
4813
4814    def test_functionalize_fx_out_op(self, device):
4815        def f(inpt: torch.Tensor) -> torch.Tensor:
4816            out = torch.empty((), dtype=torch.float32)
4817            torch.add(inpt, inpt, out=out)
4818            out_view = out.view(4)
4819            out_view.add_(1)
4820            return out
4821
4822        fn = make_fx(functionalize(f, remove="mutations_and_views"))
4823        out = fn(torch.arange(4, device=device, dtype=torch.float32))
4824        out = normalize_devices(out)
4825        self.assertExpectedInline(
4826            out.code,
4827            """\
4828
4829
4830
4831def forward(self, inpt_1) -> torch.Tensor:
4832    empty = torch.ops.aten.empty.memory_format([], dtype = torch.float32, device = 'cpu', pin_memory = False);  empty = None
4833    add = torch.ops.aten.add.Tensor(inpt_1, inpt_1);  inpt_1 = None
4834    view_copy = torch.ops.aten.view_copy.default(add, [4]);  view_copy = None
4835    view_copy_1 = torch.ops.aten.view_copy.default(add, [4]);  add = None
4836    add_1 = torch.ops.aten.add.Tensor(view_copy_1, 1);  view_copy_1 = None
4837    view_copy_2 = torch.ops.aten.view_copy.default(add_1, [4]);  add_1 = None
4838    view_copy_3 = torch.ops.aten.view_copy.default(view_copy_2, [4]);  view_copy_3 = None
4839    return view_copy_2
4840    """,
4841        )
4842
4843    def test_functionalize_fx_multi_out_op(self, device):
4844        def f(inpt: torch.Tensor) -> torch.Tensor:
4845            mins = torch.empty(4, dtype=torch.float32)
4846            maxs = torch.empty(2, 2, dtype=torch.float32)
4847            maxs_view = maxs.view(4)
4848            inpt_view = inpt.view(2, 4)
4849            torch.aminmax(inpt_view, dim=0, out=(mins, maxs_view))
4850            return (maxs, mins)
4851
4852        fn = make_fx(functionalize(f, remove="mutations_and_views"))
4853        out = fn(torch.arange(8, device=device, dtype=torch.float32))
4854        out = normalize_devices(out)
4855        self.assertExpectedInline(
4856            out.code,
4857            """\
4858
4859
4860
4861def forward(self, inpt_1) -> torch.Tensor:
4862    empty = torch.ops.aten.empty.memory_format([4], dtype = torch.float32, device = 'cpu', pin_memory = False);  empty = None
4863    empty_1 = torch.ops.aten.empty.memory_format([2, 2], dtype = torch.float32, device = 'cpu', pin_memory = False)
4864    view_copy = torch.ops.aten.view_copy.default(empty_1, [4]);  empty_1 = view_copy = None
4865    view_copy_1 = torch.ops.aten.view_copy.default(inpt_1, [2, 4]);  inpt_1 = None
4866    aminmax = torch.ops.aten.aminmax.default(view_copy_1, dim = 0);  view_copy_1 = None
4867    getitem = aminmax[0]
4868    getitem_1 = aminmax[1];  aminmax = None
4869    view_copy_2 = torch.ops.aten.view_copy.default(getitem_1, [2, 2]);  getitem_1 = None
4870    view_copy_3 = torch.ops.aten.view_copy.default(view_copy_2, [4]);  view_copy_3 = None
4871    return (view_copy_2, getitem)
4872    """,
4873        )
4874
4875    def test_functionalize_fx_reapply_views_simple(self, device):
4876        def f(x: torch.Tensor) -> torch.Tensor:
4877            tmp = torch.ones(2, device=device)
4878            y = x.view(4, 2)
4879            y.add_(tmp)
4880            return x
4881
4882        out = make_fx(functionalize(f))(torch.zeros(4, 2, device=device))
4883        out = normalize_devices(out)
4884        self.assertExpectedInline(
4885            out.code,
4886            """\
4887
4888
4889
4890def forward(self, x_1) -> torch.Tensor:
4891    ones = torch.ops.aten.ones.default([2], device = 'cpu', pin_memory = False)
4892    view = torch.ops.aten.view.default(x_1, [4, 2])
4893    add = torch.ops.aten.add.Tensor(view, ones);  view = ones = None
4894    view_1 = torch.ops.aten.view.default(add, [4, 2]);  add = None
4895    view_2 = torch.ops.aten.view.default(view_1, [4, 2]);  view_2 = None
4896    copy_ = torch.ops.aten.copy_.default(x_1, view_1);  x_1 = copy_ = None
4897    return view_1
4898    """,
4899        )
4900
4901    def test_functionalize_nonfunctional_output(self, device):
4902        global_out = torch.ones(2, device=device)
4903
4904        def f() -> torch.Tensor:
4905            return global_out
4906
4907        out = make_fx(functionalize(f))()
4908        out = normalize_devices(out)
4909        self.assertExpectedInline(
4910            out.code,
4911            """\
4912
4913
4914
4915def forward(self) -> torch.Tensor:
4916    _tensor_constant0 = self._tensor_constant0
4917    return _tensor_constant0
4918    """,
4919        )
4920
4921    def test_functionalize_optional_tensorlist1(self, device):
4922        def f(a, b) -> torch.Tensor:
4923            # at::index has OptionalTensorList arguments,
4924            # test that here
4925            return a[b]
4926
4927        a = torch.arange(4).reshape(2, 2)
4928        b = torch.ones(2, dtype=torch.long)
4929        out = make_fx(functionalize(f))(a, b)
4930        out = normalize_devices(out)
4931        self.assertExpectedInline(
4932            out.code,
4933            """\
4934
4935
4936
4937def forward(self, a_1, b_1) -> torch.Tensor:
4938    index = torch.ops.aten.index.Tensor(a_1, [b_1]);  a_1 = b_1 = None
4939    return index
4940    """,
4941        )
4942
4943    @unittest.skipIf(IS_FBCODE, "fails in fbcode")
4944    def test_functionalize_optional_tensorlist2(self, device):
4945        def f(a, b) -> torch.Tensor:
4946            # See https://github.com/pytorch/pytorch/pull/77846
4947            return torch.ops.aten.index(a, b)
4948
4949        a = torch.arange(4).reshape(2, 2)
4950        b = torch.ones(2, dtype=torch.long)
4951        out = make_fx(functionalize(f))(a, b)
4952        self.assertExpectedInline(
4953            out.code,
4954            """\
4955
4956
4957
4958def forward(self, a_1, b_1) -> torch.Tensor:
4959    unbind = torch.ops.aten.unbind.int(b_1);  b_1 = None
4960    getitem = unbind[0]
4961    getitem_1 = unbind[1];  unbind = None
4962    index = torch.ops.aten.index.Tensor(a_1, [getitem, getitem_1]);  a_1 = getitem = getitem_1 = None
4963    return index
4964    """,
4965        )
4966
4967    def test_resize_program_inputs(self, device):
4968        def f(x):
4969            x.resize_(10)
4970            x.fill_(2)
4971
4972        fn = make_fx(functionalize(f))
4973        out = fn(torch.zeros(0, device=device))
4974        out = normalize_devices(out)
4975        self.assertExpectedInline(
4976            (out.code),
4977            """\
4978
4979
4980
4981def forward(self, x_1):
4982    resize = torch.ops.aten.resize.default(x_1, [10])
4983    fill = torch.ops.aten.fill.Scalar(resize, 2);  resize = None
4984    resize_ = torch.ops.aten.resize_.default(x_1, [10]);  x_1 = None
4985    copy_ = torch.ops.aten.copy_.default(resize_, fill);  resize_ = fill = copy_ = None
4986    return None
4987    """,
4988        )
4989
4990
4991def construct_sum_pyop():
4992    class MySum(HigherOrderOperator):
4993        def __init__(self):
4994            super().__init__("mysum")
4995
4996        def __call__(self, *args, **kwargs):
4997            return super().__call__(*args, **kwargs)
4998
4999    mysum = MySum()
5000
5001    @mysum.py_impl(torch._C._functorch.TransformType.Vmap)
5002    def mysum_batch_rule(interpreter, x, dim):
5003        if not torch._C._functorch.is_batchedtensor(x):
5004            with interpreter.lower():
5005                x = x.view_as(x)  # unnecessary, just here to test the dispatch
5006                return mysum(x, dim)
5007
5008        bdim = torch._C._functorch.maybe_get_bdim(x)
5009        value = torch._C._functorch.get_unwrapped(x)
5010
5011        with interpreter.lower():
5012            value = value.movedim(bdim, 0)
5013            result = mysum(value, dim + 1)
5014
5015        return torch._C._functorch._add_batch_dim(result, 0, interpreter.level())
5016
5017    @mysum.py_impl(torch._C._functorch.TransformType.Grad)
5018    def mysum_grad_rule(interpreter, x, dim):
5019        level = interpreter.level()
5020
5021        class MySum(torch.autograd.function._SingleLevelFunction):
5022            @staticmethod
5023            def forward(ctx, x, dim):
5024                ctx.x_shape = x.shape
5025                ctx.dim = dim
5026                x = torch._C._functorch._unwrap_for_grad(x, level)
5027                with torch.enable_grad(), interpreter.lower():
5028                    x = x.view_as(x)  # unnecessary, just here to test the dispatch
5029                    y = mysum(x, dim)
5030
5031                y = torch._C._functorch._wrap_for_grad(y, level)
5032                return y
5033
5034            @staticmethod
5035            def backward(ctx, gy):
5036                return gy.unsqueeze(ctx.dim).expand(ctx.x_shape), None
5037
5038        with enable_single_level_autograd_function():
5039            return MySum.apply(x, dim)
5040
5041    @mysum.py_impl(torch._C.DispatchKey.AutogradCPU)
5042    def mysum_autograd_cpu(x, dim):
5043        return torch.sum(x, dim)
5044
5045    @mysum.py_impl(torch._C.DispatchKey.AutogradCUDA)
5046    def mysum_autograd_cuda(x, dim):
5047        return torch.sum(x, dim)
5048
5049    return mysum
5050
5051
5052sum_pyop = construct_sum_pyop()
5053
5054
5055@markDynamoStrictTest
5056class TestHigherOrderOperatorInteraction(TestCase):
5057    def test_basic_sum(self, device):
5058        x = torch.randn(2, 3, 4, device=device)
5059        result = sum_pyop(x, 1)
5060        self.assertEqual(result, torch.sum(x, 1))
5061
5062    def test_vmap_sum(self, device):
5063        x = torch.randn(2, 3, 4, device=device)
5064        result = vmap(sum_pyop, (0, None))(x, 0)
5065        self.assertEqual(result, torch.sum(x, 1))
5066
5067        result = vmap(vmap(sum_pyop, (0, None)), (0, None))(x, 0)
5068        self.assertEqual(result, torch.sum(x, 2))
5069
5070    def test_grad_sum(self, device):
5071        x = torch.randn(3, device=device)
5072        gx = grad(sum_pyop)(x, 0)
5073        self.assertEqual(gx, torch.ones_like(x))
5074
5075    def test_grad_grad_sum(self, device):
5076        x = torch.randn(3, requires_grad=True, device=device)
5077
5078        def f(x):
5079            # higher order grad. Requires a non-linearity
5080            return sum_pyop(x.sin(), 0)
5081
5082        def grad_f_sum(x):
5083            return grad(f)(x).sum()
5084
5085        ggx = grad(grad_f_sum)(x)
5086        self.assertEqual(ggx, -x.sin())
5087
5088    def test_vmap_grad_sum(self, device):
5089        x = torch.randn(2, 3, device=device)
5090        gx = vmap(grad(sum_pyop), (0, None))(x, 0)
5091        self.assertEqual(gx, torch.ones_like(x))
5092
5093    def test_no_grad_outside_grad(self, device):
5094        x = torch.randn(3, device=device, requires_grad=True)
5095        with torch.no_grad():
5096            y = grad(sum_pyop)(x, 0)
5097        self.assertEqual(y, torch.ones_like(x))
5098        self.assertFalse(y.requires_grad)
5099
5100    def test_no_grad_inside_grad(self, device):
5101        def f(x):
5102            with torch.no_grad():
5103                shift = sum_pyop(x**2, 0)
5104            return sum_pyop(x**2, 0) - shift
5105
5106        x = torch.randn(3, device=device)
5107        y = grad(f)(x)
5108        self.assertEqual(y, 2 * x)
5109        y = grad(lambda x: grad(f)(x).sum())(x)
5110        self.assertEqual(y, torch.full_like(x, 2))
5111
5112        x = torch.randn(3, device=device, requires_grad=True)
5113        y = grad(f)(x)
5114        (z,) = torch.autograd.grad(y.sum(), x)
5115        self.assertEqual(z, torch.full_like(x, 2))
5116
5117    def test_grad_name_wrapping(self, device):
5118        def my_fn(x):
5119            return x.sum()
5120
5121        grad_fn = grad(my_fn)
5122        self.assertEqual(grad_fn.__name__, "my_fn")
5123
5124    def test_functional_call_multiple_dicts(self):
5125        mod = nn.Linear(1, 1)
5126        x = torch.randn((1, 1))
5127        params = ({"weight": torch.zeros(1, 1)}, {"bias": torch.ones(1)})
5128        functional_call(mod, params, x)
5129
5130
5131def traceable(f):
5132    f = allow_in_graph(f)
5133
5134    @wraps(f)
5135    def wrapper(*args, **kwargs):
5136        return f(*args, **kwargs)
5137
5138    return wrapper
5139
5140
5141@markDynamoStrictTest
5142class TestCompileTransforms(TestCase):
5143    @skipIfRocm(msg="test leaks memory on ROCm")
5144    # torch.compile is not supported on Windows CUDA.
5145    # Triton only supports GPU with SM70 or later.
5146    @expectedFailureIf((IS_WINDOWS and TEST_CUDA) or (TEST_CUDA and not SM70OrLater))
5147    def test_compile_vmap_hessian(self, device):
5148        # The model and inputs are a smaller version
5149        # of code at benchmark repo:
5150        # https://github.com/pytorch/benchmark/blob/main/userbenchmark/functorch/vmap_hessian_fc.py
5151        D = 2
5152        B = 4
5153
5154        x = torch.randn(B, D, device=device)
5155
5156        model = nn.Sequential(nn.Linear(D, D), nn.ReLU()).to(device)
5157
5158        params_and_buffers = (
5159            dict(model.named_parameters()),
5160            dict(model.named_buffers()),
5161        )
5162
5163        def predict(params_and_buffers, x):
5164            out = torch.func.functional_call(model, params_and_buffers, x)
5165            return out, out
5166
5167        fn = vmap(
5168            jacfwd(jacrev(predict, argnums=1, has_aux=True), argnums=1, has_aux=True),
5169            in_dims=(None, 0),
5170        )
5171
5172        expected = fn(params_and_buffers, x)
5173
5174        opt_fn = torch.compile(traceable(fn))
5175        actual = opt_fn(params_and_buffers, x)
5176        self.assertEqual(actual, expected)
5177
5178    # torch.compile is not supported on Windows
5179    @torch._dynamo.config.patch(suppress_errors=False)
5180    def test_grad_deprecated_api(self, device):
5181        x = torch.randn((), device=device)
5182        y = torch.randn((), device=device)
5183
5184        def wrapper_fn(x, y):
5185            return functorch.grad(torch.mul)(x, y)
5186
5187        actual = wrapper_fn(x, y)
5188        expected = torch.compile(wrapper_fn, backend="eager", fullgraph=True)(x, y)
5189        fn = torch.compile(wrapper_fn, backend="eager", fullgraph=True)
5190        self.assertEqual(actual, expected)
5191
5192        def wrapper_fn(x, y):
5193            return functorch.grad(torch.mul, argnums=(0, 1))(x, y)
5194
5195        actual = wrapper_fn(x, y)
5196        expected = torch.compile(wrapper_fn, backend="eager", fullgraph=True)(x, y)
5197        self.assertEqual(actual, expected)
5198
5199
5200only_for = ("cpu", "cuda")
5201instantiate_device_type_tests(
5202    TestGradTransform,
5203    globals(),
5204    only_for=only_for,
5205)
5206instantiate_device_type_tests(
5207    TestVmapOfGrad,
5208    globals(),
5209    only_for=only_for,
5210)
5211instantiate_device_type_tests(
5212    TestJac,
5213    globals(),
5214    only_for=only_for,
5215)
5216instantiate_device_type_tests(
5217    TestJvp,
5218    globals(),
5219    only_for=only_for,
5220)
5221instantiate_device_type_tests(
5222    TestLinearize,
5223    globals(),
5224    only_for=only_for,
5225)
5226instantiate_device_type_tests(
5227    TestVmapJvpInplaceView,
5228    globals(),
5229    only_for=only_for,
5230)
5231instantiate_device_type_tests(
5232    TestHessian,
5233    globals(),
5234    only_for=only_for,
5235)
5236instantiate_device_type_tests(
5237    TestComposability,
5238    globals(),
5239    only_for=only_for,
5240)
5241instantiate_device_type_tests(
5242    TestExamplesCorrectness,
5243    globals(),
5244    only_for=only_for,
5245)
5246instantiate_device_type_tests(
5247    TestHigherOrderOperatorInteraction,
5248    globals(),
5249    only_for=only_for,
5250)
5251instantiate_device_type_tests(
5252    TestFunctionalize,
5253    globals(),
5254    only_for=only_for,
5255)
5256instantiate_device_type_tests(
5257    TestAutogradFunction,
5258    globals(),
5259    only_for=only_for,
5260)
5261instantiate_device_type_tests(
5262    TestAutogradFunctionVmapAPI,
5263    globals(),
5264    only_for=only_for,
5265)
5266instantiate_device_type_tests(
5267    TestHelpers,
5268    globals(),
5269    only_for=only_for,
5270)
5271instantiate_parametrized_tests(
5272    TestMakeFunctional,
5273)
5274instantiate_device_type_tests(
5275    TestCompileTransforms,
5276    globals(),
5277    only_for=only_for,
5278)
5279
5280if __name__ == "__main__":
5281    run_tests()
5282