xref: /aosp_15_r20/external/pytorch/torch/_numpy/linalg.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: ignore-errors
2
3from __future__ import annotations
4
5import functools
6import math
7from typing import Sequence
8
9import torch
10
11from . import _dtypes_impl, _util
12from ._normalizations import ArrayLike, KeepDims, normalizer
13
14
15class LinAlgError(Exception):
16    pass
17
18
19def _atleast_float_1(a):
20    if not (a.dtype.is_floating_point or a.dtype.is_complex):
21        a = a.to(_dtypes_impl.default_dtypes().float_dtype)
22    return a
23
24
25def _atleast_float_2(a, b):
26    dtyp = _dtypes_impl.result_type_impl(a, b)
27    if not (dtyp.is_floating_point or dtyp.is_complex):
28        dtyp = _dtypes_impl.default_dtypes().float_dtype
29
30    a = _util.cast_if_needed(a, dtyp)
31    b = _util.cast_if_needed(b, dtyp)
32    return a, b
33
34
35def linalg_errors(func):
36    @functools.wraps(func)
37    def wrapped(*args, **kwds):
38        try:
39            return func(*args, **kwds)
40        except torch._C._LinAlgError as e:
41            raise LinAlgError(*e.args)  # noqa: B904
42
43    return wrapped
44
45
46# ### Matrix and vector products ###
47
48
49@normalizer
50@linalg_errors
51def matrix_power(a: ArrayLike, n):
52    a = _atleast_float_1(a)
53    return torch.linalg.matrix_power(a, n)
54
55
56@normalizer
57@linalg_errors
58def multi_dot(inputs: Sequence[ArrayLike], *, out=None):
59    return torch.linalg.multi_dot(inputs)
60
61
62# ### Solving equations and inverting matrices ###
63
64
65@normalizer
66@linalg_errors
67def solve(a: ArrayLike, b: ArrayLike):
68    a, b = _atleast_float_2(a, b)
69    return torch.linalg.solve(a, b)
70
71
72@normalizer
73@linalg_errors
74def lstsq(a: ArrayLike, b: ArrayLike, rcond=None):
75    a, b = _atleast_float_2(a, b)
76    # NumPy is using gelsd: https://github.com/numpy/numpy/blob/v1.24.0/numpy/linalg/umath_linalg.cpp#L3991
77    # on CUDA, only `gels` is available though, so use it instead
78    driver = "gels" if a.is_cuda or b.is_cuda else "gelsd"
79    return torch.linalg.lstsq(a, b, rcond=rcond, driver=driver)
80
81
82@normalizer
83@linalg_errors
84def inv(a: ArrayLike):
85    a = _atleast_float_1(a)
86    result = torch.linalg.inv(a)
87    return result
88
89
90@normalizer
91@linalg_errors
92def pinv(a: ArrayLike, rcond=1e-15, hermitian=False):
93    a = _atleast_float_1(a)
94    return torch.linalg.pinv(a, rtol=rcond, hermitian=hermitian)
95
96
97@normalizer
98@linalg_errors
99def tensorsolve(a: ArrayLike, b: ArrayLike, axes=None):
100    a, b = _atleast_float_2(a, b)
101    return torch.linalg.tensorsolve(a, b, dims=axes)
102
103
104@normalizer
105@linalg_errors
106def tensorinv(a: ArrayLike, ind=2):
107    a = _atleast_float_1(a)
108    return torch.linalg.tensorinv(a, ind=ind)
109
110
111# ### Norms and other numbers ###
112
113
114@normalizer
115@linalg_errors
116def det(a: ArrayLike):
117    a = _atleast_float_1(a)
118    return torch.linalg.det(a)
119
120
121@normalizer
122@linalg_errors
123def slogdet(a: ArrayLike):
124    a = _atleast_float_1(a)
125    return torch.linalg.slogdet(a)
126
127
128@normalizer
129@linalg_errors
130def cond(x: ArrayLike, p=None):
131    x = _atleast_float_1(x)
132
133    # check if empty
134    # cf: https://github.com/numpy/numpy/blob/v1.24.0/numpy/linalg/linalg.py#L1744
135    if x.numel() == 0 and math.prod(x.shape[-2:]) == 0:
136        raise LinAlgError("cond is not defined on empty arrays")
137
138    result = torch.linalg.cond(x, p=p)
139
140    # Convert nans to infs (numpy does it in a data-dependent way, depending on
141    # whether the input array has nans or not)
142    # XXX: NumPy does this: https://github.com/numpy/numpy/blob/v1.24.0/numpy/linalg/linalg.py#L1744
143    return torch.where(torch.isnan(result), float("inf"), result)
144
145
146@normalizer
147@linalg_errors
148def matrix_rank(a: ArrayLike, tol=None, hermitian=False):
149    a = _atleast_float_1(a)
150
151    if a.ndim < 2:
152        return int((a != 0).any())
153
154    if tol is None:
155        # follow https://github.com/numpy/numpy/blob/v1.24.0/numpy/linalg/linalg.py#L1885
156        atol = 0
157        rtol = max(a.shape[-2:]) * torch.finfo(a.dtype).eps
158    else:
159        atol, rtol = tol, 0
160    return torch.linalg.matrix_rank(a, atol=atol, rtol=rtol, hermitian=hermitian)
161
162
163@normalizer
164@linalg_errors
165def norm(x: ArrayLike, ord=None, axis=None, keepdims: KeepDims = False):
166    x = _atleast_float_1(x)
167    return torch.linalg.norm(x, ord=ord, dim=axis)
168
169
170# ### Decompositions ###
171
172
173@normalizer
174@linalg_errors
175def cholesky(a: ArrayLike):
176    a = _atleast_float_1(a)
177    return torch.linalg.cholesky(a)
178
179
180@normalizer
181@linalg_errors
182def qr(a: ArrayLike, mode="reduced"):
183    a = _atleast_float_1(a)
184    result = torch.linalg.qr(a, mode=mode)
185    if mode == "r":
186        # match NumPy
187        result = result.R
188    return result
189
190
191@normalizer
192@linalg_errors
193def svd(a: ArrayLike, full_matrices=True, compute_uv=True, hermitian=False):
194    a = _atleast_float_1(a)
195    if not compute_uv:
196        return torch.linalg.svdvals(a)
197
198    # NB: ignore the hermitian= argument (no pytorch equivalent)
199    result = torch.linalg.svd(a, full_matrices=full_matrices)
200    return result
201
202
203# ### Eigenvalues and eigenvectors ###
204
205
206@normalizer
207@linalg_errors
208def eig(a: ArrayLike):
209    a = _atleast_float_1(a)
210    w, vt = torch.linalg.eig(a)
211
212    if not a.is_complex() and w.is_complex() and (w.imag == 0).all():
213        w = w.real
214        vt = vt.real
215    return w, vt
216
217
218@normalizer
219@linalg_errors
220def eigh(a: ArrayLike, UPLO="L"):
221    a = _atleast_float_1(a)
222    return torch.linalg.eigh(a, UPLO=UPLO)
223
224
225@normalizer
226@linalg_errors
227def eigvals(a: ArrayLike):
228    a = _atleast_float_1(a)
229    result = torch.linalg.eigvals(a)
230    if not a.is_complex() and result.is_complex() and (result.imag == 0).all():
231        result = result.real
232    return result
233
234
235@normalizer
236@linalg_errors
237def eigvalsh(a: ArrayLike, UPLO="L"):
238    a = _atleast_float_1(a)
239    return torch.linalg.eigvalsh(a, UPLO=UPLO)
240