xref: /aosp_15_r20/external/pytorch/torch/_lowrank.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker"""Implement various linear algebra algorithms for low rank matrices."""
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Worker__all__ = ["svd_lowrank", "pca_lowrank"]
4*da0073e9SAndroid Build Coastguard Worker
5*da0073e9SAndroid Build Coastguard Workerfrom typing import Optional, Tuple
6*da0073e9SAndroid Build Coastguard Worker
7*da0073e9SAndroid Build Coastguard Workerimport torch
8*da0073e9SAndroid Build Coastguard Workerfrom torch import _linalg_utils as _utils, Tensor
9*da0073e9SAndroid Build Coastguard Workerfrom torch.overrides import handle_torch_function, has_torch_function
10*da0073e9SAndroid Build Coastguard Worker
11*da0073e9SAndroid Build Coastguard Worker
12*da0073e9SAndroid Build Coastguard Workerdef get_approximate_basis(
13*da0073e9SAndroid Build Coastguard Worker    A: Tensor,
14*da0073e9SAndroid Build Coastguard Worker    q: int,
15*da0073e9SAndroid Build Coastguard Worker    niter: Optional[int] = 2,
16*da0073e9SAndroid Build Coastguard Worker    M: Optional[Tensor] = None,
17*da0073e9SAndroid Build Coastguard Worker) -> Tensor:
18*da0073e9SAndroid Build Coastguard Worker    """Return tensor :math:`Q` with :math:`q` orthonormal columns such
19*da0073e9SAndroid Build Coastguard Worker    that :math:`Q Q^H A` approximates :math:`A`. If :math:`M` is
20*da0073e9SAndroid Build Coastguard Worker    specified, then :math:`Q` is such that :math:`Q Q^H (A - M)`
21*da0073e9SAndroid Build Coastguard Worker    approximates :math:`A - M`. without instantiating any tensors
22*da0073e9SAndroid Build Coastguard Worker    of the size of :math:`A` or :math:`M`.
23*da0073e9SAndroid Build Coastguard Worker
24*da0073e9SAndroid Build Coastguard Worker    .. note:: The implementation is based on the Algorithm 4.4 from
25*da0073e9SAndroid Build Coastguard Worker              Halko et al., 2009.
26*da0073e9SAndroid Build Coastguard Worker
27*da0073e9SAndroid Build Coastguard Worker    .. note:: For an adequate approximation of a k-rank matrix
28*da0073e9SAndroid Build Coastguard Worker              :math:`A`, where k is not known in advance but could be
29*da0073e9SAndroid Build Coastguard Worker              estimated, the number of :math:`Q` columns, q, can be
30*da0073e9SAndroid Build Coastguard Worker              choosen according to the following criteria: in general,
31*da0073e9SAndroid Build Coastguard Worker              :math:`k <= q <= min(2*k, m, n)`. For large low-rank
32*da0073e9SAndroid Build Coastguard Worker              matrices, take :math:`q = k + 5..10`.  If k is
33*da0073e9SAndroid Build Coastguard Worker              relatively small compared to :math:`min(m, n)`, choosing
34*da0073e9SAndroid Build Coastguard Worker              :math:`q = k + 0..2` may be sufficient.
35*da0073e9SAndroid Build Coastguard Worker
36*da0073e9SAndroid Build Coastguard Worker    .. note:: To obtain repeatable results, reset the seed for the
37*da0073e9SAndroid Build Coastguard Worker              pseudorandom number generator
38*da0073e9SAndroid Build Coastguard Worker
39*da0073e9SAndroid Build Coastguard Worker    Args::
40*da0073e9SAndroid Build Coastguard Worker        A (Tensor): the input tensor of size :math:`(*, m, n)`
41*da0073e9SAndroid Build Coastguard Worker
42*da0073e9SAndroid Build Coastguard Worker        q (int): the dimension of subspace spanned by :math:`Q`
43*da0073e9SAndroid Build Coastguard Worker                 columns.
44*da0073e9SAndroid Build Coastguard Worker
45*da0073e9SAndroid Build Coastguard Worker        niter (int, optional): the number of subspace iterations to
46*da0073e9SAndroid Build Coastguard Worker                               conduct; ``niter`` must be a
47*da0073e9SAndroid Build Coastguard Worker                               nonnegative integer. In most cases, the
48*da0073e9SAndroid Build Coastguard Worker                               default value 2 is more than enough.
49*da0073e9SAndroid Build Coastguard Worker
50*da0073e9SAndroid Build Coastguard Worker        M (Tensor, optional): the input tensor's mean of size
51*da0073e9SAndroid Build Coastguard Worker                              :math:`(*, m, n)`.
52*da0073e9SAndroid Build Coastguard Worker
53*da0073e9SAndroid Build Coastguard Worker    References::
54*da0073e9SAndroid Build Coastguard Worker        - Nathan Halko, Per-Gunnar Martinsson, and Joel Tropp, Finding
55*da0073e9SAndroid Build Coastguard Worker          structure with randomness: probabilistic algorithms for
56*da0073e9SAndroid Build Coastguard Worker          constructing approximate matrix decompositions,
57*da0073e9SAndroid Build Coastguard Worker          arXiv:0909.4061 [math.NA; math.PR], 2009 (available at
58*da0073e9SAndroid Build Coastguard Worker          `arXiv <http://arxiv.org/abs/0909.4061>`_).
59*da0073e9SAndroid Build Coastguard Worker    """
60*da0073e9SAndroid Build Coastguard Worker
61*da0073e9SAndroid Build Coastguard Worker    niter = 2 if niter is None else niter
62*da0073e9SAndroid Build Coastguard Worker    dtype = _utils.get_floating_dtype(A) if not A.is_complex() else A.dtype
63*da0073e9SAndroid Build Coastguard Worker    matmul = _utils.matmul
64*da0073e9SAndroid Build Coastguard Worker
65*da0073e9SAndroid Build Coastguard Worker    R = torch.randn(A.shape[-1], q, dtype=dtype, device=A.device)
66*da0073e9SAndroid Build Coastguard Worker
67*da0073e9SAndroid Build Coastguard Worker    # The following code could be made faster using torch.geqrf + torch.ormqr
68*da0073e9SAndroid Build Coastguard Worker    # but geqrf is not differentiable
69*da0073e9SAndroid Build Coastguard Worker
70*da0073e9SAndroid Build Coastguard Worker    X = matmul(A, R)
71*da0073e9SAndroid Build Coastguard Worker    if M is not None:
72*da0073e9SAndroid Build Coastguard Worker        X = X - matmul(M, R)
73*da0073e9SAndroid Build Coastguard Worker    Q = torch.linalg.qr(X).Q
74*da0073e9SAndroid Build Coastguard Worker    for i in range(niter):
75*da0073e9SAndroid Build Coastguard Worker        X = matmul(A.mH, Q)
76*da0073e9SAndroid Build Coastguard Worker        if M is not None:
77*da0073e9SAndroid Build Coastguard Worker            X = X - matmul(M.mH, Q)
78*da0073e9SAndroid Build Coastguard Worker        Q = torch.linalg.qr(X).Q
79*da0073e9SAndroid Build Coastguard Worker        X = matmul(A, Q)
80*da0073e9SAndroid Build Coastguard Worker        if M is not None:
81*da0073e9SAndroid Build Coastguard Worker            X = X - matmul(M, Q)
82*da0073e9SAndroid Build Coastguard Worker        Q = torch.linalg.qr(X).Q
83*da0073e9SAndroid Build Coastguard Worker    return Q
84*da0073e9SAndroid Build Coastguard Worker
85*da0073e9SAndroid Build Coastguard Worker
86*da0073e9SAndroid Build Coastguard Workerdef svd_lowrank(
87*da0073e9SAndroid Build Coastguard Worker    A: Tensor,
88*da0073e9SAndroid Build Coastguard Worker    q: Optional[int] = 6,
89*da0073e9SAndroid Build Coastguard Worker    niter: Optional[int] = 2,
90*da0073e9SAndroid Build Coastguard Worker    M: Optional[Tensor] = None,
91*da0073e9SAndroid Build Coastguard Worker) -> Tuple[Tensor, Tensor, Tensor]:
92*da0073e9SAndroid Build Coastguard Worker    r"""Return the singular value decomposition ``(U, S, V)`` of a matrix,
93*da0073e9SAndroid Build Coastguard Worker    batches of matrices, or a sparse matrix :math:`A` such that
94*da0073e9SAndroid Build Coastguard Worker    :math:`A \approx U \operatorname{diag}(S) V^{\text{H}}`. In case :math:`M` is given, then
95*da0073e9SAndroid Build Coastguard Worker    SVD is computed for the matrix :math:`A - M`.
96*da0073e9SAndroid Build Coastguard Worker
97*da0073e9SAndroid Build Coastguard Worker    .. note:: The implementation is based on the Algorithm 5.1 from
98*da0073e9SAndroid Build Coastguard Worker              Halko et al., 2009.
99*da0073e9SAndroid Build Coastguard Worker
100*da0073e9SAndroid Build Coastguard Worker    .. note:: For an adequate approximation of a k-rank matrix
101*da0073e9SAndroid Build Coastguard Worker              :math:`A`, where k is not known in advance but could be
102*da0073e9SAndroid Build Coastguard Worker              estimated, the number of :math:`Q` columns, q, can be
103*da0073e9SAndroid Build Coastguard Worker              choosen according to the following criteria: in general,
104*da0073e9SAndroid Build Coastguard Worker              :math:`k <= q <= min(2*k, m, n)`. For large low-rank
105*da0073e9SAndroid Build Coastguard Worker              matrices, take :math:`q = k + 5..10`.  If k is
106*da0073e9SAndroid Build Coastguard Worker              relatively small compared to :math:`min(m, n)`, choosing
107*da0073e9SAndroid Build Coastguard Worker              :math:`q = k + 0..2` may be sufficient.
108*da0073e9SAndroid Build Coastguard Worker
109*da0073e9SAndroid Build Coastguard Worker    .. note:: This is a randomized method. To obtain repeatable results,
110*da0073e9SAndroid Build Coastguard Worker              set the seed for the pseudorandom number generator
111*da0073e9SAndroid Build Coastguard Worker
112*da0073e9SAndroid Build Coastguard Worker    .. note:: In general, use the full-rank SVD implementation
113*da0073e9SAndroid Build Coastguard Worker              :func:`torch.linalg.svd` for dense matrices due to its 10x
114*da0073e9SAndroid Build Coastguard Worker              higher performance characteristics. The low-rank SVD
115*da0073e9SAndroid Build Coastguard Worker              will be useful for huge sparse matrices that
116*da0073e9SAndroid Build Coastguard Worker              :func:`torch.linalg.svd` cannot handle.
117*da0073e9SAndroid Build Coastguard Worker
118*da0073e9SAndroid Build Coastguard Worker    Args::
119*da0073e9SAndroid Build Coastguard Worker        A (Tensor): the input tensor of size :math:`(*, m, n)`
120*da0073e9SAndroid Build Coastguard Worker
121*da0073e9SAndroid Build Coastguard Worker        q (int, optional): a slightly overestimated rank of A.
122*da0073e9SAndroid Build Coastguard Worker
123*da0073e9SAndroid Build Coastguard Worker        niter (int, optional): the number of subspace iterations to
124*da0073e9SAndroid Build Coastguard Worker                               conduct; niter must be a nonnegative
125*da0073e9SAndroid Build Coastguard Worker                               integer, and defaults to 2
126*da0073e9SAndroid Build Coastguard Worker
127*da0073e9SAndroid Build Coastguard Worker        M (Tensor, optional): the input tensor's mean of size
128*da0073e9SAndroid Build Coastguard Worker                              :math:`(*, m, n)`, which will be broadcasted
129*da0073e9SAndroid Build Coastguard Worker                              to the size of A in this function.
130*da0073e9SAndroid Build Coastguard Worker
131*da0073e9SAndroid Build Coastguard Worker    References::
132*da0073e9SAndroid Build Coastguard Worker        - Nathan Halko, Per-Gunnar Martinsson, and Joel Tropp, Finding
133*da0073e9SAndroid Build Coastguard Worker          structure with randomness: probabilistic algorithms for
134*da0073e9SAndroid Build Coastguard Worker          constructing approximate matrix decompositions,
135*da0073e9SAndroid Build Coastguard Worker          arXiv:0909.4061 [math.NA; math.PR], 2009 (available at
136*da0073e9SAndroid Build Coastguard Worker          `arXiv <https://arxiv.org/abs/0909.4061>`_).
137*da0073e9SAndroid Build Coastguard Worker
138*da0073e9SAndroid Build Coastguard Worker    """
139*da0073e9SAndroid Build Coastguard Worker    if not torch.jit.is_scripting():
140*da0073e9SAndroid Build Coastguard Worker        tensor_ops = (A, M)
141*da0073e9SAndroid Build Coastguard Worker        if not set(map(type, tensor_ops)).issubset(
142*da0073e9SAndroid Build Coastguard Worker            (torch.Tensor, type(None))
143*da0073e9SAndroid Build Coastguard Worker        ) and has_torch_function(tensor_ops):
144*da0073e9SAndroid Build Coastguard Worker            return handle_torch_function(
145*da0073e9SAndroid Build Coastguard Worker                svd_lowrank, tensor_ops, A, q=q, niter=niter, M=M
146*da0073e9SAndroid Build Coastguard Worker            )
147*da0073e9SAndroid Build Coastguard Worker    return _svd_lowrank(A, q=q, niter=niter, M=M)
148*da0073e9SAndroid Build Coastguard Worker
149*da0073e9SAndroid Build Coastguard Worker
150*da0073e9SAndroid Build Coastguard Workerdef _svd_lowrank(
151*da0073e9SAndroid Build Coastguard Worker    A: Tensor,
152*da0073e9SAndroid Build Coastguard Worker    q: Optional[int] = 6,
153*da0073e9SAndroid Build Coastguard Worker    niter: Optional[int] = 2,
154*da0073e9SAndroid Build Coastguard Worker    M: Optional[Tensor] = None,
155*da0073e9SAndroid Build Coastguard Worker) -> Tuple[Tensor, Tensor, Tensor]:
156*da0073e9SAndroid Build Coastguard Worker    # Algorithm 5.1 in Halko et al., 2009
157*da0073e9SAndroid Build Coastguard Worker
158*da0073e9SAndroid Build Coastguard Worker    q = 6 if q is None else q
159*da0073e9SAndroid Build Coastguard Worker    m, n = A.shape[-2:]
160*da0073e9SAndroid Build Coastguard Worker    matmul = _utils.matmul
161*da0073e9SAndroid Build Coastguard Worker    if M is not None:
162*da0073e9SAndroid Build Coastguard Worker        M = M.broadcast_to(A.size())
163*da0073e9SAndroid Build Coastguard Worker
164*da0073e9SAndroid Build Coastguard Worker    # Assume that A is tall
165*da0073e9SAndroid Build Coastguard Worker    if m < n:
166*da0073e9SAndroid Build Coastguard Worker        A = A.mH
167*da0073e9SAndroid Build Coastguard Worker        if M is not None:
168*da0073e9SAndroid Build Coastguard Worker            M = M.mH
169*da0073e9SAndroid Build Coastguard Worker
170*da0073e9SAndroid Build Coastguard Worker    Q = get_approximate_basis(A, q, niter=niter, M=M)
171*da0073e9SAndroid Build Coastguard Worker    B = matmul(Q.mH, A)
172*da0073e9SAndroid Build Coastguard Worker    if M is not None:
173*da0073e9SAndroid Build Coastguard Worker        B = B - matmul(Q.mH, M)
174*da0073e9SAndroid Build Coastguard Worker    U, S, Vh = torch.linalg.svd(B, full_matrices=False)
175*da0073e9SAndroid Build Coastguard Worker    V = Vh.mH
176*da0073e9SAndroid Build Coastguard Worker    U = Q.matmul(U)
177*da0073e9SAndroid Build Coastguard Worker
178*da0073e9SAndroid Build Coastguard Worker    if m < n:
179*da0073e9SAndroid Build Coastguard Worker        U, V = V, U
180*da0073e9SAndroid Build Coastguard Worker
181*da0073e9SAndroid Build Coastguard Worker    return U, S, V
182*da0073e9SAndroid Build Coastguard Worker
183*da0073e9SAndroid Build Coastguard Worker
184*da0073e9SAndroid Build Coastguard Workerdef pca_lowrank(
185*da0073e9SAndroid Build Coastguard Worker    A: Tensor,
186*da0073e9SAndroid Build Coastguard Worker    q: Optional[int] = None,
187*da0073e9SAndroid Build Coastguard Worker    center: bool = True,
188*da0073e9SAndroid Build Coastguard Worker    niter: int = 2,
189*da0073e9SAndroid Build Coastguard Worker) -> Tuple[Tensor, Tensor, Tensor]:
190*da0073e9SAndroid Build Coastguard Worker    r"""Performs linear Principal Component Analysis (PCA) on a low-rank
191*da0073e9SAndroid Build Coastguard Worker    matrix, batches of such matrices, or sparse matrix.
192*da0073e9SAndroid Build Coastguard Worker
193*da0073e9SAndroid Build Coastguard Worker    This function returns a namedtuple ``(U, S, V)`` which is the
194*da0073e9SAndroid Build Coastguard Worker    nearly optimal approximation of a singular value decomposition of
195*da0073e9SAndroid Build Coastguard Worker    a centered matrix :math:`A` such that :math:`A \approx U \operatorname{diag}(S) V^{\text{H}}`
196*da0073e9SAndroid Build Coastguard Worker
197*da0073e9SAndroid Build Coastguard Worker    .. note:: The relation of ``(U, S, V)`` to PCA is as follows:
198*da0073e9SAndroid Build Coastguard Worker
199*da0073e9SAndroid Build Coastguard Worker                - :math:`A` is a data matrix with ``m`` samples and
200*da0073e9SAndroid Build Coastguard Worker                  ``n`` features
201*da0073e9SAndroid Build Coastguard Worker
202*da0073e9SAndroid Build Coastguard Worker                - the :math:`V` columns represent the principal directions
203*da0073e9SAndroid Build Coastguard Worker
204*da0073e9SAndroid Build Coastguard Worker                - :math:`S ** 2 / (m - 1)` contains the eigenvalues of
205*da0073e9SAndroid Build Coastguard Worker                  :math:`A^T A / (m - 1)` which is the covariance of
206*da0073e9SAndroid Build Coastguard Worker                  ``A`` when ``center=True`` is provided.
207*da0073e9SAndroid Build Coastguard Worker
208*da0073e9SAndroid Build Coastguard Worker                - ``matmul(A, V[:, :k])`` projects data to the first k
209*da0073e9SAndroid Build Coastguard Worker                  principal components
210*da0073e9SAndroid Build Coastguard Worker
211*da0073e9SAndroid Build Coastguard Worker    .. note:: Different from the standard SVD, the size of returned
212*da0073e9SAndroid Build Coastguard Worker              matrices depend on the specified rank and q
213*da0073e9SAndroid Build Coastguard Worker              values as follows:
214*da0073e9SAndroid Build Coastguard Worker
215*da0073e9SAndroid Build Coastguard Worker                - :math:`U` is m x q matrix
216*da0073e9SAndroid Build Coastguard Worker
217*da0073e9SAndroid Build Coastguard Worker                - :math:`S` is q-vector
218*da0073e9SAndroid Build Coastguard Worker
219*da0073e9SAndroid Build Coastguard Worker                - :math:`V` is n x q matrix
220*da0073e9SAndroid Build Coastguard Worker
221*da0073e9SAndroid Build Coastguard Worker    .. note:: To obtain repeatable results, reset the seed for the
222*da0073e9SAndroid Build Coastguard Worker              pseudorandom number generator
223*da0073e9SAndroid Build Coastguard Worker
224*da0073e9SAndroid Build Coastguard Worker    Args:
225*da0073e9SAndroid Build Coastguard Worker
226*da0073e9SAndroid Build Coastguard Worker        A (Tensor): the input tensor of size :math:`(*, m, n)`
227*da0073e9SAndroid Build Coastguard Worker
228*da0073e9SAndroid Build Coastguard Worker        q (int, optional): a slightly overestimated rank of
229*da0073e9SAndroid Build Coastguard Worker                           :math:`A`. By default, ``q = min(6, m,
230*da0073e9SAndroid Build Coastguard Worker                           n)``.
231*da0073e9SAndroid Build Coastguard Worker
232*da0073e9SAndroid Build Coastguard Worker        center (bool, optional): if True, center the input tensor,
233*da0073e9SAndroid Build Coastguard Worker                                 otherwise, assume that the input is
234*da0073e9SAndroid Build Coastguard Worker                                 centered.
235*da0073e9SAndroid Build Coastguard Worker
236*da0073e9SAndroid Build Coastguard Worker        niter (int, optional): the number of subspace iterations to
237*da0073e9SAndroid Build Coastguard Worker                               conduct; niter must be a nonnegative
238*da0073e9SAndroid Build Coastguard Worker                               integer, and defaults to 2.
239*da0073e9SAndroid Build Coastguard Worker
240*da0073e9SAndroid Build Coastguard Worker    References::
241*da0073e9SAndroid Build Coastguard Worker
242*da0073e9SAndroid Build Coastguard Worker        - Nathan Halko, Per-Gunnar Martinsson, and Joel Tropp, Finding
243*da0073e9SAndroid Build Coastguard Worker          structure with randomness: probabilistic algorithms for
244*da0073e9SAndroid Build Coastguard Worker          constructing approximate matrix decompositions,
245*da0073e9SAndroid Build Coastguard Worker          arXiv:0909.4061 [math.NA; math.PR], 2009 (available at
246*da0073e9SAndroid Build Coastguard Worker          `arXiv <http://arxiv.org/abs/0909.4061>`_).
247*da0073e9SAndroid Build Coastguard Worker
248*da0073e9SAndroid Build Coastguard Worker    """
249*da0073e9SAndroid Build Coastguard Worker
250*da0073e9SAndroid Build Coastguard Worker    if not torch.jit.is_scripting():
251*da0073e9SAndroid Build Coastguard Worker        if type(A) is not torch.Tensor and has_torch_function((A,)):
252*da0073e9SAndroid Build Coastguard Worker            return handle_torch_function(
253*da0073e9SAndroid Build Coastguard Worker                pca_lowrank, (A,), A, q=q, center=center, niter=niter
254*da0073e9SAndroid Build Coastguard Worker            )
255*da0073e9SAndroid Build Coastguard Worker
256*da0073e9SAndroid Build Coastguard Worker    (m, n) = A.shape[-2:]
257*da0073e9SAndroid Build Coastguard Worker
258*da0073e9SAndroid Build Coastguard Worker    if q is None:
259*da0073e9SAndroid Build Coastguard Worker        q = min(6, m, n)
260*da0073e9SAndroid Build Coastguard Worker    elif not (q >= 0 and q <= min(m, n)):
261*da0073e9SAndroid Build Coastguard Worker        raise ValueError(
262*da0073e9SAndroid Build Coastguard Worker            f"q(={q}) must be non-negative integer and not greater than min(m, n)={min(m, n)}"
263*da0073e9SAndroid Build Coastguard Worker        )
264*da0073e9SAndroid Build Coastguard Worker    if not (niter >= 0):
265*da0073e9SAndroid Build Coastguard Worker        raise ValueError(f"niter(={niter}) must be non-negative integer")
266*da0073e9SAndroid Build Coastguard Worker
267*da0073e9SAndroid Build Coastguard Worker    dtype = _utils.get_floating_dtype(A)
268*da0073e9SAndroid Build Coastguard Worker
269*da0073e9SAndroid Build Coastguard Worker    if not center:
270*da0073e9SAndroid Build Coastguard Worker        return _svd_lowrank(A, q, niter=niter, M=None)
271*da0073e9SAndroid Build Coastguard Worker
272*da0073e9SAndroid Build Coastguard Worker    if _utils.is_sparse(A):
273*da0073e9SAndroid Build Coastguard Worker        if len(A.shape) != 2:
274*da0073e9SAndroid Build Coastguard Worker            raise ValueError("pca_lowrank input is expected to be 2-dimensional tensor")
275*da0073e9SAndroid Build Coastguard Worker        c = torch.sparse.sum(A, dim=(-2,)) / m
276*da0073e9SAndroid Build Coastguard Worker        # reshape c
277*da0073e9SAndroid Build Coastguard Worker        column_indices = c.indices()[0]
278*da0073e9SAndroid Build Coastguard Worker        indices = torch.zeros(
279*da0073e9SAndroid Build Coastguard Worker            2,
280*da0073e9SAndroid Build Coastguard Worker            len(column_indices),
281*da0073e9SAndroid Build Coastguard Worker            dtype=column_indices.dtype,
282*da0073e9SAndroid Build Coastguard Worker            device=column_indices.device,
283*da0073e9SAndroid Build Coastguard Worker        )
284*da0073e9SAndroid Build Coastguard Worker        indices[0] = column_indices
285*da0073e9SAndroid Build Coastguard Worker        C_t = torch.sparse_coo_tensor(
286*da0073e9SAndroid Build Coastguard Worker            indices, c.values(), (n, 1), dtype=dtype, device=A.device
287*da0073e9SAndroid Build Coastguard Worker        )
288*da0073e9SAndroid Build Coastguard Worker
289*da0073e9SAndroid Build Coastguard Worker        ones_m1_t = torch.ones(A.shape[:-2] + (1, m), dtype=dtype, device=A.device)
290*da0073e9SAndroid Build Coastguard Worker        M = torch.sparse.mm(C_t, ones_m1_t).mT
291*da0073e9SAndroid Build Coastguard Worker        return _svd_lowrank(A, q, niter=niter, M=M)
292*da0073e9SAndroid Build Coastguard Worker    else:
293*da0073e9SAndroid Build Coastguard Worker        C = A.mean(dim=(-2,), keepdim=True)
294*da0073e9SAndroid Build Coastguard Worker        return _svd_lowrank(A - C, q, niter=niter, M=None)
295