# Owner(s): ["module: functorch"] # Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import functools import itertools import unittest from common_utils import ( check_vmap_fallback, decorate, expectedFailureIf, generate_vmap_inputs, get_fallback_and_vmap_exhaustive, is_batch_norm_training, is_valid_inplace_sample_input, loop, loop2, opsToleranceOverride, skip, skipOps, tol1, tol2, xfail, ) from functorch_additional_op_db import additional_op_db import torch import torch.autograd.forward_ad as fwAD from functorch import grad, jacfwd, jacrev, vjp, vmap from torch import Tensor from torch._functorch.eager_transforms import _as_tuple, jvp from torch.testing._internal.autograd_function_db import autograd_function_db from torch.testing._internal.common_cuda import with_tf32_off from torch.testing._internal.common_device_type import ( instantiate_device_type_tests, ops, tol, toleranceOverride, ) from torch.testing._internal.common_methods_invocations import op_db from torch.testing._internal.common_utils import ( is_iterable_of_tensors, IS_MACOS, IS_X86, noncontiguous_like, parametrize, run_tests, runOnRocm, skipIfRocm, TEST_WITH_ASAN, TEST_WITH_ROCM, TestCase, unMarkDynamoStrictTest, ) from torch.testing._internal.opinfo.core import SampleInput from torch.utils import _pytree as pytree from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten aten = torch.ops.aten # Version of autograd.grad with some differences: # - pytree inputs is allowed (but leaves of the pytree have to all # be tensors) # - if an input is not used as part of derivatives, we will return a # zero-filled tensor for the result def _autograd_grad( outputs, inputs, grad_outputs=None, retain_graph=False, create_graph=True ): inputs, inputs_spec = tree_flatten(inputs) diff_inputs = tuple(inp for inp in inputs if inp.requires_grad) if grad_outputs is None: diff_outputs = tuple(out for out in outputs if out.requires_grad) else: diff_grad_outputs = [ (out, go) for out, go in zip(outputs, grad_outputs) if out.requires_grad ] if len(diff_grad_outputs) == 0: diff_outputs, grad_outputs = (), () else: diff_outputs, grad_outputs = zip(*diff_grad_outputs) grad_inputs = torch.autograd.grad( diff_outputs, diff_inputs, grad_outputs, retain_graph=retain_graph, create_graph=create_graph, allow_unused=True, ) result = [] grad_inputs_iter = iter(grad_inputs) for inp in inputs: if inp.requires_grad: grad_input = next(grad_inputs_iter) if grad_input is None: result.append(torch.zeros_like(inp)) else: result.append(grad_input) else: result.append(torch.zeros_like(inp)) return tree_unflatten(result, inputs_spec) def diff_arg(arg, requires_grad=True): def is_differentiable_arg(arg): if requires_grad: return arg.requires_grad else: return arg.is_floating_point() or arg.is_complex() if is_iterable_of_tensors(arg): if all(is_differentiable_arg(a) for a in arg): return True if all(not is_differentiable_arg(a) for a in arg): return False raise RuntimeError("NYI: The test runner can't handle this") return isinstance(arg, Tensor) and is_differentiable_arg(arg) # Given f, returns an f' such that: # - f' takes only positional arguments # - All arguments to f' are floating-point Tensors # - All outputs of f' are floating-point Tensors def normalize_op_input_output2( f, args, kwargs, output_process_fn_grad=None, requires_grad=True ): flat_args, args_spec = tree_flatten(args) diff_argnums = tuple( i for i, arg in enumerate(flat_args) if diff_arg(arg, requires_grad=requires_grad) ) assert len(diff_argnums) > 0 primals = tuple(flat_args[i] for i in diff_argnums) @functools.wraps(f) def wrapped(*primals): _args = list(flat_args) for num, arg in zip(diff_argnums, primals): _args[num] = arg _args = tree_unflatten(_args, args_spec) result = f(*_args, **kwargs) if output_process_fn_grad is not None: result = output_process_fn_grad(result) if isinstance(result, tuple): result = tuple(r for r in result if torch.is_floating_point(r)) assert len(result) > 0 return result return wrapped, primals # TODO: consolidate with normalize_op_input_output2 def normalize_op_input_output3( f, args, kwargs, sample_args, output_process_fn_grad=None ): flat_args, args_spec = tree_flatten(args) flat_sample_args = pytree.tree_leaves(sample_args) diff_argnums = tuple( i for i, (arg, sample) in enumerate(zip(flat_args, flat_sample_args)) if diff_arg(sample, requires_grad=True) ) assert len(diff_argnums) > 0 primals = tuple(flat_args[i] for i in diff_argnums) @functools.wraps(f) def wrapped(*primals): _args = list(flat_args) for num, arg in zip(diff_argnums, primals): _args[num] = arg _args = tree_unflatten(_args, args_spec) result = f(*_args, **kwargs) if output_process_fn_grad is not None: result = output_process_fn_grad(result) if isinstance(result, tuple): result = tuple(r for r in result if torch.is_floating_point(r)) assert len(result) > 0 return result return wrapped, primals def normalize_op_input_output(f, sample, requires_grad=True): args = tuple([sample.input] + list(sample.args)) return normalize_op_input_output2( f, args, sample.kwargs, sample.output_process_fn_grad, requires_grad=requires_grad, ) def ref_vjp(f, *primals): result = f(*primals) def wrapped(cotangents): return _autograd_grad(_as_tuple(result), primals, _as_tuple(cotangents)) return result, wrapped def simulate_jvp(f, primals, tangents): primals_out, tangents_out = torch.autograd.functional.jvp(f, primals, tangents) return primals_out, tangents_out def ref_jvp(f, primals, tangents): with fwAD.dual_level(): duals = tuple(fwAD.make_dual(p, t) for p, t in zip(primals, tangents)) result_duals = f(*duals) result_duals, spec = tree_flatten(result_duals) primals_out, tangents_out = zip(*(fwAD.unpack_dual(d) for d in result_duals)) return tree_unflatten(primals_out, spec), tree_unflatten(tangents_out, spec) def get_sample_cotangents(f, sample): fn, primals = normalize_op_input_output(f, sample) output = fn(*primals) return tree_map(torch.randn_like, output) # returns a new function g(*args, *cotangents) # that computes vjps and (*args, cotangents) def get_vjp_fn_and_args_with_cotangents(f, sample, cotangents): args = tuple([sample.input] + list(sample.args)) kwargs = sample.kwargs flat_args, args_spec = tree_flatten(args) flat_cotangents, cotangents_spec = tree_flatten(cotangents) @functools.wraps(f) def wrapped(*args): assert len(args) == len(flat_args) + len(flat_cotangents) actual_args = args[: len(flat_args)] cotangents = args[len(flat_args) :] actual_args = tree_unflatten(actual_args, args_spec) cotangents = tree_unflatten(cotangents, cotangents_spec) fn, primals = normalize_op_input_output3( f, actual_args, kwargs, flat_args, sample.output_process_fn_grad ) _, vjp_fn = vjp(fn, *primals) return vjp_fn(cotangents) return wrapped, tuple(flat_args + flat_cotangents) # Returns a new function g(*args, *cotangents) that computes vjps and # sample (*args, *cotangents) def get_vjpfull_variant(f, sample): fn, primals = normalize_op_input_output(f, sample) return _get_vjpfull_variant(fn, primals) def get_vjpfull_variant2(f, args, kwargs): fn, primals = normalize_op_input_output2(f, args, kwargs) return _get_vjpfull_variant(fn, primals) def _get_vjpfull_variant(fn, primals): result = fn(*primals) cotangents = _as_tuple( tree_map(lambda x: torch.randn_like(x, requires_grad=True), result) ) num_primals = len(primals) args = (*primals, *cotangents) @functools.wraps(fn) def wrapped(*args): primals = args[:num_primals] cotangents = args[num_primals:] result, vjp_fn = vjp(fn, *primals) if isinstance(result, torch.Tensor): assert len(cotangents) == 1 cotangents = cotangents[0] return vjp_fn(cotangents) return wrapped, args def get_jvp_variant(f, sample): # We want this higher-order variant of jvp, so that it can # be used to wrap vmap fn, primals = normalize_op_input_output(f, sample, requires_grad=False) tangents = _as_tuple(tree_map(lambda x: torch.randn_like(x), primals)) @functools.wraps(f) def wrapped(*args): tangents = args primals_out, tangents_out = jvp(fn, primals, tangents) if isinstance(primals_out, torch.Tensor): return (primals_out, tangents_out) else: flat_primals_out = pytree.tree_leaves(primals_out) flat_tangents_out = pytree.tree_leaves(tangents_out) return tuple(flat_primals_out + flat_tangents_out) return wrapped, tangents def get_jvp_variant_primals_tangents2( f, args, kwargs, output_process_fn_grad=None, requires_grad=False ): fn, primals = normalize_op_input_output2( f, args, kwargs, output_process_fn_grad, requires_grad ) tangents = _as_tuple(tree_map(lambda x: torch.randn_like(x), primals)) return _get_jvp_variant(fn, primals, tangents) def get_jvp_variant_primals_tangents(f, sample): # We want this higher-order variant of jvp, so that it can # be used to wrap vmap fn, primals = normalize_op_input_output(f, sample, requires_grad=False) tangents = _as_tuple(tree_map(lambda x: torch.randn_like(x), primals)) return _get_jvp_variant(fn, primals, tangents) def _get_jvp_variant(fn, primals, tangents): @functools.wraps(fn) def wrapped(*args): primals_in = args[: len(primals)] tangents_in = args[len(primals) :] primals_out, tangents_out = jvp(fn, primals_in, tangents_in) if isinstance(primals_out, torch.Tensor): return (primals_out, tangents_out) else: flat_primals_out = pytree.tree_leaves(primals_out) flat_tangents_out = pytree.tree_leaves(tangents_out) return tuple(flat_primals_out + flat_tangents_out) return wrapped, primals + tangents def is_inplace(op, variant): if hasattr(variant, "__wrapped__"): return variant.__wrapped__ is op.get_inplace() return variant is op.get_inplace() vjp_fail = { xfail("tensor_split"), # data_ptr composite compliance # Very minor accuracy issue on ROCm decorate("nn.functional.scaled_dot_product_attention", decorator=skipIfRocm), } aliasing_ops = { "T", "broadcast_to", "conj", "contiguous", "diagonal", # linalg.diagonal is an alias "expand", "flatten", "imag", "mH", # adjoint is an alias "mT", "movedim", # moveaxis is an alias "narrow", "permute", "positive", # 'ravel', is composite implicit autograd and may call clone "real", "reshape", "resolve_conj", "resolve_neg", "select", "squeeze", "transpose", # swapdims and swapaxes are aliases "unflatten", "unfold", "unsqueeze", "view", "view_as", "view_as_complex", "view_as_real", } aliasing_ops_list_return = { "chunks", "dsplit", "hsplit", "split", "unbind", "vsplit", # 'tensor_split' not composite compliant, see vjp_fail } skip_noncontig = { "_batch_norm_with_update", "as_strided_copy", } @unittest.skipIf(TEST_WITH_ASAN, "tests time out with asan, are probably redundant") @unMarkDynamoStrictTest class TestOperators(TestCase): @with_tf32_off # https://github.com/pytorch/pytorch/issues/86798 @ops(op_db + additional_op_db + autograd_function_db, allowed_dtypes=(torch.float,)) @skipOps( "TestOperators", "test_grad", vjp_fail.union( { xfail( "chalf", "", device_type="cpu" ), # RuntimeError: "sum_cpu" not implemented for 'ComplexHalf' xfail( "sparse.sampled_addmm", "" ), # RuntimeError: Sparse CSR tensors do not have strides xfail( "sparse.mm", "reduce" ), # RuntimeError: Sparse CSR tensors do not have strides # Non-contiguous Bugs # # AssertionError: Tensor-likes are not close! xfail("_softmax_backward_data", device_type="cpu"), xfail("as_strided"), xfail("as_strided", "partial_views"), # RuntimeError: !self.requires_grad() || self.is_contiguous() xfail("as_strided_scatter"), # RuntimeError: Tensor must have a last dimension with stride 1 xfail("view_as_complex"), # query: last dimension must be contiguous # Fused attention kernels require last dim to be contiguous decorate( "nn.functional.scaled_dot_product_attention", decorator=expectedFailureIf(not TEST_WITH_ROCM), ), # Works on ROCm xfail("torch.ops.aten._flash_attention_forward"), xfail("torch.ops.aten._efficient_attention_forward"), # RuntimeError: Expected contiguous tensor, but got # non-contiguous tensor for argument #2 'grad_output' decorate( "_batch_norm_with_update", decorator=expectedFailureIf(TEST_WITH_ROCM), device_type="cuda", ), } ), ) @opsToleranceOverride( "TestOperators", "test_grad", ( tol1( "nn.functional.binary_cross_entropy_with_logits", {torch.float32: tol(atol=1e-04, rtol=1e-04)}, ), tol1("masked.cumprod", {torch.float32: tol(atol=1e-05, rtol=1e-05)}), tol1("svd_lowrank", {torch.float32: tol(atol=3e-04, rtol=3e-04)}), tol1( "linalg.multi_dot", {torch.float32: tol(atol=1e-05, rtol=8e-04)}, device_type="cuda", ), tol1( "linalg.tensorsolve", {torch.float32: tol(atol=3e-04, rtol=3e-04)}, device_type="cuda", ), tol1( "nn.functional.multi_head_attention_forward", {torch.float32: tol(atol=8e-04, rtol=1e-03)}, ), tol1( "__rmatmul__", {torch.float32: tol(atol=3e-04, rtol=3e-04)}, device_type="cuda", ), tol1( "matmul", {torch.float32: tol(atol=3e-04, rtol=3e-04)}, device_type="cuda", ), tol1( "pca_lowrank", {torch.float32: tol(atol=3e-05, rtol=4e-06)}, device_type="cpu", ), ), ) def test_grad(self, device, dtype, op): if op.name in vjp_fail: self.skipTest("Skipped; Expected failures") return if not op.supports_autograd: self.skipTest("Skipped! Autograd not supported.") return samples = op.sample_inputs(device, dtype, requires_grad=True) if is_inplace(op, op.get_op()): self.skipTest("Skipped for redundancy. test_vjp handles in-place testing.") return for sample in samples: args = [sample.input] + list(sample.args) kwargs = sample.kwargs if op.name not in skip_noncontig: noncontig_sample = sample.noncontiguous() noncontig_args = [noncontig_sample.input] + list(noncontig_sample.args) noncontig_kwargs = noncontig_sample.kwargs diff_argnums = tuple(i for i, arg in enumerate(args) if diff_arg(arg)) assert len(diff_argnums) > 0 diff_args = tuple(args[i] for i in diff_argnums) def wrapped_fn(*args, **kwargs): result = op(*args, **kwargs) if sample.output_process_fn_grad is not None: result = sample.output_process_fn_grad(result) def abs_if_complex(t): if t.dtype.is_complex: return t.abs() return t # Reduce into single value for grad if isinstance(result, torch.Tensor): return abs_if_complex(result.sum()) result = sum(abs_if_complex(res.sum()) for res in result) return result result = grad(wrapped_fn, diff_argnums)(*args, **kwargs) expected = _autograd_grad(_as_tuple(wrapped_fn(*args, **kwargs)), diff_args) self.assertEqual(result, expected) if op.name not in skip_noncontig: result_noncontig = grad(wrapped_fn, diff_argnums)( *noncontig_args, **noncontig_kwargs ) self.assertEqual(result_noncontig, expected) @with_tf32_off # https://github.com/pytorch/pytorch/issues/86798 @ops(op_db + additional_op_db + autograd_function_db, allowed_dtypes=(torch.float,)) @skipOps( "TestOperators", "test_jvp", set( { # Composite ops that do bad things. Need to be fixed in PyTorch core. # RuntimeError: Cannot access data pointer of Tensor that doesn't have storage xfail("tensor_split"), # BUG: silent incorrectness: runs and produces numerical differences skip("nn.functional.max_unpool1d"), # fails everywhere except on mac skip( "nn.functional.max_unpool2d" ), # fails everywhere except on windows skip("nn.functional.max_unpool3d"), # fails everywhere except on mac xfail( "native_batch_norm" ), # TODO: fails comparing None to tensor of 0s for saved_mean/var tangents xfail( "_native_batch_norm_legit" ), # TODO: fails comparing None to tensor of 0s for saved_mean/var tangents xfail( "_batch_norm_with_update" ), # TODO: fails comparing None to tensor of 0s for saved_mean/var tangents xfail("nn.functional.scaled_dot_product_attention"), xfail("torch.ops.aten._flash_attention_forward"), xfail("torch.ops.aten._efficient_attention_forward"), xfail( "nn.functional.rrelu" ), # in-place test errors out with no formula implemented xfail( "NumpyExpMarkDirtyAutogradFunction" ), # TODO: https://github.com/pytorch/pytorch/issues/91280 # --- Non-Contiguous Failures! --- # This is expected to fail as the operator # expects last dim to have stride=1 xfail("view_as_complex"), # BUG # AssertionError: Tensor-likes are not close! xfail("as_strided"), xfail("as_strided", "partial_views"), xfail("as_strided_scatter"), decorate( "linalg.det", "singular", decorator=expectedFailureIf(IS_MACOS and IS_X86), ), } ), ) @opsToleranceOverride( "TestOperators", "test_jvp", ( tol1( "nn.functional.conv_transpose3d", {torch.float32: tol(atol=1e-04, rtol=1.3e-06)}, device_type="cuda", ), tol1( "linalg.tensorsolve", {torch.float32: tol(atol=1e-04, rtol=1.3e-05)}, device_type="cuda", ), tol1( "masked.prod", {torch.float32: tol(atol=1e-05, rtol=1.3e-05)}, device_type="cuda", ), tol1( "nn.functional.binary_cross_entropy_with_logits", {torch.float32: tol(atol=4e-04, rtol=4e-04)}, ), tol1( "nn.functional.batch_norm", {torch.float32: tol(atol=4e-05, rtol=5e-05)} ), tol1("nn.functional.conv2d", {torch.float32: tol(atol=4e-05, rtol=5e-05)}), tol1("svd_lowrank", {torch.float32: tol(atol=5e-05, rtol=5e-05)}), tol1("pca_lowrank", {torch.float32: tol(atol=5e-05, rtol=5e-05)}), tol1( "nn.functional.multi_head_attention_forward", {torch.float32: tol(atol=6e-05, rtol=2e-05)}, ), tol2( "linalg.pinv", "hermitian", {torch.float32: tol(atol=5e-5, rtol=2e-5)} ), ), ) def test_jvp(self, device, dtype, op): # TODO: get rid of vjp_decomp when we add decomposition support to # PyTorch's forward-mode ad. Currently the decomposition support only # works for functorch.jvp VJP_DECOMP = { "nn.functional.logsigmoid", } if op.name in VJP_DECOMP: fixme_ref_jvp_local = simulate_jvp else: fixme_ref_jvp_local = ref_jvp if not op.supports_forward_ad and op.name not in VJP_DECOMP: self.skipTest("Skipped! Forward AD not supported.") return samples = op.sample_inputs(device, dtype, requires_grad=True) outplace_variant = op if not is_inplace(op, op.get_op()) else None inplace_variant = op.inplace_variant if op.supports_inplace_autograd else None for sample in samples: if outplace_variant: self.jvp_opinfo_test( outplace_variant, sample, sample.output_process_fn_grad, clone_inputs=False, fixme_ref_jvp_local=fixme_ref_jvp_local, test_noncontig=op.name not in skip_noncontig, ) if is_valid_inplace_sample_input(sample, op, inplace_variant): self.jvp_opinfo_test( inplace_variant, sample, sample.output_process_fn_grad, clone_inputs=True, fixme_ref_jvp_local=fixme_ref_jvp_local, test_noncontig=op.name not in skip_noncontig, ) def jvp_opinfo_test( self, fn, sample, output_process_fn, clone_inputs, fixme_ref_jvp_local, test_noncontig, ): # NB: we used requires_grad=True to determine where the primals are, # but don't need that information otherwise args = (sample.input,) + sample.args kwargs = sample.kwargs contig_fn, primals = normalize_op_input_output2( fn, args, kwargs, output_process_fn, requires_grad=True ) orig_primals = tree_map(lambda x: x.detach(), primals) orig_tangents = tree_map(lambda x: torch.randn_like(x), primals) def maybe_clone_inputs(): if clone_inputs: primals = tree_map(torch.clone, orig_primals) tangents = tree_map(torch.clone, orig_tangents) return primals, tangents return orig_primals, orig_tangents primals, tangents = maybe_clone_inputs() expected_primal_outs, expected_tangent_outs = fixme_ref_jvp_local( contig_fn, primals, tangents ) primals, tangents = maybe_clone_inputs() primal_outs, tangent_outs = jvp(contig_fn, primals, tangents) self.assertEqual(primal_outs, expected_primal_outs) self.assertEqual(tangent_outs, expected_tangent_outs) if test_noncontig: noncontig_sample = sample.noncontiguous() noncontig_args = (noncontig_sample.input,) + noncontig_sample.args noncontig_kwargs = sample.kwargs noncontig_fn, primals = normalize_op_input_output2( fn, noncontig_args, noncontig_kwargs, output_process_fn, requires_grad=True, ) noncontig_primals = tree_map(lambda x: x.detach(), primals) noncontig_tangents = tree_map( lambda x: noncontiguous_like(x), orig_tangents ) noncontig_primal_outs, noncontig_tangent_outs = jvp( noncontig_fn, noncontig_primals, noncontig_tangents ) self.assertEqual(noncontig_primal_outs, expected_primal_outs) self.assertEqual(noncontig_tangent_outs, expected_tangent_outs) @with_tf32_off # https://github.com/pytorch/pytorch/issues/86798 @ops(op_db + additional_op_db + autograd_function_db, allowed_dtypes=(torch.float,)) @skipOps( "TestOperators", "test_vjp", vjp_fail.union( { xfail("sparse.sampled_addmm", ""), xfail("sparse.mm", "reduce"), # ---- Non-Contiguous Failures ---- # This is expected to fail as the operator # expects last dim to have stride=1 xfail("view_as_complex"), # RuntimeError: query: last dimension must be contiguous # The fused attention kernels require the last dim to be contiguous decorate( "nn.functional.scaled_dot_product_attention", decorator=expectedFailureIf(not TEST_WITH_ROCM), ), # Works on ROCm xfail("torch.ops.aten._flash_attention_forward"), xfail("torch.ops.aten._efficient_attention_forward"), # BUG # AssertionError: Tensor-likes are not close! xfail("as_strided"), xfail("as_strided_scatter"), xfail("_softmax_backward_data", device_type="cpu"), xfail("as_strided", "partial_views"), } ), ) @opsToleranceOverride( "TestOperators", "test_vjp", ( tol1( "nn.functional.conv_transpose3d", {torch.float32: tol(atol=5e-05, rtol=9e-05)}, device_type="cuda", ), tol1( "nn.functional.binary_cross_entropy_with_logits", {torch.float32: tol(atol=1e-04, rtol=1e-04)}, ), tol1( "nn.functional.multi_head_attention_forward", {torch.float32: tol(atol=2e-03, rtol=2e-04)}, ), tol1("__rmatmul__", {torch.float32: tol(atol=1e-05, rtol=1e-05)}), tol1("matmul", {torch.float32: tol(atol=1e-05, rtol=1e-05)}), tol2( "linalg.pinv", "hermitian", {torch.float32: tol(atol=1e-05, rtol=1e-05)} ), tol1("linalg.tensorsolve", {torch.float32: tol(atol=9e-03, rtol=2e-04)}), tol1("linalg.multi_dot", {torch.float32: tol(atol=1e-04, rtol=1e-04)}), tol1("svd_lowrank", {torch.float32: tol(atol=1e-04, rtol=1e-04)}), tol1("pca_lowrank", {torch.float32: tol(atol=1e-04, rtol=1e-04)}), ), ) def test_vjp(self, device, dtype, op): if not op.supports_autograd: self.skipTest("Skipped! Autograd not supported.") return samples = op.sample_inputs(device, dtype, requires_grad=True) def _test(_op, inplace=False): for sample in samples: if inplace and not is_valid_inplace_sample_input( sample, op, op.inplace_variant ): continue fn, primals = normalize_op_input_output(_op, sample) result = fn(*primals) cotangents = tree_map(lambda x: torch.randn_like(x), result) out, vjp_fn = vjp(fn, *primals) self.assertEqual(out, result) result_vjps = vjp_fn(cotangents) _, vjp_fn = ref_vjp(fn, *primals) expected_vjps = vjp_fn(cotangents) self.assertEqual(result_vjps, expected_vjps) if op.name not in skip_noncontig: noncontig_fn, noncontig_primals = normalize_op_input_output( _op, sample.noncontiguous() ) noncontig_cotangents = tree_map( lambda x: noncontiguous_like(x), cotangents ) out_noncontig, vjp_fn = vjp(noncontig_fn, *noncontig_primals) self.assertEqual(out_noncontig, result) noncontig_result_vjps = vjp_fn(noncontig_cotangents) self.assertEqual(noncontig_result_vjps, expected_vjps) _test(op) for a_op in op.aliases: _test(a_op) if op.inplace_variant: def f(inp, *args, **kwargs): return op.inplace_variant(inp.clone(), *args, **kwargs) _test(f, inplace=True) @ops(op_db + additional_op_db + autograd_function_db, allowed_dtypes=(torch.float,)) @skipOps( "TestOperators", "test_vjpvjp", vjp_fail.union( { skip("nn.functional.max_unpool1d"), # silent incorrectness; Flaky skip("nn.functional.max_unpool2d"), # silent incorrectness; Flaky xfail("nn.functional.ctc_loss"), # Not Implemented xfail( "native_layer_norm", "" ), # Expected a proper Tensor but got None for argument #1 'other' xfail("sparse.sampled_addmm", ""), # sparse tensors have no strides xfail("sparse.mm", "reduce"), # sparse tensors have no strides skip("nn.functional.scaled_dot_product_attention"), xfail("torch.ops.aten._flash_attention_forward"), xfail("torch.ops.aten._efficient_attention_forward"), # AssertionError: Tensor-likes are not close! # Mismatched elements: 1 / 15 (6.7%) # Greatest absolute difference: 24.0 at index (2, 4) (up to 1e-05 allowed) # Greatest relative difference: 1.7933241714393998e-06 at index (2, 4) (up to 1.3e-06 allowed) # The failure occurred for item [0] xfail("masked.prod"), } ), ) @opsToleranceOverride( "TestOperators", "test_vjpvjp", ( tol1( "nn.functional.conv_transpose3d", {torch.float32: tol(atol=5e-05, rtol=9e-05)}, device_type="cuda", ), tol1("prod", {torch.float32: tol(atol=2e-05, rtol=1e-04)}), tol1("masked.cumprod", {torch.float32: tol(atol=5e-04, rtol=5e-04)}), tol1("cumprod", {torch.float32: tol(atol=5e-04, rtol=5e-04)}), tol1("linalg.vander", {torch.float32: tol(atol=5e-04, rtol=5e-04)}), tol2( "linalg.det", "singular", {torch.float32: tol(atol=2e-05, rtol=2e-05)} ), ), ) def test_vjpvjp(self, device, dtype, op): if not op.supports_autograd: self.skipTest("Skipped! Autograd not supported.") return if not op.supports_gradgrad: self.skipTest("Skipped! Operation does not support gradgrad") return samples = op.sample_inputs(device, dtype, requires_grad=True) def test(_op, inplace=False): for sample in samples: if inplace and not is_valid_inplace_sample_input( sample, op, op.inplace_variant ): continue fn, args = get_vjpfull_variant(_op, sample) result = fn(*args) cotangents = tree_map(lambda x: torch.randn_like(x), result) # Compute vjp of vjp _, vjp_fn = vjp(fn, *args) result_vjps = vjp_fn(cotangents) # Compute ref_vjp of vjp. We could have done ref_vjp of ref_vjp, # but since we're confident that vjp works by itself, this is # an equivalent way to test that. _, vjp_fn = ref_vjp(fn, *args) expected_vjps = vjp_fn(cotangents) self.assertEqual(result_vjps, expected_vjps) test(op) if op.inplace_variant: def fn(inp, *args, **kwargs): return op.inplace_variant(inp.clone(), *args, **kwargs) test(fn, inplace=True) @with_tf32_off # https://github.com/pytorch/pytorch/issues/86798 @skipOps( "TestOperators", "test_vmapvjpvjp", vjp_fail.union( { skip("atleast_1d"), # Takes too long skip("atleast_2d"), # Takes too long skip("atleast_3d"), # Takes too long skip("ormqr"), # Takes too long xfail("as_strided"), # incorrect output xfail("as_strided", "partial_views"), # incorrect output xfail("as_strided_scatter"), # incorrect output skip("bernoulli"), # calls random op xfail("bfloat16"), # rank 4 tensor for channels_last xfail("cdouble"), # rank 4 tensor for channels_last xfail("cfloat"), # rank 4 tensor for channels_last xfail("chalf"), # rank 4 tensor for channels_last xfail("double"), # rank 4 tensor for channels_last xfail("float"), # rank 4 tensor for channels_last xfail("half"), # rank 4 tensor for channels_last xfail( "NumpyCubeNotComposableAutogradFunction" ), # Not composable autograd.Function # It looks like you're either (1) calling .item() on a Tensor or # (2) attempting to use a Tensor in some data-dependent control flow or # (3) encountering this error in PyTorch internals. xfail("index_reduce", "prod"), decorate( "linalg.householder_product", decorator=runOnRocm ), # works on ROCm xfail( # nans "masked.softmax", device_type="cpu", ), xfail( "nanquantile", device_type="cpu" ), # vmap not implemented for at::equal. xfail("native_layer_norm"), # vmap: inplace into a regular tensor # got a batched tensor as input while the running_mean or running_var, # which will be updated in place, were not batched. xfail("nn.functional.batch_norm"), xfail( "nn.functional.binary_cross_entropy" ), # vmap: inplace into a regular tensor xfail( "nn.functional.ctc_loss" ), # derivate not implemented for _ctc_loss_backward # flaky on ROCM needs investigation decorate("nn.functional.conv_transpose2d", decorator=skipIfRocm), skip("nn.functional.dropout"), # calls random op skip("nn.functional.dropout2d"), # calls random op skip("nn.functional.dropout3d"), # calls random op skip("nn.functional.alpha_dropout"), # calls random op skip( "nn.functional.feature_alpha_dropout", "with_train" ), # calls random op skip("nn.functional.fractional_max_pool2d"), # calls random op skip("nn.functional.fractional_max_pool3d"), # calls random op xfail("nn.functional.scaled_dot_product_attention"), # randomness xfail("torch.ops.aten._efficient_attention_forward"), # outputs ints xfail("nn.functional.multi_head_attention_forward"), # randomness # It looks like you're either (1) calling .item() on a Tensor or # (2) attempting to use a Tensor in some data-dependent control flow or # (3) encountering this error in PyTorch internals. xfail("nn.functional.gaussian_nll_loss"), # got a batched tensor as input while the running_mean or running_var, # which will be updated in place, were not batched. xfail("nn.functional.instance_norm"), xfail( "nn.functional.layer_norm" ), # vmap: inplace into a regular tensor # RuntimeError: NYI: querying is_contiguous inside of vmap # for memory_format other than torch.contiguous_formats xfail("nn.functional.max_pool2d"), # RuntimeError: NYI: Tensor.clone(memory_format) inside vmap is only # supported with memory_format torch.preserve_format or # torch.contiguous_format (got ChannelsLast) xfail("nn.functional.max_unpool2d"), # RuntimeError: NYI: Tensor.clone(memory_format) inside vmap is only # supported with memory_format torch.preserve_format # or torch.contiguous_format (got ChannelsLast)s xfail("nn.functional.max_unpool2d", "grad"), xfail( "nn.functional.rrelu" ), # RuntimeError: vmap: we do not yet support aten::rrelu_with_noise. xfail("normal"), # calls random op xfail("normal", "number_mean"), # calls random op xfail("pca_lowrank"), # calls random op xfail( "quantile", device_type="cpu" ), # Batching rule not implemented for `at::equal` xfail( "scatter_reduce", "prod" ), # vmap (looks like you are calling item/data-dependent) xfail( "sparse.sampled_addmm" ), # RuntimeError: Sparse CSR tensors do not have strides xfail( "sparse.mm", "reduce" ), # RuntimeError: Sparse CSR tensors do not have strides xfail("svd_lowrank"), # calls random op xfail("to"), # rank 4 tensor for channels_last xfail( "view_as_complex" ), # RuntimeError: Tensor must have a last dimension with stride 1 # got a batched tensor as input while the running_mean or running_var, # which will be updated in place, were not batched. xfail("nn.functional.batch_norm", "without_cudnn"), # view doesn't work on sparse xfail("to_sparse"), xfail("native_batch_norm"), xfail("_native_batch_norm_legit"), # TODO: implement batching rule xfail("_batch_norm_with_update"), } ), ) @ops(op_db + additional_op_db + autograd_function_db, allowed_dtypes=(torch.float,)) @toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)}) @opsToleranceOverride( "TestOperators", "test_vmapvjpvjp", ( tol1("linalg.svd", {torch.float32: tol(atol=1e-03, rtol=5e-04)}), tol1("linalg.lu", {torch.float32: tol(atol=5e-04, rtol=7e-04)}), tol1("linalg.lu_factor", {torch.float32: tol(atol=2e-03, rtol=2e-02)}), tol1("linalg.multi_dot", {torch.float32: tol(atol=2e-03, rtol=2e-04)}), tol1("svd", {torch.float32: tol(atol=1e-03, rtol=5e-04)}), tol1("matrix_exp", {torch.float32: tol(atol=1e-03, rtol=5e-04)}), tol1("masked.prod", {torch.float32: tol(atol=2e-03, rtol=2e-04)}), ), ) @skipOps( "TestOperators", "test_vmapvjpvjp", { xfail("as_strided", "partial_views"), xfail("as_strided_copy"), }, ) def test_vmapvjpvjp(self, device, dtype, op): # Since, we test `vjpvjp` independently, # for this test, we just verify that vmap # of `vjpvjp` is correct. if not op.supports_autograd: self.skipTest("Skipped! Autograd not supported.") return if not op.supports_gradgrad: self.skipTest("Skipped! Operation does not support gradgrad") return samples = op.sample_inputs(device, dtype, requires_grad=True) # TODO: test in-place if is_inplace(op, op.get_op()): self.skipTest("Skipped! NYI: inplace-testing not supported.") return for sample in samples: fn, args = get_vjpfull_variant(op, sample) result = fn(*args) cotangents = tree_map(lambda x: torch.randn_like(x), result) cotangents = pytree.tree_leaves(cotangents) num_args = len(args) args_and_cotangents = tuple(args) + tuple(cotangents) def vjp_of_vjp(*args_and_cotangents): args = args_and_cotangents[:num_args] cotangents = args_and_cotangents[num_args:] result, vjp_fn = vjp(fn, *args) result_vjps = vjp_fn(cotangents) result = pytree.tree_leaves(result) result_vjps = pytree.tree_leaves(result_vjps) return (*result, *result_vjps) is_batch_norm_and_training = is_batch_norm_training(op.name, sample.kwargs) generator = get_fallback_and_vmap_exhaustive( vjp_of_vjp, args_and_cotangents, {}, is_batch_norm_and_training=is_batch_norm_and_training, ) for loop_out, batched_out in generator: self.assertEqual(loop_out, batched_out) vmapvjp_fail = vjp_fail.union( { # -------------------- ALLOWED FAILURES -------------------------------- # The following are not bugs and are expected behavior xfail("masked_select"), # Not possible due to dynamic shapes skip("bernoulli"), # randomness skip("normal", ""), # randomness skip("normal", "number_mean"), # randomness skip("nn.functional.rrelu"), # randomness skip("nn.functional.feature_alpha_dropout", "with_train"), # randomness skip("nn.functional.feature_alpha_dropout", "without_train"), # randomness skip("nn.functional.dropout"), # randomness skip("nn.functional.dropout2d"), # randomness skip("nn.functional.dropout3d", ""), # randomness skip("nn.functional.alpha_dropout"), # randomness skip("nn.functional.scaled_dot_product_attention"), # randomness xfail("torch.ops.aten._efficient_attention_forward"), # outputs ints skip("nn.functional.multi_head_attention_forward"), # randomness xfail( "index_put", "" ), # not possible due to dynamic shapes; we support a subset xfail("nn.functional.fractional_max_pool2d"), # random xfail("nn.functional.fractional_max_pool3d"), # random xfail("pca_lowrank", ""), # randomness xfail("svd_lowrank", ""), # randomness xfail("to_sparse", ""), # non-dense output skip( "to" ), # RuntimeError: required rank 4 tensor to use channels_last format xfail("as_strided", "partial_views"), xfail( "NumpyCubeNotComposableAutogradFunction" ), # Not composable autograd.Function # ---------------------------------------------------------------------- # ---------------------------- BUGS ------------------------------------ # All of the following are bugs and need to be fixed skip( "linalg.svdvals" ), # # really annoying thing where it passes correctness check but not has_batch_rule skip("native_batch_norm"), skip("_native_batch_norm_legit"), # TODO: implement batching rule skip("_batch_norm_with_update"), xfail("__getitem__", ""), # dynamic error xfail("nanquantile", device_type="cpu"), # checks q via a .item() call xfail("nn.functional.gaussian_nll_loss"), # checks var for if any value < 0 xfail("narrow"), # .item() call xfail("quantile", device_type="cpu"), # checks q via a .item() call xfail("view_as_complex"), # Tensor must have a last dimension with stride 1 # required rank 4 tensor to use channels_last format xfail("bfloat16"), xfail("double"), xfail("float"), xfail("half"), xfail("cdouble", ""), xfail("cfloat", ""), xfail("chalf", ""), xfail("scatter_reduce", "prod"), # item call # Batching rule not implemented for aten::_use_cudnn_ctc_loss.Tensor xfail("nn.functional.ctc_loss", device_type="cuda"), # NYI: querying is_contiguous inside of vmap for memory_format other than torch.contiguous_format xfail("nn.functional.max_unpool2d"), xfail("nn.functional.max_unpool2d", "grad"), xfail("sparse.sampled_addmm", ""), xfail("sparse.mm", "reduce"), xfail("as_strided_scatter", ""), # calls as_strided xfail("index_reduce", "prod"), # .item() call # --------------------------------------------------------------------- } ) @with_tf32_off # https://github.com/pytorch/pytorch/issues/86798 @ops(op_db + additional_op_db + autograd_function_db, allowed_dtypes=(torch.float,)) @toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)}) @opsToleranceOverride( "TestOperators", "test_vmapvjp", ( tol1( "linalg.svd", {torch.float32: tol(atol=5e-04, rtol=1e-04)}, device_type="cuda", ), tol1( "svd", {torch.float32: tol(atol=5e-04, rtol=1e-04)}, device_type="cuda" ), tol1( "linalg.householder_product", {torch.float32: tol(atol=3e-04, rtol=9e-04)}, ), tol1( "matrix_exp", {torch.float32: tol(atol=5e-04, rtol=1e-04)}, device_type="cuda", ), tol1( "nn.functional.layer_norm", {torch.float32: tol(atol=3e-4, rtol=1e-4)}, device_type="cpu", ), tol1( "native_layer_norm", {torch.float32: tol(atol=3e-4, rtol=1e-4)}, device_type="cpu", ), ), ) @skipOps( "TestOperators", "test_vmapvjp", vmapvjp_fail.union( { xfail("as_strided"), xfail("as_strided_copy"), xfail("as_strided", "partial_views"), } ), ) def test_vmapvjp(self, device, dtype, op): if not op.supports_autograd: self.skipTest("Skipped! Autograd not supported.") return samples = op.sample_inputs(device, dtype, requires_grad=True) # TODO: test in-place if is_inplace(op, op.get_op()): self.skipTest("Skipped! NYI: inplace-testing not supported.") return for sample in samples: cotangents = get_sample_cotangents(op, sample) fn, args = get_vjp_fn_and_args_with_cotangents(op, sample, cotangents) is_batch_norm_and_training = is_batch_norm_training(op.name, sample.kwargs) generator = get_fallback_and_vmap_exhaustive( fn, args, {}, is_batch_norm_and_training=is_batch_norm_and_training ) for loop_out, batched_out in generator: self.assertEqual(loop_out, batched_out) vmapjvpall_fail = { # -------------------- ALLOWED FAILURES -------------------------------- # The following are expected (not a bug) skip("bernoulli", ""), # randomness skip("nn.functional.dropout"), # randomness skip("nn.functional.rrelu"), # randomness skip("nn.functional.dropout2d", ""), skip("nn.functional.dropout3d", ""), skip("nn.functional.scaled_dot_product_attention"), # randomness xfail("torch.ops.aten._efficient_attention_forward"), # outputs ints skip("nn.functional.multi_head_attention_forward"), # randomness skip("nn.functional.alpha_dropout"), # randomness skip("nn.functional.feature_alpha_dropout", "without_train"), skip("nn.functional.feature_alpha_dropout", "with_train"), xfail( "nn.functional.fractional_max_pool2d" ), # Cannot access data pointer of Tensor that doesn't have storage xfail( "nn.functional.fractional_max_pool3d" ), # Cannot access data pointer of Tensor that doesn't have storage # Not actually a problem: embedding with max_norm mutates the weight # and causes different runs to produce different results. # skip because this is flaky depending on what the max_norm is! skip("nn.functional.embedding", ""), skip("to"), # RuntimeError: required rank 4 tensor to use channels_last format xfail( "NumpyExpMarkDirtyAutogradFunction" ), # vmap: inplace into a regular tensor # ---------------------------------------------------------------------- # ---------------------------- BUGS ------------------------------------ # The following are bugs that we should fix xfail("masked.mean"), # silent incorrectness (nan difference) xfail("as_strided", "partial_views"), # Tensor-likes are not close! xfail( "nn.functional.soft_margin_loss", "" ), # soft_margin_loss_backward does not support forward-ad xfail("tensor_split"), # data_ptr composite compliance xfail("quantile"), # at::equal batching rule (cpu), also, in-place vmap (cuda) skip("as_strided"), # Test runner cannot handle this # requires special handling, and does not yet have a batching rule. Feel free to file a github issue! xfail("as_strided_scatter"), xfail( "nn.functional.gaussian_nll_loss" ), # .item or data-dependent control flow xfail("scatter"), # forward-mode AD does not support at::scatter xfail( "nanquantile" ), # at::equal batching rule (cpu), also, in-place vmap (cuda) xfail("view_as_complex"), # Tensor must have a last dimension with stride 1 skip("pca_lowrank", ""), # randomness skip("svd_lowrank", ""), # randomness xfail("double"), # required rank 4 tensor to use channels_last format xfail("cdouble"), # required rank 4 tensor to use channels_last format # potential silent incorrectness skip( "nn.functional.max_unpool1d" ), # Flaky, seems to sometimes his max_unpool2d skip("nn.functional.max_unpool2d"), # fails everywhere except on mac skip("nn.functional.max_unpool3d"), # fails everywhere except on mac # erroring because running_mean and running_var aren't differentiable xfail("nn.functional.batch_norm"), xfail("nn.functional.batch_norm", "without_cudnn"), xfail("native_batch_norm"), xfail("_native_batch_norm_legit"), # TODO: implement batching rule xfail("_batch_norm_with_update"), # ---------------------------------------------------------------------- } @with_tf32_off # https://github.com/pytorch/pytorch/issues/86798 @ops(op_db + additional_op_db + autograd_function_db, allowed_dtypes=(torch.float,)) @toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)}) @opsToleranceOverride( "TestOperators", "test_vmapjvpall", ( tol1( "nn.functional.conv_transpose3d", {torch.float32: tol(atol=2e-04, rtol=9e-3)}, device_type="cuda", ), tol1( "linalg.householder_product", {torch.float32: tol(atol=2e-04, rtol=9e-3)}, ), ), ) @skipOps( "TestOperators", "test_vmapjvpall", vmapjvpall_fail.union( { xfail("as_strided_copy"), decorate( "linalg.det", "singular", decorator=expectedFailureIf(IS_MACOS and IS_X86), ), } ), ) # This is technically a superset of test_vmapjvp. We should either delete test_vmapjvp # or figure out if we can split vmapjvpall. It's useful to keep test_vmapjvp intact # because that corresponds to "batched forward-mode AD" testing in PyTorch core def test_vmapjvpall(self, device, dtype, op): if is_inplace(op, op.get_op()): # TODO: test in-place self.skipTest("Skipped! NYI: inplace-testing not supported.") return samples = op.sample_inputs(device, dtype, requires_grad=False) if not op.supports_forward_ad: self.skipTest("Skipped! Forward AD not supported.") return for sample in samples: arg_values = [sample.input] + list(sample.args) kwarg_values = sample.kwargs args = tuple(arg_values) + tuple(kwarg_values) fn, args = get_jvp_variant_primals_tangents(op, sample) is_batch_norm_and_training = is_batch_norm_training(op.name, kwarg_values) generator = get_fallback_and_vmap_exhaustive( fn, args, {}, is_batch_norm_and_training=is_batch_norm_and_training ) for loop_out, batched_out in generator: self.assertEqual(loop_out, batched_out) @ops(op_db + additional_op_db + autograd_function_db, allowed_dtypes=(torch.float,)) @skipOps( "TestOperators", "test_vmapjvpall_has_batch_rule", vmapjvpall_fail.union( { skip( "to" ), # RuntimeError: required rank 4 tensor to use channels_last format xfail( "cdouble" ), # RuntimeError: required rank 4 tensor to use channels_last format xfail("cumprod"), xfail("masked_fill"), xfail("fill"), skip("masked.mean"), # ??? xfail("masked_scatter"), xfail("put"), xfail("take"), xfail("nn.functional.feature_alpha_dropout", "without_train"), xfail("nn.functional.dropout2d", ""), xfail("pca_lowrank", ""), xfail("svd_lowrank", ""), xfail("nn.functional.feature_alpha_dropout", "with_train"), xfail("special.log_ndtr", ""), xfail("fft.ihfft2"), # conj_physical fallback xfail("fft.ihfftn"), # conj_physical fallback xfail("nn.functional.max_unpool3d", "grad"), xfail("nn.functional.max_unpool2d", "grad"), xfail("nn.functional.soft_margin_loss", ""), xfail("nn.functional.max_unpool1d", "grad"), xfail("nn.functional.embedding", ""), xfail( "scatter_reduce", "sum" ), # aten::scatter_reduce.two hit the vmap fallback xfail( "scatter_reduce", "mean" ), # aten::scatter_reduce.two hit the vmap fallback xfail( "scatter_reduce", "amin" ), # aten::scatter_reduce.two hit the vmap fallback xfail( "scatter_reduce", "amax" ), # aten::scatter_reduce.two hit the vmap fallback xfail("nn.functional.glu"), xfail("nn.functional.bilinear"), # trilinear doesn't have batching rule xfail("linalg.lu", ""), xfail("nn.functional.dropout3d", ""), xfail("as_strided_scatter", ""), xfail("masked.cumprod", ""), xfail("renorm"), # hit vmap fallback, which is disabled xfail("t_copy"), xfail("unsqueeze_copy"), } ), ) @toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)}) def test_vmapjvpall_has_batch_rule(self, device, dtype, op): if is_inplace(op, op.get_op()): # TODO: test in-place self.skipTest("Skipped! NYI: inplace-testing not supported.") return samples = op.sample_inputs(device, dtype, requires_grad=False) if not op.supports_forward_ad: self.skipTest("Skipped! Forward AD not supported.") return def test(): for sample in samples: arg_values = [sample.input] + list(sample.args) kwarg_values = sample.kwargs args = tuple(arg_values) + tuple(kwarg_values) fn, args = get_jvp_variant_primals_tangents(op, sample) is_batch_norm_and_training = is_batch_norm_training( op.name, kwarg_values ) for loop_out, batched_out in get_fallback_and_vmap_exhaustive( fn, args, {}, is_batch_norm_and_training=is_batch_norm_and_training, compute_loop_out=False, ): pass check_vmap_fallback(self, test, op, dry_run=False) @ops(op_db + additional_op_db + autograd_function_db, allowed_dtypes=(torch.float,)) @toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)}) @skipOps( "TestOperators", "test_vmapvjp_has_batch_rule", vmapvjp_fail.union( { skip( "to" ), # RuntimeError: required rank 4 tensor to use channels_last format xfail("view_as_complex"), xfail("cummax"), xfail("cummin"), xfail("fill"), xfail( "narrow" ), # Batching rule not implemented for `narrow.Tensor` (and view op) xfail("special.log_ndtr"), xfail("linalg.householder_product"), xfail("masked_fill"), xfail("masked_scatter"), xfail("masked_select"), xfail("nanquantile"), xfail("ormqr"), xfail("put"), xfail( "scatter_reduce", "sum" ), # aten::scatter_reduce.two hit the vmap fallback xfail( "scatter_reduce", "mean" ), # aten::scatter_reduce.two hit the vmap fallback xfail( "scatter_reduce", "amin" ), # aten::scatter_reduce.two hit the vmap fallback xfail( "scatter_reduce", "amax" ), # aten::scatter_reduce.two hit the vmap fallback xfail("quantile"), xfail("renorm"), xfail("take"), xfail("tensor_split"), xfail("to_sparse"), xfail("unfold"), xfail("unfold_copy"), xfail("nn.functional.dropout"), xfail("fft.ihfft2"), xfail("fft.ihfftn"), xfail("nn.functional.gaussian_nll_loss"), xfail("nn.functional.bilinear"), xfail("nn.functional.fractional_max_pool3d"), xfail("nn.functional.ctc_loss"), xfail("nn.functional.rrelu"), xfail("nn.functional.embedding_bag"), xfail("nn.functional.fractional_max_pool2d"), xfail("nn.functional.feature_alpha_dropout", "with_train"), xfail("pca_lowrank", ""), xfail("nn.functional.dropout2d", ""), xfail("nn.functional.feature_alpha_dropout", "without_train"), xfail("svd_lowrank", ""), xfail("nn.functional.max_unpool2d", ""), xfail("nn.functional.multi_margin_loss", ""), xfail("nn.functional.multilabel_margin_loss", ""), xfail("nn.functional.pdist", ""), xfail("scatter_reduce", "prod"), xfail("nn.functional.max_unpool1d", ""), xfail("nn.functional.max_unpool3d", ""), xfail("nn.functional.max_unpool3d", "grad"), xfail("nn.functional.soft_margin_loss", ""), xfail("nn.functional.max_unpool1d", "grad"), xfail("nn.functional.max_unpool2d", "grad"), xfail("linalg.lu", ""), xfail("cdouble", ""), xfail("cfloat", ""), xfail("chalf", ""), xfail( "index_reduce", "prod" ), # aten::index_reduce hit the vmap fallback which is currently disabled xfail( "index_reduce", "mean" ), # aten::index_reduce hit the vmap fallback which is currently disabled xfail( "index_reduce", "amax" ), # aten::index_reduce hit the vmap fallback which is currently disabled xfail( "index_reduce", "amin" ), # aten::index_reduce hit the vmap fallback which is currently disabled xfail("nn.functional.dropout3d", ""), xfail("as_strided_scatter", ""), xfail("_segment_reduce", "offsets"), xfail("_segment_reduce", "lengths"), xfail("sparse.sampled_addmm", ""), xfail("sparse.mm", "reduce"), xfail("native_batch_norm"), xfail("_native_batch_norm_legit"), # TODO: implement batching rule xfail("_batch_norm_with_update"), xfail("native_dropout_backward"), xfail( "index_fill" ), # aten::_unique hit the vmap fallback which is currently disabled xfail("t_copy"), xfail("unsqueeze_copy"), } ), ) def test_vmapvjp_has_batch_rule(self, device, dtype, op): if not op.supports_autograd: self.skipTest("Skipped! Autograd not supported.") return samples = op.sample_inputs(device, dtype, requires_grad=True) # TODO: test in-place if is_inplace(op, op.get_op()): self.skipTest("Skipped! NYI: inplace-testing not supported.") return def test(): for sample in samples: cotangents = get_sample_cotangents(op, sample) fn, args = get_vjp_fn_and_args_with_cotangents(op, sample, cotangents) is_batch_norm_and_training = is_batch_norm_training( op.name, sample.kwargs ) for loop_out, batched_out in get_fallback_and_vmap_exhaustive( fn, args, {}, is_batch_norm_and_training=is_batch_norm_and_training, compute_loop_out=False, ): pass for a_op in op.aliases: fn, args = get_vjp_fn_and_args_with_cotangents( a_op, sample, cotangents ) for loop_out, batched_out in get_fallback_and_vmap_exhaustive( fn, args, {}, is_batch_norm_and_training=is_batch_norm_and_training, compute_loop_out=False, ): pass check_vmap_fallback(self, test, op, dry_run=False) @ops(op_db + additional_op_db + autograd_function_db, allowed_dtypes=(torch.float,)) @skipOps( "TestOperators", "test_vjpvmap", vjp_fail.union( { skip("bernoulli", ""), # vjpvmap testing can't handle randomness skip("normal", ""), # vjpvmap testing can't handle randomness skip( "normal", "number_mean" ), # vjpvmap testing can't handle randomness skip("nn.functional.rrelu"), # randomness skip("nn.functional.feature_alpha_dropout", "with_train"), # randomness skip( "nn.functional.feature_alpha_dropout", "without_train" ), # randomness skip("nn.functional.scaled_dot_product_attention"), xfail("torch.ops.aten._efficient_attention_forward"), # outputs ints skip("nn.functional.multi_head_attention_forward"), # randomness skip("nn.functional.alpha_dropout"), # randomness skip( "to" ), # RuntimeError: required rank 4 tensor to use channels_last format skip("to_sparse", ""), # non-dense output skip("ormqr", ""), # takes too long xfail( "NumpyCubeNotComposableAutogradFunction" ), # Not composable autograd.Function # fallback path doesn't work # All of the following are bugs and need to be fixed xfail("__getitem__", ""), xfail("index_put", ""), xfail("view_as_complex"), xfail("nn.functional.gaussian_nll_loss"), xfail("masked_select"), xfail( "narrow" ), # Batching rule not implemented for `narrow.Tensor` (and view op) skip( "nn.functional.fractional_max_pool3d" ), # generator works on cpu, fails on cuda skip( "nn.functional.fractional_max_pool2d" ), # generator works on cpu, fails on cuda xfail("column_stack", ""), xfail("nn.functional.dropout2d", ""), xfail("svd_lowrank", ""), xfail("pca_lowrank", ""), xfail("clamp"), # something weird happening with channels_last xfail("bfloat16"), xfail("double"), xfail("float"), xfail("half"), xfail("cdouble"), xfail("cfloat"), xfail("nn.functional.dropout3d", ""), xfail("as_strided_scatter", ""), xfail("sparse.sampled_addmm", ""), xfail("sparse.mm", "reduce"), xfail("native_batch_norm"), xfail("_native_batch_norm_legit"), # TODO: implement batching rule xfail("_batch_norm_with_update"), xfail("as_strided", "partial_views"), } ), ) def test_vjpvmap(self, device, dtype, op): # NB: there is no vjpvmap_has_batch_rule test because that is almost # certainly redundant with the vmap_has_batch_rule test in test_vmap.py # one-off skip if op.name == "nn.functional.dropout": self.skipTest("Skipped!") if not op.supports_autograd: # If the op doesn't support autograd, vmap(op) won't either self.skipTest("Skipped! Autograd not supported.") return # TODO: test in-place if is_inplace(op, op.get_op()): self.skipTest("Skipped! NYI: inplace-testing not supported.") return samples = op.sample_inputs(device, dtype, requires_grad=True) batch_norm_fns = ( "nn.functional.batch_norm", "nn.functional.instance_norm", ) # instance norm calls batch norm is_batch_norm = op.name in batch_norm_fns for sample in samples: args = [sample.input] + list(sample.args) kwargs = sample.kwargs is_batch_norm_and_training = is_batch_norm and is_batch_norm_training( op.name, kwargs ) generator = generate_vmap_inputs( args, kwargs, is_batch_norm_and_training=is_batch_norm_and_training ) for batched_args, in_dims, kwargs in generator: vmapped_op = vmap(op, in_dims) fn, primals = normalize_op_input_output2( vmapped_op, batched_args, kwargs, sample.output_process_fn_grad ) result = fn(*primals) cotangents = tree_map(lambda x: torch.randn_like(x), result) _, vjp_fn = vjp(fn, *primals) result_vjps = vjp_fn(cotangents) _, vjp_fn = ref_vjp(fn, *primals) expected_vjps = vjp_fn(cotangents) self.assertEqual(result_vjps, expected_vjps) def _compare_jacobians_of_vjp( self, fn, cotangents_and_primals, argnums=None, atol_rtol=None ): if argnums is None: argnums = tuple(range(len(cotangents_and_primals))) def get_vjp(cotangents, *primals): _, vjp_fn = vjp(fn, *primals) return vjp_fn(cotangents) jacobian_jvp = jacfwd(get_vjp, argnums)(*cotangents_and_primals) jacobian_vjp = jacrev(get_vjp, argnums)(*cotangents_and_primals) # For dtype changing operations, the jacobians have different dtype. jacobian_jvp = tree_map(lambda x: x.to(torch.float), jacobian_jvp) jacobian_vjp = tree_map(lambda x: x.to(torch.float), jacobian_vjp) if atol_rtol is not None: (atol, rtol) = atol_rtol self.assertEqual(jacobian_jvp, jacobian_vjp, atol=atol, rtol=rtol) else: self.assertEqual(jacobian_jvp, jacobian_vjp) @ops(op_db + additional_op_db + autograd_function_db, allowed_dtypes=(torch.float,)) @skipOps( "TestOperators", "test_jvpvjp", vjp_fail.union( { xfail("to_sparse", ""), # NYI # RuntimeError: Trying to set a forward gradient that has a different size than that of the original Tensor, # this is not supported. Tensor is of size [5, 2, 3] while the given forward gradient is of size [1, 2, 3]. xfail("normal", ""), xfail("cdist", ""), # NYI: forward-AD for _cdist_forward xfail("cholesky", ""), # NYI: forward-AD for cholesky xfail( "nn.functional.embedding_bag", "" ), # NYI: forward-AD for _embedding_bag xfail( "nn.functional.grid_sample", "" ), # NYI: forward AD for grid_sampler_2d xfail("grid_sampler_2d", ""), # NYI: forward AD for grid_sampler_2d xfail( "nn.functional.hardsigmoid", "" ), # NYI: forward AD for hardsigmoid_backward xfail( "nn.functional.huber_loss", "" ), # NYI: forward AD for huber_loss_backward xfail("NumpyCubeNotComposableAutogradFunction"), # not composable xfail("ormqr", ""), # NYI: forward AD for ormqr xfail( "nn.functional.multilabel_margin_loss", "" ), # NYI: multilabel_margin_loss_forward xfail( "nn.functional.soft_margin_loss", "" ), # NYI: forward-AD for soft_margin_loss_backward xfail("nn.functional.ctc_loss", ""), # NYI: forward-AD for _ctc_loss xfail("nn.functional.pdist", ""), # NYI: forward-AD with _pdist_forward skip("nn.functional.scaled_dot_product_attention"), xfail("torch.ops.aten._efficient_attention_forward"), # outputs ints xfail( "nn.functional.multi_margin_loss", "" ), # NYI: forward AD with multi_margin_loss skip( "linalg.householder_product", "", device_type="cuda" ), # flaky, I'm not sure why xfail("sparse.sampled_addmm", ""), # Sparse tensors have no strides xfail( "_segment_reduce", "offsets" ), # NYI: forward-AD for _segment_reduce xfail("sparse.mm", "reduce"), # Sparse tensors have no strides xfail("index_reduce", "prod"), # NYI: forward-AD for index_reduce xfail("index_reduce", "mean"), # NYI: forward-AD for index_reduce xfail("index_reduce", "amax"), # NYI: forward-AD for index_reduce xfail("index_reduce", "amin"), # NYI: forward-AD for index_reduce xfail( "_segment_reduce", "lengths" ), # NYI: forward-AD for _segment_reduce xfail("native_dropout_backward"), # NYI } ), ) @opsToleranceOverride( "TestOperators", "test_jvpvjp", ( tol1("masked.prod", {torch.float32: tol(atol=1e-04, rtol=1.3e-05)}), tol1("masked.cumprod", {torch.float32: tol(atol=1e-04, rtol=5e-04)}), tol1( "cumprod", {torch.float32: tol(atol=1e-03, rtol=5e-04)}, device_type="cuda", ), tol1( "linalg.det", {torch.float32: tol(atol=3e-05, rtol=5e-06)}, device_type="cuda", ), tol1( "linalg.vander", {torch.float32: tol(atol=1e-04, rtol=1.3e-05)}, device_type="cuda", ), tol1( "nn.functional.group_norm", {torch.float32: tol(atol=1e-03, rtol=1e-03)} ), tol2( "linalg.pinv", "hermitian", {torch.float32: tol(atol=5e-03, rtol=5e-03)} ), ), ) def test_jvpvjp(self, device, dtype, op): if not op.supports_autograd: self.skipTest("Skipped! Autograd not supported.") return samples = op.sample_inputs(device, dtype, requires_grad=True) # TODO: test in-place if is_inplace(op, op.get_op()): self.skipTest("Skipped! NYI: inplace-testing not supported.") return for sample in samples: fn, primals = normalize_op_input_output(op, sample) result = fn(*primals) cotangents = tree_map(lambda x: torch.randn_like(x), result) primals_tangents = tree_map(lambda x: torch.randn_like(x), primals) cotangents_tangents = tree_map(lambda x: torch.randn_like(x), cotangents) def push_vjp(primals, cotangents): _, vjp_fn = vjp(fn, *primals) return vjp_fn(cotangents) result = jvp( push_vjp, (primals, cotangents), (primals_tangents, cotangents_tangents) ) self.assertEqual(len(result), 2) def tree_map2(fn, first, second): flat_first, spec_first = tree_flatten(first) flat_second, spec_second = tree_flatten(second) assert spec_first == spec_second flat_result = [fn(f, s) for f, s in zip(flat_first, flat_second)] return tree_unflatten(flat_result, spec_first) def reference(primals, cotangents, primals_tangents, cotangents_tangents): with fwAD.dual_level(): primal_duals = tree_map2(fwAD.make_dual, primals, primals_tangents) _, vjp_fn = ref_vjp(fn, *primal_duals) cotangent_duals = tree_map2( fwAD.make_dual, cotangents, cotangents_tangents ) result = vjp_fn(cotangent_duals) flat_result, spec = tree_flatten(result) primals_out, tangents_out = zip( *[fwAD.unpack_dual(r) for r in flat_result] ) tangents_out = [ t if t is not None else torch.zeros_like(p) for p, t in zip(primals_out, tangents_out) ] expected = ( tree_unflatten(primals_out, spec), tree_unflatten(tangents_out, spec), ) return expected expected = reference( primals, cotangents, primals_tangents, cotangents_tangents ) self.assertEqual(result, expected) @with_tf32_off # https://github.com/pytorch/pytorch/issues/86798 @skipOps( "TestOperators", "test_vmapjvpvjp", vjp_fail.union( { # Following operators take too long, hence skipped skip("atleast_1d"), skip("atleast_2d"), skip("atleast_3d"), skip("meshgrid", "list_of_tensors"), skip("meshgrid", "variadic_tensors"), skip("broadcast_tensors"), skip("linalg.lstsq"), skip("nn.functional.bilinear"), skip("native_layer_norm"), skip("ormqr"), # Not actually a problem xfail("NumpyCubeNotComposableAutogradFunction"), # not composable xfail( "NumpyExpMarkDirtyAutogradFunction" ), # vmap: inplace into a regular tensor # Potential bugs/errors xfail("as_strided"), # AssertionError: Tensor-likes are not close! xfail( "as_strided", "partial_views" ), # AssertionError: Tensor-likes are not close! xfail("as_strided_copy"), # AssertionError: Tensor-likes are not close! xfail( "as_strided_scatter" ), # AssertionError: Tensor-likes are not close! xfail("bernoulli"), # calls random op xfail("bfloat16"), # required rank 4 tensor to use channels_last format xfail("cdist"), # Forward AD not implemented and no decomposition xfail("cdouble"), # required rank 4 tensor to use channels_last format xfail("cfloat"), # required rank 4 tensor to use channels_last format xfail("chalf"), # required rank 4 tensor to use channels_last format xfail("cholesky"), # Forward AD not implemented and no decomposition xfail("ormqr"), # Forward AD not implemented and no decomposition xfail("double"), # required rank 4 tensor to use channels_last format xfail("float"), # required rank 4 tensor to use channels_last format xfail("half"), # required rank 4 tensor to use channels_last format xfail("index_reduce", "prod"), # NYI: forward AD for index_reduce xfail("index_reduce", "mean"), # NYI: forward AD for index_reduce xfail("index_reduce", "amax"), # NYI: forward AD for index_reduce xfail("index_reduce", "amin"), # NYI: forward AD for index_reduce xfail( "mvlgamma", "mvlgamma_p_1" ), # vmap: inplace into a regular tensor xfail( "mvlgamma", "mvlgamma_p_3" ), # vmap: inplace into a regular tensor xfail( "mvlgamma", "mvlgamma_p_5" ), # vmap: inplace into a regular tensor xfail("nanquantile"), # Batching rule not implemented for aten::equal # RuntimeError: Batch norm got a batched tensor as input while the # running_mean or running_var, which will be updated in place, # were not batched. xfail("nn.functional.batch_norm"), xfail("nn.functional.batch_norm", "without_cudnn"), xfail( "nn.functional.ctc_loss" ), # ForwardAD not implemented and no decomposition xfail("nn.functional.dropout2d"), # calls random op xfail("nn.functional.dropout3d"), # calls random op xfail("nn.functional.dropout"), # calls random op xfail("nn.functional.scaled_dot_product_attention"), # randomness xfail("torch.ops.aten._efficient_attention_forward"), # outputs ints xfail("nn.functional.multi_head_attention_forward"), # randomness xfail( "nn.functional.embedding_bag" ), # Forward AD not implemented and no decomposition xfail("nn.functional.alpha_dropout"), # calls randomn op xfail( "nn.functional.feature_alpha_dropout", "with_train" ), # calls random op xfail("nn.functional.fractional_max_pool2d"), # calls random op xfail("nn.functional.fractional_max_pool3d"), # calls random op xfail("nn.functional.gaussian_nll_loss"), # data depenedant flow xfail( "nn.functional.grid_sample" ), # Forward AD not implemented and no decomposition xfail( "grid_sampler_2d" ), # Forward AD not implemented and no decomposition xfail( "nn.functional.hardsigmoid" ), # Forward AD not implemented and no decomposition xfail( "nn.functional.hinge_embedding_loss" ), # vmap: inplace into a regular tensor xfail( "nn.functional.huber_loss" ), # Forward AD not implemented and no decomposition # RuntimeError: Batch norm got a batched tensor as input while the # running_mean or running_var, which will be updated in place, # were not batched. xfail("nn.functional.instance_norm"), # NYI: Tensor.clone(memory_format) inside vmap is only supported with # memory_format torch.preserve_format or torch.contiguous_format (got ChannelsLast) xfail("nn.functional.max_unpool2d"), xfail("nn.functional.max_unpool2d", "grad"), xfail( "nn.functional.multi_margin_loss" ), # Forward AD not implemented and no decomposition xfail( "nn.functional.multilabel_margin_loss" ), # Forward AD not implemented and no decomposition xfail( "nn.functional.pdist" ), # Forward AD not implemented and no decomposition xfail( "nn.functional.rrelu" ), # vmap: we do not yet support aten::rrelu_with_noise. xfail( "nn.functional.soft_margin_loss" ), # Forward AD not implemented and no decomposition xfail("normal"), # calls random op xfail("normal", "number_mean"), # calls random op xfail("pca_lowrank"), # calls random op xfail("quantile"), # Batching rule not implemented for aten::equal xfail( "scatter_reduce", "prod" ), # Forward AD not implemented and no decomposition xfail( "_segment_reduce", "lengths" ), # Forward AD not implemented and no decomposition xfail( "_segment_reduce", "offsets" ), # Forward AD not implemented and no decomposition xfail( "sparse.sampled_addmm" ), # RuntimeError: Sparse CSR tensors do not have strides xfail( "sparse.mm", "reduce" ), # RuntimeError: Sparse CSR tensors do not have strides xfail("svd_lowrank"), # calls random op xfail( "to" ), # RuntimeError: required rank 4 tensor to use channels_last format xfail("to_sparse"), # Forward AD not implemented and no decomposition xfail( "view_as_complex" ), # RuntimeError: Tensor must have a last dimension with stride 1 # RuntimeError: Batch norm got a batched tensor as # input while the running_mean or running_var, which will be updated in # place, were not batched. xfail("native_batch_norm"), xfail("_native_batch_norm_legit"), # TODO: implement batching rule xfail("_batch_norm_with_update"), xfail("native_dropout_backward"), } ), ) @ops(op_db + additional_op_db + autograd_function_db, allowed_dtypes=(torch.float,)) @toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)}) @opsToleranceOverride( "TestOperators", "test_vmapjvpvjp", ( tol1("linalg.svd", {torch.float32: tol(atol=5e-04, rtol=5e-04)}), tol1( "linalg.householder_product", {torch.float32: tol(atol=5e-03, rtol=5e-03)}, ), tol1("linalg.multi_dot", {torch.float32: tol(atol=5e-04, rtol=5e-04)}), tol2( "linalg.pinv", "hermitian", {torch.float32: tol(atol=5e-04, rtol=5e-04)} ), tol1( "nn.functional.conv_transpose2d", {torch.float32: tol(atol=5e-04, rtol=5e-04)}, ), tol1("svd", {torch.float32: tol(atol=5e-04, rtol=5e-04)}), tol1("matrix_exp", {torch.float32: tol(atol=5e-04, rtol=5e-04)}), ), ) def test_vmapjvpvjp(self, device, dtype, op): # Since we test `jvpvjp` separately, # in this we just check that vmap of `jvpvjp` # is correct. if not op.supports_autograd: self.skipTest("Skipped! Autograd not supported.") return samples = op.sample_inputs(device, dtype, requires_grad=True) # TODO: test in-place if is_inplace(op, op.get_op()): self.skipTest("Skipped! NYI: inplace-testing not supported.") return for sample in samples: fn, primals = normalize_op_input_output(op, sample) result = fn(*primals) cotangents = tree_map(lambda x: torch.randn_like(x), result) primals_tangents = tree_map(lambda x: torch.randn_like(x), primals) cotangents_tangents = tree_map(lambda x: torch.randn_like(x), cotangents) def push_vjp(primals, cotangents): _, vjp_fn = vjp(fn, *primals) return vjp_fn(cotangents) args, spec = tree_flatten( ((primals, cotangents), (primals_tangents, cotangents_tangents)) ) def jvp_of_vjp(*args): (primals, tangents) = tree_unflatten(args, spec) primals_out, tangents_out = jvp(push_vjp, primals, tangents) flat_primals_out = pytree.tree_leaves(primals_out) flat_tangents_out = pytree.tree_leaves(tangents_out) return tuple(flat_primals_out + flat_tangents_out) is_batch_norm_and_training = is_batch_norm_training(op, sample.kwargs) generator = get_fallback_and_vmap_exhaustive( jvp_of_vjp, args, {}, is_batch_norm_and_training=is_batch_norm_and_training, ) for loop_out, batched_out in generator: self.assertEqual(loop_out, batched_out) def _make_extremal_inputs(self, shape, device): if shape is None: return (None,) return ( torch.full(shape, -1000.0, device=device), torch.zeros(shape, device=device), torch.full(shape, 1000.0, device=device), ) def _arg_and_kwarg_options(self, args_options, kwargs_options): return itertools.product(*args_options, kwargs_options) def test_extremal_numerics_nll_loss(self, device): N, C = 3, 4 d1, d2, d3 = 5, 6, 7 shapes = ( ((N, C), (N,), (C,)), ((N, C), (N,), None), ((N, C, d1, d2, d3), (N, d1, d2, d3), (C,)), ((N, C, d1, d2, d3), (N, d1, d2, d3), None), ) kwargs_options = ( {"ignore_index": 0, "reduction": "mean"}, {"reduction": "sum"}, {"reduction": "none"}, {}, ) for input_shape, target_shape, weight_shape in shapes: input_options = self._make_extremal_inputs(input_shape, device) for input, kwargs in self._arg_and_kwarg_options( (input_options,), kwargs_options ): if weight_shape is None: weight = None else: weight = torch.randn(weight_shape, device=device) target = torch.randint(0, C, target_shape, device=device) target[ 0 ] = 1 # since we're ignoring index 0, at least one element must be non-zero fn = functools.partial( torch.nn.functional.nll_loss, target=target, weight=weight, **kwargs ) result = fn(input) cotangents = torch.randn_like(result, device=device) self._compare_jacobians_of_vjp(fn, (cotangents, input)) def test_extremal_numerics_l1_loss(self, device): N, C, H, W = 3, 4, 5, 6 shapes = ((N, C), (N, C, H), (N, C, H, W)) kwargs_options = ({"reduction": "sum"}, {"reduction": "none"}, {}) for shape in shapes: input_options = self._make_extremal_inputs(shape, device) target_options = self._make_extremal_inputs(shape, device) for input, target, kwargs in self._arg_and_kwarg_options( (input_options, target_options), kwargs_options ): result = torch.nn.functional.l1_loss(input, target) cotangents = torch.randn_like(result, device=device) self._compare_jacobians_of_vjp( torch.nn.functional.l1_loss, (cotangents, input, target) ) def test_extremal_numerics_mse_loss(self, device): N, C, H, W = 3, 4, 5, 6 shapes = ((N, C), (N, C, H), (N, C, H, W)) kwargs_options = ({"reduction": "sum"}, {"reduction": "none"}, {}) for shape in shapes: input_options = self._make_extremal_inputs(shape, device) target_options = self._make_extremal_inputs(shape, device) for input, target, kwargs in self._arg_and_kwarg_options( (input_options, target_options), kwargs_options ): result = torch.nn.functional.mse_loss(input, target) cotangents = torch.randn_like(result, device=device) self._compare_jacobians_of_vjp( torch.nn.functional.mse_loss, (cotangents, input, target) ) def test_extremal_numerics_softmax(self, device): N, C, H, W = 3, 4, 5, 6 shapes = ((N, C), (N, C, H), (N, C, H, W)) kwargs_options = ({"dim": 1}, {}) for shape in shapes: input_options = self._make_extremal_inputs(shape, device) for input, kwargs in self._arg_and_kwarg_options( (input_options,), kwargs_options ): result = torch.nn.functional.softmax(input) cotangents = torch.randn_like(result, device=device) self._compare_jacobians_of_vjp( torch.nn.functional.softmax, (cotangents, input) ) def test_extremal_numerics_log_softmax(self, device): N, C, H, W = 3, 4, 5, 6 shapes = ((N, C), (N, C, H), (N, C, H, W)) kwargs_options = ({"dim": 1}, {}) for shape in shapes: input_options = self._make_extremal_inputs(shape, device) for input, kwargs in self._arg_and_kwarg_options( (input_options,), kwargs_options ): result = torch.nn.functional.log_softmax(input) cotangents = torch.randn_like(result, device=device) self._compare_jacobians_of_vjp( torch.nn.functional.log_softmax, (cotangents, input) ) def test_extremal_numerics_cross_entropy(self, device): N, C = 3, 4 d1, d2, d3 = 5, 6, 7 shapes = ( ((N, C), (N,), (C,)), ((N, C), (N,), None), ((N, C), (N, C), (C,)), ((N, C), (N, C), None), ((C,), (), (C,)), ((C,), (), None), ((C,), (C,), (C,)), ((C,), (C,), None), ((N, C, d1, d2, d3), (N, d1, d2, d3), (C,)), ((N, C, d1, d2, d3), (N, d1, d2, d3), None), ((N, C, d1, d2, d3), (N, C, d1, d2, d3), (C,)), ((N, C, d1, d2, d3), (N, C, d1, d2, d3), None), ) for input_shape, target_shape, weight_shape in shapes: input_options = self._make_extremal_inputs(input_shape, device) kwargs_options = [{"reduction": "sum"}, {"reduction": "none"}, {}] if input_shape != target_shape: kwargs_options.append({"ignore_index": 0, "reduction": "mean"}) for input, kwargs in self._arg_and_kwarg_options( (input_options,), kwargs_options ): if weight_shape is None: weight = None else: weight = torch.randn(weight_shape, device=device) if input_shape == target_shape: target = torch.rand(target_shape, device=device) elif len(target_shape) == 0: target = torch.tensor( 1, device=device ) # must be non-zero since ignore_index may be 0 else: target = torch.randint(0, C, target_shape, device=device) fn = functools.partial( torch.nn.functional.cross_entropy, target=target, weight=weight, **kwargs, ) result = fn(input) cotangents = torch.randn_like(result, device=device) self._compare_jacobians_of_vjp( fn, (cotangents, input), atol_rtol=(1e-4, 1e-5) ) def test_extremal_numerics_binary_cross_entropy(self, device): N, C, H, W = 3, 4, 5, 6 shapes = ((N, C), (N, C, H), (N, C, H, W)) for shape in shapes: weight_options = self._make_extremal_inputs(shape, device) kwargs_options = [{"reduction": "sum"}, {"reduction": "none"}, {}] for weight, kwargs in self._arg_and_kwarg_options( (weight_options,), kwargs_options ): input = torch.rand(shape, device=device) target = torch.rand(shape, device=device) fn = functools.partial( torch.nn.functional.binary_cross_entropy, target=target, weight=weight, **kwargs, ) result = fn(input) cotangents = torch.randn_like(result, device=device) self._compare_jacobians_of_vjp( fn, (cotangents, input), atol_rtol=(1e-4, 2e-5) ) def test_extremal_numerics_layer_norm(self, device): N, C, H, W = 3, 4, 5, 6 shapes = ((N, C), (N, C, H), (N, C, H, W)) for shape in shapes: input_options = self._make_extremal_inputs(shape, device) normalized_shape = shape[1:] weight_options = self._make_extremal_inputs(normalized_shape, device) bias_options = self._make_extremal_inputs(normalized_shape, device) for input, bias, weight in self._arg_and_kwarg_options( (input_options, bias_options, weight_options), () ): def fn(input, weight, bias): return torch.nn.functional.layer_norm( input, normalized_shape, weight=weight, bias=bias ) result = fn(input, weight, bias) cotangents = torch.randn_like(result, device=device) self._compare_jacobians_of_vjp(fn, (cotangents, input, weight, bias)) @with_tf32_off # https://github.com/pytorch/pytorch/issues/86798 @ops( op_db + additional_op_db + autograd_function_db, allowed_dtypes=(torch.float32, torch.double), ) @skipOps( "TestOperators", "test_vmap_autograd_grad", { # The size of tensor a (4) must match the size of tensor b (10) at non-singleton dimension 0 xfail("masked_select"), xfail("nn.functional.max_unpool2d", "grad"), # contiguous call xfail("nn.functional.max_unpool2d"), # contiguous call xfail("to_sparse"), # dispatch key issue xfail("torch.ops.aten._efficient_attention_forward"), # outputs ints # https://github.com/pytorch/pytorch/issues/96560#issuecomment-2151063723 # ** minor accuracy issue for float32 on ROCm decorate("xlogy", decorator=skipIfRocm), # numerical inconsistencies, look like bugs skip( "matrix_exp", dtypes=(torch.float32,), device_type="cuda" ), # fails on linux, passes on windows skip( "ldexp", dtypes=(torch.float32,), device_type="cpu" ), # fails on all but mac skip("__rmatmul__"), # flaky needs investigation skip("matmul"), # flaky needs investigation skip("nn.functional.conv_transpose3d"), # flaky needs investigation skip("nn.functional.conv_transpose2d"), # flaky needs investigation skip("nn.functional.conv_transpose1d"), # flaky needs investigation skip( "nn.functional.layer_norm", dtypes=(torch.float32,), device_type="cpu" ), # fails on windows skip( "linalg.lu_factor", dtypes=(torch.float32,), device_type="cuda" ), # fails on all but windows skip( "linalg.lu_factor_ex", dtypes=(torch.float32,), device_type="cuda" ), # fails on all but windows skip("linalg.multi_dot", "", device_type="cpu"), skip("sparse.sampled_addmm", ""), skip("sparse.mm", "reduce"), skip("native_layer_norm", "", device_type="cpu"), # RuntimeError: Expected contiguous tensor, but got # non-contiguous tensor for argument #2 'grad_output' decorate( "_batch_norm_with_update", decorator=expectedFailureIf(TEST_WITH_ROCM), device_type="cuda", ), }, ) @opsToleranceOverride( "TestOperators", "test_vmap_autograd_grad", ( tol1( "ldexp", {torch.float32: tol(atol=3e-04, rtol=1.6e-06)}, device_type="cuda", ), tol1( "linalg.householder_product", {torch.float32: tol(atol=5e-04, rtol=9e-03)}, device_type="cuda", ), tol1( "linalg.householder_product", {torch.float32: tol(atol=6e-03, rtol=1e-03)}, device_type="cpu", ), tol1( "linalg.multi_dot", {torch.float32: tol(atol=2e-04, rtol=1e-04)}, device_type="cuda", ), tol2( "linalg.pinv", "hermitian", {torch.float32: tol(atol=5e-06, rtol=5e-06)} ), tol1("nn.functional.conv3d", {torch.float32: tol(atol=5e-04, rtol=9e-03)}), tol1( "nn.functional.conv2d", {torch.float32: tol(atol=3e-05, rtol=5e-06)}, device_type="cuda", ), tol1("svd_lowrank", {torch.float32: tol(atol=5e-05, rtol=5e-05)}), tol1("pca_lowrank", {torch.float32: tol(atol=5e-05, rtol=5e-05)}), ), ) def test_vmap_autograd_grad(self, device, dtype, op): def is_differentiable(inp): return isinstance(inp, Tensor) and ( inp.grad_fn is not None or inp.requires_grad ) def get_flat_differentiable(tree): flattened = pytree.tree_leaves(tree) return tuple(i for i in flattened if is_differentiable(i)) def get_differentiable_linked(list1, list2): paired_list = zip(list1, list2) paired_list = tuple( (first, second) for (first, second) in paired_list if is_differentiable(first) ) return zip(*paired_list) def filter_none(out): flattened = pytree.tree_leaves(out) return tuple(o for o in flattened if o is not None) if not op.supports_autograd: self.skipTest("Skipped! Autograd not supported.") return sample_inputs = op.sample_inputs(device, dtype, requires_grad=True) for sample_input in sample_inputs: fn, primals = normalize_op_input_output(op, sample_input) out = fn(*primals) cotangents = tree_map(torch.randn_like, out) def compute_grad(cotangents): out_flattened = out cotangents_flattened = cotangents if not isinstance(out_flattened, torch.Tensor): out_flattened = pytree.tree_leaves(out) cotangents_flattened = pytree.tree_leaves(cotangents) out_flattened, cotangents_flattened = get_differentiable_linked( out_flattened, cotangents_flattened ) return filter_none( torch.autograd.grad( out_flattened, get_flat_differentiable(primals), cotangents_flattened, retain_graph=True, allow_unused=True, ) ) is_batch_norm_and_training = is_batch_norm_training(op, sample_input.kwargs) generator = get_fallback_and_vmap_exhaustive( compute_grad, (cotangents,), {}, is_batch_norm_and_training=is_batch_norm_and_training, ) for loop_out, batched_out in generator: self.assertEqual(loop_out, batched_out) def test_vmapvmapjvp_linalg_solve(self): ops = [op for op in op_db if op.name == "linalg.solve"] assert len(ops) > 0 # this specializes a lot of code from the get_fallback_and_vmap_exhaustive test. If we need this more # generally, this could go for a refactor B0 = 2 B1 = 3 # we want to check the case where A will be seen as contiguous by jvp but during the vmap calls will become # non-contiguous because vmap will expand. This will happen during both levels of vmap A = torch.randn(4, 4) k = torch.randn(4, 5, B1, B0) fn, args = get_jvp_variant_primals_tangents( torch.linalg.solve, SampleInput(A, args=(k,)) ) in_dims_all = (None, -1, None, -1) batched_out = vmap(vmap(fn, in_dims=in_dims_all), in_dims=in_dims_all)(*args) loop_out = loop2(fn, in_dims_all, in_dims_all, 0, 0, B0, B1, *args) self.assertEqual(loop_out, batched_out) @ops( filter(lambda op: op.name in aliasing_ops, op_db + additional_op_db), allowed_dtypes=(torch.float,), ) @parametrize("grad_op", ["jvp", "vjp"]) def test_view_then_inplace(self, device, dtype, op, grad_op): for sample_input in op.sample_inputs(device, dtype): def f(x): op(sample_input.input, *sample_input.args, **sample_input.kwargs).copy_( x ) return x without_grad = op( sample_input.input, *sample_input.args, **sample_input.kwargs ) if grad_op == "jvp": with self.assertRaisesRegex( RuntimeError, "During a grad .* attempted to call in-place operation", ): jvp( f, (torch.randn_like(without_grad),), (torch.randn_like(without_grad),), ) else: assert grad_op == "vjp" with self.assertRaisesRegex( RuntimeError, "During a grad .* attempted to call in-place operation", ): vjp(f, torch.randn_like(without_grad)) @ops( filter( lambda op: op.name in aliasing_ops_list_return, op_db + additional_op_db ), allowed_dtypes=(torch.float,), ) @parametrize("grad_op", ["jvp", "vjp"]) def test_view_then_inplace_list_return(self, device, dtype, op, grad_op): for sample_input in op.sample_inputs(device, dtype): def f(x): op(sample_input.input, *sample_input.args, **sample_input.kwargs)[ 0 ].copy_(x) return x without_grad = op( sample_input.input, *sample_input.args, **sample_input.kwargs )[0] with self.assertRaisesRegex( RuntimeError, "During a grad .* attempted to call in-place operation" ): if grad_op == "jvp": jvp( f, (torch.randn_like(without_grad),), (torch.randn_like(without_grad),), ) else: assert grad_op == "vjp" vjp(f, torch.randn_like(without_grad)) @parametrize("grad_op", ["jvp", "vjp"]) def test_view_then_inplace_special(self, grad_op): # some things in __getitem__ use at::index, which doesn't alias, so this tests a subset of them that do alias ops = [ lambda x: x[0], lambda x: x[0, 0, 0], lambda x: x[:1], lambda x: x[:, :1], lambda x: x[:, :1, :], ] for op in ops: def f(x): op(captured).copy_(x) return x captured = torch.randn(4, 3, 3) without_grad = op(captured) if grad_op == "jvp": with self.assertRaisesRegex( RuntimeError, "During a grad .* attempted to call in-place operation", ): jvp( f, (torch.randn_like(without_grad),), (torch.randn_like(without_grad),), ) else: assert grad_op == "vjp" with self.assertRaisesRegex( RuntimeError, "During a grad .* attempted to call in-place operation", ): vjp(f, torch.randn_like(without_grad)) @with_tf32_off # https://github.com/pytorch/pytorch/issues/86798 # NOTE: [three-transform testing] # We only test the autograd_function_db tests here. # # Usually testing the composition of two transforms is sufficient to convince # ourselves that an operator is correctly implemented. For the following cases, # we want to be extra sure, so we send those through some three-transform tests: # - autograd.Function. The mechanism is via PyDispatcher/HigherOrderOperator, not the # regular PyTorch dispatcher, so it's good to exercise more caution. @ops(autograd_function_db, allowed_dtypes=(torch.float32,)) @skipOps( "TestOperators", "test_vmapvjpvmap", { xfail("NumpyCubeNotComposableAutogradFunction"), # Not composable }, ) def test_vmapvjpvmap(self, device, dtype, op): samples = op.sample_inputs(device, dtype, requires_grad=True) B = 2 for sample in samples: args = [sample.input] + list(sample.args) kwargs = sample.kwargs generator = generate_vmap_inputs(args, kwargs, batch_size=B) for batched_args, in_dims, kwargs in generator: inner_vmapped_op = vmap(op, in_dims) inner_mapped_op = functools.partial(loop, op, in_dims, 0, B) inner_vmapped_fn, primals = normalize_op_input_output2( inner_vmapped_op, batched_args, kwargs, sample.output_process_fn_grad, ) inner_mapped_fn, _ = normalize_op_input_output2( inner_mapped_op, batched_args, kwargs, sample.output_process_fn_grad ) result = inner_mapped_fn(*primals) cotangents = tree_map(lambda x: torch.rand_like(x), result) def apply_vjp(fn): def inner(primals, cotangents): _, vjp_fn = vjp(fn, *primals) return vjp_fn(cotangents) return inner vjpvmap_fn = apply_vjp(inner_vmapped_fn) vjpmap_fn = apply_vjp(inner_mapped_fn) batched_args = (primals, cotangents) generator = generate_vmap_inputs(batched_args, {}) for batched_args, in_dims, _ in generator: # strategy: compare vmap(vjp(vmap(op)) vs map(vjp(map(op)) vmapvjpvmap_fn = vmap(vjpvmap_fn, in_dims) mapvjpmap_fn = functools.partial(loop, vjpmap_fn, in_dims, 0, B) result = vmapvjpvmap_fn(*batched_args) expected = mapvjpmap_fn(*batched_args) self.assertEqual(result, expected) # See NOTE: [three-transform testing] @ops(autograd_function_db, allowed_dtypes=(torch.float32,)) @skipOps( "TestOperators", "test_vjpvmapvmap", { xfail("NumpyCubeNotComposableAutogradFunction"), # Not composable }, ) def test_vjpvmapvmap(self, device, dtype, op): samples = op.sample_inputs(device, dtype, requires_grad=True) B = 2 for sample in samples: args = [sample.input] + list(sample.args) kwargs = sample.kwargs generator = generate_vmap_inputs(args, kwargs, batch_size=B) for batched_args, inner_in_dims, kwargs in generator: inner_vmapped_op = vmap(op, inner_in_dims) inner_mapped_op = functools.partial(loop, op, inner_in_dims, 0, B) generator = generate_vmap_inputs(batched_args, kwargs) for batched_args, in_dims, kwargs in generator: # strategy: compare vjp(vmap(vmap(op)) vs vjp(map(map(op)) vmapped_op = vmap(inner_vmapped_op, in_dims) mapped_op = functools.partial(loop, inner_mapped_op, in_dims, 0, B) vmapped_fn, primals = normalize_op_input_output2( vmapped_op, batched_args, kwargs, sample.output_process_fn_grad ) mapped_fn, _ = normalize_op_input_output2( mapped_op, batched_args, kwargs, sample.output_process_fn_grad ) result = mapped_fn(*primals) cotangents = tree_map(lambda x: torch.rand_like(x), result) _, vjp_fn = vjp(mapped_fn, *primals) expected_vjps = vjp_fn(cotangents) _, vjp_fn = vjp(vmapped_fn, *primals) result_vjps = vjp_fn(cotangents) self.assertEqual(result_vjps, expected_vjps) # See NOTE: [three-transform testing] @ops(autograd_function_db, allowed_dtypes=(torch.float32,)) @skipOps( "TestOperators", "test_vjpvjpvmap", { xfail("NumpyCubeNotComposableAutogradFunction"), # Not composable }, ) def test_vjpvjpvmap(self, device, dtype, op): samples = op.sample_inputs(device, dtype, requires_grad=True) B = 2 for sample in samples: args = [sample.input] + list(sample.args) kwargs = sample.kwargs generator = generate_vmap_inputs(args, kwargs, batch_size=B) for batched_args, in_dims, kwargs in generator: inner_vmapped_op = vmap(op, in_dims) inner_mapped_op = functools.partial(loop, op, in_dims, 0, B) vjpmap_fn, args = get_vjpfull_variant2( inner_mapped_op, batched_args, kwargs ) vjpvmap_fn, _ = get_vjpfull_variant2( inner_vmapped_op, batched_args, kwargs ) vjpvjpvmap_fn, new_args = get_vjpfull_variant2(vjpvmap_fn, args, {}) vjpvjpmap_fn, _ = get_vjpfull_variant2(vjpmap_fn, args, {}) expected = vjpvjpmap_fn(*new_args) result = vjpvjpvmap_fn(*new_args) self.assertEqual(result, expected) # We're generally convinced that jvp x vmap works (vmap turns an operator # into another operator and we test jvp support for operators). So # we only test it on the things we're not sure about: # - the autograd.Function <> functorch interaction @ops(autograd_function_db, allowed_dtypes=(torch.float32,)) @skipOps( "TestOperators", "test_jvpvmap", { xfail("NumpyCubeNotComposableAutogradFunction"), # Not composable }, ) def test_jvpvmap(self, device, dtype, op): samples = op.sample_inputs(device, dtype, requires_grad=True) B = 2 for sample in samples: args = [sample.input] + list(sample.args) kwargs = sample.kwargs generator = generate_vmap_inputs(args, kwargs, batch_size=B) for batched_args, in_dims, kwargs in generator: inner_vmapped_op = vmap(op, in_dims) inner_mapped_op = functools.partial(loop, op, in_dims, 0, B) jvpvmap_op, primals = get_jvp_variant_primals_tangents2( inner_vmapped_op, batched_args, kwargs, sample.output_process_fn_grad, ) jvpmap_op, _ = get_jvp_variant_primals_tangents2( inner_mapped_op, batched_args, kwargs, sample.output_process_fn_grad ) expected = jvpmap_op(*primals) result = jvpvmap_op(*primals) self.assertEqual(result, expected) # See NOTE: [three-transform testing] @ops(autograd_function_db, allowed_dtypes=(torch.float32,)) @skipOps( "TestOperators", "test_jvpvmapvmap", { xfail("NumpyCubeNotComposableAutogradFunction"), # Not composable }, ) def test_jvpvmapvmap(self, device, dtype, op): samples = op.sample_inputs(device, dtype, requires_grad=True) B = 2 for sample in samples: args = [sample.input] + list(sample.args) kwargs = sample.kwargs generator = generate_vmap_inputs(args, kwargs, batch_size=B) for batched_args, inner_in_dims, kwargs in generator: inner_vmapped_op = vmap(op, inner_in_dims) inner_mapped_op = functools.partial(loop, op, inner_in_dims, 0, B) generator = generate_vmap_inputs(batched_args, kwargs) for batched_args, in_dims, kwargs in generator: # strategy: compare jvp(vmap(vmap(op)) vs jvp(map(map(op)) vmapped_op = vmap(inner_vmapped_op, in_dims) mapped_op = functools.partial(loop, inner_mapped_op, in_dims, 0, B) jvpvmapvmap_fn, primals = get_jvp_variant_primals_tangents2( vmapped_op, batched_args, kwargs, sample.output_process_fn_grad ) jvpmapmap_fn, _ = get_jvp_variant_primals_tangents2( mapped_op, batched_args, kwargs, sample.output_process_fn_grad ) expected = jvpmapmap_fn(*primals) result = jvpvmapvmap_fn(*primals) self.assertEqual(result, expected) # See NOTE: [three-transform testing] @with_tf32_off # https://github.com/pytorch/pytorch/issues/86798 @ops(autograd_function_db, allowed_dtypes=(torch.float32,)) @skipOps( "TestOperators", "test_vmapjvpvmap", { xfail("NumpyCubeNotComposableAutogradFunction"), # Not composable }, ) def test_vmapjvpvmap(self, device, dtype, op): samples = op.sample_inputs(device, dtype, requires_grad=True) B = 2 for sample in samples: args = [sample.input] + list(sample.args) kwargs = sample.kwargs generator = generate_vmap_inputs(args, kwargs, batch_size=B) for batched_args, in_dims, kwargs in generator: inner_vmapped_op = vmap(op, in_dims) inner_mapped_op = functools.partial(loop, op, in_dims, 0, B) jvpvmap_fn, primals = get_jvp_variant_primals_tangents2( inner_vmapped_op, batched_args, kwargs, sample.output_process_fn_grad, ) jvpmap_fn, _ = get_jvp_variant_primals_tangents2( inner_mapped_op, batched_args, kwargs, sample.output_process_fn_grad ) generator = generate_vmap_inputs(primals, {}) for batched_args, in_dims, _ in generator: # strategy: compare vmap(jvp(vmap(op)) vs map(jvp(map(op)) vmapjvpvmap_fn = vmap(jvpvmap_fn, in_dims) mapjvpmap_fn = functools.partial(loop, jvpmap_fn, in_dims, 0, B) result = vmapjvpvmap_fn(*batched_args) expected = mapjvpmap_fn(*batched_args) self.assertEqual(result, expected) # See NOTE: [three-transform testing] @ops(autograd_function_db, allowed_dtypes=(torch.float32,)) @skipOps( "TestOperators", "test_jvpjvpvmap", { xfail("NumpyCubeNotComposableAutogradFunction"), # Not composable }, ) def test_jvpjvpvmap(self, device, dtype, op): samples = op.sample_inputs(device, dtype, requires_grad=True) B = 2 for sample in samples: args = [sample.input] + list(sample.args) kwargs = sample.kwargs generator = generate_vmap_inputs(args, kwargs, batch_size=B) for batched_args, in_dims, kwargs in generator: inner_vmapped_op = vmap(op, in_dims) inner_mapped_op = functools.partial(loop, op, in_dims, 0, B) jvpmap_fn, args = get_jvp_variant_primals_tangents2( inner_mapped_op, batched_args, kwargs, sample.output_process_fn_grad ) jvpvmap_fn, _ = get_jvp_variant_primals_tangents2( inner_vmapped_op, batched_args, kwargs, sample.output_process_fn_grad, ) jvpjvpvmap_fn, new_args = get_jvp_variant_primals_tangents2( jvpvmap_fn, args, {} ) jvpjvpmap_fn, _ = get_jvp_variant_primals_tangents2(jvpmap_fn, args, {}) expected = jvpjvpmap_fn(*new_args) result = jvpjvpvmap_fn(*new_args) self.assertEqual(result, expected) # See NOTE: [three-transform testing] @ops(autograd_function_db, allowed_dtypes=(torch.float32,)) @skipOps( "TestOperators", "test_jvpvjpvmap", { xfail("NumpyCubeNotComposableAutogradFunction"), # Not composable }, ) def test_jvpvjpvmap(self, device, dtype, op): samples = op.sample_inputs(device, dtype, requires_grad=True) B = 2 for sample in samples: args = [sample.input] + list(sample.args) kwargs = sample.kwargs generator = generate_vmap_inputs(args, kwargs, batch_size=B) for batched_args, in_dims, kwargs in generator: inner_vmapped_op = vmap(op, in_dims) inner_mapped_op = functools.partial(loop, op, in_dims, 0, B) vjpmap_fn, args = get_vjpfull_variant2( inner_mapped_op, batched_args, kwargs ) vjpvmap_fn, _ = get_vjpfull_variant2( inner_vmapped_op, batched_args, kwargs ) jvpvjpvmap_fn, new_args = get_jvp_variant_primals_tangents2( vjpvmap_fn, args, {} ) jvpvjpmap_fn, _ = get_jvp_variant_primals_tangents2(vjpmap_fn, args, {}) expected = jvpvjpmap_fn(*new_args) result = jvpvjpvmap_fn(*new_args) self.assertEqual(result, expected) def test_data_write_errors_under_transform(self, device): t = torch.randn(3, 3, device=device) def fn(t): t.data = torch.randn(3, 3) return t.sum() msg = "mutating directly with `.data` inside functorch transform" with self.assertRaisesRegex(RuntimeError, msg): grad(fn)(t) with self.assertRaisesRegex(RuntimeError, msg): vjp(fn, t) with self.assertRaisesRegex(RuntimeError, msg): jvp(fn, (t,), (torch.randn_like(t),)) def test_tensor_with_scalar_list(self, device): x = torch.randn((), device=device) def func_list_of_scalar(x): return torch.tensor([x], device=device) def func(x): return torch.tensor(x, device=device).view(1) actual_o, actual_fn = vjp(func_list_of_scalar, x) expected_o, expected_fn = vjp(func, x) self.assertEqual(actual_o, expected_o) self.assertEqual( expected_fn(torch.ones_like(expected_o)), actual_fn(torch.ones_like(actual_o)), ) only_for = ("cpu", "cuda") instantiate_device_type_tests(TestOperators, globals(), only_for=only_for) if __name__ == "__main__": run_tests()