xref: /aosp_15_r20/external/pytorch/test/torch_np/numpy_tests/linalg/test_linalg.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: dynamo"]
2
3""" Test functions for linalg module
4
5"""
6import functools
7import itertools
8import os
9import subprocess
10import sys
11import textwrap
12import traceback
13from unittest import expectedFailure as xfail, skipIf as skipif, SkipTest
14
15import numpy
16import pytest
17from numpy.linalg.linalg import _multi_dot_matrix_chain_order
18from pytest import raises as assert_raises
19
20from torch.testing._internal.common_utils import (
21    instantiate_parametrized_tests,
22    parametrize,
23    run_tests,
24    slowTest as slow,
25    TEST_WITH_TORCHDYNAMO,
26    TestCase,
27    xpassIfTorchDynamo,
28)
29
30
31# If we are going to trace through these, we should use NumPy
32# If testing on eager mode, we use torch._numpy
33if TEST_WITH_TORCHDYNAMO:
34    import numpy as np
35    from numpy import (
36        array,
37        asarray,
38        atleast_2d,
39        cdouble,
40        csingle,
41        dot,
42        double,
43        identity,
44        inf,
45        linalg,
46        matmul,
47        single,
48        swapaxes,
49    )
50    from numpy.linalg import LinAlgError, matrix_power, matrix_rank, multi_dot, norm
51    from numpy.testing import (  # assert_raises_regex, HAS_LAPACK64, IS_WASM
52        assert_,
53        assert_allclose,
54        assert_almost_equal,
55        assert_array_equal,
56        assert_equal,
57        suppress_warnings,
58    )
59
60else:
61    import torch._numpy as np
62    from torch._numpy import (
63        array,
64        asarray,
65        atleast_2d,
66        cdouble,
67        csingle,
68        dot,
69        double,
70        identity,
71        inf,
72        linalg,
73        matmul,
74        single,
75        swapaxes,
76    )
77    from torch._numpy.linalg import (
78        LinAlgError,
79        matrix_power,
80        matrix_rank,
81        multi_dot,
82        norm,
83    )
84    from torch._numpy.testing import (
85        assert_,
86        assert_allclose,
87        assert_almost_equal,
88        assert_array_equal,
89        assert_equal,
90        suppress_warnings,
91    )
92
93
94skip = functools.partial(skipif, True)
95
96IS_WASM = False
97HAS_LAPACK64 = False
98
99
100def consistent_subclass(out, in_):
101    # For ndarray subclass input, our output should have the same subclass
102    # (non-ndarray input gets converted to ndarray).
103    return type(out) is (type(in_) if isinstance(in_, np.ndarray) else np.ndarray)
104
105
106old_assert_almost_equal = assert_almost_equal
107
108
109def assert_almost_equal(a, b, single_decimal=6, double_decimal=12, **kw):
110    if asarray(a).dtype.type in (single, csingle):
111        decimal = single_decimal
112    else:
113        decimal = double_decimal
114    old_assert_almost_equal(a, b, decimal=decimal, **kw)
115
116
117def get_real_dtype(dtype):
118    return {single: single, double: double, csingle: single, cdouble: double}[dtype]
119
120
121def get_complex_dtype(dtype):
122    return {single: csingle, double: cdouble, csingle: csingle, cdouble: cdouble}[dtype]
123
124
125def get_rtol(dtype):
126    # Choose a safe rtol
127    if dtype in (single, csingle):
128        return 1e-5
129    else:
130        return 1e-11
131
132
133# used to categorize tests
134all_tags = {
135    "square",
136    "nonsquare",
137    "hermitian",  # mutually exclusive
138    "generalized",
139    "size-0",
140    "strided",  # optional additions
141}
142
143
144class LinalgCase:
145    def __init__(self, name, a, b, tags=None):
146        """
147        A bundle of arguments to be passed to a test case, with an identifying
148        name, the operands a and b, and a set of tags to filter the tests
149        """
150        if tags is None:
151            tags = set()
152        assert_(isinstance(name, str))
153        self.name = name
154        self.a = a
155        self.b = b
156        self.tags = frozenset(tags)  # prevent shared tags
157
158    def check(self, do):
159        """
160        Run the function `do` on this test case, expanding arguments
161        """
162        do(self.a, self.b, tags=self.tags)
163
164    def __repr__(self):
165        return f"<LinalgCase: {self.name}>"
166
167
168def apply_tag(tag, cases):
169    """
170    Add the given tag (a string) to each of the cases (a list of LinalgCase
171    objects)
172    """
173    assert tag in all_tags, "Invalid tag"
174    for case in cases:
175        case.tags = case.tags | {tag}
176    return cases
177
178
179#
180# Base test cases
181#
182
183np.random.seed(1234)
184
185CASES = []
186
187# square test cases
188CASES += apply_tag(
189    "square",
190    [
191        LinalgCase(
192            "single",
193            array([[1.0, 2.0], [3.0, 4.0]], dtype=single),
194            array([2.0, 1.0], dtype=single),
195        ),
196        LinalgCase(
197            "double",
198            array([[1.0, 2.0], [3.0, 4.0]], dtype=double),
199            array([2.0, 1.0], dtype=double),
200        ),
201        LinalgCase(
202            "double_2",
203            array([[1.0, 2.0], [3.0, 4.0]], dtype=double),
204            array([[2.0, 1.0, 4.0], [3.0, 4.0, 6.0]], dtype=double),
205        ),
206        LinalgCase(
207            "csingle",
208            array([[1.0 + 2j, 2 + 3j], [3 + 4j, 4 + 5j]], dtype=csingle),
209            array([2.0 + 1j, 1.0 + 2j], dtype=csingle),
210        ),
211        LinalgCase(
212            "cdouble",
213            array([[1.0 + 2j, 2 + 3j], [3 + 4j, 4 + 5j]], dtype=cdouble),
214            array([2.0 + 1j, 1.0 + 2j], dtype=cdouble),
215        ),
216        LinalgCase(
217            "cdouble_2",
218            array([[1.0 + 2j, 2 + 3j], [3 + 4j, 4 + 5j]], dtype=cdouble),
219            array(
220                [[2.0 + 1j, 1.0 + 2j, 1 + 3j], [1 - 2j, 1 - 3j, 1 - 6j]], dtype=cdouble
221            ),
222        ),
223        LinalgCase(
224            "0x0",
225            np.empty((0, 0), dtype=double),
226            np.empty((0,), dtype=double),
227            tags={"size-0"},
228        ),
229        LinalgCase("8x8", np.random.rand(8, 8), np.random.rand(8)),
230        LinalgCase("1x1", np.random.rand(1, 1), np.random.rand(1)),
231        LinalgCase("nonarray", [[1, 2], [3, 4]], [2, 1]),
232    ],
233)
234
235# non-square test-cases
236CASES += apply_tag(
237    "nonsquare",
238    [
239        LinalgCase(
240            "single_nsq_1",
241            array([[1.0, 2.0, 3.0], [3.0, 4.0, 6.0]], dtype=single),
242            array([2.0, 1.0], dtype=single),
243        ),
244        LinalgCase(
245            "single_nsq_2",
246            array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=single),
247            array([2.0, 1.0, 3.0], dtype=single),
248        ),
249        LinalgCase(
250            "double_nsq_1",
251            array([[1.0, 2.0, 3.0], [3.0, 4.0, 6.0]], dtype=double),
252            array([2.0, 1.0], dtype=double),
253        ),
254        LinalgCase(
255            "double_nsq_2",
256            array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=double),
257            array([2.0, 1.0, 3.0], dtype=double),
258        ),
259        LinalgCase(
260            "csingle_nsq_1",
261            array(
262                [[1.0 + 1j, 2.0 + 2j, 3.0 - 3j], [3.0 - 5j, 4.0 + 9j, 6.0 + 2j]],
263                dtype=csingle,
264            ),
265            array([2.0 + 1j, 1.0 + 2j], dtype=csingle),
266        ),
267        LinalgCase(
268            "csingle_nsq_2",
269            array(
270                [[1.0 + 1j, 2.0 + 2j], [3.0 - 3j, 4.0 - 9j], [5.0 - 4j, 6.0 + 8j]],
271                dtype=csingle,
272            ),
273            array([2.0 + 1j, 1.0 + 2j, 3.0 - 3j], dtype=csingle),
274        ),
275        LinalgCase(
276            "cdouble_nsq_1",
277            array(
278                [[1.0 + 1j, 2.0 + 2j, 3.0 - 3j], [3.0 - 5j, 4.0 + 9j, 6.0 + 2j]],
279                dtype=cdouble,
280            ),
281            array([2.0 + 1j, 1.0 + 2j], dtype=cdouble),
282        ),
283        LinalgCase(
284            "cdouble_nsq_2",
285            array(
286                [[1.0 + 1j, 2.0 + 2j], [3.0 - 3j, 4.0 - 9j], [5.0 - 4j, 6.0 + 8j]],
287                dtype=cdouble,
288            ),
289            array([2.0 + 1j, 1.0 + 2j, 3.0 - 3j], dtype=cdouble),
290        ),
291        LinalgCase(
292            "cdouble_nsq_1_2",
293            array(
294                [[1.0 + 1j, 2.0 + 2j, 3.0 - 3j], [3.0 - 5j, 4.0 + 9j, 6.0 + 2j]],
295                dtype=cdouble,
296            ),
297            array([[2.0 + 1j, 1.0 + 2j], [1 - 1j, 2 - 2j]], dtype=cdouble),
298        ),
299        LinalgCase(
300            "cdouble_nsq_2_2",
301            array(
302                [[1.0 + 1j, 2.0 + 2j], [3.0 - 3j, 4.0 - 9j], [5.0 - 4j, 6.0 + 8j]],
303                dtype=cdouble,
304            ),
305            array(
306                [[2.0 + 1j, 1.0 + 2j], [1 - 1j, 2 - 2j], [1 - 1j, 2 - 2j]],
307                dtype=cdouble,
308            ),
309        ),
310        LinalgCase("8x11", np.random.rand(8, 11), np.random.rand(8)),
311        LinalgCase("1x5", np.random.rand(1, 5), np.random.rand(1)),
312        LinalgCase("5x1", np.random.rand(5, 1), np.random.rand(5)),
313        LinalgCase("0x4", np.random.rand(0, 4), np.random.rand(0), tags={"size-0"}),
314        LinalgCase("4x0", np.random.rand(4, 0), np.random.rand(4), tags={"size-0"}),
315    ],
316)
317
318# hermitian test-cases
319CASES += apply_tag(
320    "hermitian",
321    [
322        LinalgCase("hsingle", array([[1.0, 2.0], [2.0, 1.0]], dtype=single), None),
323        LinalgCase("hdouble", array([[1.0, 2.0], [2.0, 1.0]], dtype=double), None),
324        LinalgCase(
325            "hcsingle", array([[1.0, 2 + 3j], [2 - 3j, 1]], dtype=csingle), None
326        ),
327        LinalgCase(
328            "hcdouble", array([[1.0, 2 + 3j], [2 - 3j, 1]], dtype=cdouble), None
329        ),
330        LinalgCase("hempty", np.empty((0, 0), dtype=double), None, tags={"size-0"}),
331        LinalgCase("hnonarray", [[1, 2], [2, 1]], None),
332        LinalgCase("matrix_b_only", array([[1.0, 2.0], [2.0, 1.0]]), None),
333        LinalgCase("hmatrix_1x1", np.random.rand(1, 1), None),
334    ],
335)
336
337
338#
339# Gufunc test cases
340#
341def _make_generalized_cases():
342    new_cases = []
343
344    for case in CASES:
345        if not isinstance(case.a, np.ndarray):
346            continue
347
348        a = np.stack([case.a, 2 * case.a, 3 * case.a])
349        if case.b is None:
350            b = None
351        else:
352            b = np.stack([case.b, 7 * case.b, 6 * case.b])
353        new_case = LinalgCase(
354            case.name + "_tile3", a, b, tags=case.tags | {"generalized"}
355        )
356        new_cases.append(new_case)
357
358        a = np.array([case.a] * 2 * 3).reshape((3, 2) + case.a.shape)
359        if case.b is None:
360            b = None
361        else:
362            b = np.array([case.b] * 2 * 3).reshape((3, 2) + case.b.shape)
363        new_case = LinalgCase(
364            case.name + "_tile213", a, b, tags=case.tags | {"generalized"}
365        )
366        new_cases.append(new_case)
367
368    return new_cases
369
370
371CASES += _make_generalized_cases()
372
373
374#
375# Test different routines against the above cases
376#
377class LinalgTestCase:
378    TEST_CASES = CASES
379
380    def check_cases(self, require=None, exclude=None):
381        """
382        Run func on each of the cases with all of the tags in require, and none
383        of the tags in exclude
384        """
385        if require is None:
386            require = set()
387        if exclude is None:
388            exclude = set()
389        for case in self.TEST_CASES:
390            # filter by require and exclude
391            if case.tags & require != require:
392                continue
393            if case.tags & exclude:
394                continue
395
396            try:
397                case.check(self.do)
398            except Exception as e:
399                msg = f"In test case: {case!r}\n\n"
400                msg += traceback.format_exc()
401                raise AssertionError(msg) from e
402
403
404class LinalgSquareTestCase(LinalgTestCase):
405    def test_sq_cases(self):
406        self.check_cases(require={"square"}, exclude={"generalized", "size-0"})
407
408    def test_empty_sq_cases(self):
409        self.check_cases(require={"square", "size-0"}, exclude={"generalized"})
410
411
412class LinalgNonsquareTestCase(LinalgTestCase):
413    def test_nonsq_cases(self):
414        self.check_cases(require={"nonsquare"}, exclude={"generalized", "size-0"})
415
416    def test_empty_nonsq_cases(self):
417        self.check_cases(require={"nonsquare", "size-0"}, exclude={"generalized"})
418
419
420class HermitianTestCase(LinalgTestCase):
421    def test_herm_cases(self):
422        self.check_cases(require={"hermitian"}, exclude={"generalized", "size-0"})
423
424    def test_empty_herm_cases(self):
425        self.check_cases(require={"hermitian", "size-0"}, exclude={"generalized"})
426
427
428class LinalgGeneralizedSquareTestCase(LinalgTestCase):
429    @slow
430    def test_generalized_sq_cases(self):
431        self.check_cases(require={"generalized", "square"}, exclude={"size-0"})
432
433    @slow
434    def test_generalized_empty_sq_cases(self):
435        self.check_cases(require={"generalized", "square", "size-0"})
436
437
438class LinalgGeneralizedNonsquareTestCase(LinalgTestCase):
439    @slow
440    def test_generalized_nonsq_cases(self):
441        self.check_cases(require={"generalized", "nonsquare"}, exclude={"size-0"})
442
443    @slow
444    def test_generalized_empty_nonsq_cases(self):
445        self.check_cases(require={"generalized", "nonsquare", "size-0"})
446
447
448class HermitianGeneralizedTestCase(LinalgTestCase):
449    @slow
450    def test_generalized_herm_cases(self):
451        self.check_cases(require={"generalized", "hermitian"}, exclude={"size-0"})
452
453    @slow
454    def test_generalized_empty_herm_cases(self):
455        self.check_cases(
456            require={"generalized", "hermitian", "size-0"}, exclude={"none"}
457        )
458
459
460def dot_generalized(a, b):
461    a = asarray(a)
462    if a.ndim >= 3:
463        if a.ndim == b.ndim:
464            # matrix x matrix
465            new_shape = a.shape[:-1] + b.shape[-1:]
466        elif a.ndim == b.ndim + 1:
467            # matrix x vector
468            new_shape = a.shape[:-1]
469        else:
470            raise ValueError("Not implemented...")
471        r = np.empty(new_shape, dtype=np.common_type(a, b))
472        for c in itertools.product(*map(range, a.shape[:-2])):
473            r[c] = dot(a[c], b[c])
474        return r
475    else:
476        return dot(a, b)
477
478
479def identity_like_generalized(a):
480    a = asarray(a)
481    if a.ndim >= 3:
482        r = np.empty(a.shape, dtype=a.dtype)
483        r[...] = identity(a.shape[-2])
484        return r
485    else:
486        return identity(a.shape[0])
487
488
489class SolveCases(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase):
490    # kept apart from TestSolve for use for testing with matrices.
491    def do(self, a, b, tags):
492        x = linalg.solve(a, b)
493        assert_almost_equal(b, dot_generalized(a, x))
494        assert_(consistent_subclass(x, b))
495
496
497@instantiate_parametrized_tests
498class TestSolve(SolveCases, TestCase):
499    @parametrize("dtype", [single, double, csingle, cdouble])
500    def test_types(self, dtype):
501        x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype)
502        assert_equal(linalg.solve(x, x).dtype, dtype)
503
504    @skip(reason="subclass")
505    def test_0_size(self):
506        class ArraySubclass(np.ndarray):
507            pass
508
509        # Test system of 0x0 matrices
510        a = np.arange(8).reshape(2, 2, 2)
511        b = np.arange(6).reshape(1, 2, 3).view(ArraySubclass)
512
513        expected = linalg.solve(a, b)[:, 0:0, :]
514        result = linalg.solve(a[:, 0:0, 0:0], b[:, 0:0, :])
515        assert_array_equal(result, expected)
516        assert_(isinstance(result, ArraySubclass))
517
518        # Test errors for non-square and only b's dimension being 0
519        assert_raises(linalg.LinAlgError, linalg.solve, a[:, 0:0, 0:1], b)
520        assert_raises(ValueError, linalg.solve, a, b[:, 0:0, :])
521
522        # Test broadcasting error
523        b = np.arange(6).reshape(1, 3, 2)  # broadcasting error
524        assert_raises(ValueError, linalg.solve, a, b)
525        assert_raises(ValueError, linalg.solve, a[0:0], b[0:0])
526
527        # Test zero "single equations" with 0x0 matrices.
528        b = np.arange(2).reshape(1, 2).view(ArraySubclass)
529        expected = linalg.solve(a, b)[:, 0:0]
530        result = linalg.solve(a[:, 0:0, 0:0], b[:, 0:0])
531        assert_array_equal(result, expected)
532        assert_(isinstance(result, ArraySubclass))
533
534        b = np.arange(3).reshape(1, 3)
535        assert_raises(ValueError, linalg.solve, a, b)
536        assert_raises(ValueError, linalg.solve, a[0:0], b[0:0])
537        assert_raises(ValueError, linalg.solve, a[:, 0:0, 0:0], b)
538
539    @skip(reason="subclass")
540    def test_0_size_k(self):
541        # test zero multiple equation (K=0) case.
542        class ArraySubclass(np.ndarray):
543            pass
544
545        a = np.arange(4).reshape(1, 2, 2)
546        b = np.arange(6).reshape(3, 2, 1).view(ArraySubclass)
547
548        expected = linalg.solve(a, b)[:, :, 0:0]
549        result = linalg.solve(a, b[:, :, 0:0])
550        assert_array_equal(result, expected)
551        assert_(isinstance(result, ArraySubclass))
552
553        # test both zero.
554        expected = linalg.solve(a, b)[:, 0:0, 0:0]
555        result = linalg.solve(a[:, 0:0, 0:0], b[:, 0:0, 0:0])
556        assert_array_equal(result, expected)
557        assert_(isinstance(result, ArraySubclass))
558
559
560class InvCases(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase):
561    def do(self, a, b, tags):
562        a_inv = linalg.inv(a)
563        assert_almost_equal(dot_generalized(a, a_inv), identity_like_generalized(a))
564        assert_(consistent_subclass(a_inv, a))
565
566
567@instantiate_parametrized_tests
568class TestInv(InvCases, TestCase):
569    @parametrize("dtype", [single, double, csingle, cdouble])
570    def test_types(self, dtype):
571        x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype)
572        assert_equal(linalg.inv(x).dtype, dtype)
573
574    @skip(reason="subclass")
575    def test_0_size(self):
576        # Check that all kinds of 0-sized arrays work
577        class ArraySubclass(np.ndarray):
578            pass
579
580        a = np.zeros((0, 1, 1), dtype=np.int_).view(ArraySubclass)
581        res = linalg.inv(a)
582        assert_(res.dtype.type is np.float64)
583        assert_equal(a.shape, res.shape)
584        assert_(isinstance(res, ArraySubclass))
585
586        a = np.zeros((0, 0), dtype=np.complex64).view(ArraySubclass)
587        res = linalg.inv(a)
588        assert_(res.dtype.type is np.complex64)
589        assert_equal(a.shape, res.shape)
590        assert_(isinstance(res, ArraySubclass))
591
592
593class EigvalsCases(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase):
594    def do(self, a, b, tags):
595        ev = linalg.eigvals(a)
596        evalues, evectors = linalg.eig(a)
597        assert_almost_equal(ev, evalues)
598
599
600@instantiate_parametrized_tests
601class TestEigvals(EigvalsCases, TestCase):
602    @parametrize("dtype", [single, double, csingle, cdouble])
603    def test_types(self, dtype):
604        x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype)
605        assert_equal(linalg.eigvals(x).dtype, dtype)
606        x = np.array([[1, 0.5], [-1, 1]], dtype=dtype)
607        assert_equal(linalg.eigvals(x).dtype, get_complex_dtype(dtype))
608
609    @skip(reason="subclass")
610    def test_0_size(self):
611        # Check that all kinds of 0-sized arrays work
612        class ArraySubclass(np.ndarray):
613            pass
614
615        a = np.zeros((0, 1, 1), dtype=np.int_).view(ArraySubclass)
616        res = linalg.eigvals(a)
617        assert_(res.dtype.type is np.float64)
618        assert_equal((0, 1), res.shape)
619        # This is just for documentation, it might make sense to change:
620        assert_(isinstance(res, np.ndarray))
621
622        a = np.zeros((0, 0), dtype=np.complex64).view(ArraySubclass)
623        res = linalg.eigvals(a)
624        assert_(res.dtype.type is np.complex64)
625        assert_equal((0,), res.shape)
626        # This is just for documentation, it might make sense to change:
627        assert_(isinstance(res, np.ndarray))
628
629
630class EigCases(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase):
631    def do(self, a, b, tags):
632        evalues, evectors = linalg.eig(a)
633        assert_allclose(
634            dot_generalized(a, evectors),
635            np.asarray(evectors) * np.asarray(evalues)[..., None, :],
636            rtol=get_rtol(evalues.dtype),
637        )
638        assert_(consistent_subclass(evectors, a))
639
640
641@instantiate_parametrized_tests
642class TestEig(EigCases, TestCase):
643    @parametrize("dtype", [single, double, csingle, cdouble])
644    def test_types(self, dtype):
645        x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype)
646        w, v = np.linalg.eig(x)
647        assert_equal(w.dtype, dtype)
648        assert_equal(v.dtype, dtype)
649
650        x = np.array([[1, 0.5], [-1, 1]], dtype=dtype)
651        w, v = np.linalg.eig(x)
652        assert_equal(w.dtype, get_complex_dtype(dtype))
653        assert_equal(v.dtype, get_complex_dtype(dtype))
654
655    @skip(reason="subclass")
656    def test_0_size(self):
657        # Check that all kinds of 0-sized arrays work
658        class ArraySubclass(np.ndarray):
659            pass
660
661        a = np.zeros((0, 1, 1), dtype=np.int_).view(ArraySubclass)
662        res, res_v = linalg.eig(a)
663        assert_(res_v.dtype.type is np.float64)
664        assert_(res.dtype.type is np.float64)
665        assert_equal(a.shape, res_v.shape)
666        assert_equal((0, 1), res.shape)
667        # This is just for documentation, it might make sense to change:
668        assert_(isinstance(a, np.ndarray))
669
670        a = np.zeros((0, 0), dtype=np.complex64).view(ArraySubclass)
671        res, res_v = linalg.eig(a)
672        assert_(res_v.dtype.type is np.complex64)
673        assert_(res.dtype.type is np.complex64)
674        assert_equal(a.shape, res_v.shape)
675        assert_equal((0,), res.shape)
676        # This is just for documentation, it might make sense to change:
677        assert_(isinstance(a, np.ndarray))
678
679
680@instantiate_parametrized_tests
681class SVDBaseTests:
682    hermitian = False
683
684    @parametrize("dtype", [single, double, csingle, cdouble])
685    def test_types(self, dtype):
686        x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype)
687        u, s, vh = linalg.svd(x)
688        assert_equal(u.dtype, dtype)
689        assert_equal(s.dtype, get_real_dtype(dtype))
690        assert_equal(vh.dtype, dtype)
691        s = linalg.svd(x, compute_uv=False, hermitian=self.hermitian)
692        assert_equal(s.dtype, get_real_dtype(dtype))
693
694
695class SVDCases(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase):
696    def do(self, a, b, tags):
697        u, s, vt = linalg.svd(a, False)
698        assert_allclose(
699            a,
700            dot_generalized(
701                np.asarray(u) * np.asarray(s)[..., None, :], np.asarray(vt)
702            ),
703            rtol=get_rtol(u.dtype),
704        )
705        assert_(consistent_subclass(u, a))
706        assert_(consistent_subclass(vt, a))
707
708
709class TestSVD(SVDCases, SVDBaseTests, TestCase):
710    def test_empty_identity(self):
711        """Empty input should put an identity matrix in u or vh"""
712        x = np.empty((4, 0))
713        u, s, vh = linalg.svd(x, compute_uv=True, hermitian=self.hermitian)
714        assert_equal(u.shape, (4, 4))
715        assert_equal(vh.shape, (0, 0))
716        assert_equal(u, np.eye(4))
717
718        x = np.empty((0, 4))
719        u, s, vh = linalg.svd(x, compute_uv=True, hermitian=self.hermitian)
720        assert_equal(u.shape, (0, 0))
721        assert_equal(vh.shape, (4, 4))
722        assert_equal(vh, np.eye(4))
723
724
725class SVDHermitianCases(HermitianTestCase, HermitianGeneralizedTestCase):
726    def do(self, a, b, tags):
727        u, s, vt = linalg.svd(a, False, hermitian=True)
728        assert_allclose(
729            a,
730            dot_generalized(
731                np.asarray(u) * np.asarray(s)[..., None, :], np.asarray(vt)
732            ),
733            rtol=get_rtol(u.dtype),
734        )
735
736        def hermitian(mat):
737            axes = list(range(mat.ndim))
738            axes[-1], axes[-2] = axes[-2], axes[-1]
739            return np.conj(np.transpose(mat, axes=axes))
740
741        assert_almost_equal(
742            np.matmul(u, hermitian(u)), np.broadcast_to(np.eye(u.shape[-1]), u.shape)
743        )
744        assert_almost_equal(
745            np.matmul(vt, hermitian(vt)),
746            np.broadcast_to(np.eye(vt.shape[-1]), vt.shape),
747        )
748        assert_equal(np.sort(s), np.flip(s, -1))
749        assert_(consistent_subclass(u, a))
750        assert_(consistent_subclass(vt, a))
751
752
753class TestSVDHermitian(SVDHermitianCases, SVDBaseTests, TestCase):
754    hermitian = True
755
756
757class CondCases(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase):
758    # cond(x, p) for p in (None, 2, -2)
759
760    def do(self, a, b, tags):
761        c = asarray(a)  # a might be a matrix
762        if "size-0" in tags:
763            assert_raises(LinAlgError, linalg.cond, c)
764            return
765
766        # +-2 norms
767        s = linalg.svd(c, compute_uv=False)
768        assert_almost_equal(
769            linalg.cond(a), s[..., 0] / s[..., -1], single_decimal=5, double_decimal=11
770        )
771        assert_almost_equal(
772            linalg.cond(a, 2),
773            s[..., 0] / s[..., -1],
774            single_decimal=5,
775            double_decimal=11,
776        )
777        assert_almost_equal(
778            linalg.cond(a, -2),
779            s[..., -1] / s[..., 0],
780            single_decimal=5,
781            double_decimal=11,
782        )
783
784        # Other norms
785        cinv = np.linalg.inv(c)
786        assert_almost_equal(
787            linalg.cond(a, 1),
788            abs(c).sum(-2).max(-1) * abs(cinv).sum(-2).max(-1),
789            single_decimal=5,
790            double_decimal=11,
791        )
792        assert_almost_equal(
793            linalg.cond(a, -1),
794            abs(c).sum(-2).min(-1) * abs(cinv).sum(-2).min(-1),
795            single_decimal=5,
796            double_decimal=11,
797        )
798        assert_almost_equal(
799            linalg.cond(a, np.inf),
800            abs(c).sum(-1).max(-1) * abs(cinv).sum(-1).max(-1),
801            single_decimal=5,
802            double_decimal=11,
803        )
804        assert_almost_equal(
805            linalg.cond(a, -np.inf),
806            abs(c).sum(-1).min(-1) * abs(cinv).sum(-1).min(-1),
807            single_decimal=5,
808            double_decimal=11,
809        )
810        assert_almost_equal(
811            linalg.cond(a, "fro"),
812            np.sqrt((abs(c) ** 2).sum(-1).sum(-1) * (abs(cinv) ** 2).sum(-1).sum(-1)),
813            single_decimal=5,
814            double_decimal=11,
815        )
816
817
818class TestCond(CondCases, TestCase):
819    def test_basic_nonsvd(self):
820        # Smoketest the non-svd norms
821        A = array([[1.0, 0, 1], [0, -2.0, 0], [0, 0, 3.0]])
822        assert_almost_equal(linalg.cond(A, inf), 4)
823        assert_almost_equal(linalg.cond(A, -inf), 2 / 3)
824        assert_almost_equal(linalg.cond(A, 1), 4)
825        assert_almost_equal(linalg.cond(A, -1), 0.5)
826        assert_almost_equal(linalg.cond(A, "fro"), np.sqrt(265 / 12))
827
828    def test_singular(self):
829        # Singular matrices have infinite condition number for
830        # positive norms, and negative norms shouldn't raise
831        # exceptions
832        As = [np.zeros((2, 2)), np.ones((2, 2))]
833        p_pos = [None, 1, 2, "fro"]
834        p_neg = [-1, -2]
835        for A, p in itertools.product(As, p_pos):
836            # Inversion may not hit exact infinity, so just check the
837            # number is large
838            assert_(linalg.cond(A, p) > 1e15)
839        for A, p in itertools.product(As, p_neg):
840            linalg.cond(A, p)
841
842    @skip(reason="NP_VER: fails on CI")  # (
843    #    True, run=False, reason="Platform/LAPACK-dependent failure, see gh-18914"
844    # )
845    def test_nan(self):
846        # nans should be passed through, not converted to infs
847        ps = [None, 1, -1, 2, -2, "fro"]
848        p_pos = [None, 1, 2, "fro"]
849
850        A = np.ones((2, 2))
851        A[0, 1] = np.nan
852        for p in ps:
853            c = linalg.cond(A, p)
854            assert_(isinstance(c, np.float64))
855            assert_(np.isnan(c))
856
857        A = np.ones((3, 2, 2))
858        A[1, 0, 1] = np.nan
859        for p in ps:
860            c = linalg.cond(A, p)
861            assert_(np.isnan(c[1]))
862            if p in p_pos:
863                assert_(c[0] > 1e15)
864                assert_(c[2] > 1e15)
865            else:
866                assert_(not np.isnan(c[0]))
867                assert_(not np.isnan(c[2]))
868
869    def test_stacked_singular(self):
870        # Check behavior when only some of the stacked matrices are
871        # singular
872        np.random.seed(1234)
873        A = np.random.rand(2, 2, 2, 2)
874        A[0, 0] = 0
875        A[1, 1] = 0
876
877        for p in (None, 1, 2, "fro", -1, -2):
878            c = linalg.cond(A, p)
879            assert_equal(c[0, 0], np.inf)
880            assert_equal(c[1, 1], np.inf)
881            assert_(np.isfinite(c[0, 1]))
882            assert_(np.isfinite(c[1, 0]))
883
884
885class PinvCases(
886    LinalgSquareTestCase,
887    LinalgNonsquareTestCase,
888    LinalgGeneralizedSquareTestCase,
889    LinalgGeneralizedNonsquareTestCase,
890):
891    def do(self, a, b, tags):
892        a_ginv = linalg.pinv(a)
893        # `a @ a_ginv == I` does not hold if a is singular
894        dot = dot_generalized
895        assert_almost_equal(
896            dot(dot(a, a_ginv), a), a, single_decimal=5, double_decimal=11
897        )
898        assert_(consistent_subclass(a_ginv, a))
899
900
901class TestPinv(PinvCases, TestCase):
902    pass
903
904
905class PinvHermitianCases(HermitianTestCase, HermitianGeneralizedTestCase):
906    def do(self, a, b, tags):
907        a_ginv = linalg.pinv(a, hermitian=True)
908        # `a @ a_ginv == I` does not hold if a is singular
909        dot = dot_generalized
910        assert_almost_equal(
911            dot(dot(a, a_ginv), a), a, single_decimal=5, double_decimal=11
912        )
913        assert_(consistent_subclass(a_ginv, a))
914
915
916class TestPinvHermitian(PinvHermitianCases, TestCase):
917    pass
918
919
920class DetCases(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase):
921    def do(self, a, b, tags):
922        d = linalg.det(a)
923        (s, ld) = linalg.slogdet(a)
924        if asarray(a).dtype.type in (single, double):
925            ad = asarray(a).astype(double)
926        else:
927            ad = asarray(a).astype(cdouble)
928        ev = linalg.eigvals(ad)
929        assert_almost_equal(d, np.prod(ev, axis=-1))
930        assert_almost_equal(s * np.exp(ld), np.prod(ev, axis=-1), single_decimal=5)
931
932        s = np.atleast_1d(s)
933        ld = np.atleast_1d(ld)
934        m = s != 0
935        assert_almost_equal(np.abs(s[m]), 1)
936        assert_equal(ld[~m], -inf)
937
938
939@instantiate_parametrized_tests
940class TestDet(DetCases, TestCase):
941    def test_zero(self):
942        # NB: comment out tests of type(det) == double : we return zero-dim arrays
943        assert_equal(linalg.det([[0.0]]), 0.0)
944        #    assert_equal(type(linalg.det([[0.0]])), double)
945        assert_equal(linalg.det([[0.0j]]), 0.0)
946        #    assert_equal(type(linalg.det([[0.0j]])), cdouble)
947
948        assert_equal(linalg.slogdet([[0.0]]), (0.0, -inf))
949        #    assert_equal(type(linalg.slogdet([[0.0]])[0]), double)
950        #    assert_equal(type(linalg.slogdet([[0.0]])[1]), double)
951        assert_equal(linalg.slogdet([[0.0j]]), (0.0j, -inf))
952
953    #    assert_equal(type(linalg.slogdet([[0.0j]])[0]), cdouble)
954    #    assert_equal(type(linalg.slogdet([[0.0j]])[1]), double)
955
956    @parametrize("dtype", [single, double, csingle, cdouble])
957    def test_types(self, dtype):
958        x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype)
959        assert_equal(np.linalg.det(x).dtype, dtype)
960        ph, s = np.linalg.slogdet(x)
961        assert_equal(s.dtype, get_real_dtype(dtype))
962        assert_equal(ph.dtype, dtype)
963
964    def test_0_size(self):
965        a = np.zeros((0, 0), dtype=np.complex64)
966        res = linalg.det(a)
967        assert_equal(res, 1.0)
968        assert_(res.dtype.type is np.complex64)
969        res = linalg.slogdet(a)
970        assert_equal(res, (1, 0))
971        assert_(res[0].dtype.type is np.complex64)
972        assert_(res[1].dtype.type is np.float32)
973
974        a = np.zeros((0, 0), dtype=np.float64)
975        res = linalg.det(a)
976        assert_equal(res, 1.0)
977        assert_(res.dtype.type is np.float64)
978        res = linalg.slogdet(a)
979        assert_equal(res, (1, 0))
980        assert_(res[0].dtype.type is np.float64)
981        assert_(res[1].dtype.type is np.float64)
982
983
984class LstsqCases(LinalgSquareTestCase, LinalgNonsquareTestCase):
985    def do(self, a, b, tags):
986        arr = np.asarray(a)
987        m, n = arr.shape
988        u, s, vt = linalg.svd(a, False)
989        x, residuals, rank, sv = linalg.lstsq(a, b, rcond=-1)
990        if m == 0:
991            assert_((x == 0).all())
992        if m <= n:
993            assert_almost_equal(b, dot(a, x), single_decimal=5)
994            assert_equal(rank, m)
995        else:
996            assert_equal(rank, n)
997        #     assert_almost_equal(sv, sv.__array_wrap__(s))
998        if rank == n and m > n:
999            expect_resids = (np.asarray(abs(np.dot(a, x) - b)) ** 2).sum(axis=0)
1000            expect_resids = np.asarray(expect_resids)
1001            if np.asarray(b).ndim == 1:
1002                expect_resids = expect_resids.reshape(
1003                    1,
1004                )
1005                assert_equal(residuals.shape, expect_resids.shape)
1006        else:
1007            expect_resids = np.array([])  # .view(type(x))
1008        assert_almost_equal(residuals, expect_resids, single_decimal=5)
1009        assert_(np.issubdtype(residuals.dtype, np.floating))
1010        assert_(consistent_subclass(x, b))
1011        assert_(consistent_subclass(residuals, b))
1012
1013
1014@instantiate_parametrized_tests
1015class TestLstsq(LstsqCases, TestCase):
1016    @xpassIfTorchDynamo  # (reason="Lstsq: we use the future default =None")
1017    def test_future_rcond(self):
1018        a = np.array(
1019            [
1020                [0.0, 1.0, 0.0, 1.0, 2.0, 0.0],
1021                [0.0, 2.0, 0.0, 0.0, 1.0, 0.0],
1022                [1.0, 0.0, 1.0, 0.0, 0.0, 4.0],
1023                [0.0, 0.0, 0.0, 2.0, 3.0, 0.0],
1024            ]
1025        ).T
1026
1027        b = np.array([1, 0, 0, 0, 0, 0])
1028        with suppress_warnings() as sup:
1029            w = sup.record(FutureWarning, "`rcond` parameter will change")
1030            x, residuals, rank, s = linalg.lstsq(a, b)
1031            assert_(rank == 4)
1032            x, residuals, rank, s = linalg.lstsq(a, b, rcond=-1)
1033            assert_(rank == 4)
1034            x, residuals, rank, s = linalg.lstsq(a, b, rcond=None)
1035            assert_(rank == 3)
1036            # Warning should be raised exactly once (first command)
1037            assert_(len(w) == 1)
1038
1039    @parametrize(
1040        "m, n, n_rhs",
1041        [
1042            (4, 2, 2),
1043            (0, 4, 1),
1044            (0, 4, 2),
1045            (4, 0, 1),
1046            (4, 0, 2),
1047            #    (4, 2, 0),    # Intel MKL ERROR: Parameter 4 was incorrect on entry to DLALSD.
1048            (0, 0, 0),
1049        ],
1050    )
1051    def test_empty_a_b(self, m, n, n_rhs):
1052        a = np.arange(m * n).reshape(m, n)
1053        b = np.ones((m, n_rhs))
1054        x, residuals, rank, s = linalg.lstsq(a, b, rcond=None)
1055        if m == 0:
1056            assert_((x == 0).all())
1057        assert_equal(x.shape, (n, n_rhs))
1058        assert_equal(residuals.shape, ((n_rhs,) if m > n else (0,)))
1059        if m > n and n_rhs > 0:
1060            # residuals are exactly the squared norms of b's columns
1061            r = b - np.dot(a, x)
1062            assert_almost_equal(residuals, (r * r).sum(axis=-2))
1063        assert_equal(rank, min(m, n))
1064        assert_equal(s.shape, (min(m, n),))
1065
1066    def test_incompatible_dims(self):
1067        # use modified version of docstring example
1068        x = np.array([0, 1, 2, 3])
1069        y = np.array([-1, 0.2, 0.9, 2.1, 3.3])
1070        A = np.vstack([x, np.ones(len(x))]).T
1071        #        with assert_raises_regex(LinAlgError, "Incompatible dimensions"):
1072        with assert_raises((RuntimeError, LinAlgError)):
1073            linalg.lstsq(A, y, rcond=None)
1074
1075
1076# @xfail  #(reason="no block()")
1077@skip  # FIXME: otherwise fails in setUp calling np.block
1078@instantiate_parametrized_tests
1079class TestMatrixPower(TestCase):
1080    def setUp(self):
1081        self.rshft_0 = np.eye(4)
1082        self.rshft_1 = self.rshft_0[[3, 0, 1, 2]]
1083        self.rshft_2 = self.rshft_0[[2, 3, 0, 1]]
1084        self.rshft_3 = self.rshft_0[[1, 2, 3, 0]]
1085        self.rshft_all = [self.rshft_0, self.rshft_1, self.rshft_2, self.rshft_3]
1086        self.noninv = array([[1, 0], [0, 0]])
1087        self.stacked = np.block([[[self.rshft_0]]] * 2)
1088        # FIXME the 'e' dtype might work in future
1089        self.dtnoinv = [object, np.dtype("e"), np.dtype("g"), np.dtype("G")]
1090
1091    @parametrize("dt", [np.dtype(c) for c in "?bBhilefdFD"])
1092    def test_large_power(self, dt):
1093        rshft = self.rshft_1.astype(dt)
1094        assert_equal(matrix_power(rshft, 2**100 + 2**10 + 2**5 + 0), self.rshft_0)
1095        assert_equal(matrix_power(rshft, 2**100 + 2**10 + 2**5 + 1), self.rshft_1)
1096        assert_equal(matrix_power(rshft, 2**100 + 2**10 + 2**5 + 2), self.rshft_2)
1097        assert_equal(matrix_power(rshft, 2**100 + 2**10 + 2**5 + 3), self.rshft_3)
1098
1099    @parametrize("dt", [np.dtype(c) for c in "?bBhilefdFD"])
1100    def test_power_is_zero(self, dt):
1101        def tz(M):
1102            mz = matrix_power(M, 0)
1103            assert_equal(mz, identity_like_generalized(M))
1104            assert_equal(mz.dtype, M.dtype)
1105
1106        for mat in self.rshft_all:
1107            tz(mat.astype(dt))
1108            if dt != object:
1109                tz(self.stacked.astype(dt))
1110
1111    @parametrize("dt", [np.dtype(c) for c in "?bBhilefdFD"])
1112    def test_power_is_one(self, dt):
1113        def tz(mat):
1114            mz = matrix_power(mat, 1)
1115            assert_equal(mz, mat)
1116            assert_equal(mz.dtype, mat.dtype)
1117
1118        for mat in self.rshft_all:
1119            tz(mat.astype(dt))
1120            if dt != object:
1121                tz(self.stacked.astype(dt))
1122
1123    @parametrize("dt", [np.dtype(c) for c in "?bBhilefdFD"])
1124    def test_power_is_two(self, dt):
1125        def tz(mat):
1126            mz = matrix_power(mat, 2)
1127            mmul = matmul if mat.dtype != object else dot
1128            assert_equal(mz, mmul(mat, mat))
1129            assert_equal(mz.dtype, mat.dtype)
1130
1131        for mat in self.rshft_all:
1132            tz(mat.astype(dt))
1133            if dt != object:
1134                tz(self.stacked.astype(dt))
1135
1136    @parametrize("dt", [np.dtype(c) for c in "?bBhilefdFD"])
1137    def test_power_is_minus_one(self, dt):
1138        def tz(mat):
1139            invmat = matrix_power(mat, -1)
1140            mmul = matmul if mat.dtype != object else dot
1141            assert_almost_equal(mmul(invmat, mat), identity_like_generalized(mat))
1142
1143        for mat in self.rshft_all:
1144            if dt not in self.dtnoinv:
1145                tz(mat.astype(dt))
1146
1147    @parametrize("dt", [np.dtype(c) for c in "?bBhilefdFD"])
1148    def test_exceptions_bad_power(self, dt):
1149        mat = self.rshft_0.astype(dt)
1150        assert_raises(TypeError, matrix_power, mat, 1.5)
1151        assert_raises(TypeError, matrix_power, mat, [1])
1152
1153    @parametrize("dt", [np.dtype(c) for c in "?bBhilefdFD"])
1154    def test_exceptions_non_square(self, dt):
1155        assert_raises(LinAlgError, matrix_power, np.array([1], dt), 1)
1156        assert_raises(LinAlgError, matrix_power, np.array([[1], [2]], dt), 1)
1157        assert_raises(LinAlgError, matrix_power, np.ones((4, 3, 2), dt), 1)
1158
1159    @skipif(IS_WASM, reason="fp errors don't work in wasm")
1160    @parametrize("dt", [np.dtype(c) for c in "?bBhilefdFD"])
1161    def test_exceptions_not_invertible(self, dt):
1162        if dt in self.dtnoinv:
1163            return
1164        mat = self.noninv.astype(dt)
1165        assert_raises(LinAlgError, matrix_power, mat, -1)
1166
1167
1168class TestEigvalshCases(HermitianTestCase, HermitianGeneralizedTestCase):
1169    def do(self, a, b, tags):
1170        pytest.xfail(reason="sort complex")
1171        # note that eigenvalue arrays returned by eig must be sorted since
1172        # their order isn't guaranteed.
1173        ev = linalg.eigvalsh(a, "L")
1174        evalues, evectors = linalg.eig(a)
1175        evalues.sort(axis=-1)
1176        assert_allclose(ev, evalues, rtol=get_rtol(ev.dtype))
1177
1178        ev2 = linalg.eigvalsh(a, "U")
1179        assert_allclose(ev2, evalues, rtol=get_rtol(ev.dtype))
1180
1181
1182@instantiate_parametrized_tests
1183class TestEigvalsh(TestCase):
1184    @parametrize("dtype", [single, double, csingle, cdouble])
1185    def test_types(self, dtype):
1186        x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype)
1187        w = np.linalg.eigvalsh(x)
1188        assert_equal(w.dtype, get_real_dtype(dtype))
1189
1190    def test_invalid(self):
1191        x = np.array([[1, 0.5], [0.5, 1]], dtype=np.float32)
1192        assert_raises((RuntimeError, ValueError), np.linalg.eigvalsh, x, UPLO="lrong")
1193        assert_raises((RuntimeError, ValueError), np.linalg.eigvalsh, x, "lower")
1194        assert_raises((RuntimeError, ValueError), np.linalg.eigvalsh, x, "upper")
1195
1196    def test_UPLO(self):
1197        Klo = np.array([[0, 0], [1, 0]], dtype=np.double)
1198        Kup = np.array([[0, 1], [0, 0]], dtype=np.double)
1199        tgt = np.array([-1, 1], dtype=np.double)
1200        rtol = get_rtol(np.double)
1201
1202        # Check default is 'L'
1203        w = np.linalg.eigvalsh(Klo)
1204        assert_allclose(w, tgt, rtol=rtol)
1205        # Check 'L'
1206        w = np.linalg.eigvalsh(Klo, UPLO="L")
1207        assert_allclose(w, tgt, rtol=rtol)
1208        # Check 'l'
1209        w = np.linalg.eigvalsh(Klo, UPLO="l")
1210        assert_allclose(w, tgt, rtol=rtol)
1211        # Check 'U'
1212        w = np.linalg.eigvalsh(Kup, UPLO="U")
1213        assert_allclose(w, tgt, rtol=rtol)
1214        # Check 'u'
1215        w = np.linalg.eigvalsh(Kup, UPLO="u")
1216        assert_allclose(w, tgt, rtol=rtol)
1217
1218    def test_0_size(self):
1219        # Check that all kinds of 0-sized arrays work
1220        #     class ArraySubclass(np.ndarray):
1221        #         pass
1222        a = np.zeros((0, 1, 1), dtype=np.int_)  # .view(ArraySubclass)
1223        res = linalg.eigvalsh(a)
1224        assert_(res.dtype.type is np.float64)
1225        assert_equal((0, 1), res.shape)
1226        # This is just for documentation, it might make sense to change:
1227        assert_(isinstance(res, np.ndarray))
1228
1229        a = np.zeros((0, 0), dtype=np.complex64)  # .view(ArraySubclass)
1230        res = linalg.eigvalsh(a)
1231        assert_(res.dtype.type is np.float32)
1232        assert_equal((0,), res.shape)
1233        # This is just for documentation, it might make sense to change:
1234        assert_(isinstance(res, np.ndarray))
1235
1236
1237class TestEighCases(HermitianTestCase, HermitianGeneralizedTestCase):
1238    def do(self, a, b, tags):
1239        pytest.xfail(reason="sort complex")
1240        # note that eigenvalue arrays returned by eig must be sorted since
1241        # their order isn't guaranteed.
1242        ev, evc = linalg.eigh(a)
1243        evalues, evectors = linalg.eig(a)
1244        evalues.sort(axis=-1)
1245        assert_almost_equal(ev, evalues)
1246
1247        assert_allclose(
1248            dot_generalized(a, evc),
1249            np.asarray(ev)[..., None, :] * np.asarray(evc),
1250            rtol=get_rtol(ev.dtype),
1251        )
1252
1253        ev2, evc2 = linalg.eigh(a, "U")
1254        assert_almost_equal(ev2, evalues)
1255
1256        assert_allclose(
1257            dot_generalized(a, evc2),
1258            np.asarray(ev2)[..., None, :] * np.asarray(evc2),
1259            rtol=get_rtol(ev.dtype),
1260            err_msg=repr(a),
1261        )
1262
1263
1264@instantiate_parametrized_tests
1265class TestEigh(TestCase):
1266    @parametrize("dtype", [single, double, csingle, cdouble])
1267    def test_types(self, dtype):
1268        x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype)
1269        w, v = np.linalg.eigh(x)
1270        assert_equal(w.dtype, get_real_dtype(dtype))
1271        assert_equal(v.dtype, dtype)
1272
1273    def test_invalid(self):
1274        x = np.array([[1, 0.5], [0.5, 1]], dtype=np.float32)
1275        assert_raises((RuntimeError, ValueError), np.linalg.eigh, x, UPLO="lrong")
1276        assert_raises((RuntimeError, ValueError), np.linalg.eigh, x, "lower")
1277        assert_raises((RuntimeError, ValueError), np.linalg.eigh, x, "upper")
1278
1279    def test_UPLO(self):
1280        Klo = np.array([[0, 0], [1, 0]], dtype=np.double)
1281        Kup = np.array([[0, 1], [0, 0]], dtype=np.double)
1282        tgt = np.array([-1, 1], dtype=np.double)
1283        rtol = get_rtol(np.double)
1284
1285        # Check default is 'L'
1286        w, v = np.linalg.eigh(Klo)
1287        assert_allclose(w, tgt, rtol=rtol)
1288        # Check 'L'
1289        w, v = np.linalg.eigh(Klo, UPLO="L")
1290        assert_allclose(w, tgt, rtol=rtol)
1291        # Check 'l'
1292        w, v = np.linalg.eigh(Klo, UPLO="l")
1293        assert_allclose(w, tgt, rtol=rtol)
1294        # Check 'U'
1295        w, v = np.linalg.eigh(Kup, UPLO="U")
1296        assert_allclose(w, tgt, rtol=rtol)
1297        # Check 'u'
1298        w, v = np.linalg.eigh(Kup, UPLO="u")
1299        assert_allclose(w, tgt, rtol=rtol)
1300
1301    def test_0_size(self):
1302        # Check that all kinds of 0-sized arrays work
1303        #        class ArraySubclass(np.ndarray):
1304        #            pass
1305        a = np.zeros((0, 1, 1), dtype=np.int_)  # .view(ArraySubclass)
1306        res, res_v = linalg.eigh(a)
1307        assert_(res_v.dtype.type is np.float64)
1308        assert_(res.dtype.type is np.float64)
1309        assert_equal(a.shape, res_v.shape)
1310        assert_equal((0, 1), res.shape)
1311        # This is just for documentation, it might make sense to change:
1312        assert_(isinstance(a, np.ndarray))
1313
1314        a = np.zeros((0, 0), dtype=np.complex64)  # .view(ArraySubclass)
1315        res, res_v = linalg.eigh(a)
1316        assert_(res_v.dtype.type is np.complex64)
1317        assert_(res.dtype.type is np.float32)
1318        assert_equal(a.shape, res_v.shape)
1319        assert_equal((0,), res.shape)
1320        # This is just for documentation, it might make sense to change:
1321        assert_(isinstance(a, np.ndarray))
1322
1323
1324class _TestNormBase:
1325    dt = None
1326    dec = None
1327
1328    @staticmethod
1329    def check_dtype(x, res):
1330        if issubclass(x.dtype.type, np.inexact):
1331            assert_equal(res.dtype, x.real.dtype)
1332        else:
1333            # For integer input, don't have to test float precision of output.
1334            assert_(issubclass(res.dtype.type, np.floating))
1335
1336
1337class _TestNormGeneral(_TestNormBase):
1338    def test_empty(self):
1339        assert_equal(norm([]), 0.0)
1340        assert_equal(norm(array([], dtype=self.dt)), 0.0)
1341        assert_equal(norm(atleast_2d(array([], dtype=self.dt))), 0.0)
1342
1343    def test_vector_return_type(self):
1344        a = np.array([1, 0, 1])
1345
1346        exact_types = "Bbhil"  # np.typecodes["AllInteger"]
1347        inexact_types = "efdFD"  # np.typecodes["AllFloat"]
1348
1349        all_types = exact_types + inexact_types
1350
1351        for each_type in all_types:
1352            at = a.astype(each_type)
1353
1354            if each_type == np.dtype("float16"):
1355                # FIXME: move looping to parametrize, add decorators=[xfail]
1356                # pytest.xfail("float16**float64 => float64 (?)")
1357                raise SkipTest("float16**float64 => float64 (?)")
1358
1359            an = norm(at, -np.inf)
1360            self.check_dtype(at, an)
1361            assert_almost_equal(an, 0.0)
1362
1363            with suppress_warnings() as sup:
1364                sup.filter(RuntimeWarning, "divide by zero encountered")
1365                an = norm(at, -1)
1366                self.check_dtype(at, an)
1367                assert_almost_equal(an, 0.0)
1368
1369            an = norm(at, 0)
1370            self.check_dtype(at, an)
1371            assert_almost_equal(an, 2)
1372
1373            an = norm(at, 1)
1374            self.check_dtype(at, an)
1375            assert_almost_equal(an, 2.0)
1376
1377            an = norm(at, 2)
1378            self.check_dtype(at, an)
1379            assert_almost_equal(an, an.dtype.type(2.0) ** an.dtype.type(1.0 / 2.0))
1380
1381            an = norm(at, 4)
1382            self.check_dtype(at, an)
1383            assert_almost_equal(an, an.dtype.type(2.0) ** an.dtype.type(1.0 / 4.0))
1384
1385            an = norm(at, np.inf)
1386            self.check_dtype(at, an)
1387            assert_almost_equal(an, 1.0)
1388
1389    def test_vector(self):
1390        a = [1, 2, 3, 4]
1391        b = [-1, -2, -3, -4]
1392        c = [-1, 2, -3, 4]
1393
1394        def _test(v):
1395            np.testing.assert_almost_equal(norm(v), 30**0.5, decimal=self.dec)
1396            np.testing.assert_almost_equal(norm(v, inf), 4.0, decimal=self.dec)
1397            np.testing.assert_almost_equal(norm(v, -inf), 1.0, decimal=self.dec)
1398            np.testing.assert_almost_equal(norm(v, 1), 10.0, decimal=self.dec)
1399            np.testing.assert_almost_equal(norm(v, -1), 12.0 / 25, decimal=self.dec)
1400            np.testing.assert_almost_equal(norm(v, 2), 30**0.5, decimal=self.dec)
1401            np.testing.assert_almost_equal(
1402                norm(v, -2), ((205.0 / 144) ** -0.5), decimal=self.dec
1403            )
1404            np.testing.assert_almost_equal(norm(v, 0), 4, decimal=self.dec)
1405
1406        for v in (
1407            a,
1408            b,
1409            c,
1410        ):
1411            _test(v)
1412
1413        for v in (
1414            array(a, dtype=self.dt),
1415            array(b, dtype=self.dt),
1416            array(c, dtype=self.dt),
1417        ):
1418            _test(v)
1419
1420    def test_axis(self):
1421        # Vector norms.
1422        # Compare the use of `axis` with computing the norm of each row
1423        # or column separately.
1424        A = array([[1, 2, 3], [4, 5, 6]], dtype=self.dt)
1425        for order in [None, -1, 0, 1, 2, 3, np.inf, -np.inf]:
1426            expected0 = [norm(A[:, k], ord=order) for k in range(A.shape[1])]
1427            assert_almost_equal(norm(A, ord=order, axis=0), expected0)
1428            expected1 = [norm(A[k, :], ord=order) for k in range(A.shape[0])]
1429            assert_almost_equal(norm(A, ord=order, axis=1), expected1)
1430
1431        # Matrix norms.
1432        B = np.arange(1, 25, dtype=self.dt).reshape(2, 3, 4)
1433        nd = B.ndim
1434        for order in [None, -2, 2, -1, 1, np.inf, -np.inf, "fro"]:
1435            for axis in itertools.combinations(range(-nd, nd), 2):
1436                row_axis, col_axis = axis
1437                if row_axis < 0:
1438                    row_axis += nd
1439                if col_axis < 0:
1440                    col_axis += nd
1441                if row_axis == col_axis:
1442                    assert_raises(
1443                        (RuntimeError, ValueError), norm, B, ord=order, axis=axis
1444                    )
1445                else:
1446                    n = norm(B, ord=order, axis=axis)
1447
1448                    # The logic using k_index only works for nd = 3.
1449                    # This has to be changed if nd is increased.
1450                    k_index = nd - (row_axis + col_axis)
1451                    if row_axis < col_axis:
1452                        expected = [
1453                            norm(B[:].take(k, axis=k_index), ord=order)
1454                            for k in range(B.shape[k_index])
1455                        ]
1456                    else:
1457                        expected = [
1458                            norm(B[:].take(k, axis=k_index).T, ord=order)
1459                            for k in range(B.shape[k_index])
1460                        ]
1461                    assert_almost_equal(n, expected)
1462
1463    def test_keepdims(self):
1464        A = np.arange(1, 25, dtype=self.dt).reshape(2, 3, 4)
1465
1466        allclose_err = "order {0}, axis = {1}"
1467        shape_err = "Shape mismatch found {0}, expected {1}, order={2}, axis={3}"
1468
1469        # check the order=None, axis=None case
1470        expected = norm(A, ord=None, axis=None)
1471        found = norm(A, ord=None, axis=None, keepdims=True)
1472        assert_allclose(
1473            np.squeeze(found), expected, err_msg=allclose_err.format(None, None)
1474        )
1475        expected_shape = (1, 1, 1)
1476        assert_(
1477            found.shape == expected_shape,
1478            shape_err.format(found.shape, expected_shape, None, None),
1479        )
1480
1481        # Vector norms.
1482        for order in [None, -1, 0, 1, 2, 3, np.inf, -np.inf]:
1483            for k in range(A.ndim):
1484                expected = norm(A, ord=order, axis=k)
1485                found = norm(A, ord=order, axis=k, keepdims=True)
1486                assert_allclose(
1487                    np.squeeze(found), expected, err_msg=allclose_err.format(order, k)
1488                )
1489                expected_shape = list(A.shape)
1490                expected_shape[k] = 1
1491                expected_shape = tuple(expected_shape)
1492                assert_(
1493                    found.shape == expected_shape,
1494                    shape_err.format(found.shape, expected_shape, order, k),
1495                )
1496
1497        # Matrix norms.
1498        for order in [None, -2, 2, -1, 1, np.inf, -np.inf, "fro", "nuc"]:
1499            for k in itertools.permutations(range(A.ndim), 2):
1500                expected = norm(A, ord=order, axis=k)
1501                found = norm(A, ord=order, axis=k, keepdims=True)
1502                assert_allclose(
1503                    np.squeeze(found), expected, err_msg=allclose_err.format(order, k)
1504                )
1505                expected_shape = list(A.shape)
1506                expected_shape[k[0]] = 1
1507                expected_shape[k[1]] = 1
1508                expected_shape = tuple(expected_shape)
1509                assert_(
1510                    found.shape == expected_shape,
1511                    shape_err.format(found.shape, expected_shape, order, k),
1512                )
1513
1514
1515class _TestNorm2D(_TestNormBase):
1516    # Define the part for 2d arrays separately, so we can subclass this
1517    # and run the tests using np.matrix in matrixlib.tests.test_matrix_linalg.
1518
1519    def test_matrix_empty(self):
1520        assert_equal(norm(np.array([[]], dtype=self.dt)), 0.0)
1521
1522    def test_matrix_return_type(self):
1523        a = np.array([[1, 0, 1], [0, 1, 1]])
1524
1525        exact_types = "Bbhil"  # np.typecodes["AllInteger"]
1526
1527        # float32, complex64, float64, complex128 types are the only types
1528        # allowed by `linalg`, which performs the matrix operations used
1529        # within `norm`.
1530        inexact_types = "fdFD"
1531
1532        all_types = exact_types + inexact_types
1533
1534        for each_type in all_types:
1535            at = a.astype(each_type)
1536
1537            an = norm(at, -np.inf)
1538            self.check_dtype(at, an)
1539            assert_almost_equal(an, 2.0)
1540
1541            with suppress_warnings() as sup:
1542                sup.filter(RuntimeWarning, "divide by zero encountered")
1543                an = norm(at, -1)
1544                self.check_dtype(at, an)
1545                assert_almost_equal(an, 1.0)
1546
1547            an = norm(at, 1)
1548            self.check_dtype(at, an)
1549            assert_almost_equal(an, 2.0)
1550
1551            an = norm(at, 2)
1552            self.check_dtype(at, an)
1553            assert_almost_equal(an, 3.0 ** (1.0 / 2.0))
1554
1555            an = norm(at, -2)
1556            self.check_dtype(at, an)
1557            assert_almost_equal(an, 1.0)
1558
1559            an = norm(at, np.inf)
1560            self.check_dtype(at, an)
1561            assert_almost_equal(an, 2.0)
1562
1563            an = norm(at, "fro")
1564            self.check_dtype(at, an)
1565            assert_almost_equal(an, 2.0)
1566
1567            an = norm(at, "nuc")
1568            self.check_dtype(at, an)
1569            # Lower bar needed to support low precision floats.
1570            # They end up being off by 1 in the 7th place.
1571            np.testing.assert_almost_equal(an, 2.7320508075688772, decimal=6)
1572
1573    def test_matrix_2x2(self):
1574        A = np.array([[1, 3], [5, 7]], dtype=self.dt)
1575        assert_almost_equal(norm(A), 84**0.5)
1576        assert_almost_equal(norm(A, "fro"), 84**0.5)
1577        assert_almost_equal(norm(A, "nuc"), 10.0)
1578        assert_almost_equal(norm(A, inf), 12.0)
1579        assert_almost_equal(norm(A, -inf), 4.0)
1580        assert_almost_equal(norm(A, 1), 10.0)
1581        assert_almost_equal(norm(A, -1), 6.0)
1582        assert_almost_equal(norm(A, 2), 9.1231056256176615)
1583        assert_almost_equal(norm(A, -2), 0.87689437438234041)
1584
1585        assert_raises((RuntimeError, ValueError), norm, A, "nofro")
1586        assert_raises((RuntimeError, ValueError), norm, A, -3)
1587        assert_raises((RuntimeError, ValueError), norm, A, 0)
1588
1589    def test_matrix_3x3(self):
1590        # This test has been added because the 2x2 example
1591        # happened to have equal nuclear norm and induced 1-norm.
1592        # The 1/10 scaling factor accommodates the absolute tolerance
1593        # used in assert_almost_equal.
1594        A = (1 / 10) * np.array([[1, 2, 3], [6, 0, 5], [3, 2, 1]], dtype=self.dt)
1595        assert_almost_equal(norm(A), (1 / 10) * 89**0.5)
1596        assert_almost_equal(norm(A, "fro"), (1 / 10) * 89**0.5)
1597        assert_almost_equal(norm(A, "nuc"), 1.3366836911774836)
1598        assert_almost_equal(norm(A, inf), 1.1)
1599        assert_almost_equal(norm(A, -inf), 0.6)
1600        assert_almost_equal(norm(A, 1), 1.0)
1601        assert_almost_equal(norm(A, -1), 0.4)
1602        assert_almost_equal(norm(A, 2), 0.88722940323461277)
1603        assert_almost_equal(norm(A, -2), 0.19456584790481812)
1604
1605    def test_bad_args(self):
1606        # Check that bad arguments raise the appropriate exceptions.
1607
1608        A = np.array([[1, 2, 3], [4, 5, 6]], dtype=self.dt)
1609        B = np.arange(1, 25, dtype=self.dt).reshape(2, 3, 4)
1610
1611        # Using `axis=<integer>` or passing in a 1-D array implies vector
1612        # norms are being computed, so also using `ord='fro'`
1613        # or `ord='nuc'` or any other string raises a ValueError.
1614        assert_raises((RuntimeError, ValueError), norm, A, "fro", 0)
1615        assert_raises((RuntimeError, ValueError), norm, A, "nuc", 0)
1616        assert_raises((RuntimeError, ValueError), norm, [3, 4], "fro", None)
1617        assert_raises((RuntimeError, ValueError), norm, [3, 4], "nuc", None)
1618        assert_raises((RuntimeError, ValueError), norm, [3, 4], "test", None)
1619
1620        # Similarly, norm should raise an exception when ord is any finite
1621        # number other than 1, 2, -1 or -2 when computing matrix norms.
1622        for order in [0, 3]:
1623            assert_raises((RuntimeError, ValueError), norm, A, order, None)
1624            assert_raises((RuntimeError, ValueError), norm, A, order, (0, 1))
1625            assert_raises((RuntimeError, ValueError), norm, B, order, (1, 2))
1626
1627        # Invalid axis
1628        assert_raises((IndexError, np.AxisError), norm, B, None, 3)
1629        assert_raises((IndexError, np.AxisError), norm, B, None, (2, 3))
1630        assert_raises((RuntimeError, ValueError), norm, B, None, (0, 1, 2))
1631
1632
1633class _TestNorm(_TestNorm2D, _TestNormGeneral):
1634    pass
1635
1636
1637class TestNorm_NonSystematic(TestCase):
1638    def test_intmin(self):
1639        # Non-regression test: p-norm of signed integer would previously do
1640        # float cast and abs in the wrong order.
1641        x = np.array([-(2**31)], dtype=np.int32)
1642        old_assert_almost_equal(norm(x, ord=3), 2**31, decimal=5)
1643
1644
1645# Separate definitions so we can use them for matrix tests.
1646class _TestNormDoubleBase(_TestNormBase, TestCase):
1647    dt = np.double
1648    dec = 12
1649
1650
1651class _TestNormSingleBase(_TestNormBase, TestCase):
1652    dt = np.float32
1653    dec = 6
1654
1655
1656class _TestNormInt64Base(_TestNormBase, TestCase):
1657    dt = np.int64
1658    dec = 12
1659
1660
1661class TestNormDouble(_TestNorm, _TestNormDoubleBase, TestCase):
1662    pass
1663
1664
1665class TestNormSingle(_TestNorm, _TestNormSingleBase, TestCase):
1666    pass
1667
1668
1669class TestNormInt64(_TestNorm, _TestNormInt64Base):
1670    pass
1671
1672
1673class TestMatrixRank(TestCase):
1674    def test_matrix_rank(self):
1675        # Full rank matrix
1676        assert_equal(4, matrix_rank(np.eye(4)))
1677        # rank deficient matrix
1678        I = np.eye(4)
1679        I[-1, -1] = 0.0
1680        assert_equal(matrix_rank(I), 3)
1681        # All zeros - zero rank
1682        assert_equal(matrix_rank(np.zeros((4, 4))), 0)
1683        # 1 dimension - rank 1 unless all 0
1684        assert_equal(matrix_rank([1, 0, 0, 0]), 1)
1685        assert_equal(matrix_rank(np.zeros((4,))), 0)
1686        # accepts array-like
1687        assert_equal(matrix_rank([1]), 1)
1688        # greater than 2 dimensions treated as stacked matrices
1689        ms = np.array([I, np.eye(4), np.zeros((4, 4))])
1690        assert_equal(matrix_rank(ms), np.array([3, 4, 0]))
1691        # works on scalar
1692        assert_equal(matrix_rank(1), 1)
1693
1694    def test_symmetric_rank(self):
1695        assert_equal(4, matrix_rank(np.eye(4), hermitian=True))
1696        assert_equal(1, matrix_rank(np.ones((4, 4)), hermitian=True))
1697        assert_equal(0, matrix_rank(np.zeros((4, 4)), hermitian=True))
1698        # rank deficient matrix
1699        I = np.eye(4)
1700        I[-1, -1] = 0.0
1701        assert_equal(3, matrix_rank(I, hermitian=True))
1702        # manually supplied tolerance
1703        I[-1, -1] = 1e-8
1704        assert_equal(4, matrix_rank(I, hermitian=True, tol=0.99e-8))
1705        assert_equal(3, matrix_rank(I, hermitian=True, tol=1.01e-8))
1706
1707    def test_reduced_rank(self):
1708        # Test matrices with reduced rank
1709        #  rng = np.random.RandomState(20120714)
1710        np.random.seed(20120714)
1711        for i in range(100):
1712            # Make a rank deficient matrix
1713            X = np.random.normal(size=(40, 10))
1714            X[:, 0] = X[:, 1] + X[:, 2]
1715            # Assert that matrix_rank detected deficiency
1716            assert_equal(matrix_rank(X), 9)
1717            X[:, 3] = X[:, 4] + X[:, 5]
1718            assert_equal(matrix_rank(X), 8)
1719
1720
1721@instantiate_parametrized_tests
1722class TestQR(TestCase):
1723    def check_qr(self, a):
1724        # This test expects the argument `a` to be an ndarray or
1725        # a subclass of an ndarray of inexact type.
1726        a_type = type(a)
1727        a_dtype = a.dtype
1728        m, n = a.shape
1729        k = min(m, n)
1730
1731        # mode == 'complete'
1732        q, r = linalg.qr(a, mode="complete")
1733        assert_(q.dtype == a_dtype)
1734        assert_(r.dtype == a_dtype)
1735        assert_(isinstance(q, a_type))
1736        assert_(isinstance(r, a_type))
1737        assert_(q.shape == (m, m))
1738        assert_(r.shape == (m, n))
1739        assert_almost_equal(dot(q, r), a, single_decimal=5)
1740        assert_almost_equal(dot(q.T.conj(), q), np.eye(m))
1741        assert_almost_equal(np.triu(r), r)
1742
1743        # mode == 'reduced'
1744        q1, r1 = linalg.qr(a, mode="reduced")
1745        assert_(q1.dtype == a_dtype)
1746        assert_(r1.dtype == a_dtype)
1747        assert_(isinstance(q1, a_type))
1748        assert_(isinstance(r1, a_type))
1749        assert_(q1.shape == (m, k))
1750        assert_(r1.shape == (k, n))
1751        assert_almost_equal(dot(q1, r1), a, single_decimal=5)
1752        assert_almost_equal(dot(q1.T.conj(), q1), np.eye(k))
1753        assert_almost_equal(np.triu(r1), r1)
1754
1755        # mode == 'r'
1756        r2 = linalg.qr(a, mode="r")
1757        assert_(r2.dtype == a_dtype)
1758        assert_(isinstance(r2, a_type))
1759        assert_almost_equal(r2, r1)
1760
1761    @xpassIfTorchDynamo  # (reason="torch does not allow qr(..., mode='raw'")
1762    @parametrize("m, n", [(3, 0), (0, 3), (0, 0)])
1763    def test_qr_empty(self, m, n):
1764        k = min(m, n)
1765        a = np.empty((m, n))
1766
1767        self.check_qr(a)
1768
1769        h, tau = np.linalg.qr(a, mode="raw")
1770        assert_equal(h.dtype, np.double)
1771        assert_equal(tau.dtype, np.double)
1772        assert_equal(h.shape, (n, m))
1773        assert_equal(tau.shape, (k,))
1774
1775    @xpassIfTorchDynamo  # (reason="torch does not allow qr(..., mode='raw'")
1776    def test_mode_raw(self):
1777        # The factorization is not unique and varies between libraries,
1778        # so it is not possible to check against known values. Functional
1779        # testing is a possibility, but awaits the exposure of more
1780        # of the functions in lapack_lite. Consequently, this test is
1781        # very limited in scope. Note that the results are in FORTRAN
1782        # order, hence the h arrays are transposed.
1783        a = np.array([[1, 2], [3, 4], [5, 6]], dtype=np.double)
1784
1785        # Test double
1786        h, tau = linalg.qr(a, mode="raw")
1787        assert_(h.dtype == np.double)
1788        assert_(tau.dtype == np.double)
1789        assert_(h.shape == (2, 3))
1790        assert_(tau.shape == (2,))
1791
1792        h, tau = linalg.qr(a.T, mode="raw")
1793        assert_(h.dtype == np.double)
1794        assert_(tau.dtype == np.double)
1795        assert_(h.shape == (3, 2))
1796        assert_(tau.shape == (2,))
1797
1798    def test_mode_all_but_economic(self):
1799        a = np.array([[1, 2], [3, 4]])
1800        b = np.array([[1, 2], [3, 4], [5, 6]])
1801        for dt in "fd":
1802            m1 = a.astype(dt)
1803            m2 = b.astype(dt)
1804            self.check_qr(m1)
1805            self.check_qr(m2)
1806            self.check_qr(m2.T)
1807
1808        for dt in "fd":
1809            m1 = 1 + 1j * a.astype(dt)
1810            m2 = 1 + 1j * b.astype(dt)
1811            self.check_qr(m1)
1812            self.check_qr(m2)
1813            self.check_qr(m2.T)
1814
1815    def check_qr_stacked(self, a):
1816        # This test expects the argument `a` to be an ndarray or
1817        # a subclass of an ndarray of inexact type.
1818        a_type = type(a)
1819        a_dtype = a.dtype
1820        m, n = a.shape[-2:]
1821        k = min(m, n)
1822
1823        # mode == 'complete'
1824        q, r = linalg.qr(a, mode="complete")
1825        assert_(q.dtype == a_dtype)
1826        assert_(r.dtype == a_dtype)
1827        assert_(isinstance(q, a_type))
1828        assert_(isinstance(r, a_type))
1829        assert_(q.shape[-2:] == (m, m))
1830        assert_(r.shape[-2:] == (m, n))
1831        assert_almost_equal(matmul(q, r), a, single_decimal=5)
1832        I_mat = np.identity(q.shape[-1])
1833        stack_I_mat = np.broadcast_to(I_mat, q.shape[:-2] + (q.shape[-1],) * 2)
1834        assert_almost_equal(matmul(swapaxes(q, -1, -2).conj(), q), stack_I_mat)
1835        assert_almost_equal(np.triu(r[..., :, :]), r)
1836
1837        # mode == 'reduced'
1838        q1, r1 = linalg.qr(a, mode="reduced")
1839        assert_(q1.dtype == a_dtype)
1840        assert_(r1.dtype == a_dtype)
1841        assert_(isinstance(q1, a_type))
1842        assert_(isinstance(r1, a_type))
1843        assert_(q1.shape[-2:] == (m, k))
1844        assert_(r1.shape[-2:] == (k, n))
1845        assert_almost_equal(matmul(q1, r1), a, single_decimal=5)
1846        I_mat = np.identity(q1.shape[-1])
1847        stack_I_mat = np.broadcast_to(I_mat, q1.shape[:-2] + (q1.shape[-1],) * 2)
1848        assert_almost_equal(matmul(swapaxes(q1, -1, -2).conj(), q1), stack_I_mat)
1849        assert_almost_equal(np.triu(r1[..., :, :]), r1)
1850
1851        # mode == 'r'
1852        r2 = linalg.qr(a, mode="r")
1853        assert_(r2.dtype == a_dtype)
1854        assert_(isinstance(r2, a_type))
1855        assert_almost_equal(r2, r1)
1856
1857    @skipif(numpy.__version__ < "1.22", reason="NP_VER: fails on CI with numpy 1.21.2")
1858    @parametrize("size", [(3, 4), (4, 3), (4, 4), (3, 0), (0, 3)])
1859    @parametrize("outer_size", [(2, 2), (2,), (2, 3, 4)])
1860    @parametrize("dt", [np.single, np.double, np.csingle, np.cdouble])
1861    def test_stacked_inputs(self, outer_size, size, dt):
1862        A = np.random.normal(size=outer_size + size).astype(dt)
1863        B = np.random.normal(size=outer_size + size).astype(dt)
1864        self.check_qr_stacked(A)
1865        self.check_qr_stacked(A + 1.0j * B)
1866
1867
1868@instantiate_parametrized_tests
1869class TestCholesky(TestCase):
1870    # TODO: are there no other tests for cholesky?
1871
1872    @parametrize("shape", [(1, 1), (2, 2), (3, 3), (50, 50), (3, 10, 10)])
1873    @parametrize("dtype", (np.float32, np.float64, np.complex64, np.complex128))
1874    def test_basic_property(self, shape, dtype):
1875        # Check A = L L^H
1876        np.random.seed(1)
1877        a = np.random.randn(*shape)
1878        if np.issubdtype(dtype, np.complexfloating):
1879            a = a + 1j * np.random.randn(*shape)
1880
1881        t = list(range(len(shape)))
1882        t[-2:] = -1, -2
1883
1884        a = np.matmul(a.transpose(t).conj(), a)
1885        a = np.asarray(a, dtype=dtype)
1886
1887        c = np.linalg.cholesky(a)
1888
1889        b = np.matmul(c, c.transpose(t).conj())
1890        atol = 500 * a.shape[0] * np.finfo(dtype).eps
1891        assert_allclose(b, a, atol=atol, err_msg=f"{shape} {dtype}\n{a}\n{c}")
1892
1893    def test_0_size(self):
1894        #     class ArraySubclass(np.ndarray):
1895        #         pass
1896        a = np.zeros((0, 1, 1), dtype=np.int_)  # .view(ArraySubclass)
1897        res = linalg.cholesky(a)
1898        assert_equal(a.shape, res.shape)
1899        assert_(res.dtype.type is np.float64)
1900        # for documentation purpose:
1901        assert_(isinstance(res, np.ndarray))
1902
1903        a = np.zeros((1, 0, 0), dtype=np.complex64)  # .view(ArraySubclass)
1904        res = linalg.cholesky(a)
1905        assert_equal(a.shape, res.shape)
1906        assert_(res.dtype.type is np.complex64)
1907        assert_(isinstance(res, np.ndarray))
1908
1909
1910class TestMisc(TestCase):
1911    @xpassIfTorchDynamo  # (reason="endianness")
1912    def test_byteorder_check(self):
1913        # Byte order check should pass for native order
1914        if sys.byteorder == "little":
1915            native = "<"
1916        else:
1917            native = ">"
1918
1919        for dtt in (np.float32, np.float64):
1920            arr = np.eye(4, dtype=dtt)
1921            n_arr = arr.newbyteorder(native)
1922            sw_arr = arr.newbyteorder("S").byteswap()
1923            assert_equal(arr.dtype.byteorder, "=")
1924            for routine in (linalg.inv, linalg.det, linalg.pinv):
1925                # Normal call
1926                res = routine(arr)
1927                # Native but not '='
1928                assert_array_equal(res, routine(n_arr))
1929                # Swapped
1930                assert_array_equal(res, routine(sw_arr))
1931
1932    @pytest.mark.skipif(IS_WASM, reason="fp errors don't work in wasm")
1933    def test_generalized_raise_multiloop(self):
1934        # It should raise an error even if the error doesn't occur in the
1935        # last iteration of the ufunc inner loop
1936
1937        invertible = np.array([[1, 2], [3, 4]])
1938        non_invertible = np.array([[1, 1], [1, 1]])
1939
1940        x = np.zeros([4, 4, 2, 2])[1::2]
1941        x[...] = invertible
1942        x[0, 0] = non_invertible
1943
1944        assert_raises(np.linalg.LinAlgError, np.linalg.inv, x)
1945
1946    def test_xerbla_override(self):
1947        # Check that our xerbla has been successfully linked in. If it is not,
1948        # the default xerbla routine is called, which prints a message to stdout
1949        # and may, or may not, abort the process depending on the LAPACK package.
1950
1951        XERBLA_OK = 255
1952
1953        try:
1954            pid = os.fork()
1955        except (OSError, AttributeError):
1956            # fork failed, or not running on POSIX
1957            raise SkipTest("Not POSIX or fork failed.")  # noqa: B904
1958
1959        if pid == 0:
1960            # child; close i/o file handles
1961            os.close(1)
1962            os.close(0)
1963            # Avoid producing core files.
1964            import resource
1965
1966            resource.setrlimit(resource.RLIMIT_CORE, (0, 0))
1967            # These calls may abort.
1968            try:
1969                np.linalg.lapack_lite.xerbla()
1970            except ValueError:
1971                pass
1972            except Exception:
1973                os._exit(os.EX_CONFIG)
1974
1975            try:
1976                a = np.array([[1.0]])
1977                np.linalg.lapack_lite.dorgqr(
1978                    1, 1, 1, a, 0, a, a, 0, 0
1979                )  # <- invalid value
1980            except ValueError as e:
1981                if "DORGQR parameter number 5" in str(e):
1982                    # success, reuse error code to mark success as
1983                    # FORTRAN STOP returns as success.
1984                    os._exit(XERBLA_OK)
1985
1986            # Did not abort, but our xerbla was not linked in.
1987            os._exit(os.EX_CONFIG)
1988        else:
1989            # parent
1990            pid, status = os.wait()
1991            if os.WEXITSTATUS(status) != XERBLA_OK:
1992                raise SkipTest("Numpy xerbla not linked in.")
1993
1994    @pytest.mark.skipif(IS_WASM, reason="Cannot start subprocess")
1995    @slow
1996    def test_sdot_bug_8577(self):
1997        # Regression test that loading certain other libraries does not
1998        # result to wrong results in float32 linear algebra.
1999        #
2000        # There's a bug gh-8577 on OSX that can trigger this, and perhaps
2001        # there are also other situations in which it occurs.
2002        #
2003        # Do the check in a separate process.
2004
2005        bad_libs = ["PyQt5.QtWidgets", "IPython"]
2006
2007        template = textwrap.dedent(
2008            """
2009        import sys
2010        {before}
2011        try:
2012            import {bad_lib}
2013        except ImportError:
2014            sys.exit(0)
2015        {after}
2016        x = np.ones(2, dtype=np.float32)
2017        sys.exit(0 if np.allclose(x.dot(x), 2.0) else 1)
2018        """
2019        )
2020
2021        for bad_lib in bad_libs:
2022            code = template.format(
2023                before="import numpy as np", after="", bad_lib=bad_lib
2024            )
2025            subprocess.check_call([sys.executable, "-c", code])
2026
2027            # Swapped import order
2028            code = template.format(
2029                after="import numpy as np", before="", bad_lib=bad_lib
2030            )
2031            subprocess.check_call([sys.executable, "-c", code])
2032
2033
2034class TestMultiDot(TestCase):
2035    def test_basic_function_with_three_arguments(self):
2036        # multi_dot with three arguments uses a fast hand coded algorithm to
2037        # determine the optimal order. Therefore test it separately.
2038        A = np.random.random((6, 2))
2039        B = np.random.random((2, 6))
2040        C = np.random.random((6, 2))
2041
2042        assert_almost_equal(multi_dot([A, B, C]), A.dot(B).dot(C))
2043        assert_almost_equal(multi_dot([A, B, C]), np.dot(A, np.dot(B, C)))
2044
2045    def test_basic_function_with_two_arguments(self):
2046        # separate code path with two arguments
2047        A = np.random.random((6, 2))
2048        B = np.random.random((2, 6))
2049
2050        assert_almost_equal(multi_dot([A, B]), A.dot(B))
2051        assert_almost_equal(multi_dot([A, B]), np.dot(A, B))
2052
2053    def test_basic_function_with_dynamic_programming_optimization(self):
2054        # multi_dot with four or more arguments uses the dynamic programming
2055        # optimization and therefore deserve a separate
2056        A = np.random.random((6, 2))
2057        B = np.random.random((2, 6))
2058        C = np.random.random((6, 2))
2059        D = np.random.random((2, 1))
2060        assert_almost_equal(multi_dot([A, B, C, D]), A.dot(B).dot(C).dot(D))
2061
2062    def test_vector_as_first_argument(self):
2063        # The first argument can be 1-D
2064        A1d = np.random.random(2)  # 1-D
2065        B = np.random.random((2, 6))
2066        C = np.random.random((6, 2))
2067        D = np.random.random((2, 2))
2068
2069        # the result should be 1-D
2070        assert_equal(multi_dot([A1d, B, C, D]).shape, (2,))
2071
2072    def test_vector_as_last_argument(self):
2073        # The last argument can be 1-D
2074        A = np.random.random((6, 2))
2075        B = np.random.random((2, 6))
2076        C = np.random.random((6, 2))
2077        D1d = np.random.random(2)  # 1-D
2078
2079        # the result should be 1-D
2080        assert_equal(multi_dot([A, B, C, D1d]).shape, (6,))
2081
2082    def test_vector_as_first_and_last_argument(self):
2083        # The first and last arguments can be 1-D
2084        A1d = np.random.random(2)  # 1-D
2085        B = np.random.random((2, 6))
2086        C = np.random.random((6, 2))
2087        D1d = np.random.random(2)  # 1-D
2088
2089        # the result should be a scalar
2090        assert_equal(multi_dot([A1d, B, C, D1d]).shape, ())
2091
2092    def test_three_arguments_and_out(self):
2093        # multi_dot with three arguments uses a fast hand coded algorithm to
2094        # determine the optimal order. Therefore test it separately.
2095        A = np.random.random((6, 2))
2096        B = np.random.random((2, 6))
2097        C = np.random.random((6, 2))
2098
2099        out = np.zeros((6, 2))
2100        ret = multi_dot([A, B, C], out=out)
2101        assert out is ret
2102        assert_almost_equal(out, A.dot(B).dot(C))
2103        assert_almost_equal(out, np.dot(A, np.dot(B, C)))
2104
2105    def test_two_arguments_and_out(self):
2106        # separate code path with two arguments
2107        A = np.random.random((6, 2))
2108        B = np.random.random((2, 6))
2109        out = np.zeros((6, 6))
2110        ret = multi_dot([A, B], out=out)
2111        assert out is ret
2112        assert_almost_equal(out, A.dot(B))
2113        assert_almost_equal(out, np.dot(A, B))
2114
2115    def test_dynamic_programming_optimization_and_out(self):
2116        # multi_dot with four or more arguments uses the dynamic programming
2117        # optimization and therefore deserve a separate test
2118        A = np.random.random((6, 2))
2119        B = np.random.random((2, 6))
2120        C = np.random.random((6, 2))
2121        D = np.random.random((2, 1))
2122        out = np.zeros((6, 1))
2123        ret = multi_dot([A, B, C, D], out=out)
2124        assert out is ret
2125        assert_almost_equal(out, A.dot(B).dot(C).dot(D))
2126
2127    def test_dynamic_programming_logic(self):
2128        # Test for the dynamic programming part
2129        # This test is directly taken from Cormen page 376.
2130        arrays = [
2131            np.random.random((30, 35)),
2132            np.random.random((35, 15)),
2133            np.random.random((15, 5)),
2134            np.random.random((5, 10)),
2135            np.random.random((10, 20)),
2136            np.random.random((20, 25)),
2137        ]
2138        m_expected = np.array(
2139            [
2140                [0.0, 15750.0, 7875.0, 9375.0, 11875.0, 15125.0],
2141                [0.0, 0.0, 2625.0, 4375.0, 7125.0, 10500.0],
2142                [0.0, 0.0, 0.0, 750.0, 2500.0, 5375.0],
2143                [0.0, 0.0, 0.0, 0.0, 1000.0, 3500.0],
2144                [0.0, 0.0, 0.0, 0.0, 0.0, 5000.0],
2145                [0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
2146            ]
2147        )
2148        s_expected = np.array(
2149            [
2150                [0, 1, 1, 3, 3, 3],
2151                [0, 0, 2, 3, 3, 3],
2152                [0, 0, 0, 3, 3, 3],
2153                [0, 0, 0, 0, 4, 5],
2154                [0, 0, 0, 0, 0, 5],
2155                [0, 0, 0, 0, 0, 0],
2156            ],
2157            dtype=int,
2158        )
2159        s_expected -= 1  # Cormen uses 1-based index, python does not.
2160
2161        s, m = _multi_dot_matrix_chain_order(arrays, return_costs=True)
2162
2163        # Only the upper triangular part (without the diagonal) is interesting.
2164        assert_almost_equal(np.triu(s[:-1, 1:]), np.triu(s_expected[:-1, 1:]))
2165        assert_almost_equal(np.triu(m), np.triu(m_expected))
2166
2167    def test_too_few_input_arrays(self):
2168        assert_raises((RuntimeError, ValueError), multi_dot, [])
2169        assert_raises((RuntimeError, ValueError), multi_dot, [np.random.random((3, 3))])
2170
2171
2172@instantiate_parametrized_tests
2173class TestTensorinv(TestCase):
2174    @parametrize(
2175        "arr, ind",
2176        [
2177            (np.ones((4, 6, 8, 2)), 2),
2178            (np.ones((3, 3, 2)), 1),
2179        ],
2180    )
2181    def test_non_square_handling(self, arr, ind):
2182        with assert_raises((LinAlgError, RuntimeError)):
2183            linalg.tensorinv(arr, ind=ind)
2184
2185    @parametrize(
2186        "shape, ind",
2187        [
2188            # examples from docstring
2189            ((4, 6, 8, 3), 2),
2190            ((24, 8, 3), 1),
2191        ],
2192    )
2193    def test_tensorinv_shape(self, shape, ind):
2194        a = np.eye(24).reshape(shape)
2195        ainv = linalg.tensorinv(a=a, ind=ind)
2196        expected = a.shape[ind:] + a.shape[:ind]
2197        actual = ainv.shape
2198        assert_equal(actual, expected)
2199
2200    @parametrize(
2201        "ind",
2202        [
2203            0,
2204            -2,
2205        ],
2206    )
2207    def test_tensorinv_ind_limit(self, ind):
2208        a = np.eye(24).reshape(4, 6, 8, 3)
2209        with assert_raises((ValueError, RuntimeError)):
2210            linalg.tensorinv(a=a, ind=ind)
2211
2212    def test_tensorinv_result(self):
2213        # mimic a docstring example
2214        a = np.eye(24).reshape(24, 8, 3)
2215        ainv = linalg.tensorinv(a, ind=1)
2216        b = np.ones(24)
2217        assert_allclose(np.tensordot(ainv, b, 1), np.linalg.tensorsolve(a, b))
2218
2219
2220@instantiate_parametrized_tests
2221class TestTensorsolve(TestCase):
2222    @parametrize(
2223        "a, axes",
2224        [
2225            (np.ones((4, 6, 8, 2)), None),
2226            (np.ones((3, 3, 2)), (0, 2)),
2227        ],
2228    )
2229    def test_non_square_handling(self, a, axes):
2230        with assert_raises((LinAlgError, RuntimeError)):
2231            b = np.ones(a.shape[:2])
2232            linalg.tensorsolve(a, b, axes=axes)
2233
2234    @skipif(numpy.__version__ < "1.22", reason="NP_VER: fails on CI with numpy 1.21.2")
2235    @parametrize(
2236        "shape",
2237        [(2, 3, 6), (3, 4, 4, 3), (0, 3, 3, 0)],
2238    )
2239    def test_tensorsolve_result(self, shape):
2240        a = np.random.randn(*shape)
2241        b = np.ones(a.shape[:2])
2242        x = np.linalg.tensorsolve(a, b)
2243        assert_allclose(np.tensordot(a, x, axes=len(x.shape)), b)
2244
2245
2246class TestMisc2(TestCase):
2247    @xpassIfTorchDynamo  # (reason="TODO")
2248    def test_unsupported_commontype(self):
2249        # linalg gracefully handles unsupported type
2250        arr = np.array([[1, -2], [2, 5]], dtype="float16")
2251        # with assert_raises_regex(TypeError, "unsupported in linalg"):
2252        with assert_raises(TypeError):
2253            linalg.cholesky(arr)
2254
2255    # @slow
2256    # @pytest.mark.xfail(not HAS_LAPACK64, run=False,
2257    #                   reason="Numpy not compiled with 64-bit BLAS/LAPACK")
2258    # @requires_memory(free_bytes=16e9)
2259    @skip(reason="Bad memory reports lead to OOM in ci testing")
2260    def test_blas64_dot(self):
2261        n = 2**32
2262        a = np.zeros([1, n], dtype=np.float32)
2263        b = np.ones([1, 1], dtype=np.float32)
2264        a[0, -1] = 1
2265        c = np.dot(b, a)
2266        assert_equal(c[0, -1], 1)
2267
2268    @skip(reason="lapack-lite specific")
2269    @xfail  # (
2270    #    not HAS_LAPACK64, reason="Numpy not compiled with 64-bit BLAS/LAPACK"
2271    # )
2272    def test_blas64_geqrf_lwork_smoketest(self):
2273        # Smoke test LAPACK geqrf lwork call with 64-bit integers
2274        dtype = np.float64
2275        lapack_routine = np.linalg.lapack_lite.dgeqrf
2276
2277        m = 2**32 + 1
2278        n = 2**32 + 1
2279        lda = m
2280
2281        # Dummy arrays, not referenced by the lapack routine, so don't
2282        # need to be of the right size
2283        a = np.zeros([1, 1], dtype=dtype)
2284        work = np.zeros([1], dtype=dtype)
2285        tau = np.zeros([1], dtype=dtype)
2286
2287        # Size query
2288        results = lapack_routine(m, n, a, lda, tau, work, -1, 0)
2289        assert_equal(results["info"], 0)
2290        assert_equal(results["m"], m)
2291        assert_equal(results["n"], m)
2292
2293        # Should result to an integer of a reasonable size
2294        lwork = int(work.item())
2295        assert_(2**32 < lwork < 2**42)
2296
2297
2298if __name__ == "__main__":
2299    run_tests()
2300