1*da0073e9SAndroid Build Coastguard Worker"""Implement various linear algebra algorithms for low rank matrices.""" 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Worker__all__ = ["svd_lowrank", "pca_lowrank"] 4*da0073e9SAndroid Build Coastguard Worker 5*da0073e9SAndroid Build Coastguard Workerfrom typing import Optional, Tuple 6*da0073e9SAndroid Build Coastguard Worker 7*da0073e9SAndroid Build Coastguard Workerimport torch 8*da0073e9SAndroid Build Coastguard Workerfrom torch import _linalg_utils as _utils, Tensor 9*da0073e9SAndroid Build Coastguard Workerfrom torch.overrides import handle_torch_function, has_torch_function 10*da0073e9SAndroid Build Coastguard Worker 11*da0073e9SAndroid Build Coastguard Worker 12*da0073e9SAndroid Build Coastguard Workerdef get_approximate_basis( 13*da0073e9SAndroid Build Coastguard Worker A: Tensor, 14*da0073e9SAndroid Build Coastguard Worker q: int, 15*da0073e9SAndroid Build Coastguard Worker niter: Optional[int] = 2, 16*da0073e9SAndroid Build Coastguard Worker M: Optional[Tensor] = None, 17*da0073e9SAndroid Build Coastguard Worker) -> Tensor: 18*da0073e9SAndroid Build Coastguard Worker """Return tensor :math:`Q` with :math:`q` orthonormal columns such 19*da0073e9SAndroid Build Coastguard Worker that :math:`Q Q^H A` approximates :math:`A`. If :math:`M` is 20*da0073e9SAndroid Build Coastguard Worker specified, then :math:`Q` is such that :math:`Q Q^H (A - M)` 21*da0073e9SAndroid Build Coastguard Worker approximates :math:`A - M`. without instantiating any tensors 22*da0073e9SAndroid Build Coastguard Worker of the size of :math:`A` or :math:`M`. 23*da0073e9SAndroid Build Coastguard Worker 24*da0073e9SAndroid Build Coastguard Worker .. note:: The implementation is based on the Algorithm 4.4 from 25*da0073e9SAndroid Build Coastguard Worker Halko et al., 2009. 26*da0073e9SAndroid Build Coastguard Worker 27*da0073e9SAndroid Build Coastguard Worker .. note:: For an adequate approximation of a k-rank matrix 28*da0073e9SAndroid Build Coastguard Worker :math:`A`, where k is not known in advance but could be 29*da0073e9SAndroid Build Coastguard Worker estimated, the number of :math:`Q` columns, q, can be 30*da0073e9SAndroid Build Coastguard Worker choosen according to the following criteria: in general, 31*da0073e9SAndroid Build Coastguard Worker :math:`k <= q <= min(2*k, m, n)`. For large low-rank 32*da0073e9SAndroid Build Coastguard Worker matrices, take :math:`q = k + 5..10`. If k is 33*da0073e9SAndroid Build Coastguard Worker relatively small compared to :math:`min(m, n)`, choosing 34*da0073e9SAndroid Build Coastguard Worker :math:`q = k + 0..2` may be sufficient. 35*da0073e9SAndroid Build Coastguard Worker 36*da0073e9SAndroid Build Coastguard Worker .. note:: To obtain repeatable results, reset the seed for the 37*da0073e9SAndroid Build Coastguard Worker pseudorandom number generator 38*da0073e9SAndroid Build Coastguard Worker 39*da0073e9SAndroid Build Coastguard Worker Args:: 40*da0073e9SAndroid Build Coastguard Worker A (Tensor): the input tensor of size :math:`(*, m, n)` 41*da0073e9SAndroid Build Coastguard Worker 42*da0073e9SAndroid Build Coastguard Worker q (int): the dimension of subspace spanned by :math:`Q` 43*da0073e9SAndroid Build Coastguard Worker columns. 44*da0073e9SAndroid Build Coastguard Worker 45*da0073e9SAndroid Build Coastguard Worker niter (int, optional): the number of subspace iterations to 46*da0073e9SAndroid Build Coastguard Worker conduct; ``niter`` must be a 47*da0073e9SAndroid Build Coastguard Worker nonnegative integer. In most cases, the 48*da0073e9SAndroid Build Coastguard Worker default value 2 is more than enough. 49*da0073e9SAndroid Build Coastguard Worker 50*da0073e9SAndroid Build Coastguard Worker M (Tensor, optional): the input tensor's mean of size 51*da0073e9SAndroid Build Coastguard Worker :math:`(*, m, n)`. 52*da0073e9SAndroid Build Coastguard Worker 53*da0073e9SAndroid Build Coastguard Worker References:: 54*da0073e9SAndroid Build Coastguard Worker - Nathan Halko, Per-Gunnar Martinsson, and Joel Tropp, Finding 55*da0073e9SAndroid Build Coastguard Worker structure with randomness: probabilistic algorithms for 56*da0073e9SAndroid Build Coastguard Worker constructing approximate matrix decompositions, 57*da0073e9SAndroid Build Coastguard Worker arXiv:0909.4061 [math.NA; math.PR], 2009 (available at 58*da0073e9SAndroid Build Coastguard Worker `arXiv <http://arxiv.org/abs/0909.4061>`_). 59*da0073e9SAndroid Build Coastguard Worker """ 60*da0073e9SAndroid Build Coastguard Worker 61*da0073e9SAndroid Build Coastguard Worker niter = 2 if niter is None else niter 62*da0073e9SAndroid Build Coastguard Worker dtype = _utils.get_floating_dtype(A) if not A.is_complex() else A.dtype 63*da0073e9SAndroid Build Coastguard Worker matmul = _utils.matmul 64*da0073e9SAndroid Build Coastguard Worker 65*da0073e9SAndroid Build Coastguard Worker R = torch.randn(A.shape[-1], q, dtype=dtype, device=A.device) 66*da0073e9SAndroid Build Coastguard Worker 67*da0073e9SAndroid Build Coastguard Worker # The following code could be made faster using torch.geqrf + torch.ormqr 68*da0073e9SAndroid Build Coastguard Worker # but geqrf is not differentiable 69*da0073e9SAndroid Build Coastguard Worker 70*da0073e9SAndroid Build Coastguard Worker X = matmul(A, R) 71*da0073e9SAndroid Build Coastguard Worker if M is not None: 72*da0073e9SAndroid Build Coastguard Worker X = X - matmul(M, R) 73*da0073e9SAndroid Build Coastguard Worker Q = torch.linalg.qr(X).Q 74*da0073e9SAndroid Build Coastguard Worker for i in range(niter): 75*da0073e9SAndroid Build Coastguard Worker X = matmul(A.mH, Q) 76*da0073e9SAndroid Build Coastguard Worker if M is not None: 77*da0073e9SAndroid Build Coastguard Worker X = X - matmul(M.mH, Q) 78*da0073e9SAndroid Build Coastguard Worker Q = torch.linalg.qr(X).Q 79*da0073e9SAndroid Build Coastguard Worker X = matmul(A, Q) 80*da0073e9SAndroid Build Coastguard Worker if M is not None: 81*da0073e9SAndroid Build Coastguard Worker X = X - matmul(M, Q) 82*da0073e9SAndroid Build Coastguard Worker Q = torch.linalg.qr(X).Q 83*da0073e9SAndroid Build Coastguard Worker return Q 84*da0073e9SAndroid Build Coastguard Worker 85*da0073e9SAndroid Build Coastguard Worker 86*da0073e9SAndroid Build Coastguard Workerdef svd_lowrank( 87*da0073e9SAndroid Build Coastguard Worker A: Tensor, 88*da0073e9SAndroid Build Coastguard Worker q: Optional[int] = 6, 89*da0073e9SAndroid Build Coastguard Worker niter: Optional[int] = 2, 90*da0073e9SAndroid Build Coastguard Worker M: Optional[Tensor] = None, 91*da0073e9SAndroid Build Coastguard Worker) -> Tuple[Tensor, Tensor, Tensor]: 92*da0073e9SAndroid Build Coastguard Worker r"""Return the singular value decomposition ``(U, S, V)`` of a matrix, 93*da0073e9SAndroid Build Coastguard Worker batches of matrices, or a sparse matrix :math:`A` such that 94*da0073e9SAndroid Build Coastguard Worker :math:`A \approx U \operatorname{diag}(S) V^{\text{H}}`. In case :math:`M` is given, then 95*da0073e9SAndroid Build Coastguard Worker SVD is computed for the matrix :math:`A - M`. 96*da0073e9SAndroid Build Coastguard Worker 97*da0073e9SAndroid Build Coastguard Worker .. note:: The implementation is based on the Algorithm 5.1 from 98*da0073e9SAndroid Build Coastguard Worker Halko et al., 2009. 99*da0073e9SAndroid Build Coastguard Worker 100*da0073e9SAndroid Build Coastguard Worker .. note:: For an adequate approximation of a k-rank matrix 101*da0073e9SAndroid Build Coastguard Worker :math:`A`, where k is not known in advance but could be 102*da0073e9SAndroid Build Coastguard Worker estimated, the number of :math:`Q` columns, q, can be 103*da0073e9SAndroid Build Coastguard Worker choosen according to the following criteria: in general, 104*da0073e9SAndroid Build Coastguard Worker :math:`k <= q <= min(2*k, m, n)`. For large low-rank 105*da0073e9SAndroid Build Coastguard Worker matrices, take :math:`q = k + 5..10`. If k is 106*da0073e9SAndroid Build Coastguard Worker relatively small compared to :math:`min(m, n)`, choosing 107*da0073e9SAndroid Build Coastguard Worker :math:`q = k + 0..2` may be sufficient. 108*da0073e9SAndroid Build Coastguard Worker 109*da0073e9SAndroid Build Coastguard Worker .. note:: This is a randomized method. To obtain repeatable results, 110*da0073e9SAndroid Build Coastguard Worker set the seed for the pseudorandom number generator 111*da0073e9SAndroid Build Coastguard Worker 112*da0073e9SAndroid Build Coastguard Worker .. note:: In general, use the full-rank SVD implementation 113*da0073e9SAndroid Build Coastguard Worker :func:`torch.linalg.svd` for dense matrices due to its 10x 114*da0073e9SAndroid Build Coastguard Worker higher performance characteristics. The low-rank SVD 115*da0073e9SAndroid Build Coastguard Worker will be useful for huge sparse matrices that 116*da0073e9SAndroid Build Coastguard Worker :func:`torch.linalg.svd` cannot handle. 117*da0073e9SAndroid Build Coastguard Worker 118*da0073e9SAndroid Build Coastguard Worker Args:: 119*da0073e9SAndroid Build Coastguard Worker A (Tensor): the input tensor of size :math:`(*, m, n)` 120*da0073e9SAndroid Build Coastguard Worker 121*da0073e9SAndroid Build Coastguard Worker q (int, optional): a slightly overestimated rank of A. 122*da0073e9SAndroid Build Coastguard Worker 123*da0073e9SAndroid Build Coastguard Worker niter (int, optional): the number of subspace iterations to 124*da0073e9SAndroid Build Coastguard Worker conduct; niter must be a nonnegative 125*da0073e9SAndroid Build Coastguard Worker integer, and defaults to 2 126*da0073e9SAndroid Build Coastguard Worker 127*da0073e9SAndroid Build Coastguard Worker M (Tensor, optional): the input tensor's mean of size 128*da0073e9SAndroid Build Coastguard Worker :math:`(*, m, n)`, which will be broadcasted 129*da0073e9SAndroid Build Coastguard Worker to the size of A in this function. 130*da0073e9SAndroid Build Coastguard Worker 131*da0073e9SAndroid Build Coastguard Worker References:: 132*da0073e9SAndroid Build Coastguard Worker - Nathan Halko, Per-Gunnar Martinsson, and Joel Tropp, Finding 133*da0073e9SAndroid Build Coastguard Worker structure with randomness: probabilistic algorithms for 134*da0073e9SAndroid Build Coastguard Worker constructing approximate matrix decompositions, 135*da0073e9SAndroid Build Coastguard Worker arXiv:0909.4061 [math.NA; math.PR], 2009 (available at 136*da0073e9SAndroid Build Coastguard Worker `arXiv <https://arxiv.org/abs/0909.4061>`_). 137*da0073e9SAndroid Build Coastguard Worker 138*da0073e9SAndroid Build Coastguard Worker """ 139*da0073e9SAndroid Build Coastguard Worker if not torch.jit.is_scripting(): 140*da0073e9SAndroid Build Coastguard Worker tensor_ops = (A, M) 141*da0073e9SAndroid Build Coastguard Worker if not set(map(type, tensor_ops)).issubset( 142*da0073e9SAndroid Build Coastguard Worker (torch.Tensor, type(None)) 143*da0073e9SAndroid Build Coastguard Worker ) and has_torch_function(tensor_ops): 144*da0073e9SAndroid Build Coastguard Worker return handle_torch_function( 145*da0073e9SAndroid Build Coastguard Worker svd_lowrank, tensor_ops, A, q=q, niter=niter, M=M 146*da0073e9SAndroid Build Coastguard Worker ) 147*da0073e9SAndroid Build Coastguard Worker return _svd_lowrank(A, q=q, niter=niter, M=M) 148*da0073e9SAndroid Build Coastguard Worker 149*da0073e9SAndroid Build Coastguard Worker 150*da0073e9SAndroid Build Coastguard Workerdef _svd_lowrank( 151*da0073e9SAndroid Build Coastguard Worker A: Tensor, 152*da0073e9SAndroid Build Coastguard Worker q: Optional[int] = 6, 153*da0073e9SAndroid Build Coastguard Worker niter: Optional[int] = 2, 154*da0073e9SAndroid Build Coastguard Worker M: Optional[Tensor] = None, 155*da0073e9SAndroid Build Coastguard Worker) -> Tuple[Tensor, Tensor, Tensor]: 156*da0073e9SAndroid Build Coastguard Worker # Algorithm 5.1 in Halko et al., 2009 157*da0073e9SAndroid Build Coastguard Worker 158*da0073e9SAndroid Build Coastguard Worker q = 6 if q is None else q 159*da0073e9SAndroid Build Coastguard Worker m, n = A.shape[-2:] 160*da0073e9SAndroid Build Coastguard Worker matmul = _utils.matmul 161*da0073e9SAndroid Build Coastguard Worker if M is not None: 162*da0073e9SAndroid Build Coastguard Worker M = M.broadcast_to(A.size()) 163*da0073e9SAndroid Build Coastguard Worker 164*da0073e9SAndroid Build Coastguard Worker # Assume that A is tall 165*da0073e9SAndroid Build Coastguard Worker if m < n: 166*da0073e9SAndroid Build Coastguard Worker A = A.mH 167*da0073e9SAndroid Build Coastguard Worker if M is not None: 168*da0073e9SAndroid Build Coastguard Worker M = M.mH 169*da0073e9SAndroid Build Coastguard Worker 170*da0073e9SAndroid Build Coastguard Worker Q = get_approximate_basis(A, q, niter=niter, M=M) 171*da0073e9SAndroid Build Coastguard Worker B = matmul(Q.mH, A) 172*da0073e9SAndroid Build Coastguard Worker if M is not None: 173*da0073e9SAndroid Build Coastguard Worker B = B - matmul(Q.mH, M) 174*da0073e9SAndroid Build Coastguard Worker U, S, Vh = torch.linalg.svd(B, full_matrices=False) 175*da0073e9SAndroid Build Coastguard Worker V = Vh.mH 176*da0073e9SAndroid Build Coastguard Worker U = Q.matmul(U) 177*da0073e9SAndroid Build Coastguard Worker 178*da0073e9SAndroid Build Coastguard Worker if m < n: 179*da0073e9SAndroid Build Coastguard Worker U, V = V, U 180*da0073e9SAndroid Build Coastguard Worker 181*da0073e9SAndroid Build Coastguard Worker return U, S, V 182*da0073e9SAndroid Build Coastguard Worker 183*da0073e9SAndroid Build Coastguard Worker 184*da0073e9SAndroid Build Coastguard Workerdef pca_lowrank( 185*da0073e9SAndroid Build Coastguard Worker A: Tensor, 186*da0073e9SAndroid Build Coastguard Worker q: Optional[int] = None, 187*da0073e9SAndroid Build Coastguard Worker center: bool = True, 188*da0073e9SAndroid Build Coastguard Worker niter: int = 2, 189*da0073e9SAndroid Build Coastguard Worker) -> Tuple[Tensor, Tensor, Tensor]: 190*da0073e9SAndroid Build Coastguard Worker r"""Performs linear Principal Component Analysis (PCA) on a low-rank 191*da0073e9SAndroid Build Coastguard Worker matrix, batches of such matrices, or sparse matrix. 192*da0073e9SAndroid Build Coastguard Worker 193*da0073e9SAndroid Build Coastguard Worker This function returns a namedtuple ``(U, S, V)`` which is the 194*da0073e9SAndroid Build Coastguard Worker nearly optimal approximation of a singular value decomposition of 195*da0073e9SAndroid Build Coastguard Worker a centered matrix :math:`A` such that :math:`A \approx U \operatorname{diag}(S) V^{\text{H}}` 196*da0073e9SAndroid Build Coastguard Worker 197*da0073e9SAndroid Build Coastguard Worker .. note:: The relation of ``(U, S, V)`` to PCA is as follows: 198*da0073e9SAndroid Build Coastguard Worker 199*da0073e9SAndroid Build Coastguard Worker - :math:`A` is a data matrix with ``m`` samples and 200*da0073e9SAndroid Build Coastguard Worker ``n`` features 201*da0073e9SAndroid Build Coastguard Worker 202*da0073e9SAndroid Build Coastguard Worker - the :math:`V` columns represent the principal directions 203*da0073e9SAndroid Build Coastguard Worker 204*da0073e9SAndroid Build Coastguard Worker - :math:`S ** 2 / (m - 1)` contains the eigenvalues of 205*da0073e9SAndroid Build Coastguard Worker :math:`A^T A / (m - 1)` which is the covariance of 206*da0073e9SAndroid Build Coastguard Worker ``A`` when ``center=True`` is provided. 207*da0073e9SAndroid Build Coastguard Worker 208*da0073e9SAndroid Build Coastguard Worker - ``matmul(A, V[:, :k])`` projects data to the first k 209*da0073e9SAndroid Build Coastguard Worker principal components 210*da0073e9SAndroid Build Coastguard Worker 211*da0073e9SAndroid Build Coastguard Worker .. note:: Different from the standard SVD, the size of returned 212*da0073e9SAndroid Build Coastguard Worker matrices depend on the specified rank and q 213*da0073e9SAndroid Build Coastguard Worker values as follows: 214*da0073e9SAndroid Build Coastguard Worker 215*da0073e9SAndroid Build Coastguard Worker - :math:`U` is m x q matrix 216*da0073e9SAndroid Build Coastguard Worker 217*da0073e9SAndroid Build Coastguard Worker - :math:`S` is q-vector 218*da0073e9SAndroid Build Coastguard Worker 219*da0073e9SAndroid Build Coastguard Worker - :math:`V` is n x q matrix 220*da0073e9SAndroid Build Coastguard Worker 221*da0073e9SAndroid Build Coastguard Worker .. note:: To obtain repeatable results, reset the seed for the 222*da0073e9SAndroid Build Coastguard Worker pseudorandom number generator 223*da0073e9SAndroid Build Coastguard Worker 224*da0073e9SAndroid Build Coastguard Worker Args: 225*da0073e9SAndroid Build Coastguard Worker 226*da0073e9SAndroid Build Coastguard Worker A (Tensor): the input tensor of size :math:`(*, m, n)` 227*da0073e9SAndroid Build Coastguard Worker 228*da0073e9SAndroid Build Coastguard Worker q (int, optional): a slightly overestimated rank of 229*da0073e9SAndroid Build Coastguard Worker :math:`A`. By default, ``q = min(6, m, 230*da0073e9SAndroid Build Coastguard Worker n)``. 231*da0073e9SAndroid Build Coastguard Worker 232*da0073e9SAndroid Build Coastguard Worker center (bool, optional): if True, center the input tensor, 233*da0073e9SAndroid Build Coastguard Worker otherwise, assume that the input is 234*da0073e9SAndroid Build Coastguard Worker centered. 235*da0073e9SAndroid Build Coastguard Worker 236*da0073e9SAndroid Build Coastguard Worker niter (int, optional): the number of subspace iterations to 237*da0073e9SAndroid Build Coastguard Worker conduct; niter must be a nonnegative 238*da0073e9SAndroid Build Coastguard Worker integer, and defaults to 2. 239*da0073e9SAndroid Build Coastguard Worker 240*da0073e9SAndroid Build Coastguard Worker References:: 241*da0073e9SAndroid Build Coastguard Worker 242*da0073e9SAndroid Build Coastguard Worker - Nathan Halko, Per-Gunnar Martinsson, and Joel Tropp, Finding 243*da0073e9SAndroid Build Coastguard Worker structure with randomness: probabilistic algorithms for 244*da0073e9SAndroid Build Coastguard Worker constructing approximate matrix decompositions, 245*da0073e9SAndroid Build Coastguard Worker arXiv:0909.4061 [math.NA; math.PR], 2009 (available at 246*da0073e9SAndroid Build Coastguard Worker `arXiv <http://arxiv.org/abs/0909.4061>`_). 247*da0073e9SAndroid Build Coastguard Worker 248*da0073e9SAndroid Build Coastguard Worker """ 249*da0073e9SAndroid Build Coastguard Worker 250*da0073e9SAndroid Build Coastguard Worker if not torch.jit.is_scripting(): 251*da0073e9SAndroid Build Coastguard Worker if type(A) is not torch.Tensor and has_torch_function((A,)): 252*da0073e9SAndroid Build Coastguard Worker return handle_torch_function( 253*da0073e9SAndroid Build Coastguard Worker pca_lowrank, (A,), A, q=q, center=center, niter=niter 254*da0073e9SAndroid Build Coastguard Worker ) 255*da0073e9SAndroid Build Coastguard Worker 256*da0073e9SAndroid Build Coastguard Worker (m, n) = A.shape[-2:] 257*da0073e9SAndroid Build Coastguard Worker 258*da0073e9SAndroid Build Coastguard Worker if q is None: 259*da0073e9SAndroid Build Coastguard Worker q = min(6, m, n) 260*da0073e9SAndroid Build Coastguard Worker elif not (q >= 0 and q <= min(m, n)): 261*da0073e9SAndroid Build Coastguard Worker raise ValueError( 262*da0073e9SAndroid Build Coastguard Worker f"q(={q}) must be non-negative integer and not greater than min(m, n)={min(m, n)}" 263*da0073e9SAndroid Build Coastguard Worker ) 264*da0073e9SAndroid Build Coastguard Worker if not (niter >= 0): 265*da0073e9SAndroid Build Coastguard Worker raise ValueError(f"niter(={niter}) must be non-negative integer") 266*da0073e9SAndroid Build Coastguard Worker 267*da0073e9SAndroid Build Coastguard Worker dtype = _utils.get_floating_dtype(A) 268*da0073e9SAndroid Build Coastguard Worker 269*da0073e9SAndroid Build Coastguard Worker if not center: 270*da0073e9SAndroid Build Coastguard Worker return _svd_lowrank(A, q, niter=niter, M=None) 271*da0073e9SAndroid Build Coastguard Worker 272*da0073e9SAndroid Build Coastguard Worker if _utils.is_sparse(A): 273*da0073e9SAndroid Build Coastguard Worker if len(A.shape) != 2: 274*da0073e9SAndroid Build Coastguard Worker raise ValueError("pca_lowrank input is expected to be 2-dimensional tensor") 275*da0073e9SAndroid Build Coastguard Worker c = torch.sparse.sum(A, dim=(-2,)) / m 276*da0073e9SAndroid Build Coastguard Worker # reshape c 277*da0073e9SAndroid Build Coastguard Worker column_indices = c.indices()[0] 278*da0073e9SAndroid Build Coastguard Worker indices = torch.zeros( 279*da0073e9SAndroid Build Coastguard Worker 2, 280*da0073e9SAndroid Build Coastguard Worker len(column_indices), 281*da0073e9SAndroid Build Coastguard Worker dtype=column_indices.dtype, 282*da0073e9SAndroid Build Coastguard Worker device=column_indices.device, 283*da0073e9SAndroid Build Coastguard Worker ) 284*da0073e9SAndroid Build Coastguard Worker indices[0] = column_indices 285*da0073e9SAndroid Build Coastguard Worker C_t = torch.sparse_coo_tensor( 286*da0073e9SAndroid Build Coastguard Worker indices, c.values(), (n, 1), dtype=dtype, device=A.device 287*da0073e9SAndroid Build Coastguard Worker ) 288*da0073e9SAndroid Build Coastguard Worker 289*da0073e9SAndroid Build Coastguard Worker ones_m1_t = torch.ones(A.shape[:-2] + (1, m), dtype=dtype, device=A.device) 290*da0073e9SAndroid Build Coastguard Worker M = torch.sparse.mm(C_t, ones_m1_t).mT 291*da0073e9SAndroid Build Coastguard Worker return _svd_lowrank(A, q, niter=niter, M=M) 292*da0073e9SAndroid Build Coastguard Worker else: 293*da0073e9SAndroid Build Coastguard Worker C = A.mean(dim=(-2,), keepdim=True) 294*da0073e9SAndroid Build Coastguard Worker return _svd_lowrank(A - C, q, niter=niter, M=None) 295