xref: /aosp_15_r20/external/pytorch/torch/testing/_internal/opinfo/definitions/linalg.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: ignore-errors
2
3import itertools
4import random
5import unittest
6from functools import partial
7from itertools import chain, product
8from typing import Iterable, List, Tuple
9
10import numpy as np
11from numpy import inf
12
13import torch
14from torch.testing import make_tensor
15from torch.testing._internal.common_cuda import (
16    _get_magma_version,
17    _get_torch_cuda_version,
18    with_tf32_off,
19)
20from torch.testing._internal.common_device_type import (
21    has_cusolver,
22    skipCPUIfNoLapack,
23    skipCUDAIf,
24    skipCUDAIfNoCusolver,
25    skipCUDAIfNoMagma,
26    skipCUDAIfNoMagmaAndNoCusolver,
27    skipCUDAIfNoMagmaAndNoLinalgsolver,
28    skipCUDAIfRocm,
29    tol,
30    toleranceOverride,
31)
32from torch.testing._internal.common_dtype import (
33    all_types_and_complex,
34    all_types_and_complex_and,
35    floating_and_complex_types,
36    floating_and_complex_types_and,
37    get_all_complex_dtypes,
38)
39from torch.testing._internal.common_utils import (
40    GRADCHECK_NONDET_TOL,
41    IS_MACOS,
42    make_fullrank_matrices_with_distinct_singular_values,
43    skipIfSlowGradcheckEnv,
44    slowTest,
45    TEST_WITH_ROCM,
46)
47from torch.testing._internal.opinfo.core import (
48    clone_sample,
49    DecorateInfo,
50    ErrorInput,
51    gradcheck_wrapper_hermitian_input,
52    L,
53    M,
54    OpInfo,
55    ReductionOpInfo,
56    S,
57    SampleInput,
58)
59from torch.testing._internal.opinfo.refs import PythonRefInfo, ReductionPythonRefInfo
60
61
62def sample_kwargs_vector_norm(t, **kwargs):
63    # orders with / without identity
64    def ords():
65        has_id = (6, 4, 2, 1, 0, 0.9)
66        no_id = (inf, -2.1, -inf)
67        if t.numel() == 0:
68            dim = kwargs.get("dim")
69            if dim is None:
70                return has_id
71            if not isinstance(dim, Iterable):
72                dim = (dim,)
73            for d in dim:
74                if t.size(d) == 0:
75                    return has_id
76        return has_id + no_id
77
78    return (((), dict(ord=o)) for o in ords())
79
80
81def sample_inputs_svd(op_info, device, dtype, requires_grad=False, **kwargs):
82    make_fullrank = make_fullrank_matrices_with_distinct_singular_values
83    make_arg = partial(
84        make_fullrank, dtype=dtype, device=device, requires_grad=requires_grad
85    )
86
87    is_linalg_svd = "linalg.svd" in op_info.name
88    batches = [(), (0,), (3,)]
89    ns = [0, 3, 5]
90
91    def uniformize(usv):
92        S = usv[1]
93        k = S.shape[-1]
94        U = usv[0][..., :k]
95        Vh = usv[2] if is_linalg_svd else usv[2].mH
96        Vh = Vh[..., :k, :]
97        return U, S, Vh
98
99    def fn_U(usv):
100        U, _, _ = uniformize(usv)
101        return U.abs()
102
103    def fn_S(usv):
104        return uniformize(usv)[1]
105
106    def fn_Vh(usv):
107        # We also return S to test
108        _, S, Vh = uniformize(usv)
109        return S, Vh.abs()
110
111    def fn_UVh(usv):
112        U, S, Vh = uniformize(usv)
113        return U @ Vh, S
114
115    fns = (fn_U, fn_S, fn_Vh, fn_UVh)
116
117    fullmat = "full_matrices" if is_linalg_svd else "some"
118
119    for batch, n, k, fullmat_val, fn in product(batches, ns, ns, (True, False), fns):
120        shape = batch + (n, k)
121        yield SampleInput(
122            make_arg(*shape), kwargs={fullmat: fullmat_val}, output_process_fn_grad=fn
123        )
124
125
126def sample_inputs_cross(op_info, device, dtype, requires_grad, **kwargs):
127    make_arg = partial(
128        make_tensor, dtype=dtype, device=device, requires_grad=requires_grad
129    )
130    yield SampleInput(make_arg((S, 3)), args=(make_arg((S, 3)),))
131    yield SampleInput(
132        make_arg((S, 3, S)), args=(make_arg((S, 3, S)),), kwargs=dict(dim=1)
133    )
134    yield SampleInput(make_arg((1, 3)), args=(make_arg((S, 3)),), kwargs=dict(dim=-1))
135
136
137def error_inputs_cross(op_info, device, **kwargs):
138    make_arg = partial(make_tensor, device=device, dtype=torch.float32)
139
140    sample = SampleInput(input=make_arg((S, 3)), args=(make_arg((S, 1)),))
141    err = "inputs dimension -1 must have length 3"
142    yield ErrorInput(sample, error_regex=err, error_type=RuntimeError)
143
144    sample = SampleInput(input=make_arg((5, S, 3)), args=(make_arg((S, 3)),))
145    err = "inputs must have the same number of dimensions"
146    yield ErrorInput(sample, error_regex=err, error_type=RuntimeError)
147
148    sample = SampleInput(input=make_arg((S, 2)), args=(make_arg((S, 2)),))
149    err = "must have length 3"
150    yield ErrorInput(sample, error_regex=err, error_type=RuntimeError)
151
152    sample = SampleInput(
153        input=make_arg((S, 2)), args=(make_arg((S, 2)),), kwargs=dict(dim=2)
154    )
155    err = "Dimension out of range"
156    yield ErrorInput(sample, error_regex=err, error_type=IndexError)
157
158
159def sample_inputs_householder_product(op_info, device, dtype, requires_grad, **kwargs):
160    """
161    This function generates input for torch.linalg.householder_product (torch.orgqr).
162    The first argument should be a square matrix or batch of square matrices, the second argument is a vector or batch of vectors.
163    Empty, square, rectangular, batched square and batched rectangular input is generated.
164    """
165    make_arg = partial(
166        make_tensor,
167        device=device,
168        dtype=dtype,
169        requires_grad=requires_grad,
170        low=-2,
171        high=2,
172    )
173    # Each column of the matrix is getting multiplied many times leading to very large values for
174    # the Jacobian matrix entries and making the finite-difference result of grad check less accurate.
175    # That's why gradcheck with the default range [-9, 9] fails and [-2, 2] is used here.
176    yield SampleInput(make_arg((S, S)), make_arg((S,)))
177    yield SampleInput(make_arg((S + 1, S)), make_arg((S,)))
178    yield SampleInput(make_arg((2, 1, S, S)), make_arg((2, 1, S)))
179    yield SampleInput(make_arg((2, 1, S + 1, S)), make_arg((2, 1, S)))
180    yield SampleInput(
181        make_arg((0, 0), low=None, high=None),
182        make_arg((0,), low=None, high=None),
183    )
184    yield SampleInput(make_arg((S, S)), make_arg((0,), low=None, high=None))
185    # m = n = S, k = S - 2
186    yield SampleInput(make_arg((S, S)), make_arg((S - 2,), low=None, high=None))
187    # m = S, n = S -1, k = S - 2
188    yield SampleInput(make_arg((S, S - 1)), make_arg((S - 2,), low=None, high=None))
189
190
191def sample_inputs_linalg_det_singular(op_info, device, dtype, requires_grad, **kwargs):
192    make_arg = partial(make_tensor, device=device, dtype=dtype)
193
194    def make_singular_matrix_batch_base(size, rank):
195        assert size[-1] == size[-2]
196        assert rank > 0 and rank < size[-1]
197
198        n = size[-1]
199        a = make_arg(size[:-2] + (n, rank)) / 10
200        b = make_arg(size[:-2] + (rank, n)) / 10
201        x = a @ b
202        lu, pivs, _ = torch.linalg.lu_factor_ex(x)
203        p, l, u = torch.lu_unpack(lu, pivs)
204        u_diag_abs = u.diagonal(0, -2, -1).abs()
205        u_diag_abs_largest = u_diag_abs.max(dim=-1, keepdim=True).values
206        u_diag_abs_smallest_idxs = torch.topk(
207            u_diag_abs, k=(n - rank), largest=False
208        ).indices
209        u.diagonal(0, -2, -1).div_(u_diag_abs_largest)
210        u.diagonal(0, -2, -1)[..., u_diag_abs_smallest_idxs] = torch.finfo(dtype).eps
211        matrix = p @ l @ u
212
213        matrix.requires_grad_(requires_grad)
214        return matrix
215
216    for batch, size in product(((), (2,), (2, 2)), range(6)):
217        shape = batch + (size, size)
218        for rank in range(1, size):
219            yield SampleInput(make_singular_matrix_batch_base(shape, rank))
220
221
222def sample_inputs_linalg_matrix_power(op_info, device, dtype, requires_grad, **kwargs):
223    make_fullrank = make_fullrank_matrices_with_distinct_singular_values
224    make_arg = partial(
225        make_tensor, dtype=dtype, device=device, requires_grad=requires_grad
226    )
227    make_arg_fullrank = partial(
228        make_fullrank, dtype=dtype, device=device, requires_grad=requires_grad
229    )
230    # (<matrix_size>, (<batch_sizes, ...>))
231    test_sizes = [
232        (1, ()),
233        (2, (0,)),
234        (2, (2,)),
235    ]
236
237    for matrix_size, batch_sizes in test_sizes:
238        size = batch_sizes + (matrix_size, matrix_size)
239        for n in (0, 3, 5):
240            yield SampleInput(make_arg(size), args=(n,))
241        for n in [-4, -2, -1]:
242            yield SampleInput(make_arg_fullrank(*size), args=(n,))
243
244
245def sample_inputs_linalg_det_logdet_slogdet(
246    op_info, device, dtype, requires_grad, **kwargs
247):
248    make_fullrank = make_fullrank_matrices_with_distinct_singular_values
249    make_arg = partial(
250        make_fullrank, dtype=dtype, device=device, requires_grad=requires_grad
251    )
252    batches = [(), (0,), (3,)]
253    ns = [0, 1, 5]
254
255    is_logdet = op_info.name == "logdet"
256
257    for (
258        batch,
259        n,
260    ) in product(batches, ns):
261        shape = batch + (n, n)
262        A = make_arg(*shape)
263        # Need to make the matrices in A have positive determinant for autograd
264        # To do so, we multiply A by its determinant to flip the sign of its determinant
265        if is_logdet and not A.is_complex() and A.numel() > 0:
266            s = torch.linalg.slogdet(A).sign
267            A = A * s.unsqueeze(-1).unsqueeze(-1)
268            A.requires_grad_(requires_grad)
269        yield SampleInput(A)
270
271
272def sample_inputs_lu_solve(op_info, device, dtype, requires_grad=False, **kwargs):
273    """Samples the inputs for both linalg.lu_solve and lu_solve"""
274    make_fn = make_fullrank_matrices_with_distinct_singular_values
275    make_a = partial(make_fn, dtype=dtype, device=device)
276    make_b = partial(make_tensor, dtype=dtype, device=device)
277
278    def clone(X, requires_grad):
279        Y = X.clone()
280        Y.requires_grad_(requires_grad)
281        return Y
282
283    is_linalg_lu_solve = op_info.name == "linalg.lu_solve"
284
285    batches = ((), (0,), (2,))
286    ns = (3, 1, 0)
287    nrhs = (4, 1, 0)
288
289    for n, batch, rhs in product(ns, batches, nrhs):
290        A = make_a(*(batch + (n, n)))
291        LU, pivots = torch.linalg.lu_factor(A)
292
293        B = make_b(batch + (n, rhs))
294
295        grads = (False,) if not requires_grad else (True, False)
296        # we try all possible combinations of requires_grad for each input
297        for LU_grad, B_grad in product(grads, grads):
298            # when requires_grad == True, at least one input has to have requires_grad enabled
299            if requires_grad and not LU_grad and not B_grad:
300                continue
301
302            if is_linalg_lu_solve:
303                for adjoint, left in product((True, False), repeat=2):
304                    yield SampleInput(
305                        clone(LU, LU_grad),
306                        args=(pivots, clone(B if left else B.mT, B_grad)),
307                        kwargs=dict(adjoint=adjoint, left=left),
308                    )
309            else:
310                yield SampleInput(clone(B, B_grad), args=(clone(LU, LU_grad), pivots))
311
312
313def sample_inputs_linalg_multi_dot(op_info, device, dtype, requires_grad, **kwargs):
314    # Each test case consists of the sizes in the chain of multiplications
315    # e.g. [2, 3, 4, 5] generates matrices (2, 3) @ (3, 4) @ (4, 5)
316    test_cases = [
317        [1, 2, 1],
318        [2, 0, 2],
319        [0, 2, 2],
320        [2, 2, 2, 2],
321        [2, 3, 4, 5],
322        [5, 4, 0, 2],
323        [2, 4, 3, 5, 3, 2],
324    ]
325
326    for sizes in test_cases:
327        tensors = []
328        for size in zip(sizes[:-1], sizes[1:]):
329            t = make_tensor(
330                size, dtype=dtype, device=device, requires_grad=requires_grad
331            )
332            tensors.append(t)
333        yield SampleInput(tensors)
334
335
336def sample_inputs_linalg_matrix_norm(op_info, device, dtype, requires_grad, **kwargs):
337    low_precision_dtypes = (torch.float16, torch.bfloat16, torch.complex32)
338    make_arg = partial(
339        make_tensor, device=device, dtype=dtype, requires_grad=requires_grad
340    )
341
342    sizes = ((2, 2), (2, 3, 2))
343    if dtype in low_precision_dtypes:
344        # svdvals not supported for low precision dtypes
345        ords = ("fro", inf, -inf, 1, -1)
346    else:
347        ords = ("fro", "nuc", inf, -inf, 1, -1, 2, -2)
348    dims = ((-2, -1), (-1, 0))
349
350    for size, ord, dim, keepdim in product(sizes, ords, dims, [True, False]):
351        yield SampleInput(make_arg(size), args=(ord, dim, keepdim))
352
353
354def sample_inputs_linalg_norm(
355    op_info, device, dtype, requires_grad, *, variant=None, **kwargs
356):
357    if variant is not None and variant not in ("subgradient_at_zero",):
358        raise ValueError(
359            f"Unsupported variant, expected variant to be 'subgradient_at_zero' but got: {variant}"
360        )
361
362    test_sizes = [
363        (S,),
364        (0,),
365        (S, S),
366        (0, 0),
367        (S, 0),
368        (0, S),
369        (S, S, S),
370        (0, S, S),
371        (S, 0, S),
372        (0, 0, 0),
373    ]
374
375    vector_ords = (None, 0, 0.5, 1, 2, 3.5, inf, -0.5, -1, -2, -3.5, -inf)
376    if dtype in {torch.float16, torch.bfloat16, torch.complex32}:
377        # svdvals not supported for low precision dtypes
378        matrix_ords = ("fro", inf, -inf, 1, -1)
379    else:
380        matrix_ords = (None, "fro", "nuc", inf, -inf, 1, -1, 2, -2)
381
382    make_arg = partial(
383        make_tensor,
384        dtype=dtype,
385        device=device,
386        requires_grad=requires_grad,
387        low=None,
388        high=None,
389    )
390
391    for test_size in test_sizes:
392        is_vector_norm = len(test_size) == 1
393        is_matrix_norm = len(test_size) == 2
394
395        # IndexError: amax(): Expected reduction dim 0 to have non-zero size.
396        is_valid_for_p2 = is_vector_norm or (test_size[-1] != 0 and test_size[-2] != 0)
397
398        for keepdim in [False, True]:
399            if variant != "subgradient_at_zero" and is_valid_for_p2:
400                yield SampleInput(make_arg(test_size), keepdim=keepdim)
401
402            if not (is_vector_norm or is_matrix_norm):
403                continue
404
405            ords = vector_ords if is_vector_norm else matrix_ords
406
407            for ord in ords:
408                if is_vector_norm and test_size[-1] == 0:
409                    if ord == np.inf or (ord is not None and ord < 0):
410                        # RuntimeError: linalg.vector_norm cannot compute the
411                        # {ord} norm on an empty tensor because the operation
412                        # does not have an identity
413                        continue
414                elif is_matrix_norm:
415                    dims_to_check = {
416                        None: (0,),
417                        np.inf: (0,),
418                        2: (0, 1),
419                        1: (1,),
420                        -1: (1,),
421                        -2: (0, 1),
422                        -np.inf: (0,),
423                    }.get(ord, ())
424
425                    if any(test_size[d] == 0 for d in dims_to_check):
426                        # IndexError: amax(): Expected reduction dim {dim} to
427                        # have non-zero size.
428                        continue
429
430                if variant == "subgradient_at_zero":
431                    yield SampleInput(
432                        torch.zeros(
433                            test_size,
434                            dtype=dtype,
435                            device=device,
436                            requires_grad=requires_grad,
437                        ),
438                        ord,
439                        keepdim=keepdim,
440                    )
441                else:
442                    yield SampleInput(make_arg(test_size), ord, keepdim=keepdim)
443
444                    if ord in ["nuc", "fro"]:
445                        yield SampleInput(
446                            make_arg(test_size), ord=ord, keepdim=keepdim, dim=(0, 1)
447                        )
448
449
450def sample_inputs_linalg_vecdot(op_info, device, dtype, requires_grad, **kwargs):
451    make_arg = partial(
452        make_tensor, device=device, dtype=dtype, requires_grad=requires_grad
453    )
454    batches = ((), (0,), (1,), (5,))
455    ns = (0, 1, 3, 5)
456    for b, n in product(batches, ns):
457        shape = b + (n,)
458        yield SampleInput(make_arg(shape), args=(make_arg(shape),))
459        for i in range(len(shape)):
460            yield SampleInput(
461                make_arg(shape), args=(make_arg(shape),), kwargs=dict(dim=i)
462            )
463
464
465def sample_inputs_linalg_invertible(
466    op_info, device, dtype, requires_grad=False, **kwargs
467):
468    """
469    This function generates invertible inputs for linear algebra ops
470    The input is generated as the itertools.product of 'batches' and 'ns'.
471    In total this function generates 8 SampleInputs
472    'batches' cases include:
473        () - single input,
474        (0,) - zero batched dimension,
475        (2,) - batch of two matrices,
476        (1, 1) - 1x1 batch of matrices
477    'ns' gives 0x0 and 5x5 matrices.
478    Zeros in dimensions are edge cases in the implementation and important to test for in order to avoid unexpected crashes.
479    """
480    make_fn = make_fullrank_matrices_with_distinct_singular_values
481    make_arg = partial(make_fn, dtype=dtype, device=device, requires_grad=requires_grad)
482
483    batches = [(), (0,), (2,), (1, 1)]
484    ns = [5, 0]
485
486    for batch, n in product(batches, ns):
487        yield SampleInput(make_arg(*batch, n, n))
488
489
490def sample_inputs_matrix_rank(op_info, device, dtype, requires_grad=False, **kwargs):
491    """
492    This function produces inputs for matrix rank that test
493    all possible combinations for atol and rtol
494    """
495
496    def make_tol_arg(kwarg_type, inp):
497        if kwarg_type == "none":
498            return None
499        if kwarg_type == "float":
500            return 1.0
501        assert kwarg_type == "tensor"
502        return torch.ones(inp.shape[:-2], device=device)
503
504    for tol_type in ["float", "tensor"]:
505        for atol_type, rtol_type in product(["none", tol_type], repeat=2):
506            if (
507                not atol_type and not rtol_type
508            ):  # default behavior, so skipped here so it's not tested 2 extra times
509                continue
510            for sample in sample_inputs_linalg_invertible(
511                op_info, device, dtype, requires_grad
512            ):
513                assert sample.kwargs == {}
514                sample.kwargs = {
515                    "atol": make_tol_arg(atol_type, sample.input),
516                    "rtol": make_tol_arg(rtol_type, sample.input),
517                }
518                yield sample
519
520    # default kwargs
521    yield from sample_inputs_linalg_invertible(op_info, device, dtype, requires_grad)
522
523
524def sample_inputs_linalg_pinv_singular(
525    op_info, device, dtype, requires_grad=False, **kwargs
526):
527    """
528    This function produces factors `a` and `b` to generate inputs of the form `a @ b.t()` to
529    test the backward method of `linalg_pinv`. That way we always preserve the rank of the
530    input no matter the perturbations applied to it by the gradcheck.
531    Note that `pinv` is Frechet-differentiable in a rank-preserving neighborhood.
532    """
533    batches = [(), (0,), (2,), (1, 1)]
534    # the size of at least 30 is required to cause failures for the previous implicit implementation
535    # of the pinv's backward method, albeit it is slow.
536    size = [0, 3, 50]
537
538    for batch, m, n in product(batches, size, size):
539        for k in range(min(3, m, n)):
540            # Note that by making the columns of `a` and `b` orthonormal we make sure that
541            # the product matrix `a @ b.t()` has condition number 1 when restricted to its image
542            a = (
543                torch.rand(*batch, m, k, device=device, dtype=dtype)
544                .qr()
545                .Q.requires_grad_(requires_grad)
546            )
547            b = (
548                torch.rand(*batch, n, k, device=device, dtype=dtype)
549                .qr()
550                .Q.requires_grad_(requires_grad)
551            )
552            yield SampleInput(a, args=(b,))
553
554
555def sample_inputs_linalg_cond(op_info, device, dtype, requires_grad=False, **kwargs):
556    make_arg = partial(
557        make_tensor, dtype=dtype, device=device, requires_grad=requires_grad
558    )
559
560    # autograd is not supported for inputs with zero number of elements
561    shapes = (
562        (S, S),
563        (2, S, S),
564        (2, 1, S, S),
565    )
566
567    for shape in shapes:
568        yield SampleInput(make_arg(shape))
569
570
571def sample_inputs_linalg_vander(op_info, device, dtype, requires_grad=False, **kwargs):
572    make_arg = partial(
573        make_tensor, dtype=dtype, device=device, requires_grad=requires_grad
574    )
575
576    shapes = (
577        (),
578        (1,),
579        (S,),
580        (2, S),
581    )
582
583    for shape in shapes:
584        if len(shape) > 0 and shape[-1] > 1:
585            yield SampleInput(make_arg(shape))
586        n = shape[-1] if len(shape) > 0 else 1
587        for i in range(3):
588            # n-1, n, n+1
589            N = n + i - 1
590            if N < 2:
591                continue
592            yield SampleInput(make_arg(shape), kwargs=dict(N=N))
593
594
595def np_vander_batched(x, N=None):
596    # Wrapper around np.vander that supports batches of 1 dimension (enough for the tests)
597    if x.ndim == 0:
598        x = x[np.newaxis]
599    if x.ndim == 1:
600        y = np.vander(x, N=N, increasing=True)
601        return y
602    else:
603        if N is None:
604            N = x.shape[-1]
605        y = np.vander(x.ravel(), N=N, increasing=True).reshape((*x.shape, N))
606        return y
607
608
609def sample_inputs_linalg_cholesky_inverse(
610    op_info, device, dtype, requires_grad=False, **kwargs
611):
612    from torch.testing._internal.common_utils import random_well_conditioned_matrix
613
614    # Cholesky factorization is for positive-definite matrices
615    single_well_conditioned_matrix = random_well_conditioned_matrix(
616        S, S, dtype=dtype, device=device
617    )
618    batch_well_conditioned_matrices = random_well_conditioned_matrix(
619        2, S, S, dtype=dtype, device=device
620    )
621    single_pd = single_well_conditioned_matrix @ single_well_conditioned_matrix.mH
622    batch_pd = batch_well_conditioned_matrices @ batch_well_conditioned_matrices.mH
623
624    inputs = (
625        torch.zeros(0, 0, dtype=dtype, device=device),  # 0x0 matrix
626        torch.zeros(0, 2, 2, dtype=dtype, device=device),  # zero batch of matrices
627        single_pd,
628        batch_pd,
629    )
630    test_cases = (torch.linalg.cholesky(a, upper=False) for a in inputs)
631    for l in test_cases:
632        # generated lower-triangular samples
633        l.requires_grad = requires_grad
634        yield SampleInput(l)  # upper=False by default
635        yield SampleInput(
636            l.detach().clone().requires_grad_(requires_grad), kwargs=dict(upper=False)
637        )
638
639        # generate upper-triangular inputs
640        u = l.detach().clone().mT.contiguous().requires_grad_(requires_grad)
641        yield SampleInput(u, kwargs=dict(upper=True))
642
643
644def sample_inputs_linalg_ldl_factor(
645    op_info, device, dtype, requires_grad=False, **kwargs
646):
647    from torch.testing._internal.common_utils import (
648        random_hermitian_pd_matrix,
649        random_symmetric_pd_matrix,
650    )
651
652    device = torch.device(device)
653
654    # Symmetric inputs
655    yield SampleInput(
656        random_symmetric_pd_matrix(S, dtype=dtype, device=device),
657        kwargs=dict(hermitian=False),
658    )  # single matrix
659    yield SampleInput(
660        random_symmetric_pd_matrix(S, 2, dtype=dtype, device=device),
661        kwargs=dict(hermitian=False),
662    )  # batch of matrices
663    yield SampleInput(
664        torch.zeros(0, 0, dtype=dtype, device=device), kwargs=dict(hermitian=False)
665    )  # 0x0 matrix
666    yield SampleInput(
667        torch.zeros(0, 2, 2, dtype=dtype, device=device), kwargs=dict(hermitian=False)
668    )  # zero batch of matrices
669
670    # Hermitian inputs
671    # hermitian=True for complex inputs on CUDA is supported only with MAGMA 2.5.4+
672    magma_254_available = device.type == "cuda" and _get_magma_version() >= (2, 5, 4)
673    if dtype.is_complex and (device.type == "cpu" or magma_254_available):
674        yield SampleInput(
675            random_hermitian_pd_matrix(S, dtype=dtype, device=device),
676            kwargs=dict(hermitian=True),
677        )  # single matrix
678        yield SampleInput(
679            random_hermitian_pd_matrix(S, 2, dtype=dtype, device=device),
680            kwargs=dict(hermitian=True),
681        )  # batch of matrices
682
683
684def sample_inputs_linalg_ldl_solve(
685    op_info, device, dtype, requires_grad=False, **kwargs
686):
687    # Generate LDL factors of symmetric (and Hermitian on CPU) matrices
688    from torch.testing._internal.common_utils import (
689        random_hermitian_pd_matrix,
690        random_symmetric_pd_matrix,
691    )
692
693    device = torch.device(device)
694    symmetric_inputs = (
695        random_symmetric_pd_matrix(S, dtype=dtype, device=device),  # single matrix
696        random_symmetric_pd_matrix(
697            S, 2, dtype=dtype, device=device
698        ),  # batch of matrices
699        torch.zeros(0, 0, dtype=dtype, device=device),  # 0x0 matrix
700        torch.zeros(0, 2, 2, dtype=dtype, device=device),  # zero batch of matrices
701    )
702    hermitian_inputs = (
703        (
704            random_hermitian_pd_matrix(S, dtype=dtype, device=device),
705            random_hermitian_pd_matrix(S, 2, dtype=dtype, device=device),
706        )
707        if device.type == "cpu" and dtype.is_complex
708        else ()
709    )
710    test_cases1 = (
711        torch.linalg.ldl_factor_ex(a, hermitian=False) for a in symmetric_inputs
712    )
713    test_cases2 = (
714        torch.linalg.ldl_factor_ex(a, hermitian=True) for a in hermitian_inputs
715    )
716
717    # Symmetric case
718    make_arg = partial(
719        make_tensor, device=device, dtype=dtype, requires_grad=requires_grad
720    )
721    for test_case in test_cases1:
722        factors, pivots, _ = test_case
723        factors.requires_grad = requires_grad
724        for B_batch_shape in ((), factors.shape[:-2]):
725            B = make_arg((*B_batch_shape, factors.shape[-1], S))
726            yield SampleInput(factors, args=(pivots, B), kwargs=dict(hermitian=False))
727            clone_factors = factors.detach().clone().requires_grad_(requires_grad)
728            yield SampleInput(
729                clone_factors, args=(pivots, B), kwargs=dict(hermitian=False)
730            )
731
732    # Hermitian case
733    for test_case in test_cases2:
734        factors, pivots, _ = test_case
735        factors.requires_grad = requires_grad
736        for B_batch_shape in ((), factors.shape[:-2]):
737            B = make_arg((*B_batch_shape, factors.shape[-1], S))
738            yield SampleInput(factors, args=(pivots, B), kwargs=dict(hermitian=True))
739            clone_factors = factors.detach().clone().requires_grad_(requires_grad)
740            yield SampleInput(
741                clone_factors, args=(pivots, B), kwargs=dict(hermitian=True)
742            )
743
744
745def sample_inputs_linalg_lstsq(op_info, device, dtype, requires_grad=False, **kwargs):
746    from torch.testing._internal.common_utils import random_well_conditioned_matrix
747
748    device = torch.device(device)
749
750    drivers: Tuple[str, ...]
751    if device.type == "cuda":
752        drivers = ("gels",)
753    else:
754        drivers = ("gels", "gelsy", "gelss", "gelsd")
755
756    # we generate matrices of shape (..., n + delta, n)
757    deltas: Tuple[int, ...]
758    if device.type == "cpu" or has_cusolver():
759        deltas = (-1, 0, +1)
760    # only square systems if Cusolver is not available
761    # becase we solve a lstsq problem with a transposed matrix in the backward
762    else:
763        deltas = (0,)
764
765    for batch, driver, delta in product(((), (3,), (3, 3)), drivers, deltas):
766        shape = batch + (3 + delta, 3)
767        a = random_well_conditioned_matrix(*shape, dtype=dtype, device=device)
768        a.requires_grad_(requires_grad)
769        b = make_tensor(
770            shape,
771            dtype=dtype,
772            device=device,
773            low=None,
774            high=None,
775            requires_grad=requires_grad,
776        )
777        yield SampleInput(a, b, driver=driver)
778
779
780def error_inputs_lstsq(op_info, device, **kwargs):
781    zero_d = torch.randn((), device=device)
782    yield ErrorInput(
783        SampleInput(zero_d, args=(zero_d,)),
784        error_type=RuntimeError,
785        error_regex="at least 2 dimensions",
786    )
787
788
789def error_inputs_lstsq_grad_oriented(op_info, device, **kwargs):
790    zero_d = torch.randn((), device=device)
791    yield ErrorInput(
792        SampleInput(zero_d, args=(zero_d, None)),
793        error_type=RuntimeError,
794        error_regex="at least 2 dimensions",
795    )
796
797
798def sample_inputs_diagonal_diag_embed(op_info, device, dtype, requires_grad, **kwargs):
799    make_arg = partial(
800        make_tensor, dtype=dtype, device=device, requires_grad=requires_grad
801    )
802
803    # Shapes for 2D Tensors
804    shapes_2d = ((S, S), (3, 5), (5, 3))
805
806    # Shapes for 3D Tensors
807    shapes_3d = ((S, S, S),)
808
809    kwargs_2d = ({}, dict(offset=2), dict(offset=2), dict(offset=1))
810    kwargs_3d = (
811        dict(offset=1, dim1=1, dim2=2),
812        dict(offset=2, dim1=0, dim2=1),
813        dict(offset=-2, dim1=0, dim2=1),
814    )
815
816    for shape, kwarg in chain(
817        product(shapes_2d, kwargs_2d), product(shapes_3d, kwargs_3d)
818    ):
819        yield SampleInput(make_arg(shape), kwargs=kwarg)
820
821
822def error_inputs_diagonal_diag_embed(op_info, device, **kwargs):
823    make_arg = partial(make_tensor, device=device, dtype=torch.float32)
824
825    shapes1d = (0, 1, (0,), (1,))
826    shapes2d = ((M, L),)
827    shapes3d = ((M, S, L),)
828
829    kwargs1d = {}
830
831    kwargs2d = (
832        # dim1 == dim2 is not allowed
833        dict(dim1=1, dim2=1),
834        # out of bounds dims are not allowed
835        dict(dim1=10000),
836        dict(dim2=10000),
837    )
838
839    kwargs3d = kwargs2d
840
841    samples1d = product(shapes1d, kwargs1d)
842    samples2d = product(shapes2d, kwargs2d)
843    samples3d = product(shapes3d, kwargs3d)
844
845    for shape, kwargs in chain(samples1d, samples2d, samples3d):
846        arg = make_arg(shape)
847        sample = SampleInput(input=arg, kwargs=kwargs)
848
849        dim1 = kwargs.get("dim1")
850        dim2 = kwargs.get("dim2")
851
852        if "diagonal" in op_info.name:
853            num_dim = arg.dim()
854        elif op_info.name in ("diag_embed", "_refs.diag_embed"):
855            # these are valid inputs for diag_embed
856            if shape in ((0,), (1,)):
857                continue
858            num_dim = arg.dim() + 1
859        else:
860            raise RuntimeError("should be unreachable")
861
862        bound1 = -num_dim
863        bound2 = num_dim - 1
864        dim_range = range(bound1, bound2 + 1)
865        dim1_cond = dim1 and dim1 not in dim_range
866        dim2_cond = dim2 and dim2 not in dim_range
867
868        if dim1 == dim2:
869            err = f"diagonal dimensions cannot be identical {dim1}, {dim2}"
870            yield ErrorInput(sample, error_regex=err, error_type=RuntimeError)
871        elif dim1_cond or dim2_cond:
872            err_dim = dim1 if dim1_cond else dim2
873            err = (
874                r"Dimension out of range \(expected to be in range of "
875                rf"\[{bound1}, {bound2}\], but got {err_dim}\)"
876            )
877            yield ErrorInput(sample, error_regex=err, error_type=IndexError)
878        else:
879            raise RuntimeError("should be unreachable")
880
881
882def sample_inputs_linalg_cholesky(
883    op_info, device, dtype, requires_grad=False, **kwargs
884):
885    """
886    This function generates always positive-definite input for torch.linalg.cholesky using
887    random_hermitian_pd_matrix.
888    The input is generated as the itertools.product of 'batches' and 'ns'.
889    In total this function generates 8 SampleInputs
890    'batches' cases include:
891        () - single input,
892        (0,) - zero batched dimension,
893        (2,) - batch of two matrices,
894        (1, 1) - 1x1 batch of matrices
895    'ns' gives 0x0 and 5x5 matrices.
896    Zeros in dimensions are edge cases in the implementation and important to test for in order to avoid unexpected crashes.
897    """
898    from torch.testing._internal.common_utils import random_hermitian_pd_matrix
899
900    batches = [(), (0,), (2,), (1, 1)]
901    ns = [5, 0]
902    for batch, n, upper in product(batches, ns, [True, False]):
903        a = random_hermitian_pd_matrix(n, *batch, dtype=dtype, device=device)
904        a.requires_grad = requires_grad
905        yield SampleInput(a, upper=upper)
906
907
908def sample_inputs_linalg_eig(op_info, device, dtype, requires_grad=False, **kwargs):
909    """
910    This function generates input for torch.linalg.eig
911    """
912
913    def out_fn(output):
914        return output[0], abs(output[1])
915
916    samples = sample_inputs_linalg_invertible(op_info, device, dtype, requires_grad)
917    for sample in samples:
918        sample.output_process_fn_grad = out_fn
919        yield sample
920
921
922def sample_inputs_linalg_eigh(op_info, device, dtype, requires_grad=False, **kwargs):
923    """
924    This function generates input for torch.linalg.eigh/eigvalsh with UPLO="U" or "L" keyword argument.
925    """
926
927    def out_fn(output):
928        if isinstance(output, tuple):
929            # eigh function
930            return output[0], abs(output[1])
931        else:
932            # eigvalsh function
933            return output
934
935    # Samples do not need to be Hermitian, as we're using gradcheck_wrapper_hermitian_input
936    samples = sample_inputs_linalg_invertible(op_info, device, dtype, requires_grad)
937    for sample in samples:
938        # Note: we cannot use np.random.choice here as TorchDynamo
939        # does not support tensors of strings.
940        sample.kwargs = {"UPLO": random.choice(["L", "U"])}
941        sample.output_process_fn_grad = out_fn
942        yield sample
943
944
945def sample_inputs_linalg_pinv(op_info, device, dtype, requires_grad=False, **kwargs):
946    """
947    This function generates input for torch.linalg.pinv with hermitian=False keyword argument.
948    """
949    for o in sample_inputs_linalg_invertible(
950        op_info, device, dtype, requires_grad, **kwargs
951    ):
952        real_dtype = o.input.real.dtype if dtype.is_complex else dtype
953        # requires_grad path for rtol tensor is not implemented
954        for rtol in (None, 1.0, torch.tensor(1.0, dtype=real_dtype, device=device)):
955            o = clone_sample(o)
956            o.kwargs = {"rtol": rtol}
957            yield o
958
959
960def sample_inputs_linalg_pinv_hermitian(
961    op_info, device, dtype, requires_grad=False, **kwargs
962):
963    """
964    This function generates input for torch.linalg.pinv with hermitian=True keyword argument.
965    """
966    for o in sample_inputs_linalg_invertible(
967        op_info, device, dtype, requires_grad, **kwargs
968    ):
969        o.kwargs = {"hermitian": True}
970        yield o
971
972
973def sample_inputs_linalg_solve(
974    op_info, device, dtype, requires_grad=False, vector_rhs_allowed=True, **kwargs
975):
976    """
977    This function generates always solvable input for torch.linalg.solve
978    We sample a fullrank square matrix (i.e. invertible) A
979    The first input to torch.linalg.solve is generated as the itertools.product of 'batches' and 'ns'.
980    The second input is generated as the product of 'batches', 'ns' and 'nrhs'.
981    In total this function generates 18 SampleInputs
982    'batches' cases include:
983        () - single input,
984        (0,) - zero batched dimension,
985        (2,) - batch of two matrices.
986    'ns' gives 0x0 and 5x5 matrices.
987    and 'nrhs' controls the number of vectors to solve for:
988        () - using 1 as the number of vectors implicitly
989        (1,) - same as () but explicit
990        (3,) - solve for 3 vectors.
991    Zeros in dimensions are edge cases in the implementation and important to test for in order to avoid unexpected crashes.
992    'vector_rhs_allowed' controls whether to include nrhs = () to the list of SampleInputs.
993    torch.solve / triangular_solve / cholesky_solve (opposed to torch.linalg.solve) do not allow
994    1D tensors (vectors) as the right-hand-side.
995    Once torch.solve / triangular_solve / cholesky_solve and its testing are removed,
996    'vector_rhs_allowed' may be removed here as well.
997    """
998    make_fullrank = make_fullrank_matrices_with_distinct_singular_values
999    make_a = partial(
1000        make_fullrank, dtype=dtype, device=device, requires_grad=requires_grad
1001    )
1002    make_b = partial(
1003        make_tensor, dtype=dtype, device=device, requires_grad=requires_grad
1004    )
1005
1006    batches = [(), (0,), (2,)]
1007    ns = [5, 0]
1008    if vector_rhs_allowed:
1009        nrhs = [(), (1,), (3,)]
1010    else:
1011        nrhs = [(1,), (3,)]
1012
1013    for n, batch, rhs in product(ns, batches, nrhs):
1014        yield SampleInput(make_a(*batch, n, n), args=(make_b(batch + (n,) + rhs),))
1015
1016
1017def sample_inputs_linalg_solve_triangular(
1018    op_info, device, dtype, requires_grad=False, **kwargs
1019):
1020    make_arg = partial(make_tensor, dtype=dtype, device=device)
1021    bs = (1, 2, 0)
1022    ns = (3, 0)
1023    ks = (1, 3, 0)
1024
1025    for b, n, k, (left, upper, uni) in product(
1026        bs, ns, ks, product((True, False), repeat=3)
1027    ):
1028        if b == 1:
1029            A = make_arg((n, n)) if left else make_arg((k, k))
1030            B = make_arg((n, k))
1031        else:
1032            A = make_arg((b, n, n)) if left else make_arg((b, k, k))
1033            B = make_arg((b, n, k))
1034        if uni:
1035            # Not really necessary, but writing it for consistency
1036            A.diagonal(0, -2, -1).fill_(1.0)
1037        else:
1038            d = A.diagonal(0, -2, -1)
1039            d[d.abs() < 1e-6] = 1.0
1040        if upper:
1041            A.triu_()
1042        else:
1043            A.tril_()
1044        kwargs = {"upper": upper, "left": left, "unitriangular": uni}
1045        if requires_grad:
1046            for grad_A, grad_B in product((True, False), repeat=2):
1047                # Either A or B needs to have a gradient
1048                if not grad_A and not grad_B:
1049                    continue
1050                yield SampleInput(
1051                    A.clone().requires_grad_(grad_A),
1052                    args=(B.clone().requires_grad_(grad_B),),
1053                    kwargs=kwargs,
1054                )
1055        else:
1056            yield SampleInput(A, args=(B,), kwargs=kwargs)
1057
1058
1059def sample_inputs_legacy_solve(op_info, device, dtype, requires_grad=False, **kwargs):
1060    """
1061    This function generates always solvable input for legacy solve functions
1062    (the ones that are not in torch.linalg module).
1063    The difference from sample_inputs_linalg_solve is that here the right-hand-side of A x = b equation
1064    should have b.ndim >= 2, vectors are not allowed.
1065    Also the arguments order is swapped.
1066    """
1067    out = sample_inputs_linalg_solve(
1068        op_info, device, dtype, requires_grad=requires_grad, vector_rhs_allowed=False
1069    )
1070
1071    def out_fn(output):
1072        return output[0]
1073
1074    # Reverses tensor order
1075    for sample in out:
1076        sample.input, sample.args = sample.args[0], (sample.input,)
1077        if op_info.name == "solve":
1078            sample.output_process_fn_grad = out_fn
1079        yield sample
1080
1081
1082def sample_inputs_linalg_lu(op_info, device, dtype, requires_grad=False, **kwargs):
1083    full_rank = op_info.name == "linalg.lu_factor"
1084    make_fn = (
1085        make_tensor
1086        if not full_rank
1087        else make_fullrank_matrices_with_distinct_singular_values
1088    )
1089    make_arg = partial(make_fn, dtype=dtype, device=device, requires_grad=requires_grad)
1090
1091    def out_fn(output):
1092        if op_info.name == "linalg.lu":
1093            return output[1], output[2]
1094        else:
1095            return output
1096
1097    batch_shapes = ((), (3,), (3, 3))
1098    # pivot=False only supported in CUDA
1099    pivots = (True, False) if torch.device(device).type == "cuda" else (True,)
1100    deltas = (-2, -1, 0, +1, +2)
1101    for batch_shape, pivot, delta in product(batch_shapes, pivots, deltas):
1102        shape = batch_shape + (S + delta, S)
1103        # Insanely annoying that make_fullrank_blablabla accepts a *shape and not a tuple!
1104        A = make_arg(shape) if not full_rank else make_arg(*shape)
1105        yield SampleInput(A, kwargs={"pivot": pivot}, output_process_fn_grad=out_fn)
1106
1107
1108def sample_inputs_linalg_svdvals(op_info, device, dtype, requires_grad=False, **kwargs):
1109    make_arg = partial(
1110        make_tensor, dtype=dtype, device=device, requires_grad=requires_grad
1111    )
1112
1113    batches = [(), (0,), (2,), (1, 1)]
1114    ns = [5, 2, 0]
1115
1116    for batch, m, n in product(batches, ns, ns):
1117        yield SampleInput(make_arg(batch + (m, n)))
1118
1119
1120def sample_inputs_linalg_qr_geqrf(
1121    op_info, device, dtype, requires_grad=False, **kwargs
1122):
1123    # QR is just well defined when the matrix is full rank
1124    make_fullrank = make_fullrank_matrices_with_distinct_singular_values
1125    make_arg = partial(
1126        make_fullrank, dtype=dtype, device=device, requires_grad=requires_grad
1127    )
1128
1129    batches = [(), (0,), (2,), (1, 1)]
1130    ns = [5, 2, 0]
1131
1132    for batch, (m, n) in product(batches, product(ns, ns)):
1133        shape = batch + (m, n)
1134        yield SampleInput(make_arg(*shape))
1135
1136
1137def sample_inputs_tensorsolve(op_info, device, dtype, requires_grad, **kwargs):
1138    a_shapes = [(2, 3, 6), (3, 4, 4, 3)]
1139    # Zero-dim tensors are not supported in NumPy, so we skip them for now.
1140    # NumPy is used in reference check tests.
1141    # See https://github.com/numpy/numpy/pull/20482 for tracking NumPy bugfix.
1142    # a_shapes += [(0, 0, 1, 2, 3, 0)]
1143    dimss = [None, (0, 2)]
1144
1145    make_arg = partial(
1146        make_tensor, dtype=dtype, device=device, requires_grad=requires_grad
1147    )
1148    for a_shape, dims in itertools.product(a_shapes, dimss):
1149        a = make_arg(a_shape)
1150        b = make_arg(a_shape[:2])
1151        yield SampleInput(a, b, dims=dims)
1152
1153
1154def sample_inputs_tensorinv(op_info, device, dtype, requires_grad, **kwargs):
1155    make_arg = make_fullrank_matrices_with_distinct_singular_values
1156
1157    def make_input():
1158        return make_arg(12, 12, device=device, dtype=dtype, requires_grad=requires_grad)
1159
1160    # lhs / rhs shape can have any number of dimensions as long as their product equals 12
1161    shapes = [
1162        ((2, 2, 3), (12, 1)),
1163        ((4, 3), (6, 1, 2)),
1164    ]
1165
1166    for shape_lhs, shape_rhs in shapes:
1167        inp = make_input().reshape(*shape_lhs, *shape_rhs).detach()
1168        inp.requires_grad_(requires_grad)
1169        yield SampleInput(inp, ind=len(shape_lhs))
1170
1171
1172op_db: List[OpInfo] = [
1173    OpInfo(
1174        "linalg.cross",
1175        ref=lambda x, y, dim=-1: np.cross(x, y, axis=dim),
1176        op=torch.linalg.cross,
1177        dtypes=all_types_and_complex_and(torch.half, torch.bfloat16),
1178        aten_name="linalg_cross",
1179        sample_inputs_func=sample_inputs_cross,
1180        error_inputs_func=error_inputs_cross,
1181        supports_out=True,
1182        supports_fwgrad_bwgrad=True,
1183        supports_forward_ad=True,
1184        skips=(
1185            DecorateInfo(
1186                unittest.skip("Unsupported on MPS for now"),
1187                "TestCommon",
1188                "test_numpy_ref_mps",
1189            ),
1190        ),
1191    ),
1192    OpInfo(
1193        "linalg.det",
1194        aten_name="linalg_det",
1195        op=torch.linalg.det,
1196        aliases=("det",),
1197        dtypes=floating_and_complex_types(),
1198        supports_forward_ad=True,
1199        supports_fwgrad_bwgrad=True,
1200        sample_inputs_func=sample_inputs_linalg_det_logdet_slogdet,
1201        decorators=[skipCPUIfNoLapack, skipCUDAIfNoMagmaAndNoCusolver],
1202        check_batched_gradgrad=False,
1203    ),
1204    OpInfo(
1205        "linalg.det",
1206        aten_name="linalg_det",
1207        op=torch.linalg.det,
1208        variant_test_name="singular",
1209        aliases=("det",),
1210        dtypes=floating_and_complex_types(),
1211        supports_forward_ad=True,
1212        supports_fwgrad_bwgrad=True,
1213        check_batched_gradgrad=False,
1214        sample_inputs_func=sample_inputs_linalg_det_singular,
1215        decorators=[skipCPUIfNoLapack, skipCUDAIfNoMagmaAndNoCusolver],
1216        skips=(
1217            DecorateInfo(
1218                unittest.skip("The backward may give different results"),
1219                "TestCommon",
1220                "test_noncontiguous_samples",
1221            ),
1222            DecorateInfo(
1223                unittest.skip("Gradients are incorrect on macos"),
1224                "TestBwdGradients",
1225                "test_fn_grad",
1226                device_type="cpu",
1227                dtypes=(torch.float64,),
1228                active_if=IS_MACOS,
1229            ),
1230            DecorateInfo(
1231                unittest.skip("Gradients are incorrect on macos"),
1232                "TestFwdGradients",
1233                "test_forward_mode_AD",
1234                device_type="cpu",
1235                dtypes=(torch.float64,),
1236                active_if=IS_MACOS,
1237            ),
1238            # Both Hessians are incorrect on complex inputs??
1239            DecorateInfo(
1240                unittest.expectedFailure,
1241                "TestBwdGradients",
1242                "test_fn_gradgrad",
1243                dtypes=(torch.complex128,),
1244            ),
1245            DecorateInfo(
1246                unittest.expectedFailure,
1247                "TestFwdGradients",
1248                "test_fn_fwgrad_bwgrad",
1249                dtypes=(torch.complex128,),
1250            ),
1251            DecorateInfo(
1252                unittest.skip("Skipped, see https://github.com//issues/84192"),
1253                "TestBwdGradients",
1254                "test_fn_gradgrad",
1255                device_type="cuda",
1256            ),
1257            DecorateInfo(
1258                unittest.skip("Skipped, see https://github.com//issues/84192"),
1259                "TestFwdGradients",
1260                "test_fn_fwgrad_bwgrad",
1261                device_type="cuda",
1262            ),
1263            DecorateInfo(
1264                unittest.skip(
1265                    "Flaky on ROCm https://github.com/pytorch/pytorch/issues/93044"
1266                ),
1267                "TestBwdGradients",
1268                "test_fn_grad",
1269                device_type="cuda",
1270                dtypes=get_all_complex_dtypes(),
1271                active_if=TEST_WITH_ROCM,
1272            ),
1273            DecorateInfo(
1274                unittest.skip(
1275                    "Flaky on ROCm https://github.com/pytorch/pytorch/issues/93045"
1276                ),
1277                "TestFwdGradients",
1278                "test_forward_mode_AD",
1279                device_type="cuda",
1280                dtypes=get_all_complex_dtypes(),
1281                active_if=TEST_WITH_ROCM,
1282            ),
1283        ),
1284    ),
1285    OpInfo(
1286        "linalg.diagonal",
1287        aten_name="linalg_diagonal",
1288        aten_backward_name="diagonal_backward",
1289        dtypes=all_types_and_complex_and(
1290            torch.bool, torch.bfloat16, torch.float16, torch.chalf
1291        ),
1292        supports_out=False,
1293        supports_forward_ad=True,
1294        supports_fwgrad_bwgrad=True,
1295        sample_inputs_func=sample_inputs_diagonal_diag_embed,
1296        error_inputs_func=error_inputs_diagonal_diag_embed,
1297    ),
1298    OpInfo(
1299        "linalg.cholesky",
1300        aten_name="linalg_cholesky",
1301        dtypes=floating_and_complex_types(),
1302        supports_forward_ad=True,
1303        supports_fwgrad_bwgrad=True,
1304        # See https://github.com/pytorch/pytorch/pull/78358
1305        check_batched_forward_grad=False,
1306        sample_inputs_func=sample_inputs_linalg_cholesky,
1307        gradcheck_wrapper=gradcheck_wrapper_hermitian_input,
1308        decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack],
1309    ),
1310    OpInfo(
1311        "linalg.cholesky_ex",
1312        aten_name="linalg_cholesky_ex",
1313        dtypes=floating_and_complex_types(),
1314        supports_forward_ad=True,
1315        supports_fwgrad_bwgrad=True,
1316        # See https://github.com/pytorch/pytorch/pull/78358
1317        check_batched_forward_grad=False,
1318        sample_inputs_func=sample_inputs_linalg_cholesky,
1319        gradcheck_wrapper=gradcheck_wrapper_hermitian_input,
1320        decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack],
1321    ),
1322    OpInfo(
1323        "linalg.vecdot",
1324        aten_name="linalg_vecdot",
1325        ref=lambda x, y, *, dim=-1: (x.conj() * y).sum(dim),
1326        dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16),
1327        sample_inputs_func=sample_inputs_linalg_vecdot,
1328        check_batched_forward_grad=False,
1329        supports_forward_ad=True,
1330        supports_fwgrad_bwgrad=True,
1331        skips=(
1332            # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479
1333            DecorateInfo(
1334                unittest.skip("Skipped!"),
1335                "TestSchemaCheckModeOpInfo",
1336                "test_schema_correctness",
1337                dtypes=(torch.complex64, torch.complex128),
1338            ),
1339            DecorateInfo(
1340                unittest.skip("Unsupported on MPS for now"),
1341                "TestCommon",
1342                "test_numpy_ref_mps",
1343            ),
1344            DecorateInfo(
1345                toleranceOverride({torch.half: tol(atol=1.2e-2, rtol=1.7e-2)}),
1346                "TestInductorOpInfo",
1347                "test_comprehensive",
1348                device_type="cuda",
1349            ),
1350        ),
1351    ),
1352    OpInfo(
1353        "linalg.cond",
1354        aten_name="linalg_cond",
1355        dtypes=floating_and_complex_types(),
1356        sample_inputs_func=sample_inputs_linalg_cond,
1357        check_batched_gradgrad=False,
1358        check_batched_forward_grad=False,
1359        supports_forward_ad=True,
1360        supports_fwgrad_bwgrad=True,
1361        gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
1362        decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack, with_tf32_off],
1363        skips=(
1364            DecorateInfo(
1365                unittest.skip("Skipped!"),
1366                "TestFakeTensor",
1367                "test_fake_crossref_backward_amp",
1368                device_type="cuda",
1369                dtypes=[torch.float32],
1370                active_if=TEST_WITH_ROCM,
1371            ),
1372            DecorateInfo(
1373                unittest.skip("Skipped!"),
1374                "TestFakeTensor",
1375                "test_fake_crossref_backward_no_amp",
1376                device_type="cuda",
1377                dtypes=[torch.float32],
1378                active_if=TEST_WITH_ROCM,
1379            ),
1380        ),
1381    ),
1382    OpInfo(
1383        "linalg.eig",
1384        aten_name="linalg_eig",
1385        op=torch.linalg.eig,
1386        dtypes=floating_and_complex_types(),
1387        sample_inputs_func=sample_inputs_linalg_eig,
1388        check_batched_forward_grad=False,
1389        check_batched_grad=False,
1390        check_batched_gradgrad=False,
1391        supports_forward_ad=True,
1392        supports_fwgrad_bwgrad=True,
1393        skips=(
1394            # AssertionError: Scalars are not equal!
1395            DecorateInfo(
1396                unittest.expectedFailure, "TestCommon", "test_out", device_type="cpu"
1397            ),
1398            DecorateInfo(
1399                unittest.skip("Skipped!"),
1400                "TestCommon",
1401                "test_out",
1402                device_type="mps",
1403                dtypes=[torch.float32],
1404            ),
1405            DecorateInfo(
1406                unittest.skip("Skipped!"),
1407                "TestCommon",
1408                "test_variant_consistency_eager",
1409                device_type="mps",
1410                dtypes=[torch.float32],
1411            ),
1412            DecorateInfo(
1413                unittest.skip("Skipped!"),
1414                "TestJit",
1415                "test_variant_consistency_jit",
1416                device_type="mps",
1417                dtypes=[torch.float32],
1418            ),
1419        ),
1420        decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack, with_tf32_off],
1421    ),
1422    OpInfo(
1423        "linalg.eigvals",
1424        aten_name="linalg_eigvals",
1425        op=torch.linalg.eigvals,
1426        dtypes=floating_and_complex_types(),
1427        sample_inputs_func=sample_inputs_linalg_invertible,
1428        check_batched_forward_grad=False,
1429        check_batched_grad=False,
1430        check_batched_gradgrad=False,
1431        supports_forward_ad=True,
1432        supports_fwgrad_bwgrad=True,
1433        decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack],
1434        skips=(
1435            DecorateInfo(
1436                unittest.skip("Skipped!"),
1437                "TestCommon",
1438                "test_out",
1439                device_type="mps",
1440                dtypes=[torch.float32],
1441            ),
1442            DecorateInfo(
1443                unittest.skip("Skipped!"),
1444                "TestCommon",
1445                "test_variant_consistency_eager",
1446                device_type="mps",
1447                dtypes=[torch.float32],
1448            ),
1449            DecorateInfo(
1450                unittest.skip("Skipped!"),
1451                "TestJit",
1452                "test_variant_consistency_jit",
1453                device_type="mps",
1454                dtypes=[torch.float32],
1455            ),
1456        ),
1457    ),
1458    OpInfo(
1459        "linalg.eigh",
1460        aten_name="linalg_eigh",
1461        dtypes=floating_and_complex_types(),
1462        sample_inputs_func=sample_inputs_linalg_eigh,
1463        gradcheck_wrapper=gradcheck_wrapper_hermitian_input,
1464        check_batched_forward_grad=False,
1465        check_batched_grad=False,
1466        check_batched_gradgrad=False,
1467        supports_forward_ad=True,
1468        supports_fwgrad_bwgrad=True,
1469        decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack, with_tf32_off],
1470        skips=(
1471            DecorateInfo(
1472                unittest.skip("Skipped!"),
1473                "TestCommon",
1474                "test_out",
1475                device_type="mps",
1476                dtypes=[torch.float32],
1477            ),
1478            DecorateInfo(
1479                unittest.skip("Skipped!"),
1480                "TestCommon",
1481                "test_variant_consistency_eager",
1482                device_type="mps",
1483                dtypes=[torch.float32],
1484            ),
1485            DecorateInfo(
1486                unittest.skip("Skipped!"),
1487                "TestJit",
1488                "test_variant_consistency_jit",
1489                device_type="mps",
1490                dtypes=[torch.float32],
1491            ),
1492        ),
1493    ),
1494    OpInfo(
1495        "linalg.eigvalsh",
1496        aten_name="linalg_eigvalsh",
1497        dtypes=floating_and_complex_types(),
1498        sample_inputs_func=sample_inputs_linalg_eigh,
1499        gradcheck_wrapper=gradcheck_wrapper_hermitian_input,
1500        check_batched_forward_grad=False,
1501        check_batched_grad=False,
1502        check_batched_gradgrad=False,
1503        supports_forward_ad=True,
1504        supports_fwgrad_bwgrad=True,
1505        decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack],
1506        skips=(
1507            # Pre-existing condition; Needs to be fixed
1508            DecorateInfo(
1509                unittest.skip("Skipped!"),
1510                "TestCommon",
1511                "test_out",
1512                device_type="mps",
1513                dtypes=[torch.float32],
1514            ),
1515            DecorateInfo(
1516                unittest.skip("Skipped!"),
1517                "TestCommon",
1518                "test_variant_consistency_eager",
1519                device_type="mps",
1520                dtypes=[torch.float32],
1521            ),
1522            DecorateInfo(
1523                unittest.skip("Skipped!"),
1524                "TestJit",
1525                "test_variant_consistency_jit",
1526                device_type="mps",
1527                dtypes=[torch.float32],
1528            ),
1529        ),
1530    ),
1531    OpInfo(
1532        "linalg.householder_product",
1533        aten_name="linalg_householder_product",
1534        op=torch.linalg.householder_product,
1535        aliases=("orgqr",),
1536        dtypes=floating_and_complex_types(),
1537        # https://github.com/pytorch/pytorch/issues/80411
1538        gradcheck_fast_mode=True,
1539        # TODO: backward uses in-place operations that vmap doesn't like
1540        check_batched_grad=False,
1541        check_batched_gradgrad=False,
1542        supports_forward_ad=True,
1543        supports_fwgrad_bwgrad=True,
1544        check_batched_forward_grad=False,
1545        sample_inputs_func=sample_inputs_householder_product,
1546        decorators=[
1547            skipCUDAIfNoCusolver,
1548            skipCPUIfNoLapack,
1549            DecorateInfo(
1550                toleranceOverride({torch.complex64: tol(atol=1e-3, rtol=1e-3)})
1551            ),
1552            DecorateInfo(
1553                unittest.skip("Skipped! Flaky"),
1554                "TestFwdGradients",
1555                "test_fn_fwgrad_bwgrad",
1556                device_type="cpu",
1557                dtypes=(torch.complex128,),
1558            ),
1559        ],
1560    ),
1561    OpInfo(
1562        "linalg.ldl_factor",
1563        aten_name="linalg_ldl_factor",
1564        dtypes=floating_and_complex_types(),
1565        supports_autograd=False,
1566        sample_inputs_func=sample_inputs_linalg_ldl_factor,
1567        decorators=[skipCUDAIfNoMagmaAndNoLinalgsolver, skipCPUIfNoLapack],
1568    ),
1569    OpInfo(
1570        "linalg.ldl_factor_ex",
1571        aten_name="linalg_ldl_factor_ex",
1572        dtypes=floating_and_complex_types(),
1573        supports_autograd=False,
1574        sample_inputs_func=sample_inputs_linalg_ldl_factor,
1575        decorators=[skipCUDAIfNoMagmaAndNoLinalgsolver, skipCPUIfNoLapack],
1576    ),
1577    OpInfo(
1578        "linalg.ldl_solve",
1579        aten_name="linalg_ldl_solve",
1580        dtypes=floating_and_complex_types(),
1581        supports_autograd=False,
1582        sample_inputs_func=sample_inputs_linalg_ldl_solve,
1583        decorators=[
1584            skipCUDAIf(
1585                _get_torch_cuda_version() < (11, 4), "not available before CUDA 11.3.1"
1586            ),
1587            skipCUDAIfNoCusolver,
1588            skipCUDAIfRocm,
1589            skipCPUIfNoLapack,
1590        ],
1591    ),
1592    OpInfo(
1593        "linalg.lstsq",
1594        aten_name="linalg_lstsq",
1595        dtypes=floating_and_complex_types(),
1596        supports_out=True,
1597        sample_inputs_func=sample_inputs_linalg_lstsq,
1598        error_inputs_func=error_inputs_lstsq,
1599        decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack],
1600        skips=(
1601            # we skip gradient checks for this suite as they are tested in
1602            # variant_test_name='grad_oriented'
1603            DecorateInfo(unittest.skip("Skipped!"), "TestFwdGradients"),
1604            DecorateInfo(unittest.skip("Skipped!"), "TestBwdGradients"),
1605            # The values for attribute 'shape' do not match
1606            DecorateInfo(unittest.skip("Skipped!"), "TestCommon", "test_out"),
1607            DecorateInfo(
1608                unittest.skip("Skipped!"),
1609                "TestCommon",
1610                "test_out",
1611                device_type="mps",
1612                dtypes=[torch.float32],
1613            ),
1614            DecorateInfo(
1615                unittest.skip("Skipped!"),
1616                "TestCommon",
1617                "test_variant_consistency_eager",
1618                device_type="mps",
1619                dtypes=[torch.float32],
1620            ),
1621            DecorateInfo(
1622                unittest.skip("Skipped!"),
1623                "TestJit",
1624                "test_variant_consistency_jit",
1625                device_type="mps",
1626                dtypes=[torch.float32],
1627            ),
1628        ),
1629    ),
1630    OpInfo(
1631        "linalg.lstsq",
1632        aten_name="linalg_lstsq",
1633        variant_test_name="grad_oriented",
1634        # gradchecks for forward AD fails with multi-Tensor outputs
1635        op=lambda a, b, driver: torch.linalg.lstsq(a, b, driver=driver)[0],
1636        supports_out=False,
1637        dtypes=floating_and_complex_types(),
1638        sample_inputs_func=sample_inputs_linalg_lstsq,
1639        error_inputs_func=error_inputs_lstsq_grad_oriented,
1640        # Runs very slowly on slow gradcheck - alternatively reduce input sizes
1641        gradcheck_fast_mode=True,
1642        supports_autograd=True,
1643        supports_forward_ad=True,
1644        supports_fwgrad_bwgrad=True,
1645        decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack],
1646        skips=(
1647            # tests do not work with passing lambda for op
1648            DecorateInfo(
1649                unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
1650            ),
1651            DecorateInfo(
1652                unittest.expectedFailure,
1653                "TestOperatorSignatures",
1654                "test_get_torch_func_signature_exhaustive",
1655            ),
1656        ),
1657    ),
1658    OpInfo(
1659        "linalg.matrix_power",
1660        aliases=("matrix_power",),
1661        aten_name="linalg_matrix_power",
1662        dtypes=floating_and_complex_types(),
1663        # https://github.com/pytorch/pytorch/issues/80411
1664        gradcheck_fast_mode=True,
1665        supports_inplace_autograd=False,
1666        supports_forward_ad=True,
1667        supports_fwgrad_bwgrad=True,
1668        check_batched_grad=False,
1669        decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack, with_tf32_off],
1670        sample_inputs_func=sample_inputs_linalg_matrix_power,
1671    ),
1672    OpInfo(
1673        "linalg.multi_dot",
1674        # Need this lambda because gradcheck does not work with TensorList inputs
1675        aten_name="linalg_multi_dot",
1676        dtypes=all_types_and_complex_and(torch.half, torch.bfloat16),
1677        dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16),
1678        supports_inplace_autograd=False,
1679        # Batched grad checks fail for empty input tensors (see https://github.com/pytorch/pytorch/issues/53407)
1680        check_batched_grad=False,
1681        check_batched_gradgrad=False,
1682        supports_forward_ad=True,
1683        supports_fwgrad_bwgrad=True,
1684        # https://github.com/pytorch/pytorch/issues/66357
1685        check_batched_forward_grad=False,
1686        sample_inputs_func=sample_inputs_linalg_multi_dot,
1687        gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
1688        skips=(
1689            # https://github.com/pytorch/pytorch/issues/67470
1690            DecorateInfo(
1691                unittest.skip("67470!"), "TestCommon", "test_noncontiguous_samples"
1692            ),
1693            # Fails on XLA.
1694            # AssertionError: False is not true : Tensors failed to compare as equal!
1695            DecorateInfo(
1696                unittest.skip("Skipped!"),
1697                "TestOpInfo",
1698                device_type="xla",
1699                dtypes=(torch.long,),
1700            ),
1701            # https://github.com/pytorch/pytorch/issues/71774
1702            DecorateInfo(
1703                unittest.skip("Skipped!"),
1704                "TestNNCOpInfo",
1705                "test_nnc_correctness",
1706                device_type="cpu",
1707                dtypes=(torch.long,),
1708            ),
1709        ),
1710    ),
1711    # NB: linalg.norm has two variants so that different skips can be used for different sample inputs
1712    OpInfo(
1713        "linalg.norm",
1714        aten_name="linalg_norm",
1715        op=torch.linalg.norm,
1716        dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16),
1717        decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack, with_tf32_off],
1718        sample_inputs_func=sample_inputs_linalg_norm,
1719        supports_forward_ad=True,
1720        check_batched_forward_grad=False,
1721        supports_fwgrad_bwgrad=True,
1722        skips=(
1723            DecorateInfo(
1724                unittest.expectedFailure, "TestBwdGradients", "test_fn_gradgrad"
1725            ),
1726            DecorateInfo(
1727                unittest.skip("Skipped!"),
1728                "TestFakeTensor",
1729                "test_fake_crossref_backward_amp",
1730                device_type="cuda",
1731                dtypes=[torch.float32],
1732                active_if=TEST_WITH_ROCM,
1733            ),
1734            DecorateInfo(
1735                unittest.skip("Skipped!"),
1736                "TestFakeTensor",
1737                "test_fake_crossref_backward_no_amp",
1738                device_type="cuda",
1739                dtypes=[torch.float32],
1740                active_if=TEST_WITH_ROCM,
1741            ),
1742        ),
1743    ),
1744    OpInfo(
1745        "linalg.norm",
1746        op=torch.linalg.norm,
1747        variant_test_name="subgradients_at_zero",
1748        dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16),
1749        decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack, with_tf32_off],
1750        sample_inputs_func=partial(
1751            sample_inputs_linalg_norm, variant="subgradient_at_zero"
1752        ),
1753        aten_name="linalg_norm",
1754        supports_forward_ad=True,
1755        # torch.autograd.gradcheck.GradcheckError: While computing batched gradients, got:
1756        # Could not allocate memory to change Tensor SizesAndStrides!
1757        check_batched_forward_grad=False,
1758        supports_fwgrad_bwgrad=True,
1759        skips=(
1760            # [NEW] Skips specifically for sample inputs at zero
1761            # norm's vjp/jvp are not well-conditioned near zero
1762            DecorateInfo(
1763                unittest.expectedFailure, "TestBwdGradients", "test_fn_gradgrad"
1764            ),
1765            DecorateInfo(
1766                unittest.expectedFailure, "TestFwdGradients", "test_fn_fwgrad_bwgrad"
1767            ),
1768            DecorateInfo(
1769                unittest.expectedFailure, "TestFwdGradients", "test_forward_mode_AD"
1770            ),
1771            DecorateInfo(unittest.expectedFailure, "TestBwdGradients", "test_fn_grad"),
1772        ),
1773    ),
1774    OpInfo(
1775        "linalg.matrix_norm",
1776        aten_name="linalg_matrix_norm",
1777        dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16),
1778        supports_forward_ad=True,
1779        check_batched_forward_grad=False,
1780        check_batched_gradgrad=False,
1781        supports_fwgrad_bwgrad=True,
1782        decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack, with_tf32_off],
1783        sample_inputs_func=sample_inputs_linalg_matrix_norm,
1784        skips=(
1785            DecorateInfo(
1786                unittest.skip("Skipped!"),
1787                "TestFakeTensor",
1788                "test_fake_crossref_backward_amp",
1789                device_type="cuda",
1790                dtypes=[torch.float32],
1791                active_if=TEST_WITH_ROCM,
1792            ),
1793            DecorateInfo(
1794                unittest.skip("Skipped!"),
1795                "TestFakeTensor",
1796                "test_fake_crossref_backward_no_amp",
1797                device_type="cuda",
1798                dtypes=[torch.float32],
1799                active_if=TEST_WITH_ROCM,
1800            ),
1801        ),
1802    ),
1803    OpInfo(
1804        "linalg.qr",
1805        aten_name="linalg_qr",
1806        op=torch.linalg.qr,
1807        dtypes=floating_and_complex_types(),
1808        supports_forward_ad=True,
1809        supports_fwgrad_bwgrad=True,
1810        # In-place ops
1811        check_batched_gradgrad=False,
1812        sample_inputs_func=sample_inputs_linalg_qr_geqrf,
1813        decorators=[skipCUDAIfNoCusolver, skipCPUIfNoLapack],
1814    ),
1815    OpInfo(
1816        "linalg.slogdet",
1817        aten_name="linalg_slogdet",
1818        op=torch.linalg.slogdet,
1819        dtypes=floating_and_complex_types(),
1820        supports_forward_ad=True,
1821        supports_fwgrad_bwgrad=True,
1822        sample_inputs_func=sample_inputs_linalg_det_logdet_slogdet,
1823        decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack],
1824    ),
1825    OpInfo(
1826        "linalg.vander",
1827        aten_name="linalg_vander",
1828        ref=np_vander_batched,
1829        op=torch.linalg.vander,
1830        dtypes=all_types_and_complex(),
1831        supports_forward_ad=True,
1832        supports_fwgrad_bwgrad=True,
1833        supports_out=False,
1834        sample_inputs_func=sample_inputs_linalg_vander,
1835        skips=(
1836            DecorateInfo(
1837                unittest.skip("Unsupported on MPS for now"),
1838                "TestCommon",
1839                "test_numpy_ref_mps",
1840            ),
1841        ),
1842    ),
1843    ReductionOpInfo(
1844        "linalg.vector_norm",
1845        op=torch.linalg.vector_norm,
1846        identity=0,
1847        nan_policy="propagate",
1848        supports_multiple_dims=True,
1849        complex_to_real=True,
1850        supports_forward_ad=True,
1851        # torch.autograd.gradcheck.GradcheckError: While computing batched gradients
1852        # got: Could not allocate memory to change Tensor SizesAndStrides!
1853        check_batched_forward_grad=False,
1854        supports_fwgrad_bwgrad=True,
1855        dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16),
1856        generate_args_kwargs=sample_kwargs_vector_norm,
1857        aten_name="linalg_vector_norm",
1858        skips=(
1859            # FIXME: sum reduces all dimensions when dim=[]
1860            DecorateInfo(unittest.expectedFailure, "TestReductions", "test_dim_empty"),
1861            DecorateInfo(
1862                unittest.expectedFailure, "TestReductions", "test_dim_empty_keepdim"
1863            ),
1864        ),
1865    ),
1866    OpInfo(
1867        "linalg.lu_factor",
1868        aten_name="linalg_lu_factor",
1869        op=torch.linalg.lu_factor,
1870        dtypes=floating_and_complex_types(),
1871        # Runs very slowly on slow gradcheck - alternatively reduce input sizes
1872        # https://github.com/pytorch/pytorch/issues/80411
1873        gradcheck_fast_mode=True,
1874        supports_forward_ad=True,
1875        supports_fwgrad_bwgrad=True,
1876        sample_inputs_func=sample_inputs_linalg_lu,
1877        decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack],
1878        skips=(
1879            # linalg.lu_factor: LU without pivoting is not implemented on the CPU
1880            DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"),
1881        ),
1882    ),
1883    OpInfo(
1884        "linalg.lu_factor_ex",
1885        aten_name="linalg_lu_factor_ex",
1886        op=torch.linalg.lu_factor_ex,
1887        dtypes=floating_and_complex_types(),
1888        # https://github.com/pytorch/pytorch/issues/80411
1889        gradcheck_fast_mode=True,
1890        supports_forward_ad=True,
1891        supports_fwgrad_bwgrad=True,
1892        sample_inputs_func=sample_inputs_linalg_lu,
1893        decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack],
1894        skips=(
1895            # linalg.lu_factor: LU without pivoting is not implemented on the CPU
1896            DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"),
1897        ),
1898    ),
1899    OpInfo(
1900        "linalg.lu",
1901        aten_name="linalg_lu",
1902        op=torch.linalg.lu,
1903        dtypes=floating_and_complex_types(),
1904        # https://github.com/pytorch/pytorch/issues/80411
1905        # Runs very slowly on slow-gradcheck - alternatively reduce input sizes
1906        gradcheck_fast_mode=True,
1907        supports_forward_ad=True,
1908        supports_fwgrad_bwgrad=True,
1909        sample_inputs_func=sample_inputs_linalg_lu,
1910        decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack],
1911        skips=(
1912            # linalg.lu_factor: LU without pivoting is not implemented on the CPU
1913            DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"),
1914        ),
1915    ),
1916    OpInfo(
1917        "linalg.lu_solve",
1918        op=torch.linalg.lu_solve,
1919        aten_name="linalg_lu_solve",
1920        dtypes=floating_and_complex_types(),
1921        # Runs very slowly on slow gradcheck - alternatively reduce input sizes
1922        gradcheck_fast_mode=True,
1923        supports_forward_ad=True,
1924        check_batched_forward_grad=False,
1925        supports_fwgrad_bwgrad=True,
1926        sample_inputs_func=sample_inputs_lu_solve,
1927        skips=(
1928            DecorateInfo(
1929                unittest.skip("Tests different backward paths"),
1930                "TestCommon",
1931                "test_floating_inputs_are_differentiable",
1932            ),
1933        ),
1934        decorators=[skipCPUIfNoLapack, skipCUDAIfNoMagmaAndNoCusolver],
1935    ),
1936    OpInfo(
1937        "linalg.inv",
1938        aten_name="linalg_inv",
1939        op=torch.linalg.inv,
1940        aliases=("inverse",),
1941        dtypes=floating_and_complex_types(),
1942        sample_inputs_func=sample_inputs_linalg_invertible,
1943        check_batched_gradgrad=False,
1944        supports_forward_ad=True,
1945        supports_fwgrad_bwgrad=True,
1946        decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack],
1947        skips=(
1948            DecorateInfo(
1949                unittest.skip("Skipped!"),
1950                "TestCommon",
1951                "test_out",
1952                device_type="mps",
1953                dtypes=[torch.float32],
1954            ),
1955            DecorateInfo(
1956                unittest.skip("Skipped!"),
1957                "TestCommon",
1958                "test_variant_consistency_eager",
1959                device_type="mps",
1960                dtypes=[torch.float32],
1961            ),
1962            DecorateInfo(
1963                unittest.skip("Skipped!"),
1964                "TestJit",
1965                "test_variant_consistency_jit",
1966                device_type="mps",
1967                dtypes=[torch.float32],
1968            ),
1969        ),
1970    ),
1971    OpInfo(
1972        "linalg.inv_ex",
1973        aten_name="linalg_inv_ex",
1974        op=torch.linalg.inv_ex,
1975        dtypes=floating_and_complex_types(),
1976        sample_inputs_func=sample_inputs_linalg_invertible,
1977        check_batched_gradgrad=False,
1978        supports_forward_ad=True,
1979        supports_fwgrad_bwgrad=True,
1980        decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack],
1981        skips=(
1982            DecorateInfo(
1983                unittest.skip("Skipped!"),
1984                "TestCommon",
1985                "test_out",
1986                device_type="mps",
1987                dtypes=[torch.float32],
1988            ),
1989            DecorateInfo(
1990                unittest.skip("Skipped!"),
1991                "TestCommon",
1992                "test_variant_consistency_eager",
1993                device_type="mps",
1994                dtypes=[torch.float32],
1995            ),
1996            DecorateInfo(
1997                unittest.skip("Skipped!"),
1998                "TestJit",
1999                "test_variant_consistency_jit",
2000                device_type="mps",
2001                dtypes=[torch.float32],
2002            ),
2003        ),
2004    ),
2005    OpInfo(
2006        "linalg.solve",
2007        aten_name="linalg_solve",
2008        op=torch.linalg.solve,
2009        dtypes=floating_and_complex_types(),
2010        sample_inputs_func=sample_inputs_linalg_solve,
2011        # Runs very slowly on slow gradcheck - alternatively reduce input sizes
2012        gradcheck_fast_mode=True,
2013        supports_forward_ad=True,
2014        supports_fwgrad_bwgrad=True,
2015        decorators=[
2016            skipCUDAIfNoMagmaAndNoCusolver,
2017            skipCPUIfNoLapack,
2018            DecorateInfo(
2019                toleranceOverride({torch.float32: tol(atol=1.3e-05, rtol=6e-04)}),
2020                "TestCommon",
2021                "test_noncontiguous_samples",
2022                device_type="cpu",
2023            ),
2024        ],
2025        skips=(
2026            DecorateInfo(
2027                unittest.skip("Skipped!"),
2028                "TestCommon",
2029                "test_out",
2030                device_type="mps",
2031                dtypes=[torch.float32],
2032            ),
2033            DecorateInfo(
2034                unittest.skip("Skipped!"),
2035                "TestCommon",
2036                "test_variant_consistency_eager",
2037                device_type="mps",
2038                dtypes=[torch.float32],
2039            ),
2040            DecorateInfo(
2041                unittest.skip("Skipped!"),
2042                "TestJit",
2043                "test_variant_consistency_jit",
2044                device_type="mps",
2045                dtypes=[torch.float32],
2046            ),
2047        ),
2048    ),
2049    OpInfo(
2050        "linalg.solve_ex",
2051        aten_name="linalg_solve_ex",
2052        op=torch.linalg.solve_ex,
2053        dtypes=floating_and_complex_types(),
2054        sample_inputs_func=sample_inputs_linalg_solve,
2055        supports_forward_ad=True,
2056        supports_fwgrad_bwgrad=True,
2057        decorators=[
2058            skipCUDAIfNoMagmaAndNoCusolver,
2059            skipCPUIfNoLapack,
2060            DecorateInfo(
2061                toleranceOverride({torch.float32: tol(atol=1.3e-05, rtol=6e-04)}),
2062                "TestCommon",
2063                "test_noncontiguous_samples",
2064                device_type="cpu",
2065            ),
2066        ],
2067        skips=(
2068            DecorateInfo(
2069                unittest.skip("Skipped!"),
2070                "TestCommon",
2071                "test_out",
2072                device_type="mps",
2073                dtypes=[torch.float32],
2074            ),
2075            DecorateInfo(
2076                unittest.skip("Skipped!"),
2077                "TestCommon",
2078                "test_variant_consistency_eager",
2079                device_type="mps",
2080                dtypes=[torch.float32],
2081            ),
2082            DecorateInfo(
2083                unittest.skip("Skipped!"),
2084                "TestJit",
2085                "test_variant_consistency_jit",
2086                device_type="mps",
2087                dtypes=[torch.float32],
2088            ),
2089        ),
2090    ),
2091    OpInfo(
2092        "linalg.solve_triangular",
2093        aten_name="linalg_solve_triangular",
2094        op=torch.linalg.solve_triangular,
2095        dtypes=floating_and_complex_types(),
2096        sample_inputs_func=sample_inputs_linalg_solve_triangular,
2097        supports_fwgrad_bwgrad=True,
2098        skips=(skipCPUIfNoLapack,),
2099        # linalg.solve_triangular cannot be batched over because of a call to out.copy_(result);
2100        supports_forward_ad=True,
2101    ),
2102    OpInfo(
2103        "linalg.matrix_rank",
2104        aten_name="linalg_matrix_rank",
2105        dtypes=floating_and_complex_types(),
2106        supports_autograd=False,
2107        sample_inputs_func=sample_inputs_matrix_rank,
2108        decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack],
2109        skips=(
2110            DecorateInfo(
2111                unittest.skip("Skipped!"),
2112                "TestCommon",
2113                "test_out",
2114                device_type="mps",
2115                dtypes=[torch.float32],
2116            ),
2117            DecorateInfo(
2118                unittest.skip("Skipped!"),
2119                "TestCommon",
2120                "test_variant_consistency_eager",
2121                device_type="mps",
2122                dtypes=[torch.float32],
2123            ),
2124            # jit doesn't accept tensor inputs for matrix rank
2125            DecorateInfo(
2126                unittest.skip("Skipped!"),
2127                "TestJit",
2128                "test_variant_consistency_jit",
2129                dtypes=[torch.complex64, torch.float32],
2130            ),
2131        ),
2132    ),
2133    OpInfo(
2134        "linalg.matrix_rank",
2135        aten_name="linalg_matrix_rank",
2136        variant_test_name="hermitian",
2137        dtypes=floating_and_complex_types(),
2138        supports_autograd=False,
2139        sample_inputs_func=sample_inputs_linalg_pinv_hermitian,
2140        decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack],
2141        skips=(
2142            DecorateInfo(
2143                unittest.skip("Skipped!"),
2144                "TestCommon",
2145                "test_out",
2146                device_type="mps",
2147                dtypes=[torch.float32],
2148            ),
2149            DecorateInfo(
2150                unittest.skip("Skipped!"),
2151                "TestJit",
2152                "test_variant_consistency_jit",
2153                device_type="mps",
2154                dtypes=[torch.float32],
2155            ),
2156        ),
2157    ),
2158    OpInfo(
2159        "linalg.pinv",
2160        aten_name="linalg_pinv",
2161        op=torch.linalg.pinv,
2162        dtypes=floating_and_complex_types(),
2163        # Runs very slowly on slow gradcheck - alternatively reduce input sizes
2164        gradcheck_fast_mode=True,
2165        check_batched_grad=False,
2166        check_batched_gradgrad=False,
2167        supports_forward_ad=True,
2168        supports_fwgrad_bwgrad=True,
2169        sample_inputs_func=sample_inputs_linalg_pinv,
2170        decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack],
2171        skips=(
2172            # errors with "leaked XXXX bytes CUDA memory on device 0"
2173            DecorateInfo(
2174                unittest.skip("Skipped!"),
2175                "TestJit",
2176                "test_variant_consistency_jit",
2177                device_type="cuda",
2178            ),
2179        ),
2180    ),
2181    OpInfo(
2182        "linalg.pinv",
2183        aten_name="linalg_pinv",
2184        variant_test_name="singular",
2185        # pinv is Frechet-differentiable in a rank-preserving neighborhood,
2186        # so we feed inputs that are the products of two full-rank factors,
2187        # to avoid any rank changes caused by the perturbations in the gradcheck
2188        op=lambda a, b: torch.linalg.pinv(a @ b.mT),
2189        dtypes=floating_and_complex_types(),
2190        supports_out=False,
2191        check_batched_grad=False,
2192        check_batched_gradgrad=False,
2193        supports_forward_ad=True,
2194        supports_fwgrad_bwgrad=True,
2195        sample_inputs_func=sample_inputs_linalg_pinv_singular,
2196        # Only large tensors show issues with implicit backward used prior to
2197        # explicit backward implementation.
2198        decorators=[slowTest, skipCUDAIfNoCusolver, skipCPUIfNoLapack],
2199        skips=(
2200            DecorateInfo(
2201                unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
2202            ),
2203            # CUDA runs out of memory
2204            DecorateInfo(
2205                unittest.skip("Skipped!"),
2206                "TestFwdGradients",
2207                "test_fn_fwgrad_bwgrad",
2208                device_type="cuda",
2209                dtypes=[torch.cdouble],
2210            ),
2211            # This test takes almost 2 hours to run!
2212            DecorateInfo(
2213                unittest.skip("Skipped!"),
2214                "TestBwdGradients",
2215                "test_fn_gradgrad",
2216                device_type="cuda",
2217                dtypes=[torch.cdouble],
2218            ),
2219        ),
2220    ),
2221    OpInfo(
2222        "linalg.pinv",
2223        aten_name="linalg_pinv",
2224        variant_test_name="hermitian",
2225        dtypes=floating_and_complex_types(),
2226        check_batched_grad=False,
2227        check_batched_gradgrad=False,
2228        supports_forward_ad=True,
2229        supports_fwgrad_bwgrad=True,
2230        # See https://github.com/pytorch/pytorch/pull/78358
2231        check_batched_forward_grad=False,
2232        sample_inputs_func=sample_inputs_linalg_pinv_hermitian,
2233        gradcheck_wrapper=gradcheck_wrapper_hermitian_input,
2234        decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack],
2235        skips=(
2236            DecorateInfo(
2237                unittest.skip("Skipped!"),
2238                "TestCommon",
2239                "test_out",
2240                device_type="mps",
2241                dtypes=[torch.float32],
2242            ),
2243            DecorateInfo(
2244                unittest.skip("Skipped!"),
2245                "TestCommon",
2246                "test_variant_consistency_eager",
2247                device_type="mps",
2248                dtypes=[torch.float32],
2249            ),
2250            DecorateInfo(
2251                unittest.skip("Skipped!"),
2252                "TestJit",
2253                "test_variant_consistency_jit",
2254                device_type="mps",
2255                dtypes=[torch.float32],
2256            ),
2257            DecorateInfo(
2258                toleranceOverride({torch.float32: tol(atol=1e-5, rtol=1e-5)}),
2259                "TestCommon",
2260                "test_noncontiguous_samples",
2261                device_type="cuda",
2262            ),
2263            # This test is flaky under slow gradcheck, likely due to rounding issues
2264            DecorateInfo(
2265                skipIfSlowGradcheckEnv,
2266                "TestFwdGradients",
2267                "test_fn_fwgrad_bwgrad",
2268                device_type="cuda",
2269            ),
2270        ),
2271    ),
2272    OpInfo(
2273        "linalg.svd",
2274        op=torch.linalg.svd,
2275        aten_name="linalg_svd",
2276        decomp_aten_name="_linalg_svd",
2277        dtypes=floating_and_complex_types(),
2278        # Runs very slowly on slow-gradcheck - alternatively reduce input sizes
2279        gradcheck_fast_mode=True,
2280        supports_fwgrad_bwgrad=True,
2281        supports_forward_ad=True,
2282        check_batched_forward_grad=False,
2283        # We're using at::allclose, which does not have a batching rule
2284        check_batched_grad=False,
2285        check_batched_gradgrad=False,
2286        sample_inputs_func=sample_inputs_svd,
2287        decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack, with_tf32_off],
2288        skips=(
2289            DecorateInfo(
2290                unittest.skip("Skipped!"),
2291                "TestCommon",
2292                "test_out",
2293                device_type="mps",
2294                dtypes=[torch.float32],
2295            ),
2296            DecorateInfo(
2297                unittest.skip("Skipped!"),
2298                "TestCommon",
2299                "test_variant_consistency_eager",
2300                device_type="mps",
2301                dtypes=[torch.float32],
2302            ),
2303            DecorateInfo(
2304                unittest.skip("Skipped!"),
2305                "TestJit",
2306                "test_variant_consistency_jit",
2307                device_type="mps",
2308                dtypes=[torch.float32],
2309            ),
2310            DecorateInfo(
2311                unittest.skip("Skipped!"),
2312                "TestFakeTensor",
2313                "test_fake_crossref_backward_amp",
2314                device_type="cuda",
2315                dtypes=[torch.float32],
2316                active_if=TEST_WITH_ROCM,
2317            ),
2318            DecorateInfo(
2319                unittest.skip("Skipped!"),
2320                "TestFakeTensor",
2321                "test_fake_crossref_backward_no_amp",
2322                device_type="cuda",
2323                dtypes=[torch.float32],
2324                active_if=TEST_WITH_ROCM,
2325            ),
2326        ),
2327    ),
2328    OpInfo(
2329        "linalg.svdvals",
2330        op=torch.linalg.svdvals,
2331        aten_name="linalg_svdvals",
2332        decomp_aten_name="_linalg_svd",
2333        dtypes=floating_and_complex_types(),
2334        check_batched_forward_grad=False,
2335        supports_fwgrad_bwgrad=True,
2336        supports_forward_ad=True,
2337        # We're using at::allclose, which does not have a batching rule
2338        check_batched_gradgrad=False,
2339        sample_inputs_func=sample_inputs_linalg_svdvals,
2340        decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack, with_tf32_off],
2341        skips=(
2342            DecorateInfo(
2343                unittest.skip("Skipped!"),
2344                "TestFakeTensor",
2345                "test_fake_crossref_backward_amp",
2346                device_type="cuda",
2347                dtypes=[torch.float32],
2348                active_if=TEST_WITH_ROCM,
2349            ),
2350            DecorateInfo(
2351                unittest.skip("Skipped!"),
2352                "TestFakeTensor",
2353                "test_fake_crossref_backward_no_amp",
2354                device_type="cuda",
2355                dtypes=[torch.float32],
2356                active_if=TEST_WITH_ROCM,
2357            ),
2358        ),
2359    ),
2360    OpInfo(
2361        "linalg.tensorinv",
2362        ref=np.linalg.tensorinv,
2363        dtypes=floating_and_complex_types(),
2364        sample_inputs_func=sample_inputs_tensorinv,
2365        supports_forward_ad=True,
2366        supports_fwgrad_bwgrad=True,
2367        # See https://github.com/pytorch/pytorch/pull/78358
2368        check_batched_forward_grad=False,
2369        decorators=[skipCPUIfNoLapack, skipCUDAIfNoMagmaAndNoCusolver],
2370        skips=(
2371            DecorateInfo(
2372                unittest.skip("Unsupported on MPS for now"),
2373                "TestCommon",
2374                "test_numpy_ref_mps",
2375            ),
2376        ),
2377    ),
2378    OpInfo(
2379        "linalg.tensorsolve",
2380        ref=lambda a, b, dims=None: np.linalg.tensorsolve(a, b, axes=dims),
2381        dtypes=floating_and_complex_types(),
2382        sample_inputs_func=sample_inputs_tensorsolve,
2383        supports_forward_ad=True,
2384        supports_fwgrad_bwgrad=True,
2385        decorators=[
2386            skipCUDAIfNoMagmaAndNoCusolver,
2387            skipCPUIfNoLapack,
2388            DecorateInfo(
2389                toleranceOverride({torch.float32: tol(atol=1e-03, rtol=1e-03)}),
2390                "TestCommon",
2391                "test_noncontiguous_samples",
2392                device_type="cuda",
2393            ),
2394            DecorateInfo(
2395                toleranceOverride({torch.float32: tol(atol=8e-04, rtol=7e-06)}),
2396                "TestCommon",
2397                "test_noncontiguous_samples",
2398                device_type="cpu",
2399            ),
2400        ],
2401        skips=(
2402            DecorateInfo(
2403                unittest.skip("Unsupported on MPS for now"),
2404                "TestCommon",
2405                "test_numpy_ref_mps",
2406            ),
2407        ),
2408    ),
2409]
2410
2411python_ref_db: List[OpInfo] = [
2412    #
2413    # torch.linalg
2414    #
2415    PythonRefInfo(
2416        "_refs.linalg.cross",
2417        torch_opinfo_name="linalg.cross",
2418        supports_out=True,
2419        op_db=op_db,
2420        skips=(
2421            # TODO: is this really needed?
2422            DecorateInfo(
2423                unittest.expectedFailure, "TestCommon", "test_python_ref_errors"
2424            ),
2425        ),
2426    ),
2427    PythonRefInfo(
2428        "_refs.linalg.diagonal",
2429        torch_opinfo_name="linalg.diagonal",
2430        supports_out=False,
2431        op_db=op_db,
2432    ),
2433    PythonRefInfo(
2434        "_refs.linalg.vecdot",
2435        torch_opinfo_name="linalg.vecdot",
2436        op_db=op_db,
2437    ),
2438    ReductionPythonRefInfo(
2439        "_refs.linalg.vector_norm",
2440        torch_opinfo_name="linalg.vector_norm",
2441        supports_out=True,
2442        op_db=op_db,
2443        skips=(
2444            # FIXME: sum reduces all dimensions when dim=[]
2445            DecorateInfo(unittest.expectedFailure, "TestReductions", "test_dim_empty"),
2446            DecorateInfo(
2447                unittest.expectedFailure, "TestReductions", "test_dim_empty_keepdim"
2448            ),
2449        ),
2450    ),
2451    PythonRefInfo(
2452        "_refs.linalg.matrix_norm",
2453        torch_opinfo_name="linalg.matrix_norm",
2454        supports_out=True,
2455        # Uses vector_norm inside and vector_norm is affected by
2456        # https://github.com/pytorch/pytorch/issues/77216
2457        validate_view_consistency=False,
2458        op_db=op_db,
2459    ),
2460    PythonRefInfo(
2461        "_refs.linalg.norm",
2462        torch_opinfo_name="linalg.norm",
2463        supports_out=True,
2464        # Uses vector_norm inside and vector_norm is affected by
2465        # https://github.com/pytorch/pytorch/issues/77216
2466        validate_view_consistency=False,
2467        op_db=op_db,
2468    ),
2469    PythonRefInfo(
2470        "_refs.linalg.svd",
2471        torch_opinfo_name="linalg.svd",
2472        supports_out=True,
2473        op_db=op_db,
2474    ),
2475    PythonRefInfo(
2476        "_refs.linalg.svdvals",
2477        torch_opinfo_name="linalg.svdvals",
2478        supports_out=True,
2479        op_db=op_db,
2480    ),
2481]
2482