xref: /aosp_15_r20/external/pytorch/torch/_lobpcg.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs
2*da0073e9SAndroid Build Coastguard Worker"""Locally Optimal Block Preconditioned Conjugate Gradient methods."""
3*da0073e9SAndroid Build Coastguard Worker# Author: Pearu Peterson
4*da0073e9SAndroid Build Coastguard Worker# Created: February 2020
5*da0073e9SAndroid Build Coastguard Worker
6*da0073e9SAndroid Build Coastguard Workerfrom typing import Dict, Optional, Tuple
7*da0073e9SAndroid Build Coastguard Worker
8*da0073e9SAndroid Build Coastguard Workerimport torch
9*da0073e9SAndroid Build Coastguard Workerfrom torch import _linalg_utils as _utils, Tensor
10*da0073e9SAndroid Build Coastguard Workerfrom torch.overrides import handle_torch_function, has_torch_function
11*da0073e9SAndroid Build Coastguard Worker
12*da0073e9SAndroid Build Coastguard Worker
13*da0073e9SAndroid Build Coastguard Worker__all__ = ["lobpcg"]
14*da0073e9SAndroid Build Coastguard Worker
15*da0073e9SAndroid Build Coastguard Worker
16*da0073e9SAndroid Build Coastguard Workerdef _symeig_backward_complete_eigenspace(D_grad, U_grad, A, D, U):
17*da0073e9SAndroid Build Coastguard Worker    # compute F, such that F_ij = (d_j - d_i)^{-1} for i != j, F_ii = 0
18*da0073e9SAndroid Build Coastguard Worker    F = D.unsqueeze(-2) - D.unsqueeze(-1)
19*da0073e9SAndroid Build Coastguard Worker    F.diagonal(dim1=-2, dim2=-1).fill_(float("inf"))
20*da0073e9SAndroid Build Coastguard Worker    F.pow_(-1)
21*da0073e9SAndroid Build Coastguard Worker
22*da0073e9SAndroid Build Coastguard Worker    # A.grad = U (D.grad + (U^T U.grad * F)) U^T
23*da0073e9SAndroid Build Coastguard Worker    Ut = U.mT.contiguous()
24*da0073e9SAndroid Build Coastguard Worker    res = torch.matmul(
25*da0073e9SAndroid Build Coastguard Worker        U, torch.matmul(torch.diag_embed(D_grad) + torch.matmul(Ut, U_grad) * F, Ut)
26*da0073e9SAndroid Build Coastguard Worker    )
27*da0073e9SAndroid Build Coastguard Worker
28*da0073e9SAndroid Build Coastguard Worker    return res
29*da0073e9SAndroid Build Coastguard Worker
30*da0073e9SAndroid Build Coastguard Worker
31*da0073e9SAndroid Build Coastguard Workerdef _polynomial_coefficients_given_roots(roots):
32*da0073e9SAndroid Build Coastguard Worker    """
33*da0073e9SAndroid Build Coastguard Worker    Given the `roots` of a polynomial, find the polynomial's coefficients.
34*da0073e9SAndroid Build Coastguard Worker
35*da0073e9SAndroid Build Coastguard Worker    If roots = (r_1, ..., r_n), then the method returns
36*da0073e9SAndroid Build Coastguard Worker    coefficients (a_0, a_1, ..., a_n (== 1)) so that
37*da0073e9SAndroid Build Coastguard Worker    p(x) = (x - r_1) * ... * (x - r_n)
38*da0073e9SAndroid Build Coastguard Worker         = x^n + a_{n-1} * x^{n-1} + ... a_1 * x_1 + a_0
39*da0073e9SAndroid Build Coastguard Worker
40*da0073e9SAndroid Build Coastguard Worker    Note: for better performance requires writing a low-level kernel
41*da0073e9SAndroid Build Coastguard Worker    """
42*da0073e9SAndroid Build Coastguard Worker    poly_order = roots.shape[-1]
43*da0073e9SAndroid Build Coastguard Worker    poly_coeffs_shape = list(roots.shape)
44*da0073e9SAndroid Build Coastguard Worker    # we assume p(x) = x^n + a_{n-1} * x^{n-1} + ... + a_1 * x + a_0,
45*da0073e9SAndroid Build Coastguard Worker    # so poly_coeffs = {a_0, ..., a_n, a_{n+1}(== 1)},
46*da0073e9SAndroid Build Coastguard Worker    # but we insert one extra coefficient to enable better vectorization below
47*da0073e9SAndroid Build Coastguard Worker    poly_coeffs_shape[-1] += 2
48*da0073e9SAndroid Build Coastguard Worker    poly_coeffs = roots.new_zeros(poly_coeffs_shape)
49*da0073e9SAndroid Build Coastguard Worker    poly_coeffs[..., 0] = 1
50*da0073e9SAndroid Build Coastguard Worker    poly_coeffs[..., -1] = 1
51*da0073e9SAndroid Build Coastguard Worker
52*da0073e9SAndroid Build Coastguard Worker    # perform the Horner's rule
53*da0073e9SAndroid Build Coastguard Worker    for i in range(1, poly_order + 1):
54*da0073e9SAndroid Build Coastguard Worker        # note that it is computationally hard to compute backward for this method,
55*da0073e9SAndroid Build Coastguard Worker        # because then given the coefficients it would require finding the roots and/or
56*da0073e9SAndroid Build Coastguard Worker        # calculating the sensitivity based on the Vieta's theorem.
57*da0073e9SAndroid Build Coastguard Worker        # So the code below tries to circumvent the explicit root finding by series
58*da0073e9SAndroid Build Coastguard Worker        # of operations on memory copies imitating the Horner's method.
59*da0073e9SAndroid Build Coastguard Worker        # The memory copies are required to construct nodes in the computational graph
60*da0073e9SAndroid Build Coastguard Worker        # by exploting the explicit (not in-place, separate node for each step)
61*da0073e9SAndroid Build Coastguard Worker        # recursion of the Horner's method.
62*da0073e9SAndroid Build Coastguard Worker        # Needs more memory, O(... * k^2), but with only O(... * k^2) complexity.
63*da0073e9SAndroid Build Coastguard Worker        poly_coeffs_new = poly_coeffs.clone() if roots.requires_grad else poly_coeffs
64*da0073e9SAndroid Build Coastguard Worker        out = poly_coeffs_new.narrow(-1, poly_order - i, i + 1)
65*da0073e9SAndroid Build Coastguard Worker        out -= roots.narrow(-1, i - 1, 1) * poly_coeffs.narrow(
66*da0073e9SAndroid Build Coastguard Worker            -1, poly_order - i + 1, i + 1
67*da0073e9SAndroid Build Coastguard Worker        )
68*da0073e9SAndroid Build Coastguard Worker        poly_coeffs = poly_coeffs_new
69*da0073e9SAndroid Build Coastguard Worker
70*da0073e9SAndroid Build Coastguard Worker    return poly_coeffs.narrow(-1, 1, poly_order + 1)
71*da0073e9SAndroid Build Coastguard Worker
72*da0073e9SAndroid Build Coastguard Worker
73*da0073e9SAndroid Build Coastguard Workerdef _polynomial_value(poly, x, zero_power, transition):
74*da0073e9SAndroid Build Coastguard Worker    """
75*da0073e9SAndroid Build Coastguard Worker    A generic method for computing poly(x) using the Horner's rule.
76*da0073e9SAndroid Build Coastguard Worker
77*da0073e9SAndroid Build Coastguard Worker    Args:
78*da0073e9SAndroid Build Coastguard Worker      poly (Tensor): the (possibly batched) 1D Tensor representing
79*da0073e9SAndroid Build Coastguard Worker                     polynomial coefficients such that
80*da0073e9SAndroid Build Coastguard Worker                     poly[..., i] = (a_{i_0}, ..., a{i_n} (==1)), and
81*da0073e9SAndroid Build Coastguard Worker                     poly(x) = poly[..., 0] * zero_power + ... + poly[..., n] * x^n
82*da0073e9SAndroid Build Coastguard Worker
83*da0073e9SAndroid Build Coastguard Worker      x (Tensor): the value (possible batched) to evalate the polynomial `poly` at.
84*da0073e9SAndroid Build Coastguard Worker
85*da0073e9SAndroid Build Coastguard Worker      zero_power (Tensor): the representation of `x^0`. It is application-specific.
86*da0073e9SAndroid Build Coastguard Worker
87*da0073e9SAndroid Build Coastguard Worker      transition (Callable): the function that accepts some intermediate result `int_val`,
88*da0073e9SAndroid Build Coastguard Worker                             the `x` and a specific polynomial coefficient
89*da0073e9SAndroid Build Coastguard Worker                             `poly[..., k]` for some iteration `k`.
90*da0073e9SAndroid Build Coastguard Worker                             It basically performs one iteration of the Horner's rule
91*da0073e9SAndroid Build Coastguard Worker                             defined as `x * int_val + poly[..., k] * zero_power`.
92*da0073e9SAndroid Build Coastguard Worker                             Note that `zero_power` is not a parameter,
93*da0073e9SAndroid Build Coastguard Worker                             because the step `+ poly[..., k] * zero_power` depends on `x`,
94*da0073e9SAndroid Build Coastguard Worker                             whether it is a vector, a matrix, or something else, so this
95*da0073e9SAndroid Build Coastguard Worker                             functionality is delegated to the user.
96*da0073e9SAndroid Build Coastguard Worker    """
97*da0073e9SAndroid Build Coastguard Worker
98*da0073e9SAndroid Build Coastguard Worker    res = zero_power.clone()
99*da0073e9SAndroid Build Coastguard Worker    for k in range(poly.size(-1) - 2, -1, -1):
100*da0073e9SAndroid Build Coastguard Worker        res = transition(res, x, poly[..., k])
101*da0073e9SAndroid Build Coastguard Worker    return res
102*da0073e9SAndroid Build Coastguard Worker
103*da0073e9SAndroid Build Coastguard Worker
104*da0073e9SAndroid Build Coastguard Workerdef _matrix_polynomial_value(poly, x, zero_power=None):
105*da0073e9SAndroid Build Coastguard Worker    """
106*da0073e9SAndroid Build Coastguard Worker    Evaluates `poly(x)` for the (batched) matrix input `x`.
107*da0073e9SAndroid Build Coastguard Worker    Check out `_polynomial_value` function for more details.
108*da0073e9SAndroid Build Coastguard Worker    """
109*da0073e9SAndroid Build Coastguard Worker
110*da0073e9SAndroid Build Coastguard Worker    # matrix-aware Horner's rule iteration
111*da0073e9SAndroid Build Coastguard Worker    def transition(curr_poly_val, x, poly_coeff):
112*da0073e9SAndroid Build Coastguard Worker        res = x.matmul(curr_poly_val)
113*da0073e9SAndroid Build Coastguard Worker        res.diagonal(dim1=-2, dim2=-1).add_(poly_coeff.unsqueeze(-1))
114*da0073e9SAndroid Build Coastguard Worker        return res
115*da0073e9SAndroid Build Coastguard Worker
116*da0073e9SAndroid Build Coastguard Worker    if zero_power is None:
117*da0073e9SAndroid Build Coastguard Worker        zero_power = torch.eye(
118*da0073e9SAndroid Build Coastguard Worker            x.size(-1), x.size(-1), dtype=x.dtype, device=x.device
119*da0073e9SAndroid Build Coastguard Worker        ).view(*([1] * len(list(x.shape[:-2]))), x.size(-1), x.size(-1))
120*da0073e9SAndroid Build Coastguard Worker
121*da0073e9SAndroid Build Coastguard Worker    return _polynomial_value(poly, x, zero_power, transition)
122*da0073e9SAndroid Build Coastguard Worker
123*da0073e9SAndroid Build Coastguard Worker
124*da0073e9SAndroid Build Coastguard Workerdef _vector_polynomial_value(poly, x, zero_power=None):
125*da0073e9SAndroid Build Coastguard Worker    """
126*da0073e9SAndroid Build Coastguard Worker    Evaluates `poly(x)` for the (batched) vector input `x`.
127*da0073e9SAndroid Build Coastguard Worker    Check out `_polynomial_value` function for more details.
128*da0073e9SAndroid Build Coastguard Worker    """
129*da0073e9SAndroid Build Coastguard Worker
130*da0073e9SAndroid Build Coastguard Worker    # vector-aware Horner's rule iteration
131*da0073e9SAndroid Build Coastguard Worker    def transition(curr_poly_val, x, poly_coeff):
132*da0073e9SAndroid Build Coastguard Worker        res = torch.addcmul(poly_coeff.unsqueeze(-1), x, curr_poly_val)
133*da0073e9SAndroid Build Coastguard Worker        return res
134*da0073e9SAndroid Build Coastguard Worker
135*da0073e9SAndroid Build Coastguard Worker    if zero_power is None:
136*da0073e9SAndroid Build Coastguard Worker        zero_power = x.new_ones(1).expand(x.shape)
137*da0073e9SAndroid Build Coastguard Worker
138*da0073e9SAndroid Build Coastguard Worker    return _polynomial_value(poly, x, zero_power, transition)
139*da0073e9SAndroid Build Coastguard Worker
140*da0073e9SAndroid Build Coastguard Worker
141*da0073e9SAndroid Build Coastguard Workerdef _symeig_backward_partial_eigenspace(D_grad, U_grad, A, D, U, largest):
142*da0073e9SAndroid Build Coastguard Worker    # compute a projection operator onto an orthogonal subspace spanned by the
143*da0073e9SAndroid Build Coastguard Worker    # columns of U defined as (I - UU^T)
144*da0073e9SAndroid Build Coastguard Worker    Ut = U.mT.contiguous()
145*da0073e9SAndroid Build Coastguard Worker    proj_U_ortho = -U.matmul(Ut)
146*da0073e9SAndroid Build Coastguard Worker    proj_U_ortho.diagonal(dim1=-2, dim2=-1).add_(1)
147*da0073e9SAndroid Build Coastguard Worker
148*da0073e9SAndroid Build Coastguard Worker    # compute U_ortho, a basis for the orthogonal complement to the span(U),
149*da0073e9SAndroid Build Coastguard Worker    # by projecting a random [..., m, m - k] matrix onto the subspace spanned
150*da0073e9SAndroid Build Coastguard Worker    # by the columns of U.
151*da0073e9SAndroid Build Coastguard Worker    #
152*da0073e9SAndroid Build Coastguard Worker    # fix generator for determinism
153*da0073e9SAndroid Build Coastguard Worker    gen = torch.Generator(A.device)
154*da0073e9SAndroid Build Coastguard Worker
155*da0073e9SAndroid Build Coastguard Worker    # orthogonal complement to the span(U)
156*da0073e9SAndroid Build Coastguard Worker    U_ortho = proj_U_ortho.matmul(
157*da0073e9SAndroid Build Coastguard Worker        torch.randn(
158*da0073e9SAndroid Build Coastguard Worker            (*A.shape[:-1], A.size(-1) - D.size(-1)),
159*da0073e9SAndroid Build Coastguard Worker            dtype=A.dtype,
160*da0073e9SAndroid Build Coastguard Worker            device=A.device,
161*da0073e9SAndroid Build Coastguard Worker            generator=gen,
162*da0073e9SAndroid Build Coastguard Worker        )
163*da0073e9SAndroid Build Coastguard Worker    )
164*da0073e9SAndroid Build Coastguard Worker    U_ortho_t = U_ortho.mT.contiguous()
165*da0073e9SAndroid Build Coastguard Worker
166*da0073e9SAndroid Build Coastguard Worker    # compute the coefficients of the characteristic polynomial of the tensor D.
167*da0073e9SAndroid Build Coastguard Worker    # Note that D is diagonal, so the diagonal elements are exactly the roots
168*da0073e9SAndroid Build Coastguard Worker    # of the characteristic polynomial.
169*da0073e9SAndroid Build Coastguard Worker    chr_poly_D = _polynomial_coefficients_given_roots(D)
170*da0073e9SAndroid Build Coastguard Worker
171*da0073e9SAndroid Build Coastguard Worker    # the code belows finds the explicit solution to the Sylvester equation
172*da0073e9SAndroid Build Coastguard Worker    # U_ortho^T A U_ortho dX - dX D = -U_ortho^T A U
173*da0073e9SAndroid Build Coastguard Worker    # and incorporates it into the whole gradient stored in the `res` variable.
174*da0073e9SAndroid Build Coastguard Worker    #
175*da0073e9SAndroid Build Coastguard Worker    # Equivalent to the following naive implementation:
176*da0073e9SAndroid Build Coastguard Worker    # res = A.new_zeros(A.shape)
177*da0073e9SAndroid Build Coastguard Worker    # p_res = A.new_zeros(*A.shape[:-1], D.size(-1))
178*da0073e9SAndroid Build Coastguard Worker    # for k in range(1, chr_poly_D.size(-1)):
179*da0073e9SAndroid Build Coastguard Worker    #     p_res.zero_()
180*da0073e9SAndroid Build Coastguard Worker    #     for i in range(0, k):
181*da0073e9SAndroid Build Coastguard Worker    #         p_res += (A.matrix_power(k - 1 - i) @ U_grad) * D.pow(i).unsqueeze(-2)
182*da0073e9SAndroid Build Coastguard Worker    #     res -= chr_poly_D[k] * (U_ortho @ poly_D_at_A.inverse() @ U_ortho_t @  p_res @ U.t())
183*da0073e9SAndroid Build Coastguard Worker    #
184*da0073e9SAndroid Build Coastguard Worker    # Note that dX is a differential, so the gradient contribution comes from the backward sensitivity
185*da0073e9SAndroid Build Coastguard Worker    # Tr(f(U_grad, D_grad, A, U, D)^T dX) = Tr(g(U_grad, A, U, D)^T dA) for some functions f and g,
186*da0073e9SAndroid Build Coastguard Worker    # and we need to compute g(U_grad, A, U, D)
187*da0073e9SAndroid Build Coastguard Worker    #
188*da0073e9SAndroid Build Coastguard Worker    # The naive implementation is based on the paper
189*da0073e9SAndroid Build Coastguard Worker    # Hu, Qingxi, and Daizhan Cheng.
190*da0073e9SAndroid Build Coastguard Worker    # "The polynomial solution to the Sylvester matrix equation."
191*da0073e9SAndroid Build Coastguard Worker    # Applied mathematics letters 19.9 (2006): 859-864.
192*da0073e9SAndroid Build Coastguard Worker    #
193*da0073e9SAndroid Build Coastguard Worker    # We can modify the computation of `p_res` from above in a more efficient way
194*da0073e9SAndroid Build Coastguard Worker    # p_res =   U_grad * (chr_poly_D[1] * D.pow(0) + ... + chr_poly_D[k] * D.pow(k)).unsqueeze(-2)
195*da0073e9SAndroid Build Coastguard Worker    #       + A U_grad * (chr_poly_D[2] * D.pow(0) + ... + chr_poly_D[k] * D.pow(k - 1)).unsqueeze(-2)
196*da0073e9SAndroid Build Coastguard Worker    #       + ...
197*da0073e9SAndroid Build Coastguard Worker    #       + A.matrix_power(k - 1) U_grad * chr_poly_D[k]
198*da0073e9SAndroid Build Coastguard Worker    # Note that this saves us from redundant matrix products with A (elimination of matrix_power)
199*da0073e9SAndroid Build Coastguard Worker    U_grad_projected = U_grad
200*da0073e9SAndroid Build Coastguard Worker    series_acc = U_grad_projected.new_zeros(U_grad_projected.shape)
201*da0073e9SAndroid Build Coastguard Worker    for k in range(1, chr_poly_D.size(-1)):
202*da0073e9SAndroid Build Coastguard Worker        poly_D = _vector_polynomial_value(chr_poly_D[..., k:], D)
203*da0073e9SAndroid Build Coastguard Worker        series_acc += U_grad_projected * poly_D.unsqueeze(-2)
204*da0073e9SAndroid Build Coastguard Worker        U_grad_projected = A.matmul(U_grad_projected)
205*da0073e9SAndroid Build Coastguard Worker
206*da0073e9SAndroid Build Coastguard Worker    # compute chr_poly_D(A) which essentially is:
207*da0073e9SAndroid Build Coastguard Worker    #
208*da0073e9SAndroid Build Coastguard Worker    # chr_poly_D_at_A = A.new_zeros(A.shape)
209*da0073e9SAndroid Build Coastguard Worker    # for k in range(chr_poly_D.size(-1)):
210*da0073e9SAndroid Build Coastguard Worker    #     chr_poly_D_at_A += chr_poly_D[k] * A.matrix_power(k)
211*da0073e9SAndroid Build Coastguard Worker    #
212*da0073e9SAndroid Build Coastguard Worker    # Note, however, for better performance we use the Horner's rule
213*da0073e9SAndroid Build Coastguard Worker    chr_poly_D_at_A = _matrix_polynomial_value(chr_poly_D, A)
214*da0073e9SAndroid Build Coastguard Worker
215*da0073e9SAndroid Build Coastguard Worker    # compute the action of `chr_poly_D_at_A` restricted to U_ortho_t
216*da0073e9SAndroid Build Coastguard Worker    chr_poly_D_at_A_to_U_ortho = torch.matmul(
217*da0073e9SAndroid Build Coastguard Worker        U_ortho_t, torch.matmul(chr_poly_D_at_A, U_ortho)
218*da0073e9SAndroid Build Coastguard Worker    )
219*da0073e9SAndroid Build Coastguard Worker    # we need to invert 'chr_poly_D_at_A_to_U_ortho`, for that we compute its
220*da0073e9SAndroid Build Coastguard Worker    # Cholesky decomposition and then use `torch.cholesky_solve` for better stability.
221*da0073e9SAndroid Build Coastguard Worker    # Cholesky decomposition requires the input to be positive-definite.
222*da0073e9SAndroid Build Coastguard Worker    # Note that `chr_poly_D_at_A_to_U_ortho` is positive-definite if
223*da0073e9SAndroid Build Coastguard Worker    # 1. `largest` == False, or
224*da0073e9SAndroid Build Coastguard Worker    # 2. `largest` == True and `k` is even
225*da0073e9SAndroid Build Coastguard Worker    # under the assumption that `A` has distinct eigenvalues.
226*da0073e9SAndroid Build Coastguard Worker    #
227*da0073e9SAndroid Build Coastguard Worker    # check if `chr_poly_D_at_A_to_U_ortho` is positive-definite or negative-definite
228*da0073e9SAndroid Build Coastguard Worker    chr_poly_D_at_A_to_U_ortho_sign = -1 if (largest and (k % 2 == 1)) else +1
229*da0073e9SAndroid Build Coastguard Worker    chr_poly_D_at_A_to_U_ortho_L = torch.linalg.cholesky(
230*da0073e9SAndroid Build Coastguard Worker        chr_poly_D_at_A_to_U_ortho_sign * chr_poly_D_at_A_to_U_ortho
231*da0073e9SAndroid Build Coastguard Worker    )
232*da0073e9SAndroid Build Coastguard Worker
233*da0073e9SAndroid Build Coastguard Worker    # compute the gradient part in span(U)
234*da0073e9SAndroid Build Coastguard Worker    res = _symeig_backward_complete_eigenspace(D_grad, U_grad, A, D, U)
235*da0073e9SAndroid Build Coastguard Worker
236*da0073e9SAndroid Build Coastguard Worker    # incorporate the Sylvester equation solution into the full gradient
237*da0073e9SAndroid Build Coastguard Worker    # it resides in span(U_ortho)
238*da0073e9SAndroid Build Coastguard Worker    res -= U_ortho.matmul(
239*da0073e9SAndroid Build Coastguard Worker        chr_poly_D_at_A_to_U_ortho_sign
240*da0073e9SAndroid Build Coastguard Worker        * torch.cholesky_solve(
241*da0073e9SAndroid Build Coastguard Worker            U_ortho_t.matmul(series_acc), chr_poly_D_at_A_to_U_ortho_L
242*da0073e9SAndroid Build Coastguard Worker        )
243*da0073e9SAndroid Build Coastguard Worker    ).matmul(Ut)
244*da0073e9SAndroid Build Coastguard Worker
245*da0073e9SAndroid Build Coastguard Worker    return res
246*da0073e9SAndroid Build Coastguard Worker
247*da0073e9SAndroid Build Coastguard Worker
248*da0073e9SAndroid Build Coastguard Workerdef _symeig_backward(D_grad, U_grad, A, D, U, largest):
249*da0073e9SAndroid Build Coastguard Worker    # if `U` is square, then the columns of `U` is a complete eigenspace
250*da0073e9SAndroid Build Coastguard Worker    if U.size(-1) == U.size(-2):
251*da0073e9SAndroid Build Coastguard Worker        return _symeig_backward_complete_eigenspace(D_grad, U_grad, A, D, U)
252*da0073e9SAndroid Build Coastguard Worker    else:
253*da0073e9SAndroid Build Coastguard Worker        return _symeig_backward_partial_eigenspace(D_grad, U_grad, A, D, U, largest)
254*da0073e9SAndroid Build Coastguard Worker
255*da0073e9SAndroid Build Coastguard Worker
256*da0073e9SAndroid Build Coastguard Workerclass LOBPCGAutogradFunction(torch.autograd.Function):
257*da0073e9SAndroid Build Coastguard Worker    @staticmethod
258*da0073e9SAndroid Build Coastguard Worker    def forward(  # type: ignore[override]
259*da0073e9SAndroid Build Coastguard Worker        ctx,
260*da0073e9SAndroid Build Coastguard Worker        A: Tensor,
261*da0073e9SAndroid Build Coastguard Worker        k: Optional[int] = None,
262*da0073e9SAndroid Build Coastguard Worker        B: Optional[Tensor] = None,
263*da0073e9SAndroid Build Coastguard Worker        X: Optional[Tensor] = None,
264*da0073e9SAndroid Build Coastguard Worker        n: Optional[int] = None,
265*da0073e9SAndroid Build Coastguard Worker        iK: Optional[Tensor] = None,
266*da0073e9SAndroid Build Coastguard Worker        niter: Optional[int] = None,
267*da0073e9SAndroid Build Coastguard Worker        tol: Optional[float] = None,
268*da0073e9SAndroid Build Coastguard Worker        largest: Optional[bool] = None,
269*da0073e9SAndroid Build Coastguard Worker        method: Optional[str] = None,
270*da0073e9SAndroid Build Coastguard Worker        tracker: None = None,
271*da0073e9SAndroid Build Coastguard Worker        ortho_iparams: Optional[Dict[str, int]] = None,
272*da0073e9SAndroid Build Coastguard Worker        ortho_fparams: Optional[Dict[str, float]] = None,
273*da0073e9SAndroid Build Coastguard Worker        ortho_bparams: Optional[Dict[str, bool]] = None,
274*da0073e9SAndroid Build Coastguard Worker    ) -> Tuple[Tensor, Tensor]:
275*da0073e9SAndroid Build Coastguard Worker        # makes sure that input is contiguous for efficiency.
276*da0073e9SAndroid Build Coastguard Worker        # Note: autograd does not support dense gradients for sparse input yet.
277*da0073e9SAndroid Build Coastguard Worker        A = A.contiguous() if (not A.is_sparse) else A
278*da0073e9SAndroid Build Coastguard Worker        if B is not None:
279*da0073e9SAndroid Build Coastguard Worker            B = B.contiguous() if (not B.is_sparse) else B
280*da0073e9SAndroid Build Coastguard Worker
281*da0073e9SAndroid Build Coastguard Worker        D, U = _lobpcg(
282*da0073e9SAndroid Build Coastguard Worker            A,
283*da0073e9SAndroid Build Coastguard Worker            k,
284*da0073e9SAndroid Build Coastguard Worker            B,
285*da0073e9SAndroid Build Coastguard Worker            X,
286*da0073e9SAndroid Build Coastguard Worker            n,
287*da0073e9SAndroid Build Coastguard Worker            iK,
288*da0073e9SAndroid Build Coastguard Worker            niter,
289*da0073e9SAndroid Build Coastguard Worker            tol,
290*da0073e9SAndroid Build Coastguard Worker            largest,
291*da0073e9SAndroid Build Coastguard Worker            method,
292*da0073e9SAndroid Build Coastguard Worker            tracker,
293*da0073e9SAndroid Build Coastguard Worker            ortho_iparams,
294*da0073e9SAndroid Build Coastguard Worker            ortho_fparams,
295*da0073e9SAndroid Build Coastguard Worker            ortho_bparams,
296*da0073e9SAndroid Build Coastguard Worker        )
297*da0073e9SAndroid Build Coastguard Worker
298*da0073e9SAndroid Build Coastguard Worker        ctx.save_for_backward(A, B, D, U)
299*da0073e9SAndroid Build Coastguard Worker        ctx.largest = largest
300*da0073e9SAndroid Build Coastguard Worker
301*da0073e9SAndroid Build Coastguard Worker        return D, U
302*da0073e9SAndroid Build Coastguard Worker
303*da0073e9SAndroid Build Coastguard Worker    @staticmethod
304*da0073e9SAndroid Build Coastguard Worker    def backward(ctx, D_grad, U_grad):
305*da0073e9SAndroid Build Coastguard Worker        A_grad = B_grad = None
306*da0073e9SAndroid Build Coastguard Worker        grads = [None] * 14
307*da0073e9SAndroid Build Coastguard Worker
308*da0073e9SAndroid Build Coastguard Worker        A, B, D, U = ctx.saved_tensors
309*da0073e9SAndroid Build Coastguard Worker        largest = ctx.largest
310*da0073e9SAndroid Build Coastguard Worker
311*da0073e9SAndroid Build Coastguard Worker        # lobpcg.backward has some limitations. Checks for unsupported input
312*da0073e9SAndroid Build Coastguard Worker        if A.is_sparse or (B is not None and B.is_sparse and ctx.needs_input_grad[2]):
313*da0073e9SAndroid Build Coastguard Worker            raise ValueError(
314*da0073e9SAndroid Build Coastguard Worker                "lobpcg.backward does not support sparse input yet."
315*da0073e9SAndroid Build Coastguard Worker                "Note that lobpcg.forward does though."
316*da0073e9SAndroid Build Coastguard Worker            )
317*da0073e9SAndroid Build Coastguard Worker        if (
318*da0073e9SAndroid Build Coastguard Worker            A.dtype in (torch.complex64, torch.complex128)
319*da0073e9SAndroid Build Coastguard Worker            or B is not None
320*da0073e9SAndroid Build Coastguard Worker            and B.dtype in (torch.complex64, torch.complex128)
321*da0073e9SAndroid Build Coastguard Worker        ):
322*da0073e9SAndroid Build Coastguard Worker            raise ValueError(
323*da0073e9SAndroid Build Coastguard Worker                "lobpcg.backward does not support complex input yet."
324*da0073e9SAndroid Build Coastguard Worker                "Note that lobpcg.forward does though."
325*da0073e9SAndroid Build Coastguard Worker            )
326*da0073e9SAndroid Build Coastguard Worker        if B is not None:
327*da0073e9SAndroid Build Coastguard Worker            raise ValueError(
328*da0073e9SAndroid Build Coastguard Worker                "lobpcg.backward does not support backward with B != I yet."
329*da0073e9SAndroid Build Coastguard Worker            )
330*da0073e9SAndroid Build Coastguard Worker
331*da0073e9SAndroid Build Coastguard Worker        if largest is None:
332*da0073e9SAndroid Build Coastguard Worker            largest = True
333*da0073e9SAndroid Build Coastguard Worker
334*da0073e9SAndroid Build Coastguard Worker        # symeig backward
335*da0073e9SAndroid Build Coastguard Worker        if B is None:
336*da0073e9SAndroid Build Coastguard Worker            A_grad = _symeig_backward(D_grad, U_grad, A, D, U, largest)
337*da0073e9SAndroid Build Coastguard Worker
338*da0073e9SAndroid Build Coastguard Worker        # A has index 0
339*da0073e9SAndroid Build Coastguard Worker        grads[0] = A_grad
340*da0073e9SAndroid Build Coastguard Worker        # B has index 2
341*da0073e9SAndroid Build Coastguard Worker        grads[2] = B_grad
342*da0073e9SAndroid Build Coastguard Worker        return tuple(grads)
343*da0073e9SAndroid Build Coastguard Worker
344*da0073e9SAndroid Build Coastguard Worker
345*da0073e9SAndroid Build Coastguard Workerdef lobpcg(
346*da0073e9SAndroid Build Coastguard Worker    A: Tensor,
347*da0073e9SAndroid Build Coastguard Worker    k: Optional[int] = None,
348*da0073e9SAndroid Build Coastguard Worker    B: Optional[Tensor] = None,
349*da0073e9SAndroid Build Coastguard Worker    X: Optional[Tensor] = None,
350*da0073e9SAndroid Build Coastguard Worker    n: Optional[int] = None,
351*da0073e9SAndroid Build Coastguard Worker    iK: Optional[Tensor] = None,
352*da0073e9SAndroid Build Coastguard Worker    niter: Optional[int] = None,
353*da0073e9SAndroid Build Coastguard Worker    tol: Optional[float] = None,
354*da0073e9SAndroid Build Coastguard Worker    largest: Optional[bool] = None,
355*da0073e9SAndroid Build Coastguard Worker    method: Optional[str] = None,
356*da0073e9SAndroid Build Coastguard Worker    tracker: None = None,
357*da0073e9SAndroid Build Coastguard Worker    ortho_iparams: Optional[Dict[str, int]] = None,
358*da0073e9SAndroid Build Coastguard Worker    ortho_fparams: Optional[Dict[str, float]] = None,
359*da0073e9SAndroid Build Coastguard Worker    ortho_bparams: Optional[Dict[str, bool]] = None,
360*da0073e9SAndroid Build Coastguard Worker) -> Tuple[Tensor, Tensor]:
361*da0073e9SAndroid Build Coastguard Worker    """Find the k largest (or smallest) eigenvalues and the corresponding
362*da0073e9SAndroid Build Coastguard Worker    eigenvectors of a symmetric positive definite generalized
363*da0073e9SAndroid Build Coastguard Worker    eigenvalue problem using matrix-free LOBPCG methods.
364*da0073e9SAndroid Build Coastguard Worker
365*da0073e9SAndroid Build Coastguard Worker    This function is a front-end to the following LOBPCG algorithms
366*da0073e9SAndroid Build Coastguard Worker    selectable via `method` argument:
367*da0073e9SAndroid Build Coastguard Worker
368*da0073e9SAndroid Build Coastguard Worker      `method="basic"` - the LOBPCG method introduced by Andrew
369*da0073e9SAndroid Build Coastguard Worker      Knyazev, see [Knyazev2001]. A less robust method, may fail when
370*da0073e9SAndroid Build Coastguard Worker      Cholesky is applied to singular input.
371*da0073e9SAndroid Build Coastguard Worker
372*da0073e9SAndroid Build Coastguard Worker      `method="ortho"` - the LOBPCG method with orthogonal basis
373*da0073e9SAndroid Build Coastguard Worker      selection [StathopoulosEtal2002]. A robust method.
374*da0073e9SAndroid Build Coastguard Worker
375*da0073e9SAndroid Build Coastguard Worker    Supported inputs are dense, sparse, and batches of dense matrices.
376*da0073e9SAndroid Build Coastguard Worker
377*da0073e9SAndroid Build Coastguard Worker    .. note:: In general, the basic method spends least time per
378*da0073e9SAndroid Build Coastguard Worker      iteration. However, the robust methods converge much faster and
379*da0073e9SAndroid Build Coastguard Worker      are more stable. So, the usage of the basic method is generally
380*da0073e9SAndroid Build Coastguard Worker      not recommended but there exist cases where the usage of the
381*da0073e9SAndroid Build Coastguard Worker      basic method may be preferred.
382*da0073e9SAndroid Build Coastguard Worker
383*da0073e9SAndroid Build Coastguard Worker    .. warning:: The backward method does not support sparse and complex inputs.
384*da0073e9SAndroid Build Coastguard Worker      It works only when `B` is not provided (i.e. `B == None`).
385*da0073e9SAndroid Build Coastguard Worker      We are actively working on extensions, and the details of
386*da0073e9SAndroid Build Coastguard Worker      the algorithms are going to be published promptly.
387*da0073e9SAndroid Build Coastguard Worker
388*da0073e9SAndroid Build Coastguard Worker    .. warning:: While it is assumed that `A` is symmetric, `A.grad` is not.
389*da0073e9SAndroid Build Coastguard Worker      To make sure that `A.grad` is symmetric, so that `A - t * A.grad` is symmetric
390*da0073e9SAndroid Build Coastguard Worker      in first-order optimization routines, prior to running `lobpcg`
391*da0073e9SAndroid Build Coastguard Worker      we do the following symmetrization map: `A -> (A + A.t()) / 2`.
392*da0073e9SAndroid Build Coastguard Worker      The map is performed only when the `A` requires gradients.
393*da0073e9SAndroid Build Coastguard Worker
394*da0073e9SAndroid Build Coastguard Worker    Args:
395*da0073e9SAndroid Build Coastguard Worker
396*da0073e9SAndroid Build Coastguard Worker      A (Tensor): the input tensor of size :math:`(*, m, m)`
397*da0073e9SAndroid Build Coastguard Worker
398*da0073e9SAndroid Build Coastguard Worker      B (Tensor, optional): the input tensor of size :math:`(*, m,
399*da0073e9SAndroid Build Coastguard Worker                  m)`. When not specified, `B` is interpreted as
400*da0073e9SAndroid Build Coastguard Worker                  identity matrix.
401*da0073e9SAndroid Build Coastguard Worker
402*da0073e9SAndroid Build Coastguard Worker      X (tensor, optional): the input tensor of size :math:`(*, m, n)`
403*da0073e9SAndroid Build Coastguard Worker                  where `k <= n <= m`. When specified, it is used as
404*da0073e9SAndroid Build Coastguard Worker                  initial approximation of eigenvectors. X must be a
405*da0073e9SAndroid Build Coastguard Worker                  dense tensor.
406*da0073e9SAndroid Build Coastguard Worker
407*da0073e9SAndroid Build Coastguard Worker      iK (tensor, optional): the input tensor of size :math:`(*, m,
408*da0073e9SAndroid Build Coastguard Worker                  m)`. When specified, it will be used as preconditioner.
409*da0073e9SAndroid Build Coastguard Worker
410*da0073e9SAndroid Build Coastguard Worker      k (integer, optional): the number of requested
411*da0073e9SAndroid Build Coastguard Worker                  eigenpairs. Default is the number of :math:`X`
412*da0073e9SAndroid Build Coastguard Worker                  columns (when specified) or `1`.
413*da0073e9SAndroid Build Coastguard Worker
414*da0073e9SAndroid Build Coastguard Worker      n (integer, optional): if :math:`X` is not specified then `n`
415*da0073e9SAndroid Build Coastguard Worker                  specifies the size of the generated random
416*da0073e9SAndroid Build Coastguard Worker                  approximation of eigenvectors. Default value for `n`
417*da0073e9SAndroid Build Coastguard Worker                  is `k`. If :math:`X` is specified, the value of `n`
418*da0073e9SAndroid Build Coastguard Worker                  (when specified) must be the number of :math:`X`
419*da0073e9SAndroid Build Coastguard Worker                  columns.
420*da0073e9SAndroid Build Coastguard Worker
421*da0073e9SAndroid Build Coastguard Worker      tol (float, optional): residual tolerance for stopping
422*da0073e9SAndroid Build Coastguard Worker                 criterion. Default is `feps ** 0.5` where `feps` is
423*da0073e9SAndroid Build Coastguard Worker                 smallest non-zero floating-point number of the given
424*da0073e9SAndroid Build Coastguard Worker                 input tensor `A` data type.
425*da0073e9SAndroid Build Coastguard Worker
426*da0073e9SAndroid Build Coastguard Worker      largest (bool, optional): when True, solve the eigenproblem for
427*da0073e9SAndroid Build Coastguard Worker                 the largest eigenvalues. Otherwise, solve the
428*da0073e9SAndroid Build Coastguard Worker                 eigenproblem for smallest eigenvalues. Default is
429*da0073e9SAndroid Build Coastguard Worker                 `True`.
430*da0073e9SAndroid Build Coastguard Worker
431*da0073e9SAndroid Build Coastguard Worker      method (str, optional): select LOBPCG method. See the
432*da0073e9SAndroid Build Coastguard Worker                 description of the function above. Default is
433*da0073e9SAndroid Build Coastguard Worker                 "ortho".
434*da0073e9SAndroid Build Coastguard Worker
435*da0073e9SAndroid Build Coastguard Worker      niter (int, optional): maximum number of iterations. When
436*da0073e9SAndroid Build Coastguard Worker                 reached, the iteration process is hard-stopped and
437*da0073e9SAndroid Build Coastguard Worker                 the current approximation of eigenpairs is returned.
438*da0073e9SAndroid Build Coastguard Worker                 For infinite iteration but until convergence criteria
439*da0073e9SAndroid Build Coastguard Worker                 is met, use `-1`.
440*da0073e9SAndroid Build Coastguard Worker
441*da0073e9SAndroid Build Coastguard Worker      tracker (callable, optional) : a function for tracing the
442*da0073e9SAndroid Build Coastguard Worker                 iteration process. When specified, it is called at
443*da0073e9SAndroid Build Coastguard Worker                 each iteration step with LOBPCG instance as an
444*da0073e9SAndroid Build Coastguard Worker                 argument. The LOBPCG instance holds the full state of
445*da0073e9SAndroid Build Coastguard Worker                 the iteration process in the following attributes:
446*da0073e9SAndroid Build Coastguard Worker
447*da0073e9SAndroid Build Coastguard Worker                   `iparams`, `fparams`, `bparams` - dictionaries of
448*da0073e9SAndroid Build Coastguard Worker                   integer, float, and boolean valued input
449*da0073e9SAndroid Build Coastguard Worker                   parameters, respectively
450*da0073e9SAndroid Build Coastguard Worker
451*da0073e9SAndroid Build Coastguard Worker                   `ivars`, `fvars`, `bvars`, `tvars` - dictionaries
452*da0073e9SAndroid Build Coastguard Worker                   of integer, float, boolean, and Tensor valued
453*da0073e9SAndroid Build Coastguard Worker                   iteration variables, respectively.
454*da0073e9SAndroid Build Coastguard Worker
455*da0073e9SAndroid Build Coastguard Worker                   `A`, `B`, `iK` - input Tensor arguments.
456*da0073e9SAndroid Build Coastguard Worker
457*da0073e9SAndroid Build Coastguard Worker                   `E`, `X`, `S`, `R` - iteration Tensor variables.
458*da0073e9SAndroid Build Coastguard Worker
459*da0073e9SAndroid Build Coastguard Worker                 For instance:
460*da0073e9SAndroid Build Coastguard Worker
461*da0073e9SAndroid Build Coastguard Worker                   `ivars["istep"]` - the current iteration step
462*da0073e9SAndroid Build Coastguard Worker                   `X` - the current approximation of eigenvectors
463*da0073e9SAndroid Build Coastguard Worker                   `E` - the current approximation of eigenvalues
464*da0073e9SAndroid Build Coastguard Worker                   `R` - the current residual
465*da0073e9SAndroid Build Coastguard Worker                   `ivars["converged_count"]` - the current number of converged eigenpairs
466*da0073e9SAndroid Build Coastguard Worker                   `tvars["rerr"]` - the current state of convergence criteria
467*da0073e9SAndroid Build Coastguard Worker
468*da0073e9SAndroid Build Coastguard Worker                 Note that when `tracker` stores Tensor objects from
469*da0073e9SAndroid Build Coastguard Worker                 the LOBPCG instance, it must make copies of these.
470*da0073e9SAndroid Build Coastguard Worker
471*da0073e9SAndroid Build Coastguard Worker                 If `tracker` sets `bvars["force_stop"] = True`, the
472*da0073e9SAndroid Build Coastguard Worker                 iteration process will be hard-stopped.
473*da0073e9SAndroid Build Coastguard Worker
474*da0073e9SAndroid Build Coastguard Worker      ortho_iparams, ortho_fparams, ortho_bparams (dict, optional):
475*da0073e9SAndroid Build Coastguard Worker                 various parameters to LOBPCG algorithm when using
476*da0073e9SAndroid Build Coastguard Worker                 `method="ortho"`.
477*da0073e9SAndroid Build Coastguard Worker
478*da0073e9SAndroid Build Coastguard Worker    Returns:
479*da0073e9SAndroid Build Coastguard Worker
480*da0073e9SAndroid Build Coastguard Worker      E (Tensor): tensor of eigenvalues of size :math:`(*, k)`
481*da0073e9SAndroid Build Coastguard Worker
482*da0073e9SAndroid Build Coastguard Worker      X (Tensor): tensor of eigenvectors of size :math:`(*, m, k)`
483*da0073e9SAndroid Build Coastguard Worker
484*da0073e9SAndroid Build Coastguard Worker    References:
485*da0073e9SAndroid Build Coastguard Worker
486*da0073e9SAndroid Build Coastguard Worker      [Knyazev2001] Andrew V. Knyazev. (2001) Toward the Optimal
487*da0073e9SAndroid Build Coastguard Worker      Preconditioned Eigensolver: Locally Optimal Block Preconditioned
488*da0073e9SAndroid Build Coastguard Worker      Conjugate Gradient Method. SIAM J. Sci. Comput., 23(2),
489*da0073e9SAndroid Build Coastguard Worker      517-541. (25 pages)
490*da0073e9SAndroid Build Coastguard Worker      https://epubs.siam.org/doi/abs/10.1137/S1064827500366124
491*da0073e9SAndroid Build Coastguard Worker
492*da0073e9SAndroid Build Coastguard Worker      [StathopoulosEtal2002] Andreas Stathopoulos and Kesheng
493*da0073e9SAndroid Build Coastguard Worker      Wu. (2002) A Block Orthogonalization Procedure with Constant
494*da0073e9SAndroid Build Coastguard Worker      Synchronization Requirements. SIAM J. Sci. Comput., 23(6),
495*da0073e9SAndroid Build Coastguard Worker      2165-2182. (18 pages)
496*da0073e9SAndroid Build Coastguard Worker      https://epubs.siam.org/doi/10.1137/S1064827500370883
497*da0073e9SAndroid Build Coastguard Worker
498*da0073e9SAndroid Build Coastguard Worker      [DuerschEtal2018] Jed A. Duersch, Meiyue Shao, Chao Yang, Ming
499*da0073e9SAndroid Build Coastguard Worker      Gu. (2018) A Robust and Efficient Implementation of LOBPCG.
500*da0073e9SAndroid Build Coastguard Worker      SIAM J. Sci. Comput., 40(5), C655-C676. (22 pages)
501*da0073e9SAndroid Build Coastguard Worker      https://epubs.siam.org/doi/abs/10.1137/17M1129830
502*da0073e9SAndroid Build Coastguard Worker
503*da0073e9SAndroid Build Coastguard Worker    """
504*da0073e9SAndroid Build Coastguard Worker
505*da0073e9SAndroid Build Coastguard Worker    if not torch.jit.is_scripting():
506*da0073e9SAndroid Build Coastguard Worker        tensor_ops = (A, B, X, iK)
507*da0073e9SAndroid Build Coastguard Worker        if not set(map(type, tensor_ops)).issubset(
508*da0073e9SAndroid Build Coastguard Worker            (torch.Tensor, type(None))
509*da0073e9SAndroid Build Coastguard Worker        ) and has_torch_function(tensor_ops):
510*da0073e9SAndroid Build Coastguard Worker            return handle_torch_function(
511*da0073e9SAndroid Build Coastguard Worker                lobpcg,
512*da0073e9SAndroid Build Coastguard Worker                tensor_ops,
513*da0073e9SAndroid Build Coastguard Worker                A,
514*da0073e9SAndroid Build Coastguard Worker                k=k,
515*da0073e9SAndroid Build Coastguard Worker                B=B,
516*da0073e9SAndroid Build Coastguard Worker                X=X,
517*da0073e9SAndroid Build Coastguard Worker                n=n,
518*da0073e9SAndroid Build Coastguard Worker                iK=iK,
519*da0073e9SAndroid Build Coastguard Worker                niter=niter,
520*da0073e9SAndroid Build Coastguard Worker                tol=tol,
521*da0073e9SAndroid Build Coastguard Worker                largest=largest,
522*da0073e9SAndroid Build Coastguard Worker                method=method,
523*da0073e9SAndroid Build Coastguard Worker                tracker=tracker,
524*da0073e9SAndroid Build Coastguard Worker                ortho_iparams=ortho_iparams,
525*da0073e9SAndroid Build Coastguard Worker                ortho_fparams=ortho_fparams,
526*da0073e9SAndroid Build Coastguard Worker                ortho_bparams=ortho_bparams,
527*da0073e9SAndroid Build Coastguard Worker            )
528*da0073e9SAndroid Build Coastguard Worker
529*da0073e9SAndroid Build Coastguard Worker    if not torch._jit_internal.is_scripting():
530*da0073e9SAndroid Build Coastguard Worker        if A.requires_grad or (B is not None and B.requires_grad):
531*da0073e9SAndroid Build Coastguard Worker            # While it is expected that `A` is symmetric,
532*da0073e9SAndroid Build Coastguard Worker            # the `A_grad` might be not. Therefore we perform the trick below,
533*da0073e9SAndroid Build Coastguard Worker            # so that `A_grad` becomes symmetric.
534*da0073e9SAndroid Build Coastguard Worker            # The symmetrization is important for first-order optimization methods,
535*da0073e9SAndroid Build Coastguard Worker            # so that (A - alpha * A_grad) is still a symmetric matrix.
536*da0073e9SAndroid Build Coastguard Worker            # Same holds for `B`.
537*da0073e9SAndroid Build Coastguard Worker            A_sym = (A + A.mT) / 2
538*da0073e9SAndroid Build Coastguard Worker            B_sym = (B + B.mT) / 2 if (B is not None) else None
539*da0073e9SAndroid Build Coastguard Worker
540*da0073e9SAndroid Build Coastguard Worker            return LOBPCGAutogradFunction.apply(
541*da0073e9SAndroid Build Coastguard Worker                A_sym,
542*da0073e9SAndroid Build Coastguard Worker                k,
543*da0073e9SAndroid Build Coastguard Worker                B_sym,
544*da0073e9SAndroid Build Coastguard Worker                X,
545*da0073e9SAndroid Build Coastguard Worker                n,
546*da0073e9SAndroid Build Coastguard Worker                iK,
547*da0073e9SAndroid Build Coastguard Worker                niter,
548*da0073e9SAndroid Build Coastguard Worker                tol,
549*da0073e9SAndroid Build Coastguard Worker                largest,
550*da0073e9SAndroid Build Coastguard Worker                method,
551*da0073e9SAndroid Build Coastguard Worker                tracker,
552*da0073e9SAndroid Build Coastguard Worker                ortho_iparams,
553*da0073e9SAndroid Build Coastguard Worker                ortho_fparams,
554*da0073e9SAndroid Build Coastguard Worker                ortho_bparams,
555*da0073e9SAndroid Build Coastguard Worker            )
556*da0073e9SAndroid Build Coastguard Worker    else:
557*da0073e9SAndroid Build Coastguard Worker        if A.requires_grad or (B is not None and B.requires_grad):
558*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError(
559*da0073e9SAndroid Build Coastguard Worker                "Script and require grads is not supported atm."
560*da0073e9SAndroid Build Coastguard Worker                "If you just want to do the forward, use .detach()"
561*da0073e9SAndroid Build Coastguard Worker                "on A and B before calling into lobpcg"
562*da0073e9SAndroid Build Coastguard Worker            )
563*da0073e9SAndroid Build Coastguard Worker
564*da0073e9SAndroid Build Coastguard Worker    return _lobpcg(
565*da0073e9SAndroid Build Coastguard Worker        A,
566*da0073e9SAndroid Build Coastguard Worker        k,
567*da0073e9SAndroid Build Coastguard Worker        B,
568*da0073e9SAndroid Build Coastguard Worker        X,
569*da0073e9SAndroid Build Coastguard Worker        n,
570*da0073e9SAndroid Build Coastguard Worker        iK,
571*da0073e9SAndroid Build Coastguard Worker        niter,
572*da0073e9SAndroid Build Coastguard Worker        tol,
573*da0073e9SAndroid Build Coastguard Worker        largest,
574*da0073e9SAndroid Build Coastguard Worker        method,
575*da0073e9SAndroid Build Coastguard Worker        tracker,
576*da0073e9SAndroid Build Coastguard Worker        ortho_iparams,
577*da0073e9SAndroid Build Coastguard Worker        ortho_fparams,
578*da0073e9SAndroid Build Coastguard Worker        ortho_bparams,
579*da0073e9SAndroid Build Coastguard Worker    )
580*da0073e9SAndroid Build Coastguard Worker
581*da0073e9SAndroid Build Coastguard Worker
582*da0073e9SAndroid Build Coastguard Workerdef _lobpcg(
583*da0073e9SAndroid Build Coastguard Worker    A: Tensor,
584*da0073e9SAndroid Build Coastguard Worker    k: Optional[int] = None,
585*da0073e9SAndroid Build Coastguard Worker    B: Optional[Tensor] = None,
586*da0073e9SAndroid Build Coastguard Worker    X: Optional[Tensor] = None,
587*da0073e9SAndroid Build Coastguard Worker    n: Optional[int] = None,
588*da0073e9SAndroid Build Coastguard Worker    iK: Optional[Tensor] = None,
589*da0073e9SAndroid Build Coastguard Worker    niter: Optional[int] = None,
590*da0073e9SAndroid Build Coastguard Worker    tol: Optional[float] = None,
591*da0073e9SAndroid Build Coastguard Worker    largest: Optional[bool] = None,
592*da0073e9SAndroid Build Coastguard Worker    method: Optional[str] = None,
593*da0073e9SAndroid Build Coastguard Worker    tracker: None = None,
594*da0073e9SAndroid Build Coastguard Worker    ortho_iparams: Optional[Dict[str, int]] = None,
595*da0073e9SAndroid Build Coastguard Worker    ortho_fparams: Optional[Dict[str, float]] = None,
596*da0073e9SAndroid Build Coastguard Worker    ortho_bparams: Optional[Dict[str, bool]] = None,
597*da0073e9SAndroid Build Coastguard Worker) -> Tuple[Tensor, Tensor]:
598*da0073e9SAndroid Build Coastguard Worker    # A must be square:
599*da0073e9SAndroid Build Coastguard Worker    assert A.shape[-2] == A.shape[-1], A.shape
600*da0073e9SAndroid Build Coastguard Worker    if B is not None:
601*da0073e9SAndroid Build Coastguard Worker        # A and B must have the same shapes:
602*da0073e9SAndroid Build Coastguard Worker        assert A.shape == B.shape, (A.shape, B.shape)
603*da0073e9SAndroid Build Coastguard Worker
604*da0073e9SAndroid Build Coastguard Worker    dtype = _utils.get_floating_dtype(A)
605*da0073e9SAndroid Build Coastguard Worker    device = A.device
606*da0073e9SAndroid Build Coastguard Worker    if tol is None:
607*da0073e9SAndroid Build Coastguard Worker        feps = {torch.float32: 1.2e-07, torch.float64: 2.23e-16}[dtype]
608*da0073e9SAndroid Build Coastguard Worker        tol = feps**0.5
609*da0073e9SAndroid Build Coastguard Worker
610*da0073e9SAndroid Build Coastguard Worker    m = A.shape[-1]
611*da0073e9SAndroid Build Coastguard Worker    k = (1 if X is None else X.shape[-1]) if k is None else k
612*da0073e9SAndroid Build Coastguard Worker    n = (k if n is None else n) if X is None else X.shape[-1]
613*da0073e9SAndroid Build Coastguard Worker
614*da0073e9SAndroid Build Coastguard Worker    if m < 3 * n:
615*da0073e9SAndroid Build Coastguard Worker        raise ValueError(
616*da0073e9SAndroid Build Coastguard Worker            f"LPBPCG algorithm is not applicable when the number of A rows (={m})"
617*da0073e9SAndroid Build Coastguard Worker            f" is smaller than 3 x the number of requested eigenpairs (={n})"
618*da0073e9SAndroid Build Coastguard Worker        )
619*da0073e9SAndroid Build Coastguard Worker
620*da0073e9SAndroid Build Coastguard Worker    method = "ortho" if method is None else method
621*da0073e9SAndroid Build Coastguard Worker
622*da0073e9SAndroid Build Coastguard Worker    iparams = {
623*da0073e9SAndroid Build Coastguard Worker        "m": m,
624*da0073e9SAndroid Build Coastguard Worker        "n": n,
625*da0073e9SAndroid Build Coastguard Worker        "k": k,
626*da0073e9SAndroid Build Coastguard Worker        "niter": 1000 if niter is None else niter,
627*da0073e9SAndroid Build Coastguard Worker    }
628*da0073e9SAndroid Build Coastguard Worker
629*da0073e9SAndroid Build Coastguard Worker    fparams = {
630*da0073e9SAndroid Build Coastguard Worker        "tol": tol,
631*da0073e9SAndroid Build Coastguard Worker    }
632*da0073e9SAndroid Build Coastguard Worker
633*da0073e9SAndroid Build Coastguard Worker    bparams = {"largest": True if largest is None else largest}
634*da0073e9SAndroid Build Coastguard Worker
635*da0073e9SAndroid Build Coastguard Worker    if method == "ortho":
636*da0073e9SAndroid Build Coastguard Worker        if ortho_iparams is not None:
637*da0073e9SAndroid Build Coastguard Worker            iparams.update(ortho_iparams)
638*da0073e9SAndroid Build Coastguard Worker        if ortho_fparams is not None:
639*da0073e9SAndroid Build Coastguard Worker            fparams.update(ortho_fparams)
640*da0073e9SAndroid Build Coastguard Worker        if ortho_bparams is not None:
641*da0073e9SAndroid Build Coastguard Worker            bparams.update(ortho_bparams)
642*da0073e9SAndroid Build Coastguard Worker        iparams["ortho_i_max"] = iparams.get("ortho_i_max", 3)
643*da0073e9SAndroid Build Coastguard Worker        iparams["ortho_j_max"] = iparams.get("ortho_j_max", 3)
644*da0073e9SAndroid Build Coastguard Worker        fparams["ortho_tol"] = fparams.get("ortho_tol", tol)
645*da0073e9SAndroid Build Coastguard Worker        fparams["ortho_tol_drop"] = fparams.get("ortho_tol_drop", tol)
646*da0073e9SAndroid Build Coastguard Worker        fparams["ortho_tol_replace"] = fparams.get("ortho_tol_replace", tol)
647*da0073e9SAndroid Build Coastguard Worker        bparams["ortho_use_drop"] = bparams.get("ortho_use_drop", False)
648*da0073e9SAndroid Build Coastguard Worker
649*da0073e9SAndroid Build Coastguard Worker    if not torch.jit.is_scripting():
650*da0073e9SAndroid Build Coastguard Worker        LOBPCG.call_tracker = LOBPCG_call_tracker  # type: ignore[method-assign]
651*da0073e9SAndroid Build Coastguard Worker
652*da0073e9SAndroid Build Coastguard Worker    if len(A.shape) > 2:
653*da0073e9SAndroid Build Coastguard Worker        N = int(torch.prod(torch.tensor(A.shape[:-2])))
654*da0073e9SAndroid Build Coastguard Worker        bA = A.reshape((N,) + A.shape[-2:])
655*da0073e9SAndroid Build Coastguard Worker        bB = B.reshape((N,) + A.shape[-2:]) if B is not None else None
656*da0073e9SAndroid Build Coastguard Worker        bX = X.reshape((N,) + X.shape[-2:]) if X is not None else None
657*da0073e9SAndroid Build Coastguard Worker        bE = torch.empty((N, k), dtype=dtype, device=device)
658*da0073e9SAndroid Build Coastguard Worker        bXret = torch.empty((N, m, k), dtype=dtype, device=device)
659*da0073e9SAndroid Build Coastguard Worker
660*da0073e9SAndroid Build Coastguard Worker        for i in range(N):
661*da0073e9SAndroid Build Coastguard Worker            A_ = bA[i]
662*da0073e9SAndroid Build Coastguard Worker            B_ = bB[i] if bB is not None else None
663*da0073e9SAndroid Build Coastguard Worker            X_ = (
664*da0073e9SAndroid Build Coastguard Worker                torch.randn((m, n), dtype=dtype, device=device) if bX is None else bX[i]
665*da0073e9SAndroid Build Coastguard Worker            )
666*da0073e9SAndroid Build Coastguard Worker            assert len(X_.shape) == 2 and X_.shape == (m, n), (X_.shape, (m, n))
667*da0073e9SAndroid Build Coastguard Worker            iparams["batch_index"] = i
668*da0073e9SAndroid Build Coastguard Worker            worker = LOBPCG(A_, B_, X_, iK, iparams, fparams, bparams, method, tracker)
669*da0073e9SAndroid Build Coastguard Worker            worker.run()
670*da0073e9SAndroid Build Coastguard Worker            bE[i] = worker.E[:k]
671*da0073e9SAndroid Build Coastguard Worker            bXret[i] = worker.X[:, :k]
672*da0073e9SAndroid Build Coastguard Worker
673*da0073e9SAndroid Build Coastguard Worker        if not torch.jit.is_scripting():
674*da0073e9SAndroid Build Coastguard Worker            LOBPCG.call_tracker = LOBPCG_call_tracker_orig  # type: ignore[method-assign]
675*da0073e9SAndroid Build Coastguard Worker
676*da0073e9SAndroid Build Coastguard Worker        return bE.reshape(A.shape[:-2] + (k,)), bXret.reshape(A.shape[:-2] + (m, k))
677*da0073e9SAndroid Build Coastguard Worker
678*da0073e9SAndroid Build Coastguard Worker    X = torch.randn((m, n), dtype=dtype, device=device) if X is None else X
679*da0073e9SAndroid Build Coastguard Worker    assert len(X.shape) == 2 and X.shape == (m, n), (X.shape, (m, n))
680*da0073e9SAndroid Build Coastguard Worker
681*da0073e9SAndroid Build Coastguard Worker    worker = LOBPCG(A, B, X, iK, iparams, fparams, bparams, method, tracker)
682*da0073e9SAndroid Build Coastguard Worker
683*da0073e9SAndroid Build Coastguard Worker    worker.run()
684*da0073e9SAndroid Build Coastguard Worker
685*da0073e9SAndroid Build Coastguard Worker    if not torch.jit.is_scripting():
686*da0073e9SAndroid Build Coastguard Worker        LOBPCG.call_tracker = LOBPCG_call_tracker_orig  # type: ignore[method-assign]
687*da0073e9SAndroid Build Coastguard Worker
688*da0073e9SAndroid Build Coastguard Worker    return worker.E[:k], worker.X[:, :k]
689*da0073e9SAndroid Build Coastguard Worker
690*da0073e9SAndroid Build Coastguard Worker
691*da0073e9SAndroid Build Coastguard Workerclass LOBPCG:
692*da0073e9SAndroid Build Coastguard Worker    """Worker class of LOBPCG methods."""
693*da0073e9SAndroid Build Coastguard Worker
694*da0073e9SAndroid Build Coastguard Worker    def __init__(
695*da0073e9SAndroid Build Coastguard Worker        self,
696*da0073e9SAndroid Build Coastguard Worker        A: Optional[Tensor],
697*da0073e9SAndroid Build Coastguard Worker        B: Optional[Tensor],
698*da0073e9SAndroid Build Coastguard Worker        X: Tensor,
699*da0073e9SAndroid Build Coastguard Worker        iK: Optional[Tensor],
700*da0073e9SAndroid Build Coastguard Worker        iparams: Dict[str, int],
701*da0073e9SAndroid Build Coastguard Worker        fparams: Dict[str, float],
702*da0073e9SAndroid Build Coastguard Worker        bparams: Dict[str, bool],
703*da0073e9SAndroid Build Coastguard Worker        method: str,
704*da0073e9SAndroid Build Coastguard Worker        tracker: None,
705*da0073e9SAndroid Build Coastguard Worker    ) -> None:
706*da0073e9SAndroid Build Coastguard Worker        # constant parameters
707*da0073e9SAndroid Build Coastguard Worker        self.A = A
708*da0073e9SAndroid Build Coastguard Worker        self.B = B
709*da0073e9SAndroid Build Coastguard Worker        self.iK = iK
710*da0073e9SAndroid Build Coastguard Worker        self.iparams = iparams
711*da0073e9SAndroid Build Coastguard Worker        self.fparams = fparams
712*da0073e9SAndroid Build Coastguard Worker        self.bparams = bparams
713*da0073e9SAndroid Build Coastguard Worker        self.method = method
714*da0073e9SAndroid Build Coastguard Worker        self.tracker = tracker
715*da0073e9SAndroid Build Coastguard Worker        m = iparams["m"]
716*da0073e9SAndroid Build Coastguard Worker        n = iparams["n"]
717*da0073e9SAndroid Build Coastguard Worker
718*da0073e9SAndroid Build Coastguard Worker        # variable parameters
719*da0073e9SAndroid Build Coastguard Worker        self.X = X
720*da0073e9SAndroid Build Coastguard Worker        self.E = torch.zeros((n,), dtype=X.dtype, device=X.device)
721*da0073e9SAndroid Build Coastguard Worker        self.R = torch.zeros((m, n), dtype=X.dtype, device=X.device)
722*da0073e9SAndroid Build Coastguard Worker        self.S = torch.zeros((m, 3 * n), dtype=X.dtype, device=X.device)
723*da0073e9SAndroid Build Coastguard Worker        self.tvars: Dict[str, Tensor] = {}
724*da0073e9SAndroid Build Coastguard Worker        self.ivars: Dict[str, int] = {"istep": 0}
725*da0073e9SAndroid Build Coastguard Worker        self.fvars: Dict[str, float] = {"_": 0.0}
726*da0073e9SAndroid Build Coastguard Worker        self.bvars: Dict[str, bool] = {"_": False}
727*da0073e9SAndroid Build Coastguard Worker
728*da0073e9SAndroid Build Coastguard Worker    def __str__(self):
729*da0073e9SAndroid Build Coastguard Worker        lines = ["LOPBCG:"]
730*da0073e9SAndroid Build Coastguard Worker        lines += [f"  iparams={self.iparams}"]
731*da0073e9SAndroid Build Coastguard Worker        lines += [f"  fparams={self.fparams}"]
732*da0073e9SAndroid Build Coastguard Worker        lines += [f"  bparams={self.bparams}"]
733*da0073e9SAndroid Build Coastguard Worker        lines += [f"  ivars={self.ivars}"]
734*da0073e9SAndroid Build Coastguard Worker        lines += [f"  fvars={self.fvars}"]
735*da0073e9SAndroid Build Coastguard Worker        lines += [f"  bvars={self.bvars}"]
736*da0073e9SAndroid Build Coastguard Worker        lines += [f"  tvars={self.tvars}"]
737*da0073e9SAndroid Build Coastguard Worker        lines += [f"  A={self.A}"]
738*da0073e9SAndroid Build Coastguard Worker        lines += [f"  B={self.B}"]
739*da0073e9SAndroid Build Coastguard Worker        lines += [f"  iK={self.iK}"]
740*da0073e9SAndroid Build Coastguard Worker        lines += [f"  X={self.X}"]
741*da0073e9SAndroid Build Coastguard Worker        lines += [f"  E={self.E}"]
742*da0073e9SAndroid Build Coastguard Worker        r = ""
743*da0073e9SAndroid Build Coastguard Worker        for line in lines:
744*da0073e9SAndroid Build Coastguard Worker            r += line + "\n"
745*da0073e9SAndroid Build Coastguard Worker        return r
746*da0073e9SAndroid Build Coastguard Worker
747*da0073e9SAndroid Build Coastguard Worker    def update(self):
748*da0073e9SAndroid Build Coastguard Worker        """Set and update iteration variables."""
749*da0073e9SAndroid Build Coastguard Worker        if self.ivars["istep"] == 0:
750*da0073e9SAndroid Build Coastguard Worker            X_norm = float(torch.norm(self.X))
751*da0073e9SAndroid Build Coastguard Worker            iX_norm = X_norm**-1
752*da0073e9SAndroid Build Coastguard Worker            A_norm = float(torch.norm(_utils.matmul(self.A, self.X))) * iX_norm
753*da0073e9SAndroid Build Coastguard Worker            B_norm = float(torch.norm(_utils.matmul(self.B, self.X))) * iX_norm
754*da0073e9SAndroid Build Coastguard Worker            self.fvars["X_norm"] = X_norm
755*da0073e9SAndroid Build Coastguard Worker            self.fvars["A_norm"] = A_norm
756*da0073e9SAndroid Build Coastguard Worker            self.fvars["B_norm"] = B_norm
757*da0073e9SAndroid Build Coastguard Worker            self.ivars["iterations_left"] = self.iparams["niter"]
758*da0073e9SAndroid Build Coastguard Worker            self.ivars["converged_count"] = 0
759*da0073e9SAndroid Build Coastguard Worker            self.ivars["converged_end"] = 0
760*da0073e9SAndroid Build Coastguard Worker
761*da0073e9SAndroid Build Coastguard Worker        if self.method == "ortho":
762*da0073e9SAndroid Build Coastguard Worker            self._update_ortho()
763*da0073e9SAndroid Build Coastguard Worker        else:
764*da0073e9SAndroid Build Coastguard Worker            self._update_basic()
765*da0073e9SAndroid Build Coastguard Worker
766*da0073e9SAndroid Build Coastguard Worker        self.ivars["iterations_left"] = self.ivars["iterations_left"] - 1
767*da0073e9SAndroid Build Coastguard Worker        self.ivars["istep"] = self.ivars["istep"] + 1
768*da0073e9SAndroid Build Coastguard Worker
769*da0073e9SAndroid Build Coastguard Worker    def update_residual(self):
770*da0073e9SAndroid Build Coastguard Worker        """Update residual R from A, B, X, E."""
771*da0073e9SAndroid Build Coastguard Worker        mm = _utils.matmul
772*da0073e9SAndroid Build Coastguard Worker        self.R = mm(self.A, self.X) - mm(self.B, self.X) * self.E
773*da0073e9SAndroid Build Coastguard Worker
774*da0073e9SAndroid Build Coastguard Worker    def update_converged_count(self):
775*da0073e9SAndroid Build Coastguard Worker        """Determine the number of converged eigenpairs using backward stable
776*da0073e9SAndroid Build Coastguard Worker        convergence criterion, see discussion in Sec 4.3 of [DuerschEtal2018].
777*da0073e9SAndroid Build Coastguard Worker
778*da0073e9SAndroid Build Coastguard Worker        Users may redefine this method for custom convergence criteria.
779*da0073e9SAndroid Build Coastguard Worker        """
780*da0073e9SAndroid Build Coastguard Worker        # (...) -> int
781*da0073e9SAndroid Build Coastguard Worker        prev_count = self.ivars["converged_count"]
782*da0073e9SAndroid Build Coastguard Worker        tol = self.fparams["tol"]
783*da0073e9SAndroid Build Coastguard Worker        A_norm = self.fvars["A_norm"]
784*da0073e9SAndroid Build Coastguard Worker        B_norm = self.fvars["B_norm"]
785*da0073e9SAndroid Build Coastguard Worker        E, X, R = self.E, self.X, self.R
786*da0073e9SAndroid Build Coastguard Worker        rerr = (
787*da0073e9SAndroid Build Coastguard Worker            torch.norm(R, 2, (0,))
788*da0073e9SAndroid Build Coastguard Worker            * (torch.norm(X, 2, (0,)) * (A_norm + E[: X.shape[-1]] * B_norm)) ** -1
789*da0073e9SAndroid Build Coastguard Worker        )
790*da0073e9SAndroid Build Coastguard Worker        converged = rerr.real < tol  # this is a norm so imag is 0.0
791*da0073e9SAndroid Build Coastguard Worker        count = 0
792*da0073e9SAndroid Build Coastguard Worker        for b in converged:
793*da0073e9SAndroid Build Coastguard Worker            if not b:
794*da0073e9SAndroid Build Coastguard Worker                # ignore convergence of following pairs to ensure
795*da0073e9SAndroid Build Coastguard Worker                # strict ordering of eigenpairs
796*da0073e9SAndroid Build Coastguard Worker                break
797*da0073e9SAndroid Build Coastguard Worker            count += 1
798*da0073e9SAndroid Build Coastguard Worker        assert (
799*da0073e9SAndroid Build Coastguard Worker            count >= prev_count
800*da0073e9SAndroid Build Coastguard Worker        ), f"the number of converged eigenpairs (was {prev_count}, got {count}) cannot decrease"
801*da0073e9SAndroid Build Coastguard Worker        self.ivars["converged_count"] = count
802*da0073e9SAndroid Build Coastguard Worker        self.tvars["rerr"] = rerr
803*da0073e9SAndroid Build Coastguard Worker        return count
804*da0073e9SAndroid Build Coastguard Worker
805*da0073e9SAndroid Build Coastguard Worker    def stop_iteration(self):
806*da0073e9SAndroid Build Coastguard Worker        """Return True to stop iterations.
807*da0073e9SAndroid Build Coastguard Worker
808*da0073e9SAndroid Build Coastguard Worker        Note that tracker (if defined) can force-stop iterations by
809*da0073e9SAndroid Build Coastguard Worker        setting ``worker.bvars['force_stop'] = True``.
810*da0073e9SAndroid Build Coastguard Worker        """
811*da0073e9SAndroid Build Coastguard Worker        return (
812*da0073e9SAndroid Build Coastguard Worker            self.bvars.get("force_stop", False)
813*da0073e9SAndroid Build Coastguard Worker            or self.ivars["iterations_left"] == 0
814*da0073e9SAndroid Build Coastguard Worker            or self.ivars["converged_count"] >= self.iparams["k"]
815*da0073e9SAndroid Build Coastguard Worker        )
816*da0073e9SAndroid Build Coastguard Worker
817*da0073e9SAndroid Build Coastguard Worker    def run(self):
818*da0073e9SAndroid Build Coastguard Worker        """Run LOBPCG iterations.
819*da0073e9SAndroid Build Coastguard Worker
820*da0073e9SAndroid Build Coastguard Worker        Use this method as a template for implementing LOBPCG
821*da0073e9SAndroid Build Coastguard Worker        iteration scheme with custom tracker that is compatible with
822*da0073e9SAndroid Build Coastguard Worker        TorchScript.
823*da0073e9SAndroid Build Coastguard Worker        """
824*da0073e9SAndroid Build Coastguard Worker        self.update()
825*da0073e9SAndroid Build Coastguard Worker
826*da0073e9SAndroid Build Coastguard Worker        if not torch.jit.is_scripting() and self.tracker is not None:
827*da0073e9SAndroid Build Coastguard Worker            self.call_tracker()
828*da0073e9SAndroid Build Coastguard Worker
829*da0073e9SAndroid Build Coastguard Worker        while not self.stop_iteration():
830*da0073e9SAndroid Build Coastguard Worker            self.update()
831*da0073e9SAndroid Build Coastguard Worker
832*da0073e9SAndroid Build Coastguard Worker            if not torch.jit.is_scripting() and self.tracker is not None:
833*da0073e9SAndroid Build Coastguard Worker                self.call_tracker()
834*da0073e9SAndroid Build Coastguard Worker
835*da0073e9SAndroid Build Coastguard Worker    @torch.jit.unused
836*da0073e9SAndroid Build Coastguard Worker    def call_tracker(self):
837*da0073e9SAndroid Build Coastguard Worker        """Interface for tracking iteration process in Python mode.
838*da0073e9SAndroid Build Coastguard Worker
839*da0073e9SAndroid Build Coastguard Worker        Tracking the iteration process is disabled in TorchScript
840*da0073e9SAndroid Build Coastguard Worker        mode. In fact, one should specify tracker=None when JIT
841*da0073e9SAndroid Build Coastguard Worker        compiling functions using lobpcg.
842*da0073e9SAndroid Build Coastguard Worker        """
843*da0073e9SAndroid Build Coastguard Worker        # do nothing when in TorchScript mode
844*da0073e9SAndroid Build Coastguard Worker
845*da0073e9SAndroid Build Coastguard Worker    # Internal methods
846*da0073e9SAndroid Build Coastguard Worker
847*da0073e9SAndroid Build Coastguard Worker    def _update_basic(self):
848*da0073e9SAndroid Build Coastguard Worker        """
849*da0073e9SAndroid Build Coastguard Worker        Update or initialize iteration variables when `method == "basic"`.
850*da0073e9SAndroid Build Coastguard Worker        """
851*da0073e9SAndroid Build Coastguard Worker        mm = torch.matmul
852*da0073e9SAndroid Build Coastguard Worker        ns = self.ivars["converged_end"]
853*da0073e9SAndroid Build Coastguard Worker        nc = self.ivars["converged_count"]
854*da0073e9SAndroid Build Coastguard Worker        n = self.iparams["n"]
855*da0073e9SAndroid Build Coastguard Worker        largest = self.bparams["largest"]
856*da0073e9SAndroid Build Coastguard Worker
857*da0073e9SAndroid Build Coastguard Worker        if self.ivars["istep"] == 0:
858*da0073e9SAndroid Build Coastguard Worker            Ri = self._get_rayleigh_ritz_transform(self.X)
859*da0073e9SAndroid Build Coastguard Worker            M = _utils.qform(_utils.qform(self.A, self.X), Ri)
860*da0073e9SAndroid Build Coastguard Worker            E, Z = _utils.symeig(M, largest)
861*da0073e9SAndroid Build Coastguard Worker            self.X[:] = mm(self.X, mm(Ri, Z))
862*da0073e9SAndroid Build Coastguard Worker            self.E[:] = E
863*da0073e9SAndroid Build Coastguard Worker            np = 0
864*da0073e9SAndroid Build Coastguard Worker            self.update_residual()
865*da0073e9SAndroid Build Coastguard Worker            nc = self.update_converged_count()
866*da0073e9SAndroid Build Coastguard Worker            self.S[..., :n] = self.X
867*da0073e9SAndroid Build Coastguard Worker
868*da0073e9SAndroid Build Coastguard Worker            W = _utils.matmul(self.iK, self.R)
869*da0073e9SAndroid Build Coastguard Worker            self.ivars["converged_end"] = ns = n + np + W.shape[-1]
870*da0073e9SAndroid Build Coastguard Worker            self.S[:, n + np : ns] = W
871*da0073e9SAndroid Build Coastguard Worker        else:
872*da0073e9SAndroid Build Coastguard Worker            S_ = self.S[:, nc:ns]
873*da0073e9SAndroid Build Coastguard Worker            Ri = self._get_rayleigh_ritz_transform(S_)
874*da0073e9SAndroid Build Coastguard Worker            M = _utils.qform(_utils.qform(self.A, S_), Ri)
875*da0073e9SAndroid Build Coastguard Worker            E_, Z = _utils.symeig(M, largest)
876*da0073e9SAndroid Build Coastguard Worker            self.X[:, nc:] = mm(S_, mm(Ri, Z[:, : n - nc]))
877*da0073e9SAndroid Build Coastguard Worker            self.E[nc:] = E_[: n - nc]
878*da0073e9SAndroid Build Coastguard Worker            P = mm(S_, mm(Ri, Z[:, n : 2 * n - nc]))
879*da0073e9SAndroid Build Coastguard Worker            np = P.shape[-1]
880*da0073e9SAndroid Build Coastguard Worker
881*da0073e9SAndroid Build Coastguard Worker            self.update_residual()
882*da0073e9SAndroid Build Coastguard Worker            nc = self.update_converged_count()
883*da0073e9SAndroid Build Coastguard Worker            self.S[..., :n] = self.X
884*da0073e9SAndroid Build Coastguard Worker            self.S[:, n : n + np] = P
885*da0073e9SAndroid Build Coastguard Worker            W = _utils.matmul(self.iK, self.R[:, nc:])
886*da0073e9SAndroid Build Coastguard Worker
887*da0073e9SAndroid Build Coastguard Worker            self.ivars["converged_end"] = ns = n + np + W.shape[-1]
888*da0073e9SAndroid Build Coastguard Worker            self.S[:, n + np : ns] = W
889*da0073e9SAndroid Build Coastguard Worker
890*da0073e9SAndroid Build Coastguard Worker    def _update_ortho(self):
891*da0073e9SAndroid Build Coastguard Worker        """
892*da0073e9SAndroid Build Coastguard Worker        Update or initialize iteration variables when `method == "ortho"`.
893*da0073e9SAndroid Build Coastguard Worker        """
894*da0073e9SAndroid Build Coastguard Worker        mm = torch.matmul
895*da0073e9SAndroid Build Coastguard Worker        ns = self.ivars["converged_end"]
896*da0073e9SAndroid Build Coastguard Worker        nc = self.ivars["converged_count"]
897*da0073e9SAndroid Build Coastguard Worker        n = self.iparams["n"]
898*da0073e9SAndroid Build Coastguard Worker        largest = self.bparams["largest"]
899*da0073e9SAndroid Build Coastguard Worker
900*da0073e9SAndroid Build Coastguard Worker        if self.ivars["istep"] == 0:
901*da0073e9SAndroid Build Coastguard Worker            Ri = self._get_rayleigh_ritz_transform(self.X)
902*da0073e9SAndroid Build Coastguard Worker            M = _utils.qform(_utils.qform(self.A, self.X), Ri)
903*da0073e9SAndroid Build Coastguard Worker            E, Z = _utils.symeig(M, largest)
904*da0073e9SAndroid Build Coastguard Worker            self.X = mm(self.X, mm(Ri, Z))
905*da0073e9SAndroid Build Coastguard Worker            self.update_residual()
906*da0073e9SAndroid Build Coastguard Worker            np = 0
907*da0073e9SAndroid Build Coastguard Worker            nc = self.update_converged_count()
908*da0073e9SAndroid Build Coastguard Worker            self.S[:, :n] = self.X
909*da0073e9SAndroid Build Coastguard Worker            W = self._get_ortho(self.R, self.X)
910*da0073e9SAndroid Build Coastguard Worker            ns = self.ivars["converged_end"] = n + np + W.shape[-1]
911*da0073e9SAndroid Build Coastguard Worker            self.S[:, n + np : ns] = W
912*da0073e9SAndroid Build Coastguard Worker
913*da0073e9SAndroid Build Coastguard Worker        else:
914*da0073e9SAndroid Build Coastguard Worker            S_ = self.S[:, nc:ns]
915*da0073e9SAndroid Build Coastguard Worker            # Rayleigh-Ritz procedure
916*da0073e9SAndroid Build Coastguard Worker            E_, Z = _utils.symeig(_utils.qform(self.A, S_), largest)
917*da0073e9SAndroid Build Coastguard Worker
918*da0073e9SAndroid Build Coastguard Worker            # Update E, X, P
919*da0073e9SAndroid Build Coastguard Worker            self.X[:, nc:] = mm(S_, Z[:, : n - nc])
920*da0073e9SAndroid Build Coastguard Worker            self.E[nc:] = E_[: n - nc]
921*da0073e9SAndroid Build Coastguard Worker            P = mm(S_, mm(Z[:, n - nc :], _utils.basis(Z[: n - nc, n - nc :].mT)))
922*da0073e9SAndroid Build Coastguard Worker            np = P.shape[-1]
923*da0073e9SAndroid Build Coastguard Worker
924*da0073e9SAndroid Build Coastguard Worker            # check convergence
925*da0073e9SAndroid Build Coastguard Worker            self.update_residual()
926*da0073e9SAndroid Build Coastguard Worker            nc = self.update_converged_count()
927*da0073e9SAndroid Build Coastguard Worker
928*da0073e9SAndroid Build Coastguard Worker            # update S
929*da0073e9SAndroid Build Coastguard Worker            self.S[:, :n] = self.X
930*da0073e9SAndroid Build Coastguard Worker            self.S[:, n : n + np] = P
931*da0073e9SAndroid Build Coastguard Worker            W = self._get_ortho(self.R[:, nc:], self.S[:, : n + np])
932*da0073e9SAndroid Build Coastguard Worker            ns = self.ivars["converged_end"] = n + np + W.shape[-1]
933*da0073e9SAndroid Build Coastguard Worker            self.S[:, n + np : ns] = W
934*da0073e9SAndroid Build Coastguard Worker
935*da0073e9SAndroid Build Coastguard Worker    def _get_rayleigh_ritz_transform(self, S):
936*da0073e9SAndroid Build Coastguard Worker        """Return a transformation matrix that is used in Rayleigh-Ritz
937*da0073e9SAndroid Build Coastguard Worker        procedure for reducing a general eigenvalue problem :math:`(S^TAS)
938*da0073e9SAndroid Build Coastguard Worker        C = (S^TBS) C E` to a standard eigenvalue problem :math: `(Ri^T
939*da0073e9SAndroid Build Coastguard Worker        S^TAS Ri) Z = Z E` where `C = Ri Z`.
940*da0073e9SAndroid Build Coastguard Worker
941*da0073e9SAndroid Build Coastguard Worker        .. note:: In the original Rayleight-Ritz procedure in
942*da0073e9SAndroid Build Coastguard Worker          [DuerschEtal2018], the problem is formulated as follows::
943*da0073e9SAndroid Build Coastguard Worker
944*da0073e9SAndroid Build Coastguard Worker            SAS = S^T A S
945*da0073e9SAndroid Build Coastguard Worker            SBS = S^T B S
946*da0073e9SAndroid Build Coastguard Worker            D = (<diagonal matrix of SBS>) ** -1/2
947*da0073e9SAndroid Build Coastguard Worker            R^T R = Cholesky(D SBS D)
948*da0073e9SAndroid Build Coastguard Worker            Ri = D R^-1
949*da0073e9SAndroid Build Coastguard Worker            solve symeig problem Ri^T SAS Ri Z = Theta Z
950*da0073e9SAndroid Build Coastguard Worker            C = Ri Z
951*da0073e9SAndroid Build Coastguard Worker
952*da0073e9SAndroid Build Coastguard Worker          To reduce the number of matrix products (denoted by empty
953*da0073e9SAndroid Build Coastguard Worker          space between matrices), here we introduce element-wise
954*da0073e9SAndroid Build Coastguard Worker          products (denoted by symbol `*`) so that the Rayleight-Ritz
955*da0073e9SAndroid Build Coastguard Worker          procedure becomes::
956*da0073e9SAndroid Build Coastguard Worker
957*da0073e9SAndroid Build Coastguard Worker            SAS = S^T A S
958*da0073e9SAndroid Build Coastguard Worker            SBS = S^T B S
959*da0073e9SAndroid Build Coastguard Worker            d = (<diagonal of SBS>) ** -1/2    # this is 1-d column vector
960*da0073e9SAndroid Build Coastguard Worker            dd = d d^T                         # this is 2-d matrix
961*da0073e9SAndroid Build Coastguard Worker            R^T R = Cholesky(dd * SBS)
962*da0073e9SAndroid Build Coastguard Worker            Ri = R^-1 * d                      # broadcasting
963*da0073e9SAndroid Build Coastguard Worker            solve symeig problem Ri^T SAS Ri Z = Theta Z
964*da0073e9SAndroid Build Coastguard Worker            C = Ri Z
965*da0073e9SAndroid Build Coastguard Worker
966*da0073e9SAndroid Build Coastguard Worker          where `dd` is 2-d matrix that replaces matrix products `D M
967*da0073e9SAndroid Build Coastguard Worker          D` with one element-wise product `M * dd`; and `d` replaces
968*da0073e9SAndroid Build Coastguard Worker          matrix product `D M` with element-wise product `M *
969*da0073e9SAndroid Build Coastguard Worker          d`. Also, creating the diagonal matrix `D` is avoided.
970*da0073e9SAndroid Build Coastguard Worker
971*da0073e9SAndroid Build Coastguard Worker        Args:
972*da0073e9SAndroid Build Coastguard Worker        S (Tensor): the matrix basis for the search subspace, size is
973*da0073e9SAndroid Build Coastguard Worker                    :math:`(m, n)`.
974*da0073e9SAndroid Build Coastguard Worker
975*da0073e9SAndroid Build Coastguard Worker        Returns:
976*da0073e9SAndroid Build Coastguard Worker        Ri (tensor): upper-triangular transformation matrix of size
977*da0073e9SAndroid Build Coastguard Worker                     :math:`(n, n)`.
978*da0073e9SAndroid Build Coastguard Worker
979*da0073e9SAndroid Build Coastguard Worker        """
980*da0073e9SAndroid Build Coastguard Worker        B = self.B
981*da0073e9SAndroid Build Coastguard Worker        mm = torch.matmul
982*da0073e9SAndroid Build Coastguard Worker        SBS = _utils.qform(B, S)
983*da0073e9SAndroid Build Coastguard Worker        d_row = SBS.diagonal(0, -2, -1) ** -0.5
984*da0073e9SAndroid Build Coastguard Worker        d_col = d_row.reshape(d_row.shape[0], 1)
985*da0073e9SAndroid Build Coastguard Worker        # TODO use torch.linalg.cholesky_solve once it is implemented
986*da0073e9SAndroid Build Coastguard Worker        R = torch.linalg.cholesky((SBS * d_row) * d_col, upper=True)
987*da0073e9SAndroid Build Coastguard Worker        return torch.linalg.solve_triangular(
988*da0073e9SAndroid Build Coastguard Worker            R, d_row.diag_embed(), upper=True, left=False
989*da0073e9SAndroid Build Coastguard Worker        )
990*da0073e9SAndroid Build Coastguard Worker
991*da0073e9SAndroid Build Coastguard Worker    def _get_svqb(self, U: Tensor, drop: bool, tau: float) -> Tensor:
992*da0073e9SAndroid Build Coastguard Worker        """Return B-orthonormal U.
993*da0073e9SAndroid Build Coastguard Worker
994*da0073e9SAndroid Build Coastguard Worker        .. note:: When `drop` is `False` then `svqb` is based on the
995*da0073e9SAndroid Build Coastguard Worker                  Algorithm 4 from [DuerschPhD2015] that is a slight
996*da0073e9SAndroid Build Coastguard Worker                  modification of the corresponding algorithm
997*da0073e9SAndroid Build Coastguard Worker                  introduced in [StathopolousWu2002].
998*da0073e9SAndroid Build Coastguard Worker
999*da0073e9SAndroid Build Coastguard Worker        Args:
1000*da0073e9SAndroid Build Coastguard Worker
1001*da0073e9SAndroid Build Coastguard Worker          U (Tensor) : initial approximation, size is (m, n)
1002*da0073e9SAndroid Build Coastguard Worker          drop (bool) : when True, drop columns that
1003*da0073e9SAndroid Build Coastguard Worker                     contribution to the `span([U])` is small.
1004*da0073e9SAndroid Build Coastguard Worker          tau (float) : positive tolerance
1005*da0073e9SAndroid Build Coastguard Worker
1006*da0073e9SAndroid Build Coastguard Worker        Returns:
1007*da0073e9SAndroid Build Coastguard Worker
1008*da0073e9SAndroid Build Coastguard Worker          U (Tensor) : B-orthonormal columns (:math:`U^T B U = I`), size
1009*da0073e9SAndroid Build Coastguard Worker                       is (m, n1), where `n1 = n` if `drop` is `False,
1010*da0073e9SAndroid Build Coastguard Worker                       otherwise `n1 <= n`.
1011*da0073e9SAndroid Build Coastguard Worker
1012*da0073e9SAndroid Build Coastguard Worker        """
1013*da0073e9SAndroid Build Coastguard Worker        if torch.numel(U) == 0:
1014*da0073e9SAndroid Build Coastguard Worker            return U
1015*da0073e9SAndroid Build Coastguard Worker        UBU = _utils.qform(self.B, U)
1016*da0073e9SAndroid Build Coastguard Worker        d = UBU.diagonal(0, -2, -1)
1017*da0073e9SAndroid Build Coastguard Worker
1018*da0073e9SAndroid Build Coastguard Worker        # Detect and drop exact zero columns from U. While the test
1019*da0073e9SAndroid Build Coastguard Worker        # `abs(d) == 0` is unlikely to be True for random data, it is
1020*da0073e9SAndroid Build Coastguard Worker        # possible to construct input data to lobpcg where it will be
1021*da0073e9SAndroid Build Coastguard Worker        # True leading to a failure (notice the `d ** -0.5` operation
1022*da0073e9SAndroid Build Coastguard Worker        # in the original algorithm). To prevent the failure, we drop
1023*da0073e9SAndroid Build Coastguard Worker        # the exact zero columns here and then continue with the
1024*da0073e9SAndroid Build Coastguard Worker        # original algorithm below.
1025*da0073e9SAndroid Build Coastguard Worker        nz = torch.where(abs(d) != 0.0)
1026*da0073e9SAndroid Build Coastguard Worker        assert len(nz) == 1, nz
1027*da0073e9SAndroid Build Coastguard Worker        if len(nz[0]) < len(d):
1028*da0073e9SAndroid Build Coastguard Worker            U = U[:, nz[0]]
1029*da0073e9SAndroid Build Coastguard Worker            if torch.numel(U) == 0:
1030*da0073e9SAndroid Build Coastguard Worker                return U
1031*da0073e9SAndroid Build Coastguard Worker            UBU = _utils.qform(self.B, U)
1032*da0073e9SAndroid Build Coastguard Worker            d = UBU.diagonal(0, -2, -1)
1033*da0073e9SAndroid Build Coastguard Worker            nz = torch.where(abs(d) != 0.0)
1034*da0073e9SAndroid Build Coastguard Worker            assert len(nz[0]) == len(d)
1035*da0073e9SAndroid Build Coastguard Worker
1036*da0073e9SAndroid Build Coastguard Worker        # The original algorithm 4 from [DuerschPhD2015].
1037*da0073e9SAndroid Build Coastguard Worker        d_col = (d**-0.5).reshape(d.shape[0], 1)
1038*da0073e9SAndroid Build Coastguard Worker        DUBUD = (UBU * d_col) * d_col.mT
1039*da0073e9SAndroid Build Coastguard Worker        E, Z = _utils.symeig(DUBUD)
1040*da0073e9SAndroid Build Coastguard Worker        t = tau * abs(E).max()
1041*da0073e9SAndroid Build Coastguard Worker        if drop:
1042*da0073e9SAndroid Build Coastguard Worker            keep = torch.where(E > t)
1043*da0073e9SAndroid Build Coastguard Worker            assert len(keep) == 1, keep
1044*da0073e9SAndroid Build Coastguard Worker            E = E[keep[0]]
1045*da0073e9SAndroid Build Coastguard Worker            Z = Z[:, keep[0]]
1046*da0073e9SAndroid Build Coastguard Worker            d_col = d_col[keep[0]]
1047*da0073e9SAndroid Build Coastguard Worker        else:
1048*da0073e9SAndroid Build Coastguard Worker            E[(torch.where(E < t))[0]] = t
1049*da0073e9SAndroid Build Coastguard Worker
1050*da0073e9SAndroid Build Coastguard Worker        return torch.matmul(U * d_col.mT, Z * E**-0.5)
1051*da0073e9SAndroid Build Coastguard Worker
1052*da0073e9SAndroid Build Coastguard Worker    def _get_ortho(self, U, V):
1053*da0073e9SAndroid Build Coastguard Worker        """Return B-orthonormal U with columns are B-orthogonal to V.
1054*da0073e9SAndroid Build Coastguard Worker
1055*da0073e9SAndroid Build Coastguard Worker        .. note:: When `bparams["ortho_use_drop"] == False` then
1056*da0073e9SAndroid Build Coastguard Worker                  `_get_ortho` is based on the Algorithm 3 from
1057*da0073e9SAndroid Build Coastguard Worker                  [DuerschPhD2015] that is a slight modification of
1058*da0073e9SAndroid Build Coastguard Worker                  the corresponding algorithm introduced in
1059*da0073e9SAndroid Build Coastguard Worker                  [StathopolousWu2002]. Otherwise, the method
1060*da0073e9SAndroid Build Coastguard Worker                  implements Algorithm 6 from [DuerschPhD2015]
1061*da0073e9SAndroid Build Coastguard Worker
1062*da0073e9SAndroid Build Coastguard Worker        .. note:: If all U columns are B-collinear to V then the
1063*da0073e9SAndroid Build Coastguard Worker                  returned tensor U will be empty.
1064*da0073e9SAndroid Build Coastguard Worker
1065*da0073e9SAndroid Build Coastguard Worker        Args:
1066*da0073e9SAndroid Build Coastguard Worker
1067*da0073e9SAndroid Build Coastguard Worker          U (Tensor) : initial approximation, size is (m, n)
1068*da0073e9SAndroid Build Coastguard Worker          V (Tensor) : B-orthogonal external basis, size is (m, k)
1069*da0073e9SAndroid Build Coastguard Worker
1070*da0073e9SAndroid Build Coastguard Worker        Returns:
1071*da0073e9SAndroid Build Coastguard Worker
1072*da0073e9SAndroid Build Coastguard Worker          U (Tensor) : B-orthonormal columns (:math:`U^T B U = I`)
1073*da0073e9SAndroid Build Coastguard Worker                       such that :math:`V^T B U=0`, size is (m, n1),
1074*da0073e9SAndroid Build Coastguard Worker                       where `n1 = n` if `drop` is `False, otherwise
1075*da0073e9SAndroid Build Coastguard Worker                       `n1 <= n`.
1076*da0073e9SAndroid Build Coastguard Worker        """
1077*da0073e9SAndroid Build Coastguard Worker        mm = torch.matmul
1078*da0073e9SAndroid Build Coastguard Worker        mm_B = _utils.matmul
1079*da0073e9SAndroid Build Coastguard Worker        m = self.iparams["m"]
1080*da0073e9SAndroid Build Coastguard Worker        tau_ortho = self.fparams["ortho_tol"]
1081*da0073e9SAndroid Build Coastguard Worker        tau_drop = self.fparams["ortho_tol_drop"]
1082*da0073e9SAndroid Build Coastguard Worker        tau_replace = self.fparams["ortho_tol_replace"]
1083*da0073e9SAndroid Build Coastguard Worker        i_max = self.iparams["ortho_i_max"]
1084*da0073e9SAndroid Build Coastguard Worker        j_max = self.iparams["ortho_j_max"]
1085*da0073e9SAndroid Build Coastguard Worker        # when use_drop==True, enable dropping U columns that have
1086*da0073e9SAndroid Build Coastguard Worker        # small contribution to the `span([U, V])`.
1087*da0073e9SAndroid Build Coastguard Worker        use_drop = self.bparams["ortho_use_drop"]
1088*da0073e9SAndroid Build Coastguard Worker
1089*da0073e9SAndroid Build Coastguard Worker        # clean up variables from the previous call
1090*da0073e9SAndroid Build Coastguard Worker        for vkey in list(self.fvars.keys()):
1091*da0073e9SAndroid Build Coastguard Worker            if vkey.startswith("ortho_") and vkey.endswith("_rerr"):
1092*da0073e9SAndroid Build Coastguard Worker                self.fvars.pop(vkey)
1093*da0073e9SAndroid Build Coastguard Worker        self.ivars.pop("ortho_i", 0)
1094*da0073e9SAndroid Build Coastguard Worker        self.ivars.pop("ortho_j", 0)
1095*da0073e9SAndroid Build Coastguard Worker
1096*da0073e9SAndroid Build Coastguard Worker        BV_norm = torch.norm(mm_B(self.B, V))
1097*da0073e9SAndroid Build Coastguard Worker        BU = mm_B(self.B, U)
1098*da0073e9SAndroid Build Coastguard Worker        VBU = mm(V.mT, BU)
1099*da0073e9SAndroid Build Coastguard Worker        i = j = 0
1100*da0073e9SAndroid Build Coastguard Worker        stats = ""
1101*da0073e9SAndroid Build Coastguard Worker        for i in range(i_max):
1102*da0073e9SAndroid Build Coastguard Worker            U = U - mm(V, VBU)
1103*da0073e9SAndroid Build Coastguard Worker            drop = False
1104*da0073e9SAndroid Build Coastguard Worker            tau_svqb = tau_drop
1105*da0073e9SAndroid Build Coastguard Worker            for j in range(j_max):
1106*da0073e9SAndroid Build Coastguard Worker                if use_drop:
1107*da0073e9SAndroid Build Coastguard Worker                    U = self._get_svqb(U, drop, tau_svqb)
1108*da0073e9SAndroid Build Coastguard Worker                    drop = True
1109*da0073e9SAndroid Build Coastguard Worker                    tau_svqb = tau_replace
1110*da0073e9SAndroid Build Coastguard Worker                else:
1111*da0073e9SAndroid Build Coastguard Worker                    U = self._get_svqb(U, False, tau_replace)
1112*da0073e9SAndroid Build Coastguard Worker                if torch.numel(U) == 0:
1113*da0073e9SAndroid Build Coastguard Worker                    # all initial U columns are B-collinear to V
1114*da0073e9SAndroid Build Coastguard Worker                    self.ivars["ortho_i"] = i
1115*da0073e9SAndroid Build Coastguard Worker                    self.ivars["ortho_j"] = j
1116*da0073e9SAndroid Build Coastguard Worker                    return U
1117*da0073e9SAndroid Build Coastguard Worker                BU = mm_B(self.B, U)
1118*da0073e9SAndroid Build Coastguard Worker                UBU = mm(U.mT, BU)
1119*da0073e9SAndroid Build Coastguard Worker                U_norm = torch.norm(U)
1120*da0073e9SAndroid Build Coastguard Worker                BU_norm = torch.norm(BU)
1121*da0073e9SAndroid Build Coastguard Worker                R = UBU - torch.eye(UBU.shape[-1], device=UBU.device, dtype=UBU.dtype)
1122*da0073e9SAndroid Build Coastguard Worker                R_norm = torch.norm(R)
1123*da0073e9SAndroid Build Coastguard Worker                # https://github.com/pytorch/pytorch/issues/33810 workaround:
1124*da0073e9SAndroid Build Coastguard Worker                rerr = float(R_norm) * float(BU_norm * U_norm) ** -1
1125*da0073e9SAndroid Build Coastguard Worker                vkey = f"ortho_UBUmI_rerr[{i}, {j}]"
1126*da0073e9SAndroid Build Coastguard Worker                self.fvars[vkey] = rerr
1127*da0073e9SAndroid Build Coastguard Worker                if rerr < tau_ortho:
1128*da0073e9SAndroid Build Coastguard Worker                    break
1129*da0073e9SAndroid Build Coastguard Worker            VBU = mm(V.mT, BU)
1130*da0073e9SAndroid Build Coastguard Worker            VBU_norm = torch.norm(VBU)
1131*da0073e9SAndroid Build Coastguard Worker            U_norm = torch.norm(U)
1132*da0073e9SAndroid Build Coastguard Worker            rerr = float(VBU_norm) * float(BV_norm * U_norm) ** -1
1133*da0073e9SAndroid Build Coastguard Worker            vkey = f"ortho_VBU_rerr[{i}]"
1134*da0073e9SAndroid Build Coastguard Worker            self.fvars[vkey] = rerr
1135*da0073e9SAndroid Build Coastguard Worker            if rerr < tau_ortho:
1136*da0073e9SAndroid Build Coastguard Worker                break
1137*da0073e9SAndroid Build Coastguard Worker            if m < U.shape[-1] + V.shape[-1]:
1138*da0073e9SAndroid Build Coastguard Worker                # TorchScript needs the class var to be assigned to a local to
1139*da0073e9SAndroid Build Coastguard Worker                # do optional type refinement
1140*da0073e9SAndroid Build Coastguard Worker                B = self.B
1141*da0073e9SAndroid Build Coastguard Worker                assert B is not None
1142*da0073e9SAndroid Build Coastguard Worker                raise ValueError(
1143*da0073e9SAndroid Build Coastguard Worker                    "Overdetermined shape of U:"
1144*da0073e9SAndroid Build Coastguard Worker                    f" #B-cols(={B.shape[-1]}) >= #U-cols(={U.shape[-1]}) + #V-cols(={V.shape[-1]}) must hold"
1145*da0073e9SAndroid Build Coastguard Worker                )
1146*da0073e9SAndroid Build Coastguard Worker        self.ivars["ortho_i"] = i
1147*da0073e9SAndroid Build Coastguard Worker        self.ivars["ortho_j"] = j
1148*da0073e9SAndroid Build Coastguard Worker        return U
1149*da0073e9SAndroid Build Coastguard Worker
1150*da0073e9SAndroid Build Coastguard Worker
1151*da0073e9SAndroid Build Coastguard Worker# Calling tracker is separated from LOBPCG definitions because
1152*da0073e9SAndroid Build Coastguard Worker# TorchScript does not support user-defined callback arguments:
1153*da0073e9SAndroid Build Coastguard WorkerLOBPCG_call_tracker_orig = LOBPCG.call_tracker
1154*da0073e9SAndroid Build Coastguard Worker
1155*da0073e9SAndroid Build Coastguard Worker
1156*da0073e9SAndroid Build Coastguard Workerdef LOBPCG_call_tracker(self):
1157*da0073e9SAndroid Build Coastguard Worker    self.tracker(self)
1158