1# Owner(s): ["module: linear algebra"] 2 3import torch 4import numpy as np 5 6import unittest 7import itertools 8import warnings 9import math 10from math import inf, nan, isnan 11import re 12import random 13from random import randrange 14from itertools import product 15from functools import reduce, partial 16 17from torch.testing._internal.common_utils import \ 18 (TestCase, run_tests, TEST_SCIPY, IS_MACOS, IS_WINDOWS, slowTest, 19 TEST_WITH_ROCM, IS_FBCODE, IS_REMOTE_GPU, iter_indices, 20 make_fullrank_matrices_with_distinct_singular_values, 21 freeze_rng_state, IS_ARM64, IS_SANDCASTLE, TEST_OPT_EINSUM, parametrize, skipIfTorchDynamo, 22 setBlasBackendsToDefaultFinally, setLinalgBackendsToDefaultFinally, serialTest) 23from torch.testing._internal.common_device_type import \ 24 (instantiate_device_type_tests, dtypes, has_cusolver, has_hipsolver, 25 onlyCPU, skipCUDAIf, skipCUDAIfNoMagma, skipCPUIfNoLapack, precisionOverride, 26 skipCUDAIfNoMagmaAndNoCusolver, skipCUDAIfRocm, onlyNativeDeviceTypes, dtypesIfCUDA, 27 onlyCUDA, skipCUDAVersionIn, skipMeta, skipCUDAIfNoCusolver, skipCUDAIfNotRocm, 28 dtypesIfMPS, largeTensorTest) 29from torch.testing import make_tensor 30from torch.testing._internal.common_dtype import ( 31 all_types, all_types_and_complex_and, floating_and_complex_types, integral_types, 32 floating_and_complex_types_and, floating_types_and, complex_types, 33) 34from torch.testing._internal.common_cuda import SM53OrLater, SM80OrLater, SM90OrLater, tf32_on_and_off, _get_magma_version, \ 35 _get_torch_cuda_version, CDNA2OrLater 36from torch.testing._internal.common_quantization import _group_quantize_tensor, _dynamically_quantize_per_channel 37from torch.testing._internal.common_mkldnn import bf32_on_and_off 38from torch.distributions.binomial import Binomial 39import torch.backends.opt_einsum as opt_einsum 40import operator 41 42# Protects against includes accidentally setting the default dtype 43assert torch.get_default_dtype() is torch.float32 44 45if TEST_SCIPY: 46 import scipy 47 48def blaslt_supported_device(): 49 if torch.cuda.is_available(): 50 if torch.version.hip: 51 for arch in ['gfx90a', 'gfx94']: 52 if arch in torch.cuda.get_device_properties(0).gcnArchName: 53 return True 54 else: 55 return True 56 return False 57 58def set_tunableop_defaults(): 59 if not torch.cuda.is_available(): 60 # TunableOp not supported on CPU at this time. 61 return 62 63 # disable TunableOp and restore to default values 64 ordinal = torch.cuda.current_device() 65 filename = f"tunableop_results{ordinal}.csv" 66 torch.cuda.tunable.enable(False) 67 torch.cuda.tunable.tuning_enable(True) 68 torch.cuda.tunable.set_filename(filename) # reset back to default filename for next unit test 69 torch.cuda.tunable.set_max_tuning_duration(30) 70 torch.cuda.tunable.set_max_tuning_iterations(100) 71 72 73class TestLinalg(TestCase): 74 def setUp(self): 75 super(self.__class__, self).setUp() 76 torch.backends.cuda.matmul.allow_tf32 = False 77 78 def tearDown(self): 79 torch.backends.cuda.matmul.allow_tf32 = True 80 super(self.__class__, self).tearDown() 81 82 exact_dtype = True 83 84 @dtypes(torch.float, torch.cfloat) 85 @precisionOverride({torch.float: 1e-06, torch.cfloat: 1e-06}) 86 @tf32_on_and_off(5e-3) 87 @bf32_on_and_off(5e-3) 88 def test_inner(self, device, dtype): 89 def check(a_sizes_, b_sizes_): 90 for a_sizes, b_sizes in ((a_sizes_, b_sizes_), (b_sizes_, a_sizes_)): 91 a = torch.randn(a_sizes, dtype=dtype, device=device) 92 b = torch.randn(b_sizes, dtype=dtype, device=device) 93 res = torch.inner(a, b) 94 ref = np.inner(a.cpu().numpy(), b.cpu().numpy()) 95 self.assertEqual(res.cpu(), torch.from_numpy(np.array(ref))) 96 out = torch.zeros_like(res) 97 torch.inner(a, b, out=out) 98 self.assertEqual(res, out) 99 100 check([], []) # scalar x scalar 101 check([], [0]) # scalar x empty 102 check([], [3]) # scalar x 1D 103 check([], [2, 3, 4]) # scalar x 3D 104 105 check([0], [0]) # empty x empty 106 check([0], [2, 0]) # empty x 2D 107 108 check([2], [2]) # 1D x 1D 109 check([2], [3, 1, 2]) # 1D x 3D 110 check([2], [3, 0, 2]) # 1D x 3D empty 111 112 check([1, 2], [3, 2]) # 2D x 2D 113 check([1, 2], [3, 4, 2]) # 2D x 3D 114 check([2, 1, 3, 2], [1, 3, 2, 2]) # 4D x 4D 115 116 # Test error message 117 with self.assertRaisesRegex(RuntimeError, 118 r"inner\(\) the last dimension must match on both " 119 r"input tensors but got shapes \[2, 3\] and \[2, 2\]"): 120 torch.randn(2, 3, device=device, dtype=dtype).inner(torch.randn(2, 2, device=device, dtype=dtype)) 121 122 # Tests torch.outer, and its alias, torch.ger, vs. NumPy 123 @precisionOverride({torch.bfloat16: 1e-1}) 124 @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool)) 125 def test_outer(self, device, dtype): 126 def run_test_case(a, b): 127 if dtype == torch.bfloat16: 128 a_np = a.to(torch.double).cpu().numpy() 129 b_np = b.to(torch.double).cpu().numpy() 130 exact_dtype = False 131 else: 132 a_np = a.cpu().numpy() 133 b_np = b.cpu().numpy() 134 exact_dtype = True 135 expected = np.outer(a_np, b_np) 136 137 self.assertEqual(torch.outer(a, b), expected, exact_dtype=False) 138 self.assertEqual(torch.Tensor.outer(a, b), expected, exact_dtype=False) 139 140 self.assertEqual(torch.ger(a, b), expected, exact_dtype=False) 141 self.assertEqual(torch.Tensor.ger(a, b), expected, exact_dtype=False) 142 143 # test out variant 144 out = torch.empty(a.size(0), b.size(0), device=device, dtype=dtype) 145 torch.outer(a, b, out=out) 146 self.assertEqual(out, expected, exact_dtype=False) 147 148 out = torch.empty(a.size(0), b.size(0), device=device, dtype=dtype) 149 torch.ger(a, b, out=out) 150 self.assertEqual(out, expected, exact_dtype=False) 151 152 a = torch.randn(50).to(device=device, dtype=dtype) 153 b = torch.randn(50).to(device=device, dtype=dtype) 154 run_test_case(a, b) 155 156 # test 0 strided tensor 157 zero_strided = torch.randn(1).to(device=device, dtype=dtype).expand(50) 158 run_test_case(zero_strided, b) 159 run_test_case(a, zero_strided) 160 161 def test_matrix_rank_removed_error(self, device): 162 a = make_tensor(5, 5, device=device, dtype=torch.float32) 163 with self.assertRaisesRegex(RuntimeError, "This function was deprecated since version 1.9 and is now removed"): 164 torch.matrix_rank(a) 165 166 def test_solve_removed_error(self, device): 167 a = make_tensor(5, 5, device=device, dtype=torch.float32) 168 b = make_tensor(5, 1, device=device, dtype=torch.float32) 169 with self.assertRaisesRegex(RuntimeError, "This function was deprecated since version 1.9 and is now removed"): 170 torch.solve(b, a) 171 with self.assertRaisesRegex(RuntimeError, "This function was deprecated since version 1.9 and is now removed"): 172 b.solve(a) 173 174 def test_eig_removed_error(self, device): 175 a = make_tensor(5, 5, device=device, dtype=torch.float32) 176 with self.assertRaisesRegex(RuntimeError, "This function was deprecated since version 1.9 and is now removed"): 177 torch.eig(a) 178 with self.assertRaisesRegex(RuntimeError, "This function was deprecated since version 1.9 and is now removed"): 179 a.eig() 180 181 def test_symeig_removed_error(self, device): 182 a = make_tensor(5, 5, device=device, dtype=torch.float32) 183 with self.assertRaisesRegex(RuntimeError, "This function was deprecated since version 1.9 and is now removed"): 184 torch.symeig(a) 185 with self.assertRaisesRegex(RuntimeError, "This function was deprecated since version 1.9 and is now removed"): 186 a.symeig() 187 188 def test_lstsq_removed_error(self, device): 189 a = make_tensor(5, 5, device=device, dtype=torch.float32) 190 with self.assertRaisesRegex(RuntimeError, "This function was deprecated since version 1.9 and is now removed"): 191 torch.lstsq(a, a) 192 with self.assertRaisesRegex(RuntimeError, "This function was deprecated since version 1.9 and is now removed"): 193 a.lstsq(a) 194 195 @skipCUDAIfNoMagma 196 @skipCPUIfNoLapack 197 @skipIfTorchDynamo("flaky, needs investigation") 198 @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble) 199 def test_linalg_lstsq(self, device, dtype): 200 from torch.testing._internal.common_utils import random_well_conditioned_matrix 201 if self.device_type == 'cpu': 202 drivers = ('gels', 'gelsy', 'gelsd', 'gelss', None) 203 else: 204 drivers = ('gels', None) 205 206 def check_solution_correctness(a, b, sol): 207 sol2 = a.pinverse() @ b 208 self.assertEqual(sol, sol2, atol=1e-5, rtol=1e-5) 209 210 def check_correctness_ref(a, b, res, ref, driver="default"): 211 def apply_if_not_empty(t, f): 212 if t.numel(): 213 return f(t) 214 else: 215 return t 216 217 def select_if_not_empty(t, i): 218 selected = apply_if_not_empty(t, lambda x: x.select(0, i)) 219 return selected 220 221 m = a.size(-2) 222 n = a.size(-1) 223 nrhs = b.size(-1) 224 batch_size = int(np.prod(a.shape[:-2])) 225 if batch_size == 0: 226 batch_size = 1 227 a_3d = a.view(batch_size, m, n) 228 b_3d = b.view(batch_size, m, nrhs) 229 230 solution_3d = res.solution.view(batch_size, n, nrhs) 231 residuals_2d = apply_if_not_empty(res.residuals, lambda t: t.view(-1, nrhs)) 232 rank_1d = apply_if_not_empty(res.rank, lambda t: t.view(-1)) 233 singular_values_2d = res.singular_values.view(batch_size, res.singular_values.shape[-1]) 234 235 if a.numel() > 0: 236 for i in range(batch_size): 237 sol, residuals, rank, singular_values = ref( 238 a_3d.select(0, i).numpy(), 239 b_3d.select(0, i).numpy() 240 ) 241 # Singular values are None when lapack_driver='gelsy' in SciPy 242 if singular_values is None: 243 singular_values = [] 244 self.assertEqual(sol, solution_3d.select(0, i), atol=1e-5, rtol=1e-5) 245 self.assertEqual(rank, select_if_not_empty(rank_1d, i), atol=1e-5, rtol=1e-5) 246 self.assertEqual(singular_values, singular_values_2d.select(0, i), atol=1e-5, rtol=1e-5) 247 248 # SciPy and NumPy operate only on non-batched input and 249 # return an empty array with shape (0,) if rank(a) != n 250 # in PyTorch the batched inputs are supported and 251 # matrices in the batched input can have different ranks 252 # we compute residuals only if all matrices have rank == n 253 # see https://github.com/pytorch/pytorch/issues/56483 254 if m > n: 255 if torch.all(rank_1d == n): 256 self.assertEqual( 257 residuals, select_if_not_empty(residuals_2d, i), atol=1e-5, rtol=1e-5, exact_dtype=False 258 ) 259 else: 260 self.assertTrue(residuals_2d.numel() == 0) 261 262 else: 263 self.assertEqual(res.solution.shape, (*a.shape[:-2], n, nrhs)) 264 self.assertEqual(res.rank.shape, a.shape[:-2]) 265 266 # residuals are not always computed (and have non-zero shape) 267 if m > n and driver != "gelsy": 268 self.assertEqual(res.residuals.shape, (*a.shape[:-2], 0)) 269 else: 270 self.assertEqual(res.residuals.shape, (0, )) 271 272 # singular_values are not always computed (and have non-zero shape) 273 if driver == "default" or driver == "gelsd" or driver == "gelss": 274 self.assertEqual(res.singular_values.shape, (*a.shape[:-2], min(m, n))) 275 else: 276 self.assertEqual(res.singular_values.shape, (0, )) 277 278 def check_correctness_scipy(a, b, res, driver, cond): 279 # SciPy provides 3 driver options: gelsd, gelss, gelsy 280 if TEST_SCIPY and driver in ('gelsd', 'gelss', 'gelsy'): 281 import scipy.linalg 282 283 def scipy_ref(a, b): 284 return scipy.linalg.lstsq(a, b, lapack_driver=driver, cond=cond) 285 check_correctness_ref(a, b, res, scipy_ref, driver=driver) 286 287 def check_correctness_numpy(a, b, res, driver, rcond): 288 # NumPy uses only gelsd routine 289 if driver == 'gelsd': 290 291 def numpy_ref(a, b): 292 return np.linalg.lstsq(a, b, rcond=rcond) 293 check_correctness_ref(a, b, res, numpy_ref) 294 295 ms = [2 ** i for i in range(5)] 296 m_ge_n_sizes = [(m, m // 2) for m in ms] + [(m, m) for m in ms] 297 # cases m < n are only supported on CPU and for cuSOLVER path on CUDA 298 m_l_n_sizes = [(m // 2, m) for m in ms] 299 include_m_l_n_case = (has_cusolver() or device == 'cpu') 300 matrix_sizes = m_ge_n_sizes + (m_l_n_sizes if include_m_l_n_case else []) 301 batches = [(), (2,), (2, 2), (2, 2, 2)] 302 # we generate matrices with singular values sampled from a normal distribution, 303 # that is why we use `cond=1.0`, the mean to cut roughly half of all 304 # the singular values and compare whether torch.linalg.lstsq agrees with 305 # SciPy and NumPy. 306 # if rcond is True then set value for it based on the used algorithm 307 # rcond == -1 or any other negative value forces LAPACK to use machine precision tolerance 308 rconds = (None, True, -1) 309 310 for batch, matrix_size, driver, rcond in itertools.product(batches, matrix_sizes, drivers, rconds): 311 # keep the rcond value if it is None or -1, set the driver specific value if it is True 312 if rcond and rcond != -1: 313 if driver in ('gelss', 'gelsd'): 314 # SVD based algorithm; set to zero roughly half of all the singular values 315 rcond = 1.0 316 else: 317 # driver == 'gelsy' 318 # QR based algorithm; setting the value too high might lead to non-unique solutions and flaky tests 319 # so we skip this case 320 continue 321 322 # specifying rcond value has no effect for gels driver so no need to run the tests again 323 if driver == 'gels' and rcond is not None: 324 continue 325 326 shape = batch + matrix_size 327 a = random_well_conditioned_matrix(*shape, dtype=dtype, device=device) 328 b = torch.rand(*shape, dtype=dtype, device=device) 329 330 m = a.size(-2) 331 n = a.size(-1) 332 res = torch.linalg.lstsq(a, b, rcond=rcond, driver=driver) 333 sol = res.solution 334 335 # Only checks gelsd, gelss, gelsy drivers 336 check_correctness_scipy(a, b, res, driver, rcond) 337 338 # Only checks gelsd driver 339 check_correctness_numpy(a, b, res, driver, rcond) 340 341 # gels driver is not checked by comparing to NumPy or SciPy implementation 342 # because NumPy and SciPy do not implement this driver 343 if driver == 'gels' and rcond is None: 344 check_solution_correctness(a, b, sol) 345 346 @skipCUDAIfNoMagma 347 @skipCPUIfNoLapack 348 @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble) 349 def test_linalg_lstsq_batch_broadcasting(self, device, dtype): 350 from torch.testing._internal.common_utils import random_well_conditioned_matrix 351 352 def check_correctness(a, b): 353 sol = torch.linalg.lstsq(a, b).solution 354 sol2 = a.pinverse() @ b 355 self.assertEqual(sol, sol2, rtol=1e-5, atol=1e-5) 356 357 ms = [2 ** i for i in range(5)] 358 batches = [(), (0,), (2,), (2, 2), (2, 2, 2)] 359 # the case when a single matrix is batch-broadcasted over the rhs 360 for m, batch in itertools.product(ms, batches): 361 a = random_well_conditioned_matrix(m, m, dtype=dtype, device=device).view(*([1] * len(batch)), m, m) 362 b = torch.rand(*(batch + (m, m)), dtype=dtype, device=device) 363 check_correctness(a, b) 364 365 # cases with broadcastable shapes 366 for m in ms: 367 a = random_well_conditioned_matrix(1, 3, 1, 3, m, m, dtype=dtype, device=device) 368 b = torch.rand(3, 1, 3, 1, m, m // 2, dtype=dtype, device=device) 369 check_correctness(a, b) 370 371 # rhs are vectors, not matrices in this test 372 b = torch.rand(3, 1, 3, 1, m, dtype=dtype, device=device) 373 # unsqueeze for b because `check_correctness` checks against 374 # a.pinverse() @ b, which requires b to be a matrix 375 check_correctness(a, b.unsqueeze(-1)) 376 377 a = random_well_conditioned_matrix(3, 1, 3, 1, m, m, dtype=dtype, device=device) 378 b = torch.rand(1, 3, 1, 3, m, m // 2, dtype=dtype, device=device) 379 check_correctness(a, b) 380 381 # rhs are vectors, not matrices in this test 382 b = torch.rand(1, 3, 1, 3, m, dtype=dtype, device=device) 383 check_correctness(a, b.unsqueeze(-1)) 384 385 @skipCPUIfNoLapack 386 @skipCUDAIfNoMagma 387 @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble) 388 def test_linalg_lstsq_input_checks(self, device, dtype): 389 # check empty inputs 390 # empty batches 391 a = torch.rand(0, 0, 3, 3, dtype=dtype, device=device) 392 b = torch.rand(0, 0, 3, 2, dtype=dtype, device=device) 393 self.assertEqual( 394 torch.linalg.lstsq(a, b)[0], 395 torch.zeros(0, 0, 3, 2, dtype=dtype, device=device) 396 ) 397 # empty a and b 398 a = torch.rand(2, 2, 0, 0, dtype=dtype, device=device) 399 b = torch.rand(2, 2, 0, 0, dtype=dtype, device=device) 400 self.assertEqual( 401 torch.linalg.lstsq(a, b)[0], 402 torch.zeros(2, 2, 0, 0, dtype=dtype, device=device) 403 ) 404 # empty a and b 405 a = torch.rand(2, 2, 3, 0, dtype=dtype, device=device) 406 b = torch.rand(2, 2, 3, 0, dtype=dtype, device=device) 407 self.assertEqual( 408 torch.linalg.lstsq(a, b)[0], 409 torch.zeros(2, 2, 0, 0, dtype=dtype, device=device) 410 ) 411 # empty a but not b 412 a = torch.rand(2, 2, 3, 0, dtype=dtype, device=device) 413 b = torch.rand(2, 2, 3, 2, dtype=dtype, device=device) 414 self.assertEqual( 415 torch.linalg.lstsq(a, b)[0], 416 torch.zeros(2, 2, 0, 2, dtype=dtype, device=device) 417 ) 418 419 # empty a and b 420 if torch.device(device).type == 'cpu': 421 # only CPU since CUDA does not support overdetermined systems 422 a = torch.rand(2, 2, 0, 3, dtype=dtype, device=device) 423 b = torch.rand(2, 2, 0, 3, dtype=dtype, device=device) 424 self.assertEqual( 425 torch.linalg.lstsq(a, b)[0], 426 torch.zeros(2, 2, 3, 3, dtype=dtype, device=device) 427 ) 428 429 a = torch.rand(2, 3, dtype=dtype, device=device) 430 b = torch.rand(3, dtype=dtype, device=device) 431 432 with self.assertRaisesRegex(RuntimeError, 'input must have at least 2 dimensions'): 433 torch.linalg.lstsq(b, b) 434 435 with self.assertRaisesRegex(RuntimeError, 'other must have at least 1 dimension'): 436 torch.linalg.lstsq(a, torch.tensor(1, dtype=dtype, device=device)) 437 438 with self.assertRaisesRegex(RuntimeError, r'input.size\(-2\) should match other.size\(-1\)'): 439 torch.linalg.lstsq(a, b) 440 441 with self.assertRaisesRegex(RuntimeError, r'input.size\(-2\) should match other.size\(-2\)'): 442 torch.linalg.lstsq(a, b.unsqueeze(-1)) 443 444 a = torch.randn(1, 1, 1, dtype=dtype, device=device) 445 b = torch.randn(3, 1, dtype=dtype, device=device) 446 447 with self.assertRaisesRegex(RuntimeError, r'input.size\(-2\) should match other.size\(-2\)'): 448 torch.linalg.lstsq(a, b) 449 450 def complement_device(device): 451 if device == 'cpu' and torch.cuda.is_available(): 452 return 'cuda' 453 else: 454 return 'cpu' 455 456 a = torch.rand(2, 2, 2, 2, dtype=dtype, device=device) 457 b = torch.rand(2, 2, 2, dtype=dtype, device=complement_device(device)) 458 if a.device != b.device: 459 with self.assertRaisesRegex(RuntimeError, 'be on the same device'): 460 torch.linalg.lstsq(a, b) 461 462 b = (torch.rand(2, 2, 2, dtype=dtype, device=device) * 100).long() 463 with self.assertRaisesRegex(RuntimeError, 'the same dtype'): 464 torch.linalg.lstsq(a, b) 465 466 a = torch.rand(2, 2, 2, 2, dtype=dtype, device=device) 467 b = torch.rand(2, 2, 2, dtype=dtype, device=device) 468 469 if device != 'cpu': 470 with self.assertRaisesRegex(RuntimeError, '`driver` other than `gels` is not supported on CUDA'): 471 torch.linalg.lstsq(a, b, driver='fictitious_driver') 472 # if on cpu 473 else: 474 with self.assertRaisesRegex(RuntimeError, r'parameter `driver` should be one of \(gels, gelsy, gelsd, gelss\)'): 475 torch.linalg.lstsq(a, b, driver='fictitious_driver') 476 477 # cuSOLVER path supports underdetermined systems 478 version = torch.testing._internal.common_cuda._get_torch_cuda_version() 479 cusolver_not_available = (version < (10, 1)) 480 481 if device != 'cpu' and cusolver_not_available: 482 a = torch.rand(2, 3, dtype=dtype, device=device) 483 b = torch.rand(2, 1, dtype=dtype, device=device) 484 with self.assertRaisesRegex(RuntimeError, r'only overdetermined systems'): 485 torch.linalg.lstsq(a, b) 486 487 @skipCUDAIfNoMagma 488 @skipCPUIfNoLapack 489 @dtypes(*floating_and_complex_types()) 490 def test_cholesky(self, device, dtype): 491 from torch.testing._internal.common_utils import random_hermitian_pd_matrix 492 493 def run_test(shape, batch, contiguous): 494 A = random_hermitian_pd_matrix(shape, *batch, dtype=dtype, device=device) 495 if A.numel() > 0 and not contiguous: 496 A = A.mT 497 self.assertFalse(A.is_contiguous()) 498 expected_L = np.linalg.cholesky(A.cpu().numpy()) 499 actual_L = torch.linalg.cholesky(A) 500 501 # For fp32 individual entries in matrices can differ between PyTorch and NumPy 502 # Let's compare the norms of matrices instead 503 if A.numel() > 0 and dtype in [torch.float32, torch.complex64]: 504 # axis is specified to calculate matrix norm for batched input 505 expected_norm = np.linalg.norm(expected_L, ord=1, axis=(-2, -1)) 506 actual_norm = torch.linalg.norm(actual_L, ord=1, axis=(-2, -1)) 507 # Compare the norms with standard tolerances 508 self.assertEqual(actual_norm, expected_norm) 509 # and individual values with a higher tolerance 510 self.assertEqual(actual_L, expected_L, atol=1e-2, rtol=1e-5) 511 else: 512 self.assertEqual(actual_L, expected_L) 513 514 shapes = (0, 3, 5) 515 batches = ((), (3, ), (2, 2)) 516 larger_input_case = [(100, (5, ), True)] 517 for shape, batch, contiguous in list(itertools.product(shapes, batches, (True, False))) + larger_input_case: 518 run_test(shape, batch, contiguous) 519 520 # check the out= variant 521 A = random_hermitian_pd_matrix(3, 3, dtype=dtype, device=device) 522 out = torch.empty_like(A) 523 ans = torch.linalg.cholesky(A, out=out) 524 self.assertEqual(ans, out) 525 expected = torch.linalg.cholesky(A) 526 self.assertEqual(expected, out) 527 528 # check the upper= variant 529 expected = torch.linalg.cholesky(A).mH 530 actual = torch.linalg.cholesky(A, upper=True) 531 self.assertEqual(expected, actual) 532 533 @skipCUDAIfNoMagma 534 @skipCPUIfNoLapack 535 @dtypes(*floating_and_complex_types()) 536 def test_cholesky_errors_and_warnings(self, device, dtype): 537 from torch.testing._internal.common_utils import random_hermitian_pd_matrix 538 539 # cholesky requires the input to be a square matrix or batch of square matrices 540 A = torch.randn(2, 3, device=device, dtype=dtype) 541 with self.assertRaisesRegex(RuntimeError, r'must be batches of square matrices'): 542 torch.linalg.cholesky(A) 543 A = torch.randn(2, 2, 3, device=device, dtype=dtype) 544 with self.assertRaisesRegex(RuntimeError, r'must be batches of square matrices'): 545 torch.linalg.cholesky(A) 546 with self.assertRaisesRegex(np.linalg.LinAlgError, r'Last 2 dimensions of the array must be square'): 547 np.linalg.cholesky(A.cpu().numpy()) 548 549 # cholesky requires the input to be at least 2 dimensional tensor 550 A = torch.randn(2, device=device, dtype=dtype) 551 with self.assertRaisesRegex(RuntimeError, r'must have at least 2 dimensions'): 552 torch.linalg.cholesky(A) 553 with self.assertRaisesRegex(np.linalg.LinAlgError, 554 r'1-dimensional array given\. Array must be at least two-dimensional'): 555 np.linalg.cholesky(A.cpu().numpy()) 556 557 # if the input matrix is not positive definite, an error should be raised 558 A = torch.eye(3, 3, dtype=dtype, device=device) 559 A[-1, -1] = 0 # Now A is not positive definite 560 with self.assertRaisesRegex(torch.linalg.LinAlgError, r'minor of order 3 is not positive-definite'): 561 torch.linalg.cholesky(A) 562 with self.assertRaisesRegex(np.linalg.LinAlgError, r'Matrix is not positive definite'): 563 np.linalg.cholesky(A.cpu().numpy()) 564 565 # if at least one matrix in the batch is singular, an error should be raised 566 A = torch.eye(3, 3, dtype=dtype, device=device) 567 A = A.reshape((1, 3, 3)) 568 A = A.repeat(5, 1, 1) 569 A[4, -1, -1] = 0 # Now A[4] is not positive definite 570 with self.assertRaisesRegex(torch.linalg.LinAlgError, r'\(Batch element 4\): The factorization could not be completed'): 571 torch.linalg.cholesky(A) 572 573 # if out tensor with wrong shape is passed a warning is given 574 A = random_hermitian_pd_matrix(3, dtype=dtype, device=device) 575 out = torch.empty(2, 3, dtype=dtype, device=device) 576 with warnings.catch_warnings(record=True) as w: 577 # Trigger warning 578 torch.linalg.cholesky(A, out=out) 579 # Check warning occurs 580 self.assertEqual(len(w), 1) 581 self.assertTrue("An output with one or more elements was resized" in str(w[-1].message)) 582 583 # dtypes should be safely castable 584 out = torch.empty(*A.shape, dtype=torch.int, device=device) 585 with self.assertRaisesRegex(RuntimeError, "but got int instead"): 586 torch.linalg.cholesky(A, out=out) 587 588 # device should match 589 if torch.cuda.is_available(): 590 wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda' 591 out = torch.empty(0, device=wrong_device, dtype=dtype) 592 with self.assertRaisesRegex(RuntimeError, "Expected all tensors to be on the same device"): 593 torch.linalg.cholesky(A, out=out) 594 595 # NOTE: old_cholesky* tests were moved here from test_torch.py and test_autograd.py 596 @slowTest 597 @skipCUDAIfNoMagma 598 @skipCPUIfNoLapack 599 @dtypes(torch.double) 600 def test_old_cholesky_batched_many_batches(self, device, dtype): 601 from torch.testing._internal.common_utils import random_symmetric_pd_matrix 602 603 def cholesky_test_helper(n, batchsize, device, upper): 604 A = random_symmetric_pd_matrix(n, batchsize, dtype=dtype, device=device) 605 chol_fact = torch.cholesky(A, upper=upper) 606 if upper: 607 # Correctness check 608 self.assertEqual(A, chol_fact.mT.matmul(chol_fact)) 609 # Upper triangular check 610 self.assertEqual(chol_fact, chol_fact.triu()) 611 else: 612 # Correctness check 613 self.assertEqual(A, chol_fact.matmul(chol_fact.mT)) 614 # Lower triangular check 615 self.assertEqual(chol_fact, chol_fact.tril()) 616 617 for upper, batchsize in itertools.product([True, False], [262144, 524288]): 618 cholesky_test_helper(2, batchsize, device, upper) 619 620 @precisionOverride({torch.float32: 1e-4, torch.complex64: 1e-4}) 621 @skipCUDAIfNoMagma 622 @skipCPUIfNoLapack 623 @dtypes(*floating_and_complex_types()) 624 def test_old_cholesky_batched(self, device, dtype): 625 from torch.testing._internal.common_utils import random_hermitian_pd_matrix 626 627 def cholesky_test_helper(n, batch_dims, upper): 628 A = random_hermitian_pd_matrix(n, *batch_dims, dtype=dtype, device=device) 629 cholesky_exp = torch.stack([m.cholesky(upper=upper) for m in A.reshape(-1, n, n)]) 630 cholesky_exp = cholesky_exp.reshape_as(A) 631 self.assertEqual(cholesky_exp, torch.cholesky(A, upper=upper)) 632 633 for upper, batchsize in itertools.product([True, False], [(3,), (3, 4), (2, 3, 4)]): 634 cholesky_test_helper(3, batchsize, upper) 635 636 @precisionOverride({torch.float32: 1e-4, torch.complex64: 1e-4}) 637 @skipCUDAIfNoMagma 638 @skipCPUIfNoLapack 639 @dtypes(*floating_and_complex_types()) 640 @tf32_on_and_off(0.01) 641 @bf32_on_and_off(0.01) 642 def test_old_cholesky(self, device, dtype): 643 from torch.testing._internal.common_utils import random_hermitian_pd_matrix 644 645 A = random_hermitian_pd_matrix(10, dtype=dtype, device=device) 646 647 # default Case 648 C = torch.cholesky(A) 649 B = torch.mm(C, C.t().conj()) 650 self.assertEqual(A, B, atol=1e-14, rtol=0) 651 652 # test Upper Triangular 653 U = torch.cholesky(A, True) 654 B = torch.mm(U.t().conj(), U) 655 self.assertEqual(A, B, atol=1e-14, rtol=0, msg='cholesky (upper) did not allow rebuilding the original matrix') 656 657 # test Lower Triangular 658 L = torch.cholesky(A, False) 659 B = torch.mm(L, L.t().conj()) 660 self.assertEqual(A, B, atol=1e-14, rtol=0, msg='cholesky (lower) did not allow rebuilding the original matrix') 661 662 @skipCUDAIfNoMagma 663 @skipCPUIfNoLapack 664 @dtypes(*floating_and_complex_types()) 665 def test_old_cholesky_empty(self, device, dtype): 666 def run_test(upper): 667 A = torch.empty(0, 0, dtype=dtype, device=device) 668 chol = torch.cholesky(A, upper) 669 chol_A = torch.matmul(chol, chol.t().conj()) 670 self.assertEqual(A, chol_A) 671 for upper in [True, False]: 672 run_test(upper) 673 674 # Test for issue 675 # https://github.com/pytorch/pytorch/issues/57032 676 # torch.cholesky with upper=True for batched CUDA inputs was wrong 677 # it was using the lower triangular part instead of the upper one 678 @onlyCUDA 679 @skipCUDAIfNoMagma 680 @dtypes(*floating_and_complex_types()) 681 def test_old_cholesky_batched_upper(self, device, dtype): 682 from torch.testing._internal.common_utils import random_hermitian_pd_matrix 683 684 batchsize = 2 685 A = random_hermitian_pd_matrix(3, batchsize, dtype=dtype, device=device) 686 A_triu = A.triu() # fill the lower triangular part with zero 687 688 U = torch.cholesky(A_triu, upper=True) 689 690 reconstruct_A = U.mH @ U 691 self.assertEqual(A, reconstruct_A) 692 693 @skipCUDAIfNoMagmaAndNoCusolver 694 @skipCPUIfNoLapack 695 @dtypes(*floating_and_complex_types()) 696 def test_cholesky_ex(self, device, dtype): 697 from torch.testing._internal.common_utils import random_hermitian_pd_matrix 698 699 def run_test(n, batch): 700 A = random_hermitian_pd_matrix(n, *batch, dtype=dtype, device=device) 701 expected_L = np.linalg.cholesky(A.cpu().numpy()) 702 expected_info = torch.zeros(A.shape[:-2], dtype=torch.int32, device=device) 703 actual_L, actual_info = torch.linalg.cholesky_ex(A) 704 705 # For fp32 individual entries in matrices can differ between PyTorch and NumPy 706 # Let's compare the norms of matrices instead 707 if A.numel() > 0 and dtype in [torch.float32, torch.complex64]: 708 # axis is specified to calculate matrix norm for batched input 709 expected_norm = np.linalg.norm(expected_L, ord=1, axis=(-2, -1)) 710 actual_norm = torch.linalg.norm(actual_L, ord=1, axis=(-2, -1)) 711 # Compare the norms with standard tolerances 712 self.assertEqual(actual_norm, expected_norm) 713 # and individual values with a higher tolerance 714 self.assertEqual(actual_L, expected_L, atol=1e-2, rtol=1e-5) 715 else: 716 self.assertEqual(actual_L, expected_L) 717 self.assertEqual(actual_info, expected_info) 718 719 ns = (0, 3, 5) 720 batches = ((), (2, ), (2, 1)) 721 for n, batch in itertools.product(ns, batches): 722 run_test(n, batch) 723 724 @skipCUDAIfNoMagmaAndNoCusolver 725 @skipCPUIfNoLapack 726 @dtypes(*floating_and_complex_types()) 727 def test_cholesky_ex_non_pd(self, device, dtype): 728 # if the input matrix is not positive definite, info with positive integer is returned 729 A = torch.eye(3, 3, dtype=dtype, device=device) 730 A[-1, -1] = 0 # Now A is singular 731 _, info = torch.linalg.cholesky_ex(A) 732 self.assertEqual(info, 3) 733 with self.assertRaisesRegex(torch.linalg.LinAlgError, r'minor of order 3 is not positive-definite'): 734 torch.linalg.cholesky_ex(A, check_errors=True) 735 736 # if at least one matrix in the batch is not positive definite, 737 # batched info with positive integer for the corresponding matrix is returned 738 A = torch.eye(3, 3, dtype=dtype, device=device) 739 A = A.reshape((1, 3, 3)) 740 A = A.repeat(5, 1, 1) 741 A[3, -2, -2] = 0 # Now A[3] is singular 742 _, info = torch.linalg.cholesky_ex(A) 743 744 expected_info = torch.zeros(A.shape[:-2], dtype=torch.int32, device=device) 745 expected_info[3] = 2 746 self.assertEqual(info, expected_info) 747 with self.assertRaisesRegex(torch.linalg.LinAlgError, r'\(Batch element 3\): The factorization could not be completed'): 748 torch.linalg.cholesky_ex(A, check_errors=True) 749 750 def _test_addr_vs_numpy(self, device, dtype, beta=1, alpha=1): 751 def check(m, a, b, beta, alpha): 752 if dtype == torch.bfloat16: 753 a_np = a.to(torch.double).cpu().numpy() 754 b_np = b.to(torch.double).cpu().numpy() 755 m_np = m.to(torch.double).cpu().numpy() 756 exact_dtype = False 757 else: 758 a_np = a.cpu().numpy() 759 b_np = b.cpu().numpy() 760 m_np = m.cpu().numpy() 761 exact_dtype = True 762 if beta == 0: 763 expected = alpha * np.outer(a_np, b_np) 764 else: 765 expected = beta * m_np + alpha * np.outer(a_np, b_np) 766 767 res = torch.addr(m, a, b, beta=beta, alpha=alpha) 768 self.assertEqual(res, expected, exact_dtype=exact_dtype) 769 770 # Test out variant 771 out = torch.empty_like(res) 772 torch.addr(m, a, b, beta=beta, alpha=alpha, out=out) 773 self.assertEqual(out, expected, exact_dtype=exact_dtype) 774 775 m = make_tensor((50, 50), device=device, dtype=dtype, low=-2, high=2) 776 a = make_tensor((50,), device=device, dtype=dtype, low=-2, high=2) 777 b = make_tensor((50,), device=device, dtype=dtype, low=-2, high=2) 778 779 check(m, a, b, beta, alpha) 780 781 # test transpose 782 m_transpose = torch.transpose(m, 0, 1) 783 check(m_transpose, a, b, beta, alpha) 784 785 # test 0 strided tensor 786 zero_strided = make_tensor((1,), device=device, dtype=dtype, low=-2, high=2).expand(50) 787 check(m, zero_strided, b, beta, alpha) 788 789 # test scalar 790 m_scalar = torch.tensor(1, device=device, dtype=dtype) 791 check(m_scalar, a, b, beta, alpha) 792 793 # test nans and infs are not propagated to the output when beta == 0 794 float_and_complex_dtypes = floating_and_complex_types_and(torch.half, torch.bfloat16) 795 if beta == 0 and dtype in float_and_complex_dtypes: 796 m[0][10] = m[10][10] = m[20][20] = float('inf') 797 m[1][10] = m[11][10] = m[21][20] = float('nan') 798 check(m, a, b, 0, alpha) 799 800 @dtypes(torch.bool) 801 def test_addr_bool(self, device, dtype): 802 self._test_addr_vs_numpy(device, dtype, beta=True, alpha=False) 803 self._test_addr_vs_numpy(device, dtype, beta=False, alpha=True) 804 self._test_addr_vs_numpy(device, dtype, beta=False, alpha=False) 805 self._test_addr_vs_numpy(device, dtype, beta=True, alpha=True) 806 807 @dtypes(*integral_types()) 808 def test_addr_integral(self, device, dtype): 809 with self.assertRaisesRegex(RuntimeError, 810 'argument beta must not be a floating point number.'): 811 self._test_addr_vs_numpy(device, dtype, beta=2., alpha=1) 812 with self.assertRaisesRegex(RuntimeError, 813 'argument alpha must not be a floating point number.'): 814 self._test_addr_vs_numpy(device, dtype, beta=2, alpha=1.) 815 with self.assertRaisesRegex(RuntimeError, 816 'Boolean beta only supported for Boolean results.'): 817 self._test_addr_vs_numpy(device, dtype, beta=True, alpha=1) 818 with self.assertRaisesRegex(RuntimeError, 819 'Boolean alpha only supported for Boolean results.'): 820 self._test_addr_vs_numpy(device, dtype, beta=2, alpha=True) 821 822 # when beta is zero 823 self._test_addr_vs_numpy(device, dtype, beta=0, alpha=2) 824 # when beta is not zero 825 self._test_addr_vs_numpy(device, dtype, beta=2, alpha=2) 826 827 @precisionOverride({torch.bfloat16: 1e-1}) 828 @dtypes(*floating_and_complex_types_and(torch.half, torch.bfloat16)) 829 def test_addr_float_and_complex(self, device, dtype): 830 with self.assertRaisesRegex(RuntimeError, 831 'Boolean beta only supported for Boolean results.'): 832 self._test_addr_vs_numpy(device, dtype, beta=True, alpha=1) 833 with self.assertRaisesRegex(RuntimeError, 834 'Boolean alpha only supported for Boolean results.'): 835 self._test_addr_vs_numpy(device, dtype, beta=2, alpha=True) 836 837 # when beta is zero 838 self._test_addr_vs_numpy(device, dtype, beta=0., alpha=2) 839 # when beta is not zero 840 self._test_addr_vs_numpy(device, dtype, beta=0.5, alpha=2) 841 if dtype in complex_types(): 842 self._test_addr_vs_numpy(device, dtype, beta=(0 + 0.1j), alpha=(0.2 - 0.2j)) 843 844 @dtypes(*itertools.product(all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool), 845 all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))) 846 def test_outer_type_promotion(self, device, dtypes): 847 a = torch.randn(5).to(device=device, dtype=dtypes[0]) 848 b = torch.randn(5).to(device=device, dtype=dtypes[1]) 849 for op in (torch.outer, torch.Tensor.outer, torch.ger, torch.Tensor.ger): 850 result = op(a, b) 851 self.assertEqual(result.dtype, torch.result_type(a, b)) 852 853 # don't use @dtypes decorator to avoid generating ~1700 tests per device 854 def test_addr_type_promotion(self, device): 855 for dtypes0, dtypes1, dtypes2 in product(all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool), repeat=3): 856 a = make_tensor((5,), device=device, dtype=dtypes0, low=-2, high=2) 857 b = make_tensor((5,), device=device, dtype=dtypes1, low=-2, high=2) 858 m = make_tensor((5, 5), device=device, dtype=dtypes2, low=-2, high=2) 859 860 desired_dtype = torch.promote_types(torch.promote_types(dtypes0, dtypes1), 861 dtypes2) 862 for op in (torch.addr, torch.Tensor.addr): 863 result = op(m, a, b) 864 self.assertEqual(result.dtype, desired_dtype) 865 866 # Tests migrated from test_torch.py 867 # 1) test the shape of the result tensor when there is empty input tensor 868 # 2) test the Runtime Exception when there is scalar input tensor 869 def test_outer_ger_addr_legacy_tests(self, device): 870 for size in ((0, 0), (0, 5), (5, 0)): 871 a = torch.rand(size[0], device=device) 872 b = torch.rand(size[1], device=device) 873 874 self.assertEqual(torch.outer(a, b).shape, size) 875 self.assertEqual(torch.ger(a, b).shape, size) 876 877 m = torch.empty(size, device=device) 878 self.assertEqual(torch.addr(m, a, b).shape, size) 879 880 m = torch.randn(5, 6, device=device) 881 a = torch.randn(5, device=device) 882 b = torch.tensor(6, device=device) 883 self.assertRaises(RuntimeError, lambda: torch.outer(a, b)) 884 self.assertRaises(RuntimeError, lambda: torch.outer(b, a)) 885 self.assertRaises(RuntimeError, lambda: torch.ger(a, b)) 886 self.assertRaises(RuntimeError, lambda: torch.ger(b, a)) 887 self.assertRaises(RuntimeError, lambda: torch.addr(m, a, b)) 888 self.assertRaises(RuntimeError, lambda: torch.addr(m, b, a)) 889 890 # Tests torch.det and its alias, torch.linalg.det, vs. NumPy 891 @skipCUDAIfNoMagma 892 @skipCPUIfNoLapack 893 @dtypes(torch.double, torch.cdouble) 894 def test_det(self, device, dtype): 895 tensors = ( 896 torch.randn((2, 2), device=device, dtype=dtype), 897 torch.randn((129, 129), device=device, dtype=dtype), 898 torch.randn((3, 52, 52), device=device, dtype=dtype), 899 torch.randn((4, 2, 26, 26), device=device, dtype=dtype)) 900 901 902 ops = (torch.det, torch.Tensor.det, 903 torch.linalg.det) 904 for t in tensors: 905 expected = np.linalg.det(t.cpu().numpy()) 906 for op in ops: 907 actual = op(t) 908 self.assertEqual(actual, expected) 909 self.compare_with_numpy(op, np.linalg.det, t) 910 911 # NOTE: det requires a 2D+ tensor 912 t = torch.randn(1, device=device, dtype=dtype) 913 with self.assertRaises(RuntimeError): 914 op(t) 915 916 @skipCUDAIfNoMagma 917 @skipCPUIfNoLapack 918 @dtypes(*floating_and_complex_types()) 919 @precisionOverride({torch.float32: 1e-4, torch.complex64: 1e-4}) 920 def test_eigh(self, device, dtype): 921 from torch.testing._internal.common_utils import random_hermitian_matrix 922 923 def run_test(shape, batch, uplo): 924 matrix = random_hermitian_matrix(shape, *batch, dtype=dtype, device=device) 925 expected_w, expected_v = np.linalg.eigh(matrix.cpu().numpy(), UPLO=uplo) 926 actual_w, actual_v = torch.linalg.eigh(matrix, UPLO=uplo) 927 self.assertEqual(actual_w, expected_w) 928 # sign of eigenvectors is not unique and therefore absolute values are compared 929 self.assertEqual(abs(actual_v), abs(expected_v)) 930 # additionally we can multiply the eigenvector with a phase factor e^{i\phi} and then compare the values 931 # let's choose the convention that the first element of the eigenvectors from torch and numpy be the same 932 # for real inputs, this phase factor is plus or minus one 933 if matrix.numel() > 0: 934 phase = torch.from_numpy(expected_v[..., 0, :]).to(device=device).div(actual_v[..., 0, :]) 935 actual_v_rotated = actual_v * phase.unsqueeze(-2).expand_as(actual_v) 936 self.assertEqual(actual_v_rotated, expected_v) 937 938 # check the out= variant 939 out_w = torch.empty_like(actual_w) 940 out_v = torch.empty_like(actual_v) 941 ans_w, ans_v = torch.linalg.eigh(matrix, UPLO=uplo, out=(out_w, out_v)) 942 self.assertEqual(ans_w, out_w) 943 self.assertEqual(ans_v, out_v) 944 self.assertEqual(ans_w, actual_w) 945 self.assertEqual(abs(ans_v), abs(actual_v)) 946 947 shapes = (0, 3, 5) 948 batches = ((), (3, ), (2, 2)) 949 uplos = ["U", "L"] 950 for shape, batch, uplo in itertools.product(shapes, batches, uplos): 951 run_test(shape, batch, uplo) 952 953 @skipCUDAIfNoMagma 954 @skipCPUIfNoLapack 955 @dtypes(*floating_and_complex_types()) 956 @precisionOverride({torch.float32: 1e-4, torch.complex64: 1e-4}) 957 def test_eigh_lower_uplo(self, device, dtype): 958 def run_test(shape, batch, uplo): 959 # check lower case uplo 960 # use non-symmetric input to check whether uplo argument is working as intended 961 matrix = torch.randn(shape, shape, *batch, dtype=dtype, device=device) 962 expected_w, expected_v = np.linalg.eigh(matrix.cpu().numpy(), UPLO=uplo) 963 actual_w, actual_v = torch.linalg.eigh(matrix, UPLO=uplo) 964 self.assertEqual(actual_w, expected_w) 965 self.assertEqual(abs(actual_v), abs(expected_v)) 966 967 uplos = ["u", "l"] 968 for uplo in uplos: 969 run_test(3, (2, 2), uplo) 970 971 @skipCUDAIfNoMagma 972 @skipCPUIfNoLapack 973 @dtypes(*floating_and_complex_types()) 974 def test_eigh_errors_and_warnings(self, device, dtype): 975 from torch.testing._internal.common_utils import random_hermitian_matrix 976 977 # eigh requires a square matrix 978 t = torch.randn(2, 3, device=device, dtype=dtype) 979 with self.assertRaisesRegex(RuntimeError, "must be batches of square matrices"): 980 torch.linalg.eigh(t) 981 982 # eigh requires 'uplo' parameter to be 'U' or 'L' 983 t = torch.randn(3, 3, device=device, dtype=dtype) 984 for uplo in ["a", "wrong"]: 985 with self.assertRaisesRegex(RuntimeError, "be 'L' or 'U'"): 986 torch.linalg.eigh(t, UPLO=uplo) 987 with self.assertRaisesRegex(ValueError, "be 'L' or 'U'"): 988 np.linalg.eigh(t.cpu().numpy(), UPLO=uplo) 989 990 # if non-empty out tensor with wrong shape is passed a warning is given 991 a = random_hermitian_matrix(3, dtype=dtype, device=device) 992 real_dtype = a.real.dtype if dtype.is_complex else dtype 993 out_w = torch.empty(7, 7, dtype=real_dtype, device=device) 994 out_v = torch.empty(7, 7, dtype=dtype, device=device) 995 with warnings.catch_warnings(record=True) as w: 996 # Trigger warning 997 torch.linalg.eigh(a, out=(out_w, out_v)) 998 # Check warning occurs 999 self.assertEqual(len(w), 2) 1000 self.assertTrue("An output with one or more elements was resized" in str(w[-2].message)) 1001 self.assertTrue("An output with one or more elements was resized" in str(w[-1].message)) 1002 1003 # dtypes should be safely castable 1004 out_w = torch.empty(0, dtype=real_dtype, device=device) 1005 out_v = torch.empty(0, dtype=torch.int, device=device) 1006 with self.assertRaisesRegex(RuntimeError, "but got int instead"): 1007 torch.linalg.eigh(a, out=(out_w, out_v)) 1008 1009 out_w = torch.empty(0, dtype=torch.int, device=device) 1010 out_v = torch.empty(0, dtype=dtype, device=device) 1011 with self.assertRaisesRegex(RuntimeError, "but got int instead"): 1012 torch.linalg.eigh(a, out=(out_w, out_v)) 1013 1014 # device should match 1015 if torch.cuda.is_available(): 1016 wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda' 1017 out_w = torch.empty(0, device=wrong_device, dtype=dtype) 1018 out_v = torch.empty(0, device=device, dtype=dtype) 1019 with self.assertRaisesRegex(RuntimeError, "tensors to be on the same device"): 1020 torch.linalg.eigh(a, out=(out_w, out_v)) 1021 out_w = torch.empty(0, device=device, dtype=dtype) 1022 out_v = torch.empty(0, device=wrong_device, dtype=dtype) 1023 with self.assertRaisesRegex(RuntimeError, "tensors to be on the same device"): 1024 torch.linalg.eigh(a, out=(out_w, out_v)) 1025 1026 @skipCPUIfNoLapack 1027 @dtypes(torch.float, torch.double) 1028 @unittest.skipIf(_get_torch_cuda_version() < (12, 1), "Test is fixed on cuda 12.1 update 1.") 1029 def test_eigh_svd_illcondition_matrix_input_should_not_crash(self, device, dtype): 1030 # See https://github.com/pytorch/pytorch/issues/94772, https://github.com/pytorch/pytorch/issues/105359 1031 # This test crashes with `cusolver error: CUSOLVER_STATUS_EXECUTION_FAILED` on cuda 11.8, 1032 # but passes on cuda 12.1 update 1 or later. 1033 a = torch.ones(512, 512, dtype=dtype, device=device) 1034 a[0, 0] = 1.0e-5 1035 a[-1, -1] = 1.0e5 1036 1037 eigh_out = torch.linalg.eigh(a) 1038 svd_out = torch.linalg.svd(a) 1039 1040 # Matrix input a is too ill-conditioned. 1041 # We'll just compare the first two singular values/eigenvalues. They are 1.0e5 and 511.0 1042 # The precision override with tolerance of 1.0 makes sense since ill-conditioned inputs are hard to converge 1043 # to exact values. 1044 self.assertEqual(eigh_out.eigenvalues.sort(descending=True).values[:2], [1.0e5, 511.0], atol=1.0, rtol=1.0e-2) 1045 self.assertEqual(svd_out.S[:2], [1.0e5, 511.0], atol=1.0, rtol=1.0e-2) 1046 1047 @skipCUDAIfNoMagma 1048 @skipCPUIfNoLapack 1049 @dtypes(*floating_and_complex_types()) 1050 @precisionOverride({torch.float32: 1e-4, torch.complex64: 1e-4}) 1051 def test_eigvalsh(self, device, dtype): 1052 from torch.testing._internal.common_utils import random_hermitian_matrix 1053 1054 def run_test(shape, batch, uplo): 1055 matrix = random_hermitian_matrix(shape, *batch, dtype=dtype, device=device) 1056 expected_w = np.linalg.eigvalsh(matrix.cpu().numpy(), UPLO=uplo) 1057 actual_w = torch.linalg.eigvalsh(matrix, UPLO=uplo) 1058 self.assertEqual(actual_w, expected_w) 1059 1060 # check the out= variant 1061 out = torch.empty_like(actual_w) 1062 ans = torch.linalg.eigvalsh(matrix, UPLO=uplo, out=out) 1063 self.assertEqual(ans, out) 1064 self.assertEqual(ans, actual_w) 1065 1066 shapes = (0, 3, 5) 1067 batches = ((), (3, ), (2, 2)) 1068 uplos = ["U", "L"] 1069 for shape, batch, uplo in itertools.product(shapes, batches, uplos): 1070 run_test(shape, batch, uplo) 1071 1072 @skipCUDAIfNoMagma 1073 @skipCPUIfNoLapack 1074 @dtypes(*floating_and_complex_types()) 1075 def test_eigvalsh_errors_and_warnings(self, device, dtype): 1076 # eigvalsh requires a square matrix 1077 t = torch.randn(2, 3, device=device, dtype=dtype) 1078 with self.assertRaisesRegex(RuntimeError, "must be batches of square matrices"): 1079 torch.linalg.eigvalsh(t) 1080 1081 # eigvalsh requires 'uplo' parameter to be 'U' or 'L' 1082 t = torch.randn(3, 3, device=device, dtype=dtype) 1083 for uplo in ["a", "wrong"]: 1084 with self.assertRaisesRegex(RuntimeError, "be 'L' or 'U'"): 1085 torch.linalg.eigvalsh(t, UPLO=uplo) 1086 with self.assertRaisesRegex(ValueError, "be 'L' or 'U'"): 1087 np.linalg.eigvalsh(t.cpu().numpy(), UPLO=uplo) 1088 1089 # if non-empty out tensor with wrong shape is passed a warning is given 1090 real_dtype = t.real.dtype if dtype.is_complex else dtype 1091 out = torch.empty_like(t).to(real_dtype) 1092 with warnings.catch_warnings(record=True) as w: 1093 # Trigger warning 1094 torch.linalg.eigvalsh(t, out=out) 1095 # Check warning occurs 1096 self.assertEqual(len(w), 1) 1097 self.assertTrue("An output with one or more elements was resized" in str(w[-1].message)) 1098 1099 # dtypes should be safely castable 1100 out = torch.empty(0, dtype=torch.int, device=device) 1101 with self.assertRaisesRegex(RuntimeError, "but got int instead"): 1102 torch.linalg.eigvalsh(t, out=out) 1103 1104 # device should match 1105 if torch.cuda.is_available(): 1106 wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda' 1107 out = torch.empty(0, device=wrong_device, dtype=dtype) 1108 with self.assertRaisesRegex(RuntimeError, "tensors to be on the same device"): 1109 torch.linalg.eigvalsh(t, out=out) 1110 1111 @dtypes(*floating_and_complex_types()) 1112 def test_kron(self, device, dtype): 1113 1114 def run_test_case(a_shape, b_shape): 1115 a = torch.rand(a_shape, dtype=dtype, device=device) 1116 b = torch.rand(b_shape, dtype=dtype, device=device) 1117 1118 expected = np.kron(a.cpu().numpy(), b.cpu().numpy()) 1119 result = torch.kron(a, b) 1120 self.assertEqual(result, expected) 1121 1122 # check the out= variant 1123 out = torch.empty_like(result) 1124 ans = torch.kron(a, b, out=out) 1125 self.assertEqual(ans, out) 1126 self.assertEqual(ans, result) 1127 1128 shapes = [(4,), (2, 2), (1, 2, 3), (1, 2, 3, 3)] 1129 for a_shape, b_shape in itertools.product(shapes, reversed(shapes)): 1130 run_test_case(a_shape, b_shape) 1131 1132 @dtypes(*floating_and_complex_types()) 1133 def test_kron_empty(self, device, dtype): 1134 1135 def run_test_case(empty_shape): 1136 a = torch.eye(3, dtype=dtype, device=device) 1137 b = torch.empty(empty_shape, dtype=dtype, device=device) 1138 result = torch.kron(a, b) 1139 expected = np.kron(a.cpu().numpy(), b.cpu().numpy()) 1140 self.assertEqual(result, expected) 1141 1142 # NumPy doesn't work if the first argument is empty 1143 result = torch.kron(b, a) 1144 self.assertEqual(result.shape, expected.shape) 1145 1146 empty_shapes = [(0,), (2, 0), (1, 0, 3)] 1147 for empty_shape in empty_shapes: 1148 run_test_case(empty_shape) 1149 1150 @dtypes(*floating_and_complex_types()) 1151 def test_kron_errors_and_warnings(self, device, dtype): 1152 # if non-empty out tensor with wrong shape is passed a warning is given 1153 a = torch.eye(3, dtype=dtype, device=device) 1154 b = torch.ones((2, 2), dtype=dtype, device=device) 1155 out = torch.empty_like(a) 1156 with warnings.catch_warnings(record=True) as w: 1157 # Trigger warning 1158 torch.kron(a, b, out=out) 1159 # Check warning occurs 1160 self.assertEqual(len(w), 1) 1161 self.assertTrue("An output with one or more elements was resized" in str(w[-1].message)) 1162 1163 # dtypes should match 1164 out = torch.empty_like(a).to(torch.int) 1165 with self.assertRaisesRegex(RuntimeError, "can't be cast to the desired output type"): 1166 torch.kron(a, b, out=out) 1167 1168 # This test confirms that torch.linalg.norm's dtype argument works 1169 # as expected, according to the function's documentation 1170 @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble, torch.bfloat16, torch.float16) 1171 def test_norm_dtype(self, device, dtype): 1172 make_arg = partial(make_tensor, dtype=dtype, device=device) 1173 1174 def run_test_case(input_size, ord, keepdim, to_dtype): 1175 msg = ( 1176 f'input_size={input_size}, ord={ord}, keepdim={keepdim}, ' 1177 f'dtype={dtype}, to_dtype={to_dtype}') 1178 input = make_arg(input_size) 1179 result = torch.linalg.norm(input, ord, keepdim=keepdim) 1180 self.assertEqual(result.dtype, input.real.dtype, msg=msg) 1181 1182 result_out = torch.empty((0), dtype=result.dtype, device=device) 1183 torch.linalg.norm(input, ord, keepdim=keepdim, out=result_out) 1184 self.assertEqual(result, result_out, msg=msg) 1185 1186 result = torch.linalg.norm(input.to(to_dtype), ord, keepdim=keepdim) 1187 result_with_dtype = torch.linalg.norm(input, ord, keepdim=keepdim, dtype=to_dtype) 1188 self.assertEqual(result, result_with_dtype, msg=msg) 1189 1190 result_out_with_dtype = torch.empty_like(result_with_dtype) 1191 torch.linalg.norm(input, ord, keepdim=keepdim, dtype=to_dtype, out=result_out_with_dtype) 1192 self.assertEqual(result_with_dtype, result_out_with_dtype, msg=msg) 1193 1194 ord_vector = [0, 1, -1, 2, -2, 3, -3, 4.5, -4.5, inf, -inf, None] 1195 1196 # In these orders we are computing the 10-th power and 10-th root of numbers. 1197 # We avoid them for half-precision types as it makes the tests above too badly conditioned 1198 if dtype != torch.float16 and dtype != torch.bfloat16: 1199 ord_vector.extend([0.1, -0.1]) 1200 ord_matrix = ['fro', 'nuc', 1, -1, 2, -2, inf, -inf, None] 1201 S = 10 1202 1203 if dtype == torch.cfloat: 1204 norm_dtypes = (torch.cfloat, torch.cdouble) 1205 elif dtype == torch.cdouble: 1206 norm_dtypes = (torch.cdouble,) 1207 elif dtype in (torch.float16, torch.bfloat16, torch.float): 1208 norm_dtypes = (torch.float, torch.double) 1209 elif dtype == torch.double: 1210 norm_dtypes = (torch.double,) 1211 else: 1212 raise RuntimeError("Unsupported dtype") 1213 1214 for ord, keepdim, norm_dtype in product(ord_vector, (True, False), norm_dtypes): 1215 run_test_case((S,) , ord, keepdim, norm_dtype) 1216 1217 for ord, keepdim, norm_dtype in product(ord_matrix, (True, False), norm_dtypes): 1218 if ord in [2, -2, 'nuc']: 1219 # We need torch.svdvals 1220 if dtype == torch.float16 or dtype == torch.bfloat16: 1221 continue 1222 1223 # We need LAPACK or equivalent 1224 if ((torch.device(device).type == 'cuda' and not torch.cuda.has_magma and not has_cusolver()) or 1225 (torch.device(device).type == 'cpu' and not torch._C.has_lapack)): 1226 continue 1227 run_test_case((S, S) , ord, keepdim, norm_dtype) 1228 1229 # This test confirms torch.linalg.norm bfloat16 and half get right result. 1230 @dtypes(torch.bfloat16, torch.float16) 1231 def test_norm_bfloat16_and_half(self, device, dtype): 1232 make_arg = partial(make_tensor, dtype=dtype, device=device) 1233 1234 def run_test_case(input_size, ord, keepdim): 1235 msg = ( 1236 f'input_size={input_size}, ord={ord}, keepdim={keepdim}, ' 1237 f'dtype={dtype}') 1238 input = make_arg(input_size).fill_(1) 1239 result_ref = torch.linalg.norm(input.float(), ord, keepdim=keepdim).to(dtype=dtype) 1240 result = torch.linalg.norm(input, ord, keepdim=keepdim) 1241 self.assertEqual(result_ref, result, msg=msg) 1242 1243 ord_vector = [0, 1, -1, 2, -2, 3, -3, 4.5, -4.5, inf, -inf, None] 1244 for S, ord, keepdim in product((10, 2049), ord_vector, (True, False)): 1245 run_test_case((S,) , ord, keepdim, ) 1246 1247 @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble, torch.bfloat16, torch.float16) 1248 def test_vector_norm(self, device, dtype): 1249 if IS_ARM64 and device == 'cpu' and dtype in [torch.float16, torch.bfloat16, torch.float32]: 1250 raise unittest.SkipTest("Fails on ARM, see https://github.com/pytorch/pytorch/issues/125438") 1251 # have to use torch.randn(...).to(bfloat16) instead of 1252 # This test compares torch.linalg.vector_norm's output with 1253 # torch.linalg.norm given a flattened tensor 1254 ord_vector = [0, 0.9, 1, 2, 3, inf, -0.5, -1, -2, -3, -inf] 1255 input_sizes = [ 1256 (1, ), 1257 (10, ), 1258 (4, 5), 1259 (3, 4, 5), 1260 (0, ), 1261 (0, 10), 1262 (0, 0), 1263 (10, 0, 10), 1264 ] 1265 1266 def vector_norm_reference(input, ord, dim=None, keepdim=False, dtype=None): 1267 if dim is None: 1268 input_maybe_flat = input.flatten(0, -1) 1269 else: 1270 input_maybe_flat = input 1271 1272 result = torch.linalg.norm(input_maybe_flat, ord, dim=dim, keepdim=keepdim, dtype=dtype) 1273 if keepdim and dim is None: 1274 result = result.reshape([1] * input.dim()) 1275 return result 1276 1277 def run_test_case(input, ord, dim, keepdim, norm_dtype): 1278 if (input.numel() == 0 and 1279 (ord < 0. or ord == inf) and 1280 (dim is None or input.shape[dim] == 0)): 1281 # The operation does not have an identity. 1282 error_msg = "linalg.vector_norm cannot compute" 1283 with self.assertRaisesRegex(RuntimeError, error_msg): 1284 torch.linalg.vector_norm(input, ord, dim=dim, keepdim=keepdim) 1285 else: 1286 msg = (f'input.size()={input.size()}, ord={ord}, dim={dim}, ' 1287 f'keepdim={keepdim}, dtype={dtype}, norm_dtype={norm_dtype}') 1288 result_dtype_reference = vector_norm_reference(input, ord, dim=dim, keepdim=keepdim, dtype=norm_dtype) 1289 result_dtype = torch.linalg.vector_norm(input, ord, dim=dim, keepdim=keepdim, dtype=norm_dtype) 1290 if dtype.is_complex: 1291 result_dtype_reference = result_dtype_reference.real 1292 self.assertEqual(result_dtype, result_dtype_reference, msg=msg) 1293 1294 if norm_dtype is not None: 1295 ref = torch.linalg.vector_norm(input.to(norm_dtype), ord, dim=dim, keepdim=keepdim) 1296 actual = torch.linalg.vector_norm(input, ord, dim=dim, keepdim=keepdim, dtype=norm_dtype) 1297 self.assertEqual(ref, actual, msg=msg) 1298 1299 if dtype == torch.cfloat: 1300 norm_dtypes = (None, torch.cfloat, torch.cdouble) 1301 elif dtype == torch.cdouble: 1302 norm_dtypes = (None, torch.cdouble) 1303 elif dtype in (torch.float16, torch.bfloat16, torch.float): 1304 norm_dtypes = (None, torch.float, torch.double) 1305 elif dtype == torch.double: 1306 norm_dtypes = (None, torch.double) 1307 else: 1308 raise RuntimeError("Unsupported dtype") 1309 1310 for amp in [False, True]: 1311 with torch.autocast(device_type=device, enabled=amp): 1312 for input_size, ord, keepdim, norm_dtype in product(input_sizes, ord_vector, [True, False], norm_dtypes): 1313 input = make_tensor(input_size, dtype=dtype, device=device, low=-9, high=9) 1314 for dim in [None, random.randint(0, len(input_size) - 1)]: 1315 run_test_case( 1316 input, 1317 ord, 1318 dim, 1319 keepdim, 1320 norm_dtype) 1321 1322 def test_vector_norm_dim_tuple_arg(self, device): 1323 test_cases = [ 1324 # input size, dim, error, error message 1325 ((4, ), (0, ), None, None), 1326 ((4, ), (1, ), IndexError, r'Dimension out of range'), 1327 ((4, ), (-2, ), IndexError, r'Dimension out of range'), 1328 ((4, 3), (0, -1), None, None), 1329 ((4, 3), (0, 0), RuntimeError, r'dim 0 appears multiple times in the list of dims'), 1330 ((4, 3), (0, -2), RuntimeError, r'dim 0 appears multiple times in the list of dims'), 1331 ((4, 3), (0, 1.0), TypeError, r"argument 'dim' must be tuple of ints"), 1332 ((4, 3), (None, ), TypeError, r"argument 'dim' must be tuple of ints"), 1333 ] 1334 for input_size, dim_tuple, error, error_msg in test_cases: 1335 input = torch.randn(input_size, device=device) 1336 # vector_norm should accept a tuple or a list for dim arg 1337 for dim in [dim_tuple, list(dim_tuple)]: 1338 if error is None: 1339 torch.linalg.vector_norm(input, dim=dim) 1340 else: 1341 with self.assertRaises(error): 1342 torch.linalg.vector_norm(input, dim=dim) 1343 1344 # This test compares torch.linalg.norm and numpy.linalg.norm to ensure that 1345 # their vector norm results match 1346 @dtypes(torch.float, torch.double) 1347 def test_norm_vector(self, device, dtype): 1348 def run_test_case(input, p, dim, keepdim): 1349 result = torch.linalg.norm(input, ord, dim, keepdim) 1350 input_numpy = input.cpu().numpy() 1351 result_numpy = np.linalg.norm(input_numpy, ord, dim, keepdim) 1352 1353 msg = f'input.size()={input.size()}, ord={ord}, dim={dim}, keepdim={keepdim}, dtype={dtype}' 1354 self.assertEqual(result, result_numpy, msg=msg) 1355 1356 result_out = torch.empty_like(result) 1357 torch.linalg.norm(input, ord, dim, keepdim, out=result_out) 1358 self.assertEqual(result, result_out, msg=msg) 1359 1360 ord_vector = [0, 1, -1, 2, -2, 3, -3, 4.5, -4.5, inf, -inf] 1361 S = 10 1362 test_cases = [ 1363 # input size, p settings, dim 1364 ((S, ), ord_vector, None), 1365 ((S, ), ord_vector, 0), 1366 ((S, S, S), ord_vector, 0), 1367 ((S, S, S), ord_vector, 1), 1368 ((S, S, S), ord_vector, 2), 1369 ((S, S, S), ord_vector, -1), 1370 ((S, S, S), ord_vector, -2), 1371 ] 1372 L = 1_000_000 1373 if dtype == torch.double: 1374 test_cases.append(((L, ), ord_vector, None)) 1375 for keepdim in [True, False]: 1376 for input_size, ord_settings, dim in test_cases: 1377 input = torch.randn(*input_size, dtype=dtype, device=device) 1378 for ord in ord_settings: 1379 run_test_case(input, ord, dim, keepdim) 1380 1381 # This test compares torch.linalg.norm, torch.linalg.matrix_norm and numpy.linalg.norm to 1382 # ensure that their matrix norm results match. 1383 @skipMeta # https://github.com/pytorch/pytorch/issues/54082 1384 @skipCUDAIfNoMagma 1385 @dtypes(torch.float, torch.double) 1386 @precisionOverride({torch.float32: 2e-4}) 1387 def test_norm_matrix(self, device, dtype): 1388 make_arg = partial(make_tensor, dtype=dtype, device=device) 1389 1390 def run_test_case(input, ord, dim, keepdim): 1391 msg = f'input.size()={input.size()}, ord={ord}, dim={dim}, keepdim={keepdim}, dtype={dtype}' 1392 result = torch.linalg.norm(input, ord, dim, keepdim) 1393 input_numpy = input.cpu().numpy() 1394 result_numpy = np.linalg.norm(input_numpy, ord, dim, keepdim) 1395 1396 result = torch.linalg.norm(input, ord, dim, keepdim) 1397 self.assertEqual(result, result_numpy, msg=msg) 1398 if ord is not None and dim is not None: 1399 result = torch.linalg.matrix_norm(input, ord, dim, keepdim) 1400 self.assertEqual(result, result_numpy, msg=msg) 1401 1402 ord_matrix = [1, -1, 2, -2, inf, -inf, 'nuc', 'fro'] 1403 S = 10 1404 test_cases = [ 1405 # input size, dim 1406 ((S, S), None), 1407 ((S, S), (0, 1)), 1408 ((S, S), (1, 0)), 1409 ((S, S, S, S), (2, 0)), 1410 ((S, S, S, S), (-1, -2)), 1411 ((S, S, S, S), (-1, -3)), 1412 ((S, S, S, S), (-3, 2)), 1413 ] 1414 1415 for (shape, dim), keepdim, ord in product(test_cases, [True, False], ord_matrix): 1416 if ord in [2, -2, 'nuc']: 1417 # We need torch.svdvals 1418 if dtype == torch.float16 or dtype == torch.bfloat16: 1419 continue 1420 # We need LAPACK or equivalent 1421 if ((torch.device(device).type == 'cuda' and not torch.cuda.has_magma and not has_cusolver()) or 1422 (torch.device(device).type == 'cpu' and not torch._C.has_lapack)): 1423 continue 1424 run_test_case(make_arg(shape), ord, dim, keepdim) 1425 1426 1427 @onlyCUDA 1428 @dtypes(torch.bfloat16, torch.float16) 1429 def test_norm_fused_type_promotion(self, device, dtype): 1430 x = torch.randn(10, device=device, dtype=dtype) 1431 1432 def profile_and_check(fn, x, kwargs): 1433 with torch.profiler.profile(activities=(torch.profiler.ProfilerActivity.CPU,)) as p: 1434 fn(x, **kwargs, dtype=torch.float) 1435 # smoke check that profiler returned some events 1436 self.assertTrue("aten::linalg_vector_norm" in (e.name for e in p.events())) 1437 # test that there was no explicit copy 1438 self.assertFalse("aten::to" in (e.name for e in p.events())) 1439 1440 for f, kwargs, in zip((torch.linalg.vector_norm, torch.norm), ({}, {"p" : 2})): 1441 profile_and_check(f, x, kwargs) 1442 1443 @skipMeta # https://github.com/pytorch/pytorch/issues/53739 1444 @skipCPUIfNoLapack 1445 @skipCUDAIfNoMagma 1446 @dtypes(*floating_and_complex_types()) 1447 @precisionOverride({torch.float32: 1e-3}) 1448 def test_cond(self, device, dtype): 1449 def run_test_case(input, p): 1450 result = torch.linalg.cond(input, p) 1451 result_numpy = np.linalg.cond(input.cpu().numpy(), p) 1452 self.assertEqual(result, result_numpy, rtol=1e-2, atol=self.precision, exact_dtype=False) 1453 self.assertEqual(result.shape, result_numpy.shape) 1454 1455 # test out= variant 1456 out = torch.empty_like(result) 1457 ans = torch.linalg.cond(input, p, out=out) 1458 self.assertEqual(ans, out) 1459 self.assertEqual(ans, result) 1460 1461 norm_types = [1, -1, 2, -2, inf, -inf, 'fro', 'nuc', None] 1462 input_sizes = [(32, 32), (2, 3, 3, 3)] 1463 for input_size in input_sizes: 1464 input = torch.randn(*input_size, dtype=dtype, device=device) 1465 for p in norm_types: 1466 run_test_case(input, p) 1467 1468 # test empty batch sizes 1469 input_sizes = [(0, 3, 3), (0, 2, 5, 5)] 1470 for input_size in input_sizes: 1471 input = torch.randn(*input_size, dtype=dtype, device=device) 1472 for p in norm_types: 1473 run_test_case(input, p) 1474 1475 # test non-square input 1476 input_sizes = [(16, 32), (32, 16), (2, 3, 5, 3), (2, 3, 3, 5)] 1477 for input_size in input_sizes: 1478 input = torch.randn(*input_size, dtype=dtype, device=device) 1479 for p in [2, -2, None]: 1480 run_test_case(input, p) 1481 1482 # test for singular input 1483 a = torch.eye(3, dtype=dtype, device=device) 1484 a[-1, -1] = 0 # make 'a' singular 1485 for p in norm_types: 1486 try: 1487 run_test_case(a, p) 1488 except np.linalg.LinAlgError: 1489 # Numpy may fail to converge for some BLAS backends (although this is very rare) 1490 # See the discussion in https://github.com/pytorch/pytorch/issues/67675 1491 pass 1492 1493 # test for 0x0 matrices. NumPy doesn't work for such input, we return 0 1494 input_sizes = [(0, 0), (2, 5, 0, 0)] 1495 for input_size in input_sizes: 1496 input = torch.randn(*input_size, dtype=dtype, device=device) 1497 for p in ['fro', 2]: 1498 expected_dtype = a.real.dtype if dtype.is_complex else dtype 1499 expected = torch.zeros(input_size[:-2], dtype=expected_dtype, device=device) 1500 actual = torch.linalg.cond(input, p) 1501 self.assertEqual(actual, expected) 1502 1503 @skipMeta # https://github.com/pytorch/pytorch/issues/53739 1504 @skipCPUIfNoLapack 1505 @skipCUDAIfNoMagma 1506 @dtypes(*floating_and_complex_types()) 1507 @precisionOverride({torch.float32: 1e-3}) 1508 def test_cond_errors_and_warnings(self, device, dtype): 1509 norm_types = [1, -1, 2, -2, inf, -inf, 'fro', 'nuc', None] 1510 1511 # cond expects the input to be at least 2-dimensional 1512 a = torch.ones(3, dtype=dtype, device=device) 1513 for p in norm_types: 1514 with self.assertRaisesRegex(RuntimeError, r'at least 2 dimensions'): 1515 torch.linalg.cond(a, p) 1516 1517 # for some norm types cond expects the input to be square 1518 a = torch.ones(3, 2, dtype=dtype, device=device) 1519 norm_types = [1, -1, inf, -inf, 'fro', 'nuc'] 1520 for p in norm_types: 1521 with self.assertRaisesRegex(RuntimeError, r'must be batches of square matrices'): 1522 torch.linalg.cond(a, p) 1523 1524 # if non-empty out tensor with wrong shape is passed a warning is given 1525 a = torch.ones((2, 2), dtype=dtype, device=device) 1526 for p in ['fro', 2]: 1527 real_dtype = a.real.dtype if dtype.is_complex else dtype 1528 out = torch.empty(a.shape, dtype=real_dtype, device=device) 1529 with warnings.catch_warnings(record=True) as w: 1530 # Trigger warning 1531 torch.linalg.cond(a, p, out=out) 1532 # Check warning occurs 1533 self.assertEqual(len(w), 1) 1534 self.assertTrue("An output with one or more elements was resized" in str(w[-1].message)) 1535 1536 # dtypes should be safely castable 1537 out = torch.empty(0, dtype=torch.int, device=device) 1538 for p in ['fro', 2]: 1539 with self.assertRaisesRegex(RuntimeError, "but got result with dtype Int"): 1540 torch.linalg.cond(a, p, out=out) 1541 1542 # device should match 1543 if torch.cuda.is_available(): 1544 wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda' 1545 out = torch.empty(0, dtype=dtype, device=wrong_device) 1546 for p in ['fro', 2]: 1547 with self.assertRaisesRegex(RuntimeError, "tensors to be on the same device"): 1548 torch.linalg.cond(a, p, out=out) 1549 1550 # for batched input if at least one matrix in the batch is not invertible, 1551 # we can't get the result for all other (possibly) invertible matrices in the batch without an explicit for loop. 1552 # this should change when at::inverse works with silent errors 1553 # NumPy works fine in this case because it's possible to silence the error and get the inverse matrix results 1554 # possibly filled with NANs 1555 batch_dim = 3 1556 a = torch.eye(3, 3, dtype=dtype, device=device) 1557 a = a.reshape((1, 3, 3)) 1558 a = a.repeat(batch_dim, 1, 1) 1559 a[1, -1, -1] = 0 # now a[1] is singular 1560 for p in [1, -1, inf, -inf, 'fro', 'nuc']: 1561 result = torch.linalg.cond(a, p) 1562 self.assertEqual(result[1], float('inf')) 1563 1564 # check invalid norm type 1565 a = torch.ones(3, 3, dtype=dtype, device=device) 1566 for p in ['wrong_norm', 5]: 1567 with self.assertRaisesRegex(RuntimeError, f"linalg.cond got an invalid norm type: {p}"): 1568 torch.linalg.cond(a, p) 1569 1570 # This test calls torch.linalg.norm and numpy.linalg.norm with illegal arguments 1571 # to ensure that they both throw errors 1572 @dtypes(torch.float, torch.double) 1573 def test_norm_errors(self, device, dtype): 1574 def run_error_test_case(input, ord, dim, keepdim, error_type, error_regex): 1575 test_case_info = ( 1576 f'test case input.size()={input.size()}, ord={ord}, dim={dim}, ' 1577 f'keepdim={keepdim}, dtype={dtype}') 1578 1579 with self.assertRaisesRegex(error_type, error_regex, msg=test_case_info): 1580 torch.linalg.norm(input, ord, dim, keepdim) 1581 1582 input_numpy = input.cpu().numpy() 1583 1584 msg = f'numpy does not raise error but pytorch does, for case "{test_case_info}"' 1585 with self.assertRaises(Exception, msg=test_case_info): 1586 np.linalg.norm(input_numpy, ord, dim, keepdim) 1587 1588 S = 10 1589 error_test_cases = [ 1590 # input size, p settings, dim, error type, error regex 1591 ((S, ), ['fro', 'nuc'], None, RuntimeError, r'A must have at least 2 dimensions'), 1592 ((S, S), [3.5], None, RuntimeError, r'matrix_norm: Order 3.5 not supported'), 1593 ((S, S), [0], None, RuntimeError, r'matrix_norm: Order 0 not supported'), 1594 ((S, S), ['fail'], None, RuntimeError, r'matrix_norm: Order fail not supported'), 1595 ((S, S), ['fro', 'nuc'], 0, RuntimeError, r'matrix_norm: dim must be a 2-tuple'), 1596 ((S, S), ['fro', 'nuc', 2], (0, 0), RuntimeError, r'dims must be different'), 1597 ((S, S), ['fro', 'nuc', 2], (-1, 1), RuntimeError, r'dims must be different'), 1598 ((S, S), ['fro', 'nuc', 2], (0, 4), IndexError, r'Dimension out of range'), 1599 ((S, ), [0], (4, ), IndexError, r'Dimension out of range'), 1600 ((S, ), [None], (0, 0), RuntimeError, r'dim 0 appears multiple times'), 1601 ((S, S, S), [1], (0, 1, 2), RuntimeError, r"If dim is specified, it must be of length 1 or 2."), 1602 ((S, S, S), [1], None, RuntimeError, r"If dim is not specified but ord is, the input must be 1D or 2D"), 1603 ] 1604 for keepdim in [True, False]: 1605 for input_size, ord_settings, dim, error_type, error_regex in error_test_cases: 1606 input = torch.randn(*input_size, dtype=dtype, device=device) 1607 for ord in ord_settings: 1608 run_error_test_case(input, ord, dim, keepdim, error_type, error_regex) 1609 1610 # Test complex number inputs for linalg.norm 1611 @skipCUDAIfNoMagma 1612 @skipCPUIfNoLapack 1613 @dtypes(torch.cfloat, torch.cdouble) 1614 @precisionOverride({torch.cfloat: 5e-4}) 1615 def test_norm_complex(self, device, dtype): 1616 def gen_error_message(input_size, ord, keepdim, dim=None): 1617 return f"complex norm failed for input size {input_size}, ord={ord}, keepdim={keepdim}, dim={dim}" 1618 1619 vector_ords = [None, 0, 1, 2, 3, inf, -1, -2, -3, -inf] 1620 matrix_ords = [None, 'fro', 'nuc', 1, 2, inf, -1, -2, -inf] 1621 1622 # Test supported ords 1623 for keepdim in [False, True]: 1624 # vector norm 1625 x = torch.randn(25, device=device, dtype=dtype) 1626 xn = x.cpu().numpy() 1627 for ord in vector_ords: 1628 res = torch.linalg.norm(x, ord, keepdim=keepdim).cpu() 1629 expected = np.linalg.norm(xn, ord, keepdims=keepdim) 1630 msg = gen_error_message(x.size(), ord, keepdim) 1631 self.assertEqual(res.shape, expected.shape, msg=msg) 1632 self.assertEqual(res, expected, msg=msg, exact_dtype=False) 1633 1634 res_out = torch.tensor([], device=device, dtype=res.dtype) 1635 torch.linalg.norm(x, ord, keepdim=keepdim, out=res_out) 1636 self.assertEqual(res_out.shape, expected.shape, msg=msg) 1637 self.assertEqual(res_out, expected, msg=msg) 1638 1639 # matrix norm 1640 x = torch.randn(25, 25, device=device, dtype=dtype) 1641 xn = x.cpu().numpy() 1642 for ord in matrix_ords: 1643 res = torch.linalg.norm(x, ord, keepdim=keepdim).cpu() 1644 expected = np.linalg.norm(xn, ord, keepdims=keepdim) 1645 msg = gen_error_message(x.size(), ord, keepdim) 1646 self.assertEqual(res.shape, expected.shape, msg=msg) 1647 self.assertEqual(res, expected, msg=msg, exact_dtype=False) 1648 1649 res_out = torch.tensor([], device=device, dtype=res.dtype) 1650 torch.linalg.norm(x, ord, keepdim=keepdim, out=res_out) 1651 self.assertEqual(res_out.shape, expected.shape, msg=msg) 1652 self.assertEqual(res_out, expected, msg=msg) 1653 1654 # Test that linal.vector_norm gives the same result as numpy when inputs 1655 # contain extreme values (inf, -inf, nan) 1656 def test_vector_norm_extreme_values(self, device): 1657 vector_ords = [0, 1, 2, 3, inf, -1, -2, -3, -inf] 1658 vectors = [] 1659 for pair in itertools.product([inf, -inf, 0.0, nan, 1.0], repeat=2): 1660 vectors.append(list(pair)) 1661 for vector in vectors: 1662 x = torch.tensor(vector, device=device) 1663 x_n = x.cpu().numpy() 1664 for ord in vector_ords: 1665 msg = f'ord={ord}, vector={vector}' 1666 result = torch.linalg.vector_norm(x, ord=ord) 1667 result_n = np.linalg.norm(x_n, ord=ord) 1668 self.assertEqual(result, result_n, msg=msg) 1669 1670 @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble) 1671 def test_vector_norm_reduce_over_1D_vector(self, device, dtype): 1672 input_sizes_and_dims = [ 1673 ((6, 1), -1), 1674 ((3, 1, 2, 1), (1, 3)), 1675 ((1,), None), 1676 ] 1677 orders = [float('inf'), -float('inf'), 0, 1, -1, 2, -2] 1678 keepdims = [True, False] 1679 1680 for input_size_and_dim, ord, keepdim in product(input_sizes_and_dims, orders, keepdims): 1681 input_size = input_size_and_dim[0] 1682 dim = input_size_and_dim[1] 1683 if type(dim) is tuple and ord == 0: 1684 # skip because np.linalg.norm raises 'ValueError: Invalid norm order for matrices.' 1685 continue 1686 input = make_tensor(input_size, dtype=dtype, device=device, low=-9, high=9) 1687 result = torch.linalg.vector_norm(input, ord, dim, keepdim) 1688 result_numpy = np.linalg.norm(input.cpu().numpy(), ord, dim, keepdim) 1689 1690 msg = f'input.size()={input.size()}, ord={ord}, dim={dim}, keepdim={keepdim}, dtype={dtype}' 1691 self.assertEqual(result, result_numpy, msg=msg) 1692 1693 @skipCUDAIfNoMagmaAndNoCusolver 1694 @skipCPUIfNoLapack 1695 @dtypes(torch.float, torch.double) 1696 @precisionOverride({torch.float32: 2e-5}) 1697 def test_matrix_norm(self, device, dtype): 1698 # Test only inputs for which torch.linalg.matrix_norm diverges from torch.linalg.norm 1699 A = make_tensor((2, 2, 2), dtype=dtype, device=device) 1700 1701 with self.assertRaisesRegex(RuntimeError, r'linalg.matrix_norm:.*must have at least 2 dimensions.*'): 1702 torch.linalg.matrix_norm(make_tensor((2,), dtype=dtype, device=device)) 1703 with self.assertRaisesRegex(RuntimeError, r'linalg.matrix_norm:.*must be a 2-tuple.*'): 1704 torch.linalg.matrix_norm(A, dim=(0,)) 1705 with self.assertRaisesRegex(RuntimeError, r'.*not supported.*'): 1706 torch.linalg.matrix_norm(A, ord=0) 1707 with self.assertRaisesRegex(RuntimeError, r'.*not supported.*'): 1708 torch.linalg.matrix_norm(A, ord=3.0) 1709 1710 # Test dim=None behavior 1711 ref = torch.linalg.norm(A, dim=(-2, -1)) 1712 res = torch.linalg.matrix_norm(A) 1713 self.assertEqual(ref, res) 1714 1715 # Test that linal.norm gives the same result as numpy when inputs 1716 # contain extreme values (inf, -inf, nan) 1717 @unittest.skipIf(IS_WINDOWS, "Skipped on Windows!") 1718 @unittest.skipIf(IS_MACOS, "Skipped on MacOS!") 1719 @skipCUDAIfNoMagma 1720 @skipCPUIfNoLapack 1721 def test_norm_extreme_values(self, device): 1722 vector_ords = [0, 1, 2, 3, inf, -1, -2, -3, -inf] 1723 # matrix_ords 'nuc', 2, -2 are skipped currently 1724 # See issue https://github.com/pytorch/pytorch/issues/71911 1725 matrix_ords = ['fro', 1, inf, -1, -inf] 1726 vectors = [] 1727 matrices = [] 1728 for pair in itertools.product([inf, -inf, 0.0, nan, 1.0], repeat=2): 1729 vectors.append(list(pair)) 1730 matrices.append([[pair[0], pair[1]]]) 1731 matrices.append([[pair[0]], [pair[1]]]) 1732 for vector in vectors: 1733 x = torch.tensor(vector).to(device) 1734 x_n = x.cpu().numpy() 1735 for ord in vector_ords: 1736 msg = f'ord={ord}, vector={vector}' 1737 result = torch.linalg.norm(x, ord=ord) 1738 result_n = np.linalg.norm(x_n, ord=ord) 1739 self.assertEqual(result, result_n, msg=msg) 1740 1741 # TODO: Remove this function once the broken cases are fixed 1742 def is_broken_matrix_norm_case(ord, x): 1743 if self.device_type == 'cuda': 1744 if x.size() == torch.Size([1, 2]): 1745 if ord in ['nuc', 2, -2] and isnan(x[0][0]) and x[0][1] == 1: 1746 # These cases are broken because of an issue with svd 1747 # https://github.com/pytorch/pytorch/issues/43567 1748 return True 1749 if ord in ['nuc', 2, -2]: 1750 # These cases are broken because of another issue with svd 1751 # https://github.com/pytorch/pytorch/issues/52633 1752 return True 1753 return False 1754 1755 for matrix in matrices: 1756 x = torch.tensor(matrix).to(device) 1757 x_n = x.cpu().numpy() 1758 for ord in matrix_ords: 1759 msg = f'ord={ord}, matrix={matrix}' 1760 if is_broken_matrix_norm_case(ord, x): 1761 continue 1762 else: 1763 result_n = np.linalg.norm(x_n, ord=ord) 1764 result = torch.linalg.norm(x, ord=ord) 1765 self.assertEqual(result, result_n, msg=msg) 1766 1767 # Test degenerate shape results match numpy for linalg.norm vector norms 1768 @skipCUDAIfNoMagma 1769 @skipCPUIfNoLapack 1770 @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble) 1771 def test_norm_vector_degenerate_shapes(self, device, dtype): 1772 def run_test_case(input, ord, dim, keepdim): 1773 msg = f'input.size()={input.size()}, ord={ord}, dim={dim}, keepdim={keepdim}, dtype={dtype}' 1774 if (input.numel() == 0 and 1775 (ord < 0. or ord == inf) and 1776 (dim is None or input.shape[dim] == 0)): 1777 with self.assertRaises(RuntimeError): 1778 torch.linalg.norm(input, ord, dim, keepdim) 1779 else: 1780 input_numpy = input.cpu().numpy() 1781 result_numpy = np.linalg.norm(input_numpy, ord, dim, keepdim) 1782 result = torch.linalg.norm(input, ord, dim, keepdim) 1783 self.assertEqual(result, result_numpy, msg=msg) 1784 1785 ord_vector = [0, 0.5, 1, 2, 3, inf, -0.5, -1, -2, -3, -inf] 1786 S = 10 1787 test_cases = [ 1788 # input size, dim 1789 ((0, ), None), 1790 ((0, S), 0), 1791 ((0, S), 1), 1792 ((S, 0), 0), 1793 ((S, 0), 1), 1794 ] 1795 for keepdim in [True, False]: 1796 for input_size, dim in test_cases: 1797 input = torch.randn(*input_size, dtype=dtype, device=device) 1798 for ord in ord_vector: 1799 run_test_case(input, ord, dim, keepdim) 1800 1801 # Test degenerate shape results match numpy for linalg.norm matrix norms 1802 @skipCUDAIfNoMagma 1803 @skipCPUIfNoLapack 1804 @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble) 1805 def test_norm_matrix_degenerate_shapes(self, device, dtype): 1806 def run_test_case(input, ord, dim, keepdim, should_error): 1807 msg = f'input.size()={input.size()}, ord={ord}, dim={dim}, keepdim={keepdim}, dtype={dtype}' 1808 input_numpy = input.cpu().numpy() 1809 ops = [torch.linalg.norm] 1810 1811 if ord is not None and dim is not None: 1812 ops.append(torch.linalg.matrix_norm) 1813 1814 if should_error: 1815 with self.assertRaises(ValueError): 1816 np.linalg.norm(input_numpy, ord, dim, keepdim) 1817 for op in ops: 1818 with self.assertRaises(IndexError): 1819 op(input, ord, dim, keepdim) 1820 else: 1821 result_numpy = np.linalg.norm(input_numpy, ord, dim, keepdim) 1822 for op in ops: 1823 result = op(input, ord, dim, keepdim) 1824 self.assertEqual(result, result_numpy, msg=msg) 1825 1826 ord_matrix = ['fro', 'nuc', 1, 2, inf, -1, -2, -inf, None] 1827 S = 10 1828 test_cases = [ 1829 # input size, p settings that cause error, dim 1830 ((0, 0), [1, 2, inf, -1, -2, -inf], None), 1831 ((0, S), [2, inf, -2, -inf], None), 1832 ((S, 0), [1, 2, -1, -2], None), 1833 ((S, S, 0), [], (0, 1)), 1834 ((1, S, 0), [], (0, 1)), 1835 ((0, 0, S), [1, 2, inf, -1, -2, -inf], (0, 1)), 1836 ((0, 0, S), [1, 2, inf, -1, -2, -inf], (1, 0)), 1837 ] 1838 1839 for keepdim in [True, False]: 1840 for input_size, error_ords, dim in test_cases: 1841 input = torch.randn(*input_size, dtype=dtype, device=device) 1842 for ord in ord_matrix: 1843 run_test_case(input, ord, dim, keepdim, ord in error_ords) 1844 1845 def test_norm_fastpaths(self, device): 1846 x = torch.randn(3, 5, device=device) 1847 1848 # slow path 1849 result = torch.linalg.norm(x, 4.5, 1) 1850 expected = torch.pow(x.abs().pow(4.5).sum(1), 1.0 / 4.5) 1851 self.assertEqual(result, expected) 1852 1853 # fast 0-norm 1854 result = torch.linalg.norm(x, 0, 1) 1855 expected = (x != 0).type_as(x).sum(1) 1856 self.assertEqual(result, expected) 1857 1858 # fast 1-norm 1859 result = torch.linalg.norm(x, 1, 1) 1860 expected = x.abs().sum(1) 1861 self.assertEqual(result, expected) 1862 1863 # fast 2-norm 1864 result = torch.linalg.norm(x, 2, 1) 1865 expected = torch.sqrt(x.pow(2).sum(1)) 1866 self.assertEqual(result, expected) 1867 1868 # fast 3-norm 1869 result = torch.linalg.norm(x, 3, 1) 1870 expected = torch.pow(x.pow(3).abs().sum(1), 1.0 / 3.0) 1871 self.assertEqual(result, expected) 1872 1873 @skipCPUIfNoLapack 1874 @skipCUDAIfNoMagma 1875 # NumPy computes only in float64 and complex128 precisions 1876 # for float32 or complex64 results might be very different from float64 or complex128 1877 @dtypes(torch.float64, torch.complex128) 1878 def test_eig_numpy(self, device, dtype): 1879 def run_test(shape, *, symmetric=False): 1880 from torch.testing._internal.common_utils import random_symmetric_matrix 1881 1882 if not dtype.is_complex and symmetric: 1883 # for symmetric real-valued inputs eigenvalues and eigenvectors have imaginary part equal to zero 1884 # unlike NumPy the result is not cast to float32 or float64 dtype in this case 1885 a = random_symmetric_matrix(shape[-1], *shape[:-2], dtype=dtype, device=device) 1886 else: 1887 a = make_tensor(shape, dtype=dtype, device=device) 1888 1889 actual = torch.linalg.eig(a) 1890 1891 # compare with NumPy 1892 # the eigenvalues are not necessarily ordered 1893 # so order of NumPy and PyTorch can be different 1894 expected = np.linalg.eig(a.cpu().numpy()) 1895 1896 # sort NumPy output 1897 ind = np.argsort(expected[0], axis=-1)[::-1] 1898 expected = (np.take_along_axis(expected[0], ind, axis=-1), np.take_along_axis(expected[1], ind[:, None], axis=-1)) 1899 1900 # sort PyTorch output 1901 # torch.argsort doesn't work with complex inputs, NumPy sorting on CPU is used instead 1902 # RuntimeError: _th_sort not supported on CUDAType for ComplexDouble 1903 # RuntimeError: "sorting_kernel_method_name" not implemented for 'ComplexDouble' 1904 ind = np.argsort(actual[0].cpu().numpy(), axis=-1)[::-1] 1905 actual_np = [x.cpu().numpy() for x in actual] 1906 sorted_actual = ( 1907 np.take_along_axis(actual_np[0], ind, axis=-1), 1908 np.take_along_axis(actual_np[1], ind[:, None], axis=-1)) 1909 1910 self.assertEqual(expected[0], sorted_actual[0], exact_dtype=False) 1911 self.assertEqual(abs(expected[1]), abs(sorted_actual[1]), exact_dtype=False) 1912 1913 shapes = [(0, 0), # Empty matrix 1914 (5, 5), # Single matrix 1915 (0, 0, 0), (0, 5, 5), # Zero batch dimension tensors 1916 (2, 5, 5), # 3-dim tensors 1917 (2, 1, 5, 5)] # 4-dim tensors 1918 for shape in shapes: 1919 run_test(shape) 1920 run_test(shape, symmetric=True) 1921 1922 @onlyCUDA 1923 @skipCUDAIfNoMagma 1924 @dtypes(*floating_and_complex_types()) 1925 def test_eig_compare_backends(self, device, dtype): 1926 def run_test(shape, *, symmetric=False): 1927 from torch.testing._internal.common_utils import random_symmetric_matrix 1928 1929 if not dtype.is_complex and symmetric: 1930 # for symmetric real-valued inputs eigenvalues and eigenvectors have imaginary part equal to zero 1931 a = random_symmetric_matrix(shape[-1], *shape[:-2], dtype=dtype, device=device) 1932 else: 1933 a = make_tensor(shape, dtype=dtype, device=device) 1934 1935 actual = torch.linalg.eig(a) 1936 1937 complementary_device = 'cpu' 1938 1939 # compare with CPU 1940 expected = torch.linalg.eig(a.to(complementary_device)) 1941 self.assertEqual(expected[0], actual[0]) 1942 self.assertEqual(expected[1], actual[1]) 1943 1944 shapes = [(0, 0), # Empty matrix 1945 (5, 5), # Single matrix 1946 (0, 0, 0), (0, 5, 5), # Zero batch dimension tensors 1947 (2, 5, 5), # 3-dim tensors 1948 (2, 1, 5, 5)] # 4-dim tensors 1949 for shape in shapes: 1950 run_test(shape) 1951 run_test(shape, symmetric=True) 1952 1953 @slowTest 1954 @onlyCUDA 1955 @skipCUDAIfNoMagma 1956 @dtypes(torch.float32) 1957 def test_eig_check_magma(self, device, dtype): 1958 # For CUDA inputs only matrices of size larger than 2048x2048 actually call MAGMA library 1959 shape = (2049, 2049) 1960 a = make_tensor(shape, dtype=dtype, device=device) 1961 w, v = torch.linalg.eig(a) 1962 # check correctness using eigendecomposition identity 1963 self.assertEqual(a.to(v.dtype) @ v, w * v, atol=1e-3, rtol=1e-3) 1964 1965 @skipCUDAIfNoMagma 1966 @skipCPUIfNoLapack 1967 @dtypes(*floating_and_complex_types()) 1968 def test_eig_errors_and_warnings(self, device, dtype): 1969 # eig requires the input to be at least 2 dimensional tensor 1970 a = make_tensor(2, dtype=dtype, device=device) 1971 with self.assertRaisesRegex(RuntimeError, "must have at least 2 dimensions"): 1972 torch.linalg.eig(a) 1973 1974 # eig requires a square matrix 1975 a = make_tensor((2, 3), dtype=dtype, device=device) 1976 with self.assertRaisesRegex(RuntimeError, "must be batches of square matrices"): 1977 torch.linalg.eig(a) 1978 1979 # if out tensor with floating dtype is passed for complex output an error is thrown 1980 if not dtype.is_complex: 1981 # The characteristic equation is p(lambda) = lambda^2 - 2lambda + 5 = 0, with roots lambda = 1[+-]2i 1982 a = torch.tensor([[3., -2.], [4., -1.]], dtype=dtype, device=device) 1983 out0 = torch.empty(0, device=device, dtype=dtype) 1984 out1 = torch.empty(0, device=device, dtype=dtype) 1985 with self.assertRaisesRegex(RuntimeError, "Expected eigenvalues to be safely castable"): 1986 torch.linalg.eig(a, out=(out0, out1)) 1987 1988 out0 = torch.empty(0, device=device, dtype=torch.complex128) 1989 with self.assertRaisesRegex(RuntimeError, "Expected eigenvectors to be safely castable"): 1990 torch.linalg.eig(a, out=(out0, out1)) 1991 1992 # dtypes should be safely castable 1993 a = make_tensor((3, 3), dtype=dtype, device=device) 1994 out0 = torch.empty(0, dtype=torch.int, device=device) 1995 out1 = torch.empty(0, dtype=torch.int, device=device) 1996 with self.assertRaisesRegex(RuntimeError, "but got eigenvalues with dtype Int"): 1997 torch.linalg.eig(a, out=(out0, out1)) 1998 1999 out0 = torch.empty(0, dtype=torch.complex128, device=device) 2000 with self.assertRaisesRegex(RuntimeError, "but got eigenvectors with dtype Int"): 2001 torch.linalg.eig(a, out=(out0, out1)) 2002 2003 # if non-empty out tensor with wrong shape is passed a warning is given 2004 a = make_tensor((3, 3), dtype=dtype, device=device) 2005 out0 = torch.empty(1, device=device, dtype=torch.complex128) 2006 out1 = torch.empty(1, device=device, dtype=torch.complex128) 2007 with warnings.catch_warnings(record=True) as w: 2008 # Trigger warning 2009 torch.linalg.eig(a, out=(out0, out1)) 2010 # Check warning occurs 2011 self.assertEqual(len(w), 2) 2012 self.assertTrue("An output with one or more elements was resized" in str(w[-1].message)) 2013 self.assertTrue("An output with one or more elements was resized" in str(w[-2].message)) 2014 2015 # device should match 2016 if torch.cuda.is_available(): 2017 wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda' 2018 out_w = torch.empty(0, device=wrong_device, dtype=torch.complex128) 2019 out_v = torch.empty(0, device=device, dtype=torch.complex128) 2020 with self.assertRaisesRegex(RuntimeError, "tensors to be on the same device"): 2021 torch.linalg.eig(a, out=(out_w, out_v)) 2022 out_w = torch.empty(0, device=device, dtype=torch.complex128) 2023 out_v = torch.empty(0, device=wrong_device, dtype=torch.complex128) 2024 with self.assertRaisesRegex(RuntimeError, "tensors to be on the same device"): 2025 torch.linalg.eig(a, out=(out_w, out_v)) 2026 2027 @skipCPUIfNoLapack 2028 @skipCUDAIfNoMagma 2029 @dtypes(*floating_and_complex_types()) 2030 def test_eig_with_nan(self, device, dtype): 2031 for val in [np.inf, np.nan]: 2032 for batch_dim in [(), (10,)]: 2033 a = make_tensor((*batch_dim, 5, 5), device=device, dtype=dtype) 2034 a[..., -1, -1] = val 2035 2036 with self.assertRaisesRegex(RuntimeError, "torch.linalg.eig: input tensor should not"): 2037 torch.linalg.eig(a) 2038 2039 @skipCPUIfNoLapack 2040 @skipCUDAIfNoMagma 2041 # NumPy computes only in float64 and complex128 precisions 2042 # for float32 or complex64 results might be very different from float64 or complex128 2043 @dtypes(torch.float64, torch.complex128) 2044 def test_eigvals_numpy(self, device, dtype): 2045 def run_test(shape, *, symmetric=False): 2046 from torch.testing._internal.common_utils import random_symmetric_matrix 2047 2048 if not dtype.is_complex and symmetric: 2049 # for symmetric real-valued inputs eigenvalues and eigenvectors have imaginary part equal to zero 2050 # unlike NumPy the result is not cast to float32 or float64 dtype in this case 2051 a = random_symmetric_matrix(shape[-1], *shape[:-2], dtype=dtype, device=device) 2052 else: 2053 a = make_tensor(shape, dtype=dtype, device=device) 2054 2055 actual = torch.linalg.eigvals(a) 2056 2057 # compare with NumPy 2058 # the eigenvalues are not necessarily ordered 2059 # so order of NumPy and PyTorch can be different 2060 expected = np.linalg.eigvals(a.cpu().numpy()) 2061 2062 # sort NumPy output 2063 ind = np.argsort(expected, axis=-1)[::-1] 2064 expected = np.take_along_axis(expected, ind, axis=-1) 2065 2066 # sort PyTorch output 2067 # torch.argsort doesn't work with complex inputs, NumPy sorting on CPU is used instead 2068 # RuntimeError: _th_sort not supported on CUDAType for ComplexDouble 2069 # RuntimeError: "sorting_kernel_method_name" not implemented for 'ComplexDouble' 2070 ind = np.argsort(actual.cpu().numpy(), axis=-1)[::-1] 2071 actual_np = actual.cpu().numpy() 2072 sorted_actual = np.take_along_axis(actual_np, ind, axis=-1) 2073 2074 self.assertEqual(expected, sorted_actual, exact_dtype=False) 2075 2076 shapes = [(0, 0), # Empty matrix 2077 (5, 5), # Single matrix 2078 (0, 0, 0), (0, 5, 5), # Zero batch dimension tensors 2079 (2, 5, 5), # 3-dim tensors 2080 (2, 1, 5, 5)] # 4-dim tensors 2081 for shape in shapes: 2082 run_test(shape) 2083 run_test(shape, symmetric=True) 2084 2085 @onlyCUDA 2086 @skipCUDAIfNoMagma 2087 @dtypes(*floating_and_complex_types()) 2088 def test_eigvals_compare_backends(self, device, dtype): 2089 def run_test(shape, *, symmetric=False): 2090 from torch.testing._internal.common_utils import random_symmetric_matrix 2091 2092 if not dtype.is_complex and symmetric: 2093 # for symmetric real-valued inputs eigenvalues and eigenvectors have imaginary part equal to zero 2094 a = random_symmetric_matrix(shape[-1], *shape[:-2], dtype=dtype, device=device) 2095 else: 2096 a = make_tensor(shape, dtype=dtype, device=device) 2097 2098 actual = torch.linalg.eigvals(a) 2099 2100 complementary_device = 'cpu' 2101 2102 # compare with CPU 2103 expected = torch.linalg.eigvals(a.to(complementary_device)) 2104 self.assertEqual(expected, actual) 2105 2106 # check out= variant 2107 complex_dtype = dtype 2108 if not dtype.is_complex: 2109 complex_dtype = torch.complex128 if dtype == torch.float64 else torch.complex64 2110 out = torch.empty(0, dtype=complex_dtype, device=device) 2111 ans = torch.linalg.eigvals(a, out=out) 2112 self.assertEqual(ans, out) 2113 self.assertEqual(expected.to(complex_dtype), out) 2114 2115 # check non-contiguous out 2116 if a.numel() > 0: 2117 out = torch.empty(2 * shape[0], *shape[1:-1], dtype=complex_dtype, device=device)[::2] 2118 self.assertFalse(out.is_contiguous()) 2119 ans = torch.linalg.eigvals(a, out=out) 2120 self.assertEqual(ans, out) 2121 self.assertEqual(expected.to(complex_dtype), out) 2122 2123 shapes = [(0, 0), # Empty matrix 2124 (5, 5), # Single matrix 2125 (0, 0, 0), (0, 5, 5), # Zero batch dimension tensors 2126 (2, 5, 5), # 3-dim tensors 2127 (2, 1, 5, 5)] # 4-dim tensors 2128 for shape in shapes: 2129 run_test(shape) 2130 run_test(shape, symmetric=True) 2131 2132 @skipCUDAIfNoMagma 2133 @skipCPUIfNoLapack 2134 @dtypes(*floating_and_complex_types()) 2135 def test_eigvals_errors_and_warnings(self, device, dtype): 2136 # eig requires the input to be at least 2 dimensional tensor 2137 a = make_tensor(2, dtype=dtype, device=device) 2138 with self.assertRaisesRegex(RuntimeError, "must have at least 2 dimensions"): 2139 torch.linalg.eigvals(a) 2140 2141 # eig requires a square matrix 2142 a = make_tensor((2, 3), dtype=dtype, device=device) 2143 with self.assertRaisesRegex(RuntimeError, "must be batches of square matrices"): 2144 torch.linalg.eigvals(a) 2145 2146 # if out tensor with floating dtype is passed for complex output an error is thrown 2147 if not dtype.is_complex: 2148 # The characteristic equation is p(lambda) = lambda^2 - 2lambda + 5 = 0, with roots lambda = 1[+-]2i 2149 a = torch.tensor([[3., -2.], [4., -1.]], dtype=dtype, device=device) 2150 out = torch.empty(0, device=device, dtype=dtype) 2151 with self.assertRaisesRegex(RuntimeError, "Expected eigenvalues to be safely castable"): 2152 torch.linalg.eigvals(a, out=out) 2153 2154 # dtypes should be safely castable 2155 a = make_tensor((3, 3), dtype=dtype, device=device) 2156 out = torch.empty(0, dtype=torch.int, device=device) 2157 with self.assertRaisesRegex(RuntimeError, "but got eigenvalues with dtype Int"): 2158 torch.linalg.eigvals(a, out=out) 2159 2160 # if non-empty out tensor with wrong shape is passed a warning is given 2161 out = torch.empty(1, device=device, dtype=torch.complex128) 2162 with warnings.catch_warnings(record=True) as w: 2163 # Trigger warning 2164 torch.linalg.eigvals(a, out=out) 2165 # Check warning occurs 2166 self.assertEqual(len(w), 1) 2167 self.assertTrue("An output with one or more elements was resized" in str(w[-1].message)) 2168 2169 # device should match 2170 if torch.cuda.is_available(): 2171 wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda' 2172 out_w = torch.empty(0, device=wrong_device, dtype=torch.complex128) 2173 with self.assertRaisesRegex(RuntimeError, "tensors to be on the same device"): 2174 torch.linalg.eigvals(a, out=out_w) 2175 2176 @skipCUDAIfNoMagma 2177 @skipCPUIfNoLapack 2178 def test_norm_old(self, device): 2179 def gen_error_message(input_size, p, keepdim, dim=None): 2180 return f"norm failed for input size {input_size}, p={p}, keepdim={keepdim}, dim={dim}" 2181 2182 # 'nuc' norm uses SVD, and thus its precsion is much lower than other norms. 2183 # test_svd takes @precisionOverride({torch.float: 1e-4, torch.cfloat: 2e-4}), 2184 # and here we are doing the same thing for nuc norm. 2185 class PrecisionContext: 2186 def __init__(self, test, norm): 2187 self.norm = norm 2188 self.saved_overrides = getattr(test, 'precision_overrides', None) 2189 self.target_test = test 2190 2191 def __enter__(self): 2192 if 'nuc' != self.norm: 2193 return None 2194 self.target_test.precision_overrides = {torch.float: 1e-4, torch.cfloat: 2e-4} 2195 return self.target_test.precision_overrides 2196 2197 def __exit__(self, type, value, tb) -> bool: 2198 if 'nuc' != self.norm: 2199 return True 2200 if self.saved_overrides is None: 2201 delattr(self.target_test, 'precision_overrides') 2202 else: 2203 self.target_test.precision_overrides = self.saved_overrides 2204 return True 2205 2206 for keepdim in [False, True]: 2207 # full reduction 2208 x = torch.randn(25, device=device) 2209 xn = x.cpu().numpy() 2210 for p in [0, 1, 2, 3, 4, inf, -inf, -1, -2, -3, 1.5]: 2211 res = x.norm(p, keepdim=keepdim).cpu() 2212 expected = np.linalg.norm(xn, p, keepdims=keepdim) 2213 self.assertEqual(res, expected, atol=1e-5, rtol=0, msg=gen_error_message(x.size(), p, keepdim)) 2214 2215 # one dimension 2216 x = torch.randn(25, 25, device=device) 2217 xn = x.cpu().numpy() 2218 for p in [0, 1, 2, 3, 4, inf, -inf, -1, -2, -3]: 2219 dim = 1 2220 res = x.norm(p, dim, keepdim=keepdim).cpu() 2221 expected = np.linalg.norm(xn, p, dim, keepdims=keepdim) 2222 msg = gen_error_message(x.size(), p, keepdim, dim) 2223 self.assertEqual(res.shape, expected.shape, msg=msg) 2224 self.assertEqual(res, expected, msg=msg) 2225 2226 # matrix norm 2227 for p in ['fro', 'nuc']: 2228 res = x.norm(p, keepdim=keepdim).cpu() 2229 expected = np.linalg.norm(xn, p, keepdims=keepdim) 2230 msg = gen_error_message(x.size(), p, keepdim) 2231 with PrecisionContext(self, p): 2232 self.assertEqual(res.shape, expected.shape, msg=msg) 2233 self.assertEqual(res, expected, msg=msg) 2234 2235 # zero dimensions 2236 x = torch.randn((), device=device) 2237 xn = x.cpu().numpy() 2238 res = x.norm(keepdim=keepdim).cpu() 2239 expected = np.linalg.norm(xn, keepdims=keepdim) 2240 msg = gen_error_message(x.size(), None, keepdim) 2241 self.assertEqual(res.shape, expected.shape, msg=msg) 2242 self.assertEqual(res, expected, msg=msg) 2243 2244 # larger tensor sanity check 2245 self.assertEqual( 2246 2 * torch.norm(torch.ones(10000), keepdim=keepdim), 2247 torch.norm(torch.ones(40000), keepdim=keepdim)) 2248 2249 # matrix norm with non-square >2-D tensors, all combinations of reduction dims 2250 x = torch.randn(5, 6, 7, 8, device=device) 2251 xn = x.cpu().numpy() 2252 for p in ['fro', 'nuc']: 2253 for dim in itertools.product(*[list(range(4))] * 2): 2254 if dim[0] == dim[1]: 2255 continue 2256 res = x.norm(p=p, dim=dim, keepdim=keepdim).cpu() 2257 expected = np.linalg.norm(xn, ord=p, axis=dim, keepdims=keepdim) 2258 msg = gen_error_message(x.size(), p, keepdim, dim) 2259 with PrecisionContext(self, p): 2260 self.assertEqual(res.shape, expected.shape, msg=msg) 2261 self.assertEqual(res, expected, msg=msg) 2262 2263 # Test that torch.norm with p=+/-inf propagates NaN 2264 def test_norm_old_nan_propagation(self, device): 2265 ords = [inf, -inf] 2266 for pair in itertools.product([0.0, nan, 1.0], repeat=2): 2267 x = torch.tensor(list(pair), device=device) 2268 for ord in ords: 2269 result = torch.norm(x, p=ord) 2270 result_check = torch.linalg.norm(x, ord=ord) 2271 self.assertEqual(result, result_check) 2272 2273 @skipCUDAIfNoMagma 2274 @skipCPUIfNoLapack 2275 def test_norm_complex_old(self, device): 2276 def gen_error_message(input_size, p, keepdim, dim=None): 2277 return f"complex norm failed for input size {input_size}, p={p}, keepdim={keepdim}, dim={dim}" 2278 2279 for keepdim in [False, True]: 2280 # vector norm 2281 x = torch.randn(25, device=device) + 1j * torch.randn(25, device=device) 2282 xn = x.cpu().numpy() 2283 for p in [0, 1, 2, 3, inf, -1, -2, -3, -inf]: 2284 res = x.norm(p, keepdim=keepdim).cpu() 2285 expected = np.linalg.norm(xn, p, keepdims=keepdim) 2286 msg = gen_error_message(x.size(), p, keepdim) 2287 self.assertEqual(res.shape, expected.shape, msg=msg) 2288 self.assertEqual(res, expected, msg=msg) 2289 2290 # matrix norm 2291 x = torch.randn(25, 25, device=device) + 1j * torch.randn(25, 25, device=device) 2292 xn = x.cpu().numpy() 2293 for p in ['nuc', 'fro']: 2294 res = x.norm(p, keepdim=keepdim).cpu() 2295 expected = np.linalg.norm(xn, p, keepdims=keepdim) 2296 msg = gen_error_message(x.size(), p, keepdim) 2297 self.assertEqual(res.shape, expected.shape, msg=msg) 2298 self.assertEqual(res, expected, msg=msg, rtol=4e-6, atol=6e-4) 2299 2300 # Ensure torch.norm with p='fro' and p=2 give the same results for mutually supported input combinations 2301 @dtypes(torch.float) 2302 def test_norm_fro_2_equivalence_old(self, device, dtype): 2303 input_sizes = [ 2304 (0,), 2305 (10,), 2306 (0, 0), 2307 (4, 30), 2308 (0, 45), 2309 (100, 0), 2310 (45, 10, 23), 2311 (0, 23, 59), 2312 (23, 0, 37), 2313 (34, 58, 0), 2314 (0, 0, 348), 2315 (0, 3434, 0), 2316 (0, 0, 0), 2317 (5, 3, 8, 1, 3, 5)] 2318 2319 for input_size in input_sizes: 2320 a = make_tensor(input_size, dtype=dtype, device=device, low=-9, high=9) 2321 2322 # Try full reduction 2323 dim_settings = [None] 2324 2325 # Try all possible 1-D reductions 2326 dim_settings += list(range(-a.dim(), a.dim())) 2327 2328 def wrap_dim(dim, ndims): 2329 assert (dim < ndims) and (dim >= -ndims) 2330 if dim >= 0: 2331 return dim 2332 else: 2333 return dim + ndims 2334 2335 # Try all possible 2-D reductions 2336 dim_settings += [ 2337 (d0, d1) for d0, d1 in itertools.combinations(range(-a.dim(), a.dim()), 2) 2338 if wrap_dim(d0, a.dim()) != wrap_dim(d1, a.dim())] 2339 2340 for dim in dim_settings: 2341 for keepdim in [True, False]: 2342 a_norm_2 = torch.norm(a, p=2, dim=dim, keepdim=keepdim) 2343 a_norm_fro = torch.norm(a, p='fro', dim=dim, keepdim=keepdim) 2344 self.assertEqual(a_norm_fro, a_norm_2) 2345 2346 @skipIfTorchDynamo("Not a TorchDynamo suitable test") 2347 @skipCUDAIfNoMagma 2348 @skipCPUIfNoLapack 2349 def test_nuclear_norm_axes_small_brute_force_old(self, device): 2350 def check_single_nuclear_norm(x, axes): 2351 if self.device_type != 'cpu' and randrange(100) < 95: 2352 return # too many cpu <==> device copies 2353 2354 a = np.array(x.cpu(), copy=False) 2355 expected = np.linalg.norm(a, "nuc", axis=axes) 2356 2357 ans = torch.norm(x, "nuc", dim=axes) 2358 self.assertTrue(ans.is_contiguous()) 2359 self.assertEqual(ans.shape, expected.shape) 2360 self.assertEqual(ans.cpu(), expected, rtol=1e-02, atol=1e-03, equal_nan=True) 2361 2362 out = torch.zeros(expected.shape, dtype=x.dtype, device=x.device) 2363 ans = torch.norm(x, "nuc", dim=axes, out=out) 2364 self.assertIs(ans, out) 2365 self.assertTrue(ans.is_contiguous()) 2366 self.assertEqual(ans.shape, expected.shape) 2367 self.assertEqual(ans.cpu(), expected, rtol=1e-02, atol=1e-03, equal_nan=True) 2368 2369 for n in range(1, 3): 2370 for m in range(1, 3): 2371 for axes in itertools.permutations([0, 1], 2): 2372 # 2d, inner dimensions C 2373 x = torch.randn(n, m, device=device) 2374 check_single_nuclear_norm(x, axes) 2375 2376 # 2d, inner dimensions Fortran 2377 x = torch.randn(m, n, device=device).mT 2378 check_single_nuclear_norm(x, axes) 2379 2380 # 2d, inner dimensions non-contiguous 2381 x = torch.randn(n, 2 * m, device=device)[:, ::2] 2382 check_single_nuclear_norm(x, axes) 2383 2384 # 2d, all dimensions non-contiguous 2385 x = torch.randn(7 * n, 2 * m, device=device)[::7, ::2] 2386 check_single_nuclear_norm(x, axes) 2387 2388 for o in range(1, 3): 2389 for axes in itertools.permutations([0, 1, 2], 2): 2390 # 3d, inner dimensions C 2391 x = torch.randn(o, n, m, device=device) 2392 check_single_nuclear_norm(x, axes) 2393 2394 # 3d, inner dimensions Fortran 2395 x = torch.randn(o, m, n, device=device).mT 2396 check_single_nuclear_norm(x, axes) 2397 2398 # 3d, inner dimensions non-contiguous 2399 x = torch.randn(o, n, 2 * m, device=device)[:, :, ::2] 2400 check_single_nuclear_norm(x, axes) 2401 2402 # 3d, all dimensions non-contiguous 2403 x = torch.randn(7 * o, 5 * n, 2 * m, device=device)[::7, ::5, ::2] 2404 check_single_nuclear_norm(x, axes) 2405 2406 for r in range(1, 3): 2407 for axes in itertools.permutations([0, 1, 2, 3], 2): 2408 # 4d, inner dimensions C 2409 x = torch.randn(r, o, n, m, device=device) 2410 check_single_nuclear_norm(x, axes) 2411 2412 # 4d, inner dimensions Fortran 2413 x = torch.randn(r, o, n, m, device=device).mT 2414 check_single_nuclear_norm(x, axes) 2415 2416 # 4d, inner dimensions non-contiguous 2417 x = torch.randn(r, o, n, 2 * m, device=device)[:, :, :, ::2] 2418 check_single_nuclear_norm(x, axes) 2419 2420 # 4d, all dimensions non-contiguous 2421 x = torch.randn(7 * r, 5 * o, 11 * n, 2 * m, device=device)[::7, ::5, ::11, ::2] 2422 check_single_nuclear_norm(x, axes) 2423 2424 @skipCUDAIfNoMagma 2425 def test_nuclear_norm_exceptions_old(self, device): 2426 for lst in [], [1], [1, 2]: 2427 x = torch.tensor(lst, dtype=torch.double, device=device) 2428 for axes in (), (0,): 2429 self.assertRaises(RuntimeError, torch.norm, x, "nuc", axes) 2430 self.assertRaises(RuntimeError, torch.norm, x, "nuc", (0, 1)) 2431 2432 x = torch.tensor([[0, 1, 2], [3, 4, 5]], dtype=torch.double, device=device) 2433 self.assertRaisesRegex(RuntimeError, "must be different", torch.norm, x, "nuc", (0, 0)) 2434 self.assertRaisesRegex(IndexError, "Dimension out of range", torch.norm, x, "nuc", (0, 2)) 2435 2436 @skipCUDAIfNoCusolver 2437 @skipCPUIfNoLapack 2438 @dtypes(torch.double, torch.cdouble) 2439 def test_svd_lowrank(self, device, dtype): 2440 from torch.testing._internal.common_utils import random_lowrank_matrix, random_sparse_matrix 2441 2442 def run_subtest(actual_rank, matrix_size, batches, device, svd_lowrank, **options): 2443 density = options.pop('density', 1) 2444 if isinstance(matrix_size, int): 2445 rows = columns = matrix_size 2446 else: 2447 rows, columns = matrix_size 2448 if density == 1: 2449 a_input = random_lowrank_matrix(actual_rank, rows, columns, *batches, device=device, dtype=dtype) 2450 a = a_input 2451 else: 2452 assert batches == () 2453 a_input = random_sparse_matrix(rows, columns, density, device=device, dtype=dtype) 2454 a = a_input.to_dense() 2455 2456 q = min(*size) 2457 u, s, v = svd_lowrank(a_input, q=q, **options) 2458 2459 # check if u, s, v is a SVD 2460 u, s, v = u[..., :q], s[..., :q], v[..., :q] 2461 A = (u * s.unsqueeze(-2)).matmul(v.mH) 2462 self.assertEqual(A, a, rtol=1e-7, atol=2e-7) 2463 2464 # check if svd_lowrank produces same singular values as linalg.svdvals 2465 U, S, Vh = torch.linalg.svd(a, full_matrices=False) 2466 V = Vh.mH 2467 self.assertEqual(s, S) 2468 2469 if density == 1: 2470 # actual_rank is known only for dense inputs 2471 # 2472 # check if pairs (u, U) and (v, V) span the same 2473 # subspaces, respectively 2474 u, v = u[..., :actual_rank], v[..., :actual_rank] 2475 U, V = U[..., :actual_rank], V[..., :actual_rank] 2476 expected_ones = u.mH.matmul(U).det().abs() 2477 self.assertEqual(expected_ones, torch.ones_like(expected_ones)) 2478 self.assertEqual(v.mH.matmul(V).det().abs(), torch.ones_like(expected_ones)) 2479 2480 all_batches = [(), (1,), (3,), (2, 3)] 2481 for actual_rank, size, all_batches in [ # noqa: B020 2482 (2, (17, 4), all_batches), 2483 (4, (17, 4), all_batches), 2484 (4, (17, 17), all_batches), 2485 (10, (100, 40), all_batches), 2486 (7, (1000, 1000), [()]), 2487 ]: 2488 # dense input 2489 for batches in all_batches: 2490 run_subtest(actual_rank, size, batches, device, torch.svd_lowrank) 2491 if size != size[::-1]: 2492 run_subtest(actual_rank, size[::-1], batches, device, torch.svd_lowrank) 2493 2494 # sparse input 2495 for size in [(17, 4), (4, 17), (17, 17), (100, 40), (40, 100), (1000, 1000)]: 2496 for density in [0.005, 0.1]: 2497 run_subtest(None, size, (), device, torch.svd_lowrank, density=density) 2498 2499 # jitting support 2500 jitted = torch.jit.script(torch.svd_lowrank) 2501 actual_rank, size, batches = 2, (17, 4), () 2502 run_subtest(actual_rank, size, batches, device, jitted) 2503 2504 @skipCUDAIfNoMagmaAndNoCusolver 2505 @skipCPUIfNoLapack 2506 @precisionOverride({torch.float: 1e-4, torch.cfloat: 2e-4}) 2507 @setLinalgBackendsToDefaultFinally 2508 @dtypes(*floating_and_complex_types()) 2509 @serialTest() 2510 def test_svd(self, device, dtype): 2511 # tests linalg.svd, svd, linalg.svdvals 2512 make_arg = partial(make_tensor, dtype=dtype, device=device) 2513 2514 backends = ["default"] 2515 2516 if torch.device(device).type == 'cuda': 2517 if torch.cuda.has_magma: 2518 backends.append("magma") 2519 if has_cusolver() or has_hipsolver(): 2520 backends.append("cusolver") 2521 2522 ns = (12, 4, 2, 0) 2523 batches = ((), (0,), (1,), (2,), (2, 1), (0, 2)) 2524 drivers = (None, 'gesvd', 'gesvdj', 'gesvda') 2525 2526 for backend in backends: 2527 torch.backends.cuda.preferred_linalg_library(backend) 2528 2529 for batch, m, n, driver in product(batches, ns, ns, drivers): 2530 if not (backend == 'cusolver' or driver is None): 2531 # only test cases below and skip otherwise: 2532 # - backend == 'cusolver' (driver can be anything) 2533 # - backend != 'cusolver' (driver should only be None) 2534 continue 2535 2536 shape = batch + (m, n) 2537 k = min(m, n) 2538 A = make_arg(shape) 2539 U, S, Vh = torch.linalg.svd(A, full_matrices=False, driver=driver) 2540 self.assertEqual((U @ S.to(A.dtype).diag_embed()) @ Vh, A) 2541 2542 U_f, S_f, Vh_f = torch.linalg.svd(A, full_matrices=True, driver=driver) 2543 self.assertEqual(S_f, S) 2544 self.assertEqual((U_f[..., :k] @ S_f.to(A.dtype).diag_embed()) @ Vh_f[..., :k, :], A) 2545 2546 S_s = torch.linalg.svdvals(A, driver=driver) 2547 self.assertEqual(S_s, S) 2548 2549 U, S, V = torch.svd(A, some=True) 2550 self.assertEqual((U @ S.to(A.dtype).diag_embed()) @ V.mH, A) 2551 2552 U_f, S_f, V_f = torch.svd(A, some=False) 2553 self.assertEqual(S_f, S) 2554 self.assertEqual((U_f[..., :k] @ S_f.to(A.dtype).diag_embed()) @ V_f[..., :k].mH, A) 2555 2556 S_s = torch.svd(A, compute_uv=False).S 2557 self.assertEqual(S_s, S) 2558 2559 @skipCUDAIfNoMagmaAndNoCusolver 2560 @skipCPUIfNoLapack 2561 @dtypes(torch.complex128) 2562 def test_invariance_error_spectral_decompositions(self, device, dtype): 2563 make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=True) 2564 A = make_arg((3, 3)) 2565 with self.assertRaisesRegex(RuntimeError, "ill-defined"): 2566 U, _, Vh = torch.linalg.svd(A, full_matrices=False) 2567 (U + Vh).sum().abs().backward() 2568 2569 A = make_arg((3, 3)) 2570 with self.assertRaisesRegex(RuntimeError, "ill-defined"): 2571 V = torch.linalg.eig(A).eigenvectors 2572 V.sum().abs().backward() 2573 2574 A = make_arg((3, 3)) 2575 A = A + A.mH 2576 with self.assertRaisesRegex(RuntimeError, "ill-defined"): 2577 Q = torch.linalg.eigh(A).eigenvectors 2578 Q.sum().abs().backward() 2579 2580 @skipCUDAIfNoCusolver # MAGMA backend doesn't work in this case 2581 @precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4}) 2582 @skipCPUIfNoLapack 2583 @dtypes(*floating_and_complex_types()) 2584 def test_svd_memory_allocation(self, device, dtype): 2585 # test for https://github.com/pytorch/pytorch/issues/61949 2586 # the problem was that tensors of incorrect size were allocated and then narrowed 2587 m = 3 2588 n = 2**20 2589 a = make_tensor((m, n), dtype=dtype, device=device) 2590 # the following should run without errors 2591 S = torch.linalg.svdvals(a) 2592 result = torch.linalg.svd(a, full_matrices=False) 2593 self.assertEqual(result.S, S) 2594 2595 def cholesky_solve_test_helper(self, A_dims, b_dims, upper, device, dtype): 2596 from torch.testing._internal.common_utils import random_hermitian_pd_matrix 2597 2598 b = torch.randn(*b_dims, dtype=dtype, device=device) 2599 A = random_hermitian_pd_matrix(*A_dims, dtype=dtype, device=device) 2600 L = torch.cholesky(A, upper=upper) 2601 return b, A, L 2602 2603 @skipCUDAIfNoMagma 2604 @skipCPUIfNoLapack 2605 @dtypes(*floating_and_complex_types()) 2606 @precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3, 2607 torch.float64: 1e-8, torch.complex128: 1e-8}) 2608 def test_cholesky_solve(self, device, dtype): 2609 for (k, n), upper in itertools.product(zip([2, 3, 5], [3, 5, 7]), [True, False]): 2610 b, A, L = self.cholesky_solve_test_helper((n,), (n, k), upper, device, dtype) 2611 x = torch.cholesky_solve(b, L, upper=upper) 2612 self.assertEqual(b, np.matmul(A.cpu(), x.cpu())) 2613 2614 @skipCUDAIfNoMagma 2615 @skipCPUIfNoLapack 2616 @dtypes(*floating_and_complex_types()) 2617 @precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3, 2618 torch.float64: 1e-8, torch.complex128: 1e-8}) 2619 def test_cholesky_solve_batched(self, device, dtype): 2620 def cholesky_solve_batch_helper(A_dims, b_dims, upper): 2621 b, A, L = self.cholesky_solve_test_helper(A_dims, b_dims, upper, device, dtype) 2622 x_exp_list = [] 2623 for i in range(b_dims[0]): 2624 x_exp_list.append(torch.cholesky_solve(b[i], L[i], upper=upper)) 2625 x_exp = torch.stack(x_exp_list) # Stacked output 2626 x_act = torch.cholesky_solve(b, L, upper=upper) # Actual output 2627 self.assertEqual(x_act, x_exp) # Equality check 2628 Ax = np.matmul(A.cpu(), x_act.cpu()) 2629 self.assertEqual(b, Ax) # Correctness check 2630 2631 for upper, batchsize in itertools.product([True, False], [1, 3, 4]): 2632 cholesky_solve_batch_helper((5, batchsize), (batchsize, 5, 10), upper) 2633 2634 @slowTest 2635 @skipCUDAIfNoMagma 2636 @skipCPUIfNoLapack 2637 @dtypes(*floating_and_complex_types()) 2638 @precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3, 2639 torch.float64: 1e-8, torch.complex128: 1e-8}) 2640 def test_cholesky_solve_batched_many_batches(self, device, dtype): 2641 for A_dims, b_dims in zip([(5, 256, 256), (5,)], [(5, 10), (512, 512, 5, 10)]): 2642 for upper in [True, False]: 2643 b, A, L = self.cholesky_solve_test_helper(A_dims, b_dims, upper, device, dtype) 2644 x = torch.cholesky_solve(b, L, upper) 2645 Ax = torch.matmul(A, x) 2646 self.assertEqual(Ax, b.expand_as(Ax)) 2647 2648 @skipCUDAIfNoMagma 2649 @skipCPUIfNoLapack 2650 @dtypes(*floating_and_complex_types()) 2651 @precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3, 2652 torch.float64: 1e-8, torch.complex128: 1e-8}) 2653 def test_cholesky_solve_batched_broadcasting(self, device, dtype): 2654 from numpy.linalg import solve 2655 from torch.testing._internal.common_utils import random_hermitian_pd_matrix 2656 2657 def run_test(A_dims, b_dims, upper): 2658 A_matrix_size = A_dims[-1] 2659 A_batch_dims = A_dims[:-2] 2660 A = random_hermitian_pd_matrix(A_matrix_size, *A_batch_dims, 2661 dtype=dtype, device='cpu') 2662 b = torch.randn(*b_dims, dtype=dtype, device='cpu') 2663 x_exp = torch.tensor(solve(A.numpy(), b.numpy()), dtype=dtype, device=device) 2664 A, b = A.to(dtype=dtype, device=device), b.to(dtype=dtype, device=device) 2665 L = torch.linalg.cholesky(A, upper=upper) 2666 x = torch.cholesky_solve(b, L, upper=upper) 2667 self.assertEqual(x, x_exp) 2668 # https://github.com/pytorch/pytorch/issues/42695 2669 x = torch.cholesky_solve(b, L, upper=upper, out=x) 2670 self.assertEqual(x, x_exp) 2671 2672 # test against numpy.linalg.solve 2673 for upper in [True, False]: 2674 run_test((2, 1, 3, 4, 4), (2, 1, 3, 4, 6), upper) # no broadcasting 2675 run_test((2, 1, 3, 4, 4), (4, 6), upper) # broadcasting b 2676 run_test((4, 4), (2, 1, 3, 4, 2), upper) # broadcasting A 2677 run_test((1, 3, 1, 4, 4), (2, 1, 3, 4, 5), upper) # broadcasting A & b 2678 2679 @skipCUDAIfNoMagma 2680 @skipCPUIfNoLapack 2681 @dtypes(*floating_and_complex_types()) 2682 def test_cholesky_solve_out_errors_and_warnings(self, device, dtype): 2683 # dtypes should be safely castable 2684 a = torch.eye(2, dtype=dtype, device=device) 2685 b = torch.randn(2, 1, dtype=dtype, device=device) 2686 out = torch.empty(0, dtype=torch.int, device=device) 2687 with self.assertRaisesRegex(RuntimeError, "but got result with dtype Int"): 2688 torch.cholesky_solve(b, a, out=out) 2689 2690 # device should match 2691 if torch.cuda.is_available(): 2692 wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda' 2693 out = torch.empty(0, dtype=dtype, device=wrong_device) 2694 with self.assertRaisesRegex(RuntimeError, "tensors to be on the same device"): 2695 torch.cholesky_solve(b, a, out=out) 2696 2697 # if out tensor with wrong shape is passed a warning is given 2698 with warnings.catch_warnings(record=True) as w: 2699 out = torch.empty(1, dtype=dtype, device=device) 2700 # Trigger warning 2701 torch.cholesky_solve(b, a, out=out) 2702 # Check warning occurs 2703 self.assertEqual(len(w), 1) 2704 self.assertTrue("An output with one or more elements was resized" in str(w[-1].message)) 2705 2706 @skipCUDAIfNoMagma 2707 @skipCPUIfNoLapack 2708 @dtypes(torch.double) 2709 def test_cholesky_solve_backward(self, device, dtype): 2710 b_dims = (5, 2) 2711 L_dims = (5, 5) 2712 2713 for test_L_grad in (False, True): 2714 b = torch.randn(*b_dims, dtype=dtype, device=device, requires_grad=True) 2715 L = torch.randn(*L_dims, dtype=dtype, device=device, requires_grad=test_L_grad) 2716 if test_L_grad: 2717 torch.autograd.gradcheck(lambda b, L: torch.cholesky_solve(b, torch.tril(L), upper=False), (b, L)) 2718 else: 2719 torch.autograd.gradcheck(lambda b: torch.cholesky_solve(b, L, upper=False), (b,)) 2720 2721 @skipCUDAIfNoMagmaAndNoCusolver 2722 @skipCPUIfNoLapack 2723 @dtypes(*floating_and_complex_types()) 2724 @precisionOverride({torch.float32: 2e-3, torch.complex64: 2e-3, 2725 torch.float64: 1e-8, torch.complex128: 1e-8}) 2726 def test_inverse(self, device, dtype): 2727 make_fullrank = make_fullrank_matrices_with_distinct_singular_values 2728 make_arg = partial(make_fullrank, device=device, dtype=dtype) 2729 2730 def run_test(torch_inverse, matrix, batches, n): 2731 matrix_inverse = torch_inverse(matrix) 2732 2733 # Compare against NumPy output 2734 # NumPy uses 'gesv' LAPACK routine solving the equation A A_inv = I 2735 # But in PyTorch 'gertf' + 'getrs' is used. As such, there may be some element-wise differences 2736 expected = np.linalg.inv(matrix.cpu().numpy()) 2737 self.assertEqual(matrix_inverse, expected, atol=self.precision, rtol=self.precision) 2738 2739 # Additional correctness tests, check matrix*matrix_inverse == identity 2740 identity = torch.eye(n, dtype=dtype, device=device) 2741 self.assertEqual(identity.expand_as(matrix), np.matmul(matrix.cpu(), matrix_inverse.cpu())) 2742 self.assertEqual(identity.expand_as(matrix), np.matmul(matrix_inverse.cpu(), matrix.cpu())) 2743 2744 # check the out= variant 2745 # prepare the expected out tensor 2746 matrix_inverse_out = torch.empty(*batches, n, n, dtype=dtype, device=device) 2747 matrix_inverse_out_t = matrix_inverse_out.mT.clone(memory_format=torch.contiguous_format) 2748 matrix_inverse_out = matrix_inverse_out_t.mT 2749 ans = torch_inverse(matrix, out=matrix_inverse_out) 2750 self.assertEqual(matrix_inverse_out, ans, atol=0, rtol=0) 2751 self.assertEqual(matrix_inverse_out, matrix_inverse, atol=0, rtol=0) 2752 2753 # batched matrices: 3+ dimensional tensors, check matrix_inverse same as single-inverse for each matrix 2754 if matrix.ndim > 2 and batches[0] != 0: 2755 expected_inv_list = [] 2756 p = int(np.prod(batches)) # use `p` instead of -1, so that the test works for empty input as well 2757 for mat in matrix.contiguous().view(p, n, n): 2758 expected_inv_list.append(torch_inverse(mat)) 2759 expected_inv = torch.stack(expected_inv_list).view(*batches, n, n) 2760 if self.device_type == 'cuda' and dtype in [torch.float32, torch.complex64]: 2761 # single-inverse is done using cuSOLVER, while batched inverse is done using MAGMA 2762 # individual values can be significantly different for fp32, hence rather high rtol is used 2763 # the important thing is that torch_inverse passes above checks with identity 2764 self.assertEqual(matrix_inverse, expected_inv, atol=1e-1, rtol=1e-2) 2765 else: 2766 self.assertEqual(matrix_inverse, expected_inv) 2767 2768 # helper function for testing torch.linalg.inv_ex 2769 def test_inv_ex(input, out=None): 2770 if out is not None: 2771 info = torch.empty(0, dtype=torch.int32, device=device) 2772 return torch.linalg.inv_ex(input, out=(out, info)).inverse 2773 return torch.linalg.inv_ex(input).inverse 2774 2775 for torch_inverse in [torch.inverse, torch.linalg.inv, test_inv_ex]: 2776 for batches, n in itertools.product( 2777 [[], [0], [2], [2, 1]], 2778 [0, 5] 2779 ): 2780 matrices = make_arg(*batches, n, n) 2781 run_test(torch_inverse, matrices, batches, n) 2782 2783 # test non-contiguous input 2784 run_test(torch_inverse, matrices.mT, batches, n) 2785 if n > 0: 2786 run_test( 2787 torch_inverse, 2788 make_arg(*batches, 2 * n, 2 * n) 2789 .view(-1, n * 2, n * 2)[:, ::2, ::2].view(*batches, n, n), 2790 batches, n 2791 ) 2792 2793 @skipCUDAIfNoMagmaAndNoCusolver 2794 @skipCPUIfNoLapack 2795 @dtypes(*floating_and_complex_types()) 2796 def test_inv_ex_info_device(self, device, dtype): 2797 A = torch.eye(3, 3, dtype=dtype, device=device) 2798 info = torch.linalg.inv_ex(A).info 2799 self.assertTrue(info.device == A.device) 2800 2801 @skipCUDAIfNoMagmaAndNoCusolver 2802 @skipCPUIfNoLapack 2803 @dtypes(*floating_and_complex_types()) 2804 def test_inv_ex_singular(self, device, dtype): 2805 # if the input matrix is not invertible, info with positive integer is returned 2806 A = torch.eye(3, 3, dtype=dtype, device=device) 2807 A[-1, -1] = 0 # Now A is singular 2808 info = torch.linalg.inv_ex(A).info 2809 self.assertEqual(info, 3) 2810 with self.assertRaisesRegex(torch.linalg.LinAlgError, 2811 r'diagonal element 3 is zero, the inversion could not be completed'): 2812 torch.linalg.inv_ex(A, check_errors=True) 2813 2814 # if at least one matrix in the batch is not positive definite, 2815 # batched info with positive integer for the corresponding matrix is returned 2816 A = torch.eye(3, 3, dtype=dtype, device=device) 2817 A = A.reshape((1, 3, 3)) 2818 A = A.repeat(5, 1, 1) 2819 A[3, -2, -2] = 0 # Now A[3] is singular 2820 info = torch.linalg.inv_ex(A).info 2821 2822 expected_info = torch.zeros(A.shape[:-2], dtype=torch.int32, device=device) 2823 expected_info[3] = 2 2824 self.assertEqual(info, expected_info) 2825 with self.assertRaisesRegex(torch.linalg.LinAlgError, r'\(Batch element 3\): The diagonal element 2 is zero'): 2826 torch.linalg.inv_ex(A, check_errors=True) 2827 2828 @slowTest 2829 @skipCUDAIfNoMagmaAndNoCusolver 2830 @skipCPUIfNoLapack 2831 @dtypes(*floating_and_complex_types()) 2832 @precisionOverride({torch.float32: 2e-3, torch.complex64: 2e-3, 2833 torch.float64: 1e-5, torch.complex128: 1e-5}) 2834 def test_inverse_many_batches(self, device, dtype): 2835 make_fullrank = make_fullrank_matrices_with_distinct_singular_values 2836 make_arg = partial(make_fullrank, device=device, dtype=dtype) 2837 2838 def test_inverse_many_batches_helper(torch_inverse, b, n): 2839 matrices = make_arg(b, n, n) 2840 matrices_inverse = torch_inverse(matrices) 2841 2842 # Compare against NumPy output 2843 expected = np.linalg.inv(matrices.cpu().numpy()) 2844 self.assertEqual(matrices_inverse, expected, atol=self.precision, rtol=1e-3) 2845 2846 for torch_inverse in [torch.inverse, torch.linalg.inv]: 2847 test_inverse_many_batches_helper(torch_inverse, 5, 256) 2848 test_inverse_many_batches_helper(torch_inverse, 3, 512) 2849 2850 @skipCUDAIfNoMagmaAndNoCusolver 2851 @skipCPUIfNoLapack 2852 @onlyNativeDeviceTypes # TODO: XLA doesn't raise exception 2853 @dtypes(*floating_and_complex_types()) 2854 @skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/129882") 2855 def test_inverse_errors(self, device, dtype): 2856 # inverse expects batches of square matrices as input 2857 with self.assertRaisesRegex(RuntimeError, "must be batches of square matrices"): 2858 torch.inverse(torch.randn(2, 3, 4, 3)) 2859 2860 # if input is not invertible, RuntimeError is raised mentioning the first non-invertible batch 2861 def run_test_singular_input(batch_dim, n): 2862 x = torch.eye(3, 3, dtype=dtype, device=device).reshape((1, 3, 3)).repeat(batch_dim, 1, 1) 2863 x[n, -1, -1] = 0 2864 with self.assertRaisesRegex(torch.linalg.LinAlgError, rf'\(Batch element {n}\): The diagonal element 3 is zero'): 2865 torch.inverse(x) 2866 2867 for params in [(1, 0), (2, 0), (2, 1), (4, 0), (4, 2), (10, 2)]: 2868 run_test_singular_input(*params) 2869 2870 @unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Test fails for float64 on GPU (P100, V100) on Meta infra") 2871 @skipCUDAIfNoMagmaAndNoCusolver 2872 @skipCPUIfNoLapack 2873 @onlyNativeDeviceTypes # TODO: XLA doesn't raise exception 2874 @dtypes(*floating_and_complex_types()) 2875 def test_inverse_errors_large(self, device, dtype): 2876 # Test batched inverse of singular matrices reports errors without crashing (gh-51930) 2877 x = torch.empty((8, 10, 616, 616), dtype=dtype, device=device) 2878 x[:] = torch.eye(616, dtype=dtype, device=device) 2879 x[..., 10, 10] = 0 2880 with self.assertRaisesRegex(torch.linalg.LinAlgError, r'\(Batch element 0\): The diagonal element 11 is zero'): 2881 torch.inverse(x) 2882 2883 @precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3, torch.float64: 1e-7, torch.complex128: 1e-7}) 2884 @skipCUDAIfNoMagma 2885 @skipCPUIfNoLapack 2886 @dtypes(*floating_and_complex_types()) 2887 def test_pinv(self, device, dtype): 2888 from torch.testing._internal.common_utils import random_hermitian_pd_matrix 2889 2890 def run_test_main(A, hermitian): 2891 # Testing against definition for pseudo-inverses 2892 A_pinv = torch.linalg.pinv(A, hermitian=hermitian) 2893 np_A = A.cpu().numpy() 2894 np_A_pinv = A_pinv.cpu().numpy() 2895 if A.numel() > 0: 2896 self.assertEqual(A, np_A @ np_A_pinv @ np_A, atol=self.precision, rtol=self.precision) 2897 self.assertEqual(A_pinv, np_A_pinv @ np_A @ np_A_pinv, atol=self.precision, rtol=self.precision) 2898 self.assertEqual(np_A @ np_A_pinv, (np_A @ np_A_pinv).conj().swapaxes(-2, -1)) 2899 self.assertEqual(np_A_pinv @ np_A, (np_A_pinv @ np_A).conj().swapaxes(-2, -1)) 2900 else: 2901 self.assertEqual(A.shape, A_pinv.shape[:-2] + (A_pinv.shape[-1], A_pinv.shape[-2])) 2902 2903 # Check out= variant 2904 out = torch.empty_like(A_pinv) 2905 ans = torch.linalg.pinv(A, hermitian=hermitian, out=out) 2906 self.assertEqual(ans, out) 2907 self.assertEqual(ans, A_pinv) 2908 2909 def run_test_numpy(A, hermitian): 2910 # Check against NumPy output 2911 # Test float rcond, and specific value for each matrix 2912 rconds = [float(torch.rand(1)), ] 2913 # Test different types of rcond tensor 2914 for rcond_type in all_types(): 2915 rconds.append(torch.rand(A.shape[:-2], dtype=torch.double, device=device).to(rcond_type)) 2916 # Test broadcasting of rcond 2917 if A.ndim > 2: 2918 rconds.append(torch.rand(A.shape[-3], device=device)) 2919 for rcond in rconds: 2920 actual = torch.linalg.pinv(A, rcond=rcond, hermitian=hermitian) 2921 torch_rtol = torch.linalg.pinv(A, rtol=rcond, hermitian=hermitian) 2922 self.assertEqual(actual, torch_rtol) 2923 numpy_rcond = rcond if isinstance(rcond, float) else rcond.cpu().numpy() 2924 expected = np.linalg.pinv(A.cpu().numpy(), rcond=numpy_rcond, hermitian=hermitian) 2925 self.assertEqual(actual, expected, atol=self.precision, rtol=1e-5) 2926 2927 for sizes in [(5, 5), (3, 5, 5), (3, 2, 5, 5), # square matrices 2928 (3, 2), (5, 3, 2), (2, 5, 3, 2), # fat matrices 2929 (2, 3), (5, 2, 3), (2, 5, 2, 3), # thin matrices 2930 (0, 0), (0, 2), (2, 0), (3, 0, 0), (0, 3, 0), (0, 0, 3)]: # zero numel matrices 2931 A = torch.randn(*sizes, dtype=dtype, device=device) 2932 hermitian = False 2933 run_test_main(A, hermitian) 2934 run_test_numpy(A, hermitian) 2935 2936 # Check hermitian = True 2937 for sizes in [(5, 5), (3, 5, 5), (3, 2, 5, 5), # square matrices 2938 (0, 0), (3, 0, 0), ]: # zero numel square matrices 2939 A = random_hermitian_pd_matrix(sizes[-1], *sizes[:-2], dtype=dtype, device=device) 2940 hermitian = True 2941 run_test_main(A, hermitian) 2942 run_test_numpy(A, hermitian) 2943 2944 @skipCUDAIfNoMagma 2945 @skipCPUIfNoLapack 2946 @dtypes(*floating_and_complex_types()) 2947 def test_pinv_errors_and_warnings(self, device, dtype): 2948 # pinv requires at least 2D tensor 2949 a = torch.randn(1, device=device, dtype=dtype) 2950 with self.assertRaisesRegex(RuntimeError, "expected a tensor with 2 or more dimensions"): 2951 torch.linalg.pinv(a) 2952 2953 # if non-empty out tensor with wrong shape is passed a warning is given 2954 a = torch.randn(3, 3, dtype=dtype, device=device) 2955 out = torch.empty(7, 7, dtype=dtype, device=device) 2956 with warnings.catch_warnings(record=True) as w: 2957 # Trigger warning 2958 torch.linalg.pinv(a, out=out) 2959 # Check warning occurs 2960 self.assertEqual(len(w), 1) 2961 self.assertTrue("An output with one or more elements was resized" in str(w[-1].message)) 2962 2963 # dtypes of out and input should be safely castable 2964 out = torch.empty_like(a).to(torch.int) 2965 with self.assertRaisesRegex(RuntimeError, "but got result with dtype Int"): 2966 torch.linalg.pinv(a, out=out) 2967 2968 if torch.cuda.is_available(): 2969 # device of out and input should match 2970 wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda' 2971 out = torch.empty_like(a).to(wrong_device) 2972 with self.assertRaisesRegex(RuntimeError, "Expected result and input tensors to be on the same device"): 2973 torch.linalg.pinv(a, out=out) 2974 2975 # device of rcond and input should match 2976 wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda' 2977 rcond = torch.full((), 1e-2, device=wrong_device) 2978 with self.assertRaisesRegex(RuntimeError, "Expected all tensors to be on the same device"): 2979 torch.linalg.pinv(a, rcond=rcond) 2980 2981 # rcond can't be complex 2982 rcond = torch.full((), 1j, device=device) 2983 with self.assertRaisesRegex(RuntimeError, "rcond tensor of complex type is not supported"): 2984 torch.linalg.pinv(a, rcond=rcond) 2985 2986 # atol can't be complex 2987 atol = torch.full((), 1j, device=device) 2988 with self.assertRaisesRegex(RuntimeError, "atol tensor of complex type is not supported"): 2989 torch.linalg.pinv(a, atol=atol) 2990 2991 # rtol can't be complex 2992 rtol = torch.full((), 1j, device=device) 2993 with self.assertRaisesRegex(RuntimeError, "rtol tensor of complex type is not supported"): 2994 torch.linalg.pinv(a, rtol=rtol) 2995 2996 @skipCUDAIfNoMagmaAndNoCusolver 2997 @skipCPUIfNoLapack 2998 @dtypes(*floating_and_complex_types()) 2999 @skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/129882") 3000 def test_inv_errors_and_warnings(self, device, dtype): 3001 # inv expects batches of square matrices as input 3002 a = torch.randn(2, 3, 4, 3, dtype=dtype, device=device) 3003 with self.assertRaisesRegex(RuntimeError, "must be batches of square matrices"): 3004 torch.linalg.inv(a) 3005 3006 # inv requires the input to be at least 2 dimensional tensor 3007 a = torch.randn(2, device=device, dtype=dtype) 3008 with self.assertRaisesRegex(RuntimeError, "must have at least 2 dimensions"): 3009 torch.linalg.inv(a) 3010 3011 # if input is not invertible, RuntimeError is raised mentioning the first non-invertible batch 3012 def run_test_singular_input(batch_dim, n): 3013 a = torch.eye(3, 3, dtype=dtype, device=device).reshape((1, 3, 3)).repeat(batch_dim, 1, 1) 3014 a[n, -1, -1] = 0 3015 with self.assertRaisesRegex(torch.linalg.LinAlgError, rf"\(Batch element {n}\): The diagonal element 3 is zero"): 3016 torch.linalg.inv(a) 3017 3018 for params in [(1, 0), (2, 0), (2, 1), (4, 0), (4, 2), (10, 2)]: 3019 run_test_singular_input(*params) 3020 3021 # dtypes should match 3022 a = torch.eye(2, dtype=dtype, device=device) 3023 out = torch.empty(0, dtype=torch.int, device=device) 3024 with self.assertRaisesRegex(RuntimeError, "but got int instead"): 3025 torch.linalg.inv(a, out=out) 3026 3027 # device should match 3028 if torch.cuda.is_available(): 3029 wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda' 3030 out = torch.empty(0, device=wrong_device, dtype=dtype) 3031 with self.assertRaisesRegex(RuntimeError, "tensors to be on the same device"): 3032 torch.linalg.inv(a, out=out) 3033 3034 # if out tensor with wrong shape is passed a warning is given 3035 with warnings.catch_warnings(record=True) as w: 3036 a = torch.eye(2, dtype=dtype, device=device) 3037 out = torch.empty(1, dtype=dtype, device=device) 3038 # Trigger warning 3039 torch.linalg.inv(a, out=out) 3040 # Check warning occurs 3041 self.assertEqual(len(w), 1) 3042 self.assertTrue("An output with one or more elements was resized" in str(w[-1].message)) 3043 3044 # if out tensor in batched column major format but with wrong a warning is given 3045 with warnings.catch_warnings(record=True) as w: 3046 a = torch.eye(2, dtype=dtype, device=device) 3047 out = torch.empty(3, 3, dtype=dtype, device=device) 3048 out = out.mT.clone(memory_format=torch.contiguous_format) 3049 out = out.mT 3050 self.assertTrue(out.mT.is_contiguous()) 3051 # Trigger warning 3052 torch.linalg.inv(a, out=out) 3053 # Check warning occurs 3054 self.assertEqual(len(w), 1) 3055 self.assertTrue("An output with one or more elements was resized" in str(w[-1].message)) 3056 3057 def solve_test_helper(self, A_dims, b_dims, device, dtype): 3058 make_fullrank = make_fullrank_matrices_with_distinct_singular_values 3059 make_A = partial(make_fullrank, device=device, dtype=dtype) 3060 3061 b = torch.randn(*b_dims, dtype=dtype, device=device) 3062 A = make_A(*A_dims) 3063 return b, A 3064 3065 @skipCUDAIfNoMagma 3066 @skipCPUIfNoLapack 3067 @dtypes(*floating_and_complex_types()) 3068 @precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3}) 3069 def test_solve(self, device, dtype): 3070 def run_test(n, batch, rhs): 3071 A_dims = (*batch, n, n) 3072 b_dims = (*batch, n, *rhs) 3073 b, A = self.solve_test_helper(A_dims, b_dims, device, dtype) 3074 3075 # Correctness test 3076 x = torch.linalg.solve(A, b) 3077 if rhs == (): 3078 Ax = np.matmul(A.cpu(), x.unsqueeze(-1).cpu()) 3079 Ax.squeeze_(-1) 3080 else: 3081 Ax = np.matmul(A.cpu(), x.cpu()) 3082 self.assertEqual(b.expand_as(Ax), Ax) 3083 3084 # Check against NumPy 3085 expected = np.linalg.solve(A.cpu().numpy(), b.expand_as(x).cpu().numpy()) 3086 self.assertEqual(x, expected) 3087 3088 batches = [(), (0, ), (3, ), (2, 3)] 3089 ns = [0, 5, 32] 3090 nrhs = [(), (1, ), (5, )] 3091 for n, batch, rhs in itertools.product(ns, batches, nrhs): 3092 run_test(n, batch, rhs) 3093 3094 @skipCUDAIfNoMagmaAndNoCusolver 3095 @skipCPUIfNoLapack 3096 @dtypes(*floating_and_complex_types()) 3097 def test_solve_batched_broadcasting(self, device, dtype): 3098 from numpy.linalg import solve 3099 3100 def run_test(A_dims, B_dims): 3101 A_matrix_size = A_dims[-1] 3102 A_batch_dims = A_dims[:-2] 3103 B, A = self.solve_test_helper(A_batch_dims + (A_matrix_size, A_matrix_size), B_dims, device, dtype) 3104 actual = torch.linalg.solve(A, B) 3105 expected = solve(A.cpu().numpy(), B.cpu().numpy()) 3106 self.assertEqual(actual, expected) 3107 3108 # test against numpy.linalg.solve 3109 run_test((5, 5), (2, 0, 5, 3)) # broadcasting with 0 batch dim 3110 run_test((2, 0, 5, 5), (5, 3)) # broadcasting with 0 batch dim 3111 run_test((2, 1, 3, 4, 4), (4, 6)) # broadcasting B 3112 run_test((4, 4), (2, 1, 3, 4, 2)) # broadcasting A 3113 run_test((1, 3, 1, 4, 4), (2, 1, 3, 4, 5)) # broadcasting A & B 3114 3115 @skipCUDAIfNoMagma 3116 @skipCPUIfNoLapack 3117 @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble) 3118 @precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4}) 3119 def test_tensorsolve(self, device, dtype): 3120 def run_test(a_shape, dims): 3121 a = torch.randn(a_shape, dtype=dtype, device=device) 3122 b = torch.randn(a_shape[:2], dtype=dtype, device=device) 3123 result = torch.linalg.tensorsolve(a, b, dims=dims) 3124 expected = np.linalg.tensorsolve(a.cpu().numpy(), b.cpu().numpy(), axes=dims) 3125 self.assertEqual(result, expected) 3126 3127 # check the out= variant 3128 out = torch.empty_like(result) 3129 ans = torch.linalg.tensorsolve(a, b, dims=dims, out=out) 3130 self.assertEqual(ans, out) 3131 self.assertEqual(ans, result) 3132 3133 a_shapes = [(2, 3, 6), (3, 4, 4, 3)] 3134 dims = [None, (0, 2)] 3135 for a_shape, d in itertools.product(a_shapes, dims): 3136 run_test(a_shape, d) 3137 3138 @skipCUDAIfNoMagma 3139 @skipCPUIfNoLapack 3140 @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble) 3141 def test_tensorsolve_empty(self, device, dtype): 3142 # Check for empty inputs. NumPy does not work for these cases. 3143 a = torch.empty(0, 0, 1, 2, 3, 0, dtype=dtype, device=device) 3144 b = torch.empty(a.shape[:2], dtype=dtype, device=device) 3145 x = torch.linalg.tensorsolve(a, b) 3146 self.assertEqual(torch.tensordot(a, x, dims=len(x.shape)), b) 3147 3148 @skipCUDAIfNoMagma 3149 @skipCPUIfNoLapack 3150 @dtypes(torch.float32) 3151 def test_tensorsolve_errors_and_warnings(self, device, dtype): 3152 # tensorsolve expects the input that can be reshaped to a square matrix 3153 a = torch.eye(2 * 3 * 4, dtype=dtype, device=device).reshape((2 * 3, 4, 2, 3, 4)) 3154 b = torch.randn(8, 4, dtype=dtype, device=device) 3155 self.assertTrue(np.prod(a.shape[2:]) != np.prod(b.shape)) 3156 with self.assertRaisesRegex(RuntimeError, r'Expected self to satisfy the requirement'): 3157 torch.linalg.tensorsolve(a, b) 3158 3159 # if non-empty out tensor with wrong shape is passed a warning is given 3160 out = torch.empty_like(a) 3161 b = torch.randn(6, 4, dtype=dtype, device=device) 3162 with warnings.catch_warnings(record=True) as w: 3163 # Trigger warning 3164 torch.linalg.tensorsolve(a, b, out=out) 3165 # Check warning occurs 3166 self.assertEqual(len(w), 1) 3167 self.assertTrue("An output with one or more elements was resized" in str(w[-1].message)) 3168 3169 # dtypes should be safely castable 3170 out = torch.empty_like(a).to(torch.int) 3171 with self.assertRaisesRegex(RuntimeError, "but got result with dtype Int"): 3172 torch.linalg.tensorsolve(a, b, out=out) 3173 3174 # device should match 3175 if torch.cuda.is_available(): 3176 wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda' 3177 out = torch.empty(0, dtype=dtype, device=wrong_device) 3178 with self.assertRaisesRegex(RuntimeError, "tensors to be on the same device"): 3179 torch.linalg.tensorsolve(a, b, out=out) 3180 3181 @skipCUDAIfNoMagma 3182 @skipCPUIfNoLapack 3183 @dtypes(*floating_and_complex_types()) 3184 @precisionOverride({torch.float: 1e-3, torch.cfloat: 1e-3}) 3185 def test_tensorinv(self, device, dtype): 3186 3187 def run_test(a_shape, ind): 3188 a = torch.randn(a_shape, dtype=dtype, device=device) 3189 a_numpy = a.cpu().numpy() 3190 result = torch.linalg.tensorinv(a, ind=ind) 3191 expected = np.linalg.tensorinv(a_numpy, ind=ind) 3192 self.assertEqual(result, expected) 3193 3194 # check the out= variant 3195 out = torch.empty_like(result) 3196 ans = torch.linalg.tensorinv(a, ind=ind, out=out) 3197 self.assertEqual(ans, out) 3198 self.assertEqual(ans, result) 3199 3200 # compare to NumPy output 3201 run_test((12, 3, 4), ind=1) 3202 run_test((3, 8, 24), ind=2) 3203 run_test((18, 3, 3, 2), ind=1) 3204 run_test((1, 4, 2, 2), ind=2) 3205 run_test((2, 3, 5, 30), ind=3) 3206 run_test((24, 2, 2, 3, 2), ind=1) 3207 run_test((3, 4, 2, 3, 2), ind=2) 3208 run_test((1, 2, 3, 2, 3), ind=3) 3209 run_test((3, 2, 1, 2, 12), ind=4) 3210 3211 @skipMeta # See https://github.com/pytorch/pytorch/issues/53739 3212 @skipCUDAIfNoMagma 3213 @skipCPUIfNoLapack 3214 @dtypes(*floating_and_complex_types()) 3215 def test_tensorinv_empty(self, device, dtype): 3216 for ind in range(1, 4): 3217 # Check for empty inputs. NumPy does not work for these cases. 3218 a = torch.empty(0, 0, 1, 2, 3, 0, dtype=dtype, device=device) 3219 a_inv = torch.linalg.tensorinv(a, ind=ind) 3220 self.assertEqual(a_inv.shape, a.shape[ind:] + a.shape[:ind]) 3221 3222 @skipMeta # See https://github.com/pytorch/pytorch/issues/53739 3223 @skipCUDAIfNoMagma 3224 @skipCPUIfNoLapack 3225 @dtypes(*floating_and_complex_types()) 3226 def test_tensorinv_errors_and_warnings(self, device, dtype): 3227 3228 def check_shape(a_shape, ind): 3229 # tensorinv requires the input to satisfy 3230 # prod(a.shape[ind:]) == prod(a.shape[:ind]) 3231 a = torch.randn(a_shape, dtype=dtype, device=device) 3232 with self.assertRaisesRegex(RuntimeError, "Expected self to satisfy the requirement"): 3233 torch.linalg.tensorinv(a, ind=ind) 3234 3235 def check_ind(a_shape, ind): 3236 a = torch.randn(a_shape, dtype=dtype, device=device) 3237 with self.assertRaisesRegex(RuntimeError, "Expected a strictly positive integer"): 3238 torch.linalg.tensorinv(a, ind=ind) 3239 3240 def check_out(a_shape, ind): 3241 # if non-empty out tensor with wrong shape is passed a warning is given 3242 a = torch.randn(a_shape, dtype=dtype, device=device) 3243 out = torch.empty_like(a) 3244 with warnings.catch_warnings(record=True) as w: 3245 # Trigger warning 3246 torch.linalg.tensorinv(a, ind=ind, out=out) 3247 # Check warning occurs 3248 self.assertEqual(len(w), 1) 3249 self.assertTrue("An output with one or more elements was resized" in str(w[-1].message)) 3250 3251 # dtypes should be safely castable 3252 out = torch.empty(0, dtype=torch.int, device=device) 3253 with self.assertRaisesRegex(RuntimeError, "but got result with dtype Int"): 3254 torch.linalg.tensorinv(a, ind=ind, out=out) 3255 3256 # device should match 3257 if torch.cuda.is_available(): 3258 wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda' 3259 out = torch.empty(0, dtype=dtype, device=wrong_device) 3260 with self.assertRaisesRegex(RuntimeError, "tensors to be on the same device"): 3261 torch.linalg.tensorinv(a, ind=ind, out=out) 3262 3263 # test for invalid shape 3264 check_shape((2, 3, 4), ind=1) 3265 check_shape((1, 2, 3, 4), ind=3) 3266 3267 # test for invalid ind 3268 check_ind((12, 3, 4), ind=-1) 3269 check_ind((18, 3, 3, 2), ind=0) 3270 3271 # test for invalid out tensor 3272 check_out((12, 3, 4), ind=1) 3273 check_out((3, 8, 24), ind=2) 3274 3275 @skipCUDAIfNoMagma 3276 @skipCPUIfNoLapack 3277 @dtypes(*floating_and_complex_types()) 3278 def test_tensorinv_singular_input(self, device, dtype): 3279 3280 def check_singular_input(a_shape, ind): 3281 prod_ind_end = np.prod(a_shape[ind:]) 3282 a = torch.eye(prod_ind_end, dtype=dtype, device=device) 3283 a[-1, -1] = 0 # Now `a` is singular 3284 a = a.reshape(a_shape) 3285 with self.assertRaisesRegex(torch.linalg.LinAlgError, "The diagonal element"): 3286 torch.linalg.tensorinv(a, ind=ind) 3287 3288 # test for non-invertible input 3289 check_singular_input((12, 3, 4), ind=1) 3290 check_singular_input((3, 6, 18), ind=2) 3291 3292 def _test_dot_vdot_vs_numpy(self, device, dtype, torch_fn, np_fn): 3293 def check(x, y): 3294 # Compare with numpy 3295 res = torch_fn(x, y) 3296 if x.dtype == torch.bfloat16: 3297 ref = torch.from_numpy(np.array(np_fn(x.cpu().float().numpy(), y.cpu().float().numpy()))) 3298 else: 3299 ref = torch.from_numpy(np.array(np_fn(x.cpu().numpy(), y.cpu().numpy()))) 3300 if res.dtype == torch.bfloat16: 3301 self.assertEqual(res.cpu(), ref.bfloat16()) 3302 else: 3303 self.assertEqual(res.cpu(), ref) 3304 3305 # Test out variant 3306 out = torch.empty_like(res) 3307 torch_fn(x, y, out=out) 3308 self.assertEqual(out, res) 3309 3310 # Empty 3311 x = torch.tensor([], dtype=dtype, device=device) 3312 y = torch.tensor([], dtype=dtype, device=device) 3313 check(x, y) 3314 3315 # Contiguous 3316 x = 0.1 * torch.randn(5000, dtype=dtype, device=device) 3317 y = 0.1 * torch.randn(5000, dtype=dtype, device=device) 3318 check(x, y) 3319 3320 # 0 strided 3321 y = 0.1 * torch.randn(1, dtype=dtype, device=device).expand(5000) 3322 check(x, y) 3323 3324 # 2 strided 3325 check(x[::2], y[::2]) 3326 3327 @dtypes(torch.float, torch.cfloat, torch.bfloat16, torch.float16) 3328 @dtypesIfCUDA(torch.float, torch.cfloat) 3329 @precisionOverride({torch.cfloat: 1e-4, torch.float32: 5e-5, torch.bfloat16: 1e-0}) 3330 def test_dot_vs_numpy(self, device, dtype): 3331 self._test_dot_vdot_vs_numpy(device, dtype, torch.dot, np.dot) 3332 3333 @dtypes(torch.float, torch.cfloat) 3334 @precisionOverride({torch.cfloat: 1e-4, torch.float32: 5e-5}) 3335 def test_vdot_vs_numpy(self, device, dtype): 3336 self._test_dot_vdot_vs_numpy(device, dtype, torch.vdot, np.vdot) 3337 3338 def _test_dot_vdot_invalid_args(self, device, torch_fn, complex_dtypes=False): 3339 def check(x, y, regex): 3340 with self.assertRaisesRegex(RuntimeError, regex): 3341 torch_fn(x, y) 3342 3343 if complex_dtypes: 3344 x = torch.randn(1, dtype=torch.cfloat, device=device) 3345 y = torch.randn(3, dtype=torch.cdouble, device=device) 3346 else: 3347 x = torch.randn(1, dtype=torch.float, device=device) 3348 y = torch.randn(3, dtype=torch.double, device=device) 3349 3350 check(x, y, 'dot : expected both vectors to have same dtype') 3351 check(x.reshape(1, 1), y, '1D tensors expected') 3352 check(x.expand(9), y.to(x.dtype), 'inconsistent tensor size') 3353 3354 if self.device_type != 'cpu': 3355 x_cpu = x.expand(3).cpu() 3356 check(x_cpu, y.to(x.dtype), 'Expected all tensors to be on the same device') 3357 3358 @onlyNativeDeviceTypes 3359 def test_vdot_invalid_args(self, device): 3360 self._test_dot_vdot_invalid_args(device, torch.vdot) 3361 self._test_dot_vdot_invalid_args(device, torch.vdot, complex_dtypes=True) 3362 3363 @onlyNativeDeviceTypes 3364 def test_dot_invalid_args(self, device): 3365 self._test_dot_vdot_invalid_args(device, torch.dot) 3366 self._test_dot_vdot_invalid_args(device, torch.dot, complex_dtypes=True) 3367 3368 @skipCUDAIfNoMagma 3369 @skipCPUIfNoLapack 3370 @dtypes(*floating_and_complex_types()) 3371 def test_matrix_rank(self, device, dtype): 3372 matrix_rank = torch.linalg.matrix_rank 3373 3374 def run_test(shape0, shape1, batch): 3375 a = torch.randn(*batch, shape0, shape1, dtype=dtype, device=device) 3376 rank_a = matrix_rank(a) 3377 3378 self.assertEqual(rank_a, matrix_rank(a.mH)) 3379 aaH = torch.matmul(a, a.mH) 3380 rank_aaH = matrix_rank(aaH) 3381 rank_aaH_hermitian = matrix_rank(aaH, hermitian=True) 3382 self.assertEqual(rank_aaH, rank_aaH_hermitian) 3383 aHa = torch.matmul(a.mH, a) 3384 self.assertEqual(matrix_rank(aHa), matrix_rank(aHa, hermitian=True)) 3385 3386 # check against NumPy 3387 self.assertEqual(rank_a, np.linalg.matrix_rank(a.cpu().numpy())) 3388 self.assertEqual(matrix_rank(a, 0.01), np.linalg.matrix_rank(a.cpu().numpy(), 0.01)) 3389 3390 self.assertEqual(rank_aaH, np.linalg.matrix_rank(aaH.cpu().numpy())) 3391 self.assertEqual(matrix_rank(aaH, 0.01), np.linalg.matrix_rank(aaH.cpu().numpy(), 0.01)) 3392 3393 # hermitian flag for NumPy was added in 1.14.0 3394 if np.lib.NumpyVersion(np.__version__) >= '1.14.0': 3395 self.assertEqual(rank_aaH_hermitian, 3396 np.linalg.matrix_rank(aaH.cpu().numpy(), hermitian=True)) 3397 self.assertEqual(matrix_rank(aaH, 0.01, True), 3398 np.linalg.matrix_rank(aaH.cpu().numpy(), 0.01, True)) 3399 3400 # check out= variant 3401 out = torch.empty(a.shape[:-2], dtype=torch.int64, device=device) 3402 ans = matrix_rank(a, out=out) 3403 self.assertEqual(ans, out) 3404 self.assertEqual(ans, rank_a) 3405 3406 shapes = (3, 13) 3407 batches = ((), (0, ), (4, ), (3, 5, )) 3408 for (shape0, shape1), batch in zip(itertools.product(shapes, reversed(shapes)), batches): 3409 run_test(shape0, shape1, batch) 3410 3411 @skipCUDAIfNoMagma 3412 @skipCPUIfNoLapack 3413 @dtypes(*floating_and_complex_types()) 3414 def test_matrix_rank_atol(self, device, dtype): 3415 3416 def run_test_atol(shape0, shape1, batch): 3417 a = make_tensor((*batch, shape0, shape1), dtype=dtype, device=device) 3418 # Check against NumPy output 3419 # Test float tol, and specific value for each matrix 3420 tolerances = [float(torch.rand(1)), ] 3421 # Test different types of tol tensor 3422 for tol_type in all_types(): 3423 tolerances.append(make_tensor(a.shape[:-2], dtype=tol_type, device=device, low=0)) 3424 # Test broadcasting of tol 3425 if a.ndim > 2: 3426 tolerances.append(make_tensor(a.shape[-3], dtype=torch.float32, device=device, low=0)) 3427 for tol in tolerances: 3428 actual = torch.linalg.matrix_rank(a, atol=tol) 3429 actual_tol = torch.linalg.matrix_rank(a, tol=tol) 3430 self.assertEqual(actual, actual_tol) 3431 numpy_tol = tol if isinstance(tol, float) else tol.cpu().numpy() 3432 expected = np.linalg.matrix_rank(a.cpu().numpy(), tol=numpy_tol) 3433 self.assertEqual(actual, expected) 3434 3435 shapes = (3, 13) 3436 batches = ((), (0, ), (4, ), (3, 5, )) 3437 for (shape0, shape1), batch in zip(itertools.product(shapes, reversed(shapes)), batches): 3438 run_test_atol(shape0, shape1, batch) 3439 3440 @skipCUDAIfNoMagma 3441 @skipCPUIfNoLapack 3442 @dtypes(torch.float64) 3443 def test_matrix_rank_atol_rtol(self, device, dtype): 3444 make_fullrank = make_fullrank_matrices_with_distinct_singular_values 3445 make_arg = partial(make_fullrank, device=device, dtype=dtype) 3446 3447 # creates a matrix with singular values rank=n and singular values in range [2/3, 3/2] 3448 # the singular values are 1 + 1/2, 1 - 1/3, 1 + 1/4, 1 - 1/5, ... 3449 n = 9 3450 a = make_arg(n, n) 3451 3452 # test float and tensor variants 3453 for tol_value in [0.81, torch.tensor(0.81, device=device)]: 3454 # using rtol (relative tolerance) takes into account the largest singular value (1.5 in this case) 3455 result = torch.linalg.matrix_rank(a, rtol=tol_value) 3456 self.assertEqual(result, 2) # there are 2 singular values above 1.5*0.81 = 1.215 3457 3458 # atol is used directly to compare with singular values 3459 result = torch.linalg.matrix_rank(a, atol=tol_value) 3460 self.assertEqual(result, 7) # there are 7 singular values above 0.81 3461 3462 # when both are specified the maximum tolerance is used 3463 result = torch.linalg.matrix_rank(a, atol=tol_value, rtol=tol_value) 3464 self.assertEqual(result, 2) # there are 2 singular values above max(0.81, 1.5*0.81) 3465 3466 @skipCUDAIfNoMagma 3467 @skipCPUIfNoLapack 3468 @skipCUDAVersionIn([(11, 6), (11, 7)]) # https://github.com/pytorch/pytorch/issues/75391 3469 @dtypes(*floating_and_complex_types()) 3470 def test_matrix_rank_empty(self, device, dtype): 3471 matrix_rank = torch.linalg.matrix_rank 3472 3473 # NumPy doesn't work for input with no elements 3474 def run_test(shape0, shape1, batch): 3475 a = torch.randn(*batch, shape0, shape1, dtype=dtype, device=device) 3476 rank_a = matrix_rank(a) 3477 expected = torch.zeros(batch, dtype=torch.int64, device=device) 3478 3479 self.assertEqual(rank_a, matrix_rank(a.mH)) 3480 3481 aaH = torch.matmul(a, a.mH) 3482 rank_aaH = matrix_rank(aaH) 3483 rank_aaH_hermitian = matrix_rank(aaH, hermitian=True) 3484 self.assertEqual(rank_aaH, rank_aaH_hermitian) 3485 3486 aHa = torch.matmul(a.mH, a) 3487 self.assertEqual(matrix_rank(aHa), matrix_rank(aHa, hermitian=True)) 3488 3489 self.assertEqual(rank_a, expected) 3490 self.assertEqual(matrix_rank(a, 0.01), expected) 3491 3492 self.assertEqual(rank_aaH, expected) 3493 self.assertEqual(matrix_rank(aaH, 0.01), expected) 3494 3495 self.assertEqual(rank_aaH_hermitian, expected) 3496 self.assertEqual(matrix_rank(aaH, 0.01, True), expected) 3497 3498 batches = ((), (4, ), (3, 5, )) 3499 for batch in batches: 3500 run_test(0, 0, batch) 3501 run_test(0, 3, batch) 3502 run_test(3, 0, batch) 3503 3504 @skipCUDAIfNoMagma 3505 @skipCPUIfNoLapack 3506 @dtypes(*floating_and_complex_types()) 3507 def test_matrix_rank_out_errors_and_warnings(self, device, dtype): 3508 # dtypes should be safely castable 3509 a = torch.eye(2, dtype=dtype, device=device) 3510 out = torch.empty(0, dtype=torch.bool, device=device) 3511 with self.assertRaisesRegex(RuntimeError, "but got result with dtype Bool"): 3512 torch.linalg.matrix_rank(a, out=out) 3513 3514 # device should match 3515 if torch.cuda.is_available(): 3516 wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda' 3517 out = torch.empty(0, dtype=dtype, device=wrong_device) 3518 with self.assertRaisesRegex(RuntimeError, "tensors to be on the same device"): 3519 torch.linalg.matrix_rank(a, out=out) 3520 3521 # if out tensor with wrong shape is passed a warning is given 3522 with warnings.catch_warnings(record=True) as w: 3523 out = torch.empty(3, dtype=dtype, device=device) 3524 # Trigger warning 3525 torch.linalg.matrix_rank(a, out=out) 3526 # Check warning occurs 3527 self.assertEqual(len(w), 1) 3528 self.assertTrue("An output with one or more elements was resized" in str(w[-1].message)) 3529 3530 @skipCUDAIfNoMagma 3531 @skipCPUIfNoLapack 3532 @dtypes(*floating_and_complex_types()) 3533 def test_matrix_rank_basic(self, device, dtype): 3534 matrix_rank = torch.linalg.matrix_rank 3535 3536 a = torch.eye(10, dtype=dtype, device=device) 3537 self.assertEqual(matrix_rank(a).item(), 10) 3538 self.assertEqual(matrix_rank(a, hermitian=True).item(), 10) 3539 3540 a[5, 5] = 0 3541 self.assertEqual(matrix_rank(a).item(), 9) 3542 self.assertEqual(matrix_rank(a, hermitian=True).item(), 9) 3543 3544 @onlyNativeDeviceTypes 3545 @dtypes(torch.double) 3546 # This tests only the cases where torch.chain_matmul differs from torch.linalg.multi_dot which this is an "alias" for. 3547 def test_chain_matmul(self, device, dtype): 3548 # chain_matmul accepts a single input tensor while multi_dot does not 3549 t = make_tensor((2, 2), dtype=dtype, device=device) 3550 self.assertEqual(t, torch.chain_matmul(t)) 3551 with self.assertRaisesRegex(RuntimeError, r"chain_matmul\(\): Expected one or more matrices"): 3552 torch.chain_matmul() 3553 3554 # chain_matmul expects all tensors to be 2D whereas multi_dot allows the first and last tensors to 3555 # be either 1D or 2D 3556 with self.assertRaisesRegex(RuntimeError, r"Tensor dimension is 1, expected 2 instead"): 3557 torch.chain_matmul(make_tensor(1, dtype=dtype, device=device), make_tensor(1, dtype=dtype, device=device)) 3558 3559 @onlyNativeDeviceTypes 3560 @dtypes(torch.double, torch.cdouble) 3561 def test_multi_dot(self, device, dtype): 3562 def check(*shapes): 3563 tensors = [make_tensor(shape, dtype=dtype, device=device) for shape in shapes] 3564 np_arrays = [tensor.cpu().numpy() for tensor in tensors] 3565 res = torch.linalg.multi_dot(tensors).cpu() 3566 ref = torch.from_numpy(np.array(np.linalg.multi_dot(np_arrays))) 3567 self.assertEqual(res, ref) 3568 3569 # test for inputs with empty dimensions 3570 check([0], [0]) 3571 check([2], [2, 0]) 3572 check([1, 0], [0]) 3573 check([0, 2], [2, 1]) 3574 check([2, 2], [2, 0]) 3575 check([2, 0], [0, 3]) 3576 check([0, 0], [0, 1]) 3577 check([4, 2], [2, 0], [0, 3], [3, 2]) 3578 3579 # test variable output shapes 3580 check([2], [2]) 3581 check([1, 2], [2]) 3582 check([2], [2, 1]) 3583 check([1, 2], [2, 1]) 3584 check([3, 2], [2, 4]) 3585 3586 # test multiple input tensors 3587 check([3], [3, 4], [4, 2], [2, 5], [5]) 3588 check([1, 2], [2, 2], [2, 3], [3, 1]) 3589 3590 # test large tensors 3591 check([10, 100], [100, 5], [5, 50]) 3592 check([10, 20], [20, 30], [30, 5]) 3593 3594 @onlyNativeDeviceTypes 3595 @dtypes(torch.float) 3596 def test_multi_dot_errors(self, device, dtype): 3597 def check(tensors, out, msg): 3598 with self.assertRaisesRegex(RuntimeError, msg): 3599 torch.linalg.multi_dot(tensors, out=out) 3600 3601 a = make_tensor(2, dtype=dtype, device=device) 3602 3603 check([], None, "expected at least 2 tensors") 3604 check([a], None, "expected at least 2 tensors") 3605 3606 check([torch.tensor(1, device=device, dtype=dtype), a], None, "the first tensor must be 1D or 2D") 3607 check([a, torch.tensor(1, device=device, dtype=dtype)], None, "the last tensor must be 1D or 2D") 3608 3609 check([a, a, a], None, "tensor 1 must be 2D") 3610 check([a, make_tensor((2, 2, 2), dtype=dtype, device=device), a], None, "tensor 1 must be 2D") 3611 3612 check([a, make_tensor(2, dtype=torch.double, device=device)], None, "all tensors must have be the same dtype") 3613 check([a, a], torch.empty(0, device=device, dtype=torch.double), "expected out tensor to have dtype") 3614 3615 if self.device_type == 'cuda': 3616 check([a, make_tensor(2, dtype=dtype, device="cpu")], None, "all tensors must be on the same device") 3617 check([a, a], torch.empty(0, dtype=dtype), "expected out tensor to be on device") 3618 3619 check([a, make_tensor(3, dtype=dtype, device=device)], None, "cannot be multiplied") 3620 check([a, make_tensor((3, 2), dtype=dtype, device=device), a], None, "cannot be multiplied") 3621 3622 @precisionOverride({torch.float32: 5e-6, torch.complex64: 5e-6}) 3623 @skipCUDAIfNoCusolver 3624 @skipCPUIfNoLapack 3625 @dtypes(*floating_and_complex_types()) 3626 def test_qr(self, device, dtype): 3627 def run_test(tensor_dims, some): 3628 A = torch.randn(*tensor_dims, dtype=dtype, device=device) 3629 Q, R = torch.qr(A, some=some) 3630 3631 # Check0: Q[-2:] = (m, n_columns), R[-2:] = (n_columns, n) 3632 m, n = tensor_dims[-2:] 3633 n_columns = m if (not some) and m > n else min(m, n) 3634 self.assertEqual(Q.size(-2), m) 3635 self.assertEqual(R.size(-1), n) 3636 self.assertEqual(Q.size(-1), n_columns) 3637 3638 A_ = A.cpu().numpy() 3639 Q_ = Q.cpu().numpy() 3640 R_ = R.cpu().numpy() 3641 3642 # Check1: A = QR 3643 self.assertEqual(A_, np.matmul(Q_, R_)) 3644 3645 # Check2: A = QR (with out) 3646 Q_out, R_out = torch.full_like(Q, math.nan), torch.full_like(R, math.nan) 3647 torch.qr(A, some=some, out=(Q_out, R_out)) 3648 Q_out_ = Q_out.cpu().numpy() 3649 R_out_ = R_out.cpu().numpy() 3650 self.assertEqual(A_, np.matmul(Q_out_, R_out_)) 3651 3652 # Check3: Q == Q_out, R == R_out 3653 self.assertEqual(Q_, Q_out_) 3654 self.assertEqual(R_, R_out_) 3655 3656 # Check4: Q^{T}Q = I, triu(R) = R 3657 eye = torch.eye(n_columns, device=device, dtype=dtype).expand(Q.shape[:-2] + (n_columns, n_columns)).cpu().numpy() 3658 self.assertEqual(np.matmul(Q_.swapaxes(-1, -2).conj(), Q_), eye) 3659 self.assertEqual(R.triu(), R) 3660 3661 tensor_dims_list = [(0, 5), (0, 0), (5, 0), # Empty Tensors 3662 (2, 1, 0, 5), (2, 1, 0, 0), (2, 1, 5, 0), (2, 0, 5, 5), # Batched empty Tensors 3663 (3, 5), (5, 5), (5, 3), # Single matrix 3664 (7, 3, 5), (7, 5, 5), (7, 5, 3), # 3-dim Tensors 3665 (7, 5, 3, 5), (7, 5, 5, 5), (7, 5, 5, 3)] # 4-dim Tensors 3666 for tensor_dims, some in itertools.product(tensor_dims_list, [True, False]): 3667 run_test(tensor_dims, some) 3668 3669 @skipCUDAIfNoCusolver 3670 @skipCPUIfNoLapack 3671 @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble) 3672 def test_qr_vs_numpy(self, device, dtype): 3673 """ 3674 test torch.linalg.qr vs numpy.linalg.qr 3675 """ 3676 sizes_to_test = [ 3677 (7, 5), 3678 (5, 7), 3679 (5, 0), # empty 3680 (0, 5), # empty 3681 ] 3682 for size in sizes_to_test: 3683 t = torch.randn(size, device=device, dtype=dtype) 3684 np_t = t.cpu().numpy() 3685 for mode in ['reduced', 'complete']: 3686 exp_q, exp_r = np.linalg.qr(np_t, mode=mode) 3687 q, r = torch.linalg.qr(t, mode=mode) 3688 self.assertEqual(q, exp_q) 3689 self.assertEqual(r, exp_r) 3690 # 3691 # for mode='r' we need a special logic because numpy returns only r 3692 exp_r = np.linalg.qr(np_t, mode='r') 3693 q, r = torch.linalg.qr(t, mode='r') 3694 # check that q is empty 3695 self.assertEqual(q.shape, (0,)) 3696 self.assertEqual(q.dtype, t.dtype) 3697 self.assertEqual(q.device, t.device) 3698 # check r 3699 self.assertEqual(r, exp_r) 3700 3701 @skipCUDAIfNoCusolver 3702 @skipCPUIfNoLapack 3703 @dtypes(torch.float) 3704 def test_linalg_qr_autograd_errors(self, device, dtype): 3705 # torch.linalg.qr(mode='r') returns only 'r' and discards 'q', but 3706 # without 'q' you cannot compute the backward pass. Check that 3707 # linalg_qr_backward complains cleanly in that case. 3708 inp = torch.randn((5, 7), device=device, dtype=dtype, requires_grad=True) 3709 q, r = torch.linalg.qr(inp, mode='r') 3710 self.assertEqual(q.shape, (0,)) # empty tensor 3711 b = torch.sum(r) 3712 with self.assertRaisesRegex(RuntimeError, 3713 "The derivative of linalg.qr depends on Q"): 3714 b.backward() 3715 inp = torch.randn((7, 5), device=device, dtype=dtype, requires_grad=True) 3716 q, r = torch.linalg.qr(inp, mode='complete') 3717 b = torch.sum(r) 3718 with self.assertRaisesRegex(RuntimeError, 3719 "The QR decomposition is not differentiable when mode='complete' and nrows > ncols"): 3720 b.backward() 3721 3722 @skipCUDAIfNoCusolver 3723 @skipCPUIfNoLapack 3724 @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble) 3725 def test_qr_batched(self, device, dtype): 3726 """ 3727 test torch.linalg.qr vs numpy.linalg.qr. We need some special logic 3728 because numpy does not support batched qr 3729 """ 3730 def np_qr_batched(a, mode): 3731 """poor's man batched version of np.linalg.qr""" 3732 all_q = [] 3733 all_r = [] 3734 for matrix in a: 3735 result = np.linalg.qr(matrix, mode=mode) 3736 if mode == 'r': 3737 all_r.append(result) 3738 else: 3739 q, r = result 3740 all_q.append(q) 3741 all_r.append(r) 3742 if mode == 'r': 3743 return np.array(all_r) 3744 else: 3745 return np.array(all_q), np.array(all_r) 3746 3747 t = torch.randn((3, 7, 5), device=device, dtype=dtype) 3748 np_t = t.cpu().numpy() 3749 for mode in ['reduced', 'complete']: 3750 exp_q, exp_r = np_qr_batched(np_t, mode=mode) 3751 q, r = torch.linalg.qr(t, mode=mode) 3752 self.assertEqual(q, exp_q) 3753 self.assertEqual(r, exp_r) 3754 # for mode='r' we need a special logic because numpy returns only r 3755 exp_r = np_qr_batched(np_t, mode='r') 3756 q, r = torch.linalg.qr(t, mode='r') 3757 # check that q is empty 3758 self.assertEqual(q.shape, (0,)) 3759 self.assertEqual(q.dtype, t.dtype) 3760 self.assertEqual(q.device, t.device) 3761 # check r 3762 self.assertEqual(r, exp_r) 3763 3764 @skipCUDAIfNoCusolver 3765 @skipCPUIfNoLapack 3766 @dtypes(torch.float) 3767 def test_qr_error_cases(self, device, dtype): 3768 t1 = torch.randn(5, device=device, dtype=dtype) 3769 with self.assertRaisesRegex(RuntimeError, 'linalg.qr: The input tensor A must have at least 2 dimensions.'): 3770 torch.linalg.qr(t1) 3771 t2 = torch.randn((5, 7), device=device, dtype=dtype) 3772 with self.assertRaisesRegex(RuntimeError, "qr received unrecognized mode 'hello'"): 3773 torch.linalg.qr(t2, mode='hello') 3774 3775 def _check_einsum(self, *args, np_args=None): 3776 if np_args is None: 3777 np_args = [arg.cpu().numpy() if isinstance(arg, torch.Tensor) else arg for arg in args] 3778 ref = np.einsum(*np_args) 3779 res = torch.einsum(*args) 3780 self.assertEqual(ref, res) 3781 3782 # Check that the other variations for opt_einsum work too 3783 if TEST_OPT_EINSUM: 3784 with opt_einsum.flags(enabled=False): 3785 res = torch.einsum(*args) 3786 self.assertEqual(ref, res) 3787 3788 with opt_einsum.flags(enabled=True, strategy='greedy'): 3789 res = torch.einsum(*args) 3790 self.assertEqual(ref, res) 3791 3792 with opt_einsum.flags(enabled=True, strategy='optimal'): 3793 res = torch.einsum(*args) 3794 self.assertEqual(ref, res) 3795 3796 @dtypes(torch.double, torch.cdouble) 3797 def test_einsum(self, device, dtype): 3798 # Test cases from https://gist.github.com/rockt/15ee013889d65342088e9260a377dc8f 3799 x = make_tensor((5,), dtype=dtype, device=device) 3800 y = make_tensor((7,), dtype=dtype, device=device) 3801 A = make_tensor((3, 5), dtype=dtype, device=device) 3802 B = make_tensor((2, 5), dtype=dtype, device=device) 3803 C = make_tensor((2, 3, 5), dtype=dtype, device=device) 3804 D = make_tensor((2, 5, 7), dtype=dtype, device=device) 3805 E = make_tensor((7, 9), dtype=dtype, device=device) 3806 F = make_tensor((2, 3, 3, 5), dtype=dtype, device=device) 3807 G = make_tensor((5, 4, 6), dtype=dtype, device=device) 3808 H = make_tensor((4, 4), dtype=dtype, device=device) 3809 I = make_tensor((2, 3, 2), dtype=dtype, device=device) 3810 3811 # Vector operations 3812 self._check_einsum('i->', x) # sum 3813 self._check_einsum('i,i->', x, x) # dot 3814 self._check_einsum('i,i->i', x, x) # vector element-wisem mul 3815 self._check_einsum('i,j->ij', x, y) # outer 3816 3817 # Matrix operations 3818 self._check_einsum("ij->ji", A) # transpose 3819 self._check_einsum("ij->j", A) # row sum 3820 self._check_einsum("ij->i", A) # col sum 3821 self._check_einsum("ij,ij->ij", A, A) # matrix element-wise mul 3822 self._check_einsum("ij,j->i", A, x) # matrix vector multiplication 3823 self._check_einsum("ij,kj->ik", A, B) # matmul 3824 self._check_einsum("ij,ab->ijab", A, E) # matrix outer product 3825 3826 # Tensor operations 3827 self._check_einsum("Aij,Ajk->Aik", C, D) # batch matmul 3828 self._check_einsum("ijk,jk->i", C, A) # tensor matrix contraction 3829 self._check_einsum("aij,jk->aik", D, E) # tensor matrix contraction 3830 self._check_einsum("abCd,dFg->abCFg", F, G) # tensor tensor contraction 3831 self._check_einsum("ijk,jk->ik", C, A) # tensor matrix contraction with double indices 3832 self._check_einsum("ijk,jk->ij", C, A) # tensor matrix contraction with double indices 3833 self._check_einsum("ijk,ik->j", C, B) # non contiguous 3834 self._check_einsum("ijk,ik->jk", C, B) # non contiguous with double indices 3835 3836 # Test diagonals 3837 self._check_einsum("ii", H) # trace 3838 self._check_einsum("ii->i", H) # diagonal 3839 self._check_einsum('iji->j', I) # non-contiguous trace 3840 self._check_einsum('ngrg...->nrg...', make_tensor((2, 1, 3, 1, 4), dtype=dtype, device=device)) 3841 3842 # Test ellipsis 3843 self._check_einsum("i...->...", H) 3844 self._check_einsum("ki,...k->i...", A.t(), B) 3845 self._check_einsum("k...,jk->...", A.t(), B) 3846 self._check_einsum('...ik, ...j -> ...ij', C, x) 3847 self._check_einsum('Bik,k...j->i...j', C, make_tensor((5, 3), dtype=dtype, device=device)) 3848 self._check_einsum('i...j, ij... -> ...ij', C, make_tensor((2, 5, 2, 3), dtype=dtype, device=device)) 3849 3850 # torch.bilinear with noncontiguous tensors 3851 l = make_tensor((5, 10), dtype=dtype, device=device, noncontiguous=True) 3852 r = make_tensor((5, 20), dtype=dtype, device=device, noncontiguous=True) 3853 w = make_tensor((15, 10, 20), dtype=dtype, device=device) 3854 self._check_einsum("bn,anm,bm->ba", l, w, r) 3855 3856 # with strided tensors 3857 self._check_einsum("bn,Anm,bm->bA", l[:, ::2], w[:, ::2, ::2], r[:, ::2]) 3858 3859 # test multiple inputs 3860 self._check_einsum("...,be,b...,beg,gi,bc...->bi...", A, B, C, D, E, F) 3861 3862 @dtypes(torch.double, torch.cdouble) 3863 def test_einsum_sublist_format(self, device, dtype): 3864 x = make_tensor((5,), dtype=dtype, device=device) 3865 y = make_tensor((7,), dtype=dtype, device=device) 3866 A = make_tensor((3, 5), dtype=dtype, device=device) 3867 B = make_tensor((2, 5), dtype=dtype, device=device) 3868 C = make_tensor((2, 1, 3, 1, 4), dtype=dtype, device=device) 3869 3870 self._check_einsum(x, [0]) 3871 self._check_einsum(x, [0], []) 3872 self._check_einsum(x, [0], y, [1], [0, 1]) 3873 self._check_einsum(A, [0, 1], [1, 0]) 3874 self._check_einsum(A, [0, 1], x, [1], [0]) 3875 self._check_einsum(A, [0, 1], B, [2, 1]) 3876 self._check_einsum(A, [0, 1], B, [2, 1], [0, 2]) 3877 self._check_einsum(C, [0, 1, 2, 1, Ellipsis], [0, 2, 1, Ellipsis]) 3878 self._check_einsum(A.t(), [0, 1], B, [Ellipsis, 0]) 3879 self._check_einsum(A.t(), [0, 1], B, [Ellipsis, 0], [1, Ellipsis]) 3880 self._check_einsum(A.t(), [0, Ellipsis], B, [1, 0], [Ellipsis]) 3881 3882 # torch.bilinear with noncontiguous tensors 3883 l = make_tensor((5, 10), dtype=dtype, device=device, noncontiguous=True) 3884 r = make_tensor((5, 20), dtype=dtype, device=device, noncontiguous=True) 3885 w = make_tensor((15, 10, 20), dtype=dtype, device=device) 3886 self._check_einsum(l, [40, 41], w, [2, 41, 50], r, [40, 50], [40, 2]) 3887 3888 @dtypes(torch.double, torch.cdouble) 3889 def test_einsum_random(self, device, dtype): 3890 def convert_label(label): 3891 if label == ...: 3892 return '...' 3893 elif label < 26: 3894 return chr(ord('A') + label) 3895 else: 3896 return chr(ord('a') + label - 26) 3897 3898 def convert_sublist(sublist): 3899 return ''.join(convert_label(label) for label in sublist) 3900 3901 def test(n=10, # how many tests to generate 3902 n_labels=5, # how many labels available 3903 min_ops=1, max_ops=4, # min and max number of operands per test 3904 min_dims=1, max_dims=3, # min and max number of dimensions per operand 3905 min_size=1, max_size=8, # min and max size of each dimension 3906 max_out_dim=3, # max number of dimensions for the output 3907 enable_diagonals=True, # controls if labels can be repeated for diagonals 3908 ellipsis_prob=0.5, # probability of including ellipsis in operand 3909 broadcasting_prob=0.1): # probability of turning some dim sizes 1 for broadcasting 3910 3911 all_labels = torch.arange(52) 3912 3913 assert 0 <= n 3914 assert 0 <= n_labels < len(all_labels) 3915 assert 0 < min_ops <= max_ops 3916 assert 0 <= min_dims <= max_dims 3917 assert 0 <= min_size <= max_size 3918 assert 0 <= max_out_dim 3919 assert enable_diagonals or max_dims <= n_labels 3920 3921 for _ in range(n): 3922 3923 # Select a subset of labels for this test and give them random sizes 3924 possible_labels = all_labels[torch.randperm(len(all_labels))[:n_labels]] 3925 labels_size = torch.randint_like(all_labels, min_size, max_size + 1) 3926 ellipsis_shape = torch.randint(min_size, max_size + 1, (max_dims - min_dims,)) 3927 3928 operands = [] 3929 sublists = [] 3930 3931 ell_size = 0 3932 valid_labels = set() 3933 3934 # create random input operands 3935 for _ in range(random.randint(min_ops, max_ops)): 3936 n_dim = random.randint(min_dims, max_dims) 3937 labels_idx = torch.ones(len(possible_labels)).multinomial(n_dim, enable_diagonals) 3938 labels = possible_labels[labels_idx] 3939 valid_labels.update(labels.tolist()) 3940 shape = labels_size[labels] 3941 3942 # turn some dimensions to size 1 for testing broadcasting 3943 mask = Binomial(probs=broadcasting_prob).sample((n_dim,)) 3944 broadcast_labels = torch.unique(labels[mask == 1]) 3945 shape[(labels[..., None] == broadcast_labels).any(-1)] = 1 3946 3947 labels = labels.tolist() 3948 shape = shape.tolist() 3949 3950 # include ellipsis if not all dimensions were assigned a label already 3951 if n_dim < max_dims and torch.rand(1) < ellipsis_prob: 3952 ell_num_dim = random.randint(1, max_dims - n_dim) 3953 ell_size = max(ell_size, ell_num_dim) 3954 ell_shape = ellipsis_shape[-ell_num_dim:] 3955 # again, turn some dimensions to size 1 for broadcasting 3956 mask = Binomial(probs=broadcasting_prob).sample((ell_num_dim,)) 3957 ell_shape[mask == 1] = 1 3958 ell_index = random.randint(0, n_dim) 3959 shape[ell_index:ell_index] = ell_shape 3960 labels.insert(ell_index, ...) 3961 3962 operands.append(make_tensor(shape, dtype=dtype, device=device)) 3963 sublists.append(labels) 3964 3965 # NumPy has a bug with the sublist format so for now we compare PyTorch sublist 3966 # implementation against the equation format implementation of NumPy 3967 # see https://github.com/numpy/numpy/issues/10926 3968 np_operands = [op.cpu().numpy() for op in operands] 3969 3970 # test equation format 3971 equation = ','.join(convert_sublist(l) for l in sublists) 3972 self._check_einsum(equation, *operands, np_args=(equation, *np_operands)) 3973 3974 # test sublist format 3975 args = list(itertools.chain.from_iterable(zip(operands, sublists))) 3976 self._check_einsum(*args, np_args=(equation, *np_operands)) 3977 3978 # generate an explicit output 3979 out_sublist = [] 3980 num_out_labels = max(0, random.randint(0, min(max_out_dim, len(valid_labels))) - ell_size) 3981 if num_out_labels > 0: 3982 out_labels_idx = torch.ones(len(valid_labels)).multinomial(num_out_labels) 3983 out_sublist = torch.tensor(list(valid_labels))[out_labels_idx].tolist() 3984 out_sublist.insert(random.randint(0, num_out_labels), ...) 3985 3986 # test equation format with explicit output 3987 equation += '->' + convert_sublist(out_sublist) 3988 self._check_einsum(equation, *operands, np_args=(equation, *np_operands)) 3989 3990 # test sublist format with explicit output 3991 args.append(out_sublist) 3992 self._check_einsum(*args, np_args=(equation, *np_operands)) 3993 3994 test(500) 3995 3996 def test_einsum_corner_cases(self, device): 3997 def check(equation, *operands, expected_output): 3998 tensors = [torch.tensor(operand, device=device, dtype=torch.float32) if not isinstance(operand, tuple) 3999 else make_tensor(operand, dtype=torch.float32, device=device) for operand in operands] 4000 output = torch.einsum(equation, tensors) 4001 self.assertEqual(output, torch.tensor(expected_output, dtype=torch.float32, device=device)) 4002 4003 # Test equation variantions 4004 check(' ', 1, expected_output=1) 4005 check(' -> ', 1, expected_output=1) 4006 check(' , ', 2, 2, expected_output=4) 4007 check(' , , ', 2, 2, 2, expected_output=8) 4008 check(' , -> ', 2, 2, expected_output=4) 4009 check(' i ', [1], expected_output=[1]) 4010 check(' i -> ', [1], expected_output=1) 4011 check(' i -> i ', [1], expected_output=[1]) 4012 check(' i , i ', [2], [2], expected_output=4) 4013 check(' i , i -> i ', [2], [2], expected_output=[4]) 4014 4015 # Test tensors with 0 size dimensions 4016 check('i', [], expected_output=[]) 4017 check(' i j -> j', [[], []], expected_output=[]) 4018 check('ij->i', [[], []], expected_output=[0., 0.]) 4019 check(' i j k , k -> i j ', (3, 0, 6), (6,), expected_output=[[], [], []]) 4020 4021 # Test broadcasting 4022 check('i,j', [2], [1, 2], expected_output=[[2, 4]]) 4023 check('i,ij->ij', [1, 2], [[1, 2, 3], [2, 3, 4]], expected_output=[[1, 2, 3], [4, 6, 8]]) 4024 4025 # Test ellipsis broadcasting 4026 check('...', 1, expected_output=1) 4027 check('...->', 1, expected_output=1) 4028 check('...->...', 1, expected_output=1) 4029 check('...', [1], expected_output=[1]) 4030 check('...->', [1], expected_output=1) 4031 check('z...->z', [1], expected_output=[1]) 4032 check('Z...->...Z', [1], expected_output=[1]) 4033 check('...a->', [[2], [4]], expected_output=6) 4034 check('a...b->ab', [[[1], [2]], [[3], [4]]], expected_output=[[3], [7]]) 4035 4036 def test_einsum_error_cases(self, device): 4037 def check(*args, regex, exception=RuntimeError): 4038 with self.assertRaisesRegex(exception, r'einsum\(\):.*' + regex): 4039 torch.einsum(*args) 4040 4041 x = make_tensor((2,), dtype=torch.float32, device=device) 4042 y = make_tensor((2, 3), dtype=torch.float32, device=device) 4043 4044 check('', [], regex=r'at least one operand', exception=ValueError) 4045 check('. ..', [x], regex=r'found \'.\' for operand 0 that is not part of any ellipsis') 4046 check('... ...', [x], regex=r'found \'.\' for operand 0 for which an ellipsis was already found') 4047 check('1', [x], regex=r'invalid subscript given at index 0') 4048 check(',', [x], regex=r'fewer operands were provided than specified in the equation') 4049 check('', [x, x], regex=r'more operands were provided than specified in the equation') 4050 check('', [x], regex=r'the number of subscripts in the equation \(0\) does not match the number ' 4051 r'of dimensions \(1\) for operand 0 and no ellipsis was given') 4052 check('ai', [x], regex=r'the number of subscripts in the equation \(2\) does not match the number ' 4053 r'of dimensions \(1\) for operand 0 and no ellipsis was given') 4054 check('ai...', [x], regex=r'the number of subscripts in the equation \(2\) is more than the number ' 4055 r'of dimensions \(1\) for operand 0') 4056 check('a->... .', [x], regex=r'found \'.\' for output but an ellipsis \(...\) was already found') 4057 check('a->..', [x], regex=r'found \'.\' for output that is not part of any ellipsis \(...\)') 4058 check('a->1', [x], regex=r'invalid subscript given at index 3') 4059 check('a->aa', [x], regex=r'output subscript a appears more than once in the output') 4060 check('a->i', [x], regex=r'output subscript i does not appear in the equation for any input operand') 4061 check('aa', [y], regex=r'subscript a is repeated for operand 0 but the sizes don\'t match, 3 != 2') 4062 check('...,...', [x, y], regex=r'does not broadcast') 4063 check('a,a', [x, make_tensor((3,), dtype=torch.float32, device=device)], regex=r'does not broadcast') 4064 check('a, ba', [x, y], regex=r'subscript a has size 3 for operand 1 which does not broadcast with previously' 4065 r' seen size 2') 4066 4067 check(x, [-1], regex=r'not within the valid range \[0, 52\)', exception=ValueError) 4068 check(x, [52], regex=r'not within the valid range \[0, 52\)', exception=ValueError) 4069 4070 def _gen_shape_inputs_linalg_triangular_solve(self, shape, dtype, device, well_conditioned=False): 4071 make_arg = partial(make_tensor, dtype=dtype, device=device) 4072 make_fullrank = partial(make_fullrank_matrices_with_distinct_singular_values, dtype=dtype, device=device) 4073 b, n, k = shape 4074 for left, uni, expand_a, tr_a, conj_a, expand_b, tr_b, conj_b in product((True, False), repeat=8): 4075 # expand means that we generate a batch of matrices with a stride of zero in the batch dimension 4076 if (conj_a or conj_b) and not dtype.is_complex: 4077 continue 4078 # We just expand on the batch size 4079 if (expand_a or expand_b) and b == 1: 4080 continue 4081 4082 size_a = (b, n, n) if left else (b, k, k) 4083 size_b = (b, n, k) if not tr_b else (b, k, n) 4084 4085 # If expand_a or expand_b, we'll expand them to the correct size later 4086 if b == 1 or expand_a: 4087 size_a = size_a[1:] 4088 if b == 1 or expand_b: 4089 size_b = size_b[1:] 4090 4091 if well_conditioned: 4092 PLU = torch.linalg.lu(make_fullrank(*size_a)) 4093 if uni: 4094 # A = L from PLU 4095 A = PLU[1].transpose(-2, -1).contiguous() 4096 else: 4097 # A = U from PLU 4098 A = PLU[2].contiguous() 4099 else: 4100 A = make_arg(size_a) 4101 A.triu_() 4102 4103 diag = A.diagonal(0, -2, -1) 4104 if uni: 4105 diag.fill_(1.) 4106 else: 4107 diag[diag.abs() < 1e-6] = 1. 4108 4109 B = make_arg(size_b) 4110 4111 if tr_a: 4112 A.transpose_(-2, -1) 4113 if tr_b: 4114 B.transpose_(-2, -1) 4115 if conj_a: 4116 A = A.conj() 4117 if conj_b: 4118 B = B.conj() 4119 if expand_a: 4120 A = A.expand(b, *size_a) 4121 if expand_b: 4122 B = B.expand(b, n, k) 4123 yield A, B, left, not tr_a, uni 4124 4125 def _test_linalg_solve_triangular(self, A, B, upper, left, uni): 4126 X = torch.linalg.solve_triangular(A, B, upper=upper, left=left, unitriangular=uni) 4127 if left: 4128 self.assertEqual(A @ X, B) 4129 else: 4130 self.assertEqual(X @ A, B) 4131 out = B 4132 # B may be expanded 4133 if not B.is_contiguous() and not B.transpose(-2, -1).is_contiguous(): 4134 out = B.clone() 4135 torch.linalg.solve_triangular(A, B, upper=upper, left=left, unitriangular=uni, out=out) 4136 self.assertEqual(X, out) 4137 4138 # Tolerances dictated by widest acceptable range on CPU before failure 4139 @dtypes(*floating_and_complex_types()) 4140 @precisionOverride({torch.float32: 1e-3 if TEST_WITH_ROCM else 1e-1, 4141 torch.float64: 1e-8, 4142 torch.complex64: 1e-1, 4143 torch.complex128: 1e-8}) 4144 def test_linalg_solve_triangular(self, device, dtype): 4145 # This exercises the API + BLAS CPU + batched cuBLAS 4146 ks = (3, 1, 0) 4147 ns = (5, 0) 4148 bs = (1, 2, 0) 4149 4150 gen_inputs = self._gen_shape_inputs_linalg_triangular_solve 4151 for b, n, k in product(bs, ns, ks): 4152 for A, B, left, upper, uni in gen_inputs((b, n, k), dtype, device, well_conditioned=True): 4153 self._test_linalg_solve_triangular(A, B, upper, left, uni) 4154 4155 @slowTest 4156 @unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Test fails for float64 on GPU (P100, V100) on Meta infra") 4157 @onlyCUDA 4158 @skipCUDAIfNoMagma # Magma needed for the PLU decomposition 4159 @dtypes(*floating_and_complex_types()) 4160 @precisionOverride({torch.float32: 1e-2, torch.complex64: 1e-2, 4161 torch.float64: 1e-8, torch.complex128: 1e-8}) 4162 def test_linalg_solve_triangular_large(self, device, dtype): 4163 # Exercises magma and cublas 4164 magma = (9, 513, 1) 4165 iterative_cublas = (2, 64, 1) 4166 4167 gen_inputs = self._gen_shape_inputs_linalg_triangular_solve 4168 for shape in (magma, iterative_cublas): 4169 for A, B, left, upper, uni in gen_inputs(shape, dtype, device, well_conditioned=True): 4170 self._test_linalg_solve_triangular(A, B, upper, left, uni) 4171 4172 @dtypes(*floating_and_complex_types()) 4173 @precisionOverride({torch.float32: 1e-2, torch.complex64: 1e-2, 4174 torch.float64: 1e-8, torch.complex128: 1e-8}) 4175 def test_linalg_solve_triangular_broadcasting(self, device, dtype): 4176 make_arg = partial(make_tensor, dtype=dtype, device=device) 4177 4178 sizes = (((2, 1, 3, 4, 4), (2, 1, 3, 4, 6)), 4179 ((2, 1, 3, 4, 4), (4, 6)), 4180 ((4, 4), (2, 1, 3, 4, 2)), 4181 ((1, 3, 1, 4, 4), (2, 1, 3, 4, 5))) 4182 for size_A, size_B in sizes: 4183 for left, upper, uni in itertools.product([True, False], repeat=3): 4184 A = make_arg(size_A) 4185 if upper: 4186 A.triu_() 4187 else: 4188 A.tril_() 4189 diag = A.diagonal(0, -2, -1) 4190 if uni: 4191 diag.fill_(1.) 4192 else: 4193 diag[diag.abs() < 1e-6] = 1. 4194 B = make_arg(size_B) 4195 if not left: 4196 B.transpose_(-2, -1) 4197 4198 X = torch.linalg.solve_triangular(A, B, upper=upper, left=left, unitriangular=uni) 4199 if left: 4200 B_other = A @ X 4201 else: 4202 B_other = X @ A 4203 4204 self.assertEqual(*torch.broadcast_tensors(B, B_other)) 4205 4206 def triangular_solve_test_helper(self, A_dims, b_dims, upper, unitriangular, 4207 device, dtype): 4208 triangle_function = torch.triu if upper else torch.tril 4209 b = torch.randn(*b_dims, dtype=dtype, device=device) 4210 A = torch.randn(*A_dims, dtype=dtype, device=device) 4211 # create positive definite matrix 4212 A = torch.matmul(A, A.mT) 4213 A_triangular = triangle_function(A) 4214 if unitriangular: 4215 A_triangular.diagonal(dim1=-2, dim2=-1).fill_(1.) 4216 return b, A_triangular 4217 4218 @skipCUDAIfNoMagma 4219 @skipCPUIfNoLapack 4220 @skipIfTorchDynamo("flaky, needs investigation") 4221 @dtypes(*floating_and_complex_types()) 4222 @precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3, 4223 torch.float64: 1e-8, torch.complex128: 1e-8}) 4224 def test_triangular_solve(self, device, dtype): 4225 ks = [0, 1, 3] 4226 ns = [0, 5] 4227 for k, n, (upper, unitriangular, transpose) in itertools.product(ks, ns, 4228 itertools.product([True, False], repeat=3)): 4229 b, A = self.triangular_solve_test_helper((n, n), (n, k), upper, 4230 unitriangular, device, dtype) 4231 x = torch.triangular_solve(b, A, upper=upper, unitriangular=unitriangular, transpose=transpose)[0] 4232 if transpose: 4233 self.assertEqual(b, np.matmul(A.t().cpu(), x.cpu())) 4234 else: 4235 self.assertEqual(b, np.matmul(A.cpu(), x.cpu())) 4236 4237 @skipCPUIfNoLapack 4238 @skipCUDAIfNoMagma 4239 @dtypes(*floating_and_complex_types()) 4240 @precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3, 4241 torch.float64: 1e-8, torch.complex128: 1e-8}) 4242 def test_triangular_solve_batched(self, device, dtype): 4243 def triangular_solve_batch_helper(A_dims, b_dims, upper, unitriangular, transpose): 4244 b, A = self.triangular_solve_test_helper(A_dims, b_dims, upper, 4245 unitriangular, device, dtype) 4246 x_exp_list = [] 4247 for i in range(b_dims[0]): 4248 x_exp_list.append(torch.triangular_solve(b[i], A[i], upper=upper, 4249 unitriangular=unitriangular, 4250 transpose=transpose)[0]) 4251 x_exp = torch.stack(x_exp_list) # Stacked output 4252 x_act = torch.triangular_solve(b, A, upper=upper, 4253 unitriangular=unitriangular, 4254 transpose=transpose)[0] # Actual output 4255 self.assertEqual(x_act, x_exp) # Equality check 4256 if transpose: 4257 A = A.mT 4258 4259 Ax = np.matmul(A.cpu(), x_act.cpu()) 4260 self.assertEqual(b, Ax) 4261 4262 def triangular_solve_zero_batch_helper(A_dims, b_dims, upper, unitriangular, transpose): 4263 b, A = self.triangular_solve_test_helper(A_dims, b_dims, upper, 4264 unitriangular, device, dtype) 4265 x = torch.triangular_solve(b, A, upper=upper, 4266 unitriangular=unitriangular, 4267 transpose=transpose)[0] 4268 self.assertTrue(x.shape == b.shape) 4269 4270 for upper, unitriangular, transpose in itertools.product([True, False], repeat=3): 4271 batchsize = 3 4272 triangular_solve_batch_helper((batchsize, 5, 5), (batchsize, 5, 10), 4273 upper, unitriangular, transpose) 4274 4275 # test empty input 4276 triangular_solve_batch_helper((batchsize, 0, 0), (batchsize, 0, 10), 4277 upper, unitriangular, transpose) 4278 triangular_solve_batch_helper((batchsize, 0, 0), (batchsize, 0, 0), 4279 upper, unitriangular, transpose) 4280 4281 # test zero batch case 4282 batchsize = 0 4283 triangular_solve_zero_batch_helper((batchsize, 5, 5), (batchsize, 5, 10), 4284 upper, unitriangular, transpose) 4285 4286 4287 @slowTest 4288 @skipCUDAIfNoMagma 4289 @skipCPUIfNoLapack 4290 @dtypes(*floating_and_complex_types()) 4291 @precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3, 4292 torch.float64: 1e-8, torch.complex128: 1e-8}) 4293 def test_triangular_solve_batched_many_batches(self, device, dtype): 4294 for upper, transpose, unitriangular in itertools.product([True, False], repeat=3): 4295 # test batched A case 4296 b, A = self.triangular_solve_test_helper((256, 256, 5, 5), (5, 1), 4297 upper, unitriangular, device, dtype) 4298 x, _ = torch.triangular_solve(b, A, 4299 upper=upper, transpose=transpose, unitriangular=unitriangular) 4300 if transpose: 4301 A = A.mT 4302 4303 Ax = torch.matmul(A, x) 4304 4305 rtol = 1e-2 if dtype in [torch.float32, torch.complex64] else self.precision 4306 self.assertEqual(Ax, b.expand_as(Ax), atol=self.precision, rtol=rtol) 4307 4308 # test batched b case 4309 b, A = self.triangular_solve_test_helper((3, 3), (512, 512, 3, 1), 4310 upper, unitriangular, device, dtype) 4311 x, _ = torch.triangular_solve(b, A, upper=upper, transpose=transpose, 4312 unitriangular=unitriangular) 4313 if transpose: 4314 A = A.mT 4315 4316 self.assertEqual(torch.matmul(A, x), b) 4317 4318 @skipCUDAIfNoMagma 4319 @skipCPUIfNoLapack 4320 @unittest.skipIf(not TEST_SCIPY, "SciPy not found") 4321 @skipIfTorchDynamo("flaky, needs investigation") 4322 @dtypes(*floating_and_complex_types()) 4323 def test_triangular_solve_batched_broadcasting(self, device, dtype): 4324 from scipy.linalg import solve_triangular as tri_solve 4325 4326 def scipy_tri_solve_batched(A, B, upper, trans, diag): 4327 batch_dims_A, batch_dims_B = A.shape[:-2], B.shape[:-2] 4328 single_dim_A, single_dim_B = A.shape[-2:], B.shape[-2:] 4329 expand_dims = tuple(torch._C._infer_size(torch.Size(batch_dims_A), 4330 torch.Size(batch_dims_B))) 4331 expand_A = np.broadcast_to(A, expand_dims + single_dim_A) 4332 expand_B = np.broadcast_to(B, expand_dims + single_dim_B) 4333 flat_A = expand_A.reshape((-1,) + single_dim_A) 4334 flat_B = expand_B.reshape((-1,) + single_dim_B) 4335 flat_X = np.vstack([tri_solve(a, b, lower=(not upper), trans=int(trans), unit_diagonal=diag) 4336 for a, b in zip(flat_A, flat_B)]) 4337 return flat_X.reshape(expand_B.shape) 4338 4339 def run_test(A_dims, b_dims, device, upper, transpose, unitriangular): 4340 b, A = self.triangular_solve_test_helper(A_dims, b_dims, upper, 4341 unitriangular, device, dtype) 4342 x_exp = torch.as_tensor(scipy_tri_solve_batched(A.cpu().numpy(), b.cpu().numpy(), 4343 upper, transpose, unitriangular)) 4344 x = torch.triangular_solve(b, A, upper=upper, transpose=transpose, unitriangular=unitriangular)[0] 4345 4346 self.assertEqual(x, x_exp.to(device)) 4347 4348 for upper, transpose, unitriangular in itertools.product([True, False], repeat=3): 4349 # test against scipy.linalg.solve_triangular 4350 run_test((2, 1, 3, 4, 4), (2, 1, 3, 4, 6), device, upper, transpose, unitriangular) # no broadcasting 4351 run_test((2, 1, 3, 4, 4), (4, 6), device, upper, transpose, unitriangular) # broadcasting b 4352 run_test((4, 4), (2, 1, 3, 4, 2), device, upper, transpose, unitriangular) # broadcasting A 4353 run_test((1, 3, 1, 4, 4), (2, 1, 3, 4, 5), device, upper, transpose, unitriangular) # broadcasting A & b 4354 4355 @onlyCUDA 4356 @dtypes(torch.float) 4357 def test_triangular_solve_large(self, device, dtype): 4358 # Repro for https://github.com/pytorch/pytorch/issues/79191 4359 A = torch.randn(1, 2, 2, device=device, dtype=dtype).tril_() 4360 B = torch.randn(1, 2, 524281, device=device, dtype=dtype) 4361 X = torch.linalg.solve_triangular(A, B, upper=False) 4362 self.assertEqual(A @ X, B) 4363 4364 @skipCUDAIfNoMagma 4365 @skipCPUIfNoLapack 4366 @dtypes(*floating_and_complex_types()) 4367 def test_triangular_solve_out_errors_and_warnings(self, device, dtype): 4368 # dtypes should be safely castable 4369 a = torch.eye(2, dtype=dtype, device=device) 4370 b = torch.randn(2, 1, dtype=dtype, device=device) 4371 out = torch.empty_like(b).to(torch.int) 4372 clone_a = torch.empty_like(a) 4373 with self.assertRaisesRegex(RuntimeError, "Expected out tensor to have dtype"): 4374 torch.triangular_solve(b, a, out=(out, clone_a)) 4375 4376 out = torch.empty_like(b) 4377 clone_a = clone_a.to(torch.int) 4378 with self.assertRaisesRegex(RuntimeError, "Expected out tensor to have dtype"): 4379 torch.triangular_solve(b, a, out=(out, clone_a)) 4380 4381 # device should match 4382 if torch.cuda.is_available(): 4383 wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda' 4384 out = torch.empty(0, dtype=dtype, device=wrong_device) 4385 clone_a = torch.empty_like(a) 4386 with self.assertRaisesRegex(RuntimeError, "tensors to be on the same device"): 4387 torch.triangular_solve(b, a, out=(out, clone_a)) 4388 out = torch.empty(0, dtype=dtype, device=device) 4389 clone_a = torch.empty_like(a).to(wrong_device) 4390 with self.assertRaisesRegex(RuntimeError, "tensors to be on the same device"): 4391 torch.triangular_solve(b, a, out=(out, clone_a)) 4392 4393 # Trigger the WARN_ONCE deprecation error 4394 torch.triangular_solve(b, a) 4395 4396 # if out tensor with wrong shape is passed a warning is given 4397 with warnings.catch_warnings(record=True) as w: 4398 out = torch.empty(1, dtype=dtype, device=device) 4399 clone_a = torch.empty(1, dtype=dtype, device=device) 4400 # Trigger warning 4401 torch.triangular_solve(b, a, out=(out, clone_a)) 4402 # Check warning occurs 4403 self.assertEqual(len(w), 2) 4404 self.assertTrue("An output with one or more elements was resized" in str(w[0].message)) 4405 self.assertTrue("An output with one or more elements was resized" in str(w[1].message)) 4406 4407 4408 def check_single_matmul(self, x, y): 4409 4410 def assertEqual(answer, expected): 4411 if x.dtype.is_floating_point or x.dtype.is_complex: 4412 k = max(x.shape[-1], 1) # Scale the atol with the size of the matrix 4413 self.assertEqual(answer, expected, 4414 msg=f"{x.shape} x {y.shape} = {answer.shape}", 4415 atol=k * 5e-5, 4416 rtol=1e-4) 4417 else: 4418 self.assertEqual(answer, expected, msg=f"{x.shape} x {y.shape} = {answer.shape}") 4419 4420 # test x @ y 4421 expected = np.matmul(x.cpu(), y.cpu()) 4422 ans = torch.matmul(x, y) 4423 self.assertTrue(ans.is_contiguous()) 4424 assertEqual(ans, expected) 4425 4426 # test out 4427 out = torch.empty_like(ans) 4428 ans = torch.matmul(x, y, out=out) 4429 self.assertIs(ans, out) 4430 self.assertTrue(ans.is_contiguous()) 4431 assertEqual(ans, expected) 4432 4433 def gen_sizes_matmul(self, x_dim, y_dim=4, matrix_size=4, batch_size=3): 4434 """ 4435 Generates sequences of tuples (x, y) of with size(x) = x_dim and 4436 size(y) <= y_dim that are compatible wrt. matmul 4437 """ 4438 assert x_dim >= 1 4439 assert y_dim >= 2 4440 x = x_dim 4441 for y in range(1, y_dim + 1): 4442 for batch, mn in product(product(range(batch_size), repeat=max(x - 2, y - 2, 0)), 4443 product(range(matrix_size), repeat=min(y, 2))): 4444 if x == 1: 4445 size_x = mn[:1] 4446 size_y = batch + mn 4447 yield size_x, size_y 4448 else: 4449 for k in range(matrix_size): 4450 size_x = (k,) + mn[:1] 4451 if x > 2: 4452 size_x = batch[-(x - 2):] + size_x 4453 size_y = mn 4454 if y > 2: 4455 size_y = batch[-(y - 2):] + size_y 4456 yield size_x, size_y 4457 4458 @dtypesIfCUDA(torch.float, torch.complex64) # Integer matmul just supported on CPU 4459 @dtypes(torch.int64, torch.float, torch.complex64) 4460 @setBlasBackendsToDefaultFinally 4461 def test_matmul_small_brute_force_1d_Nd(self, device, dtype): 4462 for backend in ["cublas", "cublaslt"]: 4463 if torch.device(device).type == 'cuda': 4464 torch.backends.cuda.preferred_blas_library(backend) 4465 4466 make_arg = partial(make_tensor, device=device, dtype=dtype) 4467 4468 for (size_x, size_y), nctg_x, nctg_y in product(self.gen_sizes_matmul(1), (True, False), (True, False)): 4469 x = make_arg(size_x, noncontiguous=nctg_x) 4470 y = make_arg(size_y, noncontiguous=nctg_y) 4471 self.check_single_matmul(x, y) 4472 4473 @dtypesIfCUDA(torch.float, torch.complex64) # Integer matmul just supported on CPU 4474 @dtypes(torch.int64, torch.float, torch.complex64) 4475 @setBlasBackendsToDefaultFinally 4476 def test_matmul_small_brute_force_2d_Nd(self, device, dtype): 4477 for backend in ["cublas", "cublaslt"]: 4478 if torch.device(device).type == 'cuda': 4479 torch.backends.cuda.preferred_blas_library(backend) 4480 4481 make_arg = partial(make_tensor, device=device, dtype=dtype) 4482 4483 for (size_x, size_y), nctg_x, nctg_y in product(self.gen_sizes_matmul(2), (True, False), (True, False)): 4484 x = make_arg(size_x, noncontiguous=nctg_x) 4485 y = make_arg(size_y, noncontiguous=nctg_y) 4486 self.check_single_matmul(x, y) 4487 4488 @dtypesIfCUDA(torch.float, torch.complex64) # Integer matmul just supported on CPU 4489 @dtypes(torch.int64, torch.float, torch.complex64) 4490 @setBlasBackendsToDefaultFinally 4491 def test_matmul_small_brute_force_3d_Nd(self, device, dtype): 4492 for backend in ["cublas", "cublaslt"]: 4493 if torch.device(device).type == 'cuda': 4494 torch.backends.cuda.preferred_blas_library(backend) 4495 4496 make_arg = partial(make_tensor, device=device, dtype=dtype) 4497 4498 for (size_x, size_y), nctg_x, nctg_y in product(self.gen_sizes_matmul(3), (True, False), (True, False)): 4499 x = make_arg(size_x, noncontiguous=nctg_x) 4500 y = make_arg(size_y, noncontiguous=nctg_y) 4501 self.check_single_matmul(x, y) 4502 4503 @onlyCUDA 4504 @dtypes(*floating_types_and(torch.half)) 4505 def test_matmul_small_brute_force_tunableop(self, device, dtype): 4506 # disable tunableop buffer rotation for all tests everywhere, it can be slow 4507 import os 4508 os.environ["PYTORCH_TUNABLEOP_ROTATING_BUFFER_SIZE"] = "0" 4509 set_tunableop_defaults() 4510 4511 torch.cuda.tunable.enable() 4512 # set these to single iterations to keep it short but still exercise the code 4513 torch.cuda.tunable.set_max_tuning_duration(1) 4514 torch.cuda.tunable.set_max_tuning_iterations(1) 4515 4516 make_arg = partial(make_tensor, device=device, dtype=dtype) 4517 4518 for (size_x, size_y), nctg_x, nctg_y in product(self.gen_sizes_matmul(1), (True, False), (True, False)): 4519 x = make_arg(size_x, noncontiguous=nctg_x) 4520 y = make_arg(size_y, noncontiguous=nctg_y) 4521 self.check_single_matmul(x, y) 4522 4523 filename1 = torch.cuda.tunable.get_filename() 4524 filename2 = "tunableop_results_tmp1.csv" 4525 filename3 = "tunableop_results_tmp2.csv" 4526 ordinal = torch.cuda.current_device() 4527 assert filename1 == f"tunableop_results{ordinal}.csv" 4528 assert len(torch.cuda.tunable.get_validators()) > 0 4529 validators = {} 4530 for key, value in torch.cuda.tunable.get_validators(): 4531 validators[key] = value 4532 if torch.version.hip: 4533 assert "HIPBLASLT_VERSION" in validators 4534 assert re.match(r'^\d{3}-[a-z0-9]{8}$', validators["HIPBLASLT_VERSION"]) 4535 assert len(torch.cuda.tunable.get_results()) > 0 4536 4537 assert torch.cuda.tunable.write_file() # use default filename 4538 assert torch.cuda.tunable.write_file(filename2) # use custom, one-time filename 4539 torch.cuda.tunable.set_filename(filename3) 4540 assert torch.cuda.tunable.write_file() # use previously set filename 4541 assert torch.cuda.tunable.read_file() # use previously set filename, will ignore duplicates and return True 4542 4543 with open(filename1) as file1: 4544 file1_contents = file1.read() 4545 with open(filename2) as file2: 4546 file2_contents = file2.read() 4547 with open(filename3) as file3: 4548 file3_contents = file3.read() 4549 assert file1_contents == file2_contents 4550 assert file1_contents == file3_contents 4551 4552 # remove the files created above to avoid error 'Build left local git repository checkout dirty', ignore errors 4553 for filename in [filename1, filename2, filename3]: 4554 try: 4555 import os 4556 os.remove(filename) 4557 except FileNotFoundError: 4558 pass 4559 4560 # disables TunableOp 4561 torch.cuda.tunable.enable(False) 4562 4563 @onlyCUDA 4564 @skipCUDAIfNotRocm 4565 @dtypes(torch.float) 4566 def test_bmm_tunableop_rocm(self, device, dtype): 4567 # buffer rotation (on by default) with strided batched gemm tunableop was causing a mem fault 4568 set_tunableop_defaults() 4569 torch.cuda.tunable.enable(True) 4570 torch.cuda.tunable.set_max_tuning_iterations(10) 4571 # the following 3 cases cover all previous failure cases and are here to catch regressions 4572 B = 16 4573 N = M = K = 256 4574 dtype = torch.bfloat16 4575 device = torch.device("cuda:0") 4576 # case 1 4577 i1 = torch.randn((B, N, M), device=device, dtype=dtype) 4578 i2 = torch.randn((B, M, K), device=device, dtype=dtype) 4579 out = torch.bmm(i1, i2) 4580 # case 2 4581 i1 = torch.randn((B, N, M), device=device, dtype=dtype) 4582 i1 = torch.permute(i1, (1, 2, 0)) 4583 i2 = torch.randn((B, M, K), device=device, dtype=dtype) 4584 i2 = torch.permute(i2, (1, 0, 2)) 4585 out = torch.bmm(i1, i2) 4586 # case 3 4587 i1 = torch.randn((N, B, M), device=device, dtype=dtype) 4588 i1 = torch.permute(i1, (1, 0, 2)) 4589 i2 = torch.randn((M, B, K), device=device, dtype=dtype) 4590 i2 = torch.permute(i2, (1, 2, 0)) 4591 out = torch.bmm(i1, i2) 4592 # case 4 4593 input_tensor = torch.rand((1920, 1, 100), device=device, dtype=dtype) 4594 input_tensor = torch.as_strided( 4595 input_tensor, size=(1920, 1, 100), stride=(100, 100, 1) 4596 ) 4597 batch1_tensor = torch.rand((1920, 256, 512), device=device, dtype=dtype) 4598 batch1_tensor = torch.as_strided( 4599 batch1_tensor, size=(1920, 256, 512), stride=(512, 983040, 1) 4600 ) 4601 batch2_tensor = torch.rand((1920, 512, 100), device=device, dtype=dtype) 4602 batch2_tensor = torch.as_strided( 4603 batch2_tensor, size=(1920, 512, 100), stride=(51200, 100, 1) 4604 ) 4605 out = torch.baddbmm(input_tensor, batch1_tensor, batch2_tensor) 4606 # clean up, remove any file that was generated 4607 try: 4608 import os 4609 filename = torch.cuda.tunable.get_filename() 4610 os.remove(filename) 4611 except FileNotFoundError: 4612 pass 4613 4614 # disable TunableOp 4615 torch.cuda.tunable.enable(False) 4616 4617 @onlyCUDA 4618 @skipCUDAIfNotRocm 4619 @dtypes(torch.float) 4620 def test_numeric_check_leak_tunableop_rocm(self, device, dtype): 4621 from torch.testing._internal.common_utils import CudaMemoryLeakCheck 4622 import os 4623 # run operator first without tuning to ensure all rocm libs are loaded, 4624 # otherwise false positive mem leak 4625 B = 16 4626 N = M = K = 256 4627 dtype = torch.bfloat16 4628 device = torch.device("cuda:0") 4629 i1 = torch.randn((B, N, M), device=device, dtype=dtype) 4630 i2 = torch.randn((B, M, K), device=device, dtype=dtype) 4631 out = torch.bmm(i1, i2) 4632 # enable tunableop numeric check via env variable. 4633 PYTORCH_TUNABLEOP_NUMERICAL_CHECK = "PYTORCH_TUNABLEOP_NUMERICAL_CHECK" 4634 prev_val = os.getenv(PYTORCH_TUNABLEOP_NUMERICAL_CHECK) 4635 try: 4636 os.environ[PYTORCH_TUNABLEOP_NUMERICAL_CHECK] = "1" 4637 torch.cuda.tunable.enable(True) 4638 ordinal = torch.cuda.current_device() 4639 filename = f"tunableop_results{ordinal}.csv" 4640 torch.cuda.tunable.set_filename(filename) 4641 iterations = torch.cuda.tunable.get_max_tuning_iterations() 4642 torch.cuda.tunable.set_max_tuning_iterations(10) 4643 with CudaMemoryLeakCheck(self): 4644 out = torch.bmm(i1, i2) 4645 torch.cuda.tunable.set_max_tuning_iterations(iterations) 4646 torch.cuda.tunable.enable(False) 4647 # clean up, remove any file that was generated 4648 try: 4649 os.remove(filename) 4650 except FileNotFoundError: 4651 pass 4652 finally: 4653 if prev_val is None: 4654 del os.environ[PYTORCH_TUNABLEOP_NUMERICAL_CHECK] 4655 else: 4656 os.environ[PYTORCH_TUNABLEOP_NUMERICAL_CHECK] = prev_val 4657 4658 @onlyCUDA 4659 @skipCUDAIfNotRocm 4660 @dtypes(torch.float) 4661 def test_validator_tunableop_rocm(self, device, dtype): 4662 # Test that the validator on ROCM has exactly 5 lines 4663 # Format of the Validator is as follows: 4664 # Validator,PT_VERSION,X.Y.Z. 4665 # Validator,ROCBLAS_VERSION,X.Y,Z 4666 # Validator,HIPBLASLT_VERSION,X,Y.Z 4667 # Validator,ROCM_Version,X,Y.Z 4668 # Validator,GCN_ARCH_NAME,<architecutre name> 4669 validator_num_lines = 5 4670 4671 # Test in try-finally block to avoid leaking state 4672 # if test is interrupted. 4673 try: 4674 set_tunableop_defaults() 4675 torch.cuda.tunable.enable() 4676 # set these to single iterations to keep it short but still exercise the code 4677 torch.cuda.tunable.set_max_tuning_iterations(1) 4678 4679 N = M = K = 4 4680 A = torch.randn(N, K, device=device, dtype=dtype) 4681 B = torch.randn(K, M, device=device, dtype=dtype) 4682 C = torch.matmul(A, B) 4683 self.assertEqual(len(torch.cuda.tunable.get_validators()), validator_num_lines) 4684 finally: 4685 # disable TunableOp 4686 torch.cuda.tunable.enable(False) 4687 4688 # clean up, remove any file that was generated 4689 try: 4690 import os 4691 filename = torch.cuda.tunable.get_filename() 4692 os.remove(filename) 4693 except FileNotFoundError: 4694 pass 4695 4696 @onlyCUDA 4697 @dtypes(torch.half) 4698 def test_minimum_tuning_iteration_tunableop(self, device, dtype): 4699 # Make sure that there is at least one tuning iteration under various scenarios 4700 4701 # Test in try-finally block to avoid leaking state 4702 # if test is interrupted. 4703 try: 4704 set_tunableop_defaults() 4705 torch.cuda.tunable.enable() 4706 # set these to single iterations to keep it short but still exercise the code 4707 torch.cuda.tunable.set_max_tuning_iterations(1) 4708 4709 # Set tuning duration to zero milliseconds 4710 # Tune a single GEMM and verify that we get a new tuning result 4711 import os 4712 os.environ["PYTORCH_TUNABLEOP_MAX_TUNING_DURATION_MS"] = "0" 4713 self.assertGreater(torch.cuda.tunable.get_max_tuning_iterations(), 0) 4714 os.environ["PYTORCH_TUNABLEOP_MAX_TUNING_DURATION_MS"] = "30" # reset to default 4715 4716 # Reference number of results 4717 ref_num_results = len(torch.cuda.tunable.get_results()) 4718 4719 N = M = K = 8 4720 A = torch.randn(N, K, device=device, dtype=dtype) 4721 B = torch.randn(K, M, device=device, dtype=dtype) 4722 C = torch.matmul(A, B) 4723 4724 # This stores total number of cummulative results 4725 total_num_results = len(torch.cuda.tunable.get_results()) 4726 4727 # There must be a new tuning result 4728 self.assertEqual((total_num_results - ref_num_results), 1) 4729 4730 # Set tuning iterations to zero 4731 # Tune a single GEMM and verify that we get a new tuning result 4732 os.environ["PYTORCH_TUNABLEOP_MAX_TUNING_ITERATIONS"] = "0" 4733 self.assertGreater(torch.cuda.tunable.get_max_tuning_iterations(), 0) 4734 os.environ["PYTORCH_TUNABLEOP_MAX_TUNING_ITERATIONS"] = "100" # reset to default 4735 4736 # Reference number of results 4737 ref_num_results = total_num_results 4738 4739 N = M = K = 16 4740 A = torch.randn(N, K, device=device, dtype=dtype) 4741 B = torch.randn(K, M, device=device, dtype=dtype) 4742 C = torch.matmul(A, B) 4743 4744 # This stores total number of cummulative results 4745 total_num_results = len(torch.cuda.tunable.get_results()) 4746 4747 # There must be a new tuning result 4748 self.assertEqual((total_num_results - ref_num_results), 1) 4749 4750 finally: 4751 # disable TunableOp 4752 torch.cuda.tunable.enable(False) 4753 4754 # clean up, remove any file that was generated 4755 try: 4756 import os 4757 filename = torch.cuda.tunable.get_filename() 4758 os.remove(filename) 4759 except FileNotFoundError: 4760 pass 4761 4762 @onlyCUDA 4763 @dtypes(torch.half) 4764 def test_matmul_check_entries_tunableop(self, device, dtype): 4765 # Tune a couple of matrix multiplies 4766 # Verify we get the correct number of results 4767 4768 try: 4769 set_tunableop_defaults() 4770 torch.cuda.tunable.enable() 4771 # set these to single iterations to keep it short but still exercise the code 4772 torch.cuda.tunable.set_max_tuning_iterations(1) 4773 4774 # Reference number of results 4775 ref_num_results = len(torch.cuda.tunable.get_results()) 4776 4777 # Execute matrix multiplies. We intentionally throw in M list the same index 4778 # twice. The CSV file should only get unique GEMMs 4779 count_matmul = 4 4780 K = 64 4781 for M in [32, 64, 32]: 4782 for N in [32, 64]: 4783 A = torch.randn(N, K, device=device, dtype=dtype) 4784 B = torch.randn(K, M, device=device, dtype=dtype) 4785 C = torch.matmul(A, B) 4786 4787 # This stores total number of cummulative results 4788 total_num_results = len(torch.cuda.tunable.get_results()) 4789 4790 # Take the difference to calculate the number of results from 4791 # the this test and verify that it agrees with the number of 4792 # GEMMs. 4793 self.assertEqual((total_num_results - ref_num_results), count_matmul) 4794 4795 finally: 4796 # disable TunableOp 4797 torch.cuda.tunable.enable(False) 4798 4799 # clean up, remove any file that was generated 4800 try: 4801 import os 4802 filename = torch.cuda.tunable.get_filename() 4803 os.remove(filename) 4804 except FileNotFoundError: 4805 pass 4806 4807 @onlyCUDA 4808 @skipCUDAIfNotRocm 4809 @dtypes(torch.float) 4810 def test_bmm_tunableop_rocm(self, device, dtype): 4811 # buffer rotation (on by default) with strided batched gemm tunableop was causing a mem fault 4812 torch.cuda.tunable.enable(True) 4813 ordinal = torch.cuda.current_device() 4814 filename = f"tunableop_results{ordinal}.csv" 4815 torch.cuda.tunable.set_filename(filename) 4816 iterations = torch.cuda.tunable.get_max_tuning_iterations() 4817 torch.cuda.tunable.set_max_tuning_iterations(10) 4818 # the following 3 cases cover all previous failure cases and are here to catch regressions 4819 B = 16 4820 N = M = K = 256 4821 dtype = torch.bfloat16 4822 device = torch.device("cuda:0") 4823 # case 1 4824 i1 = torch.randn((B, N, M), device=device, dtype=dtype) 4825 i2 = torch.randn((B, M, K), device=device, dtype=dtype) 4826 out = torch.bmm(i1, i2) 4827 # case 2 4828 i1 = torch.randn((B, N, M), device=device, dtype=dtype) 4829 i1 = torch.permute(i1, (1, 2, 0)) 4830 i2 = torch.randn((B, M, K), device=device, dtype=dtype) 4831 i2 = torch.permute(i2, (1, 0, 2)) 4832 out = torch.bmm(i1, i2) 4833 # case 3 4834 i1 = torch.randn((N, B, M), device=device, dtype=dtype) 4835 i1 = torch.permute(i1, (1, 0, 2)) 4836 i2 = torch.randn((M, B, K), device=device, dtype=dtype) 4837 i2 = torch.permute(i2, (1, 2, 0)) 4838 out = torch.bmm(i1, i2) 4839 # clean up, remove any file that was generated 4840 try: 4841 import os 4842 os.remove(filename) 4843 except FileNotFoundError: 4844 pass 4845 # reset back to prior settings 4846 torch.cuda.tunable.set_max_tuning_iterations(iterations) 4847 torch.cuda.tunable.enable(False) 4848 4849 @onlyCUDA 4850 @skipCUDAIfNotRocm 4851 @dtypes(torch.float) 4852 def test_numeric_check_leak_tunableop_rocm(self, device, dtype): 4853 from torch.testing._internal.common_utils import CudaMemoryLeakCheck 4854 import os 4855 # run operator first without tuning to ensure all rocm libs are loaded, 4856 # otherwise false positive mem leak 4857 B = 16 4858 N = M = K = 256 4859 dtype = torch.bfloat16 4860 device = torch.device("cuda:0") 4861 i1 = torch.randn((B, N, M), device=device, dtype=dtype) 4862 i2 = torch.randn((B, M, K), device=device, dtype=dtype) 4863 out = torch.bmm(i1, i2) 4864 # enable tunableop numeric check via env variable. 4865 PYTORCH_TUNABLEOP_NUMERICAL_CHECK = "PYTORCH_TUNABLEOP_NUMERICAL_CHECK" 4866 prev_val = os.getenv(PYTORCH_TUNABLEOP_NUMERICAL_CHECK) 4867 try: 4868 os.environ[PYTORCH_TUNABLEOP_NUMERICAL_CHECK] = "1" 4869 torch.cuda.tunable.enable(True) 4870 ordinal = torch.cuda.current_device() 4871 filename = f"tunableop_results{ordinal}.csv" 4872 torch.cuda.tunable.set_filename(filename) 4873 iterations = torch.cuda.tunable.get_max_tuning_iterations() 4874 torch.cuda.tunable.set_max_tuning_iterations(10) 4875 with CudaMemoryLeakCheck(self): 4876 out = torch.bmm(i1, i2) 4877 torch.cuda.tunable.set_max_tuning_iterations(iterations) 4878 torch.cuda.tunable.enable(False) 4879 # clean up, remove any file that was generated 4880 try: 4881 os.remove(filename) 4882 except FileNotFoundError: 4883 pass 4884 finally: 4885 if prev_val is None: 4886 del os.environ[PYTORCH_TUNABLEOP_NUMERICAL_CHECK] 4887 else: 4888 os.environ[PYTORCH_TUNABLEOP_NUMERICAL_CHECK] = prev_val 4889 4890 4891 @dtypes(torch.float, torch.complex64) 4892 def test_matmul_out_kernel_errors_with_autograd(self, device, dtype): 4893 a = torch.empty((256, 512), device=device, dtype=dtype, requires_grad=True).unsqueeze(0) 4894 b = torch.empty((4, 128, 512), device=device, dtype=dtype, requires_grad=True).transpose(-1, -2) 4895 c = torch.empty((256, 4, 128), device=device, dtype=dtype).movedim(1, 0) 4896 4897 torch.matmul(a.detach(), b.detach(), out=c) 4898 4899 with self.assertRaisesRegex(RuntimeError, "functions with out=... arguments don't support automatic differentiation"): 4900 torch.matmul(a, b, out=c) 4901 4902 with torch.no_grad(): 4903 torch.matmul(a, b, out=c) 4904 4905 # 4GB should do, but we run tests in parallel in CI, so let's be generous 4906 @largeTensorTest('16GB', device='cuda') 4907 def test_large_bmm_mm_backward(self, device): 4908 A = torch.randn([1024, 2, 1024], device="cuda").mT.contiguous().mT 4909 B = torch.randn([1024, 65536], device="cuda", requires_grad=True) 4910 G = torch.randn([1024, 2, 65536], device="cuda") 4911 4912 # Should not create an intermediary tensor of size [1024, 1024, 65536] (256GB of memory) and OOM 4913 (A @ B).backward(G) 4914 4915 # 4GB should do, but we run tests in parallel in CI, so let's be generous 4916 @largeTensorTest('16GB', device='cuda') 4917 def test_large_bmm_backward(self, device): 4918 A = torch.randn([1024, 2, 1024], device="cuda").mT.contiguous().mT 4919 B = torch.randn([1, 1024, 65536], device="cuda", requires_grad=True) 4920 G = torch.randn([1024, 2, 65536], device="cuda") 4921 4922 # Should not create an intermediary tensor of size [1024, 1024, 65536] (256GB of memory) and OOM 4923 (A @ B).backward(G) 4924 4925 def test_linear_algebra_scalar_raises(self, device) -> None: 4926 m = torch.randn(5, 5, device=device) 4927 v = torch.randn(5, device=device) 4928 s = torch.tensor(7, device=device) 4929 self.assertRaises(RuntimeError, lambda: torch.mv(m, s)) 4930 self.assertRaises(RuntimeError, lambda: torch.addmv(v, m, s)) 4931 4932 @dtypes(torch.float32, torch.complex64) 4933 def test_cross(self, device, dtype): 4934 x = torch.rand(100, 3, 100, dtype=dtype, device=device) 4935 y = torch.rand(100, 3, 100, dtype=dtype, device=device) 4936 res1 = torch.cross(x, y) 4937 res2 = torch.tensor((), dtype=dtype, device=device) 4938 torch.cross(x, y, out=res2) 4939 self.assertEqual(res1, res2) 4940 4941 @dtypes(torch.float32, torch.complex64) 4942 def test_linalg_cross(self, device, dtype): 4943 x = torch.rand(100, 3, 100, dtype=dtype, device=device) 4944 y = torch.rand(100, 3, 100, dtype=dtype, device=device) 4945 res1 = torch.linalg.cross(x, y, dim=1) 4946 res2 = torch.tensor((), dtype=dtype, device=device) 4947 torch.linalg.cross(x, y, dim=1, out=res2) 4948 self.assertEqual(res1, res2) 4949 4950 # test for broadcastable inputs 4951 x = torch.rand(1, 3, 2, dtype=dtype, device=device) 4952 y = torch.rand(4, 3, 1, dtype=dtype, device=device) 4953 res1 = torch.linalg.cross(x, y, dim=1) 4954 res2 = torch.tensor((), dtype=dtype, device=device) 4955 torch.linalg.cross(x, y, dim=1, out=res2) 4956 self.assertEqual(res1, res2) 4957 4958 @dtypes(torch.float32, torch.complex64) 4959 def test_cross_with_and_without_dim(self, device, dtype): 4960 x = torch.rand(100, 3, dtype=dtype, device=device) 4961 y = torch.rand(100, 3, dtype=dtype, device=device) 4962 res1 = torch.cross(x, y, dim=1) 4963 res2 = torch.cross(x, y, dim=-1) 4964 res3 = torch.cross(x, y) 4965 self.assertEqual(res1, res2) 4966 self.assertEqual(res1, res3) 4967 4968 @dtypes(torch.float32, torch.complex64) 4969 def test_linalg_cross_with_and_without_dim(self, device, dtype): 4970 x = torch.rand(100, 3, dtype=dtype, device=device) 4971 y = torch.rand(100, 3, dtype=dtype, device=device) 4972 res1 = torch.linalg.cross(x, y, dim=1) 4973 res2 = torch.linalg.cross(x, y, dim=-1) 4974 res3 = torch.linalg.cross(x, y) 4975 self.assertEqual(res1, res2) 4976 self.assertEqual(res1, res3) 4977 4978 def test_renorm(self, device): 4979 m1 = torch.randn(20, 20, device=device) # big enough to exercise vectorized path 4980 res1 = torch.tensor((), device=device) 4981 4982 def renorm(matrix, value, dim, max_norm): 4983 m1 = matrix.transpose(dim, 0).contiguous() 4984 # collapse non-dim dimensions. 4985 m2 = m1.clone().resize_(m1.size(0), int(math.floor(m1.nelement() / m1.size(0)))) 4986 norms = m2.norm(value, 1, True) 4987 # clip 4988 new_norms = norms.clone() 4989 new_norms[torch.gt(norms, max_norm)] = max_norm 4990 new_norms.div_(norms.add_(1e-7)) 4991 # renormalize 4992 m1.mul_(new_norms.expand_as(m1)) 4993 return m1.transpose(dim, 0) 4994 4995 # note that the axis fed to torch.renorm is different (2~=1) 4996 maxnorm = m1.norm(2, 1).mean() 4997 m2 = renorm(m1, 2, 1, maxnorm) 4998 m1.renorm_(2, 1, maxnorm) 4999 self.assertEqual(m1, m2, atol=1e-5, rtol=0) 5000 self.assertEqual(m1.norm(2, 0), m2.norm(2, 0), atol=1e-5, rtol=0) 5001 5002 m1 = torch.randn(3, 4, 5, device=device) 5003 m2 = m1.transpose(1, 2).contiguous().clone().resize_(15, 4) 5004 maxnorm = m2.norm(2, 0).mean() 5005 m2 = renorm(m2, 2, 1, maxnorm) 5006 m1.renorm_(2, 1, maxnorm) 5007 m3 = m1.transpose(1, 2).contiguous().clone().resize_(15, 4) 5008 self.assertEqual(m3, m2) 5009 self.assertEqual(m3.norm(2, 0), m2.norm(2, 0)) 5010 5011 @skipCPUIfNoLapack 5012 @skipCUDAIfNoCusolver 5013 @dtypes(*floating_and_complex_types()) 5014 def test_ormqr(self, device, dtype): 5015 5016 def run_test(batch, m, n, fortran_contiguous): 5017 A = make_tensor((*batch, m, n), dtype=dtype, device=device) 5018 reflectors, tau = torch.geqrf(A) 5019 if not fortran_contiguous: 5020 self.assertTrue(reflectors.mT.is_contiguous()) 5021 reflectors = reflectors.contiguous() 5022 5023 # Q is of size m x m 5024 Q, _ = torch.linalg.qr(A, mode='complete') 5025 C_right = make_tensor((*batch, m, n), dtype=dtype, device=device) 5026 C_left = make_tensor((*batch, n, m), dtype=dtype, device=device) 5027 5028 expected = Q @ C_right 5029 actual = torch.ormqr(reflectors, tau, C_right, left=True, transpose=False) 5030 self.assertEqual(expected, actual) 5031 5032 expected = C_left @ Q 5033 actual = torch.ormqr(reflectors, tau, C_left, left=False, transpose=False) 5034 self.assertEqual(expected, actual) 5035 5036 expected = Q.mH @ C_right 5037 actual = torch.ormqr(reflectors, tau, C_right, left=True, transpose=True) 5038 self.assertEqual(expected, actual) 5039 5040 expected = C_left @ Q.mH 5041 actual = torch.ormqr(reflectors, tau, C_left, left=False, transpose=True) 5042 self.assertEqual(expected, actual) 5043 5044 # if tau is all zeros then the implicit matrix Q is the identity matrix 5045 # so the actual result should be C_right in this case 5046 zero_tau = torch.zeros_like(tau) 5047 actual = torch.ormqr(reflectors, zero_tau, C_right, left=True, transpose=False) 5048 self.assertEqual(C_right, actual) 5049 5050 batches = [(), (0, ), (2, ), (2, 1)] 5051 ns = [5, 2, 0] 5052 for batch, (m, n), fortran_contiguous in product(batches, product(ns, ns), [True, False]): 5053 run_test(batch, m, n, fortran_contiguous) 5054 5055 @skipCPUIfNoLapack 5056 @skipCUDAIfNoCusolver 5057 @dtypes(*floating_and_complex_types()) 5058 def test_ormqr_errors_and_warnings(self, device, dtype): 5059 test_cases = [ 5060 # input1 size, input2 size, input3 size, error regex 5061 ((10,), (2,), (2,), r"input must have at least 2 dimensions"), 5062 ((2, 2), (2,), (2,), r"other must have at least 2 dimensions"), 5063 ((10, 6), (20,), (10, 6), r"other.shape\[-2\] must be greater than or equal to tau.shape\[-1\]"), 5064 ((6, 6), (5,), (5, 5), r"other.shape\[-2\] must be equal to input.shape\[-2\]"), 5065 ((1, 2, 2), (2, 2), (1, 2, 2), r"batch dimensions of tau to be equal to input.shape\[:-2\]"), 5066 ((1, 2, 2), (1, 2), (2, 2, 2), r"batch dimensions of other to be equal to input.shape\[:-2\]"), 5067 ] 5068 for a_size, tau_size, c_size, error_regex in test_cases: 5069 a = make_tensor(a_size, dtype=dtype, device=device) 5070 tau = make_tensor(tau_size, dtype=dtype, device=device) 5071 c = make_tensor(c_size, dtype=dtype, device=device) 5072 with self.assertRaisesRegex(RuntimeError, error_regex): 5073 torch.ormqr(a, tau, c) 5074 5075 def test_blas_empty(self, device): 5076 def fn(torchfn, *args, test_out=False, **kwargs): 5077 def call_torch_fn(*args, **kwargs): 5078 return torchfn(*tuple(torch.randn(shape, device=device) if isinstance(shape, tuple) else shape 5079 for shape in args), **kwargs) 5080 result = call_torch_fn(*args, **kwargs) 5081 if not test_out: 5082 return result 5083 else: 5084 out = torch.full_like(result, math.nan) 5085 out1 = call_torch_fn(*args, **kwargs, out=out) 5086 return out 5087 5088 # mm, addmm 5089 self.assertEqual((0, 0), fn(torch.mm, (0, 0), (0, 0)).shape) 5090 self.assertEqual((0, 5), fn(torch.mm, (0, 0), (0, 5)).shape) 5091 self.assertEqual((5, 0), fn(torch.mm, (5, 0), (0, 0)).shape) 5092 self.assertEqual((3, 0), fn(torch.mm, (3, 2), (2, 0)).shape) 5093 self.assertEqual(torch.zeros((5, 6), device=device), fn(torch.mm, (5, 0), (0, 6))) 5094 self.assertEqual(torch.zeros((5, 6), device=device), fn(torch.mm, (5, 0), (0, 6), test_out=True)) 5095 5096 self.assertEqual((0, 0), fn(torch.addmm, (0, 0), (0, 0), (0, 0)).shape) 5097 self.assertEqual((0, 1), fn(torch.addmm, (1, ), (0, 17), (17, 1)).shape) 5098 t = torch.randn((5, 6), device=device) 5099 self.assertEqual(t, fn(torch.addmm, t, (5, 0), (0, 6))) 5100 self.assertEqual(t, fn(torch.addmm, t, (5, 0), (0, 6), test_out=True)) 5101 5102 # mv, addmv 5103 self.assertEqual((0,), fn(torch.mv, (0, 0), (0,)).shape) 5104 self.assertEqual((0,), fn(torch.mv, (0, 2), (2,)).shape) 5105 self.assertEqual(torch.zeros((3,), device=device), fn(torch.mv, (3, 0), (0,))) 5106 self.assertEqual(torch.zeros((3,), device=device), fn(torch.mv, (3, 0), (0,), test_out=True)) 5107 5108 self.assertEqual((0,), fn(torch.addmv, (0,), (0, 0), (0,)).shape) 5109 t = torch.randn((3,), device=device) 5110 self.assertEqual(t, fn(torch.addmv, t, (3, 0), (0,))) 5111 self.assertEqual(t, fn(torch.addmv, t, (3, 0), (0,), test_out=True)) 5112 5113 # bmm, baddbmm 5114 self.assertEqual((0, 0, 0), fn(torch.bmm, (0, 0, 0), (0, 0, 0)).shape) 5115 self.assertEqual((3, 0, 5), fn(torch.bmm, (3, 0, 0), (3, 0, 5)).shape) 5116 self.assertEqual((0, 5, 6), fn(torch.bmm, (0, 5, 0), (0, 0, 6)).shape) 5117 self.assertEqual(torch.zeros((3, 5, 6), device=device), fn(torch.bmm, (3, 5, 0), (3, 0, 6))) 5118 self.assertEqual(torch.zeros((3, 5, 6), device=device), fn(torch.bmm, (3, 5, 0), (3, 0, 6), test_out=True)) 5119 5120 self.assertEqual((0, 0, 0), fn(torch.baddbmm, (0, 0, 0), (0, 0, 0), (0, 0, 0)).shape) 5121 self.assertEqual((3, 0, 5), fn(torch.baddbmm, (3, 0, 5), (3, 0, 0), (3, 0, 5)).shape) 5122 self.assertEqual((0, 5, 6), fn(torch.baddbmm, (0, 5, 6), (0, 5, 0), (0, 0, 6)).shape) 5123 self.assertEqual((3, 5, 6), fn(torch.baddbmm, (3, 5, 6), (3, 5, 0), (3, 0, 6)).shape) 5124 c = torch.arange(30, dtype=torch.float32, device=device).reshape(3, 2, 5) 5125 self.assertEqual(-2 * c, fn(torch.baddbmm, c, (3, 2, 0), (3, 0, 5), beta=-2)) # Issue #33467 5126 self.assertEqual(-2 * c, fn(torch.baddbmm, c, (3, 2, 0), (3, 0, 5), beta=-2, test_out=True)) # Issue #33467 5127 5128 # addbmm 5129 self.assertEqual((0, 0), fn(torch.addbmm, (0, 0), (0, 0, 0), (0, 0, 0)).shape) 5130 self.assertEqual((0, 5), fn(torch.addbmm, (0, 5), (3, 0, 0), (3, 0, 5)).shape) 5131 t = torch.randn((5, 6), device=device) 5132 self.assertEqual(t, fn(torch.addbmm, t, (0, 5, 0), (0, 0, 6))) 5133 self.assertEqual(t, fn(torch.addbmm, t, (0, 5, 0), (0, 0, 6), test_out=True)) 5134 5135 # matmul 5136 self.assertEqual(torch.tensor(0., device=device), fn(torch.matmul, (0,), (0,))) 5137 self.assertEqual(torch.tensor(0., device=device), fn(torch.matmul, (0,), (0,), test_out=True)) 5138 self.assertEqual((0, 0), fn(torch.matmul, (0, 0), (0, 0)).shape) 5139 self.assertEqual((0, 0, 0), fn(torch.matmul, (0, 0, 0), (0, 0, 0)).shape) 5140 self.assertEqual((5, 0, 0), fn(torch.matmul, (5, 0, 0), (5, 0, 0)).shape) 5141 self.assertEqual(torch.zeros((5, 3, 4), device=device), fn(torch.matmul, (5, 3, 0), (5, 0, 4))) 5142 self.assertEqual(torch.zeros((5, 3, 4), device=device), fn(torch.matmul, (5, 3, 0), (5, 0, 4), test_out=True)) 5143 5144 # dot 5145 self.assertEqual(torch.tensor(0., device=device), fn(torch.dot, (0,), (0,))) 5146 self.assertEqual(torch.tensor(0., device=device), fn(torch.dot, (0,), (0,), test_out=True)) 5147 5148 @precisionOverride({torch.double: 1e-8, torch.float: 1e-4, torch.bfloat16: 0.6, 5149 torch.half: 1e-1, torch.cfloat: 1e-4, torch.cdouble: 1e-8}) 5150 @dtypesIfCUDA(*floating_and_complex_types_and( 5151 torch.half, 5152 *[torch.bfloat16] if SM53OrLater else [] 5153 )) 5154 @dtypes(*all_types_and_complex_and(torch.bfloat16)) 5155 def test_corner_cases_of_cublasltmatmul(self, device, dtype): 5156 # common case 5157 M = torch.randn(128, device=device).to(dtype) 5158 m1 = torch.randn(2048, 2400, device=device).to(dtype) 5159 m2 = torch.randn(128, 2400, device=device).to(dtype) 5160 torch.nn.functional.linear(m1, m2, M) 5161 # Ntrans_B has ld >> rows 5162 m1 = torch.rand([128, 2400]).to(dtype).to(device).t() 5163 m2 = torch.rand([2048, 25272]).to(dtype).to(device).t()[21940:24340] 5164 M = torch.rand([128]).to(dtype).to(device) 5165 torch.addmm(M, m2.t(), m1) 5166 # trans_A has ld >> rows 5167 m1 = torch.rand([128, 25272]).to(dtype).to(device)[:, 21940:24340].t() 5168 m2 = torch.randn(2048, 2400, device=device).to(dtype) 5169 M = torch.rand([128]).to(dtype).to(device) 5170 torch.addmm(M, m2, m1) 5171 # large tensor dim > 65535 5172 M = torch.randn(16, device=device).to(dtype) 5173 m1 = torch.randn(32, 131071 , device=device).to(dtype) 5174 m2 = torch.randn(16, 131071, device=device).to(dtype) 5175 torch.nn.functional.linear(m1, m2, M) 5176 5177 @onlyCUDA 5178 @skipCUDAIfNotRocm 5179 @dtypes(*floating_types_and(torch.bfloat16, torch.half)) 5180 def test_hipblaslt_corner_cases_rocm(self, device, dtype): 5181 if dtype == torch.double: 5182 raise unittest.SkipTest("hipblasLt doesn't support doubles yet") 5183 5184 # enable hipblaslt path via env variable. 5185 import os 5186 DISABLE_ADDMM_HIP_LT = "DISABLE_ADDMM_HIP_LT" 5187 prev_val = os.getenv(DISABLE_ADDMM_HIP_LT) 5188 try: 5189 os.environ[DISABLE_ADDMM_HIP_LT] = "0" 5190 # common case 5191 M = torch.randn(128, device=device, dtype=dtype) 5192 m1 = torch.randn(2048, 2400, device=device, dtype=dtype) 5193 m2 = torch.randn(128, 2400, device=device, dtype=dtype) 5194 out1 = torch.nn.functional.linear(m1, m2, M) 5195 M_cpu = M.to('cpu') 5196 m1_cpu = m1.to('cpu') 5197 m2_cpu = m2.to('cpu') 5198 out1_cpu = torch.nn.functional.linear(m1_cpu, m2_cpu, M_cpu) 5199 self.assertTrue(torch.allclose(out1_cpu, out1.cpu(), rtol=1e-2, atol=1e-2)) 5200 5201 # common case without bias 5202 m1 = torch.randn(2048, 2400, device=device, dtype=dtype) 5203 m2 = torch.randn(128, 2400, device=device, dtype=dtype) 5204 out2 = torch.nn.functional.linear(m1, m2, bias=None) 5205 m1_cpu = m1.to('cpu') 5206 m2_cpu = m2.to('cpu') 5207 out2_cpu = torch.nn.functional.linear(m1_cpu, m2_cpu, bias=None) 5208 self.assertTrue(torch.allclose(out2_cpu, out2.cpu(), rtol=1e-2, atol=1e-2)) 5209 finally: 5210 if prev_val is None: 5211 del os.environ[DISABLE_ADDMM_HIP_LT] 5212 else: 5213 os.environ[DISABLE_ADDMM_HIP_LT] = prev_val 5214 5215 @dtypesIfCUDA(*floating_and_complex_types_and( 5216 torch.half, 5217 *[torch.bfloat16] if SM53OrLater else [] 5218 )) 5219 @dtypes(*all_types_and_complex_and(torch.bfloat16, torch.half)) 5220 def test_blas_alpha_beta_empty(self, device, dtype): 5221 # This test is disabled on CUDA 9 due to: 5222 # See: https://github.com/pytorch/pytorch/issues/31006 5223 if dtype is torch.bfloat16 and self.device_type == 'xla': 5224 # TODO (@zasdfgbnm): this causes the following error on test 5225 # TestTorchDeviceTypeXLA.test_blas_alpha_beta_empty_xla_bfloat16: 5226 # 5227 # RuntimeError: _th_equal not supported on CPUType for BFloat16 5228 return 5229 # ensure beta is respected 5230 value = 11 5231 input = torch.full((2,), value, dtype=dtype, device=device) 5232 mat = torch.ones((2, 0), dtype=dtype, device=device) 5233 vec = torch.ones((0,), dtype=dtype, device=device) 5234 out = torch.empty((2,), dtype=dtype, device=device) 5235 if dtype.is_complex: 5236 alpha = 6 + 7j 5237 beta = 3 + 4j 5238 else: 5239 alpha = 6 5240 beta = 3 5241 self.assertEqual(torch.full((2,), beta * value, dtype=dtype, device=device), 5242 torch.addmv(input=input, mat=mat, vec=vec, alpha=alpha, beta=beta)) 5243 self.assertEqual(torch.full((2,), beta * value, dtype=dtype, device=device), 5244 torch.addmv(input=input, mat=mat, vec=vec, alpha=alpha, beta=beta, out=out)) 5245 5246 # torch.addmm 5247 input = torch.full((2, 3), value, dtype=dtype, device=device) 5248 mat2 = torch.ones((0, 3), dtype=dtype, device=device) 5249 out = torch.empty((2, 3), dtype=dtype, device=device) 5250 self.assertEqual(torch.full((2, 3), beta * value, dtype=dtype, device=device), 5251 torch.addmm(input=input, mat1=mat, mat2=mat2, alpha=alpha, beta=beta)) 5252 self.assertEqual(torch.full((2, 3), beta * value, dtype=dtype, device=device), 5253 torch.addmm(input=input, mat1=mat, mat2=mat2, alpha=alpha, beta=beta, out=out)) 5254 5255 @dtypes(*floating_and_complex_types_and(torch.half, torch.bfloat16)) 5256 def test_blas_nan_out(self, device, dtype): 5257 # These functions should work correctly with NaN filled outputs, 5258 # but need special handling, see [NOTE: cpu_zero] 5259 b = 3 5260 n = 5 5261 m = 7 5262 p = 11 5263 5264 # torch.mv 5265 nm = torch.randn((m, n), device=device).t() 5266 _m = torch.randn((), device=device).expand(m) 5267 _m_out = torch.full((m,), float('nan'), device=device) 5268 self.assertEqual(torch.mv(nm, _m), torch.mv(nm, _m, out=_m_out)) 5269 self.assertEqual(0, torch.isnan(torch.mv(nm, _m)).sum()) 5270 5271 # torch.mm 5272 mp = torch.randn((p, m), device=device).t() 5273 np_out = torch.full((n, p), float('nan'), device=device) 5274 self.assertEqual(torch.mm(nm, mp), torch.mm(nm, mp, out=np_out)) 5275 5276 # torch.bmm 5277 bnm = torch.randn((b, m, n), device=device).transpose(1, 2) 5278 bmp = torch.randn((b, p, m), device=device).transpose(1, 2) 5279 bnp_out = torch.full((b, n, p), float('nan'), device=device) 5280 self.assertEqual(torch.bmm(bnm, bmp), torch.bmm(bnm, bmp, out=bnp_out)) 5281 5282 @onlyCPU # not supported by CUBLAS 5283 def test_blas_mv_large_input(self, device): 5284 # This would previously fail if the allocated output had NaNs, see: 5285 # https://github.com/pytorch/pytorch/issues/31663 and [NOTE: cpu_zero] 5286 n = 3000 5287 m = 200 5288 5289 nm = torch.randn((m, n), device=device).t() 5290 _m = torch.randn((), device=device).expand(m) 5291 _m_out = torch.full((m,), 0., device=device) 5292 5293 self.assertEqual(torch.mv(nm, _m), torch.mv(nm, _m, out=_m_out)) 5294 5295 @onlyCPU 5296 def test_renorm_ps(self, device): 5297 # full reduction 5298 x = torch.randn(5, 5) 5299 xn = x.numpy() 5300 for p in [1, 2, 3, 4, inf]: 5301 res = x.renorm(p, 1, 1) 5302 expected = x / x.norm(p, 0, keepdim=True).clamp(min=1) 5303 self.assertEqual(res, expected, msg=f"renorm failed for {p}-norm") 5304 5305 @skipCPUIfNoLapack 5306 @skipCUDAIfNoCusolver 5307 @dtypes(*floating_and_complex_types()) 5308 def test_householder_product(self, device, dtype): 5309 def generate_reflectors_and_tau(A): 5310 """ 5311 This function uses numpy.linalg.qr with mode "raw" to extract output of LAPACK's geqrf. 5312 There is torch.geqrf function but it doesn't work with complex-valued input. 5313 """ 5314 if A.numel() > 0: 5315 A_cpu = A.cpu() 5316 flattened_batch_shape = [-1, *A_cpu.shape[-2:]] 5317 reflectors = torch.empty_like(A_cpu).view(*flattened_batch_shape) 5318 tau_shape = [*A_cpu.shape[:-2], A_cpu.shape[-1]] 5319 tau = torch.empty(tau_shape, dtype=dtype).view(-1, A_cpu.shape[-1]) 5320 for A_i, reflectors_i, tau_i in zip(A_cpu.contiguous().view(*flattened_batch_shape), reflectors, tau): 5321 reflectors_tmp, tau_i[:] = map(torch.from_numpy, np.linalg.qr(A_i, mode='raw')) 5322 reflectors_i[:] = reflectors_tmp.T 5323 reflectors = reflectors.view(*A_cpu.shape) 5324 tau = tau.view(tau_shape) 5325 return reflectors.to(A.device), tau.to(A.device) 5326 5327 reflectors = torch.empty_like(A) 5328 tau = torch.empty(*A.shape[:-2], A.shape[-1], dtype=dtype, device=device) 5329 return reflectors, tau 5330 5331 def run_test(shape): 5332 A = torch.randn(*shape, dtype=dtype, device=device) 5333 reflectors, tau = generate_reflectors_and_tau(A) 5334 expected, _ = torch.linalg.qr(A) 5335 actual = torch.linalg.householder_product(reflectors, tau) 5336 # torch.linalg.qr does not work correctly for zero batch dimension tensors 5337 # see https://github.com/pytorch/pytorch/issues/50576 5338 if (A.numel() > 0): 5339 self.assertEqual(expected, actual) 5340 else: 5341 self.assertTrue(actual.shape == shape) 5342 5343 # if tau is empty and A is not the result should be a matrix with ones on the diagonal 5344 if (A.numel() > 0): 5345 tau_empty = torch.empty(*shape[:-2], 0, dtype=dtype, device=device) 5346 identity_mat = torch.zeros_like(reflectors) 5347 identity_mat.diagonal(dim1=-1, dim2=-2)[:] = 1 5348 actual = torch.linalg.householder_product(reflectors, tau_empty) 5349 self.assertEqual(actual, identity_mat) 5350 5351 out = torch.empty_like(A) 5352 ans = torch.linalg.householder_product(reflectors, tau, out=out) 5353 self.assertEqual(ans, out) 5354 if (A.numel() > 0): 5355 self.assertEqual(expected, out) 5356 5357 shapes = [(0, 0), (5, 0), # Empty matrix 5358 (5, 5), (5, 3), # Single matrix 5359 (0, 0, 0), (0, 5, 5), (0, 5, 3), # Zero batch dimension tensors 5360 (2, 5, 5), (2, 5, 3), # 3-dim tensors 5361 (2, 1, 5, 5), (2, 1, 5, 3)] # 4-dim tensors 5362 for shape in shapes: 5363 run_test(shape) 5364 5365 @skipCPUIfNoLapack 5366 @skipCUDAIfNoCusolver 5367 def test_householder_product_errors_and_warnings(self, device): 5368 test_cases = [ 5369 # input1 size, input2 size, error regex 5370 ((10,), (2,), r"input must have at least 2 dimensions"), 5371 ((10, 6), (20,), r"input.shape\[-1\] must be greater than or equal to tau.shape\[-1\]"), 5372 ((6, 10), (5,), r"input.shape\[-2\] must be greater than or equal to input.shape\[-1\]"), 5373 ] 5374 for a_size, tau_size, error_regex in test_cases: 5375 a = torch.rand(*a_size, device=device) 5376 tau = torch.rand(*tau_size, device=device) 5377 with self.assertRaisesRegex(RuntimeError, error_regex): 5378 torch.linalg.householder_product(a, tau) 5379 5380 # if out tensor with wrong shape is passed a warning is given 5381 reflectors = torch.randn(3, 3, device=device) 5382 tau = torch.randn(3, device=device) 5383 out = torch.empty(2, 3, device=device) 5384 with warnings.catch_warnings(record=True) as w: 5385 # Trigger warning 5386 torch.linalg.householder_product(reflectors, tau, out=out) 5387 # Check warning occurs 5388 self.assertEqual(len(w), 1) 5389 self.assertTrue("An output with one or more elements was resized" in str(w[-1].message)) 5390 5391 # dtypes should be safely castable 5392 out = torch.empty_like(reflectors).to(torch.int) 5393 with self.assertRaisesRegex(RuntimeError, "but got result with dtype Int"): 5394 torch.linalg.householder_product(reflectors, tau, out=out) 5395 5396 with self.assertRaisesRegex(RuntimeError, "tau dtype Int does not match input dtype"): 5397 torch.linalg.householder_product(reflectors, tau.to(torch.int)) 5398 5399 if torch.cuda.is_available(): 5400 # device of out and input should match 5401 wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda' 5402 out = torch.empty_like(reflectors).to(wrong_device) 5403 with self.assertRaisesRegex(RuntimeError, "Expected all tensors to be on the same device"): 5404 torch.linalg.householder_product(reflectors, tau, out=out) 5405 5406 # device of tau and input should match 5407 wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda' 5408 tau = tau.to(wrong_device) 5409 with self.assertRaisesRegex(RuntimeError, "Expected all tensors to be on the same device"): 5410 torch.linalg.householder_product(reflectors, tau) 5411 5412 @precisionOverride({torch.float32: 1e-2, torch.complex64: 1e-2}) 5413 @skipCUDAIfNoMagmaAndNoCusolver 5414 @skipIfTorchDynamo("Runtime error with torch._C._linalg.linalg_lu_factor") 5415 @skipCPUIfNoLapack 5416 @dtypes(*floating_and_complex_types()) 5417 def test_linalg_lu_family(self, device, dtype): 5418 # Tests torch.lu 5419 # torch.linalg.lu_factor 5420 # torch.linalg.lu_factor_ex 5421 # torch.lu_unpack 5422 # torch.linalg.lu_solve 5423 # torch.linalg.solve 5424 make_arg_full = partial(make_fullrank_matrices_with_distinct_singular_values, device=device, dtype=dtype) 5425 make_arg = partial(make_tensor, device=device, dtype=dtype) 5426 5427 def run_test(A, pivot, singular, fn): 5428 k = min(A.shape[-2:]) 5429 batch = A.shape[:-2] 5430 check_errors = (fn == torch.linalg.lu_factor) 5431 if singular and check_errors: 5432 # It may or may not throw as the LU decomposition without pivoting 5433 # may still succeed for singular matrices 5434 try: 5435 LU, pivots = fn(A, pivot=pivot) 5436 except RuntimeError: 5437 return 5438 else: 5439 LU, pivots = fn(A, pivot=pivot)[:2] 5440 5441 self.assertEqual(LU.size(), A.shape) 5442 self.assertEqual(pivots.size(), batch + (k,)) 5443 5444 if not pivot: 5445 self.assertEqual(pivots, torch.arange(1, 1 + k, device=device, dtype=torch.int32).expand(batch + (k, ))) 5446 5447 P, L, U = torch.lu_unpack(LU, pivots, unpack_pivots=pivot) 5448 5449 self.assertEqual(P @ L @ U if pivot else L @ U, A) 5450 5451 PLU = torch.linalg.lu(A, pivot=pivot) 5452 self.assertEqual(P, PLU.P) 5453 self.assertEqual(L, PLU.L) 5454 self.assertEqual(U, PLU.U) 5455 5456 if not singular and A.size(-2) == A.size(-1): 5457 nrhs = ((), (1,), (3,)) 5458 for left, rhs in product((True, False), nrhs): 5459 # Vector case when left = False is not allowed 5460 if not left and rhs == (): 5461 continue 5462 if left: 5463 shape_B = A.shape[:-1] + rhs 5464 else: 5465 shape_B = A.shape[:-2] + rhs + A.shape[-1:] 5466 B = make_arg(shape_B) 5467 5468 # Test linalg.lu_solve. It does not support vectors as rhs 5469 # See https://github.com/pytorch/pytorch/pull/74045#issuecomment-1112304913 5470 if rhs != (): 5471 for adjoint in (True, False): 5472 X = torch.linalg.lu_solve(LU, pivots, B, left=left, adjoint=adjoint) 5473 A_adj = A.mH if adjoint else A 5474 if left: 5475 self.assertEqual(B, A_adj @ X) 5476 else: 5477 self.assertEqual(B, X @ A_adj) 5478 5479 # Test linalg.solve 5480 X = torch.linalg.solve(A, B, left=left) 5481 X_ = X.unsqueeze(-1) if rhs == () else X 5482 B_ = B.unsqueeze(-1) if rhs == () else B 5483 if left: 5484 self.assertEqual(B_, A @ X_) 5485 else: 5486 self.assertEqual(B_, X_ @ A) 5487 5488 5489 sizes = ((3, 3), (5, 5), (4, 2), (3, 4), (0, 0), (0, 1), (1, 0)) 5490 batches = ((0,), (), (1,), (2,), (3,), (1, 0), (3, 5)) 5491 # Non pivoting just implemented for CUDA 5492 pivots = (True, False) if self.device_type == "cuda" else (True,) 5493 fns = (partial(torch.lu, get_infos=True), torch.linalg.lu_factor, torch.linalg.lu_factor_ex) 5494 for ms, batch, pivot, singular, fn in itertools.product(sizes, batches, pivots, (True, False), fns): 5495 shape = batch + ms 5496 A = make_arg(shape) if singular else make_arg_full(*shape) 5497 # Just do one of them on singular matrices 5498 if A.numel() == 0 and not singular: 5499 continue 5500 run_test(A, pivot, singular, fn) 5501 5502 # Reproducer of a magma bug, 5503 # see https://bitbucket.org/icl/magma/issues/13/getrf_batched-kernel-produces-nans-on 5504 # This is also a bug in cuSOLVER < 11.3 5505 if (dtype == torch.double 5506 and singular): 5507 A = torch.ones(batch + ms, dtype=dtype, device=device) 5508 run_test(A, pivot, singular, fn) 5509 5510 # Info should be positive for rank deficient matrices 5511 A = torch.ones(5, 3, 3, device=device) 5512 self.assertTrue((torch.linalg.lu_factor_ex(A, pivot=True).info >= 0).all()) 5513 5514 if self.device_type == 'cpu': 5515 # Error checking, no pivoting variant on CPU 5516 fns = [torch.lu, torch.linalg.lu_factor, torch.linalg.lu_factor_ex, torch.linalg.lu] 5517 for f in fns: 5518 with self.assertRaisesRegex(RuntimeError, 'LU without pivoting is not implemented on the CPU'): 5519 f(torch.empty(1, 2, 2), pivot=False) 5520 5521 5522 @precisionOverride({torch.float32: 1e-2, torch.complex64: 1e-2}) 5523 @skipCUDAIfNoMagmaAndNoCusolver 5524 @skipCPUIfNoLapack 5525 @setLinalgBackendsToDefaultFinally 5526 @dtypes(*floating_and_complex_types()) 5527 def test_linalg_lu_solve(self, device, dtype): 5528 make_arg = partial(make_tensor, dtype=dtype, device=device) 5529 5530 backends = ["default"] 5531 5532 if torch.device(device).type == 'cuda': 5533 if torch.cuda.has_magma: 5534 backends.append("magma") 5535 if has_cusolver(): 5536 backends.append("cusolver") 5537 5538 def gen_matrices(): 5539 rhs = 3 5540 ns = (5, 2, 0) 5541 batches = ((), (0,), (1,), (2,), (2, 1), (0, 2)) 5542 for batch, n in product(batches, ns): 5543 yield make_arg(batch + (n, n)), make_arg(batch + (n, rhs)) 5544 # Shapes to exercise all the paths 5545 shapes = ((1, 64), (2, 128), (1025, 2)) 5546 for b, n in shapes: 5547 yield make_arg((b, n, n)), make_arg((b, n, rhs)) 5548 5549 5550 for A, B in gen_matrices(): 5551 LU, pivots = torch.linalg.lu_factor(A) 5552 for backend in backends: 5553 torch.backends.cuda.preferred_linalg_library(backend) 5554 5555 for left, adjoint in product((True, False), repeat=2): 5556 B_left = B if left else B.mT 5557 X = torch.linalg.lu_solve(LU, pivots, B_left, left=left, adjoint=adjoint) 5558 A_adj = A.mH if adjoint else A 5559 if left: 5560 self.assertEqual(B_left, A_adj @ X) 5561 else: 5562 self.assertEqual(B_left, X @ A_adj) 5563 5564 5565 @onlyCPU 5566 @dtypes(*floating_and_complex_types()) 5567 def test_linalg_lu_cpu_errors(self, device, dtype): 5568 # Square tests 5569 sample = torch.randn(3, 2, 2, device=device, dtype=dtype) 5570 B = torch.randn(3, 2, 2, device=device, dtype=dtype) 5571 LU, pivots = torch.linalg.lu_factor(sample) 5572 5573 # This should run without issues 5574 torch.linalg.lu_solve(LU, pivots, B, adjoint=True) 5575 torch.lu_unpack(LU, pivots) 5576 5577 pivots[0] = 0 5578 with self.assertRaisesRegex(RuntimeError, r"greater or equal to 1"): 5579 torch.linalg.lu_solve(LU, pivots, B, adjoint=True) 5580 with self.assertRaisesRegex(RuntimeError, r"between 1 and LU.size\(-2\)."): 5581 torch.lu_unpack(LU, pivots) 5582 5583 pivots[0] = 3 5584 with self.assertRaisesRegex(RuntimeError, r"smaller or equal to LU.size\(-2\)"): 5585 torch.linalg.lu_solve(LU, pivots, B, adjoint=True) 5586 with self.assertRaisesRegex(RuntimeError, r"between 1 and LU.size\(-2\)."): 5587 torch.lu_unpack(LU, pivots) 5588 5589 # Rectangular tests 5590 sample = torch.randn(3, 4, 2, device=device, dtype=dtype) 5591 B = torch.randn(3, 4, 2, device=device, dtype=dtype) 5592 LU, pivots = torch.linalg.lu_factor(sample) 5593 5594 # This should run without issues 5595 torch.lu_unpack(LU, pivots) 5596 5597 pivots[0] = 0 5598 with self.assertRaisesRegex(RuntimeError, r"between 1 and LU.size\(-2\)."): 5599 torch.lu_unpack(LU, pivots) 5600 5601 pivots[0] = 5 5602 with self.assertRaisesRegex(RuntimeError, r"between 1 and LU.size\(-2\)."): 5603 torch.lu_unpack(LU, pivots) 5604 5605 5606 # Rectangular tests 5607 sample = torch.randn(2, 3, 5, device=device, dtype=dtype) 5608 B = torch.randn(2, 3, 5, device=device, dtype=dtype) 5609 LU, pivots = torch.linalg.lu_factor(sample) 5610 5611 # This should run without issues 5612 torch.lu_unpack(LU, pivots) 5613 5614 pivots[0] = 0 5615 with self.assertRaisesRegex(RuntimeError, r"between 1 and LU.size\(-2\)."): 5616 torch.lu_unpack(LU, pivots) 5617 5618 pivots[0] = 4 5619 with self.assertRaisesRegex(RuntimeError, r"between 1 and LU.size\(-2\)."): 5620 torch.lu_unpack(LU, pivots) 5621 5622 5623 @skipCPUIfNoLapack 5624 @skipCUDAIfNoMagma 5625 @dtypes(torch.double) 5626 def test_lu_unpack_check_input(self, device, dtype): 5627 x = torch.rand(5, 5, 5, device=device, dtype=dtype) 5628 lu_data, lu_pivots = torch.linalg.lu_factor(x) 5629 5630 with self.assertRaisesRegex(RuntimeError, "torch.int32 dtype"): 5631 torch.lu_unpack(lu_data, lu_pivots.long()) 5632 5633 # check that onces flags are unset, Nones are returned 5634 p, l, u = torch.lu_unpack(lu_data, lu_pivots, unpack_data=False) 5635 self.assertTrue(l.numel() == 0 and u.numel() == 0) 5636 p, l, u = torch.lu_unpack(lu_data, lu_pivots, unpack_pivots=False) 5637 self.assertTrue(p.numel() == 0) 5638 p, l, u = torch.lu_unpack(lu_data, lu_pivots, unpack_data=False, unpack_pivots=False) 5639 self.assertTrue(p.numel() == 0 and l.numel() == 0 and u.numel() == 0) 5640 5641 @skipCUDAIfNoMagma 5642 @skipCPUIfNoLapack 5643 @dtypes(torch.double) 5644 def test_lobpcg_basic(self, device, dtype): 5645 self._test_lobpcg_method(device, dtype, 'basic') 5646 5647 @skipCUDAIfNoCusolver 5648 @skipCPUIfNoLapack 5649 @dtypes(torch.double) 5650 def test_lobpcg_ortho(self, device, dtype): 5651 if torch.version.hip: 5652 torch.backends.cuda.preferred_linalg_library('magma') 5653 self._test_lobpcg_method(device, dtype, 'ortho') 5654 if torch.version.hip: 5655 torch.backends.cuda.preferred_linalg_library('default') 5656 5657 def _test_lobpcg_method(self, device, dtype, method): 5658 from torch.testing._internal.common_utils import random_symmetric_pd_matrix, random_sparse_pd_matrix 5659 from torch._linalg_utils import matmul, qform 5660 from torch._lobpcg import lobpcg 5661 5662 def test_tracker(worker): 5663 k = worker.iparams['k'] 5664 nc = worker.ivars['converged_count'] 5665 if k <= nc: 5666 tol = worker.fparams['tol'] 5667 rerr = worker.tvars['rerr'] 5668 X = worker.X 5669 E = worker.E 5670 B = worker.B 5671 A = worker.A 5672 dtype = X.dtype 5673 device = X.device 5674 5675 # Check convergence 5676 self.assertLessEqual(rerr[:k].max(), tol) 5677 5678 # Check B-orthogonality 5679 I = torch.eye(k, k, dtype=dtype, device=device) 5680 self.assertEqual(qform(B, X[:, :k]), I) 5681 5682 # Check block equation 5683 self.assertEqual(qform(A, X[:, :k]) / E[:k], I, atol=0.2, rtol=0) 5684 5685 orig_lobpcg = lobpcg 5686 5687 def lobpcg(*args, **kwargs): 5688 kwargs['tracker'] = test_tracker 5689 kwargs['niter'] = 1000 5690 kwargs['method'] = method 5691 kwargs['tol'] = 1e-8 5692 return orig_lobpcg(*args, **kwargs) 5693 prec = 5e-4 5694 5695 # check dense input 5696 mm = torch.matmul 5697 for batches in [(), (2,), (2, 3)]: 5698 for m, n, k in [ 5699 (9, 3, 1), 5700 (9, 3, 2), 5701 (9, 2, 2), 5702 (100, 15, 5), 5703 ]: 5704 # skip tests that are known to fail with the basic 5705 # LOBPCG method due to calling cholesky on singular 5706 # input 5707 if method == 'basic' and (m, n, k) in [(9, 2, 2), (100, 15, 5)]: 5708 continue 5709 A = random_symmetric_pd_matrix(m, *batches, device=device, dtype=dtype) 5710 B = random_symmetric_pd_matrix(m, *batches, device=device, dtype=dtype) 5711 5712 # classical eigenvalue problem, smallest eigenvalues 5713 E, V = lobpcg(A, k=k, n=n, largest=False) 5714 self.assertEqual(E.shape, batches + (k,)) 5715 self.assertEqual(V.shape, batches + (m, k)) 5716 self.assertEqual(matmul(A, V), mm(V, E.diag_embed()), atol=prec, rtol=0) 5717 e = torch.linalg.eigvalsh(A) 5718 e_smallest = e[..., :k] 5719 self.assertEqual(E, e_smallest) 5720 5721 # classical eigenvalue problem, largest eigenvalues 5722 E, V = lobpcg(A, k=k, n=n, largest=True) 5723 e_largest, _ = torch.sort(e[..., -k:], descending=True) 5724 self.assertEqual(E, e_largest, atol=prec, rtol=0) 5725 self.assertEqual(matmul(A, V), mm(V, E.diag_embed()), atol=prec, rtol=0) 5726 5727 # generalized eigenvalue problem, smallest eigenvalues 5728 E, V = lobpcg(A, B=B, k=k, n=n, largest=False) 5729 self.assertEqual(matmul(A, V), mm(matmul(B, V), E.diag_embed()), atol=prec, rtol=0) 5730 5731 # generalized eigenvalue problem, largest eigenvalues 5732 E, V = lobpcg(A, B=B, k=k, n=n, largest=True) 5733 self.assertEqual(matmul(A, V) / E.max(), mm(matmul(B, V), (E / E.max()).diag_embed()), 5734 atol=prec, rtol=0) 5735 5736 # check sparse input 5737 for m, n, k, density in [ 5738 (5, 1, 1, 0.8), 5739 (9, 3, 2, 0.5), 5740 (100, 1, 1, 0.1), 5741 (1000, 7, 3, 0.01), 5742 ]: 5743 # skip tests that are known to fail with the basic LOBCG 5744 # method due to insufficient accuracy 5745 if method == 'basic' and (m, n, k, density) in [(1000, 7, 3, 0.01)]: 5746 continue 5747 A = random_sparse_pd_matrix(m, density=density, device=device, dtype=dtype) 5748 B = random_sparse_pd_matrix(m, density=density, device=device, dtype=dtype) 5749 A_eigenvalues = torch.arange(1, m + 1, dtype=dtype) / m 5750 e_smallest = A_eigenvalues[..., :k] 5751 e_largest, _ = torch.sort(A_eigenvalues[..., -k:], descending=True) 5752 5753 # classical eigenvalue problem, smallest eigenvalues 5754 E, V = lobpcg(A, k=k, n=n, largest=False) 5755 self.assertEqual(E, e_smallest) 5756 self.assertEqual(matmul(A, V), mm(V, E.diag_embed()), atol=prec, rtol=0) 5757 5758 # classical eigenvalue problem, largest eigenvalues 5759 E, V = lobpcg(A, k=k, n=n, largest=True) 5760 self.assertEqual(matmul(A, V), mm(V, E.diag_embed()), atol=prec, rtol=0) 5761 self.assertEqual(E, e_largest) 5762 5763 # generalized eigenvalue problem, smallest eigenvalues 5764 E, V = lobpcg(A, B=B, k=k, n=n, largest=False) 5765 self.assertEqual(matmul(A, V), matmul(B, mm(V, E.diag_embed())), atol=prec, rtol=0) 5766 5767 # generalized eigenvalue problem, largest eigenvalues 5768 E, V = lobpcg(A, B=B, k=k, n=n, largest=True) 5769 self.assertEqual(matmul(A, V) / E.max(), mm(matmul(B, V), (E / E.max()).diag_embed()), 5770 atol=prec, rtol=0) 5771 5772 @skipCPUIfNoLapack 5773 @onlyCPU 5774 @dtypes(torch.double) 5775 def test_lobpcg_torchscript(self, device, dtype): 5776 from torch.testing._internal.common_utils import random_sparse_pd_matrix 5777 from torch._linalg_utils import matmul as mm 5778 5779 lobpcg = torch.jit.script(torch.lobpcg) 5780 5781 m = 500 5782 k = 5 5783 A1 = random_sparse_pd_matrix(m, density=2.0 / m, device=device, dtype=dtype) 5784 X1 = torch.randn((m, k), dtype=dtype, device=device) 5785 E1, V1 = lobpcg(A1, X=X1) 5786 eq_err = torch.norm((mm(A1, V1) - V1 * E1), 2) / E1.max() 5787 self.assertLess(eq_err, 1e-6) 5788 5789 @unittest.skipIf(not TEST_SCIPY or (TEST_SCIPY and scipy.__version__ < '1.4.1'), "Scipy not found or older than 1.4.1") 5790 @skipCPUIfNoLapack 5791 @skipIfTorchDynamo("fails in tracing scipy.sparse.lobpcg") 5792 @onlyCPU 5793 @dtypes(torch.double) 5794 def test_lobpcg_scipy(self, device, dtype): 5795 """Compare torch and scipy.sparse.linalg implementations of lobpcg 5796 """ 5797 import time 5798 from torch.testing._internal.common_utils import random_sparse_pd_matrix 5799 from torch._linalg_utils import matmul as mm 5800 from scipy.sparse.linalg import lobpcg as scipy_lobpcg 5801 import scipy.sparse 5802 5803 def toscipy(A): 5804 if A.layout == torch.sparse_coo: 5805 values = A.coalesce().values().cpu().numpy().copy() 5806 indices = A.coalesce().indices().cpu().numpy().copy() 5807 return scipy.sparse.coo_matrix((values, (indices[0], indices[1])), A.shape) 5808 return A.cpu().numpy().copy() 5809 5810 niter = 1000 5811 repeat = 10 5812 m = 500 # size of the square matrix 5813 k = 7 # the number of requested eigenpairs 5814 A1 = random_sparse_pd_matrix(m, density=2.0 / m, device=device, dtype=dtype) 5815 B1 = random_sparse_pd_matrix(m, density=2.0 / m, device=device, dtype=dtype) 5816 X1 = torch.randn((m, k), dtype=dtype, device=device) 5817 5818 A2 = toscipy(A1) 5819 B2 = toscipy(B1) 5820 X2 = toscipy(X1) 5821 5822 lambdas1 = [] 5823 5824 def tracker(worker): 5825 lambdas1.append(worker.E[:]) 5826 5827 tol = 1e-8 5828 # tol for scipy lobpcg will be choosed so that the number of 5829 # iterations will be equal or very close to pytorch lobpcg 5830 # (that is around 170-180) 5831 5832 # Standard eigenvalue problem 5833 E1, V1 = torch.lobpcg(A1, X=X1, niter=niter, largest=True, tracker=tracker, tol=tol) 5834 E2, V2, lambdas2 = scipy_lobpcg(A2, X2, maxiter=niter, largest=True, retLambdaHistory=True, tol=1.1 * tol) 5835 iters1 = len(lambdas1) 5836 iters2 = len(lambdas2) 5837 self.assertLess(abs(iters1 - iters2), 0.05 * max(iters1, iters2)) 5838 5839 E2a, V2a = scipy_lobpcg(A2, X2, maxiter=niter, largest=False) 5840 5841 eq_err = torch.norm((mm(A1, V1) - V1 * E1), 2) / E1.max() 5842 eq_err_scipy = (abs(A2.dot(V2) - V2 * E2)**2).sum() ** 0.5 / E2.max() 5843 self.assertLess(eq_err, 1e-6) # std 5844 self.assertLess(eq_err_scipy, 1e-6) # std 5845 5846 self.assertEqual(E1, torch.from_numpy(E2.copy())) 5847 5848 # Generalized eigenvalue problem 5849 lambdas1 = [] 5850 5851 def tracker(worker): 5852 lambdas1.append(worker.E[:]) 5853 5854 E1, V1 = torch.lobpcg(A1, B=B1, X=X1, niter=niter, largest=True, tracker=tracker, tol=tol) 5855 E2, V2, lambdas2 = scipy_lobpcg(A2, X2, B=B2, maxiter=niter, largest=True, retLambdaHistory=True, tol=39 * tol) 5856 E2a, V2a = scipy_lobpcg(A2, X2, B=B2, maxiter=niter, largest=False) 5857 iters1 = len(lambdas1) 5858 iters2 = len(lambdas2) 5859 self.assertLess(abs(iters1 - iters2), 0.05 * max(iters1, iters2)) 5860 5861 eq_err = torch.norm((mm(A1, V1) - mm(B1, V1) * E1), 2) / E1.max() 5862 eq_err_scipy = (abs(A2.dot(V2) - B2.dot(V2) * E2)**2).sum() ** 0.5 / E2.max() 5863 self.assertLess(eq_err, 1e-6) # general 5864 self.assertLess(eq_err_scipy, 1e-6) # general 5865 5866 self.assertEqual(E1, torch.from_numpy(E2.copy())) 5867 5868 # Timings 5869 elapsed_ortho = 0 5870 elapsed_ortho_general = 0 5871 elapsed_scipy = 0 5872 elapsed_general_scipy = 0 5873 for i in range(repeat): 5874 start = time.time() 5875 torch.lobpcg(A1, X=X1, niter=niter, method='ortho', tol=tol) 5876 end = time.time() 5877 elapsed_ortho += end - start 5878 5879 start = time.time() 5880 torch.lobpcg(A1, X=X1, B=B1, niter=niter, method='ortho', tol=tol) 5881 end = time.time() 5882 elapsed_ortho_general += end - start 5883 5884 start = time.time() 5885 scipy_lobpcg(A2, X2, maxiter=niter, tol=1.1 * tol) 5886 end = time.time() 5887 elapsed_scipy += end - start 5888 5889 start = time.time() 5890 scipy_lobpcg(A2, X2, B=B2, maxiter=niter, tol=39 * tol) 5891 end = time.time() 5892 elapsed_general_scipy += end - start 5893 5894 elapsed_ortho_ms = 1000.0 * elapsed_ortho / repeat 5895 elapsed_ortho_general_ms = 1000.0 * elapsed_ortho_general / repeat 5896 elapsed_scipy_ms = 1000.0 * elapsed_scipy / repeat 5897 elapsed_general_scipy_ms = 1000.0 * elapsed_general_scipy / repeat 5898 5899 print(f''' 5900CPU timings: torch.lobpcg vs scipy.sparse.linalg.lobpcg 5901------------------------------------------------------- 5902 | standard | generalized | method 5903torch.lobpcg | {elapsed_ortho_ms:10.2f} | {elapsed_ortho_general_ms:10.2f} | ortho 5904scipy_lobpcg | {elapsed_scipy_ms:10.2f} | {elapsed_general_scipy_ms:10.2f} | N/A 5905-(input size: {m:4}, eigenpairs:{k:2}, units: ms per call)- 5906 ''') 5907 5908 # Handling of very small tolerence 5909 tol = 1e-100 5910 5911 lambdas1 = [] 5912 5913 def tracker(worker): 5914 lambdas1.append(worker.E[:]) 5915 5916 E1, V1 = torch.lobpcg(A1, X=X1, niter=niter, largest=True, tracker=tracker, tol=tol) 5917 iters1 = len(lambdas1) 5918 eq_err = torch.norm((mm(A1, V1) - V1 * E1), 2) / E1.max() 5919 5920 try: 5921 E2, V2, lambdas2 = scipy_lobpcg(A2, X2, maxiter=niter, largest=True, retLambdaHistory=True, tol=tol) 5922 iters2 = len(lambdas2) 5923 eq_err_scipy = (abs(A2.dot(V2) - V2 * E2)**2).sum() ** 0.5 / E2.max() 5924 except Exception as msg: 5925 print('Calling scipy_lobpcg failed [standard]:', msg) 5926 iters2 = -1 5927 eq_err_scipy = -1 5928 5929 lambdas1 = [] 5930 5931 def tracker(worker): 5932 lambdas1.append(worker.E[:]) 5933 5934 E1, V1 = torch.lobpcg(A1, X=X1, B=B1, niter=niter, largest=True, tracker=tracker, tol=tol) 5935 iters1_general = len(lambdas1) 5936 eq_err_general = torch.norm((mm(A1, V1) - mm(B1, V1) * E1), 2) / E1.max() 5937 5938 try: 5939 E2, V2, lambdas2 = scipy_lobpcg(A2, X2, B=B2, maxiter=niter, largest=True, retLambdaHistory=True, tol=tol) 5940 iters2_general = len(lambdas2) 5941 eq_err_general_scipy = (abs(A2.dot(V2) - B2.dot(V2) * E2)**2).sum() ** 0.5 / E2.max() 5942 except Exception as msg: 5943 print('Calling scipy_lobpcg failed [generalized]:', msg) 5944 iters2_general = -1 5945 eq_err_general_scipy = -1 5946 5947 print(f'''\ 5948Handling of small tol={tol:6.0e}: torch.lobpcg vs scipy.sparse.linalg.lobpcg 5949---------------------------------------------------------------------------- 5950 | standard | generalized | niter | method 5951torch.lobpcg | {eq_err:10.2e} | {eq_err_general:10.2e} | {iters1:6} | ortho 5952scipy_lobpcg | {eq_err_scipy:10.2e} | {eq_err_general_scipy:10.2e} | {iters2:6} | N/A 5953---(input size: {m:4}, eigenpairs:{k:2}, units: relative error, maxiter={niter:4})--- 5954''') 5955 5956 def _test_addmm_addmv(self, f, t, m, v, *, alpha=None, beta=None, transpose_out=False, activation=None): 5957 dtype = t.dtype 5958 numpy_dtype = dtype 5959 if dtype in {torch.bfloat16, torch.half}: 5960 numpy_dtype = torch.float 5961 if dtype.is_complex: 5962 alpha = 0.9 + 0.3j if alpha is None else alpha 5963 beta = 0.5 + 0.6j if beta is None else beta 5964 else: 5965 alpha = 1.2 if alpha is None else alpha 5966 beta = 0.8 if beta is None else beta 5967 if activation == "gelu": 5968 res1 = f(t, m, v, alpha=alpha, beta=beta, use_gelu=True) 5969 else: 5970 res1 = f(t, m, v, alpha=alpha, beta=beta) 5971 res2 = torch.full_like(res1, math.nan) 5972 if transpose_out: 5973 res2 = res2.t().clone(memory_format=torch.contiguous_format).t() 5974 if activation == "gelu": 5975 f(t, m, v, alpha=alpha, beta=beta, out=res2, use_gelu=True) 5976 else: 5977 f(t, m, v, alpha=alpha, beta=beta, out=res2) 5978 res3 = alpha * (m.to(numpy_dtype).cpu().numpy() @ v.to(numpy_dtype).cpu().numpy()) 5979 if beta != 0: 5980 res3 += (beta * t).to(numpy_dtype).cpu().numpy() 5981 if activation == "relu": 5982 res3 = res3 * (res3 > 0) 5983 elif activation == "gelu": 5984 res3_t = torch.from_numpy(res3).to(dtype) 5985 approximate = "tanh" if t.is_cuda else "none" 5986 res3_t = torch.nn.functional.gelu(res3_t, approximate=approximate) 5987 res3 = res3_t.to(numpy_dtype).cpu().numpy() 5988 else: 5989 assert activation is None, f"unsupported activation {activation}" 5990 res3 = torch.from_numpy(res3).to(dtype) 5991 self.assertEqual(res1, res2) 5992 self.assertEqual(res1, res3) 5993 5994 @precisionOverride({torch.bfloat16: 1e-0, torch.half: 1e-3, torch.float: 1e-4, torch.double: 1e-8, 5995 torch.cfloat: 1e-4, torch.cdouble: 1e-8}) 5996 @dtypesIfCUDA(*floating_and_complex_types_and( 5997 *[torch.bfloat16] if TEST_WITH_ROCM or SM53OrLater else [], 5998 torch.half)) 5999 @dtypes(torch.bfloat16, torch.half, torch.float, torch.double, torch.cfloat, torch.cdouble) 6000 def test_addmv(self, device, dtype): 6001 if IS_ARM64 and device == 'cpu' and dtype == torch.float16: 6002 raise unittest.SkipTest("Fails on ARM, see https://github.com/pytorch/pytorch/issues/125438") 6003 # have to use torch.randn(...).to(bfloat16) instead of 6004 # torch.randn(..., dtype=bfloat16). randn does not support 6005 # bfloat16 yet. 6006 # "*0.2" to reduce errors for low precision 6007 ts = [ 6008 0.2 * torch.randn(50, device=device).to(dtype), 6009 0.2 * torch.randn(1, device=device).to(dtype).expand(50), 6010 ] 6011 vs = [ 6012 0.2 * torch.randn(100, device=device).to(dtype), 6013 0.2 * torch.ones(1, device=device).to(dtype).expand(100), # to reduce errors for low precision 6014 ] 6015 ms = [ 6016 # 0d 6017 0.2 * torch.ones((), device=device).to(dtype).expand(50, 100), # to reduce errors for low precision 6018 # 1d 6019 0.2 * torch.randn((1, 100), device=device).to(dtype).expand(50, 100), 6020 # this initialization reduces errors for low precision for broadcasted matrices 6021 # by making sure that intermediate and result values are exactly representable 6022 # in low precision type 6023 0.2 * torch.randint(3, (50, 1), dtype=torch.float, device=device).to(dtype).expand(50, 100), 6024 # 2d 6025 0.2 * torch.randn((50, 100), device=device).to(dtype), 6026 0.2 * torch.randn((100, 50), device=device).to(dtype).t(), 6027 ] 6028 for m, v, t in itertools.product(ms, vs, ts): 6029 self._test_addmm_addmv(torch.addmv, t, m, v) 6030 # Test beta=0, t=nan 6031 t = torch.full((50,), math.nan, device=device).to(dtype) 6032 for m, v in itertools.product(ms, vs): 6033 self._test_addmm_addmv(torch.addmv, t, m, v, beta=0) 6034 6035 @dtypesIfCUDA(*floating_types_and(*[torch.bfloat16] if TEST_WITH_ROCM or 6036 SM53OrLater else [])) 6037 @dtypes(torch.float, torch.double) 6038 def test_addmv_rowmajor_colmajor_incx_incy_lda(self, device, dtype): 6039 # tests (o, s)*(s). o is output size, s is summed size. 6040 o = 5 6041 s = 3 6042 a_data = torch.arange(1, o * s + 1, device=device, dtype=dtype).view(o, s) 6043 x_data = torch.arange(1, s + 1, 1, device=device, dtype=dtype) 6044 y_data = torch.ones(o, device=device, dtype=dtype) 6045 control = torch.tensor([15., 33., 51., 69., 87.], device=device, dtype=dtype) 6046 6047 def _test(row_major, incx, incy, lda_tail): 6048 if row_major: 6049 a_storage = torch.full((o, s + lda_tail), float('nan'), device=device, dtype=dtype) 6050 else: 6051 a_storage = torch.full((s, o + lda_tail), float('nan'), device=device, dtype=dtype).permute(1, 0) 6052 a = a_storage[:o, :s].copy_(a_data) 6053 6054 x_storage = torch.full((s, incx), float('nan'), device=device, dtype=dtype) 6055 x = x_storage[:, 0].copy_(x_data) 6056 6057 y_storage = torch.full((o, incy), float('nan'), device=device, dtype=dtype) 6058 y = y_storage[:, 0].copy_(y_data) 6059 6060 self._test_addmm_addmv(torch.addmv, y, a, x) 6061 6062 for row_major, incx, incy, lda_tail in itertools.product((False, True), (1, 2), (1, 2), (0, 1)): 6063 _test(row_major, incx, incy, lda_tail) 6064 6065 def _test_addmm_impl(self, func, activation, device, dtype): 6066 M = torch.randn(10, 25, device=device).to(dtype) 6067 m1 = torch.randn(10, 50, device=device).to(dtype) 6068 m2 = torch.randn(50, 25, device=device).to(dtype) 6069 self._test_addmm_addmv(func, M, m1, m2, activation=activation) 6070 6071 # vector-shaped bias and beta=1 result in epilogue fusion in CUDA 6072 V = torch.randn(25, device=device).to(dtype) 6073 self._test_addmm_addmv(func, V, m1, m2, beta=1, activation=activation) 6074 6075 # Test 0-strided 6076 M = torch.randn(10, 1, device=device).to(dtype).expand(10, 25) 6077 m1 = torch.randn(10, 1, device=device).to(dtype).expand(10, 50) 6078 m2 = torch.randn(50, 25, device=device).to(dtype) 6079 self._test_addmm_addmv(func, M, m1, m2, activation=activation) 6080 6081 # Test beta=0, M=nan 6082 M = torch.full((10, 25), math.nan, device=device).to(dtype) 6083 m1 = torch.randn(10, 50, device=device).to(dtype) 6084 m2 = torch.randn(50, 25, device=device).to(dtype) 6085 self._test_addmm_addmv(func, M, m1, m2, beta=0, activation=activation) 6086 6087 # Test transpose 6088 for t1, t2, t3, t4 in itertools.product([True, False], repeat=4): 6089 def maybe_transpose(cond, m): 6090 if not cond: 6091 return m 6092 return m.t().clone(memory_format=torch.contiguous_format).t() 6093 6094 M = maybe_transpose(t1, torch.randn(10, 25, device=device).to(dtype)) 6095 m1 = maybe_transpose(t2, torch.randn(10, 50, device=device).to(dtype)) 6096 m2 = maybe_transpose(t3, torch.randn(50, 25, device=device).to(dtype)) 6097 self._test_addmm_addmv(func, M, m1, m2, transpose_out=t4, activation=activation) 6098 6099 if t1: 6100 # use vector V instead of matrix M for epilogue fusion in CUDA (doesn't depend on t1) 6101 self._test_addmm_addmv(func, V, m1, m2, beta=1, transpose_out=t4, activation=activation,) 6102 6103 @precisionOverride({torch.double: 1e-8, torch.float: 1e-4, torch.bfloat16: 0.6, 6104 torch.half: 1e-1, torch.cfloat: 1e-4, torch.cdouble: 1e-8}) 6105 @dtypesIfMPS(torch.float32) 6106 @dtypesIfCUDA(*floating_and_complex_types_and( 6107 *[torch.bfloat16] if TEST_WITH_ROCM or SM53OrLater else [])) 6108 @dtypes(*floating_and_complex_types_and(torch.bfloat16, torch.half)) 6109 @tf32_on_and_off(0.05) 6110 @bf32_on_and_off(0.05) 6111 def test_addmm(self, device, dtype): 6112 self._test_addmm_impl(torch.addmm, None, device, dtype) 6113 6114 @precisionOverride({torch.double: 1e-8, torch.float: 1e-4, torch.bfloat16: 5e-2, 6115 torch.half: 5e-2, torch.cfloat: 1e-4, torch.cdouble: 1e-8}) 6116 @dtypesIfCUDA(*floating_types_and( 6117 *[torch.bfloat16, torch.half] if TEST_WITH_ROCM or SM53OrLater else [])) 6118 @dtypes(*floating_types_and(torch.bfloat16)) 6119 @tf32_on_and_off(0.05) 6120 @bf32_on_and_off(0.05) 6121 def test_addmm_relu(self, device, dtype): 6122 self._test_addmm_impl(torch._addmm_activation, "relu", device, dtype) 6123 6124 @onlyCUDA 6125 @skipCUDAIfNotRocm 6126 @precisionOverride({torch.double: 1e-8, torch.float: 1e-4, torch.bfloat16: 5e-2, 6127 torch.half: 5e-2, torch.cfloat: 1e-4, torch.cdouble: 1e-8}) 6128 @dtypesIfCUDA(*floating_types_and( 6129 *[torch.bfloat16, torch.half] if TEST_WITH_ROCM or SM53OrLater else [])) 6130 @dtypes(*floating_types_and(torch.bfloat16)) 6131 @tf32_on_and_off(0.05) 6132 @bf32_on_and_off(0.05) 6133 def test_addmm_relu_tunableop_rocm(self, device, dtype): 6134 torch.cuda.tunable.enable(True) 6135 ordinal = torch.cuda.current_device() 6136 filename = f"tunableop_results{ordinal}.csv" 6137 torch.cuda.tunable.set_filename(filename) 6138 iterations = torch.cuda.tunable.get_max_tuning_iterations() 6139 torch.cuda.tunable.set_max_tuning_iterations(10) 6140 self._test_addmm_impl(torch._addmm_activation, "relu", device, dtype) 6141 # clean up, remove any file that was generated 6142 try: 6143 import os 6144 os.remove(filename) 6145 except FileNotFoundError: 6146 pass 6147 # reset back to prior settings 6148 torch.cuda.tunable.set_max_tuning_iterations(iterations) 6149 torch.cuda.tunable.enable(False) 6150 6151 @precisionOverride({torch.double: 1e-8, torch.float: 1e-4, torch.bfloat16: 5e-2, 6152 torch.half: 5e-2, torch.cfloat: 1e-4, torch.cdouble: 1e-8}) 6153 @dtypesIfCUDA(*floating_types_and( 6154 *[torch.bfloat16, torch.half] if TEST_WITH_ROCM or SM53OrLater else [])) 6155 @dtypes(*floating_types_and(torch.bfloat16)) 6156 @tf32_on_and_off(0.05) 6157 @bf32_on_and_off(0.05) 6158 def test_addmm_gelu(self, device, dtype): 6159 self._test_addmm_impl(torch._addmm_activation, "gelu", device, dtype) 6160 6161 @dtypes(torch.float, torch.double) 6162 @dtypesIfCUDA(*floating_and_complex_types()) 6163 @tf32_on_and_off(0.005) 6164 @bf32_on_and_off(0.005) 6165 def test_addmm_sizes(self, device, dtype): 6166 for m in [0, 1, 25]: 6167 for n in [0, 1, 10]: 6168 for k in [0, 1, 8]: 6169 M = torch.randn(n, m, device=device).to(dtype) 6170 m1 = torch.randn(n, k, device=device).to(dtype) 6171 m2 = torch.randn(k, m, device=device).to(dtype) 6172 self._test_addmm_addmv(torch.addmm, M, m1, m2) 6173 6174 m1 = torch.randn(n, k + 1, device=device).to(dtype) 6175 m2 = torch.randn(k, m, device=device).to(dtype) 6176 self.assertRaisesRegex(RuntimeError, f"{n}x{k + 1}.*{k}x{m}", lambda: torch.addmm(M, m1, m2)) 6177 self.assertRaisesRegex(RuntimeError, f"{n}x{k + 1}.*{k}x{m}", lambda: torch.mm(m1, m2)) 6178 6179 @dtypes(torch.half) 6180 @onlyCUDA 6181 def test_addmm_baddbmm_overflow(self, device, dtype): 6182 orig = torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction 6183 torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False 6184 inp = torch.zeros(128, 128, dtype=torch.half, device=device) 6185 mat1 = torch.ones(128, 1000, dtype=torch.half, device=device) * 100 6186 mat2 = torch.ones(1000, 128, dtype=torch.half, device=device) * 100 6187 out = torch.addmm(inp, mat1, mat2, alpha=0.001, beta=0.) 6188 # just check for no overflow on ROCM 6189 if TEST_WITH_ROCM: 6190 self.assertFalse(out.isinf().any()) 6191 else: 6192 self.assertTrue((out == 10000.).all()) 6193 inp = torch.zeros(3, 128, 128, dtype=torch.half, device=device) 6194 mat1 = torch.ones(3, 128, 1000, dtype=torch.half, device=device) * 100 6195 mat2 = torch.ones(3, 1000, 128, dtype=torch.half, device=device) * 100 6196 out = torch.baddbmm(inp, mat1, mat2, alpha=0.001, beta=0.) 6197 if TEST_WITH_ROCM: 6198 self.assertFalse(out.isinf().any()) 6199 else: 6200 self.assertTrue((out == 10000.).all()) 6201 torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = orig 6202 6203 @dtypes(torch.float) 6204 def test_baddbmm_nan_input_with_zero_beta(self, device, dtype): 6205 for shape in [[3, 2, 2], [2, 20, 20]]: 6206 mat1, mat2 = (torch.randn(shape, dtype=dtype, device=device) for _ in range(2)) 6207 inputs = [torch.randn(shape, dtype=dtype, device=device), 6208 torch.randn(shape, dtype=dtype, device=device).fill_(torch.nan)] 6209 outs = [None, torch.randn(shape, dtype=dtype, device=device), 6210 torch.randn(shape, dtype=dtype, device=device).fill_(torch.nan)] 6211 options = itertools.product(inputs, outs) 6212 for input, out in options: 6213 y_ref = torch.bmm(mat1, mat2) 6214 y = torch.baddbmm(input, mat1, mat2, beta=0.0, out=out) 6215 self.assertEqual(y_ref, y) 6216 6217 @dtypes(torch.int16, torch.int32, torch.int64, torch.float16, torch.float32, torch.float64) 6218 def test_baddbmm_input_dtypes_compatibility(self, device, dtype): 6219 batch1 = torch.rand((1, 2, 2), dtype=torch.float32, device=device) 6220 batch2 = torch.rand((1, 2, 2), dtype=torch.float32, device=device) 6221 input_tensor = torch.rand((1, 2, 2), device=device).to(dtype) 6222 if dtype != torch.float32: 6223 with self.assertRaisesRegex(RuntimeError, "Input dtypes must be the same"): 6224 y = torch.baddbmm(input_tensor, batch1, batch2, beta=0.0) 6225 else: 6226 out = torch.randn((1, 2, 2), dtype=dtype, device=device).fill_(torch.nan) 6227 y_ref = torch.bmm(batch1, batch2) 6228 y = torch.baddbmm(input_tensor, batch1, batch2, beta=0.0, out=out) 6229 self.assertEqual(out, y_ref) 6230 6231 6232 @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error") 6233 @onlyCUDA 6234 def test_matmul_45724(self, device): 6235 # https://github.com/pytorch/pytorch/issues/45724 6236 a = torch.rand(65537, 22, 64, device=device, dtype=torch.half) 6237 b = torch.rand(65537, 64, 22, device=device, dtype=torch.half) 6238 c = torch.full((65537, 22, 22), math.nan, dtype=torch.half, device=device) 6239 cpu_result = torch.matmul(a.cpu().float(), b.cpu().float()).cuda().half() 6240 torch.matmul(a, b, out=c) 6241 self.assertEqual(c, cpu_result) 6242 6243 @unittest.skipIf(IS_WINDOWS, "Skipped on Windows!") 6244 @unittest.skipIf(SM90OrLater and not TEST_WITH_ROCM, "Expected failure on sm90") 6245 @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error") 6246 @onlyCUDA 6247 @parametrize("k", [16, 32]) 6248 @parametrize("n", [16, 32]) 6249 @parametrize("use_transpose_a", [True, False]) 6250 @parametrize("use_transpose_b", [True, False]) 6251 def test__int_mm(self, device, k, n, use_transpose_a, use_transpose_b): 6252 def genf_int_float(x, y, use_transpose): 6253 if use_transpose: 6254 x, y = y, x 6255 x_int8 = torch.randint(-10, 10, (x, y), dtype=torch.int8, device=device) 6256 x_float = x_int8.to(torch.float32) 6257 if use_transpose: 6258 return x_int8.t(), x_float.t() 6259 return x_int8, x_float 6260 6261 def _test(m, k, n, transpose_a, transpose_b, test_equal=True): 6262 a_int8, a_float = genf_int_float(m, k, transpose_a) 6263 b_int8, b_float = genf_int_float(k, n, transpose_b) 6264 c_int32 = torch._int_mm(a_int8, b_int8) 6265 self.assertTrue(c_int32.dtype is torch.int32) 6266 self.assertEqual(c_int32.device, torch.device(device)) 6267 if test_equal: 6268 self.assertEqual(c_int32.float(), torch.mm(a_float, b_float)) 6269 else: 6270 self.assertNotEqual(c_int32.float(), torch.mm(a_float, b_float)) 6271 c_int32_result = c_int32.new_empty(c_int32.size()) 6272 # Checking out variant 6273 torch._int_mm(a_int8, b_int8, out=c_int32_result) 6274 if test_equal: 6275 self.assertEqual(c_int32_result.float(), torch.mm(a_float, b_float)) 6276 else: 6277 self.assertNotEqual(c_int32_result.float(), torch.mm(a_float, b_float)) 6278 6279 # NOTE: We're just exercising terrible failures here. 6280 version = _get_torch_cuda_version() 6281 SM80OrLater = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 0) 6282 SM70 = torch.cuda.is_available() and torch.cuda.get_device_capability() == (7, 0) 6283 SM75 = torch.cuda.is_available() and torch.cuda.get_device_capability() == (7, 5) 6284 6285 if TEST_WITH_ROCM: 6286 _test(17, k, n, use_transpose_a, use_transpose_b, True) 6287 elif version >= (11, 7): 6288 if not use_transpose_a and use_transpose_b: 6289 if SM80OrLater or (version >= (12, 3) and (SM70 or SM75)): 6290 _test(17, k, n, use_transpose_a, use_transpose_b, version > (11, 7)) 6291 else: 6292 with self.assertRaisesRegex(RuntimeError, 6293 "CUDA error: CUBLAS_STATUS_NOT_SUPPORTED when calling cublasLtMatmul"): 6294 _test(17, k, n, use_transpose_a, use_transpose_b) 6295 6296 if use_transpose_a and not use_transpose_b: 6297 with self.assertRaisesRegex(RuntimeError, 6298 "CUDA error: CUBLAS_STATUS_NOT_SUPPORTED when calling cublasLtMatmul"): 6299 _test(17, k, n, use_transpose_a, use_transpose_b) 6300 6301 if use_transpose_a and use_transpose_b: 6302 with self.assertRaisesRegex(RuntimeError, 6303 "CUDA error: CUBLAS_STATUS_NOT_SUPPORTED when calling cublasLtMatmul"): 6304 _test(17, k, n, use_transpose_a, use_transpose_b) 6305 6306 if not use_transpose_a and not use_transpose_b: 6307 if SM80OrLater or (version >= (12, 3) and (SM70 or SM75)): 6308 _test(17, k, n, use_transpose_a, use_transpose_b) 6309 else: 6310 with self.assertRaisesRegex(RuntimeError, 6311 "CUDA error: CUBLAS_STATUS_NOT_SUPPORTED when calling cublasLtMatmul"): 6312 _test(17, k, n, use_transpose_a, use_transpose_b) 6313 else: 6314 with self.assertRaisesRegex(RuntimeError, "_int_mm_out_cuda not compiled for CUDA"): 6315 _test(17, k, n, use_transpose_a, use_transpose_b, False) 6316 6317 @unittest.skipIf(IS_WINDOWS, "Skipped on Windows!") 6318 @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error") 6319 @onlyCUDA 6320 def test__int_mm_errors(self, device): 6321 if TEST_WITH_ROCM: 6322 self.skipTest("_int_mm not compiled for ROCM") 6323 6324 version = _get_torch_cuda_version() 6325 if version < (11, 7): 6326 self.skipTest("_int_mm only compiled for CUDA 11.7") 6327 6328 def genf_int(x, y): 6329 return torch.empty((x, y), dtype=torch.int8, device=device) 6330 6331 def _gen_pair(m, k, n): 6332 return genf_int(m, k), genf_int(k, n) 6333 6334 self.assertRaisesRegex(RuntimeError, 6335 r"self.size\(0\) needs to be greater than 16, but got 16", 6336 lambda: torch._int_mm(*_gen_pair(16, 8, 32))) 6337 self.assertRaisesRegex(RuntimeError, 6338 r"self.size\(1\) needs to be greater than 0 and a multiple of 8, but got 7", 6339 lambda: torch._int_mm(*_gen_pair(17, 7, 32))) 6340 self.assertRaisesRegex(RuntimeError, 6341 r"self.size\(1\) needs to match mat2.size\(0\) but got 8 and 7", 6342 lambda: torch._int_mm(genf_int(17, 8), genf_int(7, 32))) 6343 self.assertRaisesRegex(RuntimeError, 6344 r"mat2.size\(1\) needs to be greater than 0 and a multiple of 8, but got 31", 6345 lambda: torch._int_mm(*_gen_pair(17, 8, 31))) 6346 self.assertRaisesRegex(RuntimeError, 6347 r"expected scalar type Char but found Float", 6348 lambda: torch._int_mm(genf_int(17, 8).float(), genf_int(8, 32))) 6349 self.assertRaisesRegex(RuntimeError, 6350 r"expected scalar type Char but found Float", 6351 lambda: torch._int_mm(genf_int(17, 8), genf_int(8, 32).float())) 6352 self.assertRaisesRegex(RuntimeError, 6353 r"Expected result dtype to be of type kInt but got float", 6354 lambda: torch._int_mm(genf_int(17, 8), genf_int(8, 32), out=genf_int(16, 32).float())) 6355 self.assertRaisesRegex(RuntimeError, 6356 r"Expected result.size\(0\) to be 17 but got 15", 6357 lambda: torch._int_mm(genf_int(17, 8), genf_int(8, 32), out=genf_int(15, 32).int())) 6358 self.assertRaisesRegex(RuntimeError, 6359 r"Expected result.size\(0\) to be 17 but got 16", 6360 lambda: torch._int_mm(genf_int(17, 8), genf_int(8, 32), out=genf_int(16, 31).int())) 6361 6362 @onlyCPU 6363 @parametrize("m", [0, 8, 17]) 6364 @parametrize("k", [0, 16, 32]) 6365 @parametrize("n", [16, 32]) 6366 @parametrize("use_transpose_a", [True, False]) 6367 @parametrize("use_transpose_b", [True, False]) 6368 @parametrize("non_contig_type", [0, 1, 2]) 6369 def test__int_mm_cpu(self, device, m, k, n, use_transpose_a, use_transpose_b, non_contig_type): 6370 # non_contig_type: 6371 # 0: the whole data buffer is contiguous (can be transposed) 6372 # 1: stride of one dimension is 1, but the whole buffer is not contiguous 6373 # 2: Neither stride is 1 6374 6375 def genf_int_float(x, y, use_transpose, non_contig_type): 6376 if use_transpose: 6377 x, y = y, x 6378 if non_contig_type != 0: 6379 y = y * 2 6380 x_int8 = torch.randint(-10, 10, (x, y), dtype=torch.int8, device=device) 6381 x_float = x_int8.to(torch.float32) 6382 if non_contig_type == 1: 6383 x_int8 = x_int8[:, : y // 2] 6384 x_float = x_float[:, : y // 2] 6385 elif non_contig_type == 2: 6386 x_int8 = x_int8[:, ::2] 6387 x_float = x_float[:, ::2] 6388 if use_transpose: 6389 return x_int8.t(), x_float.t() 6390 return x_int8, x_float 6391 6392 if non_contig_type != 0 and (m == 0 or k == 0): 6393 return 6394 a_int8, a_float = genf_int_float(m, k, use_transpose_a, non_contig_type) 6395 b_int8, b_float = genf_int_float(k, n, use_transpose_b, non_contig_type) 6396 c_int32 = torch._int_mm(a_int8, b_int8) 6397 self.assertTrue(c_int32.dtype is torch.int32) 6398 self.assertEqual(c_int32.device, torch.device(device)) 6399 self.assertEqual(c_int32.float(), torch.mm(a_float, b_float)) 6400 c_int32_result = c_int32.new_empty(c_int32.size()) 6401 # Checking out variant 6402 torch._int_mm(a_int8, b_int8, out=c_int32_result) 6403 self.assertEqual(c_int32_result.float(), torch.mm(a_float, b_float)) 6404 6405 @unittest.skipIf(IS_WINDOWS, "Skipped on Windows!") 6406 @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error") 6407 @onlyNativeDeviceTypes 6408 def test__convert_weight_to_int4pack(self, device): 6409 # TODO: Fix https://github.com/pytorch/pytorch/issues/131425 and use OpInfo instead 6410 test_list = [((64, 32), 2), ((64, 48), 2), ((64, 64), 2), ((256, 128), 4), ((256, 128), 8)] 6411 if self.device_type == 'cuda' and not SM80OrLater: 6412 self.skipTest("requires SM80 or later") 6413 6414 if TEST_WITH_ROCM: 6415 if not CDNA2OrLater(): 6416 self.skipTest("_int4_mm is supported only for CDNA2 or later") 6417 6418 torch.manual_seed(1) 6419 for shape, innerKTiles in test_list: 6420 b = torch.rand(shape, dtype=torch.bfloat16, device=device) 6421 b_uint8, _ = _group_quantize_tensor(b, n_bit=4, q_group_size=32) 6422 b_int4pack = torch._convert_weight_to_int4pack(b_uint8, innerKTiles=innerKTiles) 6423 b_int4pack_meta = torch._convert_weight_to_int4pack(b_uint8.to(device="meta"), innerKTiles=innerKTiles) 6424 self.assertEqual(b_int4pack.shape, b_int4pack_meta.shape) 6425 6426 @unittest.skipIf(IS_WINDOWS, "Skipped on Windows!") 6427 @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error") 6428 @onlyNativeDeviceTypes 6429 @parametrize("m", [32, 64]) 6430 @parametrize("k", [32, 64]) 6431 @parametrize("n", [48, 64]) 6432 def test__int4_mm(self, device, m, k, n): 6433 if self.device_type == 'cuda' and not SM80OrLater: 6434 self.skipTest("requires SM80 or later") 6435 6436 if TEST_WITH_ROCM: 6437 if not CDNA2OrLater(): 6438 self.skipTest("_int4_mm is supported only for CDNA2 or later") 6439 6440 q_group = 32 6441 inner_k_tiles = 2 6442 6443 torch.manual_seed(1) 6444 a_bf16 = torch.rand((m, k), dtype=torch.bfloat16, device=device) 6445 b_bf16 = torch.rand((k, n), dtype=torch.bfloat16, device=device) 6446 6447 def convert_weight_to_int4pack(b): 6448 b_uint8, b_scales_and_zeros = _group_quantize_tensor( 6449 b, n_bit=4, q_group_size=q_group 6450 ) 6451 b_int4pack = torch._convert_weight_to_int4pack( 6452 b_uint8, inner_k_tiles 6453 ) 6454 6455 return b_int4pack, b_scales_and_zeros 6456 6457 def weight_int4pack_mm(a, b_int4pack, b_scales_and_zeros): 6458 return torch._weight_int4pack_mm( 6459 a, b_int4pack, q_group, b_scales_and_zeros 6460 ) 6461 6462 b_int4pack, b_scales_and_zeros_bf16 = convert_weight_to_int4pack(b_bf16) 6463 6464 for dtype in [torch.bfloat16] + ([torch.float16, torch.float32] if device == "cpu" else []): 6465 a = a_bf16.to(dtype=dtype) 6466 b = b_bf16.to(dtype=dtype) 6467 b_scales_and_zeros = b_scales_and_zeros_bf16.to(dtype=dtype) 6468 ref = torch.mm(a, b) 6469 res = weight_int4pack_mm(a, b_int4pack, b_scales_and_zeros) 6470 6471 mean_err = ((res - ref).abs() / ref).mean() 6472 self.assertTrue(mean_err < 0.05) 6473 6474 6475 @unittest.skipIf(IS_WINDOWS, "Skipped on Windows!") 6476 @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error") 6477 @onlyNativeDeviceTypes 6478 @parametrize("m", [32, 64]) 6479 @parametrize("k", [32, 64]) 6480 @parametrize("n", [48, 64]) 6481 def test_compile_int4_mm(self, device, m, k, n): 6482 if self.device_type == 'cuda' and not SM80OrLater: 6483 self.skipTest("requires SM80 or later") 6484 6485 if TEST_WITH_ROCM: 6486 if not CDNA2OrLater(): 6487 self.skipTest("_int4_mm is supported only for CDNA2 or later") 6488 6489 q_group = 32 6490 inner_k_tiles = 2 6491 6492 torch.manual_seed(1) 6493 a = torch.rand((m, k), dtype=torch.bfloat16, device=device) 6494 b = torch.rand((k, n), dtype=torch.bfloat16, device=device) 6495 6496 b_int32, b_scales_and_zeros = _group_quantize_tensor( 6497 b, n_bit=4, q_group_size=q_group 6498 ) 6499 6500 @torch.compile 6501 def int4_mm(a, b_int32, b_scales_and_zeros): 6502 b_int4pack = torch._convert_weight_to_int4pack( 6503 b_int32, inner_k_tiles 6504 ) 6505 return torch._weight_int4pack_mm( 6506 a, b_int4pack, q_group, b_scales_and_zeros 6507 ) 6508 6509 res = int4_mm(a, b_int32, b_scales_and_zeros) 6510 ref = torch.mm(a, b) 6511 6512 mean_err = ((res - ref).abs() / ref).mean() 6513 self.assertTrue(mean_err < 0.05) 6514 6515 @onlyCPU 6516 @parametrize("m", [32, 64]) 6517 @parametrize("k", [32, 64]) 6518 @parametrize("n", [48, 64]) 6519 def test__int8_mm(self, device, m, k, n): 6520 torch.manual_seed(1) 6521 a = torch.rand((m, k), dtype=torch.bfloat16, device=device) 6522 b = torch.rand((n, k), dtype=torch.bfloat16, device=device) 6523 6524 def convert_weight_to_int8pack(b): 6525 b_int8pack, b_scales, _ = _dynamically_quantize_per_channel( 6526 b, -128, 127, torch.int8 6527 ) 6528 return b_int8pack, b_scales 6529 6530 def weight_int8pack_mm(a, b_int8pack, b_scales): 6531 return torch._weight_int8pack_mm( 6532 a, b_int8pack, b_scales 6533 ) 6534 6535 b_int8pack, b_scales = convert_weight_to_int8pack(b) 6536 res = weight_int8pack_mm(a, b_int8pack, b_scales) 6537 ref = torch.mm(a, b.transpose(0, 1)) 6538 6539 mean_err = ((res - ref).abs() / ref).mean() 6540 self.assertTrue(mean_err < 0.05) 6541 6542 @onlyCPU 6543 @parametrize("m", [32, 64]) 6544 @parametrize("k", [32, 64]) 6545 @parametrize("n", [48, 64]) 6546 def test_compile_int8_mm(self, device, m, k, n): 6547 torch.manual_seed(1) 6548 a = torch.rand((m, k), dtype=torch.bfloat16, device=device) 6549 b = torch.rand((n, k), dtype=torch.bfloat16, device=device) 6550 6551 b_int8pack, b_scales, _ = _dynamically_quantize_per_channel( 6552 b, -128, 127, torch.int8 6553 ) 6554 6555 @torch.compile 6556 def int8_mm(a, b_int8pack, b_scales): 6557 return torch._weight_int8pack_mm( 6558 a, b_int8pack, b_scales 6559 ) 6560 6561 res = int8_mm(a, b_int8pack, b_scales) 6562 ref = torch.mm(a, b.transpose(0, 1)) 6563 6564 mean_err = ((res - ref).abs() / ref).mean() 6565 self.assertTrue(mean_err < 0.05) 6566 6567 @onlyCPU 6568 @parametrize("m", [32, 35, 36, 40, 64]) 6569 @parametrize("k", [32, 35, 36, 40, 64]) 6570 # NOTE: This is intended to cover fp16_gemv_trans in 6571 # BlasKernel.cpp. Currently, bounds being divisible by 32, 8-but-not-32, and 4-but-not-8 6572 # all matter. 6573 def test_fp16_mv_transposed_first_argument_arm_cpu(self, device, m, k): 6574 torch.manual_seed(1) 6575 a = torch.rand((m, k), dtype=torch.half, device=device) 6576 b = torch.rand((1, k), dtype=torch.half, device=device) 6577 6578 prev = torch._C._get_cpu_allow_fp16_reduced_precision_reduction() 6579 try: 6580 torch._C._set_cpu_allow_fp16_reduced_precision_reduction(False) 6581 ref = torch.mm(a, b.t()) 6582 try: 6583 torch._C._set_cpu_allow_fp16_reduced_precision_reduction(True) 6584 except RuntimeError as e: 6585 raise unittest.SkipTest from e 6586 res = torch.mm(a, b.t()) 6587 torch.testing.assert_close(res, ref, atol=1e-2, rtol=1e-2) 6588 finally: 6589 torch._C._set_cpu_allow_fp16_reduced_precision_reduction(prev) 6590 6591 @slowTest 6592 @onlyNativeDeviceTypes 6593 # bfloat16 doesn't have sufficient precision to pass this test 6594 @dtypes(torch.half, torch.float32, torch.float64, torch.int32, torch.int64, torch.cfloat, torch.cdouble) 6595 @dtypesIfCUDA(torch.float32, torch.float64, torch.cfloat, torch.cdouble) 6596 @tf32_on_and_off(0.01) 6597 @bf32_on_and_off(0.01) 6598 def test_mm(self, device, dtype): 6599 def _test_mm(n, m, p, dtype, genf): 6600 # helper function 6601 def matrixmultiply(mat1, mat2): 6602 n = mat1.size(0) 6603 m = mat1.size(1) 6604 p = mat2.size(1) 6605 dtype_ = torch.float if dtype == torch.half else dtype 6606 if dtype == torch.half: 6607 mat1 = mat1.float() 6608 mat2 = mat2.float() 6609 res = torch.zeros(n, p, dtype=dtype_, device=device) 6610 for i, j in iter_indices(res): 6611 res[i, j] = sum(mat1[i, k] * mat2[k, j] for k in range(m)) 6612 return res.half() if dtype == torch.half else res 6613 6614 # contiguous case 6615 mat1 = genf(n, m) 6616 mat2 = genf(m, p) 6617 res = torch.mm(mat1, mat2) 6618 6619 res2 = matrixmultiply(mat1, mat2) 6620 self.assertEqual(res, res2) 6621 6622 # non contiguous case 1 6623 mat1 = genf(n, m) 6624 mat2 = genf(p, m).t() 6625 res = torch.mm(mat1, mat2) 6626 6627 res2 = matrixmultiply(mat1, mat2) 6628 self.assertEqual(res, res2) 6629 6630 # non contiguous case 2 6631 mat1 = genf(m, n).t() 6632 mat2 = genf(m, p) 6633 res = torch.mm(mat1, mat2) 6634 6635 res2 = matrixmultiply(mat1, mat2) 6636 self.assertEqual(res, res2) 6637 6638 # non contiguous case 3 6639 mat1 = genf(m, n).t() 6640 mat2 = genf(p, m).t() 6641 res = torch.mm(mat1, mat2) 6642 6643 res2 = matrixmultiply(mat1, mat2) 6644 self.assertEqual(res, res2) 6645 6646 # test with zero stride 6647 mat1 = genf(n, m) 6648 mat2 = genf(m, 1).expand(m, p) 6649 res = torch.mm(mat1, mat2) 6650 6651 res2 = matrixmultiply(mat1, mat2) 6652 self.assertEqual(res, res2) 6653 6654 # explicitly exercise the _out variant in torch.mm(). 6655 # contiguous case 6656 mat1 = genf(n, m) 6657 mat2 = genf(m, p) 6658 res = genf(n, p) 6659 torch.mm(mat1, mat2, out=res) 6660 6661 res2 = matrixmultiply(mat1, mat2) 6662 self.assertEqual(res, res2) 6663 6664 # explicitly exercise the _out variant in torch.mm(). 6665 # non contiguous case 3 6666 mat1 = genf(m, n).t() 6667 mat2 = genf(p, m).t() 6668 res = genf(n, p) 6669 torch.mm(mat1, mat2, out=res) 6670 6671 res2 = matrixmultiply(mat1, mat2) 6672 self.assertEqual(res, res2) 6673 6674 def genf_int(x, y): 6675 return torch.randint(0, 100, (x, y), dtype=dtype, device=device) 6676 6677 def genf_bfloat(x, y): 6678 return torch.randn(x, y, dtype=torch.float32, device=device).to(dtype) * 0.1 6679 6680 def genf_float(x, y): 6681 return torch.randn(x, y, dtype=dtype, device=device) 6682 6683 def genf_Half(x, y): 6684 return torch.randn(x, y, dtype=dtype, device=device) 6685 6686 for (n, m, p) in [(20, 10, 15), (15, 20, 10), (25, 18, 10)]: 6687 if (dtype == torch.int32) or (dtype == torch.int64): 6688 genf = genf_int 6689 elif (dtype == torch.bfloat16): 6690 genf = genf_bfloat 6691 elif (dtype == torch.half): 6692 genf = genf_Half 6693 else: 6694 genf = genf_float 6695 6696 _test_mm(n, m, p, dtype, genf) 6697 6698 @onlyNativeDeviceTypes 6699 def test_mm_bmm_non_memory_dense(self, device): 6700 def _slice(tensor, fn): 6701 return fn(tensor)[..., ::2] 6702 A = torch.randn(3, 6, dtype=torch.cfloat, device=device) 6703 B = torch.randn(3, 3, dtype=torch.cfloat, device=device) 6704 out = torch.empty(3, 3, device=device, dtype=torch.complex64).t() 6705 out1 = torch.empty(3, 3, device=device, dtype=torch.complex64).t() 6706 A_conj = _slice(A, torch.conj) 6707 A_conj_physical = _slice(A, torch.conj_physical) 6708 6709 self.assertEqual(torch.mm(A_conj, B, out=out), torch.mm(A_conj_physical, B, out=out)) 6710 self.assertEqual(torch.mm(A_conj.t(), B, out=out), torch.mm(A_conj_physical.t(), B, out=out)) 6711 6712 Ab = torch.randn(2, 3, 6, dtype=torch.cfloat, device=device) 6713 Bb = torch.randn(2, 3, 3, dtype=torch.cfloat, device=device) 6714 Bb_ = torch.randn(1, 3, 3, dtype=torch.cfloat, device=device).expand(2, 3, 3) 6715 out_b = torch.empty(2, 3, 3, device=device, dtype=torch.complex64).mT 6716 6717 Ab_conj = _slice(Ab, torch.conj) 6718 Ab_conj_physical = _slice(Ab, torch.conj_physical) 6719 6720 def t_b(tensor): 6721 return tensor.mT 6722 6723 self.assertEqual(torch.bmm(Ab_conj, Bb, out=out_b), torch.bmm(Ab_conj_physical, Bb, out=out_b)) 6724 self.assertEqual(torch.bmm(t_b(Ab_conj), Bb, out=out_b), torch.bmm(t_b(Ab_conj_physical), Bb, out=out_b)) 6725 6726 # test broadcasting 6727 self.assertEqual(torch.bmm(Ab_conj, Bb_, out=out_b), torch.bmm(Ab_conj_physical, Bb_, out=out_b)) 6728 self.assertEqual(torch.bmm(t_b(Ab_conj), Bb_, out=out_b), torch.bmm(t_b(Ab_conj_physical), Bb_, out=out_b)) 6729 6730 @onlyNativeDeviceTypes 6731 def test_mm_conjtranspose(self, device): 6732 A = torch.randn(3, 3, dtype=torch.cfloat, device=device) 6733 B = torch.randn(3, 3, dtype=torch.cfloat, device=device) 6734 6735 # A conjtranspose 6736 out1 = torch.mm(A.t().conj(), B) 6737 out1_ref = torch.mm(A.t().conj_physical(), B) 6738 self.assertEqual(out1, out1_ref) 6739 6740 # B conjtranspose 6741 out1 = torch.mm(A, B.t().conj()) 6742 out1_ref = torch.mm(A, B.t().conj_physical()) 6743 self.assertEqual(out1, out1_ref) 6744 6745 # A&B conjtranspose 6746 out1 = torch.mm(A.t().conj(), B.t().conj()) 6747 out1_ref = torch.mm(A.t().conj_physical(), B.t().conj_physical()) 6748 self.assertEqual(out1, out1_ref) 6749 6750 @onlyNativeDeviceTypes 6751 def test_mm_empty_inputs_mixed_dtype_errors(self, device): 6752 a = torch.randint(0, 10, [1, 10], dtype=torch.int16, device=device) 6753 b = torch.randn(10, 20, dtype=torch.float32, device=device) 6754 with self.assertRaisesRegex(RuntimeError, "expected .* and .* to have the same dtype, but got:"): 6755 torch.mm(a, b) 6756 6757 @onlyNativeDeviceTypes 6758 @dtypes(torch.float32, torch.float64) 6759 def test_strided_mm_bmm(self, device, dtype): 6760 # Tests strided view case with stride smaller than corresponding dimension size 6761 x = torch.tensor([[1., 2., 3.], [4., 5., 6.]], dtype=dtype, device=device) 6762 new_shape = [2, 2, 2] 6763 new_stride = [3, 1, 1] 6764 sx = torch.as_strided(x, size=new_shape, stride=new_stride) 6765 6766 torch_fn = lambda x: torch.bmm(x, x) # noqa: E731 6767 np_fn = lambda x: np.matmul(x, x) # noqa: E731 6768 self.compare_with_numpy(torch_fn, np_fn, sx) 6769 6770 torch_fn = lambda x: torch.mm(x, x) # noqa: E731 6771 self.compare_with_numpy(torch_fn, np_fn, sx[0]) 6772 6773 @precisionOverride({torch.half: 0.05, torch.bfloat16: 0.05}) 6774 @onlyNativeDeviceTypes 6775 @dtypes(*floating_and_complex_types_and(torch.bfloat16, torch.half)) 6776 @tf32_on_and_off(0.05) 6777 @bf32_on_and_off(0.05) 6778 def test_bmm(self, device, dtype): 6779 if self.device_type == 'cuda' and dtype is torch.bfloat16 and not SM53OrLater: 6780 # cuBLAS does not guarantee BFloat16 support on SM < 53. 6781 # So on PyTorch, we consider BFloat16 support on SM < 53 as 6782 # undefined bahavior 6783 return 6784 6785 batch_sizes = [1, 10] 6786 M, N, O = 23, 15, 12 6787 numpy_dtype = dtype if dtype != torch.bfloat16 else torch.float32 6788 6789 is_supported = True 6790 if dtype == torch.bfloat16 and self.device_type == 'cuda': 6791 is_supported = TEST_WITH_ROCM or SM53OrLater 6792 6793 if not is_supported: 6794 for num_batches in batch_sizes: 6795 b1 = torch.randn(num_batches, M, N, device=device).to(dtype) 6796 b2 = torch.randn(num_batches, N, O, device=device).to(dtype) 6797 self.assertRaisesRegex(RuntimeError, "type|Type|not implemented|CUBLAS_STATUS_NOT_SUPPORTED", 6798 lambda: torch.bmm(b1, b2)) 6799 return 6800 6801 def invert_perm(p): 6802 d = {x: i for i, x in enumerate(p)} 6803 return (d[0], d[1], d[2]) 6804 6805 def generate_inputs(num_batches): 6806 # transposed tensors 6807 for perm1, perm2 in itertools.product(itertools.permutations((0, 1, 2)), repeat=2): 6808 b1 = make_tensor((num_batches, M, N), dtype=dtype, device=device, low=-0.1, high=0.1) 6809 b2 = make_tensor((num_batches, N, O), dtype=dtype, device=device, low=-0.1, high=0.1) 6810 b1 = b1.permute(perm1).contiguous().permute(invert_perm(perm1)) 6811 b2 = b2.permute(perm2).contiguous().permute(invert_perm(perm2)) 6812 yield b1, b2 6813 # broadcasting tensors 6814 for b1, b2, b3, b4, b5, b6 in itertools.product((True, False), repeat=6): 6815 shape1 = (num_batches if b1 else 1, M if b2 else 1, N if b3 else 1) 6816 shape2 = (num_batches if b4 else 1, N if b5 else 1, O if b6 else 1) 6817 b1 = make_tensor(shape1, dtype=dtype, device=device, low=-0.1, high=0.1).expand(num_batches, M, N) 6818 b2 = make_tensor(shape2, dtype=dtype, device=device, low=-0.1, high=0.1).expand(num_batches, N, O) 6819 yield b1, b2 6820 # zero-sized tensors 6821 for z1, z2, z3, z4 in itertools.product((True, False), repeat=4): 6822 shape1 = (num_batches if z1 else 0, M if z2 else 0, N if z3 else 0) 6823 shape2 = (num_batches if z1 else 0, N if z3 else 0, O if z4 else 0) 6824 b1 = torch.randn(shape1, dtype=dtype, device=device) 6825 b2 = torch.randn(shape2, dtype=dtype, device=device) 6826 yield b1, b2 6827 6828 for num_batches in batch_sizes: 6829 for (b1, b2), perm3 in itertools.product(generate_inputs(num_batches), itertools.permutations((0, 1, 2))): 6830 res1 = torch.bmm(b1, b2) 6831 res2 = torch.full((num_batches, M, O), math.nan, dtype=dtype, device=device) \ 6832 .permute(perm3).contiguous().permute(invert_perm(perm3)) 6833 torch.bmm(b1, b2, out=res2) 6834 expect = torch.from_numpy( 6835 b1.to(numpy_dtype).cpu().numpy() @ b2.to(numpy_dtype).cpu().numpy()).to(device=device, dtype=dtype) 6836 self.assertEqual(expect, res1) 6837 self.assertEqual(expect, res2) 6838 6839 if self.device_type == 'cuda': 6840 # check that mixed arguments are rejected 6841 self.assertRaises(RuntimeError, lambda: torch.bmm(b1, b2.cpu())) 6842 self.assertRaises(RuntimeError, lambda: torch.bmm(b1.cpu(), b2)) 6843 self.assertRaises(RuntimeError, lambda: torch.bmm(b1, b2, out=res2.cpu())) 6844 6845 def _test_addbmm_baddbmm(self, func, b1, b2, ref, out_tensor): 6846 getattr(out_tensor, func + "_")(b1, b2) 6847 self.assertEqual(out_tensor, ref) 6848 res3 = out_tensor.clone() 6849 6850 with self.assertWarnsOnceRegex( 6851 UserWarning, f"This overload of {func}_ is deprecated"): 6852 getattr(out_tensor, func + "_")(1, b1, b2) 6853 self.assertEqual(out_tensor, ref * 2), 6854 getattr(res3, func + "_")(b1, b2, beta=1) 6855 self.assertEqual(out_tensor, res3) 6856 6857 with self.assertWarnsOnceRegex( 6858 UserWarning, f"This overload of {func}_ is deprecated"): 6859 getattr(out_tensor, func + "_")(1., .5, b1, b2) 6860 self.assertEqual(out_tensor, ref * 2.5) 6861 getattr(res3, func + "_")(b1, b2, beta=1., alpha=.5) 6862 self.assertEqual(out_tensor, res3) 6863 6864 with self.assertWarnsOnceRegex( 6865 UserWarning, f"This overload of {func} is deprecated"): 6866 self.assertEqual(out_tensor, getattr(torch, func)(1, out_tensor, 0, b1, b2)) 6867 6868 res4 = getattr(torch, func)(out_tensor, b1, b2, beta=1, alpha=.5) 6869 self.assertEqual(res4, ref * 3), 6870 6871 nan = torch.full_like(out_tensor, math.nan) 6872 res5 = getattr(torch, func)(nan, b1, b2, beta=0, alpha=1) 6873 self.assertEqual(res5, ref) 6874 6875 if b1.is_complex(): 6876 res6 = getattr(torch, func)(out_tensor, b1, b2, beta=.1j, alpha=.5j) 6877 self.assertEqual(res6, out_tensor * .1j + .5j * ref) 6878 else: 6879 res6 = getattr(torch, func)(out_tensor, b1, b2, beta=.1, alpha=.5) 6880 self.assertEqual(res6, out_tensor * .1 + .5 * ref) 6881 6882 res7 = torch.full_like(out_tensor, math.nan) 6883 getattr(torch, func)(nan, b1, b2, beta=0, out=res7) 6884 self.assertEqual(res7, ref) 6885 6886 @precisionOverride({torch.half: 0.05, torch.bfloat16: 0.05}) 6887 @onlyNativeDeviceTypes 6888 @dtypes(*floating_and_complex_types_and(torch.bfloat16, torch.half)) 6889 @tf32_on_and_off(0.05) 6890 @bf32_on_and_off(0.05) 6891 def test_addbmm(self, device, dtype): 6892 if self.device_type == 'cuda' and dtype is torch.bfloat16 and not SM53OrLater: 6893 # cuBLAS does not guarantee BFloat16 support on SM < 53. 6894 # So on PyTorch, we consider BFloat16 support on SM < 53 as 6895 # undefined bahavior 6896 return 6897 6898 num_batches = 2 6899 M, N, O = 16, 17, 18 6900 6901 is_supported = True 6902 if dtype == torch.bfloat16: 6903 if self.device_type == 'cpu': 6904 self.precision = 1 # 43 vs 43.75 6905 else: 6906 is_supported = TEST_WITH_ROCM or SM53OrLater 6907 6908 if not is_supported: 6909 b1 = make_tensor((num_batches, M, N), dtype=dtype, device=device, low=-1, high=1) 6910 b2 = make_tensor((num_batches, N, O), dtype=dtype, device=device, low=-1, high=1) 6911 t = make_tensor((M, O), dtype=dtype, device=device, low=-1, high=1) 6912 self.assertRaisesRegex(RuntimeError, "type|Type|not implemented|CUBLAS_STATUS_NOT_SUPPORTED", 6913 lambda: torch.addbmm(t, b1, b2)) 6914 return 6915 6916 def invert_perm(p): 6917 d = {x: i for i, x in enumerate(p)} 6918 return (d[0], d[1], d[2]) 6919 6920 def generate_tensor(): 6921 numpy_dtype = dtype if dtype != torch.bfloat16 else torch.float32 6922 # transposed tensors 6923 for perm1, perm2 in itertools.product(itertools.permutations((0, 1, 2)), repeat=2): 6924 for perm3 in itertools.permutations((0, 1)): 6925 b1 = make_tensor((num_batches, M, N), dtype=dtype, device=device, low=-1, high=1) * 0.1 6926 b2 = make_tensor((num_batches, N, O), dtype=dtype, device=device, low=-1, high=1) * 0.1 6927 b1 = b1.permute(perm1).contiguous().permute(invert_perm(perm1)) 6928 b2 = b2.permute(perm2).contiguous().permute(invert_perm(perm2)) 6929 ref = torch.from_numpy( 6930 b1.to(numpy_dtype).cpu().numpy() @ b2.to(numpy_dtype).cpu().numpy() 6931 ).to(device=device, dtype=dtype).sum(0) 6932 out_tensor = torch.zeros_like(ref).permute(perm3).contiguous().permute(perm3) 6933 yield b1, b2, ref, out_tensor 6934 # broadcasting tensors 6935 for s1, s2, s3, s4, s5, s6 in itertools.product((True, False), repeat=6): 6936 shape1 = (num_batches if s1 else 1, M if s2 else 1, N if s3 else 1) 6937 shape2 = (num_batches if s4 else 1, N if s5 else 1, O if s6 else 1) 6938 b1 = make_tensor(shape1, dtype=dtype, device=device, low=-1, high=1).expand(num_batches, M, N) * 0.1 6939 b2 = make_tensor(shape2, dtype=dtype, device=device, low=-1, high=1).expand(num_batches, N, O) * 0.1 6940 ref = torch.from_numpy( 6941 b1.to(numpy_dtype).cpu().numpy() @ b2.to(numpy_dtype).cpu().numpy() 6942 ).to(device=device, dtype=dtype).sum(0) 6943 out_tensor = torch.zeros_like(ref) 6944 yield b1, b2, ref, out_tensor 6945 # zero-sized tensors 6946 for z1, z2, z3, z4 in itertools.product((True, False), repeat=4): 6947 shape1 = (num_batches if z1 else 0, M if z2 else 0, N if z3 else 0) 6948 shape2 = (num_batches if z1 else 0, N if z3 else 0, O if z4 else 0) 6949 b1 = make_tensor(shape1, dtype=dtype, device=device, low=-1, high=1) * 0.1 6950 b2 = make_tensor(shape2, dtype=dtype, device=device, low=-1, high=1) * 0.1 6951 ref = torch.from_numpy( 6952 b1.to(numpy_dtype).cpu().numpy() @ b2.to(numpy_dtype).cpu().numpy() 6953 ).to(device=device, dtype=dtype).sum(0) 6954 out_tensor = torch.zeros_like(ref) 6955 yield b1, b2, ref, out_tensor 6956 6957 for b1, b2, ref, out_tensor in generate_tensor(): 6958 self._test_addbmm_baddbmm("addbmm", b1, b2, ref, out_tensor) 6959 6960 @precisionOverride({torch.half: 0.1, torch.bfloat16: 0.5}) 6961 @onlyNativeDeviceTypes 6962 @dtypes(*floating_and_complex_types_and(torch.bfloat16, torch.half)) 6963 @tf32_on_and_off(0.05) 6964 @bf32_on_and_off(0.05) 6965 def test_baddbmm(self, device, dtype): 6966 if self.device_type == 'cuda' and dtype is torch.bfloat16 and not SM53OrLater: 6967 # cuBLAS does not guarantee BFloat16 support on SM < 53. 6968 # So on PyTorch, we consider BFloat16 support on SM < 53 as 6969 # undefined bahavior 6970 return 6971 6972 num_batches = 10 6973 M, N, O = 12, 8, 50 6974 6975 is_supported = True 6976 if dtype == torch.bfloat16 and self.device_type == 'cuda': 6977 is_supported = TEST_WITH_ROCM or SM53OrLater 6978 6979 if not is_supported: 6980 b1 = make_tensor((num_batches, M, N), dtype=dtype, device=device, low=-1, high=1) 6981 b2 = make_tensor((num_batches, N, O), dtype=dtype, device=device, low=-1, high=1) 6982 t = make_tensor((num_batches, M, O), dtype=dtype, device=device, low=-1, high=1) 6983 self.assertRaisesRegex(RuntimeError, "type|Type|not implemented|CUBLAS_STATUS_NOT_SUPPORTED", 6984 lambda: torch.baddbmm(t, b1, b2)) 6985 return 6986 6987 def invert_perm(p): 6988 d = {x: i for i, x in enumerate(p)} 6989 return (d[0], d[1], d[2]) 6990 6991 def generate_tensor(): 6992 numpy_dtype = dtype if dtype not in [torch.bfloat16, torch.half] else torch.float32 6993 # transposed tensors 6994 for perm1, perm2, perm3 in itertools.product(itertools.permutations((0, 1, 2)), repeat=3): 6995 b1 = make_tensor((num_batches, M, N), dtype=dtype, device=device, low=-1, high=1) 6996 b2 = make_tensor((num_batches, N, O), dtype=dtype, device=device, low=-1, high=1) 6997 b1 = b1.permute(perm1).contiguous().permute(invert_perm(perm1)) 6998 b2 = b2.permute(perm2).contiguous().permute(invert_perm(perm2)) 6999 ref = torch.from_numpy( 7000 b1.to(numpy_dtype).cpu().numpy() @ b2.to(numpy_dtype).cpu().numpy()).to(device=device, dtype=dtype) 7001 out_tensor = torch.zeros_like(ref) 7002 out_tensor = out_tensor.permute(perm3).contiguous().permute(invert_perm(perm3)) 7003 yield b1, b2, ref, out_tensor 7004 # broadcasting tensors 7005 for s1, s2, s3, s4, s5, s6 in itertools.product((True, False), repeat=6): 7006 shape1 = (num_batches if s1 else 1, M if s2 else 1, N if s3 else 1) 7007 shape2 = (num_batches if s4 else 1, N if s5 else 1, O if s6 else 1) 7008 b1 = make_tensor(shape1, dtype=dtype, device=device, low=-1, high=1).expand(num_batches, M, N) 7009 b2 = make_tensor(shape2, dtype=dtype, device=device, low=-1, high=1).expand(num_batches, N, O) 7010 ref = torch.from_numpy( 7011 b1.to(numpy_dtype).cpu().numpy() @ b2.to(numpy_dtype).cpu().numpy()).to(device=device, dtype=dtype) 7012 out_tensor = torch.zeros_like(ref) 7013 yield b1, b2, ref, out_tensor 7014 # zero-sized tensors 7015 for z1, z2, z3, z4 in itertools.product((True, False), repeat=4): 7016 shape1 = (num_batches if z1 else 0, M if z2 else 0, N if z3 else 0) 7017 shape2 = (num_batches if z1 else 0, N if z3 else 0, O if z4 else 0) 7018 b1 = make_tensor(shape1, dtype=dtype, device=device, low=-2, high=2) 7019 b2 = make_tensor(shape2, dtype=dtype, device=device, low=-2, high=2) 7020 ref = torch.from_numpy( 7021 b1.to(numpy_dtype).cpu().numpy() @ b2.to(numpy_dtype).cpu().numpy()).to(device=device, dtype=dtype) 7022 out_tensor = torch.zeros_like(ref) 7023 yield b1, b2, ref, out_tensor 7024 7025 for b1, b2, ref, out_tensor in generate_tensor(): 7026 self._test_addbmm_baddbmm("baddbmm", b1, b2, ref, out_tensor) 7027 7028 @precisionOverride({torch.float32: 5e-3, torch.complex64: 1e-3}) 7029 @skipCUDAIfNoMagma 7030 @skipCPUIfNoLapack 7031 @dtypes(*floating_and_complex_types()) 7032 def test_pinverse(self, device, dtype): 7033 make_fullrank = make_fullrank_matrices_with_distinct_singular_values 7034 make_arg = partial(make_fullrank, device=device, dtype=dtype) 7035 7036 def run_test(M): 7037 # Testing against definition for pseudo-inverses 7038 MPI = torch.pinverse(M) 7039 MPI_ = MPI.cpu().numpy() 7040 M_ = M.cpu().numpy() 7041 if M.numel() > 0: 7042 self.assertEqual(M_, np.matmul(np.matmul(M_, MPI_), M_)) 7043 self.assertEqual(MPI_, np.matmul(np.matmul(MPI_, M_), MPI_)) 7044 self.assertEqual(np.matmul(M_, MPI_), np.matmul(M_, MPI_).swapaxes(-2, -1).conj()) 7045 self.assertEqual(np.matmul(MPI_, M_), np.matmul(MPI_, M_).swapaxes(-2, -1).conj()) 7046 else: 7047 self.assertEqual(M.shape, MPI.shape[:-2] + (MPI.shape[-1], MPI.shape[-2])) 7048 for sizes in [(5, 5), (3, 5, 5), (3, 7, 5, 5), # square matrices 7049 (3, 2), (5, 3, 2), (7, 5, 3, 2), # fat matrices 7050 (2, 3), (5, 2, 3), (7, 5, 2, 3), # thin matrices 7051 (0, 0), (0, 2), (2, 0), (3, 0, 0), (0, 3, 0), (0, 0, 3)]: # zero numel matrices 7052 M = torch.randn(*sizes, dtype=dtype, device=device) 7053 run_test(M) 7054 7055 # Test inverse and pseudo-inverse for invertible matrix 7056 for sizes in [(5, 5), (3, 5, 5), (3, 7, 5, 5)]: 7057 matsize = sizes[-1] 7058 batchdims = sizes[:-2] 7059 M = make_arg(*batchdims, matsize, matsize) 7060 self.assertEqual(torch.eye(matsize, dtype=dtype, device=device).expand(sizes), M.pinverse().matmul(M), 7061 atol=1e-7, rtol=0, msg='pseudo-inverse for invertible matrix') 7062 7063 @skipCPUIfNoLapack 7064 @skipCUDAIfNoMagmaAndNoCusolver 7065 @dtypes(torch.double, torch.cdouble) 7066 def test_matrix_power_non_negative(self, device, dtype): 7067 def check(*size): 7068 t = make_tensor(size, dtype=dtype, device=device) 7069 for n in range(8): 7070 res = torch.linalg.matrix_power(t, n) 7071 ref = np.linalg.matrix_power(t.cpu().numpy(), n) 7072 self.assertEqual(res.cpu(), torch.from_numpy(ref)) 7073 7074 check(0, 0) 7075 check(1, 1) 7076 check(5, 5) 7077 check(0, 3, 3) 7078 check(2, 3, 3) 7079 7080 @skipCPUIfNoLapack 7081 @skipCUDAIfNoMagmaAndNoCusolver 7082 @dtypes(torch.double, torch.cdouble) 7083 def test_matrix_power_negative(self, device, dtype): 7084 make_fullrank = make_fullrank_matrices_with_distinct_singular_values 7085 make_arg = partial(make_fullrank, device=device, dtype=dtype) 7086 7087 def check(*size): 7088 t = make_arg(*size) 7089 for n in range(-7, 0): 7090 res = torch.linalg.matrix_power(t, n) 7091 ref = np.linalg.matrix_power(t.cpu().numpy(), n) 7092 self.assertEqual(res.cpu(), torch.from_numpy(ref)) 7093 7094 check(0, 0) 7095 check(5, 5) 7096 check(2, 0, 0) 7097 check(0, 3, 3) 7098 check(2, 3, 3) 7099 check(2, 3, 5, 5) 7100 7101 @skipCUDAIfNoMagma 7102 @skipCPUIfNoLapack 7103 @dtypes(torch.float, torch.complex64) 7104 def test_linalg_matrix_exp_utils(self, device, dtype): 7105 # test linear combination 7106 def run_test(coeff_shape, data_shape): 7107 coeffs = torch.rand(*coeff_shape, device=device, dtype=torch.float) 7108 x = torch.rand(coeff_shape[1], *data_shape, device=device, dtype=dtype) 7109 7110 res1 = torch._compute_linear_combination(x, coeffs) 7111 res2 = (x.unsqueeze(0) * coeffs.view(*coeff_shape, *([1] * len(data_shape)))).sum(1) 7112 self.assertEqual(res1, res2, atol=1e-5, rtol=0.0) 7113 7114 # check `out=` version 7115 res3 = torch.zeros(coeff_shape[0], *data_shape, device=device, dtype=dtype) 7116 torch._compute_linear_combination(x, coeffs, out=res3) 7117 self.assertEqual(res1, res3, atol=1e-5, rtol=0.0) 7118 7119 res4 = torch.ones(coeff_shape[0], *data_shape, device=device, dtype=dtype) 7120 torch._compute_linear_combination(x, coeffs, out=res4) 7121 self.assertEqual(res1, res4 - 1.0, atol=1e-5, rtol=0.0) 7122 7123 res5 = torch.ones(coeff_shape[0], *data_shape, device=device, dtype=dtype) 7124 res5_clone = res5.clone() 7125 torch._compute_linear_combination(x, coeffs, out=res5) 7126 self.assertEqual(res1, res5 - res5_clone, atol=1e-5, rtol=0.0) 7127 7128 run_test([1, 3], [2, 2]) 7129 run_test([3, 1], [2, 2]) 7130 run_test([1, 10], [10, 10]) 7131 run_test([10, 1], [10, 10]) 7132 run_test([5, 3], [2, 2]) 7133 run_test([5, 3], [100, 100]) 7134 run_test([3, 4], [3, 3, 3]) 7135 run_test([3, 4], [3, 3, 3, 3]) 7136 7137 # Regression test for https://github.com/pytorch/pytorch/issues/94124 7138 with self.assertRaises(RuntimeError): 7139 x = torch.rand([], device=device, dtype=dtype) 7140 coeffs = torch.rand([2, 2], device=device, dtype=dtype) 7141 res = torch._compute_linear_combination(x, coeffs) 7142 7143 @onlyCPU 7144 @skipCPUIfNoLapack 7145 @dtypes(torch.complex64) 7146 def test_linalg_matrix_exp_no_warnings(self, device, dtype): 7147 # this tests https://github.com/pytorch/pytorch/issues/80948 7148 with freeze_rng_state(): 7149 torch.manual_seed(42) 7150 tens = 0.5 * torch.randn(10, 3, 3, dtype=dtype, device=device) 7151 tens = (0.5 * (tens.transpose(-1, -2) + tens)) 7152 with warnings.catch_warnings(record=True) as w: 7153 tens.imag = torch.matrix_exp(tens.imag) 7154 self.assertFalse(len(w)) 7155 7156 @skipCUDAIfNoMagma 7157 @skipCPUIfNoLapack 7158 @dtypes(torch.float, torch.double, torch.complex64, torch.complex128) 7159 def test_linalg_matrix_exp_boundary_cases(self, device, dtype): 7160 expm = torch.linalg.matrix_exp 7161 7162 with self.assertRaisesRegex(RuntimeError, "Expected a floating point or complex tensor"): 7163 expm(torch.randn(3, 3).type(torch.int)) 7164 7165 with self.assertRaisesRegex(RuntimeError, "must have at least 2 dimensions"): 7166 expm(torch.randn(3)) 7167 7168 with self.assertRaisesRegex(RuntimeError, "must be batches of square matrices"): 7169 expm(torch.randn(3, 2, 1)) 7170 7171 # check 1x1 matrices 7172 x = torch.randn(3, 3, 1, 1) 7173 self.assertEqual(expm(x), x.exp()) 7174 7175 @skipCUDAIfNoMagma 7176 @skipCPUIfNoLapack 7177 @dtypes(torch.float, torch.double, torch.complex64, torch.complex128) 7178 def test_linalg_matrix_exp_perverse_nan_values(self, device, dtype): 7179 expm = torch.linalg.matrix_exp 7180 7181 def with_nan(x): 7182 x[0, 0, 0] = torch.nan 7183 return x 7184 7185 # Check small batches 7186 x = with_nan(torch.randn(1, 1, 1)) 7187 self.assertTrue(torch.isnan(expm(x)).any()) 7188 x = with_nan(torch.randn(1, 2, 2)) 7189 for v in [1, 2, 3, 4, 5, 6, 7, 8, 9, 100, 1000]: 7190 self.assertTrue(torch.isnan(expm(x / v)).any()) 7191 7192 # Check large batches 7193 x = with_nan(torch.randn(2, 2, 2)) 7194 self.assertTrue(torch.isnan(expm(x)).any()) 7195 x = with_nan(torch.randn(4096, 2, 2)) 7196 self.assertTrue(torch.isnan(expm(x)).any()) 7197 7198 @slowTest 7199 @skipCUDAIfNoMagma 7200 @skipCPUIfNoLapack 7201 @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble) 7202 def test_linalg_matrix_exp_analytic(self, device, dtype): 7203 expm = torch.linalg.matrix_exp 7204 # check zero matrix 7205 x = torch.zeros(20, 20, dtype=dtype, device=device) 7206 self.assertTrue((expm(x) == torch.eye(20, 20, dtype=dtype, device=device)).all().item()) 7207 7208 def normalize_to_1_operator_norm(sample, desired_norm): 7209 sample_norm, _ = sample.abs().sum(-2).max(-1) 7210 sample_to_1_norm = sample / sample_norm.unsqueeze(-1).unsqueeze(-1) 7211 return sample_to_1_norm * desired_norm 7212 7213 def gen_good_cond_number_matrices(*n): 7214 """ 7215 Generates a diagonally-domimant matrix 7216 with the eigenvalues centered at 1 7217 and the radii at most (n[-1] - 1) / (n[-2] ** 2) 7218 """ 7219 identity = torch.eye(n[-2], n[-1], dtype=dtype, device=device).expand(*n) 7220 x = torch.rand(*n, dtype=dtype, device=device) / (n[-1] ** 2) 7221 x = (x - x * identity) + identity 7222 return x 7223 7224 def run_test(*n): 7225 if dtype == torch.float: 7226 thetas = [ 7227 1.192092800768788e-07, # deg 1 7228 5.978858893805233e-04, # deg 2 7229 5.116619363445086e-02, # deg 4 7230 5.800524627688768e-01, # deg 8 7231 1.461661507209034e+00, # deg 12 7232 3.010066362817634e+00 # deg 18 7233 ] 7234 else: # if torch.double 7235 thetas = [ 7236 2.220446049250313e-16, # deg 1 7237 2.580956802971767e-08, # deg 2 7238 3.397168839976962e-04, # deg 4 7239 4.991228871115323e-02, # deg 8 7240 2.996158913811580e-01, # deg 12 7241 1.090863719290036e+00 # deg 18 7242 ] 7243 7244 # generate input 7245 q = gen_good_cond_number_matrices(*n) 7246 q_ = q.cpu().numpy() 7247 qinv = torch.inverse(q) 7248 qinv_ = qinv.cpu().numpy() 7249 d = torch.randn(n[:-1], dtype=dtype, device=device) 7250 x = torch.from_numpy( 7251 np.matmul(q_, np.matmul(torch.diag_embed(d).cpu().numpy(), qinv_))).to(device) 7252 x_norm, _ = x.abs().sum(-2).max(-1) 7253 7254 # test simple analytic whatever norm generated 7255 mexp = expm(x) 7256 mexp_analytic = np.matmul( 7257 q_, 7258 np.matmul( 7259 torch.diag_embed(d.exp()).cpu().numpy(), 7260 qinv_ 7261 ) 7262 ) 7263 self.assertEqual(mexp, mexp_analytic, atol=1e-3, rtol=0.0) 7264 7265 # generate norms to test different degree expansions 7266 sample_norms = [] 7267 for i in range(len(thetas) - 1): 7268 sample_norms.append(0.5 * (thetas[i] + thetas[i + 1])) 7269 sample_norms = [thetas[0] / 2] + sample_norms + [thetas[-1] * 2] 7270 7271 # matrices to equal norm 7272 for sample_norm in sample_norms: 7273 x_normalized = normalize_to_1_operator_norm(x, sample_norm) 7274 7275 mexp = expm(x_normalized) 7276 mexp_analytic = np.matmul( 7277 q_, 7278 np.matmul( 7279 torch.diag_embed((d / x_norm.unsqueeze(-1) * sample_norm).exp()).cpu().numpy(), 7280 qinv_ 7281 ) 7282 ) 7283 self.assertEqual(mexp, mexp_analytic, atol=1e-3, rtol=0.0) 7284 7285 # single matrix 7286 run_test(2, 2) 7287 run_test(3, 3) 7288 run_test(4, 4) 7289 run_test(5, 5) 7290 run_test(100, 100) 7291 run_test(200, 200) 7292 7293 # small batch of matrices 7294 run_test(3, 2, 2) 7295 run_test(3, 3, 3) 7296 run_test(3, 4, 4) 7297 run_test(3, 5, 5) 7298 run_test(3, 100, 100) 7299 run_test(3, 200, 200) 7300 7301 # large batch of matrices 7302 run_test(3, 3, 2, 2) 7303 run_test(3, 3, 3, 3) 7304 run_test(3, 3, 4, 4) 7305 run_test(3, 3, 5, 5) 7306 run_test(3, 3, 100, 100) 7307 run_test(3, 3, 200, 200) 7308 7309 @skipCUDAIfNoMagma 7310 @skipCPUIfNoLapack 7311 @dtypes(torch.float, torch.double) 7312 def test_linalg_matrix_exp_batch(self, device, dtype): 7313 7314 def run_test(*n): 7315 tensors_batch = torch.zeros(n, dtype=dtype, device=device) 7316 tensors_batch = tensors_batch.view(-1, n[-2], n[-1]) 7317 7318 num_matrices = tensors_batch.size(0) 7319 tensors_list = [] 7320 for i in range(num_matrices): 7321 tensors_list.append(torch.randn(n[-2], n[-1], dtype=dtype, device=device)) 7322 7323 for i in range(num_matrices): 7324 tensors_batch[i, ...] = tensors_list[i] 7325 7326 tensors_exp_map = (torch.linalg.matrix_exp(x) for x in tensors_list) 7327 tensors_exp_batch = torch.linalg.matrix_exp(tensors_batch) 7328 7329 for i, tensor_exp in enumerate(tensors_exp_map): 7330 self.assertEqual(tensors_exp_batch[i, ...], tensor_exp) 7331 7332 # small batch of matrices 7333 run_test(3, 2, 2) 7334 run_test(3, 3, 3) 7335 run_test(3, 4, 4) 7336 run_test(3, 5, 5) 7337 7338 # large batch of matrices 7339 run_test(3, 3, 2, 2) 7340 run_test(3, 3, 3, 3) 7341 run_test(3, 3, 4, 4) 7342 run_test(3, 3, 5, 5) 7343 7344 @skipCUDAIfNoMagma 7345 @skipCPUIfNoLapack 7346 @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble) 7347 def test_linalg_matrix_exp_compare_with_taylor(self, device, dtype): 7348 7349 def normalize_to_1_operator_norm(sample, desired_norm): 7350 sample_norm, _ = sample.abs().sum(-2).max(-1) 7351 sample_to_1_norm = sample / sample_norm.unsqueeze(-1).unsqueeze(-1) 7352 return sample_to_1_norm * desired_norm 7353 7354 def gen_good_cond_number_matrices(*n): 7355 """ 7356 Generates a diagonally-domimant matrix 7357 with the eigenvalues centered at 1 7358 and the radii at most (n[-1] - 1) / (n[-2] ** 2) 7359 """ 7360 identity = torch.eye(n[-2], n[-1], dtype=dtype, device=device).expand(*n) 7361 x = torch.rand(*n, dtype=dtype, device=device) / (n[-1] ** 2) 7362 x = (x - x * identity) + identity 7363 return x 7364 7365 def get_taylor_approximation(a, deg): 7366 a_ = a.cpu().numpy() 7367 identity = torch.eye(a.size(-2), a.size(-1), dtype=dtype, device=device).expand_as(a) 7368 res = identity.cpu().numpy() 7369 taylor_term = identity.cpu().numpy() 7370 7371 for i in range(1, deg + 1): 7372 taylor_term = np.matmul(a_, taylor_term) / i 7373 res = res + taylor_term 7374 7375 return res 7376 7377 def scale_square(a, deg): 7378 if a.abs().pow(2).sum().sqrt() < 1.0: 7379 return get_taylor_approximation(a, 12) 7380 else: 7381 s = int(torch.log2(a.abs().pow(2).sum().sqrt()).ceil().item()) 7382 b = a / (2 ** s) 7383 b = get_taylor_approximation(b, 18) 7384 for _ in range(s): 7385 b = np.matmul(b, b) 7386 return torch.from_numpy(b).to(a.device) 7387 7388 def run_test(*n): 7389 degs = [1, 2, 4, 8, 12, 18] 7390 if dtype == torch.float: 7391 thetas = [ 7392 1.192092800768788e-07, # deg 1 7393 5.978858893805233e-04, # deg 2 7394 5.116619363445086e-02, # deg 4 7395 5.800524627688768e-01, # deg 8 7396 1.461661507209034e+00, # deg 12 7397 3.010066362817634e+00 # deg 18 7398 ] 7399 else: # if torch.double 7400 thetas = [ 7401 2.220446049250313e-16, # deg 1 7402 2.580956802971767e-08, # deg 2 7403 3.397168839976962e-04, # deg 4 7404 4.991228871115323e-02, # deg 8 7405 2.996158913811580e-01, # deg 12 7406 1.090863719290036e+00 # deg 18 7407 ] 7408 7409 # generate norms to test different degree expansions 7410 sample_norms = [] 7411 for i in range(len(thetas) - 1): 7412 sample_norms.append(0.5 * (thetas[i] + thetas[i + 1])) 7413 sample_norms = [thetas[0] / 2] + sample_norms + [thetas[-1] * 2] 7414 degs = [degs[0]] + degs 7415 7416 for sample_norm, deg in zip(sample_norms, degs): 7417 x = gen_good_cond_number_matrices(*n) 7418 x = normalize_to_1_operator_norm(x, sample_norm) 7419 7420 mexp = torch.linalg.matrix_exp(x) 7421 mexp_taylor = scale_square(x, deg) 7422 7423 self.assertEqual(mexp, mexp_taylor, atol=1e-2, rtol=0.0) 7424 7425 # single matrix 7426 run_test(2, 2) 7427 run_test(3, 3) 7428 run_test(4, 4) 7429 run_test(5, 5) 7430 7431 # small batch of matrices 7432 run_test(3, 2, 2) 7433 run_test(3, 3, 3) 7434 run_test(3, 4, 4) 7435 run_test(3, 5, 5) 7436 7437 # large batch of matrices 7438 run_test(3, 3, 2, 2) 7439 run_test(3, 3, 3, 3) 7440 run_test(3, 3, 4, 4) 7441 run_test(3, 3, 5, 5) 7442 7443 @skipCUDAIfNoMagma 7444 @skipCPUIfNoLapack 7445 @dtypes(*floating_and_complex_types()) 7446 @precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3, 7447 torch.float64: 1e-8, torch.complex128: 1e-8}) 7448 def test_slogdet(self, device, dtype): 7449 from torch.testing._internal.common_utils import (random_hermitian_matrix, random_hermitian_psd_matrix, 7450 random_hermitian_pd_matrix, random_square_matrix_of_rank) 7451 7452 # mat_chars denotes matrix characteristics 7453 # possible values are: hermitian, hermitian_psd, hermitian_pd, singular, non_singular 7454 def run_test(matsize, batchdims, mat_chars): 7455 num_matrices = np.prod(batchdims) 7456 list_of_matrices = [] 7457 if num_matrices != 0: 7458 for idx in range(num_matrices): 7459 mat_type = idx % len(mat_chars) 7460 if mat_chars[mat_type] == 'hermitian': 7461 list_of_matrices.append(random_hermitian_matrix(matsize, dtype=dtype, device=device)) 7462 elif mat_chars[mat_type] == 'hermitian_psd': 7463 list_of_matrices.append(random_hermitian_psd_matrix(matsize, dtype=dtype, device=device)) 7464 elif mat_chars[mat_type] == 'hermitian_pd': 7465 list_of_matrices.append(random_hermitian_pd_matrix(matsize, dtype=dtype, device=device)) 7466 elif mat_chars[mat_type] == 'singular': 7467 list_of_matrices.append(torch.ones(matsize, matsize, dtype=dtype, device=device)) 7468 elif mat_chars[mat_type] == 'non_singular': 7469 list_of_matrices.append(random_square_matrix_of_rank(matsize, matsize, dtype=dtype, device=device)) 7470 full_tensor = torch.stack(list_of_matrices, dim=0).reshape(batchdims + (matsize, matsize)) 7471 else: 7472 full_tensor = torch.randn(*batchdims, matsize, matsize, dtype=dtype, device=device) 7473 7474 actual_value = torch.linalg.slogdet(full_tensor) 7475 expected_value = np.linalg.slogdet(full_tensor.cpu().numpy()) 7476 self.assertEqual(expected_value[0], actual_value[0], atol=self.precision, rtol=self.precision) 7477 self.assertEqual(expected_value[1], actual_value[1], atol=self.precision, rtol=self.precision) 7478 7479 # test out=variant 7480 sign_out = torch.empty_like(actual_value[0]) 7481 logabsdet_out = torch.empty_like(actual_value[1]) 7482 ans = torch.linalg.slogdet(full_tensor, out=(sign_out, logabsdet_out)) 7483 self.assertEqual(ans[0], sign_out) 7484 self.assertEqual(ans[1], logabsdet_out) 7485 self.assertEqual(sign_out, actual_value[0]) 7486 self.assertEqual(logabsdet_out, actual_value[1]) 7487 7488 for matsize, batchdims in itertools.product([0, 3, 5], [(0,), (3,), (5, 3)]): 7489 run_test(matsize, batchdims, mat_chars=['hermitian_pd']) 7490 run_test(matsize, batchdims, mat_chars=['singular']) 7491 run_test(matsize, batchdims, mat_chars=['non_singular']) 7492 run_test(matsize, batchdims, mat_chars=['hermitian', 'hermitian_pd', 'hermitian_psd']) 7493 run_test(matsize, batchdims, mat_chars=['singular', 'non_singular']) 7494 7495 @skipCUDAIfNoMagma 7496 @skipCPUIfNoLapack 7497 @dtypes(*floating_and_complex_types()) 7498 def test_slogdet_errors_and_warnings(self, device, dtype): 7499 # slogdet requires the input to be a square matrix or batch of square matrices 7500 a = torch.randn(2, 3, device=device, dtype=dtype) 7501 with self.assertRaisesRegex(RuntimeError, r'must be batches of square matrices'): 7502 torch.linalg.slogdet(a) 7503 7504 # slogdet requires the input to be at least 2 dimensional tensor 7505 a = torch.randn(2, device=device, dtype=dtype) 7506 with self.assertRaisesRegex(RuntimeError, r'must have at least 2 dimensions'): 7507 torch.linalg.slogdet(a) 7508 7509 a = torch.randn(2, 2, device=device, dtype=torch.bfloat16) 7510 with self.assertRaisesRegex(RuntimeError, r'Low precision dtypes not supported'): 7511 torch.linalg.slogdet(a) 7512 7513 # if non-empty out tensor with wrong shape is passed a warning is given 7514 a = torch.randn(2, 3, 3, device=device, dtype=dtype) 7515 sign_out = torch.empty(1, device=device, dtype=dtype) 7516 real_dtype = a.real.dtype if dtype.is_complex else dtype 7517 logabsdet_out = torch.empty(1, device=device, dtype=real_dtype) 7518 with warnings.catch_warnings(record=True) as w: 7519 # Trigger warning 7520 torch.linalg.slogdet(a, out=(sign_out, logabsdet_out)) 7521 # Check warning occurs 7522 self.assertEqual(len(w), 1) 7523 self.assertTrue("An output with one or more elements was resized" in str(w[-1].message)) 7524 7525 # device should match 7526 if torch.cuda.is_available(): 7527 wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda' 7528 sign_out = torch.empty(0, device=wrong_device, dtype=dtype) 7529 logabsdet_out = torch.empty(0, device=wrong_device, dtype=real_dtype) 7530 with self.assertRaisesRegex(RuntimeError, "tensors to be on the same device"): 7531 torch.linalg.slogdet(a, out=(sign_out, logabsdet_out)) 7532 7533 # FIXME One of the backends of lu_factor fails in windows. I haven't investigated which or why 7534 # https://github.com/pytorch/pytorch/issues/75225 7535 @unittest.skipIf(IS_WINDOWS, "Skipped on Windows!") 7536 @skipCUDAIfNoCusolver 7537 @skipCPUIfNoLapack 7538 @dtypes(torch.double) 7539 def test_det_logdet_slogdet(self, device, dtype): 7540 def reference_slogdet(M): 7541 sdet, logabsdet = np.linalg.slogdet(M.detach().cpu().numpy()) 7542 return M.new_tensor(sdet), M.new_tensor(logabsdet) 7543 7544 def test_single_det(M, target, desc): 7545 target_sdet, target_logabsdet = target 7546 7547 det = M.det() 7548 logdet = M.logdet() 7549 sdet, logabsdet = M.slogdet() 7550 linalg_sdet, linalg_logabsdet = torch.linalg.slogdet(M) 7551 7552 # Test det 7553 self.assertEqual(det, target_sdet * target_logabsdet.exp(), 7554 atol=1e-6, rtol=0, msg=f'{desc} (det)') 7555 7556 # Test slogdet 7557 # Compare the overall value rather than individual parts because of 7558 # precision issues when det is near zero. 7559 self.assertEqual(sdet * logabsdet.exp(), target_sdet * target_logabsdet.exp(), 7560 atol=1e-6, rtol=0, msg=f'{desc} (slogdet)') 7561 self.assertEqual(linalg_sdet * linalg_logabsdet.exp(), target_sdet * target_logabsdet.exp(), 7562 atol=1e-6, rtol=0, msg=f'{desc} (linalg_slogdet)') 7563 7564 # Test logdet 7565 # Compare logdet against our own pytorch slogdet because they should 7566 # be consistent, while it may behave slightly differently with other 7567 # slogdet implementations when det is near zero due to precision 7568 # issues. 7569 if sdet.item() < 0: 7570 self.assertTrue(logdet.item() != logdet.item(), f'{desc} (logdet negative case)') 7571 else: 7572 self.assertEqual(logdet.exp(), target_logabsdet.exp(), 7573 atol=1e-6, rtol=0, msg=f'{desc} (logdet non-negative case)') 7574 7575 eye = torch.eye(5, dtype=dtype, device=device) 7576 test_single_det(eye, (torch.ones((), dtype=dtype, device=device), torch.zeros((), dtype=dtype, device=device)), 'identity') 7577 # Testing bug in #34061 (https://github.com/pytorch/pytorch/issues/34061) 7578 for n in range(250, 551, 100): 7579 mat = torch.randn(n, n, dtype=dtype, device=device) 7580 q, _ = torch.qr(mat) 7581 ref_det, ref_logabsdet = reference_slogdet(q) 7582 test_single_det(q, (ref_det, ref_logabsdet), 'orthogonal') 7583 7584 def test(M): 7585 assert M.size(0) >= 5, 'this helper fn assumes M to be at least 5x5' 7586 M = M.to(device) 7587 7588 ref_M_sdet, ref_M_logabsdet = reference_slogdet(M) 7589 7590 test_single_det(M, (ref_M_sdet, ref_M_logabsdet), 'basic') 7591 if ref_M_logabsdet.exp().item() >= 1e-6: # skip singular 7592 M_inv = M.inverse() 7593 test_single_det(M_inv, reference_slogdet(M_inv), 'inverse') 7594 7595 test_single_det(M, (ref_M_sdet, ref_M_logabsdet), 'transpose') 7596 7597 for x in [0, 2, 4]: 7598 for scale in [-2, -0.1, 0, 10]: 7599 if scale > 0: 7600 target = ref_M_sdet, ref_M_logabsdet + math.log(scale) 7601 elif scale == 0: 7602 target = torch.zeros_like(ref_M_sdet), torch.full_like(ref_M_logabsdet, -inf) 7603 else: 7604 target = ref_M_sdet.neg(), ref_M_logabsdet + math.log(-scale) 7605 7606 # dim 0 7607 M_clone = M.clone() 7608 M_clone[:, x] *= scale 7609 test_single_det(M_clone, target, 'scale a row') 7610 # dim 1 7611 M_clone = M.clone() 7612 M_clone[x, :] *= scale 7613 test_single_det(M_clone, target, 'scale a column') 7614 7615 for x1, x2 in [(0, 3), (4, 1), (3, 2)]: 7616 assert x1 != x2, 'x1 and x2 needs to be different for this test' 7617 target = torch.zeros_like(ref_M_sdet), torch.full_like(ref_M_logabsdet, -inf) 7618 # dim 0 7619 M_clone = M.clone() 7620 M_clone[:, x2] = M_clone[:, x1] 7621 test_single_det(M_clone, target, 'two rows are same') 7622 # dim 1 7623 M_clone = M.clone() 7624 M_clone[x2, :] = M_clone[x1, :] 7625 test_single_det(M_clone, target, 'two columns are same') 7626 7627 for scale1, scale2 in [(0.3, -1), (0, 2), (10, 0.1)]: 7628 det_scale = scale1 * scale2 * -1 7629 if det_scale > 0: 7630 target = ref_M_sdet, ref_M_logabsdet + math.log(det_scale) 7631 elif det_scale == 0: 7632 target = torch.zeros_like(ref_M_sdet), torch.full_like(ref_M_logabsdet, -inf) 7633 else: 7634 target = ref_M_sdet.neg(), ref_M_logabsdet + math.log(-det_scale) 7635 7636 # dim 0 7637 M_clone = M.clone() 7638 t = M_clone[:, x1] * scale1 7639 M_clone[:, x1] += M_clone[:, x2] * scale2 7640 M_clone[:, x2] = t 7641 test_single_det(M_clone, target, 'exchanging rows') 7642 # dim 1 7643 M_clone = M.clone() 7644 t = M_clone[x1, :] * scale1 7645 M_clone[x1, :] += M_clone[x2, :] * scale2 7646 M_clone[x2, :] = t 7647 test_single_det(M_clone, target, 'exchanging columns') 7648 7649 def get_random_mat_scale(n): 7650 # For matrices with values i.i.d. with 0 mean, unit variance, and 7651 # subexponential tail, we have: 7652 # E[log det(A^2)] \approx log((n-1)!) 7653 # 7654 # Notice: 7655 # log Var[det(A)] = log E[det(A^2)] >= E[log det(A^2)] 7656 # 7657 # So: 7658 # stddev[det(A)] >= sqrt( (n-1)! ) 7659 # 7660 # We use this as an intuitive guideline to scale random generated 7661 # matrices so our closeness tests can work more robustly: 7662 # scale by sqrt( (n-1)! )^(-1/n) = ( (n-1)! )^(-1/(2n)) 7663 # 7664 # source: https://arxiv.org/pdf/1112.0752.pdf 7665 7666 # TODO: technically we need subexponential distn for this to hold, 7667 # but we mostly use gaussian entries below. Consider switching 7668 # to Chi-sq if this turns out not stable enough, since Chi-sq 7669 # is easy enough to sample from. 7670 return math.factorial(n - 1) ** (-1.0 / (2 * n)) 7671 7672 for n in [5, 10, 25]: 7673 scale = get_random_mat_scale(n) 7674 test(torch.randn(n, n, dtype=dtype, device=device) * scale) 7675 r = torch.randn(n, n, dtype=dtype, device=device) * scale 7676 # symmetric psd 7677 test(r.mm(r.t())) 7678 # symmetric pd 7679 r = torch.randn(n, n, dtype=dtype, device=device) * scale 7680 test(r.mm(r.t()) + torch.eye(n, dtype=dtype, device=device) * 1e-6) 7681 # symmetric 7682 r = torch.randn(n, n, dtype=dtype, device=device) * scale 7683 for i in range(n): 7684 for j in range(i): 7685 r[i, j] = r[j, i] 7686 test(r) 7687 # non-contiguous 7688 test((torch.randn(n, n, n + 1, dtype=dtype, device=device) * scale)[:, 2, 1:]) 7689 # det = 0 7690 r = torch.randn(n, n, dtype=dtype, device=device) * scale 7691 u, s, v = r.svd() 7692 if reference_slogdet(u)[0] < 0: 7693 u = -u 7694 if reference_slogdet(v)[0] < 0: 7695 v = -v 7696 s[0] *= -1 7697 s[-1] = 0 7698 test(u.mm(s.diag()).mm(v)) 7699 7700 # Small values to test numerical stability. Note that we don't scale 7701 # this matrix. 7702 r = torch.randn(512, 512, dtype=dtype, device=device) 7703 u, s, v = r.svd() 7704 s.fill_(1. / (100 * s.numel())) 7705 test(u.mm(s.diag()).mm(v)) 7706 7707 @skipCUDAIfNoMagma 7708 @skipCPUIfNoLapack 7709 @dtypes(torch.double) 7710 def test_det_logdet_slogdet_batched(self, device, dtype): 7711 from torch.testing._internal.common_utils import (random_symmetric_matrix, random_symmetric_psd_matrix, 7712 random_symmetric_pd_matrix, random_square_matrix_of_rank) 7713 7714 # mat_chars denotes matrix characteristics 7715 # possible values are: sym, sym_psd, sym_pd, sing, non_sym 7716 def run_test(matsize, batchdims, mat_chars): 7717 num_matrices = reduce(operator.mul, batchdims, 1) 7718 list_of_matrices = [] 7719 7720 for idx in range(num_matrices): 7721 mat_type = idx % len(mat_chars) 7722 if mat_chars[mat_type] == 'sym': 7723 list_of_matrices.append(random_symmetric_matrix(matsize, dtype=dtype, device=device)) 7724 elif mat_chars[mat_type] == 'sym_psd': 7725 list_of_matrices.append(random_symmetric_psd_matrix(matsize, dtype=dtype, device=device)) 7726 elif mat_chars[mat_type] == 'sym_pd': 7727 list_of_matrices.append(random_symmetric_pd_matrix(matsize, dtype=dtype, device=device)) 7728 elif mat_chars[mat_type] == 'sing': 7729 list_of_matrices.append(torch.ones(matsize, matsize, dtype=dtype, device=device)) 7730 elif mat_chars[mat_type] == 'non_sing': 7731 list_of_matrices.append(random_square_matrix_of_rank(matsize, matsize, dtype=dtype, device=device)) 7732 full_tensor = torch.stack(list_of_matrices, dim=0).reshape(batchdims + (matsize, matsize)) 7733 # Scaling adapted from `get_random_mat_scale` in _test_det_logdet_slogdet 7734 full_tensor *= (math.factorial(matsize - 1) ** (-1.0 / (2 * matsize))) 7735 7736 for fn in [torch.det, torch.logdet, torch.slogdet, torch.linalg.slogdet]: 7737 expected_value = [] 7738 actual_value = fn(full_tensor) 7739 for full_idx in itertools.product(*(list(range(x)) for x in batchdims)): 7740 expected_value.append(fn(full_tensor[full_idx])) 7741 7742 if fn == torch.slogdet or fn == torch.linalg.slogdet: 7743 sign_value = torch.stack([tup[0] for tup in expected_value], dim=0).reshape(batchdims) 7744 expected_value = torch.stack([tup[1] for tup in expected_value], dim=0).reshape(batchdims) 7745 self.assertEqual(sign_value, actual_value[0]) 7746 self.assertEqual(expected_value, actual_value[1]) 7747 else: 7748 expected_value = torch.stack(expected_value, dim=0).reshape(batchdims) 7749 self.assertEqual(actual_value, expected_value) 7750 7751 for matsize, batchdims in itertools.product([3, 5], [(3,), (5, 3)]): 7752 run_test(matsize, batchdims, mat_chars=['sym_pd']) 7753 run_test(matsize, batchdims, mat_chars=['sing']) 7754 run_test(matsize, batchdims, mat_chars=['non_sing']) 7755 run_test(matsize, batchdims, mat_chars=['sym', 'sym_pd', 'sym_psd']) 7756 run_test(matsize, batchdims, mat_chars=['sing', 'non_sing']) 7757 7758 @skipCUDAIfNoMagma 7759 @skipCPUIfNoLapack 7760 @dtypes(*floating_and_complex_types()) 7761 def test_cholesky_inverse(self, device, dtype): 7762 from torch.testing._internal.common_utils import random_hermitian_pd_matrix 7763 7764 def run_test(shape, batch, upper, contiguous): 7765 A = random_hermitian_pd_matrix(shape, *batch, dtype=dtype, device=device) 7766 if A.numel() > 0 and not contiguous: 7767 A = A.mT 7768 self.assertFalse(A.is_contiguous()) 7769 L = torch.linalg.cholesky(A) 7770 expected_inverse = torch.inverse(A) 7771 L = L.mH if upper else L 7772 actual_inverse = torch.cholesky_inverse(L, upper) 7773 self.assertEqual(actual_inverse, expected_inverse) 7774 7775 shapes = (0, 3, 5) 7776 batches = ((), (0,), (3, ), (2, 2)) 7777 for shape, batch, upper, contiguous in list(itertools.product(shapes, batches, (True, False), (True, False))): 7778 run_test(shape, batch, upper, contiguous) 7779 7780 # check the out= variant 7781 A = random_hermitian_pd_matrix(3, 2, dtype=dtype, device=device) 7782 L = torch.linalg.cholesky(A) 7783 7784 # There are two code paths currently for the out= variant 7785 # 1. When 'out' tensor is in Fortran (column-major) memory format 7786 # then the fast route is taken and the storage is reused directly in the computations 7787 # 2. When 'out' tensor is not in Fortran format then a temporary tensor is allocated internally 7788 # and the result is copied from the temporary tensor to 'out' tensor 7789 7790 # This test checks the first code path 7791 out = torch.empty_like(A) 7792 out_t = out.mT.clone(memory_format=torch.contiguous_format) 7793 out = out_t.mT 7794 ans = torch.cholesky_inverse(L, out=out) 7795 self.assertEqual(ans, out) 7796 expected = torch.inverse(A) 7797 self.assertEqual(expected, out) 7798 7799 # This test checks the second code path 7800 out = torch.empty_like(A) 7801 ans = torch.cholesky_inverse(L, out=out) 7802 self.assertEqual(ans, out) 7803 expected = torch.inverse(A) 7804 self.assertEqual(expected, out) 7805 7806 @skipCUDAIfNoMagma 7807 @skipCPUIfNoLapack 7808 @dtypes(*floating_and_complex_types()) 7809 def test_cholesky_inverse_errors_and_warnings(self, device, dtype): 7810 # cholesky_inverse requires the input to be at least 2 dimensional tensor 7811 a = torch.randn(2, device=device, dtype=dtype) 7812 with self.assertRaisesRegex(RuntimeError, "must have at least 2 dimensions"): 7813 torch.cholesky_inverse(a) 7814 7815 # cholesky_inverse requires a square matrix 7816 a = torch.randn(2, 3, device=device, dtype=dtype) 7817 with self.assertRaisesRegex(RuntimeError, "must be batches of square matrices"): 7818 torch.cholesky_inverse(a) 7819 7820 # if non-empty out tensor with wrong shape is passed a warning is given 7821 a = torch.randn(3, 3, device=device, dtype=dtype) 7822 out = torch.empty(2, 3, device=device, dtype=dtype) 7823 with warnings.catch_warnings(record=True) as w: 7824 # Trigger warning 7825 torch.cholesky_inverse(a, out=out) 7826 # Check warning occurs 7827 self.assertEqual(len(w), 1) 7828 self.assertTrue("An output with one or more elements was resized" in str(w[-1].message)) 7829 7830 # dtypes should be safely castable 7831 out = torch.empty(*a.shape, dtype=torch.int, device=device) 7832 with self.assertRaisesRegex(RuntimeError, "but got result with dtype Int"): 7833 torch.cholesky_inverse(a, out=out) 7834 7835 # device should match 7836 if torch.cuda.is_available(): 7837 wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda' 7838 out = torch.empty(0, device=wrong_device, dtype=dtype) 7839 with self.assertRaisesRegex(RuntimeError, "Expected all tensors to be on the same device"): 7840 torch.cholesky_inverse(a, out=out) 7841 7842 # cholesky_inverse raises an error for invalid inputs on CPU 7843 # for example if at least one diagonal element is zero 7844 a = torch.randn(3, 3, device=device, dtype=dtype) 7845 a[1, 1] = 0 7846 if self.device_type == 'cpu': 7847 with self.assertRaisesRegex(torch.linalg.LinAlgError, r"cholesky_inverse: The diagonal element 2 is zero"): 7848 torch.cholesky_inverse(a) 7849 # cholesky_inverse on GPU does not raise an error for this case 7850 elif self.device_type == 'cuda': 7851 out = torch.cholesky_inverse(a) 7852 self.assertTrue(out.isinf().any() or out.isnan().any()) 7853 7854 def _select_broadcastable_dims(self, dims_full=None): 7855 # select full dimensionality 7856 if dims_full is None: 7857 dims_full = [] 7858 ndims = random.randint(1, 4) 7859 dims_full = [random.randint(1, 8) for _ in range(ndims)] 7860 else: 7861 ndims = len(dims_full) 7862 7863 # select actual dimensions for ops: 7864 # larger: full ndims, individual sizes may be reduced 7865 # smaller: possibly reduced ndims, sizes may be reduced 7866 smaller_ndims = random.randint(1, ndims) 7867 dims_small = [] 7868 dims_large = [] 7869 for i in range(ndims - 1, -1, -1): 7870 j = random.randint(1, 3) 7871 if j == 1: # no reduced singleton dimension 7872 ds = dims_full[i] 7873 dl = dims_full[i] 7874 elif j == 2: # larger may have reduced singleton dimension 7875 ds = dims_full[i] 7876 dl = 1 if len(dims_small) < smaller_ndims else dims_full[i] 7877 elif j == 3: # smaller may have reduced singleton dimension 7878 ds = 1 7879 dl = dims_full[i] 7880 dims_large = [dl] + dims_large 7881 if len(dims_small) < smaller_ndims: 7882 dims_small = [ds] + dims_small 7883 return (dims_small, dims_large, dims_full) 7884 7885 def test_broadcast_fused_matmul(self, device): 7886 fns = ["baddbmm", "addbmm", "addmm", "addmv", "addr"] 7887 7888 for fn in fns: 7889 batch_dim = random.randint(1, 8) 7890 n_dim = random.randint(1, 8) 7891 m_dim = random.randint(1, 8) 7892 p_dim = random.randint(1, 8) 7893 7894 def dims_full_for_fn(): 7895 if fn == "baddbmm": 7896 return ([batch_dim, n_dim, p_dim], [batch_dim, n_dim, m_dim], [batch_dim, m_dim, p_dim]) 7897 elif fn == "addbmm": 7898 return ([n_dim, p_dim], [batch_dim, n_dim, m_dim], [batch_dim, m_dim, p_dim]) 7899 elif fn == "addmm": 7900 return ([n_dim, p_dim], [n_dim, m_dim], [m_dim, p_dim]) 7901 elif fn == "addmv": 7902 return ([n_dim], [n_dim, m_dim], [m_dim]) 7903 elif fn == "addr": 7904 return ([n_dim, m_dim], [n_dim], [m_dim]) 7905 else: 7906 raise AssertionError("unknown function") 7907 7908 (t0_dims_full, t1_dims, t2_dims) = dims_full_for_fn() 7909 (t0_dims_small, _, _) = self._select_broadcastable_dims(t0_dims_full) 7910 7911 t0_small = torch.randn(*t0_dims_small, device=device).float() 7912 t1 = torch.randn(*t1_dims, device=device).float() 7913 t2 = torch.randn(*t2_dims, device=device).float() 7914 7915 t0_full = t0_small.expand(*t0_dims_full).to(device) 7916 7917 fntorch = getattr(torch, fn) 7918 r0 = fntorch(t0_small, t1, t2) 7919 r1 = fntorch(t0_full, t1, t2) 7920 self.assertEqual(r0, r1) 7921 7922 @tf32_on_and_off(0.001) 7923 @bf32_on_and_off(0.001) 7924 def test_broadcast_batched_matmul(self, device): 7925 n_dim = random.randint(1, 8) 7926 m_dim = random.randint(1, 8) 7927 p_dim = random.randint(1, 8) 7928 full_batch_dims = [random.randint(1, 3) for i in range(random.randint(1, 3))] 7929 (batch_dims_small, _, _) = self._select_broadcastable_dims(full_batch_dims) 7930 7931 def verify_batched_matmul(full_lhs, one_dimensional): 7932 if not one_dimensional: 7933 lhs_dims = [n_dim, m_dim] 7934 rhs_dims = [m_dim, p_dim] 7935 result_dims = [n_dim, p_dim] 7936 else: 7937 lhs_dims = [n_dim, m_dim] if full_lhs else [m_dim] 7938 rhs_dims = [m_dim, p_dim] if not full_lhs else [m_dim] 7939 result_dims = [n_dim] if full_lhs else [p_dim] 7940 7941 lhs_mat_dims = lhs_dims if len(lhs_dims) != 1 else [1, m_dim] 7942 rhs_mat_dims = rhs_dims if len(rhs_dims) != 1 else [m_dim, 1] 7943 full_mat_dims = lhs_mat_dims if full_lhs else rhs_mat_dims 7944 dim0_dims = rhs_dims if full_lhs else lhs_dims 7945 small_dims = batch_dims_small + (rhs_mat_dims if full_lhs else lhs_mat_dims) 7946 7947 small = torch.randn(*(small_dims), device=device).float() 7948 dim0 = torch.randn(*(dim0_dims), device=device).float() 7949 full = torch.randn(*(full_batch_dims + full_mat_dims), device=device).float() 7950 if not one_dimensional: 7951 (lhsTensors, rhsTensors) = ((full,), (small, dim0)) if full_lhs else ((small, dim0), (full,)) 7952 else: 7953 (lhsTensors, rhsTensors) = ((full,), (dim0,)) if full_lhs else ((dim0,), (full,)) 7954 7955 def maybe_squeeze_result(l, r, result): 7956 if len(lhs_dims) == 1 and l.dim() != 1: 7957 return result.squeeze(-2) 7958 elif len(rhs_dims) == 1 and r.dim() != 1: 7959 return result.squeeze(-1) 7960 else: 7961 return result 7962 7963 for lhs in lhsTensors: 7964 lhs_expanded = lhs.expand(*(torch.Size(full_batch_dims) + torch.Size(lhs_mat_dims))) 7965 lhs_expanded_matmul_fn = lhs_expanded.matmul 7966 for rhs in rhsTensors: 7967 rhs_expanded = ((rhs if len(rhs_dims) != 1 else rhs.unsqueeze(-1)). 7968 expand(*(torch.Size(full_batch_dims) + torch.Size(rhs_mat_dims)))) 7969 truth = maybe_squeeze_result(lhs_expanded, rhs_expanded, lhs_expanded_matmul_fn(rhs_expanded)) 7970 for l in (lhs, lhs_expanded): 7971 for r in (rhs, rhs_expanded): 7972 l_matmul_fn = l.matmul 7973 result = maybe_squeeze_result(l, r, l_matmul_fn(r)) 7974 self.assertEqual(truth, result) 7975 # test torch.matmul function as well 7976 torch_result = maybe_squeeze_result(l, r, torch.matmul(l, r)) 7977 self.assertEqual(truth, torch_result) 7978 # test torch.matmul with out 7979 out = torch.zeros_like(torch_result) 7980 torch.matmul(l, r, out=out) 7981 self.assertEqual(truth, maybe_squeeze_result(l, r, out)) 7982 7983 # compare to bmm 7984 bmm_result = (torch.bmm(lhs_expanded.contiguous().view(-1, *lhs_mat_dims), 7985 rhs_expanded.contiguous().view(-1, *rhs_mat_dims))) 7986 self.assertEqual(truth.view(-1, *result_dims), bmm_result.view(-1, *result_dims)) 7987 7988 for indices in itertools.product((True, False), repeat=2): 7989 verify_batched_matmul(*indices) 7990 7991 def lu_solve_test_helper(self, A_dims, b_dims, pivot, device, dtype): 7992 make_fullrank = make_fullrank_matrices_with_distinct_singular_values 7993 make_A = partial(make_fullrank, device=device, dtype=dtype) 7994 7995 b = torch.randn(*b_dims, dtype=dtype, device=device) 7996 A = make_A(*A_dims) 7997 LU_data, LU_pivots, info = torch.linalg.lu_factor_ex(A) 7998 self.assertEqual(info, torch.zeros_like(info)) 7999 return b, A, LU_data, LU_pivots 8000 8001 @skipCPUIfNoLapack 8002 @skipCUDAIfNoMagmaAndNoCusolver 8003 @dtypes(*floating_and_complex_types()) 8004 @precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3, 8005 torch.float64: 1e-8, torch.complex128: 1e-8}) 8006 def test_lu_solve(self, device, dtype): 8007 def sub_test(pivot): 8008 for k, n in zip([2, 3, 5], [3, 5, 7]): 8009 b, A, LU_data, LU_pivots = self.lu_solve_test_helper((n, n), (n, k), pivot, device, dtype) 8010 x = torch.lu_solve(b, LU_data, LU_pivots) 8011 self.assertEqual(b, np.matmul(A.cpu(), x.cpu())) 8012 8013 sub_test(True) 8014 if self.device_type == 'cuda': 8015 sub_test(False) 8016 8017 @skipCPUIfNoLapack 8018 @skipCUDAIfNoMagmaAndNoCusolver 8019 @dtypes(*floating_and_complex_types()) 8020 @precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3, 8021 torch.float64: 1e-8, torch.complex128: 1e-8}) 8022 def test_lu_solve_batched(self, device, dtype): 8023 def sub_test(pivot): 8024 def lu_solve_batch_test_helper(A_dims, b_dims, pivot): 8025 b, A, LU_data, LU_pivots = self.lu_solve_test_helper(A_dims, b_dims, pivot, device, dtype) 8026 x_exp_list = [] 8027 for i in range(b_dims[0]): 8028 x_exp_list.append(torch.lu_solve(b[i], LU_data[i], LU_pivots[i])) 8029 x_exp = torch.stack(x_exp_list) # Stacked output 8030 x_act = torch.lu_solve(b, LU_data, LU_pivots) # Actual output 8031 self.assertEqual(x_exp, x_act) # Equality check 8032 Ax = np.matmul(A.cpu(), x_act.cpu()) 8033 self.assertEqual(b, Ax) 8034 8035 for batchsize in [1, 3, 4]: 8036 lu_solve_batch_test_helper((batchsize, 5, 5), (batchsize, 5, 10), pivot) 8037 8038 # Tests tensors with 0 elements 8039 b = torch.randn(3, 0, 3, dtype=dtype, device=device) 8040 A = torch.randn(3, 0, 0, dtype=dtype, device=device) 8041 LU_data, LU_pivots = torch.linalg.lu_factor(A) 8042 self.assertEqual(torch.empty_like(b), b.lu_solve(LU_data, LU_pivots)) 8043 8044 sub_test(True) 8045 if self.device_type == 'cuda': 8046 sub_test(False) 8047 8048 @slowTest 8049 @skipCPUIfNoLapack 8050 @skipCUDAIfNoMagmaAndNoCusolver 8051 @dtypes(*floating_and_complex_types()) 8052 def test_lu_solve_batched_many_batches(self, device, dtype): 8053 def run_test(A_dims, b_dims): 8054 b, A, LU_data, LU_pivots = self.lu_solve_test_helper(A_dims, b_dims, True, device, dtype) 8055 x = torch.lu_solve(b, LU_data, LU_pivots) 8056 Ax = torch.matmul(A, x) 8057 self.assertEqual(Ax, b.expand_as(Ax)) 8058 8059 run_test((65536, 5, 5), (65536, 5, 10)) 8060 run_test((262144, 5, 5), (262144, 5, 10)) 8061 8062 @skipCPUIfNoLapack 8063 @skipCUDAIfNoMagmaAndNoCusolver 8064 @dtypes(*floating_and_complex_types()) 8065 def test_lu_solve_batched_broadcasting(self, device, dtype): 8066 make_fullrank = make_fullrank_matrices_with_distinct_singular_values 8067 make_A = partial(make_fullrank, device=device, dtype=dtype) 8068 8069 def run_test(A_dims, b_dims, pivot=True): 8070 A_matrix_size = A_dims[-1] 8071 A_batch_dims = A_dims[:-2] 8072 A = make_A(*A_batch_dims, A_matrix_size, A_matrix_size) 8073 b = make_tensor(b_dims, dtype=dtype, device=device) 8074 x_exp = np.linalg.solve(A.cpu(), b.cpu()) 8075 LU_data, LU_pivots = torch.linalg.lu_factor(A) 8076 x = torch.lu_solve(b, LU_data, LU_pivots) 8077 self.assertEqual(x, x_exp) 8078 8079 # test against numpy.linalg.solve 8080 run_test((2, 1, 3, 4, 4), (2, 1, 3, 4, 6)) # no broadcasting 8081 run_test((2, 1, 3, 4, 4), (4, 6)) # broadcasting b 8082 run_test((4, 4), (2, 1, 3, 4, 2)) # broadcasting A 8083 run_test((1, 3, 1, 4, 4), (2, 1, 3, 4, 5)) # broadcasting A & b 8084 8085 @onlyCUDA 8086 @skipCUDAIfNoMagma 8087 @dtypes(*floating_and_complex_types()) 8088 # this tests https://github.com/pytorch/pytorch/issues/36921 8089 def test_lu_solve_large_matrices(self, device, dtype): 8090 def run_test(A_dims, b_dims): 8091 b, A, LU_data, LU_pivots = self.lu_solve_test_helper(A_dims, b_dims, True, device, dtype) 8092 x = torch.lu_solve(b, LU_data, LU_pivots) 8093 Ax = torch.matmul(A, x) 8094 self.assertEqual(Ax, b.expand_as(Ax)) 8095 8096 run_test((1, 1), (1, 1, 1025)) 8097 8098 @skipCUDAIfNoCusolver 8099 @skipCPUIfNoLapack 8100 def test_pca_lowrank(self, device): 8101 from torch.testing._internal.common_utils import random_lowrank_matrix, random_sparse_matrix 8102 8103 dtype = torch.double 8104 8105 def run_subtest(guess_rank, actual_rank, matrix_size, batches, device, pca, **options): 8106 density = options.pop('density', 1) 8107 use_svd_lowrank = options.pop('use_svd_lowrank', False) 8108 if isinstance(matrix_size, int): 8109 rows = columns = matrix_size 8110 else: 8111 rows, columns = matrix_size 8112 if density == 1: 8113 a_input = random_lowrank_matrix(actual_rank, rows, columns, *batches, device=device, dtype=dtype) 8114 a = a_input 8115 else: 8116 a_input = random_sparse_matrix(rows, columns, density, device=device, dtype=dtype) 8117 a = a_input.to_dense() 8118 8119 if use_svd_lowrank: 8120 m = a_input.mean(dim=-2, keepdim=True) 8121 u, s, v = pca(a_input, q=guess_rank, M=m, **options) 8122 else: 8123 u, s, v = pca(a_input, q=guess_rank, **options) 8124 8125 self.assertEqual(s.shape[-1], guess_rank) 8126 self.assertEqual(u.shape[-2], rows) 8127 self.assertEqual(u.shape[-1], guess_rank) 8128 self.assertEqual(v.shape[-1], guess_rank) 8129 self.assertEqual(v.shape[-2], columns) 8130 8131 A1 = u.matmul(s.diag_embed()).matmul(v.mT) 8132 ones_m1 = torch.ones(batches + (rows, 1), dtype=a.dtype, device=device) 8133 c = a.sum(axis=-2) / rows 8134 c = c.reshape(batches + (1, columns)) 8135 A2 = a - ones_m1.matmul(c) 8136 self.assertEqual(A1, A2) 8137 8138 if density == 1: 8139 # actual rank is known only for dense input 8140 detect_rank = (s.abs() > 1e-5).sum(axis=-1) 8141 self.assertEqual(actual_rank * torch.ones(batches, device=device, dtype=torch.int64), detect_rank) 8142 S = torch.linalg.svdvals(A2) 8143 self.assertEqual(s[..., :actual_rank], S[..., :actual_rank]) 8144 8145 all_batches = [(), (1,), (3,), (2, 3)] 8146 for actual_rank, size, all_batches in [ # noqa: B020 8147 (2, (17, 4), all_batches), 8148 (2, (100, 4), all_batches), 8149 (6, (100, 40), all_batches), 8150 (12, (1000, 1000), [()]), 8151 ]: 8152 for batches in all_batches: 8153 for guess_rank in [ 8154 actual_rank, 8155 actual_rank + 2, 8156 actual_rank + 6, 8157 ]: 8158 if guess_rank <= min(*size): 8159 run_subtest(guess_rank, actual_rank, size, batches, device, torch.pca_lowrank) 8160 run_subtest(guess_rank, actual_rank, size[::-1], batches, device, torch.pca_lowrank) 8161 run_subtest(guess_rank, actual_rank, size, batches, device, torch.svd_lowrank, use_svd_lowrank=True) 8162 run_subtest(guess_rank, actual_rank, size[::-1], batches, device, torch.svd_lowrank, use_svd_lowrank=True) 8163 8164 # sparse input 8165 for guess_rank, size in [ 8166 (4, (17, 4)), (4, (4, 17)), (16, (17, 17)), 8167 (21, (100, 40)), (20, (40, 100)), (600, (1000, 1000))]: 8168 for density in [0.005, 0.1]: 8169 run_subtest(guess_rank, None, size, (), device, torch.pca_lowrank, density=density) 8170 8171 # jitting support 8172 jitted = torch.jit.script(torch.pca_lowrank) 8173 guess_rank, actual_rank, size, batches = 2, 2, (17, 4), () 8174 run_subtest(guess_rank, actual_rank, size, batches, device, jitted) 8175 8176 # Ensure that nuclear_norm's out variant gives the same result as the non-out 8177 @onlyNativeDeviceTypes 8178 @skipCUDAIfNoMagma 8179 @skipCPUIfNoLapack 8180 @dtypes(torch.float32, torch.float64) 8181 def test_nuclear_norm_out(self, device, dtype): 8182 test_cases = [ 8183 # input size, dim 8184 ((25, 25), None), 8185 ((25, 25), (0, 1)), 8186 ((25, 25), (1, 0)), 8187 ((25, 25, 25), (2, 0)), 8188 ((25, 25, 25), (0, 1)), 8189 ] 8190 for keepdim in [False, True]: 8191 for input_size, dim in test_cases: 8192 msg = f'input_size: {input_size}, dim: {dim}, keepdim: {keepdim}' 8193 x = torch.randn(*input_size, device=device, dtype=dtype) 8194 result_out = torch.empty(0, device=device, dtype=dtype) 8195 if dim is None: 8196 result = torch.nuclear_norm(x, keepdim=keepdim) 8197 torch.nuclear_norm(x, keepdim=keepdim, out=result_out) 8198 else: 8199 result = torch.nuclear_norm(x, keepdim=keepdim, dim=dim) 8200 torch.nuclear_norm(x, keepdim=keepdim, dim=dim, out=result_out) 8201 self.assertEqual(result, result_out, msg=msg) 8202 8203 @skipCUDAIfNoMagmaAndNoCusolver 8204 @skipCPUIfNoLapack 8205 @dtypes(*floating_and_complex_types()) 8206 def test_geqrf(self, device, dtype): 8207 8208 def run_test(shape): 8209 # numpy.linalg.qr with mode = 'raw' computes the same operation as torch.geqrf 8210 # so this test compares against that function 8211 A = make_tensor(shape, dtype=dtype, device=device) 8212 8213 # numpy.linalg.qr doesn't work with batched input 8214 m, n = A.shape[-2:] 8215 tau_size = "n" if m > n else "m" 8216 np_dtype = A.cpu().numpy().dtype 8217 ot = [np_dtype, np_dtype] 8218 numpy_geqrf_batched = np.vectorize( 8219 lambda x: np.linalg.qr(x, mode='raw'), 8220 otypes=ot, 8221 signature=f'(m,n)->(n,m),({tau_size})') 8222 8223 expected = numpy_geqrf_batched(A.cpu()) 8224 actual = torch.geqrf(A) 8225 8226 # numpy.linalg.qr returns transposed result 8227 self.assertEqual(expected[0].swapaxes(-2, -1), actual[0]) 8228 self.assertEqual(expected[1], actual[1]) 8229 8230 batches = [(), (0, ), (2, ), (2, 1)] 8231 ns = [5, 2, 0] 8232 for batch, (m, n) in product(batches, product(ns, ns)): 8233 run_test((*batch, m, n)) 8234 8235 @skipCUDAIfNoMagma 8236 @skipCPUIfNoLapack 8237 def test_lapack_empty(self, device): 8238 # FIXME: these are just a selection of LAPACK functions -- we need a general strategy here. 8239 # The LAPACK functions themselves generally do NOT work with zero sized dimensions, although 8240 # numpy/sci often has a direct wrapper (e.g. lu_factor) and a wrapper that "does the right thing" 8241 # (e.g. lu). We often name our functions identically to the lapack function, so it will take work 8242 # to name / migrate-to better wrappers. 8243 def fn(torchfn, *args): 8244 return torchfn(*tuple(torch.randn(shape, device=device) if isinstance(shape, tuple) else shape 8245 for shape in args)) 8246 8247 # inverse, pinverse 8248 self.assertEqual((0, 0), fn(torch.inverse, (0, 0)).shape) 8249 self.assertEqual((5, 0), fn(torch.pinverse, (0, 5)).shape) 8250 self.assertEqual((0, 5), fn(torch.pinverse, (5, 0)).shape) 8251 self.assertEqual((0, 0), fn(torch.pinverse, (0, 0)).shape) 8252 8253 # det, logdet, slogdet 8254 self.assertEqual(torch.tensor(1., device=device), fn(torch.det, (0, 0))) 8255 self.assertEqual(torch.tensor(0., device=device), fn(torch.logdet, (0, 0))) 8256 self.assertEqual((torch.tensor(1., device=device), torch.tensor(0., device=device)), 8257 fn(torch.slogdet, (0, 0))) 8258 8259 @tf32_on_and_off(0.005) 8260 @bf32_on_and_off(0.005) 8261 def test_tensordot(self, device): 8262 a = torch.arange(60., device=device).reshape(3, 4, 5) 8263 b = torch.arange(24., device=device).reshape(4, 3, 2) 8264 c = torch.tensordot(a, b, dims=([1, 0], [0, 1])).cpu() 8265 cn = torch.from_numpy(np.tensordot(a.cpu().numpy(), b.cpu().numpy(), 8266 axes=([1, 0], [0, 1]))) 8267 self.assertEqual(c, cn) 8268 8269 cout = torch.zeros((5, 2), device=device) 8270 torch.tensordot(a, b, dims=([1, 0], [0, 1]), out=cout).cpu() 8271 self.assertEqual(c, cout) 8272 8273 a = torch.randn(2, 3, 4, 5, device=device) 8274 b = torch.randn(4, 5, 6, 7, device=device) 8275 c = torch.tensordot(a, b, dims=2).cpu() 8276 cn = torch.from_numpy(np.tensordot(a.cpu().numpy(), b.cpu().numpy(), 8277 axes=2)) 8278 8279 with self.assertRaisesRegex(RuntimeError, "expects dims >= 0"): 8280 torch.tensordot(a, b, dims=-1) 8281 8282 self.assertEqual(c, cn) 8283 c = torch.tensordot(a, b).cpu() 8284 cn = torch.from_numpy(np.tensordot(a.cpu().numpy(), b.cpu().numpy())) 8285 self.assertEqual(c, cn) 8286 8287 a = torch.tensordot(torch.tensor(0.), torch.tensor(0.), 0) 8288 an = torch.from_numpy(np.tensordot(np.zeros((), dtype=np.float32), np.zeros((), dtype=np.float32), 0)) 8289 self.assertEqual(a, an) 8290 8291 @skipCUDAIfNoCusolver 8292 @skipCUDAIfNoMagma 8293 @skipCPUIfNoLapack 8294 @skipIfTorchDynamo("flaky, needs investigation") 8295 @dtypes(*floating_and_complex_types()) 8296 def test_ldl_factor(self, device, dtype): 8297 from torch.testing._internal.common_utils import random_hermitian_pd_matrix 8298 8299 def run_test(shape, batch, hermitian): 8300 A = random_hermitian_pd_matrix(shape, *batch, dtype=dtype, device=device) 8301 actual_factors, actual_pivots, info = torch.linalg.ldl_factor_ex(A, hermitian=hermitian) 8302 actual_L = torch.tril(actual_factors, diagonal=-1) 8303 actual_L.diagonal(0, -2, -1).fill_(1.0) 8304 8305 # This test is designed only for inputs with 1x1 block diagonal matrix D. 8306 # That is for positive definite input matrices, the pivots tensor is always > 0. 8307 # If negative pivots are encountered, it means that the input matrix is not positive definite. 8308 # And matrix D is a 2x2 block diagonal matrix. 8309 self.assertTrue((actual_pivots > 0).all()) 8310 8311 # Construct a 1x1 block diagonal matrix D from factors. 8312 actual_D = torch.diag_embed(actual_factors.diagonal(0, -2, -1)) 8313 8314 def T(x): 8315 return x.mH if hermitian else x.mT 8316 A_reconstructed = actual_L @ actual_D @ T(actual_L) 8317 8318 def symmetric(A): 8319 return A.tril() + A.tril(-1).mT 8320 8321 self.assertEqual(symmetric(A) if not hermitian else A, A_reconstructed) 8322 8323 # Now test against SciPy implementation 8324 if TEST_SCIPY: 8325 from scipy.linalg import ldl as scipy_ldl 8326 A_np = A.cpu().numpy() 8327 np_dtype = A_np.dtype 8328 scipy_ldl_batched = np.vectorize( 8329 lambda x: scipy_ldl(x, hermitian=hermitian, lower=True), 8330 otypes=[np_dtype, np_dtype, np.dtype('int64')], 8331 signature='(m,m)->(m,m),(m,m),(m)') 8332 8333 expected = scipy_ldl_batched(A_np) 8334 expected_L, expected_D, expected_pivots = expected 8335 8336 if expected_pivots.ndim > 1: 8337 permuted_expected_L = np.stack( 8338 [expected_L[i][expected_pivots[i], :] for i in range(expected_pivots.shape[0])] 8339 ) 8340 else: 8341 permuted_expected_L = expected_L[expected_pivots, :] 8342 self.assertEqual(actual_L, permuted_expected_L) 8343 self.assertEqual(actual_D, expected_D) 8344 else: 8345 self.assertEqual(actual_factors.shape, A.shape) 8346 self.assertEqual(actual_pivots.shape, A.shape[:-1]) 8347 self.assertEqual(info.shape, A.shape[:-2]) 8348 8349 # hermitian=True for complex inputs on CUDA is supported only with MAGMA 2.5.4+ 8350 magma_254_available = self.device_type == 'cuda' and _get_magma_version() >= (2, 5, 4) 8351 hermitians = (True, False) if dtype.is_complex and (self.device_type == 'cpu' or magma_254_available) else (False,) 8352 8353 shapes = (5,) 8354 batches = ((), (4,),) 8355 for shape, batch, hermitian in itertools.product(shapes, batches, hermitians): 8356 run_test(shape, batch, hermitian) 8357 8358 @skipCUDAIfNoCusolver 8359 @skipCUDAIfNoMagma 8360 @skipCPUIfNoLapack 8361 @skipCUDAIfRocm 8362 @skipCUDAIf(_get_torch_cuda_version() < (11, 4), "not available before CUDA 11.3.1") 8363 @dtypes(*floating_and_complex_types()) 8364 def test_ldl_solve(self, device, dtype): 8365 from torch.testing._internal.common_utils import random_hermitian_pd_matrix 8366 8367 def run_test(shape, batch, nrhs, hermitian): 8368 A = random_hermitian_pd_matrix(shape, *batch, dtype=dtype, device=device) 8369 B = make_tensor((*A.shape[:-1], nrhs), dtype=dtype, device=device) 8370 factors, pivots, info = torch.linalg.ldl_factor_ex(A, hermitian=hermitian) 8371 X = torch.linalg.ldl_solve(factors, pivots, B, hermitian=hermitian) 8372 8373 def symmetric(A): 8374 return A.tril() + A.tril(-1).mT 8375 8376 # verify A @ X == B 8377 expected_B = symmetric(A) @ X if not hermitian else A @ X 8378 self.assertEqual(B, expected_B) 8379 8380 # hermitian=True is not supported on CUDA yet 8381 hermitians = (True, False) if dtype.is_complex and self.device_type == 'cpu' else (False,) 8382 8383 shapes = (5,) 8384 batches = ((), (4,), (2, 2)) 8385 nrhss = (1, 7) 8386 for shape, batch, nrhs, hermitian in itertools.product(shapes, batches, nrhss, hermitians): 8387 run_test(shape, batch, nrhs, hermitian) 8388 8389 @onlyCUDA 8390 @skipCUDAIfNoMagma 8391 @skipCUDAIfNoCusolver 8392 @setLinalgBackendsToDefaultFinally 8393 def test_preferred_linalg_library(self): 8394 # The main purpose of this test is to make sure these "backend" calls work normally without raising exceptions. 8395 x = torch.randint(2, 5, (2, 4, 4), device='cuda', dtype=torch.double) 8396 8397 torch.backends.cuda.preferred_linalg_library('cusolver') 8398 out1 = torch.linalg.inv(x) 8399 8400 torch.backends.cuda.preferred_linalg_library('magma') 8401 out2 = torch.linalg.inv(x) 8402 8403 torch.backends.cuda.preferred_linalg_library('default') 8404 # Although linalg preferred flags doesn't affect CPU currently, 8405 # we set this to make sure the flag can switch back to default normally. 8406 out_ref = torch.linalg.inv(x.cpu()) 8407 8408 self.assertEqual(out_ref, out1.cpu()) 8409 self.assertEqual(out1, out2) 8410 8411 @onlyCUDA 8412 @unittest.skipIf(not blaslt_supported_device(), "blasLt not supported on current device") 8413 @setBlasBackendsToDefaultFinally 8414 def test_preferred_blas_library(self): 8415 # The main purpose of this test is to make sure these "backend" calls work normally without raising exceptions. 8416 m1 = torch.randint(2, 5, (2048, 2400), device='cuda', dtype=torch.float) 8417 m2 = torch.randint(2, 5, (128, 2400), device='cuda', dtype=torch.float) 8418 8419 torch.backends.cuda.preferred_blas_library('cublaslt') 8420 out1 = torch.nn.functional.linear(m1, m2) 8421 8422 torch.backends.cuda.preferred_blas_library('cublas') 8423 out2 = torch.nn.functional.linear(m1, m2) 8424 8425 # Although blas preferred flags doesn't affect CPU currently, 8426 # we set this to make sure the flag can switch back to default normally. 8427 out_ref = torch.nn.functional.linear(m1.cpu(), m2.cpu()) 8428 8429 self.assertEqual(out1, out2) 8430 self.assertEqual(out_ref, out2.cpu()) 8431 8432 def test_permute_matmul(self): 8433 a = torch.ones([2, 5, 24, 24]) 8434 b = torch.ones([3, 2, 5, 24, 24]) 8435 c = a.permute(0, 1, 3, 2).matmul(b) 8436 self.assertEqual([c.min(), c.max(), c.sum()], [24, 24, 414720]) 8437 8438 def test_lower_precision_accumulation_with_ref_path(self): 8439 # fix https://github.com/pytorch/pytorch/issues/95125 8440 # and https://github.com/pytorch/pytorch/issues/83863 8441 # for bf16 accumulation in gemm ref path 8442 def check_correctness(fn, dtype, *args): 8443 expected = fn(*args).to(dtype=dtype) 8444 with torch.backends.mkldnn.flags(enabled=False): 8445 def test(): 8446 lower_args = (arg.to(dtype=dtype) for arg in args) 8447 tmp_result = fn(*lower_args) 8448 return tmp_result 8449 c = test() 8450 assert (torch.all(c == expected)), "Incorrect result with\n" \ 8451 f"expected: {expected}\n" \ 8452 f"got: {c}\n" 8453 # test matmul 8454 for dtype in [torch.bfloat16, torch.half]: 8455 for transa in [True, False]: 8456 for transb in [True, False]: 8457 a = torch.ones(300, 300) 8458 b = torch.ones(300, 300) 8459 if transa: 8460 a = a.transpose(0, 1).contiguous().transpose(0, 1) 8461 if transb: 8462 b = b.transpose(0, 1).contiguous().transpose(0, 1) 8463 check_correctness(torch.matmul, dtype, a, b) 8464 # test bmm 8465 a = torch.ones(1, 1, 300) 8466 b = torch.ones(1, 300, 1) 8467 check_correctness(torch.bmm, torch.bfloat16, a, b) 8468 check_correctness(torch.bmm, torch.half, a, b) 8469 # test baddbmm 8470 a = torch.ones(1, 1, 300) 8471 b = torch.ones(1, 300, 1) 8472 c = torch.ones(1, 1, 1) 8473 check_correctness(torch.baddbmm, torch.bfloat16, c, a, b) 8474 check_correctness(torch.baddbmm, torch.half, c, a, b) 8475 # test mv/addmv 8476 for dtype in [torch.bfloat16, torch.half]: 8477 for trans in [True, False]: 8478 c = torch.ones(300) * -300 8479 a = torch.ones(300, 300) 8480 if trans: 8481 a = a.transpose(0, 1).contiguous().transpose(0, 1) 8482 b = torch.ones(300) 8483 check_correctness(torch.mv, dtype, a, b) 8484 check_correctness(torch.addmv, dtype, c, a, b) 8485 # test dot 8486 a = torch.ones(300) 8487 b = torch.ones(300) 8488 check_correctness(torch.dot, torch.bfloat16, a, b) 8489 check_correctness(torch.dot, torch.half, a, b) 8490 8491 @dtypes(torch.float, torch.half, torch.bfloat16) 8492 @parametrize("transpose_a", [True, False]) 8493 @parametrize("transpose_b", [True, False]) 8494 @parametrize("alpha", [0.0, 0.2, 1.0]) 8495 @parametrize("beta", [0.0, 0.5, 1.0]) 8496 def test_addmm_mv(self, device, dtype, transpose_a, transpose_b, alpha, beta): 8497 def gen_mat(w, h, use_transpose: bool = False): 8498 if not use_transpose: 8499 return torch.rand(w, h, dtype=dtype, device=device) 8500 return torch.rand(h, w, dtype=dtype, device=device).t() 8501 # Regression tests for https://github.com/pytorch/pytorch/issues/136299 8502 # Should only expose problems on aarch64, but let's be thorough 8503 m, n , k = 1, 8, 32 8504 A = gen_mat(m, k, transpose_a) 8505 B = gen_mat(k, n, transpose_b) 8506 C = torch.ones(m, n, dtype=dtype, device=device) 8507 rc = torch.addmm(C, A, B, alpha=alpha, beta=beta) 8508 ref = alpha * A @ B + beta * C 8509 self.assertEqual(rc, ref) 8510 8511 8512 @dtypes(torch.float, torch.double) 8513 @precisionOverride({torch.float32: 1e-4}) 8514 def test_1_sized_with_0_strided(self, device, dtype): 8515 a = make_tensor((8, 1, 64), dtype=dtype, device=device) 8516 a_strided = torch.as_strided(a, size=[8, 1, 64], stride=[64, 0, 1]) 8517 b = make_tensor((8, 64, 512), dtype=dtype, device=device) 8518 b_strided = torch.as_strided(b, size=[8, 64, 512], stride=[64, 1, 512]) 8519 res = torch.bmm(a_strided, b_strided) 8520 expect = torch.from_numpy( 8521 a_strided.cpu().numpy() @ b_strided.cpu().numpy()).to(device=device, dtype=dtype) 8522 self.assertEqual(expect, res) 8523 8524instantiate_device_type_tests(TestLinalg, globals()) 8525 8526if __name__ == '__main__': 8527 TestCase._default_dtype_check_enabled = True 8528 run_tests() 8529