1# mypy: ignore-errors 2 3import itertools 4import random 5import unittest 6from functools import partial 7from itertools import chain, product 8from typing import Iterable, List, Tuple 9 10import numpy as np 11from numpy import inf 12 13import torch 14from torch.testing import make_tensor 15from torch.testing._internal.common_cuda import ( 16 _get_magma_version, 17 _get_torch_cuda_version, 18 with_tf32_off, 19) 20from torch.testing._internal.common_device_type import ( 21 has_cusolver, 22 skipCPUIfNoLapack, 23 skipCUDAIf, 24 skipCUDAIfNoCusolver, 25 skipCUDAIfNoMagma, 26 skipCUDAIfNoMagmaAndNoCusolver, 27 skipCUDAIfNoMagmaAndNoLinalgsolver, 28 skipCUDAIfRocm, 29 tol, 30 toleranceOverride, 31) 32from torch.testing._internal.common_dtype import ( 33 all_types_and_complex, 34 all_types_and_complex_and, 35 floating_and_complex_types, 36 floating_and_complex_types_and, 37 get_all_complex_dtypes, 38) 39from torch.testing._internal.common_utils import ( 40 GRADCHECK_NONDET_TOL, 41 IS_MACOS, 42 make_fullrank_matrices_with_distinct_singular_values, 43 skipIfSlowGradcheckEnv, 44 slowTest, 45 TEST_WITH_ROCM, 46) 47from torch.testing._internal.opinfo.core import ( 48 clone_sample, 49 DecorateInfo, 50 ErrorInput, 51 gradcheck_wrapper_hermitian_input, 52 L, 53 M, 54 OpInfo, 55 ReductionOpInfo, 56 S, 57 SampleInput, 58) 59from torch.testing._internal.opinfo.refs import PythonRefInfo, ReductionPythonRefInfo 60 61 62def sample_kwargs_vector_norm(t, **kwargs): 63 # orders with / without identity 64 def ords(): 65 has_id = (6, 4, 2, 1, 0, 0.9) 66 no_id = (inf, -2.1, -inf) 67 if t.numel() == 0: 68 dim = kwargs.get("dim") 69 if dim is None: 70 return has_id 71 if not isinstance(dim, Iterable): 72 dim = (dim,) 73 for d in dim: 74 if t.size(d) == 0: 75 return has_id 76 return has_id + no_id 77 78 return (((), dict(ord=o)) for o in ords()) 79 80 81def sample_inputs_svd(op_info, device, dtype, requires_grad=False, **kwargs): 82 make_fullrank = make_fullrank_matrices_with_distinct_singular_values 83 make_arg = partial( 84 make_fullrank, dtype=dtype, device=device, requires_grad=requires_grad 85 ) 86 87 is_linalg_svd = "linalg.svd" in op_info.name 88 batches = [(), (0,), (3,)] 89 ns = [0, 3, 5] 90 91 def uniformize(usv): 92 S = usv[1] 93 k = S.shape[-1] 94 U = usv[0][..., :k] 95 Vh = usv[2] if is_linalg_svd else usv[2].mH 96 Vh = Vh[..., :k, :] 97 return U, S, Vh 98 99 def fn_U(usv): 100 U, _, _ = uniformize(usv) 101 return U.abs() 102 103 def fn_S(usv): 104 return uniformize(usv)[1] 105 106 def fn_Vh(usv): 107 # We also return S to test 108 _, S, Vh = uniformize(usv) 109 return S, Vh.abs() 110 111 def fn_UVh(usv): 112 U, S, Vh = uniformize(usv) 113 return U @ Vh, S 114 115 fns = (fn_U, fn_S, fn_Vh, fn_UVh) 116 117 fullmat = "full_matrices" if is_linalg_svd else "some" 118 119 for batch, n, k, fullmat_val, fn in product(batches, ns, ns, (True, False), fns): 120 shape = batch + (n, k) 121 yield SampleInput( 122 make_arg(*shape), kwargs={fullmat: fullmat_val}, output_process_fn_grad=fn 123 ) 124 125 126def sample_inputs_cross(op_info, device, dtype, requires_grad, **kwargs): 127 make_arg = partial( 128 make_tensor, dtype=dtype, device=device, requires_grad=requires_grad 129 ) 130 yield SampleInput(make_arg((S, 3)), args=(make_arg((S, 3)),)) 131 yield SampleInput( 132 make_arg((S, 3, S)), args=(make_arg((S, 3, S)),), kwargs=dict(dim=1) 133 ) 134 yield SampleInput(make_arg((1, 3)), args=(make_arg((S, 3)),), kwargs=dict(dim=-1)) 135 136 137def error_inputs_cross(op_info, device, **kwargs): 138 make_arg = partial(make_tensor, device=device, dtype=torch.float32) 139 140 sample = SampleInput(input=make_arg((S, 3)), args=(make_arg((S, 1)),)) 141 err = "inputs dimension -1 must have length 3" 142 yield ErrorInput(sample, error_regex=err, error_type=RuntimeError) 143 144 sample = SampleInput(input=make_arg((5, S, 3)), args=(make_arg((S, 3)),)) 145 err = "inputs must have the same number of dimensions" 146 yield ErrorInput(sample, error_regex=err, error_type=RuntimeError) 147 148 sample = SampleInput(input=make_arg((S, 2)), args=(make_arg((S, 2)),)) 149 err = "must have length 3" 150 yield ErrorInput(sample, error_regex=err, error_type=RuntimeError) 151 152 sample = SampleInput( 153 input=make_arg((S, 2)), args=(make_arg((S, 2)),), kwargs=dict(dim=2) 154 ) 155 err = "Dimension out of range" 156 yield ErrorInput(sample, error_regex=err, error_type=IndexError) 157 158 159def sample_inputs_householder_product(op_info, device, dtype, requires_grad, **kwargs): 160 """ 161 This function generates input for torch.linalg.householder_product (torch.orgqr). 162 The first argument should be a square matrix or batch of square matrices, the second argument is a vector or batch of vectors. 163 Empty, square, rectangular, batched square and batched rectangular input is generated. 164 """ 165 make_arg = partial( 166 make_tensor, 167 device=device, 168 dtype=dtype, 169 requires_grad=requires_grad, 170 low=-2, 171 high=2, 172 ) 173 # Each column of the matrix is getting multiplied many times leading to very large values for 174 # the Jacobian matrix entries and making the finite-difference result of grad check less accurate. 175 # That's why gradcheck with the default range [-9, 9] fails and [-2, 2] is used here. 176 yield SampleInput(make_arg((S, S)), make_arg((S,))) 177 yield SampleInput(make_arg((S + 1, S)), make_arg((S,))) 178 yield SampleInput(make_arg((2, 1, S, S)), make_arg((2, 1, S))) 179 yield SampleInput(make_arg((2, 1, S + 1, S)), make_arg((2, 1, S))) 180 yield SampleInput( 181 make_arg((0, 0), low=None, high=None), 182 make_arg((0,), low=None, high=None), 183 ) 184 yield SampleInput(make_arg((S, S)), make_arg((0,), low=None, high=None)) 185 # m = n = S, k = S - 2 186 yield SampleInput(make_arg((S, S)), make_arg((S - 2,), low=None, high=None)) 187 # m = S, n = S -1, k = S - 2 188 yield SampleInput(make_arg((S, S - 1)), make_arg((S - 2,), low=None, high=None)) 189 190 191def sample_inputs_linalg_det_singular(op_info, device, dtype, requires_grad, **kwargs): 192 make_arg = partial(make_tensor, device=device, dtype=dtype) 193 194 def make_singular_matrix_batch_base(size, rank): 195 assert size[-1] == size[-2] 196 assert rank > 0 and rank < size[-1] 197 198 n = size[-1] 199 a = make_arg(size[:-2] + (n, rank)) / 10 200 b = make_arg(size[:-2] + (rank, n)) / 10 201 x = a @ b 202 lu, pivs, _ = torch.linalg.lu_factor_ex(x) 203 p, l, u = torch.lu_unpack(lu, pivs) 204 u_diag_abs = u.diagonal(0, -2, -1).abs() 205 u_diag_abs_largest = u_diag_abs.max(dim=-1, keepdim=True).values 206 u_diag_abs_smallest_idxs = torch.topk( 207 u_diag_abs, k=(n - rank), largest=False 208 ).indices 209 u.diagonal(0, -2, -1).div_(u_diag_abs_largest) 210 u.diagonal(0, -2, -1)[..., u_diag_abs_smallest_idxs] = torch.finfo(dtype).eps 211 matrix = p @ l @ u 212 213 matrix.requires_grad_(requires_grad) 214 return matrix 215 216 for batch, size in product(((), (2,), (2, 2)), range(6)): 217 shape = batch + (size, size) 218 for rank in range(1, size): 219 yield SampleInput(make_singular_matrix_batch_base(shape, rank)) 220 221 222def sample_inputs_linalg_matrix_power(op_info, device, dtype, requires_grad, **kwargs): 223 make_fullrank = make_fullrank_matrices_with_distinct_singular_values 224 make_arg = partial( 225 make_tensor, dtype=dtype, device=device, requires_grad=requires_grad 226 ) 227 make_arg_fullrank = partial( 228 make_fullrank, dtype=dtype, device=device, requires_grad=requires_grad 229 ) 230 # (<matrix_size>, (<batch_sizes, ...>)) 231 test_sizes = [ 232 (1, ()), 233 (2, (0,)), 234 (2, (2,)), 235 ] 236 237 for matrix_size, batch_sizes in test_sizes: 238 size = batch_sizes + (matrix_size, matrix_size) 239 for n in (0, 3, 5): 240 yield SampleInput(make_arg(size), args=(n,)) 241 for n in [-4, -2, -1]: 242 yield SampleInput(make_arg_fullrank(*size), args=(n,)) 243 244 245def sample_inputs_linalg_det_logdet_slogdet( 246 op_info, device, dtype, requires_grad, **kwargs 247): 248 make_fullrank = make_fullrank_matrices_with_distinct_singular_values 249 make_arg = partial( 250 make_fullrank, dtype=dtype, device=device, requires_grad=requires_grad 251 ) 252 batches = [(), (0,), (3,)] 253 ns = [0, 1, 5] 254 255 is_logdet = op_info.name == "logdet" 256 257 for ( 258 batch, 259 n, 260 ) in product(batches, ns): 261 shape = batch + (n, n) 262 A = make_arg(*shape) 263 # Need to make the matrices in A have positive determinant for autograd 264 # To do so, we multiply A by its determinant to flip the sign of its determinant 265 if is_logdet and not A.is_complex() and A.numel() > 0: 266 s = torch.linalg.slogdet(A).sign 267 A = A * s.unsqueeze(-1).unsqueeze(-1) 268 A.requires_grad_(requires_grad) 269 yield SampleInput(A) 270 271 272def sample_inputs_lu_solve(op_info, device, dtype, requires_grad=False, **kwargs): 273 """Samples the inputs for both linalg.lu_solve and lu_solve""" 274 make_fn = make_fullrank_matrices_with_distinct_singular_values 275 make_a = partial(make_fn, dtype=dtype, device=device) 276 make_b = partial(make_tensor, dtype=dtype, device=device) 277 278 def clone(X, requires_grad): 279 Y = X.clone() 280 Y.requires_grad_(requires_grad) 281 return Y 282 283 is_linalg_lu_solve = op_info.name == "linalg.lu_solve" 284 285 batches = ((), (0,), (2,)) 286 ns = (3, 1, 0) 287 nrhs = (4, 1, 0) 288 289 for n, batch, rhs in product(ns, batches, nrhs): 290 A = make_a(*(batch + (n, n))) 291 LU, pivots = torch.linalg.lu_factor(A) 292 293 B = make_b(batch + (n, rhs)) 294 295 grads = (False,) if not requires_grad else (True, False) 296 # we try all possible combinations of requires_grad for each input 297 for LU_grad, B_grad in product(grads, grads): 298 # when requires_grad == True, at least one input has to have requires_grad enabled 299 if requires_grad and not LU_grad and not B_grad: 300 continue 301 302 if is_linalg_lu_solve: 303 for adjoint, left in product((True, False), repeat=2): 304 yield SampleInput( 305 clone(LU, LU_grad), 306 args=(pivots, clone(B if left else B.mT, B_grad)), 307 kwargs=dict(adjoint=adjoint, left=left), 308 ) 309 else: 310 yield SampleInput(clone(B, B_grad), args=(clone(LU, LU_grad), pivots)) 311 312 313def sample_inputs_linalg_multi_dot(op_info, device, dtype, requires_grad, **kwargs): 314 # Each test case consists of the sizes in the chain of multiplications 315 # e.g. [2, 3, 4, 5] generates matrices (2, 3) @ (3, 4) @ (4, 5) 316 test_cases = [ 317 [1, 2, 1], 318 [2, 0, 2], 319 [0, 2, 2], 320 [2, 2, 2, 2], 321 [2, 3, 4, 5], 322 [5, 4, 0, 2], 323 [2, 4, 3, 5, 3, 2], 324 ] 325 326 for sizes in test_cases: 327 tensors = [] 328 for size in zip(sizes[:-1], sizes[1:]): 329 t = make_tensor( 330 size, dtype=dtype, device=device, requires_grad=requires_grad 331 ) 332 tensors.append(t) 333 yield SampleInput(tensors) 334 335 336def sample_inputs_linalg_matrix_norm(op_info, device, dtype, requires_grad, **kwargs): 337 low_precision_dtypes = (torch.float16, torch.bfloat16, torch.complex32) 338 make_arg = partial( 339 make_tensor, device=device, dtype=dtype, requires_grad=requires_grad 340 ) 341 342 sizes = ((2, 2), (2, 3, 2)) 343 if dtype in low_precision_dtypes: 344 # svdvals not supported for low precision dtypes 345 ords = ("fro", inf, -inf, 1, -1) 346 else: 347 ords = ("fro", "nuc", inf, -inf, 1, -1, 2, -2) 348 dims = ((-2, -1), (-1, 0)) 349 350 for size, ord, dim, keepdim in product(sizes, ords, dims, [True, False]): 351 yield SampleInput(make_arg(size), args=(ord, dim, keepdim)) 352 353 354def sample_inputs_linalg_norm( 355 op_info, device, dtype, requires_grad, *, variant=None, **kwargs 356): 357 if variant is not None and variant not in ("subgradient_at_zero",): 358 raise ValueError( 359 f"Unsupported variant, expected variant to be 'subgradient_at_zero' but got: {variant}" 360 ) 361 362 test_sizes = [ 363 (S,), 364 (0,), 365 (S, S), 366 (0, 0), 367 (S, 0), 368 (0, S), 369 (S, S, S), 370 (0, S, S), 371 (S, 0, S), 372 (0, 0, 0), 373 ] 374 375 vector_ords = (None, 0, 0.5, 1, 2, 3.5, inf, -0.5, -1, -2, -3.5, -inf) 376 if dtype in {torch.float16, torch.bfloat16, torch.complex32}: 377 # svdvals not supported for low precision dtypes 378 matrix_ords = ("fro", inf, -inf, 1, -1) 379 else: 380 matrix_ords = (None, "fro", "nuc", inf, -inf, 1, -1, 2, -2) 381 382 make_arg = partial( 383 make_tensor, 384 dtype=dtype, 385 device=device, 386 requires_grad=requires_grad, 387 low=None, 388 high=None, 389 ) 390 391 for test_size in test_sizes: 392 is_vector_norm = len(test_size) == 1 393 is_matrix_norm = len(test_size) == 2 394 395 # IndexError: amax(): Expected reduction dim 0 to have non-zero size. 396 is_valid_for_p2 = is_vector_norm or (test_size[-1] != 0 and test_size[-2] != 0) 397 398 for keepdim in [False, True]: 399 if variant != "subgradient_at_zero" and is_valid_for_p2: 400 yield SampleInput(make_arg(test_size), keepdim=keepdim) 401 402 if not (is_vector_norm or is_matrix_norm): 403 continue 404 405 ords = vector_ords if is_vector_norm else matrix_ords 406 407 for ord in ords: 408 if is_vector_norm and test_size[-1] == 0: 409 if ord == np.inf or (ord is not None and ord < 0): 410 # RuntimeError: linalg.vector_norm cannot compute the 411 # {ord} norm on an empty tensor because the operation 412 # does not have an identity 413 continue 414 elif is_matrix_norm: 415 dims_to_check = { 416 None: (0,), 417 np.inf: (0,), 418 2: (0, 1), 419 1: (1,), 420 -1: (1,), 421 -2: (0, 1), 422 -np.inf: (0,), 423 }.get(ord, ()) 424 425 if any(test_size[d] == 0 for d in dims_to_check): 426 # IndexError: amax(): Expected reduction dim {dim} to 427 # have non-zero size. 428 continue 429 430 if variant == "subgradient_at_zero": 431 yield SampleInput( 432 torch.zeros( 433 test_size, 434 dtype=dtype, 435 device=device, 436 requires_grad=requires_grad, 437 ), 438 ord, 439 keepdim=keepdim, 440 ) 441 else: 442 yield SampleInput(make_arg(test_size), ord, keepdim=keepdim) 443 444 if ord in ["nuc", "fro"]: 445 yield SampleInput( 446 make_arg(test_size), ord=ord, keepdim=keepdim, dim=(0, 1) 447 ) 448 449 450def sample_inputs_linalg_vecdot(op_info, device, dtype, requires_grad, **kwargs): 451 make_arg = partial( 452 make_tensor, device=device, dtype=dtype, requires_grad=requires_grad 453 ) 454 batches = ((), (0,), (1,), (5,)) 455 ns = (0, 1, 3, 5) 456 for b, n in product(batches, ns): 457 shape = b + (n,) 458 yield SampleInput(make_arg(shape), args=(make_arg(shape),)) 459 for i in range(len(shape)): 460 yield SampleInput( 461 make_arg(shape), args=(make_arg(shape),), kwargs=dict(dim=i) 462 ) 463 464 465def sample_inputs_linalg_invertible( 466 op_info, device, dtype, requires_grad=False, **kwargs 467): 468 """ 469 This function generates invertible inputs for linear algebra ops 470 The input is generated as the itertools.product of 'batches' and 'ns'. 471 In total this function generates 8 SampleInputs 472 'batches' cases include: 473 () - single input, 474 (0,) - zero batched dimension, 475 (2,) - batch of two matrices, 476 (1, 1) - 1x1 batch of matrices 477 'ns' gives 0x0 and 5x5 matrices. 478 Zeros in dimensions are edge cases in the implementation and important to test for in order to avoid unexpected crashes. 479 """ 480 make_fn = make_fullrank_matrices_with_distinct_singular_values 481 make_arg = partial(make_fn, dtype=dtype, device=device, requires_grad=requires_grad) 482 483 batches = [(), (0,), (2,), (1, 1)] 484 ns = [5, 0] 485 486 for batch, n in product(batches, ns): 487 yield SampleInput(make_arg(*batch, n, n)) 488 489 490def sample_inputs_matrix_rank(op_info, device, dtype, requires_grad=False, **kwargs): 491 """ 492 This function produces inputs for matrix rank that test 493 all possible combinations for atol and rtol 494 """ 495 496 def make_tol_arg(kwarg_type, inp): 497 if kwarg_type == "none": 498 return None 499 if kwarg_type == "float": 500 return 1.0 501 assert kwarg_type == "tensor" 502 return torch.ones(inp.shape[:-2], device=device) 503 504 for tol_type in ["float", "tensor"]: 505 for atol_type, rtol_type in product(["none", tol_type], repeat=2): 506 if ( 507 not atol_type and not rtol_type 508 ): # default behavior, so skipped here so it's not tested 2 extra times 509 continue 510 for sample in sample_inputs_linalg_invertible( 511 op_info, device, dtype, requires_grad 512 ): 513 assert sample.kwargs == {} 514 sample.kwargs = { 515 "atol": make_tol_arg(atol_type, sample.input), 516 "rtol": make_tol_arg(rtol_type, sample.input), 517 } 518 yield sample 519 520 # default kwargs 521 yield from sample_inputs_linalg_invertible(op_info, device, dtype, requires_grad) 522 523 524def sample_inputs_linalg_pinv_singular( 525 op_info, device, dtype, requires_grad=False, **kwargs 526): 527 """ 528 This function produces factors `a` and `b` to generate inputs of the form `a @ b.t()` to 529 test the backward method of `linalg_pinv`. That way we always preserve the rank of the 530 input no matter the perturbations applied to it by the gradcheck. 531 Note that `pinv` is Frechet-differentiable in a rank-preserving neighborhood. 532 """ 533 batches = [(), (0,), (2,), (1, 1)] 534 # the size of at least 30 is required to cause failures for the previous implicit implementation 535 # of the pinv's backward method, albeit it is slow. 536 size = [0, 3, 50] 537 538 for batch, m, n in product(batches, size, size): 539 for k in range(min(3, m, n)): 540 # Note that by making the columns of `a` and `b` orthonormal we make sure that 541 # the product matrix `a @ b.t()` has condition number 1 when restricted to its image 542 a = ( 543 torch.rand(*batch, m, k, device=device, dtype=dtype) 544 .qr() 545 .Q.requires_grad_(requires_grad) 546 ) 547 b = ( 548 torch.rand(*batch, n, k, device=device, dtype=dtype) 549 .qr() 550 .Q.requires_grad_(requires_grad) 551 ) 552 yield SampleInput(a, args=(b,)) 553 554 555def sample_inputs_linalg_cond(op_info, device, dtype, requires_grad=False, **kwargs): 556 make_arg = partial( 557 make_tensor, dtype=dtype, device=device, requires_grad=requires_grad 558 ) 559 560 # autograd is not supported for inputs with zero number of elements 561 shapes = ( 562 (S, S), 563 (2, S, S), 564 (2, 1, S, S), 565 ) 566 567 for shape in shapes: 568 yield SampleInput(make_arg(shape)) 569 570 571def sample_inputs_linalg_vander(op_info, device, dtype, requires_grad=False, **kwargs): 572 make_arg = partial( 573 make_tensor, dtype=dtype, device=device, requires_grad=requires_grad 574 ) 575 576 shapes = ( 577 (), 578 (1,), 579 (S,), 580 (2, S), 581 ) 582 583 for shape in shapes: 584 if len(shape) > 0 and shape[-1] > 1: 585 yield SampleInput(make_arg(shape)) 586 n = shape[-1] if len(shape) > 0 else 1 587 for i in range(3): 588 # n-1, n, n+1 589 N = n + i - 1 590 if N < 2: 591 continue 592 yield SampleInput(make_arg(shape), kwargs=dict(N=N)) 593 594 595def np_vander_batched(x, N=None): 596 # Wrapper around np.vander that supports batches of 1 dimension (enough for the tests) 597 if x.ndim == 0: 598 x = x[np.newaxis] 599 if x.ndim == 1: 600 y = np.vander(x, N=N, increasing=True) 601 return y 602 else: 603 if N is None: 604 N = x.shape[-1] 605 y = np.vander(x.ravel(), N=N, increasing=True).reshape((*x.shape, N)) 606 return y 607 608 609def sample_inputs_linalg_cholesky_inverse( 610 op_info, device, dtype, requires_grad=False, **kwargs 611): 612 from torch.testing._internal.common_utils import random_well_conditioned_matrix 613 614 # Cholesky factorization is for positive-definite matrices 615 single_well_conditioned_matrix = random_well_conditioned_matrix( 616 S, S, dtype=dtype, device=device 617 ) 618 batch_well_conditioned_matrices = random_well_conditioned_matrix( 619 2, S, S, dtype=dtype, device=device 620 ) 621 single_pd = single_well_conditioned_matrix @ single_well_conditioned_matrix.mH 622 batch_pd = batch_well_conditioned_matrices @ batch_well_conditioned_matrices.mH 623 624 inputs = ( 625 torch.zeros(0, 0, dtype=dtype, device=device), # 0x0 matrix 626 torch.zeros(0, 2, 2, dtype=dtype, device=device), # zero batch of matrices 627 single_pd, 628 batch_pd, 629 ) 630 test_cases = (torch.linalg.cholesky(a, upper=False) for a in inputs) 631 for l in test_cases: 632 # generated lower-triangular samples 633 l.requires_grad = requires_grad 634 yield SampleInput(l) # upper=False by default 635 yield SampleInput( 636 l.detach().clone().requires_grad_(requires_grad), kwargs=dict(upper=False) 637 ) 638 639 # generate upper-triangular inputs 640 u = l.detach().clone().mT.contiguous().requires_grad_(requires_grad) 641 yield SampleInput(u, kwargs=dict(upper=True)) 642 643 644def sample_inputs_linalg_ldl_factor( 645 op_info, device, dtype, requires_grad=False, **kwargs 646): 647 from torch.testing._internal.common_utils import ( 648 random_hermitian_pd_matrix, 649 random_symmetric_pd_matrix, 650 ) 651 652 device = torch.device(device) 653 654 # Symmetric inputs 655 yield SampleInput( 656 random_symmetric_pd_matrix(S, dtype=dtype, device=device), 657 kwargs=dict(hermitian=False), 658 ) # single matrix 659 yield SampleInput( 660 random_symmetric_pd_matrix(S, 2, dtype=dtype, device=device), 661 kwargs=dict(hermitian=False), 662 ) # batch of matrices 663 yield SampleInput( 664 torch.zeros(0, 0, dtype=dtype, device=device), kwargs=dict(hermitian=False) 665 ) # 0x0 matrix 666 yield SampleInput( 667 torch.zeros(0, 2, 2, dtype=dtype, device=device), kwargs=dict(hermitian=False) 668 ) # zero batch of matrices 669 670 # Hermitian inputs 671 # hermitian=True for complex inputs on CUDA is supported only with MAGMA 2.5.4+ 672 magma_254_available = device.type == "cuda" and _get_magma_version() >= (2, 5, 4) 673 if dtype.is_complex and (device.type == "cpu" or magma_254_available): 674 yield SampleInput( 675 random_hermitian_pd_matrix(S, dtype=dtype, device=device), 676 kwargs=dict(hermitian=True), 677 ) # single matrix 678 yield SampleInput( 679 random_hermitian_pd_matrix(S, 2, dtype=dtype, device=device), 680 kwargs=dict(hermitian=True), 681 ) # batch of matrices 682 683 684def sample_inputs_linalg_ldl_solve( 685 op_info, device, dtype, requires_grad=False, **kwargs 686): 687 # Generate LDL factors of symmetric (and Hermitian on CPU) matrices 688 from torch.testing._internal.common_utils import ( 689 random_hermitian_pd_matrix, 690 random_symmetric_pd_matrix, 691 ) 692 693 device = torch.device(device) 694 symmetric_inputs = ( 695 random_symmetric_pd_matrix(S, dtype=dtype, device=device), # single matrix 696 random_symmetric_pd_matrix( 697 S, 2, dtype=dtype, device=device 698 ), # batch of matrices 699 torch.zeros(0, 0, dtype=dtype, device=device), # 0x0 matrix 700 torch.zeros(0, 2, 2, dtype=dtype, device=device), # zero batch of matrices 701 ) 702 hermitian_inputs = ( 703 ( 704 random_hermitian_pd_matrix(S, dtype=dtype, device=device), 705 random_hermitian_pd_matrix(S, 2, dtype=dtype, device=device), 706 ) 707 if device.type == "cpu" and dtype.is_complex 708 else () 709 ) 710 test_cases1 = ( 711 torch.linalg.ldl_factor_ex(a, hermitian=False) for a in symmetric_inputs 712 ) 713 test_cases2 = ( 714 torch.linalg.ldl_factor_ex(a, hermitian=True) for a in hermitian_inputs 715 ) 716 717 # Symmetric case 718 make_arg = partial( 719 make_tensor, device=device, dtype=dtype, requires_grad=requires_grad 720 ) 721 for test_case in test_cases1: 722 factors, pivots, _ = test_case 723 factors.requires_grad = requires_grad 724 for B_batch_shape in ((), factors.shape[:-2]): 725 B = make_arg((*B_batch_shape, factors.shape[-1], S)) 726 yield SampleInput(factors, args=(pivots, B), kwargs=dict(hermitian=False)) 727 clone_factors = factors.detach().clone().requires_grad_(requires_grad) 728 yield SampleInput( 729 clone_factors, args=(pivots, B), kwargs=dict(hermitian=False) 730 ) 731 732 # Hermitian case 733 for test_case in test_cases2: 734 factors, pivots, _ = test_case 735 factors.requires_grad = requires_grad 736 for B_batch_shape in ((), factors.shape[:-2]): 737 B = make_arg((*B_batch_shape, factors.shape[-1], S)) 738 yield SampleInput(factors, args=(pivots, B), kwargs=dict(hermitian=True)) 739 clone_factors = factors.detach().clone().requires_grad_(requires_grad) 740 yield SampleInput( 741 clone_factors, args=(pivots, B), kwargs=dict(hermitian=True) 742 ) 743 744 745def sample_inputs_linalg_lstsq(op_info, device, dtype, requires_grad=False, **kwargs): 746 from torch.testing._internal.common_utils import random_well_conditioned_matrix 747 748 device = torch.device(device) 749 750 drivers: Tuple[str, ...] 751 if device.type == "cuda": 752 drivers = ("gels",) 753 else: 754 drivers = ("gels", "gelsy", "gelss", "gelsd") 755 756 # we generate matrices of shape (..., n + delta, n) 757 deltas: Tuple[int, ...] 758 if device.type == "cpu" or has_cusolver(): 759 deltas = (-1, 0, +1) 760 # only square systems if Cusolver is not available 761 # becase we solve a lstsq problem with a transposed matrix in the backward 762 else: 763 deltas = (0,) 764 765 for batch, driver, delta in product(((), (3,), (3, 3)), drivers, deltas): 766 shape = batch + (3 + delta, 3) 767 a = random_well_conditioned_matrix(*shape, dtype=dtype, device=device) 768 a.requires_grad_(requires_grad) 769 b = make_tensor( 770 shape, 771 dtype=dtype, 772 device=device, 773 low=None, 774 high=None, 775 requires_grad=requires_grad, 776 ) 777 yield SampleInput(a, b, driver=driver) 778 779 780def error_inputs_lstsq(op_info, device, **kwargs): 781 zero_d = torch.randn((), device=device) 782 yield ErrorInput( 783 SampleInput(zero_d, args=(zero_d,)), 784 error_type=RuntimeError, 785 error_regex="at least 2 dimensions", 786 ) 787 788 789def error_inputs_lstsq_grad_oriented(op_info, device, **kwargs): 790 zero_d = torch.randn((), device=device) 791 yield ErrorInput( 792 SampleInput(zero_d, args=(zero_d, None)), 793 error_type=RuntimeError, 794 error_regex="at least 2 dimensions", 795 ) 796 797 798def sample_inputs_diagonal_diag_embed(op_info, device, dtype, requires_grad, **kwargs): 799 make_arg = partial( 800 make_tensor, dtype=dtype, device=device, requires_grad=requires_grad 801 ) 802 803 # Shapes for 2D Tensors 804 shapes_2d = ((S, S), (3, 5), (5, 3)) 805 806 # Shapes for 3D Tensors 807 shapes_3d = ((S, S, S),) 808 809 kwargs_2d = ({}, dict(offset=2), dict(offset=2), dict(offset=1)) 810 kwargs_3d = ( 811 dict(offset=1, dim1=1, dim2=2), 812 dict(offset=2, dim1=0, dim2=1), 813 dict(offset=-2, dim1=0, dim2=1), 814 ) 815 816 for shape, kwarg in chain( 817 product(shapes_2d, kwargs_2d), product(shapes_3d, kwargs_3d) 818 ): 819 yield SampleInput(make_arg(shape), kwargs=kwarg) 820 821 822def error_inputs_diagonal_diag_embed(op_info, device, **kwargs): 823 make_arg = partial(make_tensor, device=device, dtype=torch.float32) 824 825 shapes1d = (0, 1, (0,), (1,)) 826 shapes2d = ((M, L),) 827 shapes3d = ((M, S, L),) 828 829 kwargs1d = {} 830 831 kwargs2d = ( 832 # dim1 == dim2 is not allowed 833 dict(dim1=1, dim2=1), 834 # out of bounds dims are not allowed 835 dict(dim1=10000), 836 dict(dim2=10000), 837 ) 838 839 kwargs3d = kwargs2d 840 841 samples1d = product(shapes1d, kwargs1d) 842 samples2d = product(shapes2d, kwargs2d) 843 samples3d = product(shapes3d, kwargs3d) 844 845 for shape, kwargs in chain(samples1d, samples2d, samples3d): 846 arg = make_arg(shape) 847 sample = SampleInput(input=arg, kwargs=kwargs) 848 849 dim1 = kwargs.get("dim1") 850 dim2 = kwargs.get("dim2") 851 852 if "diagonal" in op_info.name: 853 num_dim = arg.dim() 854 elif op_info.name in ("diag_embed", "_refs.diag_embed"): 855 # these are valid inputs for diag_embed 856 if shape in ((0,), (1,)): 857 continue 858 num_dim = arg.dim() + 1 859 else: 860 raise RuntimeError("should be unreachable") 861 862 bound1 = -num_dim 863 bound2 = num_dim - 1 864 dim_range = range(bound1, bound2 + 1) 865 dim1_cond = dim1 and dim1 not in dim_range 866 dim2_cond = dim2 and dim2 not in dim_range 867 868 if dim1 == dim2: 869 err = f"diagonal dimensions cannot be identical {dim1}, {dim2}" 870 yield ErrorInput(sample, error_regex=err, error_type=RuntimeError) 871 elif dim1_cond or dim2_cond: 872 err_dim = dim1 if dim1_cond else dim2 873 err = ( 874 r"Dimension out of range \(expected to be in range of " 875 rf"\[{bound1}, {bound2}\], but got {err_dim}\)" 876 ) 877 yield ErrorInput(sample, error_regex=err, error_type=IndexError) 878 else: 879 raise RuntimeError("should be unreachable") 880 881 882def sample_inputs_linalg_cholesky( 883 op_info, device, dtype, requires_grad=False, **kwargs 884): 885 """ 886 This function generates always positive-definite input for torch.linalg.cholesky using 887 random_hermitian_pd_matrix. 888 The input is generated as the itertools.product of 'batches' and 'ns'. 889 In total this function generates 8 SampleInputs 890 'batches' cases include: 891 () - single input, 892 (0,) - zero batched dimension, 893 (2,) - batch of two matrices, 894 (1, 1) - 1x1 batch of matrices 895 'ns' gives 0x0 and 5x5 matrices. 896 Zeros in dimensions are edge cases in the implementation and important to test for in order to avoid unexpected crashes. 897 """ 898 from torch.testing._internal.common_utils import random_hermitian_pd_matrix 899 900 batches = [(), (0,), (2,), (1, 1)] 901 ns = [5, 0] 902 for batch, n, upper in product(batches, ns, [True, False]): 903 a = random_hermitian_pd_matrix(n, *batch, dtype=dtype, device=device) 904 a.requires_grad = requires_grad 905 yield SampleInput(a, upper=upper) 906 907 908def sample_inputs_linalg_eig(op_info, device, dtype, requires_grad=False, **kwargs): 909 """ 910 This function generates input for torch.linalg.eig 911 """ 912 913 def out_fn(output): 914 return output[0], abs(output[1]) 915 916 samples = sample_inputs_linalg_invertible(op_info, device, dtype, requires_grad) 917 for sample in samples: 918 sample.output_process_fn_grad = out_fn 919 yield sample 920 921 922def sample_inputs_linalg_eigh(op_info, device, dtype, requires_grad=False, **kwargs): 923 """ 924 This function generates input for torch.linalg.eigh/eigvalsh with UPLO="U" or "L" keyword argument. 925 """ 926 927 def out_fn(output): 928 if isinstance(output, tuple): 929 # eigh function 930 return output[0], abs(output[1]) 931 else: 932 # eigvalsh function 933 return output 934 935 # Samples do not need to be Hermitian, as we're using gradcheck_wrapper_hermitian_input 936 samples = sample_inputs_linalg_invertible(op_info, device, dtype, requires_grad) 937 for sample in samples: 938 # Note: we cannot use np.random.choice here as TorchDynamo 939 # does not support tensors of strings. 940 sample.kwargs = {"UPLO": random.choice(["L", "U"])} 941 sample.output_process_fn_grad = out_fn 942 yield sample 943 944 945def sample_inputs_linalg_pinv(op_info, device, dtype, requires_grad=False, **kwargs): 946 """ 947 This function generates input for torch.linalg.pinv with hermitian=False keyword argument. 948 """ 949 for o in sample_inputs_linalg_invertible( 950 op_info, device, dtype, requires_grad, **kwargs 951 ): 952 real_dtype = o.input.real.dtype if dtype.is_complex else dtype 953 # requires_grad path for rtol tensor is not implemented 954 for rtol in (None, 1.0, torch.tensor(1.0, dtype=real_dtype, device=device)): 955 o = clone_sample(o) 956 o.kwargs = {"rtol": rtol} 957 yield o 958 959 960def sample_inputs_linalg_pinv_hermitian( 961 op_info, device, dtype, requires_grad=False, **kwargs 962): 963 """ 964 This function generates input for torch.linalg.pinv with hermitian=True keyword argument. 965 """ 966 for o in sample_inputs_linalg_invertible( 967 op_info, device, dtype, requires_grad, **kwargs 968 ): 969 o.kwargs = {"hermitian": True} 970 yield o 971 972 973def sample_inputs_linalg_solve( 974 op_info, device, dtype, requires_grad=False, vector_rhs_allowed=True, **kwargs 975): 976 """ 977 This function generates always solvable input for torch.linalg.solve 978 We sample a fullrank square matrix (i.e. invertible) A 979 The first input to torch.linalg.solve is generated as the itertools.product of 'batches' and 'ns'. 980 The second input is generated as the product of 'batches', 'ns' and 'nrhs'. 981 In total this function generates 18 SampleInputs 982 'batches' cases include: 983 () - single input, 984 (0,) - zero batched dimension, 985 (2,) - batch of two matrices. 986 'ns' gives 0x0 and 5x5 matrices. 987 and 'nrhs' controls the number of vectors to solve for: 988 () - using 1 as the number of vectors implicitly 989 (1,) - same as () but explicit 990 (3,) - solve for 3 vectors. 991 Zeros in dimensions are edge cases in the implementation and important to test for in order to avoid unexpected crashes. 992 'vector_rhs_allowed' controls whether to include nrhs = () to the list of SampleInputs. 993 torch.solve / triangular_solve / cholesky_solve (opposed to torch.linalg.solve) do not allow 994 1D tensors (vectors) as the right-hand-side. 995 Once torch.solve / triangular_solve / cholesky_solve and its testing are removed, 996 'vector_rhs_allowed' may be removed here as well. 997 """ 998 make_fullrank = make_fullrank_matrices_with_distinct_singular_values 999 make_a = partial( 1000 make_fullrank, dtype=dtype, device=device, requires_grad=requires_grad 1001 ) 1002 make_b = partial( 1003 make_tensor, dtype=dtype, device=device, requires_grad=requires_grad 1004 ) 1005 1006 batches = [(), (0,), (2,)] 1007 ns = [5, 0] 1008 if vector_rhs_allowed: 1009 nrhs = [(), (1,), (3,)] 1010 else: 1011 nrhs = [(1,), (3,)] 1012 1013 for n, batch, rhs in product(ns, batches, nrhs): 1014 yield SampleInput(make_a(*batch, n, n), args=(make_b(batch + (n,) + rhs),)) 1015 1016 1017def sample_inputs_linalg_solve_triangular( 1018 op_info, device, dtype, requires_grad=False, **kwargs 1019): 1020 make_arg = partial(make_tensor, dtype=dtype, device=device) 1021 bs = (1, 2, 0) 1022 ns = (3, 0) 1023 ks = (1, 3, 0) 1024 1025 for b, n, k, (left, upper, uni) in product( 1026 bs, ns, ks, product((True, False), repeat=3) 1027 ): 1028 if b == 1: 1029 A = make_arg((n, n)) if left else make_arg((k, k)) 1030 B = make_arg((n, k)) 1031 else: 1032 A = make_arg((b, n, n)) if left else make_arg((b, k, k)) 1033 B = make_arg((b, n, k)) 1034 if uni: 1035 # Not really necessary, but writing it for consistency 1036 A.diagonal(0, -2, -1).fill_(1.0) 1037 else: 1038 d = A.diagonal(0, -2, -1) 1039 d[d.abs() < 1e-6] = 1.0 1040 if upper: 1041 A.triu_() 1042 else: 1043 A.tril_() 1044 kwargs = {"upper": upper, "left": left, "unitriangular": uni} 1045 if requires_grad: 1046 for grad_A, grad_B in product((True, False), repeat=2): 1047 # Either A or B needs to have a gradient 1048 if not grad_A and not grad_B: 1049 continue 1050 yield SampleInput( 1051 A.clone().requires_grad_(grad_A), 1052 args=(B.clone().requires_grad_(grad_B),), 1053 kwargs=kwargs, 1054 ) 1055 else: 1056 yield SampleInput(A, args=(B,), kwargs=kwargs) 1057 1058 1059def sample_inputs_legacy_solve(op_info, device, dtype, requires_grad=False, **kwargs): 1060 """ 1061 This function generates always solvable input for legacy solve functions 1062 (the ones that are not in torch.linalg module). 1063 The difference from sample_inputs_linalg_solve is that here the right-hand-side of A x = b equation 1064 should have b.ndim >= 2, vectors are not allowed. 1065 Also the arguments order is swapped. 1066 """ 1067 out = sample_inputs_linalg_solve( 1068 op_info, device, dtype, requires_grad=requires_grad, vector_rhs_allowed=False 1069 ) 1070 1071 def out_fn(output): 1072 return output[0] 1073 1074 # Reverses tensor order 1075 for sample in out: 1076 sample.input, sample.args = sample.args[0], (sample.input,) 1077 if op_info.name == "solve": 1078 sample.output_process_fn_grad = out_fn 1079 yield sample 1080 1081 1082def sample_inputs_linalg_lu(op_info, device, dtype, requires_grad=False, **kwargs): 1083 full_rank = op_info.name == "linalg.lu_factor" 1084 make_fn = ( 1085 make_tensor 1086 if not full_rank 1087 else make_fullrank_matrices_with_distinct_singular_values 1088 ) 1089 make_arg = partial(make_fn, dtype=dtype, device=device, requires_grad=requires_grad) 1090 1091 def out_fn(output): 1092 if op_info.name == "linalg.lu": 1093 return output[1], output[2] 1094 else: 1095 return output 1096 1097 batch_shapes = ((), (3,), (3, 3)) 1098 # pivot=False only supported in CUDA 1099 pivots = (True, False) if torch.device(device).type == "cuda" else (True,) 1100 deltas = (-2, -1, 0, +1, +2) 1101 for batch_shape, pivot, delta in product(batch_shapes, pivots, deltas): 1102 shape = batch_shape + (S + delta, S) 1103 # Insanely annoying that make_fullrank_blablabla accepts a *shape and not a tuple! 1104 A = make_arg(shape) if not full_rank else make_arg(*shape) 1105 yield SampleInput(A, kwargs={"pivot": pivot}, output_process_fn_grad=out_fn) 1106 1107 1108def sample_inputs_linalg_svdvals(op_info, device, dtype, requires_grad=False, **kwargs): 1109 make_arg = partial( 1110 make_tensor, dtype=dtype, device=device, requires_grad=requires_grad 1111 ) 1112 1113 batches = [(), (0,), (2,), (1, 1)] 1114 ns = [5, 2, 0] 1115 1116 for batch, m, n in product(batches, ns, ns): 1117 yield SampleInput(make_arg(batch + (m, n))) 1118 1119 1120def sample_inputs_linalg_qr_geqrf( 1121 op_info, device, dtype, requires_grad=False, **kwargs 1122): 1123 # QR is just well defined when the matrix is full rank 1124 make_fullrank = make_fullrank_matrices_with_distinct_singular_values 1125 make_arg = partial( 1126 make_fullrank, dtype=dtype, device=device, requires_grad=requires_grad 1127 ) 1128 1129 batches = [(), (0,), (2,), (1, 1)] 1130 ns = [5, 2, 0] 1131 1132 for batch, (m, n) in product(batches, product(ns, ns)): 1133 shape = batch + (m, n) 1134 yield SampleInput(make_arg(*shape)) 1135 1136 1137def sample_inputs_tensorsolve(op_info, device, dtype, requires_grad, **kwargs): 1138 a_shapes = [(2, 3, 6), (3, 4, 4, 3)] 1139 # Zero-dim tensors are not supported in NumPy, so we skip them for now. 1140 # NumPy is used in reference check tests. 1141 # See https://github.com/numpy/numpy/pull/20482 for tracking NumPy bugfix. 1142 # a_shapes += [(0, 0, 1, 2, 3, 0)] 1143 dimss = [None, (0, 2)] 1144 1145 make_arg = partial( 1146 make_tensor, dtype=dtype, device=device, requires_grad=requires_grad 1147 ) 1148 for a_shape, dims in itertools.product(a_shapes, dimss): 1149 a = make_arg(a_shape) 1150 b = make_arg(a_shape[:2]) 1151 yield SampleInput(a, b, dims=dims) 1152 1153 1154def sample_inputs_tensorinv(op_info, device, dtype, requires_grad, **kwargs): 1155 make_arg = make_fullrank_matrices_with_distinct_singular_values 1156 1157 def make_input(): 1158 return make_arg(12, 12, device=device, dtype=dtype, requires_grad=requires_grad) 1159 1160 # lhs / rhs shape can have any number of dimensions as long as their product equals 12 1161 shapes = [ 1162 ((2, 2, 3), (12, 1)), 1163 ((4, 3), (6, 1, 2)), 1164 ] 1165 1166 for shape_lhs, shape_rhs in shapes: 1167 inp = make_input().reshape(*shape_lhs, *shape_rhs).detach() 1168 inp.requires_grad_(requires_grad) 1169 yield SampleInput(inp, ind=len(shape_lhs)) 1170 1171 1172op_db: List[OpInfo] = [ 1173 OpInfo( 1174 "linalg.cross", 1175 ref=lambda x, y, dim=-1: np.cross(x, y, axis=dim), 1176 op=torch.linalg.cross, 1177 dtypes=all_types_and_complex_and(torch.half, torch.bfloat16), 1178 aten_name="linalg_cross", 1179 sample_inputs_func=sample_inputs_cross, 1180 error_inputs_func=error_inputs_cross, 1181 supports_out=True, 1182 supports_fwgrad_bwgrad=True, 1183 supports_forward_ad=True, 1184 skips=( 1185 DecorateInfo( 1186 unittest.skip("Unsupported on MPS for now"), 1187 "TestCommon", 1188 "test_numpy_ref_mps", 1189 ), 1190 ), 1191 ), 1192 OpInfo( 1193 "linalg.det", 1194 aten_name="linalg_det", 1195 op=torch.linalg.det, 1196 aliases=("det",), 1197 dtypes=floating_and_complex_types(), 1198 supports_forward_ad=True, 1199 supports_fwgrad_bwgrad=True, 1200 sample_inputs_func=sample_inputs_linalg_det_logdet_slogdet, 1201 decorators=[skipCPUIfNoLapack, skipCUDAIfNoMagmaAndNoCusolver], 1202 check_batched_gradgrad=False, 1203 ), 1204 OpInfo( 1205 "linalg.det", 1206 aten_name="linalg_det", 1207 op=torch.linalg.det, 1208 variant_test_name="singular", 1209 aliases=("det",), 1210 dtypes=floating_and_complex_types(), 1211 supports_forward_ad=True, 1212 supports_fwgrad_bwgrad=True, 1213 check_batched_gradgrad=False, 1214 sample_inputs_func=sample_inputs_linalg_det_singular, 1215 decorators=[skipCPUIfNoLapack, skipCUDAIfNoMagmaAndNoCusolver], 1216 skips=( 1217 DecorateInfo( 1218 unittest.skip("The backward may give different results"), 1219 "TestCommon", 1220 "test_noncontiguous_samples", 1221 ), 1222 DecorateInfo( 1223 unittest.skip("Gradients are incorrect on macos"), 1224 "TestBwdGradients", 1225 "test_fn_grad", 1226 device_type="cpu", 1227 dtypes=(torch.float64,), 1228 active_if=IS_MACOS, 1229 ), 1230 DecorateInfo( 1231 unittest.skip("Gradients are incorrect on macos"), 1232 "TestFwdGradients", 1233 "test_forward_mode_AD", 1234 device_type="cpu", 1235 dtypes=(torch.float64,), 1236 active_if=IS_MACOS, 1237 ), 1238 # Both Hessians are incorrect on complex inputs?? 1239 DecorateInfo( 1240 unittest.expectedFailure, 1241 "TestBwdGradients", 1242 "test_fn_gradgrad", 1243 dtypes=(torch.complex128,), 1244 ), 1245 DecorateInfo( 1246 unittest.expectedFailure, 1247 "TestFwdGradients", 1248 "test_fn_fwgrad_bwgrad", 1249 dtypes=(torch.complex128,), 1250 ), 1251 DecorateInfo( 1252 unittest.skip("Skipped, see https://github.com//issues/84192"), 1253 "TestBwdGradients", 1254 "test_fn_gradgrad", 1255 device_type="cuda", 1256 ), 1257 DecorateInfo( 1258 unittest.skip("Skipped, see https://github.com//issues/84192"), 1259 "TestFwdGradients", 1260 "test_fn_fwgrad_bwgrad", 1261 device_type="cuda", 1262 ), 1263 DecorateInfo( 1264 unittest.skip( 1265 "Flaky on ROCm https://github.com/pytorch/pytorch/issues/93044" 1266 ), 1267 "TestBwdGradients", 1268 "test_fn_grad", 1269 device_type="cuda", 1270 dtypes=get_all_complex_dtypes(), 1271 active_if=TEST_WITH_ROCM, 1272 ), 1273 DecorateInfo( 1274 unittest.skip( 1275 "Flaky on ROCm https://github.com/pytorch/pytorch/issues/93045" 1276 ), 1277 "TestFwdGradients", 1278 "test_forward_mode_AD", 1279 device_type="cuda", 1280 dtypes=get_all_complex_dtypes(), 1281 active_if=TEST_WITH_ROCM, 1282 ), 1283 ), 1284 ), 1285 OpInfo( 1286 "linalg.diagonal", 1287 aten_name="linalg_diagonal", 1288 aten_backward_name="diagonal_backward", 1289 dtypes=all_types_and_complex_and( 1290 torch.bool, torch.bfloat16, torch.float16, torch.chalf 1291 ), 1292 supports_out=False, 1293 supports_forward_ad=True, 1294 supports_fwgrad_bwgrad=True, 1295 sample_inputs_func=sample_inputs_diagonal_diag_embed, 1296 error_inputs_func=error_inputs_diagonal_diag_embed, 1297 ), 1298 OpInfo( 1299 "linalg.cholesky", 1300 aten_name="linalg_cholesky", 1301 dtypes=floating_and_complex_types(), 1302 supports_forward_ad=True, 1303 supports_fwgrad_bwgrad=True, 1304 # See https://github.com/pytorch/pytorch/pull/78358 1305 check_batched_forward_grad=False, 1306 sample_inputs_func=sample_inputs_linalg_cholesky, 1307 gradcheck_wrapper=gradcheck_wrapper_hermitian_input, 1308 decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack], 1309 ), 1310 OpInfo( 1311 "linalg.cholesky_ex", 1312 aten_name="linalg_cholesky_ex", 1313 dtypes=floating_and_complex_types(), 1314 supports_forward_ad=True, 1315 supports_fwgrad_bwgrad=True, 1316 # See https://github.com/pytorch/pytorch/pull/78358 1317 check_batched_forward_grad=False, 1318 sample_inputs_func=sample_inputs_linalg_cholesky, 1319 gradcheck_wrapper=gradcheck_wrapper_hermitian_input, 1320 decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack], 1321 ), 1322 OpInfo( 1323 "linalg.vecdot", 1324 aten_name="linalg_vecdot", 1325 ref=lambda x, y, *, dim=-1: (x.conj() * y).sum(dim), 1326 dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16), 1327 sample_inputs_func=sample_inputs_linalg_vecdot, 1328 check_batched_forward_grad=False, 1329 supports_forward_ad=True, 1330 supports_fwgrad_bwgrad=True, 1331 skips=( 1332 # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479 1333 DecorateInfo( 1334 unittest.skip("Skipped!"), 1335 "TestSchemaCheckModeOpInfo", 1336 "test_schema_correctness", 1337 dtypes=(torch.complex64, torch.complex128), 1338 ), 1339 DecorateInfo( 1340 unittest.skip("Unsupported on MPS for now"), 1341 "TestCommon", 1342 "test_numpy_ref_mps", 1343 ), 1344 DecorateInfo( 1345 toleranceOverride({torch.half: tol(atol=1.2e-2, rtol=1.7e-2)}), 1346 "TestInductorOpInfo", 1347 "test_comprehensive", 1348 device_type="cuda", 1349 ), 1350 ), 1351 ), 1352 OpInfo( 1353 "linalg.cond", 1354 aten_name="linalg_cond", 1355 dtypes=floating_and_complex_types(), 1356 sample_inputs_func=sample_inputs_linalg_cond, 1357 check_batched_gradgrad=False, 1358 check_batched_forward_grad=False, 1359 supports_forward_ad=True, 1360 supports_fwgrad_bwgrad=True, 1361 gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, 1362 decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack, with_tf32_off], 1363 skips=( 1364 DecorateInfo( 1365 unittest.skip("Skipped!"), 1366 "TestFakeTensor", 1367 "test_fake_crossref_backward_amp", 1368 device_type="cuda", 1369 dtypes=[torch.float32], 1370 active_if=TEST_WITH_ROCM, 1371 ), 1372 DecorateInfo( 1373 unittest.skip("Skipped!"), 1374 "TestFakeTensor", 1375 "test_fake_crossref_backward_no_amp", 1376 device_type="cuda", 1377 dtypes=[torch.float32], 1378 active_if=TEST_WITH_ROCM, 1379 ), 1380 ), 1381 ), 1382 OpInfo( 1383 "linalg.eig", 1384 aten_name="linalg_eig", 1385 op=torch.linalg.eig, 1386 dtypes=floating_and_complex_types(), 1387 sample_inputs_func=sample_inputs_linalg_eig, 1388 check_batched_forward_grad=False, 1389 check_batched_grad=False, 1390 check_batched_gradgrad=False, 1391 supports_forward_ad=True, 1392 supports_fwgrad_bwgrad=True, 1393 skips=( 1394 # AssertionError: Scalars are not equal! 1395 DecorateInfo( 1396 unittest.expectedFailure, "TestCommon", "test_out", device_type="cpu" 1397 ), 1398 DecorateInfo( 1399 unittest.skip("Skipped!"), 1400 "TestCommon", 1401 "test_out", 1402 device_type="mps", 1403 dtypes=[torch.float32], 1404 ), 1405 DecorateInfo( 1406 unittest.skip("Skipped!"), 1407 "TestCommon", 1408 "test_variant_consistency_eager", 1409 device_type="mps", 1410 dtypes=[torch.float32], 1411 ), 1412 DecorateInfo( 1413 unittest.skip("Skipped!"), 1414 "TestJit", 1415 "test_variant_consistency_jit", 1416 device_type="mps", 1417 dtypes=[torch.float32], 1418 ), 1419 ), 1420 decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack, with_tf32_off], 1421 ), 1422 OpInfo( 1423 "linalg.eigvals", 1424 aten_name="linalg_eigvals", 1425 op=torch.linalg.eigvals, 1426 dtypes=floating_and_complex_types(), 1427 sample_inputs_func=sample_inputs_linalg_invertible, 1428 check_batched_forward_grad=False, 1429 check_batched_grad=False, 1430 check_batched_gradgrad=False, 1431 supports_forward_ad=True, 1432 supports_fwgrad_bwgrad=True, 1433 decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack], 1434 skips=( 1435 DecorateInfo( 1436 unittest.skip("Skipped!"), 1437 "TestCommon", 1438 "test_out", 1439 device_type="mps", 1440 dtypes=[torch.float32], 1441 ), 1442 DecorateInfo( 1443 unittest.skip("Skipped!"), 1444 "TestCommon", 1445 "test_variant_consistency_eager", 1446 device_type="mps", 1447 dtypes=[torch.float32], 1448 ), 1449 DecorateInfo( 1450 unittest.skip("Skipped!"), 1451 "TestJit", 1452 "test_variant_consistency_jit", 1453 device_type="mps", 1454 dtypes=[torch.float32], 1455 ), 1456 ), 1457 ), 1458 OpInfo( 1459 "linalg.eigh", 1460 aten_name="linalg_eigh", 1461 dtypes=floating_and_complex_types(), 1462 sample_inputs_func=sample_inputs_linalg_eigh, 1463 gradcheck_wrapper=gradcheck_wrapper_hermitian_input, 1464 check_batched_forward_grad=False, 1465 check_batched_grad=False, 1466 check_batched_gradgrad=False, 1467 supports_forward_ad=True, 1468 supports_fwgrad_bwgrad=True, 1469 decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack, with_tf32_off], 1470 skips=( 1471 DecorateInfo( 1472 unittest.skip("Skipped!"), 1473 "TestCommon", 1474 "test_out", 1475 device_type="mps", 1476 dtypes=[torch.float32], 1477 ), 1478 DecorateInfo( 1479 unittest.skip("Skipped!"), 1480 "TestCommon", 1481 "test_variant_consistency_eager", 1482 device_type="mps", 1483 dtypes=[torch.float32], 1484 ), 1485 DecorateInfo( 1486 unittest.skip("Skipped!"), 1487 "TestJit", 1488 "test_variant_consistency_jit", 1489 device_type="mps", 1490 dtypes=[torch.float32], 1491 ), 1492 ), 1493 ), 1494 OpInfo( 1495 "linalg.eigvalsh", 1496 aten_name="linalg_eigvalsh", 1497 dtypes=floating_and_complex_types(), 1498 sample_inputs_func=sample_inputs_linalg_eigh, 1499 gradcheck_wrapper=gradcheck_wrapper_hermitian_input, 1500 check_batched_forward_grad=False, 1501 check_batched_grad=False, 1502 check_batched_gradgrad=False, 1503 supports_forward_ad=True, 1504 supports_fwgrad_bwgrad=True, 1505 decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack], 1506 skips=( 1507 # Pre-existing condition; Needs to be fixed 1508 DecorateInfo( 1509 unittest.skip("Skipped!"), 1510 "TestCommon", 1511 "test_out", 1512 device_type="mps", 1513 dtypes=[torch.float32], 1514 ), 1515 DecorateInfo( 1516 unittest.skip("Skipped!"), 1517 "TestCommon", 1518 "test_variant_consistency_eager", 1519 device_type="mps", 1520 dtypes=[torch.float32], 1521 ), 1522 DecorateInfo( 1523 unittest.skip("Skipped!"), 1524 "TestJit", 1525 "test_variant_consistency_jit", 1526 device_type="mps", 1527 dtypes=[torch.float32], 1528 ), 1529 ), 1530 ), 1531 OpInfo( 1532 "linalg.householder_product", 1533 aten_name="linalg_householder_product", 1534 op=torch.linalg.householder_product, 1535 aliases=("orgqr",), 1536 dtypes=floating_and_complex_types(), 1537 # https://github.com/pytorch/pytorch/issues/80411 1538 gradcheck_fast_mode=True, 1539 # TODO: backward uses in-place operations that vmap doesn't like 1540 check_batched_grad=False, 1541 check_batched_gradgrad=False, 1542 supports_forward_ad=True, 1543 supports_fwgrad_bwgrad=True, 1544 check_batched_forward_grad=False, 1545 sample_inputs_func=sample_inputs_householder_product, 1546 decorators=[ 1547 skipCUDAIfNoCusolver, 1548 skipCPUIfNoLapack, 1549 DecorateInfo( 1550 toleranceOverride({torch.complex64: tol(atol=1e-3, rtol=1e-3)}) 1551 ), 1552 DecorateInfo( 1553 unittest.skip("Skipped! Flaky"), 1554 "TestFwdGradients", 1555 "test_fn_fwgrad_bwgrad", 1556 device_type="cpu", 1557 dtypes=(torch.complex128,), 1558 ), 1559 ], 1560 ), 1561 OpInfo( 1562 "linalg.ldl_factor", 1563 aten_name="linalg_ldl_factor", 1564 dtypes=floating_and_complex_types(), 1565 supports_autograd=False, 1566 sample_inputs_func=sample_inputs_linalg_ldl_factor, 1567 decorators=[skipCUDAIfNoMagmaAndNoLinalgsolver, skipCPUIfNoLapack], 1568 ), 1569 OpInfo( 1570 "linalg.ldl_factor_ex", 1571 aten_name="linalg_ldl_factor_ex", 1572 dtypes=floating_and_complex_types(), 1573 supports_autograd=False, 1574 sample_inputs_func=sample_inputs_linalg_ldl_factor, 1575 decorators=[skipCUDAIfNoMagmaAndNoLinalgsolver, skipCPUIfNoLapack], 1576 ), 1577 OpInfo( 1578 "linalg.ldl_solve", 1579 aten_name="linalg_ldl_solve", 1580 dtypes=floating_and_complex_types(), 1581 supports_autograd=False, 1582 sample_inputs_func=sample_inputs_linalg_ldl_solve, 1583 decorators=[ 1584 skipCUDAIf( 1585 _get_torch_cuda_version() < (11, 4), "not available before CUDA 11.3.1" 1586 ), 1587 skipCUDAIfNoCusolver, 1588 skipCUDAIfRocm, 1589 skipCPUIfNoLapack, 1590 ], 1591 ), 1592 OpInfo( 1593 "linalg.lstsq", 1594 aten_name="linalg_lstsq", 1595 dtypes=floating_and_complex_types(), 1596 supports_out=True, 1597 sample_inputs_func=sample_inputs_linalg_lstsq, 1598 error_inputs_func=error_inputs_lstsq, 1599 decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack], 1600 skips=( 1601 # we skip gradient checks for this suite as they are tested in 1602 # variant_test_name='grad_oriented' 1603 DecorateInfo(unittest.skip("Skipped!"), "TestFwdGradients"), 1604 DecorateInfo(unittest.skip("Skipped!"), "TestBwdGradients"), 1605 # The values for attribute 'shape' do not match 1606 DecorateInfo(unittest.skip("Skipped!"), "TestCommon", "test_out"), 1607 DecorateInfo( 1608 unittest.skip("Skipped!"), 1609 "TestCommon", 1610 "test_out", 1611 device_type="mps", 1612 dtypes=[torch.float32], 1613 ), 1614 DecorateInfo( 1615 unittest.skip("Skipped!"), 1616 "TestCommon", 1617 "test_variant_consistency_eager", 1618 device_type="mps", 1619 dtypes=[torch.float32], 1620 ), 1621 DecorateInfo( 1622 unittest.skip("Skipped!"), 1623 "TestJit", 1624 "test_variant_consistency_jit", 1625 device_type="mps", 1626 dtypes=[torch.float32], 1627 ), 1628 ), 1629 ), 1630 OpInfo( 1631 "linalg.lstsq", 1632 aten_name="linalg_lstsq", 1633 variant_test_name="grad_oriented", 1634 # gradchecks for forward AD fails with multi-Tensor outputs 1635 op=lambda a, b, driver: torch.linalg.lstsq(a, b, driver=driver)[0], 1636 supports_out=False, 1637 dtypes=floating_and_complex_types(), 1638 sample_inputs_func=sample_inputs_linalg_lstsq, 1639 error_inputs_func=error_inputs_lstsq_grad_oriented, 1640 # Runs very slowly on slow gradcheck - alternatively reduce input sizes 1641 gradcheck_fast_mode=True, 1642 supports_autograd=True, 1643 supports_forward_ad=True, 1644 supports_fwgrad_bwgrad=True, 1645 decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack], 1646 skips=( 1647 # tests do not work with passing lambda for op 1648 DecorateInfo( 1649 unittest.expectedFailure, "TestJit", "test_variant_consistency_jit" 1650 ), 1651 DecorateInfo( 1652 unittest.expectedFailure, 1653 "TestOperatorSignatures", 1654 "test_get_torch_func_signature_exhaustive", 1655 ), 1656 ), 1657 ), 1658 OpInfo( 1659 "linalg.matrix_power", 1660 aliases=("matrix_power",), 1661 aten_name="linalg_matrix_power", 1662 dtypes=floating_and_complex_types(), 1663 # https://github.com/pytorch/pytorch/issues/80411 1664 gradcheck_fast_mode=True, 1665 supports_inplace_autograd=False, 1666 supports_forward_ad=True, 1667 supports_fwgrad_bwgrad=True, 1668 check_batched_grad=False, 1669 decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack, with_tf32_off], 1670 sample_inputs_func=sample_inputs_linalg_matrix_power, 1671 ), 1672 OpInfo( 1673 "linalg.multi_dot", 1674 # Need this lambda because gradcheck does not work with TensorList inputs 1675 aten_name="linalg_multi_dot", 1676 dtypes=all_types_and_complex_and(torch.half, torch.bfloat16), 1677 dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16), 1678 supports_inplace_autograd=False, 1679 # Batched grad checks fail for empty input tensors (see https://github.com/pytorch/pytorch/issues/53407) 1680 check_batched_grad=False, 1681 check_batched_gradgrad=False, 1682 supports_forward_ad=True, 1683 supports_fwgrad_bwgrad=True, 1684 # https://github.com/pytorch/pytorch/issues/66357 1685 check_batched_forward_grad=False, 1686 sample_inputs_func=sample_inputs_linalg_multi_dot, 1687 gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, 1688 skips=( 1689 # https://github.com/pytorch/pytorch/issues/67470 1690 DecorateInfo( 1691 unittest.skip("67470!"), "TestCommon", "test_noncontiguous_samples" 1692 ), 1693 # Fails on XLA. 1694 # AssertionError: False is not true : Tensors failed to compare as equal! 1695 DecorateInfo( 1696 unittest.skip("Skipped!"), 1697 "TestOpInfo", 1698 device_type="xla", 1699 dtypes=(torch.long,), 1700 ), 1701 # https://github.com/pytorch/pytorch/issues/71774 1702 DecorateInfo( 1703 unittest.skip("Skipped!"), 1704 "TestNNCOpInfo", 1705 "test_nnc_correctness", 1706 device_type="cpu", 1707 dtypes=(torch.long,), 1708 ), 1709 ), 1710 ), 1711 # NB: linalg.norm has two variants so that different skips can be used for different sample inputs 1712 OpInfo( 1713 "linalg.norm", 1714 aten_name="linalg_norm", 1715 op=torch.linalg.norm, 1716 dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16), 1717 decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack, with_tf32_off], 1718 sample_inputs_func=sample_inputs_linalg_norm, 1719 supports_forward_ad=True, 1720 check_batched_forward_grad=False, 1721 supports_fwgrad_bwgrad=True, 1722 skips=( 1723 DecorateInfo( 1724 unittest.expectedFailure, "TestBwdGradients", "test_fn_gradgrad" 1725 ), 1726 DecorateInfo( 1727 unittest.skip("Skipped!"), 1728 "TestFakeTensor", 1729 "test_fake_crossref_backward_amp", 1730 device_type="cuda", 1731 dtypes=[torch.float32], 1732 active_if=TEST_WITH_ROCM, 1733 ), 1734 DecorateInfo( 1735 unittest.skip("Skipped!"), 1736 "TestFakeTensor", 1737 "test_fake_crossref_backward_no_amp", 1738 device_type="cuda", 1739 dtypes=[torch.float32], 1740 active_if=TEST_WITH_ROCM, 1741 ), 1742 ), 1743 ), 1744 OpInfo( 1745 "linalg.norm", 1746 op=torch.linalg.norm, 1747 variant_test_name="subgradients_at_zero", 1748 dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16), 1749 decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack, with_tf32_off], 1750 sample_inputs_func=partial( 1751 sample_inputs_linalg_norm, variant="subgradient_at_zero" 1752 ), 1753 aten_name="linalg_norm", 1754 supports_forward_ad=True, 1755 # torch.autograd.gradcheck.GradcheckError: While computing batched gradients, got: 1756 # Could not allocate memory to change Tensor SizesAndStrides! 1757 check_batched_forward_grad=False, 1758 supports_fwgrad_bwgrad=True, 1759 skips=( 1760 # [NEW] Skips specifically for sample inputs at zero 1761 # norm's vjp/jvp are not well-conditioned near zero 1762 DecorateInfo( 1763 unittest.expectedFailure, "TestBwdGradients", "test_fn_gradgrad" 1764 ), 1765 DecorateInfo( 1766 unittest.expectedFailure, "TestFwdGradients", "test_fn_fwgrad_bwgrad" 1767 ), 1768 DecorateInfo( 1769 unittest.expectedFailure, "TestFwdGradients", "test_forward_mode_AD" 1770 ), 1771 DecorateInfo(unittest.expectedFailure, "TestBwdGradients", "test_fn_grad"), 1772 ), 1773 ), 1774 OpInfo( 1775 "linalg.matrix_norm", 1776 aten_name="linalg_matrix_norm", 1777 dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16), 1778 supports_forward_ad=True, 1779 check_batched_forward_grad=False, 1780 check_batched_gradgrad=False, 1781 supports_fwgrad_bwgrad=True, 1782 decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack, with_tf32_off], 1783 sample_inputs_func=sample_inputs_linalg_matrix_norm, 1784 skips=( 1785 DecorateInfo( 1786 unittest.skip("Skipped!"), 1787 "TestFakeTensor", 1788 "test_fake_crossref_backward_amp", 1789 device_type="cuda", 1790 dtypes=[torch.float32], 1791 active_if=TEST_WITH_ROCM, 1792 ), 1793 DecorateInfo( 1794 unittest.skip("Skipped!"), 1795 "TestFakeTensor", 1796 "test_fake_crossref_backward_no_amp", 1797 device_type="cuda", 1798 dtypes=[torch.float32], 1799 active_if=TEST_WITH_ROCM, 1800 ), 1801 ), 1802 ), 1803 OpInfo( 1804 "linalg.qr", 1805 aten_name="linalg_qr", 1806 op=torch.linalg.qr, 1807 dtypes=floating_and_complex_types(), 1808 supports_forward_ad=True, 1809 supports_fwgrad_bwgrad=True, 1810 # In-place ops 1811 check_batched_gradgrad=False, 1812 sample_inputs_func=sample_inputs_linalg_qr_geqrf, 1813 decorators=[skipCUDAIfNoCusolver, skipCPUIfNoLapack], 1814 ), 1815 OpInfo( 1816 "linalg.slogdet", 1817 aten_name="linalg_slogdet", 1818 op=torch.linalg.slogdet, 1819 dtypes=floating_and_complex_types(), 1820 supports_forward_ad=True, 1821 supports_fwgrad_bwgrad=True, 1822 sample_inputs_func=sample_inputs_linalg_det_logdet_slogdet, 1823 decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack], 1824 ), 1825 OpInfo( 1826 "linalg.vander", 1827 aten_name="linalg_vander", 1828 ref=np_vander_batched, 1829 op=torch.linalg.vander, 1830 dtypes=all_types_and_complex(), 1831 supports_forward_ad=True, 1832 supports_fwgrad_bwgrad=True, 1833 supports_out=False, 1834 sample_inputs_func=sample_inputs_linalg_vander, 1835 skips=( 1836 DecorateInfo( 1837 unittest.skip("Unsupported on MPS for now"), 1838 "TestCommon", 1839 "test_numpy_ref_mps", 1840 ), 1841 ), 1842 ), 1843 ReductionOpInfo( 1844 "linalg.vector_norm", 1845 op=torch.linalg.vector_norm, 1846 identity=0, 1847 nan_policy="propagate", 1848 supports_multiple_dims=True, 1849 complex_to_real=True, 1850 supports_forward_ad=True, 1851 # torch.autograd.gradcheck.GradcheckError: While computing batched gradients 1852 # got: Could not allocate memory to change Tensor SizesAndStrides! 1853 check_batched_forward_grad=False, 1854 supports_fwgrad_bwgrad=True, 1855 dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16), 1856 generate_args_kwargs=sample_kwargs_vector_norm, 1857 aten_name="linalg_vector_norm", 1858 skips=( 1859 # FIXME: sum reduces all dimensions when dim=[] 1860 DecorateInfo(unittest.expectedFailure, "TestReductions", "test_dim_empty"), 1861 DecorateInfo( 1862 unittest.expectedFailure, "TestReductions", "test_dim_empty_keepdim" 1863 ), 1864 ), 1865 ), 1866 OpInfo( 1867 "linalg.lu_factor", 1868 aten_name="linalg_lu_factor", 1869 op=torch.linalg.lu_factor, 1870 dtypes=floating_and_complex_types(), 1871 # Runs very slowly on slow gradcheck - alternatively reduce input sizes 1872 # https://github.com/pytorch/pytorch/issues/80411 1873 gradcheck_fast_mode=True, 1874 supports_forward_ad=True, 1875 supports_fwgrad_bwgrad=True, 1876 sample_inputs_func=sample_inputs_linalg_lu, 1877 decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack], 1878 skips=( 1879 # linalg.lu_factor: LU without pivoting is not implemented on the CPU 1880 DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"), 1881 ), 1882 ), 1883 OpInfo( 1884 "linalg.lu_factor_ex", 1885 aten_name="linalg_lu_factor_ex", 1886 op=torch.linalg.lu_factor_ex, 1887 dtypes=floating_and_complex_types(), 1888 # https://github.com/pytorch/pytorch/issues/80411 1889 gradcheck_fast_mode=True, 1890 supports_forward_ad=True, 1891 supports_fwgrad_bwgrad=True, 1892 sample_inputs_func=sample_inputs_linalg_lu, 1893 decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack], 1894 skips=( 1895 # linalg.lu_factor: LU without pivoting is not implemented on the CPU 1896 DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"), 1897 ), 1898 ), 1899 OpInfo( 1900 "linalg.lu", 1901 aten_name="linalg_lu", 1902 op=torch.linalg.lu, 1903 dtypes=floating_and_complex_types(), 1904 # https://github.com/pytorch/pytorch/issues/80411 1905 # Runs very slowly on slow-gradcheck - alternatively reduce input sizes 1906 gradcheck_fast_mode=True, 1907 supports_forward_ad=True, 1908 supports_fwgrad_bwgrad=True, 1909 sample_inputs_func=sample_inputs_linalg_lu, 1910 decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack], 1911 skips=( 1912 # linalg.lu_factor: LU without pivoting is not implemented on the CPU 1913 DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"), 1914 ), 1915 ), 1916 OpInfo( 1917 "linalg.lu_solve", 1918 op=torch.linalg.lu_solve, 1919 aten_name="linalg_lu_solve", 1920 dtypes=floating_and_complex_types(), 1921 # Runs very slowly on slow gradcheck - alternatively reduce input sizes 1922 gradcheck_fast_mode=True, 1923 supports_forward_ad=True, 1924 check_batched_forward_grad=False, 1925 supports_fwgrad_bwgrad=True, 1926 sample_inputs_func=sample_inputs_lu_solve, 1927 skips=( 1928 DecorateInfo( 1929 unittest.skip("Tests different backward paths"), 1930 "TestCommon", 1931 "test_floating_inputs_are_differentiable", 1932 ), 1933 ), 1934 decorators=[skipCPUIfNoLapack, skipCUDAIfNoMagmaAndNoCusolver], 1935 ), 1936 OpInfo( 1937 "linalg.inv", 1938 aten_name="linalg_inv", 1939 op=torch.linalg.inv, 1940 aliases=("inverse",), 1941 dtypes=floating_and_complex_types(), 1942 sample_inputs_func=sample_inputs_linalg_invertible, 1943 check_batched_gradgrad=False, 1944 supports_forward_ad=True, 1945 supports_fwgrad_bwgrad=True, 1946 decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack], 1947 skips=( 1948 DecorateInfo( 1949 unittest.skip("Skipped!"), 1950 "TestCommon", 1951 "test_out", 1952 device_type="mps", 1953 dtypes=[torch.float32], 1954 ), 1955 DecorateInfo( 1956 unittest.skip("Skipped!"), 1957 "TestCommon", 1958 "test_variant_consistency_eager", 1959 device_type="mps", 1960 dtypes=[torch.float32], 1961 ), 1962 DecorateInfo( 1963 unittest.skip("Skipped!"), 1964 "TestJit", 1965 "test_variant_consistency_jit", 1966 device_type="mps", 1967 dtypes=[torch.float32], 1968 ), 1969 ), 1970 ), 1971 OpInfo( 1972 "linalg.inv_ex", 1973 aten_name="linalg_inv_ex", 1974 op=torch.linalg.inv_ex, 1975 dtypes=floating_and_complex_types(), 1976 sample_inputs_func=sample_inputs_linalg_invertible, 1977 check_batched_gradgrad=False, 1978 supports_forward_ad=True, 1979 supports_fwgrad_bwgrad=True, 1980 decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack], 1981 skips=( 1982 DecorateInfo( 1983 unittest.skip("Skipped!"), 1984 "TestCommon", 1985 "test_out", 1986 device_type="mps", 1987 dtypes=[torch.float32], 1988 ), 1989 DecorateInfo( 1990 unittest.skip("Skipped!"), 1991 "TestCommon", 1992 "test_variant_consistency_eager", 1993 device_type="mps", 1994 dtypes=[torch.float32], 1995 ), 1996 DecorateInfo( 1997 unittest.skip("Skipped!"), 1998 "TestJit", 1999 "test_variant_consistency_jit", 2000 device_type="mps", 2001 dtypes=[torch.float32], 2002 ), 2003 ), 2004 ), 2005 OpInfo( 2006 "linalg.solve", 2007 aten_name="linalg_solve", 2008 op=torch.linalg.solve, 2009 dtypes=floating_and_complex_types(), 2010 sample_inputs_func=sample_inputs_linalg_solve, 2011 # Runs very slowly on slow gradcheck - alternatively reduce input sizes 2012 gradcheck_fast_mode=True, 2013 supports_forward_ad=True, 2014 supports_fwgrad_bwgrad=True, 2015 decorators=[ 2016 skipCUDAIfNoMagmaAndNoCusolver, 2017 skipCPUIfNoLapack, 2018 DecorateInfo( 2019 toleranceOverride({torch.float32: tol(atol=1.3e-05, rtol=6e-04)}), 2020 "TestCommon", 2021 "test_noncontiguous_samples", 2022 device_type="cpu", 2023 ), 2024 ], 2025 skips=( 2026 DecorateInfo( 2027 unittest.skip("Skipped!"), 2028 "TestCommon", 2029 "test_out", 2030 device_type="mps", 2031 dtypes=[torch.float32], 2032 ), 2033 DecorateInfo( 2034 unittest.skip("Skipped!"), 2035 "TestCommon", 2036 "test_variant_consistency_eager", 2037 device_type="mps", 2038 dtypes=[torch.float32], 2039 ), 2040 DecorateInfo( 2041 unittest.skip("Skipped!"), 2042 "TestJit", 2043 "test_variant_consistency_jit", 2044 device_type="mps", 2045 dtypes=[torch.float32], 2046 ), 2047 ), 2048 ), 2049 OpInfo( 2050 "linalg.solve_ex", 2051 aten_name="linalg_solve_ex", 2052 op=torch.linalg.solve_ex, 2053 dtypes=floating_and_complex_types(), 2054 sample_inputs_func=sample_inputs_linalg_solve, 2055 supports_forward_ad=True, 2056 supports_fwgrad_bwgrad=True, 2057 decorators=[ 2058 skipCUDAIfNoMagmaAndNoCusolver, 2059 skipCPUIfNoLapack, 2060 DecorateInfo( 2061 toleranceOverride({torch.float32: tol(atol=1.3e-05, rtol=6e-04)}), 2062 "TestCommon", 2063 "test_noncontiguous_samples", 2064 device_type="cpu", 2065 ), 2066 ], 2067 skips=( 2068 DecorateInfo( 2069 unittest.skip("Skipped!"), 2070 "TestCommon", 2071 "test_out", 2072 device_type="mps", 2073 dtypes=[torch.float32], 2074 ), 2075 DecorateInfo( 2076 unittest.skip("Skipped!"), 2077 "TestCommon", 2078 "test_variant_consistency_eager", 2079 device_type="mps", 2080 dtypes=[torch.float32], 2081 ), 2082 DecorateInfo( 2083 unittest.skip("Skipped!"), 2084 "TestJit", 2085 "test_variant_consistency_jit", 2086 device_type="mps", 2087 dtypes=[torch.float32], 2088 ), 2089 ), 2090 ), 2091 OpInfo( 2092 "linalg.solve_triangular", 2093 aten_name="linalg_solve_triangular", 2094 op=torch.linalg.solve_triangular, 2095 dtypes=floating_and_complex_types(), 2096 sample_inputs_func=sample_inputs_linalg_solve_triangular, 2097 supports_fwgrad_bwgrad=True, 2098 skips=(skipCPUIfNoLapack,), 2099 # linalg.solve_triangular cannot be batched over because of a call to out.copy_(result); 2100 supports_forward_ad=True, 2101 ), 2102 OpInfo( 2103 "linalg.matrix_rank", 2104 aten_name="linalg_matrix_rank", 2105 dtypes=floating_and_complex_types(), 2106 supports_autograd=False, 2107 sample_inputs_func=sample_inputs_matrix_rank, 2108 decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack], 2109 skips=( 2110 DecorateInfo( 2111 unittest.skip("Skipped!"), 2112 "TestCommon", 2113 "test_out", 2114 device_type="mps", 2115 dtypes=[torch.float32], 2116 ), 2117 DecorateInfo( 2118 unittest.skip("Skipped!"), 2119 "TestCommon", 2120 "test_variant_consistency_eager", 2121 device_type="mps", 2122 dtypes=[torch.float32], 2123 ), 2124 # jit doesn't accept tensor inputs for matrix rank 2125 DecorateInfo( 2126 unittest.skip("Skipped!"), 2127 "TestJit", 2128 "test_variant_consistency_jit", 2129 dtypes=[torch.complex64, torch.float32], 2130 ), 2131 ), 2132 ), 2133 OpInfo( 2134 "linalg.matrix_rank", 2135 aten_name="linalg_matrix_rank", 2136 variant_test_name="hermitian", 2137 dtypes=floating_and_complex_types(), 2138 supports_autograd=False, 2139 sample_inputs_func=sample_inputs_linalg_pinv_hermitian, 2140 decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack], 2141 skips=( 2142 DecorateInfo( 2143 unittest.skip("Skipped!"), 2144 "TestCommon", 2145 "test_out", 2146 device_type="mps", 2147 dtypes=[torch.float32], 2148 ), 2149 DecorateInfo( 2150 unittest.skip("Skipped!"), 2151 "TestJit", 2152 "test_variant_consistency_jit", 2153 device_type="mps", 2154 dtypes=[torch.float32], 2155 ), 2156 ), 2157 ), 2158 OpInfo( 2159 "linalg.pinv", 2160 aten_name="linalg_pinv", 2161 op=torch.linalg.pinv, 2162 dtypes=floating_and_complex_types(), 2163 # Runs very slowly on slow gradcheck - alternatively reduce input sizes 2164 gradcheck_fast_mode=True, 2165 check_batched_grad=False, 2166 check_batched_gradgrad=False, 2167 supports_forward_ad=True, 2168 supports_fwgrad_bwgrad=True, 2169 sample_inputs_func=sample_inputs_linalg_pinv, 2170 decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack], 2171 skips=( 2172 # errors with "leaked XXXX bytes CUDA memory on device 0" 2173 DecorateInfo( 2174 unittest.skip("Skipped!"), 2175 "TestJit", 2176 "test_variant_consistency_jit", 2177 device_type="cuda", 2178 ), 2179 ), 2180 ), 2181 OpInfo( 2182 "linalg.pinv", 2183 aten_name="linalg_pinv", 2184 variant_test_name="singular", 2185 # pinv is Frechet-differentiable in a rank-preserving neighborhood, 2186 # so we feed inputs that are the products of two full-rank factors, 2187 # to avoid any rank changes caused by the perturbations in the gradcheck 2188 op=lambda a, b: torch.linalg.pinv(a @ b.mT), 2189 dtypes=floating_and_complex_types(), 2190 supports_out=False, 2191 check_batched_grad=False, 2192 check_batched_gradgrad=False, 2193 supports_forward_ad=True, 2194 supports_fwgrad_bwgrad=True, 2195 sample_inputs_func=sample_inputs_linalg_pinv_singular, 2196 # Only large tensors show issues with implicit backward used prior to 2197 # explicit backward implementation. 2198 decorators=[slowTest, skipCUDAIfNoCusolver, skipCPUIfNoLapack], 2199 skips=( 2200 DecorateInfo( 2201 unittest.expectedFailure, "TestJit", "test_variant_consistency_jit" 2202 ), 2203 # CUDA runs out of memory 2204 DecorateInfo( 2205 unittest.skip("Skipped!"), 2206 "TestFwdGradients", 2207 "test_fn_fwgrad_bwgrad", 2208 device_type="cuda", 2209 dtypes=[torch.cdouble], 2210 ), 2211 # This test takes almost 2 hours to run! 2212 DecorateInfo( 2213 unittest.skip("Skipped!"), 2214 "TestBwdGradients", 2215 "test_fn_gradgrad", 2216 device_type="cuda", 2217 dtypes=[torch.cdouble], 2218 ), 2219 ), 2220 ), 2221 OpInfo( 2222 "linalg.pinv", 2223 aten_name="linalg_pinv", 2224 variant_test_name="hermitian", 2225 dtypes=floating_and_complex_types(), 2226 check_batched_grad=False, 2227 check_batched_gradgrad=False, 2228 supports_forward_ad=True, 2229 supports_fwgrad_bwgrad=True, 2230 # See https://github.com/pytorch/pytorch/pull/78358 2231 check_batched_forward_grad=False, 2232 sample_inputs_func=sample_inputs_linalg_pinv_hermitian, 2233 gradcheck_wrapper=gradcheck_wrapper_hermitian_input, 2234 decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack], 2235 skips=( 2236 DecorateInfo( 2237 unittest.skip("Skipped!"), 2238 "TestCommon", 2239 "test_out", 2240 device_type="mps", 2241 dtypes=[torch.float32], 2242 ), 2243 DecorateInfo( 2244 unittest.skip("Skipped!"), 2245 "TestCommon", 2246 "test_variant_consistency_eager", 2247 device_type="mps", 2248 dtypes=[torch.float32], 2249 ), 2250 DecorateInfo( 2251 unittest.skip("Skipped!"), 2252 "TestJit", 2253 "test_variant_consistency_jit", 2254 device_type="mps", 2255 dtypes=[torch.float32], 2256 ), 2257 DecorateInfo( 2258 toleranceOverride({torch.float32: tol(atol=1e-5, rtol=1e-5)}), 2259 "TestCommon", 2260 "test_noncontiguous_samples", 2261 device_type="cuda", 2262 ), 2263 # This test is flaky under slow gradcheck, likely due to rounding issues 2264 DecorateInfo( 2265 skipIfSlowGradcheckEnv, 2266 "TestFwdGradients", 2267 "test_fn_fwgrad_bwgrad", 2268 device_type="cuda", 2269 ), 2270 ), 2271 ), 2272 OpInfo( 2273 "linalg.svd", 2274 op=torch.linalg.svd, 2275 aten_name="linalg_svd", 2276 decomp_aten_name="_linalg_svd", 2277 dtypes=floating_and_complex_types(), 2278 # Runs very slowly on slow-gradcheck - alternatively reduce input sizes 2279 gradcheck_fast_mode=True, 2280 supports_fwgrad_bwgrad=True, 2281 supports_forward_ad=True, 2282 check_batched_forward_grad=False, 2283 # We're using at::allclose, which does not have a batching rule 2284 check_batched_grad=False, 2285 check_batched_gradgrad=False, 2286 sample_inputs_func=sample_inputs_svd, 2287 decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack, with_tf32_off], 2288 skips=( 2289 DecorateInfo( 2290 unittest.skip("Skipped!"), 2291 "TestCommon", 2292 "test_out", 2293 device_type="mps", 2294 dtypes=[torch.float32], 2295 ), 2296 DecorateInfo( 2297 unittest.skip("Skipped!"), 2298 "TestCommon", 2299 "test_variant_consistency_eager", 2300 device_type="mps", 2301 dtypes=[torch.float32], 2302 ), 2303 DecorateInfo( 2304 unittest.skip("Skipped!"), 2305 "TestJit", 2306 "test_variant_consistency_jit", 2307 device_type="mps", 2308 dtypes=[torch.float32], 2309 ), 2310 DecorateInfo( 2311 unittest.skip("Skipped!"), 2312 "TestFakeTensor", 2313 "test_fake_crossref_backward_amp", 2314 device_type="cuda", 2315 dtypes=[torch.float32], 2316 active_if=TEST_WITH_ROCM, 2317 ), 2318 DecorateInfo( 2319 unittest.skip("Skipped!"), 2320 "TestFakeTensor", 2321 "test_fake_crossref_backward_no_amp", 2322 device_type="cuda", 2323 dtypes=[torch.float32], 2324 active_if=TEST_WITH_ROCM, 2325 ), 2326 ), 2327 ), 2328 OpInfo( 2329 "linalg.svdvals", 2330 op=torch.linalg.svdvals, 2331 aten_name="linalg_svdvals", 2332 decomp_aten_name="_linalg_svd", 2333 dtypes=floating_and_complex_types(), 2334 check_batched_forward_grad=False, 2335 supports_fwgrad_bwgrad=True, 2336 supports_forward_ad=True, 2337 # We're using at::allclose, which does not have a batching rule 2338 check_batched_gradgrad=False, 2339 sample_inputs_func=sample_inputs_linalg_svdvals, 2340 decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack, with_tf32_off], 2341 skips=( 2342 DecorateInfo( 2343 unittest.skip("Skipped!"), 2344 "TestFakeTensor", 2345 "test_fake_crossref_backward_amp", 2346 device_type="cuda", 2347 dtypes=[torch.float32], 2348 active_if=TEST_WITH_ROCM, 2349 ), 2350 DecorateInfo( 2351 unittest.skip("Skipped!"), 2352 "TestFakeTensor", 2353 "test_fake_crossref_backward_no_amp", 2354 device_type="cuda", 2355 dtypes=[torch.float32], 2356 active_if=TEST_WITH_ROCM, 2357 ), 2358 ), 2359 ), 2360 OpInfo( 2361 "linalg.tensorinv", 2362 ref=np.linalg.tensorinv, 2363 dtypes=floating_and_complex_types(), 2364 sample_inputs_func=sample_inputs_tensorinv, 2365 supports_forward_ad=True, 2366 supports_fwgrad_bwgrad=True, 2367 # See https://github.com/pytorch/pytorch/pull/78358 2368 check_batched_forward_grad=False, 2369 decorators=[skipCPUIfNoLapack, skipCUDAIfNoMagmaAndNoCusolver], 2370 skips=( 2371 DecorateInfo( 2372 unittest.skip("Unsupported on MPS for now"), 2373 "TestCommon", 2374 "test_numpy_ref_mps", 2375 ), 2376 ), 2377 ), 2378 OpInfo( 2379 "linalg.tensorsolve", 2380 ref=lambda a, b, dims=None: np.linalg.tensorsolve(a, b, axes=dims), 2381 dtypes=floating_and_complex_types(), 2382 sample_inputs_func=sample_inputs_tensorsolve, 2383 supports_forward_ad=True, 2384 supports_fwgrad_bwgrad=True, 2385 decorators=[ 2386 skipCUDAIfNoMagmaAndNoCusolver, 2387 skipCPUIfNoLapack, 2388 DecorateInfo( 2389 toleranceOverride({torch.float32: tol(atol=1e-03, rtol=1e-03)}), 2390 "TestCommon", 2391 "test_noncontiguous_samples", 2392 device_type="cuda", 2393 ), 2394 DecorateInfo( 2395 toleranceOverride({torch.float32: tol(atol=8e-04, rtol=7e-06)}), 2396 "TestCommon", 2397 "test_noncontiguous_samples", 2398 device_type="cpu", 2399 ), 2400 ], 2401 skips=( 2402 DecorateInfo( 2403 unittest.skip("Unsupported on MPS for now"), 2404 "TestCommon", 2405 "test_numpy_ref_mps", 2406 ), 2407 ), 2408 ), 2409] 2410 2411python_ref_db: List[OpInfo] = [ 2412 # 2413 # torch.linalg 2414 # 2415 PythonRefInfo( 2416 "_refs.linalg.cross", 2417 torch_opinfo_name="linalg.cross", 2418 supports_out=True, 2419 op_db=op_db, 2420 skips=( 2421 # TODO: is this really needed? 2422 DecorateInfo( 2423 unittest.expectedFailure, "TestCommon", "test_python_ref_errors" 2424 ), 2425 ), 2426 ), 2427 PythonRefInfo( 2428 "_refs.linalg.diagonal", 2429 torch_opinfo_name="linalg.diagonal", 2430 supports_out=False, 2431 op_db=op_db, 2432 ), 2433 PythonRefInfo( 2434 "_refs.linalg.vecdot", 2435 torch_opinfo_name="linalg.vecdot", 2436 op_db=op_db, 2437 ), 2438 ReductionPythonRefInfo( 2439 "_refs.linalg.vector_norm", 2440 torch_opinfo_name="linalg.vector_norm", 2441 supports_out=True, 2442 op_db=op_db, 2443 skips=( 2444 # FIXME: sum reduces all dimensions when dim=[] 2445 DecorateInfo(unittest.expectedFailure, "TestReductions", "test_dim_empty"), 2446 DecorateInfo( 2447 unittest.expectedFailure, "TestReductions", "test_dim_empty_keepdim" 2448 ), 2449 ), 2450 ), 2451 PythonRefInfo( 2452 "_refs.linalg.matrix_norm", 2453 torch_opinfo_name="linalg.matrix_norm", 2454 supports_out=True, 2455 # Uses vector_norm inside and vector_norm is affected by 2456 # https://github.com/pytorch/pytorch/issues/77216 2457 validate_view_consistency=False, 2458 op_db=op_db, 2459 ), 2460 PythonRefInfo( 2461 "_refs.linalg.norm", 2462 torch_opinfo_name="linalg.norm", 2463 supports_out=True, 2464 # Uses vector_norm inside and vector_norm is affected by 2465 # https://github.com/pytorch/pytorch/issues/77216 2466 validate_view_consistency=False, 2467 op_db=op_db, 2468 ), 2469 PythonRefInfo( 2470 "_refs.linalg.svd", 2471 torch_opinfo_name="linalg.svd", 2472 supports_out=True, 2473 op_db=op_db, 2474 ), 2475 PythonRefInfo( 2476 "_refs.linalg.svdvals", 2477 torch_opinfo_name="linalg.svdvals", 2478 supports_out=True, 2479 op_db=op_db, 2480 ), 2481] 2482