1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs 2*da0073e9SAndroid Build Coastguard Worker"""Locally Optimal Block Preconditioned Conjugate Gradient methods.""" 3*da0073e9SAndroid Build Coastguard Worker# Author: Pearu Peterson 4*da0073e9SAndroid Build Coastguard Worker# Created: February 2020 5*da0073e9SAndroid Build Coastguard Worker 6*da0073e9SAndroid Build Coastguard Workerfrom typing import Dict, Optional, Tuple 7*da0073e9SAndroid Build Coastguard Worker 8*da0073e9SAndroid Build Coastguard Workerimport torch 9*da0073e9SAndroid Build Coastguard Workerfrom torch import _linalg_utils as _utils, Tensor 10*da0073e9SAndroid Build Coastguard Workerfrom torch.overrides import handle_torch_function, has_torch_function 11*da0073e9SAndroid Build Coastguard Worker 12*da0073e9SAndroid Build Coastguard Worker 13*da0073e9SAndroid Build Coastguard Worker__all__ = ["lobpcg"] 14*da0073e9SAndroid Build Coastguard Worker 15*da0073e9SAndroid Build Coastguard Worker 16*da0073e9SAndroid Build Coastguard Workerdef _symeig_backward_complete_eigenspace(D_grad, U_grad, A, D, U): 17*da0073e9SAndroid Build Coastguard Worker # compute F, such that F_ij = (d_j - d_i)^{-1} for i != j, F_ii = 0 18*da0073e9SAndroid Build Coastguard Worker F = D.unsqueeze(-2) - D.unsqueeze(-1) 19*da0073e9SAndroid Build Coastguard Worker F.diagonal(dim1=-2, dim2=-1).fill_(float("inf")) 20*da0073e9SAndroid Build Coastguard Worker F.pow_(-1) 21*da0073e9SAndroid Build Coastguard Worker 22*da0073e9SAndroid Build Coastguard Worker # A.grad = U (D.grad + (U^T U.grad * F)) U^T 23*da0073e9SAndroid Build Coastguard Worker Ut = U.mT.contiguous() 24*da0073e9SAndroid Build Coastguard Worker res = torch.matmul( 25*da0073e9SAndroid Build Coastguard Worker U, torch.matmul(torch.diag_embed(D_grad) + torch.matmul(Ut, U_grad) * F, Ut) 26*da0073e9SAndroid Build Coastguard Worker ) 27*da0073e9SAndroid Build Coastguard Worker 28*da0073e9SAndroid Build Coastguard Worker return res 29*da0073e9SAndroid Build Coastguard Worker 30*da0073e9SAndroid Build Coastguard Worker 31*da0073e9SAndroid Build Coastguard Workerdef _polynomial_coefficients_given_roots(roots): 32*da0073e9SAndroid Build Coastguard Worker """ 33*da0073e9SAndroid Build Coastguard Worker Given the `roots` of a polynomial, find the polynomial's coefficients. 34*da0073e9SAndroid Build Coastguard Worker 35*da0073e9SAndroid Build Coastguard Worker If roots = (r_1, ..., r_n), then the method returns 36*da0073e9SAndroid Build Coastguard Worker coefficients (a_0, a_1, ..., a_n (== 1)) so that 37*da0073e9SAndroid Build Coastguard Worker p(x) = (x - r_1) * ... * (x - r_n) 38*da0073e9SAndroid Build Coastguard Worker = x^n + a_{n-1} * x^{n-1} + ... a_1 * x_1 + a_0 39*da0073e9SAndroid Build Coastguard Worker 40*da0073e9SAndroid Build Coastguard Worker Note: for better performance requires writing a low-level kernel 41*da0073e9SAndroid Build Coastguard Worker """ 42*da0073e9SAndroid Build Coastguard Worker poly_order = roots.shape[-1] 43*da0073e9SAndroid Build Coastguard Worker poly_coeffs_shape = list(roots.shape) 44*da0073e9SAndroid Build Coastguard Worker # we assume p(x) = x^n + a_{n-1} * x^{n-1} + ... + a_1 * x + a_0, 45*da0073e9SAndroid Build Coastguard Worker # so poly_coeffs = {a_0, ..., a_n, a_{n+1}(== 1)}, 46*da0073e9SAndroid Build Coastguard Worker # but we insert one extra coefficient to enable better vectorization below 47*da0073e9SAndroid Build Coastguard Worker poly_coeffs_shape[-1] += 2 48*da0073e9SAndroid Build Coastguard Worker poly_coeffs = roots.new_zeros(poly_coeffs_shape) 49*da0073e9SAndroid Build Coastguard Worker poly_coeffs[..., 0] = 1 50*da0073e9SAndroid Build Coastguard Worker poly_coeffs[..., -1] = 1 51*da0073e9SAndroid Build Coastguard Worker 52*da0073e9SAndroid Build Coastguard Worker # perform the Horner's rule 53*da0073e9SAndroid Build Coastguard Worker for i in range(1, poly_order + 1): 54*da0073e9SAndroid Build Coastguard Worker # note that it is computationally hard to compute backward for this method, 55*da0073e9SAndroid Build Coastguard Worker # because then given the coefficients it would require finding the roots and/or 56*da0073e9SAndroid Build Coastguard Worker # calculating the sensitivity based on the Vieta's theorem. 57*da0073e9SAndroid Build Coastguard Worker # So the code below tries to circumvent the explicit root finding by series 58*da0073e9SAndroid Build Coastguard Worker # of operations on memory copies imitating the Horner's method. 59*da0073e9SAndroid Build Coastguard Worker # The memory copies are required to construct nodes in the computational graph 60*da0073e9SAndroid Build Coastguard Worker # by exploting the explicit (not in-place, separate node for each step) 61*da0073e9SAndroid Build Coastguard Worker # recursion of the Horner's method. 62*da0073e9SAndroid Build Coastguard Worker # Needs more memory, O(... * k^2), but with only O(... * k^2) complexity. 63*da0073e9SAndroid Build Coastguard Worker poly_coeffs_new = poly_coeffs.clone() if roots.requires_grad else poly_coeffs 64*da0073e9SAndroid Build Coastguard Worker out = poly_coeffs_new.narrow(-1, poly_order - i, i + 1) 65*da0073e9SAndroid Build Coastguard Worker out -= roots.narrow(-1, i - 1, 1) * poly_coeffs.narrow( 66*da0073e9SAndroid Build Coastguard Worker -1, poly_order - i + 1, i + 1 67*da0073e9SAndroid Build Coastguard Worker ) 68*da0073e9SAndroid Build Coastguard Worker poly_coeffs = poly_coeffs_new 69*da0073e9SAndroid Build Coastguard Worker 70*da0073e9SAndroid Build Coastguard Worker return poly_coeffs.narrow(-1, 1, poly_order + 1) 71*da0073e9SAndroid Build Coastguard Worker 72*da0073e9SAndroid Build Coastguard Worker 73*da0073e9SAndroid Build Coastguard Workerdef _polynomial_value(poly, x, zero_power, transition): 74*da0073e9SAndroid Build Coastguard Worker """ 75*da0073e9SAndroid Build Coastguard Worker A generic method for computing poly(x) using the Horner's rule. 76*da0073e9SAndroid Build Coastguard Worker 77*da0073e9SAndroid Build Coastguard Worker Args: 78*da0073e9SAndroid Build Coastguard Worker poly (Tensor): the (possibly batched) 1D Tensor representing 79*da0073e9SAndroid Build Coastguard Worker polynomial coefficients such that 80*da0073e9SAndroid Build Coastguard Worker poly[..., i] = (a_{i_0}, ..., a{i_n} (==1)), and 81*da0073e9SAndroid Build Coastguard Worker poly(x) = poly[..., 0] * zero_power + ... + poly[..., n] * x^n 82*da0073e9SAndroid Build Coastguard Worker 83*da0073e9SAndroid Build Coastguard Worker x (Tensor): the value (possible batched) to evalate the polynomial `poly` at. 84*da0073e9SAndroid Build Coastguard Worker 85*da0073e9SAndroid Build Coastguard Worker zero_power (Tensor): the representation of `x^0`. It is application-specific. 86*da0073e9SAndroid Build Coastguard Worker 87*da0073e9SAndroid Build Coastguard Worker transition (Callable): the function that accepts some intermediate result `int_val`, 88*da0073e9SAndroid Build Coastguard Worker the `x` and a specific polynomial coefficient 89*da0073e9SAndroid Build Coastguard Worker `poly[..., k]` for some iteration `k`. 90*da0073e9SAndroid Build Coastguard Worker It basically performs one iteration of the Horner's rule 91*da0073e9SAndroid Build Coastguard Worker defined as `x * int_val + poly[..., k] * zero_power`. 92*da0073e9SAndroid Build Coastguard Worker Note that `zero_power` is not a parameter, 93*da0073e9SAndroid Build Coastguard Worker because the step `+ poly[..., k] * zero_power` depends on `x`, 94*da0073e9SAndroid Build Coastguard Worker whether it is a vector, a matrix, or something else, so this 95*da0073e9SAndroid Build Coastguard Worker functionality is delegated to the user. 96*da0073e9SAndroid Build Coastguard Worker """ 97*da0073e9SAndroid Build Coastguard Worker 98*da0073e9SAndroid Build Coastguard Worker res = zero_power.clone() 99*da0073e9SAndroid Build Coastguard Worker for k in range(poly.size(-1) - 2, -1, -1): 100*da0073e9SAndroid Build Coastguard Worker res = transition(res, x, poly[..., k]) 101*da0073e9SAndroid Build Coastguard Worker return res 102*da0073e9SAndroid Build Coastguard Worker 103*da0073e9SAndroid Build Coastguard Worker 104*da0073e9SAndroid Build Coastguard Workerdef _matrix_polynomial_value(poly, x, zero_power=None): 105*da0073e9SAndroid Build Coastguard Worker """ 106*da0073e9SAndroid Build Coastguard Worker Evaluates `poly(x)` for the (batched) matrix input `x`. 107*da0073e9SAndroid Build Coastguard Worker Check out `_polynomial_value` function for more details. 108*da0073e9SAndroid Build Coastguard Worker """ 109*da0073e9SAndroid Build Coastguard Worker 110*da0073e9SAndroid Build Coastguard Worker # matrix-aware Horner's rule iteration 111*da0073e9SAndroid Build Coastguard Worker def transition(curr_poly_val, x, poly_coeff): 112*da0073e9SAndroid Build Coastguard Worker res = x.matmul(curr_poly_val) 113*da0073e9SAndroid Build Coastguard Worker res.diagonal(dim1=-2, dim2=-1).add_(poly_coeff.unsqueeze(-1)) 114*da0073e9SAndroid Build Coastguard Worker return res 115*da0073e9SAndroid Build Coastguard Worker 116*da0073e9SAndroid Build Coastguard Worker if zero_power is None: 117*da0073e9SAndroid Build Coastguard Worker zero_power = torch.eye( 118*da0073e9SAndroid Build Coastguard Worker x.size(-1), x.size(-1), dtype=x.dtype, device=x.device 119*da0073e9SAndroid Build Coastguard Worker ).view(*([1] * len(list(x.shape[:-2]))), x.size(-1), x.size(-1)) 120*da0073e9SAndroid Build Coastguard Worker 121*da0073e9SAndroid Build Coastguard Worker return _polynomial_value(poly, x, zero_power, transition) 122*da0073e9SAndroid Build Coastguard Worker 123*da0073e9SAndroid Build Coastguard Worker 124*da0073e9SAndroid Build Coastguard Workerdef _vector_polynomial_value(poly, x, zero_power=None): 125*da0073e9SAndroid Build Coastguard Worker """ 126*da0073e9SAndroid Build Coastguard Worker Evaluates `poly(x)` for the (batched) vector input `x`. 127*da0073e9SAndroid Build Coastguard Worker Check out `_polynomial_value` function for more details. 128*da0073e9SAndroid Build Coastguard Worker """ 129*da0073e9SAndroid Build Coastguard Worker 130*da0073e9SAndroid Build Coastguard Worker # vector-aware Horner's rule iteration 131*da0073e9SAndroid Build Coastguard Worker def transition(curr_poly_val, x, poly_coeff): 132*da0073e9SAndroid Build Coastguard Worker res = torch.addcmul(poly_coeff.unsqueeze(-1), x, curr_poly_val) 133*da0073e9SAndroid Build Coastguard Worker return res 134*da0073e9SAndroid Build Coastguard Worker 135*da0073e9SAndroid Build Coastguard Worker if zero_power is None: 136*da0073e9SAndroid Build Coastguard Worker zero_power = x.new_ones(1).expand(x.shape) 137*da0073e9SAndroid Build Coastguard Worker 138*da0073e9SAndroid Build Coastguard Worker return _polynomial_value(poly, x, zero_power, transition) 139*da0073e9SAndroid Build Coastguard Worker 140*da0073e9SAndroid Build Coastguard Worker 141*da0073e9SAndroid Build Coastguard Workerdef _symeig_backward_partial_eigenspace(D_grad, U_grad, A, D, U, largest): 142*da0073e9SAndroid Build Coastguard Worker # compute a projection operator onto an orthogonal subspace spanned by the 143*da0073e9SAndroid Build Coastguard Worker # columns of U defined as (I - UU^T) 144*da0073e9SAndroid Build Coastguard Worker Ut = U.mT.contiguous() 145*da0073e9SAndroid Build Coastguard Worker proj_U_ortho = -U.matmul(Ut) 146*da0073e9SAndroid Build Coastguard Worker proj_U_ortho.diagonal(dim1=-2, dim2=-1).add_(1) 147*da0073e9SAndroid Build Coastguard Worker 148*da0073e9SAndroid Build Coastguard Worker # compute U_ortho, a basis for the orthogonal complement to the span(U), 149*da0073e9SAndroid Build Coastguard Worker # by projecting a random [..., m, m - k] matrix onto the subspace spanned 150*da0073e9SAndroid Build Coastguard Worker # by the columns of U. 151*da0073e9SAndroid Build Coastguard Worker # 152*da0073e9SAndroid Build Coastguard Worker # fix generator for determinism 153*da0073e9SAndroid Build Coastguard Worker gen = torch.Generator(A.device) 154*da0073e9SAndroid Build Coastguard Worker 155*da0073e9SAndroid Build Coastguard Worker # orthogonal complement to the span(U) 156*da0073e9SAndroid Build Coastguard Worker U_ortho = proj_U_ortho.matmul( 157*da0073e9SAndroid Build Coastguard Worker torch.randn( 158*da0073e9SAndroid Build Coastguard Worker (*A.shape[:-1], A.size(-1) - D.size(-1)), 159*da0073e9SAndroid Build Coastguard Worker dtype=A.dtype, 160*da0073e9SAndroid Build Coastguard Worker device=A.device, 161*da0073e9SAndroid Build Coastguard Worker generator=gen, 162*da0073e9SAndroid Build Coastguard Worker ) 163*da0073e9SAndroid Build Coastguard Worker ) 164*da0073e9SAndroid Build Coastguard Worker U_ortho_t = U_ortho.mT.contiguous() 165*da0073e9SAndroid Build Coastguard Worker 166*da0073e9SAndroid Build Coastguard Worker # compute the coefficients of the characteristic polynomial of the tensor D. 167*da0073e9SAndroid Build Coastguard Worker # Note that D is diagonal, so the diagonal elements are exactly the roots 168*da0073e9SAndroid Build Coastguard Worker # of the characteristic polynomial. 169*da0073e9SAndroid Build Coastguard Worker chr_poly_D = _polynomial_coefficients_given_roots(D) 170*da0073e9SAndroid Build Coastguard Worker 171*da0073e9SAndroid Build Coastguard Worker # the code belows finds the explicit solution to the Sylvester equation 172*da0073e9SAndroid Build Coastguard Worker # U_ortho^T A U_ortho dX - dX D = -U_ortho^T A U 173*da0073e9SAndroid Build Coastguard Worker # and incorporates it into the whole gradient stored in the `res` variable. 174*da0073e9SAndroid Build Coastguard Worker # 175*da0073e9SAndroid Build Coastguard Worker # Equivalent to the following naive implementation: 176*da0073e9SAndroid Build Coastguard Worker # res = A.new_zeros(A.shape) 177*da0073e9SAndroid Build Coastguard Worker # p_res = A.new_zeros(*A.shape[:-1], D.size(-1)) 178*da0073e9SAndroid Build Coastguard Worker # for k in range(1, chr_poly_D.size(-1)): 179*da0073e9SAndroid Build Coastguard Worker # p_res.zero_() 180*da0073e9SAndroid Build Coastguard Worker # for i in range(0, k): 181*da0073e9SAndroid Build Coastguard Worker # p_res += (A.matrix_power(k - 1 - i) @ U_grad) * D.pow(i).unsqueeze(-2) 182*da0073e9SAndroid Build Coastguard Worker # res -= chr_poly_D[k] * (U_ortho @ poly_D_at_A.inverse() @ U_ortho_t @ p_res @ U.t()) 183*da0073e9SAndroid Build Coastguard Worker # 184*da0073e9SAndroid Build Coastguard Worker # Note that dX is a differential, so the gradient contribution comes from the backward sensitivity 185*da0073e9SAndroid Build Coastguard Worker # Tr(f(U_grad, D_grad, A, U, D)^T dX) = Tr(g(U_grad, A, U, D)^T dA) for some functions f and g, 186*da0073e9SAndroid Build Coastguard Worker # and we need to compute g(U_grad, A, U, D) 187*da0073e9SAndroid Build Coastguard Worker # 188*da0073e9SAndroid Build Coastguard Worker # The naive implementation is based on the paper 189*da0073e9SAndroid Build Coastguard Worker # Hu, Qingxi, and Daizhan Cheng. 190*da0073e9SAndroid Build Coastguard Worker # "The polynomial solution to the Sylvester matrix equation." 191*da0073e9SAndroid Build Coastguard Worker # Applied mathematics letters 19.9 (2006): 859-864. 192*da0073e9SAndroid Build Coastguard Worker # 193*da0073e9SAndroid Build Coastguard Worker # We can modify the computation of `p_res` from above in a more efficient way 194*da0073e9SAndroid Build Coastguard Worker # p_res = U_grad * (chr_poly_D[1] * D.pow(0) + ... + chr_poly_D[k] * D.pow(k)).unsqueeze(-2) 195*da0073e9SAndroid Build Coastguard Worker # + A U_grad * (chr_poly_D[2] * D.pow(0) + ... + chr_poly_D[k] * D.pow(k - 1)).unsqueeze(-2) 196*da0073e9SAndroid Build Coastguard Worker # + ... 197*da0073e9SAndroid Build Coastguard Worker # + A.matrix_power(k - 1) U_grad * chr_poly_D[k] 198*da0073e9SAndroid Build Coastguard Worker # Note that this saves us from redundant matrix products with A (elimination of matrix_power) 199*da0073e9SAndroid Build Coastguard Worker U_grad_projected = U_grad 200*da0073e9SAndroid Build Coastguard Worker series_acc = U_grad_projected.new_zeros(U_grad_projected.shape) 201*da0073e9SAndroid Build Coastguard Worker for k in range(1, chr_poly_D.size(-1)): 202*da0073e9SAndroid Build Coastguard Worker poly_D = _vector_polynomial_value(chr_poly_D[..., k:], D) 203*da0073e9SAndroid Build Coastguard Worker series_acc += U_grad_projected * poly_D.unsqueeze(-2) 204*da0073e9SAndroid Build Coastguard Worker U_grad_projected = A.matmul(U_grad_projected) 205*da0073e9SAndroid Build Coastguard Worker 206*da0073e9SAndroid Build Coastguard Worker # compute chr_poly_D(A) which essentially is: 207*da0073e9SAndroid Build Coastguard Worker # 208*da0073e9SAndroid Build Coastguard Worker # chr_poly_D_at_A = A.new_zeros(A.shape) 209*da0073e9SAndroid Build Coastguard Worker # for k in range(chr_poly_D.size(-1)): 210*da0073e9SAndroid Build Coastguard Worker # chr_poly_D_at_A += chr_poly_D[k] * A.matrix_power(k) 211*da0073e9SAndroid Build Coastguard Worker # 212*da0073e9SAndroid Build Coastguard Worker # Note, however, for better performance we use the Horner's rule 213*da0073e9SAndroid Build Coastguard Worker chr_poly_D_at_A = _matrix_polynomial_value(chr_poly_D, A) 214*da0073e9SAndroid Build Coastguard Worker 215*da0073e9SAndroid Build Coastguard Worker # compute the action of `chr_poly_D_at_A` restricted to U_ortho_t 216*da0073e9SAndroid Build Coastguard Worker chr_poly_D_at_A_to_U_ortho = torch.matmul( 217*da0073e9SAndroid Build Coastguard Worker U_ortho_t, torch.matmul(chr_poly_D_at_A, U_ortho) 218*da0073e9SAndroid Build Coastguard Worker ) 219*da0073e9SAndroid Build Coastguard Worker # we need to invert 'chr_poly_D_at_A_to_U_ortho`, for that we compute its 220*da0073e9SAndroid Build Coastguard Worker # Cholesky decomposition and then use `torch.cholesky_solve` for better stability. 221*da0073e9SAndroid Build Coastguard Worker # Cholesky decomposition requires the input to be positive-definite. 222*da0073e9SAndroid Build Coastguard Worker # Note that `chr_poly_D_at_A_to_U_ortho` is positive-definite if 223*da0073e9SAndroid Build Coastguard Worker # 1. `largest` == False, or 224*da0073e9SAndroid Build Coastguard Worker # 2. `largest` == True and `k` is even 225*da0073e9SAndroid Build Coastguard Worker # under the assumption that `A` has distinct eigenvalues. 226*da0073e9SAndroid Build Coastguard Worker # 227*da0073e9SAndroid Build Coastguard Worker # check if `chr_poly_D_at_A_to_U_ortho` is positive-definite or negative-definite 228*da0073e9SAndroid Build Coastguard Worker chr_poly_D_at_A_to_U_ortho_sign = -1 if (largest and (k % 2 == 1)) else +1 229*da0073e9SAndroid Build Coastguard Worker chr_poly_D_at_A_to_U_ortho_L = torch.linalg.cholesky( 230*da0073e9SAndroid Build Coastguard Worker chr_poly_D_at_A_to_U_ortho_sign * chr_poly_D_at_A_to_U_ortho 231*da0073e9SAndroid Build Coastguard Worker ) 232*da0073e9SAndroid Build Coastguard Worker 233*da0073e9SAndroid Build Coastguard Worker # compute the gradient part in span(U) 234*da0073e9SAndroid Build Coastguard Worker res = _symeig_backward_complete_eigenspace(D_grad, U_grad, A, D, U) 235*da0073e9SAndroid Build Coastguard Worker 236*da0073e9SAndroid Build Coastguard Worker # incorporate the Sylvester equation solution into the full gradient 237*da0073e9SAndroid Build Coastguard Worker # it resides in span(U_ortho) 238*da0073e9SAndroid Build Coastguard Worker res -= U_ortho.matmul( 239*da0073e9SAndroid Build Coastguard Worker chr_poly_D_at_A_to_U_ortho_sign 240*da0073e9SAndroid Build Coastguard Worker * torch.cholesky_solve( 241*da0073e9SAndroid Build Coastguard Worker U_ortho_t.matmul(series_acc), chr_poly_D_at_A_to_U_ortho_L 242*da0073e9SAndroid Build Coastguard Worker ) 243*da0073e9SAndroid Build Coastguard Worker ).matmul(Ut) 244*da0073e9SAndroid Build Coastguard Worker 245*da0073e9SAndroid Build Coastguard Worker return res 246*da0073e9SAndroid Build Coastguard Worker 247*da0073e9SAndroid Build Coastguard Worker 248*da0073e9SAndroid Build Coastguard Workerdef _symeig_backward(D_grad, U_grad, A, D, U, largest): 249*da0073e9SAndroid Build Coastguard Worker # if `U` is square, then the columns of `U` is a complete eigenspace 250*da0073e9SAndroid Build Coastguard Worker if U.size(-1) == U.size(-2): 251*da0073e9SAndroid Build Coastguard Worker return _symeig_backward_complete_eigenspace(D_grad, U_grad, A, D, U) 252*da0073e9SAndroid Build Coastguard Worker else: 253*da0073e9SAndroid Build Coastguard Worker return _symeig_backward_partial_eigenspace(D_grad, U_grad, A, D, U, largest) 254*da0073e9SAndroid Build Coastguard Worker 255*da0073e9SAndroid Build Coastguard Worker 256*da0073e9SAndroid Build Coastguard Workerclass LOBPCGAutogradFunction(torch.autograd.Function): 257*da0073e9SAndroid Build Coastguard Worker @staticmethod 258*da0073e9SAndroid Build Coastguard Worker def forward( # type: ignore[override] 259*da0073e9SAndroid Build Coastguard Worker ctx, 260*da0073e9SAndroid Build Coastguard Worker A: Tensor, 261*da0073e9SAndroid Build Coastguard Worker k: Optional[int] = None, 262*da0073e9SAndroid Build Coastguard Worker B: Optional[Tensor] = None, 263*da0073e9SAndroid Build Coastguard Worker X: Optional[Tensor] = None, 264*da0073e9SAndroid Build Coastguard Worker n: Optional[int] = None, 265*da0073e9SAndroid Build Coastguard Worker iK: Optional[Tensor] = None, 266*da0073e9SAndroid Build Coastguard Worker niter: Optional[int] = None, 267*da0073e9SAndroid Build Coastguard Worker tol: Optional[float] = None, 268*da0073e9SAndroid Build Coastguard Worker largest: Optional[bool] = None, 269*da0073e9SAndroid Build Coastguard Worker method: Optional[str] = None, 270*da0073e9SAndroid Build Coastguard Worker tracker: None = None, 271*da0073e9SAndroid Build Coastguard Worker ortho_iparams: Optional[Dict[str, int]] = None, 272*da0073e9SAndroid Build Coastguard Worker ortho_fparams: Optional[Dict[str, float]] = None, 273*da0073e9SAndroid Build Coastguard Worker ortho_bparams: Optional[Dict[str, bool]] = None, 274*da0073e9SAndroid Build Coastguard Worker ) -> Tuple[Tensor, Tensor]: 275*da0073e9SAndroid Build Coastguard Worker # makes sure that input is contiguous for efficiency. 276*da0073e9SAndroid Build Coastguard Worker # Note: autograd does not support dense gradients for sparse input yet. 277*da0073e9SAndroid Build Coastguard Worker A = A.contiguous() if (not A.is_sparse) else A 278*da0073e9SAndroid Build Coastguard Worker if B is not None: 279*da0073e9SAndroid Build Coastguard Worker B = B.contiguous() if (not B.is_sparse) else B 280*da0073e9SAndroid Build Coastguard Worker 281*da0073e9SAndroid Build Coastguard Worker D, U = _lobpcg( 282*da0073e9SAndroid Build Coastguard Worker A, 283*da0073e9SAndroid Build Coastguard Worker k, 284*da0073e9SAndroid Build Coastguard Worker B, 285*da0073e9SAndroid Build Coastguard Worker X, 286*da0073e9SAndroid Build Coastguard Worker n, 287*da0073e9SAndroid Build Coastguard Worker iK, 288*da0073e9SAndroid Build Coastguard Worker niter, 289*da0073e9SAndroid Build Coastguard Worker tol, 290*da0073e9SAndroid Build Coastguard Worker largest, 291*da0073e9SAndroid Build Coastguard Worker method, 292*da0073e9SAndroid Build Coastguard Worker tracker, 293*da0073e9SAndroid Build Coastguard Worker ortho_iparams, 294*da0073e9SAndroid Build Coastguard Worker ortho_fparams, 295*da0073e9SAndroid Build Coastguard Worker ortho_bparams, 296*da0073e9SAndroid Build Coastguard Worker ) 297*da0073e9SAndroid Build Coastguard Worker 298*da0073e9SAndroid Build Coastguard Worker ctx.save_for_backward(A, B, D, U) 299*da0073e9SAndroid Build Coastguard Worker ctx.largest = largest 300*da0073e9SAndroid Build Coastguard Worker 301*da0073e9SAndroid Build Coastguard Worker return D, U 302*da0073e9SAndroid Build Coastguard Worker 303*da0073e9SAndroid Build Coastguard Worker @staticmethod 304*da0073e9SAndroid Build Coastguard Worker def backward(ctx, D_grad, U_grad): 305*da0073e9SAndroid Build Coastguard Worker A_grad = B_grad = None 306*da0073e9SAndroid Build Coastguard Worker grads = [None] * 14 307*da0073e9SAndroid Build Coastguard Worker 308*da0073e9SAndroid Build Coastguard Worker A, B, D, U = ctx.saved_tensors 309*da0073e9SAndroid Build Coastguard Worker largest = ctx.largest 310*da0073e9SAndroid Build Coastguard Worker 311*da0073e9SAndroid Build Coastguard Worker # lobpcg.backward has some limitations. Checks for unsupported input 312*da0073e9SAndroid Build Coastguard Worker if A.is_sparse or (B is not None and B.is_sparse and ctx.needs_input_grad[2]): 313*da0073e9SAndroid Build Coastguard Worker raise ValueError( 314*da0073e9SAndroid Build Coastguard Worker "lobpcg.backward does not support sparse input yet." 315*da0073e9SAndroid Build Coastguard Worker "Note that lobpcg.forward does though." 316*da0073e9SAndroid Build Coastguard Worker ) 317*da0073e9SAndroid Build Coastguard Worker if ( 318*da0073e9SAndroid Build Coastguard Worker A.dtype in (torch.complex64, torch.complex128) 319*da0073e9SAndroid Build Coastguard Worker or B is not None 320*da0073e9SAndroid Build Coastguard Worker and B.dtype in (torch.complex64, torch.complex128) 321*da0073e9SAndroid Build Coastguard Worker ): 322*da0073e9SAndroid Build Coastguard Worker raise ValueError( 323*da0073e9SAndroid Build Coastguard Worker "lobpcg.backward does not support complex input yet." 324*da0073e9SAndroid Build Coastguard Worker "Note that lobpcg.forward does though." 325*da0073e9SAndroid Build Coastguard Worker ) 326*da0073e9SAndroid Build Coastguard Worker if B is not None: 327*da0073e9SAndroid Build Coastguard Worker raise ValueError( 328*da0073e9SAndroid Build Coastguard Worker "lobpcg.backward does not support backward with B != I yet." 329*da0073e9SAndroid Build Coastguard Worker ) 330*da0073e9SAndroid Build Coastguard Worker 331*da0073e9SAndroid Build Coastguard Worker if largest is None: 332*da0073e9SAndroid Build Coastguard Worker largest = True 333*da0073e9SAndroid Build Coastguard Worker 334*da0073e9SAndroid Build Coastguard Worker # symeig backward 335*da0073e9SAndroid Build Coastguard Worker if B is None: 336*da0073e9SAndroid Build Coastguard Worker A_grad = _symeig_backward(D_grad, U_grad, A, D, U, largest) 337*da0073e9SAndroid Build Coastguard Worker 338*da0073e9SAndroid Build Coastguard Worker # A has index 0 339*da0073e9SAndroid Build Coastguard Worker grads[0] = A_grad 340*da0073e9SAndroid Build Coastguard Worker # B has index 2 341*da0073e9SAndroid Build Coastguard Worker grads[2] = B_grad 342*da0073e9SAndroid Build Coastguard Worker return tuple(grads) 343*da0073e9SAndroid Build Coastguard Worker 344*da0073e9SAndroid Build Coastguard Worker 345*da0073e9SAndroid Build Coastguard Workerdef lobpcg( 346*da0073e9SAndroid Build Coastguard Worker A: Tensor, 347*da0073e9SAndroid Build Coastguard Worker k: Optional[int] = None, 348*da0073e9SAndroid Build Coastguard Worker B: Optional[Tensor] = None, 349*da0073e9SAndroid Build Coastguard Worker X: Optional[Tensor] = None, 350*da0073e9SAndroid Build Coastguard Worker n: Optional[int] = None, 351*da0073e9SAndroid Build Coastguard Worker iK: Optional[Tensor] = None, 352*da0073e9SAndroid Build Coastguard Worker niter: Optional[int] = None, 353*da0073e9SAndroid Build Coastguard Worker tol: Optional[float] = None, 354*da0073e9SAndroid Build Coastguard Worker largest: Optional[bool] = None, 355*da0073e9SAndroid Build Coastguard Worker method: Optional[str] = None, 356*da0073e9SAndroid Build Coastguard Worker tracker: None = None, 357*da0073e9SAndroid Build Coastguard Worker ortho_iparams: Optional[Dict[str, int]] = None, 358*da0073e9SAndroid Build Coastguard Worker ortho_fparams: Optional[Dict[str, float]] = None, 359*da0073e9SAndroid Build Coastguard Worker ortho_bparams: Optional[Dict[str, bool]] = None, 360*da0073e9SAndroid Build Coastguard Worker) -> Tuple[Tensor, Tensor]: 361*da0073e9SAndroid Build Coastguard Worker """Find the k largest (or smallest) eigenvalues and the corresponding 362*da0073e9SAndroid Build Coastguard Worker eigenvectors of a symmetric positive definite generalized 363*da0073e9SAndroid Build Coastguard Worker eigenvalue problem using matrix-free LOBPCG methods. 364*da0073e9SAndroid Build Coastguard Worker 365*da0073e9SAndroid Build Coastguard Worker This function is a front-end to the following LOBPCG algorithms 366*da0073e9SAndroid Build Coastguard Worker selectable via `method` argument: 367*da0073e9SAndroid Build Coastguard Worker 368*da0073e9SAndroid Build Coastguard Worker `method="basic"` - the LOBPCG method introduced by Andrew 369*da0073e9SAndroid Build Coastguard Worker Knyazev, see [Knyazev2001]. A less robust method, may fail when 370*da0073e9SAndroid Build Coastguard Worker Cholesky is applied to singular input. 371*da0073e9SAndroid Build Coastguard Worker 372*da0073e9SAndroid Build Coastguard Worker `method="ortho"` - the LOBPCG method with orthogonal basis 373*da0073e9SAndroid Build Coastguard Worker selection [StathopoulosEtal2002]. A robust method. 374*da0073e9SAndroid Build Coastguard Worker 375*da0073e9SAndroid Build Coastguard Worker Supported inputs are dense, sparse, and batches of dense matrices. 376*da0073e9SAndroid Build Coastguard Worker 377*da0073e9SAndroid Build Coastguard Worker .. note:: In general, the basic method spends least time per 378*da0073e9SAndroid Build Coastguard Worker iteration. However, the robust methods converge much faster and 379*da0073e9SAndroid Build Coastguard Worker are more stable. So, the usage of the basic method is generally 380*da0073e9SAndroid Build Coastguard Worker not recommended but there exist cases where the usage of the 381*da0073e9SAndroid Build Coastguard Worker basic method may be preferred. 382*da0073e9SAndroid Build Coastguard Worker 383*da0073e9SAndroid Build Coastguard Worker .. warning:: The backward method does not support sparse and complex inputs. 384*da0073e9SAndroid Build Coastguard Worker It works only when `B` is not provided (i.e. `B == None`). 385*da0073e9SAndroid Build Coastguard Worker We are actively working on extensions, and the details of 386*da0073e9SAndroid Build Coastguard Worker the algorithms are going to be published promptly. 387*da0073e9SAndroid Build Coastguard Worker 388*da0073e9SAndroid Build Coastguard Worker .. warning:: While it is assumed that `A` is symmetric, `A.grad` is not. 389*da0073e9SAndroid Build Coastguard Worker To make sure that `A.grad` is symmetric, so that `A - t * A.grad` is symmetric 390*da0073e9SAndroid Build Coastguard Worker in first-order optimization routines, prior to running `lobpcg` 391*da0073e9SAndroid Build Coastguard Worker we do the following symmetrization map: `A -> (A + A.t()) / 2`. 392*da0073e9SAndroid Build Coastguard Worker The map is performed only when the `A` requires gradients. 393*da0073e9SAndroid Build Coastguard Worker 394*da0073e9SAndroid Build Coastguard Worker Args: 395*da0073e9SAndroid Build Coastguard Worker 396*da0073e9SAndroid Build Coastguard Worker A (Tensor): the input tensor of size :math:`(*, m, m)` 397*da0073e9SAndroid Build Coastguard Worker 398*da0073e9SAndroid Build Coastguard Worker B (Tensor, optional): the input tensor of size :math:`(*, m, 399*da0073e9SAndroid Build Coastguard Worker m)`. When not specified, `B` is interpreted as 400*da0073e9SAndroid Build Coastguard Worker identity matrix. 401*da0073e9SAndroid Build Coastguard Worker 402*da0073e9SAndroid Build Coastguard Worker X (tensor, optional): the input tensor of size :math:`(*, m, n)` 403*da0073e9SAndroid Build Coastguard Worker where `k <= n <= m`. When specified, it is used as 404*da0073e9SAndroid Build Coastguard Worker initial approximation of eigenvectors. X must be a 405*da0073e9SAndroid Build Coastguard Worker dense tensor. 406*da0073e9SAndroid Build Coastguard Worker 407*da0073e9SAndroid Build Coastguard Worker iK (tensor, optional): the input tensor of size :math:`(*, m, 408*da0073e9SAndroid Build Coastguard Worker m)`. When specified, it will be used as preconditioner. 409*da0073e9SAndroid Build Coastguard Worker 410*da0073e9SAndroid Build Coastguard Worker k (integer, optional): the number of requested 411*da0073e9SAndroid Build Coastguard Worker eigenpairs. Default is the number of :math:`X` 412*da0073e9SAndroid Build Coastguard Worker columns (when specified) or `1`. 413*da0073e9SAndroid Build Coastguard Worker 414*da0073e9SAndroid Build Coastguard Worker n (integer, optional): if :math:`X` is not specified then `n` 415*da0073e9SAndroid Build Coastguard Worker specifies the size of the generated random 416*da0073e9SAndroid Build Coastguard Worker approximation of eigenvectors. Default value for `n` 417*da0073e9SAndroid Build Coastguard Worker is `k`. If :math:`X` is specified, the value of `n` 418*da0073e9SAndroid Build Coastguard Worker (when specified) must be the number of :math:`X` 419*da0073e9SAndroid Build Coastguard Worker columns. 420*da0073e9SAndroid Build Coastguard Worker 421*da0073e9SAndroid Build Coastguard Worker tol (float, optional): residual tolerance for stopping 422*da0073e9SAndroid Build Coastguard Worker criterion. Default is `feps ** 0.5` where `feps` is 423*da0073e9SAndroid Build Coastguard Worker smallest non-zero floating-point number of the given 424*da0073e9SAndroid Build Coastguard Worker input tensor `A` data type. 425*da0073e9SAndroid Build Coastguard Worker 426*da0073e9SAndroid Build Coastguard Worker largest (bool, optional): when True, solve the eigenproblem for 427*da0073e9SAndroid Build Coastguard Worker the largest eigenvalues. Otherwise, solve the 428*da0073e9SAndroid Build Coastguard Worker eigenproblem for smallest eigenvalues. Default is 429*da0073e9SAndroid Build Coastguard Worker `True`. 430*da0073e9SAndroid Build Coastguard Worker 431*da0073e9SAndroid Build Coastguard Worker method (str, optional): select LOBPCG method. See the 432*da0073e9SAndroid Build Coastguard Worker description of the function above. Default is 433*da0073e9SAndroid Build Coastguard Worker "ortho". 434*da0073e9SAndroid Build Coastguard Worker 435*da0073e9SAndroid Build Coastguard Worker niter (int, optional): maximum number of iterations. When 436*da0073e9SAndroid Build Coastguard Worker reached, the iteration process is hard-stopped and 437*da0073e9SAndroid Build Coastguard Worker the current approximation of eigenpairs is returned. 438*da0073e9SAndroid Build Coastguard Worker For infinite iteration but until convergence criteria 439*da0073e9SAndroid Build Coastguard Worker is met, use `-1`. 440*da0073e9SAndroid Build Coastguard Worker 441*da0073e9SAndroid Build Coastguard Worker tracker (callable, optional) : a function for tracing the 442*da0073e9SAndroid Build Coastguard Worker iteration process. When specified, it is called at 443*da0073e9SAndroid Build Coastguard Worker each iteration step with LOBPCG instance as an 444*da0073e9SAndroid Build Coastguard Worker argument. The LOBPCG instance holds the full state of 445*da0073e9SAndroid Build Coastguard Worker the iteration process in the following attributes: 446*da0073e9SAndroid Build Coastguard Worker 447*da0073e9SAndroid Build Coastguard Worker `iparams`, `fparams`, `bparams` - dictionaries of 448*da0073e9SAndroid Build Coastguard Worker integer, float, and boolean valued input 449*da0073e9SAndroid Build Coastguard Worker parameters, respectively 450*da0073e9SAndroid Build Coastguard Worker 451*da0073e9SAndroid Build Coastguard Worker `ivars`, `fvars`, `bvars`, `tvars` - dictionaries 452*da0073e9SAndroid Build Coastguard Worker of integer, float, boolean, and Tensor valued 453*da0073e9SAndroid Build Coastguard Worker iteration variables, respectively. 454*da0073e9SAndroid Build Coastguard Worker 455*da0073e9SAndroid Build Coastguard Worker `A`, `B`, `iK` - input Tensor arguments. 456*da0073e9SAndroid Build Coastguard Worker 457*da0073e9SAndroid Build Coastguard Worker `E`, `X`, `S`, `R` - iteration Tensor variables. 458*da0073e9SAndroid Build Coastguard Worker 459*da0073e9SAndroid Build Coastguard Worker For instance: 460*da0073e9SAndroid Build Coastguard Worker 461*da0073e9SAndroid Build Coastguard Worker `ivars["istep"]` - the current iteration step 462*da0073e9SAndroid Build Coastguard Worker `X` - the current approximation of eigenvectors 463*da0073e9SAndroid Build Coastguard Worker `E` - the current approximation of eigenvalues 464*da0073e9SAndroid Build Coastguard Worker `R` - the current residual 465*da0073e9SAndroid Build Coastguard Worker `ivars["converged_count"]` - the current number of converged eigenpairs 466*da0073e9SAndroid Build Coastguard Worker `tvars["rerr"]` - the current state of convergence criteria 467*da0073e9SAndroid Build Coastguard Worker 468*da0073e9SAndroid Build Coastguard Worker Note that when `tracker` stores Tensor objects from 469*da0073e9SAndroid Build Coastguard Worker the LOBPCG instance, it must make copies of these. 470*da0073e9SAndroid Build Coastguard Worker 471*da0073e9SAndroid Build Coastguard Worker If `tracker` sets `bvars["force_stop"] = True`, the 472*da0073e9SAndroid Build Coastguard Worker iteration process will be hard-stopped. 473*da0073e9SAndroid Build Coastguard Worker 474*da0073e9SAndroid Build Coastguard Worker ortho_iparams, ortho_fparams, ortho_bparams (dict, optional): 475*da0073e9SAndroid Build Coastguard Worker various parameters to LOBPCG algorithm when using 476*da0073e9SAndroid Build Coastguard Worker `method="ortho"`. 477*da0073e9SAndroid Build Coastguard Worker 478*da0073e9SAndroid Build Coastguard Worker Returns: 479*da0073e9SAndroid Build Coastguard Worker 480*da0073e9SAndroid Build Coastguard Worker E (Tensor): tensor of eigenvalues of size :math:`(*, k)` 481*da0073e9SAndroid Build Coastguard Worker 482*da0073e9SAndroid Build Coastguard Worker X (Tensor): tensor of eigenvectors of size :math:`(*, m, k)` 483*da0073e9SAndroid Build Coastguard Worker 484*da0073e9SAndroid Build Coastguard Worker References: 485*da0073e9SAndroid Build Coastguard Worker 486*da0073e9SAndroid Build Coastguard Worker [Knyazev2001] Andrew V. Knyazev. (2001) Toward the Optimal 487*da0073e9SAndroid Build Coastguard Worker Preconditioned Eigensolver: Locally Optimal Block Preconditioned 488*da0073e9SAndroid Build Coastguard Worker Conjugate Gradient Method. SIAM J. Sci. Comput., 23(2), 489*da0073e9SAndroid Build Coastguard Worker 517-541. (25 pages) 490*da0073e9SAndroid Build Coastguard Worker https://epubs.siam.org/doi/abs/10.1137/S1064827500366124 491*da0073e9SAndroid Build Coastguard Worker 492*da0073e9SAndroid Build Coastguard Worker [StathopoulosEtal2002] Andreas Stathopoulos and Kesheng 493*da0073e9SAndroid Build Coastguard Worker Wu. (2002) A Block Orthogonalization Procedure with Constant 494*da0073e9SAndroid Build Coastguard Worker Synchronization Requirements. SIAM J. Sci. Comput., 23(6), 495*da0073e9SAndroid Build Coastguard Worker 2165-2182. (18 pages) 496*da0073e9SAndroid Build Coastguard Worker https://epubs.siam.org/doi/10.1137/S1064827500370883 497*da0073e9SAndroid Build Coastguard Worker 498*da0073e9SAndroid Build Coastguard Worker [DuerschEtal2018] Jed A. Duersch, Meiyue Shao, Chao Yang, Ming 499*da0073e9SAndroid Build Coastguard Worker Gu. (2018) A Robust and Efficient Implementation of LOBPCG. 500*da0073e9SAndroid Build Coastguard Worker SIAM J. Sci. Comput., 40(5), C655-C676. (22 pages) 501*da0073e9SAndroid Build Coastguard Worker https://epubs.siam.org/doi/abs/10.1137/17M1129830 502*da0073e9SAndroid Build Coastguard Worker 503*da0073e9SAndroid Build Coastguard Worker """ 504*da0073e9SAndroid Build Coastguard Worker 505*da0073e9SAndroid Build Coastguard Worker if not torch.jit.is_scripting(): 506*da0073e9SAndroid Build Coastguard Worker tensor_ops = (A, B, X, iK) 507*da0073e9SAndroid Build Coastguard Worker if not set(map(type, tensor_ops)).issubset( 508*da0073e9SAndroid Build Coastguard Worker (torch.Tensor, type(None)) 509*da0073e9SAndroid Build Coastguard Worker ) and has_torch_function(tensor_ops): 510*da0073e9SAndroid Build Coastguard Worker return handle_torch_function( 511*da0073e9SAndroid Build Coastguard Worker lobpcg, 512*da0073e9SAndroid Build Coastguard Worker tensor_ops, 513*da0073e9SAndroid Build Coastguard Worker A, 514*da0073e9SAndroid Build Coastguard Worker k=k, 515*da0073e9SAndroid Build Coastguard Worker B=B, 516*da0073e9SAndroid Build Coastguard Worker X=X, 517*da0073e9SAndroid Build Coastguard Worker n=n, 518*da0073e9SAndroid Build Coastguard Worker iK=iK, 519*da0073e9SAndroid Build Coastguard Worker niter=niter, 520*da0073e9SAndroid Build Coastguard Worker tol=tol, 521*da0073e9SAndroid Build Coastguard Worker largest=largest, 522*da0073e9SAndroid Build Coastguard Worker method=method, 523*da0073e9SAndroid Build Coastguard Worker tracker=tracker, 524*da0073e9SAndroid Build Coastguard Worker ortho_iparams=ortho_iparams, 525*da0073e9SAndroid Build Coastguard Worker ortho_fparams=ortho_fparams, 526*da0073e9SAndroid Build Coastguard Worker ortho_bparams=ortho_bparams, 527*da0073e9SAndroid Build Coastguard Worker ) 528*da0073e9SAndroid Build Coastguard Worker 529*da0073e9SAndroid Build Coastguard Worker if not torch._jit_internal.is_scripting(): 530*da0073e9SAndroid Build Coastguard Worker if A.requires_grad or (B is not None and B.requires_grad): 531*da0073e9SAndroid Build Coastguard Worker # While it is expected that `A` is symmetric, 532*da0073e9SAndroid Build Coastguard Worker # the `A_grad` might be not. Therefore we perform the trick below, 533*da0073e9SAndroid Build Coastguard Worker # so that `A_grad` becomes symmetric. 534*da0073e9SAndroid Build Coastguard Worker # The symmetrization is important for first-order optimization methods, 535*da0073e9SAndroid Build Coastguard Worker # so that (A - alpha * A_grad) is still a symmetric matrix. 536*da0073e9SAndroid Build Coastguard Worker # Same holds for `B`. 537*da0073e9SAndroid Build Coastguard Worker A_sym = (A + A.mT) / 2 538*da0073e9SAndroid Build Coastguard Worker B_sym = (B + B.mT) / 2 if (B is not None) else None 539*da0073e9SAndroid Build Coastguard Worker 540*da0073e9SAndroid Build Coastguard Worker return LOBPCGAutogradFunction.apply( 541*da0073e9SAndroid Build Coastguard Worker A_sym, 542*da0073e9SAndroid Build Coastguard Worker k, 543*da0073e9SAndroid Build Coastguard Worker B_sym, 544*da0073e9SAndroid Build Coastguard Worker X, 545*da0073e9SAndroid Build Coastguard Worker n, 546*da0073e9SAndroid Build Coastguard Worker iK, 547*da0073e9SAndroid Build Coastguard Worker niter, 548*da0073e9SAndroid Build Coastguard Worker tol, 549*da0073e9SAndroid Build Coastguard Worker largest, 550*da0073e9SAndroid Build Coastguard Worker method, 551*da0073e9SAndroid Build Coastguard Worker tracker, 552*da0073e9SAndroid Build Coastguard Worker ortho_iparams, 553*da0073e9SAndroid Build Coastguard Worker ortho_fparams, 554*da0073e9SAndroid Build Coastguard Worker ortho_bparams, 555*da0073e9SAndroid Build Coastguard Worker ) 556*da0073e9SAndroid Build Coastguard Worker else: 557*da0073e9SAndroid Build Coastguard Worker if A.requires_grad or (B is not None and B.requires_grad): 558*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 559*da0073e9SAndroid Build Coastguard Worker "Script and require grads is not supported atm." 560*da0073e9SAndroid Build Coastguard Worker "If you just want to do the forward, use .detach()" 561*da0073e9SAndroid Build Coastguard Worker "on A and B before calling into lobpcg" 562*da0073e9SAndroid Build Coastguard Worker ) 563*da0073e9SAndroid Build Coastguard Worker 564*da0073e9SAndroid Build Coastguard Worker return _lobpcg( 565*da0073e9SAndroid Build Coastguard Worker A, 566*da0073e9SAndroid Build Coastguard Worker k, 567*da0073e9SAndroid Build Coastguard Worker B, 568*da0073e9SAndroid Build Coastguard Worker X, 569*da0073e9SAndroid Build Coastguard Worker n, 570*da0073e9SAndroid Build Coastguard Worker iK, 571*da0073e9SAndroid Build Coastguard Worker niter, 572*da0073e9SAndroid Build Coastguard Worker tol, 573*da0073e9SAndroid Build Coastguard Worker largest, 574*da0073e9SAndroid Build Coastguard Worker method, 575*da0073e9SAndroid Build Coastguard Worker tracker, 576*da0073e9SAndroid Build Coastguard Worker ortho_iparams, 577*da0073e9SAndroid Build Coastguard Worker ortho_fparams, 578*da0073e9SAndroid Build Coastguard Worker ortho_bparams, 579*da0073e9SAndroid Build Coastguard Worker ) 580*da0073e9SAndroid Build Coastguard Worker 581*da0073e9SAndroid Build Coastguard Worker 582*da0073e9SAndroid Build Coastguard Workerdef _lobpcg( 583*da0073e9SAndroid Build Coastguard Worker A: Tensor, 584*da0073e9SAndroid Build Coastguard Worker k: Optional[int] = None, 585*da0073e9SAndroid Build Coastguard Worker B: Optional[Tensor] = None, 586*da0073e9SAndroid Build Coastguard Worker X: Optional[Tensor] = None, 587*da0073e9SAndroid Build Coastguard Worker n: Optional[int] = None, 588*da0073e9SAndroid Build Coastguard Worker iK: Optional[Tensor] = None, 589*da0073e9SAndroid Build Coastguard Worker niter: Optional[int] = None, 590*da0073e9SAndroid Build Coastguard Worker tol: Optional[float] = None, 591*da0073e9SAndroid Build Coastguard Worker largest: Optional[bool] = None, 592*da0073e9SAndroid Build Coastguard Worker method: Optional[str] = None, 593*da0073e9SAndroid Build Coastguard Worker tracker: None = None, 594*da0073e9SAndroid Build Coastguard Worker ortho_iparams: Optional[Dict[str, int]] = None, 595*da0073e9SAndroid Build Coastguard Worker ortho_fparams: Optional[Dict[str, float]] = None, 596*da0073e9SAndroid Build Coastguard Worker ortho_bparams: Optional[Dict[str, bool]] = None, 597*da0073e9SAndroid Build Coastguard Worker) -> Tuple[Tensor, Tensor]: 598*da0073e9SAndroid Build Coastguard Worker # A must be square: 599*da0073e9SAndroid Build Coastguard Worker assert A.shape[-2] == A.shape[-1], A.shape 600*da0073e9SAndroid Build Coastguard Worker if B is not None: 601*da0073e9SAndroid Build Coastguard Worker # A and B must have the same shapes: 602*da0073e9SAndroid Build Coastguard Worker assert A.shape == B.shape, (A.shape, B.shape) 603*da0073e9SAndroid Build Coastguard Worker 604*da0073e9SAndroid Build Coastguard Worker dtype = _utils.get_floating_dtype(A) 605*da0073e9SAndroid Build Coastguard Worker device = A.device 606*da0073e9SAndroid Build Coastguard Worker if tol is None: 607*da0073e9SAndroid Build Coastguard Worker feps = {torch.float32: 1.2e-07, torch.float64: 2.23e-16}[dtype] 608*da0073e9SAndroid Build Coastguard Worker tol = feps**0.5 609*da0073e9SAndroid Build Coastguard Worker 610*da0073e9SAndroid Build Coastguard Worker m = A.shape[-1] 611*da0073e9SAndroid Build Coastguard Worker k = (1 if X is None else X.shape[-1]) if k is None else k 612*da0073e9SAndroid Build Coastguard Worker n = (k if n is None else n) if X is None else X.shape[-1] 613*da0073e9SAndroid Build Coastguard Worker 614*da0073e9SAndroid Build Coastguard Worker if m < 3 * n: 615*da0073e9SAndroid Build Coastguard Worker raise ValueError( 616*da0073e9SAndroid Build Coastguard Worker f"LPBPCG algorithm is not applicable when the number of A rows (={m})" 617*da0073e9SAndroid Build Coastguard Worker f" is smaller than 3 x the number of requested eigenpairs (={n})" 618*da0073e9SAndroid Build Coastguard Worker ) 619*da0073e9SAndroid Build Coastguard Worker 620*da0073e9SAndroid Build Coastguard Worker method = "ortho" if method is None else method 621*da0073e9SAndroid Build Coastguard Worker 622*da0073e9SAndroid Build Coastguard Worker iparams = { 623*da0073e9SAndroid Build Coastguard Worker "m": m, 624*da0073e9SAndroid Build Coastguard Worker "n": n, 625*da0073e9SAndroid Build Coastguard Worker "k": k, 626*da0073e9SAndroid Build Coastguard Worker "niter": 1000 if niter is None else niter, 627*da0073e9SAndroid Build Coastguard Worker } 628*da0073e9SAndroid Build Coastguard Worker 629*da0073e9SAndroid Build Coastguard Worker fparams = { 630*da0073e9SAndroid Build Coastguard Worker "tol": tol, 631*da0073e9SAndroid Build Coastguard Worker } 632*da0073e9SAndroid Build Coastguard Worker 633*da0073e9SAndroid Build Coastguard Worker bparams = {"largest": True if largest is None else largest} 634*da0073e9SAndroid Build Coastguard Worker 635*da0073e9SAndroid Build Coastguard Worker if method == "ortho": 636*da0073e9SAndroid Build Coastguard Worker if ortho_iparams is not None: 637*da0073e9SAndroid Build Coastguard Worker iparams.update(ortho_iparams) 638*da0073e9SAndroid Build Coastguard Worker if ortho_fparams is not None: 639*da0073e9SAndroid Build Coastguard Worker fparams.update(ortho_fparams) 640*da0073e9SAndroid Build Coastguard Worker if ortho_bparams is not None: 641*da0073e9SAndroid Build Coastguard Worker bparams.update(ortho_bparams) 642*da0073e9SAndroid Build Coastguard Worker iparams["ortho_i_max"] = iparams.get("ortho_i_max", 3) 643*da0073e9SAndroid Build Coastguard Worker iparams["ortho_j_max"] = iparams.get("ortho_j_max", 3) 644*da0073e9SAndroid Build Coastguard Worker fparams["ortho_tol"] = fparams.get("ortho_tol", tol) 645*da0073e9SAndroid Build Coastguard Worker fparams["ortho_tol_drop"] = fparams.get("ortho_tol_drop", tol) 646*da0073e9SAndroid Build Coastguard Worker fparams["ortho_tol_replace"] = fparams.get("ortho_tol_replace", tol) 647*da0073e9SAndroid Build Coastguard Worker bparams["ortho_use_drop"] = bparams.get("ortho_use_drop", False) 648*da0073e9SAndroid Build Coastguard Worker 649*da0073e9SAndroid Build Coastguard Worker if not torch.jit.is_scripting(): 650*da0073e9SAndroid Build Coastguard Worker LOBPCG.call_tracker = LOBPCG_call_tracker # type: ignore[method-assign] 651*da0073e9SAndroid Build Coastguard Worker 652*da0073e9SAndroid Build Coastguard Worker if len(A.shape) > 2: 653*da0073e9SAndroid Build Coastguard Worker N = int(torch.prod(torch.tensor(A.shape[:-2]))) 654*da0073e9SAndroid Build Coastguard Worker bA = A.reshape((N,) + A.shape[-2:]) 655*da0073e9SAndroid Build Coastguard Worker bB = B.reshape((N,) + A.shape[-2:]) if B is not None else None 656*da0073e9SAndroid Build Coastguard Worker bX = X.reshape((N,) + X.shape[-2:]) if X is not None else None 657*da0073e9SAndroid Build Coastguard Worker bE = torch.empty((N, k), dtype=dtype, device=device) 658*da0073e9SAndroid Build Coastguard Worker bXret = torch.empty((N, m, k), dtype=dtype, device=device) 659*da0073e9SAndroid Build Coastguard Worker 660*da0073e9SAndroid Build Coastguard Worker for i in range(N): 661*da0073e9SAndroid Build Coastguard Worker A_ = bA[i] 662*da0073e9SAndroid Build Coastguard Worker B_ = bB[i] if bB is not None else None 663*da0073e9SAndroid Build Coastguard Worker X_ = ( 664*da0073e9SAndroid Build Coastguard Worker torch.randn((m, n), dtype=dtype, device=device) if bX is None else bX[i] 665*da0073e9SAndroid Build Coastguard Worker ) 666*da0073e9SAndroid Build Coastguard Worker assert len(X_.shape) == 2 and X_.shape == (m, n), (X_.shape, (m, n)) 667*da0073e9SAndroid Build Coastguard Worker iparams["batch_index"] = i 668*da0073e9SAndroid Build Coastguard Worker worker = LOBPCG(A_, B_, X_, iK, iparams, fparams, bparams, method, tracker) 669*da0073e9SAndroid Build Coastguard Worker worker.run() 670*da0073e9SAndroid Build Coastguard Worker bE[i] = worker.E[:k] 671*da0073e9SAndroid Build Coastguard Worker bXret[i] = worker.X[:, :k] 672*da0073e9SAndroid Build Coastguard Worker 673*da0073e9SAndroid Build Coastguard Worker if not torch.jit.is_scripting(): 674*da0073e9SAndroid Build Coastguard Worker LOBPCG.call_tracker = LOBPCG_call_tracker_orig # type: ignore[method-assign] 675*da0073e9SAndroid Build Coastguard Worker 676*da0073e9SAndroid Build Coastguard Worker return bE.reshape(A.shape[:-2] + (k,)), bXret.reshape(A.shape[:-2] + (m, k)) 677*da0073e9SAndroid Build Coastguard Worker 678*da0073e9SAndroid Build Coastguard Worker X = torch.randn((m, n), dtype=dtype, device=device) if X is None else X 679*da0073e9SAndroid Build Coastguard Worker assert len(X.shape) == 2 and X.shape == (m, n), (X.shape, (m, n)) 680*da0073e9SAndroid Build Coastguard Worker 681*da0073e9SAndroid Build Coastguard Worker worker = LOBPCG(A, B, X, iK, iparams, fparams, bparams, method, tracker) 682*da0073e9SAndroid Build Coastguard Worker 683*da0073e9SAndroid Build Coastguard Worker worker.run() 684*da0073e9SAndroid Build Coastguard Worker 685*da0073e9SAndroid Build Coastguard Worker if not torch.jit.is_scripting(): 686*da0073e9SAndroid Build Coastguard Worker LOBPCG.call_tracker = LOBPCG_call_tracker_orig # type: ignore[method-assign] 687*da0073e9SAndroid Build Coastguard Worker 688*da0073e9SAndroid Build Coastguard Worker return worker.E[:k], worker.X[:, :k] 689*da0073e9SAndroid Build Coastguard Worker 690*da0073e9SAndroid Build Coastguard Worker 691*da0073e9SAndroid Build Coastguard Workerclass LOBPCG: 692*da0073e9SAndroid Build Coastguard Worker """Worker class of LOBPCG methods.""" 693*da0073e9SAndroid Build Coastguard Worker 694*da0073e9SAndroid Build Coastguard Worker def __init__( 695*da0073e9SAndroid Build Coastguard Worker self, 696*da0073e9SAndroid Build Coastguard Worker A: Optional[Tensor], 697*da0073e9SAndroid Build Coastguard Worker B: Optional[Tensor], 698*da0073e9SAndroid Build Coastguard Worker X: Tensor, 699*da0073e9SAndroid Build Coastguard Worker iK: Optional[Tensor], 700*da0073e9SAndroid Build Coastguard Worker iparams: Dict[str, int], 701*da0073e9SAndroid Build Coastguard Worker fparams: Dict[str, float], 702*da0073e9SAndroid Build Coastguard Worker bparams: Dict[str, bool], 703*da0073e9SAndroid Build Coastguard Worker method: str, 704*da0073e9SAndroid Build Coastguard Worker tracker: None, 705*da0073e9SAndroid Build Coastguard Worker ) -> None: 706*da0073e9SAndroid Build Coastguard Worker # constant parameters 707*da0073e9SAndroid Build Coastguard Worker self.A = A 708*da0073e9SAndroid Build Coastguard Worker self.B = B 709*da0073e9SAndroid Build Coastguard Worker self.iK = iK 710*da0073e9SAndroid Build Coastguard Worker self.iparams = iparams 711*da0073e9SAndroid Build Coastguard Worker self.fparams = fparams 712*da0073e9SAndroid Build Coastguard Worker self.bparams = bparams 713*da0073e9SAndroid Build Coastguard Worker self.method = method 714*da0073e9SAndroid Build Coastguard Worker self.tracker = tracker 715*da0073e9SAndroid Build Coastguard Worker m = iparams["m"] 716*da0073e9SAndroid Build Coastguard Worker n = iparams["n"] 717*da0073e9SAndroid Build Coastguard Worker 718*da0073e9SAndroid Build Coastguard Worker # variable parameters 719*da0073e9SAndroid Build Coastguard Worker self.X = X 720*da0073e9SAndroid Build Coastguard Worker self.E = torch.zeros((n,), dtype=X.dtype, device=X.device) 721*da0073e9SAndroid Build Coastguard Worker self.R = torch.zeros((m, n), dtype=X.dtype, device=X.device) 722*da0073e9SAndroid Build Coastguard Worker self.S = torch.zeros((m, 3 * n), dtype=X.dtype, device=X.device) 723*da0073e9SAndroid Build Coastguard Worker self.tvars: Dict[str, Tensor] = {} 724*da0073e9SAndroid Build Coastguard Worker self.ivars: Dict[str, int] = {"istep": 0} 725*da0073e9SAndroid Build Coastguard Worker self.fvars: Dict[str, float] = {"_": 0.0} 726*da0073e9SAndroid Build Coastguard Worker self.bvars: Dict[str, bool] = {"_": False} 727*da0073e9SAndroid Build Coastguard Worker 728*da0073e9SAndroid Build Coastguard Worker def __str__(self): 729*da0073e9SAndroid Build Coastguard Worker lines = ["LOPBCG:"] 730*da0073e9SAndroid Build Coastguard Worker lines += [f" iparams={self.iparams}"] 731*da0073e9SAndroid Build Coastguard Worker lines += [f" fparams={self.fparams}"] 732*da0073e9SAndroid Build Coastguard Worker lines += [f" bparams={self.bparams}"] 733*da0073e9SAndroid Build Coastguard Worker lines += [f" ivars={self.ivars}"] 734*da0073e9SAndroid Build Coastguard Worker lines += [f" fvars={self.fvars}"] 735*da0073e9SAndroid Build Coastguard Worker lines += [f" bvars={self.bvars}"] 736*da0073e9SAndroid Build Coastguard Worker lines += [f" tvars={self.tvars}"] 737*da0073e9SAndroid Build Coastguard Worker lines += [f" A={self.A}"] 738*da0073e9SAndroid Build Coastguard Worker lines += [f" B={self.B}"] 739*da0073e9SAndroid Build Coastguard Worker lines += [f" iK={self.iK}"] 740*da0073e9SAndroid Build Coastguard Worker lines += [f" X={self.X}"] 741*da0073e9SAndroid Build Coastguard Worker lines += [f" E={self.E}"] 742*da0073e9SAndroid Build Coastguard Worker r = "" 743*da0073e9SAndroid Build Coastguard Worker for line in lines: 744*da0073e9SAndroid Build Coastguard Worker r += line + "\n" 745*da0073e9SAndroid Build Coastguard Worker return r 746*da0073e9SAndroid Build Coastguard Worker 747*da0073e9SAndroid Build Coastguard Worker def update(self): 748*da0073e9SAndroid Build Coastguard Worker """Set and update iteration variables.""" 749*da0073e9SAndroid Build Coastguard Worker if self.ivars["istep"] == 0: 750*da0073e9SAndroid Build Coastguard Worker X_norm = float(torch.norm(self.X)) 751*da0073e9SAndroid Build Coastguard Worker iX_norm = X_norm**-1 752*da0073e9SAndroid Build Coastguard Worker A_norm = float(torch.norm(_utils.matmul(self.A, self.X))) * iX_norm 753*da0073e9SAndroid Build Coastguard Worker B_norm = float(torch.norm(_utils.matmul(self.B, self.X))) * iX_norm 754*da0073e9SAndroid Build Coastguard Worker self.fvars["X_norm"] = X_norm 755*da0073e9SAndroid Build Coastguard Worker self.fvars["A_norm"] = A_norm 756*da0073e9SAndroid Build Coastguard Worker self.fvars["B_norm"] = B_norm 757*da0073e9SAndroid Build Coastguard Worker self.ivars["iterations_left"] = self.iparams["niter"] 758*da0073e9SAndroid Build Coastguard Worker self.ivars["converged_count"] = 0 759*da0073e9SAndroid Build Coastguard Worker self.ivars["converged_end"] = 0 760*da0073e9SAndroid Build Coastguard Worker 761*da0073e9SAndroid Build Coastguard Worker if self.method == "ortho": 762*da0073e9SAndroid Build Coastguard Worker self._update_ortho() 763*da0073e9SAndroid Build Coastguard Worker else: 764*da0073e9SAndroid Build Coastguard Worker self._update_basic() 765*da0073e9SAndroid Build Coastguard Worker 766*da0073e9SAndroid Build Coastguard Worker self.ivars["iterations_left"] = self.ivars["iterations_left"] - 1 767*da0073e9SAndroid Build Coastguard Worker self.ivars["istep"] = self.ivars["istep"] + 1 768*da0073e9SAndroid Build Coastguard Worker 769*da0073e9SAndroid Build Coastguard Worker def update_residual(self): 770*da0073e9SAndroid Build Coastguard Worker """Update residual R from A, B, X, E.""" 771*da0073e9SAndroid Build Coastguard Worker mm = _utils.matmul 772*da0073e9SAndroid Build Coastguard Worker self.R = mm(self.A, self.X) - mm(self.B, self.X) * self.E 773*da0073e9SAndroid Build Coastguard Worker 774*da0073e9SAndroid Build Coastguard Worker def update_converged_count(self): 775*da0073e9SAndroid Build Coastguard Worker """Determine the number of converged eigenpairs using backward stable 776*da0073e9SAndroid Build Coastguard Worker convergence criterion, see discussion in Sec 4.3 of [DuerschEtal2018]. 777*da0073e9SAndroid Build Coastguard Worker 778*da0073e9SAndroid Build Coastguard Worker Users may redefine this method for custom convergence criteria. 779*da0073e9SAndroid Build Coastguard Worker """ 780*da0073e9SAndroid Build Coastguard Worker # (...) -> int 781*da0073e9SAndroid Build Coastguard Worker prev_count = self.ivars["converged_count"] 782*da0073e9SAndroid Build Coastguard Worker tol = self.fparams["tol"] 783*da0073e9SAndroid Build Coastguard Worker A_norm = self.fvars["A_norm"] 784*da0073e9SAndroid Build Coastguard Worker B_norm = self.fvars["B_norm"] 785*da0073e9SAndroid Build Coastguard Worker E, X, R = self.E, self.X, self.R 786*da0073e9SAndroid Build Coastguard Worker rerr = ( 787*da0073e9SAndroid Build Coastguard Worker torch.norm(R, 2, (0,)) 788*da0073e9SAndroid Build Coastguard Worker * (torch.norm(X, 2, (0,)) * (A_norm + E[: X.shape[-1]] * B_norm)) ** -1 789*da0073e9SAndroid Build Coastguard Worker ) 790*da0073e9SAndroid Build Coastguard Worker converged = rerr.real < tol # this is a norm so imag is 0.0 791*da0073e9SAndroid Build Coastguard Worker count = 0 792*da0073e9SAndroid Build Coastguard Worker for b in converged: 793*da0073e9SAndroid Build Coastguard Worker if not b: 794*da0073e9SAndroid Build Coastguard Worker # ignore convergence of following pairs to ensure 795*da0073e9SAndroid Build Coastguard Worker # strict ordering of eigenpairs 796*da0073e9SAndroid Build Coastguard Worker break 797*da0073e9SAndroid Build Coastguard Worker count += 1 798*da0073e9SAndroid Build Coastguard Worker assert ( 799*da0073e9SAndroid Build Coastguard Worker count >= prev_count 800*da0073e9SAndroid Build Coastguard Worker ), f"the number of converged eigenpairs (was {prev_count}, got {count}) cannot decrease" 801*da0073e9SAndroid Build Coastguard Worker self.ivars["converged_count"] = count 802*da0073e9SAndroid Build Coastguard Worker self.tvars["rerr"] = rerr 803*da0073e9SAndroid Build Coastguard Worker return count 804*da0073e9SAndroid Build Coastguard Worker 805*da0073e9SAndroid Build Coastguard Worker def stop_iteration(self): 806*da0073e9SAndroid Build Coastguard Worker """Return True to stop iterations. 807*da0073e9SAndroid Build Coastguard Worker 808*da0073e9SAndroid Build Coastguard Worker Note that tracker (if defined) can force-stop iterations by 809*da0073e9SAndroid Build Coastguard Worker setting ``worker.bvars['force_stop'] = True``. 810*da0073e9SAndroid Build Coastguard Worker """ 811*da0073e9SAndroid Build Coastguard Worker return ( 812*da0073e9SAndroid Build Coastguard Worker self.bvars.get("force_stop", False) 813*da0073e9SAndroid Build Coastguard Worker or self.ivars["iterations_left"] == 0 814*da0073e9SAndroid Build Coastguard Worker or self.ivars["converged_count"] >= self.iparams["k"] 815*da0073e9SAndroid Build Coastguard Worker ) 816*da0073e9SAndroid Build Coastguard Worker 817*da0073e9SAndroid Build Coastguard Worker def run(self): 818*da0073e9SAndroid Build Coastguard Worker """Run LOBPCG iterations. 819*da0073e9SAndroid Build Coastguard Worker 820*da0073e9SAndroid Build Coastguard Worker Use this method as a template for implementing LOBPCG 821*da0073e9SAndroid Build Coastguard Worker iteration scheme with custom tracker that is compatible with 822*da0073e9SAndroid Build Coastguard Worker TorchScript. 823*da0073e9SAndroid Build Coastguard Worker """ 824*da0073e9SAndroid Build Coastguard Worker self.update() 825*da0073e9SAndroid Build Coastguard Worker 826*da0073e9SAndroid Build Coastguard Worker if not torch.jit.is_scripting() and self.tracker is not None: 827*da0073e9SAndroid Build Coastguard Worker self.call_tracker() 828*da0073e9SAndroid Build Coastguard Worker 829*da0073e9SAndroid Build Coastguard Worker while not self.stop_iteration(): 830*da0073e9SAndroid Build Coastguard Worker self.update() 831*da0073e9SAndroid Build Coastguard Worker 832*da0073e9SAndroid Build Coastguard Worker if not torch.jit.is_scripting() and self.tracker is not None: 833*da0073e9SAndroid Build Coastguard Worker self.call_tracker() 834*da0073e9SAndroid Build Coastguard Worker 835*da0073e9SAndroid Build Coastguard Worker @torch.jit.unused 836*da0073e9SAndroid Build Coastguard Worker def call_tracker(self): 837*da0073e9SAndroid Build Coastguard Worker """Interface for tracking iteration process in Python mode. 838*da0073e9SAndroid Build Coastguard Worker 839*da0073e9SAndroid Build Coastguard Worker Tracking the iteration process is disabled in TorchScript 840*da0073e9SAndroid Build Coastguard Worker mode. In fact, one should specify tracker=None when JIT 841*da0073e9SAndroid Build Coastguard Worker compiling functions using lobpcg. 842*da0073e9SAndroid Build Coastguard Worker """ 843*da0073e9SAndroid Build Coastguard Worker # do nothing when in TorchScript mode 844*da0073e9SAndroid Build Coastguard Worker 845*da0073e9SAndroid Build Coastguard Worker # Internal methods 846*da0073e9SAndroid Build Coastguard Worker 847*da0073e9SAndroid Build Coastguard Worker def _update_basic(self): 848*da0073e9SAndroid Build Coastguard Worker """ 849*da0073e9SAndroid Build Coastguard Worker Update or initialize iteration variables when `method == "basic"`. 850*da0073e9SAndroid Build Coastguard Worker """ 851*da0073e9SAndroid Build Coastguard Worker mm = torch.matmul 852*da0073e9SAndroid Build Coastguard Worker ns = self.ivars["converged_end"] 853*da0073e9SAndroid Build Coastguard Worker nc = self.ivars["converged_count"] 854*da0073e9SAndroid Build Coastguard Worker n = self.iparams["n"] 855*da0073e9SAndroid Build Coastguard Worker largest = self.bparams["largest"] 856*da0073e9SAndroid Build Coastguard Worker 857*da0073e9SAndroid Build Coastguard Worker if self.ivars["istep"] == 0: 858*da0073e9SAndroid Build Coastguard Worker Ri = self._get_rayleigh_ritz_transform(self.X) 859*da0073e9SAndroid Build Coastguard Worker M = _utils.qform(_utils.qform(self.A, self.X), Ri) 860*da0073e9SAndroid Build Coastguard Worker E, Z = _utils.symeig(M, largest) 861*da0073e9SAndroid Build Coastguard Worker self.X[:] = mm(self.X, mm(Ri, Z)) 862*da0073e9SAndroid Build Coastguard Worker self.E[:] = E 863*da0073e9SAndroid Build Coastguard Worker np = 0 864*da0073e9SAndroid Build Coastguard Worker self.update_residual() 865*da0073e9SAndroid Build Coastguard Worker nc = self.update_converged_count() 866*da0073e9SAndroid Build Coastguard Worker self.S[..., :n] = self.X 867*da0073e9SAndroid Build Coastguard Worker 868*da0073e9SAndroid Build Coastguard Worker W = _utils.matmul(self.iK, self.R) 869*da0073e9SAndroid Build Coastguard Worker self.ivars["converged_end"] = ns = n + np + W.shape[-1] 870*da0073e9SAndroid Build Coastguard Worker self.S[:, n + np : ns] = W 871*da0073e9SAndroid Build Coastguard Worker else: 872*da0073e9SAndroid Build Coastguard Worker S_ = self.S[:, nc:ns] 873*da0073e9SAndroid Build Coastguard Worker Ri = self._get_rayleigh_ritz_transform(S_) 874*da0073e9SAndroid Build Coastguard Worker M = _utils.qform(_utils.qform(self.A, S_), Ri) 875*da0073e9SAndroid Build Coastguard Worker E_, Z = _utils.symeig(M, largest) 876*da0073e9SAndroid Build Coastguard Worker self.X[:, nc:] = mm(S_, mm(Ri, Z[:, : n - nc])) 877*da0073e9SAndroid Build Coastguard Worker self.E[nc:] = E_[: n - nc] 878*da0073e9SAndroid Build Coastguard Worker P = mm(S_, mm(Ri, Z[:, n : 2 * n - nc])) 879*da0073e9SAndroid Build Coastguard Worker np = P.shape[-1] 880*da0073e9SAndroid Build Coastguard Worker 881*da0073e9SAndroid Build Coastguard Worker self.update_residual() 882*da0073e9SAndroid Build Coastguard Worker nc = self.update_converged_count() 883*da0073e9SAndroid Build Coastguard Worker self.S[..., :n] = self.X 884*da0073e9SAndroid Build Coastguard Worker self.S[:, n : n + np] = P 885*da0073e9SAndroid Build Coastguard Worker W = _utils.matmul(self.iK, self.R[:, nc:]) 886*da0073e9SAndroid Build Coastguard Worker 887*da0073e9SAndroid Build Coastguard Worker self.ivars["converged_end"] = ns = n + np + W.shape[-1] 888*da0073e9SAndroid Build Coastguard Worker self.S[:, n + np : ns] = W 889*da0073e9SAndroid Build Coastguard Worker 890*da0073e9SAndroid Build Coastguard Worker def _update_ortho(self): 891*da0073e9SAndroid Build Coastguard Worker """ 892*da0073e9SAndroid Build Coastguard Worker Update or initialize iteration variables when `method == "ortho"`. 893*da0073e9SAndroid Build Coastguard Worker """ 894*da0073e9SAndroid Build Coastguard Worker mm = torch.matmul 895*da0073e9SAndroid Build Coastguard Worker ns = self.ivars["converged_end"] 896*da0073e9SAndroid Build Coastguard Worker nc = self.ivars["converged_count"] 897*da0073e9SAndroid Build Coastguard Worker n = self.iparams["n"] 898*da0073e9SAndroid Build Coastguard Worker largest = self.bparams["largest"] 899*da0073e9SAndroid Build Coastguard Worker 900*da0073e9SAndroid Build Coastguard Worker if self.ivars["istep"] == 0: 901*da0073e9SAndroid Build Coastguard Worker Ri = self._get_rayleigh_ritz_transform(self.X) 902*da0073e9SAndroid Build Coastguard Worker M = _utils.qform(_utils.qform(self.A, self.X), Ri) 903*da0073e9SAndroid Build Coastguard Worker E, Z = _utils.symeig(M, largest) 904*da0073e9SAndroid Build Coastguard Worker self.X = mm(self.X, mm(Ri, Z)) 905*da0073e9SAndroid Build Coastguard Worker self.update_residual() 906*da0073e9SAndroid Build Coastguard Worker np = 0 907*da0073e9SAndroid Build Coastguard Worker nc = self.update_converged_count() 908*da0073e9SAndroid Build Coastguard Worker self.S[:, :n] = self.X 909*da0073e9SAndroid Build Coastguard Worker W = self._get_ortho(self.R, self.X) 910*da0073e9SAndroid Build Coastguard Worker ns = self.ivars["converged_end"] = n + np + W.shape[-1] 911*da0073e9SAndroid Build Coastguard Worker self.S[:, n + np : ns] = W 912*da0073e9SAndroid Build Coastguard Worker 913*da0073e9SAndroid Build Coastguard Worker else: 914*da0073e9SAndroid Build Coastguard Worker S_ = self.S[:, nc:ns] 915*da0073e9SAndroid Build Coastguard Worker # Rayleigh-Ritz procedure 916*da0073e9SAndroid Build Coastguard Worker E_, Z = _utils.symeig(_utils.qform(self.A, S_), largest) 917*da0073e9SAndroid Build Coastguard Worker 918*da0073e9SAndroid Build Coastguard Worker # Update E, X, P 919*da0073e9SAndroid Build Coastguard Worker self.X[:, nc:] = mm(S_, Z[:, : n - nc]) 920*da0073e9SAndroid Build Coastguard Worker self.E[nc:] = E_[: n - nc] 921*da0073e9SAndroid Build Coastguard Worker P = mm(S_, mm(Z[:, n - nc :], _utils.basis(Z[: n - nc, n - nc :].mT))) 922*da0073e9SAndroid Build Coastguard Worker np = P.shape[-1] 923*da0073e9SAndroid Build Coastguard Worker 924*da0073e9SAndroid Build Coastguard Worker # check convergence 925*da0073e9SAndroid Build Coastguard Worker self.update_residual() 926*da0073e9SAndroid Build Coastguard Worker nc = self.update_converged_count() 927*da0073e9SAndroid Build Coastguard Worker 928*da0073e9SAndroid Build Coastguard Worker # update S 929*da0073e9SAndroid Build Coastguard Worker self.S[:, :n] = self.X 930*da0073e9SAndroid Build Coastguard Worker self.S[:, n : n + np] = P 931*da0073e9SAndroid Build Coastguard Worker W = self._get_ortho(self.R[:, nc:], self.S[:, : n + np]) 932*da0073e9SAndroid Build Coastguard Worker ns = self.ivars["converged_end"] = n + np + W.shape[-1] 933*da0073e9SAndroid Build Coastguard Worker self.S[:, n + np : ns] = W 934*da0073e9SAndroid Build Coastguard Worker 935*da0073e9SAndroid Build Coastguard Worker def _get_rayleigh_ritz_transform(self, S): 936*da0073e9SAndroid Build Coastguard Worker """Return a transformation matrix that is used in Rayleigh-Ritz 937*da0073e9SAndroid Build Coastguard Worker procedure for reducing a general eigenvalue problem :math:`(S^TAS) 938*da0073e9SAndroid Build Coastguard Worker C = (S^TBS) C E` to a standard eigenvalue problem :math: `(Ri^T 939*da0073e9SAndroid Build Coastguard Worker S^TAS Ri) Z = Z E` where `C = Ri Z`. 940*da0073e9SAndroid Build Coastguard Worker 941*da0073e9SAndroid Build Coastguard Worker .. note:: In the original Rayleight-Ritz procedure in 942*da0073e9SAndroid Build Coastguard Worker [DuerschEtal2018], the problem is formulated as follows:: 943*da0073e9SAndroid Build Coastguard Worker 944*da0073e9SAndroid Build Coastguard Worker SAS = S^T A S 945*da0073e9SAndroid Build Coastguard Worker SBS = S^T B S 946*da0073e9SAndroid Build Coastguard Worker D = (<diagonal matrix of SBS>) ** -1/2 947*da0073e9SAndroid Build Coastguard Worker R^T R = Cholesky(D SBS D) 948*da0073e9SAndroid Build Coastguard Worker Ri = D R^-1 949*da0073e9SAndroid Build Coastguard Worker solve symeig problem Ri^T SAS Ri Z = Theta Z 950*da0073e9SAndroid Build Coastguard Worker C = Ri Z 951*da0073e9SAndroid Build Coastguard Worker 952*da0073e9SAndroid Build Coastguard Worker To reduce the number of matrix products (denoted by empty 953*da0073e9SAndroid Build Coastguard Worker space between matrices), here we introduce element-wise 954*da0073e9SAndroid Build Coastguard Worker products (denoted by symbol `*`) so that the Rayleight-Ritz 955*da0073e9SAndroid Build Coastguard Worker procedure becomes:: 956*da0073e9SAndroid Build Coastguard Worker 957*da0073e9SAndroid Build Coastguard Worker SAS = S^T A S 958*da0073e9SAndroid Build Coastguard Worker SBS = S^T B S 959*da0073e9SAndroid Build Coastguard Worker d = (<diagonal of SBS>) ** -1/2 # this is 1-d column vector 960*da0073e9SAndroid Build Coastguard Worker dd = d d^T # this is 2-d matrix 961*da0073e9SAndroid Build Coastguard Worker R^T R = Cholesky(dd * SBS) 962*da0073e9SAndroid Build Coastguard Worker Ri = R^-1 * d # broadcasting 963*da0073e9SAndroid Build Coastguard Worker solve symeig problem Ri^T SAS Ri Z = Theta Z 964*da0073e9SAndroid Build Coastguard Worker C = Ri Z 965*da0073e9SAndroid Build Coastguard Worker 966*da0073e9SAndroid Build Coastguard Worker where `dd` is 2-d matrix that replaces matrix products `D M 967*da0073e9SAndroid Build Coastguard Worker D` with one element-wise product `M * dd`; and `d` replaces 968*da0073e9SAndroid Build Coastguard Worker matrix product `D M` with element-wise product `M * 969*da0073e9SAndroid Build Coastguard Worker d`. Also, creating the diagonal matrix `D` is avoided. 970*da0073e9SAndroid Build Coastguard Worker 971*da0073e9SAndroid Build Coastguard Worker Args: 972*da0073e9SAndroid Build Coastguard Worker S (Tensor): the matrix basis for the search subspace, size is 973*da0073e9SAndroid Build Coastguard Worker :math:`(m, n)`. 974*da0073e9SAndroid Build Coastguard Worker 975*da0073e9SAndroid Build Coastguard Worker Returns: 976*da0073e9SAndroid Build Coastguard Worker Ri (tensor): upper-triangular transformation matrix of size 977*da0073e9SAndroid Build Coastguard Worker :math:`(n, n)`. 978*da0073e9SAndroid Build Coastguard Worker 979*da0073e9SAndroid Build Coastguard Worker """ 980*da0073e9SAndroid Build Coastguard Worker B = self.B 981*da0073e9SAndroid Build Coastguard Worker mm = torch.matmul 982*da0073e9SAndroid Build Coastguard Worker SBS = _utils.qform(B, S) 983*da0073e9SAndroid Build Coastguard Worker d_row = SBS.diagonal(0, -2, -1) ** -0.5 984*da0073e9SAndroid Build Coastguard Worker d_col = d_row.reshape(d_row.shape[0], 1) 985*da0073e9SAndroid Build Coastguard Worker # TODO use torch.linalg.cholesky_solve once it is implemented 986*da0073e9SAndroid Build Coastguard Worker R = torch.linalg.cholesky((SBS * d_row) * d_col, upper=True) 987*da0073e9SAndroid Build Coastguard Worker return torch.linalg.solve_triangular( 988*da0073e9SAndroid Build Coastguard Worker R, d_row.diag_embed(), upper=True, left=False 989*da0073e9SAndroid Build Coastguard Worker ) 990*da0073e9SAndroid Build Coastguard Worker 991*da0073e9SAndroid Build Coastguard Worker def _get_svqb(self, U: Tensor, drop: bool, tau: float) -> Tensor: 992*da0073e9SAndroid Build Coastguard Worker """Return B-orthonormal U. 993*da0073e9SAndroid Build Coastguard Worker 994*da0073e9SAndroid Build Coastguard Worker .. note:: When `drop` is `False` then `svqb` is based on the 995*da0073e9SAndroid Build Coastguard Worker Algorithm 4 from [DuerschPhD2015] that is a slight 996*da0073e9SAndroid Build Coastguard Worker modification of the corresponding algorithm 997*da0073e9SAndroid Build Coastguard Worker introduced in [StathopolousWu2002]. 998*da0073e9SAndroid Build Coastguard Worker 999*da0073e9SAndroid Build Coastguard Worker Args: 1000*da0073e9SAndroid Build Coastguard Worker 1001*da0073e9SAndroid Build Coastguard Worker U (Tensor) : initial approximation, size is (m, n) 1002*da0073e9SAndroid Build Coastguard Worker drop (bool) : when True, drop columns that 1003*da0073e9SAndroid Build Coastguard Worker contribution to the `span([U])` is small. 1004*da0073e9SAndroid Build Coastguard Worker tau (float) : positive tolerance 1005*da0073e9SAndroid Build Coastguard Worker 1006*da0073e9SAndroid Build Coastguard Worker Returns: 1007*da0073e9SAndroid Build Coastguard Worker 1008*da0073e9SAndroid Build Coastguard Worker U (Tensor) : B-orthonormal columns (:math:`U^T B U = I`), size 1009*da0073e9SAndroid Build Coastguard Worker is (m, n1), where `n1 = n` if `drop` is `False, 1010*da0073e9SAndroid Build Coastguard Worker otherwise `n1 <= n`. 1011*da0073e9SAndroid Build Coastguard Worker 1012*da0073e9SAndroid Build Coastguard Worker """ 1013*da0073e9SAndroid Build Coastguard Worker if torch.numel(U) == 0: 1014*da0073e9SAndroid Build Coastguard Worker return U 1015*da0073e9SAndroid Build Coastguard Worker UBU = _utils.qform(self.B, U) 1016*da0073e9SAndroid Build Coastguard Worker d = UBU.diagonal(0, -2, -1) 1017*da0073e9SAndroid Build Coastguard Worker 1018*da0073e9SAndroid Build Coastguard Worker # Detect and drop exact zero columns from U. While the test 1019*da0073e9SAndroid Build Coastguard Worker # `abs(d) == 0` is unlikely to be True for random data, it is 1020*da0073e9SAndroid Build Coastguard Worker # possible to construct input data to lobpcg where it will be 1021*da0073e9SAndroid Build Coastguard Worker # True leading to a failure (notice the `d ** -0.5` operation 1022*da0073e9SAndroid Build Coastguard Worker # in the original algorithm). To prevent the failure, we drop 1023*da0073e9SAndroid Build Coastguard Worker # the exact zero columns here and then continue with the 1024*da0073e9SAndroid Build Coastguard Worker # original algorithm below. 1025*da0073e9SAndroid Build Coastguard Worker nz = torch.where(abs(d) != 0.0) 1026*da0073e9SAndroid Build Coastguard Worker assert len(nz) == 1, nz 1027*da0073e9SAndroid Build Coastguard Worker if len(nz[0]) < len(d): 1028*da0073e9SAndroid Build Coastguard Worker U = U[:, nz[0]] 1029*da0073e9SAndroid Build Coastguard Worker if torch.numel(U) == 0: 1030*da0073e9SAndroid Build Coastguard Worker return U 1031*da0073e9SAndroid Build Coastguard Worker UBU = _utils.qform(self.B, U) 1032*da0073e9SAndroid Build Coastguard Worker d = UBU.diagonal(0, -2, -1) 1033*da0073e9SAndroid Build Coastguard Worker nz = torch.where(abs(d) != 0.0) 1034*da0073e9SAndroid Build Coastguard Worker assert len(nz[0]) == len(d) 1035*da0073e9SAndroid Build Coastguard Worker 1036*da0073e9SAndroid Build Coastguard Worker # The original algorithm 4 from [DuerschPhD2015]. 1037*da0073e9SAndroid Build Coastguard Worker d_col = (d**-0.5).reshape(d.shape[0], 1) 1038*da0073e9SAndroid Build Coastguard Worker DUBUD = (UBU * d_col) * d_col.mT 1039*da0073e9SAndroid Build Coastguard Worker E, Z = _utils.symeig(DUBUD) 1040*da0073e9SAndroid Build Coastguard Worker t = tau * abs(E).max() 1041*da0073e9SAndroid Build Coastguard Worker if drop: 1042*da0073e9SAndroid Build Coastguard Worker keep = torch.where(E > t) 1043*da0073e9SAndroid Build Coastguard Worker assert len(keep) == 1, keep 1044*da0073e9SAndroid Build Coastguard Worker E = E[keep[0]] 1045*da0073e9SAndroid Build Coastguard Worker Z = Z[:, keep[0]] 1046*da0073e9SAndroid Build Coastguard Worker d_col = d_col[keep[0]] 1047*da0073e9SAndroid Build Coastguard Worker else: 1048*da0073e9SAndroid Build Coastguard Worker E[(torch.where(E < t))[0]] = t 1049*da0073e9SAndroid Build Coastguard Worker 1050*da0073e9SAndroid Build Coastguard Worker return torch.matmul(U * d_col.mT, Z * E**-0.5) 1051*da0073e9SAndroid Build Coastguard Worker 1052*da0073e9SAndroid Build Coastguard Worker def _get_ortho(self, U, V): 1053*da0073e9SAndroid Build Coastguard Worker """Return B-orthonormal U with columns are B-orthogonal to V. 1054*da0073e9SAndroid Build Coastguard Worker 1055*da0073e9SAndroid Build Coastguard Worker .. note:: When `bparams["ortho_use_drop"] == False` then 1056*da0073e9SAndroid Build Coastguard Worker `_get_ortho` is based on the Algorithm 3 from 1057*da0073e9SAndroid Build Coastguard Worker [DuerschPhD2015] that is a slight modification of 1058*da0073e9SAndroid Build Coastguard Worker the corresponding algorithm introduced in 1059*da0073e9SAndroid Build Coastguard Worker [StathopolousWu2002]. Otherwise, the method 1060*da0073e9SAndroid Build Coastguard Worker implements Algorithm 6 from [DuerschPhD2015] 1061*da0073e9SAndroid Build Coastguard Worker 1062*da0073e9SAndroid Build Coastguard Worker .. note:: If all U columns are B-collinear to V then the 1063*da0073e9SAndroid Build Coastguard Worker returned tensor U will be empty. 1064*da0073e9SAndroid Build Coastguard Worker 1065*da0073e9SAndroid Build Coastguard Worker Args: 1066*da0073e9SAndroid Build Coastguard Worker 1067*da0073e9SAndroid Build Coastguard Worker U (Tensor) : initial approximation, size is (m, n) 1068*da0073e9SAndroid Build Coastguard Worker V (Tensor) : B-orthogonal external basis, size is (m, k) 1069*da0073e9SAndroid Build Coastguard Worker 1070*da0073e9SAndroid Build Coastguard Worker Returns: 1071*da0073e9SAndroid Build Coastguard Worker 1072*da0073e9SAndroid Build Coastguard Worker U (Tensor) : B-orthonormal columns (:math:`U^T B U = I`) 1073*da0073e9SAndroid Build Coastguard Worker such that :math:`V^T B U=0`, size is (m, n1), 1074*da0073e9SAndroid Build Coastguard Worker where `n1 = n` if `drop` is `False, otherwise 1075*da0073e9SAndroid Build Coastguard Worker `n1 <= n`. 1076*da0073e9SAndroid Build Coastguard Worker """ 1077*da0073e9SAndroid Build Coastguard Worker mm = torch.matmul 1078*da0073e9SAndroid Build Coastguard Worker mm_B = _utils.matmul 1079*da0073e9SAndroid Build Coastguard Worker m = self.iparams["m"] 1080*da0073e9SAndroid Build Coastguard Worker tau_ortho = self.fparams["ortho_tol"] 1081*da0073e9SAndroid Build Coastguard Worker tau_drop = self.fparams["ortho_tol_drop"] 1082*da0073e9SAndroid Build Coastguard Worker tau_replace = self.fparams["ortho_tol_replace"] 1083*da0073e9SAndroid Build Coastguard Worker i_max = self.iparams["ortho_i_max"] 1084*da0073e9SAndroid Build Coastguard Worker j_max = self.iparams["ortho_j_max"] 1085*da0073e9SAndroid Build Coastguard Worker # when use_drop==True, enable dropping U columns that have 1086*da0073e9SAndroid Build Coastguard Worker # small contribution to the `span([U, V])`. 1087*da0073e9SAndroid Build Coastguard Worker use_drop = self.bparams["ortho_use_drop"] 1088*da0073e9SAndroid Build Coastguard Worker 1089*da0073e9SAndroid Build Coastguard Worker # clean up variables from the previous call 1090*da0073e9SAndroid Build Coastguard Worker for vkey in list(self.fvars.keys()): 1091*da0073e9SAndroid Build Coastguard Worker if vkey.startswith("ortho_") and vkey.endswith("_rerr"): 1092*da0073e9SAndroid Build Coastguard Worker self.fvars.pop(vkey) 1093*da0073e9SAndroid Build Coastguard Worker self.ivars.pop("ortho_i", 0) 1094*da0073e9SAndroid Build Coastguard Worker self.ivars.pop("ortho_j", 0) 1095*da0073e9SAndroid Build Coastguard Worker 1096*da0073e9SAndroid Build Coastguard Worker BV_norm = torch.norm(mm_B(self.B, V)) 1097*da0073e9SAndroid Build Coastguard Worker BU = mm_B(self.B, U) 1098*da0073e9SAndroid Build Coastguard Worker VBU = mm(V.mT, BU) 1099*da0073e9SAndroid Build Coastguard Worker i = j = 0 1100*da0073e9SAndroid Build Coastguard Worker stats = "" 1101*da0073e9SAndroid Build Coastguard Worker for i in range(i_max): 1102*da0073e9SAndroid Build Coastguard Worker U = U - mm(V, VBU) 1103*da0073e9SAndroid Build Coastguard Worker drop = False 1104*da0073e9SAndroid Build Coastguard Worker tau_svqb = tau_drop 1105*da0073e9SAndroid Build Coastguard Worker for j in range(j_max): 1106*da0073e9SAndroid Build Coastguard Worker if use_drop: 1107*da0073e9SAndroid Build Coastguard Worker U = self._get_svqb(U, drop, tau_svqb) 1108*da0073e9SAndroid Build Coastguard Worker drop = True 1109*da0073e9SAndroid Build Coastguard Worker tau_svqb = tau_replace 1110*da0073e9SAndroid Build Coastguard Worker else: 1111*da0073e9SAndroid Build Coastguard Worker U = self._get_svqb(U, False, tau_replace) 1112*da0073e9SAndroid Build Coastguard Worker if torch.numel(U) == 0: 1113*da0073e9SAndroid Build Coastguard Worker # all initial U columns are B-collinear to V 1114*da0073e9SAndroid Build Coastguard Worker self.ivars["ortho_i"] = i 1115*da0073e9SAndroid Build Coastguard Worker self.ivars["ortho_j"] = j 1116*da0073e9SAndroid Build Coastguard Worker return U 1117*da0073e9SAndroid Build Coastguard Worker BU = mm_B(self.B, U) 1118*da0073e9SAndroid Build Coastguard Worker UBU = mm(U.mT, BU) 1119*da0073e9SAndroid Build Coastguard Worker U_norm = torch.norm(U) 1120*da0073e9SAndroid Build Coastguard Worker BU_norm = torch.norm(BU) 1121*da0073e9SAndroid Build Coastguard Worker R = UBU - torch.eye(UBU.shape[-1], device=UBU.device, dtype=UBU.dtype) 1122*da0073e9SAndroid Build Coastguard Worker R_norm = torch.norm(R) 1123*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/33810 workaround: 1124*da0073e9SAndroid Build Coastguard Worker rerr = float(R_norm) * float(BU_norm * U_norm) ** -1 1125*da0073e9SAndroid Build Coastguard Worker vkey = f"ortho_UBUmI_rerr[{i}, {j}]" 1126*da0073e9SAndroid Build Coastguard Worker self.fvars[vkey] = rerr 1127*da0073e9SAndroid Build Coastguard Worker if rerr < tau_ortho: 1128*da0073e9SAndroid Build Coastguard Worker break 1129*da0073e9SAndroid Build Coastguard Worker VBU = mm(V.mT, BU) 1130*da0073e9SAndroid Build Coastguard Worker VBU_norm = torch.norm(VBU) 1131*da0073e9SAndroid Build Coastguard Worker U_norm = torch.norm(U) 1132*da0073e9SAndroid Build Coastguard Worker rerr = float(VBU_norm) * float(BV_norm * U_norm) ** -1 1133*da0073e9SAndroid Build Coastguard Worker vkey = f"ortho_VBU_rerr[{i}]" 1134*da0073e9SAndroid Build Coastguard Worker self.fvars[vkey] = rerr 1135*da0073e9SAndroid Build Coastguard Worker if rerr < tau_ortho: 1136*da0073e9SAndroid Build Coastguard Worker break 1137*da0073e9SAndroid Build Coastguard Worker if m < U.shape[-1] + V.shape[-1]: 1138*da0073e9SAndroid Build Coastguard Worker # TorchScript needs the class var to be assigned to a local to 1139*da0073e9SAndroid Build Coastguard Worker # do optional type refinement 1140*da0073e9SAndroid Build Coastguard Worker B = self.B 1141*da0073e9SAndroid Build Coastguard Worker assert B is not None 1142*da0073e9SAndroid Build Coastguard Worker raise ValueError( 1143*da0073e9SAndroid Build Coastguard Worker "Overdetermined shape of U:" 1144*da0073e9SAndroid Build Coastguard Worker f" #B-cols(={B.shape[-1]}) >= #U-cols(={U.shape[-1]}) + #V-cols(={V.shape[-1]}) must hold" 1145*da0073e9SAndroid Build Coastguard Worker ) 1146*da0073e9SAndroid Build Coastguard Worker self.ivars["ortho_i"] = i 1147*da0073e9SAndroid Build Coastguard Worker self.ivars["ortho_j"] = j 1148*da0073e9SAndroid Build Coastguard Worker return U 1149*da0073e9SAndroid Build Coastguard Worker 1150*da0073e9SAndroid Build Coastguard Worker 1151*da0073e9SAndroid Build Coastguard Worker# Calling tracker is separated from LOBPCG definitions because 1152*da0073e9SAndroid Build Coastguard Worker# TorchScript does not support user-defined callback arguments: 1153*da0073e9SAndroid Build Coastguard WorkerLOBPCG_call_tracker_orig = LOBPCG.call_tracker 1154*da0073e9SAndroid Build Coastguard Worker 1155*da0073e9SAndroid Build Coastguard Worker 1156*da0073e9SAndroid Build Coastguard Workerdef LOBPCG_call_tracker(self): 1157*da0073e9SAndroid Build Coastguard Worker self.tracker(self) 1158