xref: /aosp_15_r20/external/pytorch/torch/autograd/gradcheck.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import collections
3import functools
4import warnings
5from itertools import product
6from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union
7from typing_extensions import deprecated
8
9import torch
10import torch.testing
11from torch._vmap_internals import _vmap, vmap
12from torch.overrides import is_tensor_like
13from torch.types import _TensorOrTensors
14
15
16# Note: `get_*_jacobian` functions are added here even though we didn't intend to make them public
17# since they have been exposed from before we added `__all__`  and we already maintain BC for them
18# We should eventually deprecate them and remove them from `__all__`
19__all__ = [
20    "gradcheck",
21    "gradgradcheck",
22    "GradcheckError",
23    "get_numerical_jacobian",
24    "get_analytical_jacobian",
25    "get_numerical_jacobian_wrt_specific_input",
26]
27
28
29class GradcheckError(RuntimeError):
30    r"""Error raised by :func:`gradcheck` and :func:`gradgradcheck`."""
31
32
33def _is_sparse_compressed_tensor(obj: torch.Tensor):
34    return obj.layout in {
35        torch.sparse_csr,
36        torch.sparse_csc,
37        torch.sparse_bsr,
38        torch.sparse_bsc,
39    }
40
41
42def _is_sparse_any_tensor(obj: torch.Tensor):
43    return _is_sparse_compressed_tensor(obj) or obj.layout is torch.sparse_coo
44
45
46def _is_float_or_complex_tensor(obj):
47    return is_tensor_like(obj) and (obj.is_floating_point() or obj.is_complex())
48
49
50def _allocate_jacobians_with_inputs(
51    input_tensors: Tuple, numel_output
52) -> Tuple[torch.Tensor, ...]:
53    # Makes zero-filled tensors from inputs. If `numel_output` is not None, for
54    # each tensor in `input_tensors`, returns a new zero-filled tensor with height
55    # of `t.numel` and width of `numel_output`. Otherwise, for each tensor, returns
56    # a 1-d tensor with size `(t.numel,)`. Each new tensor will be strided and have
57    # the same dtype and device as those of the corresponding input.
58    out: List[torch.Tensor] = []
59    for t in input_tensors:
60        if _is_float_or_complex_tensor(t) and t.requires_grad:
61            out.append(t.new_zeros((t.numel(), numel_output), layout=torch.strided))
62    return tuple(out)
63
64
65def _allocate_jacobians_with_outputs(
66    output_tensors: Tuple, numel_input, dtype=None, device=None
67) -> Tuple[torch.Tensor, ...]:
68    # Makes zero-filled tensors from outputs. If `dim` is not None, for each tensor
69    # in `output_tensors`, returns a new zero-filled tensor with height of `dim` and
70    # width of `t.numel`. Otherwise, for each tensor, returns a 1-d tensor with size
71    # (t.numel,).
72    out: List[torch.Tensor] = []
73    options = {"dtype": dtype, "device": device, "layout": torch.strided}
74    for t in output_tensors:
75        if _is_float_or_complex_tensor(t):
76            out.append(t.new_zeros((numel_input, t.numel()), **options))
77    return tuple(out)
78
79
80def _iter_tensors(
81    x: Union[torch.Tensor, Iterable[torch.Tensor]], only_requiring_grad: bool = False
82) -> Iterable[torch.Tensor]:
83    if is_tensor_like(x):
84        # mypy doesn't narrow type of `x` to torch.Tensor
85        if x.requires_grad or not only_requiring_grad:  # type: ignore[union-attr]
86            yield x  # type: ignore[misc]
87    elif isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
88        for elem in x:
89            yield from _iter_tensors(elem, only_requiring_grad)
90
91
92def _densify(x):
93    # return a copy of sparse x with all unspecified elements
94    # "replaced" with zero-valued elements
95    if isinstance(x, (list, tuple)):
96        return type(x)(map(_densify, x))
97    elif not is_tensor_like(x) or x.layout in {torch.strided, torch._mkldnn}:  # type: ignore[attr-defined] # no attr _mkldnn
98        return x
99    elif x.layout is torch.sparse_coo:
100        device = x.device
101        indices_dtype = x._indices().dtype
102        tmp = torch.ones(x.shape[: x.sparse_dim()], dtype=torch.int8, device=device)
103        indices = tmp.nonzero().t().to(dtype=indices_dtype)
104        values = torch.zeros(
105            (tmp.numel(), *x.shape[x.sparse_dim() :]), dtype=x.dtype, device=device
106        )
107        x_coalesced = x.detach().coalesce()
108        if x_coalesced.numel() > 0:
109            stride = tmp.stride()
110            flat_indices = (
111                x_coalesced.indices()
112                .mul(
113                    torch.tensor(stride, dtype=indices_dtype, device=device).unsqueeze(
114                        1
115                    )
116                )
117                .sum(0)
118            )
119            values[flat_indices] = x_coalesced.values()
120        return (
121            torch.sparse_coo_tensor(indices, values, x.shape)
122            ._coalesced_(True)
123            .requires_grad_(x.requires_grad)
124        )
125    elif _is_sparse_compressed_tensor(x):
126        blocksize = (
127            x.values().shape[1:3]
128            if x.layout in {torch.sparse_bsr, torch.sparse_bsc}
129            else None
130        )
131        compressed_indices = (
132            x.crow_indices()
133            if x.layout in {torch.sparse_csr, torch.sparse_bsr}
134            else x.ccol_indices()
135        )
136        # We'll use intermediate sparse COO for simplicity
137        r = _densify(x.detach().to_sparse(layout=torch.sparse_coo)).to_sparse(
138            layout=x.layout, blocksize=blocksize
139        )
140        # Check that all elements are specified also after `to_sparse` op:
141        dense_numel = r.values().numel() // max(1, r.values().shape[0])
142        batch_numel = compressed_indices.numel() // compressed_indices.shape[-1]
143        sparse_numel = r.numel() // max(1, dense_numel * batch_numel)
144        if sparse_numel != r._nnz():
145            raise AssertionError(
146                f"{x.layout} densify failed: expected nnz={sparse_numel} but got {r._nnz()}"
147            )
148        return r.requires_grad_(x.requires_grad)
149    elif _is_sparse_any_tensor(x):
150        raise NotImplementedError(x.layout)
151    return x
152
153
154def _iter_tensor(x_tensor):
155    # (Only used for slow gradcheck) Returns a generator that yields the following
156    # elements at each iteration:
157    #  1) a tensor: the same tensor is returned across all iterations. The tensor
158    #     is not the same as the original x_tensor as given as input - it is
159    #     prepared so that it can be modified in-place. Depending on whether the
160    #     input tensor is strided, sparse, or dense, the returned tensor may or may
161    #     not share storage with x_tensor.
162    #  2) a tuple of indices that can be used with advanced indexing (yielded in
163    #     dictionary order)
164    #  3) flattened index that will be used to index into the Jacobian tensor
165    #
166    # For a tensor t with size (2, 2), _iter_tensor yields:
167    #     `x, (0, 0), 0`, `x, (0, 1), 1`, `x, (1, 0), 2`, `x, (1, 1), 3`
168    #
169    # where x is the t.data of the original tensor. Perturbing the entry of x
170    # at index (1, 1) yields the 3rd column of the overall Jacobian matrix.
171    if _is_sparse_any_tensor(x_tensor):
172
173        def get_stride(size):
174            dim = len(size)
175            tmp = 1
176            stride = [0] * dim
177            for i in reversed(range(dim)):
178                stride[i] = tmp
179                tmp *= size[i]
180            return stride
181
182        x_nnz = x_tensor._nnz()
183        x_size = list(x_tensor.size())
184        if x_tensor.layout is torch.sparse_coo:
185            x_indices = x_tensor._indices().t()
186            x_values = x_tensor._values()
187        elif x_tensor.layout is torch.sparse_csr:
188            x_indices = torch._convert_indices_from_csr_to_coo(
189                x_tensor.crow_indices(), x_tensor.col_indices()
190            ).t()
191            x_values = x_tensor.values()
192        elif x_tensor.layout is torch.sparse_csc:
193            x_indices = torch._convert_indices_from_csr_to_coo(
194                x_tensor.ccol_indices(), x_tensor.row_indices(), transpose=True
195            ).t()
196            x_values = x_tensor.values()
197        elif x_tensor.layout is torch.sparse_bsr:
198            x_block_values = x_tensor.values()
199            x_blocksize = x_block_values.size()[1:3]
200            x_indices = (
201                torch._convert_indices_from_csr_to_coo(
202                    x_tensor.crow_indices(), x_tensor.col_indices()
203                )
204                .repeat_interleave(x_blocksize[0] * x_blocksize[1], 1)
205                .mul_(torch.tensor(x_blocksize, device=x_tensor.device).reshape(2, 1))
206                .add_(
207                    torch.stack(
208                        torch.where(torch.ones(x_blocksize, device=x_tensor.device))
209                    ).repeat(1, x_nnz)
210                )
211                .t()
212            )
213            x_values = x_block_values.flatten(0, 2)
214            x_nnz = x_values.size(0)
215        elif x_tensor.layout is torch.sparse_bsc:
216            x_block_values = x_tensor.values()
217            x_blocksize = x_block_values.size()[1:3]
218            x_indices = (
219                torch._convert_indices_from_csr_to_coo(
220                    x_tensor.ccol_indices(), x_tensor.row_indices(), transpose=True
221                )
222                .repeat_interleave(x_blocksize[0] * x_blocksize[1], 1)
223                .mul_(torch.tensor(x_blocksize, device=x_tensor.device).reshape(2, 1))
224                .add_(
225                    torch.stack(
226                        torch.where(torch.ones(x_blocksize, device=x_tensor.device))
227                    ).repeat(1, x_nnz)
228                )
229                .t()
230            )
231            x_values = x_block_values.flatten(0, 2)
232            x_nnz = x_values.size(0)
233        else:
234            raise NotImplementedError(f"_iter_tensor for {x_tensor.layout} input")
235        x_stride = get_stride(x_size)
236        # Use .data here to get around the version check
237        x_values = x_values.data
238        for i in range(x_nnz):
239            x_value = x_values[i]
240            for x_idx in product(*[range(m) for m in x_values.size()[1:]]):
241                indices = x_indices[i].tolist() + list(x_idx)
242                d_idx = sum(indices[k] * x_stride[k] for k in range(len(x_size)))
243                yield x_value, x_idx, d_idx
244    elif x_tensor.layout == torch._mkldnn:  # type: ignore[attr-defined]
245        for d_idx, x_idx in enumerate(product(*[range(m) for m in x_tensor.size()])):
246            # this is really inefficient, but without indexing implemented, there's
247            # not really a better way than converting back and forth
248            x_tensor_dense = x_tensor.to_dense()
249            yield x_tensor_dense, x_idx, d_idx
250    else:
251        # Use .data here to get around the version check
252        x_tensor = x_tensor.data
253        for d_idx, x_idx in enumerate(product(*[range(m) for m in x_tensor.size()])):
254            yield x_tensor, x_idx, d_idx
255
256
257def _get_numerical_jacobian(
258    fn, inputs, outputs=None, target=None, eps=1e-3, is_forward_ad=False
259) -> List[Tuple[torch.Tensor, ...]]:
260    """Compute the numerical Jacobian of `fn(inputs)` with respect to `target`.
261
262    If not specified, targets are the input. Returns M * N Jacobians where N is the
263    number of tensors in target that require grad and M is the number of non-integral
264    outputs.
265
266    Args:
267        fn: the function to compute the jacobian for
268        inputs: inputs to `fn`
269        outputs: provide precomputed outputs to avoid one extra invocation of fn
270        target: the Tensors wrt whom Jacobians are calculated (default=`inputs`)
271        eps: the magnitude of the perturbation during finite differencing
272             (default=`1e-3`)
273        is_forward_ad: if this numerical jacobian is computed to be checked wrt
274                       forward AD gradients (this is used for error checking only)
275
276    Returns:
277        A list of M N-tuples of tensors
278
279    Note that `target` may not even be part of `input` to `fn`, so please be
280    **very careful** in this to not clone `target`.
281    """
282    jacobians: List[Tuple[torch.Tensor, ...]] = []
283    if outputs is None:
284        outputs = _as_tuple(fn(*_as_tuple(inputs)))
285    if not is_forward_ad and any(o.is_complex() for o in outputs):
286        raise ValueError(
287            "Expected output to be non-complex. get_numerical_jacobian no "
288            "longer supports functions that return complex outputs."
289        )
290    if target is None:
291        target = inputs
292    inp_indices = [
293        i for i, a in enumerate(target) if is_tensor_like(a) and a.requires_grad
294    ]
295    for i, (inp, inp_idx) in enumerate(zip(_iter_tensors(target, True), inp_indices)):
296        jacobians += [
297            get_numerical_jacobian_wrt_specific_input(
298                fn,
299                inp_idx,
300                inputs,
301                outputs,
302                eps,
303                input=inp,
304                is_forward_ad=is_forward_ad,
305            )
306        ]
307    return jacobians
308
309
310@deprecated(
311    "`get_numerical_jacobian` was part of PyTorch's private API and not "
312    "meant to be exposed. We are deprecating it and it will be removed "
313    "in a future version of PyTorch. If you have a specific use for "
314    "this or feature request for this to be a stable API, please file "
315    "us an issue at https://github.com/pytorch/pytorch/issues/new",
316    category=FutureWarning,
317)
318def get_numerical_jacobian(fn, inputs, target=None, eps=1e-3, grad_out=1.0):
319    """Compute the numerical Jacobian for a given fn and its inputs.
320
321    This is a Deprecated API.
322
323    Args:
324        fn: the function to compute the Jacobian for (must take inputs as a tuple)
325        input: input to `fn`
326        target: the Tensors wrt whom Jacobians are calculated (default=`input`)
327        eps: the magnitude of the perturbation during finite differencing
328             (default=`1e-3`)
329
330    Returns:
331        A list of Jacobians of `fn` (restricted to its first output) with respect to
332        each input or target, if provided.
333
334    Note that `target` may not even be part of `input` to `fn`, so please be
335    **very careful** in this to not clone `target`.
336    """
337    if (
338        grad_out != 1.0
339    ):  # grad_out param is only kept for backward compatibility reasons
340        raise ValueError(
341            "Expected grad_out to be 1.0. get_numerical_jacobian no longer "
342            "supports values of grad_out != 1.0."
343        )
344
345    def fn_pack_inps(*inps):
346        return fn(inps)
347
348    jacobians = _get_numerical_jacobian(fn_pack_inps, inputs, None, target, eps)
349
350    return tuple(jacobian_for_each_output[0] for jacobian_for_each_output in jacobians)
351
352
353def _compute_numerical_gradient(fn, entry, v, norm_v, nbhd_checks_fn):
354    # Computes numerical directional derivative as finite difference
355    # of function `fn` at input `entry`, perturbed by vector `v`.
356    if _is_sparse_compressed_tensor(entry):
357        # sparse compressed tensors don't implement sub/add/copy_
358        # yet. However, in non-masked semantics context entry and v
359        # have the same sparse indices ...
360        assert entry.layout == v.layout, (entry.layout, v.layout)
361        assert entry._nnz() == v._nnz(), (entry._nnz(), v._nnz(), entry.shape)
362        # ... the finite differencing can be performed on values only:
363        entry = entry.values()
364        v = v.values()
365        # we'll detach to avoid backward computations that sparse
366        # tensors have limited support for.
367        entry = entry.detach()
368
369    orig = entry.clone()
370    entry.copy_(orig - v)
371    outa = fn()
372    entry.copy_(orig + v)
373    outb = fn()
374    entry.copy_(orig)
375
376    def compute(a, b):
377        nbhd_checks_fn(a, b)
378        ret = (b - a) / (2 * norm_v)  # use central difference approx
379        return ret.detach().reshape(-1)
380
381    return tuple(compute(a, b) for (a, b) in zip(outa, outb))
382
383
384def _compute_numerical_jvps_wrt_specific_input(
385    jvp_fn, delta, input_is_complex, is_forward_ad=False
386) -> List[torch.Tensor]:
387    # Computing the jacobian only works for real delta
388    # For details on the algorithm used here, refer:
389    # Section 3.5.3 https://arxiv.org/pdf/1701.00392.pdf
390    # s = fn(z) where z = x for real valued input
391    # and z = x + yj for complex valued input
392    jvps: List[torch.Tensor] = []
393    ds_dx_tup = jvp_fn(delta[0] if isinstance(delta, tuple) else delta)
394
395    if input_is_complex:  # C -> R
396        ds_dy_tup = (
397            jvp_fn(delta[1] * 1j) if isinstance(delta, tuple) else jvp_fn(delta * 1j)
398        )
399        for ds_dx, ds_dy in zip(ds_dx_tup, ds_dy_tup):
400            assert not ds_dx.is_complex()
401            # conjugate wirtinger derivative
402            conj_w_d = ds_dx + ds_dy * 1j
403            jvps.append(conj_w_d)
404    else:
405        for ds_dx in ds_dx_tup:  # R -> R or (R -> C for the forward AD case)
406            assert is_forward_ad or not ds_dx.is_complex()
407            jvps.append(ds_dx)
408    return jvps
409
410
411def _combine_jacobian_cols(
412    jacobians_cols: Dict[int, List[torch.Tensor]], outputs, input, numel
413) -> Tuple[torch.Tensor, ...]:
414    # jacobian_cols maps column_idx -> output_idx -> single column of jacobian Tensor
415    # we return a list that maps output_idx -> full jacobian Tensor
416    jacobians = _allocate_jacobians_with_outputs(
417        outputs, numel, dtype=input.dtype if input.dtype.is_complex else None
418    )
419    for i, jacobian in enumerate(jacobians):
420        for k, v in jacobians_cols.items():
421            jacobian[k] = v[i]
422    return jacobians
423
424
425def _prepare_input(
426    input: torch.Tensor, maybe_perturbed_input: Optional[torch.Tensor], fast_mode=False
427) -> torch.Tensor:
428    # Prepares the inputs to be passed into the function while including the new
429    # modified input.
430    if input.layout == torch._mkldnn:  # type: ignore[attr-defined] # no attr _mkldnn
431        # Convert back to mkldnn
432        if maybe_perturbed_input is not None:
433            return maybe_perturbed_input.to_mkldnn()
434        else:
435            return input
436    elif _is_sparse_any_tensor(input):
437        if fast_mode and maybe_perturbed_input is not None:
438            # entry is already a "cloned" version of the original tensor
439            # thus changes to entry are not reflected in the input
440            return maybe_perturbed_input
441        else:
442            return input
443    else:
444        # We cannot use entry (input.data) if we want gradgrad to work because
445        # fn (in the gradgrad case) needs to compute grad wrt input
446        return input
447
448
449def _check_outputs_same_dtype_and_shape(output1, output2, eps, idx=None) -> None:
450    # Check that the returned outputs don't have different dtype or shape when you
451    # perturb the input
452    on_index = "on index {idx} " if idx is not None else ""
453    assert output1.shape == output2.shape, (
454        f"Expected `func` to return outputs with the same shape"
455        f" when inputs are perturbed {on_index}by {eps}, but got:"
456        f" shapes {output1.shape} and {output2.shape}."
457    )
458    assert output1.dtype == output2.dtype, (
459        f"Expected `func` to return outputs with the same dtype"
460        f" when inputs are perturbed {on_index}by {eps}, but got:"
461        f" dtypes {output1.dtype} and {output2.dtype}."
462    )
463
464
465def get_numerical_jacobian_wrt_specific_input(
466    fn, input_idx, inputs, outputs, eps, input=None, is_forward_ad=False
467) -> Tuple[torch.Tensor, ...]:
468    # Computes the numerical jacobians wrt to a single input. Returns N jacobian
469    # tensors, where N is the number of outputs. We use a dictionary for
470    # jacobian_cols because indices aren't necessarily consecutive for sparse inputs
471    # When we perturb only a single element of the input tensor at a time, the jvp
472    # is equivalent to a single col of the Jacobian matrix of fn.
473    jacobian_cols: Dict[int, List[torch.Tensor]] = {}
474    input = inputs[input_idx] if input is None else input
475    assert input.requires_grad
476    for x, idx, d_idx in _iter_tensor(input):
477        wrapped_fn = _with_prepare_inputs(fn, inputs, input_idx, x)
478        input_to_perturb = x[idx]
479        nbhd_checks_fn = functools.partial(
480            _check_outputs_same_dtype_and_shape, idx=idx, eps=eps
481        )
482        jvp_fn = _get_numerical_jvp_fn(
483            wrapped_fn, input_to_perturb, eps, nbhd_checks_fn
484        )
485        jacobian_cols[d_idx] = _compute_numerical_jvps_wrt_specific_input(
486            jvp_fn, eps, x.is_complex(), is_forward_ad
487        )
488    return _combine_jacobian_cols(jacobian_cols, outputs, input, input.numel())
489
490
491def _get_analytical_jacobian_forward_ad(
492    fn, inputs, outputs, *, check_grad_dtypes=False, all_u=None
493) -> Tuple[Tuple[torch.Tensor, ...], ...]:
494    """Compute the analytical Jacobian using forward mode AD of `fn(inputs)` using forward mode AD with respect to `target`.
495
496    Return N * M Jacobians where N is the number of tensors in target that require grad and
497    M is the number of non-integral outputs.
498    Contrary to other functions here, this function requires "inputs" to actually be used by the function.
499    The computed value is expected to be wrong if the function captures the inputs by side effect instead of
500    using the passed ones (many torch.nn tests do this).
501
502    Args:
503        fn: the function to compute the jacobian for
504        inputs: inputs to `fn`
505        outputs: provide precomputed outputs to avoid one extra invocation of fn
506        check_grad_dtypes: if True, will check that the gradient dtype are valid
507        all_u (optional): if provided, the Jacobian will be right multiplied with this vector
508
509    Returns:
510        A tuple of M N-tuples of tensors
511    """
512    # To avoid early import issues
513    fwAD = torch.autograd.forward_ad
514
515    tensor_inputs = tuple(i for i in inputs if is_tensor_like(i) and i.requires_grad)
516
517    if any(i.is_complex() for i in tensor_inputs):
518        raise ValueError(
519            "Expected inputs to be non-complex for _get_analytical_jacobian_forward_ad."
520        )
521
522    if all_u:
523        jacobians = tuple(
524            _allocate_jacobians_with_outputs(outputs, 1) for i in tensor_inputs
525        )
526    else:
527        jacobians = tuple(
528            _allocate_jacobians_with_outputs(outputs, i.numel()) for i in tensor_inputs
529        )
530
531    with fwAD.dual_level():
532        fw_grads = []
533        dual_inputs = []
534        for i, inp in enumerate(inputs):
535            if is_tensor_like(inp) and inp.requires_grad:
536                if inp.layout == torch._mkldnn:  # type: ignore[attr-defined]
537                    raise ValueError(
538                        "MKLDNN inputs are not support for forward AD gradcheck."
539                    )
540
541                inp = fwAD.make_dual(inp.detach(), torch.zeros_like(inp))
542                # If inp is a differentiable view, the dual might not be the tangent given to
543                # make_dual, so read it explicitly from the dual tensor
544                fw_grads.append(fwAD.unpack_dual(inp)[1])
545            dual_inputs.append(inp)
546
547        if all_u:
548            # Do the full reduction in one pass
549            # To be consistent with numerical evaluation, we actually compute one reduction per input
550            for i, (fw_grad, u) in enumerate(zip(fw_grads, all_u)):
551                fw_grad.copy_(u.view_as(fw_grad))
552                raw_outputs = _as_tuple(fn(*dual_inputs))
553                dual_outputs = filter(_is_float_or_complex_tensor, raw_outputs)
554                for index_o, d_o in enumerate(dual_outputs):
555                    val, res = fwAD.unpack_dual(d_o)
556                    if (
557                        check_grad_dtypes
558                        and res is not None
559                        and val.is_complex() != res.is_complex()
560                    ):
561                        raise GradcheckError("Forward AD gradient has dtype mismatch.")
562
563                    # Remove extra dimension of size 1 corresponding to the reduced input
564                    jacobians[i][index_o].squeeze_(0)
565                    if res is None:
566                        jacobians[i][index_o].zero_()
567                    else:
568                        jacobians[i][index_o].copy_(res.reshape(-1))
569                fw_grad.zero_()
570        else:
571            # Reconstruct the full Jacobian column by column
572            for i, fw_grad in enumerate(fw_grads):
573                for lin_idx, grad_idx in enumerate(
574                    product(*[range(m) for m in fw_grad.size()])
575                ):
576                    fw_grad[grad_idx] = 1.0
577                    raw_outputs = _as_tuple(fn(*dual_inputs))
578                    dual_outputs = filter(_is_float_or_complex_tensor, raw_outputs)
579                    for index_o, d_o in enumerate(dual_outputs):
580                        val, res = fwAD.unpack_dual(d_o)
581                        if (
582                            check_grad_dtypes
583                            and res is not None
584                            and val.is_complex() != res.is_complex()
585                        ):
586                            raise GradcheckError(
587                                "Forward AD gradient has dtype mismatch."
588                            )
589
590                        if res is None:
591                            jacobians[i][index_o][lin_idx].zero_()
592                        else:
593                            jacobians[i][index_o][lin_idx].copy_(res.reshape(-1))
594                    fw_grad[grad_idx] = 0.0
595
596    return jacobians
597
598
599def _get_input_to_perturb(input):
600    # Prepare the input so that it can be modified in-place and do certain
601    # operations that require the tensor to have strides. If fast_mode=False,
602    # _iter_tensor would handle the below cases:
603    if input.layout == torch._mkldnn:  # type: ignore[attr-defined] # no attr _mkldnn
604        # Convert to dense so we can perform operations that require strided tensors
605        input_to_perturb = input.to_dense()
606    elif _is_sparse_any_tensor(input):
607        # Clone because input may require grad, and copy_ calls resize_,
608        # which is not allowed for .data
609        input_to_perturb = input.clone()
610    else:
611        input_to_perturb = input.data
612    return input_to_perturb
613
614
615def _with_prepare_inputs(fn, inputs, input_idx, input_to_perturb, fast_mode=False):
616    # Wraps `fn` so that its inputs are already supplied
617    def wrapped_fn():
618        inp = tuple(
619            _prepare_input(a, input_to_perturb if i == input_idx else None, fast_mode)
620            if is_tensor_like(a)
621            else a
622            for i, a in enumerate(_as_tuple(inputs))
623        )
624        return tuple(a.clone() for a in _as_tuple(fn(*inp)))
625
626    return wrapped_fn
627
628
629def _get_numerical_jvp_fn(wrapped_fn, input_to_perturb, eps, nbhd_checks_fn):
630    # Wraps jvp_fn so that certain arguments are already supplied
631    def jvp_fn(delta):
632        return _compute_numerical_gradient(
633            wrapped_fn, input_to_perturb, delta, eps, nbhd_checks_fn
634        )
635
636    return jvp_fn
637
638
639def _reshape_tensor_or_tuple(u, shape):
640    # We don't need to reshape when input corresponding to u is sparse
641    if isinstance(u, tuple):
642        if not _is_sparse_any_tensor(u[0]):
643            return (u[0].reshape(shape), u[1].reshape(shape))
644    else:
645        if not _is_sparse_any_tensor(u):
646            return u.reshape(shape)
647    return u
648
649
650def _mul_tensor_or_tuple(u, k):
651    if isinstance(u, tuple):
652        return (k * u[0], k * u[1])
653    else:
654        return k * u
655
656
657def _get_numerical_jvp_wrt_specific_input(
658    fn, input_idx, inputs, u, eps, is_forward_ad=False
659) -> List[torch.Tensor]:
660    input = inputs[input_idx]
661    input_to_perturb = _get_input_to_perturb(input)
662    wrapped_fn = _with_prepare_inputs(fn, inputs, input_idx, input_to_perturb, True)
663    nbhd_checks_fn = functools.partial(_check_outputs_same_dtype_and_shape, eps=eps)
664    jvp_fn = _get_numerical_jvp_fn(wrapped_fn, input_to_perturb, eps, nbhd_checks_fn)
665    u = _reshape_tensor_or_tuple(u, input_to_perturb.shape)
666    u = _mul_tensor_or_tuple(u, eps)
667    return _compute_numerical_jvps_wrt_specific_input(
668        jvp_fn, u, input.is_complex(), is_forward_ad
669    )
670
671
672def _get_numerical_vJu(
673    fn, inputs, inp_indices, func_out, all_u, all_v, eps, is_forward_ad
674):
675    # Note that all_v can also be None, in that case, this function only computes Ju.
676    reduced_jacobians: List[List[torch.Tensor]] = []
677    for i, (inp_idx, u) in enumerate(zip(inp_indices, all_u)):
678        all_Ju = _get_numerical_jvp_wrt_specific_input(
679            fn, inp_idx, inputs, u, eps, is_forward_ad
680        )
681        # Filter out the Ju for non floating point outputs
682        filtered_Ju = []
683        func_out = _as_tuple(func_out)
684        assert len(all_Ju) == len(func_out)
685        for Ju, output in zip(all_Ju, func_out):
686            if _is_float_or_complex_tensor(output):
687                filtered_Ju.append(Ju)
688            else:
689                # TODO: handle the other Ju
690                pass
691        if all_v is not None:
692            jacobian_scalars: List[torch.Tensor] = []
693            for v, Ju in zip(all_v, filtered_Ju):
694                jacobian_scalars.append(_dot_with_type_promotion(v, Ju))
695            reduced_jacobians.append(jacobian_scalars)
696        else:
697            reduced_jacobians.append(filtered_Ju)
698    return reduced_jacobians
699
700
701def _check_jacobians_equal(j1, j2, atol):
702    # Check whether the max difference between two Jacobian tensors are within some
703    # tolerance `atol`.
704    for j1_x, j2_x in zip(j1, j2):
705        if j1_x.numel() != 0 and (j1_x - j2_x).abs().max() > atol:
706            return False
707    return True
708
709
710def _stack_and_check_tensors(
711    list_of_list_of_tensors, inputs, numel_outputs
712) -> Tuple[Tuple[torch.Tensor, ...], bool, bool]:
713    # For the ith tensor in the inner list checks whether it has the same size and
714    # dtype as the ith differentiable input.
715    out_jacobians = _allocate_jacobians_with_inputs(inputs, numel_outputs)
716    diff_input_list = list(_iter_tensors(inputs, True))
717    correct_grad_sizes = True
718    correct_grad_types = True
719    for i, tensor_list in enumerate(list_of_list_of_tensors):
720        inp = diff_input_list[i]
721        out_jacobian = out_jacobians[i]
722        for j, tensor in enumerate(tensor_list):
723            if tensor is not None and tensor.size() != inp.size():
724                correct_grad_sizes = False
725            elif tensor is not None and tensor.dtype != inp.dtype:
726                correct_grad_types = False
727            if tensor is None:
728                out_jacobian[:, j].zero_()
729            else:
730                dense = (
731                    tensor.to_dense() if not tensor.layout == torch.strided else tensor
732                )
733                assert out_jacobian[:, j].numel() == dense.numel()
734                out_jacobian[:, j] = dense.reshape(-1)
735    return out_jacobians, correct_grad_sizes, correct_grad_types
736
737
738FAILED_NONDET_MSG = """\n
739NOTE: If your op relies on non-deterministic operations i.e., it is listed here:
740https://pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms.html
741this failure might be expected.
742
743If you are adding a new operator, please file an issue and then use one of the
744workarounds. The workaround depends on how your test invokes gradcheck/gradgradcheck.
745If the test
746- manually invokes gradcheck/gradgradcheck, then call gradcheck/gradgradcheck
747  with `nondet_tol=<tol>` as a keyword argument.
748- is OpInfo-based (e.g., in test_ops_gradients.py), then modify the OpInfo for the test
749  to have `gradcheck_nondet_tol=<tol>`.
750- is a Module test (e.g., in common_nn.py), then modify the corresponding
751  module_test entry to have `gradcheck_nondet_tol=<tol>`
752"""
753
754
755def _check_analytical_jacobian_attributes(
756    inputs, output, nondet_tol, check_grad_dtypes, fast_mode=False, v=None
757) -> Tuple[torch.Tensor, ...]:
758    # This is used by both fast and slow mode:
759    #  - For slow mode, vjps[i][j] is the jth row of the Jacobian wrt the ith
760    #    input.
761    #  - For fast mode, vjps[i][0] is a linear combination of the rows
762    #    of the Jacobian wrt the ith input
763    diff_input_list = list(_iter_tensors(inputs, True))
764
765    def vjp_fn(grad_output):
766        return torch.autograd.grad(
767            output, diff_input_list, grad_output, retain_graph=True, allow_unused=True
768        )
769
770    # Compute everything twice to check for nondeterminism (which we call reentrancy)
771    if fast_mode:
772        vjps1 = _get_analytical_vjps_wrt_specific_output(vjp_fn, output.clone(), v)
773        vjps2 = _get_analytical_vjps_wrt_specific_output(vjp_fn, output.clone(), v)
774    else:
775        vjps1 = _compute_analytical_jacobian_rows(vjp_fn, output.clone())
776        vjps2 = _compute_analytical_jacobian_rows(vjp_fn, output.clone())
777
778    output_numel = output.numel() if not fast_mode else 1
779    jacobians1, types_ok, sizes_ok = _stack_and_check_tensors(
780        vjps1, inputs, output_numel
781    )
782    jacobians2, _, _ = _stack_and_check_tensors(vjps2, inputs, output_numel)
783    reentrant = _check_jacobians_equal(jacobians1, jacobians2, nondet_tol)
784
785    if not types_ok and check_grad_dtypes:
786        raise GradcheckError("Gradient has dtype mismatch")
787    if not sizes_ok:
788        raise GradcheckError("Analytical gradient has incorrect size")
789    if not reentrant:
790        raise GradcheckError(
791            "Backward is not reentrant, i.e., running backward with "
792            "same input and grad_output multiple times gives different values, "
793            "although analytical gradient matches numerical gradient."
794            f"The tolerance for nondeterminism was {nondet_tol}." + FAILED_NONDET_MSG
795        )
796    return jacobians1
797
798
799def _get_analytical_vJu_backward_mode(
800    inputs, outputs, nondet_tol, check_grad_dtypes, all_v, all_u
801):
802    reduced_jacobians: List[List[torch.Tensor]] = []
803    for output, v in zip(outputs, all_v):
804        all_vJ = _check_analytical_jacobian_attributes(
805            inputs, output, nondet_tol, check_grad_dtypes, fast_mode=True, v=v
806        )
807        jacobian_scalars: List[torch.Tensor] = []
808        for vJ, u in zip(all_vJ, all_u):
809            # Why do we need squeeze here? vJ is a 2-d tensor so that we can reuse
810            # the error checking logic from slow mode
811            vJ = vJ.T.squeeze(0)
812            if vJ.is_complex():  # C -> R
813                tv = torch.view_as_real(vJ.resolve_conj())
814                tr = tv.select(-1, 0)
815                ti = tv.select(-1, 1)
816                jacobian_scalars.append(tr.dot(u[0]) + 1j * ti.dot(u[1]))
817            else:  # R -> R
818                jacobian_scalars.append(vJ.dot(u))
819        reduced_jacobians.append(jacobian_scalars)
820    return reduced_jacobians
821
822
823@deprecated(
824    "`get_analytical_jacobian` was part of PyTorch's private API and not "
825    "meant to be exposed. We are deprecating it and it will be removed "
826    "in a future version of PyTorch. If you have a specific use for "
827    "this or feature request for this to be a stable API, please file "
828    "us an issue at https://github.com/pytorch/pytorch/issues/new",
829    category=FutureWarning,
830)
831def get_analytical_jacobian(inputs, output, nondet_tol=0.0, grad_out=1.0):
832    # Replicates the behavior of the old get_analytical_jacobian before the refactor
833    # This shares much of its code with _check_analytical_jacobian_attributes
834    if (
835        grad_out != 1.0
836    ):  # grad_out param is only kept for backward compatibility reasons
837        raise ValueError(
838            "Expected grad_out to be 1.0. get_analytical_jacobian no longer "
839            "supports values of grad_out != 1.0."
840        )
841    if output.is_complex():
842        raise ValueError(
843            "Expected output to be non-complex. get_analytical_jacobian no "
844            "longer supports functions that return complex outputs."
845        )
846    diff_input_list = list(_iter_tensors(inputs, True))
847
848    def vjp_fn(grad_output):
849        return torch.autograd.grad(
850            output, diff_input_list, grad_output, retain_graph=True, allow_unused=True
851        )
852
853    # Compute everything twice to check for nondeterminism (which we call reentrancy)
854    vjps1 = _compute_analytical_jacobian_rows(vjp_fn, output.clone())
855    vjps2 = _compute_analytical_jacobian_rows(vjp_fn, output.clone())
856
857    output_numel = output.numel()
858    jacobians1, types_ok, sizes_ok = _stack_and_check_tensors(
859        vjps1, inputs, output_numel
860    )
861    jacobians2, _, _ = _stack_and_check_tensors(vjps2, inputs, output_numel)
862    reentrant = _check_jacobians_equal(jacobians1, jacobians2, nondet_tol)
863
864    return jacobians1, reentrant, sizes_ok, types_ok
865
866
867def _get_analytical_jacobian(inputs, outputs, input_idx, output_idx):
868    # Computes the analytical Jacobian in slow mode for a single input-output pair.
869    # Forgoes performing checks on dtype, shape, and reentrancy.
870    jacobians = _check_analytical_jacobian_attributes(
871        inputs, outputs[output_idx], nondet_tol=float("inf"), check_grad_dtypes=False
872    )
873    return jacobians[input_idx]
874
875
876def _compute_analytical_jacobian_rows(
877    vjp_fn, sample_output
878) -> List[List[Optional[torch.Tensor]]]:
879    # Computes Jacobian row-by-row by projecting `vjp_fn` = v^T J on standard basis
880    # vectors: vjp_fn(e) = e^T J is a corresponding row of the Jacobian.
881    # NB: this function does not assume vjp_fn(v) to return tensors with the same
882    # number of elements for different v. This is checked when we later combine the
883    # rows into a single tensor.
884    grad_out_base = torch.zeros_like(
885        sample_output, memory_format=torch.legacy_contiguous_format
886    )
887    flat_grad_out = grad_out_base.view(-1)
888    # jacobians_rows[i][j] is the Jacobian jth row for the ith input
889    jacobians_rows: List[List[Optional[torch.Tensor]]] = []
890    for j in range(flat_grad_out.numel()):
891        flat_grad_out.zero_()
892        flat_grad_out[j] = 1.0  # projection for jth row of Jacobian
893        grad_inputs = vjp_fn(grad_out_base)
894        for i, d_x in enumerate(grad_inputs):
895            if j == 0:
896                jacobians_rows.append([])
897            jacobians_rows[i] += [
898                d_x.clone() if isinstance(d_x, torch.Tensor) else None
899            ]
900    return jacobians_rows
901
902
903def _get_analytical_vjps_wrt_specific_output(
904    vjp_fn, sample_output, v
905) -> List[List[Optional[torch.Tensor]]]:
906    vjps: List[List[Optional[torch.Tensor]]] = []
907    grad_inputs = vjp_fn(v.reshape(sample_output.shape))
908    for vjp in grad_inputs:
909        vjps.append([vjp.clone() if isinstance(vjp, torch.Tensor) else None])
910    return vjps
911
912
913def _check_inputs(tupled_inputs) -> bool:
914    # Make sure that gradients are saved for at least one input
915    any_input_requiring_grad = False
916    for idx, inp in enumerate(tupled_inputs):
917        if is_tensor_like(inp) and inp.requires_grad:
918            if not (inp.dtype == torch.float64 or inp.dtype == torch.complex128):
919                warnings.warn(
920                    f"Input #{idx} requires gradient and "
921                    "is not a double precision floating point or complex. "
922                    "This check will likely fail if all the inputs are "
923                    "not of double precision floating point or complex. "
924                )
925            if inp.is_sparse:
926                content = inp._values()
927            elif _is_sparse_compressed_tensor(inp):
928                content = inp.values()
929            else:
930                content = inp
931            # TODO: To cover more problematic cases, replace stride = 0 check with
932            # "any overlap in memory" once we have a proper function to check it.
933            if content.layout is not torch._mkldnn:  # type: ignore[attr-defined]
934                if not all(
935                    st > 0 or sz <= 1
936                    for st, sz in zip(content.stride(), content.size())
937                ):
938                    raise RuntimeError(
939                        f"The {idx}th input has a dimension with stride 0. gradcheck only "
940                        "supports inputs that are non-overlapping to be able to "
941                        "compute the numerical gradients correctly. You should call "
942                        ".contiguous on the input before passing it to gradcheck."
943                    )
944            any_input_requiring_grad = True
945
946    if not any_input_requiring_grad:
947        raise ValueError(
948            "gradcheck expects at least one input tensor to require gradient, "
949            "but none of the them have requires_grad=True."
950        )
951    return True
952
953
954def _check_outputs(outputs) -> None:
955    if any(_is_sparse_any_tensor(t) for t in outputs if isinstance(t, torch.Tensor)):
956        # it is easier to call to_dense() on the sparse output than
957        # to modify analytical jacobian
958        raise ValueError(
959            "Sparse output is not supported at gradcheck yet. "
960            "Please call to_dense(masked_grad=...) on the output of fn for gradcheck."
961        )
962    if any(t.layout == torch._mkldnn for t in outputs if isinstance(t, torch.Tensor)):  # type: ignore[attr-defined]
963        raise ValueError(
964            "MKLDNN output is not supported at gradcheck yet. "
965            "Please call to_dense(masked_grad=...) on the output of fn for gradcheck."
966        )
967
968
969def _check_no_differentiable_outputs(
970    func, inputs, func_out, eps, *, is_forward_ad
971) -> bool:
972    # When there are no differentiable outputs, numerical gradient for a function is
973    # expected to be zero.
974    jacobians_all_inputs_outputs = _get_numerical_jacobian(
975        func, inputs, func_out, eps=eps, is_forward_ad=is_forward_ad
976    )
977    for jacobians_all_outputs_and_fixed_input in jacobians_all_inputs_outputs:
978        for jacobian in jacobians_all_outputs_and_fixed_input:
979            if torch.ne(jacobian, 0).sum() > 0:
980                raise GradcheckError(
981                    "Numerical gradient for function expected to be zero"
982                )
983    return True
984
985
986def _check_no_differentiable_outputs_fast(
987    func, func_out, all_inputs, inputs_indices, all_u, eps, nondet_tol
988):
989    for inp_idx, u in zip(inputs_indices, all_u):
990        jvps = _get_numerical_jvp_wrt_specific_input(func, inp_idx, all_inputs, u, eps)
991        for jvp in jvps:
992            if jvp.numel() == 0:
993                continue
994            if (jvp - torch.zeros_like(jvp)).abs().max() > nondet_tol:
995                raise GradcheckError(
996                    "Numerical gradient for function expected to be zero"
997                )
998    return True
999
1000
1001FAILED_BATCHED_GRAD_MSG = """
1002gradcheck or gradgradcheck failed while testing batched gradient computation.
1003This could have been invoked in a number of ways (via a test that calls
1004gradcheck/gradgradcheck directly or via an autogenerated test).
1005
1006If you are adding a new operator, please file an issue and then use one of the
1007workarounds. The workaround depends on how your test invokes gradcheck/gradgradcheck.
1008If the test
1009- manually invokes gradcheck/gradgradcheck, then call gradcheck/gradgradcheck
1010  with `check_batched_grad=False` as a keyword argument.
1011- is OpInfo-based (e.g., in test_ops_gradients.py), then modify the OpInfo for the test
1012  to have `check_batched_grad=False` and/or `check_batched_gradgrad=False`.
1013
1014If you're modifying an existing operator that supports batched grad computation,
1015or wish to make a new operator work with batched grad computation, please read
1016the following.
1017
1018To compute batched grads (e.g., jacobians, hessians), we vmap over the backward
1019computation. The most common failure case is if there is a 'vmap-incompatible
1020operation' in the backward pass. Please see
1021NOTE: [How to write vmap-compatible backward formulas]
1022in the codebase for an explanation of how to fix this.
1023""".strip()
1024
1025FAILED_BATCHED_GRAD_MSG_FWD_AD = """
1026gradcheck failed while testing batched gradient computation with forward-mode AD.
1027This test is enabled automatically when both `check_batched_grad=True`
1028and `check_forward_ad=True`, but can be disabled in the following ways
1029dependong on how the test was invoked (via a test that calls gradcheck
1030directly or via an autogenerated test).
1031
1032If you are adding a new operator, please file an issue and then use one of the
1033workarounds. The workaround depends on how your test invokes gradcheck/gradgradcheck.
1034If the test
1035- manually invokes gradcheck/gradgradcheck, then call gradcheck/gradgradcheck
1036  with `check_batched_forward_grad=False` as a keyword argument.
1037- is OpInfo-based (e.g., in test_ops_gradients.py), then modify the OpInfo for the test
1038  to have `check_batched_forward_grad=False`
1039"""
1040
1041
1042def _get_failed_batched_grad_test_msg(
1043    output_idx, input_idx, res, exp, is_forward_ad=False
1044):
1045    return f"""
1046For output {output_idx} and input {input_idx}:
1047
1048{FAILED_BATCHED_GRAD_MSG_FWD_AD if is_forward_ad else FAILED_BATCHED_GRAD_MSG}
1049
1050Got:
1051{res}
1052
1053Expected:
1054{exp}
1055""".strip()
1056
1057
1058def _test_batched_grad_forward_ad(func, inputs) -> bool:
1059    fwAD = torch.autograd.forward_ad  # To avoid early import issues (do we need this?)
1060    assert isinstance(inputs, tuple)
1061
1062    for input_idx, current_input in enumerate(inputs):
1063        if not (is_tensor_like(current_input) and current_input.requires_grad):
1064            continue
1065
1066        def jvp(tangent: torch.Tensor):
1067            with fwAD.dual_level():
1068                dual = fwAD.make_dual(current_input.detach(), tangent)
1069                inputs_with_dual = tuple(
1070                    dual
1071                    if idx == input_idx
1072                    else (inp.detach() if is_tensor_like(inp) else inp)
1073                    for idx, inp in enumerate(inputs)
1074                )
1075                dual_outputs = _as_tuple(func(*inputs_with_dual))
1076                ret = []
1077                for dual_output in dual_outputs:
1078                    if dual_output is None:
1079                        continue
1080                    primal_out, tangent_out = fwAD.unpack_dual(dual_output)
1081                    if tangent_out is not None:
1082                        ret.append(tangent_out)
1083                    else:
1084                        ret.append(
1085                            torch.zeros(
1086                                [], dtype=primal_out.dtype, device=primal_out.device
1087                            ).expand(primal_out.shape)
1088                        )
1089                return tuple(ret)
1090
1091        if not _is_float_or_complex_tensor(current_input):
1092            continue
1093
1094        tangents = [torch.randn_like(current_input) for _ in range(2)]
1095        expected = [jvp(t) for t in tangents]
1096        expected = [torch.stack(shards) for shards in zip(*expected)]
1097
1098        try:
1099            result = _vmap(jvp)(torch.stack(tangents))
1100        except RuntimeError as ex:
1101            # Rethrow to provide a better error message
1102            raise GradcheckError(
1103                f"While computing batched gradients, got: {ex}\n\n{FAILED_BATCHED_GRAD_MSG_FWD_AD}"
1104            ) from ex
1105
1106        for input_idx, (res, exp) in enumerate(zip(result, expected)):
1107            if torch.allclose(res, exp):
1108                continue
1109            raise GradcheckError(
1110                _get_failed_batched_grad_test_msg(
1111                    input_idx, input_idx, res, exp, is_forward_ad=True
1112                )
1113            )
1114    return True
1115
1116
1117def _test_batched_grad(input, output, output_idx) -> bool:
1118    # NB: _test_batched_grad compares two autograd.grad invocations with a single
1119    # vmap(autograd.grad) invocation. It's not exactly a "gradcheck" in the
1120    # sense that we're not comparing an analytical jacobian with a numeric one,
1121    # but it is morally similar (we could have computed a full analytic jac
1122    # via vmap, but that is potentially slow)
1123    diff_input_list = list(_iter_tensors(input, True))
1124    grad = functools.partial(
1125        torch.autograd.grad,
1126        output,
1127        diff_input_list,
1128        retain_graph=True,
1129        allow_unused=True,
1130    )
1131
1132    def vjp(v):
1133        results = grad(v)
1134        results = tuple(
1135            grad
1136            if grad is not None
1137            else torch.zeros([], dtype=inp.dtype, device=inp.device).expand(inp.shape)
1138            for grad, inp in zip(results, diff_input_list)
1139        )
1140        return results
1141
1142    grad_outputs = [torch.randn_like(output) for _ in range(2)]
1143
1144    expected = [vjp(gO) for gO in grad_outputs]
1145    expected = [torch.stack(shards) for shards in zip(*expected)]
1146
1147    # Squash warnings since these are expected to happen in most cases
1148    # NB: this doesn't work for CUDA tests: https://github.com/pytorch/pytorch/issues/50209
1149    with warnings.catch_warnings():
1150        warnings.filterwarnings("ignore", message="There is a performance drop")
1151        warnings.filterwarnings("ignore", message="Please use torch.vmap")
1152        try:
1153            result = vmap(vjp)(torch.stack(grad_outputs))
1154        except RuntimeError as ex:
1155            # It's OK that we're not raising the error at the correct callsite.
1156            # That's because the callsite is always going to inside the Python
1157            # autograd.grad instead of the C++ traceback of what line in the
1158            # backward formula
1159            raise GradcheckError(
1160                f"While computing batched gradients, got: {ex}\n\n{FAILED_BATCHED_GRAD_MSG}"
1161            ) from ex
1162
1163    for input_idx, (res, exp) in enumerate(zip(result, expected)):
1164        if torch.allclose(res, exp):
1165            continue
1166        raise GradcheckError(
1167            _get_failed_batched_grad_test_msg(output_idx, input_idx, res, exp)
1168        )
1169    return True
1170
1171
1172def _test_backward_mul_by_grad_output(outputs, inputs, masked) -> bool:
1173    # Tests that backward is multiplied by grad_output
1174    diff_input_list: List[torch.Tensor] = list(_iter_tensors(inputs, True))
1175    if not diff_input_list:
1176        raise GradcheckError("no Tensors requiring grad found in input")
1177    grads_input = torch.autograd.grad(
1178        outputs,
1179        diff_input_list,
1180        [
1181            torch.zeros_like(o, memory_format=torch.legacy_contiguous_format)
1182            for o in outputs
1183        ],
1184        allow_unused=True,
1185    )
1186    for gi, di in zip(grads_input, diff_input_list):
1187        if gi is None:
1188            continue
1189        if isinstance(gi, torch.Tensor) and gi.layout != torch.strided:
1190            if gi.layout != di.layout:
1191                raise GradcheckError(
1192                    "grad is incorrect layout ("
1193                    + str(gi.layout)
1194                    + " is not "
1195                    + str(di.layout)
1196                    + ")"
1197                )
1198            if _is_sparse_any_tensor(gi):
1199                sparse_kind = str(gi.layout).replace("torch.", "").replace("_coo", "")
1200                if gi.sparse_dim() != di.sparse_dim():
1201                    raise GradcheckError(
1202                        f"grad is {sparse_kind} tensor, but has incorrect sparse_dim"
1203                        f" {gi.sparse_dim()}, expected {di.sparse_dim()}"
1204                    )
1205                if gi.dense_dim() != di.dense_dim():
1206                    raise GradcheckError(
1207                        f"grad is {sparse_kind} tensor, but has incorrect dense_dim"
1208                        f" {gi.dense_dim()}, expected {di.dense_dim()}"
1209                    )
1210            gi = gi.to_dense()
1211            di = di.to_dense()
1212        if masked:
1213            if not torch.allclose(gi, torch.zeros_like(gi)):
1214                raise GradcheckError("backward not multiplied by grad_output")
1215        elif not gi.eq(0).all():
1216            raise GradcheckError("backward not multiplied by grad_output")
1217        if gi.dtype != di.dtype:
1218            raise GradcheckError("grad is incorrect type")
1219        if gi.device != di.device:
1220            raise GradcheckError("grad is incorrect device")
1221        if gi.size() != di.size():
1222            raise GradcheckError("grad is incorrect size")
1223    return True
1224
1225
1226def _test_undefined_forward_mode(func, outputs, inputs):
1227    fwAD = torch.autograd.forward_ad
1228
1229    inp_tensors_idx, inp_tensors = _get_inp_tensors(inputs)
1230    all_v, all_u, all_u_dense = _make_vectors(inp_tensors, outputs, use_forward_ad=True)
1231
1232    tensor_inputs = tuple(i for i in inputs if is_tensor_like(i) and i.requires_grad)
1233
1234    with fwAD.dual_level():
1235        fw_grads = []
1236        dual_inputs = []
1237        tensor_indices = set()
1238        for i, inp in enumerate(inputs):
1239            if is_tensor_like(inp) and inp.requires_grad:
1240                if inp.layout == torch._mkldnn:  # type: ignore[attr-defined]
1241                    raise ValueError(
1242                        "MKLDNN inputs are not support for forward AD gradcheck."
1243                    )
1244
1245                inp = fwAD.make_dual(inp.detach(), torch.zeros_like(inp))
1246                # If inp is a differentiable view, the dual might not be the tangent given to
1247                # make_dual, so read it explicitly from the dual tensor
1248                fw_grads.append(fwAD.unpack_dual(inp)[1])
1249                tensor_indices.add(i)
1250            dual_inputs.append(inp)
1251
1252        for i, (fw_grad, u) in enumerate(zip(fw_grads, all_u)):
1253            fw_grad.copy_(u.view_as(fw_grad))
1254
1255        for idx, inp in enumerate(inputs):
1256            if idx not in tensor_indices:
1257                continue
1258            dual_inp_obj = dual_inputs[idx]
1259
1260            # case 1 (Materialized Zero Tensor Tangent)
1261            dual_inputs[idx] = fwAD.make_dual(inp.detach(), torch.zeros_like(inp))
1262            raw_outputs = _as_tuple(func(*dual_inputs))
1263            dual_outputs1 = filter(_is_float_or_complex_tensor, raw_outputs)
1264
1265            # case 2 (Efficient Zero Tensor Tangent since we don't make a dual object and pass a regular tensor)
1266            dual_inputs[idx] = inp.detach()
1267            raw_outputs = _as_tuple(func(*dual_inputs))
1268            dual_outputs2 = filter(_is_float_or_complex_tensor, raw_outputs)
1269
1270            # reset
1271            dual_inputs[idx] = dual_inp_obj
1272
1273            for index_o, (d_o1, d_o2) in enumerate(zip(dual_outputs1, dual_outputs2)):
1274                val1, res1 = fwAD.unpack_dual(d_o1)
1275                val2, res2 = fwAD.unpack_dual(d_o2)
1276
1277                if not (res1 is None or res2 is None):
1278                    if not torch.allclose(res1, res2):
1279                        raise GradcheckError(
1280                            "Mismatch in tangent values for output with index: ",
1281                            index_o,
1282                            " when input: ",
1283                            inp,
1284                            " has an undefined tangent value. ",
1285                            " Got: ",
1286                            res1,
1287                            " but expected: ",
1288                            res2,
1289                        )
1290    return True
1291
1292
1293def _test_undefined_backward_mode(func, outputs, inputs) -> bool:
1294    diff_input_list: List[torch.Tensor] = list(_iter_tensors(inputs, True))
1295    if not diff_input_list:
1296        raise GradcheckError("no Tensors requiring grad found in input")
1297
1298    def warn_bc_breaking():
1299        warnings.warn(
1300            "Backwards compatibility: New undefined gradient support checking "
1301            "feature is enabled by default, but it may break existing callers "
1302            "of this function. If this is true for you, you can call this "
1303            'function with "check_undefined_grad=False" to disable the feature'
1304        )
1305
1306    def check_undefined_grad_support(output_to_check):
1307        grads_output = [
1308            torch.zeros_like(o, memory_format=torch.legacy_contiguous_format)
1309            for o in output_to_check
1310        ]
1311        try:
1312            grads_input = torch.autograd.grad(
1313                output_to_check, diff_input_list, grads_output, allow_unused=True
1314            )
1315        except RuntimeError as e:
1316            warn_bc_breaking()
1317            raise GradcheckError(
1318                "Expected backward function to handle undefined output grads. "
1319                'Please look at "Notes about undefined output gradients" in '
1320                '"tools/autograd/derivatives.yaml"'
1321            ) from e
1322
1323        for gi, i in zip(grads_input, diff_input_list):
1324            if (gi is not None) and (not gi.eq(0).all()):
1325                warn_bc_breaking()
1326                raise GradcheckError(
1327                    "Expected all input grads to be undefined or zero when all output grads are undefined "
1328                    'or zero. Please look at "Notes about undefined output gradients" in '
1329                    '"tools/autograd/derivatives.yaml"'
1330                )
1331        return True
1332
1333    # All backward functions must work properly if all output grads are undefined
1334    outputs_to_check = [
1335        [
1336            torch._C._functions.UndefinedGrad()(o)
1337            for o in _differentiable_outputs(func(*inputs))
1338            # This check filters out Tensor-likes that aren't instances of Tensor.
1339            if isinstance(o, torch.Tensor)
1340        ]
1341    ]
1342
1343    # If there are multiple output grads, we should be able to undef one at a time without error
1344    if len(outputs_to_check[0]) > 1:
1345        for undef_grad_idx in range(len(outputs)):
1346            output_to_check = _differentiable_outputs(func(*inputs))
1347            outputs_to_check.append(
1348                [
1349                    torch._C._functions.UndefinedGrad()(o)
1350                    if idx == undef_grad_idx
1351                    else o
1352                    for idx, o in enumerate(output_to_check)
1353                ]
1354            )
1355
1356    return all(check_undefined_grad_support(output) for output in outputs_to_check)
1357
1358
1359def _as_tuple(x):
1360    if isinstance(x, tuple):
1361        return x
1362    elif isinstance(x, list):
1363        return tuple(x)
1364    else:
1365        return (x,)
1366
1367
1368def _differentiable_outputs(x):
1369    return tuple(o for o in _as_tuple(x) if o.requires_grad)
1370
1371
1372def _get_notallclose_msg(
1373    analytical,
1374    numerical,
1375    output_idx,
1376    input_idx,
1377    complex_indices,
1378    test_imag=False,
1379    is_forward_ad=False,
1380) -> str:
1381    out_is_complex = (
1382        (not is_forward_ad) and complex_indices and output_idx in complex_indices
1383    )
1384    inp_is_complex = is_forward_ad and complex_indices and input_idx in complex_indices
1385    part = "imaginary" if test_imag else "real"
1386    element = "inputs" if is_forward_ad else "outputs"
1387    prefix = (
1388        ""
1389        if not (out_is_complex or inp_is_complex)
1390        else f"While considering the {part} part of complex {element} only, "
1391    )
1392    mode = "computed with forward mode " if is_forward_ad else ""
1393    return (
1394        prefix + "Jacobian %smismatch for output %d with respect to input %d,\n"
1395        "numerical:%s\nanalytical:%s\n"
1396        % (mode, output_idx, input_idx, numerical, analytical)
1397    )
1398
1399
1400def _transpose(matrix_of_tensors):
1401    # returns list of tuples
1402    return list(zip(*matrix_of_tensors))
1403
1404
1405def _real_and_imag_output(fn):
1406    # returns new functions real(fn), and imag(fn) where real(fn) and imag(fn) behave the same as
1407    # the original fn, except torch.real or torch.imag are applied to the complex outputs
1408    def apply_to_c_outs(fn, fn_to_apply):
1409        def wrapped_fn(*inputs):
1410            outs = _as_tuple(fn(*inputs))
1411            return tuple(fn_to_apply(o) if o.is_complex() else o for o in outs)
1412
1413        return wrapped_fn
1414
1415    return apply_to_c_outs(fn, torch.real), apply_to_c_outs(fn, torch.imag)
1416
1417
1418def _real_and_imag_input(fn, complex_inp_indices, tupled_inputs):
1419    # returns new functions that take real inputs instead of complex inputs as
1420    # (x, y) -> fn(x + y * 1j). And it computes: inp -> fn(inp + y * 1j) and inp -> fn(x + inp * 1j).
1421    # In each case, the other part is considered constant.
1422    # We do not use 0 for the constant here to make sure we always call the user function with a valid input.
1423    def apply_to_c_inps(fn, fn_to_apply):
1424        def wrapped_fn(*inputs):
1425            new_inputs = list(inputs)
1426            for should_be_complex in complex_inp_indices:
1427                new_inputs[should_be_complex] = fn_to_apply(
1428                    new_inputs[should_be_complex], tupled_inputs[should_be_complex]
1429                )
1430            return _as_tuple(fn(*new_inputs))
1431
1432        return wrapped_fn
1433
1434    real_fn = apply_to_c_inps(fn, lambda inp, orig: inp + orig.imag * 1j)
1435    imag_fn = apply_to_c_inps(fn, lambda inp, orig: orig.real + inp * 1j)
1436    return real_fn, imag_fn
1437
1438
1439def _gradcheck_real_imag(
1440    gradcheck_fn,
1441    func,
1442    func_out,
1443    tupled_inputs,
1444    outputs,
1445    eps,
1446    rtol,
1447    atol,
1448    check_grad_dtypes,
1449    check_forward_ad,
1450    check_backward_ad,
1451    nondet_tol,
1452    check_undefined_grad,
1453):
1454    complex_out_indices = [i for i, o in enumerate(outputs) if o.is_complex()]
1455    has_any_complex_output = any(o.is_complex() for o in _as_tuple(func_out))
1456    if check_backward_ad:
1457        if has_any_complex_output:
1458            real_fn, imag_fn = _real_and_imag_output(func)
1459
1460            imag_func_out = imag_fn(*tupled_inputs)
1461            imag_outputs = _differentiable_outputs(imag_func_out)
1462            gradcheck_fn(
1463                imag_fn,
1464                imag_func_out,
1465                tupled_inputs,
1466                imag_outputs,
1467                eps,
1468                rtol,
1469                atol,
1470                check_grad_dtypes,
1471                nondet_tol,
1472                complex_indices=complex_out_indices,
1473                test_imag=True,
1474            )
1475
1476            real_func_out = real_fn(*tupled_inputs)
1477            real_outputs = _differentiable_outputs(real_func_out)
1478            gradcheck_fn(
1479                real_fn,
1480                real_func_out,
1481                tupled_inputs,
1482                real_outputs,
1483                eps,
1484                rtol,
1485                atol,
1486                check_grad_dtypes,
1487                nondet_tol,
1488                complex_indices=complex_out_indices,
1489            )
1490        else:
1491            gradcheck_fn(
1492                func,
1493                func_out,
1494                tupled_inputs,
1495                outputs,
1496                eps,
1497                rtol,
1498                atol,
1499                check_grad_dtypes,
1500                nondet_tol,
1501            )
1502
1503    if check_forward_ad:
1504        complex_inp_indices = [
1505            i
1506            for i, inp in enumerate(tupled_inputs)
1507            if is_tensor_like(inp) and inp.is_complex()
1508        ]
1509        if complex_inp_indices:
1510            real_fn, imag_fn = _real_and_imag_input(
1511                func, complex_inp_indices, tupled_inputs
1512            )
1513
1514            imag_inputs = [
1515                inp.imag if is_tensor_like(inp) and inp.is_complex() else inp
1516                for inp in tupled_inputs
1517            ]
1518            imag_func_out = imag_fn(*imag_inputs)
1519            diff_imag_func_out = _differentiable_outputs(imag_func_out)
1520            gradcheck_fn(
1521                imag_fn,
1522                imag_func_out,
1523                imag_inputs,
1524                diff_imag_func_out,
1525                eps,
1526                rtol,
1527                atol,
1528                check_grad_dtypes,
1529                nondet_tol,
1530                complex_indices=complex_inp_indices,
1531                test_imag=True,
1532                use_forward_ad=True,
1533            )
1534
1535            real_inputs = [
1536                inp.real if is_tensor_like(inp) and inp.is_complex() else inp
1537                for inp in tupled_inputs
1538            ]
1539            real_func_out = real_fn(*real_inputs)
1540            diff_real_func_out = _differentiable_outputs(real_func_out)
1541            gradcheck_fn(
1542                real_fn,
1543                real_func_out,
1544                real_inputs,
1545                diff_real_func_out,
1546                eps,
1547                rtol,
1548                atol,
1549                check_grad_dtypes,
1550                nondet_tol,
1551                complex_indices=complex_inp_indices,
1552                use_forward_ad=True,
1553            )
1554            if check_undefined_grad:
1555                _test_undefined_forward_mode(imag_fn, imag_func_out, imag_inputs)
1556                _test_undefined_forward_mode(real_fn, real_func_out, real_inputs)
1557        else:
1558            gradcheck_fn(
1559                func,
1560                func_out,
1561                tupled_inputs,
1562                outputs,
1563                eps,
1564                rtol,
1565                atol,
1566                check_grad_dtypes,
1567                nondet_tol,
1568                use_forward_ad=True,
1569            )
1570            if check_undefined_grad:
1571                _test_undefined_forward_mode(func, outputs, tupled_inputs)
1572
1573
1574def _slow_gradcheck(
1575    func,
1576    func_out,
1577    tupled_inputs,
1578    outputs,
1579    eps,
1580    rtol,
1581    atol,
1582    check_grad_dtypes,
1583    nondet_tol,
1584    *,
1585    use_forward_ad=False,
1586    complex_indices=None,
1587    test_imag=False,
1588    masked=False,
1589):
1590    func_out = _as_tuple(func_out)
1591    if not outputs:
1592        return _check_no_differentiable_outputs(
1593            func, tupled_inputs, func_out, eps=eps, is_forward_ad=use_forward_ad
1594        )
1595    tupled_inputs_numerical = tupled_inputs if masked else _densify(tupled_inputs)
1596
1597    numerical = _transpose(
1598        _get_numerical_jacobian(
1599            func,
1600            tupled_inputs_numerical,
1601            func_out,
1602            eps=eps,
1603            is_forward_ad=use_forward_ad,
1604        )
1605    )
1606    # Note: [numerical vs analytical output length]
1607    # The numerical path returns jacobian quantity for all outputs, even if requires_grad of that
1608    # output is False. This behavior is necessary for _check_no_differentiable_outputs to work.
1609    numerical = [nj for o, nj in zip(func_out, numerical) if o.requires_grad]
1610    if use_forward_ad:
1611        analytical_forward = _get_analytical_jacobian_forward_ad(
1612            func, tupled_inputs, func_out, check_grad_dtypes=check_grad_dtypes
1613        )
1614
1615        for i, n_per_out in enumerate(numerical):
1616            for j, n in enumerate(n_per_out):
1617                a = analytical_forward[j][i]
1618                if not _allclose_with_type_promotion(a, n.to(a.device), rtol, atol):
1619                    raise GradcheckError(
1620                        _get_notallclose_msg(
1621                            a, n, i, j, complex_indices, test_imag, is_forward_ad=True
1622                        )
1623                    )
1624    else:
1625        for i, o in enumerate(outputs):
1626            analytical = _check_analytical_jacobian_attributes(
1627                tupled_inputs, o, nondet_tol, check_grad_dtypes
1628            )
1629
1630            for j, (a, n) in enumerate(zip(analytical, numerical[i])):
1631                if not _allclose_with_type_promotion(a, n.to(a.device), rtol, atol):
1632                    raise GradcheckError(
1633                        _get_notallclose_msg(a, n, i, j, complex_indices, test_imag)
1634                    )
1635
1636    return True
1637
1638
1639def _dot_with_type_promotion(u, v):
1640    assert u.dim() == 1 and v.dim() == 1
1641    return (u * v).sum()
1642
1643
1644def _allclose_with_type_promotion(a, b, rtol, atol):
1645    promoted_type = torch.promote_types(a.dtype, b.dtype)
1646    a = a.to(dtype=promoted_type)
1647    b = b.to(dtype=promoted_type)
1648    return torch.allclose(a, b, rtol, atol)
1649
1650
1651def _to_real_dtype(dtype):
1652    if dtype == torch.complex128:
1653        return torch.float64
1654    elif dtype == torch.complex64:
1655        return torch.float32
1656    else:
1657        return dtype
1658
1659
1660def _vec_from_tensor(x, generator, downcast_complex=False):
1661    # Create a random vector with the same number of elements as x and the same
1662    # dtype/device. If x is complex and downcast_complex is False, we create a
1663    # complex tensor with only real component.
1664    if x.layout == torch.sparse_coo:
1665        # For sparse, create a random sparse vec with random values in the same
1666        # indices. Make sure size is set so that it isn't inferred to be smaller.
1667        x_values = x._values()
1668        dtype = _to_real_dtype(x.dtype) if downcast_complex else x.dtype
1669        values = (
1670            torch.rand(x_values.numel(), generator=generator)
1671            .to(dtype=dtype, device=x.device)
1672            .view(x_values.shape)
1673        )
1674        values /= values.norm()
1675        vec = torch.sparse_coo_tensor(x._indices(), values, x.size(), device=x.device)
1676    elif _is_sparse_compressed_tensor(x):
1677        if x.layout in {torch.sparse_csr, torch.sparse_bsr}:
1678            compressed_indices, plain_indices = x.crow_indices(), x.col_indices()
1679        else:
1680            compressed_indices, plain_indices = x.ccol_indices(), x.row_indices()
1681        x_values = x.values()
1682        dtype = _to_real_dtype(x.dtype) if downcast_complex else x.dtype
1683        values = (
1684            torch.rand(x_values.numel(), generator=generator)
1685            .to(dtype=dtype, device=x.device)
1686            .view(x_values.shape)
1687        )
1688        values /= values.norm()
1689        vec = torch.sparse_compressed_tensor(
1690            compressed_indices,
1691            plain_indices,
1692            values,
1693            x.size(),
1694            layout=x.layout,
1695            device=x.device,
1696        )
1697    else:
1698        dtype = _to_real_dtype(x.dtype) if downcast_complex else x.dtype
1699        vec = torch.rand(x.numel(), generator=generator).to(
1700            dtype=dtype, device=x.device
1701        )
1702        vec /= vec.norm()
1703    return vec
1704
1705
1706def _get_inp_tensors(tupled_inputs):
1707    inp_idx_tup = [
1708        (i, t)
1709        for i, t in enumerate(tupled_inputs)
1710        if is_tensor_like(t) and t.requires_grad
1711    ]
1712    return [tup[0] for tup in inp_idx_tup], [tup[1] for tup in inp_idx_tup]
1713
1714
1715def _adjusted_atol(atol, u, v):
1716    # In slow gradcheck, we compare A and B element-wise, i.e., for some a, b we
1717    # allow: |a - b| < atol + rtol * b. But since we now compare q1 = v^T A u and
1718    # q2 = v^T B u, we must allow |q1 - q2| < v^T E u + rtol * v^T B u, where E is
1719    # the correctly sized matrix in which each entry is atol.
1720    #
1721    # We see that atol needs to be scaled by v^T M u (where M is an all-ones M x N
1722    # matrix): v^T M u = \sum_{i} \sum_{j} u_i * v_j = (\sum_{i} u_i)(\sum_{i} v_i)
1723    # TODO: properly handle case when u is tuple instead of only taking first element
1724    u = u[0] if isinstance(u, tuple) else u
1725    sum_u = u.sum()
1726    sum_v = 1.0 if v is None else v.sum()
1727    return atol * float(sum_u) * float(sum_v)
1728
1729
1730FAST_FAIL_SLOW_OK_MSG = """
1731Fast gradcheck failed but element-wise differences are small. This means that the
1732test might've passed in slow_mode!
1733
1734If you are adding a new operator, please file an issue and then use one of the
1735workarounds. The workaround depends on how your test invokes gradcheck/gradgradcheck:
1736
1737If the test
1738- manually invokes gradcheck/gradgradcheck, then call gradcheck/gradgradcheck
1739  with `fast_mode=False` as a keyword argument.
1740- is OpInfo-based (e.g., in test_ops_gradients.py), then modify the OpInfo for the test
1741  to have `gradcheck_fast_mode=False`
1742- is a Module test (e.g., in common_nn.py), then modify the corresponding
1743  module_test entry to have `gradcheck_fast_mode=False`
1744""".strip()
1745
1746
1747def _run_slow_mode_and_get_error(
1748    func, tupled_inputs, outputs, input_idx, output_idx, rtol, atol, eps, is_forward_ad
1749):
1750    # Compute jacobians in slow mode for better error message
1751    slow_numerical = _get_numerical_jacobian(
1752        func, tupled_inputs, outputs, eps=eps, is_forward_ad=is_forward_ad
1753    )[input_idx][output_idx]
1754    if is_forward_ad:
1755
1756        def new_fn(inp):
1757            new_inputs = list(tupled_inputs)
1758            new_inputs[input_idx] = inp
1759            return _as_tuple(func(*new_inputs))[output_idx]
1760
1761        slow_analytical = _get_analytical_jacobian_forward_ad(
1762            new_fn, (tupled_inputs[input_idx],), (outputs[output_idx],)
1763        )[0][0]
1764    else:
1765        slow_analytical = _get_analytical_jacobian(
1766            tupled_inputs, outputs, input_idx, output_idx
1767        )
1768
1769    # Assume jacobians are non-empty and have the same shape
1770    slow_max_diff = (slow_numerical - slow_analytical).abs().max()
1771
1772    slow_allclose = torch.allclose(slow_analytical, slow_numerical, rtol, atol)
1773    msg = (
1774        "\nThe above quantities relating the numerical and analytical jacobians are computed \n"
1775        "in fast mode. See: https://github.com/pytorch/pytorch/issues/53876 for more background \n"
1776        "about fast mode. Below, we recompute numerical and analytical jacobians in slow mode:\n\n"
1777        f"Numerical:\n {slow_numerical}\n"
1778        f"Analytical:\n{slow_analytical}\n\n"
1779        f"The max per-element difference (slow mode) is: {slow_max_diff}.\n"
1780    )
1781    if slow_allclose:
1782        # Slow gradcheck would've passed!
1783        msg += FAST_FAIL_SLOW_OK_MSG
1784    return msg
1785
1786
1787def _to_flat_dense_if_sparse(tensor):
1788    if _is_sparse_any_tensor(tensor):
1789        return tensor.to_dense().reshape(-1)
1790    else:
1791        return tensor
1792
1793
1794def _make_vectors(inp_tensors, outputs, *, use_forward_ad):
1795    # Use our own generator to avoid messing with the user's RNG state
1796    g_cpu = torch.Generator()
1797
1798    def _vec_from_tensor_cpu(*args):
1799        # Default allocate all tensors on CPU, so they are on the same device as the generator
1800        # even if the user specified a default device
1801        with torch.device("cpu"):
1802            return _vec_from_tensor(*args)
1803
1804    all_u = []
1805    all_u_dense = []
1806    for inp in inp_tensors:
1807        ur = _vec_from_tensor_cpu(inp, g_cpu, True)
1808        ur_dense = _to_flat_dense_if_sparse(ur)
1809        if inp.is_complex():
1810            ui = _vec_from_tensor_cpu(inp, g_cpu, True)
1811            all_u.append((ur, ui))
1812            ui_dense = _to_flat_dense_if_sparse(ui)
1813            all_u_dense.append((ur_dense, ui_dense))
1814        else:
1815            all_u.append(ur)
1816            all_u_dense.append(ur_dense)
1817    all_v = (
1818        None
1819        if use_forward_ad
1820        else [_vec_from_tensor_cpu(out, g_cpu) for out in outputs]
1821    )
1822    return all_v, all_u, all_u_dense
1823
1824
1825def _check_analytical_numerical_equal(
1826    all_analytical,
1827    all_numerical,
1828    complex_indices,
1829    tupled_inputs,
1830    outputs,
1831    func,
1832    all_v,
1833    all_u,
1834    rtol,
1835    atol,
1836    eps,
1837    test_imag,
1838    *,
1839    is_forward_ad=False,
1840):
1841    for i, all_numerical_for_input_i in enumerate(all_numerical):
1842        for j, n in enumerate(all_numerical_for_input_i):
1843            # Forward AD generates the transpose of what this function expects
1844            if is_forward_ad:
1845                a = all_analytical[i][j]
1846            else:
1847                a = all_analytical[j][i]
1848            n = n.to(device=a.device)
1849            updated_atol = _adjusted_atol(atol, all_u[i], all_v[j] if all_v else None)
1850            if not _allclose_with_type_promotion(a, n.to(a.device), rtol, updated_atol):
1851                jacobians_str = _run_slow_mode_and_get_error(
1852                    func, tupled_inputs, outputs, i, j, rtol, atol, eps, is_forward_ad
1853                )
1854                raise GradcheckError(
1855                    _get_notallclose_msg(
1856                        a, n, j, i, complex_indices, test_imag, is_forward_ad
1857                    )
1858                    + jacobians_str
1859                )
1860
1861
1862def _fast_gradcheck(
1863    func,
1864    func_out,
1865    inputs,
1866    outputs,
1867    eps,
1868    rtol,
1869    atol,
1870    check_grad_dtypes,
1871    nondet_tol,
1872    *,
1873    use_forward_ad=False,
1874    complex_indices=None,
1875    test_imag=False,
1876    masked=False,
1877):
1878    # See https://github.com/pytorch/pytorch/issues/53876 for details
1879    inp_tensors_idx, inp_tensors = _get_inp_tensors(inputs)
1880    # Backward mode computes v^T * J (VJP)
1881    # Since we computed J * u (JVP) through finite difference method, we perform an equality check
1882    # between VJP * u, v * JVP
1883    # ----
1884    # Forward mode computes J * u (JVP)
1885    # Since we already compute JVP through finite difference method,
1886    # we don't need v for correctness check here as asserted below
1887    all_v, all_u, all_u_dense = _make_vectors(
1888        inp_tensors, outputs, use_forward_ad=use_forward_ad
1889    )
1890
1891    inputs_numerical, all_u_numerical, all_v_numerical = (
1892        (inputs, all_u, all_v) if masked else _densify((inputs, all_u, all_v))
1893    )
1894
1895    numerical_vJu = _get_numerical_vJu(
1896        func,
1897        inputs_numerical,
1898        inp_tensors_idx,
1899        func_out,
1900        all_u_numerical,
1901        all_v_numerical,
1902        eps,
1903        is_forward_ad=use_forward_ad,
1904    )
1905    # TODO: replicate https://github.com/pytorch/pytorch/pull/77743 for fast gradcheck as well
1906    if use_forward_ad:
1907        assert all_v is None
1908        analytical_vJu = _get_analytical_jacobian_forward_ad(
1909            func,
1910            inputs,
1911            _as_tuple(func_out),
1912            all_u=all_u,
1913            check_grad_dtypes=check_grad_dtypes,
1914        )
1915    else:
1916        if not outputs:
1917            _check_no_differentiable_outputs_fast(
1918                func, func_out, inputs, inp_tensors_idx, all_u, eps, nondet_tol
1919            )
1920
1921        analytical_vJu = _get_analytical_vJu_backward_mode(
1922            inputs, outputs, nondet_tol, check_grad_dtypes, all_v, all_u_dense
1923        )
1924
1925    _check_analytical_numerical_equal(
1926        analytical_vJu,
1927        numerical_vJu,
1928        complex_indices,
1929        inputs,
1930        outputs,
1931        func,
1932        all_v,
1933        all_u,
1934        rtol,
1935        atol,
1936        eps,
1937        test_imag,
1938        is_forward_ad=use_forward_ad,
1939    )
1940
1941    return True
1942
1943
1944# Note [VarArg of Tensors]
1945# ~~~~~~~~~~~~~~~~~~~~~~~~
1946# 'func' accepts a vararg of tensors, which isn't expressable in the type system at the moment.
1947# If https://mypy.readthedocs.io/en/latest/additional_features.html?highlight=callable#extended-callable-types is accepted,
1948# the '...' first argument of Callable can be replaced with VarArg(Tensor).
1949# For now, we permit any input.
1950def gradcheck(
1951    func: Callable[..., Union[_TensorOrTensors]],  # See Note [VarArg of Tensors]
1952    inputs: _TensorOrTensors,
1953    *,
1954    eps: float = 1e-6,
1955    atol: float = 1e-5,
1956    rtol: float = 1e-3,
1957    raise_exception: bool = True,
1958    nondet_tol: float = 0.0,
1959    check_undefined_grad: bool = True,
1960    check_grad_dtypes: bool = False,
1961    check_batched_grad: bool = False,
1962    check_batched_forward_grad: bool = False,
1963    check_forward_ad: bool = False,
1964    check_backward_ad: bool = True,
1965    fast_mode: bool = False,
1966    masked: Optional[bool] = None,
1967) -> bool:  # noqa: D400,D205
1968    r"""Check gradients computed via small finite differences against analytical
1969    gradients wrt tensors in :attr:`inputs` that are of floating point or complex type
1970    and with ``requires_grad=True``.
1971
1972    The check between numerical and analytical gradients uses :func:`~torch.allclose`.
1973
1974    For most of the complex functions we consider for optimization purposes, no notion of
1975    Jacobian exists. Instead, gradcheck verifies if the numerical and analytical values of
1976    the Wirtinger and Conjugate Wirtinger derivatives are consistent. Because the gradient
1977    computation is done under the assumption that the overall function has a real-valued
1978    output, we treat functions with complex output in a special way. For these functions,
1979    gradcheck is applied to two real-valued functions corresponding to taking the real
1980    components of the complex outputs for the first, and taking the imaginary components
1981    of the complex outputs for the second. For more details, check out
1982    :ref:`complex_autograd-doc`.
1983
1984    .. note::
1985        The default values are designed for :attr:`input` of double precision.
1986        This check will likely fail if :attr:`input` is of less precision, e.g.,
1987        ``FloatTensor``.
1988
1989    .. note::
1990        Gradcheck may fail when evaluated on non-differentiable points
1991        because the numerically computed gradients via finite differencing may differ
1992        those computed analytically (not necessarily because either is incorrect).
1993        For more context, see :ref:`non-differentiable-func-grad`.
1994
1995    .. warning::
1996       If any checked tensor in :attr:`input` has overlapping memory, i.e.,
1997       different indices pointing to the same memory address (e.g., from
1998       :func:`torch.expand`), this check will likely fail because the numerical
1999       gradients computed by point perturbation at such indices will change
2000       values at all other indices that share the same memory address.
2001
2002    Args:
2003        func (function): a Python function that takes Tensor inputs and returns
2004            a Tensor or a tuple of Tensors
2005        inputs (tuple of Tensor or Tensor): inputs to the function
2006        eps (float, optional): perturbation for finite differences
2007        atol (float, optional): absolute tolerance
2008        rtol (float, optional): relative tolerance
2009        raise_exception (bool, optional): indicating whether to raise an exception if
2010            the check fails. The exception gives more information about the
2011            exact nature of the failure. This is helpful when debugging gradchecks.
2012        nondet_tol (float, optional): tolerance for non-determinism. When running
2013            identical inputs through the differentiation, the results must either match
2014            exactly (default, 0.0) or be within this tolerance.
2015        check_undefined_grad (bool, optional): if ``True``, check if undefined output grads
2016            are supported and treated as zeros, for ``Tensor`` outputs.
2017        check_batched_grad (bool, optional): if ``True``, check if we can compute
2018            batched gradients using prototype vmap support. Defaults to False.
2019        check_batched_forward_grad (bool, optional): if ``True``, checks if we can compute
2020            batched forward gradients using forward ad and prototype vmap support. Defaults to ``False``.
2021        check_forward_ad (bool, optional): if ``True``, check that the gradients computed with forward
2022            mode AD match the numerical ones. Defaults to ``False``.
2023        check_backward_ad (bool, optional): if ``False``, do not perform any checks that rely on
2024            backward mode AD to be implemented. Defaults to ``True``.
2025        fast_mode (bool, optional): Fast mode for gradcheck and gradgradcheck is currently only
2026            implemented for R to R functions. If none of the inputs and outputs are complex
2027            a faster implementation of gradcheck that no longer computes the entire jacobian
2028            is run; otherwise, we fall back to the slow implementation.
2029        masked (bool, optional): if ``True``, the gradients of unspecified elements of
2030            sparse tensors are ignored. Defaults to ``False``.
2031    Returns:
2032        ``True`` if all differences satisfy allclose condition
2033
2034    """
2035    assert (
2036        check_forward_ad or check_backward_ad
2037    ), "Expected at least one of check_forward_ad or check_backward_ad to be True"
2038    assert not (
2039        check_batched_grad and not check_backward_ad
2040    ), "Setting check_batched_grad=True requires check_backward_ad to be True"
2041    assert not (
2042        check_batched_forward_grad and not check_forward_ad
2043    ), "Setting check_batched_forward_grad=True requires check_forward_ad to be True"
2044    args = locals().copy()
2045    args.pop("raise_exception")
2046    if not raise_exception:
2047        try:
2048            return _gradcheck_helper(**args)
2049        except GradcheckError as e:
2050            return False
2051    else:
2052        return _gradcheck_helper(**args)
2053
2054
2055def _gradcheck_helper(
2056    func,
2057    inputs,
2058    eps,
2059    atol,
2060    rtol,
2061    nondet_tol,
2062    check_undefined_grad,
2063    check_grad_dtypes,
2064    check_batched_grad,
2065    check_batched_forward_grad,
2066    check_forward_ad,
2067    check_backward_ad,
2068    fast_mode,
2069    masked,
2070):
2071    tupled_inputs = _as_tuple(inputs)
2072    _check_inputs(tupled_inputs)
2073
2074    func_out = func(*tupled_inputs)
2075    outputs = _differentiable_outputs(func_out)
2076    _check_outputs(outputs)
2077
2078    gradcheck_fn = functools.partial(
2079        _fast_gradcheck if fast_mode else _slow_gradcheck, masked=masked
2080    )
2081    _gradcheck_real_imag(
2082        gradcheck_fn,
2083        func,
2084        func_out,
2085        tupled_inputs,
2086        outputs,
2087        eps,
2088        rtol,
2089        atol,
2090        check_grad_dtypes,
2091        check_forward_ad=check_forward_ad,
2092        check_backward_ad=check_backward_ad,
2093        nondet_tol=nondet_tol,
2094        check_undefined_grad=check_undefined_grad,
2095    )
2096
2097    if check_batched_forward_grad:
2098        _test_batched_grad_forward_ad(func, tupled_inputs)
2099
2100    # Short circuit because remaining tests rely on backward AD to be implemented
2101    if not check_backward_ad:
2102        return True
2103
2104    for i, o in enumerate(outputs):
2105        if check_batched_grad:
2106            _test_batched_grad(tupled_inputs, o, i)
2107
2108    _test_backward_mul_by_grad_output(outputs, tupled_inputs, masked)
2109
2110    if check_undefined_grad and check_backward_ad:
2111        _test_undefined_backward_mode(func, outputs, tupled_inputs)
2112    return True
2113
2114
2115def gradgradcheck(
2116    func: Callable[..., _TensorOrTensors],  # See Note [VarArg of Tensors]
2117    inputs: _TensorOrTensors,
2118    grad_outputs: Optional[_TensorOrTensors] = None,
2119    *,
2120    eps: float = 1e-6,
2121    atol: float = 1e-5,
2122    rtol: float = 1e-3,
2123    gen_non_contig_grad_outputs: bool = False,
2124    raise_exception: bool = True,
2125    nondet_tol: float = 0.0,
2126    check_undefined_grad: bool = True,
2127    check_grad_dtypes: bool = False,
2128    check_batched_grad: bool = False,
2129    check_fwd_over_rev: bool = False,
2130    check_rev_over_rev: bool = True,
2131    fast_mode: bool = False,
2132    masked: bool = False,
2133) -> bool:  # noqa: D400,D205
2134    r"""Check gradients of gradients computed via small finite differences
2135    against analytical gradients wrt tensors in :attr:`inputs` and
2136    :attr:`grad_outputs` that are of floating point or complex type and with
2137    ``requires_grad=True``.
2138
2139    This function checks that backpropagating through the gradients computed
2140    to the given :attr:`grad_outputs` are correct.
2141
2142    The check between numerical and analytical gradients uses :func:`~torch.allclose`.
2143
2144    .. note::
2145        The default values are designed for :attr:`input` and
2146        :attr:`grad_outputs` of double precision. This check will likely fail if
2147        they are of less precision, e.g., ``FloatTensor``.
2148
2149    .. warning::
2150       If any checked tensor in :attr:`input` and :attr:`grad_outputs` has
2151       overlapping memory, i.e., different indices pointing to the same memory
2152       address (e.g., from :func:`torch.expand`), this check will likely fail
2153       because the numerical gradients computed by point perturbation at such
2154       indices will change values at all other indices that share the same
2155       memory address.
2156
2157    Args:
2158        func (function): a Python function that takes Tensor inputs and returns
2159            a Tensor or a tuple of Tensors
2160        inputs (tuple of Tensor or Tensor): inputs to the function
2161        grad_outputs (tuple of Tensor or Tensor, optional): The gradients with
2162            respect to the function's outputs.
2163        eps (float, optional): perturbation for finite differences
2164        atol (float, optional): absolute tolerance
2165        rtol (float, optional): relative tolerance
2166        gen_non_contig_grad_outputs (bool, optional): if :attr:`grad_outputs` is
2167            ``None`` and :attr:`gen_non_contig_grad_outputs` is ``True``, the
2168            randomly generated gradient outputs are made to be noncontiguous
2169        raise_exception (bool, optional): indicating whether to raise an exception if
2170            the check fails. The exception gives more information about the
2171            exact nature of the failure. This is helpful when debugging gradchecks.
2172        nondet_tol (float, optional): tolerance for non-determinism. When running
2173            identical inputs through the differentiation, the results must either match
2174            exactly (default, 0.0) or be within this tolerance. Note that a small amount
2175            of nondeterminism in the gradient will lead to larger inaccuracies in
2176            the second derivative.
2177        check_undefined_grad (bool, optional): if True, check if undefined output grads
2178            are supported and treated as zeros
2179        check_batched_grad (bool, optional): if True, check if we can compute
2180            batched gradients using prototype vmap support. Defaults to False.
2181        fast_mode (bool, optional): if True, run a faster implementation of gradgradcheck that
2182            no longer computes the entire jacobian.
2183        masked (bool, optional): if True, the gradients of unspecified elements of
2184            sparse tensors are ignored (default, False).
2185    Returns:
2186        True if all differences satisfy allclose condition
2187    """
2188    assert (
2189        check_fwd_over_rev or check_rev_over_rev
2190    ), "Expected at least one of check_fwd_over_rev or check_rev_over_rev to be True"
2191    assert not (
2192        check_undefined_grad and not check_rev_over_rev
2193    ), "Setting check_undefined_grad=True requires check_rev_over_rev to be True"
2194    assert not (
2195        check_batched_grad and not check_rev_over_rev
2196    ), "Setting check_batched_grad=True requires check_rev_over_rev to be True"
2197    # TODO: do we want to test this too?
2198    # assert not (check_batched_forward_grad and not check_fwd_over_rev), (
2199    #     "Setting check_batched_forward_grad=True requires check_fwd_over_rev to be True")
2200    tupled_inputs = _as_tuple(inputs)
2201
2202    if grad_outputs is None:
2203        # If grad_outputs is not specified, create random Tensors of the same shape, type, and device as the outputs
2204
2205        outputs = _differentiable_outputs(func(*tupled_inputs))
2206        tupled_grad_outputs = tuple(
2207            torch.testing.make_tensor(
2208                x.shape,
2209                dtype=x.dtype
2210                if x.is_floating_point() or x.is_complex()
2211                else torch.double,
2212                device=x.device,
2213                low=-1,
2214                high=1,
2215                requires_grad=True,
2216                noncontiguous=gen_non_contig_grad_outputs,
2217            )
2218            for x in outputs
2219        )
2220    else:
2221        tupled_grad_outputs = _as_tuple(grad_outputs)
2222
2223    num_outputs = len(tupled_grad_outputs)
2224
2225    # NB: We need to save the requires_grad information about the inputs here because gradcheck detaches inputs
2226    #     before running forward mode AD
2227    diff_input_args_indices = {
2228        i for i, x in enumerate(tupled_inputs) if is_tensor_like(x) and x.requires_grad
2229    }
2230    diff_grad_output_indices = {
2231        i for i, x in enumerate(tupled_grad_outputs) if x.requires_grad
2232    }
2233
2234    def new_func(*args):
2235        # Restore the requires_grad information
2236        input_args = tuple(
2237            x.requires_grad_() if i in diff_input_args_indices else x
2238            for i, x in enumerate(args[:-num_outputs])
2239        )
2240        outputs = _differentiable_outputs(func(*input_args))
2241        grad_outputs = tuple(
2242            x.requires_grad_() if i in diff_grad_output_indices else x
2243            for i, x in enumerate(args[-num_outputs:])
2244        )
2245        diff_input_args = tuple(
2246            x for i, x in enumerate(input_args) if i in diff_input_args_indices
2247        )
2248        grad_inputs = torch.autograd.grad(
2249            outputs, diff_input_args, grad_outputs, create_graph=True, allow_unused=True
2250        )
2251        grad_inputs = tuple(g for g in grad_inputs if g is not None)
2252        return grad_inputs
2253
2254    return gradcheck(
2255        new_func,
2256        tupled_inputs + tupled_grad_outputs,
2257        eps=eps,
2258        atol=atol,
2259        rtol=rtol,
2260        raise_exception=raise_exception,
2261        nondet_tol=nondet_tol,
2262        check_undefined_grad=check_undefined_grad,
2263        check_grad_dtypes=check_grad_dtypes,
2264        check_batched_grad=check_batched_grad,
2265        fast_mode=fast_mode,
2266        check_forward_ad=check_fwd_over_rev,
2267        check_backward_ad=check_rev_over_rev,
2268        masked=masked,
2269    )
2270