xref: /aosp_15_r20/external/pytorch/test/functorch/test_ops.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: functorch"]
2
3# Copyright (c) Facebook, Inc. and its affiliates.
4# All rights reserved.
5#
6# This source code is licensed under the BSD-style license found in the
7# LICENSE file in the root directory of this source tree.
8
9import functools
10import itertools
11import unittest
12
13from common_utils import (
14    check_vmap_fallback,
15    decorate,
16    expectedFailureIf,
17    generate_vmap_inputs,
18    get_fallback_and_vmap_exhaustive,
19    is_batch_norm_training,
20    is_valid_inplace_sample_input,
21    loop,
22    loop2,
23    opsToleranceOverride,
24    skip,
25    skipOps,
26    tol1,
27    tol2,
28    xfail,
29)
30from functorch_additional_op_db import additional_op_db
31
32import torch
33import torch.autograd.forward_ad as fwAD
34from functorch import grad, jacfwd, jacrev, vjp, vmap
35from torch import Tensor
36from torch._functorch.eager_transforms import _as_tuple, jvp
37from torch.testing._internal.autograd_function_db import autograd_function_db
38from torch.testing._internal.common_cuda import with_tf32_off
39from torch.testing._internal.common_device_type import (
40    instantiate_device_type_tests,
41    ops,
42    tol,
43    toleranceOverride,
44)
45from torch.testing._internal.common_methods_invocations import op_db
46from torch.testing._internal.common_utils import (
47    is_iterable_of_tensors,
48    IS_MACOS,
49    IS_X86,
50    noncontiguous_like,
51    parametrize,
52    run_tests,
53    runOnRocm,
54    skipIfRocm,
55    TEST_WITH_ASAN,
56    TEST_WITH_ROCM,
57    TestCase,
58    unMarkDynamoStrictTest,
59)
60from torch.testing._internal.opinfo.core import SampleInput
61from torch.utils import _pytree as pytree
62from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten
63
64
65aten = torch.ops.aten
66
67
68# Version of autograd.grad with some differences:
69#   - pytree inputs is allowed (but leaves of the pytree have to all
70#     be tensors)
71#   - if an input is not used as part of derivatives, we will return a
72#     zero-filled tensor for the result
73def _autograd_grad(
74    outputs, inputs, grad_outputs=None, retain_graph=False, create_graph=True
75):
76    inputs, inputs_spec = tree_flatten(inputs)
77    diff_inputs = tuple(inp for inp in inputs if inp.requires_grad)
78    if grad_outputs is None:
79        diff_outputs = tuple(out for out in outputs if out.requires_grad)
80    else:
81        diff_grad_outputs = [
82            (out, go) for out, go in zip(outputs, grad_outputs) if out.requires_grad
83        ]
84        if len(diff_grad_outputs) == 0:
85            diff_outputs, grad_outputs = (), ()
86        else:
87            diff_outputs, grad_outputs = zip(*diff_grad_outputs)
88    grad_inputs = torch.autograd.grad(
89        diff_outputs,
90        diff_inputs,
91        grad_outputs,
92        retain_graph=retain_graph,
93        create_graph=create_graph,
94        allow_unused=True,
95    )
96    result = []
97    grad_inputs_iter = iter(grad_inputs)
98    for inp in inputs:
99        if inp.requires_grad:
100            grad_input = next(grad_inputs_iter)
101            if grad_input is None:
102                result.append(torch.zeros_like(inp))
103            else:
104                result.append(grad_input)
105        else:
106            result.append(torch.zeros_like(inp))
107    return tree_unflatten(result, inputs_spec)
108
109
110def diff_arg(arg, requires_grad=True):
111    def is_differentiable_arg(arg):
112        if requires_grad:
113            return arg.requires_grad
114        else:
115            return arg.is_floating_point() or arg.is_complex()
116
117    if is_iterable_of_tensors(arg):
118        if all(is_differentiable_arg(a) for a in arg):
119            return True
120        if all(not is_differentiable_arg(a) for a in arg):
121            return False
122        raise RuntimeError("NYI: The test runner can't handle this")
123    return isinstance(arg, Tensor) and is_differentiable_arg(arg)
124
125
126# Given f, returns an f' such that:
127# - f' takes only positional arguments
128# - All arguments to f' are floating-point Tensors
129# - All outputs of f' are floating-point Tensors
130def normalize_op_input_output2(
131    f, args, kwargs, output_process_fn_grad=None, requires_grad=True
132):
133    flat_args, args_spec = tree_flatten(args)
134    diff_argnums = tuple(
135        i
136        for i, arg in enumerate(flat_args)
137        if diff_arg(arg, requires_grad=requires_grad)
138    )
139    assert len(diff_argnums) > 0
140    primals = tuple(flat_args[i] for i in diff_argnums)
141
142    @functools.wraps(f)
143    def wrapped(*primals):
144        _args = list(flat_args)
145        for num, arg in zip(diff_argnums, primals):
146            _args[num] = arg
147        _args = tree_unflatten(_args, args_spec)
148        result = f(*_args, **kwargs)
149        if output_process_fn_grad is not None:
150            result = output_process_fn_grad(result)
151        if isinstance(result, tuple):
152            result = tuple(r for r in result if torch.is_floating_point(r))
153            assert len(result) > 0
154        return result
155
156    return wrapped, primals
157
158
159# TODO: consolidate with normalize_op_input_output2
160def normalize_op_input_output3(
161    f, args, kwargs, sample_args, output_process_fn_grad=None
162):
163    flat_args, args_spec = tree_flatten(args)
164    flat_sample_args = pytree.tree_leaves(sample_args)
165    diff_argnums = tuple(
166        i
167        for i, (arg, sample) in enumerate(zip(flat_args, flat_sample_args))
168        if diff_arg(sample, requires_grad=True)
169    )
170    assert len(diff_argnums) > 0
171    primals = tuple(flat_args[i] for i in diff_argnums)
172
173    @functools.wraps(f)
174    def wrapped(*primals):
175        _args = list(flat_args)
176        for num, arg in zip(diff_argnums, primals):
177            _args[num] = arg
178        _args = tree_unflatten(_args, args_spec)
179        result = f(*_args, **kwargs)
180        if output_process_fn_grad is not None:
181            result = output_process_fn_grad(result)
182        if isinstance(result, tuple):
183            result = tuple(r for r in result if torch.is_floating_point(r))
184            assert len(result) > 0
185        return result
186
187    return wrapped, primals
188
189
190def normalize_op_input_output(f, sample, requires_grad=True):
191    args = tuple([sample.input] + list(sample.args))
192    return normalize_op_input_output2(
193        f,
194        args,
195        sample.kwargs,
196        sample.output_process_fn_grad,
197        requires_grad=requires_grad,
198    )
199
200
201def ref_vjp(f, *primals):
202    result = f(*primals)
203
204    def wrapped(cotangents):
205        return _autograd_grad(_as_tuple(result), primals, _as_tuple(cotangents))
206
207    return result, wrapped
208
209
210def simulate_jvp(f, primals, tangents):
211    primals_out, tangents_out = torch.autograd.functional.jvp(f, primals, tangents)
212    return primals_out, tangents_out
213
214
215def ref_jvp(f, primals, tangents):
216    with fwAD.dual_level():
217        duals = tuple(fwAD.make_dual(p, t) for p, t in zip(primals, tangents))
218        result_duals = f(*duals)
219        result_duals, spec = tree_flatten(result_duals)
220        primals_out, tangents_out = zip(*(fwAD.unpack_dual(d) for d in result_duals))
221        return tree_unflatten(primals_out, spec), tree_unflatten(tangents_out, spec)
222
223
224def get_sample_cotangents(f, sample):
225    fn, primals = normalize_op_input_output(f, sample)
226    output = fn(*primals)
227    return tree_map(torch.randn_like, output)
228
229
230# returns a new function g(*args, *cotangents)
231# that computes vjps and (*args, cotangents)
232def get_vjp_fn_and_args_with_cotangents(f, sample, cotangents):
233    args = tuple([sample.input] + list(sample.args))
234    kwargs = sample.kwargs
235    flat_args, args_spec = tree_flatten(args)
236    flat_cotangents, cotangents_spec = tree_flatten(cotangents)
237
238    @functools.wraps(f)
239    def wrapped(*args):
240        assert len(args) == len(flat_args) + len(flat_cotangents)
241        actual_args = args[: len(flat_args)]
242        cotangents = args[len(flat_args) :]
243        actual_args = tree_unflatten(actual_args, args_spec)
244        cotangents = tree_unflatten(cotangents, cotangents_spec)
245
246        fn, primals = normalize_op_input_output3(
247            f, actual_args, kwargs, flat_args, sample.output_process_fn_grad
248        )
249        _, vjp_fn = vjp(fn, *primals)
250        return vjp_fn(cotangents)
251
252    return wrapped, tuple(flat_args + flat_cotangents)
253
254
255# Returns a new function g(*args, *cotangents) that computes vjps and
256# sample (*args, *cotangents)
257def get_vjpfull_variant(f, sample):
258    fn, primals = normalize_op_input_output(f, sample)
259    return _get_vjpfull_variant(fn, primals)
260
261
262def get_vjpfull_variant2(f, args, kwargs):
263    fn, primals = normalize_op_input_output2(f, args, kwargs)
264    return _get_vjpfull_variant(fn, primals)
265
266
267def _get_vjpfull_variant(fn, primals):
268    result = fn(*primals)
269    cotangents = _as_tuple(
270        tree_map(lambda x: torch.randn_like(x, requires_grad=True), result)
271    )
272    num_primals = len(primals)
273    args = (*primals, *cotangents)
274
275    @functools.wraps(fn)
276    def wrapped(*args):
277        primals = args[:num_primals]
278        cotangents = args[num_primals:]
279        result, vjp_fn = vjp(fn, *primals)
280        if isinstance(result, torch.Tensor):
281            assert len(cotangents) == 1
282            cotangents = cotangents[0]
283        return vjp_fn(cotangents)
284
285    return wrapped, args
286
287
288def get_jvp_variant(f, sample):
289    # We want this higher-order variant of jvp, so that it can
290    # be used to wrap vmap
291    fn, primals = normalize_op_input_output(f, sample, requires_grad=False)
292    tangents = _as_tuple(tree_map(lambda x: torch.randn_like(x), primals))
293
294    @functools.wraps(f)
295    def wrapped(*args):
296        tangents = args
297        primals_out, tangents_out = jvp(fn, primals, tangents)
298
299        if isinstance(primals_out, torch.Tensor):
300            return (primals_out, tangents_out)
301        else:
302            flat_primals_out = pytree.tree_leaves(primals_out)
303            flat_tangents_out = pytree.tree_leaves(tangents_out)
304            return tuple(flat_primals_out + flat_tangents_out)
305
306    return wrapped, tangents
307
308
309def get_jvp_variant_primals_tangents2(
310    f, args, kwargs, output_process_fn_grad=None, requires_grad=False
311):
312    fn, primals = normalize_op_input_output2(
313        f, args, kwargs, output_process_fn_grad, requires_grad
314    )
315    tangents = _as_tuple(tree_map(lambda x: torch.randn_like(x), primals))
316    return _get_jvp_variant(fn, primals, tangents)
317
318
319def get_jvp_variant_primals_tangents(f, sample):
320    # We want this higher-order variant of jvp, so that it can
321    # be used to wrap vmap
322    fn, primals = normalize_op_input_output(f, sample, requires_grad=False)
323    tangents = _as_tuple(tree_map(lambda x: torch.randn_like(x), primals))
324    return _get_jvp_variant(fn, primals, tangents)
325
326
327def _get_jvp_variant(fn, primals, tangents):
328    @functools.wraps(fn)
329    def wrapped(*args):
330        primals_in = args[: len(primals)]
331        tangents_in = args[len(primals) :]
332        primals_out, tangents_out = jvp(fn, primals_in, tangents_in)
333
334        if isinstance(primals_out, torch.Tensor):
335            return (primals_out, tangents_out)
336        else:
337            flat_primals_out = pytree.tree_leaves(primals_out)
338            flat_tangents_out = pytree.tree_leaves(tangents_out)
339            return tuple(flat_primals_out + flat_tangents_out)
340
341    return wrapped, primals + tangents
342
343
344def is_inplace(op, variant):
345    if hasattr(variant, "__wrapped__"):
346        return variant.__wrapped__ is op.get_inplace()
347    return variant is op.get_inplace()
348
349
350vjp_fail = {
351    xfail("tensor_split"),  # data_ptr composite compliance
352    # Very minor accuracy issue on ROCm
353    decorate("nn.functional.scaled_dot_product_attention", decorator=skipIfRocm),
354}
355
356aliasing_ops = {
357    "T",
358    "broadcast_to",
359    "conj",
360    "contiguous",
361    "diagonal",  # linalg.diagonal is an alias
362    "expand",
363    "flatten",
364    "imag",
365    "mH",  # adjoint is an alias
366    "mT",
367    "movedim",  # moveaxis is an alias
368    "narrow",
369    "permute",
370    "positive",
371    # 'ravel', is composite implicit autograd and may call clone
372    "real",
373    "reshape",
374    "resolve_conj",
375    "resolve_neg",
376    "select",
377    "squeeze",
378    "transpose",  # swapdims and swapaxes are aliases
379    "unflatten",
380    "unfold",
381    "unsqueeze",
382    "view",
383    "view_as",
384    "view_as_complex",
385    "view_as_real",
386}
387
388aliasing_ops_list_return = {
389    "chunks",
390    "dsplit",
391    "hsplit",
392    "split",
393    "unbind",
394    "vsplit",
395    # 'tensor_split' not composite compliant, see vjp_fail
396}
397
398skip_noncontig = {
399    "_batch_norm_with_update",
400    "as_strided_copy",
401}
402
403
404@unittest.skipIf(TEST_WITH_ASAN, "tests time out with asan, are probably redundant")
405@unMarkDynamoStrictTest
406class TestOperators(TestCase):
407    @with_tf32_off  # https://github.com/pytorch/pytorch/issues/86798
408    @ops(op_db + additional_op_db + autograd_function_db, allowed_dtypes=(torch.float,))
409    @skipOps(
410        "TestOperators",
411        "test_grad",
412        vjp_fail.union(
413            {
414                xfail(
415                    "chalf", "", device_type="cpu"
416                ),  # RuntimeError: "sum_cpu" not implemented for 'ComplexHalf'
417                xfail(
418                    "sparse.sampled_addmm", ""
419                ),  # RuntimeError: Sparse CSR tensors do not have strides
420                xfail(
421                    "sparse.mm", "reduce"
422                ),  # RuntimeError: Sparse CSR tensors do not have strides
423                # Non-contiguous Bugs
424                #
425                # AssertionError: Tensor-likes are not close!
426                xfail("_softmax_backward_data", device_type="cpu"),
427                xfail("as_strided"),
428                xfail("as_strided", "partial_views"),
429                # RuntimeError: !self.requires_grad() || self.is_contiguous()
430                xfail("as_strided_scatter"),
431                # RuntimeError: Tensor must have a last dimension with stride 1
432                xfail("view_as_complex"),
433                # query: last dimension must be contiguous
434                # Fused attention kernels require last dim to be contiguous
435                decorate(
436                    "nn.functional.scaled_dot_product_attention",
437                    decorator=expectedFailureIf(not TEST_WITH_ROCM),
438                ),  # Works on ROCm
439                xfail("torch.ops.aten._flash_attention_forward"),
440                xfail("torch.ops.aten._efficient_attention_forward"),
441                # RuntimeError: Expected contiguous tensor, but got
442                # non-contiguous tensor for argument #2 'grad_output'
443                decorate(
444                    "_batch_norm_with_update",
445                    decorator=expectedFailureIf(TEST_WITH_ROCM),
446                    device_type="cuda",
447                ),
448            }
449        ),
450    )
451    @opsToleranceOverride(
452        "TestOperators",
453        "test_grad",
454        (
455            tol1(
456                "nn.functional.binary_cross_entropy_with_logits",
457                {torch.float32: tol(atol=1e-04, rtol=1e-04)},
458            ),
459            tol1("masked.cumprod", {torch.float32: tol(atol=1e-05, rtol=1e-05)}),
460            tol1("svd_lowrank", {torch.float32: tol(atol=3e-04, rtol=3e-04)}),
461            tol1(
462                "linalg.multi_dot",
463                {torch.float32: tol(atol=1e-05, rtol=8e-04)},
464                device_type="cuda",
465            ),
466            tol1(
467                "linalg.tensorsolve",
468                {torch.float32: tol(atol=3e-04, rtol=3e-04)},
469                device_type="cuda",
470            ),
471            tol1(
472                "nn.functional.multi_head_attention_forward",
473                {torch.float32: tol(atol=8e-04, rtol=1e-03)},
474            ),
475            tol1(
476                "__rmatmul__",
477                {torch.float32: tol(atol=3e-04, rtol=3e-04)},
478                device_type="cuda",
479            ),
480            tol1(
481                "matmul",
482                {torch.float32: tol(atol=3e-04, rtol=3e-04)},
483                device_type="cuda",
484            ),
485            tol1(
486                "pca_lowrank",
487                {torch.float32: tol(atol=3e-05, rtol=4e-06)},
488                device_type="cpu",
489            ),
490        ),
491    )
492    def test_grad(self, device, dtype, op):
493        if op.name in vjp_fail:
494            self.skipTest("Skipped; Expected failures")
495            return
496
497        if not op.supports_autograd:
498            self.skipTest("Skipped! Autograd not supported.")
499            return
500
501        samples = op.sample_inputs(device, dtype, requires_grad=True)
502
503        if is_inplace(op, op.get_op()):
504            self.skipTest("Skipped for redundancy. test_vjp handles in-place testing.")
505            return
506
507        for sample in samples:
508            args = [sample.input] + list(sample.args)
509            kwargs = sample.kwargs
510
511            if op.name not in skip_noncontig:
512                noncontig_sample = sample.noncontiguous()
513                noncontig_args = [noncontig_sample.input] + list(noncontig_sample.args)
514                noncontig_kwargs = noncontig_sample.kwargs
515
516            diff_argnums = tuple(i for i, arg in enumerate(args) if diff_arg(arg))
517            assert len(diff_argnums) > 0
518            diff_args = tuple(args[i] for i in diff_argnums)
519
520            def wrapped_fn(*args, **kwargs):
521                result = op(*args, **kwargs)
522                if sample.output_process_fn_grad is not None:
523                    result = sample.output_process_fn_grad(result)
524
525                def abs_if_complex(t):
526                    if t.dtype.is_complex:
527                        return t.abs()
528                    return t
529
530                # Reduce into single value for grad
531                if isinstance(result, torch.Tensor):
532                    return abs_if_complex(result.sum())
533                result = sum(abs_if_complex(res.sum()) for res in result)
534                return result
535
536            result = grad(wrapped_fn, diff_argnums)(*args, **kwargs)
537            expected = _autograd_grad(_as_tuple(wrapped_fn(*args, **kwargs)), diff_args)
538            self.assertEqual(result, expected)
539
540            if op.name not in skip_noncontig:
541                result_noncontig = grad(wrapped_fn, diff_argnums)(
542                    *noncontig_args, **noncontig_kwargs
543                )
544                self.assertEqual(result_noncontig, expected)
545
546    @with_tf32_off  # https://github.com/pytorch/pytorch/issues/86798
547    @ops(op_db + additional_op_db + autograd_function_db, allowed_dtypes=(torch.float,))
548    @skipOps(
549        "TestOperators",
550        "test_jvp",
551        set(
552            {
553                # Composite ops that do bad things. Need to be fixed in PyTorch core.
554                # RuntimeError: Cannot access data pointer of Tensor that doesn't have storage
555                xfail("tensor_split"),
556                # BUG: silent incorrectness: runs and produces numerical differences
557                skip("nn.functional.max_unpool1d"),  # fails everywhere except on mac
558                skip(
559                    "nn.functional.max_unpool2d"
560                ),  # fails everywhere except on windows
561                skip("nn.functional.max_unpool3d"),  # fails everywhere except on mac
562                xfail(
563                    "native_batch_norm"
564                ),  # TODO: fails comparing None to tensor of 0s for saved_mean/var tangents
565                xfail(
566                    "_native_batch_norm_legit"
567                ),  # TODO: fails comparing None to tensor of 0s for saved_mean/var tangents
568                xfail(
569                    "_batch_norm_with_update"
570                ),  # TODO: fails comparing None to tensor of 0s for saved_mean/var tangents
571                xfail("nn.functional.scaled_dot_product_attention"),
572                xfail("torch.ops.aten._flash_attention_forward"),
573                xfail("torch.ops.aten._efficient_attention_forward"),
574                xfail(
575                    "nn.functional.rrelu"
576                ),  # in-place test errors out with no formula implemented
577                xfail(
578                    "NumpyExpMarkDirtyAutogradFunction"
579                ),  # TODO: https://github.com/pytorch/pytorch/issues/91280
580                # --- Non-Contiguous Failures! ---
581                # This is expected to fail as the operator
582                # expects last dim to have stride=1
583                xfail("view_as_complex"),
584                # BUG
585                # AssertionError: Tensor-likes are not close!
586                xfail("as_strided"),
587                xfail("as_strided", "partial_views"),
588                xfail("as_strided_scatter"),
589                decorate(
590                    "linalg.det",
591                    "singular",
592                    decorator=expectedFailureIf(IS_MACOS and IS_X86),
593                ),
594            }
595        ),
596    )
597    @opsToleranceOverride(
598        "TestOperators",
599        "test_jvp",
600        (
601            tol1(
602                "nn.functional.conv_transpose3d",
603                {torch.float32: tol(atol=1e-04, rtol=1.3e-06)},
604                device_type="cuda",
605            ),
606            tol1(
607                "linalg.tensorsolve",
608                {torch.float32: tol(atol=1e-04, rtol=1.3e-05)},
609                device_type="cuda",
610            ),
611            tol1(
612                "masked.prod",
613                {torch.float32: tol(atol=1e-05, rtol=1.3e-05)},
614                device_type="cuda",
615            ),
616            tol1(
617                "nn.functional.binary_cross_entropy_with_logits",
618                {torch.float32: tol(atol=4e-04, rtol=4e-04)},
619            ),
620            tol1(
621                "nn.functional.batch_norm", {torch.float32: tol(atol=4e-05, rtol=5e-05)}
622            ),
623            tol1("nn.functional.conv2d", {torch.float32: tol(atol=4e-05, rtol=5e-05)}),
624            tol1("svd_lowrank", {torch.float32: tol(atol=5e-05, rtol=5e-05)}),
625            tol1("pca_lowrank", {torch.float32: tol(atol=5e-05, rtol=5e-05)}),
626            tol1(
627                "nn.functional.multi_head_attention_forward",
628                {torch.float32: tol(atol=6e-05, rtol=2e-05)},
629            ),
630            tol2(
631                "linalg.pinv", "hermitian", {torch.float32: tol(atol=5e-5, rtol=2e-5)}
632            ),
633        ),
634    )
635    def test_jvp(self, device, dtype, op):
636        # TODO: get rid of vjp_decomp when we add decomposition support to
637        # PyTorch's forward-mode ad. Currently the decomposition support only
638        # works for functorch.jvp
639        VJP_DECOMP = {
640            "nn.functional.logsigmoid",
641        }
642        if op.name in VJP_DECOMP:
643            fixme_ref_jvp_local = simulate_jvp
644        else:
645            fixme_ref_jvp_local = ref_jvp
646
647        if not op.supports_forward_ad and op.name not in VJP_DECOMP:
648            self.skipTest("Skipped! Forward AD not supported.")
649            return
650
651        samples = op.sample_inputs(device, dtype, requires_grad=True)
652
653        outplace_variant = op if not is_inplace(op, op.get_op()) else None
654        inplace_variant = op.inplace_variant if op.supports_inplace_autograd else None
655
656        for sample in samples:
657            if outplace_variant:
658                self.jvp_opinfo_test(
659                    outplace_variant,
660                    sample,
661                    sample.output_process_fn_grad,
662                    clone_inputs=False,
663                    fixme_ref_jvp_local=fixme_ref_jvp_local,
664                    test_noncontig=op.name not in skip_noncontig,
665                )
666            if is_valid_inplace_sample_input(sample, op, inplace_variant):
667                self.jvp_opinfo_test(
668                    inplace_variant,
669                    sample,
670                    sample.output_process_fn_grad,
671                    clone_inputs=True,
672                    fixme_ref_jvp_local=fixme_ref_jvp_local,
673                    test_noncontig=op.name not in skip_noncontig,
674                )
675
676    def jvp_opinfo_test(
677        self,
678        fn,
679        sample,
680        output_process_fn,
681        clone_inputs,
682        fixme_ref_jvp_local,
683        test_noncontig,
684    ):
685        # NB: we used requires_grad=True to determine where the primals are,
686        # but don't need that information otherwise
687        args = (sample.input,) + sample.args
688        kwargs = sample.kwargs
689        contig_fn, primals = normalize_op_input_output2(
690            fn, args, kwargs, output_process_fn, requires_grad=True
691        )
692        orig_primals = tree_map(lambda x: x.detach(), primals)
693        orig_tangents = tree_map(lambda x: torch.randn_like(x), primals)
694
695        def maybe_clone_inputs():
696            if clone_inputs:
697                primals = tree_map(torch.clone, orig_primals)
698                tangents = tree_map(torch.clone, orig_tangents)
699                return primals, tangents
700            return orig_primals, orig_tangents
701
702        primals, tangents = maybe_clone_inputs()
703        expected_primal_outs, expected_tangent_outs = fixme_ref_jvp_local(
704            contig_fn, primals, tangents
705        )
706
707        primals, tangents = maybe_clone_inputs()
708        primal_outs, tangent_outs = jvp(contig_fn, primals, tangents)
709
710        self.assertEqual(primal_outs, expected_primal_outs)
711        self.assertEqual(tangent_outs, expected_tangent_outs)
712
713        if test_noncontig:
714            noncontig_sample = sample.noncontiguous()
715            noncontig_args = (noncontig_sample.input,) + noncontig_sample.args
716            noncontig_kwargs = sample.kwargs
717            noncontig_fn, primals = normalize_op_input_output2(
718                fn,
719                noncontig_args,
720                noncontig_kwargs,
721                output_process_fn,
722                requires_grad=True,
723            )
724            noncontig_primals = tree_map(lambda x: x.detach(), primals)
725            noncontig_tangents = tree_map(
726                lambda x: noncontiguous_like(x), orig_tangents
727            )
728            noncontig_primal_outs, noncontig_tangent_outs = jvp(
729                noncontig_fn, noncontig_primals, noncontig_tangents
730            )
731
732            self.assertEqual(noncontig_primal_outs, expected_primal_outs)
733            self.assertEqual(noncontig_tangent_outs, expected_tangent_outs)
734
735    @with_tf32_off  # https://github.com/pytorch/pytorch/issues/86798
736    @ops(op_db + additional_op_db + autograd_function_db, allowed_dtypes=(torch.float,))
737    @skipOps(
738        "TestOperators",
739        "test_vjp",
740        vjp_fail.union(
741            {
742                xfail("sparse.sampled_addmm", ""),
743                xfail("sparse.mm", "reduce"),
744                # ---- Non-Contiguous Failures ----
745                # This is expected to fail as the operator
746                # expects last dim to have stride=1
747                xfail("view_as_complex"),
748                # RuntimeError: query: last dimension must be contiguous
749                # The fused attention kernels require the last dim to be contiguous
750                decorate(
751                    "nn.functional.scaled_dot_product_attention",
752                    decorator=expectedFailureIf(not TEST_WITH_ROCM),
753                ),  # Works on ROCm
754                xfail("torch.ops.aten._flash_attention_forward"),
755                xfail("torch.ops.aten._efficient_attention_forward"),
756                # BUG
757                # AssertionError: Tensor-likes are not close!
758                xfail("as_strided"),
759                xfail("as_strided_scatter"),
760                xfail("_softmax_backward_data", device_type="cpu"),
761                xfail("as_strided", "partial_views"),
762            }
763        ),
764    )
765    @opsToleranceOverride(
766        "TestOperators",
767        "test_vjp",
768        (
769            tol1(
770                "nn.functional.conv_transpose3d",
771                {torch.float32: tol(atol=5e-05, rtol=9e-05)},
772                device_type="cuda",
773            ),
774            tol1(
775                "nn.functional.binary_cross_entropy_with_logits",
776                {torch.float32: tol(atol=1e-04, rtol=1e-04)},
777            ),
778            tol1(
779                "nn.functional.multi_head_attention_forward",
780                {torch.float32: tol(atol=2e-03, rtol=2e-04)},
781            ),
782            tol1("__rmatmul__", {torch.float32: tol(atol=1e-05, rtol=1e-05)}),
783            tol1("matmul", {torch.float32: tol(atol=1e-05, rtol=1e-05)}),
784            tol2(
785                "linalg.pinv", "hermitian", {torch.float32: tol(atol=1e-05, rtol=1e-05)}
786            ),
787            tol1("linalg.tensorsolve", {torch.float32: tol(atol=9e-03, rtol=2e-04)}),
788            tol1("linalg.multi_dot", {torch.float32: tol(atol=1e-04, rtol=1e-04)}),
789            tol1("svd_lowrank", {torch.float32: tol(atol=1e-04, rtol=1e-04)}),
790            tol1("pca_lowrank", {torch.float32: tol(atol=1e-04, rtol=1e-04)}),
791        ),
792    )
793    def test_vjp(self, device, dtype, op):
794        if not op.supports_autograd:
795            self.skipTest("Skipped! Autograd not supported.")
796            return
797
798        samples = op.sample_inputs(device, dtype, requires_grad=True)
799
800        def _test(_op, inplace=False):
801            for sample in samples:
802                if inplace and not is_valid_inplace_sample_input(
803                    sample, op, op.inplace_variant
804                ):
805                    continue
806                fn, primals = normalize_op_input_output(_op, sample)
807                result = fn(*primals)
808                cotangents = tree_map(lambda x: torch.randn_like(x), result)
809
810                out, vjp_fn = vjp(fn, *primals)
811                self.assertEqual(out, result)
812                result_vjps = vjp_fn(cotangents)
813
814                _, vjp_fn = ref_vjp(fn, *primals)
815                expected_vjps = vjp_fn(cotangents)
816
817                self.assertEqual(result_vjps, expected_vjps)
818
819                if op.name not in skip_noncontig:
820                    noncontig_fn, noncontig_primals = normalize_op_input_output(
821                        _op, sample.noncontiguous()
822                    )
823                    noncontig_cotangents = tree_map(
824                        lambda x: noncontiguous_like(x), cotangents
825                    )
826                    out_noncontig, vjp_fn = vjp(noncontig_fn, *noncontig_primals)
827                    self.assertEqual(out_noncontig, result)
828                    noncontig_result_vjps = vjp_fn(noncontig_cotangents)
829                    self.assertEqual(noncontig_result_vjps, expected_vjps)
830
831        _test(op)
832        for a_op in op.aliases:
833            _test(a_op)
834        if op.inplace_variant:
835
836            def f(inp, *args, **kwargs):
837                return op.inplace_variant(inp.clone(), *args, **kwargs)
838
839            _test(f, inplace=True)
840
841    @ops(op_db + additional_op_db + autograd_function_db, allowed_dtypes=(torch.float,))
842    @skipOps(
843        "TestOperators",
844        "test_vjpvjp",
845        vjp_fail.union(
846            {
847                skip("nn.functional.max_unpool1d"),  # silent incorrectness; Flaky
848                skip("nn.functional.max_unpool2d"),  # silent incorrectness; Flaky
849                xfail("nn.functional.ctc_loss"),  # Not Implemented
850                xfail(
851                    "native_layer_norm", ""
852                ),  # Expected a proper Tensor but got None for argument #1 'other'
853                xfail("sparse.sampled_addmm", ""),  # sparse tensors have no strides
854                xfail("sparse.mm", "reduce"),  # sparse tensors have no strides
855                skip("nn.functional.scaled_dot_product_attention"),
856                xfail("torch.ops.aten._flash_attention_forward"),
857                xfail("torch.ops.aten._efficient_attention_forward"),
858                # AssertionError: Tensor-likes are not close!
859                # Mismatched elements: 1 / 15 (6.7%)
860                # Greatest absolute difference: 24.0 at index (2, 4) (up to 1e-05 allowed)
861                # Greatest relative difference: 1.7933241714393998e-06 at index (2, 4) (up to 1.3e-06 allowed)
862                # The failure occurred for item [0]
863                xfail("masked.prod"),
864            }
865        ),
866    )
867    @opsToleranceOverride(
868        "TestOperators",
869        "test_vjpvjp",
870        (
871            tol1(
872                "nn.functional.conv_transpose3d",
873                {torch.float32: tol(atol=5e-05, rtol=9e-05)},
874                device_type="cuda",
875            ),
876            tol1("prod", {torch.float32: tol(atol=2e-05, rtol=1e-04)}),
877            tol1("masked.cumprod", {torch.float32: tol(atol=5e-04, rtol=5e-04)}),
878            tol1("cumprod", {torch.float32: tol(atol=5e-04, rtol=5e-04)}),
879            tol1("linalg.vander", {torch.float32: tol(atol=5e-04, rtol=5e-04)}),
880            tol2(
881                "linalg.det", "singular", {torch.float32: tol(atol=2e-05, rtol=2e-05)}
882            ),
883        ),
884    )
885    def test_vjpvjp(self, device, dtype, op):
886        if not op.supports_autograd:
887            self.skipTest("Skipped! Autograd not supported.")
888            return
889        if not op.supports_gradgrad:
890            self.skipTest("Skipped! Operation does not support gradgrad")
891            return
892
893        samples = op.sample_inputs(device, dtype, requires_grad=True)
894
895        def test(_op, inplace=False):
896            for sample in samples:
897                if inplace and not is_valid_inplace_sample_input(
898                    sample, op, op.inplace_variant
899                ):
900                    continue
901                fn, args = get_vjpfull_variant(_op, sample)
902                result = fn(*args)
903                cotangents = tree_map(lambda x: torch.randn_like(x), result)
904
905                # Compute vjp of vjp
906                _, vjp_fn = vjp(fn, *args)
907                result_vjps = vjp_fn(cotangents)
908
909                # Compute ref_vjp of vjp. We could have done ref_vjp of ref_vjp,
910                # but since we're confident that vjp works by itself, this is
911                # an equivalent way to test that.
912                _, vjp_fn = ref_vjp(fn, *args)
913                expected_vjps = vjp_fn(cotangents)
914
915                self.assertEqual(result_vjps, expected_vjps)
916
917        test(op)
918        if op.inplace_variant:
919
920            def fn(inp, *args, **kwargs):
921                return op.inplace_variant(inp.clone(), *args, **kwargs)
922
923            test(fn, inplace=True)
924
925    @with_tf32_off  # https://github.com/pytorch/pytorch/issues/86798
926    @skipOps(
927        "TestOperators",
928        "test_vmapvjpvjp",
929        vjp_fail.union(
930            {
931                skip("atleast_1d"),  # Takes too long
932                skip("atleast_2d"),  # Takes too long
933                skip("atleast_3d"),  # Takes too long
934                skip("ormqr"),  # Takes too long
935                xfail("as_strided"),  # incorrect output
936                xfail("as_strided", "partial_views"),  # incorrect output
937                xfail("as_strided_scatter"),  # incorrect output
938                skip("bernoulli"),  # calls random op
939                xfail("bfloat16"),  # rank 4 tensor for channels_last
940                xfail("cdouble"),  # rank 4 tensor for channels_last
941                xfail("cfloat"),  # rank 4 tensor for channels_last
942                xfail("chalf"),  # rank 4 tensor for channels_last
943                xfail("double"),  # rank 4 tensor for channels_last
944                xfail("float"),  # rank 4 tensor for channels_last
945                xfail("half"),  # rank 4 tensor for channels_last
946                xfail(
947                    "NumpyCubeNotComposableAutogradFunction"
948                ),  # Not composable autograd.Function
949                # It looks like you're either (1) calling .item() on a Tensor or
950                # (2) attempting to use a Tensor in some data-dependent control flow or
951                # (3) encountering this error in PyTorch internals.
952                xfail("index_reduce", "prod"),
953                decorate(
954                    "linalg.householder_product", decorator=runOnRocm
955                ),  # works on ROCm
956                xfail(
957                    # nans
958                    "masked.softmax",
959                    device_type="cpu",
960                ),
961                xfail(
962                    "nanquantile", device_type="cpu"
963                ),  # vmap not implemented for at::equal.
964                xfail("native_layer_norm"),  # vmap: inplace into a regular tensor
965                # got a batched tensor as input while the running_mean or running_var,
966                # which will be updated in place, were not batched.
967                xfail("nn.functional.batch_norm"),
968                xfail(
969                    "nn.functional.binary_cross_entropy"
970                ),  # vmap: inplace into a regular tensor
971                xfail(
972                    "nn.functional.ctc_loss"
973                ),  # derivate not implemented for _ctc_loss_backward
974                # flaky on ROCM needs investigation
975                decorate("nn.functional.conv_transpose2d", decorator=skipIfRocm),
976                skip("nn.functional.dropout"),  # calls random op
977                skip("nn.functional.dropout2d"),  # calls random op
978                skip("nn.functional.dropout3d"),  # calls random op
979                skip("nn.functional.alpha_dropout"),  # calls random op
980                skip(
981                    "nn.functional.feature_alpha_dropout", "with_train"
982                ),  # calls random op
983                skip("nn.functional.fractional_max_pool2d"),  # calls random op
984                skip("nn.functional.fractional_max_pool3d"),  # calls random op
985                xfail("nn.functional.scaled_dot_product_attention"),  # randomness
986                xfail("torch.ops.aten._efficient_attention_forward"),  # outputs ints
987                xfail("nn.functional.multi_head_attention_forward"),  # randomness
988                # It looks like you're either (1) calling .item() on a Tensor or
989                # (2) attempting to use a Tensor in some data-dependent control flow or
990                # (3) encountering this error in PyTorch internals.
991                xfail("nn.functional.gaussian_nll_loss"),
992                # got a batched tensor as input while the running_mean or running_var,
993                # which will be updated in place, were not batched.
994                xfail("nn.functional.instance_norm"),
995                xfail(
996                    "nn.functional.layer_norm"
997                ),  # vmap: inplace into a regular tensor
998                # RuntimeError: NYI: querying is_contiguous inside of vmap
999                # for memory_format other than torch.contiguous_formats
1000                xfail("nn.functional.max_pool2d"),
1001                # RuntimeError: NYI: Tensor.clone(memory_format) inside vmap is only
1002                # supported with memory_format torch.preserve_format or
1003                # torch.contiguous_format (got ChannelsLast)
1004                xfail("nn.functional.max_unpool2d"),
1005                # RuntimeError: NYI: Tensor.clone(memory_format) inside vmap is only
1006                # supported with memory_format torch.preserve_format
1007                # or torch.contiguous_format (got ChannelsLast)s
1008                xfail("nn.functional.max_unpool2d", "grad"),
1009                xfail(
1010                    "nn.functional.rrelu"
1011                ),  # RuntimeError: vmap: we do not yet support aten::rrelu_with_noise.
1012                xfail("normal"),  # calls random op
1013                xfail("normal", "number_mean"),  # calls random op
1014                xfail("pca_lowrank"),  # calls random op
1015                xfail(
1016                    "quantile", device_type="cpu"
1017                ),  # Batching rule not implemented for `at::equal`
1018                xfail(
1019                    "scatter_reduce", "prod"
1020                ),  # vmap (looks like you are calling item/data-dependent)
1021                xfail(
1022                    "sparse.sampled_addmm"
1023                ),  # RuntimeError: Sparse CSR tensors do not have strides
1024                xfail(
1025                    "sparse.mm", "reduce"
1026                ),  # RuntimeError: Sparse CSR tensors do not have strides
1027                xfail("svd_lowrank"),  # calls random op
1028                xfail("to"),  # rank 4 tensor for channels_last
1029                xfail(
1030                    "view_as_complex"
1031                ),  # RuntimeError: Tensor must have a last dimension with stride 1
1032                # got a batched tensor as input while the running_mean or running_var,
1033                # which will be updated in place, were not batched.
1034                xfail("nn.functional.batch_norm", "without_cudnn"),
1035                # view doesn't work on sparse
1036                xfail("to_sparse"),
1037                xfail("native_batch_norm"),
1038                xfail("_native_batch_norm_legit"),
1039                # TODO: implement batching rule
1040                xfail("_batch_norm_with_update"),
1041            }
1042        ),
1043    )
1044    @ops(op_db + additional_op_db + autograd_function_db, allowed_dtypes=(torch.float,))
1045    @toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)})
1046    @opsToleranceOverride(
1047        "TestOperators",
1048        "test_vmapvjpvjp",
1049        (
1050            tol1("linalg.svd", {torch.float32: tol(atol=1e-03, rtol=5e-04)}),
1051            tol1("linalg.lu", {torch.float32: tol(atol=5e-04, rtol=7e-04)}),
1052            tol1("linalg.lu_factor", {torch.float32: tol(atol=2e-03, rtol=2e-02)}),
1053            tol1("linalg.multi_dot", {torch.float32: tol(atol=2e-03, rtol=2e-04)}),
1054            tol1("svd", {torch.float32: tol(atol=1e-03, rtol=5e-04)}),
1055            tol1("matrix_exp", {torch.float32: tol(atol=1e-03, rtol=5e-04)}),
1056            tol1("masked.prod", {torch.float32: tol(atol=2e-03, rtol=2e-04)}),
1057        ),
1058    )
1059    @skipOps(
1060        "TestOperators",
1061        "test_vmapvjpvjp",
1062        {
1063            xfail("as_strided", "partial_views"),
1064            xfail("as_strided_copy"),
1065        },
1066    )
1067    def test_vmapvjpvjp(self, device, dtype, op):
1068        # Since, we test `vjpvjp` independently,
1069        # for this test, we just verify that vmap
1070        # of `vjpvjp` is correct.
1071        if not op.supports_autograd:
1072            self.skipTest("Skipped! Autograd not supported.")
1073            return
1074        if not op.supports_gradgrad:
1075            self.skipTest("Skipped! Operation does not support gradgrad")
1076            return
1077
1078        samples = op.sample_inputs(device, dtype, requires_grad=True)
1079
1080        # TODO: test in-place
1081        if is_inplace(op, op.get_op()):
1082            self.skipTest("Skipped! NYI: inplace-testing not supported.")
1083            return
1084
1085        for sample in samples:
1086            fn, args = get_vjpfull_variant(op, sample)
1087            result = fn(*args)
1088            cotangents = tree_map(lambda x: torch.randn_like(x), result)
1089            cotangents = pytree.tree_leaves(cotangents)
1090            num_args = len(args)
1091
1092            args_and_cotangents = tuple(args) + tuple(cotangents)
1093
1094            def vjp_of_vjp(*args_and_cotangents):
1095                args = args_and_cotangents[:num_args]
1096                cotangents = args_and_cotangents[num_args:]
1097                result, vjp_fn = vjp(fn, *args)
1098                result_vjps = vjp_fn(cotangents)
1099                result = pytree.tree_leaves(result)
1100                result_vjps = pytree.tree_leaves(result_vjps)
1101                return (*result, *result_vjps)
1102
1103            is_batch_norm_and_training = is_batch_norm_training(op.name, sample.kwargs)
1104            generator = get_fallback_and_vmap_exhaustive(
1105                vjp_of_vjp,
1106                args_and_cotangents,
1107                {},
1108                is_batch_norm_and_training=is_batch_norm_and_training,
1109            )
1110            for loop_out, batched_out in generator:
1111                self.assertEqual(loop_out, batched_out)
1112
1113    vmapvjp_fail = vjp_fail.union(
1114        {
1115            # -------------------- ALLOWED FAILURES --------------------------------
1116            # The following are not bugs and are expected behavior
1117            xfail("masked_select"),  # Not possible due to dynamic shapes
1118            skip("bernoulli"),  # randomness
1119            skip("normal", ""),  # randomness
1120            skip("normal", "number_mean"),  # randomness
1121            skip("nn.functional.rrelu"),  # randomness
1122            skip("nn.functional.feature_alpha_dropout", "with_train"),  # randomness
1123            skip("nn.functional.feature_alpha_dropout", "without_train"),  # randomness
1124            skip("nn.functional.dropout"),  # randomness
1125            skip("nn.functional.dropout2d"),  # randomness
1126            skip("nn.functional.dropout3d", ""),  # randomness
1127            skip("nn.functional.alpha_dropout"),  # randomness
1128            skip("nn.functional.scaled_dot_product_attention"),  # randomness
1129            xfail("torch.ops.aten._efficient_attention_forward"),  # outputs ints
1130            skip("nn.functional.multi_head_attention_forward"),  # randomness
1131            xfail(
1132                "index_put", ""
1133            ),  # not possible due to dynamic shapes; we support a subset
1134            xfail("nn.functional.fractional_max_pool2d"),  # random
1135            xfail("nn.functional.fractional_max_pool3d"),  # random
1136            xfail("pca_lowrank", ""),  # randomness
1137            xfail("svd_lowrank", ""),  # randomness
1138            xfail("to_sparse", ""),  # non-dense output
1139            skip(
1140                "to"
1141            ),  # RuntimeError: required rank 4 tensor to use channels_last format
1142            xfail("as_strided", "partial_views"),
1143            xfail(
1144                "NumpyCubeNotComposableAutogradFunction"
1145            ),  # Not composable autograd.Function
1146            # ----------------------------------------------------------------------
1147            # ---------------------------- BUGS ------------------------------------
1148            # All of the following are bugs and need to be fixed
1149            skip(
1150                "linalg.svdvals"
1151            ),  # # really annoying thing where it passes correctness check but not has_batch_rule
1152            skip("native_batch_norm"),
1153            skip("_native_batch_norm_legit"),
1154            # TODO: implement batching rule
1155            skip("_batch_norm_with_update"),
1156            xfail("__getitem__", ""),  # dynamic error
1157            xfail("nanquantile", device_type="cpu"),  # checks q via a .item() call
1158            xfail("nn.functional.gaussian_nll_loss"),  # checks var for if any value < 0
1159            xfail("narrow"),  # .item() call
1160            xfail("quantile", device_type="cpu"),  # checks q via a .item() call
1161            xfail("view_as_complex"),  # Tensor must have a last dimension with stride 1
1162            # required rank 4 tensor to use channels_last format
1163            xfail("bfloat16"),
1164            xfail("double"),
1165            xfail("float"),
1166            xfail("half"),
1167            xfail("cdouble", ""),
1168            xfail("cfloat", ""),
1169            xfail("chalf", ""),
1170            xfail("scatter_reduce", "prod"),  # item call
1171            # Batching rule not implemented for aten::_use_cudnn_ctc_loss.Tensor
1172            xfail("nn.functional.ctc_loss", device_type="cuda"),
1173            # NYI: querying is_contiguous inside of vmap for memory_format other than torch.contiguous_format
1174            xfail("nn.functional.max_unpool2d"),
1175            xfail("nn.functional.max_unpool2d", "grad"),
1176            xfail("sparse.sampled_addmm", ""),
1177            xfail("sparse.mm", "reduce"),
1178            xfail("as_strided_scatter", ""),  # calls as_strided
1179            xfail("index_reduce", "prod"),  # .item() call
1180            # ---------------------------------------------------------------------
1181        }
1182    )
1183
1184    @with_tf32_off  # https://github.com/pytorch/pytorch/issues/86798
1185    @ops(op_db + additional_op_db + autograd_function_db, allowed_dtypes=(torch.float,))
1186    @toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)})
1187    @opsToleranceOverride(
1188        "TestOperators",
1189        "test_vmapvjp",
1190        (
1191            tol1(
1192                "linalg.svd",
1193                {torch.float32: tol(atol=5e-04, rtol=1e-04)},
1194                device_type="cuda",
1195            ),
1196            tol1(
1197                "svd", {torch.float32: tol(atol=5e-04, rtol=1e-04)}, device_type="cuda"
1198            ),
1199            tol1(
1200                "linalg.householder_product",
1201                {torch.float32: tol(atol=3e-04, rtol=9e-04)},
1202            ),
1203            tol1(
1204                "matrix_exp",
1205                {torch.float32: tol(atol=5e-04, rtol=1e-04)},
1206                device_type="cuda",
1207            ),
1208            tol1(
1209                "nn.functional.layer_norm",
1210                {torch.float32: tol(atol=3e-4, rtol=1e-4)},
1211                device_type="cpu",
1212            ),
1213            tol1(
1214                "native_layer_norm",
1215                {torch.float32: tol(atol=3e-4, rtol=1e-4)},
1216                device_type="cpu",
1217            ),
1218        ),
1219    )
1220    @skipOps(
1221        "TestOperators",
1222        "test_vmapvjp",
1223        vmapvjp_fail.union(
1224            {
1225                xfail("as_strided"),
1226                xfail("as_strided_copy"),
1227                xfail("as_strided", "partial_views"),
1228            }
1229        ),
1230    )
1231    def test_vmapvjp(self, device, dtype, op):
1232        if not op.supports_autograd:
1233            self.skipTest("Skipped! Autograd not supported.")
1234            return
1235
1236        samples = op.sample_inputs(device, dtype, requires_grad=True)
1237
1238        # TODO: test in-place
1239        if is_inplace(op, op.get_op()):
1240            self.skipTest("Skipped! NYI: inplace-testing not supported.")
1241            return
1242        for sample in samples:
1243            cotangents = get_sample_cotangents(op, sample)
1244            fn, args = get_vjp_fn_and_args_with_cotangents(op, sample, cotangents)
1245            is_batch_norm_and_training = is_batch_norm_training(op.name, sample.kwargs)
1246            generator = get_fallback_and_vmap_exhaustive(
1247                fn, args, {}, is_batch_norm_and_training=is_batch_norm_and_training
1248            )
1249            for loop_out, batched_out in generator:
1250                self.assertEqual(loop_out, batched_out)
1251
1252    vmapjvpall_fail = {
1253        # -------------------- ALLOWED FAILURES --------------------------------
1254        # The following are expected (not a bug)
1255        skip("bernoulli", ""),  # randomness
1256        skip("nn.functional.dropout"),  # randomness
1257        skip("nn.functional.rrelu"),  # randomness
1258        skip("nn.functional.dropout2d", ""),
1259        skip("nn.functional.dropout3d", ""),
1260        skip("nn.functional.scaled_dot_product_attention"),  # randomness
1261        xfail("torch.ops.aten._efficient_attention_forward"),  # outputs ints
1262        skip("nn.functional.multi_head_attention_forward"),  # randomness
1263        skip("nn.functional.alpha_dropout"),  # randomness
1264        skip("nn.functional.feature_alpha_dropout", "without_train"),
1265        skip("nn.functional.feature_alpha_dropout", "with_train"),
1266        xfail(
1267            "nn.functional.fractional_max_pool2d"
1268        ),  # Cannot access data pointer of Tensor that doesn't have storage
1269        xfail(
1270            "nn.functional.fractional_max_pool3d"
1271        ),  # Cannot access data pointer of Tensor that doesn't have storage
1272        # Not actually a problem: embedding with max_norm mutates the weight
1273        # and causes different runs to produce different results.
1274        # skip because this is flaky depending on what the max_norm is!
1275        skip("nn.functional.embedding", ""),
1276        skip("to"),  # RuntimeError: required rank 4 tensor to use channels_last format
1277        xfail(
1278            "NumpyExpMarkDirtyAutogradFunction"
1279        ),  # vmap: inplace into a regular tensor
1280        # ----------------------------------------------------------------------
1281        # ---------------------------- BUGS ------------------------------------
1282        # The following are bugs that we should fix
1283        xfail("masked.mean"),  # silent incorrectness (nan difference)
1284        xfail("as_strided", "partial_views"),  # Tensor-likes are not close!
1285        xfail(
1286            "nn.functional.soft_margin_loss", ""
1287        ),  # soft_margin_loss_backward does not support forward-ad
1288        xfail("tensor_split"),  # data_ptr composite compliance
1289        xfail("quantile"),  # at::equal batching rule (cpu), also, in-place vmap (cuda)
1290        skip("as_strided"),  # Test runner cannot handle this
1291        # requires special handling, and does not yet have a batching rule. Feel free to file a github issue!
1292        xfail("as_strided_scatter"),
1293        xfail(
1294            "nn.functional.gaussian_nll_loss"
1295        ),  # .item or data-dependent control flow
1296        xfail("scatter"),  # forward-mode AD does not support at::scatter
1297        xfail(
1298            "nanquantile"
1299        ),  # at::equal batching rule (cpu), also, in-place vmap (cuda)
1300        xfail("view_as_complex"),  # Tensor must have a last dimension with stride 1
1301        skip("pca_lowrank", ""),  # randomness
1302        skip("svd_lowrank", ""),  # randomness
1303        xfail("double"),  # required rank 4 tensor to use channels_last format
1304        xfail("cdouble"),  # required rank 4 tensor to use channels_last format
1305        # potential silent incorrectness
1306        skip(
1307            "nn.functional.max_unpool1d"
1308        ),  # Flaky, seems to sometimes his max_unpool2d
1309        skip("nn.functional.max_unpool2d"),  # fails everywhere except on mac
1310        skip("nn.functional.max_unpool3d"),  # fails everywhere except on mac
1311        # erroring because running_mean and running_var aren't differentiable
1312        xfail("nn.functional.batch_norm"),
1313        xfail("nn.functional.batch_norm", "without_cudnn"),
1314        xfail("native_batch_norm"),
1315        xfail("_native_batch_norm_legit"),
1316        # TODO: implement batching rule
1317        xfail("_batch_norm_with_update"),
1318        # ----------------------------------------------------------------------
1319    }
1320
1321    @with_tf32_off  # https://github.com/pytorch/pytorch/issues/86798
1322    @ops(op_db + additional_op_db + autograd_function_db, allowed_dtypes=(torch.float,))
1323    @toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)})
1324    @opsToleranceOverride(
1325        "TestOperators",
1326        "test_vmapjvpall",
1327        (
1328            tol1(
1329                "nn.functional.conv_transpose3d",
1330                {torch.float32: tol(atol=2e-04, rtol=9e-3)},
1331                device_type="cuda",
1332            ),
1333            tol1(
1334                "linalg.householder_product",
1335                {torch.float32: tol(atol=2e-04, rtol=9e-3)},
1336            ),
1337        ),
1338    )
1339    @skipOps(
1340        "TestOperators",
1341        "test_vmapjvpall",
1342        vmapjvpall_fail.union(
1343            {
1344                xfail("as_strided_copy"),
1345                decorate(
1346                    "linalg.det",
1347                    "singular",
1348                    decorator=expectedFailureIf(IS_MACOS and IS_X86),
1349                ),
1350            }
1351        ),
1352    )
1353    # This is technically a superset of test_vmapjvp. We should either delete test_vmapjvp
1354    # or figure out if we can split vmapjvpall. It's useful to keep test_vmapjvp intact
1355    # because that corresponds to "batched forward-mode AD" testing in PyTorch core
1356    def test_vmapjvpall(self, device, dtype, op):
1357        if is_inplace(op, op.get_op()):
1358            # TODO: test in-place
1359            self.skipTest("Skipped! NYI: inplace-testing not supported.")
1360            return
1361
1362        samples = op.sample_inputs(device, dtype, requires_grad=False)
1363
1364        if not op.supports_forward_ad:
1365            self.skipTest("Skipped! Forward AD not supported.")
1366            return
1367
1368        for sample in samples:
1369            arg_values = [sample.input] + list(sample.args)
1370            kwarg_values = sample.kwargs
1371            args = tuple(arg_values) + tuple(kwarg_values)
1372            fn, args = get_jvp_variant_primals_tangents(op, sample)
1373            is_batch_norm_and_training = is_batch_norm_training(op.name, kwarg_values)
1374            generator = get_fallback_and_vmap_exhaustive(
1375                fn, args, {}, is_batch_norm_and_training=is_batch_norm_and_training
1376            )
1377            for loop_out, batched_out in generator:
1378                self.assertEqual(loop_out, batched_out)
1379
1380    @ops(op_db + additional_op_db + autograd_function_db, allowed_dtypes=(torch.float,))
1381    @skipOps(
1382        "TestOperators",
1383        "test_vmapjvpall_has_batch_rule",
1384        vmapjvpall_fail.union(
1385            {
1386                skip(
1387                    "to"
1388                ),  # RuntimeError: required rank 4 tensor to use channels_last format
1389                xfail(
1390                    "cdouble"
1391                ),  # RuntimeError: required rank 4 tensor to use channels_last format
1392                xfail("cumprod"),
1393                xfail("masked_fill"),
1394                xfail("fill"),
1395                skip("masked.mean"),  # ???
1396                xfail("masked_scatter"),
1397                xfail("put"),
1398                xfail("take"),
1399                xfail("nn.functional.feature_alpha_dropout", "without_train"),
1400                xfail("nn.functional.dropout2d", ""),
1401                xfail("pca_lowrank", ""),
1402                xfail("svd_lowrank", ""),
1403                xfail("nn.functional.feature_alpha_dropout", "with_train"),
1404                xfail("special.log_ndtr", ""),
1405                xfail("fft.ihfft2"),  # conj_physical fallback
1406                xfail("fft.ihfftn"),  # conj_physical fallback
1407                xfail("nn.functional.max_unpool3d", "grad"),
1408                xfail("nn.functional.max_unpool2d", "grad"),
1409                xfail("nn.functional.soft_margin_loss", ""),
1410                xfail("nn.functional.max_unpool1d", "grad"),
1411                xfail("nn.functional.embedding", ""),
1412                xfail(
1413                    "scatter_reduce", "sum"
1414                ),  # aten::scatter_reduce.two hit the vmap fallback
1415                xfail(
1416                    "scatter_reduce", "mean"
1417                ),  # aten::scatter_reduce.two hit the vmap fallback
1418                xfail(
1419                    "scatter_reduce", "amin"
1420                ),  # aten::scatter_reduce.two hit the vmap fallback
1421                xfail(
1422                    "scatter_reduce", "amax"
1423                ),  # aten::scatter_reduce.two hit the vmap fallback
1424                xfail("nn.functional.glu"),
1425                xfail("nn.functional.bilinear"),  # trilinear doesn't have batching rule
1426                xfail("linalg.lu", ""),
1427                xfail("nn.functional.dropout3d", ""),
1428                xfail("as_strided_scatter", ""),
1429                xfail("masked.cumprod", ""),
1430                xfail("renorm"),  # hit vmap fallback, which is disabled
1431                xfail("t_copy"),
1432                xfail("unsqueeze_copy"),
1433            }
1434        ),
1435    )
1436    @toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)})
1437    def test_vmapjvpall_has_batch_rule(self, device, dtype, op):
1438        if is_inplace(op, op.get_op()):
1439            # TODO: test in-place
1440            self.skipTest("Skipped! NYI: inplace-testing not supported.")
1441            return
1442
1443        samples = op.sample_inputs(device, dtype, requires_grad=False)
1444
1445        if not op.supports_forward_ad:
1446            self.skipTest("Skipped! Forward AD not supported.")
1447            return
1448
1449        def test():
1450            for sample in samples:
1451                arg_values = [sample.input] + list(sample.args)
1452                kwarg_values = sample.kwargs
1453                args = tuple(arg_values) + tuple(kwarg_values)
1454                fn, args = get_jvp_variant_primals_tangents(op, sample)
1455                is_batch_norm_and_training = is_batch_norm_training(
1456                    op.name, kwarg_values
1457                )
1458                for loop_out, batched_out in get_fallback_and_vmap_exhaustive(
1459                    fn,
1460                    args,
1461                    {},
1462                    is_batch_norm_and_training=is_batch_norm_and_training,
1463                    compute_loop_out=False,
1464                ):
1465                    pass
1466
1467        check_vmap_fallback(self, test, op, dry_run=False)
1468
1469    @ops(op_db + additional_op_db + autograd_function_db, allowed_dtypes=(torch.float,))
1470    @toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)})
1471    @skipOps(
1472        "TestOperators",
1473        "test_vmapvjp_has_batch_rule",
1474        vmapvjp_fail.union(
1475            {
1476                skip(
1477                    "to"
1478                ),  # RuntimeError: required rank 4 tensor to use channels_last format
1479                xfail("view_as_complex"),
1480                xfail("cummax"),
1481                xfail("cummin"),
1482                xfail("fill"),
1483                xfail(
1484                    "narrow"
1485                ),  # Batching rule not implemented for `narrow.Tensor` (and view op)
1486                xfail("special.log_ndtr"),
1487                xfail("linalg.householder_product"),
1488                xfail("masked_fill"),
1489                xfail("masked_scatter"),
1490                xfail("masked_select"),
1491                xfail("nanquantile"),
1492                xfail("ormqr"),
1493                xfail("put"),
1494                xfail(
1495                    "scatter_reduce", "sum"
1496                ),  # aten::scatter_reduce.two hit the vmap fallback
1497                xfail(
1498                    "scatter_reduce", "mean"
1499                ),  # aten::scatter_reduce.two hit the vmap fallback
1500                xfail(
1501                    "scatter_reduce", "amin"
1502                ),  # aten::scatter_reduce.two hit the vmap fallback
1503                xfail(
1504                    "scatter_reduce", "amax"
1505                ),  # aten::scatter_reduce.two hit the vmap fallback
1506                xfail("quantile"),
1507                xfail("renorm"),
1508                xfail("take"),
1509                xfail("tensor_split"),
1510                xfail("to_sparse"),
1511                xfail("unfold"),
1512                xfail("unfold_copy"),
1513                xfail("nn.functional.dropout"),
1514                xfail("fft.ihfft2"),
1515                xfail("fft.ihfftn"),
1516                xfail("nn.functional.gaussian_nll_loss"),
1517                xfail("nn.functional.bilinear"),
1518                xfail("nn.functional.fractional_max_pool3d"),
1519                xfail("nn.functional.ctc_loss"),
1520                xfail("nn.functional.rrelu"),
1521                xfail("nn.functional.embedding_bag"),
1522                xfail("nn.functional.fractional_max_pool2d"),
1523                xfail("nn.functional.feature_alpha_dropout", "with_train"),
1524                xfail("pca_lowrank", ""),
1525                xfail("nn.functional.dropout2d", ""),
1526                xfail("nn.functional.feature_alpha_dropout", "without_train"),
1527                xfail("svd_lowrank", ""),
1528                xfail("nn.functional.max_unpool2d", ""),
1529                xfail("nn.functional.multi_margin_loss", ""),
1530                xfail("nn.functional.multilabel_margin_loss", ""),
1531                xfail("nn.functional.pdist", ""),
1532                xfail("scatter_reduce", "prod"),
1533                xfail("nn.functional.max_unpool1d", ""),
1534                xfail("nn.functional.max_unpool3d", ""),
1535                xfail("nn.functional.max_unpool3d", "grad"),
1536                xfail("nn.functional.soft_margin_loss", ""),
1537                xfail("nn.functional.max_unpool1d", "grad"),
1538                xfail("nn.functional.max_unpool2d", "grad"),
1539                xfail("linalg.lu", ""),
1540                xfail("cdouble", ""),
1541                xfail("cfloat", ""),
1542                xfail("chalf", ""),
1543                xfail(
1544                    "index_reduce", "prod"
1545                ),  # aten::index_reduce hit the vmap fallback which is currently disabled
1546                xfail(
1547                    "index_reduce", "mean"
1548                ),  # aten::index_reduce hit the vmap fallback which is currently disabled
1549                xfail(
1550                    "index_reduce", "amax"
1551                ),  # aten::index_reduce hit the vmap fallback which is currently disabled
1552                xfail(
1553                    "index_reduce", "amin"
1554                ),  # aten::index_reduce hit the vmap fallback which is currently disabled
1555                xfail("nn.functional.dropout3d", ""),
1556                xfail("as_strided_scatter", ""),
1557                xfail("_segment_reduce", "offsets"),
1558                xfail("_segment_reduce", "lengths"),
1559                xfail("sparse.sampled_addmm", ""),
1560                xfail("sparse.mm", "reduce"),
1561                xfail("native_batch_norm"),
1562                xfail("_native_batch_norm_legit"),
1563                # TODO: implement batching rule
1564                xfail("_batch_norm_with_update"),
1565                xfail("native_dropout_backward"),
1566                xfail(
1567                    "index_fill"
1568                ),  # aten::_unique hit the vmap fallback which is currently disabled
1569                xfail("t_copy"),
1570                xfail("unsqueeze_copy"),
1571            }
1572        ),
1573    )
1574    def test_vmapvjp_has_batch_rule(self, device, dtype, op):
1575        if not op.supports_autograd:
1576            self.skipTest("Skipped! Autograd not supported.")
1577            return
1578
1579        samples = op.sample_inputs(device, dtype, requires_grad=True)
1580
1581        # TODO: test in-place
1582        if is_inplace(op, op.get_op()):
1583            self.skipTest("Skipped! NYI: inplace-testing not supported.")
1584            return
1585
1586        def test():
1587            for sample in samples:
1588                cotangents = get_sample_cotangents(op, sample)
1589                fn, args = get_vjp_fn_and_args_with_cotangents(op, sample, cotangents)
1590                is_batch_norm_and_training = is_batch_norm_training(
1591                    op.name, sample.kwargs
1592                )
1593                for loop_out, batched_out in get_fallback_and_vmap_exhaustive(
1594                    fn,
1595                    args,
1596                    {},
1597                    is_batch_norm_and_training=is_batch_norm_and_training,
1598                    compute_loop_out=False,
1599                ):
1600                    pass
1601                for a_op in op.aliases:
1602                    fn, args = get_vjp_fn_and_args_with_cotangents(
1603                        a_op, sample, cotangents
1604                    )
1605                    for loop_out, batched_out in get_fallback_and_vmap_exhaustive(
1606                        fn,
1607                        args,
1608                        {},
1609                        is_batch_norm_and_training=is_batch_norm_and_training,
1610                        compute_loop_out=False,
1611                    ):
1612                        pass
1613
1614        check_vmap_fallback(self, test, op, dry_run=False)
1615
1616    @ops(op_db + additional_op_db + autograd_function_db, allowed_dtypes=(torch.float,))
1617    @skipOps(
1618        "TestOperators",
1619        "test_vjpvmap",
1620        vjp_fail.union(
1621            {
1622                skip("bernoulli", ""),  # vjpvmap testing can't handle randomness
1623                skip("normal", ""),  # vjpvmap testing can't handle randomness
1624                skip(
1625                    "normal", "number_mean"
1626                ),  # vjpvmap testing can't handle randomness
1627                skip("nn.functional.rrelu"),  # randomness
1628                skip("nn.functional.feature_alpha_dropout", "with_train"),  # randomness
1629                skip(
1630                    "nn.functional.feature_alpha_dropout", "without_train"
1631                ),  # randomness
1632                skip("nn.functional.scaled_dot_product_attention"),
1633                xfail("torch.ops.aten._efficient_attention_forward"),  # outputs ints
1634                skip("nn.functional.multi_head_attention_forward"),  # randomness
1635                skip("nn.functional.alpha_dropout"),  # randomness
1636                skip(
1637                    "to"
1638                ),  # RuntimeError: required rank 4 tensor to use channels_last format
1639                skip("to_sparse", ""),  # non-dense output
1640                skip("ormqr", ""),  # takes too long
1641                xfail(
1642                    "NumpyCubeNotComposableAutogradFunction"
1643                ),  # Not composable autograd.Function
1644                # fallback path doesn't work
1645                # All of the following are bugs and need to be fixed
1646                xfail("__getitem__", ""),
1647                xfail("index_put", ""),
1648                xfail("view_as_complex"),
1649                xfail("nn.functional.gaussian_nll_loss"),
1650                xfail("masked_select"),
1651                xfail(
1652                    "narrow"
1653                ),  # Batching rule not implemented for `narrow.Tensor` (and view op)
1654                skip(
1655                    "nn.functional.fractional_max_pool3d"
1656                ),  # generator works on cpu, fails on cuda
1657                skip(
1658                    "nn.functional.fractional_max_pool2d"
1659                ),  # generator works on cpu, fails on cuda
1660                xfail("column_stack", ""),
1661                xfail("nn.functional.dropout2d", ""),
1662                xfail("svd_lowrank", ""),
1663                xfail("pca_lowrank", ""),
1664                xfail("clamp"),
1665                # something weird happening with channels_last
1666                xfail("bfloat16"),
1667                xfail("double"),
1668                xfail("float"),
1669                xfail("half"),
1670                xfail("cdouble"),
1671                xfail("cfloat"),
1672                xfail("nn.functional.dropout3d", ""),
1673                xfail("as_strided_scatter", ""),
1674                xfail("sparse.sampled_addmm", ""),
1675                xfail("sparse.mm", "reduce"),
1676                xfail("native_batch_norm"),
1677                xfail("_native_batch_norm_legit"),
1678                # TODO: implement batching rule
1679                xfail("_batch_norm_with_update"),
1680                xfail("as_strided", "partial_views"),
1681            }
1682        ),
1683    )
1684    def test_vjpvmap(self, device, dtype, op):
1685        # NB: there is no vjpvmap_has_batch_rule test because that is almost
1686        # certainly redundant with the vmap_has_batch_rule test in test_vmap.py
1687
1688        # one-off skip
1689        if op.name == "nn.functional.dropout":
1690            self.skipTest("Skipped!")
1691
1692        if not op.supports_autograd:
1693            # If the op doesn't support autograd, vmap(op) won't either
1694            self.skipTest("Skipped! Autograd not supported.")
1695            return
1696
1697        # TODO: test in-place
1698        if is_inplace(op, op.get_op()):
1699            self.skipTest("Skipped! NYI: inplace-testing not supported.")
1700            return
1701
1702        samples = op.sample_inputs(device, dtype, requires_grad=True)
1703        batch_norm_fns = (
1704            "nn.functional.batch_norm",
1705            "nn.functional.instance_norm",
1706        )  # instance norm calls batch norm
1707        is_batch_norm = op.name in batch_norm_fns
1708
1709        for sample in samples:
1710            args = [sample.input] + list(sample.args)
1711            kwargs = sample.kwargs
1712
1713            is_batch_norm_and_training = is_batch_norm and is_batch_norm_training(
1714                op.name, kwargs
1715            )
1716            generator = generate_vmap_inputs(
1717                args, kwargs, is_batch_norm_and_training=is_batch_norm_and_training
1718            )
1719
1720            for batched_args, in_dims, kwargs in generator:
1721                vmapped_op = vmap(op, in_dims)
1722                fn, primals = normalize_op_input_output2(
1723                    vmapped_op, batched_args, kwargs, sample.output_process_fn_grad
1724                )
1725                result = fn(*primals)
1726                cotangents = tree_map(lambda x: torch.randn_like(x), result)
1727
1728                _, vjp_fn = vjp(fn, *primals)
1729                result_vjps = vjp_fn(cotangents)
1730
1731                _, vjp_fn = ref_vjp(fn, *primals)
1732                expected_vjps = vjp_fn(cotangents)
1733
1734                self.assertEqual(result_vjps, expected_vjps)
1735
1736    def _compare_jacobians_of_vjp(
1737        self, fn, cotangents_and_primals, argnums=None, atol_rtol=None
1738    ):
1739        if argnums is None:
1740            argnums = tuple(range(len(cotangents_and_primals)))
1741
1742        def get_vjp(cotangents, *primals):
1743            _, vjp_fn = vjp(fn, *primals)
1744            return vjp_fn(cotangents)
1745
1746        jacobian_jvp = jacfwd(get_vjp, argnums)(*cotangents_and_primals)
1747        jacobian_vjp = jacrev(get_vjp, argnums)(*cotangents_and_primals)
1748
1749        # For dtype changing operations, the jacobians have different dtype.
1750        jacobian_jvp = tree_map(lambda x: x.to(torch.float), jacobian_jvp)
1751        jacobian_vjp = tree_map(lambda x: x.to(torch.float), jacobian_vjp)
1752
1753        if atol_rtol is not None:
1754            (atol, rtol) = atol_rtol
1755            self.assertEqual(jacobian_jvp, jacobian_vjp, atol=atol, rtol=rtol)
1756        else:
1757            self.assertEqual(jacobian_jvp, jacobian_vjp)
1758
1759    @ops(op_db + additional_op_db + autograd_function_db, allowed_dtypes=(torch.float,))
1760    @skipOps(
1761        "TestOperators",
1762        "test_jvpvjp",
1763        vjp_fail.union(
1764            {
1765                xfail("to_sparse", ""),  # NYI
1766                # RuntimeError: Trying to set a forward gradient that has a different size than that of the original Tensor,
1767                # this is not supported. Tensor is of size [5, 2, 3] while the given forward gradient is of size [1, 2, 3].
1768                xfail("normal", ""),
1769                xfail("cdist", ""),  # NYI: forward-AD for _cdist_forward
1770                xfail("cholesky", ""),  # NYI: forward-AD for cholesky
1771                xfail(
1772                    "nn.functional.embedding_bag", ""
1773                ),  # NYI: forward-AD for _embedding_bag
1774                xfail(
1775                    "nn.functional.grid_sample", ""
1776                ),  # NYI: forward AD for grid_sampler_2d
1777                xfail("grid_sampler_2d", ""),  # NYI: forward AD for grid_sampler_2d
1778                xfail(
1779                    "nn.functional.hardsigmoid", ""
1780                ),  # NYI: forward AD for hardsigmoid_backward
1781                xfail(
1782                    "nn.functional.huber_loss", ""
1783                ),  # NYI: forward AD for huber_loss_backward
1784                xfail("NumpyCubeNotComposableAutogradFunction"),  # not composable
1785                xfail("ormqr", ""),  # NYI: forward AD for ormqr
1786                xfail(
1787                    "nn.functional.multilabel_margin_loss", ""
1788                ),  # NYI: multilabel_margin_loss_forward
1789                xfail(
1790                    "nn.functional.soft_margin_loss", ""
1791                ),  # NYI: forward-AD for soft_margin_loss_backward
1792                xfail("nn.functional.ctc_loss", ""),  # NYI: forward-AD for _ctc_loss
1793                xfail("nn.functional.pdist", ""),  # NYI: forward-AD with _pdist_forward
1794                skip("nn.functional.scaled_dot_product_attention"),
1795                xfail("torch.ops.aten._efficient_attention_forward"),  # outputs ints
1796                xfail(
1797                    "nn.functional.multi_margin_loss", ""
1798                ),  # NYI: forward AD with multi_margin_loss
1799                skip(
1800                    "linalg.householder_product", "", device_type="cuda"
1801                ),  # flaky, I'm not sure why
1802                xfail("sparse.sampled_addmm", ""),  # Sparse tensors have no strides
1803                xfail(
1804                    "_segment_reduce", "offsets"
1805                ),  # NYI: forward-AD for _segment_reduce
1806                xfail("sparse.mm", "reduce"),  # Sparse tensors have no strides
1807                xfail("index_reduce", "prod"),  # NYI: forward-AD for index_reduce
1808                xfail("index_reduce", "mean"),  # NYI: forward-AD for index_reduce
1809                xfail("index_reduce", "amax"),  # NYI: forward-AD for index_reduce
1810                xfail("index_reduce", "amin"),  # NYI: forward-AD for index_reduce
1811                xfail(
1812                    "_segment_reduce", "lengths"
1813                ),  # NYI: forward-AD for _segment_reduce
1814                xfail("native_dropout_backward"),  # NYI
1815            }
1816        ),
1817    )
1818    @opsToleranceOverride(
1819        "TestOperators",
1820        "test_jvpvjp",
1821        (
1822            tol1("masked.prod", {torch.float32: tol(atol=1e-04, rtol=1.3e-05)}),
1823            tol1("masked.cumprod", {torch.float32: tol(atol=1e-04, rtol=5e-04)}),
1824            tol1(
1825                "cumprod",
1826                {torch.float32: tol(atol=1e-03, rtol=5e-04)},
1827                device_type="cuda",
1828            ),
1829            tol1(
1830                "linalg.det",
1831                {torch.float32: tol(atol=3e-05, rtol=5e-06)},
1832                device_type="cuda",
1833            ),
1834            tol1(
1835                "linalg.vander",
1836                {torch.float32: tol(atol=1e-04, rtol=1.3e-05)},
1837                device_type="cuda",
1838            ),
1839            tol1(
1840                "nn.functional.group_norm", {torch.float32: tol(atol=1e-03, rtol=1e-03)}
1841            ),
1842            tol2(
1843                "linalg.pinv", "hermitian", {torch.float32: tol(atol=5e-03, rtol=5e-03)}
1844            ),
1845        ),
1846    )
1847    def test_jvpvjp(self, device, dtype, op):
1848        if not op.supports_autograd:
1849            self.skipTest("Skipped! Autograd not supported.")
1850            return
1851
1852        samples = op.sample_inputs(device, dtype, requires_grad=True)
1853
1854        # TODO: test in-place
1855        if is_inplace(op, op.get_op()):
1856            self.skipTest("Skipped! NYI: inplace-testing not supported.")
1857            return
1858
1859        for sample in samples:
1860            fn, primals = normalize_op_input_output(op, sample)
1861            result = fn(*primals)
1862            cotangents = tree_map(lambda x: torch.randn_like(x), result)
1863
1864            primals_tangents = tree_map(lambda x: torch.randn_like(x), primals)
1865            cotangents_tangents = tree_map(lambda x: torch.randn_like(x), cotangents)
1866
1867            def push_vjp(primals, cotangents):
1868                _, vjp_fn = vjp(fn, *primals)
1869                return vjp_fn(cotangents)
1870
1871            result = jvp(
1872                push_vjp, (primals, cotangents), (primals_tangents, cotangents_tangents)
1873            )
1874            self.assertEqual(len(result), 2)
1875
1876            def tree_map2(fn, first, second):
1877                flat_first, spec_first = tree_flatten(first)
1878                flat_second, spec_second = tree_flatten(second)
1879                assert spec_first == spec_second
1880                flat_result = [fn(f, s) for f, s in zip(flat_first, flat_second)]
1881                return tree_unflatten(flat_result, spec_first)
1882
1883            def reference(primals, cotangents, primals_tangents, cotangents_tangents):
1884                with fwAD.dual_level():
1885                    primal_duals = tree_map2(fwAD.make_dual, primals, primals_tangents)
1886                    _, vjp_fn = ref_vjp(fn, *primal_duals)
1887
1888                    cotangent_duals = tree_map2(
1889                        fwAD.make_dual, cotangents, cotangents_tangents
1890                    )
1891                    result = vjp_fn(cotangent_duals)
1892
1893                    flat_result, spec = tree_flatten(result)
1894                    primals_out, tangents_out = zip(
1895                        *[fwAD.unpack_dual(r) for r in flat_result]
1896                    )
1897                    tangents_out = [
1898                        t if t is not None else torch.zeros_like(p)
1899                        for p, t in zip(primals_out, tangents_out)
1900                    ]
1901                    expected = (
1902                        tree_unflatten(primals_out, spec),
1903                        tree_unflatten(tangents_out, spec),
1904                    )
1905                return expected
1906
1907            expected = reference(
1908                primals, cotangents, primals_tangents, cotangents_tangents
1909            )
1910            self.assertEqual(result, expected)
1911
1912    @with_tf32_off  # https://github.com/pytorch/pytorch/issues/86798
1913    @skipOps(
1914        "TestOperators",
1915        "test_vmapjvpvjp",
1916        vjp_fail.union(
1917            {
1918                # Following operators take too long, hence skipped
1919                skip("atleast_1d"),
1920                skip("atleast_2d"),
1921                skip("atleast_3d"),
1922                skip("meshgrid", "list_of_tensors"),
1923                skip("meshgrid", "variadic_tensors"),
1924                skip("broadcast_tensors"),
1925                skip("linalg.lstsq"),
1926                skip("nn.functional.bilinear"),
1927                skip("native_layer_norm"),
1928                skip("ormqr"),
1929                # Not actually a problem
1930                xfail("NumpyCubeNotComposableAutogradFunction"),  # not composable
1931                xfail(
1932                    "NumpyExpMarkDirtyAutogradFunction"
1933                ),  # vmap: inplace into a regular tensor
1934                # Potential bugs/errors
1935                xfail("as_strided"),  # AssertionError: Tensor-likes are not close!
1936                xfail(
1937                    "as_strided", "partial_views"
1938                ),  # AssertionError: Tensor-likes are not close!
1939                xfail("as_strided_copy"),  # AssertionError: Tensor-likes are not close!
1940                xfail(
1941                    "as_strided_scatter"
1942                ),  # AssertionError: Tensor-likes are not close!
1943                xfail("bernoulli"),  # calls random op
1944                xfail("bfloat16"),  # required rank 4 tensor to use channels_last format
1945                xfail("cdist"),  # Forward AD not implemented and no decomposition
1946                xfail("cdouble"),  # required rank 4 tensor to use channels_last format
1947                xfail("cfloat"),  # required rank 4 tensor to use channels_last format
1948                xfail("chalf"),  # required rank 4 tensor to use channels_last format
1949                xfail("cholesky"),  # Forward AD not implemented and no decomposition
1950                xfail("ormqr"),  # Forward AD not implemented and no decomposition
1951                xfail("double"),  # required rank 4 tensor to use channels_last format
1952                xfail("float"),  # required rank 4 tensor to use channels_last format
1953                xfail("half"),  # required rank 4 tensor to use channels_last format
1954                xfail("index_reduce", "prod"),  # NYI: forward AD for index_reduce
1955                xfail("index_reduce", "mean"),  # NYI: forward AD for index_reduce
1956                xfail("index_reduce", "amax"),  # NYI: forward AD for index_reduce
1957                xfail("index_reduce", "amin"),  # NYI: forward AD for index_reduce
1958                xfail(
1959                    "mvlgamma", "mvlgamma_p_1"
1960                ),  # vmap: inplace into a regular tensor
1961                xfail(
1962                    "mvlgamma", "mvlgamma_p_3"
1963                ),  # vmap: inplace into a regular tensor
1964                xfail(
1965                    "mvlgamma", "mvlgamma_p_5"
1966                ),  # vmap: inplace into a regular tensor
1967                xfail("nanquantile"),  # Batching rule not implemented for aten::equal
1968                # RuntimeError: Batch norm got a batched tensor as input while the
1969                # running_mean or running_var, which will be updated in place,
1970                # were not batched.
1971                xfail("nn.functional.batch_norm"),
1972                xfail("nn.functional.batch_norm", "without_cudnn"),
1973                xfail(
1974                    "nn.functional.ctc_loss"
1975                ),  # ForwardAD not implemented and no decomposition
1976                xfail("nn.functional.dropout2d"),  # calls random op
1977                xfail("nn.functional.dropout3d"),  # calls random op
1978                xfail("nn.functional.dropout"),  # calls random op
1979                xfail("nn.functional.scaled_dot_product_attention"),  # randomness
1980                xfail("torch.ops.aten._efficient_attention_forward"),  # outputs ints
1981                xfail("nn.functional.multi_head_attention_forward"),  # randomness
1982                xfail(
1983                    "nn.functional.embedding_bag"
1984                ),  # Forward AD not implemented and no decomposition
1985                xfail("nn.functional.alpha_dropout"),  # calls randomn op
1986                xfail(
1987                    "nn.functional.feature_alpha_dropout", "with_train"
1988                ),  # calls random op
1989                xfail("nn.functional.fractional_max_pool2d"),  # calls random op
1990                xfail("nn.functional.fractional_max_pool3d"),  # calls random op
1991                xfail("nn.functional.gaussian_nll_loss"),  # data depenedant flow
1992                xfail(
1993                    "nn.functional.grid_sample"
1994                ),  # Forward AD not implemented and no decomposition
1995                xfail(
1996                    "grid_sampler_2d"
1997                ),  # Forward AD not implemented and no decomposition
1998                xfail(
1999                    "nn.functional.hardsigmoid"
2000                ),  # Forward AD not implemented and no decomposition
2001                xfail(
2002                    "nn.functional.hinge_embedding_loss"
2003                ),  # vmap: inplace into a regular tensor
2004                xfail(
2005                    "nn.functional.huber_loss"
2006                ),  # Forward AD not implemented and no decomposition
2007                # RuntimeError: Batch norm got a batched tensor as input while the
2008                # running_mean or running_var, which will be updated in place,
2009                # were not batched.
2010                xfail("nn.functional.instance_norm"),
2011                # NYI: Tensor.clone(memory_format) inside vmap is only supported with
2012                # memory_format torch.preserve_format or torch.contiguous_format (got ChannelsLast)
2013                xfail("nn.functional.max_unpool2d"),
2014                xfail("nn.functional.max_unpool2d", "grad"),
2015                xfail(
2016                    "nn.functional.multi_margin_loss"
2017                ),  # Forward AD not implemented and no decomposition
2018                xfail(
2019                    "nn.functional.multilabel_margin_loss"
2020                ),  # Forward AD not implemented and no decomposition
2021                xfail(
2022                    "nn.functional.pdist"
2023                ),  # Forward AD not implemented and no decomposition
2024                xfail(
2025                    "nn.functional.rrelu"
2026                ),  # vmap: we do not yet support aten::rrelu_with_noise.
2027                xfail(
2028                    "nn.functional.soft_margin_loss"
2029                ),  # Forward AD not implemented and no decomposition
2030                xfail("normal"),  # calls random op
2031                xfail("normal", "number_mean"),  # calls random op
2032                xfail("pca_lowrank"),  # calls random op
2033                xfail("quantile"),  # Batching rule not implemented for aten::equal
2034                xfail(
2035                    "scatter_reduce", "prod"
2036                ),  # Forward AD not implemented and no decomposition
2037                xfail(
2038                    "_segment_reduce", "lengths"
2039                ),  # Forward AD not implemented and no decomposition
2040                xfail(
2041                    "_segment_reduce", "offsets"
2042                ),  # Forward AD not implemented and no decomposition
2043                xfail(
2044                    "sparse.sampled_addmm"
2045                ),  # RuntimeError: Sparse CSR tensors do not have strides
2046                xfail(
2047                    "sparse.mm", "reduce"
2048                ),  # RuntimeError: Sparse CSR tensors do not have strides
2049                xfail("svd_lowrank"),  # calls random op
2050                xfail(
2051                    "to"
2052                ),  # RuntimeError: required rank 4 tensor to use channels_last format
2053                xfail("to_sparse"),  # Forward AD not implemented and no decomposition
2054                xfail(
2055                    "view_as_complex"
2056                ),  # RuntimeError: Tensor must have a last dimension with stride 1
2057                # RuntimeError: Batch norm got a batched tensor as
2058                # input while the running_mean or running_var, which will be updated in
2059                # place, were not batched.
2060                xfail("native_batch_norm"),
2061                xfail("_native_batch_norm_legit"),
2062                # TODO: implement batching rule
2063                xfail("_batch_norm_with_update"),
2064                xfail("native_dropout_backward"),
2065            }
2066        ),
2067    )
2068    @ops(op_db + additional_op_db + autograd_function_db, allowed_dtypes=(torch.float,))
2069    @toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)})
2070    @opsToleranceOverride(
2071        "TestOperators",
2072        "test_vmapjvpvjp",
2073        (
2074            tol1("linalg.svd", {torch.float32: tol(atol=5e-04, rtol=5e-04)}),
2075            tol1(
2076                "linalg.householder_product",
2077                {torch.float32: tol(atol=5e-03, rtol=5e-03)},
2078            ),
2079            tol1("linalg.multi_dot", {torch.float32: tol(atol=5e-04, rtol=5e-04)}),
2080            tol2(
2081                "linalg.pinv", "hermitian", {torch.float32: tol(atol=5e-04, rtol=5e-04)}
2082            ),
2083            tol1(
2084                "nn.functional.conv_transpose2d",
2085                {torch.float32: tol(atol=5e-04, rtol=5e-04)},
2086            ),
2087            tol1("svd", {torch.float32: tol(atol=5e-04, rtol=5e-04)}),
2088            tol1("matrix_exp", {torch.float32: tol(atol=5e-04, rtol=5e-04)}),
2089        ),
2090    )
2091    def test_vmapjvpvjp(self, device, dtype, op):
2092        # Since we test `jvpvjp` separately,
2093        # in this we just check that vmap of `jvpvjp`
2094        # is correct.
2095        if not op.supports_autograd:
2096            self.skipTest("Skipped! Autograd not supported.")
2097            return
2098
2099        samples = op.sample_inputs(device, dtype, requires_grad=True)
2100
2101        # TODO: test in-place
2102        if is_inplace(op, op.get_op()):
2103            self.skipTest("Skipped! NYI: inplace-testing not supported.")
2104            return
2105
2106        for sample in samples:
2107            fn, primals = normalize_op_input_output(op, sample)
2108            result = fn(*primals)
2109            cotangents = tree_map(lambda x: torch.randn_like(x), result)
2110
2111            primals_tangents = tree_map(lambda x: torch.randn_like(x), primals)
2112            cotangents_tangents = tree_map(lambda x: torch.randn_like(x), cotangents)
2113
2114            def push_vjp(primals, cotangents):
2115                _, vjp_fn = vjp(fn, *primals)
2116                return vjp_fn(cotangents)
2117
2118            args, spec = tree_flatten(
2119                ((primals, cotangents), (primals_tangents, cotangents_tangents))
2120            )
2121
2122            def jvp_of_vjp(*args):
2123                (primals, tangents) = tree_unflatten(args, spec)
2124                primals_out, tangents_out = jvp(push_vjp, primals, tangents)
2125
2126                flat_primals_out = pytree.tree_leaves(primals_out)
2127                flat_tangents_out = pytree.tree_leaves(tangents_out)
2128                return tuple(flat_primals_out + flat_tangents_out)
2129
2130            is_batch_norm_and_training = is_batch_norm_training(op, sample.kwargs)
2131            generator = get_fallback_and_vmap_exhaustive(
2132                jvp_of_vjp,
2133                args,
2134                {},
2135                is_batch_norm_and_training=is_batch_norm_and_training,
2136            )
2137            for loop_out, batched_out in generator:
2138                self.assertEqual(loop_out, batched_out)
2139
2140    def _make_extremal_inputs(self, shape, device):
2141        if shape is None:
2142            return (None,)
2143        return (
2144            torch.full(shape, -1000.0, device=device),
2145            torch.zeros(shape, device=device),
2146            torch.full(shape, 1000.0, device=device),
2147        )
2148
2149    def _arg_and_kwarg_options(self, args_options, kwargs_options):
2150        return itertools.product(*args_options, kwargs_options)
2151
2152    def test_extremal_numerics_nll_loss(self, device):
2153        N, C = 3, 4
2154        d1, d2, d3 = 5, 6, 7
2155        shapes = (
2156            ((N, C), (N,), (C,)),
2157            ((N, C), (N,), None),
2158            ((N, C, d1, d2, d3), (N, d1, d2, d3), (C,)),
2159            ((N, C, d1, d2, d3), (N, d1, d2, d3), None),
2160        )
2161        kwargs_options = (
2162            {"ignore_index": 0, "reduction": "mean"},
2163            {"reduction": "sum"},
2164            {"reduction": "none"},
2165            {},
2166        )
2167        for input_shape, target_shape, weight_shape in shapes:
2168            input_options = self._make_extremal_inputs(input_shape, device)
2169            for input, kwargs in self._arg_and_kwarg_options(
2170                (input_options,), kwargs_options
2171            ):
2172                if weight_shape is None:
2173                    weight = None
2174                else:
2175                    weight = torch.randn(weight_shape, device=device)
2176                target = torch.randint(0, C, target_shape, device=device)
2177                target[
2178                    0
2179                ] = 1  # since we're ignoring index 0, at least one element must be non-zero
2180
2181                fn = functools.partial(
2182                    torch.nn.functional.nll_loss, target=target, weight=weight, **kwargs
2183                )
2184                result = fn(input)
2185                cotangents = torch.randn_like(result, device=device)
2186                self._compare_jacobians_of_vjp(fn, (cotangents, input))
2187
2188    def test_extremal_numerics_l1_loss(self, device):
2189        N, C, H, W = 3, 4, 5, 6
2190        shapes = ((N, C), (N, C, H), (N, C, H, W))
2191        kwargs_options = ({"reduction": "sum"}, {"reduction": "none"}, {})
2192        for shape in shapes:
2193            input_options = self._make_extremal_inputs(shape, device)
2194            target_options = self._make_extremal_inputs(shape, device)
2195            for input, target, kwargs in self._arg_and_kwarg_options(
2196                (input_options, target_options), kwargs_options
2197            ):
2198                result = torch.nn.functional.l1_loss(input, target)
2199                cotangents = torch.randn_like(result, device=device)
2200                self._compare_jacobians_of_vjp(
2201                    torch.nn.functional.l1_loss, (cotangents, input, target)
2202                )
2203
2204    def test_extremal_numerics_mse_loss(self, device):
2205        N, C, H, W = 3, 4, 5, 6
2206        shapes = ((N, C), (N, C, H), (N, C, H, W))
2207        kwargs_options = ({"reduction": "sum"}, {"reduction": "none"}, {})
2208        for shape in shapes:
2209            input_options = self._make_extremal_inputs(shape, device)
2210            target_options = self._make_extremal_inputs(shape, device)
2211            for input, target, kwargs in self._arg_and_kwarg_options(
2212                (input_options, target_options), kwargs_options
2213            ):
2214                result = torch.nn.functional.mse_loss(input, target)
2215                cotangents = torch.randn_like(result, device=device)
2216                self._compare_jacobians_of_vjp(
2217                    torch.nn.functional.mse_loss, (cotangents, input, target)
2218                )
2219
2220    def test_extremal_numerics_softmax(self, device):
2221        N, C, H, W = 3, 4, 5, 6
2222        shapes = ((N, C), (N, C, H), (N, C, H, W))
2223        kwargs_options = ({"dim": 1}, {})
2224        for shape in shapes:
2225            input_options = self._make_extremal_inputs(shape, device)
2226            for input, kwargs in self._arg_and_kwarg_options(
2227                (input_options,), kwargs_options
2228            ):
2229                result = torch.nn.functional.softmax(input)
2230                cotangents = torch.randn_like(result, device=device)
2231                self._compare_jacobians_of_vjp(
2232                    torch.nn.functional.softmax, (cotangents, input)
2233                )
2234
2235    def test_extremal_numerics_log_softmax(self, device):
2236        N, C, H, W = 3, 4, 5, 6
2237        shapes = ((N, C), (N, C, H), (N, C, H, W))
2238        kwargs_options = ({"dim": 1}, {})
2239        for shape in shapes:
2240            input_options = self._make_extremal_inputs(shape, device)
2241            for input, kwargs in self._arg_and_kwarg_options(
2242                (input_options,), kwargs_options
2243            ):
2244                result = torch.nn.functional.log_softmax(input)
2245                cotangents = torch.randn_like(result, device=device)
2246                self._compare_jacobians_of_vjp(
2247                    torch.nn.functional.log_softmax, (cotangents, input)
2248                )
2249
2250    def test_extremal_numerics_cross_entropy(self, device):
2251        N, C = 3, 4
2252        d1, d2, d3 = 5, 6, 7
2253        shapes = (
2254            ((N, C), (N,), (C,)),
2255            ((N, C), (N,), None),
2256            ((N, C), (N, C), (C,)),
2257            ((N, C), (N, C), None),
2258            ((C,), (), (C,)),
2259            ((C,), (), None),
2260            ((C,), (C,), (C,)),
2261            ((C,), (C,), None),
2262            ((N, C, d1, d2, d3), (N, d1, d2, d3), (C,)),
2263            ((N, C, d1, d2, d3), (N, d1, d2, d3), None),
2264            ((N, C, d1, d2, d3), (N, C, d1, d2, d3), (C,)),
2265            ((N, C, d1, d2, d3), (N, C, d1, d2, d3), None),
2266        )
2267        for input_shape, target_shape, weight_shape in shapes:
2268            input_options = self._make_extremal_inputs(input_shape, device)
2269            kwargs_options = [{"reduction": "sum"}, {"reduction": "none"}, {}]
2270            if input_shape != target_shape:
2271                kwargs_options.append({"ignore_index": 0, "reduction": "mean"})
2272
2273            for input, kwargs in self._arg_and_kwarg_options(
2274                (input_options,), kwargs_options
2275            ):
2276                if weight_shape is None:
2277                    weight = None
2278                else:
2279                    weight = torch.randn(weight_shape, device=device)
2280
2281                if input_shape == target_shape:
2282                    target = torch.rand(target_shape, device=device)
2283                elif len(target_shape) == 0:
2284                    target = torch.tensor(
2285                        1, device=device
2286                    )  # must be non-zero since ignore_index may be 0
2287                else:
2288                    target = torch.randint(0, C, target_shape, device=device)
2289
2290                fn = functools.partial(
2291                    torch.nn.functional.cross_entropy,
2292                    target=target,
2293                    weight=weight,
2294                    **kwargs,
2295                )
2296                result = fn(input)
2297                cotangents = torch.randn_like(result, device=device)
2298                self._compare_jacobians_of_vjp(
2299                    fn, (cotangents, input), atol_rtol=(1e-4, 1e-5)
2300                )
2301
2302    def test_extremal_numerics_binary_cross_entropy(self, device):
2303        N, C, H, W = 3, 4, 5, 6
2304        shapes = ((N, C), (N, C, H), (N, C, H, W))
2305        for shape in shapes:
2306            weight_options = self._make_extremal_inputs(shape, device)
2307            kwargs_options = [{"reduction": "sum"}, {"reduction": "none"}, {}]
2308
2309            for weight, kwargs in self._arg_and_kwarg_options(
2310                (weight_options,), kwargs_options
2311            ):
2312                input = torch.rand(shape, device=device)
2313                target = torch.rand(shape, device=device)
2314                fn = functools.partial(
2315                    torch.nn.functional.binary_cross_entropy,
2316                    target=target,
2317                    weight=weight,
2318                    **kwargs,
2319                )
2320                result = fn(input)
2321                cotangents = torch.randn_like(result, device=device)
2322                self._compare_jacobians_of_vjp(
2323                    fn, (cotangents, input), atol_rtol=(1e-4, 2e-5)
2324                )
2325
2326    def test_extremal_numerics_layer_norm(self, device):
2327        N, C, H, W = 3, 4, 5, 6
2328        shapes = ((N, C), (N, C, H), (N, C, H, W))
2329        for shape in shapes:
2330            input_options = self._make_extremal_inputs(shape, device)
2331            normalized_shape = shape[1:]
2332            weight_options = self._make_extremal_inputs(normalized_shape, device)
2333            bias_options = self._make_extremal_inputs(normalized_shape, device)
2334
2335            for input, bias, weight in self._arg_and_kwarg_options(
2336                (input_options, bias_options, weight_options), ()
2337            ):
2338
2339                def fn(input, weight, bias):
2340                    return torch.nn.functional.layer_norm(
2341                        input, normalized_shape, weight=weight, bias=bias
2342                    )
2343
2344                result = fn(input, weight, bias)
2345                cotangents = torch.randn_like(result, device=device)
2346                self._compare_jacobians_of_vjp(fn, (cotangents, input, weight, bias))
2347
2348    @with_tf32_off  # https://github.com/pytorch/pytorch/issues/86798
2349    @ops(
2350        op_db + additional_op_db + autograd_function_db,
2351        allowed_dtypes=(torch.float32, torch.double),
2352    )
2353    @skipOps(
2354        "TestOperators",
2355        "test_vmap_autograd_grad",
2356        {
2357            # The size of tensor a (4) must match the size of tensor b (10) at non-singleton dimension 0
2358            xfail("masked_select"),
2359            xfail("nn.functional.max_unpool2d", "grad"),  # contiguous call
2360            xfail("nn.functional.max_unpool2d"),  # contiguous call
2361            xfail("to_sparse"),  # dispatch key issue
2362            xfail("torch.ops.aten._efficient_attention_forward"),  # outputs ints
2363            # https://github.com/pytorch/pytorch/issues/96560#issuecomment-2151063723
2364            # ** minor accuracy issue for float32 on ROCm
2365            decorate("xlogy", decorator=skipIfRocm),
2366            # numerical inconsistencies, look like bugs
2367            skip(
2368                "matrix_exp", dtypes=(torch.float32,), device_type="cuda"
2369            ),  # fails on linux, passes on windows
2370            skip(
2371                "ldexp", dtypes=(torch.float32,), device_type="cpu"
2372            ),  # fails on all but mac
2373            skip("__rmatmul__"),  # flaky needs investigation
2374            skip("matmul"),  # flaky needs investigation
2375            skip("nn.functional.conv_transpose3d"),  # flaky needs investigation
2376            skip("nn.functional.conv_transpose2d"),  # flaky needs investigation
2377            skip("nn.functional.conv_transpose1d"),  # flaky needs investigation
2378            skip(
2379                "nn.functional.layer_norm", dtypes=(torch.float32,), device_type="cpu"
2380            ),  # fails on windows
2381            skip(
2382                "linalg.lu_factor", dtypes=(torch.float32,), device_type="cuda"
2383            ),  # fails on all but windows
2384            skip(
2385                "linalg.lu_factor_ex", dtypes=(torch.float32,), device_type="cuda"
2386            ),  # fails on all but windows
2387            skip("linalg.multi_dot", "", device_type="cpu"),
2388            skip("sparse.sampled_addmm", ""),
2389            skip("sparse.mm", "reduce"),
2390            skip("native_layer_norm", "", device_type="cpu"),
2391            # RuntimeError: Expected contiguous tensor, but got
2392            # non-contiguous tensor for argument #2 'grad_output'
2393            decorate(
2394                "_batch_norm_with_update",
2395                decorator=expectedFailureIf(TEST_WITH_ROCM),
2396                device_type="cuda",
2397            ),
2398        },
2399    )
2400    @opsToleranceOverride(
2401        "TestOperators",
2402        "test_vmap_autograd_grad",
2403        (
2404            tol1(
2405                "ldexp",
2406                {torch.float32: tol(atol=3e-04, rtol=1.6e-06)},
2407                device_type="cuda",
2408            ),
2409            tol1(
2410                "linalg.householder_product",
2411                {torch.float32: tol(atol=5e-04, rtol=9e-03)},
2412                device_type="cuda",
2413            ),
2414            tol1(
2415                "linalg.householder_product",
2416                {torch.float32: tol(atol=6e-03, rtol=1e-03)},
2417                device_type="cpu",
2418            ),
2419            tol1(
2420                "linalg.multi_dot",
2421                {torch.float32: tol(atol=2e-04, rtol=1e-04)},
2422                device_type="cuda",
2423            ),
2424            tol2(
2425                "linalg.pinv", "hermitian", {torch.float32: tol(atol=5e-06, rtol=5e-06)}
2426            ),
2427            tol1("nn.functional.conv3d", {torch.float32: tol(atol=5e-04, rtol=9e-03)}),
2428            tol1(
2429                "nn.functional.conv2d",
2430                {torch.float32: tol(atol=3e-05, rtol=5e-06)},
2431                device_type="cuda",
2432            ),
2433            tol1("svd_lowrank", {torch.float32: tol(atol=5e-05, rtol=5e-05)}),
2434            tol1("pca_lowrank", {torch.float32: tol(atol=5e-05, rtol=5e-05)}),
2435        ),
2436    )
2437    def test_vmap_autograd_grad(self, device, dtype, op):
2438        def is_differentiable(inp):
2439            return isinstance(inp, Tensor) and (
2440                inp.grad_fn is not None or inp.requires_grad
2441            )
2442
2443        def get_flat_differentiable(tree):
2444            flattened = pytree.tree_leaves(tree)
2445            return tuple(i for i in flattened if is_differentiable(i))
2446
2447        def get_differentiable_linked(list1, list2):
2448            paired_list = zip(list1, list2)
2449            paired_list = tuple(
2450                (first, second)
2451                for (first, second) in paired_list
2452                if is_differentiable(first)
2453            )
2454            return zip(*paired_list)
2455
2456        def filter_none(out):
2457            flattened = pytree.tree_leaves(out)
2458            return tuple(o for o in flattened if o is not None)
2459
2460        if not op.supports_autograd:
2461            self.skipTest("Skipped! Autograd not supported.")
2462            return
2463
2464        sample_inputs = op.sample_inputs(device, dtype, requires_grad=True)
2465
2466        for sample_input in sample_inputs:
2467            fn, primals = normalize_op_input_output(op, sample_input)
2468            out = fn(*primals)
2469            cotangents = tree_map(torch.randn_like, out)
2470
2471            def compute_grad(cotangents):
2472                out_flattened = out
2473                cotangents_flattened = cotangents
2474                if not isinstance(out_flattened, torch.Tensor):
2475                    out_flattened = pytree.tree_leaves(out)
2476                    cotangents_flattened = pytree.tree_leaves(cotangents)
2477                    out_flattened, cotangents_flattened = get_differentiable_linked(
2478                        out_flattened, cotangents_flattened
2479                    )
2480
2481                return filter_none(
2482                    torch.autograd.grad(
2483                        out_flattened,
2484                        get_flat_differentiable(primals),
2485                        cotangents_flattened,
2486                        retain_graph=True,
2487                        allow_unused=True,
2488                    )
2489                )
2490
2491            is_batch_norm_and_training = is_batch_norm_training(op, sample_input.kwargs)
2492            generator = get_fallback_and_vmap_exhaustive(
2493                compute_grad,
2494                (cotangents,),
2495                {},
2496                is_batch_norm_and_training=is_batch_norm_and_training,
2497            )
2498            for loop_out, batched_out in generator:
2499                self.assertEqual(loop_out, batched_out)
2500
2501    def test_vmapvmapjvp_linalg_solve(self):
2502        ops = [op for op in op_db if op.name == "linalg.solve"]
2503        assert len(ops) > 0
2504
2505        # this specializes a lot of code from the get_fallback_and_vmap_exhaustive test. If we need this more
2506        # generally, this could go for a refactor
2507
2508        B0 = 2
2509        B1 = 3
2510
2511        # we want to check the case where A will be seen as contiguous by jvp but during the vmap calls will become
2512        # non-contiguous because vmap will expand. This will happen during both levels of vmap
2513        A = torch.randn(4, 4)
2514        k = torch.randn(4, 5, B1, B0)
2515        fn, args = get_jvp_variant_primals_tangents(
2516            torch.linalg.solve, SampleInput(A, args=(k,))
2517        )
2518
2519        in_dims_all = (None, -1, None, -1)
2520        batched_out = vmap(vmap(fn, in_dims=in_dims_all), in_dims=in_dims_all)(*args)
2521        loop_out = loop2(fn, in_dims_all, in_dims_all, 0, 0, B0, B1, *args)
2522        self.assertEqual(loop_out, batched_out)
2523
2524    @ops(
2525        filter(lambda op: op.name in aliasing_ops, op_db + additional_op_db),
2526        allowed_dtypes=(torch.float,),
2527    )
2528    @parametrize("grad_op", ["jvp", "vjp"])
2529    def test_view_then_inplace(self, device, dtype, op, grad_op):
2530        for sample_input in op.sample_inputs(device, dtype):
2531
2532            def f(x):
2533                op(sample_input.input, *sample_input.args, **sample_input.kwargs).copy_(
2534                    x
2535                )
2536                return x
2537
2538            without_grad = op(
2539                sample_input.input, *sample_input.args, **sample_input.kwargs
2540            )
2541            if grad_op == "jvp":
2542                with self.assertRaisesRegex(
2543                    RuntimeError,
2544                    "During a grad .* attempted to call in-place operation",
2545                ):
2546                    jvp(
2547                        f,
2548                        (torch.randn_like(without_grad),),
2549                        (torch.randn_like(without_grad),),
2550                    )
2551            else:
2552                assert grad_op == "vjp"
2553                with self.assertRaisesRegex(
2554                    RuntimeError,
2555                    "During a grad .* attempted to call in-place operation",
2556                ):
2557                    vjp(f, torch.randn_like(without_grad))
2558
2559    @ops(
2560        filter(
2561            lambda op: op.name in aliasing_ops_list_return, op_db + additional_op_db
2562        ),
2563        allowed_dtypes=(torch.float,),
2564    )
2565    @parametrize("grad_op", ["jvp", "vjp"])
2566    def test_view_then_inplace_list_return(self, device, dtype, op, grad_op):
2567        for sample_input in op.sample_inputs(device, dtype):
2568
2569            def f(x):
2570                op(sample_input.input, *sample_input.args, **sample_input.kwargs)[
2571                    0
2572                ].copy_(x)
2573                return x
2574
2575            without_grad = op(
2576                sample_input.input, *sample_input.args, **sample_input.kwargs
2577            )[0]
2578            with self.assertRaisesRegex(
2579                RuntimeError, "During a grad .* attempted to call in-place operation"
2580            ):
2581                if grad_op == "jvp":
2582                    jvp(
2583                        f,
2584                        (torch.randn_like(without_grad),),
2585                        (torch.randn_like(without_grad),),
2586                    )
2587                else:
2588                    assert grad_op == "vjp"
2589                    vjp(f, torch.randn_like(without_grad))
2590
2591    @parametrize("grad_op", ["jvp", "vjp"])
2592    def test_view_then_inplace_special(self, grad_op):
2593        # some things in __getitem__ use at::index, which doesn't alias, so this tests a subset of them that do alias
2594        ops = [
2595            lambda x: x[0],
2596            lambda x: x[0, 0, 0],
2597            lambda x: x[:1],
2598            lambda x: x[:, :1],
2599            lambda x: x[:, :1, :],
2600        ]
2601
2602        for op in ops:
2603
2604            def f(x):
2605                op(captured).copy_(x)
2606                return x
2607
2608            captured = torch.randn(4, 3, 3)
2609            without_grad = op(captured)
2610            if grad_op == "jvp":
2611                with self.assertRaisesRegex(
2612                    RuntimeError,
2613                    "During a grad .* attempted to call in-place operation",
2614                ):
2615                    jvp(
2616                        f,
2617                        (torch.randn_like(without_grad),),
2618                        (torch.randn_like(without_grad),),
2619                    )
2620            else:
2621                assert grad_op == "vjp"
2622                with self.assertRaisesRegex(
2623                    RuntimeError,
2624                    "During a grad .* attempted to call in-place operation",
2625                ):
2626                    vjp(f, torch.randn_like(without_grad))
2627
2628    @with_tf32_off  # https://github.com/pytorch/pytorch/issues/86798
2629    # NOTE: [three-transform testing]
2630    # We only test the autograd_function_db tests here.
2631    #
2632    # Usually testing the composition of two transforms is sufficient to convince
2633    # ourselves that an operator is correctly implemented. For the following cases,
2634    # we want to be extra sure, so we send those through some three-transform tests:
2635    # - autograd.Function. The mechanism is via PyDispatcher/HigherOrderOperator, not the
2636    #   regular PyTorch dispatcher, so it's good to exercise more caution.
2637    @ops(autograd_function_db, allowed_dtypes=(torch.float32,))
2638    @skipOps(
2639        "TestOperators",
2640        "test_vmapvjpvmap",
2641        {
2642            xfail("NumpyCubeNotComposableAutogradFunction"),  # Not composable
2643        },
2644    )
2645    def test_vmapvjpvmap(self, device, dtype, op):
2646        samples = op.sample_inputs(device, dtype, requires_grad=True)
2647        B = 2
2648        for sample in samples:
2649            args = [sample.input] + list(sample.args)
2650            kwargs = sample.kwargs
2651            generator = generate_vmap_inputs(args, kwargs, batch_size=B)
2652            for batched_args, in_dims, kwargs in generator:
2653                inner_vmapped_op = vmap(op, in_dims)
2654                inner_mapped_op = functools.partial(loop, op, in_dims, 0, B)
2655
2656                inner_vmapped_fn, primals = normalize_op_input_output2(
2657                    inner_vmapped_op,
2658                    batched_args,
2659                    kwargs,
2660                    sample.output_process_fn_grad,
2661                )
2662                inner_mapped_fn, _ = normalize_op_input_output2(
2663                    inner_mapped_op, batched_args, kwargs, sample.output_process_fn_grad
2664                )
2665                result = inner_mapped_fn(*primals)
2666                cotangents = tree_map(lambda x: torch.rand_like(x), result)
2667
2668                def apply_vjp(fn):
2669                    def inner(primals, cotangents):
2670                        _, vjp_fn = vjp(fn, *primals)
2671                        return vjp_fn(cotangents)
2672
2673                    return inner
2674
2675                vjpvmap_fn = apply_vjp(inner_vmapped_fn)
2676                vjpmap_fn = apply_vjp(inner_mapped_fn)
2677                batched_args = (primals, cotangents)
2678                generator = generate_vmap_inputs(batched_args, {})
2679
2680                for batched_args, in_dims, _ in generator:
2681                    # strategy: compare vmap(vjp(vmap(op)) vs map(vjp(map(op))
2682                    vmapvjpvmap_fn = vmap(vjpvmap_fn, in_dims)
2683                    mapvjpmap_fn = functools.partial(loop, vjpmap_fn, in_dims, 0, B)
2684
2685                    result = vmapvjpvmap_fn(*batched_args)
2686                    expected = mapvjpmap_fn(*batched_args)
2687                    self.assertEqual(result, expected)
2688
2689    # See NOTE: [three-transform testing]
2690    @ops(autograd_function_db, allowed_dtypes=(torch.float32,))
2691    @skipOps(
2692        "TestOperators",
2693        "test_vjpvmapvmap",
2694        {
2695            xfail("NumpyCubeNotComposableAutogradFunction"),  # Not composable
2696        },
2697    )
2698    def test_vjpvmapvmap(self, device, dtype, op):
2699        samples = op.sample_inputs(device, dtype, requires_grad=True)
2700        B = 2
2701        for sample in samples:
2702            args = [sample.input] + list(sample.args)
2703            kwargs = sample.kwargs
2704            generator = generate_vmap_inputs(args, kwargs, batch_size=B)
2705            for batched_args, inner_in_dims, kwargs in generator:
2706                inner_vmapped_op = vmap(op, inner_in_dims)
2707                inner_mapped_op = functools.partial(loop, op, inner_in_dims, 0, B)
2708                generator = generate_vmap_inputs(batched_args, kwargs)
2709                for batched_args, in_dims, kwargs in generator:
2710                    # strategy: compare vjp(vmap(vmap(op)) vs vjp(map(map(op))
2711                    vmapped_op = vmap(inner_vmapped_op, in_dims)
2712                    mapped_op = functools.partial(loop, inner_mapped_op, in_dims, 0, B)
2713
2714                    vmapped_fn, primals = normalize_op_input_output2(
2715                        vmapped_op, batched_args, kwargs, sample.output_process_fn_grad
2716                    )
2717                    mapped_fn, _ = normalize_op_input_output2(
2718                        mapped_op, batched_args, kwargs, sample.output_process_fn_grad
2719                    )
2720
2721                    result = mapped_fn(*primals)
2722                    cotangents = tree_map(lambda x: torch.rand_like(x), result)
2723
2724                    _, vjp_fn = vjp(mapped_fn, *primals)
2725                    expected_vjps = vjp_fn(cotangents)
2726
2727                    _, vjp_fn = vjp(vmapped_fn, *primals)
2728                    result_vjps = vjp_fn(cotangents)
2729
2730                    self.assertEqual(result_vjps, expected_vjps)
2731
2732    # See NOTE: [three-transform testing]
2733    @ops(autograd_function_db, allowed_dtypes=(torch.float32,))
2734    @skipOps(
2735        "TestOperators",
2736        "test_vjpvjpvmap",
2737        {
2738            xfail("NumpyCubeNotComposableAutogradFunction"),  # Not composable
2739        },
2740    )
2741    def test_vjpvjpvmap(self, device, dtype, op):
2742        samples = op.sample_inputs(device, dtype, requires_grad=True)
2743        B = 2
2744        for sample in samples:
2745            args = [sample.input] + list(sample.args)
2746            kwargs = sample.kwargs
2747            generator = generate_vmap_inputs(args, kwargs, batch_size=B)
2748            for batched_args, in_dims, kwargs in generator:
2749                inner_vmapped_op = vmap(op, in_dims)
2750                inner_mapped_op = functools.partial(loop, op, in_dims, 0, B)
2751
2752                vjpmap_fn, args = get_vjpfull_variant2(
2753                    inner_mapped_op, batched_args, kwargs
2754                )
2755                vjpvmap_fn, _ = get_vjpfull_variant2(
2756                    inner_vmapped_op, batched_args, kwargs
2757                )
2758
2759                vjpvjpvmap_fn, new_args = get_vjpfull_variant2(vjpvmap_fn, args, {})
2760                vjpvjpmap_fn, _ = get_vjpfull_variant2(vjpmap_fn, args, {})
2761
2762                expected = vjpvjpmap_fn(*new_args)
2763                result = vjpvjpvmap_fn(*new_args)
2764                self.assertEqual(result, expected)
2765
2766    # We're generally convinced that jvp x vmap works (vmap turns an operator
2767    # into another operator and we test jvp support for operators). So
2768    # we only test it on the things we're not sure about:
2769    # - the autograd.Function <> functorch interaction
2770    @ops(autograd_function_db, allowed_dtypes=(torch.float32,))
2771    @skipOps(
2772        "TestOperators",
2773        "test_jvpvmap",
2774        {
2775            xfail("NumpyCubeNotComposableAutogradFunction"),  # Not composable
2776        },
2777    )
2778    def test_jvpvmap(self, device, dtype, op):
2779        samples = op.sample_inputs(device, dtype, requires_grad=True)
2780        B = 2
2781        for sample in samples:
2782            args = [sample.input] + list(sample.args)
2783            kwargs = sample.kwargs
2784            generator = generate_vmap_inputs(args, kwargs, batch_size=B)
2785            for batched_args, in_dims, kwargs in generator:
2786                inner_vmapped_op = vmap(op, in_dims)
2787                inner_mapped_op = functools.partial(loop, op, in_dims, 0, B)
2788
2789                jvpvmap_op, primals = get_jvp_variant_primals_tangents2(
2790                    inner_vmapped_op,
2791                    batched_args,
2792                    kwargs,
2793                    sample.output_process_fn_grad,
2794                )
2795                jvpmap_op, _ = get_jvp_variant_primals_tangents2(
2796                    inner_mapped_op, batched_args, kwargs, sample.output_process_fn_grad
2797                )
2798
2799                expected = jvpmap_op(*primals)
2800                result = jvpvmap_op(*primals)
2801                self.assertEqual(result, expected)
2802
2803    # See NOTE: [three-transform testing]
2804    @ops(autograd_function_db, allowed_dtypes=(torch.float32,))
2805    @skipOps(
2806        "TestOperators",
2807        "test_jvpvmapvmap",
2808        {
2809            xfail("NumpyCubeNotComposableAutogradFunction"),  # Not composable
2810        },
2811    )
2812    def test_jvpvmapvmap(self, device, dtype, op):
2813        samples = op.sample_inputs(device, dtype, requires_grad=True)
2814        B = 2
2815        for sample in samples:
2816            args = [sample.input] + list(sample.args)
2817            kwargs = sample.kwargs
2818            generator = generate_vmap_inputs(args, kwargs, batch_size=B)
2819            for batched_args, inner_in_dims, kwargs in generator:
2820                inner_vmapped_op = vmap(op, inner_in_dims)
2821                inner_mapped_op = functools.partial(loop, op, inner_in_dims, 0, B)
2822                generator = generate_vmap_inputs(batched_args, kwargs)
2823                for batched_args, in_dims, kwargs in generator:
2824                    # strategy: compare jvp(vmap(vmap(op)) vs jvp(map(map(op))
2825                    vmapped_op = vmap(inner_vmapped_op, in_dims)
2826                    mapped_op = functools.partial(loop, inner_mapped_op, in_dims, 0, B)
2827
2828                    jvpvmapvmap_fn, primals = get_jvp_variant_primals_tangents2(
2829                        vmapped_op, batched_args, kwargs, sample.output_process_fn_grad
2830                    )
2831                    jvpmapmap_fn, _ = get_jvp_variant_primals_tangents2(
2832                        mapped_op, batched_args, kwargs, sample.output_process_fn_grad
2833                    )
2834
2835                    expected = jvpmapmap_fn(*primals)
2836                    result = jvpvmapvmap_fn(*primals)
2837                    self.assertEqual(result, expected)
2838
2839    # See NOTE: [three-transform testing]
2840    @with_tf32_off  # https://github.com/pytorch/pytorch/issues/86798
2841    @ops(autograd_function_db, allowed_dtypes=(torch.float32,))
2842    @skipOps(
2843        "TestOperators",
2844        "test_vmapjvpvmap",
2845        {
2846            xfail("NumpyCubeNotComposableAutogradFunction"),  # Not composable
2847        },
2848    )
2849    def test_vmapjvpvmap(self, device, dtype, op):
2850        samples = op.sample_inputs(device, dtype, requires_grad=True)
2851        B = 2
2852        for sample in samples:
2853            args = [sample.input] + list(sample.args)
2854            kwargs = sample.kwargs
2855            generator = generate_vmap_inputs(args, kwargs, batch_size=B)
2856            for batched_args, in_dims, kwargs in generator:
2857                inner_vmapped_op = vmap(op, in_dims)
2858                inner_mapped_op = functools.partial(loop, op, in_dims, 0, B)
2859
2860                jvpvmap_fn, primals = get_jvp_variant_primals_tangents2(
2861                    inner_vmapped_op,
2862                    batched_args,
2863                    kwargs,
2864                    sample.output_process_fn_grad,
2865                )
2866                jvpmap_fn, _ = get_jvp_variant_primals_tangents2(
2867                    inner_mapped_op, batched_args, kwargs, sample.output_process_fn_grad
2868                )
2869
2870                generator = generate_vmap_inputs(primals, {})
2871
2872                for batched_args, in_dims, _ in generator:
2873                    # strategy: compare vmap(jvp(vmap(op)) vs map(jvp(map(op))
2874                    vmapjvpvmap_fn = vmap(jvpvmap_fn, in_dims)
2875                    mapjvpmap_fn = functools.partial(loop, jvpmap_fn, in_dims, 0, B)
2876
2877                    result = vmapjvpvmap_fn(*batched_args)
2878                    expected = mapjvpmap_fn(*batched_args)
2879                    self.assertEqual(result, expected)
2880
2881    # See NOTE: [three-transform testing]
2882    @ops(autograd_function_db, allowed_dtypes=(torch.float32,))
2883    @skipOps(
2884        "TestOperators",
2885        "test_jvpjvpvmap",
2886        {
2887            xfail("NumpyCubeNotComposableAutogradFunction"),  # Not composable
2888        },
2889    )
2890    def test_jvpjvpvmap(self, device, dtype, op):
2891        samples = op.sample_inputs(device, dtype, requires_grad=True)
2892        B = 2
2893        for sample in samples:
2894            args = [sample.input] + list(sample.args)
2895            kwargs = sample.kwargs
2896            generator = generate_vmap_inputs(args, kwargs, batch_size=B)
2897            for batched_args, in_dims, kwargs in generator:
2898                inner_vmapped_op = vmap(op, in_dims)
2899                inner_mapped_op = functools.partial(loop, op, in_dims, 0, B)
2900
2901                jvpmap_fn, args = get_jvp_variant_primals_tangents2(
2902                    inner_mapped_op, batched_args, kwargs, sample.output_process_fn_grad
2903                )
2904                jvpvmap_fn, _ = get_jvp_variant_primals_tangents2(
2905                    inner_vmapped_op,
2906                    batched_args,
2907                    kwargs,
2908                    sample.output_process_fn_grad,
2909                )
2910
2911                jvpjvpvmap_fn, new_args = get_jvp_variant_primals_tangents2(
2912                    jvpvmap_fn, args, {}
2913                )
2914                jvpjvpmap_fn, _ = get_jvp_variant_primals_tangents2(jvpmap_fn, args, {})
2915
2916                expected = jvpjvpmap_fn(*new_args)
2917                result = jvpjvpvmap_fn(*new_args)
2918                self.assertEqual(result, expected)
2919
2920    # See NOTE: [three-transform testing]
2921    @ops(autograd_function_db, allowed_dtypes=(torch.float32,))
2922    @skipOps(
2923        "TestOperators",
2924        "test_jvpvjpvmap",
2925        {
2926            xfail("NumpyCubeNotComposableAutogradFunction"),  # Not composable
2927        },
2928    )
2929    def test_jvpvjpvmap(self, device, dtype, op):
2930        samples = op.sample_inputs(device, dtype, requires_grad=True)
2931        B = 2
2932        for sample in samples:
2933            args = [sample.input] + list(sample.args)
2934            kwargs = sample.kwargs
2935            generator = generate_vmap_inputs(args, kwargs, batch_size=B)
2936            for batched_args, in_dims, kwargs in generator:
2937                inner_vmapped_op = vmap(op, in_dims)
2938                inner_mapped_op = functools.partial(loop, op, in_dims, 0, B)
2939
2940                vjpmap_fn, args = get_vjpfull_variant2(
2941                    inner_mapped_op, batched_args, kwargs
2942                )
2943                vjpvmap_fn, _ = get_vjpfull_variant2(
2944                    inner_vmapped_op, batched_args, kwargs
2945                )
2946
2947                jvpvjpvmap_fn, new_args = get_jvp_variant_primals_tangents2(
2948                    vjpvmap_fn, args, {}
2949                )
2950                jvpvjpmap_fn, _ = get_jvp_variant_primals_tangents2(vjpmap_fn, args, {})
2951
2952                expected = jvpvjpmap_fn(*new_args)
2953                result = jvpvjpvmap_fn(*new_args)
2954                self.assertEqual(result, expected)
2955
2956    def test_data_write_errors_under_transform(self, device):
2957        t = torch.randn(3, 3, device=device)
2958
2959        def fn(t):
2960            t.data = torch.randn(3, 3)
2961            return t.sum()
2962
2963        msg = "mutating directly with `.data` inside functorch transform"
2964        with self.assertRaisesRegex(RuntimeError, msg):
2965            grad(fn)(t)
2966
2967        with self.assertRaisesRegex(RuntimeError, msg):
2968            vjp(fn, t)
2969
2970        with self.assertRaisesRegex(RuntimeError, msg):
2971            jvp(fn, (t,), (torch.randn_like(t),))
2972
2973    def test_tensor_with_scalar_list(self, device):
2974        x = torch.randn((), device=device)
2975
2976        def func_list_of_scalar(x):
2977            return torch.tensor([x], device=device)
2978
2979        def func(x):
2980            return torch.tensor(x, device=device).view(1)
2981
2982        actual_o, actual_fn = vjp(func_list_of_scalar, x)
2983        expected_o, expected_fn = vjp(func, x)
2984
2985        self.assertEqual(actual_o, expected_o)
2986        self.assertEqual(
2987            expected_fn(torch.ones_like(expected_o)),
2988            actual_fn(torch.ones_like(actual_o)),
2989        )
2990
2991
2992only_for = ("cpu", "cuda")
2993instantiate_device_type_tests(TestOperators, globals(), only_for=only_for)
2994
2995if __name__ == "__main__":
2996    run_tests()
2997