1# Owner(s): ["module: dynamo"] 2 3""" Test functions for linalg module 4 5""" 6import functools 7import itertools 8import os 9import subprocess 10import sys 11import textwrap 12import traceback 13from unittest import expectedFailure as xfail, skipIf as skipif, SkipTest 14 15import numpy 16import pytest 17from numpy.linalg.linalg import _multi_dot_matrix_chain_order 18from pytest import raises as assert_raises 19 20from torch.testing._internal.common_utils import ( 21 instantiate_parametrized_tests, 22 parametrize, 23 run_tests, 24 slowTest as slow, 25 TEST_WITH_TORCHDYNAMO, 26 TestCase, 27 xpassIfTorchDynamo, 28) 29 30 31# If we are going to trace through these, we should use NumPy 32# If testing on eager mode, we use torch._numpy 33if TEST_WITH_TORCHDYNAMO: 34 import numpy as np 35 from numpy import ( 36 array, 37 asarray, 38 atleast_2d, 39 cdouble, 40 csingle, 41 dot, 42 double, 43 identity, 44 inf, 45 linalg, 46 matmul, 47 single, 48 swapaxes, 49 ) 50 from numpy.linalg import LinAlgError, matrix_power, matrix_rank, multi_dot, norm 51 from numpy.testing import ( # assert_raises_regex, HAS_LAPACK64, IS_WASM 52 assert_, 53 assert_allclose, 54 assert_almost_equal, 55 assert_array_equal, 56 assert_equal, 57 suppress_warnings, 58 ) 59 60else: 61 import torch._numpy as np 62 from torch._numpy import ( 63 array, 64 asarray, 65 atleast_2d, 66 cdouble, 67 csingle, 68 dot, 69 double, 70 identity, 71 inf, 72 linalg, 73 matmul, 74 single, 75 swapaxes, 76 ) 77 from torch._numpy.linalg import ( 78 LinAlgError, 79 matrix_power, 80 matrix_rank, 81 multi_dot, 82 norm, 83 ) 84 from torch._numpy.testing import ( 85 assert_, 86 assert_allclose, 87 assert_almost_equal, 88 assert_array_equal, 89 assert_equal, 90 suppress_warnings, 91 ) 92 93 94skip = functools.partial(skipif, True) 95 96IS_WASM = False 97HAS_LAPACK64 = False 98 99 100def consistent_subclass(out, in_): 101 # For ndarray subclass input, our output should have the same subclass 102 # (non-ndarray input gets converted to ndarray). 103 return type(out) is (type(in_) if isinstance(in_, np.ndarray) else np.ndarray) 104 105 106old_assert_almost_equal = assert_almost_equal 107 108 109def assert_almost_equal(a, b, single_decimal=6, double_decimal=12, **kw): 110 if asarray(a).dtype.type in (single, csingle): 111 decimal = single_decimal 112 else: 113 decimal = double_decimal 114 old_assert_almost_equal(a, b, decimal=decimal, **kw) 115 116 117def get_real_dtype(dtype): 118 return {single: single, double: double, csingle: single, cdouble: double}[dtype] 119 120 121def get_complex_dtype(dtype): 122 return {single: csingle, double: cdouble, csingle: csingle, cdouble: cdouble}[dtype] 123 124 125def get_rtol(dtype): 126 # Choose a safe rtol 127 if dtype in (single, csingle): 128 return 1e-5 129 else: 130 return 1e-11 131 132 133# used to categorize tests 134all_tags = { 135 "square", 136 "nonsquare", 137 "hermitian", # mutually exclusive 138 "generalized", 139 "size-0", 140 "strided", # optional additions 141} 142 143 144class LinalgCase: 145 def __init__(self, name, a, b, tags=None): 146 """ 147 A bundle of arguments to be passed to a test case, with an identifying 148 name, the operands a and b, and a set of tags to filter the tests 149 """ 150 if tags is None: 151 tags = set() 152 assert_(isinstance(name, str)) 153 self.name = name 154 self.a = a 155 self.b = b 156 self.tags = frozenset(tags) # prevent shared tags 157 158 def check(self, do): 159 """ 160 Run the function `do` on this test case, expanding arguments 161 """ 162 do(self.a, self.b, tags=self.tags) 163 164 def __repr__(self): 165 return f"<LinalgCase: {self.name}>" 166 167 168def apply_tag(tag, cases): 169 """ 170 Add the given tag (a string) to each of the cases (a list of LinalgCase 171 objects) 172 """ 173 assert tag in all_tags, "Invalid tag" 174 for case in cases: 175 case.tags = case.tags | {tag} 176 return cases 177 178 179# 180# Base test cases 181# 182 183np.random.seed(1234) 184 185CASES = [] 186 187# square test cases 188CASES += apply_tag( 189 "square", 190 [ 191 LinalgCase( 192 "single", 193 array([[1.0, 2.0], [3.0, 4.0]], dtype=single), 194 array([2.0, 1.0], dtype=single), 195 ), 196 LinalgCase( 197 "double", 198 array([[1.0, 2.0], [3.0, 4.0]], dtype=double), 199 array([2.0, 1.0], dtype=double), 200 ), 201 LinalgCase( 202 "double_2", 203 array([[1.0, 2.0], [3.0, 4.0]], dtype=double), 204 array([[2.0, 1.0, 4.0], [3.0, 4.0, 6.0]], dtype=double), 205 ), 206 LinalgCase( 207 "csingle", 208 array([[1.0 + 2j, 2 + 3j], [3 + 4j, 4 + 5j]], dtype=csingle), 209 array([2.0 + 1j, 1.0 + 2j], dtype=csingle), 210 ), 211 LinalgCase( 212 "cdouble", 213 array([[1.0 + 2j, 2 + 3j], [3 + 4j, 4 + 5j]], dtype=cdouble), 214 array([2.0 + 1j, 1.0 + 2j], dtype=cdouble), 215 ), 216 LinalgCase( 217 "cdouble_2", 218 array([[1.0 + 2j, 2 + 3j], [3 + 4j, 4 + 5j]], dtype=cdouble), 219 array( 220 [[2.0 + 1j, 1.0 + 2j, 1 + 3j], [1 - 2j, 1 - 3j, 1 - 6j]], dtype=cdouble 221 ), 222 ), 223 LinalgCase( 224 "0x0", 225 np.empty((0, 0), dtype=double), 226 np.empty((0,), dtype=double), 227 tags={"size-0"}, 228 ), 229 LinalgCase("8x8", np.random.rand(8, 8), np.random.rand(8)), 230 LinalgCase("1x1", np.random.rand(1, 1), np.random.rand(1)), 231 LinalgCase("nonarray", [[1, 2], [3, 4]], [2, 1]), 232 ], 233) 234 235# non-square test-cases 236CASES += apply_tag( 237 "nonsquare", 238 [ 239 LinalgCase( 240 "single_nsq_1", 241 array([[1.0, 2.0, 3.0], [3.0, 4.0, 6.0]], dtype=single), 242 array([2.0, 1.0], dtype=single), 243 ), 244 LinalgCase( 245 "single_nsq_2", 246 array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=single), 247 array([2.0, 1.0, 3.0], dtype=single), 248 ), 249 LinalgCase( 250 "double_nsq_1", 251 array([[1.0, 2.0, 3.0], [3.0, 4.0, 6.0]], dtype=double), 252 array([2.0, 1.0], dtype=double), 253 ), 254 LinalgCase( 255 "double_nsq_2", 256 array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=double), 257 array([2.0, 1.0, 3.0], dtype=double), 258 ), 259 LinalgCase( 260 "csingle_nsq_1", 261 array( 262 [[1.0 + 1j, 2.0 + 2j, 3.0 - 3j], [3.0 - 5j, 4.0 + 9j, 6.0 + 2j]], 263 dtype=csingle, 264 ), 265 array([2.0 + 1j, 1.0 + 2j], dtype=csingle), 266 ), 267 LinalgCase( 268 "csingle_nsq_2", 269 array( 270 [[1.0 + 1j, 2.0 + 2j], [3.0 - 3j, 4.0 - 9j], [5.0 - 4j, 6.0 + 8j]], 271 dtype=csingle, 272 ), 273 array([2.0 + 1j, 1.0 + 2j, 3.0 - 3j], dtype=csingle), 274 ), 275 LinalgCase( 276 "cdouble_nsq_1", 277 array( 278 [[1.0 + 1j, 2.0 + 2j, 3.0 - 3j], [3.0 - 5j, 4.0 + 9j, 6.0 + 2j]], 279 dtype=cdouble, 280 ), 281 array([2.0 + 1j, 1.0 + 2j], dtype=cdouble), 282 ), 283 LinalgCase( 284 "cdouble_nsq_2", 285 array( 286 [[1.0 + 1j, 2.0 + 2j], [3.0 - 3j, 4.0 - 9j], [5.0 - 4j, 6.0 + 8j]], 287 dtype=cdouble, 288 ), 289 array([2.0 + 1j, 1.0 + 2j, 3.0 - 3j], dtype=cdouble), 290 ), 291 LinalgCase( 292 "cdouble_nsq_1_2", 293 array( 294 [[1.0 + 1j, 2.0 + 2j, 3.0 - 3j], [3.0 - 5j, 4.0 + 9j, 6.0 + 2j]], 295 dtype=cdouble, 296 ), 297 array([[2.0 + 1j, 1.0 + 2j], [1 - 1j, 2 - 2j]], dtype=cdouble), 298 ), 299 LinalgCase( 300 "cdouble_nsq_2_2", 301 array( 302 [[1.0 + 1j, 2.0 + 2j], [3.0 - 3j, 4.0 - 9j], [5.0 - 4j, 6.0 + 8j]], 303 dtype=cdouble, 304 ), 305 array( 306 [[2.0 + 1j, 1.0 + 2j], [1 - 1j, 2 - 2j], [1 - 1j, 2 - 2j]], 307 dtype=cdouble, 308 ), 309 ), 310 LinalgCase("8x11", np.random.rand(8, 11), np.random.rand(8)), 311 LinalgCase("1x5", np.random.rand(1, 5), np.random.rand(1)), 312 LinalgCase("5x1", np.random.rand(5, 1), np.random.rand(5)), 313 LinalgCase("0x4", np.random.rand(0, 4), np.random.rand(0), tags={"size-0"}), 314 LinalgCase("4x0", np.random.rand(4, 0), np.random.rand(4), tags={"size-0"}), 315 ], 316) 317 318# hermitian test-cases 319CASES += apply_tag( 320 "hermitian", 321 [ 322 LinalgCase("hsingle", array([[1.0, 2.0], [2.0, 1.0]], dtype=single), None), 323 LinalgCase("hdouble", array([[1.0, 2.0], [2.0, 1.0]], dtype=double), None), 324 LinalgCase( 325 "hcsingle", array([[1.0, 2 + 3j], [2 - 3j, 1]], dtype=csingle), None 326 ), 327 LinalgCase( 328 "hcdouble", array([[1.0, 2 + 3j], [2 - 3j, 1]], dtype=cdouble), None 329 ), 330 LinalgCase("hempty", np.empty((0, 0), dtype=double), None, tags={"size-0"}), 331 LinalgCase("hnonarray", [[1, 2], [2, 1]], None), 332 LinalgCase("matrix_b_only", array([[1.0, 2.0], [2.0, 1.0]]), None), 333 LinalgCase("hmatrix_1x1", np.random.rand(1, 1), None), 334 ], 335) 336 337 338# 339# Gufunc test cases 340# 341def _make_generalized_cases(): 342 new_cases = [] 343 344 for case in CASES: 345 if not isinstance(case.a, np.ndarray): 346 continue 347 348 a = np.stack([case.a, 2 * case.a, 3 * case.a]) 349 if case.b is None: 350 b = None 351 else: 352 b = np.stack([case.b, 7 * case.b, 6 * case.b]) 353 new_case = LinalgCase( 354 case.name + "_tile3", a, b, tags=case.tags | {"generalized"} 355 ) 356 new_cases.append(new_case) 357 358 a = np.array([case.a] * 2 * 3).reshape((3, 2) + case.a.shape) 359 if case.b is None: 360 b = None 361 else: 362 b = np.array([case.b] * 2 * 3).reshape((3, 2) + case.b.shape) 363 new_case = LinalgCase( 364 case.name + "_tile213", a, b, tags=case.tags | {"generalized"} 365 ) 366 new_cases.append(new_case) 367 368 return new_cases 369 370 371CASES += _make_generalized_cases() 372 373 374# 375# Test different routines against the above cases 376# 377class LinalgTestCase: 378 TEST_CASES = CASES 379 380 def check_cases(self, require=None, exclude=None): 381 """ 382 Run func on each of the cases with all of the tags in require, and none 383 of the tags in exclude 384 """ 385 if require is None: 386 require = set() 387 if exclude is None: 388 exclude = set() 389 for case in self.TEST_CASES: 390 # filter by require and exclude 391 if case.tags & require != require: 392 continue 393 if case.tags & exclude: 394 continue 395 396 try: 397 case.check(self.do) 398 except Exception as e: 399 msg = f"In test case: {case!r}\n\n" 400 msg += traceback.format_exc() 401 raise AssertionError(msg) from e 402 403 404class LinalgSquareTestCase(LinalgTestCase): 405 def test_sq_cases(self): 406 self.check_cases(require={"square"}, exclude={"generalized", "size-0"}) 407 408 def test_empty_sq_cases(self): 409 self.check_cases(require={"square", "size-0"}, exclude={"generalized"}) 410 411 412class LinalgNonsquareTestCase(LinalgTestCase): 413 def test_nonsq_cases(self): 414 self.check_cases(require={"nonsquare"}, exclude={"generalized", "size-0"}) 415 416 def test_empty_nonsq_cases(self): 417 self.check_cases(require={"nonsquare", "size-0"}, exclude={"generalized"}) 418 419 420class HermitianTestCase(LinalgTestCase): 421 def test_herm_cases(self): 422 self.check_cases(require={"hermitian"}, exclude={"generalized", "size-0"}) 423 424 def test_empty_herm_cases(self): 425 self.check_cases(require={"hermitian", "size-0"}, exclude={"generalized"}) 426 427 428class LinalgGeneralizedSquareTestCase(LinalgTestCase): 429 @slow 430 def test_generalized_sq_cases(self): 431 self.check_cases(require={"generalized", "square"}, exclude={"size-0"}) 432 433 @slow 434 def test_generalized_empty_sq_cases(self): 435 self.check_cases(require={"generalized", "square", "size-0"}) 436 437 438class LinalgGeneralizedNonsquareTestCase(LinalgTestCase): 439 @slow 440 def test_generalized_nonsq_cases(self): 441 self.check_cases(require={"generalized", "nonsquare"}, exclude={"size-0"}) 442 443 @slow 444 def test_generalized_empty_nonsq_cases(self): 445 self.check_cases(require={"generalized", "nonsquare", "size-0"}) 446 447 448class HermitianGeneralizedTestCase(LinalgTestCase): 449 @slow 450 def test_generalized_herm_cases(self): 451 self.check_cases(require={"generalized", "hermitian"}, exclude={"size-0"}) 452 453 @slow 454 def test_generalized_empty_herm_cases(self): 455 self.check_cases( 456 require={"generalized", "hermitian", "size-0"}, exclude={"none"} 457 ) 458 459 460def dot_generalized(a, b): 461 a = asarray(a) 462 if a.ndim >= 3: 463 if a.ndim == b.ndim: 464 # matrix x matrix 465 new_shape = a.shape[:-1] + b.shape[-1:] 466 elif a.ndim == b.ndim + 1: 467 # matrix x vector 468 new_shape = a.shape[:-1] 469 else: 470 raise ValueError("Not implemented...") 471 r = np.empty(new_shape, dtype=np.common_type(a, b)) 472 for c in itertools.product(*map(range, a.shape[:-2])): 473 r[c] = dot(a[c], b[c]) 474 return r 475 else: 476 return dot(a, b) 477 478 479def identity_like_generalized(a): 480 a = asarray(a) 481 if a.ndim >= 3: 482 r = np.empty(a.shape, dtype=a.dtype) 483 r[...] = identity(a.shape[-2]) 484 return r 485 else: 486 return identity(a.shape[0]) 487 488 489class SolveCases(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase): 490 # kept apart from TestSolve for use for testing with matrices. 491 def do(self, a, b, tags): 492 x = linalg.solve(a, b) 493 assert_almost_equal(b, dot_generalized(a, x)) 494 assert_(consistent_subclass(x, b)) 495 496 497@instantiate_parametrized_tests 498class TestSolve(SolveCases, TestCase): 499 @parametrize("dtype", [single, double, csingle, cdouble]) 500 def test_types(self, dtype): 501 x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype) 502 assert_equal(linalg.solve(x, x).dtype, dtype) 503 504 @skip(reason="subclass") 505 def test_0_size(self): 506 class ArraySubclass(np.ndarray): 507 pass 508 509 # Test system of 0x0 matrices 510 a = np.arange(8).reshape(2, 2, 2) 511 b = np.arange(6).reshape(1, 2, 3).view(ArraySubclass) 512 513 expected = linalg.solve(a, b)[:, 0:0, :] 514 result = linalg.solve(a[:, 0:0, 0:0], b[:, 0:0, :]) 515 assert_array_equal(result, expected) 516 assert_(isinstance(result, ArraySubclass)) 517 518 # Test errors for non-square and only b's dimension being 0 519 assert_raises(linalg.LinAlgError, linalg.solve, a[:, 0:0, 0:1], b) 520 assert_raises(ValueError, linalg.solve, a, b[:, 0:0, :]) 521 522 # Test broadcasting error 523 b = np.arange(6).reshape(1, 3, 2) # broadcasting error 524 assert_raises(ValueError, linalg.solve, a, b) 525 assert_raises(ValueError, linalg.solve, a[0:0], b[0:0]) 526 527 # Test zero "single equations" with 0x0 matrices. 528 b = np.arange(2).reshape(1, 2).view(ArraySubclass) 529 expected = linalg.solve(a, b)[:, 0:0] 530 result = linalg.solve(a[:, 0:0, 0:0], b[:, 0:0]) 531 assert_array_equal(result, expected) 532 assert_(isinstance(result, ArraySubclass)) 533 534 b = np.arange(3).reshape(1, 3) 535 assert_raises(ValueError, linalg.solve, a, b) 536 assert_raises(ValueError, linalg.solve, a[0:0], b[0:0]) 537 assert_raises(ValueError, linalg.solve, a[:, 0:0, 0:0], b) 538 539 @skip(reason="subclass") 540 def test_0_size_k(self): 541 # test zero multiple equation (K=0) case. 542 class ArraySubclass(np.ndarray): 543 pass 544 545 a = np.arange(4).reshape(1, 2, 2) 546 b = np.arange(6).reshape(3, 2, 1).view(ArraySubclass) 547 548 expected = linalg.solve(a, b)[:, :, 0:0] 549 result = linalg.solve(a, b[:, :, 0:0]) 550 assert_array_equal(result, expected) 551 assert_(isinstance(result, ArraySubclass)) 552 553 # test both zero. 554 expected = linalg.solve(a, b)[:, 0:0, 0:0] 555 result = linalg.solve(a[:, 0:0, 0:0], b[:, 0:0, 0:0]) 556 assert_array_equal(result, expected) 557 assert_(isinstance(result, ArraySubclass)) 558 559 560class InvCases(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase): 561 def do(self, a, b, tags): 562 a_inv = linalg.inv(a) 563 assert_almost_equal(dot_generalized(a, a_inv), identity_like_generalized(a)) 564 assert_(consistent_subclass(a_inv, a)) 565 566 567@instantiate_parametrized_tests 568class TestInv(InvCases, TestCase): 569 @parametrize("dtype", [single, double, csingle, cdouble]) 570 def test_types(self, dtype): 571 x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype) 572 assert_equal(linalg.inv(x).dtype, dtype) 573 574 @skip(reason="subclass") 575 def test_0_size(self): 576 # Check that all kinds of 0-sized arrays work 577 class ArraySubclass(np.ndarray): 578 pass 579 580 a = np.zeros((0, 1, 1), dtype=np.int_).view(ArraySubclass) 581 res = linalg.inv(a) 582 assert_(res.dtype.type is np.float64) 583 assert_equal(a.shape, res.shape) 584 assert_(isinstance(res, ArraySubclass)) 585 586 a = np.zeros((0, 0), dtype=np.complex64).view(ArraySubclass) 587 res = linalg.inv(a) 588 assert_(res.dtype.type is np.complex64) 589 assert_equal(a.shape, res.shape) 590 assert_(isinstance(res, ArraySubclass)) 591 592 593class EigvalsCases(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase): 594 def do(self, a, b, tags): 595 ev = linalg.eigvals(a) 596 evalues, evectors = linalg.eig(a) 597 assert_almost_equal(ev, evalues) 598 599 600@instantiate_parametrized_tests 601class TestEigvals(EigvalsCases, TestCase): 602 @parametrize("dtype", [single, double, csingle, cdouble]) 603 def test_types(self, dtype): 604 x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype) 605 assert_equal(linalg.eigvals(x).dtype, dtype) 606 x = np.array([[1, 0.5], [-1, 1]], dtype=dtype) 607 assert_equal(linalg.eigvals(x).dtype, get_complex_dtype(dtype)) 608 609 @skip(reason="subclass") 610 def test_0_size(self): 611 # Check that all kinds of 0-sized arrays work 612 class ArraySubclass(np.ndarray): 613 pass 614 615 a = np.zeros((0, 1, 1), dtype=np.int_).view(ArraySubclass) 616 res = linalg.eigvals(a) 617 assert_(res.dtype.type is np.float64) 618 assert_equal((0, 1), res.shape) 619 # This is just for documentation, it might make sense to change: 620 assert_(isinstance(res, np.ndarray)) 621 622 a = np.zeros((0, 0), dtype=np.complex64).view(ArraySubclass) 623 res = linalg.eigvals(a) 624 assert_(res.dtype.type is np.complex64) 625 assert_equal((0,), res.shape) 626 # This is just for documentation, it might make sense to change: 627 assert_(isinstance(res, np.ndarray)) 628 629 630class EigCases(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase): 631 def do(self, a, b, tags): 632 evalues, evectors = linalg.eig(a) 633 assert_allclose( 634 dot_generalized(a, evectors), 635 np.asarray(evectors) * np.asarray(evalues)[..., None, :], 636 rtol=get_rtol(evalues.dtype), 637 ) 638 assert_(consistent_subclass(evectors, a)) 639 640 641@instantiate_parametrized_tests 642class TestEig(EigCases, TestCase): 643 @parametrize("dtype", [single, double, csingle, cdouble]) 644 def test_types(self, dtype): 645 x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype) 646 w, v = np.linalg.eig(x) 647 assert_equal(w.dtype, dtype) 648 assert_equal(v.dtype, dtype) 649 650 x = np.array([[1, 0.5], [-1, 1]], dtype=dtype) 651 w, v = np.linalg.eig(x) 652 assert_equal(w.dtype, get_complex_dtype(dtype)) 653 assert_equal(v.dtype, get_complex_dtype(dtype)) 654 655 @skip(reason="subclass") 656 def test_0_size(self): 657 # Check that all kinds of 0-sized arrays work 658 class ArraySubclass(np.ndarray): 659 pass 660 661 a = np.zeros((0, 1, 1), dtype=np.int_).view(ArraySubclass) 662 res, res_v = linalg.eig(a) 663 assert_(res_v.dtype.type is np.float64) 664 assert_(res.dtype.type is np.float64) 665 assert_equal(a.shape, res_v.shape) 666 assert_equal((0, 1), res.shape) 667 # This is just for documentation, it might make sense to change: 668 assert_(isinstance(a, np.ndarray)) 669 670 a = np.zeros((0, 0), dtype=np.complex64).view(ArraySubclass) 671 res, res_v = linalg.eig(a) 672 assert_(res_v.dtype.type is np.complex64) 673 assert_(res.dtype.type is np.complex64) 674 assert_equal(a.shape, res_v.shape) 675 assert_equal((0,), res.shape) 676 # This is just for documentation, it might make sense to change: 677 assert_(isinstance(a, np.ndarray)) 678 679 680@instantiate_parametrized_tests 681class SVDBaseTests: 682 hermitian = False 683 684 @parametrize("dtype", [single, double, csingle, cdouble]) 685 def test_types(self, dtype): 686 x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype) 687 u, s, vh = linalg.svd(x) 688 assert_equal(u.dtype, dtype) 689 assert_equal(s.dtype, get_real_dtype(dtype)) 690 assert_equal(vh.dtype, dtype) 691 s = linalg.svd(x, compute_uv=False, hermitian=self.hermitian) 692 assert_equal(s.dtype, get_real_dtype(dtype)) 693 694 695class SVDCases(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase): 696 def do(self, a, b, tags): 697 u, s, vt = linalg.svd(a, False) 698 assert_allclose( 699 a, 700 dot_generalized( 701 np.asarray(u) * np.asarray(s)[..., None, :], np.asarray(vt) 702 ), 703 rtol=get_rtol(u.dtype), 704 ) 705 assert_(consistent_subclass(u, a)) 706 assert_(consistent_subclass(vt, a)) 707 708 709class TestSVD(SVDCases, SVDBaseTests, TestCase): 710 def test_empty_identity(self): 711 """Empty input should put an identity matrix in u or vh""" 712 x = np.empty((4, 0)) 713 u, s, vh = linalg.svd(x, compute_uv=True, hermitian=self.hermitian) 714 assert_equal(u.shape, (4, 4)) 715 assert_equal(vh.shape, (0, 0)) 716 assert_equal(u, np.eye(4)) 717 718 x = np.empty((0, 4)) 719 u, s, vh = linalg.svd(x, compute_uv=True, hermitian=self.hermitian) 720 assert_equal(u.shape, (0, 0)) 721 assert_equal(vh.shape, (4, 4)) 722 assert_equal(vh, np.eye(4)) 723 724 725class SVDHermitianCases(HermitianTestCase, HermitianGeneralizedTestCase): 726 def do(self, a, b, tags): 727 u, s, vt = linalg.svd(a, False, hermitian=True) 728 assert_allclose( 729 a, 730 dot_generalized( 731 np.asarray(u) * np.asarray(s)[..., None, :], np.asarray(vt) 732 ), 733 rtol=get_rtol(u.dtype), 734 ) 735 736 def hermitian(mat): 737 axes = list(range(mat.ndim)) 738 axes[-1], axes[-2] = axes[-2], axes[-1] 739 return np.conj(np.transpose(mat, axes=axes)) 740 741 assert_almost_equal( 742 np.matmul(u, hermitian(u)), np.broadcast_to(np.eye(u.shape[-1]), u.shape) 743 ) 744 assert_almost_equal( 745 np.matmul(vt, hermitian(vt)), 746 np.broadcast_to(np.eye(vt.shape[-1]), vt.shape), 747 ) 748 assert_equal(np.sort(s), np.flip(s, -1)) 749 assert_(consistent_subclass(u, a)) 750 assert_(consistent_subclass(vt, a)) 751 752 753class TestSVDHermitian(SVDHermitianCases, SVDBaseTests, TestCase): 754 hermitian = True 755 756 757class CondCases(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase): 758 # cond(x, p) for p in (None, 2, -2) 759 760 def do(self, a, b, tags): 761 c = asarray(a) # a might be a matrix 762 if "size-0" in tags: 763 assert_raises(LinAlgError, linalg.cond, c) 764 return 765 766 # +-2 norms 767 s = linalg.svd(c, compute_uv=False) 768 assert_almost_equal( 769 linalg.cond(a), s[..., 0] / s[..., -1], single_decimal=5, double_decimal=11 770 ) 771 assert_almost_equal( 772 linalg.cond(a, 2), 773 s[..., 0] / s[..., -1], 774 single_decimal=5, 775 double_decimal=11, 776 ) 777 assert_almost_equal( 778 linalg.cond(a, -2), 779 s[..., -1] / s[..., 0], 780 single_decimal=5, 781 double_decimal=11, 782 ) 783 784 # Other norms 785 cinv = np.linalg.inv(c) 786 assert_almost_equal( 787 linalg.cond(a, 1), 788 abs(c).sum(-2).max(-1) * abs(cinv).sum(-2).max(-1), 789 single_decimal=5, 790 double_decimal=11, 791 ) 792 assert_almost_equal( 793 linalg.cond(a, -1), 794 abs(c).sum(-2).min(-1) * abs(cinv).sum(-2).min(-1), 795 single_decimal=5, 796 double_decimal=11, 797 ) 798 assert_almost_equal( 799 linalg.cond(a, np.inf), 800 abs(c).sum(-1).max(-1) * abs(cinv).sum(-1).max(-1), 801 single_decimal=5, 802 double_decimal=11, 803 ) 804 assert_almost_equal( 805 linalg.cond(a, -np.inf), 806 abs(c).sum(-1).min(-1) * abs(cinv).sum(-1).min(-1), 807 single_decimal=5, 808 double_decimal=11, 809 ) 810 assert_almost_equal( 811 linalg.cond(a, "fro"), 812 np.sqrt((abs(c) ** 2).sum(-1).sum(-1) * (abs(cinv) ** 2).sum(-1).sum(-1)), 813 single_decimal=5, 814 double_decimal=11, 815 ) 816 817 818class TestCond(CondCases, TestCase): 819 def test_basic_nonsvd(self): 820 # Smoketest the non-svd norms 821 A = array([[1.0, 0, 1], [0, -2.0, 0], [0, 0, 3.0]]) 822 assert_almost_equal(linalg.cond(A, inf), 4) 823 assert_almost_equal(linalg.cond(A, -inf), 2 / 3) 824 assert_almost_equal(linalg.cond(A, 1), 4) 825 assert_almost_equal(linalg.cond(A, -1), 0.5) 826 assert_almost_equal(linalg.cond(A, "fro"), np.sqrt(265 / 12)) 827 828 def test_singular(self): 829 # Singular matrices have infinite condition number for 830 # positive norms, and negative norms shouldn't raise 831 # exceptions 832 As = [np.zeros((2, 2)), np.ones((2, 2))] 833 p_pos = [None, 1, 2, "fro"] 834 p_neg = [-1, -2] 835 for A, p in itertools.product(As, p_pos): 836 # Inversion may not hit exact infinity, so just check the 837 # number is large 838 assert_(linalg.cond(A, p) > 1e15) 839 for A, p in itertools.product(As, p_neg): 840 linalg.cond(A, p) 841 842 @skip(reason="NP_VER: fails on CI") # ( 843 # True, run=False, reason="Platform/LAPACK-dependent failure, see gh-18914" 844 # ) 845 def test_nan(self): 846 # nans should be passed through, not converted to infs 847 ps = [None, 1, -1, 2, -2, "fro"] 848 p_pos = [None, 1, 2, "fro"] 849 850 A = np.ones((2, 2)) 851 A[0, 1] = np.nan 852 for p in ps: 853 c = linalg.cond(A, p) 854 assert_(isinstance(c, np.float64)) 855 assert_(np.isnan(c)) 856 857 A = np.ones((3, 2, 2)) 858 A[1, 0, 1] = np.nan 859 for p in ps: 860 c = linalg.cond(A, p) 861 assert_(np.isnan(c[1])) 862 if p in p_pos: 863 assert_(c[0] > 1e15) 864 assert_(c[2] > 1e15) 865 else: 866 assert_(not np.isnan(c[0])) 867 assert_(not np.isnan(c[2])) 868 869 def test_stacked_singular(self): 870 # Check behavior when only some of the stacked matrices are 871 # singular 872 np.random.seed(1234) 873 A = np.random.rand(2, 2, 2, 2) 874 A[0, 0] = 0 875 A[1, 1] = 0 876 877 for p in (None, 1, 2, "fro", -1, -2): 878 c = linalg.cond(A, p) 879 assert_equal(c[0, 0], np.inf) 880 assert_equal(c[1, 1], np.inf) 881 assert_(np.isfinite(c[0, 1])) 882 assert_(np.isfinite(c[1, 0])) 883 884 885class PinvCases( 886 LinalgSquareTestCase, 887 LinalgNonsquareTestCase, 888 LinalgGeneralizedSquareTestCase, 889 LinalgGeneralizedNonsquareTestCase, 890): 891 def do(self, a, b, tags): 892 a_ginv = linalg.pinv(a) 893 # `a @ a_ginv == I` does not hold if a is singular 894 dot = dot_generalized 895 assert_almost_equal( 896 dot(dot(a, a_ginv), a), a, single_decimal=5, double_decimal=11 897 ) 898 assert_(consistent_subclass(a_ginv, a)) 899 900 901class TestPinv(PinvCases, TestCase): 902 pass 903 904 905class PinvHermitianCases(HermitianTestCase, HermitianGeneralizedTestCase): 906 def do(self, a, b, tags): 907 a_ginv = linalg.pinv(a, hermitian=True) 908 # `a @ a_ginv == I` does not hold if a is singular 909 dot = dot_generalized 910 assert_almost_equal( 911 dot(dot(a, a_ginv), a), a, single_decimal=5, double_decimal=11 912 ) 913 assert_(consistent_subclass(a_ginv, a)) 914 915 916class TestPinvHermitian(PinvHermitianCases, TestCase): 917 pass 918 919 920class DetCases(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase): 921 def do(self, a, b, tags): 922 d = linalg.det(a) 923 (s, ld) = linalg.slogdet(a) 924 if asarray(a).dtype.type in (single, double): 925 ad = asarray(a).astype(double) 926 else: 927 ad = asarray(a).astype(cdouble) 928 ev = linalg.eigvals(ad) 929 assert_almost_equal(d, np.prod(ev, axis=-1)) 930 assert_almost_equal(s * np.exp(ld), np.prod(ev, axis=-1), single_decimal=5) 931 932 s = np.atleast_1d(s) 933 ld = np.atleast_1d(ld) 934 m = s != 0 935 assert_almost_equal(np.abs(s[m]), 1) 936 assert_equal(ld[~m], -inf) 937 938 939@instantiate_parametrized_tests 940class TestDet(DetCases, TestCase): 941 def test_zero(self): 942 # NB: comment out tests of type(det) == double : we return zero-dim arrays 943 assert_equal(linalg.det([[0.0]]), 0.0) 944 # assert_equal(type(linalg.det([[0.0]])), double) 945 assert_equal(linalg.det([[0.0j]]), 0.0) 946 # assert_equal(type(linalg.det([[0.0j]])), cdouble) 947 948 assert_equal(linalg.slogdet([[0.0]]), (0.0, -inf)) 949 # assert_equal(type(linalg.slogdet([[0.0]])[0]), double) 950 # assert_equal(type(linalg.slogdet([[0.0]])[1]), double) 951 assert_equal(linalg.slogdet([[0.0j]]), (0.0j, -inf)) 952 953 # assert_equal(type(linalg.slogdet([[0.0j]])[0]), cdouble) 954 # assert_equal(type(linalg.slogdet([[0.0j]])[1]), double) 955 956 @parametrize("dtype", [single, double, csingle, cdouble]) 957 def test_types(self, dtype): 958 x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype) 959 assert_equal(np.linalg.det(x).dtype, dtype) 960 ph, s = np.linalg.slogdet(x) 961 assert_equal(s.dtype, get_real_dtype(dtype)) 962 assert_equal(ph.dtype, dtype) 963 964 def test_0_size(self): 965 a = np.zeros((0, 0), dtype=np.complex64) 966 res = linalg.det(a) 967 assert_equal(res, 1.0) 968 assert_(res.dtype.type is np.complex64) 969 res = linalg.slogdet(a) 970 assert_equal(res, (1, 0)) 971 assert_(res[0].dtype.type is np.complex64) 972 assert_(res[1].dtype.type is np.float32) 973 974 a = np.zeros((0, 0), dtype=np.float64) 975 res = linalg.det(a) 976 assert_equal(res, 1.0) 977 assert_(res.dtype.type is np.float64) 978 res = linalg.slogdet(a) 979 assert_equal(res, (1, 0)) 980 assert_(res[0].dtype.type is np.float64) 981 assert_(res[1].dtype.type is np.float64) 982 983 984class LstsqCases(LinalgSquareTestCase, LinalgNonsquareTestCase): 985 def do(self, a, b, tags): 986 arr = np.asarray(a) 987 m, n = arr.shape 988 u, s, vt = linalg.svd(a, False) 989 x, residuals, rank, sv = linalg.lstsq(a, b, rcond=-1) 990 if m == 0: 991 assert_((x == 0).all()) 992 if m <= n: 993 assert_almost_equal(b, dot(a, x), single_decimal=5) 994 assert_equal(rank, m) 995 else: 996 assert_equal(rank, n) 997 # assert_almost_equal(sv, sv.__array_wrap__(s)) 998 if rank == n and m > n: 999 expect_resids = (np.asarray(abs(np.dot(a, x) - b)) ** 2).sum(axis=0) 1000 expect_resids = np.asarray(expect_resids) 1001 if np.asarray(b).ndim == 1: 1002 expect_resids = expect_resids.reshape( 1003 1, 1004 ) 1005 assert_equal(residuals.shape, expect_resids.shape) 1006 else: 1007 expect_resids = np.array([]) # .view(type(x)) 1008 assert_almost_equal(residuals, expect_resids, single_decimal=5) 1009 assert_(np.issubdtype(residuals.dtype, np.floating)) 1010 assert_(consistent_subclass(x, b)) 1011 assert_(consistent_subclass(residuals, b)) 1012 1013 1014@instantiate_parametrized_tests 1015class TestLstsq(LstsqCases, TestCase): 1016 @xpassIfTorchDynamo # (reason="Lstsq: we use the future default =None") 1017 def test_future_rcond(self): 1018 a = np.array( 1019 [ 1020 [0.0, 1.0, 0.0, 1.0, 2.0, 0.0], 1021 [0.0, 2.0, 0.0, 0.0, 1.0, 0.0], 1022 [1.0, 0.0, 1.0, 0.0, 0.0, 4.0], 1023 [0.0, 0.0, 0.0, 2.0, 3.0, 0.0], 1024 ] 1025 ).T 1026 1027 b = np.array([1, 0, 0, 0, 0, 0]) 1028 with suppress_warnings() as sup: 1029 w = sup.record(FutureWarning, "`rcond` parameter will change") 1030 x, residuals, rank, s = linalg.lstsq(a, b) 1031 assert_(rank == 4) 1032 x, residuals, rank, s = linalg.lstsq(a, b, rcond=-1) 1033 assert_(rank == 4) 1034 x, residuals, rank, s = linalg.lstsq(a, b, rcond=None) 1035 assert_(rank == 3) 1036 # Warning should be raised exactly once (first command) 1037 assert_(len(w) == 1) 1038 1039 @parametrize( 1040 "m, n, n_rhs", 1041 [ 1042 (4, 2, 2), 1043 (0, 4, 1), 1044 (0, 4, 2), 1045 (4, 0, 1), 1046 (4, 0, 2), 1047 # (4, 2, 0), # Intel MKL ERROR: Parameter 4 was incorrect on entry to DLALSD. 1048 (0, 0, 0), 1049 ], 1050 ) 1051 def test_empty_a_b(self, m, n, n_rhs): 1052 a = np.arange(m * n).reshape(m, n) 1053 b = np.ones((m, n_rhs)) 1054 x, residuals, rank, s = linalg.lstsq(a, b, rcond=None) 1055 if m == 0: 1056 assert_((x == 0).all()) 1057 assert_equal(x.shape, (n, n_rhs)) 1058 assert_equal(residuals.shape, ((n_rhs,) if m > n else (0,))) 1059 if m > n and n_rhs > 0: 1060 # residuals are exactly the squared norms of b's columns 1061 r = b - np.dot(a, x) 1062 assert_almost_equal(residuals, (r * r).sum(axis=-2)) 1063 assert_equal(rank, min(m, n)) 1064 assert_equal(s.shape, (min(m, n),)) 1065 1066 def test_incompatible_dims(self): 1067 # use modified version of docstring example 1068 x = np.array([0, 1, 2, 3]) 1069 y = np.array([-1, 0.2, 0.9, 2.1, 3.3]) 1070 A = np.vstack([x, np.ones(len(x))]).T 1071 # with assert_raises_regex(LinAlgError, "Incompatible dimensions"): 1072 with assert_raises((RuntimeError, LinAlgError)): 1073 linalg.lstsq(A, y, rcond=None) 1074 1075 1076# @xfail #(reason="no block()") 1077@skip # FIXME: otherwise fails in setUp calling np.block 1078@instantiate_parametrized_tests 1079class TestMatrixPower(TestCase): 1080 def setUp(self): 1081 self.rshft_0 = np.eye(4) 1082 self.rshft_1 = self.rshft_0[[3, 0, 1, 2]] 1083 self.rshft_2 = self.rshft_0[[2, 3, 0, 1]] 1084 self.rshft_3 = self.rshft_0[[1, 2, 3, 0]] 1085 self.rshft_all = [self.rshft_0, self.rshft_1, self.rshft_2, self.rshft_3] 1086 self.noninv = array([[1, 0], [0, 0]]) 1087 self.stacked = np.block([[[self.rshft_0]]] * 2) 1088 # FIXME the 'e' dtype might work in future 1089 self.dtnoinv = [object, np.dtype("e"), np.dtype("g"), np.dtype("G")] 1090 1091 @parametrize("dt", [np.dtype(c) for c in "?bBhilefdFD"]) 1092 def test_large_power(self, dt): 1093 rshft = self.rshft_1.astype(dt) 1094 assert_equal(matrix_power(rshft, 2**100 + 2**10 + 2**5 + 0), self.rshft_0) 1095 assert_equal(matrix_power(rshft, 2**100 + 2**10 + 2**5 + 1), self.rshft_1) 1096 assert_equal(matrix_power(rshft, 2**100 + 2**10 + 2**5 + 2), self.rshft_2) 1097 assert_equal(matrix_power(rshft, 2**100 + 2**10 + 2**5 + 3), self.rshft_3) 1098 1099 @parametrize("dt", [np.dtype(c) for c in "?bBhilefdFD"]) 1100 def test_power_is_zero(self, dt): 1101 def tz(M): 1102 mz = matrix_power(M, 0) 1103 assert_equal(mz, identity_like_generalized(M)) 1104 assert_equal(mz.dtype, M.dtype) 1105 1106 for mat in self.rshft_all: 1107 tz(mat.astype(dt)) 1108 if dt != object: 1109 tz(self.stacked.astype(dt)) 1110 1111 @parametrize("dt", [np.dtype(c) for c in "?bBhilefdFD"]) 1112 def test_power_is_one(self, dt): 1113 def tz(mat): 1114 mz = matrix_power(mat, 1) 1115 assert_equal(mz, mat) 1116 assert_equal(mz.dtype, mat.dtype) 1117 1118 for mat in self.rshft_all: 1119 tz(mat.astype(dt)) 1120 if dt != object: 1121 tz(self.stacked.astype(dt)) 1122 1123 @parametrize("dt", [np.dtype(c) for c in "?bBhilefdFD"]) 1124 def test_power_is_two(self, dt): 1125 def tz(mat): 1126 mz = matrix_power(mat, 2) 1127 mmul = matmul if mat.dtype != object else dot 1128 assert_equal(mz, mmul(mat, mat)) 1129 assert_equal(mz.dtype, mat.dtype) 1130 1131 for mat in self.rshft_all: 1132 tz(mat.astype(dt)) 1133 if dt != object: 1134 tz(self.stacked.astype(dt)) 1135 1136 @parametrize("dt", [np.dtype(c) for c in "?bBhilefdFD"]) 1137 def test_power_is_minus_one(self, dt): 1138 def tz(mat): 1139 invmat = matrix_power(mat, -1) 1140 mmul = matmul if mat.dtype != object else dot 1141 assert_almost_equal(mmul(invmat, mat), identity_like_generalized(mat)) 1142 1143 for mat in self.rshft_all: 1144 if dt not in self.dtnoinv: 1145 tz(mat.astype(dt)) 1146 1147 @parametrize("dt", [np.dtype(c) for c in "?bBhilefdFD"]) 1148 def test_exceptions_bad_power(self, dt): 1149 mat = self.rshft_0.astype(dt) 1150 assert_raises(TypeError, matrix_power, mat, 1.5) 1151 assert_raises(TypeError, matrix_power, mat, [1]) 1152 1153 @parametrize("dt", [np.dtype(c) for c in "?bBhilefdFD"]) 1154 def test_exceptions_non_square(self, dt): 1155 assert_raises(LinAlgError, matrix_power, np.array([1], dt), 1) 1156 assert_raises(LinAlgError, matrix_power, np.array([[1], [2]], dt), 1) 1157 assert_raises(LinAlgError, matrix_power, np.ones((4, 3, 2), dt), 1) 1158 1159 @skipif(IS_WASM, reason="fp errors don't work in wasm") 1160 @parametrize("dt", [np.dtype(c) for c in "?bBhilefdFD"]) 1161 def test_exceptions_not_invertible(self, dt): 1162 if dt in self.dtnoinv: 1163 return 1164 mat = self.noninv.astype(dt) 1165 assert_raises(LinAlgError, matrix_power, mat, -1) 1166 1167 1168class TestEigvalshCases(HermitianTestCase, HermitianGeneralizedTestCase): 1169 def do(self, a, b, tags): 1170 pytest.xfail(reason="sort complex") 1171 # note that eigenvalue arrays returned by eig must be sorted since 1172 # their order isn't guaranteed. 1173 ev = linalg.eigvalsh(a, "L") 1174 evalues, evectors = linalg.eig(a) 1175 evalues.sort(axis=-1) 1176 assert_allclose(ev, evalues, rtol=get_rtol(ev.dtype)) 1177 1178 ev2 = linalg.eigvalsh(a, "U") 1179 assert_allclose(ev2, evalues, rtol=get_rtol(ev.dtype)) 1180 1181 1182@instantiate_parametrized_tests 1183class TestEigvalsh(TestCase): 1184 @parametrize("dtype", [single, double, csingle, cdouble]) 1185 def test_types(self, dtype): 1186 x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype) 1187 w = np.linalg.eigvalsh(x) 1188 assert_equal(w.dtype, get_real_dtype(dtype)) 1189 1190 def test_invalid(self): 1191 x = np.array([[1, 0.5], [0.5, 1]], dtype=np.float32) 1192 assert_raises((RuntimeError, ValueError), np.linalg.eigvalsh, x, UPLO="lrong") 1193 assert_raises((RuntimeError, ValueError), np.linalg.eigvalsh, x, "lower") 1194 assert_raises((RuntimeError, ValueError), np.linalg.eigvalsh, x, "upper") 1195 1196 def test_UPLO(self): 1197 Klo = np.array([[0, 0], [1, 0]], dtype=np.double) 1198 Kup = np.array([[0, 1], [0, 0]], dtype=np.double) 1199 tgt = np.array([-1, 1], dtype=np.double) 1200 rtol = get_rtol(np.double) 1201 1202 # Check default is 'L' 1203 w = np.linalg.eigvalsh(Klo) 1204 assert_allclose(w, tgt, rtol=rtol) 1205 # Check 'L' 1206 w = np.linalg.eigvalsh(Klo, UPLO="L") 1207 assert_allclose(w, tgt, rtol=rtol) 1208 # Check 'l' 1209 w = np.linalg.eigvalsh(Klo, UPLO="l") 1210 assert_allclose(w, tgt, rtol=rtol) 1211 # Check 'U' 1212 w = np.linalg.eigvalsh(Kup, UPLO="U") 1213 assert_allclose(w, tgt, rtol=rtol) 1214 # Check 'u' 1215 w = np.linalg.eigvalsh(Kup, UPLO="u") 1216 assert_allclose(w, tgt, rtol=rtol) 1217 1218 def test_0_size(self): 1219 # Check that all kinds of 0-sized arrays work 1220 # class ArraySubclass(np.ndarray): 1221 # pass 1222 a = np.zeros((0, 1, 1), dtype=np.int_) # .view(ArraySubclass) 1223 res = linalg.eigvalsh(a) 1224 assert_(res.dtype.type is np.float64) 1225 assert_equal((0, 1), res.shape) 1226 # This is just for documentation, it might make sense to change: 1227 assert_(isinstance(res, np.ndarray)) 1228 1229 a = np.zeros((0, 0), dtype=np.complex64) # .view(ArraySubclass) 1230 res = linalg.eigvalsh(a) 1231 assert_(res.dtype.type is np.float32) 1232 assert_equal((0,), res.shape) 1233 # This is just for documentation, it might make sense to change: 1234 assert_(isinstance(res, np.ndarray)) 1235 1236 1237class TestEighCases(HermitianTestCase, HermitianGeneralizedTestCase): 1238 def do(self, a, b, tags): 1239 pytest.xfail(reason="sort complex") 1240 # note that eigenvalue arrays returned by eig must be sorted since 1241 # their order isn't guaranteed. 1242 ev, evc = linalg.eigh(a) 1243 evalues, evectors = linalg.eig(a) 1244 evalues.sort(axis=-1) 1245 assert_almost_equal(ev, evalues) 1246 1247 assert_allclose( 1248 dot_generalized(a, evc), 1249 np.asarray(ev)[..., None, :] * np.asarray(evc), 1250 rtol=get_rtol(ev.dtype), 1251 ) 1252 1253 ev2, evc2 = linalg.eigh(a, "U") 1254 assert_almost_equal(ev2, evalues) 1255 1256 assert_allclose( 1257 dot_generalized(a, evc2), 1258 np.asarray(ev2)[..., None, :] * np.asarray(evc2), 1259 rtol=get_rtol(ev.dtype), 1260 err_msg=repr(a), 1261 ) 1262 1263 1264@instantiate_parametrized_tests 1265class TestEigh(TestCase): 1266 @parametrize("dtype", [single, double, csingle, cdouble]) 1267 def test_types(self, dtype): 1268 x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype) 1269 w, v = np.linalg.eigh(x) 1270 assert_equal(w.dtype, get_real_dtype(dtype)) 1271 assert_equal(v.dtype, dtype) 1272 1273 def test_invalid(self): 1274 x = np.array([[1, 0.5], [0.5, 1]], dtype=np.float32) 1275 assert_raises((RuntimeError, ValueError), np.linalg.eigh, x, UPLO="lrong") 1276 assert_raises((RuntimeError, ValueError), np.linalg.eigh, x, "lower") 1277 assert_raises((RuntimeError, ValueError), np.linalg.eigh, x, "upper") 1278 1279 def test_UPLO(self): 1280 Klo = np.array([[0, 0], [1, 0]], dtype=np.double) 1281 Kup = np.array([[0, 1], [0, 0]], dtype=np.double) 1282 tgt = np.array([-1, 1], dtype=np.double) 1283 rtol = get_rtol(np.double) 1284 1285 # Check default is 'L' 1286 w, v = np.linalg.eigh(Klo) 1287 assert_allclose(w, tgt, rtol=rtol) 1288 # Check 'L' 1289 w, v = np.linalg.eigh(Klo, UPLO="L") 1290 assert_allclose(w, tgt, rtol=rtol) 1291 # Check 'l' 1292 w, v = np.linalg.eigh(Klo, UPLO="l") 1293 assert_allclose(w, tgt, rtol=rtol) 1294 # Check 'U' 1295 w, v = np.linalg.eigh(Kup, UPLO="U") 1296 assert_allclose(w, tgt, rtol=rtol) 1297 # Check 'u' 1298 w, v = np.linalg.eigh(Kup, UPLO="u") 1299 assert_allclose(w, tgt, rtol=rtol) 1300 1301 def test_0_size(self): 1302 # Check that all kinds of 0-sized arrays work 1303 # class ArraySubclass(np.ndarray): 1304 # pass 1305 a = np.zeros((0, 1, 1), dtype=np.int_) # .view(ArraySubclass) 1306 res, res_v = linalg.eigh(a) 1307 assert_(res_v.dtype.type is np.float64) 1308 assert_(res.dtype.type is np.float64) 1309 assert_equal(a.shape, res_v.shape) 1310 assert_equal((0, 1), res.shape) 1311 # This is just for documentation, it might make sense to change: 1312 assert_(isinstance(a, np.ndarray)) 1313 1314 a = np.zeros((0, 0), dtype=np.complex64) # .view(ArraySubclass) 1315 res, res_v = linalg.eigh(a) 1316 assert_(res_v.dtype.type is np.complex64) 1317 assert_(res.dtype.type is np.float32) 1318 assert_equal(a.shape, res_v.shape) 1319 assert_equal((0,), res.shape) 1320 # This is just for documentation, it might make sense to change: 1321 assert_(isinstance(a, np.ndarray)) 1322 1323 1324class _TestNormBase: 1325 dt = None 1326 dec = None 1327 1328 @staticmethod 1329 def check_dtype(x, res): 1330 if issubclass(x.dtype.type, np.inexact): 1331 assert_equal(res.dtype, x.real.dtype) 1332 else: 1333 # For integer input, don't have to test float precision of output. 1334 assert_(issubclass(res.dtype.type, np.floating)) 1335 1336 1337class _TestNormGeneral(_TestNormBase): 1338 def test_empty(self): 1339 assert_equal(norm([]), 0.0) 1340 assert_equal(norm(array([], dtype=self.dt)), 0.0) 1341 assert_equal(norm(atleast_2d(array([], dtype=self.dt))), 0.0) 1342 1343 def test_vector_return_type(self): 1344 a = np.array([1, 0, 1]) 1345 1346 exact_types = "Bbhil" # np.typecodes["AllInteger"] 1347 inexact_types = "efdFD" # np.typecodes["AllFloat"] 1348 1349 all_types = exact_types + inexact_types 1350 1351 for each_type in all_types: 1352 at = a.astype(each_type) 1353 1354 if each_type == np.dtype("float16"): 1355 # FIXME: move looping to parametrize, add decorators=[xfail] 1356 # pytest.xfail("float16**float64 => float64 (?)") 1357 raise SkipTest("float16**float64 => float64 (?)") 1358 1359 an = norm(at, -np.inf) 1360 self.check_dtype(at, an) 1361 assert_almost_equal(an, 0.0) 1362 1363 with suppress_warnings() as sup: 1364 sup.filter(RuntimeWarning, "divide by zero encountered") 1365 an = norm(at, -1) 1366 self.check_dtype(at, an) 1367 assert_almost_equal(an, 0.0) 1368 1369 an = norm(at, 0) 1370 self.check_dtype(at, an) 1371 assert_almost_equal(an, 2) 1372 1373 an = norm(at, 1) 1374 self.check_dtype(at, an) 1375 assert_almost_equal(an, 2.0) 1376 1377 an = norm(at, 2) 1378 self.check_dtype(at, an) 1379 assert_almost_equal(an, an.dtype.type(2.0) ** an.dtype.type(1.0 / 2.0)) 1380 1381 an = norm(at, 4) 1382 self.check_dtype(at, an) 1383 assert_almost_equal(an, an.dtype.type(2.0) ** an.dtype.type(1.0 / 4.0)) 1384 1385 an = norm(at, np.inf) 1386 self.check_dtype(at, an) 1387 assert_almost_equal(an, 1.0) 1388 1389 def test_vector(self): 1390 a = [1, 2, 3, 4] 1391 b = [-1, -2, -3, -4] 1392 c = [-1, 2, -3, 4] 1393 1394 def _test(v): 1395 np.testing.assert_almost_equal(norm(v), 30**0.5, decimal=self.dec) 1396 np.testing.assert_almost_equal(norm(v, inf), 4.0, decimal=self.dec) 1397 np.testing.assert_almost_equal(norm(v, -inf), 1.0, decimal=self.dec) 1398 np.testing.assert_almost_equal(norm(v, 1), 10.0, decimal=self.dec) 1399 np.testing.assert_almost_equal(norm(v, -1), 12.0 / 25, decimal=self.dec) 1400 np.testing.assert_almost_equal(norm(v, 2), 30**0.5, decimal=self.dec) 1401 np.testing.assert_almost_equal( 1402 norm(v, -2), ((205.0 / 144) ** -0.5), decimal=self.dec 1403 ) 1404 np.testing.assert_almost_equal(norm(v, 0), 4, decimal=self.dec) 1405 1406 for v in ( 1407 a, 1408 b, 1409 c, 1410 ): 1411 _test(v) 1412 1413 for v in ( 1414 array(a, dtype=self.dt), 1415 array(b, dtype=self.dt), 1416 array(c, dtype=self.dt), 1417 ): 1418 _test(v) 1419 1420 def test_axis(self): 1421 # Vector norms. 1422 # Compare the use of `axis` with computing the norm of each row 1423 # or column separately. 1424 A = array([[1, 2, 3], [4, 5, 6]], dtype=self.dt) 1425 for order in [None, -1, 0, 1, 2, 3, np.inf, -np.inf]: 1426 expected0 = [norm(A[:, k], ord=order) for k in range(A.shape[1])] 1427 assert_almost_equal(norm(A, ord=order, axis=0), expected0) 1428 expected1 = [norm(A[k, :], ord=order) for k in range(A.shape[0])] 1429 assert_almost_equal(norm(A, ord=order, axis=1), expected1) 1430 1431 # Matrix norms. 1432 B = np.arange(1, 25, dtype=self.dt).reshape(2, 3, 4) 1433 nd = B.ndim 1434 for order in [None, -2, 2, -1, 1, np.inf, -np.inf, "fro"]: 1435 for axis in itertools.combinations(range(-nd, nd), 2): 1436 row_axis, col_axis = axis 1437 if row_axis < 0: 1438 row_axis += nd 1439 if col_axis < 0: 1440 col_axis += nd 1441 if row_axis == col_axis: 1442 assert_raises( 1443 (RuntimeError, ValueError), norm, B, ord=order, axis=axis 1444 ) 1445 else: 1446 n = norm(B, ord=order, axis=axis) 1447 1448 # The logic using k_index only works for nd = 3. 1449 # This has to be changed if nd is increased. 1450 k_index = nd - (row_axis + col_axis) 1451 if row_axis < col_axis: 1452 expected = [ 1453 norm(B[:].take(k, axis=k_index), ord=order) 1454 for k in range(B.shape[k_index]) 1455 ] 1456 else: 1457 expected = [ 1458 norm(B[:].take(k, axis=k_index).T, ord=order) 1459 for k in range(B.shape[k_index]) 1460 ] 1461 assert_almost_equal(n, expected) 1462 1463 def test_keepdims(self): 1464 A = np.arange(1, 25, dtype=self.dt).reshape(2, 3, 4) 1465 1466 allclose_err = "order {0}, axis = {1}" 1467 shape_err = "Shape mismatch found {0}, expected {1}, order={2}, axis={3}" 1468 1469 # check the order=None, axis=None case 1470 expected = norm(A, ord=None, axis=None) 1471 found = norm(A, ord=None, axis=None, keepdims=True) 1472 assert_allclose( 1473 np.squeeze(found), expected, err_msg=allclose_err.format(None, None) 1474 ) 1475 expected_shape = (1, 1, 1) 1476 assert_( 1477 found.shape == expected_shape, 1478 shape_err.format(found.shape, expected_shape, None, None), 1479 ) 1480 1481 # Vector norms. 1482 for order in [None, -1, 0, 1, 2, 3, np.inf, -np.inf]: 1483 for k in range(A.ndim): 1484 expected = norm(A, ord=order, axis=k) 1485 found = norm(A, ord=order, axis=k, keepdims=True) 1486 assert_allclose( 1487 np.squeeze(found), expected, err_msg=allclose_err.format(order, k) 1488 ) 1489 expected_shape = list(A.shape) 1490 expected_shape[k] = 1 1491 expected_shape = tuple(expected_shape) 1492 assert_( 1493 found.shape == expected_shape, 1494 shape_err.format(found.shape, expected_shape, order, k), 1495 ) 1496 1497 # Matrix norms. 1498 for order in [None, -2, 2, -1, 1, np.inf, -np.inf, "fro", "nuc"]: 1499 for k in itertools.permutations(range(A.ndim), 2): 1500 expected = norm(A, ord=order, axis=k) 1501 found = norm(A, ord=order, axis=k, keepdims=True) 1502 assert_allclose( 1503 np.squeeze(found), expected, err_msg=allclose_err.format(order, k) 1504 ) 1505 expected_shape = list(A.shape) 1506 expected_shape[k[0]] = 1 1507 expected_shape[k[1]] = 1 1508 expected_shape = tuple(expected_shape) 1509 assert_( 1510 found.shape == expected_shape, 1511 shape_err.format(found.shape, expected_shape, order, k), 1512 ) 1513 1514 1515class _TestNorm2D(_TestNormBase): 1516 # Define the part for 2d arrays separately, so we can subclass this 1517 # and run the tests using np.matrix in matrixlib.tests.test_matrix_linalg. 1518 1519 def test_matrix_empty(self): 1520 assert_equal(norm(np.array([[]], dtype=self.dt)), 0.0) 1521 1522 def test_matrix_return_type(self): 1523 a = np.array([[1, 0, 1], [0, 1, 1]]) 1524 1525 exact_types = "Bbhil" # np.typecodes["AllInteger"] 1526 1527 # float32, complex64, float64, complex128 types are the only types 1528 # allowed by `linalg`, which performs the matrix operations used 1529 # within `norm`. 1530 inexact_types = "fdFD" 1531 1532 all_types = exact_types + inexact_types 1533 1534 for each_type in all_types: 1535 at = a.astype(each_type) 1536 1537 an = norm(at, -np.inf) 1538 self.check_dtype(at, an) 1539 assert_almost_equal(an, 2.0) 1540 1541 with suppress_warnings() as sup: 1542 sup.filter(RuntimeWarning, "divide by zero encountered") 1543 an = norm(at, -1) 1544 self.check_dtype(at, an) 1545 assert_almost_equal(an, 1.0) 1546 1547 an = norm(at, 1) 1548 self.check_dtype(at, an) 1549 assert_almost_equal(an, 2.0) 1550 1551 an = norm(at, 2) 1552 self.check_dtype(at, an) 1553 assert_almost_equal(an, 3.0 ** (1.0 / 2.0)) 1554 1555 an = norm(at, -2) 1556 self.check_dtype(at, an) 1557 assert_almost_equal(an, 1.0) 1558 1559 an = norm(at, np.inf) 1560 self.check_dtype(at, an) 1561 assert_almost_equal(an, 2.0) 1562 1563 an = norm(at, "fro") 1564 self.check_dtype(at, an) 1565 assert_almost_equal(an, 2.0) 1566 1567 an = norm(at, "nuc") 1568 self.check_dtype(at, an) 1569 # Lower bar needed to support low precision floats. 1570 # They end up being off by 1 in the 7th place. 1571 np.testing.assert_almost_equal(an, 2.7320508075688772, decimal=6) 1572 1573 def test_matrix_2x2(self): 1574 A = np.array([[1, 3], [5, 7]], dtype=self.dt) 1575 assert_almost_equal(norm(A), 84**0.5) 1576 assert_almost_equal(norm(A, "fro"), 84**0.5) 1577 assert_almost_equal(norm(A, "nuc"), 10.0) 1578 assert_almost_equal(norm(A, inf), 12.0) 1579 assert_almost_equal(norm(A, -inf), 4.0) 1580 assert_almost_equal(norm(A, 1), 10.0) 1581 assert_almost_equal(norm(A, -1), 6.0) 1582 assert_almost_equal(norm(A, 2), 9.1231056256176615) 1583 assert_almost_equal(norm(A, -2), 0.87689437438234041) 1584 1585 assert_raises((RuntimeError, ValueError), norm, A, "nofro") 1586 assert_raises((RuntimeError, ValueError), norm, A, -3) 1587 assert_raises((RuntimeError, ValueError), norm, A, 0) 1588 1589 def test_matrix_3x3(self): 1590 # This test has been added because the 2x2 example 1591 # happened to have equal nuclear norm and induced 1-norm. 1592 # The 1/10 scaling factor accommodates the absolute tolerance 1593 # used in assert_almost_equal. 1594 A = (1 / 10) * np.array([[1, 2, 3], [6, 0, 5], [3, 2, 1]], dtype=self.dt) 1595 assert_almost_equal(norm(A), (1 / 10) * 89**0.5) 1596 assert_almost_equal(norm(A, "fro"), (1 / 10) * 89**0.5) 1597 assert_almost_equal(norm(A, "nuc"), 1.3366836911774836) 1598 assert_almost_equal(norm(A, inf), 1.1) 1599 assert_almost_equal(norm(A, -inf), 0.6) 1600 assert_almost_equal(norm(A, 1), 1.0) 1601 assert_almost_equal(norm(A, -1), 0.4) 1602 assert_almost_equal(norm(A, 2), 0.88722940323461277) 1603 assert_almost_equal(norm(A, -2), 0.19456584790481812) 1604 1605 def test_bad_args(self): 1606 # Check that bad arguments raise the appropriate exceptions. 1607 1608 A = np.array([[1, 2, 3], [4, 5, 6]], dtype=self.dt) 1609 B = np.arange(1, 25, dtype=self.dt).reshape(2, 3, 4) 1610 1611 # Using `axis=<integer>` or passing in a 1-D array implies vector 1612 # norms are being computed, so also using `ord='fro'` 1613 # or `ord='nuc'` or any other string raises a ValueError. 1614 assert_raises((RuntimeError, ValueError), norm, A, "fro", 0) 1615 assert_raises((RuntimeError, ValueError), norm, A, "nuc", 0) 1616 assert_raises((RuntimeError, ValueError), norm, [3, 4], "fro", None) 1617 assert_raises((RuntimeError, ValueError), norm, [3, 4], "nuc", None) 1618 assert_raises((RuntimeError, ValueError), norm, [3, 4], "test", None) 1619 1620 # Similarly, norm should raise an exception when ord is any finite 1621 # number other than 1, 2, -1 or -2 when computing matrix norms. 1622 for order in [0, 3]: 1623 assert_raises((RuntimeError, ValueError), norm, A, order, None) 1624 assert_raises((RuntimeError, ValueError), norm, A, order, (0, 1)) 1625 assert_raises((RuntimeError, ValueError), norm, B, order, (1, 2)) 1626 1627 # Invalid axis 1628 assert_raises((IndexError, np.AxisError), norm, B, None, 3) 1629 assert_raises((IndexError, np.AxisError), norm, B, None, (2, 3)) 1630 assert_raises((RuntimeError, ValueError), norm, B, None, (0, 1, 2)) 1631 1632 1633class _TestNorm(_TestNorm2D, _TestNormGeneral): 1634 pass 1635 1636 1637class TestNorm_NonSystematic(TestCase): 1638 def test_intmin(self): 1639 # Non-regression test: p-norm of signed integer would previously do 1640 # float cast and abs in the wrong order. 1641 x = np.array([-(2**31)], dtype=np.int32) 1642 old_assert_almost_equal(norm(x, ord=3), 2**31, decimal=5) 1643 1644 1645# Separate definitions so we can use them for matrix tests. 1646class _TestNormDoubleBase(_TestNormBase, TestCase): 1647 dt = np.double 1648 dec = 12 1649 1650 1651class _TestNormSingleBase(_TestNormBase, TestCase): 1652 dt = np.float32 1653 dec = 6 1654 1655 1656class _TestNormInt64Base(_TestNormBase, TestCase): 1657 dt = np.int64 1658 dec = 12 1659 1660 1661class TestNormDouble(_TestNorm, _TestNormDoubleBase, TestCase): 1662 pass 1663 1664 1665class TestNormSingle(_TestNorm, _TestNormSingleBase, TestCase): 1666 pass 1667 1668 1669class TestNormInt64(_TestNorm, _TestNormInt64Base): 1670 pass 1671 1672 1673class TestMatrixRank(TestCase): 1674 def test_matrix_rank(self): 1675 # Full rank matrix 1676 assert_equal(4, matrix_rank(np.eye(4))) 1677 # rank deficient matrix 1678 I = np.eye(4) 1679 I[-1, -1] = 0.0 1680 assert_equal(matrix_rank(I), 3) 1681 # All zeros - zero rank 1682 assert_equal(matrix_rank(np.zeros((4, 4))), 0) 1683 # 1 dimension - rank 1 unless all 0 1684 assert_equal(matrix_rank([1, 0, 0, 0]), 1) 1685 assert_equal(matrix_rank(np.zeros((4,))), 0) 1686 # accepts array-like 1687 assert_equal(matrix_rank([1]), 1) 1688 # greater than 2 dimensions treated as stacked matrices 1689 ms = np.array([I, np.eye(4), np.zeros((4, 4))]) 1690 assert_equal(matrix_rank(ms), np.array([3, 4, 0])) 1691 # works on scalar 1692 assert_equal(matrix_rank(1), 1) 1693 1694 def test_symmetric_rank(self): 1695 assert_equal(4, matrix_rank(np.eye(4), hermitian=True)) 1696 assert_equal(1, matrix_rank(np.ones((4, 4)), hermitian=True)) 1697 assert_equal(0, matrix_rank(np.zeros((4, 4)), hermitian=True)) 1698 # rank deficient matrix 1699 I = np.eye(4) 1700 I[-1, -1] = 0.0 1701 assert_equal(3, matrix_rank(I, hermitian=True)) 1702 # manually supplied tolerance 1703 I[-1, -1] = 1e-8 1704 assert_equal(4, matrix_rank(I, hermitian=True, tol=0.99e-8)) 1705 assert_equal(3, matrix_rank(I, hermitian=True, tol=1.01e-8)) 1706 1707 def test_reduced_rank(self): 1708 # Test matrices with reduced rank 1709 # rng = np.random.RandomState(20120714) 1710 np.random.seed(20120714) 1711 for i in range(100): 1712 # Make a rank deficient matrix 1713 X = np.random.normal(size=(40, 10)) 1714 X[:, 0] = X[:, 1] + X[:, 2] 1715 # Assert that matrix_rank detected deficiency 1716 assert_equal(matrix_rank(X), 9) 1717 X[:, 3] = X[:, 4] + X[:, 5] 1718 assert_equal(matrix_rank(X), 8) 1719 1720 1721@instantiate_parametrized_tests 1722class TestQR(TestCase): 1723 def check_qr(self, a): 1724 # This test expects the argument `a` to be an ndarray or 1725 # a subclass of an ndarray of inexact type. 1726 a_type = type(a) 1727 a_dtype = a.dtype 1728 m, n = a.shape 1729 k = min(m, n) 1730 1731 # mode == 'complete' 1732 q, r = linalg.qr(a, mode="complete") 1733 assert_(q.dtype == a_dtype) 1734 assert_(r.dtype == a_dtype) 1735 assert_(isinstance(q, a_type)) 1736 assert_(isinstance(r, a_type)) 1737 assert_(q.shape == (m, m)) 1738 assert_(r.shape == (m, n)) 1739 assert_almost_equal(dot(q, r), a, single_decimal=5) 1740 assert_almost_equal(dot(q.T.conj(), q), np.eye(m)) 1741 assert_almost_equal(np.triu(r), r) 1742 1743 # mode == 'reduced' 1744 q1, r1 = linalg.qr(a, mode="reduced") 1745 assert_(q1.dtype == a_dtype) 1746 assert_(r1.dtype == a_dtype) 1747 assert_(isinstance(q1, a_type)) 1748 assert_(isinstance(r1, a_type)) 1749 assert_(q1.shape == (m, k)) 1750 assert_(r1.shape == (k, n)) 1751 assert_almost_equal(dot(q1, r1), a, single_decimal=5) 1752 assert_almost_equal(dot(q1.T.conj(), q1), np.eye(k)) 1753 assert_almost_equal(np.triu(r1), r1) 1754 1755 # mode == 'r' 1756 r2 = linalg.qr(a, mode="r") 1757 assert_(r2.dtype == a_dtype) 1758 assert_(isinstance(r2, a_type)) 1759 assert_almost_equal(r2, r1) 1760 1761 @xpassIfTorchDynamo # (reason="torch does not allow qr(..., mode='raw'") 1762 @parametrize("m, n", [(3, 0), (0, 3), (0, 0)]) 1763 def test_qr_empty(self, m, n): 1764 k = min(m, n) 1765 a = np.empty((m, n)) 1766 1767 self.check_qr(a) 1768 1769 h, tau = np.linalg.qr(a, mode="raw") 1770 assert_equal(h.dtype, np.double) 1771 assert_equal(tau.dtype, np.double) 1772 assert_equal(h.shape, (n, m)) 1773 assert_equal(tau.shape, (k,)) 1774 1775 @xpassIfTorchDynamo # (reason="torch does not allow qr(..., mode='raw'") 1776 def test_mode_raw(self): 1777 # The factorization is not unique and varies between libraries, 1778 # so it is not possible to check against known values. Functional 1779 # testing is a possibility, but awaits the exposure of more 1780 # of the functions in lapack_lite. Consequently, this test is 1781 # very limited in scope. Note that the results are in FORTRAN 1782 # order, hence the h arrays are transposed. 1783 a = np.array([[1, 2], [3, 4], [5, 6]], dtype=np.double) 1784 1785 # Test double 1786 h, tau = linalg.qr(a, mode="raw") 1787 assert_(h.dtype == np.double) 1788 assert_(tau.dtype == np.double) 1789 assert_(h.shape == (2, 3)) 1790 assert_(tau.shape == (2,)) 1791 1792 h, tau = linalg.qr(a.T, mode="raw") 1793 assert_(h.dtype == np.double) 1794 assert_(tau.dtype == np.double) 1795 assert_(h.shape == (3, 2)) 1796 assert_(tau.shape == (2,)) 1797 1798 def test_mode_all_but_economic(self): 1799 a = np.array([[1, 2], [3, 4]]) 1800 b = np.array([[1, 2], [3, 4], [5, 6]]) 1801 for dt in "fd": 1802 m1 = a.astype(dt) 1803 m2 = b.astype(dt) 1804 self.check_qr(m1) 1805 self.check_qr(m2) 1806 self.check_qr(m2.T) 1807 1808 for dt in "fd": 1809 m1 = 1 + 1j * a.astype(dt) 1810 m2 = 1 + 1j * b.astype(dt) 1811 self.check_qr(m1) 1812 self.check_qr(m2) 1813 self.check_qr(m2.T) 1814 1815 def check_qr_stacked(self, a): 1816 # This test expects the argument `a` to be an ndarray or 1817 # a subclass of an ndarray of inexact type. 1818 a_type = type(a) 1819 a_dtype = a.dtype 1820 m, n = a.shape[-2:] 1821 k = min(m, n) 1822 1823 # mode == 'complete' 1824 q, r = linalg.qr(a, mode="complete") 1825 assert_(q.dtype == a_dtype) 1826 assert_(r.dtype == a_dtype) 1827 assert_(isinstance(q, a_type)) 1828 assert_(isinstance(r, a_type)) 1829 assert_(q.shape[-2:] == (m, m)) 1830 assert_(r.shape[-2:] == (m, n)) 1831 assert_almost_equal(matmul(q, r), a, single_decimal=5) 1832 I_mat = np.identity(q.shape[-1]) 1833 stack_I_mat = np.broadcast_to(I_mat, q.shape[:-2] + (q.shape[-1],) * 2) 1834 assert_almost_equal(matmul(swapaxes(q, -1, -2).conj(), q), stack_I_mat) 1835 assert_almost_equal(np.triu(r[..., :, :]), r) 1836 1837 # mode == 'reduced' 1838 q1, r1 = linalg.qr(a, mode="reduced") 1839 assert_(q1.dtype == a_dtype) 1840 assert_(r1.dtype == a_dtype) 1841 assert_(isinstance(q1, a_type)) 1842 assert_(isinstance(r1, a_type)) 1843 assert_(q1.shape[-2:] == (m, k)) 1844 assert_(r1.shape[-2:] == (k, n)) 1845 assert_almost_equal(matmul(q1, r1), a, single_decimal=5) 1846 I_mat = np.identity(q1.shape[-1]) 1847 stack_I_mat = np.broadcast_to(I_mat, q1.shape[:-2] + (q1.shape[-1],) * 2) 1848 assert_almost_equal(matmul(swapaxes(q1, -1, -2).conj(), q1), stack_I_mat) 1849 assert_almost_equal(np.triu(r1[..., :, :]), r1) 1850 1851 # mode == 'r' 1852 r2 = linalg.qr(a, mode="r") 1853 assert_(r2.dtype == a_dtype) 1854 assert_(isinstance(r2, a_type)) 1855 assert_almost_equal(r2, r1) 1856 1857 @skipif(numpy.__version__ < "1.22", reason="NP_VER: fails on CI with numpy 1.21.2") 1858 @parametrize("size", [(3, 4), (4, 3), (4, 4), (3, 0), (0, 3)]) 1859 @parametrize("outer_size", [(2, 2), (2,), (2, 3, 4)]) 1860 @parametrize("dt", [np.single, np.double, np.csingle, np.cdouble]) 1861 def test_stacked_inputs(self, outer_size, size, dt): 1862 A = np.random.normal(size=outer_size + size).astype(dt) 1863 B = np.random.normal(size=outer_size + size).astype(dt) 1864 self.check_qr_stacked(A) 1865 self.check_qr_stacked(A + 1.0j * B) 1866 1867 1868@instantiate_parametrized_tests 1869class TestCholesky(TestCase): 1870 # TODO: are there no other tests for cholesky? 1871 1872 @parametrize("shape", [(1, 1), (2, 2), (3, 3), (50, 50), (3, 10, 10)]) 1873 @parametrize("dtype", (np.float32, np.float64, np.complex64, np.complex128)) 1874 def test_basic_property(self, shape, dtype): 1875 # Check A = L L^H 1876 np.random.seed(1) 1877 a = np.random.randn(*shape) 1878 if np.issubdtype(dtype, np.complexfloating): 1879 a = a + 1j * np.random.randn(*shape) 1880 1881 t = list(range(len(shape))) 1882 t[-2:] = -1, -2 1883 1884 a = np.matmul(a.transpose(t).conj(), a) 1885 a = np.asarray(a, dtype=dtype) 1886 1887 c = np.linalg.cholesky(a) 1888 1889 b = np.matmul(c, c.transpose(t).conj()) 1890 atol = 500 * a.shape[0] * np.finfo(dtype).eps 1891 assert_allclose(b, a, atol=atol, err_msg=f"{shape} {dtype}\n{a}\n{c}") 1892 1893 def test_0_size(self): 1894 # class ArraySubclass(np.ndarray): 1895 # pass 1896 a = np.zeros((0, 1, 1), dtype=np.int_) # .view(ArraySubclass) 1897 res = linalg.cholesky(a) 1898 assert_equal(a.shape, res.shape) 1899 assert_(res.dtype.type is np.float64) 1900 # for documentation purpose: 1901 assert_(isinstance(res, np.ndarray)) 1902 1903 a = np.zeros((1, 0, 0), dtype=np.complex64) # .view(ArraySubclass) 1904 res = linalg.cholesky(a) 1905 assert_equal(a.shape, res.shape) 1906 assert_(res.dtype.type is np.complex64) 1907 assert_(isinstance(res, np.ndarray)) 1908 1909 1910class TestMisc(TestCase): 1911 @xpassIfTorchDynamo # (reason="endianness") 1912 def test_byteorder_check(self): 1913 # Byte order check should pass for native order 1914 if sys.byteorder == "little": 1915 native = "<" 1916 else: 1917 native = ">" 1918 1919 for dtt in (np.float32, np.float64): 1920 arr = np.eye(4, dtype=dtt) 1921 n_arr = arr.newbyteorder(native) 1922 sw_arr = arr.newbyteorder("S").byteswap() 1923 assert_equal(arr.dtype.byteorder, "=") 1924 for routine in (linalg.inv, linalg.det, linalg.pinv): 1925 # Normal call 1926 res = routine(arr) 1927 # Native but not '=' 1928 assert_array_equal(res, routine(n_arr)) 1929 # Swapped 1930 assert_array_equal(res, routine(sw_arr)) 1931 1932 @pytest.mark.skipif(IS_WASM, reason="fp errors don't work in wasm") 1933 def test_generalized_raise_multiloop(self): 1934 # It should raise an error even if the error doesn't occur in the 1935 # last iteration of the ufunc inner loop 1936 1937 invertible = np.array([[1, 2], [3, 4]]) 1938 non_invertible = np.array([[1, 1], [1, 1]]) 1939 1940 x = np.zeros([4, 4, 2, 2])[1::2] 1941 x[...] = invertible 1942 x[0, 0] = non_invertible 1943 1944 assert_raises(np.linalg.LinAlgError, np.linalg.inv, x) 1945 1946 def test_xerbla_override(self): 1947 # Check that our xerbla has been successfully linked in. If it is not, 1948 # the default xerbla routine is called, which prints a message to stdout 1949 # and may, or may not, abort the process depending on the LAPACK package. 1950 1951 XERBLA_OK = 255 1952 1953 try: 1954 pid = os.fork() 1955 except (OSError, AttributeError): 1956 # fork failed, or not running on POSIX 1957 raise SkipTest("Not POSIX or fork failed.") # noqa: B904 1958 1959 if pid == 0: 1960 # child; close i/o file handles 1961 os.close(1) 1962 os.close(0) 1963 # Avoid producing core files. 1964 import resource 1965 1966 resource.setrlimit(resource.RLIMIT_CORE, (0, 0)) 1967 # These calls may abort. 1968 try: 1969 np.linalg.lapack_lite.xerbla() 1970 except ValueError: 1971 pass 1972 except Exception: 1973 os._exit(os.EX_CONFIG) 1974 1975 try: 1976 a = np.array([[1.0]]) 1977 np.linalg.lapack_lite.dorgqr( 1978 1, 1, 1, a, 0, a, a, 0, 0 1979 ) # <- invalid value 1980 except ValueError as e: 1981 if "DORGQR parameter number 5" in str(e): 1982 # success, reuse error code to mark success as 1983 # FORTRAN STOP returns as success. 1984 os._exit(XERBLA_OK) 1985 1986 # Did not abort, but our xerbla was not linked in. 1987 os._exit(os.EX_CONFIG) 1988 else: 1989 # parent 1990 pid, status = os.wait() 1991 if os.WEXITSTATUS(status) != XERBLA_OK: 1992 raise SkipTest("Numpy xerbla not linked in.") 1993 1994 @pytest.mark.skipif(IS_WASM, reason="Cannot start subprocess") 1995 @slow 1996 def test_sdot_bug_8577(self): 1997 # Regression test that loading certain other libraries does not 1998 # result to wrong results in float32 linear algebra. 1999 # 2000 # There's a bug gh-8577 on OSX that can trigger this, and perhaps 2001 # there are also other situations in which it occurs. 2002 # 2003 # Do the check in a separate process. 2004 2005 bad_libs = ["PyQt5.QtWidgets", "IPython"] 2006 2007 template = textwrap.dedent( 2008 """ 2009 import sys 2010 {before} 2011 try: 2012 import {bad_lib} 2013 except ImportError: 2014 sys.exit(0) 2015 {after} 2016 x = np.ones(2, dtype=np.float32) 2017 sys.exit(0 if np.allclose(x.dot(x), 2.0) else 1) 2018 """ 2019 ) 2020 2021 for bad_lib in bad_libs: 2022 code = template.format( 2023 before="import numpy as np", after="", bad_lib=bad_lib 2024 ) 2025 subprocess.check_call([sys.executable, "-c", code]) 2026 2027 # Swapped import order 2028 code = template.format( 2029 after="import numpy as np", before="", bad_lib=bad_lib 2030 ) 2031 subprocess.check_call([sys.executable, "-c", code]) 2032 2033 2034class TestMultiDot(TestCase): 2035 def test_basic_function_with_three_arguments(self): 2036 # multi_dot with three arguments uses a fast hand coded algorithm to 2037 # determine the optimal order. Therefore test it separately. 2038 A = np.random.random((6, 2)) 2039 B = np.random.random((2, 6)) 2040 C = np.random.random((6, 2)) 2041 2042 assert_almost_equal(multi_dot([A, B, C]), A.dot(B).dot(C)) 2043 assert_almost_equal(multi_dot([A, B, C]), np.dot(A, np.dot(B, C))) 2044 2045 def test_basic_function_with_two_arguments(self): 2046 # separate code path with two arguments 2047 A = np.random.random((6, 2)) 2048 B = np.random.random((2, 6)) 2049 2050 assert_almost_equal(multi_dot([A, B]), A.dot(B)) 2051 assert_almost_equal(multi_dot([A, B]), np.dot(A, B)) 2052 2053 def test_basic_function_with_dynamic_programming_optimization(self): 2054 # multi_dot with four or more arguments uses the dynamic programming 2055 # optimization and therefore deserve a separate 2056 A = np.random.random((6, 2)) 2057 B = np.random.random((2, 6)) 2058 C = np.random.random((6, 2)) 2059 D = np.random.random((2, 1)) 2060 assert_almost_equal(multi_dot([A, B, C, D]), A.dot(B).dot(C).dot(D)) 2061 2062 def test_vector_as_first_argument(self): 2063 # The first argument can be 1-D 2064 A1d = np.random.random(2) # 1-D 2065 B = np.random.random((2, 6)) 2066 C = np.random.random((6, 2)) 2067 D = np.random.random((2, 2)) 2068 2069 # the result should be 1-D 2070 assert_equal(multi_dot([A1d, B, C, D]).shape, (2,)) 2071 2072 def test_vector_as_last_argument(self): 2073 # The last argument can be 1-D 2074 A = np.random.random((6, 2)) 2075 B = np.random.random((2, 6)) 2076 C = np.random.random((6, 2)) 2077 D1d = np.random.random(2) # 1-D 2078 2079 # the result should be 1-D 2080 assert_equal(multi_dot([A, B, C, D1d]).shape, (6,)) 2081 2082 def test_vector_as_first_and_last_argument(self): 2083 # The first and last arguments can be 1-D 2084 A1d = np.random.random(2) # 1-D 2085 B = np.random.random((2, 6)) 2086 C = np.random.random((6, 2)) 2087 D1d = np.random.random(2) # 1-D 2088 2089 # the result should be a scalar 2090 assert_equal(multi_dot([A1d, B, C, D1d]).shape, ()) 2091 2092 def test_three_arguments_and_out(self): 2093 # multi_dot with three arguments uses a fast hand coded algorithm to 2094 # determine the optimal order. Therefore test it separately. 2095 A = np.random.random((6, 2)) 2096 B = np.random.random((2, 6)) 2097 C = np.random.random((6, 2)) 2098 2099 out = np.zeros((6, 2)) 2100 ret = multi_dot([A, B, C], out=out) 2101 assert out is ret 2102 assert_almost_equal(out, A.dot(B).dot(C)) 2103 assert_almost_equal(out, np.dot(A, np.dot(B, C))) 2104 2105 def test_two_arguments_and_out(self): 2106 # separate code path with two arguments 2107 A = np.random.random((6, 2)) 2108 B = np.random.random((2, 6)) 2109 out = np.zeros((6, 6)) 2110 ret = multi_dot([A, B], out=out) 2111 assert out is ret 2112 assert_almost_equal(out, A.dot(B)) 2113 assert_almost_equal(out, np.dot(A, B)) 2114 2115 def test_dynamic_programming_optimization_and_out(self): 2116 # multi_dot with four or more arguments uses the dynamic programming 2117 # optimization and therefore deserve a separate test 2118 A = np.random.random((6, 2)) 2119 B = np.random.random((2, 6)) 2120 C = np.random.random((6, 2)) 2121 D = np.random.random((2, 1)) 2122 out = np.zeros((6, 1)) 2123 ret = multi_dot([A, B, C, D], out=out) 2124 assert out is ret 2125 assert_almost_equal(out, A.dot(B).dot(C).dot(D)) 2126 2127 def test_dynamic_programming_logic(self): 2128 # Test for the dynamic programming part 2129 # This test is directly taken from Cormen page 376. 2130 arrays = [ 2131 np.random.random((30, 35)), 2132 np.random.random((35, 15)), 2133 np.random.random((15, 5)), 2134 np.random.random((5, 10)), 2135 np.random.random((10, 20)), 2136 np.random.random((20, 25)), 2137 ] 2138 m_expected = np.array( 2139 [ 2140 [0.0, 15750.0, 7875.0, 9375.0, 11875.0, 15125.0], 2141 [0.0, 0.0, 2625.0, 4375.0, 7125.0, 10500.0], 2142 [0.0, 0.0, 0.0, 750.0, 2500.0, 5375.0], 2143 [0.0, 0.0, 0.0, 0.0, 1000.0, 3500.0], 2144 [0.0, 0.0, 0.0, 0.0, 0.0, 5000.0], 2145 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 2146 ] 2147 ) 2148 s_expected = np.array( 2149 [ 2150 [0, 1, 1, 3, 3, 3], 2151 [0, 0, 2, 3, 3, 3], 2152 [0, 0, 0, 3, 3, 3], 2153 [0, 0, 0, 0, 4, 5], 2154 [0, 0, 0, 0, 0, 5], 2155 [0, 0, 0, 0, 0, 0], 2156 ], 2157 dtype=int, 2158 ) 2159 s_expected -= 1 # Cormen uses 1-based index, python does not. 2160 2161 s, m = _multi_dot_matrix_chain_order(arrays, return_costs=True) 2162 2163 # Only the upper triangular part (without the diagonal) is interesting. 2164 assert_almost_equal(np.triu(s[:-1, 1:]), np.triu(s_expected[:-1, 1:])) 2165 assert_almost_equal(np.triu(m), np.triu(m_expected)) 2166 2167 def test_too_few_input_arrays(self): 2168 assert_raises((RuntimeError, ValueError), multi_dot, []) 2169 assert_raises((RuntimeError, ValueError), multi_dot, [np.random.random((3, 3))]) 2170 2171 2172@instantiate_parametrized_tests 2173class TestTensorinv(TestCase): 2174 @parametrize( 2175 "arr, ind", 2176 [ 2177 (np.ones((4, 6, 8, 2)), 2), 2178 (np.ones((3, 3, 2)), 1), 2179 ], 2180 ) 2181 def test_non_square_handling(self, arr, ind): 2182 with assert_raises((LinAlgError, RuntimeError)): 2183 linalg.tensorinv(arr, ind=ind) 2184 2185 @parametrize( 2186 "shape, ind", 2187 [ 2188 # examples from docstring 2189 ((4, 6, 8, 3), 2), 2190 ((24, 8, 3), 1), 2191 ], 2192 ) 2193 def test_tensorinv_shape(self, shape, ind): 2194 a = np.eye(24).reshape(shape) 2195 ainv = linalg.tensorinv(a=a, ind=ind) 2196 expected = a.shape[ind:] + a.shape[:ind] 2197 actual = ainv.shape 2198 assert_equal(actual, expected) 2199 2200 @parametrize( 2201 "ind", 2202 [ 2203 0, 2204 -2, 2205 ], 2206 ) 2207 def test_tensorinv_ind_limit(self, ind): 2208 a = np.eye(24).reshape(4, 6, 8, 3) 2209 with assert_raises((ValueError, RuntimeError)): 2210 linalg.tensorinv(a=a, ind=ind) 2211 2212 def test_tensorinv_result(self): 2213 # mimic a docstring example 2214 a = np.eye(24).reshape(24, 8, 3) 2215 ainv = linalg.tensorinv(a, ind=1) 2216 b = np.ones(24) 2217 assert_allclose(np.tensordot(ainv, b, 1), np.linalg.tensorsolve(a, b)) 2218 2219 2220@instantiate_parametrized_tests 2221class TestTensorsolve(TestCase): 2222 @parametrize( 2223 "a, axes", 2224 [ 2225 (np.ones((4, 6, 8, 2)), None), 2226 (np.ones((3, 3, 2)), (0, 2)), 2227 ], 2228 ) 2229 def test_non_square_handling(self, a, axes): 2230 with assert_raises((LinAlgError, RuntimeError)): 2231 b = np.ones(a.shape[:2]) 2232 linalg.tensorsolve(a, b, axes=axes) 2233 2234 @skipif(numpy.__version__ < "1.22", reason="NP_VER: fails on CI with numpy 1.21.2") 2235 @parametrize( 2236 "shape", 2237 [(2, 3, 6), (3, 4, 4, 3), (0, 3, 3, 0)], 2238 ) 2239 def test_tensorsolve_result(self, shape): 2240 a = np.random.randn(*shape) 2241 b = np.ones(a.shape[:2]) 2242 x = np.linalg.tensorsolve(a, b) 2243 assert_allclose(np.tensordot(a, x, axes=len(x.shape)), b) 2244 2245 2246class TestMisc2(TestCase): 2247 @xpassIfTorchDynamo # (reason="TODO") 2248 def test_unsupported_commontype(self): 2249 # linalg gracefully handles unsupported type 2250 arr = np.array([[1, -2], [2, 5]], dtype="float16") 2251 # with assert_raises_regex(TypeError, "unsupported in linalg"): 2252 with assert_raises(TypeError): 2253 linalg.cholesky(arr) 2254 2255 # @slow 2256 # @pytest.mark.xfail(not HAS_LAPACK64, run=False, 2257 # reason="Numpy not compiled with 64-bit BLAS/LAPACK") 2258 # @requires_memory(free_bytes=16e9) 2259 @skip(reason="Bad memory reports lead to OOM in ci testing") 2260 def test_blas64_dot(self): 2261 n = 2**32 2262 a = np.zeros([1, n], dtype=np.float32) 2263 b = np.ones([1, 1], dtype=np.float32) 2264 a[0, -1] = 1 2265 c = np.dot(b, a) 2266 assert_equal(c[0, -1], 1) 2267 2268 @skip(reason="lapack-lite specific") 2269 @xfail # ( 2270 # not HAS_LAPACK64, reason="Numpy not compiled with 64-bit BLAS/LAPACK" 2271 # ) 2272 def test_blas64_geqrf_lwork_smoketest(self): 2273 # Smoke test LAPACK geqrf lwork call with 64-bit integers 2274 dtype = np.float64 2275 lapack_routine = np.linalg.lapack_lite.dgeqrf 2276 2277 m = 2**32 + 1 2278 n = 2**32 + 1 2279 lda = m 2280 2281 # Dummy arrays, not referenced by the lapack routine, so don't 2282 # need to be of the right size 2283 a = np.zeros([1, 1], dtype=dtype) 2284 work = np.zeros([1], dtype=dtype) 2285 tau = np.zeros([1], dtype=dtype) 2286 2287 # Size query 2288 results = lapack_routine(m, n, a, lda, tau, work, -1, 0) 2289 assert_equal(results["info"], 0) 2290 assert_equal(results["m"], m) 2291 assert_equal(results["n"], m) 2292 2293 # Should result to an integer of a reasonable size 2294 lwork = int(work.item()) 2295 assert_(2**32 < lwork < 2**42) 2296 2297 2298if __name__ == "__main__": 2299 run_tests() 2300