1# mypy: ignore-errors 2 3from __future__ import annotations 4 5import functools 6import math 7from typing import Sequence 8 9import torch 10 11from . import _dtypes_impl, _util 12from ._normalizations import ArrayLike, KeepDims, normalizer 13 14 15class LinAlgError(Exception): 16 pass 17 18 19def _atleast_float_1(a): 20 if not (a.dtype.is_floating_point or a.dtype.is_complex): 21 a = a.to(_dtypes_impl.default_dtypes().float_dtype) 22 return a 23 24 25def _atleast_float_2(a, b): 26 dtyp = _dtypes_impl.result_type_impl(a, b) 27 if not (dtyp.is_floating_point or dtyp.is_complex): 28 dtyp = _dtypes_impl.default_dtypes().float_dtype 29 30 a = _util.cast_if_needed(a, dtyp) 31 b = _util.cast_if_needed(b, dtyp) 32 return a, b 33 34 35def linalg_errors(func): 36 @functools.wraps(func) 37 def wrapped(*args, **kwds): 38 try: 39 return func(*args, **kwds) 40 except torch._C._LinAlgError as e: 41 raise LinAlgError(*e.args) # noqa: B904 42 43 return wrapped 44 45 46# ### Matrix and vector products ### 47 48 49@normalizer 50@linalg_errors 51def matrix_power(a: ArrayLike, n): 52 a = _atleast_float_1(a) 53 return torch.linalg.matrix_power(a, n) 54 55 56@normalizer 57@linalg_errors 58def multi_dot(inputs: Sequence[ArrayLike], *, out=None): 59 return torch.linalg.multi_dot(inputs) 60 61 62# ### Solving equations and inverting matrices ### 63 64 65@normalizer 66@linalg_errors 67def solve(a: ArrayLike, b: ArrayLike): 68 a, b = _atleast_float_2(a, b) 69 return torch.linalg.solve(a, b) 70 71 72@normalizer 73@linalg_errors 74def lstsq(a: ArrayLike, b: ArrayLike, rcond=None): 75 a, b = _atleast_float_2(a, b) 76 # NumPy is using gelsd: https://github.com/numpy/numpy/blob/v1.24.0/numpy/linalg/umath_linalg.cpp#L3991 77 # on CUDA, only `gels` is available though, so use it instead 78 driver = "gels" if a.is_cuda or b.is_cuda else "gelsd" 79 return torch.linalg.lstsq(a, b, rcond=rcond, driver=driver) 80 81 82@normalizer 83@linalg_errors 84def inv(a: ArrayLike): 85 a = _atleast_float_1(a) 86 result = torch.linalg.inv(a) 87 return result 88 89 90@normalizer 91@linalg_errors 92def pinv(a: ArrayLike, rcond=1e-15, hermitian=False): 93 a = _atleast_float_1(a) 94 return torch.linalg.pinv(a, rtol=rcond, hermitian=hermitian) 95 96 97@normalizer 98@linalg_errors 99def tensorsolve(a: ArrayLike, b: ArrayLike, axes=None): 100 a, b = _atleast_float_2(a, b) 101 return torch.linalg.tensorsolve(a, b, dims=axes) 102 103 104@normalizer 105@linalg_errors 106def tensorinv(a: ArrayLike, ind=2): 107 a = _atleast_float_1(a) 108 return torch.linalg.tensorinv(a, ind=ind) 109 110 111# ### Norms and other numbers ### 112 113 114@normalizer 115@linalg_errors 116def det(a: ArrayLike): 117 a = _atleast_float_1(a) 118 return torch.linalg.det(a) 119 120 121@normalizer 122@linalg_errors 123def slogdet(a: ArrayLike): 124 a = _atleast_float_1(a) 125 return torch.linalg.slogdet(a) 126 127 128@normalizer 129@linalg_errors 130def cond(x: ArrayLike, p=None): 131 x = _atleast_float_1(x) 132 133 # check if empty 134 # cf: https://github.com/numpy/numpy/blob/v1.24.0/numpy/linalg/linalg.py#L1744 135 if x.numel() == 0 and math.prod(x.shape[-2:]) == 0: 136 raise LinAlgError("cond is not defined on empty arrays") 137 138 result = torch.linalg.cond(x, p=p) 139 140 # Convert nans to infs (numpy does it in a data-dependent way, depending on 141 # whether the input array has nans or not) 142 # XXX: NumPy does this: https://github.com/numpy/numpy/blob/v1.24.0/numpy/linalg/linalg.py#L1744 143 return torch.where(torch.isnan(result), float("inf"), result) 144 145 146@normalizer 147@linalg_errors 148def matrix_rank(a: ArrayLike, tol=None, hermitian=False): 149 a = _atleast_float_1(a) 150 151 if a.ndim < 2: 152 return int((a != 0).any()) 153 154 if tol is None: 155 # follow https://github.com/numpy/numpy/blob/v1.24.0/numpy/linalg/linalg.py#L1885 156 atol = 0 157 rtol = max(a.shape[-2:]) * torch.finfo(a.dtype).eps 158 else: 159 atol, rtol = tol, 0 160 return torch.linalg.matrix_rank(a, atol=atol, rtol=rtol, hermitian=hermitian) 161 162 163@normalizer 164@linalg_errors 165def norm(x: ArrayLike, ord=None, axis=None, keepdims: KeepDims = False): 166 x = _atleast_float_1(x) 167 return torch.linalg.norm(x, ord=ord, dim=axis) 168 169 170# ### Decompositions ### 171 172 173@normalizer 174@linalg_errors 175def cholesky(a: ArrayLike): 176 a = _atleast_float_1(a) 177 return torch.linalg.cholesky(a) 178 179 180@normalizer 181@linalg_errors 182def qr(a: ArrayLike, mode="reduced"): 183 a = _atleast_float_1(a) 184 result = torch.linalg.qr(a, mode=mode) 185 if mode == "r": 186 # match NumPy 187 result = result.R 188 return result 189 190 191@normalizer 192@linalg_errors 193def svd(a: ArrayLike, full_matrices=True, compute_uv=True, hermitian=False): 194 a = _atleast_float_1(a) 195 if not compute_uv: 196 return torch.linalg.svdvals(a) 197 198 # NB: ignore the hermitian= argument (no pytorch equivalent) 199 result = torch.linalg.svd(a, full_matrices=full_matrices) 200 return result 201 202 203# ### Eigenvalues and eigenvectors ### 204 205 206@normalizer 207@linalg_errors 208def eig(a: ArrayLike): 209 a = _atleast_float_1(a) 210 w, vt = torch.linalg.eig(a) 211 212 if not a.is_complex() and w.is_complex() and (w.imag == 0).all(): 213 w = w.real 214 vt = vt.real 215 return w, vt 216 217 218@normalizer 219@linalg_errors 220def eigh(a: ArrayLike, UPLO="L"): 221 a = _atleast_float_1(a) 222 return torch.linalg.eigh(a, UPLO=UPLO) 223 224 225@normalizer 226@linalg_errors 227def eigvals(a: ArrayLike): 228 a = _atleast_float_1(a) 229 result = torch.linalg.eigvals(a) 230 if not a.is_complex() and result.is_complex() and (result.imag == 0).all(): 231 result = result.real 232 return result 233 234 235@normalizer 236@linalg_errors 237def eigvalsh(a: ArrayLike, UPLO="L"): 238 a = _atleast_float_1(a) 239 return torch.linalg.eigvalsh(a, UPLO=UPLO) 240