xref: /aosp_15_r20/external/pytorch/test/test_linalg.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: linear algebra"]
2
3import torch
4import numpy as np
5
6import unittest
7import itertools
8import warnings
9import math
10from math import inf, nan, isnan
11import re
12import random
13from random import randrange
14from itertools import product
15from functools import reduce, partial
16
17from torch.testing._internal.common_utils import \
18    (TestCase, run_tests, TEST_SCIPY, IS_MACOS, IS_WINDOWS, slowTest,
19     TEST_WITH_ROCM, IS_FBCODE, IS_REMOTE_GPU, iter_indices,
20     make_fullrank_matrices_with_distinct_singular_values,
21     freeze_rng_state, IS_ARM64, IS_SANDCASTLE, TEST_OPT_EINSUM, parametrize, skipIfTorchDynamo,
22     setBlasBackendsToDefaultFinally, setLinalgBackendsToDefaultFinally, serialTest)
23from torch.testing._internal.common_device_type import \
24    (instantiate_device_type_tests, dtypes, has_cusolver, has_hipsolver,
25     onlyCPU, skipCUDAIf, skipCUDAIfNoMagma, skipCPUIfNoLapack, precisionOverride,
26     skipCUDAIfNoMagmaAndNoCusolver, skipCUDAIfRocm, onlyNativeDeviceTypes, dtypesIfCUDA,
27     onlyCUDA, skipCUDAVersionIn, skipMeta, skipCUDAIfNoCusolver, skipCUDAIfNotRocm,
28     dtypesIfMPS, largeTensorTest)
29from torch.testing import make_tensor
30from torch.testing._internal.common_dtype import (
31    all_types, all_types_and_complex_and, floating_and_complex_types, integral_types,
32    floating_and_complex_types_and, floating_types_and, complex_types,
33)
34from torch.testing._internal.common_cuda import SM53OrLater, SM80OrLater, SM90OrLater, tf32_on_and_off, _get_magma_version, \
35    _get_torch_cuda_version, CDNA2OrLater
36from torch.testing._internal.common_quantization import _group_quantize_tensor, _dynamically_quantize_per_channel
37from torch.testing._internal.common_mkldnn import bf32_on_and_off
38from torch.distributions.binomial import Binomial
39import torch.backends.opt_einsum as opt_einsum
40import operator
41
42# Protects against includes accidentally setting the default dtype
43assert torch.get_default_dtype() is torch.float32
44
45if TEST_SCIPY:
46    import scipy
47
48def blaslt_supported_device():
49    if torch.cuda.is_available():
50        if torch.version.hip:
51            for arch in ['gfx90a', 'gfx94']:
52                if arch in torch.cuda.get_device_properties(0).gcnArchName:
53                    return True
54        else:
55            return True
56    return False
57
58def set_tunableop_defaults():
59    if not torch.cuda.is_available():
60        # TunableOp not supported on CPU at this time.
61        return
62
63    # disable TunableOp and restore to default values
64    ordinal = torch.cuda.current_device()
65    filename = f"tunableop_results{ordinal}.csv"
66    torch.cuda.tunable.enable(False)
67    torch.cuda.tunable.tuning_enable(True)
68    torch.cuda.tunable.set_filename(filename)  # reset back to default filename for next unit test
69    torch.cuda.tunable.set_max_tuning_duration(30)
70    torch.cuda.tunable.set_max_tuning_iterations(100)
71
72
73class TestLinalg(TestCase):
74    def setUp(self):
75        super(self.__class__, self).setUp()
76        torch.backends.cuda.matmul.allow_tf32 = False
77
78    def tearDown(self):
79        torch.backends.cuda.matmul.allow_tf32 = True
80        super(self.__class__, self).tearDown()
81
82    exact_dtype = True
83
84    @dtypes(torch.float, torch.cfloat)
85    @precisionOverride({torch.float: 1e-06, torch.cfloat: 1e-06})
86    @tf32_on_and_off(5e-3)
87    @bf32_on_and_off(5e-3)
88    def test_inner(self, device, dtype):
89        def check(a_sizes_, b_sizes_):
90            for a_sizes, b_sizes in ((a_sizes_, b_sizes_), (b_sizes_, a_sizes_)):
91                a = torch.randn(a_sizes, dtype=dtype, device=device)
92                b = torch.randn(b_sizes, dtype=dtype, device=device)
93                res = torch.inner(a, b)
94                ref = np.inner(a.cpu().numpy(), b.cpu().numpy())
95                self.assertEqual(res.cpu(), torch.from_numpy(np.array(ref)))
96                out = torch.zeros_like(res)
97                torch.inner(a, b, out=out)
98                self.assertEqual(res, out)
99
100        check([], [])                       # scalar x scalar
101        check([], [0])                      # scalar x empty
102        check([], [3])                      # scalar x 1D
103        check([], [2, 3, 4])                # scalar x 3D
104
105        check([0], [0])                     # empty x empty
106        check([0], [2, 0])                  # empty x 2D
107
108        check([2], [2])                     # 1D x 1D
109        check([2], [3, 1, 2])               # 1D x 3D
110        check([2], [3, 0, 2])               # 1D x 3D empty
111
112        check([1, 2], [3, 2])               # 2D x 2D
113        check([1, 2], [3, 4, 2])            # 2D x 3D
114        check([2, 1, 3, 2], [1, 3, 2, 2])   # 4D x 4D
115
116        # Test error message
117        with self.assertRaisesRegex(RuntimeError,
118                                    r"inner\(\) the last dimension must match on both "
119                                    r"input tensors but got shapes \[2, 3\] and \[2, 2\]"):
120            torch.randn(2, 3, device=device, dtype=dtype).inner(torch.randn(2, 2, device=device, dtype=dtype))
121
122    # Tests torch.outer, and its alias, torch.ger, vs. NumPy
123    @precisionOverride({torch.bfloat16: 1e-1})
124    @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
125    def test_outer(self, device, dtype):
126        def run_test_case(a, b):
127            if dtype == torch.bfloat16:
128                a_np = a.to(torch.double).cpu().numpy()
129                b_np = b.to(torch.double).cpu().numpy()
130                exact_dtype = False
131            else:
132                a_np = a.cpu().numpy()
133                b_np = b.cpu().numpy()
134                exact_dtype = True
135            expected = np.outer(a_np, b_np)
136
137            self.assertEqual(torch.outer(a, b), expected, exact_dtype=False)
138            self.assertEqual(torch.Tensor.outer(a, b), expected, exact_dtype=False)
139
140            self.assertEqual(torch.ger(a, b), expected, exact_dtype=False)
141            self.assertEqual(torch.Tensor.ger(a, b), expected, exact_dtype=False)
142
143            # test out variant
144            out = torch.empty(a.size(0), b.size(0), device=device, dtype=dtype)
145            torch.outer(a, b, out=out)
146            self.assertEqual(out, expected, exact_dtype=False)
147
148            out = torch.empty(a.size(0), b.size(0), device=device, dtype=dtype)
149            torch.ger(a, b, out=out)
150            self.assertEqual(out, expected, exact_dtype=False)
151
152        a = torch.randn(50).to(device=device, dtype=dtype)
153        b = torch.randn(50).to(device=device, dtype=dtype)
154        run_test_case(a, b)
155
156        # test 0 strided tensor
157        zero_strided = torch.randn(1).to(device=device, dtype=dtype).expand(50)
158        run_test_case(zero_strided, b)
159        run_test_case(a, zero_strided)
160
161    def test_matrix_rank_removed_error(self, device):
162        a = make_tensor(5, 5, device=device, dtype=torch.float32)
163        with self.assertRaisesRegex(RuntimeError, "This function was deprecated since version 1.9 and is now removed"):
164            torch.matrix_rank(a)
165
166    def test_solve_removed_error(self, device):
167        a = make_tensor(5, 5, device=device, dtype=torch.float32)
168        b = make_tensor(5, 1, device=device, dtype=torch.float32)
169        with self.assertRaisesRegex(RuntimeError, "This function was deprecated since version 1.9 and is now removed"):
170            torch.solve(b, a)
171        with self.assertRaisesRegex(RuntimeError, "This function was deprecated since version 1.9 and is now removed"):
172            b.solve(a)
173
174    def test_eig_removed_error(self, device):
175        a = make_tensor(5, 5, device=device, dtype=torch.float32)
176        with self.assertRaisesRegex(RuntimeError, "This function was deprecated since version 1.9 and is now removed"):
177            torch.eig(a)
178        with self.assertRaisesRegex(RuntimeError, "This function was deprecated since version 1.9 and is now removed"):
179            a.eig()
180
181    def test_symeig_removed_error(self, device):
182        a = make_tensor(5, 5, device=device, dtype=torch.float32)
183        with self.assertRaisesRegex(RuntimeError, "This function was deprecated since version 1.9 and is now removed"):
184            torch.symeig(a)
185        with self.assertRaisesRegex(RuntimeError, "This function was deprecated since version 1.9 and is now removed"):
186            a.symeig()
187
188    def test_lstsq_removed_error(self, device):
189        a = make_tensor(5, 5, device=device, dtype=torch.float32)
190        with self.assertRaisesRegex(RuntimeError, "This function was deprecated since version 1.9 and is now removed"):
191            torch.lstsq(a, a)
192        with self.assertRaisesRegex(RuntimeError, "This function was deprecated since version 1.9 and is now removed"):
193            a.lstsq(a)
194
195    @skipCUDAIfNoMagma
196    @skipCPUIfNoLapack
197    @skipIfTorchDynamo("flaky, needs investigation")
198    @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble)
199    def test_linalg_lstsq(self, device, dtype):
200        from torch.testing._internal.common_utils import random_well_conditioned_matrix
201        if self.device_type == 'cpu':
202            drivers = ('gels', 'gelsy', 'gelsd', 'gelss', None)
203        else:
204            drivers = ('gels', None)
205
206        def check_solution_correctness(a, b, sol):
207            sol2 = a.pinverse() @ b
208            self.assertEqual(sol, sol2, atol=1e-5, rtol=1e-5)
209
210        def check_correctness_ref(a, b, res, ref, driver="default"):
211            def apply_if_not_empty(t, f):
212                if t.numel():
213                    return f(t)
214                else:
215                    return t
216
217            def select_if_not_empty(t, i):
218                selected = apply_if_not_empty(t, lambda x: x.select(0, i))
219                return selected
220
221            m = a.size(-2)
222            n = a.size(-1)
223            nrhs = b.size(-1)
224            batch_size = int(np.prod(a.shape[:-2]))
225            if batch_size == 0:
226                batch_size = 1
227            a_3d = a.view(batch_size, m, n)
228            b_3d = b.view(batch_size, m, nrhs)
229
230            solution_3d = res.solution.view(batch_size, n, nrhs)
231            residuals_2d = apply_if_not_empty(res.residuals, lambda t: t.view(-1, nrhs))
232            rank_1d = apply_if_not_empty(res.rank, lambda t: t.view(-1))
233            singular_values_2d = res.singular_values.view(batch_size, res.singular_values.shape[-1])
234
235            if a.numel() > 0:
236                for i in range(batch_size):
237                    sol, residuals, rank, singular_values = ref(
238                        a_3d.select(0, i).numpy(),
239                        b_3d.select(0, i).numpy()
240                    )
241                    # Singular values are None when lapack_driver='gelsy' in SciPy
242                    if singular_values is None:
243                        singular_values = []
244                    self.assertEqual(sol, solution_3d.select(0, i), atol=1e-5, rtol=1e-5)
245                    self.assertEqual(rank, select_if_not_empty(rank_1d, i), atol=1e-5, rtol=1e-5)
246                    self.assertEqual(singular_values, singular_values_2d.select(0, i), atol=1e-5, rtol=1e-5)
247
248                    # SciPy and NumPy operate only on non-batched input and
249                    # return an empty array with shape (0,) if rank(a) != n
250                    # in PyTorch the batched inputs are supported and
251                    # matrices in the batched input can have different ranks
252                    # we compute residuals only if all matrices have rank == n
253                    # see https://github.com/pytorch/pytorch/issues/56483
254                    if m > n:
255                        if torch.all(rank_1d == n):
256                            self.assertEqual(
257                                residuals, select_if_not_empty(residuals_2d, i), atol=1e-5, rtol=1e-5, exact_dtype=False
258                            )
259                        else:
260                            self.assertTrue(residuals_2d.numel() == 0)
261
262            else:
263                self.assertEqual(res.solution.shape, (*a.shape[:-2], n, nrhs))
264                self.assertEqual(res.rank.shape, a.shape[:-2])
265
266                # residuals are not always computed (and have non-zero shape)
267                if m > n and driver != "gelsy":
268                    self.assertEqual(res.residuals.shape, (*a.shape[:-2], 0))
269                else:
270                    self.assertEqual(res.residuals.shape, (0, ))
271
272                # singular_values are not always computed (and have non-zero shape)
273                if driver == "default" or driver == "gelsd" or driver == "gelss":
274                    self.assertEqual(res.singular_values.shape, (*a.shape[:-2], min(m, n)))
275                else:
276                    self.assertEqual(res.singular_values.shape, (0, ))
277
278        def check_correctness_scipy(a, b, res, driver, cond):
279            # SciPy provides 3 driver options: gelsd, gelss, gelsy
280            if TEST_SCIPY and driver in ('gelsd', 'gelss', 'gelsy'):
281                import scipy.linalg
282
283                def scipy_ref(a, b):
284                    return scipy.linalg.lstsq(a, b, lapack_driver=driver, cond=cond)
285                check_correctness_ref(a, b, res, scipy_ref, driver=driver)
286
287        def check_correctness_numpy(a, b, res, driver, rcond):
288            # NumPy uses only gelsd routine
289            if driver == 'gelsd':
290
291                def numpy_ref(a, b):
292                    return np.linalg.lstsq(a, b, rcond=rcond)
293                check_correctness_ref(a, b, res, numpy_ref)
294
295        ms = [2 ** i for i in range(5)]
296        m_ge_n_sizes = [(m, m // 2) for m in ms] + [(m, m) for m in ms]
297        # cases m < n are only supported on CPU and for cuSOLVER path on CUDA
298        m_l_n_sizes = [(m // 2, m) for m in ms]
299        include_m_l_n_case = (has_cusolver() or device == 'cpu')
300        matrix_sizes = m_ge_n_sizes + (m_l_n_sizes if include_m_l_n_case else [])
301        batches = [(), (2,), (2, 2), (2, 2, 2)]
302        # we generate matrices with singular values sampled from a normal distribution,
303        # that is why we use `cond=1.0`, the mean to cut roughly half of all
304        # the singular values and compare whether torch.linalg.lstsq agrees with
305        # SciPy and NumPy.
306        # if rcond is True then set value for it based on the used algorithm
307        # rcond == -1 or any other negative value forces LAPACK to use machine precision tolerance
308        rconds = (None, True, -1)
309
310        for batch, matrix_size, driver, rcond in itertools.product(batches, matrix_sizes, drivers, rconds):
311            # keep the rcond value if it is None or -1, set the driver specific value if it is True
312            if rcond and rcond != -1:
313                if driver in ('gelss', 'gelsd'):
314                    # SVD based algorithm; set to zero roughly half of all the singular values
315                    rcond = 1.0
316                else:
317                    # driver == 'gelsy'
318                    # QR based algorithm; setting the value too high might lead to non-unique solutions and flaky tests
319                    # so we skip this case
320                    continue
321
322            # specifying rcond value has no effect for gels driver so no need to run the tests again
323            if driver == 'gels' and rcond is not None:
324                continue
325
326            shape = batch + matrix_size
327            a = random_well_conditioned_matrix(*shape, dtype=dtype, device=device)
328            b = torch.rand(*shape, dtype=dtype, device=device)
329
330            m = a.size(-2)
331            n = a.size(-1)
332            res = torch.linalg.lstsq(a, b, rcond=rcond, driver=driver)
333            sol = res.solution
334
335            # Only checks gelsd, gelss, gelsy drivers
336            check_correctness_scipy(a, b, res, driver, rcond)
337
338            # Only checks gelsd driver
339            check_correctness_numpy(a, b, res, driver, rcond)
340
341            # gels driver is not checked by comparing to NumPy or SciPy implementation
342            # because NumPy and SciPy do not implement this driver
343            if driver == 'gels' and rcond is None:
344                check_solution_correctness(a, b, sol)
345
346    @skipCUDAIfNoMagma
347    @skipCPUIfNoLapack
348    @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble)
349    def test_linalg_lstsq_batch_broadcasting(self, device, dtype):
350        from torch.testing._internal.common_utils import random_well_conditioned_matrix
351
352        def check_correctness(a, b):
353            sol = torch.linalg.lstsq(a, b).solution
354            sol2 = a.pinverse() @ b
355            self.assertEqual(sol, sol2, rtol=1e-5, atol=1e-5)
356
357        ms = [2 ** i for i in range(5)]
358        batches = [(), (0,), (2,), (2, 2), (2, 2, 2)]
359        # the case when a single matrix is batch-broadcasted over the rhs
360        for m, batch in itertools.product(ms, batches):
361            a = random_well_conditioned_matrix(m, m, dtype=dtype, device=device).view(*([1] * len(batch)), m, m)
362            b = torch.rand(*(batch + (m, m)), dtype=dtype, device=device)
363            check_correctness(a, b)
364
365        # cases with broadcastable shapes
366        for m in ms:
367            a = random_well_conditioned_matrix(1, 3, 1, 3, m, m, dtype=dtype, device=device)
368            b = torch.rand(3, 1, 3, 1, m, m // 2, dtype=dtype, device=device)
369            check_correctness(a, b)
370
371            # rhs are vectors, not matrices in this test
372            b = torch.rand(3, 1, 3, 1, m, dtype=dtype, device=device)
373            # unsqueeze for b because `check_correctness` checks against
374            # a.pinverse() @ b, which requires b to be a matrix
375            check_correctness(a, b.unsqueeze(-1))
376
377            a = random_well_conditioned_matrix(3, 1, 3, 1, m, m, dtype=dtype, device=device)
378            b = torch.rand(1, 3, 1, 3, m, m // 2, dtype=dtype, device=device)
379            check_correctness(a, b)
380
381            # rhs are vectors, not matrices in this test
382            b = torch.rand(1, 3, 1, 3, m, dtype=dtype, device=device)
383            check_correctness(a, b.unsqueeze(-1))
384
385    @skipCPUIfNoLapack
386    @skipCUDAIfNoMagma
387    @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble)
388    def test_linalg_lstsq_input_checks(self, device, dtype):
389        # check empty inputs
390        # empty batches
391        a = torch.rand(0, 0, 3, 3, dtype=dtype, device=device)
392        b = torch.rand(0, 0, 3, 2, dtype=dtype, device=device)
393        self.assertEqual(
394            torch.linalg.lstsq(a, b)[0],
395            torch.zeros(0, 0, 3, 2, dtype=dtype, device=device)
396        )
397        # empty a and b
398        a = torch.rand(2, 2, 0, 0, dtype=dtype, device=device)
399        b = torch.rand(2, 2, 0, 0, dtype=dtype, device=device)
400        self.assertEqual(
401            torch.linalg.lstsq(a, b)[0],
402            torch.zeros(2, 2, 0, 0, dtype=dtype, device=device)
403        )
404        # empty a and b
405        a = torch.rand(2, 2, 3, 0, dtype=dtype, device=device)
406        b = torch.rand(2, 2, 3, 0, dtype=dtype, device=device)
407        self.assertEqual(
408            torch.linalg.lstsq(a, b)[0],
409            torch.zeros(2, 2, 0, 0, dtype=dtype, device=device)
410        )
411        # empty a but not b
412        a = torch.rand(2, 2, 3, 0, dtype=dtype, device=device)
413        b = torch.rand(2, 2, 3, 2, dtype=dtype, device=device)
414        self.assertEqual(
415            torch.linalg.lstsq(a, b)[0],
416            torch.zeros(2, 2, 0, 2, dtype=dtype, device=device)
417        )
418
419        # empty a and b
420        if torch.device(device).type == 'cpu':
421            # only CPU since CUDA does not support overdetermined systems
422            a = torch.rand(2, 2, 0, 3, dtype=dtype, device=device)
423            b = torch.rand(2, 2, 0, 3, dtype=dtype, device=device)
424            self.assertEqual(
425                torch.linalg.lstsq(a, b)[0],
426                torch.zeros(2, 2, 3, 3, dtype=dtype, device=device)
427            )
428
429        a = torch.rand(2, 3, dtype=dtype, device=device)
430        b = torch.rand(3, dtype=dtype, device=device)
431
432        with self.assertRaisesRegex(RuntimeError, 'input must have at least 2 dimensions'):
433            torch.linalg.lstsq(b, b)
434
435        with self.assertRaisesRegex(RuntimeError, 'other must have at least 1 dimension'):
436            torch.linalg.lstsq(a, torch.tensor(1, dtype=dtype, device=device))
437
438        with self.assertRaisesRegex(RuntimeError, r'input.size\(-2\) should match other.size\(-1\)'):
439            torch.linalg.lstsq(a, b)
440
441        with self.assertRaisesRegex(RuntimeError, r'input.size\(-2\) should match other.size\(-2\)'):
442            torch.linalg.lstsq(a, b.unsqueeze(-1))
443
444        a = torch.randn(1, 1, 1, dtype=dtype, device=device)
445        b = torch.randn(3, 1, dtype=dtype, device=device)
446
447        with self.assertRaisesRegex(RuntimeError, r'input.size\(-2\) should match other.size\(-2\)'):
448            torch.linalg.lstsq(a, b)
449
450        def complement_device(device):
451            if device == 'cpu' and torch.cuda.is_available():
452                return 'cuda'
453            else:
454                return 'cpu'
455
456        a = torch.rand(2, 2, 2, 2, dtype=dtype, device=device)
457        b = torch.rand(2, 2, 2, dtype=dtype, device=complement_device(device))
458        if a.device != b.device:
459            with self.assertRaisesRegex(RuntimeError, 'be on the same device'):
460                torch.linalg.lstsq(a, b)
461
462        b = (torch.rand(2, 2, 2, dtype=dtype, device=device) * 100).long()
463        with self.assertRaisesRegex(RuntimeError, 'the same dtype'):
464            torch.linalg.lstsq(a, b)
465
466        a = torch.rand(2, 2, 2, 2, dtype=dtype, device=device)
467        b = torch.rand(2, 2, 2, dtype=dtype, device=device)
468
469        if device != 'cpu':
470            with self.assertRaisesRegex(RuntimeError, '`driver` other than `gels` is not supported on CUDA'):
471                torch.linalg.lstsq(a, b, driver='fictitious_driver')
472        # if on cpu
473        else:
474            with self.assertRaisesRegex(RuntimeError, r'parameter `driver` should be one of \(gels, gelsy, gelsd, gelss\)'):
475                torch.linalg.lstsq(a, b, driver='fictitious_driver')
476
477        # cuSOLVER path supports underdetermined systems
478        version = torch.testing._internal.common_cuda._get_torch_cuda_version()
479        cusolver_not_available = (version < (10, 1))
480
481        if device != 'cpu' and cusolver_not_available:
482            a = torch.rand(2, 3, dtype=dtype, device=device)
483            b = torch.rand(2, 1, dtype=dtype, device=device)
484            with self.assertRaisesRegex(RuntimeError, r'only overdetermined systems'):
485                torch.linalg.lstsq(a, b)
486
487    @skipCUDAIfNoMagma
488    @skipCPUIfNoLapack
489    @dtypes(*floating_and_complex_types())
490    def test_cholesky(self, device, dtype):
491        from torch.testing._internal.common_utils import random_hermitian_pd_matrix
492
493        def run_test(shape, batch, contiguous):
494            A = random_hermitian_pd_matrix(shape, *batch, dtype=dtype, device=device)
495            if A.numel() > 0 and not contiguous:
496                A = A.mT
497                self.assertFalse(A.is_contiguous())
498            expected_L = np.linalg.cholesky(A.cpu().numpy())
499            actual_L = torch.linalg.cholesky(A)
500
501            # For fp32 individual entries in matrices can differ between PyTorch and NumPy
502            # Let's compare the norms of matrices instead
503            if A.numel() > 0 and dtype in [torch.float32, torch.complex64]:
504                # axis is specified to calculate matrix norm for batched input
505                expected_norm = np.linalg.norm(expected_L, ord=1, axis=(-2, -1))
506                actual_norm = torch.linalg.norm(actual_L, ord=1, axis=(-2, -1))
507                # Compare the norms with standard tolerances
508                self.assertEqual(actual_norm, expected_norm)
509                # and individual values with a higher tolerance
510                self.assertEqual(actual_L, expected_L, atol=1e-2, rtol=1e-5)
511            else:
512                self.assertEqual(actual_L, expected_L)
513
514        shapes = (0, 3, 5)
515        batches = ((), (3, ), (2, 2))
516        larger_input_case = [(100, (5, ), True)]
517        for shape, batch, contiguous in list(itertools.product(shapes, batches, (True, False))) + larger_input_case:
518            run_test(shape, batch, contiguous)
519
520        # check the out= variant
521        A = random_hermitian_pd_matrix(3, 3, dtype=dtype, device=device)
522        out = torch.empty_like(A)
523        ans = torch.linalg.cholesky(A, out=out)
524        self.assertEqual(ans, out)
525        expected = torch.linalg.cholesky(A)
526        self.assertEqual(expected, out)
527
528        # check the upper= variant
529        expected = torch.linalg.cholesky(A).mH
530        actual = torch.linalg.cholesky(A, upper=True)
531        self.assertEqual(expected, actual)
532
533    @skipCUDAIfNoMagma
534    @skipCPUIfNoLapack
535    @dtypes(*floating_and_complex_types())
536    def test_cholesky_errors_and_warnings(self, device, dtype):
537        from torch.testing._internal.common_utils import random_hermitian_pd_matrix
538
539        # cholesky requires the input to be a square matrix or batch of square matrices
540        A = torch.randn(2, 3, device=device, dtype=dtype)
541        with self.assertRaisesRegex(RuntimeError, r'must be batches of square matrices'):
542            torch.linalg.cholesky(A)
543        A = torch.randn(2, 2, 3, device=device, dtype=dtype)
544        with self.assertRaisesRegex(RuntimeError, r'must be batches of square matrices'):
545            torch.linalg.cholesky(A)
546        with self.assertRaisesRegex(np.linalg.LinAlgError, r'Last 2 dimensions of the array must be square'):
547            np.linalg.cholesky(A.cpu().numpy())
548
549        # cholesky requires the input to be at least 2 dimensional tensor
550        A = torch.randn(2, device=device, dtype=dtype)
551        with self.assertRaisesRegex(RuntimeError, r'must have at least 2 dimensions'):
552            torch.linalg.cholesky(A)
553        with self.assertRaisesRegex(np.linalg.LinAlgError,
554                                    r'1-dimensional array given\. Array must be at least two-dimensional'):
555            np.linalg.cholesky(A.cpu().numpy())
556
557        # if the input matrix is not positive definite, an error should be raised
558        A = torch.eye(3, 3, dtype=dtype, device=device)
559        A[-1, -1] = 0  # Now A is not positive definite
560        with self.assertRaisesRegex(torch.linalg.LinAlgError, r'minor of order 3 is not positive-definite'):
561            torch.linalg.cholesky(A)
562        with self.assertRaisesRegex(np.linalg.LinAlgError, r'Matrix is not positive definite'):
563            np.linalg.cholesky(A.cpu().numpy())
564
565        # if at least one matrix in the batch is singular, an error should be raised
566        A = torch.eye(3, 3, dtype=dtype, device=device)
567        A = A.reshape((1, 3, 3))
568        A = A.repeat(5, 1, 1)
569        A[4, -1, -1] = 0  # Now A[4] is not positive definite
570        with self.assertRaisesRegex(torch.linalg.LinAlgError, r'\(Batch element 4\): The factorization could not be completed'):
571            torch.linalg.cholesky(A)
572
573        # if out tensor with wrong shape is passed a warning is given
574        A = random_hermitian_pd_matrix(3, dtype=dtype, device=device)
575        out = torch.empty(2, 3, dtype=dtype, device=device)
576        with warnings.catch_warnings(record=True) as w:
577            # Trigger warning
578            torch.linalg.cholesky(A, out=out)
579            # Check warning occurs
580            self.assertEqual(len(w), 1)
581            self.assertTrue("An output with one or more elements was resized" in str(w[-1].message))
582
583        # dtypes should be safely castable
584        out = torch.empty(*A.shape, dtype=torch.int, device=device)
585        with self.assertRaisesRegex(RuntimeError, "but got int instead"):
586            torch.linalg.cholesky(A, out=out)
587
588        # device should match
589        if torch.cuda.is_available():
590            wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda'
591            out = torch.empty(0, device=wrong_device, dtype=dtype)
592            with self.assertRaisesRegex(RuntimeError, "Expected all tensors to be on the same device"):
593                torch.linalg.cholesky(A, out=out)
594
595    # NOTE: old_cholesky* tests were moved here from test_torch.py and test_autograd.py
596    @slowTest
597    @skipCUDAIfNoMagma
598    @skipCPUIfNoLapack
599    @dtypes(torch.double)
600    def test_old_cholesky_batched_many_batches(self, device, dtype):
601        from torch.testing._internal.common_utils import random_symmetric_pd_matrix
602
603        def cholesky_test_helper(n, batchsize, device, upper):
604            A = random_symmetric_pd_matrix(n, batchsize, dtype=dtype, device=device)
605            chol_fact = torch.cholesky(A, upper=upper)
606            if upper:
607                # Correctness check
608                self.assertEqual(A, chol_fact.mT.matmul(chol_fact))
609                # Upper triangular check
610                self.assertEqual(chol_fact, chol_fact.triu())
611            else:
612                # Correctness check
613                self.assertEqual(A, chol_fact.matmul(chol_fact.mT))
614                # Lower triangular check
615                self.assertEqual(chol_fact, chol_fact.tril())
616
617        for upper, batchsize in itertools.product([True, False], [262144, 524288]):
618            cholesky_test_helper(2, batchsize, device, upper)
619
620    @precisionOverride({torch.float32: 1e-4, torch.complex64: 1e-4})
621    @skipCUDAIfNoMagma
622    @skipCPUIfNoLapack
623    @dtypes(*floating_and_complex_types())
624    def test_old_cholesky_batched(self, device, dtype):
625        from torch.testing._internal.common_utils import random_hermitian_pd_matrix
626
627        def cholesky_test_helper(n, batch_dims, upper):
628            A = random_hermitian_pd_matrix(n, *batch_dims, dtype=dtype, device=device)
629            cholesky_exp = torch.stack([m.cholesky(upper=upper) for m in A.reshape(-1, n, n)])
630            cholesky_exp = cholesky_exp.reshape_as(A)
631            self.assertEqual(cholesky_exp, torch.cholesky(A, upper=upper))
632
633        for upper, batchsize in itertools.product([True, False], [(3,), (3, 4), (2, 3, 4)]):
634            cholesky_test_helper(3, batchsize, upper)
635
636    @precisionOverride({torch.float32: 1e-4, torch.complex64: 1e-4})
637    @skipCUDAIfNoMagma
638    @skipCPUIfNoLapack
639    @dtypes(*floating_and_complex_types())
640    @tf32_on_and_off(0.01)
641    @bf32_on_and_off(0.01)
642    def test_old_cholesky(self, device, dtype):
643        from torch.testing._internal.common_utils import random_hermitian_pd_matrix
644
645        A = random_hermitian_pd_matrix(10, dtype=dtype, device=device)
646
647        # default Case
648        C = torch.cholesky(A)
649        B = torch.mm(C, C.t().conj())
650        self.assertEqual(A, B, atol=1e-14, rtol=0)
651
652        # test Upper Triangular
653        U = torch.cholesky(A, True)
654        B = torch.mm(U.t().conj(), U)
655        self.assertEqual(A, B, atol=1e-14, rtol=0, msg='cholesky (upper) did not allow rebuilding the original matrix')
656
657        # test Lower Triangular
658        L = torch.cholesky(A, False)
659        B = torch.mm(L, L.t().conj())
660        self.assertEqual(A, B, atol=1e-14, rtol=0, msg='cholesky (lower) did not allow rebuilding the original matrix')
661
662    @skipCUDAIfNoMagma
663    @skipCPUIfNoLapack
664    @dtypes(*floating_and_complex_types())
665    def test_old_cholesky_empty(self, device, dtype):
666        def run_test(upper):
667            A = torch.empty(0, 0, dtype=dtype, device=device)
668            chol = torch.cholesky(A, upper)
669            chol_A = torch.matmul(chol, chol.t().conj())
670            self.assertEqual(A, chol_A)
671        for upper in [True, False]:
672            run_test(upper)
673
674    # Test for issue
675    # https://github.com/pytorch/pytorch/issues/57032
676    # torch.cholesky with upper=True for batched CUDA inputs was wrong
677    # it was using the lower triangular part instead of the upper one
678    @onlyCUDA
679    @skipCUDAIfNoMagma
680    @dtypes(*floating_and_complex_types())
681    def test_old_cholesky_batched_upper(self, device, dtype):
682        from torch.testing._internal.common_utils import random_hermitian_pd_matrix
683
684        batchsize = 2
685        A = random_hermitian_pd_matrix(3, batchsize, dtype=dtype, device=device)
686        A_triu = A.triu()  # fill the lower triangular part with zero
687
688        U = torch.cholesky(A_triu, upper=True)
689
690        reconstruct_A = U.mH @ U
691        self.assertEqual(A, reconstruct_A)
692
693    @skipCUDAIfNoMagmaAndNoCusolver
694    @skipCPUIfNoLapack
695    @dtypes(*floating_and_complex_types())
696    def test_cholesky_ex(self, device, dtype):
697        from torch.testing._internal.common_utils import random_hermitian_pd_matrix
698
699        def run_test(n, batch):
700            A = random_hermitian_pd_matrix(n, *batch, dtype=dtype, device=device)
701            expected_L = np.linalg.cholesky(A.cpu().numpy())
702            expected_info = torch.zeros(A.shape[:-2], dtype=torch.int32, device=device)
703            actual_L, actual_info = torch.linalg.cholesky_ex(A)
704
705            # For fp32 individual entries in matrices can differ between PyTorch and NumPy
706            # Let's compare the norms of matrices instead
707            if A.numel() > 0 and dtype in [torch.float32, torch.complex64]:
708                # axis is specified to calculate matrix norm for batched input
709                expected_norm = np.linalg.norm(expected_L, ord=1, axis=(-2, -1))
710                actual_norm = torch.linalg.norm(actual_L, ord=1, axis=(-2, -1))
711                # Compare the norms with standard tolerances
712                self.assertEqual(actual_norm, expected_norm)
713                # and individual values with a higher tolerance
714                self.assertEqual(actual_L, expected_L, atol=1e-2, rtol=1e-5)
715            else:
716                self.assertEqual(actual_L, expected_L)
717            self.assertEqual(actual_info, expected_info)
718
719        ns = (0, 3, 5)
720        batches = ((), (2, ), (2, 1))
721        for n, batch in itertools.product(ns, batches):
722            run_test(n, batch)
723
724    @skipCUDAIfNoMagmaAndNoCusolver
725    @skipCPUIfNoLapack
726    @dtypes(*floating_and_complex_types())
727    def test_cholesky_ex_non_pd(self, device, dtype):
728        # if the input matrix is not positive definite, info with positive integer is returned
729        A = torch.eye(3, 3, dtype=dtype, device=device)
730        A[-1, -1] = 0  # Now A is singular
731        _, info = torch.linalg.cholesky_ex(A)
732        self.assertEqual(info, 3)
733        with self.assertRaisesRegex(torch.linalg.LinAlgError, r'minor of order 3 is not positive-definite'):
734            torch.linalg.cholesky_ex(A, check_errors=True)
735
736        # if at least one matrix in the batch is not positive definite,
737        # batched info with positive integer for the corresponding matrix is returned
738        A = torch.eye(3, 3, dtype=dtype, device=device)
739        A = A.reshape((1, 3, 3))
740        A = A.repeat(5, 1, 1)
741        A[3, -2, -2] = 0  # Now A[3] is singular
742        _, info = torch.linalg.cholesky_ex(A)
743
744        expected_info = torch.zeros(A.shape[:-2], dtype=torch.int32, device=device)
745        expected_info[3] = 2
746        self.assertEqual(info, expected_info)
747        with self.assertRaisesRegex(torch.linalg.LinAlgError, r'\(Batch element 3\): The factorization could not be completed'):
748            torch.linalg.cholesky_ex(A, check_errors=True)
749
750    def _test_addr_vs_numpy(self, device, dtype, beta=1, alpha=1):
751        def check(m, a, b, beta, alpha):
752            if dtype == torch.bfloat16:
753                a_np = a.to(torch.double).cpu().numpy()
754                b_np = b.to(torch.double).cpu().numpy()
755                m_np = m.to(torch.double).cpu().numpy()
756                exact_dtype = False
757            else:
758                a_np = a.cpu().numpy()
759                b_np = b.cpu().numpy()
760                m_np = m.cpu().numpy()
761                exact_dtype = True
762            if beta == 0:
763                expected = alpha * np.outer(a_np, b_np)
764            else:
765                expected = beta * m_np + alpha * np.outer(a_np, b_np)
766
767            res = torch.addr(m, a, b, beta=beta, alpha=alpha)
768            self.assertEqual(res, expected, exact_dtype=exact_dtype)
769
770            # Test out variant
771            out = torch.empty_like(res)
772            torch.addr(m, a, b, beta=beta, alpha=alpha, out=out)
773            self.assertEqual(out, expected, exact_dtype=exact_dtype)
774
775        m = make_tensor((50, 50), device=device, dtype=dtype, low=-2, high=2)
776        a = make_tensor((50,), device=device, dtype=dtype, low=-2, high=2)
777        b = make_tensor((50,), device=device, dtype=dtype, low=-2, high=2)
778
779        check(m, a, b, beta, alpha)
780
781        # test transpose
782        m_transpose = torch.transpose(m, 0, 1)
783        check(m_transpose, a, b, beta, alpha)
784
785        # test 0 strided tensor
786        zero_strided = make_tensor((1,), device=device, dtype=dtype, low=-2, high=2).expand(50)
787        check(m, zero_strided, b, beta, alpha)
788
789        # test scalar
790        m_scalar = torch.tensor(1, device=device, dtype=dtype)
791        check(m_scalar, a, b, beta, alpha)
792
793        # test nans and infs are not propagated to the output when beta == 0
794        float_and_complex_dtypes = floating_and_complex_types_and(torch.half, torch.bfloat16)
795        if beta == 0 and dtype in float_and_complex_dtypes:
796            m[0][10] = m[10][10] = m[20][20] = float('inf')
797            m[1][10] = m[11][10] = m[21][20] = float('nan')
798        check(m, a, b, 0, alpha)
799
800    @dtypes(torch.bool)
801    def test_addr_bool(self, device, dtype):
802        self._test_addr_vs_numpy(device, dtype, beta=True, alpha=False)
803        self._test_addr_vs_numpy(device, dtype, beta=False, alpha=True)
804        self._test_addr_vs_numpy(device, dtype, beta=False, alpha=False)
805        self._test_addr_vs_numpy(device, dtype, beta=True, alpha=True)
806
807    @dtypes(*integral_types())
808    def test_addr_integral(self, device, dtype):
809        with self.assertRaisesRegex(RuntimeError,
810                                    'argument beta must not be a floating point number.'):
811            self._test_addr_vs_numpy(device, dtype, beta=2., alpha=1)
812        with self.assertRaisesRegex(RuntimeError,
813                                    'argument alpha must not be a floating point number.'):
814            self._test_addr_vs_numpy(device, dtype, beta=2, alpha=1.)
815        with self.assertRaisesRegex(RuntimeError,
816                                    'Boolean beta only supported for Boolean results.'):
817            self._test_addr_vs_numpy(device, dtype, beta=True, alpha=1)
818        with self.assertRaisesRegex(RuntimeError,
819                                    'Boolean alpha only supported for Boolean results.'):
820            self._test_addr_vs_numpy(device, dtype, beta=2, alpha=True)
821
822        # when beta is zero
823        self._test_addr_vs_numpy(device, dtype, beta=0, alpha=2)
824        # when beta is not zero
825        self._test_addr_vs_numpy(device, dtype, beta=2, alpha=2)
826
827    @precisionOverride({torch.bfloat16: 1e-1})
828    @dtypes(*floating_and_complex_types_and(torch.half, torch.bfloat16))
829    def test_addr_float_and_complex(self, device, dtype):
830        with self.assertRaisesRegex(RuntimeError,
831                                    'Boolean beta only supported for Boolean results.'):
832            self._test_addr_vs_numpy(device, dtype, beta=True, alpha=1)
833        with self.assertRaisesRegex(RuntimeError,
834                                    'Boolean alpha only supported for Boolean results.'):
835            self._test_addr_vs_numpy(device, dtype, beta=2, alpha=True)
836
837        # when beta is zero
838        self._test_addr_vs_numpy(device, dtype, beta=0., alpha=2)
839        # when beta is not zero
840        self._test_addr_vs_numpy(device, dtype, beta=0.5, alpha=2)
841        if dtype in complex_types():
842            self._test_addr_vs_numpy(device, dtype, beta=(0 + 0.1j), alpha=(0.2 - 0.2j))
843
844    @dtypes(*itertools.product(all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool),
845                               all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool)))
846    def test_outer_type_promotion(self, device, dtypes):
847        a = torch.randn(5).to(device=device, dtype=dtypes[0])
848        b = torch.randn(5).to(device=device, dtype=dtypes[1])
849        for op in (torch.outer, torch.Tensor.outer, torch.ger, torch.Tensor.ger):
850            result = op(a, b)
851            self.assertEqual(result.dtype, torch.result_type(a, b))
852
853    # don't use @dtypes decorator to avoid generating ~1700 tests per device
854    def test_addr_type_promotion(self, device):
855        for dtypes0, dtypes1, dtypes2 in product(all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool), repeat=3):
856            a = make_tensor((5,), device=device, dtype=dtypes0, low=-2, high=2)
857            b = make_tensor((5,), device=device, dtype=dtypes1, low=-2, high=2)
858            m = make_tensor((5, 5), device=device, dtype=dtypes2, low=-2, high=2)
859
860            desired_dtype = torch.promote_types(torch.promote_types(dtypes0, dtypes1),
861                                                dtypes2)
862            for op in (torch.addr, torch.Tensor.addr):
863                result = op(m, a, b)
864                self.assertEqual(result.dtype, desired_dtype)
865
866    # Tests migrated from test_torch.py
867    # 1) test the shape of the result tensor when there is empty input tensor
868    # 2) test the Runtime Exception when there is scalar input tensor
869    def test_outer_ger_addr_legacy_tests(self, device):
870        for size in ((0, 0), (0, 5), (5, 0)):
871            a = torch.rand(size[0], device=device)
872            b = torch.rand(size[1], device=device)
873
874            self.assertEqual(torch.outer(a, b).shape, size)
875            self.assertEqual(torch.ger(a, b).shape, size)
876
877            m = torch.empty(size, device=device)
878            self.assertEqual(torch.addr(m, a, b).shape, size)
879
880        m = torch.randn(5, 6, device=device)
881        a = torch.randn(5, device=device)
882        b = torch.tensor(6, device=device)
883        self.assertRaises(RuntimeError, lambda: torch.outer(a, b))
884        self.assertRaises(RuntimeError, lambda: torch.outer(b, a))
885        self.assertRaises(RuntimeError, lambda: torch.ger(a, b))
886        self.assertRaises(RuntimeError, lambda: torch.ger(b, a))
887        self.assertRaises(RuntimeError, lambda: torch.addr(m, a, b))
888        self.assertRaises(RuntimeError, lambda: torch.addr(m, b, a))
889
890    # Tests torch.det and its alias, torch.linalg.det, vs. NumPy
891    @skipCUDAIfNoMagma
892    @skipCPUIfNoLapack
893    @dtypes(torch.double, torch.cdouble)
894    def test_det(self, device, dtype):
895        tensors = (
896            torch.randn((2, 2), device=device, dtype=dtype),
897            torch.randn((129, 129), device=device, dtype=dtype),
898            torch.randn((3, 52, 52), device=device, dtype=dtype),
899            torch.randn((4, 2, 26, 26), device=device, dtype=dtype))
900
901
902        ops = (torch.det, torch.Tensor.det,
903               torch.linalg.det)
904        for t in tensors:
905            expected = np.linalg.det(t.cpu().numpy())
906            for op in ops:
907                actual = op(t)
908                self.assertEqual(actual, expected)
909                self.compare_with_numpy(op, np.linalg.det, t)
910
911        # NOTE: det requires a 2D+ tensor
912        t = torch.randn(1, device=device, dtype=dtype)
913        with self.assertRaises(RuntimeError):
914            op(t)
915
916    @skipCUDAIfNoMagma
917    @skipCPUIfNoLapack
918    @dtypes(*floating_and_complex_types())
919    @precisionOverride({torch.float32: 1e-4, torch.complex64: 1e-4})
920    def test_eigh(self, device, dtype):
921        from torch.testing._internal.common_utils import random_hermitian_matrix
922
923        def run_test(shape, batch, uplo):
924            matrix = random_hermitian_matrix(shape, *batch, dtype=dtype, device=device)
925            expected_w, expected_v = np.linalg.eigh(matrix.cpu().numpy(), UPLO=uplo)
926            actual_w, actual_v = torch.linalg.eigh(matrix, UPLO=uplo)
927            self.assertEqual(actual_w, expected_w)
928            # sign of eigenvectors is not unique and therefore absolute values are compared
929            self.assertEqual(abs(actual_v), abs(expected_v))
930            # additionally we can multiply the eigenvector with a phase factor e^{i\phi} and then compare the values
931            # let's choose the convention that the first element of the eigenvectors from torch and numpy be the same
932            # for real inputs, this phase factor is plus or minus one
933            if matrix.numel() > 0:
934                phase = torch.from_numpy(expected_v[..., 0, :]).to(device=device).div(actual_v[..., 0, :])
935                actual_v_rotated = actual_v * phase.unsqueeze(-2).expand_as(actual_v)
936                self.assertEqual(actual_v_rotated, expected_v)
937
938            # check the out= variant
939            out_w = torch.empty_like(actual_w)
940            out_v = torch.empty_like(actual_v)
941            ans_w, ans_v = torch.linalg.eigh(matrix, UPLO=uplo, out=(out_w, out_v))
942            self.assertEqual(ans_w, out_w)
943            self.assertEqual(ans_v, out_v)
944            self.assertEqual(ans_w, actual_w)
945            self.assertEqual(abs(ans_v), abs(actual_v))
946
947        shapes = (0, 3, 5)
948        batches = ((), (3, ), (2, 2))
949        uplos = ["U", "L"]
950        for shape, batch, uplo in itertools.product(shapes, batches, uplos):
951            run_test(shape, batch, uplo)
952
953    @skipCUDAIfNoMagma
954    @skipCPUIfNoLapack
955    @dtypes(*floating_and_complex_types())
956    @precisionOverride({torch.float32: 1e-4, torch.complex64: 1e-4})
957    def test_eigh_lower_uplo(self, device, dtype):
958        def run_test(shape, batch, uplo):
959            # check lower case uplo
960            # use non-symmetric input to check whether uplo argument is working as intended
961            matrix = torch.randn(shape, shape, *batch, dtype=dtype, device=device)
962            expected_w, expected_v = np.linalg.eigh(matrix.cpu().numpy(), UPLO=uplo)
963            actual_w, actual_v = torch.linalg.eigh(matrix, UPLO=uplo)
964            self.assertEqual(actual_w, expected_w)
965            self.assertEqual(abs(actual_v), abs(expected_v))
966
967        uplos = ["u", "l"]
968        for uplo in uplos:
969            run_test(3, (2, 2), uplo)
970
971    @skipCUDAIfNoMagma
972    @skipCPUIfNoLapack
973    @dtypes(*floating_and_complex_types())
974    def test_eigh_errors_and_warnings(self, device, dtype):
975        from torch.testing._internal.common_utils import random_hermitian_matrix
976
977        # eigh requires a square matrix
978        t = torch.randn(2, 3, device=device, dtype=dtype)
979        with self.assertRaisesRegex(RuntimeError, "must be batches of square matrices"):
980            torch.linalg.eigh(t)
981
982        # eigh requires 'uplo' parameter to be 'U' or 'L'
983        t = torch.randn(3, 3, device=device, dtype=dtype)
984        for uplo in ["a", "wrong"]:
985            with self.assertRaisesRegex(RuntimeError, "be 'L' or 'U'"):
986                torch.linalg.eigh(t, UPLO=uplo)
987            with self.assertRaisesRegex(ValueError, "be 'L' or 'U'"):
988                np.linalg.eigh(t.cpu().numpy(), UPLO=uplo)
989
990        # if non-empty out tensor with wrong shape is passed a warning is given
991        a = random_hermitian_matrix(3, dtype=dtype, device=device)
992        real_dtype = a.real.dtype if dtype.is_complex else dtype
993        out_w = torch.empty(7, 7, dtype=real_dtype, device=device)
994        out_v = torch.empty(7, 7, dtype=dtype, device=device)
995        with warnings.catch_warnings(record=True) as w:
996            # Trigger warning
997            torch.linalg.eigh(a, out=(out_w, out_v))
998            # Check warning occurs
999            self.assertEqual(len(w), 2)
1000            self.assertTrue("An output with one or more elements was resized" in str(w[-2].message))
1001            self.assertTrue("An output with one or more elements was resized" in str(w[-1].message))
1002
1003        # dtypes should be safely castable
1004        out_w = torch.empty(0, dtype=real_dtype, device=device)
1005        out_v = torch.empty(0, dtype=torch.int, device=device)
1006        with self.assertRaisesRegex(RuntimeError, "but got int instead"):
1007            torch.linalg.eigh(a, out=(out_w, out_v))
1008
1009        out_w = torch.empty(0, dtype=torch.int, device=device)
1010        out_v = torch.empty(0, dtype=dtype, device=device)
1011        with self.assertRaisesRegex(RuntimeError, "but got int instead"):
1012            torch.linalg.eigh(a, out=(out_w, out_v))
1013
1014        # device should match
1015        if torch.cuda.is_available():
1016            wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda'
1017            out_w = torch.empty(0, device=wrong_device, dtype=dtype)
1018            out_v = torch.empty(0, device=device, dtype=dtype)
1019            with self.assertRaisesRegex(RuntimeError, "tensors to be on the same device"):
1020                torch.linalg.eigh(a, out=(out_w, out_v))
1021            out_w = torch.empty(0, device=device, dtype=dtype)
1022            out_v = torch.empty(0, device=wrong_device, dtype=dtype)
1023            with self.assertRaisesRegex(RuntimeError, "tensors to be on the same device"):
1024                torch.linalg.eigh(a, out=(out_w, out_v))
1025
1026    @skipCPUIfNoLapack
1027    @dtypes(torch.float, torch.double)
1028    @unittest.skipIf(_get_torch_cuda_version() < (12, 1), "Test is fixed on cuda 12.1 update 1.")
1029    def test_eigh_svd_illcondition_matrix_input_should_not_crash(self, device, dtype):
1030        # See https://github.com/pytorch/pytorch/issues/94772, https://github.com/pytorch/pytorch/issues/105359
1031        # This test crashes with `cusolver error: CUSOLVER_STATUS_EXECUTION_FAILED` on cuda 11.8,
1032        # but passes on cuda 12.1 update 1 or later.
1033        a = torch.ones(512, 512, dtype=dtype, device=device)
1034        a[0, 0] = 1.0e-5
1035        a[-1, -1] = 1.0e5
1036
1037        eigh_out = torch.linalg.eigh(a)
1038        svd_out = torch.linalg.svd(a)
1039
1040        # Matrix input a is too ill-conditioned.
1041        # We'll just compare the first two singular values/eigenvalues. They are 1.0e5 and 511.0
1042        # The precision override with tolerance of 1.0 makes sense since ill-conditioned inputs are hard to converge
1043        # to exact values.
1044        self.assertEqual(eigh_out.eigenvalues.sort(descending=True).values[:2], [1.0e5, 511.0], atol=1.0, rtol=1.0e-2)
1045        self.assertEqual(svd_out.S[:2], [1.0e5, 511.0], atol=1.0, rtol=1.0e-2)
1046
1047    @skipCUDAIfNoMagma
1048    @skipCPUIfNoLapack
1049    @dtypes(*floating_and_complex_types())
1050    @precisionOverride({torch.float32: 1e-4, torch.complex64: 1e-4})
1051    def test_eigvalsh(self, device, dtype):
1052        from torch.testing._internal.common_utils import random_hermitian_matrix
1053
1054        def run_test(shape, batch, uplo):
1055            matrix = random_hermitian_matrix(shape, *batch, dtype=dtype, device=device)
1056            expected_w = np.linalg.eigvalsh(matrix.cpu().numpy(), UPLO=uplo)
1057            actual_w = torch.linalg.eigvalsh(matrix, UPLO=uplo)
1058            self.assertEqual(actual_w, expected_w)
1059
1060            # check the out= variant
1061            out = torch.empty_like(actual_w)
1062            ans = torch.linalg.eigvalsh(matrix, UPLO=uplo, out=out)
1063            self.assertEqual(ans, out)
1064            self.assertEqual(ans, actual_w)
1065
1066        shapes = (0, 3, 5)
1067        batches = ((), (3, ), (2, 2))
1068        uplos = ["U", "L"]
1069        for shape, batch, uplo in itertools.product(shapes, batches, uplos):
1070            run_test(shape, batch, uplo)
1071
1072    @skipCUDAIfNoMagma
1073    @skipCPUIfNoLapack
1074    @dtypes(*floating_and_complex_types())
1075    def test_eigvalsh_errors_and_warnings(self, device, dtype):
1076        # eigvalsh requires a square matrix
1077        t = torch.randn(2, 3, device=device, dtype=dtype)
1078        with self.assertRaisesRegex(RuntimeError, "must be batches of square matrices"):
1079            torch.linalg.eigvalsh(t)
1080
1081        # eigvalsh requires 'uplo' parameter to be 'U' or 'L'
1082        t = torch.randn(3, 3, device=device, dtype=dtype)
1083        for uplo in ["a", "wrong"]:
1084            with self.assertRaisesRegex(RuntimeError, "be 'L' or 'U'"):
1085                torch.linalg.eigvalsh(t, UPLO=uplo)
1086            with self.assertRaisesRegex(ValueError, "be 'L' or 'U'"):
1087                np.linalg.eigvalsh(t.cpu().numpy(), UPLO=uplo)
1088
1089        # if non-empty out tensor with wrong shape is passed a warning is given
1090        real_dtype = t.real.dtype if dtype.is_complex else dtype
1091        out = torch.empty_like(t).to(real_dtype)
1092        with warnings.catch_warnings(record=True) as w:
1093            # Trigger warning
1094            torch.linalg.eigvalsh(t, out=out)
1095            # Check warning occurs
1096            self.assertEqual(len(w), 1)
1097            self.assertTrue("An output with one or more elements was resized" in str(w[-1].message))
1098
1099        # dtypes should be safely castable
1100        out = torch.empty(0, dtype=torch.int, device=device)
1101        with self.assertRaisesRegex(RuntimeError, "but got int instead"):
1102            torch.linalg.eigvalsh(t, out=out)
1103
1104        # device should match
1105        if torch.cuda.is_available():
1106            wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda'
1107            out = torch.empty(0, device=wrong_device, dtype=dtype)
1108            with self.assertRaisesRegex(RuntimeError, "tensors to be on the same device"):
1109                torch.linalg.eigvalsh(t, out=out)
1110
1111    @dtypes(*floating_and_complex_types())
1112    def test_kron(self, device, dtype):
1113
1114        def run_test_case(a_shape, b_shape):
1115            a = torch.rand(a_shape, dtype=dtype, device=device)
1116            b = torch.rand(b_shape, dtype=dtype, device=device)
1117
1118            expected = np.kron(a.cpu().numpy(), b.cpu().numpy())
1119            result = torch.kron(a, b)
1120            self.assertEqual(result, expected)
1121
1122            # check the out= variant
1123            out = torch.empty_like(result)
1124            ans = torch.kron(a, b, out=out)
1125            self.assertEqual(ans, out)
1126            self.assertEqual(ans, result)
1127
1128        shapes = [(4,), (2, 2), (1, 2, 3), (1, 2, 3, 3)]
1129        for a_shape, b_shape in itertools.product(shapes, reversed(shapes)):
1130            run_test_case(a_shape, b_shape)
1131
1132    @dtypes(*floating_and_complex_types())
1133    def test_kron_empty(self, device, dtype):
1134
1135        def run_test_case(empty_shape):
1136            a = torch.eye(3, dtype=dtype, device=device)
1137            b = torch.empty(empty_shape, dtype=dtype, device=device)
1138            result = torch.kron(a, b)
1139            expected = np.kron(a.cpu().numpy(), b.cpu().numpy())
1140            self.assertEqual(result, expected)
1141
1142            # NumPy doesn't work if the first argument is empty
1143            result = torch.kron(b, a)
1144            self.assertEqual(result.shape, expected.shape)
1145
1146        empty_shapes = [(0,), (2, 0), (1, 0, 3)]
1147        for empty_shape in empty_shapes:
1148            run_test_case(empty_shape)
1149
1150    @dtypes(*floating_and_complex_types())
1151    def test_kron_errors_and_warnings(self, device, dtype):
1152        # if non-empty out tensor with wrong shape is passed a warning is given
1153        a = torch.eye(3, dtype=dtype, device=device)
1154        b = torch.ones((2, 2), dtype=dtype, device=device)
1155        out = torch.empty_like(a)
1156        with warnings.catch_warnings(record=True) as w:
1157            # Trigger warning
1158            torch.kron(a, b, out=out)
1159            # Check warning occurs
1160            self.assertEqual(len(w), 1)
1161            self.assertTrue("An output with one or more elements was resized" in str(w[-1].message))
1162
1163        # dtypes should match
1164        out = torch.empty_like(a).to(torch.int)
1165        with self.assertRaisesRegex(RuntimeError, "can't be cast to the desired output type"):
1166            torch.kron(a, b, out=out)
1167
1168    # This test confirms that torch.linalg.norm's dtype argument works
1169    # as expected, according to the function's documentation
1170    @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble, torch.bfloat16, torch.float16)
1171    def test_norm_dtype(self, device, dtype):
1172        make_arg = partial(make_tensor, dtype=dtype, device=device)
1173
1174        def run_test_case(input_size, ord, keepdim, to_dtype):
1175            msg = (
1176                f'input_size={input_size}, ord={ord}, keepdim={keepdim}, '
1177                f'dtype={dtype}, to_dtype={to_dtype}')
1178            input = make_arg(input_size)
1179            result = torch.linalg.norm(input, ord, keepdim=keepdim)
1180            self.assertEqual(result.dtype, input.real.dtype, msg=msg)
1181
1182            result_out = torch.empty((0), dtype=result.dtype, device=device)
1183            torch.linalg.norm(input, ord, keepdim=keepdim, out=result_out)
1184            self.assertEqual(result, result_out, msg=msg)
1185
1186            result = torch.linalg.norm(input.to(to_dtype), ord, keepdim=keepdim)
1187            result_with_dtype = torch.linalg.norm(input, ord, keepdim=keepdim, dtype=to_dtype)
1188            self.assertEqual(result, result_with_dtype, msg=msg)
1189
1190            result_out_with_dtype = torch.empty_like(result_with_dtype)
1191            torch.linalg.norm(input, ord, keepdim=keepdim, dtype=to_dtype, out=result_out_with_dtype)
1192            self.assertEqual(result_with_dtype, result_out_with_dtype, msg=msg)
1193
1194        ord_vector = [0, 1, -1, 2, -2, 3, -3, 4.5, -4.5, inf, -inf, None]
1195
1196        # In these orders we are computing the 10-th power and 10-th root of numbers.
1197        # We avoid them for half-precision types as it makes the tests above too badly conditioned
1198        if dtype != torch.float16 and dtype != torch.bfloat16:
1199            ord_vector.extend([0.1, -0.1])
1200        ord_matrix = ['fro', 'nuc', 1, -1, 2, -2, inf, -inf, None]
1201        S = 10
1202
1203        if dtype == torch.cfloat:
1204            norm_dtypes = (torch.cfloat, torch.cdouble)
1205        elif dtype == torch.cdouble:
1206            norm_dtypes = (torch.cdouble,)
1207        elif dtype in (torch.float16, torch.bfloat16, torch.float):
1208            norm_dtypes = (torch.float, torch.double)
1209        elif dtype == torch.double:
1210            norm_dtypes = (torch.double,)
1211        else:
1212            raise RuntimeError("Unsupported dtype")
1213
1214        for ord, keepdim, norm_dtype in product(ord_vector, (True, False), norm_dtypes):
1215            run_test_case((S,) , ord, keepdim, norm_dtype)
1216
1217        for ord, keepdim, norm_dtype in product(ord_matrix, (True, False), norm_dtypes):
1218            if ord in [2, -2, 'nuc']:
1219                # We need torch.svdvals
1220                if dtype == torch.float16 or dtype == torch.bfloat16:
1221                    continue
1222
1223                # We need LAPACK or equivalent
1224                if ((torch.device(device).type == 'cuda' and not torch.cuda.has_magma and not has_cusolver()) or
1225                   (torch.device(device).type == 'cpu' and not torch._C.has_lapack)):
1226                    continue
1227            run_test_case((S, S) , ord, keepdim, norm_dtype)
1228
1229    # This test confirms torch.linalg.norm bfloat16 and half get right result.
1230    @dtypes(torch.bfloat16, torch.float16)
1231    def test_norm_bfloat16_and_half(self, device, dtype):
1232        make_arg = partial(make_tensor, dtype=dtype, device=device)
1233
1234        def run_test_case(input_size, ord, keepdim):
1235            msg = (
1236                f'input_size={input_size}, ord={ord}, keepdim={keepdim}, '
1237                f'dtype={dtype}')
1238            input = make_arg(input_size).fill_(1)
1239            result_ref = torch.linalg.norm(input.float(), ord, keepdim=keepdim).to(dtype=dtype)
1240            result = torch.linalg.norm(input, ord, keepdim=keepdim)
1241            self.assertEqual(result_ref, result, msg=msg)
1242
1243        ord_vector = [0, 1, -1, 2, -2, 3, -3, 4.5, -4.5, inf, -inf, None]
1244        for S, ord, keepdim in product((10, 2049), ord_vector, (True, False)):
1245            run_test_case((S,) , ord, keepdim, )
1246
1247    @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble, torch.bfloat16, torch.float16)
1248    def test_vector_norm(self, device, dtype):
1249        if IS_ARM64 and device == 'cpu' and dtype in [torch.float16, torch.bfloat16, torch.float32]:
1250            raise unittest.SkipTest("Fails on ARM, see https://github.com/pytorch/pytorch/issues/125438")
1251        # have to use torch.randn(...).to(bfloat16) instead of
1252        # This test compares torch.linalg.vector_norm's output with
1253        # torch.linalg.norm given a flattened tensor
1254        ord_vector = [0, 0.9, 1, 2, 3, inf, -0.5, -1, -2, -3, -inf]
1255        input_sizes = [
1256            (1, ),
1257            (10, ),
1258            (4, 5),
1259            (3, 4, 5),
1260            (0, ),
1261            (0, 10),
1262            (0, 0),
1263            (10, 0, 10),
1264        ]
1265
1266        def vector_norm_reference(input, ord, dim=None, keepdim=False, dtype=None):
1267            if dim is None:
1268                input_maybe_flat = input.flatten(0, -1)
1269            else:
1270                input_maybe_flat = input
1271
1272            result = torch.linalg.norm(input_maybe_flat, ord, dim=dim, keepdim=keepdim, dtype=dtype)
1273            if keepdim and dim is None:
1274                result = result.reshape([1] * input.dim())
1275            return result
1276
1277        def run_test_case(input, ord, dim, keepdim, norm_dtype):
1278            if (input.numel() == 0 and
1279                (ord < 0. or ord == inf) and
1280               (dim is None or input.shape[dim] == 0)):
1281                # The operation does not have an identity.
1282                error_msg = "linalg.vector_norm cannot compute"
1283                with self.assertRaisesRegex(RuntimeError, error_msg):
1284                    torch.linalg.vector_norm(input, ord, dim=dim, keepdim=keepdim)
1285            else:
1286                msg = (f'input.size()={input.size()}, ord={ord}, dim={dim}, '
1287                       f'keepdim={keepdim}, dtype={dtype}, norm_dtype={norm_dtype}')
1288                result_dtype_reference = vector_norm_reference(input, ord, dim=dim, keepdim=keepdim, dtype=norm_dtype)
1289                result_dtype = torch.linalg.vector_norm(input, ord, dim=dim, keepdim=keepdim, dtype=norm_dtype)
1290                if dtype.is_complex:
1291                    result_dtype_reference = result_dtype_reference.real
1292                self.assertEqual(result_dtype, result_dtype_reference, msg=msg)
1293
1294                if norm_dtype is not None:
1295                    ref = torch.linalg.vector_norm(input.to(norm_dtype), ord, dim=dim, keepdim=keepdim)
1296                    actual = torch.linalg.vector_norm(input, ord, dim=dim, keepdim=keepdim, dtype=norm_dtype)
1297                    self.assertEqual(ref, actual, msg=msg)
1298
1299        if dtype == torch.cfloat:
1300            norm_dtypes = (None, torch.cfloat, torch.cdouble)
1301        elif dtype == torch.cdouble:
1302            norm_dtypes = (None, torch.cdouble)
1303        elif dtype in (torch.float16, torch.bfloat16, torch.float):
1304            norm_dtypes = (None, torch.float, torch.double)
1305        elif dtype == torch.double:
1306            norm_dtypes = (None, torch.double)
1307        else:
1308            raise RuntimeError("Unsupported dtype")
1309
1310        for amp in [False, True]:
1311            with torch.autocast(device_type=device, enabled=amp):
1312                for input_size, ord, keepdim, norm_dtype in product(input_sizes, ord_vector, [True, False], norm_dtypes):
1313                    input = make_tensor(input_size, dtype=dtype, device=device, low=-9, high=9)
1314                    for dim in [None, random.randint(0, len(input_size) - 1)]:
1315                        run_test_case(
1316                            input,
1317                            ord,
1318                            dim,
1319                            keepdim,
1320                            norm_dtype)
1321
1322    def test_vector_norm_dim_tuple_arg(self, device):
1323        test_cases = [
1324            # input size, dim, error, error message
1325            ((4, ), (0, ), None, None),
1326            ((4, ), (1, ), IndexError, r'Dimension out of range'),
1327            ((4, ), (-2, ), IndexError, r'Dimension out of range'),
1328            ((4, 3), (0, -1), None, None),
1329            ((4, 3), (0, 0), RuntimeError, r'dim 0 appears multiple times in the list of dims'),
1330            ((4, 3), (0, -2), RuntimeError, r'dim 0 appears multiple times in the list of dims'),
1331            ((4, 3), (0, 1.0), TypeError, r"argument 'dim' must be tuple of ints"),
1332            ((4, 3), (None, ), TypeError, r"argument 'dim' must be tuple of ints"),
1333        ]
1334        for input_size, dim_tuple, error, error_msg in test_cases:
1335            input = torch.randn(input_size, device=device)
1336            # vector_norm should accept a tuple or a list for dim arg
1337            for dim in [dim_tuple, list(dim_tuple)]:
1338                if error is None:
1339                    torch.linalg.vector_norm(input, dim=dim)
1340                else:
1341                    with self.assertRaises(error):
1342                        torch.linalg.vector_norm(input, dim=dim)
1343
1344    # This test compares torch.linalg.norm and numpy.linalg.norm to ensure that
1345    # their vector norm results match
1346    @dtypes(torch.float, torch.double)
1347    def test_norm_vector(self, device, dtype):
1348        def run_test_case(input, p, dim, keepdim):
1349            result = torch.linalg.norm(input, ord, dim, keepdim)
1350            input_numpy = input.cpu().numpy()
1351            result_numpy = np.linalg.norm(input_numpy, ord, dim, keepdim)
1352
1353            msg = f'input.size()={input.size()}, ord={ord}, dim={dim}, keepdim={keepdim}, dtype={dtype}'
1354            self.assertEqual(result, result_numpy, msg=msg)
1355
1356            result_out = torch.empty_like(result)
1357            torch.linalg.norm(input, ord, dim, keepdim, out=result_out)
1358            self.assertEqual(result, result_out, msg=msg)
1359
1360        ord_vector = [0, 1, -1, 2, -2, 3, -3, 4.5, -4.5, inf, -inf]
1361        S = 10
1362        test_cases = [
1363            # input size, p settings, dim
1364            ((S, ), ord_vector, None),
1365            ((S, ), ord_vector, 0),
1366            ((S, S, S), ord_vector, 0),
1367            ((S, S, S), ord_vector, 1),
1368            ((S, S, S), ord_vector, 2),
1369            ((S, S, S), ord_vector, -1),
1370            ((S, S, S), ord_vector, -2),
1371        ]
1372        L = 1_000_000
1373        if dtype == torch.double:
1374            test_cases.append(((L, ), ord_vector, None))
1375        for keepdim in [True, False]:
1376            for input_size, ord_settings, dim in test_cases:
1377                input = torch.randn(*input_size, dtype=dtype, device=device)
1378                for ord in ord_settings:
1379                    run_test_case(input, ord, dim, keepdim)
1380
1381    # This test compares torch.linalg.norm, torch.linalg.matrix_norm and numpy.linalg.norm to
1382    # ensure that their matrix norm results match.
1383    @skipMeta  # https://github.com/pytorch/pytorch/issues/54082
1384    @skipCUDAIfNoMagma
1385    @dtypes(torch.float, torch.double)
1386    @precisionOverride({torch.float32: 2e-4})
1387    def test_norm_matrix(self, device, dtype):
1388        make_arg = partial(make_tensor, dtype=dtype, device=device)
1389
1390        def run_test_case(input, ord, dim, keepdim):
1391            msg = f'input.size()={input.size()}, ord={ord}, dim={dim}, keepdim={keepdim}, dtype={dtype}'
1392            result = torch.linalg.norm(input, ord, dim, keepdim)
1393            input_numpy = input.cpu().numpy()
1394            result_numpy = np.linalg.norm(input_numpy, ord, dim, keepdim)
1395
1396            result = torch.linalg.norm(input, ord, dim, keepdim)
1397            self.assertEqual(result, result_numpy, msg=msg)
1398            if ord is not None and dim is not None:
1399                result = torch.linalg.matrix_norm(input, ord, dim, keepdim)
1400                self.assertEqual(result, result_numpy, msg=msg)
1401
1402        ord_matrix = [1, -1, 2, -2, inf, -inf, 'nuc', 'fro']
1403        S = 10
1404        test_cases = [
1405            # input size, dim
1406            ((S, S), None),
1407            ((S, S), (0, 1)),
1408            ((S, S), (1, 0)),
1409            ((S, S, S, S), (2, 0)),
1410            ((S, S, S, S), (-1, -2)),
1411            ((S, S, S, S), (-1, -3)),
1412            ((S, S, S, S), (-3, 2)),
1413        ]
1414
1415        for (shape, dim), keepdim, ord in product(test_cases, [True, False], ord_matrix):
1416            if ord in [2, -2, 'nuc']:
1417                # We need torch.svdvals
1418                if dtype == torch.float16 or dtype == torch.bfloat16:
1419                    continue
1420                # We need LAPACK or equivalent
1421                if ((torch.device(device).type == 'cuda' and not torch.cuda.has_magma and not has_cusolver()) or
1422                   (torch.device(device).type == 'cpu' and not torch._C.has_lapack)):
1423                    continue
1424            run_test_case(make_arg(shape), ord, dim, keepdim)
1425
1426
1427    @onlyCUDA
1428    @dtypes(torch.bfloat16, torch.float16)
1429    def test_norm_fused_type_promotion(self, device, dtype):
1430        x = torch.randn(10, device=device, dtype=dtype)
1431
1432        def profile_and_check(fn, x, kwargs):
1433            with torch.profiler.profile(activities=(torch.profiler.ProfilerActivity.CPU,)) as p:
1434                fn(x, **kwargs, dtype=torch.float)
1435            # smoke check that profiler returned some events
1436            self.assertTrue("aten::linalg_vector_norm" in (e.name for e in p.events()))
1437            # test that there was no explicit copy
1438            self.assertFalse("aten::to" in (e.name for e in p.events()))
1439
1440        for f, kwargs, in zip((torch.linalg.vector_norm, torch.norm), ({}, {"p" : 2})):
1441            profile_and_check(f, x, kwargs)
1442
1443    @skipMeta  # https://github.com/pytorch/pytorch/issues/53739
1444    @skipCPUIfNoLapack
1445    @skipCUDAIfNoMagma
1446    @dtypes(*floating_and_complex_types())
1447    @precisionOverride({torch.float32: 1e-3})
1448    def test_cond(self, device, dtype):
1449        def run_test_case(input, p):
1450            result = torch.linalg.cond(input, p)
1451            result_numpy = np.linalg.cond(input.cpu().numpy(), p)
1452            self.assertEqual(result, result_numpy, rtol=1e-2, atol=self.precision, exact_dtype=False)
1453            self.assertEqual(result.shape, result_numpy.shape)
1454
1455            # test out= variant
1456            out = torch.empty_like(result)
1457            ans = torch.linalg.cond(input, p, out=out)
1458            self.assertEqual(ans, out)
1459            self.assertEqual(ans, result)
1460
1461        norm_types = [1, -1, 2, -2, inf, -inf, 'fro', 'nuc', None]
1462        input_sizes = [(32, 32), (2, 3, 3, 3)]
1463        for input_size in input_sizes:
1464            input = torch.randn(*input_size, dtype=dtype, device=device)
1465            for p in norm_types:
1466                run_test_case(input, p)
1467
1468        # test empty batch sizes
1469        input_sizes = [(0, 3, 3), (0, 2, 5, 5)]
1470        for input_size in input_sizes:
1471            input = torch.randn(*input_size, dtype=dtype, device=device)
1472            for p in norm_types:
1473                run_test_case(input, p)
1474
1475        # test non-square input
1476        input_sizes = [(16, 32), (32, 16), (2, 3, 5, 3), (2, 3, 3, 5)]
1477        for input_size in input_sizes:
1478            input = torch.randn(*input_size, dtype=dtype, device=device)
1479            for p in [2, -2, None]:
1480                run_test_case(input, p)
1481
1482        # test for singular input
1483        a = torch.eye(3, dtype=dtype, device=device)
1484        a[-1, -1] = 0  # make 'a' singular
1485        for p in norm_types:
1486            try:
1487                run_test_case(a, p)
1488            except np.linalg.LinAlgError:
1489                # Numpy may fail to converge for some BLAS backends (although this is very rare)
1490                # See the discussion in https://github.com/pytorch/pytorch/issues/67675
1491                pass
1492
1493        # test for 0x0 matrices. NumPy doesn't work for such input, we return 0
1494        input_sizes = [(0, 0), (2, 5, 0, 0)]
1495        for input_size in input_sizes:
1496            input = torch.randn(*input_size, dtype=dtype, device=device)
1497            for p in ['fro', 2]:
1498                expected_dtype = a.real.dtype if dtype.is_complex else dtype
1499                expected = torch.zeros(input_size[:-2], dtype=expected_dtype, device=device)
1500                actual = torch.linalg.cond(input, p)
1501                self.assertEqual(actual, expected)
1502
1503    @skipMeta  # https://github.com/pytorch/pytorch/issues/53739
1504    @skipCPUIfNoLapack
1505    @skipCUDAIfNoMagma
1506    @dtypes(*floating_and_complex_types())
1507    @precisionOverride({torch.float32: 1e-3})
1508    def test_cond_errors_and_warnings(self, device, dtype):
1509        norm_types = [1, -1, 2, -2, inf, -inf, 'fro', 'nuc', None]
1510
1511        # cond expects the input to be at least 2-dimensional
1512        a = torch.ones(3, dtype=dtype, device=device)
1513        for p in norm_types:
1514            with self.assertRaisesRegex(RuntimeError, r'at least 2 dimensions'):
1515                torch.linalg.cond(a, p)
1516
1517        # for some norm types cond expects the input to be square
1518        a = torch.ones(3, 2, dtype=dtype, device=device)
1519        norm_types = [1, -1, inf, -inf, 'fro', 'nuc']
1520        for p in norm_types:
1521            with self.assertRaisesRegex(RuntimeError, r'must be batches of square matrices'):
1522                torch.linalg.cond(a, p)
1523
1524        # if non-empty out tensor with wrong shape is passed a warning is given
1525        a = torch.ones((2, 2), dtype=dtype, device=device)
1526        for p in ['fro', 2]:
1527            real_dtype = a.real.dtype if dtype.is_complex else dtype
1528            out = torch.empty(a.shape, dtype=real_dtype, device=device)
1529            with warnings.catch_warnings(record=True) as w:
1530                # Trigger warning
1531                torch.linalg.cond(a, p, out=out)
1532                # Check warning occurs
1533                self.assertEqual(len(w), 1)
1534                self.assertTrue("An output with one or more elements was resized" in str(w[-1].message))
1535
1536        # dtypes should be safely castable
1537        out = torch.empty(0, dtype=torch.int, device=device)
1538        for p in ['fro', 2]:
1539            with self.assertRaisesRegex(RuntimeError, "but got result with dtype Int"):
1540                torch.linalg.cond(a, p, out=out)
1541
1542        # device should match
1543        if torch.cuda.is_available():
1544            wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda'
1545            out = torch.empty(0, dtype=dtype, device=wrong_device)
1546            for p in ['fro', 2]:
1547                with self.assertRaisesRegex(RuntimeError, "tensors to be on the same device"):
1548                    torch.linalg.cond(a, p, out=out)
1549
1550        # for batched input if at least one matrix in the batch is not invertible,
1551        # we can't get the result for all other (possibly) invertible matrices in the batch without an explicit for loop.
1552        # this should change when at::inverse works with silent errors
1553        # NumPy works fine in this case because it's possible to silence the error and get the inverse matrix results
1554        # possibly filled with NANs
1555        batch_dim = 3
1556        a = torch.eye(3, 3, dtype=dtype, device=device)
1557        a = a.reshape((1, 3, 3))
1558        a = a.repeat(batch_dim, 1, 1)
1559        a[1, -1, -1] = 0  # now a[1] is singular
1560        for p in [1, -1, inf, -inf, 'fro', 'nuc']:
1561            result = torch.linalg.cond(a, p)
1562            self.assertEqual(result[1], float('inf'))
1563
1564        # check invalid norm type
1565        a = torch.ones(3, 3, dtype=dtype, device=device)
1566        for p in ['wrong_norm', 5]:
1567            with self.assertRaisesRegex(RuntimeError, f"linalg.cond got an invalid norm type: {p}"):
1568                torch.linalg.cond(a, p)
1569
1570    # This test calls torch.linalg.norm and numpy.linalg.norm with illegal arguments
1571    # to ensure that they both throw errors
1572    @dtypes(torch.float, torch.double)
1573    def test_norm_errors(self, device, dtype):
1574        def run_error_test_case(input, ord, dim, keepdim, error_type, error_regex):
1575            test_case_info = (
1576                f'test case input.size()={input.size()}, ord={ord}, dim={dim}, '
1577                f'keepdim={keepdim}, dtype={dtype}')
1578
1579            with self.assertRaisesRegex(error_type, error_regex, msg=test_case_info):
1580                torch.linalg.norm(input, ord, dim, keepdim)
1581
1582            input_numpy = input.cpu().numpy()
1583
1584            msg = f'numpy does not raise error but pytorch does, for case "{test_case_info}"'
1585            with self.assertRaises(Exception, msg=test_case_info):
1586                np.linalg.norm(input_numpy, ord, dim, keepdim)
1587
1588        S = 10
1589        error_test_cases = [
1590            # input size, p settings, dim, error type, error regex
1591            ((S, ), ['fro', 'nuc'], None, RuntimeError, r'A must have at least 2 dimensions'),
1592            ((S, S), [3.5], None, RuntimeError, r'matrix_norm: Order 3.5 not supported'),
1593            ((S, S), [0], None, RuntimeError, r'matrix_norm: Order 0 not supported'),
1594            ((S, S), ['fail'], None, RuntimeError, r'matrix_norm: Order fail not supported'),
1595            ((S, S), ['fro', 'nuc'], 0, RuntimeError, r'matrix_norm: dim must be a 2-tuple'),
1596            ((S, S), ['fro', 'nuc', 2], (0, 0), RuntimeError, r'dims must be different'),
1597            ((S, S), ['fro', 'nuc', 2], (-1, 1), RuntimeError, r'dims must be different'),
1598            ((S, S), ['fro', 'nuc', 2], (0, 4), IndexError, r'Dimension out of range'),
1599            ((S, ), [0], (4, ), IndexError, r'Dimension out of range'),
1600            ((S, ), [None], (0, 0), RuntimeError, r'dim 0 appears multiple times'),
1601            ((S, S, S), [1], (0, 1, 2), RuntimeError, r"If dim is specified, it must be of length 1 or 2."),
1602            ((S, S, S), [1], None, RuntimeError, r"If dim is not specified but ord is, the input must be 1D or 2D"),
1603        ]
1604        for keepdim in [True, False]:
1605            for input_size, ord_settings, dim, error_type, error_regex in error_test_cases:
1606                input = torch.randn(*input_size, dtype=dtype, device=device)
1607                for ord in ord_settings:
1608                    run_error_test_case(input, ord, dim, keepdim, error_type, error_regex)
1609
1610    # Test complex number inputs for linalg.norm
1611    @skipCUDAIfNoMagma
1612    @skipCPUIfNoLapack
1613    @dtypes(torch.cfloat, torch.cdouble)
1614    @precisionOverride({torch.cfloat: 5e-4})
1615    def test_norm_complex(self, device, dtype):
1616        def gen_error_message(input_size, ord, keepdim, dim=None):
1617            return f"complex norm failed for input size {input_size}, ord={ord}, keepdim={keepdim}, dim={dim}"
1618
1619        vector_ords = [None, 0, 1, 2, 3, inf, -1, -2, -3, -inf]
1620        matrix_ords = [None, 'fro', 'nuc', 1, 2, inf, -1, -2, -inf]
1621
1622        # Test supported ords
1623        for keepdim in [False, True]:
1624            # vector norm
1625            x = torch.randn(25, device=device, dtype=dtype)
1626            xn = x.cpu().numpy()
1627            for ord in vector_ords:
1628                res = torch.linalg.norm(x, ord, keepdim=keepdim).cpu()
1629                expected = np.linalg.norm(xn, ord, keepdims=keepdim)
1630                msg = gen_error_message(x.size(), ord, keepdim)
1631                self.assertEqual(res.shape, expected.shape, msg=msg)
1632                self.assertEqual(res, expected, msg=msg, exact_dtype=False)
1633
1634                res_out = torch.tensor([], device=device, dtype=res.dtype)
1635                torch.linalg.norm(x, ord, keepdim=keepdim, out=res_out)
1636                self.assertEqual(res_out.shape, expected.shape, msg=msg)
1637                self.assertEqual(res_out, expected, msg=msg)
1638
1639            # matrix norm
1640            x = torch.randn(25, 25, device=device, dtype=dtype)
1641            xn = x.cpu().numpy()
1642            for ord in matrix_ords:
1643                res = torch.linalg.norm(x, ord, keepdim=keepdim).cpu()
1644                expected = np.linalg.norm(xn, ord, keepdims=keepdim)
1645                msg = gen_error_message(x.size(), ord, keepdim)
1646                self.assertEqual(res.shape, expected.shape, msg=msg)
1647                self.assertEqual(res, expected, msg=msg, exact_dtype=False)
1648
1649                res_out = torch.tensor([], device=device, dtype=res.dtype)
1650                torch.linalg.norm(x, ord, keepdim=keepdim, out=res_out)
1651                self.assertEqual(res_out.shape, expected.shape, msg=msg)
1652                self.assertEqual(res_out, expected, msg=msg)
1653
1654    # Test that linal.vector_norm gives the same result as numpy when inputs
1655    # contain extreme values (inf, -inf, nan)
1656    def test_vector_norm_extreme_values(self, device):
1657        vector_ords = [0, 1, 2, 3, inf, -1, -2, -3, -inf]
1658        vectors = []
1659        for pair in itertools.product([inf, -inf, 0.0, nan, 1.0], repeat=2):
1660            vectors.append(list(pair))
1661        for vector in vectors:
1662            x = torch.tensor(vector, device=device)
1663            x_n = x.cpu().numpy()
1664            for ord in vector_ords:
1665                msg = f'ord={ord}, vector={vector}'
1666                result = torch.linalg.vector_norm(x, ord=ord)
1667                result_n = np.linalg.norm(x_n, ord=ord)
1668                self.assertEqual(result, result_n, msg=msg)
1669
1670    @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble)
1671    def test_vector_norm_reduce_over_1D_vector(self, device, dtype):
1672        input_sizes_and_dims = [
1673            ((6, 1), -1),
1674            ((3, 1, 2, 1), (1, 3)),
1675            ((1,), None),
1676        ]
1677        orders = [float('inf'), -float('inf'), 0, 1, -1, 2, -2]
1678        keepdims = [True, False]
1679
1680        for input_size_and_dim, ord, keepdim in product(input_sizes_and_dims, orders, keepdims):
1681            input_size = input_size_and_dim[0]
1682            dim = input_size_and_dim[1]
1683            if type(dim) is tuple and ord == 0:
1684                # skip because np.linalg.norm raises 'ValueError: Invalid norm order for matrices.'
1685                continue
1686            input = make_tensor(input_size, dtype=dtype, device=device, low=-9, high=9)
1687            result = torch.linalg.vector_norm(input, ord, dim, keepdim)
1688            result_numpy = np.linalg.norm(input.cpu().numpy(), ord, dim, keepdim)
1689
1690            msg = f'input.size()={input.size()}, ord={ord}, dim={dim}, keepdim={keepdim}, dtype={dtype}'
1691            self.assertEqual(result, result_numpy, msg=msg)
1692
1693    @skipCUDAIfNoMagmaAndNoCusolver
1694    @skipCPUIfNoLapack
1695    @dtypes(torch.float, torch.double)
1696    @precisionOverride({torch.float32: 2e-5})
1697    def test_matrix_norm(self, device, dtype):
1698        # Test only inputs for which torch.linalg.matrix_norm diverges from torch.linalg.norm
1699        A = make_tensor((2, 2, 2), dtype=dtype, device=device)
1700
1701        with self.assertRaisesRegex(RuntimeError, r'linalg.matrix_norm:.*must have at least 2 dimensions.*'):
1702            torch.linalg.matrix_norm(make_tensor((2,), dtype=dtype, device=device))
1703        with self.assertRaisesRegex(RuntimeError, r'linalg.matrix_norm:.*must be a 2-tuple.*'):
1704            torch.linalg.matrix_norm(A, dim=(0,))
1705        with self.assertRaisesRegex(RuntimeError, r'.*not supported.*'):
1706            torch.linalg.matrix_norm(A, ord=0)
1707        with self.assertRaisesRegex(RuntimeError, r'.*not supported.*'):
1708            torch.linalg.matrix_norm(A, ord=3.0)
1709
1710        # Test dim=None behavior
1711        ref = torch.linalg.norm(A, dim=(-2, -1))
1712        res = torch.linalg.matrix_norm(A)
1713        self.assertEqual(ref, res)
1714
1715    # Test that linal.norm gives the same result as numpy when inputs
1716    # contain extreme values (inf, -inf, nan)
1717    @unittest.skipIf(IS_WINDOWS, "Skipped on Windows!")
1718    @unittest.skipIf(IS_MACOS, "Skipped on MacOS!")
1719    @skipCUDAIfNoMagma
1720    @skipCPUIfNoLapack
1721    def test_norm_extreme_values(self, device):
1722        vector_ords = [0, 1, 2, 3, inf, -1, -2, -3, -inf]
1723        # matrix_ords 'nuc', 2, -2 are skipped currently
1724        # See issue https://github.com/pytorch/pytorch/issues/71911
1725        matrix_ords = ['fro', 1, inf, -1, -inf]
1726        vectors = []
1727        matrices = []
1728        for pair in itertools.product([inf, -inf, 0.0, nan, 1.0], repeat=2):
1729            vectors.append(list(pair))
1730            matrices.append([[pair[0], pair[1]]])
1731            matrices.append([[pair[0]], [pair[1]]])
1732        for vector in vectors:
1733            x = torch.tensor(vector).to(device)
1734            x_n = x.cpu().numpy()
1735            for ord in vector_ords:
1736                msg = f'ord={ord}, vector={vector}'
1737                result = torch.linalg.norm(x, ord=ord)
1738                result_n = np.linalg.norm(x_n, ord=ord)
1739                self.assertEqual(result, result_n, msg=msg)
1740
1741        # TODO: Remove this function once the broken cases are fixed
1742        def is_broken_matrix_norm_case(ord, x):
1743            if self.device_type == 'cuda':
1744                if x.size() == torch.Size([1, 2]):
1745                    if ord in ['nuc', 2, -2] and isnan(x[0][0]) and x[0][1] == 1:
1746                        # These cases are broken because of an issue with svd
1747                        # https://github.com/pytorch/pytorch/issues/43567
1748                        return True
1749                if ord in ['nuc', 2, -2]:
1750                    # These cases are broken because of another issue with svd
1751                    # https://github.com/pytorch/pytorch/issues/52633
1752                    return True
1753            return False
1754
1755        for matrix in matrices:
1756            x = torch.tensor(matrix).to(device)
1757            x_n = x.cpu().numpy()
1758            for ord in matrix_ords:
1759                msg = f'ord={ord}, matrix={matrix}'
1760                if is_broken_matrix_norm_case(ord, x):
1761                    continue
1762                else:
1763                    result_n = np.linalg.norm(x_n, ord=ord)
1764                    result = torch.linalg.norm(x, ord=ord)
1765                    self.assertEqual(result, result_n, msg=msg)
1766
1767    # Test degenerate shape results match numpy for linalg.norm vector norms
1768    @skipCUDAIfNoMagma
1769    @skipCPUIfNoLapack
1770    @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble)
1771    def test_norm_vector_degenerate_shapes(self, device, dtype):
1772        def run_test_case(input, ord, dim, keepdim):
1773            msg = f'input.size()={input.size()}, ord={ord}, dim={dim}, keepdim={keepdim}, dtype={dtype}'
1774            if (input.numel() == 0 and
1775                (ord < 0. or ord == inf) and
1776               (dim is None or input.shape[dim] == 0)):
1777                with self.assertRaises(RuntimeError):
1778                    torch.linalg.norm(input, ord, dim, keepdim)
1779            else:
1780                input_numpy = input.cpu().numpy()
1781                result_numpy = np.linalg.norm(input_numpy, ord, dim, keepdim)
1782                result = torch.linalg.norm(input, ord, dim, keepdim)
1783                self.assertEqual(result, result_numpy, msg=msg)
1784
1785        ord_vector = [0, 0.5, 1, 2, 3, inf, -0.5, -1, -2, -3, -inf]
1786        S = 10
1787        test_cases = [
1788            # input size, dim
1789            ((0, ), None),
1790            ((0, S), 0),
1791            ((0, S), 1),
1792            ((S, 0), 0),
1793            ((S, 0), 1),
1794        ]
1795        for keepdim in [True, False]:
1796            for input_size, dim in test_cases:
1797                input = torch.randn(*input_size, dtype=dtype, device=device)
1798                for ord in ord_vector:
1799                    run_test_case(input, ord, dim, keepdim)
1800
1801    # Test degenerate shape results match numpy for linalg.norm matrix norms
1802    @skipCUDAIfNoMagma
1803    @skipCPUIfNoLapack
1804    @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble)
1805    def test_norm_matrix_degenerate_shapes(self, device, dtype):
1806        def run_test_case(input, ord, dim, keepdim, should_error):
1807            msg = f'input.size()={input.size()}, ord={ord}, dim={dim}, keepdim={keepdim}, dtype={dtype}'
1808            input_numpy = input.cpu().numpy()
1809            ops = [torch.linalg.norm]
1810
1811            if ord is not None and dim is not None:
1812                ops.append(torch.linalg.matrix_norm)
1813
1814            if should_error:
1815                with self.assertRaises(ValueError):
1816                    np.linalg.norm(input_numpy, ord, dim, keepdim)
1817                for op in ops:
1818                    with self.assertRaises(IndexError):
1819                        op(input, ord, dim, keepdim)
1820            else:
1821                result_numpy = np.linalg.norm(input_numpy, ord, dim, keepdim)
1822                for op in ops:
1823                    result = op(input, ord, dim, keepdim)
1824                    self.assertEqual(result, result_numpy, msg=msg)
1825
1826        ord_matrix = ['fro', 'nuc', 1, 2, inf, -1, -2, -inf, None]
1827        S = 10
1828        test_cases = [
1829            # input size, p settings that cause error, dim
1830            ((0, 0), [1, 2, inf, -1, -2, -inf], None),
1831            ((0, S), [2, inf, -2, -inf], None),
1832            ((S, 0), [1, 2, -1, -2], None),
1833            ((S, S, 0), [], (0, 1)),
1834            ((1, S, 0), [], (0, 1)),
1835            ((0, 0, S), [1, 2, inf, -1, -2, -inf], (0, 1)),
1836            ((0, 0, S), [1, 2, inf, -1, -2, -inf], (1, 0)),
1837        ]
1838
1839        for keepdim in [True, False]:
1840            for input_size, error_ords, dim in test_cases:
1841                input = torch.randn(*input_size, dtype=dtype, device=device)
1842                for ord in ord_matrix:
1843                    run_test_case(input, ord, dim, keepdim, ord in error_ords)
1844
1845    def test_norm_fastpaths(self, device):
1846        x = torch.randn(3, 5, device=device)
1847
1848        # slow path
1849        result = torch.linalg.norm(x, 4.5, 1)
1850        expected = torch.pow(x.abs().pow(4.5).sum(1), 1.0 / 4.5)
1851        self.assertEqual(result, expected)
1852
1853        # fast 0-norm
1854        result = torch.linalg.norm(x, 0, 1)
1855        expected = (x != 0).type_as(x).sum(1)
1856        self.assertEqual(result, expected)
1857
1858        # fast 1-norm
1859        result = torch.linalg.norm(x, 1, 1)
1860        expected = x.abs().sum(1)
1861        self.assertEqual(result, expected)
1862
1863        # fast 2-norm
1864        result = torch.linalg.norm(x, 2, 1)
1865        expected = torch.sqrt(x.pow(2).sum(1))
1866        self.assertEqual(result, expected)
1867
1868        # fast 3-norm
1869        result = torch.linalg.norm(x, 3, 1)
1870        expected = torch.pow(x.pow(3).abs().sum(1), 1.0 / 3.0)
1871        self.assertEqual(result, expected)
1872
1873    @skipCPUIfNoLapack
1874    @skipCUDAIfNoMagma
1875    # NumPy computes only in float64 and complex128 precisions
1876    # for float32 or complex64 results might be very different from float64 or complex128
1877    @dtypes(torch.float64, torch.complex128)
1878    def test_eig_numpy(self, device, dtype):
1879        def run_test(shape, *, symmetric=False):
1880            from torch.testing._internal.common_utils import random_symmetric_matrix
1881
1882            if not dtype.is_complex and symmetric:
1883                # for symmetric real-valued inputs eigenvalues and eigenvectors have imaginary part equal to zero
1884                # unlike NumPy the result is not cast to float32 or float64 dtype in this case
1885                a = random_symmetric_matrix(shape[-1], *shape[:-2], dtype=dtype, device=device)
1886            else:
1887                a = make_tensor(shape, dtype=dtype, device=device)
1888
1889            actual = torch.linalg.eig(a)
1890
1891            # compare with NumPy
1892            # the eigenvalues are not necessarily ordered
1893            # so order of NumPy and PyTorch can be different
1894            expected = np.linalg.eig(a.cpu().numpy())
1895
1896            # sort NumPy output
1897            ind = np.argsort(expected[0], axis=-1)[::-1]
1898            expected = (np.take_along_axis(expected[0], ind, axis=-1), np.take_along_axis(expected[1], ind[:, None], axis=-1))
1899
1900            # sort PyTorch output
1901            # torch.argsort doesn't work with complex inputs, NumPy sorting on CPU is used instead
1902            # RuntimeError: _th_sort not supported on CUDAType for ComplexDouble
1903            # RuntimeError: "sorting_kernel_method_name" not implemented for 'ComplexDouble'
1904            ind = np.argsort(actual[0].cpu().numpy(), axis=-1)[::-1]
1905            actual_np = [x.cpu().numpy() for x in actual]
1906            sorted_actual = (
1907                np.take_along_axis(actual_np[0], ind, axis=-1),
1908                np.take_along_axis(actual_np[1], ind[:, None], axis=-1))
1909
1910            self.assertEqual(expected[0], sorted_actual[0], exact_dtype=False)
1911            self.assertEqual(abs(expected[1]), abs(sorted_actual[1]), exact_dtype=False)
1912
1913        shapes = [(0, 0),  # Empty matrix
1914                  (5, 5),  # Single matrix
1915                  (0, 0, 0), (0, 5, 5),  # Zero batch dimension tensors
1916                  (2, 5, 5),  # 3-dim tensors
1917                  (2, 1, 5, 5)]  # 4-dim tensors
1918        for shape in shapes:
1919            run_test(shape)
1920            run_test(shape, symmetric=True)
1921
1922    @onlyCUDA
1923    @skipCUDAIfNoMagma
1924    @dtypes(*floating_and_complex_types())
1925    def test_eig_compare_backends(self, device, dtype):
1926        def run_test(shape, *, symmetric=False):
1927            from torch.testing._internal.common_utils import random_symmetric_matrix
1928
1929            if not dtype.is_complex and symmetric:
1930                # for symmetric real-valued inputs eigenvalues and eigenvectors have imaginary part equal to zero
1931                a = random_symmetric_matrix(shape[-1], *shape[:-2], dtype=dtype, device=device)
1932            else:
1933                a = make_tensor(shape, dtype=dtype, device=device)
1934
1935            actual = torch.linalg.eig(a)
1936
1937            complementary_device = 'cpu'
1938
1939            # compare with CPU
1940            expected = torch.linalg.eig(a.to(complementary_device))
1941            self.assertEqual(expected[0], actual[0])
1942            self.assertEqual(expected[1], actual[1])
1943
1944        shapes = [(0, 0),  # Empty matrix
1945                  (5, 5),  # Single matrix
1946                  (0, 0, 0), (0, 5, 5),  # Zero batch dimension tensors
1947                  (2, 5, 5),  # 3-dim tensors
1948                  (2, 1, 5, 5)]  # 4-dim tensors
1949        for shape in shapes:
1950            run_test(shape)
1951            run_test(shape, symmetric=True)
1952
1953    @slowTest
1954    @onlyCUDA
1955    @skipCUDAIfNoMagma
1956    @dtypes(torch.float32)
1957    def test_eig_check_magma(self, device, dtype):
1958        # For CUDA inputs only matrices of size larger than 2048x2048 actually call MAGMA library
1959        shape = (2049, 2049)
1960        a = make_tensor(shape, dtype=dtype, device=device)
1961        w, v = torch.linalg.eig(a)
1962        # check correctness using eigendecomposition identity
1963        self.assertEqual(a.to(v.dtype) @ v, w * v, atol=1e-3, rtol=1e-3)
1964
1965    @skipCUDAIfNoMagma
1966    @skipCPUIfNoLapack
1967    @dtypes(*floating_and_complex_types())
1968    def test_eig_errors_and_warnings(self, device, dtype):
1969        # eig requires the input to be at least 2 dimensional tensor
1970        a = make_tensor(2, dtype=dtype, device=device)
1971        with self.assertRaisesRegex(RuntimeError, "must have at least 2 dimensions"):
1972            torch.linalg.eig(a)
1973
1974        # eig requires a square matrix
1975        a = make_tensor((2, 3), dtype=dtype, device=device)
1976        with self.assertRaisesRegex(RuntimeError, "must be batches of square matrices"):
1977            torch.linalg.eig(a)
1978
1979        # if out tensor with floating dtype is passed for complex output an error is thrown
1980        if not dtype.is_complex:
1981            # The characteristic equation is p(lambda) = lambda^2 - 2lambda + 5 = 0, with roots lambda = 1[+-]2i
1982            a = torch.tensor([[3., -2.], [4., -1.]], dtype=dtype, device=device)
1983            out0 = torch.empty(0, device=device, dtype=dtype)
1984            out1 = torch.empty(0, device=device, dtype=dtype)
1985            with self.assertRaisesRegex(RuntimeError, "Expected eigenvalues to be safely castable"):
1986                torch.linalg.eig(a, out=(out0, out1))
1987
1988            out0 = torch.empty(0, device=device, dtype=torch.complex128)
1989            with self.assertRaisesRegex(RuntimeError, "Expected eigenvectors to be safely castable"):
1990                torch.linalg.eig(a, out=(out0, out1))
1991
1992        # dtypes should be safely castable
1993        a = make_tensor((3, 3), dtype=dtype, device=device)
1994        out0 = torch.empty(0, dtype=torch.int, device=device)
1995        out1 = torch.empty(0, dtype=torch.int, device=device)
1996        with self.assertRaisesRegex(RuntimeError, "but got eigenvalues with dtype Int"):
1997            torch.linalg.eig(a, out=(out0, out1))
1998
1999        out0 = torch.empty(0, dtype=torch.complex128, device=device)
2000        with self.assertRaisesRegex(RuntimeError, "but got eigenvectors with dtype Int"):
2001            torch.linalg.eig(a, out=(out0, out1))
2002
2003        # if non-empty out tensor with wrong shape is passed a warning is given
2004        a = make_tensor((3, 3), dtype=dtype, device=device)
2005        out0 = torch.empty(1, device=device, dtype=torch.complex128)
2006        out1 = torch.empty(1, device=device, dtype=torch.complex128)
2007        with warnings.catch_warnings(record=True) as w:
2008            # Trigger warning
2009            torch.linalg.eig(a, out=(out0, out1))
2010            # Check warning occurs
2011            self.assertEqual(len(w), 2)
2012            self.assertTrue("An output with one or more elements was resized" in str(w[-1].message))
2013            self.assertTrue("An output with one or more elements was resized" in str(w[-2].message))
2014
2015        # device should match
2016        if torch.cuda.is_available():
2017            wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda'
2018            out_w = torch.empty(0, device=wrong_device, dtype=torch.complex128)
2019            out_v = torch.empty(0, device=device, dtype=torch.complex128)
2020            with self.assertRaisesRegex(RuntimeError, "tensors to be on the same device"):
2021                torch.linalg.eig(a, out=(out_w, out_v))
2022            out_w = torch.empty(0, device=device, dtype=torch.complex128)
2023            out_v = torch.empty(0, device=wrong_device, dtype=torch.complex128)
2024            with self.assertRaisesRegex(RuntimeError, "tensors to be on the same device"):
2025                torch.linalg.eig(a, out=(out_w, out_v))
2026
2027    @skipCPUIfNoLapack
2028    @skipCUDAIfNoMagma
2029    @dtypes(*floating_and_complex_types())
2030    def test_eig_with_nan(self, device, dtype):
2031        for val in [np.inf, np.nan]:
2032            for batch_dim in [(), (10,)]:
2033                a = make_tensor((*batch_dim, 5, 5), device=device, dtype=dtype)
2034                a[..., -1, -1] = val
2035
2036                with self.assertRaisesRegex(RuntimeError, "torch.linalg.eig: input tensor should not"):
2037                    torch.linalg.eig(a)
2038
2039    @skipCPUIfNoLapack
2040    @skipCUDAIfNoMagma
2041    # NumPy computes only in float64 and complex128 precisions
2042    # for float32 or complex64 results might be very different from float64 or complex128
2043    @dtypes(torch.float64, torch.complex128)
2044    def test_eigvals_numpy(self, device, dtype):
2045        def run_test(shape, *, symmetric=False):
2046            from torch.testing._internal.common_utils import random_symmetric_matrix
2047
2048            if not dtype.is_complex and symmetric:
2049                # for symmetric real-valued inputs eigenvalues and eigenvectors have imaginary part equal to zero
2050                # unlike NumPy the result is not cast to float32 or float64 dtype in this case
2051                a = random_symmetric_matrix(shape[-1], *shape[:-2], dtype=dtype, device=device)
2052            else:
2053                a = make_tensor(shape, dtype=dtype, device=device)
2054
2055            actual = torch.linalg.eigvals(a)
2056
2057            # compare with NumPy
2058            # the eigenvalues are not necessarily ordered
2059            # so order of NumPy and PyTorch can be different
2060            expected = np.linalg.eigvals(a.cpu().numpy())
2061
2062            # sort NumPy output
2063            ind = np.argsort(expected, axis=-1)[::-1]
2064            expected = np.take_along_axis(expected, ind, axis=-1)
2065
2066            # sort PyTorch output
2067            # torch.argsort doesn't work with complex inputs, NumPy sorting on CPU is used instead
2068            # RuntimeError: _th_sort not supported on CUDAType for ComplexDouble
2069            # RuntimeError: "sorting_kernel_method_name" not implemented for 'ComplexDouble'
2070            ind = np.argsort(actual.cpu().numpy(), axis=-1)[::-1]
2071            actual_np = actual.cpu().numpy()
2072            sorted_actual = np.take_along_axis(actual_np, ind, axis=-1)
2073
2074            self.assertEqual(expected, sorted_actual, exact_dtype=False)
2075
2076        shapes = [(0, 0),  # Empty matrix
2077                  (5, 5),  # Single matrix
2078                  (0, 0, 0), (0, 5, 5),  # Zero batch dimension tensors
2079                  (2, 5, 5),  # 3-dim tensors
2080                  (2, 1, 5, 5)]  # 4-dim tensors
2081        for shape in shapes:
2082            run_test(shape)
2083            run_test(shape, symmetric=True)
2084
2085    @onlyCUDA
2086    @skipCUDAIfNoMagma
2087    @dtypes(*floating_and_complex_types())
2088    def test_eigvals_compare_backends(self, device, dtype):
2089        def run_test(shape, *, symmetric=False):
2090            from torch.testing._internal.common_utils import random_symmetric_matrix
2091
2092            if not dtype.is_complex and symmetric:
2093                # for symmetric real-valued inputs eigenvalues and eigenvectors have imaginary part equal to zero
2094                a = random_symmetric_matrix(shape[-1], *shape[:-2], dtype=dtype, device=device)
2095            else:
2096                a = make_tensor(shape, dtype=dtype, device=device)
2097
2098            actual = torch.linalg.eigvals(a)
2099
2100            complementary_device = 'cpu'
2101
2102            # compare with CPU
2103            expected = torch.linalg.eigvals(a.to(complementary_device))
2104            self.assertEqual(expected, actual)
2105
2106            # check out= variant
2107            complex_dtype = dtype
2108            if not dtype.is_complex:
2109                complex_dtype = torch.complex128 if dtype == torch.float64 else torch.complex64
2110            out = torch.empty(0, dtype=complex_dtype, device=device)
2111            ans = torch.linalg.eigvals(a, out=out)
2112            self.assertEqual(ans, out)
2113            self.assertEqual(expected.to(complex_dtype), out)
2114
2115            # check non-contiguous out
2116            if a.numel() > 0:
2117                out = torch.empty(2 * shape[0], *shape[1:-1], dtype=complex_dtype, device=device)[::2]
2118                self.assertFalse(out.is_contiguous())
2119                ans = torch.linalg.eigvals(a, out=out)
2120                self.assertEqual(ans, out)
2121                self.assertEqual(expected.to(complex_dtype), out)
2122
2123        shapes = [(0, 0),  # Empty matrix
2124                  (5, 5),  # Single matrix
2125                  (0, 0, 0), (0, 5, 5),  # Zero batch dimension tensors
2126                  (2, 5, 5),  # 3-dim tensors
2127                  (2, 1, 5, 5)]  # 4-dim tensors
2128        for shape in shapes:
2129            run_test(shape)
2130            run_test(shape, symmetric=True)
2131
2132    @skipCUDAIfNoMagma
2133    @skipCPUIfNoLapack
2134    @dtypes(*floating_and_complex_types())
2135    def test_eigvals_errors_and_warnings(self, device, dtype):
2136        # eig requires the input to be at least 2 dimensional tensor
2137        a = make_tensor(2, dtype=dtype, device=device)
2138        with self.assertRaisesRegex(RuntimeError, "must have at least 2 dimensions"):
2139            torch.linalg.eigvals(a)
2140
2141        # eig requires a square matrix
2142        a = make_tensor((2, 3), dtype=dtype, device=device)
2143        with self.assertRaisesRegex(RuntimeError, "must be batches of square matrices"):
2144            torch.linalg.eigvals(a)
2145
2146        # if out tensor with floating dtype is passed for complex output an error is thrown
2147        if not dtype.is_complex:
2148            # The characteristic equation is p(lambda) = lambda^2 - 2lambda + 5 = 0, with roots lambda = 1[+-]2i
2149            a = torch.tensor([[3., -2.], [4., -1.]], dtype=dtype, device=device)
2150            out = torch.empty(0, device=device, dtype=dtype)
2151            with self.assertRaisesRegex(RuntimeError, "Expected eigenvalues to be safely castable"):
2152                torch.linalg.eigvals(a, out=out)
2153
2154        # dtypes should be safely castable
2155        a = make_tensor((3, 3), dtype=dtype, device=device)
2156        out = torch.empty(0, dtype=torch.int, device=device)
2157        with self.assertRaisesRegex(RuntimeError, "but got eigenvalues with dtype Int"):
2158            torch.linalg.eigvals(a, out=out)
2159
2160        # if non-empty out tensor with wrong shape is passed a warning is given
2161        out = torch.empty(1, device=device, dtype=torch.complex128)
2162        with warnings.catch_warnings(record=True) as w:
2163            # Trigger warning
2164            torch.linalg.eigvals(a, out=out)
2165            # Check warning occurs
2166            self.assertEqual(len(w), 1)
2167            self.assertTrue("An output with one or more elements was resized" in str(w[-1].message))
2168
2169        # device should match
2170        if torch.cuda.is_available():
2171            wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda'
2172            out_w = torch.empty(0, device=wrong_device, dtype=torch.complex128)
2173            with self.assertRaisesRegex(RuntimeError, "tensors to be on the same device"):
2174                torch.linalg.eigvals(a, out=out_w)
2175
2176    @skipCUDAIfNoMagma
2177    @skipCPUIfNoLapack
2178    def test_norm_old(self, device):
2179        def gen_error_message(input_size, p, keepdim, dim=None):
2180            return f"norm failed for input size {input_size}, p={p}, keepdim={keepdim}, dim={dim}"
2181
2182        # 'nuc' norm uses SVD, and thus its precsion is much lower than other norms.
2183        # test_svd takes @precisionOverride({torch.float: 1e-4, torch.cfloat: 2e-4}),
2184        # and here we are doing the same thing for nuc norm.
2185        class PrecisionContext:
2186            def __init__(self, test, norm):
2187                self.norm = norm
2188                self.saved_overrides = getattr(test, 'precision_overrides', None)
2189                self.target_test = test
2190
2191            def __enter__(self):
2192                if 'nuc' != self.norm:
2193                    return None
2194                self.target_test.precision_overrides = {torch.float: 1e-4, torch.cfloat: 2e-4}
2195                return self.target_test.precision_overrides
2196
2197            def __exit__(self, type, value, tb) -> bool:
2198                if 'nuc' != self.norm:
2199                    return True
2200                if self.saved_overrides is None:
2201                    delattr(self.target_test, 'precision_overrides')
2202                else:
2203                    self.target_test.precision_overrides = self.saved_overrides
2204                return True
2205
2206        for keepdim in [False, True]:
2207            # full reduction
2208            x = torch.randn(25, device=device)
2209            xn = x.cpu().numpy()
2210            for p in [0, 1, 2, 3, 4, inf, -inf, -1, -2, -3, 1.5]:
2211                res = x.norm(p, keepdim=keepdim).cpu()
2212                expected = np.linalg.norm(xn, p, keepdims=keepdim)
2213                self.assertEqual(res, expected, atol=1e-5, rtol=0, msg=gen_error_message(x.size(), p, keepdim))
2214
2215            # one dimension
2216            x = torch.randn(25, 25, device=device)
2217            xn = x.cpu().numpy()
2218            for p in [0, 1, 2, 3, 4, inf, -inf, -1, -2, -3]:
2219                dim = 1
2220                res = x.norm(p, dim, keepdim=keepdim).cpu()
2221                expected = np.linalg.norm(xn, p, dim, keepdims=keepdim)
2222                msg = gen_error_message(x.size(), p, keepdim, dim)
2223                self.assertEqual(res.shape, expected.shape, msg=msg)
2224                self.assertEqual(res, expected, msg=msg)
2225
2226            # matrix norm
2227            for p in ['fro', 'nuc']:
2228                res = x.norm(p, keepdim=keepdim).cpu()
2229                expected = np.linalg.norm(xn, p, keepdims=keepdim)
2230                msg = gen_error_message(x.size(), p, keepdim)
2231                with PrecisionContext(self, p):
2232                    self.assertEqual(res.shape, expected.shape, msg=msg)
2233                    self.assertEqual(res, expected, msg=msg)
2234
2235            # zero dimensions
2236            x = torch.randn((), device=device)
2237            xn = x.cpu().numpy()
2238            res = x.norm(keepdim=keepdim).cpu()
2239            expected = np.linalg.norm(xn, keepdims=keepdim)
2240            msg = gen_error_message(x.size(), None, keepdim)
2241            self.assertEqual(res.shape, expected.shape, msg=msg)
2242            self.assertEqual(res, expected, msg=msg)
2243
2244            # larger tensor sanity check
2245            self.assertEqual(
2246                2 * torch.norm(torch.ones(10000), keepdim=keepdim),
2247                torch.norm(torch.ones(40000), keepdim=keepdim))
2248
2249            # matrix norm with non-square >2-D tensors, all combinations of reduction dims
2250            x = torch.randn(5, 6, 7, 8, device=device)
2251            xn = x.cpu().numpy()
2252            for p in ['fro', 'nuc']:
2253                for dim in itertools.product(*[list(range(4))] * 2):
2254                    if dim[0] == dim[1]:
2255                        continue
2256                    res = x.norm(p=p, dim=dim, keepdim=keepdim).cpu()
2257                    expected = np.linalg.norm(xn, ord=p, axis=dim, keepdims=keepdim)
2258                    msg = gen_error_message(x.size(), p, keepdim, dim)
2259                    with PrecisionContext(self, p):
2260                        self.assertEqual(res.shape, expected.shape, msg=msg)
2261                        self.assertEqual(res, expected, msg=msg)
2262
2263    # Test that torch.norm with p=+/-inf propagates NaN
2264    def test_norm_old_nan_propagation(self, device):
2265        ords = [inf, -inf]
2266        for pair in itertools.product([0.0, nan, 1.0], repeat=2):
2267            x = torch.tensor(list(pair), device=device)
2268            for ord in ords:
2269                result = torch.norm(x, p=ord)
2270                result_check = torch.linalg.norm(x, ord=ord)
2271                self.assertEqual(result, result_check)
2272
2273    @skipCUDAIfNoMagma
2274    @skipCPUIfNoLapack
2275    def test_norm_complex_old(self, device):
2276        def gen_error_message(input_size, p, keepdim, dim=None):
2277            return f"complex norm failed for input size {input_size}, p={p}, keepdim={keepdim}, dim={dim}"
2278
2279        for keepdim in [False, True]:
2280            # vector norm
2281            x = torch.randn(25, device=device) + 1j * torch.randn(25, device=device)
2282            xn = x.cpu().numpy()
2283            for p in [0, 1, 2, 3, inf, -1, -2, -3, -inf]:
2284                res = x.norm(p, keepdim=keepdim).cpu()
2285                expected = np.linalg.norm(xn, p, keepdims=keepdim)
2286                msg = gen_error_message(x.size(), p, keepdim)
2287                self.assertEqual(res.shape, expected.shape, msg=msg)
2288                self.assertEqual(res, expected, msg=msg)
2289
2290            # matrix norm
2291            x = torch.randn(25, 25, device=device) + 1j * torch.randn(25, 25, device=device)
2292            xn = x.cpu().numpy()
2293            for p in ['nuc', 'fro']:
2294                res = x.norm(p, keepdim=keepdim).cpu()
2295                expected = np.linalg.norm(xn, p, keepdims=keepdim)
2296                msg = gen_error_message(x.size(), p, keepdim)
2297                self.assertEqual(res.shape, expected.shape, msg=msg)
2298                self.assertEqual(res, expected, msg=msg, rtol=4e-6, atol=6e-4)
2299
2300    # Ensure torch.norm with p='fro' and p=2 give the same results for mutually supported input combinations
2301    @dtypes(torch.float)
2302    def test_norm_fro_2_equivalence_old(self, device, dtype):
2303        input_sizes = [
2304            (0,),
2305            (10,),
2306            (0, 0),
2307            (4, 30),
2308            (0, 45),
2309            (100, 0),
2310            (45, 10, 23),
2311            (0, 23, 59),
2312            (23, 0, 37),
2313            (34, 58, 0),
2314            (0, 0, 348),
2315            (0, 3434, 0),
2316            (0, 0, 0),
2317            (5, 3, 8, 1, 3, 5)]
2318
2319        for input_size in input_sizes:
2320            a = make_tensor(input_size, dtype=dtype, device=device, low=-9, high=9)
2321
2322            # Try full reduction
2323            dim_settings = [None]
2324
2325            # Try all possible 1-D reductions
2326            dim_settings += list(range(-a.dim(), a.dim()))
2327
2328            def wrap_dim(dim, ndims):
2329                assert (dim < ndims) and (dim >= -ndims)
2330                if dim >= 0:
2331                    return dim
2332                else:
2333                    return dim + ndims
2334
2335            # Try all possible 2-D reductions
2336            dim_settings += [
2337                (d0, d1) for d0, d1 in itertools.combinations(range(-a.dim(), a.dim()), 2)
2338                if wrap_dim(d0, a.dim()) != wrap_dim(d1, a.dim())]
2339
2340            for dim in dim_settings:
2341                for keepdim in [True, False]:
2342                    a_norm_2 = torch.norm(a, p=2, dim=dim, keepdim=keepdim)
2343                    a_norm_fro = torch.norm(a, p='fro', dim=dim, keepdim=keepdim)
2344                    self.assertEqual(a_norm_fro, a_norm_2)
2345
2346    @skipIfTorchDynamo("Not a TorchDynamo suitable test")
2347    @skipCUDAIfNoMagma
2348    @skipCPUIfNoLapack
2349    def test_nuclear_norm_axes_small_brute_force_old(self, device):
2350        def check_single_nuclear_norm(x, axes):
2351            if self.device_type != 'cpu' and randrange(100) < 95:
2352                return  # too many cpu <==> device copies
2353
2354            a = np.array(x.cpu(), copy=False)
2355            expected = np.linalg.norm(a, "nuc", axis=axes)
2356
2357            ans = torch.norm(x, "nuc", dim=axes)
2358            self.assertTrue(ans.is_contiguous())
2359            self.assertEqual(ans.shape, expected.shape)
2360            self.assertEqual(ans.cpu(), expected, rtol=1e-02, atol=1e-03, equal_nan=True)
2361
2362            out = torch.zeros(expected.shape, dtype=x.dtype, device=x.device)
2363            ans = torch.norm(x, "nuc", dim=axes, out=out)
2364            self.assertIs(ans, out)
2365            self.assertTrue(ans.is_contiguous())
2366            self.assertEqual(ans.shape, expected.shape)
2367            self.assertEqual(ans.cpu(), expected, rtol=1e-02, atol=1e-03, equal_nan=True)
2368
2369        for n in range(1, 3):
2370            for m in range(1, 3):
2371                for axes in itertools.permutations([0, 1], 2):
2372                    # 2d, inner dimensions C
2373                    x = torch.randn(n, m, device=device)
2374                    check_single_nuclear_norm(x, axes)
2375
2376                    # 2d, inner dimensions Fortran
2377                    x = torch.randn(m, n, device=device).mT
2378                    check_single_nuclear_norm(x, axes)
2379
2380                    # 2d, inner dimensions non-contiguous
2381                    x = torch.randn(n, 2 * m, device=device)[:, ::2]
2382                    check_single_nuclear_norm(x, axes)
2383
2384                    # 2d, all dimensions non-contiguous
2385                    x = torch.randn(7 * n, 2 * m, device=device)[::7, ::2]
2386                    check_single_nuclear_norm(x, axes)
2387
2388                for o in range(1, 3):
2389                    for axes in itertools.permutations([0, 1, 2], 2):
2390                        # 3d, inner dimensions C
2391                        x = torch.randn(o, n, m, device=device)
2392                        check_single_nuclear_norm(x, axes)
2393
2394                        # 3d, inner dimensions Fortran
2395                        x = torch.randn(o, m, n, device=device).mT
2396                        check_single_nuclear_norm(x, axes)
2397
2398                        # 3d, inner dimensions non-contiguous
2399                        x = torch.randn(o, n, 2 * m, device=device)[:, :, ::2]
2400                        check_single_nuclear_norm(x, axes)
2401
2402                        # 3d, all dimensions non-contiguous
2403                        x = torch.randn(7 * o, 5 * n, 2 * m, device=device)[::7, ::5, ::2]
2404                        check_single_nuclear_norm(x, axes)
2405
2406                    for r in range(1, 3):
2407                        for axes in itertools.permutations([0, 1, 2, 3], 2):
2408                            # 4d, inner dimensions C
2409                            x = torch.randn(r, o, n, m, device=device)
2410                            check_single_nuclear_norm(x, axes)
2411
2412                            # 4d, inner dimensions Fortran
2413                            x = torch.randn(r, o, n, m, device=device).mT
2414                            check_single_nuclear_norm(x, axes)
2415
2416                            # 4d, inner dimensions non-contiguous
2417                            x = torch.randn(r, o, n, 2 * m, device=device)[:, :, :, ::2]
2418                            check_single_nuclear_norm(x, axes)
2419
2420                            # 4d, all dimensions non-contiguous
2421                            x = torch.randn(7 * r, 5 * o, 11 * n, 2 * m, device=device)[::7, ::5, ::11, ::2]
2422                            check_single_nuclear_norm(x, axes)
2423
2424    @skipCUDAIfNoMagma
2425    def test_nuclear_norm_exceptions_old(self, device):
2426        for lst in [], [1], [1, 2]:
2427            x = torch.tensor(lst, dtype=torch.double, device=device)
2428            for axes in (), (0,):
2429                self.assertRaises(RuntimeError, torch.norm, x, "nuc", axes)
2430            self.assertRaises(RuntimeError, torch.norm, x, "nuc", (0, 1))
2431
2432        x = torch.tensor([[0, 1, 2], [3, 4, 5]], dtype=torch.double, device=device)
2433        self.assertRaisesRegex(RuntimeError, "must be different", torch.norm, x, "nuc", (0, 0))
2434        self.assertRaisesRegex(IndexError, "Dimension out of range", torch.norm, x, "nuc", (0, 2))
2435
2436    @skipCUDAIfNoCusolver
2437    @skipCPUIfNoLapack
2438    @dtypes(torch.double, torch.cdouble)
2439    def test_svd_lowrank(self, device, dtype):
2440        from torch.testing._internal.common_utils import random_lowrank_matrix, random_sparse_matrix
2441
2442        def run_subtest(actual_rank, matrix_size, batches, device, svd_lowrank, **options):
2443            density = options.pop('density', 1)
2444            if isinstance(matrix_size, int):
2445                rows = columns = matrix_size
2446            else:
2447                rows, columns = matrix_size
2448            if density == 1:
2449                a_input = random_lowrank_matrix(actual_rank, rows, columns, *batches, device=device, dtype=dtype)
2450                a = a_input
2451            else:
2452                assert batches == ()
2453                a_input = random_sparse_matrix(rows, columns, density, device=device, dtype=dtype)
2454                a = a_input.to_dense()
2455
2456            q = min(*size)
2457            u, s, v = svd_lowrank(a_input, q=q, **options)
2458
2459            # check if u, s, v is a SVD
2460            u, s, v = u[..., :q], s[..., :q], v[..., :q]
2461            A = (u * s.unsqueeze(-2)).matmul(v.mH)
2462            self.assertEqual(A, a, rtol=1e-7, atol=2e-7)
2463
2464            # check if svd_lowrank produces same singular values as linalg.svdvals
2465            U, S, Vh = torch.linalg.svd(a, full_matrices=False)
2466            V = Vh.mH
2467            self.assertEqual(s, S)
2468
2469            if density == 1:
2470                # actual_rank is known only for dense inputs
2471                #
2472                # check if pairs (u, U) and (v, V) span the same
2473                # subspaces, respectively
2474                u, v = u[..., :actual_rank], v[..., :actual_rank]
2475                U, V = U[..., :actual_rank], V[..., :actual_rank]
2476                expected_ones = u.mH.matmul(U).det().abs()
2477                self.assertEqual(expected_ones, torch.ones_like(expected_ones))
2478                self.assertEqual(v.mH.matmul(V).det().abs(), torch.ones_like(expected_ones))
2479
2480        all_batches = [(), (1,), (3,), (2, 3)]
2481        for actual_rank, size, all_batches in [  # noqa: B020
2482                (2, (17, 4), all_batches),
2483                (4, (17, 4), all_batches),
2484                (4, (17, 17), all_batches),
2485                (10, (100, 40), all_batches),
2486                (7, (1000, 1000), [()]),
2487        ]:
2488            # dense input
2489            for batches in all_batches:
2490                run_subtest(actual_rank, size, batches, device, torch.svd_lowrank)
2491                if size != size[::-1]:
2492                    run_subtest(actual_rank, size[::-1], batches, device, torch.svd_lowrank)
2493
2494        # sparse input
2495        for size in [(17, 4), (4, 17), (17, 17), (100, 40), (40, 100), (1000, 1000)]:
2496            for density in [0.005, 0.1]:
2497                run_subtest(None, size, (), device, torch.svd_lowrank, density=density)
2498
2499        # jitting support
2500        jitted = torch.jit.script(torch.svd_lowrank)
2501        actual_rank, size, batches = 2, (17, 4), ()
2502        run_subtest(actual_rank, size, batches, device, jitted)
2503
2504    @skipCUDAIfNoMagmaAndNoCusolver
2505    @skipCPUIfNoLapack
2506    @precisionOverride({torch.float: 1e-4, torch.cfloat: 2e-4})
2507    @setLinalgBackendsToDefaultFinally
2508    @dtypes(*floating_and_complex_types())
2509    @serialTest()
2510    def test_svd(self, device, dtype):
2511        # tests linalg.svd, svd, linalg.svdvals
2512        make_arg = partial(make_tensor, dtype=dtype, device=device)
2513
2514        backends = ["default"]
2515
2516        if torch.device(device).type == 'cuda':
2517            if torch.cuda.has_magma:
2518                backends.append("magma")
2519            if has_cusolver() or has_hipsolver():
2520                backends.append("cusolver")
2521
2522        ns = (12, 4, 2, 0)
2523        batches = ((), (0,), (1,), (2,), (2, 1), (0, 2))
2524        drivers = (None, 'gesvd', 'gesvdj', 'gesvda')
2525
2526        for backend in backends:
2527            torch.backends.cuda.preferred_linalg_library(backend)
2528
2529            for batch, m, n, driver in product(batches, ns, ns, drivers):
2530                if not (backend == 'cusolver' or driver is None):
2531                    # only test cases below and skip otherwise:
2532                    # - backend == 'cusolver' (driver can be anything)
2533                    # - backend != 'cusolver' (driver should only be None)
2534                    continue
2535
2536                shape = batch + (m, n)
2537                k = min(m, n)
2538                A = make_arg(shape)
2539                U, S, Vh = torch.linalg.svd(A, full_matrices=False, driver=driver)
2540                self.assertEqual((U @ S.to(A.dtype).diag_embed()) @ Vh, A)
2541
2542                U_f, S_f, Vh_f = torch.linalg.svd(A, full_matrices=True, driver=driver)
2543                self.assertEqual(S_f, S)
2544                self.assertEqual((U_f[..., :k] @ S_f.to(A.dtype).diag_embed()) @ Vh_f[..., :k, :], A)
2545
2546                S_s = torch.linalg.svdvals(A, driver=driver)
2547                self.assertEqual(S_s, S)
2548
2549                U, S, V = torch.svd(A, some=True)
2550                self.assertEqual((U @ S.to(A.dtype).diag_embed()) @ V.mH, A)
2551
2552                U_f, S_f, V_f = torch.svd(A, some=False)
2553                self.assertEqual(S_f, S)
2554                self.assertEqual((U_f[..., :k] @ S_f.to(A.dtype).diag_embed()) @ V_f[..., :k].mH, A)
2555
2556                S_s = torch.svd(A, compute_uv=False).S
2557                self.assertEqual(S_s, S)
2558
2559    @skipCUDAIfNoMagmaAndNoCusolver
2560    @skipCPUIfNoLapack
2561    @dtypes(torch.complex128)
2562    def test_invariance_error_spectral_decompositions(self, device, dtype):
2563        make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=True)
2564        A = make_arg((3, 3))
2565        with self.assertRaisesRegex(RuntimeError, "ill-defined"):
2566            U, _, Vh = torch.linalg.svd(A, full_matrices=False)
2567            (U + Vh).sum().abs().backward()
2568
2569        A = make_arg((3, 3))
2570        with self.assertRaisesRegex(RuntimeError, "ill-defined"):
2571            V = torch.linalg.eig(A).eigenvectors
2572            V.sum().abs().backward()
2573
2574        A = make_arg((3, 3))
2575        A = A + A.mH
2576        with self.assertRaisesRegex(RuntimeError, "ill-defined"):
2577            Q = torch.linalg.eigh(A).eigenvectors
2578            Q.sum().abs().backward()
2579
2580    @skipCUDAIfNoCusolver  # MAGMA backend doesn't work in this case
2581    @precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4})
2582    @skipCPUIfNoLapack
2583    @dtypes(*floating_and_complex_types())
2584    def test_svd_memory_allocation(self, device, dtype):
2585        # test for https://github.com/pytorch/pytorch/issues/61949
2586        # the problem was that tensors of incorrect size were allocated and then narrowed
2587        m = 3
2588        n = 2**20
2589        a = make_tensor((m, n), dtype=dtype, device=device)
2590        # the following should run without errors
2591        S = torch.linalg.svdvals(a)
2592        result = torch.linalg.svd(a, full_matrices=False)
2593        self.assertEqual(result.S, S)
2594
2595    def cholesky_solve_test_helper(self, A_dims, b_dims, upper, device, dtype):
2596        from torch.testing._internal.common_utils import random_hermitian_pd_matrix
2597
2598        b = torch.randn(*b_dims, dtype=dtype, device=device)
2599        A = random_hermitian_pd_matrix(*A_dims, dtype=dtype, device=device)
2600        L = torch.cholesky(A, upper=upper)
2601        return b, A, L
2602
2603    @skipCUDAIfNoMagma
2604    @skipCPUIfNoLapack
2605    @dtypes(*floating_and_complex_types())
2606    @precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3,
2607                        torch.float64: 1e-8, torch.complex128: 1e-8})
2608    def test_cholesky_solve(self, device, dtype):
2609        for (k, n), upper in itertools.product(zip([2, 3, 5], [3, 5, 7]), [True, False]):
2610            b, A, L = self.cholesky_solve_test_helper((n,), (n, k), upper, device, dtype)
2611            x = torch.cholesky_solve(b, L, upper=upper)
2612            self.assertEqual(b, np.matmul(A.cpu(), x.cpu()))
2613
2614    @skipCUDAIfNoMagma
2615    @skipCPUIfNoLapack
2616    @dtypes(*floating_and_complex_types())
2617    @precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3,
2618                        torch.float64: 1e-8, torch.complex128: 1e-8})
2619    def test_cholesky_solve_batched(self, device, dtype):
2620        def cholesky_solve_batch_helper(A_dims, b_dims, upper):
2621            b, A, L = self.cholesky_solve_test_helper(A_dims, b_dims, upper, device, dtype)
2622            x_exp_list = []
2623            for i in range(b_dims[0]):
2624                x_exp_list.append(torch.cholesky_solve(b[i], L[i], upper=upper))
2625            x_exp = torch.stack(x_exp_list)  # Stacked output
2626            x_act = torch.cholesky_solve(b, L, upper=upper)  # Actual output
2627            self.assertEqual(x_act, x_exp)  # Equality check
2628            Ax = np.matmul(A.cpu(), x_act.cpu())
2629            self.assertEqual(b, Ax)  # Correctness check
2630
2631        for upper, batchsize in itertools.product([True, False], [1, 3, 4]):
2632            cholesky_solve_batch_helper((5, batchsize), (batchsize, 5, 10), upper)
2633
2634    @slowTest
2635    @skipCUDAIfNoMagma
2636    @skipCPUIfNoLapack
2637    @dtypes(*floating_and_complex_types())
2638    @precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3,
2639                        torch.float64: 1e-8, torch.complex128: 1e-8})
2640    def test_cholesky_solve_batched_many_batches(self, device, dtype):
2641        for A_dims, b_dims in zip([(5, 256, 256), (5,)], [(5, 10), (512, 512, 5, 10)]):
2642            for upper in [True, False]:
2643                b, A, L = self.cholesky_solve_test_helper(A_dims, b_dims, upper, device, dtype)
2644                x = torch.cholesky_solve(b, L, upper)
2645                Ax = torch.matmul(A, x)
2646                self.assertEqual(Ax, b.expand_as(Ax))
2647
2648    @skipCUDAIfNoMagma
2649    @skipCPUIfNoLapack
2650    @dtypes(*floating_and_complex_types())
2651    @precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3,
2652                        torch.float64: 1e-8, torch.complex128: 1e-8})
2653    def test_cholesky_solve_batched_broadcasting(self, device, dtype):
2654        from numpy.linalg import solve
2655        from torch.testing._internal.common_utils import random_hermitian_pd_matrix
2656
2657        def run_test(A_dims, b_dims, upper):
2658            A_matrix_size = A_dims[-1]
2659            A_batch_dims = A_dims[:-2]
2660            A = random_hermitian_pd_matrix(A_matrix_size, *A_batch_dims,
2661                                           dtype=dtype, device='cpu')
2662            b = torch.randn(*b_dims, dtype=dtype, device='cpu')
2663            x_exp = torch.tensor(solve(A.numpy(), b.numpy()), dtype=dtype, device=device)
2664            A, b = A.to(dtype=dtype, device=device), b.to(dtype=dtype, device=device)
2665            L = torch.linalg.cholesky(A, upper=upper)
2666            x = torch.cholesky_solve(b, L, upper=upper)
2667            self.assertEqual(x, x_exp)
2668            # https://github.com/pytorch/pytorch/issues/42695
2669            x = torch.cholesky_solve(b, L, upper=upper, out=x)
2670            self.assertEqual(x, x_exp)
2671
2672        # test against numpy.linalg.solve
2673        for upper in [True, False]:
2674            run_test((2, 1, 3, 4, 4), (2, 1, 3, 4, 6), upper)  # no broadcasting
2675            run_test((2, 1, 3, 4, 4), (4, 6), upper)  # broadcasting b
2676            run_test((4, 4), (2, 1, 3, 4, 2), upper)  # broadcasting A
2677            run_test((1, 3, 1, 4, 4), (2, 1, 3, 4, 5), upper)  # broadcasting A & b
2678
2679    @skipCUDAIfNoMagma
2680    @skipCPUIfNoLapack
2681    @dtypes(*floating_and_complex_types())
2682    def test_cholesky_solve_out_errors_and_warnings(self, device, dtype):
2683        # dtypes should be safely castable
2684        a = torch.eye(2, dtype=dtype, device=device)
2685        b = torch.randn(2, 1, dtype=dtype, device=device)
2686        out = torch.empty(0, dtype=torch.int, device=device)
2687        with self.assertRaisesRegex(RuntimeError, "but got result with dtype Int"):
2688            torch.cholesky_solve(b, a, out=out)
2689
2690        # device should match
2691        if torch.cuda.is_available():
2692            wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda'
2693            out = torch.empty(0, dtype=dtype, device=wrong_device)
2694            with self.assertRaisesRegex(RuntimeError, "tensors to be on the same device"):
2695                torch.cholesky_solve(b, a, out=out)
2696
2697        # if out tensor with wrong shape is passed a warning is given
2698        with warnings.catch_warnings(record=True) as w:
2699            out = torch.empty(1, dtype=dtype, device=device)
2700            # Trigger warning
2701            torch.cholesky_solve(b, a, out=out)
2702            # Check warning occurs
2703            self.assertEqual(len(w), 1)
2704            self.assertTrue("An output with one or more elements was resized" in str(w[-1].message))
2705
2706    @skipCUDAIfNoMagma
2707    @skipCPUIfNoLapack
2708    @dtypes(torch.double)
2709    def test_cholesky_solve_backward(self, device, dtype):
2710        b_dims = (5, 2)
2711        L_dims = (5, 5)
2712
2713        for test_L_grad in (False, True):
2714            b = torch.randn(*b_dims, dtype=dtype, device=device, requires_grad=True)
2715            L = torch.randn(*L_dims, dtype=dtype, device=device, requires_grad=test_L_grad)
2716            if test_L_grad:
2717                torch.autograd.gradcheck(lambda b, L: torch.cholesky_solve(b, torch.tril(L), upper=False), (b, L))
2718            else:
2719                torch.autograd.gradcheck(lambda b: torch.cholesky_solve(b, L, upper=False), (b,))
2720
2721    @skipCUDAIfNoMagmaAndNoCusolver
2722    @skipCPUIfNoLapack
2723    @dtypes(*floating_and_complex_types())
2724    @precisionOverride({torch.float32: 2e-3, torch.complex64: 2e-3,
2725                        torch.float64: 1e-8, torch.complex128: 1e-8})
2726    def test_inverse(self, device, dtype):
2727        make_fullrank = make_fullrank_matrices_with_distinct_singular_values
2728        make_arg = partial(make_fullrank, device=device, dtype=dtype)
2729
2730        def run_test(torch_inverse, matrix, batches, n):
2731            matrix_inverse = torch_inverse(matrix)
2732
2733            # Compare against NumPy output
2734            # NumPy uses 'gesv' LAPACK routine solving the equation A A_inv = I
2735            # But in PyTorch 'gertf' + 'getrs' is used. As such, there may be some element-wise differences
2736            expected = np.linalg.inv(matrix.cpu().numpy())
2737            self.assertEqual(matrix_inverse, expected, atol=self.precision, rtol=self.precision)
2738
2739            # Additional correctness tests, check matrix*matrix_inverse == identity
2740            identity = torch.eye(n, dtype=dtype, device=device)
2741            self.assertEqual(identity.expand_as(matrix), np.matmul(matrix.cpu(), matrix_inverse.cpu()))
2742            self.assertEqual(identity.expand_as(matrix), np.matmul(matrix_inverse.cpu(), matrix.cpu()))
2743
2744            # check the out= variant
2745            # prepare the expected out tensor
2746            matrix_inverse_out = torch.empty(*batches, n, n, dtype=dtype, device=device)
2747            matrix_inverse_out_t = matrix_inverse_out.mT.clone(memory_format=torch.contiguous_format)
2748            matrix_inverse_out = matrix_inverse_out_t.mT
2749            ans = torch_inverse(matrix, out=matrix_inverse_out)
2750            self.assertEqual(matrix_inverse_out, ans, atol=0, rtol=0)
2751            self.assertEqual(matrix_inverse_out, matrix_inverse, atol=0, rtol=0)
2752
2753            # batched matrices: 3+ dimensional tensors, check matrix_inverse same as single-inverse for each matrix
2754            if matrix.ndim > 2 and batches[0] != 0:
2755                expected_inv_list = []
2756                p = int(np.prod(batches))  # use `p` instead of -1, so that the test works for empty input as well
2757                for mat in matrix.contiguous().view(p, n, n):
2758                    expected_inv_list.append(torch_inverse(mat))
2759                expected_inv = torch.stack(expected_inv_list).view(*batches, n, n)
2760                if self.device_type == 'cuda' and dtype in [torch.float32, torch.complex64]:
2761                    # single-inverse is done using cuSOLVER, while batched inverse is done using MAGMA
2762                    # individual values can be significantly different for fp32, hence rather high rtol is used
2763                    # the important thing is that torch_inverse passes above checks with identity
2764                    self.assertEqual(matrix_inverse, expected_inv, atol=1e-1, rtol=1e-2)
2765                else:
2766                    self.assertEqual(matrix_inverse, expected_inv)
2767
2768        # helper function for testing torch.linalg.inv_ex
2769        def test_inv_ex(input, out=None):
2770            if out is not None:
2771                info = torch.empty(0, dtype=torch.int32, device=device)
2772                return torch.linalg.inv_ex(input, out=(out, info)).inverse
2773            return torch.linalg.inv_ex(input).inverse
2774
2775        for torch_inverse in [torch.inverse, torch.linalg.inv, test_inv_ex]:
2776            for batches, n in itertools.product(
2777                [[], [0], [2], [2, 1]],
2778                [0, 5]
2779            ):
2780                matrices = make_arg(*batches, n, n)
2781                run_test(torch_inverse, matrices, batches, n)
2782
2783                # test non-contiguous input
2784                run_test(torch_inverse, matrices.mT, batches, n)
2785                if n > 0:
2786                    run_test(
2787                        torch_inverse,
2788                        make_arg(*batches, 2 * n, 2 * n)
2789                        .view(-1, n * 2, n * 2)[:, ::2, ::2].view(*batches, n, n),
2790                        batches, n
2791                    )
2792
2793    @skipCUDAIfNoMagmaAndNoCusolver
2794    @skipCPUIfNoLapack
2795    @dtypes(*floating_and_complex_types())
2796    def test_inv_ex_info_device(self, device, dtype):
2797        A = torch.eye(3, 3, dtype=dtype, device=device)
2798        info = torch.linalg.inv_ex(A).info
2799        self.assertTrue(info.device == A.device)
2800
2801    @skipCUDAIfNoMagmaAndNoCusolver
2802    @skipCPUIfNoLapack
2803    @dtypes(*floating_and_complex_types())
2804    def test_inv_ex_singular(self, device, dtype):
2805        # if the input matrix is not invertible, info with positive integer is returned
2806        A = torch.eye(3, 3, dtype=dtype, device=device)
2807        A[-1, -1] = 0  # Now A is singular
2808        info = torch.linalg.inv_ex(A).info
2809        self.assertEqual(info, 3)
2810        with self.assertRaisesRegex(torch.linalg.LinAlgError,
2811                                    r'diagonal element 3 is zero, the inversion could not be completed'):
2812            torch.linalg.inv_ex(A, check_errors=True)
2813
2814        # if at least one matrix in the batch is not positive definite,
2815        # batched info with positive integer for the corresponding matrix is returned
2816        A = torch.eye(3, 3, dtype=dtype, device=device)
2817        A = A.reshape((1, 3, 3))
2818        A = A.repeat(5, 1, 1)
2819        A[3, -2, -2] = 0  # Now A[3] is singular
2820        info = torch.linalg.inv_ex(A).info
2821
2822        expected_info = torch.zeros(A.shape[:-2], dtype=torch.int32, device=device)
2823        expected_info[3] = 2
2824        self.assertEqual(info, expected_info)
2825        with self.assertRaisesRegex(torch.linalg.LinAlgError, r'\(Batch element 3\): The diagonal element 2 is zero'):
2826            torch.linalg.inv_ex(A, check_errors=True)
2827
2828    @slowTest
2829    @skipCUDAIfNoMagmaAndNoCusolver
2830    @skipCPUIfNoLapack
2831    @dtypes(*floating_and_complex_types())
2832    @precisionOverride({torch.float32: 2e-3, torch.complex64: 2e-3,
2833                        torch.float64: 1e-5, torch.complex128: 1e-5})
2834    def test_inverse_many_batches(self, device, dtype):
2835        make_fullrank = make_fullrank_matrices_with_distinct_singular_values
2836        make_arg = partial(make_fullrank, device=device, dtype=dtype)
2837
2838        def test_inverse_many_batches_helper(torch_inverse, b, n):
2839            matrices = make_arg(b, n, n)
2840            matrices_inverse = torch_inverse(matrices)
2841
2842            # Compare against NumPy output
2843            expected = np.linalg.inv(matrices.cpu().numpy())
2844            self.assertEqual(matrices_inverse, expected, atol=self.precision, rtol=1e-3)
2845
2846        for torch_inverse in [torch.inverse, torch.linalg.inv]:
2847            test_inverse_many_batches_helper(torch_inverse, 5, 256)
2848            test_inverse_many_batches_helper(torch_inverse, 3, 512)
2849
2850    @skipCUDAIfNoMagmaAndNoCusolver
2851    @skipCPUIfNoLapack
2852    @onlyNativeDeviceTypes   # TODO: XLA doesn't raise exception
2853    @dtypes(*floating_and_complex_types())
2854    @skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/129882")
2855    def test_inverse_errors(self, device, dtype):
2856        # inverse expects batches of square matrices as input
2857        with self.assertRaisesRegex(RuntimeError, "must be batches of square matrices"):
2858            torch.inverse(torch.randn(2, 3, 4, 3))
2859
2860        # if input is not invertible, RuntimeError is raised mentioning the first non-invertible batch
2861        def run_test_singular_input(batch_dim, n):
2862            x = torch.eye(3, 3, dtype=dtype, device=device).reshape((1, 3, 3)).repeat(batch_dim, 1, 1)
2863            x[n, -1, -1] = 0
2864            with self.assertRaisesRegex(torch.linalg.LinAlgError, rf'\(Batch element {n}\): The diagonal element 3 is zero'):
2865                torch.inverse(x)
2866
2867        for params in [(1, 0), (2, 0), (2, 1), (4, 0), (4, 2), (10, 2)]:
2868            run_test_singular_input(*params)
2869
2870    @unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Test fails for float64 on GPU (P100, V100) on Meta infra")
2871    @skipCUDAIfNoMagmaAndNoCusolver
2872    @skipCPUIfNoLapack
2873    @onlyNativeDeviceTypes   # TODO: XLA doesn't raise exception
2874    @dtypes(*floating_and_complex_types())
2875    def test_inverse_errors_large(self, device, dtype):
2876        # Test batched inverse of singular matrices reports errors without crashing (gh-51930)
2877        x = torch.empty((8, 10, 616, 616), dtype=dtype, device=device)
2878        x[:] = torch.eye(616, dtype=dtype, device=device)
2879        x[..., 10, 10] = 0
2880        with self.assertRaisesRegex(torch.linalg.LinAlgError, r'\(Batch element 0\): The diagonal element 11 is zero'):
2881            torch.inverse(x)
2882
2883    @precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3, torch.float64: 1e-7, torch.complex128: 1e-7})
2884    @skipCUDAIfNoMagma
2885    @skipCPUIfNoLapack
2886    @dtypes(*floating_and_complex_types())
2887    def test_pinv(self, device, dtype):
2888        from torch.testing._internal.common_utils import random_hermitian_pd_matrix
2889
2890        def run_test_main(A, hermitian):
2891            # Testing against definition for pseudo-inverses
2892            A_pinv = torch.linalg.pinv(A, hermitian=hermitian)
2893            np_A = A.cpu().numpy()
2894            np_A_pinv = A_pinv.cpu().numpy()
2895            if A.numel() > 0:
2896                self.assertEqual(A, np_A @ np_A_pinv @ np_A, atol=self.precision, rtol=self.precision)
2897                self.assertEqual(A_pinv, np_A_pinv @ np_A @ np_A_pinv, atol=self.precision, rtol=self.precision)
2898                self.assertEqual(np_A @ np_A_pinv, (np_A @ np_A_pinv).conj().swapaxes(-2, -1))
2899                self.assertEqual(np_A_pinv @ np_A, (np_A_pinv @ np_A).conj().swapaxes(-2, -1))
2900            else:
2901                self.assertEqual(A.shape, A_pinv.shape[:-2] + (A_pinv.shape[-1], A_pinv.shape[-2]))
2902
2903            # Check out= variant
2904            out = torch.empty_like(A_pinv)
2905            ans = torch.linalg.pinv(A, hermitian=hermitian, out=out)
2906            self.assertEqual(ans, out)
2907            self.assertEqual(ans, A_pinv)
2908
2909        def run_test_numpy(A, hermitian):
2910            # Check against NumPy output
2911            # Test float rcond, and specific value for each matrix
2912            rconds = [float(torch.rand(1)), ]
2913            # Test different types of rcond tensor
2914            for rcond_type in all_types():
2915                rconds.append(torch.rand(A.shape[:-2], dtype=torch.double, device=device).to(rcond_type))
2916            # Test broadcasting of rcond
2917            if A.ndim > 2:
2918                rconds.append(torch.rand(A.shape[-3], device=device))
2919            for rcond in rconds:
2920                actual = torch.linalg.pinv(A, rcond=rcond, hermitian=hermitian)
2921                torch_rtol = torch.linalg.pinv(A, rtol=rcond, hermitian=hermitian)
2922                self.assertEqual(actual, torch_rtol)
2923                numpy_rcond = rcond if isinstance(rcond, float) else rcond.cpu().numpy()
2924                expected = np.linalg.pinv(A.cpu().numpy(), rcond=numpy_rcond, hermitian=hermitian)
2925                self.assertEqual(actual, expected, atol=self.precision, rtol=1e-5)
2926
2927        for sizes in [(5, 5), (3, 5, 5), (3, 2, 5, 5),  # square matrices
2928                      (3, 2), (5, 3, 2), (2, 5, 3, 2),  # fat matrices
2929                      (2, 3), (5, 2, 3), (2, 5, 2, 3),  # thin matrices
2930                      (0, 0), (0, 2), (2, 0), (3, 0, 0), (0, 3, 0), (0, 0, 3)]:  # zero numel matrices
2931            A = torch.randn(*sizes, dtype=dtype, device=device)
2932            hermitian = False
2933            run_test_main(A, hermitian)
2934            run_test_numpy(A, hermitian)
2935
2936        # Check hermitian = True
2937        for sizes in [(5, 5), (3, 5, 5), (3, 2, 5, 5),  # square matrices
2938                      (0, 0), (3, 0, 0), ]:  # zero numel square matrices
2939            A = random_hermitian_pd_matrix(sizes[-1], *sizes[:-2], dtype=dtype, device=device)
2940            hermitian = True
2941            run_test_main(A, hermitian)
2942            run_test_numpy(A, hermitian)
2943
2944    @skipCUDAIfNoMagma
2945    @skipCPUIfNoLapack
2946    @dtypes(*floating_and_complex_types())
2947    def test_pinv_errors_and_warnings(self, device, dtype):
2948        # pinv requires at least 2D tensor
2949        a = torch.randn(1, device=device, dtype=dtype)
2950        with self.assertRaisesRegex(RuntimeError, "expected a tensor with 2 or more dimensions"):
2951            torch.linalg.pinv(a)
2952
2953        # if non-empty out tensor with wrong shape is passed a warning is given
2954        a = torch.randn(3, 3, dtype=dtype, device=device)
2955        out = torch.empty(7, 7, dtype=dtype, device=device)
2956        with warnings.catch_warnings(record=True) as w:
2957            # Trigger warning
2958            torch.linalg.pinv(a, out=out)
2959            # Check warning occurs
2960            self.assertEqual(len(w), 1)
2961            self.assertTrue("An output with one or more elements was resized" in str(w[-1].message))
2962
2963        # dtypes of out and input should be safely castable
2964        out = torch.empty_like(a).to(torch.int)
2965        with self.assertRaisesRegex(RuntimeError, "but got result with dtype Int"):
2966            torch.linalg.pinv(a, out=out)
2967
2968        if torch.cuda.is_available():
2969            # device of out and input should match
2970            wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda'
2971            out = torch.empty_like(a).to(wrong_device)
2972            with self.assertRaisesRegex(RuntimeError, "Expected result and input tensors to be on the same device"):
2973                torch.linalg.pinv(a, out=out)
2974
2975            # device of rcond and input should match
2976            wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda'
2977            rcond = torch.full((), 1e-2, device=wrong_device)
2978            with self.assertRaisesRegex(RuntimeError, "Expected all tensors to be on the same device"):
2979                torch.linalg.pinv(a, rcond=rcond)
2980
2981        # rcond can't be complex
2982        rcond = torch.full((), 1j, device=device)
2983        with self.assertRaisesRegex(RuntimeError, "rcond tensor of complex type is not supported"):
2984            torch.linalg.pinv(a, rcond=rcond)
2985
2986        # atol can't be complex
2987        atol = torch.full((), 1j, device=device)
2988        with self.assertRaisesRegex(RuntimeError, "atol tensor of complex type is not supported"):
2989            torch.linalg.pinv(a, atol=atol)
2990
2991        # rtol can't be complex
2992        rtol = torch.full((), 1j, device=device)
2993        with self.assertRaisesRegex(RuntimeError, "rtol tensor of complex type is not supported"):
2994            torch.linalg.pinv(a, rtol=rtol)
2995
2996    @skipCUDAIfNoMagmaAndNoCusolver
2997    @skipCPUIfNoLapack
2998    @dtypes(*floating_and_complex_types())
2999    @skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/129882")
3000    def test_inv_errors_and_warnings(self, device, dtype):
3001        # inv expects batches of square matrices as input
3002        a = torch.randn(2, 3, 4, 3, dtype=dtype, device=device)
3003        with self.assertRaisesRegex(RuntimeError, "must be batches of square matrices"):
3004            torch.linalg.inv(a)
3005
3006        # inv requires the input to be at least 2 dimensional tensor
3007        a = torch.randn(2, device=device, dtype=dtype)
3008        with self.assertRaisesRegex(RuntimeError, "must have at least 2 dimensions"):
3009            torch.linalg.inv(a)
3010
3011        # if input is not invertible, RuntimeError is raised mentioning the first non-invertible batch
3012        def run_test_singular_input(batch_dim, n):
3013            a = torch.eye(3, 3, dtype=dtype, device=device).reshape((1, 3, 3)).repeat(batch_dim, 1, 1)
3014            a[n, -1, -1] = 0
3015            with self.assertRaisesRegex(torch.linalg.LinAlgError, rf"\(Batch element {n}\): The diagonal element 3 is zero"):
3016                torch.linalg.inv(a)
3017
3018        for params in [(1, 0), (2, 0), (2, 1), (4, 0), (4, 2), (10, 2)]:
3019            run_test_singular_input(*params)
3020
3021        # dtypes should match
3022        a = torch.eye(2, dtype=dtype, device=device)
3023        out = torch.empty(0, dtype=torch.int, device=device)
3024        with self.assertRaisesRegex(RuntimeError, "but got int instead"):
3025            torch.linalg.inv(a, out=out)
3026
3027        # device should match
3028        if torch.cuda.is_available():
3029            wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda'
3030            out = torch.empty(0, device=wrong_device, dtype=dtype)
3031            with self.assertRaisesRegex(RuntimeError, "tensors to be on the same device"):
3032                torch.linalg.inv(a, out=out)
3033
3034        # if out tensor with wrong shape is passed a warning is given
3035        with warnings.catch_warnings(record=True) as w:
3036            a = torch.eye(2, dtype=dtype, device=device)
3037            out = torch.empty(1, dtype=dtype, device=device)
3038            # Trigger warning
3039            torch.linalg.inv(a, out=out)
3040            # Check warning occurs
3041            self.assertEqual(len(w), 1)
3042            self.assertTrue("An output with one or more elements was resized" in str(w[-1].message))
3043
3044        # if out tensor in batched column major format but with wrong a warning is given
3045        with warnings.catch_warnings(record=True) as w:
3046            a = torch.eye(2, dtype=dtype, device=device)
3047            out = torch.empty(3, 3, dtype=dtype, device=device)
3048            out = out.mT.clone(memory_format=torch.contiguous_format)
3049            out = out.mT
3050            self.assertTrue(out.mT.is_contiguous())
3051            # Trigger warning
3052            torch.linalg.inv(a, out=out)
3053            # Check warning occurs
3054            self.assertEqual(len(w), 1)
3055            self.assertTrue("An output with one or more elements was resized" in str(w[-1].message))
3056
3057    def solve_test_helper(self, A_dims, b_dims, device, dtype):
3058        make_fullrank = make_fullrank_matrices_with_distinct_singular_values
3059        make_A = partial(make_fullrank, device=device, dtype=dtype)
3060
3061        b = torch.randn(*b_dims, dtype=dtype, device=device)
3062        A = make_A(*A_dims)
3063        return b, A
3064
3065    @skipCUDAIfNoMagma
3066    @skipCPUIfNoLapack
3067    @dtypes(*floating_and_complex_types())
3068    @precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3})
3069    def test_solve(self, device, dtype):
3070        def run_test(n, batch, rhs):
3071            A_dims = (*batch, n, n)
3072            b_dims = (*batch, n, *rhs)
3073            b, A = self.solve_test_helper(A_dims, b_dims, device, dtype)
3074
3075            # Correctness test
3076            x = torch.linalg.solve(A, b)
3077            if rhs == ():
3078                Ax = np.matmul(A.cpu(), x.unsqueeze(-1).cpu())
3079                Ax.squeeze_(-1)
3080            else:
3081                Ax = np.matmul(A.cpu(), x.cpu())
3082            self.assertEqual(b.expand_as(Ax), Ax)
3083
3084            # Check against NumPy
3085            expected = np.linalg.solve(A.cpu().numpy(), b.expand_as(x).cpu().numpy())
3086            self.assertEqual(x, expected)
3087
3088        batches = [(), (0, ), (3, ), (2, 3)]
3089        ns = [0, 5, 32]
3090        nrhs = [(), (1, ), (5, )]
3091        for n, batch, rhs in itertools.product(ns, batches, nrhs):
3092            run_test(n, batch, rhs)
3093
3094    @skipCUDAIfNoMagmaAndNoCusolver
3095    @skipCPUIfNoLapack
3096    @dtypes(*floating_and_complex_types())
3097    def test_solve_batched_broadcasting(self, device, dtype):
3098        from numpy.linalg import solve
3099
3100        def run_test(A_dims, B_dims):
3101            A_matrix_size = A_dims[-1]
3102            A_batch_dims = A_dims[:-2]
3103            B, A = self.solve_test_helper(A_batch_dims + (A_matrix_size, A_matrix_size), B_dims, device, dtype)
3104            actual = torch.linalg.solve(A, B)
3105            expected = solve(A.cpu().numpy(), B.cpu().numpy())
3106            self.assertEqual(actual, expected)
3107
3108        # test against numpy.linalg.solve
3109        run_test((5, 5), (2, 0, 5, 3))  # broadcasting with 0 batch dim
3110        run_test((2, 0, 5, 5), (5, 3))  # broadcasting with 0 batch dim
3111        run_test((2, 1, 3, 4, 4), (4, 6))  # broadcasting B
3112        run_test((4, 4), (2, 1, 3, 4, 2))  # broadcasting A
3113        run_test((1, 3, 1, 4, 4), (2, 1, 3, 4, 5))  # broadcasting A & B
3114
3115    @skipCUDAIfNoMagma
3116    @skipCPUIfNoLapack
3117    @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble)
3118    @precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4})
3119    def test_tensorsolve(self, device, dtype):
3120        def run_test(a_shape, dims):
3121            a = torch.randn(a_shape, dtype=dtype, device=device)
3122            b = torch.randn(a_shape[:2], dtype=dtype, device=device)
3123            result = torch.linalg.tensorsolve(a, b, dims=dims)
3124            expected = np.linalg.tensorsolve(a.cpu().numpy(), b.cpu().numpy(), axes=dims)
3125            self.assertEqual(result, expected)
3126
3127            # check the out= variant
3128            out = torch.empty_like(result)
3129            ans = torch.linalg.tensorsolve(a, b, dims=dims, out=out)
3130            self.assertEqual(ans, out)
3131            self.assertEqual(ans, result)
3132
3133        a_shapes = [(2, 3, 6), (3, 4, 4, 3)]
3134        dims = [None, (0, 2)]
3135        for a_shape, d in itertools.product(a_shapes, dims):
3136            run_test(a_shape, d)
3137
3138    @skipCUDAIfNoMagma
3139    @skipCPUIfNoLapack
3140    @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble)
3141    def test_tensorsolve_empty(self, device, dtype):
3142        # Check for empty inputs. NumPy does not work for these cases.
3143        a = torch.empty(0, 0, 1, 2, 3, 0, dtype=dtype, device=device)
3144        b = torch.empty(a.shape[:2], dtype=dtype, device=device)
3145        x = torch.linalg.tensorsolve(a, b)
3146        self.assertEqual(torch.tensordot(a, x, dims=len(x.shape)), b)
3147
3148    @skipCUDAIfNoMagma
3149    @skipCPUIfNoLapack
3150    @dtypes(torch.float32)
3151    def test_tensorsolve_errors_and_warnings(self, device, dtype):
3152        # tensorsolve expects the input that can be reshaped to a square matrix
3153        a = torch.eye(2 * 3 * 4, dtype=dtype, device=device).reshape((2 * 3, 4, 2, 3, 4))
3154        b = torch.randn(8, 4, dtype=dtype, device=device)
3155        self.assertTrue(np.prod(a.shape[2:]) != np.prod(b.shape))
3156        with self.assertRaisesRegex(RuntimeError, r'Expected self to satisfy the requirement'):
3157            torch.linalg.tensorsolve(a, b)
3158
3159        # if non-empty out tensor with wrong shape is passed a warning is given
3160        out = torch.empty_like(a)
3161        b = torch.randn(6, 4, dtype=dtype, device=device)
3162        with warnings.catch_warnings(record=True) as w:
3163            # Trigger warning
3164            torch.linalg.tensorsolve(a, b, out=out)
3165            # Check warning occurs
3166            self.assertEqual(len(w), 1)
3167            self.assertTrue("An output with one or more elements was resized" in str(w[-1].message))
3168
3169        # dtypes should be safely castable
3170        out = torch.empty_like(a).to(torch.int)
3171        with self.assertRaisesRegex(RuntimeError, "but got result with dtype Int"):
3172            torch.linalg.tensorsolve(a, b, out=out)
3173
3174        # device should match
3175        if torch.cuda.is_available():
3176            wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda'
3177            out = torch.empty(0, dtype=dtype, device=wrong_device)
3178            with self.assertRaisesRegex(RuntimeError, "tensors to be on the same device"):
3179                torch.linalg.tensorsolve(a, b, out=out)
3180
3181    @skipCUDAIfNoMagma
3182    @skipCPUIfNoLapack
3183    @dtypes(*floating_and_complex_types())
3184    @precisionOverride({torch.float: 1e-3, torch.cfloat: 1e-3})
3185    def test_tensorinv(self, device, dtype):
3186
3187        def run_test(a_shape, ind):
3188            a = torch.randn(a_shape, dtype=dtype, device=device)
3189            a_numpy = a.cpu().numpy()
3190            result = torch.linalg.tensorinv(a, ind=ind)
3191            expected = np.linalg.tensorinv(a_numpy, ind=ind)
3192            self.assertEqual(result, expected)
3193
3194            # check the out= variant
3195            out = torch.empty_like(result)
3196            ans = torch.linalg.tensorinv(a, ind=ind, out=out)
3197            self.assertEqual(ans, out)
3198            self.assertEqual(ans, result)
3199
3200        # compare to NumPy output
3201        run_test((12, 3, 4), ind=1)
3202        run_test((3, 8, 24), ind=2)
3203        run_test((18, 3, 3, 2), ind=1)
3204        run_test((1, 4, 2, 2), ind=2)
3205        run_test((2, 3, 5, 30), ind=3)
3206        run_test((24, 2, 2, 3, 2), ind=1)
3207        run_test((3, 4, 2, 3, 2), ind=2)
3208        run_test((1, 2, 3, 2, 3), ind=3)
3209        run_test((3, 2, 1, 2, 12), ind=4)
3210
3211    @skipMeta  # See https://github.com/pytorch/pytorch/issues/53739
3212    @skipCUDAIfNoMagma
3213    @skipCPUIfNoLapack
3214    @dtypes(*floating_and_complex_types())
3215    def test_tensorinv_empty(self, device, dtype):
3216        for ind in range(1, 4):
3217            # Check for empty inputs. NumPy does not work for these cases.
3218            a = torch.empty(0, 0, 1, 2, 3, 0, dtype=dtype, device=device)
3219            a_inv = torch.linalg.tensorinv(a, ind=ind)
3220            self.assertEqual(a_inv.shape, a.shape[ind:] + a.shape[:ind])
3221
3222    @skipMeta  # See https://github.com/pytorch/pytorch/issues/53739
3223    @skipCUDAIfNoMagma
3224    @skipCPUIfNoLapack
3225    @dtypes(*floating_and_complex_types())
3226    def test_tensorinv_errors_and_warnings(self, device, dtype):
3227
3228        def check_shape(a_shape, ind):
3229            # tensorinv requires the input to satisfy
3230            # prod(a.shape[ind:]) == prod(a.shape[:ind])
3231            a = torch.randn(a_shape, dtype=dtype, device=device)
3232            with self.assertRaisesRegex(RuntimeError, "Expected self to satisfy the requirement"):
3233                torch.linalg.tensorinv(a, ind=ind)
3234
3235        def check_ind(a_shape, ind):
3236            a = torch.randn(a_shape, dtype=dtype, device=device)
3237            with self.assertRaisesRegex(RuntimeError, "Expected a strictly positive integer"):
3238                torch.linalg.tensorinv(a, ind=ind)
3239
3240        def check_out(a_shape, ind):
3241            # if non-empty out tensor with wrong shape is passed a warning is given
3242            a = torch.randn(a_shape, dtype=dtype, device=device)
3243            out = torch.empty_like(a)
3244            with warnings.catch_warnings(record=True) as w:
3245                # Trigger warning
3246                torch.linalg.tensorinv(a, ind=ind, out=out)
3247                # Check warning occurs
3248                self.assertEqual(len(w), 1)
3249                self.assertTrue("An output with one or more elements was resized" in str(w[-1].message))
3250
3251            # dtypes should be safely castable
3252            out = torch.empty(0, dtype=torch.int, device=device)
3253            with self.assertRaisesRegex(RuntimeError, "but got result with dtype Int"):
3254                torch.linalg.tensorinv(a, ind=ind, out=out)
3255
3256            # device should match
3257            if torch.cuda.is_available():
3258                wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda'
3259                out = torch.empty(0, dtype=dtype, device=wrong_device)
3260                with self.assertRaisesRegex(RuntimeError, "tensors to be on the same device"):
3261                    torch.linalg.tensorinv(a, ind=ind, out=out)
3262
3263        # test for invalid shape
3264        check_shape((2, 3, 4), ind=1)
3265        check_shape((1, 2, 3, 4), ind=3)
3266
3267        # test for invalid ind
3268        check_ind((12, 3, 4), ind=-1)
3269        check_ind((18, 3, 3, 2), ind=0)
3270
3271        # test for invalid out tensor
3272        check_out((12, 3, 4), ind=1)
3273        check_out((3, 8, 24), ind=2)
3274
3275    @skipCUDAIfNoMagma
3276    @skipCPUIfNoLapack
3277    @dtypes(*floating_and_complex_types())
3278    def test_tensorinv_singular_input(self, device, dtype):
3279
3280        def check_singular_input(a_shape, ind):
3281            prod_ind_end = np.prod(a_shape[ind:])
3282            a = torch.eye(prod_ind_end, dtype=dtype, device=device)
3283            a[-1, -1] = 0   # Now `a` is singular
3284            a = a.reshape(a_shape)
3285            with self.assertRaisesRegex(torch.linalg.LinAlgError, "The diagonal element"):
3286                torch.linalg.tensorinv(a, ind=ind)
3287
3288        # test for non-invertible input
3289        check_singular_input((12, 3, 4), ind=1)
3290        check_singular_input((3, 6, 18), ind=2)
3291
3292    def _test_dot_vdot_vs_numpy(self, device, dtype, torch_fn, np_fn):
3293        def check(x, y):
3294            # Compare with numpy
3295            res = torch_fn(x, y)
3296            if x.dtype == torch.bfloat16:
3297                ref = torch.from_numpy(np.array(np_fn(x.cpu().float().numpy(), y.cpu().float().numpy())))
3298            else:
3299                ref = torch.from_numpy(np.array(np_fn(x.cpu().numpy(), y.cpu().numpy())))
3300            if res.dtype == torch.bfloat16:
3301                self.assertEqual(res.cpu(), ref.bfloat16())
3302            else:
3303                self.assertEqual(res.cpu(), ref)
3304
3305            # Test out variant
3306            out = torch.empty_like(res)
3307            torch_fn(x, y, out=out)
3308            self.assertEqual(out, res)
3309
3310        # Empty
3311        x = torch.tensor([], dtype=dtype, device=device)
3312        y = torch.tensor([], dtype=dtype, device=device)
3313        check(x, y)
3314
3315        # Contiguous
3316        x = 0.1 * torch.randn(5000, dtype=dtype, device=device)
3317        y = 0.1 * torch.randn(5000, dtype=dtype, device=device)
3318        check(x, y)
3319
3320        # 0 strided
3321        y = 0.1 * torch.randn(1, dtype=dtype, device=device).expand(5000)
3322        check(x, y)
3323
3324        # 2 strided
3325        check(x[::2], y[::2])
3326
3327    @dtypes(torch.float, torch.cfloat, torch.bfloat16, torch.float16)
3328    @dtypesIfCUDA(torch.float, torch.cfloat)
3329    @precisionOverride({torch.cfloat: 1e-4, torch.float32: 5e-5, torch.bfloat16: 1e-0})
3330    def test_dot_vs_numpy(self, device, dtype):
3331        self._test_dot_vdot_vs_numpy(device, dtype, torch.dot, np.dot)
3332
3333    @dtypes(torch.float, torch.cfloat)
3334    @precisionOverride({torch.cfloat: 1e-4, torch.float32: 5e-5})
3335    def test_vdot_vs_numpy(self, device, dtype):
3336        self._test_dot_vdot_vs_numpy(device, dtype, torch.vdot, np.vdot)
3337
3338    def _test_dot_vdot_invalid_args(self, device, torch_fn, complex_dtypes=False):
3339        def check(x, y, regex):
3340            with self.assertRaisesRegex(RuntimeError, regex):
3341                torch_fn(x, y)
3342
3343        if complex_dtypes:
3344            x = torch.randn(1, dtype=torch.cfloat, device=device)
3345            y = torch.randn(3, dtype=torch.cdouble, device=device)
3346        else:
3347            x = torch.randn(1, dtype=torch.float, device=device)
3348            y = torch.randn(3, dtype=torch.double, device=device)
3349
3350        check(x, y, 'dot : expected both vectors to have same dtype')
3351        check(x.reshape(1, 1), y, '1D tensors expected')
3352        check(x.expand(9), y.to(x.dtype), 'inconsistent tensor size')
3353
3354        if self.device_type != 'cpu':
3355            x_cpu = x.expand(3).cpu()
3356            check(x_cpu, y.to(x.dtype), 'Expected all tensors to be on the same device')
3357
3358    @onlyNativeDeviceTypes
3359    def test_vdot_invalid_args(self, device):
3360        self._test_dot_vdot_invalid_args(device, torch.vdot)
3361        self._test_dot_vdot_invalid_args(device, torch.vdot, complex_dtypes=True)
3362
3363    @onlyNativeDeviceTypes
3364    def test_dot_invalid_args(self, device):
3365        self._test_dot_vdot_invalid_args(device, torch.dot)
3366        self._test_dot_vdot_invalid_args(device, torch.dot, complex_dtypes=True)
3367
3368    @skipCUDAIfNoMagma
3369    @skipCPUIfNoLapack
3370    @dtypes(*floating_and_complex_types())
3371    def test_matrix_rank(self, device, dtype):
3372        matrix_rank = torch.linalg.matrix_rank
3373
3374        def run_test(shape0, shape1, batch):
3375            a = torch.randn(*batch, shape0, shape1, dtype=dtype, device=device)
3376            rank_a = matrix_rank(a)
3377
3378            self.assertEqual(rank_a, matrix_rank(a.mH))
3379            aaH = torch.matmul(a, a.mH)
3380            rank_aaH = matrix_rank(aaH)
3381            rank_aaH_hermitian = matrix_rank(aaH, hermitian=True)
3382            self.assertEqual(rank_aaH, rank_aaH_hermitian)
3383            aHa = torch.matmul(a.mH, a)
3384            self.assertEqual(matrix_rank(aHa), matrix_rank(aHa, hermitian=True))
3385
3386            # check against NumPy
3387            self.assertEqual(rank_a, np.linalg.matrix_rank(a.cpu().numpy()))
3388            self.assertEqual(matrix_rank(a, 0.01), np.linalg.matrix_rank(a.cpu().numpy(), 0.01))
3389
3390            self.assertEqual(rank_aaH, np.linalg.matrix_rank(aaH.cpu().numpy()))
3391            self.assertEqual(matrix_rank(aaH, 0.01), np.linalg.matrix_rank(aaH.cpu().numpy(), 0.01))
3392
3393            # hermitian flag for NumPy was added in 1.14.0
3394            if np.lib.NumpyVersion(np.__version__) >= '1.14.0':
3395                self.assertEqual(rank_aaH_hermitian,
3396                                 np.linalg.matrix_rank(aaH.cpu().numpy(), hermitian=True))
3397                self.assertEqual(matrix_rank(aaH, 0.01, True),
3398                                 np.linalg.matrix_rank(aaH.cpu().numpy(), 0.01, True))
3399
3400            # check out= variant
3401            out = torch.empty(a.shape[:-2], dtype=torch.int64, device=device)
3402            ans = matrix_rank(a, out=out)
3403            self.assertEqual(ans, out)
3404            self.assertEqual(ans, rank_a)
3405
3406        shapes = (3, 13)
3407        batches = ((), (0, ), (4, ), (3, 5, ))
3408        for (shape0, shape1), batch in zip(itertools.product(shapes, reversed(shapes)), batches):
3409            run_test(shape0, shape1, batch)
3410
3411    @skipCUDAIfNoMagma
3412    @skipCPUIfNoLapack
3413    @dtypes(*floating_and_complex_types())
3414    def test_matrix_rank_atol(self, device, dtype):
3415
3416        def run_test_atol(shape0, shape1, batch):
3417            a = make_tensor((*batch, shape0, shape1), dtype=dtype, device=device)
3418            # Check against NumPy output
3419            # Test float tol, and specific value for each matrix
3420            tolerances = [float(torch.rand(1)), ]
3421            # Test different types of tol tensor
3422            for tol_type in all_types():
3423                tolerances.append(make_tensor(a.shape[:-2], dtype=tol_type, device=device, low=0))
3424            # Test broadcasting of tol
3425            if a.ndim > 2:
3426                tolerances.append(make_tensor(a.shape[-3], dtype=torch.float32, device=device, low=0))
3427            for tol in tolerances:
3428                actual = torch.linalg.matrix_rank(a, atol=tol)
3429                actual_tol = torch.linalg.matrix_rank(a, tol=tol)
3430                self.assertEqual(actual, actual_tol)
3431                numpy_tol = tol if isinstance(tol, float) else tol.cpu().numpy()
3432                expected = np.linalg.matrix_rank(a.cpu().numpy(), tol=numpy_tol)
3433                self.assertEqual(actual, expected)
3434
3435        shapes = (3, 13)
3436        batches = ((), (0, ), (4, ), (3, 5, ))
3437        for (shape0, shape1), batch in zip(itertools.product(shapes, reversed(shapes)), batches):
3438            run_test_atol(shape0, shape1, batch)
3439
3440    @skipCUDAIfNoMagma
3441    @skipCPUIfNoLapack
3442    @dtypes(torch.float64)
3443    def test_matrix_rank_atol_rtol(self, device, dtype):
3444        make_fullrank = make_fullrank_matrices_with_distinct_singular_values
3445        make_arg = partial(make_fullrank, device=device, dtype=dtype)
3446
3447        # creates a matrix with singular values rank=n and singular values in range [2/3, 3/2]
3448        # the singular values are 1 + 1/2, 1 - 1/3, 1 + 1/4, 1 - 1/5, ...
3449        n = 9
3450        a = make_arg(n, n)
3451
3452        # test float and tensor variants
3453        for tol_value in [0.81, torch.tensor(0.81, device=device)]:
3454            # using rtol (relative tolerance) takes into account the largest singular value (1.5 in this case)
3455            result = torch.linalg.matrix_rank(a, rtol=tol_value)
3456            self.assertEqual(result, 2)  # there are 2 singular values above 1.5*0.81 = 1.215
3457
3458            # atol is used directly to compare with singular values
3459            result = torch.linalg.matrix_rank(a, atol=tol_value)
3460            self.assertEqual(result, 7)  # there are 7 singular values above 0.81
3461
3462            # when both are specified the maximum tolerance is used
3463            result = torch.linalg.matrix_rank(a, atol=tol_value, rtol=tol_value)
3464            self.assertEqual(result, 2)  # there are 2 singular values above max(0.81, 1.5*0.81)
3465
3466    @skipCUDAIfNoMagma
3467    @skipCPUIfNoLapack
3468    @skipCUDAVersionIn([(11, 6), (11, 7)])  # https://github.com/pytorch/pytorch/issues/75391
3469    @dtypes(*floating_and_complex_types())
3470    def test_matrix_rank_empty(self, device, dtype):
3471        matrix_rank = torch.linalg.matrix_rank
3472
3473        # NumPy doesn't work for input with no elements
3474        def run_test(shape0, shape1, batch):
3475            a = torch.randn(*batch, shape0, shape1, dtype=dtype, device=device)
3476            rank_a = matrix_rank(a)
3477            expected = torch.zeros(batch, dtype=torch.int64, device=device)
3478
3479            self.assertEqual(rank_a, matrix_rank(a.mH))
3480
3481            aaH = torch.matmul(a, a.mH)
3482            rank_aaH = matrix_rank(aaH)
3483            rank_aaH_hermitian = matrix_rank(aaH, hermitian=True)
3484            self.assertEqual(rank_aaH, rank_aaH_hermitian)
3485
3486            aHa = torch.matmul(a.mH, a)
3487            self.assertEqual(matrix_rank(aHa), matrix_rank(aHa, hermitian=True))
3488
3489            self.assertEqual(rank_a, expected)
3490            self.assertEqual(matrix_rank(a, 0.01), expected)
3491
3492            self.assertEqual(rank_aaH, expected)
3493            self.assertEqual(matrix_rank(aaH, 0.01), expected)
3494
3495            self.assertEqual(rank_aaH_hermitian, expected)
3496            self.assertEqual(matrix_rank(aaH, 0.01, True), expected)
3497
3498        batches = ((), (4, ), (3, 5, ))
3499        for batch in batches:
3500            run_test(0, 0, batch)
3501            run_test(0, 3, batch)
3502            run_test(3, 0, batch)
3503
3504    @skipCUDAIfNoMagma
3505    @skipCPUIfNoLapack
3506    @dtypes(*floating_and_complex_types())
3507    def test_matrix_rank_out_errors_and_warnings(self, device, dtype):
3508        # dtypes should be safely castable
3509        a = torch.eye(2, dtype=dtype, device=device)
3510        out = torch.empty(0, dtype=torch.bool, device=device)
3511        with self.assertRaisesRegex(RuntimeError, "but got result with dtype Bool"):
3512            torch.linalg.matrix_rank(a, out=out)
3513
3514        # device should match
3515        if torch.cuda.is_available():
3516            wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda'
3517            out = torch.empty(0, dtype=dtype, device=wrong_device)
3518            with self.assertRaisesRegex(RuntimeError, "tensors to be on the same device"):
3519                torch.linalg.matrix_rank(a, out=out)
3520
3521        # if out tensor with wrong shape is passed a warning is given
3522        with warnings.catch_warnings(record=True) as w:
3523            out = torch.empty(3, dtype=dtype, device=device)
3524            # Trigger warning
3525            torch.linalg.matrix_rank(a, out=out)
3526            # Check warning occurs
3527            self.assertEqual(len(w), 1)
3528            self.assertTrue("An output with one or more elements was resized" in str(w[-1].message))
3529
3530    @skipCUDAIfNoMagma
3531    @skipCPUIfNoLapack
3532    @dtypes(*floating_and_complex_types())
3533    def test_matrix_rank_basic(self, device, dtype):
3534        matrix_rank = torch.linalg.matrix_rank
3535
3536        a = torch.eye(10, dtype=dtype, device=device)
3537        self.assertEqual(matrix_rank(a).item(), 10)
3538        self.assertEqual(matrix_rank(a, hermitian=True).item(), 10)
3539
3540        a[5, 5] = 0
3541        self.assertEqual(matrix_rank(a).item(), 9)
3542        self.assertEqual(matrix_rank(a, hermitian=True).item(), 9)
3543
3544    @onlyNativeDeviceTypes
3545    @dtypes(torch.double)
3546    # This tests only the cases where torch.chain_matmul differs from torch.linalg.multi_dot which this is an "alias" for.
3547    def test_chain_matmul(self, device, dtype):
3548        # chain_matmul accepts a single input tensor while multi_dot does not
3549        t = make_tensor((2, 2), dtype=dtype, device=device)
3550        self.assertEqual(t, torch.chain_matmul(t))
3551        with self.assertRaisesRegex(RuntimeError, r"chain_matmul\(\): Expected one or more matrices"):
3552            torch.chain_matmul()
3553
3554        # chain_matmul expects all tensors to be 2D whereas multi_dot allows the first and last tensors to
3555        # be either 1D or 2D
3556        with self.assertRaisesRegex(RuntimeError, r"Tensor dimension is 1, expected 2 instead"):
3557            torch.chain_matmul(make_tensor(1, dtype=dtype, device=device), make_tensor(1, dtype=dtype, device=device))
3558
3559    @onlyNativeDeviceTypes
3560    @dtypes(torch.double, torch.cdouble)
3561    def test_multi_dot(self, device, dtype):
3562        def check(*shapes):
3563            tensors = [make_tensor(shape, dtype=dtype, device=device) for shape in shapes]
3564            np_arrays = [tensor.cpu().numpy() for tensor in tensors]
3565            res = torch.linalg.multi_dot(tensors).cpu()
3566            ref = torch.from_numpy(np.array(np.linalg.multi_dot(np_arrays)))
3567            self.assertEqual(res, ref)
3568
3569        # test for inputs with empty dimensions
3570        check([0], [0])
3571        check([2], [2, 0])
3572        check([1, 0], [0])
3573        check([0, 2], [2, 1])
3574        check([2, 2], [2, 0])
3575        check([2, 0], [0, 3])
3576        check([0, 0], [0, 1])
3577        check([4, 2], [2, 0], [0, 3], [3, 2])
3578
3579        # test variable output shapes
3580        check([2], [2])
3581        check([1, 2], [2])
3582        check([2], [2, 1])
3583        check([1, 2], [2, 1])
3584        check([3, 2], [2, 4])
3585
3586        # test multiple input tensors
3587        check([3], [3, 4], [4, 2], [2, 5], [5])
3588        check([1, 2], [2, 2], [2, 3], [3, 1])
3589
3590        # test large tensors
3591        check([10, 100], [100, 5], [5, 50])
3592        check([10, 20], [20, 30], [30, 5])
3593
3594    @onlyNativeDeviceTypes
3595    @dtypes(torch.float)
3596    def test_multi_dot_errors(self, device, dtype):
3597        def check(tensors, out, msg):
3598            with self.assertRaisesRegex(RuntimeError, msg):
3599                torch.linalg.multi_dot(tensors, out=out)
3600
3601        a = make_tensor(2, dtype=dtype, device=device)
3602
3603        check([], None, "expected at least 2 tensors")
3604        check([a], None, "expected at least 2 tensors")
3605
3606        check([torch.tensor(1, device=device, dtype=dtype), a], None, "the first tensor must be 1D or 2D")
3607        check([a, torch.tensor(1, device=device, dtype=dtype)], None, "the last tensor must be 1D or 2D")
3608
3609        check([a, a, a], None, "tensor 1 must be 2D")
3610        check([a, make_tensor((2, 2, 2), dtype=dtype, device=device), a], None, "tensor 1 must be 2D")
3611
3612        check([a, make_tensor(2, dtype=torch.double, device=device)], None, "all tensors must have be the same dtype")
3613        check([a, a], torch.empty(0, device=device, dtype=torch.double), "expected out tensor to have dtype")
3614
3615        if self.device_type == 'cuda':
3616            check([a, make_tensor(2, dtype=dtype, device="cpu")], None, "all tensors must be on the same device")
3617            check([a, a], torch.empty(0, dtype=dtype), "expected out tensor to be on device")
3618
3619        check([a, make_tensor(3, dtype=dtype, device=device)], None, "cannot be multiplied")
3620        check([a, make_tensor((3, 2), dtype=dtype, device=device), a], None, "cannot be multiplied")
3621
3622    @precisionOverride({torch.float32: 5e-6, torch.complex64: 5e-6})
3623    @skipCUDAIfNoCusolver
3624    @skipCPUIfNoLapack
3625    @dtypes(*floating_and_complex_types())
3626    def test_qr(self, device, dtype):
3627        def run_test(tensor_dims, some):
3628            A = torch.randn(*tensor_dims, dtype=dtype, device=device)
3629            Q, R = torch.qr(A, some=some)
3630
3631            # Check0: Q[-2:] = (m, n_columns), R[-2:] = (n_columns, n)
3632            m, n = tensor_dims[-2:]
3633            n_columns = m if (not some) and m > n else min(m, n)
3634            self.assertEqual(Q.size(-2), m)
3635            self.assertEqual(R.size(-1), n)
3636            self.assertEqual(Q.size(-1), n_columns)
3637
3638            A_ = A.cpu().numpy()
3639            Q_ = Q.cpu().numpy()
3640            R_ = R.cpu().numpy()
3641
3642            # Check1: A = QR
3643            self.assertEqual(A_, np.matmul(Q_, R_))
3644
3645            # Check2: A = QR (with out)
3646            Q_out, R_out = torch.full_like(Q, math.nan), torch.full_like(R, math.nan)
3647            torch.qr(A, some=some, out=(Q_out, R_out))
3648            Q_out_ = Q_out.cpu().numpy()
3649            R_out_ = R_out.cpu().numpy()
3650            self.assertEqual(A_, np.matmul(Q_out_, R_out_))
3651
3652            # Check3: Q == Q_out, R == R_out
3653            self.assertEqual(Q_, Q_out_)
3654            self.assertEqual(R_, R_out_)
3655
3656            # Check4: Q^{T}Q = I, triu(R) = R
3657            eye = torch.eye(n_columns, device=device, dtype=dtype).expand(Q.shape[:-2] + (n_columns, n_columns)).cpu().numpy()
3658            self.assertEqual(np.matmul(Q_.swapaxes(-1, -2).conj(), Q_), eye)
3659            self.assertEqual(R.triu(), R)
3660
3661        tensor_dims_list = [(0, 5), (0, 0), (5, 0),  # Empty Tensors
3662                            (2, 1, 0, 5), (2, 1, 0, 0), (2, 1, 5, 0), (2, 0, 5, 5),  # Batched empty Tensors
3663                            (3, 5), (5, 5), (5, 3),  # Single matrix
3664                            (7, 3, 5), (7, 5, 5), (7, 5, 3),  # 3-dim Tensors
3665                            (7, 5, 3, 5), (7, 5, 5, 5), (7, 5, 5, 3)]  # 4-dim Tensors
3666        for tensor_dims, some in itertools.product(tensor_dims_list, [True, False]):
3667            run_test(tensor_dims, some)
3668
3669    @skipCUDAIfNoCusolver
3670    @skipCPUIfNoLapack
3671    @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble)
3672    def test_qr_vs_numpy(self, device, dtype):
3673        """
3674        test torch.linalg.qr vs numpy.linalg.qr
3675        """
3676        sizes_to_test = [
3677            (7, 5),
3678            (5, 7),
3679            (5, 0),    # empty
3680            (0, 5),    # empty
3681        ]
3682        for size in sizes_to_test:
3683            t = torch.randn(size, device=device, dtype=dtype)
3684            np_t = t.cpu().numpy()
3685            for mode in ['reduced', 'complete']:
3686                exp_q, exp_r = np.linalg.qr(np_t, mode=mode)
3687                q, r = torch.linalg.qr(t, mode=mode)
3688                self.assertEqual(q, exp_q)
3689                self.assertEqual(r, exp_r)
3690            #
3691            # for mode='r' we need a special logic because numpy returns only r
3692            exp_r = np.linalg.qr(np_t, mode='r')
3693            q, r = torch.linalg.qr(t, mode='r')
3694            # check that q is empty
3695            self.assertEqual(q.shape, (0,))
3696            self.assertEqual(q.dtype, t.dtype)
3697            self.assertEqual(q.device, t.device)
3698            # check r
3699            self.assertEqual(r, exp_r)
3700
3701    @skipCUDAIfNoCusolver
3702    @skipCPUIfNoLapack
3703    @dtypes(torch.float)
3704    def test_linalg_qr_autograd_errors(self, device, dtype):
3705        # torch.linalg.qr(mode='r') returns only 'r' and discards 'q', but
3706        # without 'q' you cannot compute the backward pass. Check that
3707        # linalg_qr_backward complains cleanly in that case.
3708        inp = torch.randn((5, 7), device=device, dtype=dtype, requires_grad=True)
3709        q, r = torch.linalg.qr(inp, mode='r')
3710        self.assertEqual(q.shape, (0,))  # empty tensor
3711        b = torch.sum(r)
3712        with self.assertRaisesRegex(RuntimeError,
3713                                    "The derivative of linalg.qr depends on Q"):
3714            b.backward()
3715        inp = torch.randn((7, 5), device=device, dtype=dtype, requires_grad=True)
3716        q, r = torch.linalg.qr(inp, mode='complete')
3717        b = torch.sum(r)
3718        with self.assertRaisesRegex(RuntimeError,
3719                                    "The QR decomposition is not differentiable when mode='complete' and nrows > ncols"):
3720            b.backward()
3721
3722    @skipCUDAIfNoCusolver
3723    @skipCPUIfNoLapack
3724    @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble)
3725    def test_qr_batched(self, device, dtype):
3726        """
3727        test torch.linalg.qr vs numpy.linalg.qr. We need some special logic
3728        because numpy does not support batched qr
3729        """
3730        def np_qr_batched(a, mode):
3731            """poor's man batched version of np.linalg.qr"""
3732            all_q = []
3733            all_r = []
3734            for matrix in a:
3735                result = np.linalg.qr(matrix, mode=mode)
3736                if mode == 'r':
3737                    all_r.append(result)
3738                else:
3739                    q, r = result
3740                    all_q.append(q)
3741                    all_r.append(r)
3742            if mode == 'r':
3743                return np.array(all_r)
3744            else:
3745                return np.array(all_q), np.array(all_r)
3746
3747        t = torch.randn((3, 7, 5), device=device, dtype=dtype)
3748        np_t = t.cpu().numpy()
3749        for mode in ['reduced', 'complete']:
3750            exp_q, exp_r = np_qr_batched(np_t, mode=mode)
3751            q, r = torch.linalg.qr(t, mode=mode)
3752            self.assertEqual(q, exp_q)
3753            self.assertEqual(r, exp_r)
3754        # for mode='r' we need a special logic because numpy returns only r
3755        exp_r = np_qr_batched(np_t, mode='r')
3756        q, r = torch.linalg.qr(t, mode='r')
3757        # check that q is empty
3758        self.assertEqual(q.shape, (0,))
3759        self.assertEqual(q.dtype, t.dtype)
3760        self.assertEqual(q.device, t.device)
3761        # check r
3762        self.assertEqual(r, exp_r)
3763
3764    @skipCUDAIfNoCusolver
3765    @skipCPUIfNoLapack
3766    @dtypes(torch.float)
3767    def test_qr_error_cases(self, device, dtype):
3768        t1 = torch.randn(5, device=device, dtype=dtype)
3769        with self.assertRaisesRegex(RuntimeError, 'linalg.qr: The input tensor A must have at least 2 dimensions.'):
3770            torch.linalg.qr(t1)
3771        t2 = torch.randn((5, 7), device=device, dtype=dtype)
3772        with self.assertRaisesRegex(RuntimeError, "qr received unrecognized mode 'hello'"):
3773            torch.linalg.qr(t2, mode='hello')
3774
3775    def _check_einsum(self, *args, np_args=None):
3776        if np_args is None:
3777            np_args = [arg.cpu().numpy() if isinstance(arg, torch.Tensor) else arg for arg in args]
3778        ref = np.einsum(*np_args)
3779        res = torch.einsum(*args)
3780        self.assertEqual(ref, res)
3781
3782        # Check that the other variations for opt_einsum work too
3783        if TEST_OPT_EINSUM:
3784            with opt_einsum.flags(enabled=False):
3785                res = torch.einsum(*args)
3786                self.assertEqual(ref, res)
3787
3788            with opt_einsum.flags(enabled=True, strategy='greedy'):
3789                res = torch.einsum(*args)
3790                self.assertEqual(ref, res)
3791
3792            with opt_einsum.flags(enabled=True, strategy='optimal'):
3793                res = torch.einsum(*args)
3794                self.assertEqual(ref, res)
3795
3796    @dtypes(torch.double, torch.cdouble)
3797    def test_einsum(self, device, dtype):
3798        # Test cases from https://gist.github.com/rockt/15ee013889d65342088e9260a377dc8f
3799        x = make_tensor((5,), dtype=dtype, device=device)
3800        y = make_tensor((7,), dtype=dtype, device=device)
3801        A = make_tensor((3, 5), dtype=dtype, device=device)
3802        B = make_tensor((2, 5), dtype=dtype, device=device)
3803        C = make_tensor((2, 3, 5), dtype=dtype, device=device)
3804        D = make_tensor((2, 5, 7), dtype=dtype, device=device)
3805        E = make_tensor((7, 9), dtype=dtype, device=device)
3806        F = make_tensor((2, 3, 3, 5), dtype=dtype, device=device)
3807        G = make_tensor((5, 4, 6), dtype=dtype, device=device)
3808        H = make_tensor((4, 4), dtype=dtype, device=device)
3809        I = make_tensor((2, 3, 2), dtype=dtype, device=device)
3810
3811        # Vector operations
3812        self._check_einsum('i->', x)                     # sum
3813        self._check_einsum('i,i->', x, x)                # dot
3814        self._check_einsum('i,i->i', x, x)               # vector element-wisem mul
3815        self._check_einsum('i,j->ij', x, y)              # outer
3816
3817        # Matrix operations
3818        self._check_einsum("ij->ji", A)                  # transpose
3819        self._check_einsum("ij->j", A)                   # row sum
3820        self._check_einsum("ij->i", A)                   # col sum
3821        self._check_einsum("ij,ij->ij", A, A)            # matrix element-wise mul
3822        self._check_einsum("ij,j->i", A, x)              # matrix vector multiplication
3823        self._check_einsum("ij,kj->ik", A, B)            # matmul
3824        self._check_einsum("ij,ab->ijab", A, E)          # matrix outer product
3825
3826        # Tensor operations
3827        self._check_einsum("Aij,Ajk->Aik", C, D)         # batch matmul
3828        self._check_einsum("ijk,jk->i", C, A)            # tensor matrix contraction
3829        self._check_einsum("aij,jk->aik", D, E)          # tensor matrix contraction
3830        self._check_einsum("abCd,dFg->abCFg", F, G)      # tensor tensor contraction
3831        self._check_einsum("ijk,jk->ik", C, A)           # tensor matrix contraction with double indices
3832        self._check_einsum("ijk,jk->ij", C, A)           # tensor matrix contraction with double indices
3833        self._check_einsum("ijk,ik->j", C, B)            # non contiguous
3834        self._check_einsum("ijk,ik->jk", C, B)           # non contiguous with double indices
3835
3836        # Test diagonals
3837        self._check_einsum("ii", H)                      # trace
3838        self._check_einsum("ii->i", H)                   # diagonal
3839        self._check_einsum('iji->j', I)                  # non-contiguous trace
3840        self._check_einsum('ngrg...->nrg...', make_tensor((2, 1, 3, 1, 4), dtype=dtype, device=device))
3841
3842        # Test ellipsis
3843        self._check_einsum("i...->...", H)
3844        self._check_einsum("ki,...k->i...", A.t(), B)
3845        self._check_einsum("k...,jk->...", A.t(), B)
3846        self._check_einsum('...ik, ...j -> ...ij', C, x)
3847        self._check_einsum('Bik,k...j->i...j', C, make_tensor((5, 3), dtype=dtype, device=device))
3848        self._check_einsum('i...j, ij... -> ...ij', C, make_tensor((2, 5, 2, 3), dtype=dtype, device=device))
3849
3850        # torch.bilinear with noncontiguous tensors
3851        l = make_tensor((5, 10), dtype=dtype, device=device, noncontiguous=True)
3852        r = make_tensor((5, 20), dtype=dtype, device=device, noncontiguous=True)
3853        w = make_tensor((15, 10, 20), dtype=dtype, device=device)
3854        self._check_einsum("bn,anm,bm->ba", l, w, r)
3855
3856        # with strided tensors
3857        self._check_einsum("bn,Anm,bm->bA", l[:, ::2], w[:, ::2, ::2], r[:, ::2])
3858
3859        # test multiple inputs
3860        self._check_einsum("...,be,b...,beg,gi,bc...->bi...", A, B, C, D, E, F)
3861
3862    @dtypes(torch.double, torch.cdouble)
3863    def test_einsum_sublist_format(self, device, dtype):
3864        x = make_tensor((5,), dtype=dtype, device=device)
3865        y = make_tensor((7,), dtype=dtype, device=device)
3866        A = make_tensor((3, 5), dtype=dtype, device=device)
3867        B = make_tensor((2, 5), dtype=dtype, device=device)
3868        C = make_tensor((2, 1, 3, 1, 4), dtype=dtype, device=device)
3869
3870        self._check_einsum(x, [0])
3871        self._check_einsum(x, [0], [])
3872        self._check_einsum(x, [0], y, [1], [0, 1])
3873        self._check_einsum(A, [0, 1], [1, 0])
3874        self._check_einsum(A, [0, 1], x, [1], [0])
3875        self._check_einsum(A, [0, 1], B, [2, 1])
3876        self._check_einsum(A, [0, 1], B, [2, 1], [0, 2])
3877        self._check_einsum(C, [0, 1, 2, 1, Ellipsis], [0, 2, 1, Ellipsis])
3878        self._check_einsum(A.t(), [0, 1], B, [Ellipsis, 0])
3879        self._check_einsum(A.t(), [0, 1], B, [Ellipsis, 0], [1, Ellipsis])
3880        self._check_einsum(A.t(), [0, Ellipsis], B, [1, 0], [Ellipsis])
3881
3882        # torch.bilinear with noncontiguous tensors
3883        l = make_tensor((5, 10), dtype=dtype, device=device, noncontiguous=True)
3884        r = make_tensor((5, 20), dtype=dtype, device=device, noncontiguous=True)
3885        w = make_tensor((15, 10, 20), dtype=dtype, device=device)
3886        self._check_einsum(l, [40, 41], w, [2, 41, 50], r, [40, 50], [40, 2])
3887
3888    @dtypes(torch.double, torch.cdouble)
3889    def test_einsum_random(self, device, dtype):
3890        def convert_label(label):
3891            if label == ...:
3892                return '...'
3893            elif label < 26:
3894                return chr(ord('A') + label)
3895            else:
3896                return chr(ord('a') + label - 26)
3897
3898        def convert_sublist(sublist):
3899            return ''.join(convert_label(label) for label in sublist)
3900
3901        def test(n=10,                       # how many tests to generate
3902                 n_labels=5,                 # how many labels available
3903                 min_ops=1, max_ops=4,       # min and max number of operands per test
3904                 min_dims=1, max_dims=3,     # min and max number of dimensions per operand
3905                 min_size=1, max_size=8,     # min and max size of each dimension
3906                 max_out_dim=3,              # max number of dimensions for the output
3907                 enable_diagonals=True,      # controls if labels can be repeated for diagonals
3908                 ellipsis_prob=0.5,          # probability of including ellipsis in operand
3909                 broadcasting_prob=0.1):     # probability of turning some dim sizes 1 for broadcasting
3910
3911            all_labels = torch.arange(52)
3912
3913            assert 0 <= n
3914            assert 0 <= n_labels < len(all_labels)
3915            assert 0 < min_ops <= max_ops
3916            assert 0 <= min_dims <= max_dims
3917            assert 0 <= min_size <= max_size
3918            assert 0 <= max_out_dim
3919            assert enable_diagonals or max_dims <= n_labels
3920
3921            for _ in range(n):
3922
3923                # Select a subset of labels for this test and give them random sizes
3924                possible_labels = all_labels[torch.randperm(len(all_labels))[:n_labels]]
3925                labels_size = torch.randint_like(all_labels, min_size, max_size + 1)
3926                ellipsis_shape = torch.randint(min_size, max_size + 1, (max_dims - min_dims,))
3927
3928                operands = []
3929                sublists = []
3930
3931                ell_size = 0
3932                valid_labels = set()
3933
3934                # create random input operands
3935                for _ in range(random.randint(min_ops, max_ops)):
3936                    n_dim = random.randint(min_dims, max_dims)
3937                    labels_idx = torch.ones(len(possible_labels)).multinomial(n_dim, enable_diagonals)
3938                    labels = possible_labels[labels_idx]
3939                    valid_labels.update(labels.tolist())
3940                    shape = labels_size[labels]
3941
3942                    # turn some dimensions to size 1 for testing broadcasting
3943                    mask = Binomial(probs=broadcasting_prob).sample((n_dim,))
3944                    broadcast_labels = torch.unique(labels[mask == 1])
3945                    shape[(labels[..., None] == broadcast_labels).any(-1)] = 1
3946
3947                    labels = labels.tolist()
3948                    shape = shape.tolist()
3949
3950                    # include ellipsis if not all dimensions were assigned a label already
3951                    if n_dim < max_dims and torch.rand(1) < ellipsis_prob:
3952                        ell_num_dim = random.randint(1, max_dims - n_dim)
3953                        ell_size = max(ell_size, ell_num_dim)
3954                        ell_shape = ellipsis_shape[-ell_num_dim:]
3955                        # again, turn some dimensions to size 1 for broadcasting
3956                        mask = Binomial(probs=broadcasting_prob).sample((ell_num_dim,))
3957                        ell_shape[mask == 1] = 1
3958                        ell_index = random.randint(0, n_dim)
3959                        shape[ell_index:ell_index] = ell_shape
3960                        labels.insert(ell_index, ...)
3961
3962                    operands.append(make_tensor(shape, dtype=dtype, device=device))
3963                    sublists.append(labels)
3964
3965                # NumPy has a bug with the sublist format so for now we compare PyTorch sublist
3966                # implementation against the equation format implementation of NumPy
3967                # see https://github.com/numpy/numpy/issues/10926
3968                np_operands = [op.cpu().numpy() for op in operands]
3969
3970                # test equation format
3971                equation = ','.join(convert_sublist(l) for l in sublists)
3972                self._check_einsum(equation, *operands, np_args=(equation, *np_operands))
3973
3974                # test sublist format
3975                args = list(itertools.chain.from_iterable(zip(operands, sublists)))
3976                self._check_einsum(*args, np_args=(equation, *np_operands))
3977
3978                # generate an explicit output
3979                out_sublist = []
3980                num_out_labels = max(0, random.randint(0, min(max_out_dim, len(valid_labels))) - ell_size)
3981                if num_out_labels > 0:
3982                    out_labels_idx = torch.ones(len(valid_labels)).multinomial(num_out_labels)
3983                    out_sublist = torch.tensor(list(valid_labels))[out_labels_idx].tolist()
3984                out_sublist.insert(random.randint(0, num_out_labels), ...)
3985
3986                # test equation format with explicit output
3987                equation += '->' + convert_sublist(out_sublist)
3988                self._check_einsum(equation, *operands, np_args=(equation, *np_operands))
3989
3990                # test sublist format with explicit output
3991                args.append(out_sublist)
3992                self._check_einsum(*args, np_args=(equation, *np_operands))
3993
3994        test(500)
3995
3996    def test_einsum_corner_cases(self, device):
3997        def check(equation, *operands, expected_output):
3998            tensors = [torch.tensor(operand, device=device, dtype=torch.float32) if not isinstance(operand, tuple)
3999                       else make_tensor(operand, dtype=torch.float32, device=device) for operand in operands]
4000            output = torch.einsum(equation, tensors)
4001            self.assertEqual(output, torch.tensor(expected_output, dtype=torch.float32, device=device))
4002
4003        # Test equation variantions
4004        check(' ', 1, expected_output=1)
4005        check(' -> ', 1, expected_output=1)
4006        check(' , ', 2, 2, expected_output=4)
4007        check(' , , ', 2, 2, 2, expected_output=8)
4008        check(' , -> ', 2, 2, expected_output=4)
4009        check(' i ', [1], expected_output=[1])
4010        check(' i -> ', [1], expected_output=1)
4011        check(' i -> i ', [1], expected_output=[1])
4012        check(' i , i ', [2], [2], expected_output=4)
4013        check(' i , i -> i ', [2], [2], expected_output=[4])
4014
4015        # Test tensors with 0 size dimensions
4016        check('i', [], expected_output=[])
4017        check(' i j -> j', [[], []], expected_output=[])
4018        check('ij->i', [[], []], expected_output=[0., 0.])
4019        check(' i j k  ,  k  -> i j ', (3, 0, 6), (6,), expected_output=[[], [], []])
4020
4021        # Test broadcasting
4022        check('i,j', [2], [1, 2], expected_output=[[2, 4]])
4023        check('i,ij->ij', [1, 2], [[1, 2, 3], [2, 3, 4]], expected_output=[[1, 2, 3], [4, 6, 8]])
4024
4025        # Test ellipsis broadcasting
4026        check('...', 1, expected_output=1)
4027        check('...->', 1, expected_output=1)
4028        check('...->...', 1, expected_output=1)
4029        check('...', [1], expected_output=[1])
4030        check('...->', [1], expected_output=1)
4031        check('z...->z', [1], expected_output=[1])
4032        check('Z...->...Z', [1], expected_output=[1])
4033        check('...a->', [[2], [4]], expected_output=6)
4034        check('a...b->ab', [[[1], [2]], [[3], [4]]], expected_output=[[3], [7]])
4035
4036    def test_einsum_error_cases(self, device):
4037        def check(*args, regex, exception=RuntimeError):
4038            with self.assertRaisesRegex(exception, r'einsum\(\):.*' + regex):
4039                torch.einsum(*args)
4040
4041        x = make_tensor((2,), dtype=torch.float32, device=device)
4042        y = make_tensor((2, 3), dtype=torch.float32, device=device)
4043
4044        check('', [], regex=r'at least one operand', exception=ValueError)
4045        check('. ..', [x], regex=r'found \'.\' for operand 0 that is not part of any ellipsis')
4046        check('... ...', [x], regex=r'found \'.\' for operand 0 for which an ellipsis was already found')
4047        check('1', [x], regex=r'invalid subscript given at index 0')
4048        check(',', [x], regex=r'fewer operands were provided than specified in the equation')
4049        check('', [x, x], regex=r'more operands were provided than specified in the equation')
4050        check('', [x], regex=r'the number of subscripts in the equation \(0\) does not match the number '
4051              r'of dimensions \(1\) for operand 0 and no ellipsis was given')
4052        check('ai', [x], regex=r'the number of subscripts in the equation \(2\) does not match the number '
4053              r'of dimensions \(1\) for operand 0 and no ellipsis was given')
4054        check('ai...', [x], regex=r'the number of subscripts in the equation \(2\) is more than the number '
4055              r'of dimensions \(1\) for operand 0')
4056        check('a->... .', [x], regex=r'found \'.\' for output but an ellipsis \(...\) was already found')
4057        check('a->..', [x], regex=r'found \'.\' for output that is not part of any ellipsis \(...\)')
4058        check('a->1', [x], regex=r'invalid subscript given at index 3')
4059        check('a->aa', [x], regex=r'output subscript a appears more than once in the output')
4060        check('a->i', [x], regex=r'output subscript i does not appear in the equation for any input operand')
4061        check('aa', [y], regex=r'subscript a is repeated for operand 0 but the sizes don\'t match, 3 != 2')
4062        check('...,...', [x, y], regex=r'does not broadcast')
4063        check('a,a', [x, make_tensor((3,), dtype=torch.float32, device=device)], regex=r'does not broadcast')
4064        check('a, ba', [x, y], regex=r'subscript a has size 3 for operand 1 which does not broadcast with previously'
4065              r' seen size 2')
4066
4067        check(x, [-1], regex=r'not within the valid range \[0, 52\)', exception=ValueError)
4068        check(x, [52], regex=r'not within the valid range \[0, 52\)', exception=ValueError)
4069
4070    def _gen_shape_inputs_linalg_triangular_solve(self, shape, dtype, device, well_conditioned=False):
4071        make_arg = partial(make_tensor, dtype=dtype, device=device)
4072        make_fullrank = partial(make_fullrank_matrices_with_distinct_singular_values, dtype=dtype, device=device)
4073        b, n, k = shape
4074        for left, uni, expand_a, tr_a, conj_a, expand_b, tr_b, conj_b in product((True, False), repeat=8):
4075            # expand means that we generate a batch of matrices with a stride of zero in the batch dimension
4076            if (conj_a or conj_b) and not dtype.is_complex:
4077                continue
4078            # We just expand on the batch size
4079            if (expand_a or expand_b) and b == 1:
4080                continue
4081
4082            size_a = (b, n, n) if left else (b, k, k)
4083            size_b = (b, n, k) if not tr_b else (b, k, n)
4084
4085            # If expand_a or expand_b, we'll expand them to the correct size later
4086            if b == 1 or expand_a:
4087                size_a = size_a[1:]
4088            if b == 1 or expand_b:
4089                size_b = size_b[1:]
4090
4091            if well_conditioned:
4092                PLU = torch.linalg.lu(make_fullrank(*size_a))
4093                if uni:
4094                    # A = L from PLU
4095                    A = PLU[1].transpose(-2, -1).contiguous()
4096                else:
4097                    # A = U from PLU
4098                    A = PLU[2].contiguous()
4099            else:
4100                A = make_arg(size_a)
4101                A.triu_()
4102
4103            diag = A.diagonal(0, -2, -1)
4104            if uni:
4105                diag.fill_(1.)
4106            else:
4107                diag[diag.abs() < 1e-6] = 1.
4108
4109            B = make_arg(size_b)
4110
4111            if tr_a:
4112                A.transpose_(-2, -1)
4113            if tr_b:
4114                B.transpose_(-2, -1)
4115            if conj_a:
4116                A = A.conj()
4117            if conj_b:
4118                B = B.conj()
4119            if expand_a:
4120                A = A.expand(b, *size_a)
4121            if expand_b:
4122                B = B.expand(b, n, k)
4123            yield A, B, left, not tr_a, uni
4124
4125    def _test_linalg_solve_triangular(self, A, B, upper, left, uni):
4126        X = torch.linalg.solve_triangular(A, B, upper=upper, left=left, unitriangular=uni)
4127        if left:
4128            self.assertEqual(A @ X, B)
4129        else:
4130            self.assertEqual(X @ A, B)
4131        out = B
4132        # B may be expanded
4133        if not B.is_contiguous() and not B.transpose(-2, -1).is_contiguous():
4134            out = B.clone()
4135        torch.linalg.solve_triangular(A, B, upper=upper, left=left, unitriangular=uni, out=out)
4136        self.assertEqual(X, out)
4137
4138    # Tolerances dictated by widest acceptable range on CPU before failure
4139    @dtypes(*floating_and_complex_types())
4140    @precisionOverride({torch.float32: 1e-3 if TEST_WITH_ROCM else 1e-1,
4141                        torch.float64: 1e-8,
4142                        torch.complex64: 1e-1,
4143                        torch.complex128: 1e-8})
4144    def test_linalg_solve_triangular(self, device, dtype):
4145        # This exercises the API + BLAS CPU + batched cuBLAS
4146        ks = (3, 1, 0)
4147        ns = (5, 0)
4148        bs = (1, 2, 0)
4149
4150        gen_inputs = self._gen_shape_inputs_linalg_triangular_solve
4151        for b, n, k in product(bs, ns, ks):
4152            for A, B, left, upper, uni in gen_inputs((b, n, k), dtype, device, well_conditioned=True):
4153                self._test_linalg_solve_triangular(A, B, upper, left, uni)
4154
4155    @slowTest
4156    @unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Test fails for float64 on GPU (P100, V100) on Meta infra")
4157    @onlyCUDA
4158    @skipCUDAIfNoMagma  # Magma needed for the PLU decomposition
4159    @dtypes(*floating_and_complex_types())
4160    @precisionOverride({torch.float32: 1e-2, torch.complex64: 1e-2,
4161                        torch.float64: 1e-8, torch.complex128: 1e-8})
4162    def test_linalg_solve_triangular_large(self, device, dtype):
4163        # Exercises magma and cublas
4164        magma = (9, 513, 1)
4165        iterative_cublas = (2, 64, 1)
4166
4167        gen_inputs = self._gen_shape_inputs_linalg_triangular_solve
4168        for shape in (magma, iterative_cublas):
4169            for A, B, left, upper, uni in gen_inputs(shape, dtype, device, well_conditioned=True):
4170                self._test_linalg_solve_triangular(A, B, upper, left, uni)
4171
4172    @dtypes(*floating_and_complex_types())
4173    @precisionOverride({torch.float32: 1e-2, torch.complex64: 1e-2,
4174                        torch.float64: 1e-8, torch.complex128: 1e-8})
4175    def test_linalg_solve_triangular_broadcasting(self, device, dtype):
4176        make_arg = partial(make_tensor, dtype=dtype, device=device)
4177
4178        sizes = (((2, 1, 3, 4, 4), (2, 1, 3, 4, 6)),
4179                 ((2, 1, 3, 4, 4), (4, 6)),
4180                 ((4, 4), (2, 1, 3, 4, 2)),
4181                 ((1, 3, 1, 4, 4), (2, 1, 3, 4, 5)))
4182        for size_A, size_B in sizes:
4183            for left, upper, uni in itertools.product([True, False], repeat=3):
4184                A = make_arg(size_A)
4185                if upper:
4186                    A.triu_()
4187                else:
4188                    A.tril_()
4189                diag = A.diagonal(0, -2, -1)
4190                if uni:
4191                    diag.fill_(1.)
4192                else:
4193                    diag[diag.abs() < 1e-6] = 1.
4194                B = make_arg(size_B)
4195                if not left:
4196                    B.transpose_(-2, -1)
4197
4198                X = torch.linalg.solve_triangular(A, B, upper=upper, left=left, unitriangular=uni)
4199                if left:
4200                    B_other = A @ X
4201                else:
4202                    B_other = X @ A
4203
4204                self.assertEqual(*torch.broadcast_tensors(B, B_other))
4205
4206    def triangular_solve_test_helper(self, A_dims, b_dims, upper, unitriangular,
4207                                     device, dtype):
4208        triangle_function = torch.triu if upper else torch.tril
4209        b = torch.randn(*b_dims, dtype=dtype, device=device)
4210        A = torch.randn(*A_dims, dtype=dtype, device=device)
4211        # create positive definite matrix
4212        A = torch.matmul(A, A.mT)
4213        A_triangular = triangle_function(A)
4214        if unitriangular:
4215            A_triangular.diagonal(dim1=-2, dim2=-1).fill_(1.)
4216        return b, A_triangular
4217
4218    @skipCUDAIfNoMagma
4219    @skipCPUIfNoLapack
4220    @skipIfTorchDynamo("flaky, needs investigation")
4221    @dtypes(*floating_and_complex_types())
4222    @precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3,
4223                        torch.float64: 1e-8, torch.complex128: 1e-8})
4224    def test_triangular_solve(self, device, dtype):
4225        ks = [0, 1, 3]
4226        ns = [0, 5]
4227        for k, n, (upper, unitriangular, transpose) in itertools.product(ks, ns,
4228                                                                         itertools.product([True, False], repeat=3)):
4229            b, A = self.triangular_solve_test_helper((n, n), (n, k), upper,
4230                                                     unitriangular, device, dtype)
4231            x = torch.triangular_solve(b, A, upper=upper, unitriangular=unitriangular, transpose=transpose)[0]
4232            if transpose:
4233                self.assertEqual(b, np.matmul(A.t().cpu(), x.cpu()))
4234            else:
4235                self.assertEqual(b, np.matmul(A.cpu(), x.cpu()))
4236
4237    @skipCPUIfNoLapack
4238    @skipCUDAIfNoMagma
4239    @dtypes(*floating_and_complex_types())
4240    @precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3,
4241                        torch.float64: 1e-8, torch.complex128: 1e-8})
4242    def test_triangular_solve_batched(self, device, dtype):
4243        def triangular_solve_batch_helper(A_dims, b_dims, upper, unitriangular, transpose):
4244            b, A = self.triangular_solve_test_helper(A_dims, b_dims, upper,
4245                                                     unitriangular, device, dtype)
4246            x_exp_list = []
4247            for i in range(b_dims[0]):
4248                x_exp_list.append(torch.triangular_solve(b[i], A[i], upper=upper,
4249                                                         unitriangular=unitriangular,
4250                                                         transpose=transpose)[0])
4251            x_exp = torch.stack(x_exp_list)  # Stacked output
4252            x_act = torch.triangular_solve(b, A, upper=upper,
4253                                           unitriangular=unitriangular,
4254                                           transpose=transpose)[0]  # Actual output
4255            self.assertEqual(x_act, x_exp)  # Equality check
4256            if transpose:
4257                A = A.mT
4258
4259            Ax = np.matmul(A.cpu(), x_act.cpu())
4260            self.assertEqual(b, Ax)
4261
4262        def triangular_solve_zero_batch_helper(A_dims, b_dims, upper, unitriangular, transpose):
4263            b, A = self.triangular_solve_test_helper(A_dims, b_dims, upper,
4264                                                     unitriangular, device, dtype)
4265            x = torch.triangular_solve(b, A, upper=upper,
4266                                       unitriangular=unitriangular,
4267                                       transpose=transpose)[0]
4268            self.assertTrue(x.shape == b.shape)
4269
4270        for upper, unitriangular, transpose in itertools.product([True, False], repeat=3):
4271            batchsize = 3
4272            triangular_solve_batch_helper((batchsize, 5, 5), (batchsize, 5, 10),
4273                                          upper, unitriangular, transpose)
4274
4275            # test empty input
4276            triangular_solve_batch_helper((batchsize, 0, 0), (batchsize, 0, 10),
4277                                          upper, unitriangular, transpose)
4278            triangular_solve_batch_helper((batchsize, 0, 0), (batchsize, 0, 0),
4279                                          upper, unitriangular, transpose)
4280
4281            # test zero batch case
4282            batchsize = 0
4283            triangular_solve_zero_batch_helper((batchsize, 5, 5), (batchsize, 5, 10),
4284                                               upper, unitriangular, transpose)
4285
4286
4287    @slowTest
4288    @skipCUDAIfNoMagma
4289    @skipCPUIfNoLapack
4290    @dtypes(*floating_and_complex_types())
4291    @precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3,
4292                        torch.float64: 1e-8, torch.complex128: 1e-8})
4293    def test_triangular_solve_batched_many_batches(self, device, dtype):
4294        for upper, transpose, unitriangular in itertools.product([True, False], repeat=3):
4295            # test batched A case
4296            b, A = self.triangular_solve_test_helper((256, 256, 5, 5), (5, 1),
4297                                                     upper, unitriangular, device, dtype)
4298            x, _ = torch.triangular_solve(b, A,
4299                                          upper=upper, transpose=transpose, unitriangular=unitriangular)
4300            if transpose:
4301                A = A.mT
4302
4303            Ax = torch.matmul(A, x)
4304
4305            rtol = 1e-2 if dtype in [torch.float32, torch.complex64] else self.precision
4306            self.assertEqual(Ax, b.expand_as(Ax), atol=self.precision, rtol=rtol)
4307
4308            # test batched b case
4309            b, A = self.triangular_solve_test_helper((3, 3), (512, 512, 3, 1),
4310                                                     upper, unitriangular, device, dtype)
4311            x, _ = torch.triangular_solve(b, A, upper=upper, transpose=transpose,
4312                                          unitriangular=unitriangular)
4313            if transpose:
4314                A = A.mT
4315
4316            self.assertEqual(torch.matmul(A, x), b)
4317
4318    @skipCUDAIfNoMagma
4319    @skipCPUIfNoLapack
4320    @unittest.skipIf(not TEST_SCIPY, "SciPy not found")
4321    @skipIfTorchDynamo("flaky, needs investigation")
4322    @dtypes(*floating_and_complex_types())
4323    def test_triangular_solve_batched_broadcasting(self, device, dtype):
4324        from scipy.linalg import solve_triangular as tri_solve
4325
4326        def scipy_tri_solve_batched(A, B, upper, trans, diag):
4327            batch_dims_A, batch_dims_B = A.shape[:-2], B.shape[:-2]
4328            single_dim_A, single_dim_B = A.shape[-2:], B.shape[-2:]
4329            expand_dims = tuple(torch._C._infer_size(torch.Size(batch_dims_A),
4330                                                     torch.Size(batch_dims_B)))
4331            expand_A = np.broadcast_to(A, expand_dims + single_dim_A)
4332            expand_B = np.broadcast_to(B, expand_dims + single_dim_B)
4333            flat_A = expand_A.reshape((-1,) + single_dim_A)
4334            flat_B = expand_B.reshape((-1,) + single_dim_B)
4335            flat_X = np.vstack([tri_solve(a, b, lower=(not upper), trans=int(trans), unit_diagonal=diag)
4336                                for a, b in zip(flat_A, flat_B)])
4337            return flat_X.reshape(expand_B.shape)
4338
4339        def run_test(A_dims, b_dims, device, upper, transpose, unitriangular):
4340            b, A = self.triangular_solve_test_helper(A_dims, b_dims, upper,
4341                                                     unitriangular, device, dtype)
4342            x_exp = torch.as_tensor(scipy_tri_solve_batched(A.cpu().numpy(), b.cpu().numpy(),
4343                                                            upper, transpose, unitriangular))
4344            x = torch.triangular_solve(b, A, upper=upper, transpose=transpose, unitriangular=unitriangular)[0]
4345
4346            self.assertEqual(x, x_exp.to(device))
4347
4348        for upper, transpose, unitriangular in itertools.product([True, False], repeat=3):
4349            # test against scipy.linalg.solve_triangular
4350            run_test((2, 1, 3, 4, 4), (2, 1, 3, 4, 6), device, upper, transpose, unitriangular)  # no broadcasting
4351            run_test((2, 1, 3, 4, 4), (4, 6), device, upper, transpose, unitriangular)  # broadcasting b
4352            run_test((4, 4), (2, 1, 3, 4, 2), device, upper, transpose, unitriangular)  # broadcasting A
4353            run_test((1, 3, 1, 4, 4), (2, 1, 3, 4, 5), device, upper, transpose, unitriangular)  # broadcasting A & b
4354
4355    @onlyCUDA
4356    @dtypes(torch.float)
4357    def test_triangular_solve_large(self, device, dtype):
4358        # Repro for https://github.com/pytorch/pytorch/issues/79191
4359        A = torch.randn(1, 2, 2, device=device, dtype=dtype).tril_()
4360        B = torch.randn(1, 2, 524281, device=device, dtype=dtype)
4361        X = torch.linalg.solve_triangular(A, B, upper=False)
4362        self.assertEqual(A @ X, B)
4363
4364    @skipCUDAIfNoMagma
4365    @skipCPUIfNoLapack
4366    @dtypes(*floating_and_complex_types())
4367    def test_triangular_solve_out_errors_and_warnings(self, device, dtype):
4368        # dtypes should be safely castable
4369        a = torch.eye(2, dtype=dtype, device=device)
4370        b = torch.randn(2, 1, dtype=dtype, device=device)
4371        out = torch.empty_like(b).to(torch.int)
4372        clone_a = torch.empty_like(a)
4373        with self.assertRaisesRegex(RuntimeError, "Expected out tensor to have dtype"):
4374            torch.triangular_solve(b, a, out=(out, clone_a))
4375
4376        out = torch.empty_like(b)
4377        clone_a = clone_a.to(torch.int)
4378        with self.assertRaisesRegex(RuntimeError, "Expected out tensor to have dtype"):
4379            torch.triangular_solve(b, a, out=(out, clone_a))
4380
4381        # device should match
4382        if torch.cuda.is_available():
4383            wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda'
4384            out = torch.empty(0, dtype=dtype, device=wrong_device)
4385            clone_a = torch.empty_like(a)
4386            with self.assertRaisesRegex(RuntimeError, "tensors to be on the same device"):
4387                torch.triangular_solve(b, a, out=(out, clone_a))
4388            out = torch.empty(0, dtype=dtype, device=device)
4389            clone_a = torch.empty_like(a).to(wrong_device)
4390            with self.assertRaisesRegex(RuntimeError, "tensors to be on the same device"):
4391                torch.triangular_solve(b, a, out=(out, clone_a))
4392
4393        # Trigger the WARN_ONCE deprecation error
4394        torch.triangular_solve(b, a)
4395
4396        # if out tensor with wrong shape is passed a warning is given
4397        with warnings.catch_warnings(record=True) as w:
4398            out = torch.empty(1, dtype=dtype, device=device)
4399            clone_a = torch.empty(1, dtype=dtype, device=device)
4400            # Trigger warning
4401            torch.triangular_solve(b, a, out=(out, clone_a))
4402            # Check warning occurs
4403            self.assertEqual(len(w), 2)
4404            self.assertTrue("An output with one or more elements was resized" in str(w[0].message))
4405            self.assertTrue("An output with one or more elements was resized" in str(w[1].message))
4406
4407
4408    def check_single_matmul(self, x, y):
4409
4410        def assertEqual(answer, expected):
4411            if x.dtype.is_floating_point or x.dtype.is_complex:
4412                k = max(x.shape[-1], 1)  # Scale the atol with the size of the matrix
4413                self.assertEqual(answer, expected,
4414                                 msg=f"{x.shape} x {y.shape} = {answer.shape}",
4415                                 atol=k * 5e-5,
4416                                 rtol=1e-4)
4417            else:
4418                self.assertEqual(answer, expected, msg=f"{x.shape} x {y.shape} = {answer.shape}")
4419
4420        # test x @ y
4421        expected = np.matmul(x.cpu(), y.cpu())
4422        ans = torch.matmul(x, y)
4423        self.assertTrue(ans.is_contiguous())
4424        assertEqual(ans, expected)
4425
4426        # test out
4427        out = torch.empty_like(ans)
4428        ans = torch.matmul(x, y, out=out)
4429        self.assertIs(ans, out)
4430        self.assertTrue(ans.is_contiguous())
4431        assertEqual(ans, expected)
4432
4433    def gen_sizes_matmul(self, x_dim, y_dim=4, matrix_size=4, batch_size=3):
4434        """
4435        Generates sequences of tuples (x, y) of with size(x) = x_dim and
4436        size(y) <= y_dim that are compatible wrt. matmul
4437        """
4438        assert x_dim >= 1
4439        assert y_dim >= 2
4440        x = x_dim
4441        for y in range(1, y_dim + 1):
4442            for batch, mn in product(product(range(batch_size), repeat=max(x - 2, y - 2, 0)),
4443                                     product(range(matrix_size), repeat=min(y, 2))):
4444                if x == 1:
4445                    size_x = mn[:1]
4446                    size_y = batch + mn
4447                    yield size_x, size_y
4448                else:
4449                    for k in range(matrix_size):
4450                        size_x = (k,) + mn[:1]
4451                        if x > 2:
4452                            size_x = batch[-(x - 2):] + size_x
4453                        size_y = mn
4454                        if y > 2:
4455                            size_y = batch[-(y - 2):] + size_y
4456                        yield size_x, size_y
4457
4458    @dtypesIfCUDA(torch.float, torch.complex64)  # Integer matmul just supported on CPU
4459    @dtypes(torch.int64, torch.float, torch.complex64)
4460    @setBlasBackendsToDefaultFinally
4461    def test_matmul_small_brute_force_1d_Nd(self, device, dtype):
4462        for backend in ["cublas", "cublaslt"]:
4463            if torch.device(device).type == 'cuda':
4464                torch.backends.cuda.preferred_blas_library(backend)
4465
4466            make_arg = partial(make_tensor, device=device, dtype=dtype)
4467
4468            for (size_x, size_y), nctg_x, nctg_y in product(self.gen_sizes_matmul(1), (True, False), (True, False)):
4469                x = make_arg(size_x, noncontiguous=nctg_x)
4470                y = make_arg(size_y, noncontiguous=nctg_y)
4471                self.check_single_matmul(x, y)
4472
4473    @dtypesIfCUDA(torch.float, torch.complex64)  # Integer matmul just supported on CPU
4474    @dtypes(torch.int64, torch.float, torch.complex64)
4475    @setBlasBackendsToDefaultFinally
4476    def test_matmul_small_brute_force_2d_Nd(self, device, dtype):
4477        for backend in ["cublas", "cublaslt"]:
4478            if torch.device(device).type == 'cuda':
4479                torch.backends.cuda.preferred_blas_library(backend)
4480
4481            make_arg = partial(make_tensor, device=device, dtype=dtype)
4482
4483            for (size_x, size_y), nctg_x, nctg_y in product(self.gen_sizes_matmul(2), (True, False), (True, False)):
4484                x = make_arg(size_x, noncontiguous=nctg_x)
4485                y = make_arg(size_y, noncontiguous=nctg_y)
4486                self.check_single_matmul(x, y)
4487
4488    @dtypesIfCUDA(torch.float, torch.complex64)  # Integer matmul just supported on CPU
4489    @dtypes(torch.int64, torch.float, torch.complex64)
4490    @setBlasBackendsToDefaultFinally
4491    def test_matmul_small_brute_force_3d_Nd(self, device, dtype):
4492        for backend in ["cublas", "cublaslt"]:
4493            if torch.device(device).type == 'cuda':
4494                torch.backends.cuda.preferred_blas_library(backend)
4495
4496            make_arg = partial(make_tensor, device=device, dtype=dtype)
4497
4498            for (size_x, size_y), nctg_x, nctg_y in product(self.gen_sizes_matmul(3), (True, False), (True, False)):
4499                x = make_arg(size_x, noncontiguous=nctg_x)
4500                y = make_arg(size_y, noncontiguous=nctg_y)
4501                self.check_single_matmul(x, y)
4502
4503    @onlyCUDA
4504    @dtypes(*floating_types_and(torch.half))
4505    def test_matmul_small_brute_force_tunableop(self, device, dtype):
4506        # disable tunableop buffer rotation for all tests everywhere, it can be slow
4507        import os
4508        os.environ["PYTORCH_TUNABLEOP_ROTATING_BUFFER_SIZE"] = "0"
4509        set_tunableop_defaults()
4510
4511        torch.cuda.tunable.enable()
4512        # set these to single iterations to keep it short but still exercise the code
4513        torch.cuda.tunable.set_max_tuning_duration(1)
4514        torch.cuda.tunable.set_max_tuning_iterations(1)
4515
4516        make_arg = partial(make_tensor, device=device, dtype=dtype)
4517
4518        for (size_x, size_y), nctg_x, nctg_y in product(self.gen_sizes_matmul(1), (True, False), (True, False)):
4519            x = make_arg(size_x, noncontiguous=nctg_x)
4520            y = make_arg(size_y, noncontiguous=nctg_y)
4521            self.check_single_matmul(x, y)
4522
4523        filename1 = torch.cuda.tunable.get_filename()
4524        filename2 = "tunableop_results_tmp1.csv"
4525        filename3 = "tunableop_results_tmp2.csv"
4526        ordinal = torch.cuda.current_device()
4527        assert filename1 == f"tunableop_results{ordinal}.csv"
4528        assert len(torch.cuda.tunable.get_validators()) > 0
4529        validators = {}
4530        for key, value in torch.cuda.tunable.get_validators():
4531            validators[key] = value
4532        if torch.version.hip:
4533            assert "HIPBLASLT_VERSION" in validators
4534            assert re.match(r'^\d{3}-[a-z0-9]{8}$', validators["HIPBLASLT_VERSION"])
4535        assert len(torch.cuda.tunable.get_results()) > 0
4536
4537        assert torch.cuda.tunable.write_file()  # use default filename
4538        assert torch.cuda.tunable.write_file(filename2)  # use custom, one-time filename
4539        torch.cuda.tunable.set_filename(filename3)
4540        assert torch.cuda.tunable.write_file()  # use previously set filename
4541        assert torch.cuda.tunable.read_file()  # use previously set filename, will ignore duplicates and return True
4542
4543        with open(filename1) as file1:
4544            file1_contents = file1.read()
4545        with open(filename2) as file2:
4546            file2_contents = file2.read()
4547        with open(filename3) as file3:
4548            file3_contents = file3.read()
4549        assert file1_contents == file2_contents
4550        assert file1_contents == file3_contents
4551
4552        # remove the files created above to avoid error 'Build left local git repository checkout dirty', ignore errors
4553        for filename in [filename1, filename2, filename3]:
4554            try:
4555                import os
4556                os.remove(filename)
4557            except FileNotFoundError:
4558                pass
4559
4560        # disables TunableOp
4561        torch.cuda.tunable.enable(False)
4562
4563    @onlyCUDA
4564    @skipCUDAIfNotRocm
4565    @dtypes(torch.float)
4566    def test_bmm_tunableop_rocm(self, device, dtype):
4567        # buffer rotation (on by default) with strided batched gemm tunableop was causing a mem fault
4568        set_tunableop_defaults()
4569        torch.cuda.tunable.enable(True)
4570        torch.cuda.tunable.set_max_tuning_iterations(10)
4571        # the following 3 cases cover all previous failure cases and are here to catch regressions
4572        B = 16
4573        N = M = K = 256
4574        dtype = torch.bfloat16
4575        device = torch.device("cuda:0")
4576        # case 1
4577        i1 = torch.randn((B, N, M), device=device, dtype=dtype)
4578        i2 = torch.randn((B, M, K), device=device, dtype=dtype)
4579        out = torch.bmm(i1, i2)
4580        # case 2
4581        i1 = torch.randn((B, N, M), device=device, dtype=dtype)
4582        i1 = torch.permute(i1, (1, 2, 0))
4583        i2 = torch.randn((B, M, K), device=device, dtype=dtype)
4584        i2 = torch.permute(i2, (1, 0, 2))
4585        out = torch.bmm(i1, i2)
4586        # case 3
4587        i1 = torch.randn((N, B, M), device=device, dtype=dtype)
4588        i1 = torch.permute(i1, (1, 0, 2))
4589        i2 = torch.randn((M, B, K), device=device, dtype=dtype)
4590        i2 = torch.permute(i2, (1, 2, 0))
4591        out = torch.bmm(i1, i2)
4592        # case 4
4593        input_tensor = torch.rand((1920, 1, 100), device=device, dtype=dtype)
4594        input_tensor = torch.as_strided(
4595            input_tensor, size=(1920, 1, 100), stride=(100, 100, 1)
4596        )
4597        batch1_tensor = torch.rand((1920, 256, 512), device=device, dtype=dtype)
4598        batch1_tensor = torch.as_strided(
4599            batch1_tensor, size=(1920, 256, 512), stride=(512, 983040, 1)
4600        )
4601        batch2_tensor = torch.rand((1920, 512, 100), device=device, dtype=dtype)
4602        batch2_tensor = torch.as_strided(
4603            batch2_tensor, size=(1920, 512, 100), stride=(51200, 100, 1)
4604        )
4605        out = torch.baddbmm(input_tensor, batch1_tensor, batch2_tensor)
4606        # clean up, remove any file that was generated
4607        try:
4608            import os
4609            filename = torch.cuda.tunable.get_filename()
4610            os.remove(filename)
4611        except FileNotFoundError:
4612            pass
4613
4614        # disable TunableOp
4615        torch.cuda.tunable.enable(False)
4616
4617    @onlyCUDA
4618    @skipCUDAIfNotRocm
4619    @dtypes(torch.float)
4620    def test_numeric_check_leak_tunableop_rocm(self, device, dtype):
4621        from torch.testing._internal.common_utils import CudaMemoryLeakCheck
4622        import os
4623        # run operator first without tuning to ensure all rocm libs are loaded,
4624        # otherwise false positive mem leak
4625        B = 16
4626        N = M = K = 256
4627        dtype = torch.bfloat16
4628        device = torch.device("cuda:0")
4629        i1 = torch.randn((B, N, M), device=device, dtype=dtype)
4630        i2 = torch.randn((B, M, K), device=device, dtype=dtype)
4631        out = torch.bmm(i1, i2)
4632        # enable tunableop numeric check via env variable.
4633        PYTORCH_TUNABLEOP_NUMERICAL_CHECK = "PYTORCH_TUNABLEOP_NUMERICAL_CHECK"
4634        prev_val = os.getenv(PYTORCH_TUNABLEOP_NUMERICAL_CHECK)
4635        try:
4636            os.environ[PYTORCH_TUNABLEOP_NUMERICAL_CHECK] = "1"
4637            torch.cuda.tunable.enable(True)
4638            ordinal = torch.cuda.current_device()
4639            filename = f"tunableop_results{ordinal}.csv"
4640            torch.cuda.tunable.set_filename(filename)
4641            iterations = torch.cuda.tunable.get_max_tuning_iterations()
4642            torch.cuda.tunable.set_max_tuning_iterations(10)
4643            with CudaMemoryLeakCheck(self):
4644                out = torch.bmm(i1, i2)
4645                torch.cuda.tunable.set_max_tuning_iterations(iterations)
4646                torch.cuda.tunable.enable(False)
4647                # clean up, remove any file that was generated
4648                try:
4649                    os.remove(filename)
4650                except FileNotFoundError:
4651                    pass
4652        finally:
4653            if prev_val is None:
4654                del os.environ[PYTORCH_TUNABLEOP_NUMERICAL_CHECK]
4655            else:
4656                os.environ[PYTORCH_TUNABLEOP_NUMERICAL_CHECK] = prev_val
4657
4658    @onlyCUDA
4659    @skipCUDAIfNotRocm
4660    @dtypes(torch.float)
4661    def test_validator_tunableop_rocm(self, device, dtype):
4662        # Test that the validator on ROCM has exactly 5 lines
4663        # Format of the Validator is as follows:
4664        # Validator,PT_VERSION,X.Y.Z.
4665        # Validator,ROCBLAS_VERSION,X.Y,Z
4666        # Validator,HIPBLASLT_VERSION,X,Y.Z
4667        # Validator,ROCM_Version,X,Y.Z
4668        # Validator,GCN_ARCH_NAME,<architecutre name>
4669        validator_num_lines = 5
4670
4671        # Test in try-finally block to avoid leaking state
4672        # if test is interrupted.
4673        try:
4674            set_tunableop_defaults()
4675            torch.cuda.tunable.enable()
4676            # set these to single iterations to keep it short but still exercise the code
4677            torch.cuda.tunable.set_max_tuning_iterations(1)
4678
4679            N = M = K = 4
4680            A = torch.randn(N, K, device=device, dtype=dtype)
4681            B = torch.randn(K, M, device=device, dtype=dtype)
4682            C = torch.matmul(A, B)
4683            self.assertEqual(len(torch.cuda.tunable.get_validators()), validator_num_lines)
4684        finally:
4685            # disable TunableOp
4686            torch.cuda.tunable.enable(False)
4687
4688            # clean up, remove any file that was generated
4689            try:
4690                import os
4691                filename = torch.cuda.tunable.get_filename()
4692                os.remove(filename)
4693            except FileNotFoundError:
4694                pass
4695
4696    @onlyCUDA
4697    @dtypes(torch.half)
4698    def test_minimum_tuning_iteration_tunableop(self, device, dtype):
4699        # Make sure that there is at least one tuning iteration under various scenarios
4700
4701        # Test in try-finally block to avoid leaking state
4702        # if test is interrupted.
4703        try:
4704            set_tunableop_defaults()
4705            torch.cuda.tunable.enable()
4706            # set these to single iterations to keep it short but still exercise the code
4707            torch.cuda.tunable.set_max_tuning_iterations(1)
4708
4709            # Set tuning duration to zero milliseconds
4710            # Tune a single GEMM and verify that we get a new tuning result
4711            import os
4712            os.environ["PYTORCH_TUNABLEOP_MAX_TUNING_DURATION_MS"] = "0"
4713            self.assertGreater(torch.cuda.tunable.get_max_tuning_iterations(), 0)
4714            os.environ["PYTORCH_TUNABLEOP_MAX_TUNING_DURATION_MS"] = "30"  # reset to default
4715
4716            # Reference number of results
4717            ref_num_results = len(torch.cuda.tunable.get_results())
4718
4719            N = M = K = 8
4720            A = torch.randn(N, K, device=device, dtype=dtype)
4721            B = torch.randn(K, M, device=device, dtype=dtype)
4722            C = torch.matmul(A, B)
4723
4724            # This stores total number of cummulative results
4725            total_num_results = len(torch.cuda.tunable.get_results())
4726
4727            # There must be a new tuning result
4728            self.assertEqual((total_num_results - ref_num_results), 1)
4729
4730            # Set tuning iterations to zero
4731            # Tune a single GEMM and verify that we get a new tuning result
4732            os.environ["PYTORCH_TUNABLEOP_MAX_TUNING_ITERATIONS"] = "0"
4733            self.assertGreater(torch.cuda.tunable.get_max_tuning_iterations(), 0)
4734            os.environ["PYTORCH_TUNABLEOP_MAX_TUNING_ITERATIONS"] = "100"  # reset to default
4735
4736            # Reference number of results
4737            ref_num_results = total_num_results
4738
4739            N = M = K = 16
4740            A = torch.randn(N, K, device=device, dtype=dtype)
4741            B = torch.randn(K, M, device=device, dtype=dtype)
4742            C = torch.matmul(A, B)
4743
4744            # This stores total number of cummulative results
4745            total_num_results = len(torch.cuda.tunable.get_results())
4746
4747            # There must be a new tuning result
4748            self.assertEqual((total_num_results - ref_num_results), 1)
4749
4750        finally:
4751            # disable TunableOp
4752            torch.cuda.tunable.enable(False)
4753
4754            # clean up, remove any file that was generated
4755            try:
4756                import os
4757                filename = torch.cuda.tunable.get_filename()
4758                os.remove(filename)
4759            except FileNotFoundError:
4760                pass
4761
4762    @onlyCUDA
4763    @dtypes(torch.half)
4764    def test_matmul_check_entries_tunableop(self, device, dtype):
4765        # Tune a couple of matrix multiplies
4766        # Verify we get the correct number of results
4767
4768        try:
4769            set_tunableop_defaults()
4770            torch.cuda.tunable.enable()
4771            # set these to single iterations to keep it short but still exercise the code
4772            torch.cuda.tunable.set_max_tuning_iterations(1)
4773
4774            # Reference number of results
4775            ref_num_results = len(torch.cuda.tunable.get_results())
4776
4777            # Execute matrix multiplies. We intentionally throw in M list the same index
4778            # twice. The CSV file should only get unique GEMMs
4779            count_matmul = 4
4780            K = 64
4781            for M in [32, 64, 32]:
4782                for N in [32, 64]:
4783                    A = torch.randn(N, K, device=device, dtype=dtype)
4784                    B = torch.randn(K, M, device=device, dtype=dtype)
4785                    C = torch.matmul(A, B)
4786
4787            # This stores total number of cummulative results
4788            total_num_results = len(torch.cuda.tunable.get_results())
4789
4790            # Take the difference to calculate the number of results from
4791            # the this test and verify that it agrees with the number of
4792            # GEMMs.
4793            self.assertEqual((total_num_results - ref_num_results), count_matmul)
4794
4795        finally:
4796            # disable TunableOp
4797            torch.cuda.tunable.enable(False)
4798
4799            # clean up, remove any file that was generated
4800            try:
4801                import os
4802                filename = torch.cuda.tunable.get_filename()
4803                os.remove(filename)
4804            except FileNotFoundError:
4805                pass
4806
4807    @onlyCUDA
4808    @skipCUDAIfNotRocm
4809    @dtypes(torch.float)
4810    def test_bmm_tunableop_rocm(self, device, dtype):
4811        # buffer rotation (on by default) with strided batched gemm tunableop was causing a mem fault
4812        torch.cuda.tunable.enable(True)
4813        ordinal = torch.cuda.current_device()
4814        filename = f"tunableop_results{ordinal}.csv"
4815        torch.cuda.tunable.set_filename(filename)
4816        iterations = torch.cuda.tunable.get_max_tuning_iterations()
4817        torch.cuda.tunable.set_max_tuning_iterations(10)
4818        # the following 3 cases cover all previous failure cases and are here to catch regressions
4819        B = 16
4820        N = M = K = 256
4821        dtype = torch.bfloat16
4822        device = torch.device("cuda:0")
4823        # case 1
4824        i1 = torch.randn((B, N, M), device=device, dtype=dtype)
4825        i2 = torch.randn((B, M, K), device=device, dtype=dtype)
4826        out = torch.bmm(i1, i2)
4827        # case 2
4828        i1 = torch.randn((B, N, M), device=device, dtype=dtype)
4829        i1 = torch.permute(i1, (1, 2, 0))
4830        i2 = torch.randn((B, M, K), device=device, dtype=dtype)
4831        i2 = torch.permute(i2, (1, 0, 2))
4832        out = torch.bmm(i1, i2)
4833        # case 3
4834        i1 = torch.randn((N, B, M), device=device, dtype=dtype)
4835        i1 = torch.permute(i1, (1, 0, 2))
4836        i2 = torch.randn((M, B, K), device=device, dtype=dtype)
4837        i2 = torch.permute(i2, (1, 2, 0))
4838        out = torch.bmm(i1, i2)
4839        # clean up, remove any file that was generated
4840        try:
4841            import os
4842            os.remove(filename)
4843        except FileNotFoundError:
4844            pass
4845        # reset back to prior settings
4846        torch.cuda.tunable.set_max_tuning_iterations(iterations)
4847        torch.cuda.tunable.enable(False)
4848
4849    @onlyCUDA
4850    @skipCUDAIfNotRocm
4851    @dtypes(torch.float)
4852    def test_numeric_check_leak_tunableop_rocm(self, device, dtype):
4853        from torch.testing._internal.common_utils import CudaMemoryLeakCheck
4854        import os
4855        # run operator first without tuning to ensure all rocm libs are loaded,
4856        # otherwise false positive mem leak
4857        B = 16
4858        N = M = K = 256
4859        dtype = torch.bfloat16
4860        device = torch.device("cuda:0")
4861        i1 = torch.randn((B, N, M), device=device, dtype=dtype)
4862        i2 = torch.randn((B, M, K), device=device, dtype=dtype)
4863        out = torch.bmm(i1, i2)
4864        # enable tunableop numeric check via env variable.
4865        PYTORCH_TUNABLEOP_NUMERICAL_CHECK = "PYTORCH_TUNABLEOP_NUMERICAL_CHECK"
4866        prev_val = os.getenv(PYTORCH_TUNABLEOP_NUMERICAL_CHECK)
4867        try:
4868            os.environ[PYTORCH_TUNABLEOP_NUMERICAL_CHECK] = "1"
4869            torch.cuda.tunable.enable(True)
4870            ordinal = torch.cuda.current_device()
4871            filename = f"tunableop_results{ordinal}.csv"
4872            torch.cuda.tunable.set_filename(filename)
4873            iterations = torch.cuda.tunable.get_max_tuning_iterations()
4874            torch.cuda.tunable.set_max_tuning_iterations(10)
4875            with CudaMemoryLeakCheck(self):
4876                out = torch.bmm(i1, i2)
4877                torch.cuda.tunable.set_max_tuning_iterations(iterations)
4878                torch.cuda.tunable.enable(False)
4879                # clean up, remove any file that was generated
4880                try:
4881                    os.remove(filename)
4882                except FileNotFoundError:
4883                    pass
4884        finally:
4885            if prev_val is None:
4886                del os.environ[PYTORCH_TUNABLEOP_NUMERICAL_CHECK]
4887            else:
4888                os.environ[PYTORCH_TUNABLEOP_NUMERICAL_CHECK] = prev_val
4889
4890
4891    @dtypes(torch.float, torch.complex64)
4892    def test_matmul_out_kernel_errors_with_autograd(self, device, dtype):
4893        a = torch.empty((256, 512), device=device, dtype=dtype, requires_grad=True).unsqueeze(0)
4894        b = torch.empty((4, 128, 512), device=device, dtype=dtype, requires_grad=True).transpose(-1, -2)
4895        c = torch.empty((256, 4, 128), device=device, dtype=dtype).movedim(1, 0)
4896
4897        torch.matmul(a.detach(), b.detach(), out=c)
4898
4899        with self.assertRaisesRegex(RuntimeError, "functions with out=... arguments don't support automatic differentiation"):
4900            torch.matmul(a, b, out=c)
4901
4902        with torch.no_grad():
4903            torch.matmul(a, b, out=c)
4904
4905    # 4GB should do, but we run tests in parallel in CI, so let's be generous
4906    @largeTensorTest('16GB', device='cuda')
4907    def test_large_bmm_mm_backward(self, device):
4908        A = torch.randn([1024, 2, 1024], device="cuda").mT.contiguous().mT
4909        B = torch.randn([1024, 65536], device="cuda", requires_grad=True)
4910        G = torch.randn([1024, 2, 65536], device="cuda")
4911
4912        # Should not create an intermediary tensor of size [1024, 1024, 65536] (256GB of memory) and OOM
4913        (A @ B).backward(G)
4914
4915    # 4GB should do, but we run tests in parallel in CI, so let's be generous
4916    @largeTensorTest('16GB', device='cuda')
4917    def test_large_bmm_backward(self, device):
4918        A = torch.randn([1024, 2, 1024], device="cuda").mT.contiguous().mT
4919        B = torch.randn([1, 1024, 65536], device="cuda", requires_grad=True)
4920        G = torch.randn([1024, 2, 65536], device="cuda")
4921
4922        # Should not create an intermediary tensor of size [1024, 1024, 65536] (256GB of memory) and OOM
4923        (A @ B).backward(G)
4924
4925    def test_linear_algebra_scalar_raises(self, device) -> None:
4926        m = torch.randn(5, 5, device=device)
4927        v = torch.randn(5, device=device)
4928        s = torch.tensor(7, device=device)
4929        self.assertRaises(RuntimeError, lambda: torch.mv(m, s))
4930        self.assertRaises(RuntimeError, lambda: torch.addmv(v, m, s))
4931
4932    @dtypes(torch.float32, torch.complex64)
4933    def test_cross(self, device, dtype):
4934        x = torch.rand(100, 3, 100, dtype=dtype, device=device)
4935        y = torch.rand(100, 3, 100, dtype=dtype, device=device)
4936        res1 = torch.cross(x, y)
4937        res2 = torch.tensor((), dtype=dtype, device=device)
4938        torch.cross(x, y, out=res2)
4939        self.assertEqual(res1, res2)
4940
4941    @dtypes(torch.float32, torch.complex64)
4942    def test_linalg_cross(self, device, dtype):
4943        x = torch.rand(100, 3, 100, dtype=dtype, device=device)
4944        y = torch.rand(100, 3, 100, dtype=dtype, device=device)
4945        res1 = torch.linalg.cross(x, y, dim=1)
4946        res2 = torch.tensor((), dtype=dtype, device=device)
4947        torch.linalg.cross(x, y, dim=1, out=res2)
4948        self.assertEqual(res1, res2)
4949
4950        # test for broadcastable inputs
4951        x = torch.rand(1, 3, 2, dtype=dtype, device=device)
4952        y = torch.rand(4, 3, 1, dtype=dtype, device=device)
4953        res1 = torch.linalg.cross(x, y, dim=1)
4954        res2 = torch.tensor((), dtype=dtype, device=device)
4955        torch.linalg.cross(x, y, dim=1, out=res2)
4956        self.assertEqual(res1, res2)
4957
4958    @dtypes(torch.float32, torch.complex64)
4959    def test_cross_with_and_without_dim(self, device, dtype):
4960        x = torch.rand(100, 3, dtype=dtype, device=device)
4961        y = torch.rand(100, 3, dtype=dtype, device=device)
4962        res1 = torch.cross(x, y, dim=1)
4963        res2 = torch.cross(x, y, dim=-1)
4964        res3 = torch.cross(x, y)
4965        self.assertEqual(res1, res2)
4966        self.assertEqual(res1, res3)
4967
4968    @dtypes(torch.float32, torch.complex64)
4969    def test_linalg_cross_with_and_without_dim(self, device, dtype):
4970        x = torch.rand(100, 3, dtype=dtype, device=device)
4971        y = torch.rand(100, 3, dtype=dtype, device=device)
4972        res1 = torch.linalg.cross(x, y, dim=1)
4973        res2 = torch.linalg.cross(x, y, dim=-1)
4974        res3 = torch.linalg.cross(x, y)
4975        self.assertEqual(res1, res2)
4976        self.assertEqual(res1, res3)
4977
4978    def test_renorm(self, device):
4979        m1 = torch.randn(20, 20, device=device)  # big enough to exercise vectorized path
4980        res1 = torch.tensor((), device=device)
4981
4982        def renorm(matrix, value, dim, max_norm):
4983            m1 = matrix.transpose(dim, 0).contiguous()
4984            # collapse non-dim dimensions.
4985            m2 = m1.clone().resize_(m1.size(0), int(math.floor(m1.nelement() / m1.size(0))))
4986            norms = m2.norm(value, 1, True)
4987            # clip
4988            new_norms = norms.clone()
4989            new_norms[torch.gt(norms, max_norm)] = max_norm
4990            new_norms.div_(norms.add_(1e-7))
4991            # renormalize
4992            m1.mul_(new_norms.expand_as(m1))
4993            return m1.transpose(dim, 0)
4994
4995        # note that the axis fed to torch.renorm is different (2~=1)
4996        maxnorm = m1.norm(2, 1).mean()
4997        m2 = renorm(m1, 2, 1, maxnorm)
4998        m1.renorm_(2, 1, maxnorm)
4999        self.assertEqual(m1, m2, atol=1e-5, rtol=0)
5000        self.assertEqual(m1.norm(2, 0), m2.norm(2, 0), atol=1e-5, rtol=0)
5001
5002        m1 = torch.randn(3, 4, 5, device=device)
5003        m2 = m1.transpose(1, 2).contiguous().clone().resize_(15, 4)
5004        maxnorm = m2.norm(2, 0).mean()
5005        m2 = renorm(m2, 2, 1, maxnorm)
5006        m1.renorm_(2, 1, maxnorm)
5007        m3 = m1.transpose(1, 2).contiguous().clone().resize_(15, 4)
5008        self.assertEqual(m3, m2)
5009        self.assertEqual(m3.norm(2, 0), m2.norm(2, 0))
5010
5011    @skipCPUIfNoLapack
5012    @skipCUDAIfNoCusolver
5013    @dtypes(*floating_and_complex_types())
5014    def test_ormqr(self, device, dtype):
5015
5016        def run_test(batch, m, n, fortran_contiguous):
5017            A = make_tensor((*batch, m, n), dtype=dtype, device=device)
5018            reflectors, tau = torch.geqrf(A)
5019            if not fortran_contiguous:
5020                self.assertTrue(reflectors.mT.is_contiguous())
5021                reflectors = reflectors.contiguous()
5022
5023            # Q is of size m x m
5024            Q, _ = torch.linalg.qr(A, mode='complete')
5025            C_right = make_tensor((*batch, m, n), dtype=dtype, device=device)
5026            C_left = make_tensor((*batch, n, m), dtype=dtype, device=device)
5027
5028            expected = Q @ C_right
5029            actual = torch.ormqr(reflectors, tau, C_right, left=True, transpose=False)
5030            self.assertEqual(expected, actual)
5031
5032            expected = C_left @ Q
5033            actual = torch.ormqr(reflectors, tau, C_left, left=False, transpose=False)
5034            self.assertEqual(expected, actual)
5035
5036            expected = Q.mH @ C_right
5037            actual = torch.ormqr(reflectors, tau, C_right, left=True, transpose=True)
5038            self.assertEqual(expected, actual)
5039
5040            expected = C_left @ Q.mH
5041            actual = torch.ormqr(reflectors, tau, C_left, left=False, transpose=True)
5042            self.assertEqual(expected, actual)
5043
5044            # if tau is all zeros then the implicit matrix Q is the identity matrix
5045            # so the actual result should be C_right in this case
5046            zero_tau = torch.zeros_like(tau)
5047            actual = torch.ormqr(reflectors, zero_tau, C_right, left=True, transpose=False)
5048            self.assertEqual(C_right, actual)
5049
5050        batches = [(), (0, ), (2, ), (2, 1)]
5051        ns = [5, 2, 0]
5052        for batch, (m, n), fortran_contiguous in product(batches, product(ns, ns), [True, False]):
5053            run_test(batch, m, n, fortran_contiguous)
5054
5055    @skipCPUIfNoLapack
5056    @skipCUDAIfNoCusolver
5057    @dtypes(*floating_and_complex_types())
5058    def test_ormqr_errors_and_warnings(self, device, dtype):
5059        test_cases = [
5060            # input1 size, input2 size, input3 size, error regex
5061            ((10,), (2,), (2,), r"input must have at least 2 dimensions"),
5062            ((2, 2), (2,), (2,), r"other must have at least 2 dimensions"),
5063            ((10, 6), (20,), (10, 6), r"other.shape\[-2\] must be greater than or equal to tau.shape\[-1\]"),
5064            ((6, 6), (5,), (5, 5), r"other.shape\[-2\] must be equal to input.shape\[-2\]"),
5065            ((1, 2, 2), (2, 2), (1, 2, 2), r"batch dimensions of tau to be equal to input.shape\[:-2\]"),
5066            ((1, 2, 2), (1, 2), (2, 2, 2), r"batch dimensions of other to be equal to input.shape\[:-2\]"),
5067        ]
5068        for a_size, tau_size, c_size, error_regex in test_cases:
5069            a = make_tensor(a_size, dtype=dtype, device=device)
5070            tau = make_tensor(tau_size, dtype=dtype, device=device)
5071            c = make_tensor(c_size, dtype=dtype, device=device)
5072            with self.assertRaisesRegex(RuntimeError, error_regex):
5073                torch.ormqr(a, tau, c)
5074
5075    def test_blas_empty(self, device):
5076        def fn(torchfn, *args, test_out=False, **kwargs):
5077            def call_torch_fn(*args, **kwargs):
5078                return torchfn(*tuple(torch.randn(shape, device=device) if isinstance(shape, tuple) else shape
5079                                      for shape in args), **kwargs)
5080            result = call_torch_fn(*args, **kwargs)
5081            if not test_out:
5082                return result
5083            else:
5084                out = torch.full_like(result, math.nan)
5085                out1 = call_torch_fn(*args, **kwargs, out=out)
5086                return out
5087
5088        # mm, addmm
5089        self.assertEqual((0, 0), fn(torch.mm, (0, 0), (0, 0)).shape)
5090        self.assertEqual((0, 5), fn(torch.mm, (0, 0), (0, 5)).shape)
5091        self.assertEqual((5, 0), fn(torch.mm, (5, 0), (0, 0)).shape)
5092        self.assertEqual((3, 0), fn(torch.mm, (3, 2), (2, 0)).shape)
5093        self.assertEqual(torch.zeros((5, 6), device=device), fn(torch.mm, (5, 0), (0, 6)))
5094        self.assertEqual(torch.zeros((5, 6), device=device), fn(torch.mm, (5, 0), (0, 6), test_out=True))
5095
5096        self.assertEqual((0, 0), fn(torch.addmm, (0, 0), (0, 0), (0, 0)).shape)
5097        self.assertEqual((0, 1), fn(torch.addmm, (1, ), (0, 17), (17, 1)).shape)
5098        t = torch.randn((5, 6), device=device)
5099        self.assertEqual(t, fn(torch.addmm, t, (5, 0), (0, 6)))
5100        self.assertEqual(t, fn(torch.addmm, t, (5, 0), (0, 6), test_out=True))
5101
5102        # mv, addmv
5103        self.assertEqual((0,), fn(torch.mv, (0, 0), (0,)).shape)
5104        self.assertEqual((0,), fn(torch.mv, (0, 2), (2,)).shape)
5105        self.assertEqual(torch.zeros((3,), device=device), fn(torch.mv, (3, 0), (0,)))
5106        self.assertEqual(torch.zeros((3,), device=device), fn(torch.mv, (3, 0), (0,), test_out=True))
5107
5108        self.assertEqual((0,), fn(torch.addmv, (0,), (0, 0), (0,)).shape)
5109        t = torch.randn((3,), device=device)
5110        self.assertEqual(t, fn(torch.addmv, t, (3, 0), (0,)))
5111        self.assertEqual(t, fn(torch.addmv, t, (3, 0), (0,), test_out=True))
5112
5113        # bmm, baddbmm
5114        self.assertEqual((0, 0, 0), fn(torch.bmm, (0, 0, 0), (0, 0, 0)).shape)
5115        self.assertEqual((3, 0, 5), fn(torch.bmm, (3, 0, 0), (3, 0, 5)).shape)
5116        self.assertEqual((0, 5, 6), fn(torch.bmm, (0, 5, 0), (0, 0, 6)).shape)
5117        self.assertEqual(torch.zeros((3, 5, 6), device=device), fn(torch.bmm, (3, 5, 0), (3, 0, 6)))
5118        self.assertEqual(torch.zeros((3, 5, 6), device=device), fn(torch.bmm, (3, 5, 0), (3, 0, 6), test_out=True))
5119
5120        self.assertEqual((0, 0, 0), fn(torch.baddbmm, (0, 0, 0), (0, 0, 0), (0, 0, 0)).shape)
5121        self.assertEqual((3, 0, 5), fn(torch.baddbmm, (3, 0, 5), (3, 0, 0), (3, 0, 5)).shape)
5122        self.assertEqual((0, 5, 6), fn(torch.baddbmm, (0, 5, 6), (0, 5, 0), (0, 0, 6)).shape)
5123        self.assertEqual((3, 5, 6), fn(torch.baddbmm, (3, 5, 6), (3, 5, 0), (3, 0, 6)).shape)
5124        c = torch.arange(30, dtype=torch.float32, device=device).reshape(3, 2, 5)
5125        self.assertEqual(-2 * c, fn(torch.baddbmm, c, (3, 2, 0), (3, 0, 5), beta=-2))  # Issue #33467
5126        self.assertEqual(-2 * c, fn(torch.baddbmm, c, (3, 2, 0), (3, 0, 5), beta=-2, test_out=True))  # Issue #33467
5127
5128        # addbmm
5129        self.assertEqual((0, 0), fn(torch.addbmm, (0, 0), (0, 0, 0), (0, 0, 0)).shape)
5130        self.assertEqual((0, 5), fn(torch.addbmm, (0, 5), (3, 0, 0), (3, 0, 5)).shape)
5131        t = torch.randn((5, 6), device=device)
5132        self.assertEqual(t, fn(torch.addbmm, t, (0, 5, 0), (0, 0, 6)))
5133        self.assertEqual(t, fn(torch.addbmm, t, (0, 5, 0), (0, 0, 6), test_out=True))
5134
5135        # matmul
5136        self.assertEqual(torch.tensor(0., device=device), fn(torch.matmul, (0,), (0,)))
5137        self.assertEqual(torch.tensor(0., device=device), fn(torch.matmul, (0,), (0,), test_out=True))
5138        self.assertEqual((0, 0), fn(torch.matmul, (0, 0), (0, 0)).shape)
5139        self.assertEqual((0, 0, 0), fn(torch.matmul, (0, 0, 0), (0, 0, 0)).shape)
5140        self.assertEqual((5, 0, 0), fn(torch.matmul, (5, 0, 0), (5, 0, 0)).shape)
5141        self.assertEqual(torch.zeros((5, 3, 4), device=device), fn(torch.matmul, (5, 3, 0), (5, 0, 4)))
5142        self.assertEqual(torch.zeros((5, 3, 4), device=device), fn(torch.matmul, (5, 3, 0), (5, 0, 4), test_out=True))
5143
5144        # dot
5145        self.assertEqual(torch.tensor(0., device=device), fn(torch.dot, (0,), (0,)))
5146        self.assertEqual(torch.tensor(0., device=device), fn(torch.dot, (0,), (0,), test_out=True))
5147
5148    @precisionOverride({torch.double: 1e-8, torch.float: 1e-4, torch.bfloat16: 0.6,
5149                        torch.half: 1e-1, torch.cfloat: 1e-4, torch.cdouble: 1e-8})
5150    @dtypesIfCUDA(*floating_and_complex_types_and(
5151                  torch.half,
5152                  *[torch.bfloat16] if SM53OrLater else []
5153                  ))
5154    @dtypes(*all_types_and_complex_and(torch.bfloat16))
5155    def test_corner_cases_of_cublasltmatmul(self, device, dtype):
5156        # common case
5157        M = torch.randn(128, device=device).to(dtype)
5158        m1 = torch.randn(2048, 2400, device=device).to(dtype)
5159        m2 = torch.randn(128, 2400, device=device).to(dtype)
5160        torch.nn.functional.linear(m1, m2, M)
5161        # Ntrans_B has ld >> rows
5162        m1 = torch.rand([128, 2400]).to(dtype).to(device).t()
5163        m2 = torch.rand([2048, 25272]).to(dtype).to(device).t()[21940:24340]
5164        M = torch.rand([128]).to(dtype).to(device)
5165        torch.addmm(M, m2.t(), m1)
5166        # trans_A has ld >> rows
5167        m1 = torch.rand([128, 25272]).to(dtype).to(device)[:, 21940:24340].t()
5168        m2 = torch.randn(2048, 2400, device=device).to(dtype)
5169        M = torch.rand([128]).to(dtype).to(device)
5170        torch.addmm(M, m2, m1)
5171        # large tensor dim > 65535
5172        M = torch.randn(16, device=device).to(dtype)
5173        m1 = torch.randn(32, 131071 , device=device).to(dtype)
5174        m2 = torch.randn(16, 131071, device=device).to(dtype)
5175        torch.nn.functional.linear(m1, m2, M)
5176
5177    @onlyCUDA
5178    @skipCUDAIfNotRocm
5179    @dtypes(*floating_types_and(torch.bfloat16, torch.half))
5180    def test_hipblaslt_corner_cases_rocm(self, device, dtype):
5181        if dtype == torch.double:
5182            raise unittest.SkipTest("hipblasLt doesn't support doubles yet")
5183
5184        # enable hipblaslt path via env variable.
5185        import os
5186        DISABLE_ADDMM_HIP_LT = "DISABLE_ADDMM_HIP_LT"
5187        prev_val = os.getenv(DISABLE_ADDMM_HIP_LT)
5188        try:
5189            os.environ[DISABLE_ADDMM_HIP_LT] = "0"
5190            # common case
5191            M = torch.randn(128, device=device, dtype=dtype)
5192            m1 = torch.randn(2048, 2400, device=device, dtype=dtype)
5193            m2 = torch.randn(128, 2400, device=device, dtype=dtype)
5194            out1 = torch.nn.functional.linear(m1, m2, M)
5195            M_cpu = M.to('cpu')
5196            m1_cpu = m1.to('cpu')
5197            m2_cpu = m2.to('cpu')
5198            out1_cpu = torch.nn.functional.linear(m1_cpu, m2_cpu, M_cpu)
5199            self.assertTrue(torch.allclose(out1_cpu, out1.cpu(), rtol=1e-2, atol=1e-2))
5200
5201            # common case without bias
5202            m1 = torch.randn(2048, 2400, device=device, dtype=dtype)
5203            m2 = torch.randn(128, 2400, device=device, dtype=dtype)
5204            out2 = torch.nn.functional.linear(m1, m2, bias=None)
5205            m1_cpu = m1.to('cpu')
5206            m2_cpu = m2.to('cpu')
5207            out2_cpu = torch.nn.functional.linear(m1_cpu, m2_cpu, bias=None)
5208            self.assertTrue(torch.allclose(out2_cpu, out2.cpu(), rtol=1e-2, atol=1e-2))
5209        finally:
5210            if prev_val is None:
5211                del os.environ[DISABLE_ADDMM_HIP_LT]
5212            else:
5213                os.environ[DISABLE_ADDMM_HIP_LT] = prev_val
5214
5215    @dtypesIfCUDA(*floating_and_complex_types_and(
5216                  torch.half,
5217                  *[torch.bfloat16] if SM53OrLater else []
5218                  ))
5219    @dtypes(*all_types_and_complex_and(torch.bfloat16, torch.half))
5220    def test_blas_alpha_beta_empty(self, device, dtype):
5221        # This test is disabled on CUDA 9 due to:
5222        # See: https://github.com/pytorch/pytorch/issues/31006
5223        if dtype is torch.bfloat16 and self.device_type == 'xla':
5224            # TODO (@zasdfgbnm): this causes the following error on test
5225            # TestTorchDeviceTypeXLA.test_blas_alpha_beta_empty_xla_bfloat16:
5226            #
5227            #   RuntimeError: _th_equal not supported on CPUType for BFloat16
5228            return
5229        # ensure beta is respected
5230        value = 11
5231        input = torch.full((2,), value, dtype=dtype, device=device)
5232        mat = torch.ones((2, 0), dtype=dtype, device=device)
5233        vec = torch.ones((0,), dtype=dtype, device=device)
5234        out = torch.empty((2,), dtype=dtype, device=device)
5235        if dtype.is_complex:
5236            alpha = 6 + 7j
5237            beta = 3 + 4j
5238        else:
5239            alpha = 6
5240            beta = 3
5241        self.assertEqual(torch.full((2,), beta * value, dtype=dtype, device=device),
5242                         torch.addmv(input=input, mat=mat, vec=vec, alpha=alpha, beta=beta))
5243        self.assertEqual(torch.full((2,), beta * value, dtype=dtype, device=device),
5244                         torch.addmv(input=input, mat=mat, vec=vec, alpha=alpha, beta=beta, out=out))
5245
5246        # torch.addmm
5247        input = torch.full((2, 3), value, dtype=dtype, device=device)
5248        mat2 = torch.ones((0, 3), dtype=dtype, device=device)
5249        out = torch.empty((2, 3), dtype=dtype, device=device)
5250        self.assertEqual(torch.full((2, 3), beta * value, dtype=dtype, device=device),
5251                         torch.addmm(input=input, mat1=mat, mat2=mat2, alpha=alpha, beta=beta))
5252        self.assertEqual(torch.full((2, 3), beta * value, dtype=dtype, device=device),
5253                         torch.addmm(input=input, mat1=mat, mat2=mat2, alpha=alpha, beta=beta, out=out))
5254
5255    @dtypes(*floating_and_complex_types_and(torch.half, torch.bfloat16))
5256    def test_blas_nan_out(self, device, dtype):
5257        # These functions should work correctly with NaN filled outputs,
5258        # but need special handling, see [NOTE: cpu_zero]
5259        b = 3
5260        n = 5
5261        m = 7
5262        p = 11
5263
5264        # torch.mv
5265        nm = torch.randn((m, n), device=device).t()
5266        _m = torch.randn((), device=device).expand(m)
5267        _m_out = torch.full((m,), float('nan'), device=device)
5268        self.assertEqual(torch.mv(nm, _m), torch.mv(nm, _m, out=_m_out))
5269        self.assertEqual(0, torch.isnan(torch.mv(nm, _m)).sum())
5270
5271        # torch.mm
5272        mp = torch.randn((p, m), device=device).t()
5273        np_out = torch.full((n, p), float('nan'), device=device)
5274        self.assertEqual(torch.mm(nm, mp), torch.mm(nm, mp, out=np_out))
5275
5276        # torch.bmm
5277        bnm = torch.randn((b, m, n), device=device).transpose(1, 2)
5278        bmp = torch.randn((b, p, m), device=device).transpose(1, 2)
5279        bnp_out = torch.full((b, n, p), float('nan'), device=device)
5280        self.assertEqual(torch.bmm(bnm, bmp), torch.bmm(bnm, bmp, out=bnp_out))
5281
5282    @onlyCPU  # not supported by CUBLAS
5283    def test_blas_mv_large_input(self, device):
5284        # This would previously fail if the allocated output had NaNs, see:
5285        # https://github.com/pytorch/pytorch/issues/31663 and [NOTE: cpu_zero]
5286        n = 3000
5287        m = 200
5288
5289        nm = torch.randn((m, n), device=device).t()
5290        _m = torch.randn((), device=device).expand(m)
5291        _m_out = torch.full((m,), 0., device=device)
5292
5293        self.assertEqual(torch.mv(nm, _m), torch.mv(nm, _m, out=_m_out))
5294
5295    @onlyCPU
5296    def test_renorm_ps(self, device):
5297        # full reduction
5298        x = torch.randn(5, 5)
5299        xn = x.numpy()
5300        for p in [1, 2, 3, 4, inf]:
5301            res = x.renorm(p, 1, 1)
5302            expected = x / x.norm(p, 0, keepdim=True).clamp(min=1)
5303            self.assertEqual(res, expected, msg=f"renorm failed for {p}-norm")
5304
5305    @skipCPUIfNoLapack
5306    @skipCUDAIfNoCusolver
5307    @dtypes(*floating_and_complex_types())
5308    def test_householder_product(self, device, dtype):
5309        def generate_reflectors_and_tau(A):
5310            """
5311            This function uses numpy.linalg.qr with mode "raw" to extract output of LAPACK's geqrf.
5312            There is torch.geqrf function but it doesn't work with complex-valued input.
5313            """
5314            if A.numel() > 0:
5315                A_cpu = A.cpu()
5316                flattened_batch_shape = [-1, *A_cpu.shape[-2:]]
5317                reflectors = torch.empty_like(A_cpu).view(*flattened_batch_shape)
5318                tau_shape = [*A_cpu.shape[:-2], A_cpu.shape[-1]]
5319                tau = torch.empty(tau_shape, dtype=dtype).view(-1, A_cpu.shape[-1])
5320                for A_i, reflectors_i, tau_i in zip(A_cpu.contiguous().view(*flattened_batch_shape), reflectors, tau):
5321                    reflectors_tmp, tau_i[:] = map(torch.from_numpy, np.linalg.qr(A_i, mode='raw'))
5322                    reflectors_i[:] = reflectors_tmp.T
5323                reflectors = reflectors.view(*A_cpu.shape)
5324                tau = tau.view(tau_shape)
5325                return reflectors.to(A.device), tau.to(A.device)
5326
5327            reflectors = torch.empty_like(A)
5328            tau = torch.empty(*A.shape[:-2], A.shape[-1], dtype=dtype, device=device)
5329            return reflectors, tau
5330
5331        def run_test(shape):
5332            A = torch.randn(*shape, dtype=dtype, device=device)
5333            reflectors, tau = generate_reflectors_and_tau(A)
5334            expected, _ = torch.linalg.qr(A)
5335            actual = torch.linalg.householder_product(reflectors, tau)
5336            # torch.linalg.qr does not work correctly for zero batch dimension tensors
5337            # see https://github.com/pytorch/pytorch/issues/50576
5338            if (A.numel() > 0):
5339                self.assertEqual(expected, actual)
5340            else:
5341                self.assertTrue(actual.shape == shape)
5342
5343            # if tau is empty and A is not the result should be a matrix with ones on the diagonal
5344            if (A.numel() > 0):
5345                tau_empty = torch.empty(*shape[:-2], 0, dtype=dtype, device=device)
5346                identity_mat = torch.zeros_like(reflectors)
5347                identity_mat.diagonal(dim1=-1, dim2=-2)[:] = 1
5348                actual = torch.linalg.householder_product(reflectors, tau_empty)
5349                self.assertEqual(actual, identity_mat)
5350
5351            out = torch.empty_like(A)
5352            ans = torch.linalg.householder_product(reflectors, tau, out=out)
5353            self.assertEqual(ans, out)
5354            if (A.numel() > 0):
5355                self.assertEqual(expected, out)
5356
5357        shapes = [(0, 0), (5, 0),  # Empty matrix
5358                  (5, 5), (5, 3),  # Single matrix
5359                  (0, 0, 0), (0, 5, 5), (0, 5, 3),  # Zero batch dimension tensors
5360                  (2, 5, 5), (2, 5, 3),  # 3-dim tensors
5361                  (2, 1, 5, 5), (2, 1, 5, 3)]  # 4-dim tensors
5362        for shape in shapes:
5363            run_test(shape)
5364
5365    @skipCPUIfNoLapack
5366    @skipCUDAIfNoCusolver
5367    def test_householder_product_errors_and_warnings(self, device):
5368        test_cases = [
5369            # input1 size, input2 size, error regex
5370            ((10,), (2,), r"input must have at least 2 dimensions"),
5371            ((10, 6), (20,), r"input.shape\[-1\] must be greater than or equal to tau.shape\[-1\]"),
5372            ((6, 10), (5,), r"input.shape\[-2\] must be greater than or equal to input.shape\[-1\]"),
5373        ]
5374        for a_size, tau_size, error_regex in test_cases:
5375            a = torch.rand(*a_size, device=device)
5376            tau = torch.rand(*tau_size, device=device)
5377            with self.assertRaisesRegex(RuntimeError, error_regex):
5378                torch.linalg.householder_product(a, tau)
5379
5380        # if out tensor with wrong shape is passed a warning is given
5381        reflectors = torch.randn(3, 3, device=device)
5382        tau = torch.randn(3, device=device)
5383        out = torch.empty(2, 3, device=device)
5384        with warnings.catch_warnings(record=True) as w:
5385            # Trigger warning
5386            torch.linalg.householder_product(reflectors, tau, out=out)
5387            # Check warning occurs
5388            self.assertEqual(len(w), 1)
5389            self.assertTrue("An output with one or more elements was resized" in str(w[-1].message))
5390
5391        # dtypes should be safely castable
5392        out = torch.empty_like(reflectors).to(torch.int)
5393        with self.assertRaisesRegex(RuntimeError, "but got result with dtype Int"):
5394            torch.linalg.householder_product(reflectors, tau, out=out)
5395
5396        with self.assertRaisesRegex(RuntimeError, "tau dtype Int does not match input dtype"):
5397            torch.linalg.householder_product(reflectors, tau.to(torch.int))
5398
5399        if torch.cuda.is_available():
5400            # device of out and input should match
5401            wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda'
5402            out = torch.empty_like(reflectors).to(wrong_device)
5403            with self.assertRaisesRegex(RuntimeError, "Expected all tensors to be on the same device"):
5404                torch.linalg.householder_product(reflectors, tau, out=out)
5405
5406            # device of tau and input should match
5407            wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda'
5408            tau = tau.to(wrong_device)
5409            with self.assertRaisesRegex(RuntimeError, "Expected all tensors to be on the same device"):
5410                torch.linalg.householder_product(reflectors, tau)
5411
5412    @precisionOverride({torch.float32: 1e-2, torch.complex64: 1e-2})
5413    @skipCUDAIfNoMagmaAndNoCusolver
5414    @skipIfTorchDynamo("Runtime error with torch._C._linalg.linalg_lu_factor")
5415    @skipCPUIfNoLapack
5416    @dtypes(*floating_and_complex_types())
5417    def test_linalg_lu_family(self, device, dtype):
5418        # Tests torch.lu
5419        #       torch.linalg.lu_factor
5420        #       torch.linalg.lu_factor_ex
5421        #       torch.lu_unpack
5422        #       torch.linalg.lu_solve
5423        #       torch.linalg.solve
5424        make_arg_full = partial(make_fullrank_matrices_with_distinct_singular_values, device=device, dtype=dtype)
5425        make_arg = partial(make_tensor, device=device, dtype=dtype)
5426
5427        def run_test(A, pivot, singular, fn):
5428            k = min(A.shape[-2:])
5429            batch = A.shape[:-2]
5430            check_errors = (fn == torch.linalg.lu_factor)
5431            if singular and check_errors:
5432                # It may or may not throw as the LU decomposition without pivoting
5433                # may still succeed for singular matrices
5434                try:
5435                    LU, pivots = fn(A, pivot=pivot)
5436                except RuntimeError:
5437                    return
5438            else:
5439                LU, pivots = fn(A, pivot=pivot)[:2]
5440
5441            self.assertEqual(LU.size(), A.shape)
5442            self.assertEqual(pivots.size(), batch + (k,))
5443
5444            if not pivot:
5445                self.assertEqual(pivots, torch.arange(1, 1 + k, device=device, dtype=torch.int32).expand(batch + (k, )))
5446
5447            P, L, U = torch.lu_unpack(LU, pivots, unpack_pivots=pivot)
5448
5449            self.assertEqual(P @ L @ U if pivot else L @ U, A)
5450
5451            PLU = torch.linalg.lu(A, pivot=pivot)
5452            self.assertEqual(P, PLU.P)
5453            self.assertEqual(L, PLU.L)
5454            self.assertEqual(U, PLU.U)
5455
5456            if not singular and A.size(-2) == A.size(-1):
5457                nrhs = ((), (1,), (3,))
5458                for left, rhs in product((True, False), nrhs):
5459                    # Vector case when left = False is not allowed
5460                    if not left and rhs == ():
5461                        continue
5462                    if left:
5463                        shape_B = A.shape[:-1] + rhs
5464                    else:
5465                        shape_B = A.shape[:-2] + rhs + A.shape[-1:]
5466                    B = make_arg(shape_B)
5467
5468                    # Test linalg.lu_solve. It does not support vectors as rhs
5469                    # See https://github.com/pytorch/pytorch/pull/74045#issuecomment-1112304913
5470                    if rhs != ():
5471                        for adjoint in (True, False):
5472                            X = torch.linalg.lu_solve(LU, pivots, B, left=left, adjoint=adjoint)
5473                            A_adj = A.mH if adjoint else A
5474                            if left:
5475                                self.assertEqual(B, A_adj @ X)
5476                            else:
5477                                self.assertEqual(B, X @ A_adj)
5478
5479                    # Test linalg.solve
5480                    X = torch.linalg.solve(A, B, left=left)
5481                    X_ = X.unsqueeze(-1) if rhs == () else X
5482                    B_ = B.unsqueeze(-1) if rhs == () else B
5483                    if left:
5484                        self.assertEqual(B_, A @ X_)
5485                    else:
5486                        self.assertEqual(B_, X_ @ A)
5487
5488
5489        sizes = ((3, 3), (5, 5), (4, 2), (3, 4), (0, 0), (0, 1), (1, 0))
5490        batches = ((0,), (), (1,), (2,), (3,), (1, 0), (3, 5))
5491        # Non pivoting just implemented for CUDA
5492        pivots = (True, False) if self.device_type == "cuda" else (True,)
5493        fns = (partial(torch.lu, get_infos=True), torch.linalg.lu_factor, torch.linalg.lu_factor_ex)
5494        for ms, batch, pivot, singular, fn in itertools.product(sizes, batches, pivots, (True, False), fns):
5495            shape = batch + ms
5496            A = make_arg(shape) if singular else make_arg_full(*shape)
5497            # Just do one of them on singular matrices
5498            if A.numel() == 0 and not singular:
5499                continue
5500            run_test(A, pivot, singular, fn)
5501
5502            # Reproducer of a magma bug,
5503            # see https://bitbucket.org/icl/magma/issues/13/getrf_batched-kernel-produces-nans-on
5504            # This is also a bug in cuSOLVER < 11.3
5505            if (dtype == torch.double
5506               and singular):
5507                A = torch.ones(batch + ms, dtype=dtype, device=device)
5508                run_test(A, pivot, singular, fn)
5509
5510        # Info should be positive for rank deficient matrices
5511        A = torch.ones(5, 3, 3, device=device)
5512        self.assertTrue((torch.linalg.lu_factor_ex(A, pivot=True).info >= 0).all())
5513
5514        if self.device_type == 'cpu':
5515            # Error checking, no pivoting variant on CPU
5516            fns = [torch.lu, torch.linalg.lu_factor, torch.linalg.lu_factor_ex, torch.linalg.lu]
5517            for f in fns:
5518                with self.assertRaisesRegex(RuntimeError, 'LU without pivoting is not implemented on the CPU'):
5519                    f(torch.empty(1, 2, 2), pivot=False)
5520
5521
5522    @precisionOverride({torch.float32: 1e-2, torch.complex64: 1e-2})
5523    @skipCUDAIfNoMagmaAndNoCusolver
5524    @skipCPUIfNoLapack
5525    @setLinalgBackendsToDefaultFinally
5526    @dtypes(*floating_and_complex_types())
5527    def test_linalg_lu_solve(self, device, dtype):
5528        make_arg = partial(make_tensor, dtype=dtype, device=device)
5529
5530        backends = ["default"]
5531
5532        if torch.device(device).type == 'cuda':
5533            if torch.cuda.has_magma:
5534                backends.append("magma")
5535            if has_cusolver():
5536                backends.append("cusolver")
5537
5538        def gen_matrices():
5539            rhs = 3
5540            ns = (5, 2, 0)
5541            batches = ((), (0,), (1,), (2,), (2, 1), (0, 2))
5542            for batch, n in product(batches, ns):
5543                yield make_arg(batch + (n, n)), make_arg(batch + (n, rhs))
5544            # Shapes to exercise all the paths
5545            shapes = ((1, 64), (2, 128), (1025, 2))
5546            for b, n in shapes:
5547                yield make_arg((b, n, n)), make_arg((b, n, rhs))
5548
5549
5550        for A, B in gen_matrices():
5551            LU, pivots = torch.linalg.lu_factor(A)
5552            for backend in backends:
5553                torch.backends.cuda.preferred_linalg_library(backend)
5554
5555                for left, adjoint in product((True, False), repeat=2):
5556                    B_left = B if left else B.mT
5557                    X = torch.linalg.lu_solve(LU, pivots, B_left, left=left, adjoint=adjoint)
5558                    A_adj = A.mH if adjoint else A
5559                    if left:
5560                        self.assertEqual(B_left, A_adj @ X)
5561                    else:
5562                        self.assertEqual(B_left, X @ A_adj)
5563
5564
5565    @onlyCPU
5566    @dtypes(*floating_and_complex_types())
5567    def test_linalg_lu_cpu_errors(self, device, dtype):
5568        # Square tests
5569        sample = torch.randn(3, 2, 2, device=device, dtype=dtype)
5570        B = torch.randn(3, 2, 2, device=device, dtype=dtype)
5571        LU, pivots = torch.linalg.lu_factor(sample)
5572
5573        # This should run without issues
5574        torch.linalg.lu_solve(LU, pivots, B, adjoint=True)
5575        torch.lu_unpack(LU, pivots)
5576
5577        pivots[0] = 0
5578        with self.assertRaisesRegex(RuntimeError, r"greater or equal to 1"):
5579            torch.linalg.lu_solve(LU, pivots, B, adjoint=True)
5580        with self.assertRaisesRegex(RuntimeError, r"between 1 and LU.size\(-2\)."):
5581            torch.lu_unpack(LU, pivots)
5582
5583        pivots[0] = 3
5584        with self.assertRaisesRegex(RuntimeError, r"smaller or equal to LU.size\(-2\)"):
5585            torch.linalg.lu_solve(LU, pivots, B, adjoint=True)
5586        with self.assertRaisesRegex(RuntimeError, r"between 1 and LU.size\(-2\)."):
5587            torch.lu_unpack(LU, pivots)
5588
5589        # Rectangular tests
5590        sample = torch.randn(3, 4, 2, device=device, dtype=dtype)
5591        B = torch.randn(3, 4, 2, device=device, dtype=dtype)
5592        LU, pivots = torch.linalg.lu_factor(sample)
5593
5594        # This should run without issues
5595        torch.lu_unpack(LU, pivots)
5596
5597        pivots[0] = 0
5598        with self.assertRaisesRegex(RuntimeError, r"between 1 and LU.size\(-2\)."):
5599            torch.lu_unpack(LU, pivots)
5600
5601        pivots[0] = 5
5602        with self.assertRaisesRegex(RuntimeError, r"between 1 and LU.size\(-2\)."):
5603            torch.lu_unpack(LU, pivots)
5604
5605
5606        # Rectangular tests
5607        sample = torch.randn(2, 3, 5, device=device, dtype=dtype)
5608        B = torch.randn(2, 3, 5, device=device, dtype=dtype)
5609        LU, pivots = torch.linalg.lu_factor(sample)
5610
5611        # This should run without issues
5612        torch.lu_unpack(LU, pivots)
5613
5614        pivots[0] = 0
5615        with self.assertRaisesRegex(RuntimeError, r"between 1 and LU.size\(-2\)."):
5616            torch.lu_unpack(LU, pivots)
5617
5618        pivots[0] = 4
5619        with self.assertRaisesRegex(RuntimeError, r"between 1 and LU.size\(-2\)."):
5620            torch.lu_unpack(LU, pivots)
5621
5622
5623    @skipCPUIfNoLapack
5624    @skipCUDAIfNoMagma
5625    @dtypes(torch.double)
5626    def test_lu_unpack_check_input(self, device, dtype):
5627        x = torch.rand(5, 5, 5, device=device, dtype=dtype)
5628        lu_data, lu_pivots = torch.linalg.lu_factor(x)
5629
5630        with self.assertRaisesRegex(RuntimeError, "torch.int32 dtype"):
5631            torch.lu_unpack(lu_data, lu_pivots.long())
5632
5633        # check that onces flags are unset, Nones are returned
5634        p, l, u = torch.lu_unpack(lu_data, lu_pivots, unpack_data=False)
5635        self.assertTrue(l.numel() == 0 and u.numel() == 0)
5636        p, l, u = torch.lu_unpack(lu_data, lu_pivots, unpack_pivots=False)
5637        self.assertTrue(p.numel() == 0)
5638        p, l, u = torch.lu_unpack(lu_data, lu_pivots, unpack_data=False, unpack_pivots=False)
5639        self.assertTrue(p.numel() == 0 and l.numel() == 0 and u.numel() == 0)
5640
5641    @skipCUDAIfNoMagma
5642    @skipCPUIfNoLapack
5643    @dtypes(torch.double)
5644    def test_lobpcg_basic(self, device, dtype):
5645        self._test_lobpcg_method(device, dtype, 'basic')
5646
5647    @skipCUDAIfNoCusolver
5648    @skipCPUIfNoLapack
5649    @dtypes(torch.double)
5650    def test_lobpcg_ortho(self, device, dtype):
5651        if torch.version.hip:
5652            torch.backends.cuda.preferred_linalg_library('magma')
5653        self._test_lobpcg_method(device, dtype, 'ortho')
5654        if torch.version.hip:
5655            torch.backends.cuda.preferred_linalg_library('default')
5656
5657    def _test_lobpcg_method(self, device, dtype, method):
5658        from torch.testing._internal.common_utils import random_symmetric_pd_matrix, random_sparse_pd_matrix
5659        from torch._linalg_utils import matmul, qform
5660        from torch._lobpcg import lobpcg
5661
5662        def test_tracker(worker):
5663            k = worker.iparams['k']
5664            nc = worker.ivars['converged_count']
5665            if k <= nc:
5666                tol = worker.fparams['tol']
5667                rerr = worker.tvars['rerr']
5668                X = worker.X
5669                E = worker.E
5670                B = worker.B
5671                A = worker.A
5672                dtype = X.dtype
5673                device = X.device
5674
5675                # Check convergence
5676                self.assertLessEqual(rerr[:k].max(), tol)
5677
5678                # Check B-orthogonality
5679                I = torch.eye(k, k, dtype=dtype, device=device)
5680                self.assertEqual(qform(B, X[:, :k]), I)
5681
5682                # Check block equation
5683                self.assertEqual(qform(A, X[:, :k]) / E[:k], I, atol=0.2, rtol=0)
5684
5685        orig_lobpcg = lobpcg
5686
5687        def lobpcg(*args, **kwargs):
5688            kwargs['tracker'] = test_tracker
5689            kwargs['niter'] = 1000
5690            kwargs['method'] = method
5691            kwargs['tol'] = 1e-8
5692            return orig_lobpcg(*args, **kwargs)
5693        prec = 5e-4
5694
5695        # check dense input
5696        mm = torch.matmul
5697        for batches in [(), (2,), (2, 3)]:
5698            for m, n, k in [
5699                    (9, 3, 1),
5700                    (9, 3, 2),
5701                    (9, 2, 2),
5702                    (100, 15, 5),
5703            ]:
5704                # skip tests that are known to fail with the basic
5705                # LOBPCG method due to calling cholesky on singular
5706                # input
5707                if method == 'basic' and (m, n, k) in [(9, 2, 2), (100, 15, 5)]:
5708                    continue
5709                A = random_symmetric_pd_matrix(m, *batches, device=device, dtype=dtype)
5710                B = random_symmetric_pd_matrix(m, *batches, device=device, dtype=dtype)
5711
5712                # classical eigenvalue problem, smallest eigenvalues
5713                E, V = lobpcg(A, k=k, n=n, largest=False)
5714                self.assertEqual(E.shape, batches + (k,))
5715                self.assertEqual(V.shape, batches + (m, k))
5716                self.assertEqual(matmul(A, V), mm(V, E.diag_embed()), atol=prec, rtol=0)
5717                e = torch.linalg.eigvalsh(A)
5718                e_smallest = e[..., :k]
5719                self.assertEqual(E, e_smallest)
5720
5721                # classical eigenvalue problem, largest eigenvalues
5722                E, V = lobpcg(A, k=k, n=n, largest=True)
5723                e_largest, _ = torch.sort(e[..., -k:], descending=True)
5724                self.assertEqual(E, e_largest, atol=prec, rtol=0)
5725                self.assertEqual(matmul(A, V), mm(V, E.diag_embed()), atol=prec, rtol=0)
5726
5727                # generalized eigenvalue problem, smallest eigenvalues
5728                E, V = lobpcg(A, B=B, k=k, n=n, largest=False)
5729                self.assertEqual(matmul(A, V), mm(matmul(B, V), E.diag_embed()), atol=prec, rtol=0)
5730
5731                # generalized eigenvalue problem, largest eigenvalues
5732                E, V = lobpcg(A, B=B, k=k, n=n, largest=True)
5733                self.assertEqual(matmul(A, V) / E.max(), mm(matmul(B, V), (E / E.max()).diag_embed()),
5734                                 atol=prec, rtol=0)
5735
5736        # check sparse input
5737        for m, n, k, density in [
5738                (5, 1, 1, 0.8),
5739                (9, 3, 2, 0.5),
5740                (100, 1, 1, 0.1),
5741                (1000, 7, 3, 0.01),
5742        ]:
5743            # skip tests that are known to fail with the basic LOBCG
5744            # method due to insufficient accuracy
5745            if method == 'basic' and (m, n, k, density) in [(1000, 7, 3, 0.01)]:
5746                continue
5747            A = random_sparse_pd_matrix(m, density=density, device=device, dtype=dtype)
5748            B = random_sparse_pd_matrix(m, density=density, device=device, dtype=dtype)
5749            A_eigenvalues = torch.arange(1, m + 1, dtype=dtype) / m
5750            e_smallest = A_eigenvalues[..., :k]
5751            e_largest, _ = torch.sort(A_eigenvalues[..., -k:], descending=True)
5752
5753            # classical eigenvalue problem, smallest eigenvalues
5754            E, V = lobpcg(A, k=k, n=n, largest=False)
5755            self.assertEqual(E, e_smallest)
5756            self.assertEqual(matmul(A, V), mm(V, E.diag_embed()), atol=prec, rtol=0)
5757
5758            # classical eigenvalue problem, largest eigenvalues
5759            E, V = lobpcg(A, k=k, n=n, largest=True)
5760            self.assertEqual(matmul(A, V), mm(V, E.diag_embed()), atol=prec, rtol=0)
5761            self.assertEqual(E, e_largest)
5762
5763            # generalized eigenvalue problem, smallest eigenvalues
5764            E, V = lobpcg(A, B=B, k=k, n=n, largest=False)
5765            self.assertEqual(matmul(A, V), matmul(B, mm(V, E.diag_embed())), atol=prec, rtol=0)
5766
5767            # generalized eigenvalue problem, largest eigenvalues
5768            E, V = lobpcg(A, B=B, k=k, n=n, largest=True)
5769            self.assertEqual(matmul(A, V) / E.max(), mm(matmul(B, V), (E / E.max()).diag_embed()),
5770                             atol=prec, rtol=0)
5771
5772    @skipCPUIfNoLapack
5773    @onlyCPU
5774    @dtypes(torch.double)
5775    def test_lobpcg_torchscript(self, device, dtype):
5776        from torch.testing._internal.common_utils import random_sparse_pd_matrix
5777        from torch._linalg_utils import matmul as mm
5778
5779        lobpcg = torch.jit.script(torch.lobpcg)
5780
5781        m = 500
5782        k = 5
5783        A1 = random_sparse_pd_matrix(m, density=2.0 / m, device=device, dtype=dtype)
5784        X1 = torch.randn((m, k), dtype=dtype, device=device)
5785        E1, V1 = lobpcg(A1, X=X1)
5786        eq_err = torch.norm((mm(A1, V1) - V1 * E1), 2) / E1.max()
5787        self.assertLess(eq_err, 1e-6)
5788
5789    @unittest.skipIf(not TEST_SCIPY or (TEST_SCIPY and scipy.__version__ < '1.4.1'), "Scipy not found or older than 1.4.1")
5790    @skipCPUIfNoLapack
5791    @skipIfTorchDynamo("fails in tracing scipy.sparse.lobpcg")
5792    @onlyCPU
5793    @dtypes(torch.double)
5794    def test_lobpcg_scipy(self, device, dtype):
5795        """Compare torch and scipy.sparse.linalg implementations of lobpcg
5796        """
5797        import time
5798        from torch.testing._internal.common_utils import random_sparse_pd_matrix
5799        from torch._linalg_utils import matmul as mm
5800        from scipy.sparse.linalg import lobpcg as scipy_lobpcg
5801        import scipy.sparse
5802
5803        def toscipy(A):
5804            if A.layout == torch.sparse_coo:
5805                values = A.coalesce().values().cpu().numpy().copy()
5806                indices = A.coalesce().indices().cpu().numpy().copy()
5807                return scipy.sparse.coo_matrix((values, (indices[0], indices[1])), A.shape)
5808            return A.cpu().numpy().copy()
5809
5810        niter = 1000
5811        repeat = 10
5812        m = 500   # size of the square matrix
5813        k = 7     # the number of requested eigenpairs
5814        A1 = random_sparse_pd_matrix(m, density=2.0 / m, device=device, dtype=dtype)
5815        B1 = random_sparse_pd_matrix(m, density=2.0 / m, device=device, dtype=dtype)
5816        X1 = torch.randn((m, k), dtype=dtype, device=device)
5817
5818        A2 = toscipy(A1)
5819        B2 = toscipy(B1)
5820        X2 = toscipy(X1)
5821
5822        lambdas1 = []
5823
5824        def tracker(worker):
5825            lambdas1.append(worker.E[:])
5826
5827        tol = 1e-8
5828        # tol for scipy lobpcg will be choosed so that the number of
5829        # iterations will be equal or very close to pytorch lobpcg
5830        # (that is around 170-180)
5831
5832        # Standard eigenvalue problem
5833        E1, V1 = torch.lobpcg(A1, X=X1, niter=niter, largest=True, tracker=tracker, tol=tol)
5834        E2, V2, lambdas2 = scipy_lobpcg(A2, X2, maxiter=niter, largest=True, retLambdaHistory=True, tol=1.1 * tol)
5835        iters1 = len(lambdas1)
5836        iters2 = len(lambdas2)
5837        self.assertLess(abs(iters1 - iters2), 0.05 * max(iters1, iters2))
5838
5839        E2a, V2a = scipy_lobpcg(A2, X2, maxiter=niter, largest=False)
5840
5841        eq_err = torch.norm((mm(A1, V1) - V1 * E1), 2) / E1.max()
5842        eq_err_scipy = (abs(A2.dot(V2) - V2 * E2)**2).sum() ** 0.5 / E2.max()
5843        self.assertLess(eq_err, 1e-6)        # std
5844        self.assertLess(eq_err_scipy, 1e-6)  # std
5845
5846        self.assertEqual(E1, torch.from_numpy(E2.copy()))
5847
5848        # Generalized eigenvalue problem
5849        lambdas1 = []
5850
5851        def tracker(worker):
5852            lambdas1.append(worker.E[:])
5853
5854        E1, V1 = torch.lobpcg(A1, B=B1, X=X1, niter=niter, largest=True, tracker=tracker, tol=tol)
5855        E2, V2, lambdas2 = scipy_lobpcg(A2, X2, B=B2, maxiter=niter, largest=True, retLambdaHistory=True, tol=39 * tol)
5856        E2a, V2a = scipy_lobpcg(A2, X2, B=B2, maxiter=niter, largest=False)
5857        iters1 = len(lambdas1)
5858        iters2 = len(lambdas2)
5859        self.assertLess(abs(iters1 - iters2), 0.05 * max(iters1, iters2))
5860
5861        eq_err = torch.norm((mm(A1, V1) - mm(B1, V1) * E1), 2) / E1.max()
5862        eq_err_scipy = (abs(A2.dot(V2) - B2.dot(V2) * E2)**2).sum() ** 0.5 / E2.max()
5863        self.assertLess(eq_err, 1e-6)        # general
5864        self.assertLess(eq_err_scipy, 1e-6)  # general
5865
5866        self.assertEqual(E1, torch.from_numpy(E2.copy()))
5867
5868        # Timings
5869        elapsed_ortho = 0
5870        elapsed_ortho_general = 0
5871        elapsed_scipy = 0
5872        elapsed_general_scipy = 0
5873        for i in range(repeat):
5874            start = time.time()
5875            torch.lobpcg(A1, X=X1, niter=niter, method='ortho', tol=tol)
5876            end = time.time()
5877            elapsed_ortho += end - start
5878
5879            start = time.time()
5880            torch.lobpcg(A1, X=X1, B=B1, niter=niter, method='ortho', tol=tol)
5881            end = time.time()
5882            elapsed_ortho_general += end - start
5883
5884            start = time.time()
5885            scipy_lobpcg(A2, X2, maxiter=niter, tol=1.1 * tol)
5886            end = time.time()
5887            elapsed_scipy += end - start
5888
5889            start = time.time()
5890            scipy_lobpcg(A2, X2, B=B2, maxiter=niter, tol=39 * tol)
5891            end = time.time()
5892            elapsed_general_scipy += end - start
5893
5894        elapsed_ortho_ms = 1000.0 * elapsed_ortho / repeat
5895        elapsed_ortho_general_ms = 1000.0 * elapsed_ortho_general / repeat
5896        elapsed_scipy_ms = 1000.0 * elapsed_scipy / repeat
5897        elapsed_general_scipy_ms = 1000.0 * elapsed_general_scipy / repeat
5898
5899        print(f'''
5900CPU timings: torch.lobpcg vs scipy.sparse.linalg.lobpcg
5901-------------------------------------------------------
5902              | standard    | generalized | method
5903torch.lobpcg  | {elapsed_ortho_ms:10.2f}  | {elapsed_ortho_general_ms:10.2f}  | ortho
5904scipy_lobpcg  | {elapsed_scipy_ms:10.2f}  | {elapsed_general_scipy_ms:10.2f}  | N/A
5905-(input size: {m:4}, eigenpairs:{k:2}, units: ms per call)-
5906        ''')
5907
5908        # Handling of very small tolerence
5909        tol = 1e-100
5910
5911        lambdas1 = []
5912
5913        def tracker(worker):
5914            lambdas1.append(worker.E[:])
5915
5916        E1, V1 = torch.lobpcg(A1, X=X1, niter=niter, largest=True, tracker=tracker, tol=tol)
5917        iters1 = len(lambdas1)
5918        eq_err = torch.norm((mm(A1, V1) - V1 * E1), 2) / E1.max()
5919
5920        try:
5921            E2, V2, lambdas2 = scipy_lobpcg(A2, X2, maxiter=niter, largest=True, retLambdaHistory=True, tol=tol)
5922            iters2 = len(lambdas2)
5923            eq_err_scipy = (abs(A2.dot(V2) - V2 * E2)**2).sum() ** 0.5 / E2.max()
5924        except Exception as msg:
5925            print('Calling scipy_lobpcg failed [standard]:', msg)
5926            iters2 = -1
5927            eq_err_scipy = -1
5928
5929        lambdas1 = []
5930
5931        def tracker(worker):
5932            lambdas1.append(worker.E[:])
5933
5934        E1, V1 = torch.lobpcg(A1, X=X1, B=B1, niter=niter, largest=True, tracker=tracker, tol=tol)
5935        iters1_general = len(lambdas1)
5936        eq_err_general = torch.norm((mm(A1, V1) - mm(B1, V1) * E1), 2) / E1.max()
5937
5938        try:
5939            E2, V2, lambdas2 = scipy_lobpcg(A2, X2, B=B2, maxiter=niter, largest=True, retLambdaHistory=True, tol=tol)
5940            iters2_general = len(lambdas2)
5941            eq_err_general_scipy = (abs(A2.dot(V2) - B2.dot(V2) * E2)**2).sum() ** 0.5 / E2.max()
5942        except Exception as msg:
5943            print('Calling scipy_lobpcg failed [generalized]:', msg)
5944            iters2_general = -1
5945            eq_err_general_scipy = -1
5946
5947        print(f'''\
5948Handling of small tol={tol:6.0e}: torch.lobpcg vs scipy.sparse.linalg.lobpcg
5949----------------------------------------------------------------------------
5950              | standard    | generalized |  niter | method
5951torch.lobpcg  | {eq_err:10.2e}  | {eq_err_general:10.2e}  | {iters1:6} | ortho
5952scipy_lobpcg  | {eq_err_scipy:10.2e}  | {eq_err_general_scipy:10.2e}  | {iters2:6} | N/A
5953---(input size: {m:4}, eigenpairs:{k:2}, units: relative error, maxiter={niter:4})---
5954''')
5955
5956    def _test_addmm_addmv(self, f, t, m, v, *, alpha=None, beta=None, transpose_out=False, activation=None):
5957        dtype = t.dtype
5958        numpy_dtype = dtype
5959        if dtype in {torch.bfloat16, torch.half}:
5960            numpy_dtype = torch.float
5961        if dtype.is_complex:
5962            alpha = 0.9 + 0.3j if alpha is None else alpha
5963            beta = 0.5 + 0.6j if beta is None else beta
5964        else:
5965            alpha = 1.2 if alpha is None else alpha
5966            beta = 0.8 if beta is None else beta
5967        if activation == "gelu":
5968            res1 = f(t, m, v, alpha=alpha, beta=beta, use_gelu=True)
5969        else:
5970            res1 = f(t, m, v, alpha=alpha, beta=beta)
5971        res2 = torch.full_like(res1, math.nan)
5972        if transpose_out:
5973            res2 = res2.t().clone(memory_format=torch.contiguous_format).t()
5974        if activation == "gelu":
5975            f(t, m, v, alpha=alpha, beta=beta, out=res2, use_gelu=True)
5976        else:
5977            f(t, m, v, alpha=alpha, beta=beta, out=res2)
5978        res3 = alpha * (m.to(numpy_dtype).cpu().numpy() @ v.to(numpy_dtype).cpu().numpy())
5979        if beta != 0:
5980            res3 += (beta * t).to(numpy_dtype).cpu().numpy()
5981        if activation == "relu":
5982            res3 = res3 * (res3 > 0)
5983        elif activation == "gelu":
5984            res3_t = torch.from_numpy(res3).to(dtype)
5985            approximate = "tanh" if t.is_cuda else "none"
5986            res3_t = torch.nn.functional.gelu(res3_t, approximate=approximate)
5987            res3 = res3_t.to(numpy_dtype).cpu().numpy()
5988        else:
5989            assert activation is None, f"unsupported activation {activation}"
5990        res3 = torch.from_numpy(res3).to(dtype)
5991        self.assertEqual(res1, res2)
5992        self.assertEqual(res1, res3)
5993
5994    @precisionOverride({torch.bfloat16: 1e-0, torch.half: 1e-3, torch.float: 1e-4, torch.double: 1e-8,
5995                        torch.cfloat: 1e-4, torch.cdouble: 1e-8})
5996    @dtypesIfCUDA(*floating_and_complex_types_and(
5997                  *[torch.bfloat16] if TEST_WITH_ROCM or SM53OrLater else [],
5998                  torch.half))
5999    @dtypes(torch.bfloat16, torch.half, torch.float, torch.double, torch.cfloat, torch.cdouble)
6000    def test_addmv(self, device, dtype):
6001        if IS_ARM64 and device == 'cpu' and dtype == torch.float16:
6002            raise unittest.SkipTest("Fails on ARM, see https://github.com/pytorch/pytorch/issues/125438")
6003        # have to use torch.randn(...).to(bfloat16) instead of
6004        # torch.randn(..., dtype=bfloat16). randn does not support
6005        # bfloat16 yet.
6006        # "*0.2" to reduce errors for low precision
6007        ts = [
6008            0.2 * torch.randn(50, device=device).to(dtype),
6009            0.2 * torch.randn(1, device=device).to(dtype).expand(50),
6010        ]
6011        vs = [
6012            0.2 * torch.randn(100, device=device).to(dtype),
6013            0.2 * torch.ones(1, device=device).to(dtype).expand(100),  # to reduce errors for low precision
6014        ]
6015        ms = [
6016            # 0d
6017            0.2 * torch.ones((), device=device).to(dtype).expand(50, 100),  # to reduce errors for low precision
6018            # 1d
6019            0.2 * torch.randn((1, 100), device=device).to(dtype).expand(50, 100),
6020            # this initialization reduces errors for low precision for broadcasted matrices
6021            # by making sure that intermediate and result values are exactly representable
6022            # in low precision type
6023            0.2 * torch.randint(3, (50, 1), dtype=torch.float, device=device).to(dtype).expand(50, 100),
6024            # 2d
6025            0.2 * torch.randn((50, 100), device=device).to(dtype),
6026            0.2 * torch.randn((100, 50), device=device).to(dtype).t(),
6027        ]
6028        for m, v, t in itertools.product(ms, vs, ts):
6029            self._test_addmm_addmv(torch.addmv, t, m, v)
6030        # Test beta=0, t=nan
6031        t = torch.full((50,), math.nan, device=device).to(dtype)
6032        for m, v in itertools.product(ms, vs):
6033            self._test_addmm_addmv(torch.addmv, t, m, v, beta=0)
6034
6035    @dtypesIfCUDA(*floating_types_and(*[torch.bfloat16] if TEST_WITH_ROCM or
6036                  SM53OrLater else []))
6037    @dtypes(torch.float, torch.double)
6038    def test_addmv_rowmajor_colmajor_incx_incy_lda(self, device, dtype):
6039        # tests (o, s)*(s).  o is output size, s is summed size.
6040        o = 5
6041        s = 3
6042        a_data = torch.arange(1, o * s + 1, device=device, dtype=dtype).view(o, s)
6043        x_data = torch.arange(1, s + 1, 1, device=device, dtype=dtype)
6044        y_data = torch.ones(o, device=device, dtype=dtype)
6045        control = torch.tensor([15., 33., 51., 69., 87.], device=device, dtype=dtype)
6046
6047        def _test(row_major, incx, incy, lda_tail):
6048            if row_major:
6049                a_storage = torch.full((o, s + lda_tail), float('nan'), device=device, dtype=dtype)
6050            else:
6051                a_storage = torch.full((s, o + lda_tail), float('nan'), device=device, dtype=dtype).permute(1, 0)
6052            a = a_storage[:o, :s].copy_(a_data)
6053
6054            x_storage = torch.full((s, incx), float('nan'), device=device, dtype=dtype)
6055            x = x_storage[:, 0].copy_(x_data)
6056
6057            y_storage = torch.full((o, incy), float('nan'), device=device, dtype=dtype)
6058            y = y_storage[:, 0].copy_(y_data)
6059
6060            self._test_addmm_addmv(torch.addmv, y, a, x)
6061
6062        for row_major, incx, incy, lda_tail in itertools.product((False, True), (1, 2), (1, 2), (0, 1)):
6063            _test(row_major, incx, incy, lda_tail)
6064
6065    def _test_addmm_impl(self, func, activation, device, dtype):
6066        M = torch.randn(10, 25, device=device).to(dtype)
6067        m1 = torch.randn(10, 50, device=device).to(dtype)
6068        m2 = torch.randn(50, 25, device=device).to(dtype)
6069        self._test_addmm_addmv(func, M, m1, m2, activation=activation)
6070
6071        # vector-shaped bias and beta=1 result in epilogue fusion in CUDA
6072        V = torch.randn(25, device=device).to(dtype)
6073        self._test_addmm_addmv(func, V, m1, m2, beta=1, activation=activation)
6074
6075        # Test 0-strided
6076        M = torch.randn(10, 1, device=device).to(dtype).expand(10, 25)
6077        m1 = torch.randn(10, 1, device=device).to(dtype).expand(10, 50)
6078        m2 = torch.randn(50, 25, device=device).to(dtype)
6079        self._test_addmm_addmv(func, M, m1, m2, activation=activation)
6080
6081        # Test beta=0, M=nan
6082        M = torch.full((10, 25), math.nan, device=device).to(dtype)
6083        m1 = torch.randn(10, 50, device=device).to(dtype)
6084        m2 = torch.randn(50, 25, device=device).to(dtype)
6085        self._test_addmm_addmv(func, M, m1, m2, beta=0, activation=activation)
6086
6087        # Test transpose
6088        for t1, t2, t3, t4 in itertools.product([True, False], repeat=4):
6089            def maybe_transpose(cond, m):
6090                if not cond:
6091                    return m
6092                return m.t().clone(memory_format=torch.contiguous_format).t()
6093
6094            M = maybe_transpose(t1, torch.randn(10, 25, device=device).to(dtype))
6095            m1 = maybe_transpose(t2, torch.randn(10, 50, device=device).to(dtype))
6096            m2 = maybe_transpose(t3, torch.randn(50, 25, device=device).to(dtype))
6097            self._test_addmm_addmv(func, M, m1, m2, transpose_out=t4, activation=activation)
6098
6099            if t1:
6100                # use vector V instead of matrix M for epilogue fusion in CUDA (doesn't depend on t1)
6101                self._test_addmm_addmv(func, V, m1, m2, beta=1, transpose_out=t4, activation=activation,)
6102
6103    @precisionOverride({torch.double: 1e-8, torch.float: 1e-4, torch.bfloat16: 0.6,
6104                        torch.half: 1e-1, torch.cfloat: 1e-4, torch.cdouble: 1e-8})
6105    @dtypesIfMPS(torch.float32)
6106    @dtypesIfCUDA(*floating_and_complex_types_and(
6107                  *[torch.bfloat16] if TEST_WITH_ROCM or SM53OrLater else []))
6108    @dtypes(*floating_and_complex_types_and(torch.bfloat16, torch.half))
6109    @tf32_on_and_off(0.05)
6110    @bf32_on_and_off(0.05)
6111    def test_addmm(self, device, dtype):
6112        self._test_addmm_impl(torch.addmm, None, device, dtype)
6113
6114    @precisionOverride({torch.double: 1e-8, torch.float: 1e-4, torch.bfloat16: 5e-2,
6115                        torch.half: 5e-2, torch.cfloat: 1e-4, torch.cdouble: 1e-8})
6116    @dtypesIfCUDA(*floating_types_and(
6117                  *[torch.bfloat16, torch.half] if TEST_WITH_ROCM or SM53OrLater else []))
6118    @dtypes(*floating_types_and(torch.bfloat16))
6119    @tf32_on_and_off(0.05)
6120    @bf32_on_and_off(0.05)
6121    def test_addmm_relu(self, device, dtype):
6122        self._test_addmm_impl(torch._addmm_activation, "relu", device, dtype)
6123
6124    @onlyCUDA
6125    @skipCUDAIfNotRocm
6126    @precisionOverride({torch.double: 1e-8, torch.float: 1e-4, torch.bfloat16: 5e-2,
6127                        torch.half: 5e-2, torch.cfloat: 1e-4, torch.cdouble: 1e-8})
6128    @dtypesIfCUDA(*floating_types_and(
6129                  *[torch.bfloat16, torch.half] if TEST_WITH_ROCM or SM53OrLater else []))
6130    @dtypes(*floating_types_and(torch.bfloat16))
6131    @tf32_on_and_off(0.05)
6132    @bf32_on_and_off(0.05)
6133    def test_addmm_relu_tunableop_rocm(self, device, dtype):
6134        torch.cuda.tunable.enable(True)
6135        ordinal = torch.cuda.current_device()
6136        filename = f"tunableop_results{ordinal}.csv"
6137        torch.cuda.tunable.set_filename(filename)
6138        iterations = torch.cuda.tunable.get_max_tuning_iterations()
6139        torch.cuda.tunable.set_max_tuning_iterations(10)
6140        self._test_addmm_impl(torch._addmm_activation, "relu", device, dtype)
6141        # clean up, remove any file that was generated
6142        try:
6143            import os
6144            os.remove(filename)
6145        except FileNotFoundError:
6146            pass
6147        # reset back to prior settings
6148        torch.cuda.tunable.set_max_tuning_iterations(iterations)
6149        torch.cuda.tunable.enable(False)
6150
6151    @precisionOverride({torch.double: 1e-8, torch.float: 1e-4, torch.bfloat16: 5e-2,
6152                        torch.half: 5e-2, torch.cfloat: 1e-4, torch.cdouble: 1e-8})
6153    @dtypesIfCUDA(*floating_types_and(
6154                  *[torch.bfloat16, torch.half] if TEST_WITH_ROCM or SM53OrLater else []))
6155    @dtypes(*floating_types_and(torch.bfloat16))
6156    @tf32_on_and_off(0.05)
6157    @bf32_on_and_off(0.05)
6158    def test_addmm_gelu(self, device, dtype):
6159        self._test_addmm_impl(torch._addmm_activation, "gelu", device, dtype)
6160
6161    @dtypes(torch.float, torch.double)
6162    @dtypesIfCUDA(*floating_and_complex_types())
6163    @tf32_on_and_off(0.005)
6164    @bf32_on_and_off(0.005)
6165    def test_addmm_sizes(self, device, dtype):
6166        for m in [0, 1, 25]:
6167            for n in [0, 1, 10]:
6168                for k in [0, 1, 8]:
6169                    M = torch.randn(n, m, device=device).to(dtype)
6170                    m1 = torch.randn(n, k, device=device).to(dtype)
6171                    m2 = torch.randn(k, m, device=device).to(dtype)
6172                    self._test_addmm_addmv(torch.addmm, M, m1, m2)
6173
6174                    m1 = torch.randn(n, k + 1, device=device).to(dtype)
6175                    m2 = torch.randn(k, m, device=device).to(dtype)
6176                    self.assertRaisesRegex(RuntimeError, f"{n}x{k + 1}.*{k}x{m}", lambda: torch.addmm(M, m1, m2))
6177                    self.assertRaisesRegex(RuntimeError, f"{n}x{k + 1}.*{k}x{m}", lambda: torch.mm(m1, m2))
6178
6179    @dtypes(torch.half)
6180    @onlyCUDA
6181    def test_addmm_baddbmm_overflow(self, device, dtype):
6182        orig = torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction
6183        torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
6184        inp = torch.zeros(128, 128, dtype=torch.half, device=device)
6185        mat1 = torch.ones(128, 1000, dtype=torch.half, device=device) * 100
6186        mat2 = torch.ones(1000, 128, dtype=torch.half, device=device) * 100
6187        out = torch.addmm(inp, mat1, mat2, alpha=0.001, beta=0.)
6188        # just check for no overflow on ROCM
6189        if TEST_WITH_ROCM:
6190            self.assertFalse(out.isinf().any())
6191        else:
6192            self.assertTrue((out == 10000.).all())
6193        inp = torch.zeros(3, 128, 128, dtype=torch.half, device=device)
6194        mat1 = torch.ones(3, 128, 1000, dtype=torch.half, device=device) * 100
6195        mat2 = torch.ones(3, 1000, 128, dtype=torch.half, device=device) * 100
6196        out = torch.baddbmm(inp, mat1, mat2, alpha=0.001, beta=0.)
6197        if TEST_WITH_ROCM:
6198            self.assertFalse(out.isinf().any())
6199        else:
6200            self.assertTrue((out == 10000.).all())
6201        torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = orig
6202
6203    @dtypes(torch.float)
6204    def test_baddbmm_nan_input_with_zero_beta(self, device, dtype):
6205        for shape in [[3, 2, 2], [2, 20, 20]]:
6206            mat1, mat2 = (torch.randn(shape, dtype=dtype, device=device) for _ in range(2))
6207            inputs = [torch.randn(shape, dtype=dtype, device=device),
6208                      torch.randn(shape, dtype=dtype, device=device).fill_(torch.nan)]
6209            outs = [None, torch.randn(shape, dtype=dtype, device=device),
6210                    torch.randn(shape, dtype=dtype, device=device).fill_(torch.nan)]
6211            options = itertools.product(inputs, outs)
6212            for input, out in options:
6213                y_ref = torch.bmm(mat1, mat2)
6214                y = torch.baddbmm(input, mat1, mat2, beta=0.0, out=out)
6215                self.assertEqual(y_ref, y)
6216
6217    @dtypes(torch.int16, torch.int32, torch.int64, torch.float16, torch.float32, torch.float64)
6218    def test_baddbmm_input_dtypes_compatibility(self, device, dtype):
6219        batch1 = torch.rand((1, 2, 2), dtype=torch.float32, device=device)
6220        batch2 = torch.rand((1, 2, 2), dtype=torch.float32, device=device)
6221        input_tensor = torch.rand((1, 2, 2), device=device).to(dtype)
6222        if dtype != torch.float32:
6223            with self.assertRaisesRegex(RuntimeError, "Input dtypes must be the same"):
6224                y = torch.baddbmm(input_tensor, batch1, batch2, beta=0.0)
6225        else:
6226            out = torch.randn((1, 2, 2), dtype=dtype, device=device).fill_(torch.nan)
6227            y_ref = torch.bmm(batch1, batch2)
6228            y = torch.baddbmm(input_tensor, batch1, batch2, beta=0.0, out=out)
6229            self.assertEqual(out, y_ref)
6230
6231
6232    @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error")
6233    @onlyCUDA
6234    def test_matmul_45724(self, device):
6235        # https://github.com/pytorch/pytorch/issues/45724
6236        a = torch.rand(65537, 22, 64, device=device, dtype=torch.half)
6237        b = torch.rand(65537, 64, 22, device=device, dtype=torch.half)
6238        c = torch.full((65537, 22, 22), math.nan, dtype=torch.half, device=device)
6239        cpu_result = torch.matmul(a.cpu().float(), b.cpu().float()).cuda().half()
6240        torch.matmul(a, b, out=c)
6241        self.assertEqual(c, cpu_result)
6242
6243    @unittest.skipIf(IS_WINDOWS, "Skipped on Windows!")
6244    @unittest.skipIf(SM90OrLater and not TEST_WITH_ROCM, "Expected failure on sm90")
6245    @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error")
6246    @onlyCUDA
6247    @parametrize("k", [16, 32])
6248    @parametrize("n", [16, 32])
6249    @parametrize("use_transpose_a", [True, False])
6250    @parametrize("use_transpose_b", [True, False])
6251    def test__int_mm(self, device, k, n, use_transpose_a, use_transpose_b):
6252        def genf_int_float(x, y, use_transpose):
6253            if use_transpose:
6254                x, y = y, x
6255            x_int8 = torch.randint(-10, 10, (x, y), dtype=torch.int8, device=device)
6256            x_float = x_int8.to(torch.float32)
6257            if use_transpose:
6258                return x_int8.t(), x_float.t()
6259            return x_int8, x_float
6260
6261        def _test(m, k, n, transpose_a, transpose_b, test_equal=True):
6262            a_int8, a_float = genf_int_float(m, k, transpose_a)
6263            b_int8, b_float = genf_int_float(k, n, transpose_b)
6264            c_int32 = torch._int_mm(a_int8, b_int8)
6265            self.assertTrue(c_int32.dtype is torch.int32)
6266            self.assertEqual(c_int32.device, torch.device(device))
6267            if test_equal:
6268                self.assertEqual(c_int32.float(), torch.mm(a_float, b_float))
6269            else:
6270                self.assertNotEqual(c_int32.float(), torch.mm(a_float, b_float))
6271            c_int32_result = c_int32.new_empty(c_int32.size())
6272            # Checking out variant
6273            torch._int_mm(a_int8, b_int8, out=c_int32_result)
6274            if test_equal:
6275                self.assertEqual(c_int32_result.float(), torch.mm(a_float, b_float))
6276            else:
6277                self.assertNotEqual(c_int32_result.float(), torch.mm(a_float, b_float))
6278
6279        # NOTE: We're just exercising terrible failures here.
6280        version = _get_torch_cuda_version()
6281        SM80OrLater = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 0)
6282        SM70 = torch.cuda.is_available() and torch.cuda.get_device_capability() == (7, 0)
6283        SM75 = torch.cuda.is_available() and torch.cuda.get_device_capability() == (7, 5)
6284
6285        if TEST_WITH_ROCM:
6286            _test(17, k, n, use_transpose_a, use_transpose_b, True)
6287        elif version >= (11, 7):
6288            if not use_transpose_a and use_transpose_b:
6289                if SM80OrLater or (version >= (12, 3) and (SM70 or SM75)):
6290                    _test(17, k, n, use_transpose_a, use_transpose_b, version > (11, 7))
6291                else:
6292                    with self.assertRaisesRegex(RuntimeError,
6293                                                "CUDA error: CUBLAS_STATUS_NOT_SUPPORTED when calling cublasLtMatmul"):
6294                        _test(17, k, n, use_transpose_a, use_transpose_b)
6295
6296            if use_transpose_a and not use_transpose_b:
6297                with self.assertRaisesRegex(RuntimeError,
6298                                            "CUDA error: CUBLAS_STATUS_NOT_SUPPORTED when calling cublasLtMatmul"):
6299                    _test(17, k, n, use_transpose_a, use_transpose_b)
6300
6301            if use_transpose_a and use_transpose_b:
6302                with self.assertRaisesRegex(RuntimeError,
6303                                            "CUDA error: CUBLAS_STATUS_NOT_SUPPORTED when calling cublasLtMatmul"):
6304                    _test(17, k, n, use_transpose_a, use_transpose_b)
6305
6306            if not use_transpose_a and not use_transpose_b:
6307                if SM80OrLater or (version >= (12, 3) and (SM70 or SM75)):
6308                    _test(17, k, n, use_transpose_a, use_transpose_b)
6309                else:
6310                    with self.assertRaisesRegex(RuntimeError,
6311                                                "CUDA error: CUBLAS_STATUS_NOT_SUPPORTED when calling cublasLtMatmul"):
6312                        _test(17, k, n, use_transpose_a, use_transpose_b)
6313        else:
6314            with self.assertRaisesRegex(RuntimeError, "_int_mm_out_cuda not compiled for CUDA"):
6315                _test(17, k, n, use_transpose_a, use_transpose_b, False)
6316
6317    @unittest.skipIf(IS_WINDOWS, "Skipped on Windows!")
6318    @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error")
6319    @onlyCUDA
6320    def test__int_mm_errors(self, device):
6321        if TEST_WITH_ROCM:
6322            self.skipTest("_int_mm not compiled for ROCM")
6323
6324        version = _get_torch_cuda_version()
6325        if version < (11, 7):
6326            self.skipTest("_int_mm only compiled for CUDA 11.7")
6327
6328        def genf_int(x, y):
6329            return torch.empty((x, y), dtype=torch.int8, device=device)
6330
6331        def _gen_pair(m, k, n):
6332            return genf_int(m, k), genf_int(k, n)
6333
6334        self.assertRaisesRegex(RuntimeError,
6335                               r"self.size\(0\) needs to be greater than 16, but got 16",
6336                               lambda: torch._int_mm(*_gen_pair(16, 8, 32)))
6337        self.assertRaisesRegex(RuntimeError,
6338                               r"self.size\(1\) needs to be greater than 0 and a multiple of 8, but got 7",
6339                               lambda: torch._int_mm(*_gen_pair(17, 7, 32)))
6340        self.assertRaisesRegex(RuntimeError,
6341                               r"self.size\(1\) needs to match mat2.size\(0\) but got 8 and 7",
6342                               lambda: torch._int_mm(genf_int(17, 8), genf_int(7, 32)))
6343        self.assertRaisesRegex(RuntimeError,
6344                               r"mat2.size\(1\) needs to be greater than 0 and a multiple of 8, but got 31",
6345                               lambda: torch._int_mm(*_gen_pair(17, 8, 31)))
6346        self.assertRaisesRegex(RuntimeError,
6347                               r"expected scalar type Char but found Float",
6348                               lambda: torch._int_mm(genf_int(17, 8).float(), genf_int(8, 32)))
6349        self.assertRaisesRegex(RuntimeError,
6350                               r"expected scalar type Char but found Float",
6351                               lambda: torch._int_mm(genf_int(17, 8), genf_int(8, 32).float()))
6352        self.assertRaisesRegex(RuntimeError,
6353                               r"Expected result dtype to be of type kInt but got float",
6354                               lambda: torch._int_mm(genf_int(17, 8), genf_int(8, 32), out=genf_int(16, 32).float()))
6355        self.assertRaisesRegex(RuntimeError,
6356                               r"Expected result.size\(0\) to be 17 but got 15",
6357                               lambda: torch._int_mm(genf_int(17, 8), genf_int(8, 32), out=genf_int(15, 32).int()))
6358        self.assertRaisesRegex(RuntimeError,
6359                               r"Expected result.size\(0\) to be 17 but got 16",
6360                               lambda: torch._int_mm(genf_int(17, 8), genf_int(8, 32), out=genf_int(16, 31).int()))
6361
6362    @onlyCPU
6363    @parametrize("m", [0, 8, 17])
6364    @parametrize("k", [0, 16, 32])
6365    @parametrize("n", [16, 32])
6366    @parametrize("use_transpose_a", [True, False])
6367    @parametrize("use_transpose_b", [True, False])
6368    @parametrize("non_contig_type", [0, 1, 2])
6369    def test__int_mm_cpu(self, device, m, k, n, use_transpose_a, use_transpose_b, non_contig_type):
6370        # non_contig_type:
6371        # 0: the whole data buffer is contiguous (can be transposed)
6372        # 1: stride of one dimension is 1, but the whole buffer is not contiguous
6373        # 2: Neither stride is 1
6374
6375        def genf_int_float(x, y, use_transpose, non_contig_type):
6376            if use_transpose:
6377                x, y = y, x
6378            if non_contig_type != 0:
6379                y = y * 2
6380            x_int8 = torch.randint(-10, 10, (x, y), dtype=torch.int8, device=device)
6381            x_float = x_int8.to(torch.float32)
6382            if non_contig_type == 1:
6383                x_int8 = x_int8[:, : y // 2]
6384                x_float = x_float[:, : y // 2]
6385            elif non_contig_type == 2:
6386                x_int8 = x_int8[:, ::2]
6387                x_float = x_float[:, ::2]
6388            if use_transpose:
6389                return x_int8.t(), x_float.t()
6390            return x_int8, x_float
6391
6392        if non_contig_type != 0 and (m == 0 or k == 0):
6393            return
6394        a_int8, a_float = genf_int_float(m, k, use_transpose_a, non_contig_type)
6395        b_int8, b_float = genf_int_float(k, n, use_transpose_b, non_contig_type)
6396        c_int32 = torch._int_mm(a_int8, b_int8)
6397        self.assertTrue(c_int32.dtype is torch.int32)
6398        self.assertEqual(c_int32.device, torch.device(device))
6399        self.assertEqual(c_int32.float(), torch.mm(a_float, b_float))
6400        c_int32_result = c_int32.new_empty(c_int32.size())
6401        # Checking out variant
6402        torch._int_mm(a_int8, b_int8, out=c_int32_result)
6403        self.assertEqual(c_int32_result.float(), torch.mm(a_float, b_float))
6404
6405    @unittest.skipIf(IS_WINDOWS, "Skipped on Windows!")
6406    @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error")
6407    @onlyNativeDeviceTypes
6408    def test__convert_weight_to_int4pack(self, device):
6409        # TODO: Fix https://github.com/pytorch/pytorch/issues/131425 and use OpInfo instead
6410        test_list = [((64, 32), 2), ((64, 48), 2), ((64, 64), 2), ((256, 128), 4), ((256, 128), 8)]
6411        if self.device_type == 'cuda' and not SM80OrLater:
6412            self.skipTest("requires SM80 or later")
6413
6414        if TEST_WITH_ROCM:
6415            if not CDNA2OrLater():
6416                self.skipTest("_int4_mm is supported only for CDNA2 or later")
6417
6418        torch.manual_seed(1)
6419        for shape, innerKTiles in test_list:
6420            b = torch.rand(shape, dtype=torch.bfloat16, device=device)
6421            b_uint8, _ = _group_quantize_tensor(b, n_bit=4, q_group_size=32)
6422            b_int4pack = torch._convert_weight_to_int4pack(b_uint8, innerKTiles=innerKTiles)
6423            b_int4pack_meta = torch._convert_weight_to_int4pack(b_uint8.to(device="meta"), innerKTiles=innerKTiles)
6424            self.assertEqual(b_int4pack.shape, b_int4pack_meta.shape)
6425
6426    @unittest.skipIf(IS_WINDOWS, "Skipped on Windows!")
6427    @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error")
6428    @onlyNativeDeviceTypes
6429    @parametrize("m", [32, 64])
6430    @parametrize("k", [32, 64])
6431    @parametrize("n", [48, 64])
6432    def test__int4_mm(self, device, m, k, n):
6433        if self.device_type == 'cuda' and not SM80OrLater:
6434            self.skipTest("requires SM80 or later")
6435
6436        if TEST_WITH_ROCM:
6437            if not CDNA2OrLater():
6438                self.skipTest("_int4_mm is supported only for CDNA2 or later")
6439
6440        q_group = 32
6441        inner_k_tiles = 2
6442
6443        torch.manual_seed(1)
6444        a_bf16 = torch.rand((m, k), dtype=torch.bfloat16, device=device)
6445        b_bf16 = torch.rand((k, n), dtype=torch.bfloat16, device=device)
6446
6447        def convert_weight_to_int4pack(b):
6448            b_uint8, b_scales_and_zeros = _group_quantize_tensor(
6449                b, n_bit=4, q_group_size=q_group
6450            )
6451            b_int4pack = torch._convert_weight_to_int4pack(
6452                b_uint8, inner_k_tiles
6453            )
6454
6455            return b_int4pack, b_scales_and_zeros
6456
6457        def weight_int4pack_mm(a, b_int4pack, b_scales_and_zeros):
6458            return torch._weight_int4pack_mm(
6459                a, b_int4pack, q_group, b_scales_and_zeros
6460            )
6461
6462        b_int4pack, b_scales_and_zeros_bf16 = convert_weight_to_int4pack(b_bf16)
6463
6464        for dtype in [torch.bfloat16] + ([torch.float16, torch.float32] if device == "cpu" else []):
6465            a = a_bf16.to(dtype=dtype)
6466            b = b_bf16.to(dtype=dtype)
6467            b_scales_and_zeros = b_scales_and_zeros_bf16.to(dtype=dtype)
6468            ref = torch.mm(a, b)
6469            res = weight_int4pack_mm(a, b_int4pack, b_scales_and_zeros)
6470
6471            mean_err = ((res - ref).abs() / ref).mean()
6472            self.assertTrue(mean_err < 0.05)
6473
6474
6475    @unittest.skipIf(IS_WINDOWS, "Skipped on Windows!")
6476    @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error")
6477    @onlyNativeDeviceTypes
6478    @parametrize("m", [32, 64])
6479    @parametrize("k", [32, 64])
6480    @parametrize("n", [48, 64])
6481    def test_compile_int4_mm(self, device, m, k, n):
6482        if self.device_type == 'cuda' and not SM80OrLater:
6483            self.skipTest("requires SM80 or later")
6484
6485        if TEST_WITH_ROCM:
6486            if not CDNA2OrLater():
6487                self.skipTest("_int4_mm is supported only for CDNA2 or later")
6488
6489        q_group = 32
6490        inner_k_tiles = 2
6491
6492        torch.manual_seed(1)
6493        a = torch.rand((m, k), dtype=torch.bfloat16, device=device)
6494        b = torch.rand((k, n), dtype=torch.bfloat16, device=device)
6495
6496        b_int32, b_scales_and_zeros = _group_quantize_tensor(
6497            b, n_bit=4, q_group_size=q_group
6498        )
6499
6500        @torch.compile
6501        def int4_mm(a, b_int32, b_scales_and_zeros):
6502            b_int4pack = torch._convert_weight_to_int4pack(
6503                b_int32, inner_k_tiles
6504            )
6505            return torch._weight_int4pack_mm(
6506                a, b_int4pack, q_group, b_scales_and_zeros
6507            )
6508
6509        res = int4_mm(a, b_int32, b_scales_and_zeros)
6510        ref = torch.mm(a, b)
6511
6512        mean_err = ((res - ref).abs() / ref).mean()
6513        self.assertTrue(mean_err < 0.05)
6514
6515    @onlyCPU
6516    @parametrize("m", [32, 64])
6517    @parametrize("k", [32, 64])
6518    @parametrize("n", [48, 64])
6519    def test__int8_mm(self, device, m, k, n):
6520        torch.manual_seed(1)
6521        a = torch.rand((m, k), dtype=torch.bfloat16, device=device)
6522        b = torch.rand((n, k), dtype=torch.bfloat16, device=device)
6523
6524        def convert_weight_to_int8pack(b):
6525            b_int8pack, b_scales, _ = _dynamically_quantize_per_channel(
6526                b, -128, 127, torch.int8
6527            )
6528            return b_int8pack, b_scales
6529
6530        def weight_int8pack_mm(a, b_int8pack, b_scales):
6531            return torch._weight_int8pack_mm(
6532                a, b_int8pack, b_scales
6533            )
6534
6535        b_int8pack, b_scales = convert_weight_to_int8pack(b)
6536        res = weight_int8pack_mm(a, b_int8pack, b_scales)
6537        ref = torch.mm(a, b.transpose(0, 1))
6538
6539        mean_err = ((res - ref).abs() / ref).mean()
6540        self.assertTrue(mean_err < 0.05)
6541
6542    @onlyCPU
6543    @parametrize("m", [32, 64])
6544    @parametrize("k", [32, 64])
6545    @parametrize("n", [48, 64])
6546    def test_compile_int8_mm(self, device, m, k, n):
6547        torch.manual_seed(1)
6548        a = torch.rand((m, k), dtype=torch.bfloat16, device=device)
6549        b = torch.rand((n, k), dtype=torch.bfloat16, device=device)
6550
6551        b_int8pack, b_scales, _ = _dynamically_quantize_per_channel(
6552            b, -128, 127, torch.int8
6553        )
6554
6555        @torch.compile
6556        def int8_mm(a, b_int8pack, b_scales):
6557            return torch._weight_int8pack_mm(
6558                a, b_int8pack, b_scales
6559            )
6560
6561        res = int8_mm(a, b_int8pack, b_scales)
6562        ref = torch.mm(a, b.transpose(0, 1))
6563
6564        mean_err = ((res - ref).abs() / ref).mean()
6565        self.assertTrue(mean_err < 0.05)
6566
6567    @onlyCPU
6568    @parametrize("m", [32, 35, 36, 40, 64])
6569    @parametrize("k", [32, 35, 36, 40, 64])
6570    # NOTE: This is intended to cover fp16_gemv_trans in
6571    # BlasKernel.cpp. Currently, bounds being divisible by 32, 8-but-not-32, and 4-but-not-8
6572    # all matter.
6573    def test_fp16_mv_transposed_first_argument_arm_cpu(self, device, m, k):
6574        torch.manual_seed(1)
6575        a = torch.rand((m, k), dtype=torch.half, device=device)
6576        b = torch.rand((1, k), dtype=torch.half, device=device)
6577
6578        prev = torch._C._get_cpu_allow_fp16_reduced_precision_reduction()
6579        try:
6580            torch._C._set_cpu_allow_fp16_reduced_precision_reduction(False)
6581            ref = torch.mm(a, b.t())
6582            try:
6583                torch._C._set_cpu_allow_fp16_reduced_precision_reduction(True)
6584            except RuntimeError as e:
6585                raise unittest.SkipTest from e
6586            res = torch.mm(a, b.t())
6587            torch.testing.assert_close(res, ref, atol=1e-2, rtol=1e-2)
6588        finally:
6589            torch._C._set_cpu_allow_fp16_reduced_precision_reduction(prev)
6590
6591    @slowTest
6592    @onlyNativeDeviceTypes
6593    # bfloat16 doesn't have sufficient precision to pass this test
6594    @dtypes(torch.half, torch.float32, torch.float64, torch.int32, torch.int64, torch.cfloat, torch.cdouble)
6595    @dtypesIfCUDA(torch.float32, torch.float64, torch.cfloat, torch.cdouble)
6596    @tf32_on_and_off(0.01)
6597    @bf32_on_and_off(0.01)
6598    def test_mm(self, device, dtype):
6599        def _test_mm(n, m, p, dtype, genf):
6600            # helper function
6601            def matrixmultiply(mat1, mat2):
6602                n = mat1.size(0)
6603                m = mat1.size(1)
6604                p = mat2.size(1)
6605                dtype_ = torch.float if dtype == torch.half else dtype
6606                if dtype == torch.half:
6607                    mat1 = mat1.float()
6608                    mat2 = mat2.float()
6609                res = torch.zeros(n, p, dtype=dtype_, device=device)
6610                for i, j in iter_indices(res):
6611                    res[i, j] = sum(mat1[i, k] * mat2[k, j] for k in range(m))
6612                return res.half() if dtype == torch.half else res
6613
6614            # contiguous case
6615            mat1 = genf(n, m)
6616            mat2 = genf(m, p)
6617            res = torch.mm(mat1, mat2)
6618
6619            res2 = matrixmultiply(mat1, mat2)
6620            self.assertEqual(res, res2)
6621
6622            # non contiguous case 1
6623            mat1 = genf(n, m)
6624            mat2 = genf(p, m).t()
6625            res = torch.mm(mat1, mat2)
6626
6627            res2 = matrixmultiply(mat1, mat2)
6628            self.assertEqual(res, res2)
6629
6630            # non contiguous case 2
6631            mat1 = genf(m, n).t()
6632            mat2 = genf(m, p)
6633            res = torch.mm(mat1, mat2)
6634
6635            res2 = matrixmultiply(mat1, mat2)
6636            self.assertEqual(res, res2)
6637
6638            # non contiguous case 3
6639            mat1 = genf(m, n).t()
6640            mat2 = genf(p, m).t()
6641            res = torch.mm(mat1, mat2)
6642
6643            res2 = matrixmultiply(mat1, mat2)
6644            self.assertEqual(res, res2)
6645
6646            # test with zero stride
6647            mat1 = genf(n, m)
6648            mat2 = genf(m, 1).expand(m, p)
6649            res = torch.mm(mat1, mat2)
6650
6651            res2 = matrixmultiply(mat1, mat2)
6652            self.assertEqual(res, res2)
6653
6654            # explicitly exercise the _out variant in torch.mm().
6655            # contiguous case
6656            mat1 = genf(n, m)
6657            mat2 = genf(m, p)
6658            res = genf(n, p)
6659            torch.mm(mat1, mat2, out=res)
6660
6661            res2 = matrixmultiply(mat1, mat2)
6662            self.assertEqual(res, res2)
6663
6664            # explicitly exercise the _out variant in torch.mm().
6665            # non contiguous case 3
6666            mat1 = genf(m, n).t()
6667            mat2 = genf(p, m).t()
6668            res = genf(n, p)
6669            torch.mm(mat1, mat2, out=res)
6670
6671            res2 = matrixmultiply(mat1, mat2)
6672            self.assertEqual(res, res2)
6673
6674        def genf_int(x, y):
6675            return torch.randint(0, 100, (x, y), dtype=dtype, device=device)
6676
6677        def genf_bfloat(x, y):
6678            return torch.randn(x, y, dtype=torch.float32, device=device).to(dtype) * 0.1
6679
6680        def genf_float(x, y):
6681            return torch.randn(x, y, dtype=dtype, device=device)
6682
6683        def genf_Half(x, y):
6684            return torch.randn(x, y, dtype=dtype, device=device)
6685
6686        for (n, m, p) in [(20, 10, 15), (15, 20, 10), (25, 18, 10)]:
6687            if (dtype == torch.int32) or (dtype == torch.int64):
6688                genf = genf_int
6689            elif (dtype == torch.bfloat16):
6690                genf = genf_bfloat
6691            elif (dtype == torch.half):
6692                genf = genf_Half
6693            else:
6694                genf = genf_float
6695
6696            _test_mm(n, m, p, dtype, genf)
6697
6698    @onlyNativeDeviceTypes
6699    def test_mm_bmm_non_memory_dense(self, device):
6700        def _slice(tensor, fn):
6701            return fn(tensor)[..., ::2]
6702        A = torch.randn(3, 6, dtype=torch.cfloat, device=device)
6703        B = torch.randn(3, 3, dtype=torch.cfloat, device=device)
6704        out = torch.empty(3, 3, device=device, dtype=torch.complex64).t()
6705        out1 = torch.empty(3, 3, device=device, dtype=torch.complex64).t()
6706        A_conj = _slice(A, torch.conj)
6707        A_conj_physical = _slice(A, torch.conj_physical)
6708
6709        self.assertEqual(torch.mm(A_conj, B, out=out), torch.mm(A_conj_physical, B, out=out))
6710        self.assertEqual(torch.mm(A_conj.t(), B, out=out), torch.mm(A_conj_physical.t(), B, out=out))
6711
6712        Ab = torch.randn(2, 3, 6, dtype=torch.cfloat, device=device)
6713        Bb = torch.randn(2, 3, 3, dtype=torch.cfloat, device=device)
6714        Bb_ = torch.randn(1, 3, 3, dtype=torch.cfloat, device=device).expand(2, 3, 3)
6715        out_b = torch.empty(2, 3, 3, device=device, dtype=torch.complex64).mT
6716
6717        Ab_conj = _slice(Ab, torch.conj)
6718        Ab_conj_physical = _slice(Ab, torch.conj_physical)
6719
6720        def t_b(tensor):
6721            return tensor.mT
6722
6723        self.assertEqual(torch.bmm(Ab_conj, Bb, out=out_b), torch.bmm(Ab_conj_physical, Bb, out=out_b))
6724        self.assertEqual(torch.bmm(t_b(Ab_conj), Bb, out=out_b), torch.bmm(t_b(Ab_conj_physical), Bb, out=out_b))
6725
6726        # test broadcasting
6727        self.assertEqual(torch.bmm(Ab_conj, Bb_, out=out_b), torch.bmm(Ab_conj_physical, Bb_, out=out_b))
6728        self.assertEqual(torch.bmm(t_b(Ab_conj), Bb_, out=out_b), torch.bmm(t_b(Ab_conj_physical), Bb_, out=out_b))
6729
6730    @onlyNativeDeviceTypes
6731    def test_mm_conjtranspose(self, device):
6732        A = torch.randn(3, 3, dtype=torch.cfloat, device=device)
6733        B = torch.randn(3, 3, dtype=torch.cfloat, device=device)
6734
6735        # A conjtranspose
6736        out1 = torch.mm(A.t().conj(), B)
6737        out1_ref = torch.mm(A.t().conj_physical(), B)
6738        self.assertEqual(out1, out1_ref)
6739
6740        # B conjtranspose
6741        out1 = torch.mm(A, B.t().conj())
6742        out1_ref = torch.mm(A, B.t().conj_physical())
6743        self.assertEqual(out1, out1_ref)
6744
6745        # A&B conjtranspose
6746        out1 = torch.mm(A.t().conj(), B.t().conj())
6747        out1_ref = torch.mm(A.t().conj_physical(), B.t().conj_physical())
6748        self.assertEqual(out1, out1_ref)
6749
6750    @onlyNativeDeviceTypes
6751    def test_mm_empty_inputs_mixed_dtype_errors(self, device):
6752        a = torch.randint(0, 10, [1, 10], dtype=torch.int16, device=device)
6753        b = torch.randn(10, 20, dtype=torch.float32, device=device)
6754        with self.assertRaisesRegex(RuntimeError, "expected .* and .* to have the same dtype, but got:"):
6755            torch.mm(a, b)
6756
6757    @onlyNativeDeviceTypes
6758    @dtypes(torch.float32, torch.float64)
6759    def test_strided_mm_bmm(self, device, dtype):
6760        # Tests strided view case with stride smaller than corresponding dimension size
6761        x = torch.tensor([[1., 2., 3.], [4., 5., 6.]], dtype=dtype, device=device)
6762        new_shape = [2, 2, 2]
6763        new_stride = [3, 1, 1]
6764        sx = torch.as_strided(x, size=new_shape, stride=new_stride)
6765
6766        torch_fn = lambda x: torch.bmm(x, x)  # noqa: E731
6767        np_fn = lambda x: np.matmul(x, x)  # noqa: E731
6768        self.compare_with_numpy(torch_fn, np_fn, sx)
6769
6770        torch_fn = lambda x: torch.mm(x, x)  # noqa: E731
6771        self.compare_with_numpy(torch_fn, np_fn, sx[0])
6772
6773    @precisionOverride({torch.half: 0.05, torch.bfloat16: 0.05})
6774    @onlyNativeDeviceTypes
6775    @dtypes(*floating_and_complex_types_and(torch.bfloat16, torch.half))
6776    @tf32_on_and_off(0.05)
6777    @bf32_on_and_off(0.05)
6778    def test_bmm(self, device, dtype):
6779        if self.device_type == 'cuda' and dtype is torch.bfloat16 and not SM53OrLater:
6780            # cuBLAS does not guarantee BFloat16 support on SM < 53.
6781            # So on PyTorch, we consider BFloat16 support on SM < 53 as
6782            # undefined bahavior
6783            return
6784
6785        batch_sizes = [1, 10]
6786        M, N, O = 23, 15, 12
6787        numpy_dtype = dtype if dtype != torch.bfloat16 else torch.float32
6788
6789        is_supported = True
6790        if dtype == torch.bfloat16 and self.device_type == 'cuda':
6791            is_supported = TEST_WITH_ROCM or SM53OrLater
6792
6793        if not is_supported:
6794            for num_batches in batch_sizes:
6795                b1 = torch.randn(num_batches, M, N, device=device).to(dtype)
6796                b2 = torch.randn(num_batches, N, O, device=device).to(dtype)
6797                self.assertRaisesRegex(RuntimeError, "type|Type|not implemented|CUBLAS_STATUS_NOT_SUPPORTED",
6798                                       lambda: torch.bmm(b1, b2))
6799            return
6800
6801        def invert_perm(p):
6802            d = {x: i for i, x in enumerate(p)}
6803            return (d[0], d[1], d[2])
6804
6805        def generate_inputs(num_batches):
6806            # transposed tensors
6807            for perm1, perm2 in itertools.product(itertools.permutations((0, 1, 2)), repeat=2):
6808                b1 = make_tensor((num_batches, M, N), dtype=dtype, device=device, low=-0.1, high=0.1)
6809                b2 = make_tensor((num_batches, N, O), dtype=dtype, device=device, low=-0.1, high=0.1)
6810                b1 = b1.permute(perm1).contiguous().permute(invert_perm(perm1))
6811                b2 = b2.permute(perm2).contiguous().permute(invert_perm(perm2))
6812                yield b1, b2
6813            # broadcasting tensors
6814            for b1, b2, b3, b4, b5, b6 in itertools.product((True, False), repeat=6):
6815                shape1 = (num_batches if b1 else 1, M if b2 else 1, N if b3 else 1)
6816                shape2 = (num_batches if b4 else 1, N if b5 else 1, O if b6 else 1)
6817                b1 = make_tensor(shape1, dtype=dtype, device=device, low=-0.1, high=0.1).expand(num_batches, M, N)
6818                b2 = make_tensor(shape2, dtype=dtype, device=device, low=-0.1, high=0.1).expand(num_batches, N, O)
6819                yield b1, b2
6820            # zero-sized tensors
6821            for z1, z2, z3, z4 in itertools.product((True, False), repeat=4):
6822                shape1 = (num_batches if z1 else 0, M if z2 else 0, N if z3 else 0)
6823                shape2 = (num_batches if z1 else 0, N if z3 else 0, O if z4 else 0)
6824                b1 = torch.randn(shape1, dtype=dtype, device=device)
6825                b2 = torch.randn(shape2, dtype=dtype, device=device)
6826                yield b1, b2
6827
6828        for num_batches in batch_sizes:
6829            for (b1, b2), perm3 in itertools.product(generate_inputs(num_batches), itertools.permutations((0, 1, 2))):
6830                res1 = torch.bmm(b1, b2)
6831                res2 = torch.full((num_batches, M, O), math.nan, dtype=dtype, device=device) \
6832                    .permute(perm3).contiguous().permute(invert_perm(perm3))
6833                torch.bmm(b1, b2, out=res2)
6834                expect = torch.from_numpy(
6835                    b1.to(numpy_dtype).cpu().numpy() @ b2.to(numpy_dtype).cpu().numpy()).to(device=device, dtype=dtype)
6836                self.assertEqual(expect, res1)
6837                self.assertEqual(expect, res2)
6838
6839                if self.device_type == 'cuda':
6840                    # check that mixed arguments are rejected
6841                    self.assertRaises(RuntimeError, lambda: torch.bmm(b1, b2.cpu()))
6842                    self.assertRaises(RuntimeError, lambda: torch.bmm(b1.cpu(), b2))
6843                    self.assertRaises(RuntimeError, lambda: torch.bmm(b1, b2, out=res2.cpu()))
6844
6845    def _test_addbmm_baddbmm(self, func, b1, b2, ref, out_tensor):
6846        getattr(out_tensor, func + "_")(b1, b2)
6847        self.assertEqual(out_tensor, ref)
6848        res3 = out_tensor.clone()
6849
6850        with self.assertWarnsOnceRegex(
6851                UserWarning, f"This overload of {func}_ is deprecated"):
6852            getattr(out_tensor, func + "_")(1, b1, b2)
6853        self.assertEqual(out_tensor, ref * 2),
6854        getattr(res3, func + "_")(b1, b2, beta=1)
6855        self.assertEqual(out_tensor, res3)
6856
6857        with self.assertWarnsOnceRegex(
6858                UserWarning, f"This overload of {func}_ is deprecated"):
6859            getattr(out_tensor, func + "_")(1., .5, b1, b2)
6860        self.assertEqual(out_tensor, ref * 2.5)
6861        getattr(res3, func + "_")(b1, b2, beta=1., alpha=.5)
6862        self.assertEqual(out_tensor, res3)
6863
6864        with self.assertWarnsOnceRegex(
6865                UserWarning, f"This overload of {func} is deprecated"):
6866            self.assertEqual(out_tensor, getattr(torch, func)(1, out_tensor, 0, b1, b2))
6867
6868        res4 = getattr(torch, func)(out_tensor, b1, b2, beta=1, alpha=.5)
6869        self.assertEqual(res4, ref * 3),
6870
6871        nan = torch.full_like(out_tensor, math.nan)
6872        res5 = getattr(torch, func)(nan, b1, b2, beta=0, alpha=1)
6873        self.assertEqual(res5, ref)
6874
6875        if b1.is_complex():
6876            res6 = getattr(torch, func)(out_tensor, b1, b2, beta=.1j, alpha=.5j)
6877            self.assertEqual(res6, out_tensor * .1j + .5j * ref)
6878        else:
6879            res6 = getattr(torch, func)(out_tensor, b1, b2, beta=.1, alpha=.5)
6880            self.assertEqual(res6, out_tensor * .1 + .5 * ref)
6881
6882        res7 = torch.full_like(out_tensor, math.nan)
6883        getattr(torch, func)(nan, b1, b2, beta=0, out=res7)
6884        self.assertEqual(res7, ref)
6885
6886    @precisionOverride({torch.half: 0.05, torch.bfloat16: 0.05})
6887    @onlyNativeDeviceTypes
6888    @dtypes(*floating_and_complex_types_and(torch.bfloat16, torch.half))
6889    @tf32_on_and_off(0.05)
6890    @bf32_on_and_off(0.05)
6891    def test_addbmm(self, device, dtype):
6892        if self.device_type == 'cuda' and dtype is torch.bfloat16 and not SM53OrLater:
6893            # cuBLAS does not guarantee BFloat16 support on SM < 53.
6894            # So on PyTorch, we consider BFloat16 support on SM < 53 as
6895            # undefined bahavior
6896            return
6897
6898        num_batches = 2
6899        M, N, O = 16, 17, 18
6900
6901        is_supported = True
6902        if dtype == torch.bfloat16:
6903            if self.device_type == 'cpu':
6904                self.precision = 1  # 43 vs 43.75
6905            else:
6906                is_supported = TEST_WITH_ROCM or SM53OrLater
6907
6908        if not is_supported:
6909            b1 = make_tensor((num_batches, M, N), dtype=dtype, device=device, low=-1, high=1)
6910            b2 = make_tensor((num_batches, N, O), dtype=dtype, device=device, low=-1, high=1)
6911            t = make_tensor((M, O), dtype=dtype, device=device, low=-1, high=1)
6912            self.assertRaisesRegex(RuntimeError, "type|Type|not implemented|CUBLAS_STATUS_NOT_SUPPORTED",
6913                                   lambda: torch.addbmm(t, b1, b2))
6914            return
6915
6916        def invert_perm(p):
6917            d = {x: i for i, x in enumerate(p)}
6918            return (d[0], d[1], d[2])
6919
6920        def generate_tensor():
6921            numpy_dtype = dtype if dtype != torch.bfloat16 else torch.float32
6922            # transposed tensors
6923            for perm1, perm2 in itertools.product(itertools.permutations((0, 1, 2)), repeat=2):
6924                for perm3 in itertools.permutations((0, 1)):
6925                    b1 = make_tensor((num_batches, M, N), dtype=dtype, device=device, low=-1, high=1) * 0.1
6926                    b2 = make_tensor((num_batches, N, O), dtype=dtype, device=device, low=-1, high=1) * 0.1
6927                    b1 = b1.permute(perm1).contiguous().permute(invert_perm(perm1))
6928                    b2 = b2.permute(perm2).contiguous().permute(invert_perm(perm2))
6929                    ref = torch.from_numpy(
6930                        b1.to(numpy_dtype).cpu().numpy() @ b2.to(numpy_dtype).cpu().numpy()
6931                    ).to(device=device, dtype=dtype).sum(0)
6932                    out_tensor = torch.zeros_like(ref).permute(perm3).contiguous().permute(perm3)
6933                    yield b1, b2, ref, out_tensor
6934            # broadcasting tensors
6935            for s1, s2, s3, s4, s5, s6 in itertools.product((True, False), repeat=6):
6936                shape1 = (num_batches if s1 else 1, M if s2 else 1, N if s3 else 1)
6937                shape2 = (num_batches if s4 else 1, N if s5 else 1, O if s6 else 1)
6938                b1 = make_tensor(shape1, dtype=dtype, device=device, low=-1, high=1).expand(num_batches, M, N) * 0.1
6939                b2 = make_tensor(shape2, dtype=dtype, device=device, low=-1, high=1).expand(num_batches, N, O) * 0.1
6940                ref = torch.from_numpy(
6941                    b1.to(numpy_dtype).cpu().numpy() @ b2.to(numpy_dtype).cpu().numpy()
6942                ).to(device=device, dtype=dtype).sum(0)
6943                out_tensor = torch.zeros_like(ref)
6944                yield b1, b2, ref, out_tensor
6945            # zero-sized tensors
6946            for z1, z2, z3, z4 in itertools.product((True, False), repeat=4):
6947                shape1 = (num_batches if z1 else 0, M if z2 else 0, N if z3 else 0)
6948                shape2 = (num_batches if z1 else 0, N if z3 else 0, O if z4 else 0)
6949                b1 = make_tensor(shape1, dtype=dtype, device=device, low=-1, high=1) * 0.1
6950                b2 = make_tensor(shape2, dtype=dtype, device=device, low=-1, high=1) * 0.1
6951                ref = torch.from_numpy(
6952                    b1.to(numpy_dtype).cpu().numpy() @ b2.to(numpy_dtype).cpu().numpy()
6953                ).to(device=device, dtype=dtype).sum(0)
6954                out_tensor = torch.zeros_like(ref)
6955                yield b1, b2, ref, out_tensor
6956
6957        for b1, b2, ref, out_tensor in generate_tensor():
6958            self._test_addbmm_baddbmm("addbmm", b1, b2, ref, out_tensor)
6959
6960    @precisionOverride({torch.half: 0.1, torch.bfloat16: 0.5})
6961    @onlyNativeDeviceTypes
6962    @dtypes(*floating_and_complex_types_and(torch.bfloat16, torch.half))
6963    @tf32_on_and_off(0.05)
6964    @bf32_on_and_off(0.05)
6965    def test_baddbmm(self, device, dtype):
6966        if self.device_type == 'cuda' and dtype is torch.bfloat16 and not SM53OrLater:
6967            # cuBLAS does not guarantee BFloat16 support on SM < 53.
6968            # So on PyTorch, we consider BFloat16 support on SM < 53 as
6969            # undefined bahavior
6970            return
6971
6972        num_batches = 10
6973        M, N, O = 12, 8, 50
6974
6975        is_supported = True
6976        if dtype == torch.bfloat16 and self.device_type == 'cuda':
6977            is_supported = TEST_WITH_ROCM or SM53OrLater
6978
6979        if not is_supported:
6980            b1 = make_tensor((num_batches, M, N), dtype=dtype, device=device, low=-1, high=1)
6981            b2 = make_tensor((num_batches, N, O), dtype=dtype, device=device, low=-1, high=1)
6982            t = make_tensor((num_batches, M, O), dtype=dtype, device=device, low=-1, high=1)
6983            self.assertRaisesRegex(RuntimeError, "type|Type|not implemented|CUBLAS_STATUS_NOT_SUPPORTED",
6984                                   lambda: torch.baddbmm(t, b1, b2))
6985            return
6986
6987        def invert_perm(p):
6988            d = {x: i for i, x in enumerate(p)}
6989            return (d[0], d[1], d[2])
6990
6991        def generate_tensor():
6992            numpy_dtype = dtype if dtype not in [torch.bfloat16, torch.half] else torch.float32
6993            # transposed tensors
6994            for perm1, perm2, perm3 in itertools.product(itertools.permutations((0, 1, 2)), repeat=3):
6995                b1 = make_tensor((num_batches, M, N), dtype=dtype, device=device, low=-1, high=1)
6996                b2 = make_tensor((num_batches, N, O), dtype=dtype, device=device, low=-1, high=1)
6997                b1 = b1.permute(perm1).contiguous().permute(invert_perm(perm1))
6998                b2 = b2.permute(perm2).contiguous().permute(invert_perm(perm2))
6999                ref = torch.from_numpy(
7000                    b1.to(numpy_dtype).cpu().numpy() @ b2.to(numpy_dtype).cpu().numpy()).to(device=device, dtype=dtype)
7001                out_tensor = torch.zeros_like(ref)
7002                out_tensor = out_tensor.permute(perm3).contiguous().permute(invert_perm(perm3))
7003                yield b1, b2, ref, out_tensor
7004            # broadcasting tensors
7005            for s1, s2, s3, s4, s5, s6 in itertools.product((True, False), repeat=6):
7006                shape1 = (num_batches if s1 else 1, M if s2 else 1, N if s3 else 1)
7007                shape2 = (num_batches if s4 else 1, N if s5 else 1, O if s6 else 1)
7008                b1 = make_tensor(shape1, dtype=dtype, device=device, low=-1, high=1).expand(num_batches, M, N)
7009                b2 = make_tensor(shape2, dtype=dtype, device=device, low=-1, high=1).expand(num_batches, N, O)
7010                ref = torch.from_numpy(
7011                    b1.to(numpy_dtype).cpu().numpy() @ b2.to(numpy_dtype).cpu().numpy()).to(device=device, dtype=dtype)
7012                out_tensor = torch.zeros_like(ref)
7013                yield b1, b2, ref, out_tensor
7014            # zero-sized tensors
7015            for z1, z2, z3, z4 in itertools.product((True, False), repeat=4):
7016                shape1 = (num_batches if z1 else 0, M if z2 else 0, N if z3 else 0)
7017                shape2 = (num_batches if z1 else 0, N if z3 else 0, O if z4 else 0)
7018                b1 = make_tensor(shape1, dtype=dtype, device=device, low=-2, high=2)
7019                b2 = make_tensor(shape2, dtype=dtype, device=device, low=-2, high=2)
7020                ref = torch.from_numpy(
7021                    b1.to(numpy_dtype).cpu().numpy() @ b2.to(numpy_dtype).cpu().numpy()).to(device=device, dtype=dtype)
7022                out_tensor = torch.zeros_like(ref)
7023                yield b1, b2, ref, out_tensor
7024
7025        for b1, b2, ref, out_tensor in generate_tensor():
7026            self._test_addbmm_baddbmm("baddbmm", b1, b2, ref, out_tensor)
7027
7028    @precisionOverride({torch.float32: 5e-3, torch.complex64: 1e-3})
7029    @skipCUDAIfNoMagma
7030    @skipCPUIfNoLapack
7031    @dtypes(*floating_and_complex_types())
7032    def test_pinverse(self, device, dtype):
7033        make_fullrank = make_fullrank_matrices_with_distinct_singular_values
7034        make_arg = partial(make_fullrank, device=device, dtype=dtype)
7035
7036        def run_test(M):
7037            # Testing against definition for pseudo-inverses
7038            MPI = torch.pinverse(M)
7039            MPI_ = MPI.cpu().numpy()
7040            M_ = M.cpu().numpy()
7041            if M.numel() > 0:
7042                self.assertEqual(M_, np.matmul(np.matmul(M_, MPI_), M_))
7043                self.assertEqual(MPI_, np.matmul(np.matmul(MPI_, M_), MPI_))
7044                self.assertEqual(np.matmul(M_, MPI_), np.matmul(M_, MPI_).swapaxes(-2, -1).conj())
7045                self.assertEqual(np.matmul(MPI_, M_), np.matmul(MPI_, M_).swapaxes(-2, -1).conj())
7046            else:
7047                self.assertEqual(M.shape, MPI.shape[:-2] + (MPI.shape[-1], MPI.shape[-2]))
7048        for sizes in [(5, 5), (3, 5, 5), (3, 7, 5, 5),  # square matrices
7049                      (3, 2), (5, 3, 2), (7, 5, 3, 2),  # fat matrices
7050                      (2, 3), (5, 2, 3), (7, 5, 2, 3),  # thin matrices
7051                      (0, 0), (0, 2), (2, 0), (3, 0, 0), (0, 3, 0), (0, 0, 3)]:  # zero numel matrices
7052            M = torch.randn(*sizes, dtype=dtype, device=device)
7053            run_test(M)
7054
7055        # Test inverse and pseudo-inverse for invertible matrix
7056        for sizes in [(5, 5), (3, 5, 5), (3, 7, 5, 5)]:
7057            matsize = sizes[-1]
7058            batchdims = sizes[:-2]
7059            M = make_arg(*batchdims, matsize, matsize)
7060            self.assertEqual(torch.eye(matsize, dtype=dtype, device=device).expand(sizes), M.pinverse().matmul(M),
7061                             atol=1e-7, rtol=0, msg='pseudo-inverse for invertible matrix')
7062
7063    @skipCPUIfNoLapack
7064    @skipCUDAIfNoMagmaAndNoCusolver
7065    @dtypes(torch.double, torch.cdouble)
7066    def test_matrix_power_non_negative(self, device, dtype):
7067        def check(*size):
7068            t = make_tensor(size, dtype=dtype, device=device)
7069            for n in range(8):
7070                res = torch.linalg.matrix_power(t, n)
7071                ref = np.linalg.matrix_power(t.cpu().numpy(), n)
7072                self.assertEqual(res.cpu(), torch.from_numpy(ref))
7073
7074        check(0, 0)
7075        check(1, 1)
7076        check(5, 5)
7077        check(0, 3, 3)
7078        check(2, 3, 3)
7079
7080    @skipCPUIfNoLapack
7081    @skipCUDAIfNoMagmaAndNoCusolver
7082    @dtypes(torch.double, torch.cdouble)
7083    def test_matrix_power_negative(self, device, dtype):
7084        make_fullrank = make_fullrank_matrices_with_distinct_singular_values
7085        make_arg = partial(make_fullrank, device=device, dtype=dtype)
7086
7087        def check(*size):
7088            t = make_arg(*size)
7089            for n in range(-7, 0):
7090                res = torch.linalg.matrix_power(t, n)
7091                ref = np.linalg.matrix_power(t.cpu().numpy(), n)
7092                self.assertEqual(res.cpu(), torch.from_numpy(ref))
7093
7094        check(0, 0)
7095        check(5, 5)
7096        check(2, 0, 0)
7097        check(0, 3, 3)
7098        check(2, 3, 3)
7099        check(2, 3, 5, 5)
7100
7101    @skipCUDAIfNoMagma
7102    @skipCPUIfNoLapack
7103    @dtypes(torch.float, torch.complex64)
7104    def test_linalg_matrix_exp_utils(self, device, dtype):
7105        # test linear combination
7106        def run_test(coeff_shape, data_shape):
7107            coeffs = torch.rand(*coeff_shape, device=device, dtype=torch.float)
7108            x = torch.rand(coeff_shape[1], *data_shape, device=device, dtype=dtype)
7109
7110            res1 = torch._compute_linear_combination(x, coeffs)
7111            res2 = (x.unsqueeze(0) * coeffs.view(*coeff_shape, *([1] * len(data_shape)))).sum(1)
7112            self.assertEqual(res1, res2, atol=1e-5, rtol=0.0)
7113
7114            # check `out=` version
7115            res3 = torch.zeros(coeff_shape[0], *data_shape, device=device, dtype=dtype)
7116            torch._compute_linear_combination(x, coeffs, out=res3)
7117            self.assertEqual(res1, res3, atol=1e-5, rtol=0.0)
7118
7119            res4 = torch.ones(coeff_shape[0], *data_shape, device=device, dtype=dtype)
7120            torch._compute_linear_combination(x, coeffs, out=res4)
7121            self.assertEqual(res1, res4 - 1.0, atol=1e-5, rtol=0.0)
7122
7123            res5 = torch.ones(coeff_shape[0], *data_shape, device=device, dtype=dtype)
7124            res5_clone = res5.clone()
7125            torch._compute_linear_combination(x, coeffs, out=res5)
7126            self.assertEqual(res1, res5 - res5_clone, atol=1e-5, rtol=0.0)
7127
7128        run_test([1, 3], [2, 2])
7129        run_test([3, 1], [2, 2])
7130        run_test([1, 10], [10, 10])
7131        run_test([10, 1], [10, 10])
7132        run_test([5, 3], [2, 2])
7133        run_test([5, 3], [100, 100])
7134        run_test([3, 4], [3, 3, 3])
7135        run_test([3, 4], [3, 3, 3, 3])
7136
7137        # Regression test for https://github.com/pytorch/pytorch/issues/94124
7138        with self.assertRaises(RuntimeError):
7139            x = torch.rand([], device=device, dtype=dtype)
7140            coeffs = torch.rand([2, 2], device=device, dtype=dtype)
7141            res = torch._compute_linear_combination(x, coeffs)
7142
7143    @onlyCPU
7144    @skipCPUIfNoLapack
7145    @dtypes(torch.complex64)
7146    def test_linalg_matrix_exp_no_warnings(self, device, dtype):
7147        # this tests https://github.com/pytorch/pytorch/issues/80948
7148        with freeze_rng_state():
7149            torch.manual_seed(42)
7150            tens = 0.5 * torch.randn(10, 3, 3, dtype=dtype, device=device)
7151            tens = (0.5 * (tens.transpose(-1, -2) + tens))
7152            with warnings.catch_warnings(record=True) as w:
7153                tens.imag = torch.matrix_exp(tens.imag)
7154                self.assertFalse(len(w))
7155
7156    @skipCUDAIfNoMagma
7157    @skipCPUIfNoLapack
7158    @dtypes(torch.float, torch.double, torch.complex64, torch.complex128)
7159    def test_linalg_matrix_exp_boundary_cases(self, device, dtype):
7160        expm = torch.linalg.matrix_exp
7161
7162        with self.assertRaisesRegex(RuntimeError, "Expected a floating point or complex tensor"):
7163            expm(torch.randn(3, 3).type(torch.int))
7164
7165        with self.assertRaisesRegex(RuntimeError, "must have at least 2 dimensions"):
7166            expm(torch.randn(3))
7167
7168        with self.assertRaisesRegex(RuntimeError, "must be batches of square matrices"):
7169            expm(torch.randn(3, 2, 1))
7170
7171        # check 1x1 matrices
7172        x = torch.randn(3, 3, 1, 1)
7173        self.assertEqual(expm(x), x.exp())
7174
7175    @skipCUDAIfNoMagma
7176    @skipCPUIfNoLapack
7177    @dtypes(torch.float, torch.double, torch.complex64, torch.complex128)
7178    def test_linalg_matrix_exp_perverse_nan_values(self, device, dtype):
7179        expm = torch.linalg.matrix_exp
7180
7181        def with_nan(x):
7182            x[0, 0, 0] = torch.nan
7183            return x
7184
7185        # Check small batches
7186        x = with_nan(torch.randn(1, 1, 1))
7187        self.assertTrue(torch.isnan(expm(x)).any())
7188        x = with_nan(torch.randn(1, 2, 2))
7189        for v in [1, 2, 3, 4, 5, 6, 7, 8, 9, 100, 1000]:
7190            self.assertTrue(torch.isnan(expm(x / v)).any())
7191
7192        # Check large batches
7193        x = with_nan(torch.randn(2, 2, 2))
7194        self.assertTrue(torch.isnan(expm(x)).any())
7195        x = with_nan(torch.randn(4096, 2, 2))
7196        self.assertTrue(torch.isnan(expm(x)).any())
7197
7198    @slowTest
7199    @skipCUDAIfNoMagma
7200    @skipCPUIfNoLapack
7201    @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble)
7202    def test_linalg_matrix_exp_analytic(self, device, dtype):
7203        expm = torch.linalg.matrix_exp
7204        # check zero matrix
7205        x = torch.zeros(20, 20, dtype=dtype, device=device)
7206        self.assertTrue((expm(x) == torch.eye(20, 20, dtype=dtype, device=device)).all().item())
7207
7208        def normalize_to_1_operator_norm(sample, desired_norm):
7209            sample_norm, _ = sample.abs().sum(-2).max(-1)
7210            sample_to_1_norm = sample / sample_norm.unsqueeze(-1).unsqueeze(-1)
7211            return sample_to_1_norm * desired_norm
7212
7213        def gen_good_cond_number_matrices(*n):
7214            """
7215            Generates a diagonally-domimant matrix
7216            with the eigenvalues centered at 1
7217            and the radii at most (n[-1] - 1) / (n[-2] ** 2)
7218            """
7219            identity = torch.eye(n[-2], n[-1], dtype=dtype, device=device).expand(*n)
7220            x = torch.rand(*n, dtype=dtype, device=device) / (n[-1] ** 2)
7221            x = (x - x * identity) + identity
7222            return x
7223
7224        def run_test(*n):
7225            if dtype == torch.float:
7226                thetas = [
7227                    1.192092800768788e-07,  # deg 1
7228                    5.978858893805233e-04,  # deg 2
7229                    5.116619363445086e-02,  # deg 4
7230                    5.800524627688768e-01,  # deg 8
7231                    1.461661507209034e+00,  # deg 12
7232                    3.010066362817634e+00   # deg 18
7233                ]
7234            else:  # if torch.double
7235                thetas = [
7236                    2.220446049250313e-16,  # deg 1
7237                    2.580956802971767e-08,  # deg 2
7238                    3.397168839976962e-04,  # deg 4
7239                    4.991228871115323e-02,  # deg 8
7240                    2.996158913811580e-01,  # deg 12
7241                    1.090863719290036e+00   # deg 18
7242                ]
7243
7244            # generate input
7245            q = gen_good_cond_number_matrices(*n)
7246            q_ = q.cpu().numpy()
7247            qinv = torch.inverse(q)
7248            qinv_ = qinv.cpu().numpy()
7249            d = torch.randn(n[:-1], dtype=dtype, device=device)
7250            x = torch.from_numpy(
7251                np.matmul(q_, np.matmul(torch.diag_embed(d).cpu().numpy(), qinv_))).to(device)
7252            x_norm, _ = x.abs().sum(-2).max(-1)
7253
7254            # test simple analytic whatever norm generated
7255            mexp = expm(x)
7256            mexp_analytic = np.matmul(
7257                q_,
7258                np.matmul(
7259                    torch.diag_embed(d.exp()).cpu().numpy(),
7260                    qinv_
7261                )
7262            )
7263            self.assertEqual(mexp, mexp_analytic, atol=1e-3, rtol=0.0)
7264
7265            # generate norms to test different degree expansions
7266            sample_norms = []
7267            for i in range(len(thetas) - 1):
7268                sample_norms.append(0.5 * (thetas[i] + thetas[i + 1]))
7269            sample_norms = [thetas[0] / 2] + sample_norms + [thetas[-1] * 2]
7270
7271            # matrices to equal norm
7272            for sample_norm in sample_norms:
7273                x_normalized = normalize_to_1_operator_norm(x, sample_norm)
7274
7275                mexp = expm(x_normalized)
7276                mexp_analytic = np.matmul(
7277                    q_,
7278                    np.matmul(
7279                        torch.diag_embed((d / x_norm.unsqueeze(-1) * sample_norm).exp()).cpu().numpy(),
7280                        qinv_
7281                    )
7282                )
7283                self.assertEqual(mexp, mexp_analytic, atol=1e-3, rtol=0.0)
7284
7285        # single matrix
7286        run_test(2, 2)
7287        run_test(3, 3)
7288        run_test(4, 4)
7289        run_test(5, 5)
7290        run_test(100, 100)
7291        run_test(200, 200)
7292
7293        # small batch of matrices
7294        run_test(3, 2, 2)
7295        run_test(3, 3, 3)
7296        run_test(3, 4, 4)
7297        run_test(3, 5, 5)
7298        run_test(3, 100, 100)
7299        run_test(3, 200, 200)
7300
7301        # large batch of matrices
7302        run_test(3, 3, 2, 2)
7303        run_test(3, 3, 3, 3)
7304        run_test(3, 3, 4, 4)
7305        run_test(3, 3, 5, 5)
7306        run_test(3, 3, 100, 100)
7307        run_test(3, 3, 200, 200)
7308
7309    @skipCUDAIfNoMagma
7310    @skipCPUIfNoLapack
7311    @dtypes(torch.float, torch.double)
7312    def test_linalg_matrix_exp_batch(self, device, dtype):
7313
7314        def run_test(*n):
7315            tensors_batch = torch.zeros(n, dtype=dtype, device=device)
7316            tensors_batch = tensors_batch.view(-1, n[-2], n[-1])
7317
7318            num_matrices = tensors_batch.size(0)
7319            tensors_list = []
7320            for i in range(num_matrices):
7321                tensors_list.append(torch.randn(n[-2], n[-1], dtype=dtype, device=device))
7322
7323            for i in range(num_matrices):
7324                tensors_batch[i, ...] = tensors_list[i]
7325
7326            tensors_exp_map = (torch.linalg.matrix_exp(x) for x in tensors_list)
7327            tensors_exp_batch = torch.linalg.matrix_exp(tensors_batch)
7328
7329            for i, tensor_exp in enumerate(tensors_exp_map):
7330                self.assertEqual(tensors_exp_batch[i, ...], tensor_exp)
7331
7332        # small batch of matrices
7333        run_test(3, 2, 2)
7334        run_test(3, 3, 3)
7335        run_test(3, 4, 4)
7336        run_test(3, 5, 5)
7337
7338        # large batch of matrices
7339        run_test(3, 3, 2, 2)
7340        run_test(3, 3, 3, 3)
7341        run_test(3, 3, 4, 4)
7342        run_test(3, 3, 5, 5)
7343
7344    @skipCUDAIfNoMagma
7345    @skipCPUIfNoLapack
7346    @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble)
7347    def test_linalg_matrix_exp_compare_with_taylor(self, device, dtype):
7348
7349        def normalize_to_1_operator_norm(sample, desired_norm):
7350            sample_norm, _ = sample.abs().sum(-2).max(-1)
7351            sample_to_1_norm = sample / sample_norm.unsqueeze(-1).unsqueeze(-1)
7352            return sample_to_1_norm * desired_norm
7353
7354        def gen_good_cond_number_matrices(*n):
7355            """
7356            Generates a diagonally-domimant matrix
7357            with the eigenvalues centered at 1
7358            and the radii at most (n[-1] - 1) / (n[-2] ** 2)
7359            """
7360            identity = torch.eye(n[-2], n[-1], dtype=dtype, device=device).expand(*n)
7361            x = torch.rand(*n, dtype=dtype, device=device) / (n[-1] ** 2)
7362            x = (x - x * identity) + identity
7363            return x
7364
7365        def get_taylor_approximation(a, deg):
7366            a_ = a.cpu().numpy()
7367            identity = torch.eye(a.size(-2), a.size(-1), dtype=dtype, device=device).expand_as(a)
7368            res = identity.cpu().numpy()
7369            taylor_term = identity.cpu().numpy()
7370
7371            for i in range(1, deg + 1):
7372                taylor_term = np.matmul(a_, taylor_term) / i
7373                res = res + taylor_term
7374
7375            return res
7376
7377        def scale_square(a, deg):
7378            if a.abs().pow(2).sum().sqrt() < 1.0:
7379                return get_taylor_approximation(a, 12)
7380            else:
7381                s = int(torch.log2(a.abs().pow(2).sum().sqrt()).ceil().item())
7382                b = a / (2 ** s)
7383                b = get_taylor_approximation(b, 18)
7384                for _ in range(s):
7385                    b = np.matmul(b, b)
7386                return torch.from_numpy(b).to(a.device)
7387
7388        def run_test(*n):
7389            degs = [1, 2, 4, 8, 12, 18]
7390            if dtype == torch.float:
7391                thetas = [
7392                    1.192092800768788e-07,  # deg 1
7393                    5.978858893805233e-04,  # deg 2
7394                    5.116619363445086e-02,  # deg 4
7395                    5.800524627688768e-01,  # deg 8
7396                    1.461661507209034e+00,  # deg 12
7397                    3.010066362817634e+00   # deg 18
7398                ]
7399            else:  # if torch.double
7400                thetas = [
7401                    2.220446049250313e-16,  # deg 1
7402                    2.580956802971767e-08,  # deg 2
7403                    3.397168839976962e-04,  # deg 4
7404                    4.991228871115323e-02,  # deg 8
7405                    2.996158913811580e-01,  # deg 12
7406                    1.090863719290036e+00   # deg 18
7407                ]
7408
7409            # generate norms to test different degree expansions
7410            sample_norms = []
7411            for i in range(len(thetas) - 1):
7412                sample_norms.append(0.5 * (thetas[i] + thetas[i + 1]))
7413            sample_norms = [thetas[0] / 2] + sample_norms + [thetas[-1] * 2]
7414            degs = [degs[0]] + degs
7415
7416            for sample_norm, deg in zip(sample_norms, degs):
7417                x = gen_good_cond_number_matrices(*n)
7418                x = normalize_to_1_operator_norm(x, sample_norm)
7419
7420                mexp = torch.linalg.matrix_exp(x)
7421                mexp_taylor = scale_square(x, deg)
7422
7423                self.assertEqual(mexp, mexp_taylor, atol=1e-2, rtol=0.0)
7424
7425        # single matrix
7426        run_test(2, 2)
7427        run_test(3, 3)
7428        run_test(4, 4)
7429        run_test(5, 5)
7430
7431        # small batch of matrices
7432        run_test(3, 2, 2)
7433        run_test(3, 3, 3)
7434        run_test(3, 4, 4)
7435        run_test(3, 5, 5)
7436
7437        # large batch of matrices
7438        run_test(3, 3, 2, 2)
7439        run_test(3, 3, 3, 3)
7440        run_test(3, 3, 4, 4)
7441        run_test(3, 3, 5, 5)
7442
7443    @skipCUDAIfNoMagma
7444    @skipCPUIfNoLapack
7445    @dtypes(*floating_and_complex_types())
7446    @precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3,
7447                        torch.float64: 1e-8, torch.complex128: 1e-8})
7448    def test_slogdet(self, device, dtype):
7449        from torch.testing._internal.common_utils import (random_hermitian_matrix, random_hermitian_psd_matrix,
7450                                                          random_hermitian_pd_matrix, random_square_matrix_of_rank)
7451
7452        # mat_chars denotes matrix characteristics
7453        # possible values are: hermitian, hermitian_psd, hermitian_pd, singular, non_singular
7454        def run_test(matsize, batchdims, mat_chars):
7455            num_matrices = np.prod(batchdims)
7456            list_of_matrices = []
7457            if num_matrices != 0:
7458                for idx in range(num_matrices):
7459                    mat_type = idx % len(mat_chars)
7460                    if mat_chars[mat_type] == 'hermitian':
7461                        list_of_matrices.append(random_hermitian_matrix(matsize, dtype=dtype, device=device))
7462                    elif mat_chars[mat_type] == 'hermitian_psd':
7463                        list_of_matrices.append(random_hermitian_psd_matrix(matsize, dtype=dtype, device=device))
7464                    elif mat_chars[mat_type] == 'hermitian_pd':
7465                        list_of_matrices.append(random_hermitian_pd_matrix(matsize, dtype=dtype, device=device))
7466                    elif mat_chars[mat_type] == 'singular':
7467                        list_of_matrices.append(torch.ones(matsize, matsize, dtype=dtype, device=device))
7468                    elif mat_chars[mat_type] == 'non_singular':
7469                        list_of_matrices.append(random_square_matrix_of_rank(matsize, matsize, dtype=dtype, device=device))
7470                full_tensor = torch.stack(list_of_matrices, dim=0).reshape(batchdims + (matsize, matsize))
7471            else:
7472                full_tensor = torch.randn(*batchdims, matsize, matsize, dtype=dtype, device=device)
7473
7474            actual_value = torch.linalg.slogdet(full_tensor)
7475            expected_value = np.linalg.slogdet(full_tensor.cpu().numpy())
7476            self.assertEqual(expected_value[0], actual_value[0], atol=self.precision, rtol=self.precision)
7477            self.assertEqual(expected_value[1], actual_value[1], atol=self.precision, rtol=self.precision)
7478
7479            # test out=variant
7480            sign_out = torch.empty_like(actual_value[0])
7481            logabsdet_out = torch.empty_like(actual_value[1])
7482            ans = torch.linalg.slogdet(full_tensor, out=(sign_out, logabsdet_out))
7483            self.assertEqual(ans[0], sign_out)
7484            self.assertEqual(ans[1], logabsdet_out)
7485            self.assertEqual(sign_out, actual_value[0])
7486            self.assertEqual(logabsdet_out, actual_value[1])
7487
7488        for matsize, batchdims in itertools.product([0, 3, 5], [(0,), (3,), (5, 3)]):
7489            run_test(matsize, batchdims, mat_chars=['hermitian_pd'])
7490            run_test(matsize, batchdims, mat_chars=['singular'])
7491            run_test(matsize, batchdims, mat_chars=['non_singular'])
7492            run_test(matsize, batchdims, mat_chars=['hermitian', 'hermitian_pd', 'hermitian_psd'])
7493            run_test(matsize, batchdims, mat_chars=['singular', 'non_singular'])
7494
7495    @skipCUDAIfNoMagma
7496    @skipCPUIfNoLapack
7497    @dtypes(*floating_and_complex_types())
7498    def test_slogdet_errors_and_warnings(self, device, dtype):
7499        # slogdet requires the input to be a square matrix or batch of square matrices
7500        a = torch.randn(2, 3, device=device, dtype=dtype)
7501        with self.assertRaisesRegex(RuntimeError, r'must be batches of square matrices'):
7502            torch.linalg.slogdet(a)
7503
7504        # slogdet requires the input to be at least 2 dimensional tensor
7505        a = torch.randn(2, device=device, dtype=dtype)
7506        with self.assertRaisesRegex(RuntimeError, r'must have at least 2 dimensions'):
7507            torch.linalg.slogdet(a)
7508
7509        a = torch.randn(2, 2, device=device, dtype=torch.bfloat16)
7510        with self.assertRaisesRegex(RuntimeError, r'Low precision dtypes not supported'):
7511            torch.linalg.slogdet(a)
7512
7513        # if non-empty out tensor with wrong shape is passed a warning is given
7514        a = torch.randn(2, 3, 3, device=device, dtype=dtype)
7515        sign_out = torch.empty(1, device=device, dtype=dtype)
7516        real_dtype = a.real.dtype if dtype.is_complex else dtype
7517        logabsdet_out = torch.empty(1, device=device, dtype=real_dtype)
7518        with warnings.catch_warnings(record=True) as w:
7519            # Trigger warning
7520            torch.linalg.slogdet(a, out=(sign_out, logabsdet_out))
7521            # Check warning occurs
7522            self.assertEqual(len(w), 1)
7523            self.assertTrue("An output with one or more elements was resized" in str(w[-1].message))
7524
7525        # device should match
7526        if torch.cuda.is_available():
7527            wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda'
7528            sign_out = torch.empty(0, device=wrong_device, dtype=dtype)
7529            logabsdet_out = torch.empty(0, device=wrong_device, dtype=real_dtype)
7530            with self.assertRaisesRegex(RuntimeError, "tensors to be on the same device"):
7531                torch.linalg.slogdet(a, out=(sign_out, logabsdet_out))
7532
7533    # FIXME One of the backends of lu_factor fails in windows. I haven't investigated which or why
7534    # https://github.com/pytorch/pytorch/issues/75225
7535    @unittest.skipIf(IS_WINDOWS, "Skipped on Windows!")
7536    @skipCUDAIfNoCusolver
7537    @skipCPUIfNoLapack
7538    @dtypes(torch.double)
7539    def test_det_logdet_slogdet(self, device, dtype):
7540        def reference_slogdet(M):
7541            sdet, logabsdet = np.linalg.slogdet(M.detach().cpu().numpy())
7542            return M.new_tensor(sdet), M.new_tensor(logabsdet)
7543
7544        def test_single_det(M, target, desc):
7545            target_sdet, target_logabsdet = target
7546
7547            det = M.det()
7548            logdet = M.logdet()
7549            sdet, logabsdet = M.slogdet()
7550            linalg_sdet, linalg_logabsdet = torch.linalg.slogdet(M)
7551
7552            # Test det
7553            self.assertEqual(det, target_sdet * target_logabsdet.exp(),
7554                             atol=1e-6, rtol=0, msg=f'{desc} (det)')
7555
7556            # Test slogdet
7557            # Compare the overall value rather than individual parts because of
7558            # precision issues when det is near zero.
7559            self.assertEqual(sdet * logabsdet.exp(), target_sdet * target_logabsdet.exp(),
7560                             atol=1e-6, rtol=0, msg=f'{desc} (slogdet)')
7561            self.assertEqual(linalg_sdet * linalg_logabsdet.exp(), target_sdet * target_logabsdet.exp(),
7562                             atol=1e-6, rtol=0, msg=f'{desc} (linalg_slogdet)')
7563
7564            # Test logdet
7565            # Compare logdet against our own pytorch slogdet because they should
7566            # be consistent, while it may behave slightly differently with other
7567            # slogdet implementations when det is near zero due to precision
7568            # issues.
7569            if sdet.item() < 0:
7570                self.assertTrue(logdet.item() != logdet.item(), f'{desc} (logdet negative case)')
7571            else:
7572                self.assertEqual(logdet.exp(), target_logabsdet.exp(),
7573                                 atol=1e-6, rtol=0, msg=f'{desc} (logdet non-negative case)')
7574
7575        eye = torch.eye(5, dtype=dtype, device=device)
7576        test_single_det(eye, (torch.ones((), dtype=dtype, device=device), torch.zeros((), dtype=dtype, device=device)), 'identity')
7577        # Testing bug in #34061 (https://github.com/pytorch/pytorch/issues/34061)
7578        for n in range(250, 551, 100):
7579            mat = torch.randn(n, n, dtype=dtype, device=device)
7580            q, _ = torch.qr(mat)
7581            ref_det, ref_logabsdet = reference_slogdet(q)
7582            test_single_det(q, (ref_det, ref_logabsdet), 'orthogonal')
7583
7584        def test(M):
7585            assert M.size(0) >= 5, 'this helper fn assumes M to be at least 5x5'
7586            M = M.to(device)
7587
7588            ref_M_sdet, ref_M_logabsdet = reference_slogdet(M)
7589
7590            test_single_det(M, (ref_M_sdet, ref_M_logabsdet), 'basic')
7591            if ref_M_logabsdet.exp().item() >= 1e-6:  # skip singular
7592                M_inv = M.inverse()
7593                test_single_det(M_inv, reference_slogdet(M_inv), 'inverse')
7594
7595            test_single_det(M, (ref_M_sdet, ref_M_logabsdet), 'transpose')
7596
7597            for x in [0, 2, 4]:
7598                for scale in [-2, -0.1, 0, 10]:
7599                    if scale > 0:
7600                        target = ref_M_sdet, ref_M_logabsdet + math.log(scale)
7601                    elif scale == 0:
7602                        target = torch.zeros_like(ref_M_sdet), torch.full_like(ref_M_logabsdet, -inf)
7603                    else:
7604                        target = ref_M_sdet.neg(), ref_M_logabsdet + math.log(-scale)
7605
7606                    # dim 0
7607                    M_clone = M.clone()
7608                    M_clone[:, x] *= scale
7609                    test_single_det(M_clone, target, 'scale a row')
7610                    # dim 1
7611                    M_clone = M.clone()
7612                    M_clone[x, :] *= scale
7613                    test_single_det(M_clone, target, 'scale a column')
7614
7615            for x1, x2 in [(0, 3), (4, 1), (3, 2)]:
7616                assert x1 != x2, 'x1 and x2 needs to be different for this test'
7617                target = torch.zeros_like(ref_M_sdet), torch.full_like(ref_M_logabsdet, -inf)
7618                # dim 0
7619                M_clone = M.clone()
7620                M_clone[:, x2] = M_clone[:, x1]
7621                test_single_det(M_clone, target, 'two rows are same')
7622                # dim 1
7623                M_clone = M.clone()
7624                M_clone[x2, :] = M_clone[x1, :]
7625                test_single_det(M_clone, target, 'two columns are same')
7626
7627                for scale1, scale2 in [(0.3, -1), (0, 2), (10, 0.1)]:
7628                    det_scale = scale1 * scale2 * -1
7629                    if det_scale > 0:
7630                        target = ref_M_sdet, ref_M_logabsdet + math.log(det_scale)
7631                    elif det_scale == 0:
7632                        target = torch.zeros_like(ref_M_sdet), torch.full_like(ref_M_logabsdet, -inf)
7633                    else:
7634                        target = ref_M_sdet.neg(), ref_M_logabsdet + math.log(-det_scale)
7635
7636                    # dim 0
7637                    M_clone = M.clone()
7638                    t = M_clone[:, x1] * scale1
7639                    M_clone[:, x1] += M_clone[:, x2] * scale2
7640                    M_clone[:, x2] = t
7641                    test_single_det(M_clone, target, 'exchanging rows')
7642                    # dim 1
7643                    M_clone = M.clone()
7644                    t = M_clone[x1, :] * scale1
7645                    M_clone[x1, :] += M_clone[x2, :] * scale2
7646                    M_clone[x2, :] = t
7647                    test_single_det(M_clone, target, 'exchanging columns')
7648
7649        def get_random_mat_scale(n):
7650            # For matrices with values i.i.d. with 0 mean, unit variance, and
7651            # subexponential tail, we have:
7652            #   E[log det(A^2)] \approx log((n-1)!)
7653            #
7654            # Notice:
7655            #   log Var[det(A)] = log E[det(A^2)] >= E[log det(A^2)]
7656            #
7657            # So:
7658            #   stddev[det(A)] >= sqrt( (n-1)! )
7659            #
7660            # We use this as an intuitive guideline to scale random generated
7661            # matrices so our closeness tests can work more robustly:
7662            #   scale by sqrt( (n-1)! )^(-1/n) = ( (n-1)! )^(-1/(2n))
7663            #
7664            # source: https://arxiv.org/pdf/1112.0752.pdf
7665
7666            # TODO: technically we need subexponential distn for this to hold,
7667            #       but we mostly use gaussian entries below. Consider switching
7668            #       to Chi-sq if this turns out not stable enough, since Chi-sq
7669            #       is easy enough to sample from.
7670            return math.factorial(n - 1) ** (-1.0 / (2 * n))
7671
7672        for n in [5, 10, 25]:
7673            scale = get_random_mat_scale(n)
7674            test(torch.randn(n, n, dtype=dtype, device=device) * scale)
7675            r = torch.randn(n, n, dtype=dtype, device=device) * scale
7676            # symmetric psd
7677            test(r.mm(r.t()))
7678            # symmetric pd
7679            r = torch.randn(n, n, dtype=dtype, device=device) * scale
7680            test(r.mm(r.t()) + torch.eye(n, dtype=dtype, device=device) * 1e-6)
7681            # symmetric
7682            r = torch.randn(n, n, dtype=dtype, device=device) * scale
7683            for i in range(n):
7684                for j in range(i):
7685                    r[i, j] = r[j, i]
7686            test(r)
7687            # non-contiguous
7688            test((torch.randn(n, n, n + 1, dtype=dtype, device=device) * scale)[:, 2, 1:])
7689            # det = 0
7690            r = torch.randn(n, n, dtype=dtype, device=device) * scale
7691            u, s, v = r.svd()
7692            if reference_slogdet(u)[0] < 0:
7693                u = -u
7694            if reference_slogdet(v)[0] < 0:
7695                v = -v
7696            s[0] *= -1
7697            s[-1] = 0
7698            test(u.mm(s.diag()).mm(v))
7699
7700        # Small values to test numerical stability. Note that we don't scale
7701        # this matrix.
7702        r = torch.randn(512, 512, dtype=dtype, device=device)
7703        u, s, v = r.svd()
7704        s.fill_(1. / (100 * s.numel()))
7705        test(u.mm(s.diag()).mm(v))
7706
7707    @skipCUDAIfNoMagma
7708    @skipCPUIfNoLapack
7709    @dtypes(torch.double)
7710    def test_det_logdet_slogdet_batched(self, device, dtype):
7711        from torch.testing._internal.common_utils import (random_symmetric_matrix, random_symmetric_psd_matrix,
7712                                                          random_symmetric_pd_matrix, random_square_matrix_of_rank)
7713
7714        # mat_chars denotes matrix characteristics
7715        # possible values are: sym, sym_psd, sym_pd, sing, non_sym
7716        def run_test(matsize, batchdims, mat_chars):
7717            num_matrices = reduce(operator.mul, batchdims, 1)
7718            list_of_matrices = []
7719
7720            for idx in range(num_matrices):
7721                mat_type = idx % len(mat_chars)
7722                if mat_chars[mat_type] == 'sym':
7723                    list_of_matrices.append(random_symmetric_matrix(matsize, dtype=dtype, device=device))
7724                elif mat_chars[mat_type] == 'sym_psd':
7725                    list_of_matrices.append(random_symmetric_psd_matrix(matsize, dtype=dtype, device=device))
7726                elif mat_chars[mat_type] == 'sym_pd':
7727                    list_of_matrices.append(random_symmetric_pd_matrix(matsize, dtype=dtype, device=device))
7728                elif mat_chars[mat_type] == 'sing':
7729                    list_of_matrices.append(torch.ones(matsize, matsize, dtype=dtype, device=device))
7730                elif mat_chars[mat_type] == 'non_sing':
7731                    list_of_matrices.append(random_square_matrix_of_rank(matsize, matsize, dtype=dtype, device=device))
7732            full_tensor = torch.stack(list_of_matrices, dim=0).reshape(batchdims + (matsize, matsize))
7733            # Scaling adapted from `get_random_mat_scale` in _test_det_logdet_slogdet
7734            full_tensor *= (math.factorial(matsize - 1) ** (-1.0 / (2 * matsize)))
7735
7736            for fn in [torch.det, torch.logdet, torch.slogdet, torch.linalg.slogdet]:
7737                expected_value = []
7738                actual_value = fn(full_tensor)
7739                for full_idx in itertools.product(*(list(range(x)) for x in batchdims)):
7740                    expected_value.append(fn(full_tensor[full_idx]))
7741
7742                if fn == torch.slogdet or fn == torch.linalg.slogdet:
7743                    sign_value = torch.stack([tup[0] for tup in expected_value], dim=0).reshape(batchdims)
7744                    expected_value = torch.stack([tup[1] for tup in expected_value], dim=0).reshape(batchdims)
7745                    self.assertEqual(sign_value, actual_value[0])
7746                    self.assertEqual(expected_value, actual_value[1])
7747                else:
7748                    expected_value = torch.stack(expected_value, dim=0).reshape(batchdims)
7749                    self.assertEqual(actual_value, expected_value)
7750
7751        for matsize, batchdims in itertools.product([3, 5], [(3,), (5, 3)]):
7752            run_test(matsize, batchdims, mat_chars=['sym_pd'])
7753            run_test(matsize, batchdims, mat_chars=['sing'])
7754            run_test(matsize, batchdims, mat_chars=['non_sing'])
7755            run_test(matsize, batchdims, mat_chars=['sym', 'sym_pd', 'sym_psd'])
7756            run_test(matsize, batchdims, mat_chars=['sing', 'non_sing'])
7757
7758    @skipCUDAIfNoMagma
7759    @skipCPUIfNoLapack
7760    @dtypes(*floating_and_complex_types())
7761    def test_cholesky_inverse(self, device, dtype):
7762        from torch.testing._internal.common_utils import random_hermitian_pd_matrix
7763
7764        def run_test(shape, batch, upper, contiguous):
7765            A = random_hermitian_pd_matrix(shape, *batch, dtype=dtype, device=device)
7766            if A.numel() > 0 and not contiguous:
7767                A = A.mT
7768                self.assertFalse(A.is_contiguous())
7769            L = torch.linalg.cholesky(A)
7770            expected_inverse = torch.inverse(A)
7771            L = L.mH if upper else L
7772            actual_inverse = torch.cholesky_inverse(L, upper)
7773            self.assertEqual(actual_inverse, expected_inverse)
7774
7775        shapes = (0, 3, 5)
7776        batches = ((), (0,), (3, ), (2, 2))
7777        for shape, batch, upper, contiguous in list(itertools.product(shapes, batches, (True, False), (True, False))):
7778            run_test(shape, batch, upper, contiguous)
7779
7780        # check the out= variant
7781        A = random_hermitian_pd_matrix(3, 2, dtype=dtype, device=device)
7782        L = torch.linalg.cholesky(A)
7783
7784        # There are two code paths currently for the out= variant
7785        # 1. When 'out' tensor is in Fortran (column-major) memory format
7786        # then the fast route is taken and the storage is reused directly in the computations
7787        # 2. When 'out' tensor is not in Fortran format then a temporary tensor is allocated internally
7788        # and the result is copied from the temporary tensor to 'out' tensor
7789
7790        # This test checks the first code path
7791        out = torch.empty_like(A)
7792        out_t = out.mT.clone(memory_format=torch.contiguous_format)
7793        out = out_t.mT
7794        ans = torch.cholesky_inverse(L, out=out)
7795        self.assertEqual(ans, out)
7796        expected = torch.inverse(A)
7797        self.assertEqual(expected, out)
7798
7799        # This test checks the second code path
7800        out = torch.empty_like(A)
7801        ans = torch.cholesky_inverse(L, out=out)
7802        self.assertEqual(ans, out)
7803        expected = torch.inverse(A)
7804        self.assertEqual(expected, out)
7805
7806    @skipCUDAIfNoMagma
7807    @skipCPUIfNoLapack
7808    @dtypes(*floating_and_complex_types())
7809    def test_cholesky_inverse_errors_and_warnings(self, device, dtype):
7810        # cholesky_inverse requires the input to be at least 2 dimensional tensor
7811        a = torch.randn(2, device=device, dtype=dtype)
7812        with self.assertRaisesRegex(RuntimeError, "must have at least 2 dimensions"):
7813            torch.cholesky_inverse(a)
7814
7815        # cholesky_inverse requires a square matrix
7816        a = torch.randn(2, 3, device=device, dtype=dtype)
7817        with self.assertRaisesRegex(RuntimeError, "must be batches of square matrices"):
7818            torch.cholesky_inverse(a)
7819
7820        # if non-empty out tensor with wrong shape is passed a warning is given
7821        a = torch.randn(3, 3, device=device, dtype=dtype)
7822        out = torch.empty(2, 3, device=device, dtype=dtype)
7823        with warnings.catch_warnings(record=True) as w:
7824            # Trigger warning
7825            torch.cholesky_inverse(a, out=out)
7826            # Check warning occurs
7827            self.assertEqual(len(w), 1)
7828            self.assertTrue("An output with one or more elements was resized" in str(w[-1].message))
7829
7830        # dtypes should be safely castable
7831        out = torch.empty(*a.shape, dtype=torch.int, device=device)
7832        with self.assertRaisesRegex(RuntimeError, "but got result with dtype Int"):
7833            torch.cholesky_inverse(a, out=out)
7834
7835        # device should match
7836        if torch.cuda.is_available():
7837            wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda'
7838            out = torch.empty(0, device=wrong_device, dtype=dtype)
7839            with self.assertRaisesRegex(RuntimeError, "Expected all tensors to be on the same device"):
7840                torch.cholesky_inverse(a, out=out)
7841
7842        # cholesky_inverse raises an error for invalid inputs on CPU
7843        # for example if at least one diagonal element is zero
7844        a = torch.randn(3, 3, device=device, dtype=dtype)
7845        a[1, 1] = 0
7846        if self.device_type == 'cpu':
7847            with self.assertRaisesRegex(torch.linalg.LinAlgError, r"cholesky_inverse: The diagonal element 2 is zero"):
7848                torch.cholesky_inverse(a)
7849        # cholesky_inverse on GPU does not raise an error for this case
7850        elif self.device_type == 'cuda':
7851            out = torch.cholesky_inverse(a)
7852            self.assertTrue(out.isinf().any() or out.isnan().any())
7853
7854    def _select_broadcastable_dims(self, dims_full=None):
7855        # select full dimensionality
7856        if dims_full is None:
7857            dims_full = []
7858            ndims = random.randint(1, 4)
7859            dims_full = [random.randint(1, 8) for _ in range(ndims)]
7860        else:
7861            ndims = len(dims_full)
7862
7863        # select actual dimensions for ops:
7864        # larger: full ndims, individual sizes may be reduced
7865        # smaller: possibly reduced ndims, sizes may be reduced
7866        smaller_ndims = random.randint(1, ndims)
7867        dims_small = []
7868        dims_large = []
7869        for i in range(ndims - 1, -1, -1):
7870            j = random.randint(1, 3)
7871            if j == 1:  # no reduced singleton dimension
7872                ds = dims_full[i]
7873                dl = dims_full[i]
7874            elif j == 2:  # larger may have reduced singleton dimension
7875                ds = dims_full[i]
7876                dl = 1 if len(dims_small) < smaller_ndims else dims_full[i]
7877            elif j == 3:  # smaller may have reduced singleton dimension
7878                ds = 1
7879                dl = dims_full[i]
7880            dims_large = [dl] + dims_large
7881            if len(dims_small) < smaller_ndims:
7882                dims_small = [ds] + dims_small
7883        return (dims_small, dims_large, dims_full)
7884
7885    def test_broadcast_fused_matmul(self, device):
7886        fns = ["baddbmm", "addbmm", "addmm", "addmv", "addr"]
7887
7888        for fn in fns:
7889            batch_dim = random.randint(1, 8)
7890            n_dim = random.randint(1, 8)
7891            m_dim = random.randint(1, 8)
7892            p_dim = random.randint(1, 8)
7893
7894            def dims_full_for_fn():
7895                if fn == "baddbmm":
7896                    return ([batch_dim, n_dim, p_dim], [batch_dim, n_dim, m_dim], [batch_dim, m_dim, p_dim])
7897                elif fn == "addbmm":
7898                    return ([n_dim, p_dim], [batch_dim, n_dim, m_dim], [batch_dim, m_dim, p_dim])
7899                elif fn == "addmm":
7900                    return ([n_dim, p_dim], [n_dim, m_dim], [m_dim, p_dim])
7901                elif fn == "addmv":
7902                    return ([n_dim], [n_dim, m_dim], [m_dim])
7903                elif fn == "addr":
7904                    return ([n_dim, m_dim], [n_dim], [m_dim])
7905                else:
7906                    raise AssertionError("unknown function")
7907
7908            (t0_dims_full, t1_dims, t2_dims) = dims_full_for_fn()
7909            (t0_dims_small, _, _) = self._select_broadcastable_dims(t0_dims_full)
7910
7911            t0_small = torch.randn(*t0_dims_small, device=device).float()
7912            t1 = torch.randn(*t1_dims, device=device).float()
7913            t2 = torch.randn(*t2_dims, device=device).float()
7914
7915            t0_full = t0_small.expand(*t0_dims_full).to(device)
7916
7917            fntorch = getattr(torch, fn)
7918            r0 = fntorch(t0_small, t1, t2)
7919            r1 = fntorch(t0_full, t1, t2)
7920            self.assertEqual(r0, r1)
7921
7922    @tf32_on_and_off(0.001)
7923    @bf32_on_and_off(0.001)
7924    def test_broadcast_batched_matmul(self, device):
7925        n_dim = random.randint(1, 8)
7926        m_dim = random.randint(1, 8)
7927        p_dim = random.randint(1, 8)
7928        full_batch_dims = [random.randint(1, 3) for i in range(random.randint(1, 3))]
7929        (batch_dims_small, _, _) = self._select_broadcastable_dims(full_batch_dims)
7930
7931        def verify_batched_matmul(full_lhs, one_dimensional):
7932            if not one_dimensional:
7933                lhs_dims = [n_dim, m_dim]
7934                rhs_dims = [m_dim, p_dim]
7935                result_dims = [n_dim, p_dim]
7936            else:
7937                lhs_dims = [n_dim, m_dim] if full_lhs else [m_dim]
7938                rhs_dims = [m_dim, p_dim] if not full_lhs else [m_dim]
7939                result_dims = [n_dim] if full_lhs else [p_dim]
7940
7941            lhs_mat_dims = lhs_dims if len(lhs_dims) != 1 else [1, m_dim]
7942            rhs_mat_dims = rhs_dims if len(rhs_dims) != 1 else [m_dim, 1]
7943            full_mat_dims = lhs_mat_dims if full_lhs else rhs_mat_dims
7944            dim0_dims = rhs_dims if full_lhs else lhs_dims
7945            small_dims = batch_dims_small + (rhs_mat_dims if full_lhs else lhs_mat_dims)
7946
7947            small = torch.randn(*(small_dims), device=device).float()
7948            dim0 = torch.randn(*(dim0_dims), device=device).float()
7949            full = torch.randn(*(full_batch_dims + full_mat_dims), device=device).float()
7950            if not one_dimensional:
7951                (lhsTensors, rhsTensors) = ((full,), (small, dim0)) if full_lhs else ((small, dim0), (full,))
7952            else:
7953                (lhsTensors, rhsTensors) = ((full,), (dim0,)) if full_lhs else ((dim0,), (full,))
7954
7955            def maybe_squeeze_result(l, r, result):
7956                if len(lhs_dims) == 1 and l.dim() != 1:
7957                    return result.squeeze(-2)
7958                elif len(rhs_dims) == 1 and r.dim() != 1:
7959                    return result.squeeze(-1)
7960                else:
7961                    return result
7962
7963            for lhs in lhsTensors:
7964                lhs_expanded = lhs.expand(*(torch.Size(full_batch_dims) + torch.Size(lhs_mat_dims)))
7965                lhs_expanded_matmul_fn = lhs_expanded.matmul
7966                for rhs in rhsTensors:
7967                    rhs_expanded = ((rhs if len(rhs_dims) != 1 else rhs.unsqueeze(-1)).
7968                                    expand(*(torch.Size(full_batch_dims) + torch.Size(rhs_mat_dims))))
7969                    truth = maybe_squeeze_result(lhs_expanded, rhs_expanded, lhs_expanded_matmul_fn(rhs_expanded))
7970                    for l in (lhs, lhs_expanded):
7971                        for r in (rhs, rhs_expanded):
7972                            l_matmul_fn = l.matmul
7973                            result = maybe_squeeze_result(l, r, l_matmul_fn(r))
7974                            self.assertEqual(truth, result)
7975                            # test torch.matmul function as well
7976                            torch_result = maybe_squeeze_result(l, r, torch.matmul(l, r))
7977                            self.assertEqual(truth, torch_result)
7978                            # test torch.matmul with out
7979                            out = torch.zeros_like(torch_result)
7980                            torch.matmul(l, r, out=out)
7981                            self.assertEqual(truth, maybe_squeeze_result(l, r, out))
7982
7983                # compare to bmm
7984                bmm_result = (torch.bmm(lhs_expanded.contiguous().view(-1, *lhs_mat_dims),
7985                                        rhs_expanded.contiguous().view(-1, *rhs_mat_dims)))
7986                self.assertEqual(truth.view(-1, *result_dims), bmm_result.view(-1, *result_dims))
7987
7988        for indices in itertools.product((True, False), repeat=2):
7989            verify_batched_matmul(*indices)
7990
7991    def lu_solve_test_helper(self, A_dims, b_dims, pivot, device, dtype):
7992        make_fullrank = make_fullrank_matrices_with_distinct_singular_values
7993        make_A = partial(make_fullrank, device=device, dtype=dtype)
7994
7995        b = torch.randn(*b_dims, dtype=dtype, device=device)
7996        A = make_A(*A_dims)
7997        LU_data, LU_pivots, info = torch.linalg.lu_factor_ex(A)
7998        self.assertEqual(info, torch.zeros_like(info))
7999        return b, A, LU_data, LU_pivots
8000
8001    @skipCPUIfNoLapack
8002    @skipCUDAIfNoMagmaAndNoCusolver
8003    @dtypes(*floating_and_complex_types())
8004    @precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3,
8005                        torch.float64: 1e-8, torch.complex128: 1e-8})
8006    def test_lu_solve(self, device, dtype):
8007        def sub_test(pivot):
8008            for k, n in zip([2, 3, 5], [3, 5, 7]):
8009                b, A, LU_data, LU_pivots = self.lu_solve_test_helper((n, n), (n, k), pivot, device, dtype)
8010                x = torch.lu_solve(b, LU_data, LU_pivots)
8011                self.assertEqual(b, np.matmul(A.cpu(), x.cpu()))
8012
8013        sub_test(True)
8014        if self.device_type == 'cuda':
8015            sub_test(False)
8016
8017    @skipCPUIfNoLapack
8018    @skipCUDAIfNoMagmaAndNoCusolver
8019    @dtypes(*floating_and_complex_types())
8020    @precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3,
8021                        torch.float64: 1e-8, torch.complex128: 1e-8})
8022    def test_lu_solve_batched(self, device, dtype):
8023        def sub_test(pivot):
8024            def lu_solve_batch_test_helper(A_dims, b_dims, pivot):
8025                b, A, LU_data, LU_pivots = self.lu_solve_test_helper(A_dims, b_dims, pivot, device, dtype)
8026                x_exp_list = []
8027                for i in range(b_dims[0]):
8028                    x_exp_list.append(torch.lu_solve(b[i], LU_data[i], LU_pivots[i]))
8029                x_exp = torch.stack(x_exp_list)  # Stacked output
8030                x_act = torch.lu_solve(b, LU_data, LU_pivots)  # Actual output
8031                self.assertEqual(x_exp, x_act)  # Equality check
8032                Ax = np.matmul(A.cpu(), x_act.cpu())
8033                self.assertEqual(b, Ax)
8034
8035            for batchsize in [1, 3, 4]:
8036                lu_solve_batch_test_helper((batchsize, 5, 5), (batchsize, 5, 10), pivot)
8037
8038        # Tests tensors with 0 elements
8039        b = torch.randn(3, 0, 3, dtype=dtype, device=device)
8040        A = torch.randn(3, 0, 0, dtype=dtype, device=device)
8041        LU_data, LU_pivots = torch.linalg.lu_factor(A)
8042        self.assertEqual(torch.empty_like(b), b.lu_solve(LU_data, LU_pivots))
8043
8044        sub_test(True)
8045        if self.device_type == 'cuda':
8046            sub_test(False)
8047
8048    @slowTest
8049    @skipCPUIfNoLapack
8050    @skipCUDAIfNoMagmaAndNoCusolver
8051    @dtypes(*floating_and_complex_types())
8052    def test_lu_solve_batched_many_batches(self, device, dtype):
8053        def run_test(A_dims, b_dims):
8054            b, A, LU_data, LU_pivots = self.lu_solve_test_helper(A_dims, b_dims, True, device, dtype)
8055            x = torch.lu_solve(b, LU_data, LU_pivots)
8056            Ax = torch.matmul(A, x)
8057            self.assertEqual(Ax, b.expand_as(Ax))
8058
8059        run_test((65536, 5, 5), (65536, 5, 10))
8060        run_test((262144, 5, 5), (262144, 5, 10))
8061
8062    @skipCPUIfNoLapack
8063    @skipCUDAIfNoMagmaAndNoCusolver
8064    @dtypes(*floating_and_complex_types())
8065    def test_lu_solve_batched_broadcasting(self, device, dtype):
8066        make_fullrank = make_fullrank_matrices_with_distinct_singular_values
8067        make_A = partial(make_fullrank, device=device, dtype=dtype)
8068
8069        def run_test(A_dims, b_dims, pivot=True):
8070            A_matrix_size = A_dims[-1]
8071            A_batch_dims = A_dims[:-2]
8072            A = make_A(*A_batch_dims, A_matrix_size, A_matrix_size)
8073            b = make_tensor(b_dims, dtype=dtype, device=device)
8074            x_exp = np.linalg.solve(A.cpu(), b.cpu())
8075            LU_data, LU_pivots = torch.linalg.lu_factor(A)
8076            x = torch.lu_solve(b, LU_data, LU_pivots)
8077            self.assertEqual(x, x_exp)
8078
8079        # test against numpy.linalg.solve
8080        run_test((2, 1, 3, 4, 4), (2, 1, 3, 4, 6))  # no broadcasting
8081        run_test((2, 1, 3, 4, 4), (4, 6))  # broadcasting b
8082        run_test((4, 4), (2, 1, 3, 4, 2))  # broadcasting A
8083        run_test((1, 3, 1, 4, 4), (2, 1, 3, 4, 5))  # broadcasting A & b
8084
8085    @onlyCUDA
8086    @skipCUDAIfNoMagma
8087    @dtypes(*floating_and_complex_types())
8088    # this tests https://github.com/pytorch/pytorch/issues/36921
8089    def test_lu_solve_large_matrices(self, device, dtype):
8090        def run_test(A_dims, b_dims):
8091            b, A, LU_data, LU_pivots = self.lu_solve_test_helper(A_dims, b_dims, True, device, dtype)
8092            x = torch.lu_solve(b, LU_data, LU_pivots)
8093            Ax = torch.matmul(A, x)
8094            self.assertEqual(Ax, b.expand_as(Ax))
8095
8096        run_test((1, 1), (1, 1, 1025))
8097
8098    @skipCUDAIfNoCusolver
8099    @skipCPUIfNoLapack
8100    def test_pca_lowrank(self, device):
8101        from torch.testing._internal.common_utils import random_lowrank_matrix, random_sparse_matrix
8102
8103        dtype = torch.double
8104
8105        def run_subtest(guess_rank, actual_rank, matrix_size, batches, device, pca, **options):
8106            density = options.pop('density', 1)
8107            use_svd_lowrank = options.pop('use_svd_lowrank', False)
8108            if isinstance(matrix_size, int):
8109                rows = columns = matrix_size
8110            else:
8111                rows, columns = matrix_size
8112            if density == 1:
8113                a_input = random_lowrank_matrix(actual_rank, rows, columns, *batches, device=device, dtype=dtype)
8114                a = a_input
8115            else:
8116                a_input = random_sparse_matrix(rows, columns, density, device=device, dtype=dtype)
8117                a = a_input.to_dense()
8118
8119            if use_svd_lowrank:
8120                m = a_input.mean(dim=-2, keepdim=True)
8121                u, s, v = pca(a_input, q=guess_rank, M=m, **options)
8122            else:
8123                u, s, v = pca(a_input, q=guess_rank, **options)
8124
8125            self.assertEqual(s.shape[-1], guess_rank)
8126            self.assertEqual(u.shape[-2], rows)
8127            self.assertEqual(u.shape[-1], guess_rank)
8128            self.assertEqual(v.shape[-1], guess_rank)
8129            self.assertEqual(v.shape[-2], columns)
8130
8131            A1 = u.matmul(s.diag_embed()).matmul(v.mT)
8132            ones_m1 = torch.ones(batches + (rows, 1), dtype=a.dtype, device=device)
8133            c = a.sum(axis=-2) / rows
8134            c = c.reshape(batches + (1, columns))
8135            A2 = a - ones_m1.matmul(c)
8136            self.assertEqual(A1, A2)
8137
8138            if density == 1:
8139                # actual rank is known only for dense input
8140                detect_rank = (s.abs() > 1e-5).sum(axis=-1)
8141                self.assertEqual(actual_rank * torch.ones(batches, device=device, dtype=torch.int64), detect_rank)
8142                S = torch.linalg.svdvals(A2)
8143                self.assertEqual(s[..., :actual_rank], S[..., :actual_rank])
8144
8145        all_batches = [(), (1,), (3,), (2, 3)]
8146        for actual_rank, size, all_batches in [  # noqa: B020
8147                (2, (17, 4), all_batches),
8148                (2, (100, 4), all_batches),
8149                (6, (100, 40), all_batches),
8150                (12, (1000, 1000), [()]),
8151        ]:
8152            for batches in all_batches:
8153                for guess_rank in [
8154                        actual_rank,
8155                        actual_rank + 2,
8156                        actual_rank + 6,
8157                ]:
8158                    if guess_rank <= min(*size):
8159                        run_subtest(guess_rank, actual_rank, size, batches, device, torch.pca_lowrank)
8160                        run_subtest(guess_rank, actual_rank, size[::-1], batches, device, torch.pca_lowrank)
8161                        run_subtest(guess_rank, actual_rank, size, batches, device, torch.svd_lowrank, use_svd_lowrank=True)
8162                        run_subtest(guess_rank, actual_rank, size[::-1], batches, device, torch.svd_lowrank, use_svd_lowrank=True)
8163
8164        # sparse input
8165        for guess_rank, size in [
8166                (4, (17, 4)), (4, (4, 17)), (16, (17, 17)),
8167                (21, (100, 40)), (20, (40, 100)), (600, (1000, 1000))]:
8168            for density in [0.005, 0.1]:
8169                run_subtest(guess_rank, None, size, (), device, torch.pca_lowrank, density=density)
8170
8171        # jitting support
8172        jitted = torch.jit.script(torch.pca_lowrank)
8173        guess_rank, actual_rank, size, batches = 2, 2, (17, 4), ()
8174        run_subtest(guess_rank, actual_rank, size, batches, device, jitted)
8175
8176    # Ensure that nuclear_norm's out variant gives the same result as the non-out
8177    @onlyNativeDeviceTypes
8178    @skipCUDAIfNoMagma
8179    @skipCPUIfNoLapack
8180    @dtypes(torch.float32, torch.float64)
8181    def test_nuclear_norm_out(self, device, dtype):
8182        test_cases = [
8183            # input size, dim
8184            ((25, 25), None),
8185            ((25, 25), (0, 1)),
8186            ((25, 25), (1, 0)),
8187            ((25, 25, 25), (2, 0)),
8188            ((25, 25, 25), (0, 1)),
8189        ]
8190        for keepdim in [False, True]:
8191            for input_size, dim in test_cases:
8192                msg = f'input_size: {input_size}, dim: {dim}, keepdim: {keepdim}'
8193                x = torch.randn(*input_size, device=device, dtype=dtype)
8194                result_out = torch.empty(0, device=device, dtype=dtype)
8195                if dim is None:
8196                    result = torch.nuclear_norm(x, keepdim=keepdim)
8197                    torch.nuclear_norm(x, keepdim=keepdim, out=result_out)
8198                else:
8199                    result = torch.nuclear_norm(x, keepdim=keepdim, dim=dim)
8200                    torch.nuclear_norm(x, keepdim=keepdim, dim=dim, out=result_out)
8201                self.assertEqual(result, result_out, msg=msg)
8202
8203    @skipCUDAIfNoMagmaAndNoCusolver
8204    @skipCPUIfNoLapack
8205    @dtypes(*floating_and_complex_types())
8206    def test_geqrf(self, device, dtype):
8207
8208        def run_test(shape):
8209            # numpy.linalg.qr with mode = 'raw' computes the same operation as torch.geqrf
8210            # so this test compares against that function
8211            A = make_tensor(shape, dtype=dtype, device=device)
8212
8213            # numpy.linalg.qr doesn't work with batched input
8214            m, n = A.shape[-2:]
8215            tau_size = "n" if m > n else "m"
8216            np_dtype = A.cpu().numpy().dtype
8217            ot = [np_dtype, np_dtype]
8218            numpy_geqrf_batched = np.vectorize(
8219                lambda x: np.linalg.qr(x, mode='raw'),
8220                otypes=ot,
8221                signature=f'(m,n)->(n,m),({tau_size})')
8222
8223            expected = numpy_geqrf_batched(A.cpu())
8224            actual = torch.geqrf(A)
8225
8226            # numpy.linalg.qr returns transposed result
8227            self.assertEqual(expected[0].swapaxes(-2, -1), actual[0])
8228            self.assertEqual(expected[1], actual[1])
8229
8230        batches = [(), (0, ), (2, ), (2, 1)]
8231        ns = [5, 2, 0]
8232        for batch, (m, n) in product(batches, product(ns, ns)):
8233            run_test((*batch, m, n))
8234
8235    @skipCUDAIfNoMagma
8236    @skipCPUIfNoLapack
8237    def test_lapack_empty(self, device):
8238        # FIXME: these are just a selection of LAPACK functions -- we need a general strategy here.
8239        # The LAPACK functions themselves generally do NOT work with zero sized dimensions, although
8240        # numpy/sci often has a direct wrapper (e.g. lu_factor) and a wrapper that "does the right thing"
8241        # (e.g. lu).  We often name our functions identically to the lapack function, so it will take work
8242        # to name / migrate-to better wrappers.
8243        def fn(torchfn, *args):
8244            return torchfn(*tuple(torch.randn(shape, device=device) if isinstance(shape, tuple) else shape
8245                                  for shape in args))
8246
8247        # inverse, pinverse
8248        self.assertEqual((0, 0), fn(torch.inverse, (0, 0)).shape)
8249        self.assertEqual((5, 0), fn(torch.pinverse, (0, 5)).shape)
8250        self.assertEqual((0, 5), fn(torch.pinverse, (5, 0)).shape)
8251        self.assertEqual((0, 0), fn(torch.pinverse, (0, 0)).shape)
8252
8253        # det, logdet, slogdet
8254        self.assertEqual(torch.tensor(1., device=device), fn(torch.det, (0, 0)))
8255        self.assertEqual(torch.tensor(0., device=device), fn(torch.logdet, (0, 0)))
8256        self.assertEqual((torch.tensor(1., device=device), torch.tensor(0., device=device)),
8257                         fn(torch.slogdet, (0, 0)))
8258
8259    @tf32_on_and_off(0.005)
8260    @bf32_on_and_off(0.005)
8261    def test_tensordot(self, device):
8262        a = torch.arange(60., device=device).reshape(3, 4, 5)
8263        b = torch.arange(24., device=device).reshape(4, 3, 2)
8264        c = torch.tensordot(a, b, dims=([1, 0], [0, 1])).cpu()
8265        cn = torch.from_numpy(np.tensordot(a.cpu().numpy(), b.cpu().numpy(),
8266                                           axes=([1, 0], [0, 1])))
8267        self.assertEqual(c, cn)
8268
8269        cout = torch.zeros((5, 2), device=device)
8270        torch.tensordot(a, b, dims=([1, 0], [0, 1]), out=cout).cpu()
8271        self.assertEqual(c, cout)
8272
8273        a = torch.randn(2, 3, 4, 5, device=device)
8274        b = torch.randn(4, 5, 6, 7, device=device)
8275        c = torch.tensordot(a, b, dims=2).cpu()
8276        cn = torch.from_numpy(np.tensordot(a.cpu().numpy(), b.cpu().numpy(),
8277                                           axes=2))
8278
8279        with self.assertRaisesRegex(RuntimeError, "expects dims >= 0"):
8280            torch.tensordot(a, b, dims=-1)
8281
8282        self.assertEqual(c, cn)
8283        c = torch.tensordot(a, b).cpu()
8284        cn = torch.from_numpy(np.tensordot(a.cpu().numpy(), b.cpu().numpy()))
8285        self.assertEqual(c, cn)
8286
8287        a = torch.tensordot(torch.tensor(0.), torch.tensor(0.), 0)
8288        an = torch.from_numpy(np.tensordot(np.zeros((), dtype=np.float32), np.zeros((), dtype=np.float32), 0))
8289        self.assertEqual(a, an)
8290
8291    @skipCUDAIfNoCusolver
8292    @skipCUDAIfNoMagma
8293    @skipCPUIfNoLapack
8294    @skipIfTorchDynamo("flaky, needs investigation")
8295    @dtypes(*floating_and_complex_types())
8296    def test_ldl_factor(self, device, dtype):
8297        from torch.testing._internal.common_utils import random_hermitian_pd_matrix
8298
8299        def run_test(shape, batch, hermitian):
8300            A = random_hermitian_pd_matrix(shape, *batch, dtype=dtype, device=device)
8301            actual_factors, actual_pivots, info = torch.linalg.ldl_factor_ex(A, hermitian=hermitian)
8302            actual_L = torch.tril(actual_factors, diagonal=-1)
8303            actual_L.diagonal(0, -2, -1).fill_(1.0)
8304
8305            # This test is designed only for inputs with 1x1 block diagonal matrix D.
8306            # That is for positive definite input matrices, the pivots tensor is always > 0.
8307            # If negative pivots are encountered, it means that the input matrix is not positive definite.
8308            # And matrix D is a 2x2 block diagonal matrix.
8309            self.assertTrue((actual_pivots > 0).all())
8310
8311            # Construct a 1x1 block diagonal matrix D from factors.
8312            actual_D = torch.diag_embed(actual_factors.diagonal(0, -2, -1))
8313
8314            def T(x):
8315                return x.mH if hermitian else x.mT
8316            A_reconstructed = actual_L @ actual_D @ T(actual_L)
8317
8318            def symmetric(A):
8319                return A.tril() + A.tril(-1).mT
8320
8321            self.assertEqual(symmetric(A) if not hermitian else A, A_reconstructed)
8322
8323            # Now test against SciPy implementation
8324            if TEST_SCIPY:
8325                from scipy.linalg import ldl as scipy_ldl
8326                A_np = A.cpu().numpy()
8327                np_dtype = A_np.dtype
8328                scipy_ldl_batched = np.vectorize(
8329                    lambda x: scipy_ldl(x, hermitian=hermitian, lower=True),
8330                    otypes=[np_dtype, np_dtype, np.dtype('int64')],
8331                    signature='(m,m)->(m,m),(m,m),(m)')
8332
8333                expected = scipy_ldl_batched(A_np)
8334                expected_L, expected_D, expected_pivots = expected
8335
8336                if expected_pivots.ndim > 1:
8337                    permuted_expected_L = np.stack(
8338                        [expected_L[i][expected_pivots[i], :] for i in range(expected_pivots.shape[0])]
8339                    )
8340                else:
8341                    permuted_expected_L = expected_L[expected_pivots, :]
8342                self.assertEqual(actual_L, permuted_expected_L)
8343                self.assertEqual(actual_D, expected_D)
8344            else:
8345                self.assertEqual(actual_factors.shape, A.shape)
8346                self.assertEqual(actual_pivots.shape, A.shape[:-1])
8347                self.assertEqual(info.shape, A.shape[:-2])
8348
8349        # hermitian=True for complex inputs on CUDA is supported only with MAGMA 2.5.4+
8350        magma_254_available = self.device_type == 'cuda' and _get_magma_version() >= (2, 5, 4)
8351        hermitians = (True, False) if dtype.is_complex and (self.device_type == 'cpu' or magma_254_available) else (False,)
8352
8353        shapes = (5,)
8354        batches = ((), (4,),)
8355        for shape, batch, hermitian in itertools.product(shapes, batches, hermitians):
8356            run_test(shape, batch, hermitian)
8357
8358    @skipCUDAIfNoCusolver
8359    @skipCUDAIfNoMagma
8360    @skipCPUIfNoLapack
8361    @skipCUDAIfRocm
8362    @skipCUDAIf(_get_torch_cuda_version() < (11, 4), "not available before CUDA 11.3.1")
8363    @dtypes(*floating_and_complex_types())
8364    def test_ldl_solve(self, device, dtype):
8365        from torch.testing._internal.common_utils import random_hermitian_pd_matrix
8366
8367        def run_test(shape, batch, nrhs, hermitian):
8368            A = random_hermitian_pd_matrix(shape, *batch, dtype=dtype, device=device)
8369            B = make_tensor((*A.shape[:-1], nrhs), dtype=dtype, device=device)
8370            factors, pivots, info = torch.linalg.ldl_factor_ex(A, hermitian=hermitian)
8371            X = torch.linalg.ldl_solve(factors, pivots, B, hermitian=hermitian)
8372
8373            def symmetric(A):
8374                return A.tril() + A.tril(-1).mT
8375
8376            # verify A @ X == B
8377            expected_B = symmetric(A) @ X if not hermitian else A @ X
8378            self.assertEqual(B, expected_B)
8379
8380        # hermitian=True is not supported on CUDA yet
8381        hermitians = (True, False) if dtype.is_complex and self.device_type == 'cpu' else (False,)
8382
8383        shapes = (5,)
8384        batches = ((), (4,), (2, 2))
8385        nrhss = (1, 7)
8386        for shape, batch, nrhs, hermitian in itertools.product(shapes, batches, nrhss, hermitians):
8387            run_test(shape, batch, nrhs, hermitian)
8388
8389    @onlyCUDA
8390    @skipCUDAIfNoMagma
8391    @skipCUDAIfNoCusolver
8392    @setLinalgBackendsToDefaultFinally
8393    def test_preferred_linalg_library(self):
8394        # The main purpose of this test is to make sure these "backend" calls work normally without raising exceptions.
8395        x = torch.randint(2, 5, (2, 4, 4), device='cuda', dtype=torch.double)
8396
8397        torch.backends.cuda.preferred_linalg_library('cusolver')
8398        out1 = torch.linalg.inv(x)
8399
8400        torch.backends.cuda.preferred_linalg_library('magma')
8401        out2 = torch.linalg.inv(x)
8402
8403        torch.backends.cuda.preferred_linalg_library('default')
8404        # Although linalg preferred flags doesn't affect CPU currently,
8405        # we set this to make sure the flag can switch back to default normally.
8406        out_ref = torch.linalg.inv(x.cpu())
8407
8408        self.assertEqual(out_ref, out1.cpu())
8409        self.assertEqual(out1, out2)
8410
8411    @onlyCUDA
8412    @unittest.skipIf(not blaslt_supported_device(), "blasLt not supported on current device")
8413    @setBlasBackendsToDefaultFinally
8414    def test_preferred_blas_library(self):
8415        # The main purpose of this test is to make sure these "backend" calls work normally without raising exceptions.
8416        m1 = torch.randint(2, 5, (2048, 2400), device='cuda', dtype=torch.float)
8417        m2 = torch.randint(2, 5, (128, 2400), device='cuda', dtype=torch.float)
8418
8419        torch.backends.cuda.preferred_blas_library('cublaslt')
8420        out1 = torch.nn.functional.linear(m1, m2)
8421
8422        torch.backends.cuda.preferred_blas_library('cublas')
8423        out2 = torch.nn.functional.linear(m1, m2)
8424
8425        # Although blas preferred flags doesn't affect CPU currently,
8426        # we set this to make sure the flag can switch back to default normally.
8427        out_ref = torch.nn.functional.linear(m1.cpu(), m2.cpu())
8428
8429        self.assertEqual(out1, out2)
8430        self.assertEqual(out_ref, out2.cpu())
8431
8432    def test_permute_matmul(self):
8433        a = torch.ones([2, 5, 24, 24])
8434        b = torch.ones([3, 2, 5, 24, 24])
8435        c = a.permute(0, 1, 3, 2).matmul(b)
8436        self.assertEqual([c.min(), c.max(), c.sum()], [24, 24, 414720])
8437
8438    def test_lower_precision_accumulation_with_ref_path(self):
8439        # fix https://github.com/pytorch/pytorch/issues/95125
8440        # and https://github.com/pytorch/pytorch/issues/83863
8441        # for bf16 accumulation in gemm ref path
8442        def check_correctness(fn, dtype, *args):
8443            expected = fn(*args).to(dtype=dtype)
8444            with torch.backends.mkldnn.flags(enabled=False):
8445                def test():
8446                    lower_args = (arg.to(dtype=dtype) for arg in args)
8447                    tmp_result = fn(*lower_args)
8448                    return tmp_result
8449                c = test()
8450                assert (torch.all(c == expected)), "Incorrect result with\n" \
8451                                                   f"expected: {expected}\n" \
8452                                                   f"got: {c}\n"
8453        # test matmul
8454        for dtype in [torch.bfloat16, torch.half]:
8455            for transa in [True, False]:
8456                for transb in [True, False]:
8457                    a = torch.ones(300, 300)
8458                    b = torch.ones(300, 300)
8459                    if transa:
8460                        a = a.transpose(0, 1).contiguous().transpose(0, 1)
8461                    if transb:
8462                        b = b.transpose(0, 1).contiguous().transpose(0, 1)
8463                    check_correctness(torch.matmul, dtype, a, b)
8464        # test bmm
8465        a = torch.ones(1, 1, 300)
8466        b = torch.ones(1, 300, 1)
8467        check_correctness(torch.bmm, torch.bfloat16, a, b)
8468        check_correctness(torch.bmm, torch.half, a, b)
8469        # test baddbmm
8470        a = torch.ones(1, 1, 300)
8471        b = torch.ones(1, 300, 1)
8472        c = torch.ones(1, 1, 1)
8473        check_correctness(torch.baddbmm, torch.bfloat16, c, a, b)
8474        check_correctness(torch.baddbmm, torch.half, c, a, b)
8475        # test mv/addmv
8476        for dtype in [torch.bfloat16, torch.half]:
8477            for trans in [True, False]:
8478                c = torch.ones(300) * -300
8479                a = torch.ones(300, 300)
8480                if trans:
8481                    a = a.transpose(0, 1).contiguous().transpose(0, 1)
8482                b = torch.ones(300)
8483                check_correctness(torch.mv, dtype, a, b)
8484                check_correctness(torch.addmv, dtype, c, a, b)
8485        # test dot
8486        a = torch.ones(300)
8487        b = torch.ones(300)
8488        check_correctness(torch.dot, torch.bfloat16, a, b)
8489        check_correctness(torch.dot, torch.half, a, b)
8490
8491    @dtypes(torch.float, torch.half, torch.bfloat16)
8492    @parametrize("transpose_a", [True, False])
8493    @parametrize("transpose_b", [True, False])
8494    @parametrize("alpha", [0.0, 0.2, 1.0])
8495    @parametrize("beta", [0.0, 0.5, 1.0])
8496    def test_addmm_mv(self, device, dtype, transpose_a, transpose_b, alpha, beta):
8497        def gen_mat(w, h, use_transpose: bool = False):
8498            if not use_transpose:
8499                return torch.rand(w, h, dtype=dtype, device=device)
8500            return torch.rand(h, w, dtype=dtype, device=device).t()
8501        # Regression tests for https://github.com/pytorch/pytorch/issues/136299
8502        # Should only expose problems on aarch64, but let's be thorough
8503        m, n , k = 1, 8, 32
8504        A = gen_mat(m, k, transpose_a)
8505        B = gen_mat(k, n, transpose_b)
8506        C = torch.ones(m, n, dtype=dtype, device=device)
8507        rc = torch.addmm(C, A, B, alpha=alpha, beta=beta)
8508        ref = alpha * A @ B + beta * C
8509        self.assertEqual(rc, ref)
8510
8511
8512    @dtypes(torch.float, torch.double)
8513    @precisionOverride({torch.float32: 1e-4})
8514    def test_1_sized_with_0_strided(self, device, dtype):
8515        a = make_tensor((8, 1, 64), dtype=dtype, device=device)
8516        a_strided = torch.as_strided(a, size=[8, 1, 64], stride=[64, 0, 1])
8517        b = make_tensor((8, 64, 512), dtype=dtype, device=device)
8518        b_strided = torch.as_strided(b, size=[8, 64, 512], stride=[64, 1, 512])
8519        res = torch.bmm(a_strided, b_strided)
8520        expect = torch.from_numpy(
8521            a_strided.cpu().numpy() @ b_strided.cpu().numpy()).to(device=device, dtype=dtype)
8522        self.assertEqual(expect, res)
8523
8524instantiate_device_type_tests(TestLinalg, globals())
8525
8526if __name__ == '__main__':
8527    TestCase._default_dtype_check_enabled = True
8528    run_tests()
8529