xref: /aosp_15_r20/external/pytorch/torch/optim/lbfgs.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs
2*da0073e9SAndroid Build Coastguard Workerfrom typing import Optional, Union
3*da0073e9SAndroid Build Coastguard Worker
4*da0073e9SAndroid Build Coastguard Workerimport torch
5*da0073e9SAndroid Build Coastguard Workerfrom torch import Tensor
6*da0073e9SAndroid Build Coastguard Worker
7*da0073e9SAndroid Build Coastguard Workerfrom .optimizer import Optimizer, ParamsT
8*da0073e9SAndroid Build Coastguard Worker
9*da0073e9SAndroid Build Coastguard Worker
10*da0073e9SAndroid Build Coastguard Worker__all__ = ["LBFGS"]
11*da0073e9SAndroid Build Coastguard Worker
12*da0073e9SAndroid Build Coastguard Worker
13*da0073e9SAndroid Build Coastguard Workerdef _cubic_interpolate(x1, f1, g1, x2, f2, g2, bounds=None):
14*da0073e9SAndroid Build Coastguard Worker    # ported from https://github.com/torch/optim/blob/master/polyinterp.lua
15*da0073e9SAndroid Build Coastguard Worker    # Compute bounds of interpolation area
16*da0073e9SAndroid Build Coastguard Worker    if bounds is not None:
17*da0073e9SAndroid Build Coastguard Worker        xmin_bound, xmax_bound = bounds
18*da0073e9SAndroid Build Coastguard Worker    else:
19*da0073e9SAndroid Build Coastguard Worker        xmin_bound, xmax_bound = (x1, x2) if x1 <= x2 else (x2, x1)
20*da0073e9SAndroid Build Coastguard Worker
21*da0073e9SAndroid Build Coastguard Worker    # Code for most common case: cubic interpolation of 2 points
22*da0073e9SAndroid Build Coastguard Worker    #   w/ function and derivative values for both
23*da0073e9SAndroid Build Coastguard Worker    # Solution in this case (where x2 is the farthest point):
24*da0073e9SAndroid Build Coastguard Worker    #   d1 = g1 + g2 - 3*(f1-f2)/(x1-x2);
25*da0073e9SAndroid Build Coastguard Worker    #   d2 = sqrt(d1^2 - g1*g2);
26*da0073e9SAndroid Build Coastguard Worker    #   min_pos = x2 - (x2 - x1)*((g2 + d2 - d1)/(g2 - g1 + 2*d2));
27*da0073e9SAndroid Build Coastguard Worker    #   t_new = min(max(min_pos,xmin_bound),xmax_bound);
28*da0073e9SAndroid Build Coastguard Worker    d1 = g1 + g2 - 3 * (f1 - f2) / (x1 - x2)
29*da0073e9SAndroid Build Coastguard Worker    d2_square = d1**2 - g1 * g2
30*da0073e9SAndroid Build Coastguard Worker    if d2_square >= 0:
31*da0073e9SAndroid Build Coastguard Worker        d2 = d2_square.sqrt()
32*da0073e9SAndroid Build Coastguard Worker        if x1 <= x2:
33*da0073e9SAndroid Build Coastguard Worker            min_pos = x2 - (x2 - x1) * ((g2 + d2 - d1) / (g2 - g1 + 2 * d2))
34*da0073e9SAndroid Build Coastguard Worker        else:
35*da0073e9SAndroid Build Coastguard Worker            min_pos = x1 - (x1 - x2) * ((g1 + d2 - d1) / (g1 - g2 + 2 * d2))
36*da0073e9SAndroid Build Coastguard Worker        return min(max(min_pos, xmin_bound), xmax_bound)
37*da0073e9SAndroid Build Coastguard Worker    else:
38*da0073e9SAndroid Build Coastguard Worker        return (xmin_bound + xmax_bound) / 2.0
39*da0073e9SAndroid Build Coastguard Worker
40*da0073e9SAndroid Build Coastguard Worker
41*da0073e9SAndroid Build Coastguard Workerdef _strong_wolfe(
42*da0073e9SAndroid Build Coastguard Worker    obj_func, x, t, d, f, g, gtd, c1=1e-4, c2=0.9, tolerance_change=1e-9, max_ls=25
43*da0073e9SAndroid Build Coastguard Worker):
44*da0073e9SAndroid Build Coastguard Worker    # ported from https://github.com/torch/optim/blob/master/lswolfe.lua
45*da0073e9SAndroid Build Coastguard Worker    d_norm = d.abs().max()
46*da0073e9SAndroid Build Coastguard Worker    g = g.clone(memory_format=torch.contiguous_format)
47*da0073e9SAndroid Build Coastguard Worker    # evaluate objective and gradient using initial step
48*da0073e9SAndroid Build Coastguard Worker    f_new, g_new = obj_func(x, t, d)
49*da0073e9SAndroid Build Coastguard Worker    ls_func_evals = 1
50*da0073e9SAndroid Build Coastguard Worker    gtd_new = g_new.dot(d)
51*da0073e9SAndroid Build Coastguard Worker
52*da0073e9SAndroid Build Coastguard Worker    # bracket an interval containing a point satisfying the Wolfe criteria
53*da0073e9SAndroid Build Coastguard Worker    t_prev, f_prev, g_prev, gtd_prev = 0, f, g, gtd
54*da0073e9SAndroid Build Coastguard Worker    done = False
55*da0073e9SAndroid Build Coastguard Worker    ls_iter = 0
56*da0073e9SAndroid Build Coastguard Worker    while ls_iter < max_ls:
57*da0073e9SAndroid Build Coastguard Worker        # check conditions
58*da0073e9SAndroid Build Coastguard Worker        if f_new > (f + c1 * t * gtd) or (ls_iter > 1 and f_new >= f_prev):
59*da0073e9SAndroid Build Coastguard Worker            bracket = [t_prev, t]
60*da0073e9SAndroid Build Coastguard Worker            bracket_f = [f_prev, f_new]
61*da0073e9SAndroid Build Coastguard Worker            bracket_g = [g_prev, g_new.clone(memory_format=torch.contiguous_format)]
62*da0073e9SAndroid Build Coastguard Worker            bracket_gtd = [gtd_prev, gtd_new]
63*da0073e9SAndroid Build Coastguard Worker            break
64*da0073e9SAndroid Build Coastguard Worker
65*da0073e9SAndroid Build Coastguard Worker        if abs(gtd_new) <= -c2 * gtd:
66*da0073e9SAndroid Build Coastguard Worker            bracket = [t]
67*da0073e9SAndroid Build Coastguard Worker            bracket_f = [f_new]
68*da0073e9SAndroid Build Coastguard Worker            bracket_g = [g_new]
69*da0073e9SAndroid Build Coastguard Worker            done = True
70*da0073e9SAndroid Build Coastguard Worker            break
71*da0073e9SAndroid Build Coastguard Worker
72*da0073e9SAndroid Build Coastguard Worker        if gtd_new >= 0:
73*da0073e9SAndroid Build Coastguard Worker            bracket = [t_prev, t]
74*da0073e9SAndroid Build Coastguard Worker            bracket_f = [f_prev, f_new]
75*da0073e9SAndroid Build Coastguard Worker            bracket_g = [g_prev, g_new.clone(memory_format=torch.contiguous_format)]
76*da0073e9SAndroid Build Coastguard Worker            bracket_gtd = [gtd_prev, gtd_new]
77*da0073e9SAndroid Build Coastguard Worker            break
78*da0073e9SAndroid Build Coastguard Worker
79*da0073e9SAndroid Build Coastguard Worker        # interpolate
80*da0073e9SAndroid Build Coastguard Worker        min_step = t + 0.01 * (t - t_prev)
81*da0073e9SAndroid Build Coastguard Worker        max_step = t * 10
82*da0073e9SAndroid Build Coastguard Worker        tmp = t
83*da0073e9SAndroid Build Coastguard Worker        t = _cubic_interpolate(
84*da0073e9SAndroid Build Coastguard Worker            t_prev, f_prev, gtd_prev, t, f_new, gtd_new, bounds=(min_step, max_step)
85*da0073e9SAndroid Build Coastguard Worker        )
86*da0073e9SAndroid Build Coastguard Worker
87*da0073e9SAndroid Build Coastguard Worker        # next step
88*da0073e9SAndroid Build Coastguard Worker        t_prev = tmp
89*da0073e9SAndroid Build Coastguard Worker        f_prev = f_new
90*da0073e9SAndroid Build Coastguard Worker        g_prev = g_new.clone(memory_format=torch.contiguous_format)
91*da0073e9SAndroid Build Coastguard Worker        gtd_prev = gtd_new
92*da0073e9SAndroid Build Coastguard Worker        f_new, g_new = obj_func(x, t, d)
93*da0073e9SAndroid Build Coastguard Worker        ls_func_evals += 1
94*da0073e9SAndroid Build Coastguard Worker        gtd_new = g_new.dot(d)
95*da0073e9SAndroid Build Coastguard Worker        ls_iter += 1
96*da0073e9SAndroid Build Coastguard Worker
97*da0073e9SAndroid Build Coastguard Worker    # reached max number of iterations?
98*da0073e9SAndroid Build Coastguard Worker    if ls_iter == max_ls:
99*da0073e9SAndroid Build Coastguard Worker        bracket = [0, t]
100*da0073e9SAndroid Build Coastguard Worker        bracket_f = [f, f_new]
101*da0073e9SAndroid Build Coastguard Worker        bracket_g = [g, g_new]
102*da0073e9SAndroid Build Coastguard Worker
103*da0073e9SAndroid Build Coastguard Worker    # zoom phase: we now have a point satisfying the criteria, or
104*da0073e9SAndroid Build Coastguard Worker    # a bracket around it. We refine the bracket until we find the
105*da0073e9SAndroid Build Coastguard Worker    # exact point satisfying the criteria
106*da0073e9SAndroid Build Coastguard Worker    insuf_progress = False
107*da0073e9SAndroid Build Coastguard Worker    # find high and low points in bracket
108*da0073e9SAndroid Build Coastguard Worker    low_pos, high_pos = (0, 1) if bracket_f[0] <= bracket_f[-1] else (1, 0)  # type: ignore[possibly-undefined]
109*da0073e9SAndroid Build Coastguard Worker    while not done and ls_iter < max_ls:
110*da0073e9SAndroid Build Coastguard Worker        # line-search bracket is so small
111*da0073e9SAndroid Build Coastguard Worker        if abs(bracket[1] - bracket[0]) * d_norm < tolerance_change:  # type: ignore[possibly-undefined]
112*da0073e9SAndroid Build Coastguard Worker            break
113*da0073e9SAndroid Build Coastguard Worker
114*da0073e9SAndroid Build Coastguard Worker        # compute new trial value
115*da0073e9SAndroid Build Coastguard Worker        t = _cubic_interpolate(
116*da0073e9SAndroid Build Coastguard Worker            bracket[0],
117*da0073e9SAndroid Build Coastguard Worker            bracket_f[0],
118*da0073e9SAndroid Build Coastguard Worker            bracket_gtd[0],  # type: ignore[possibly-undefined]
119*da0073e9SAndroid Build Coastguard Worker            bracket[1],
120*da0073e9SAndroid Build Coastguard Worker            bracket_f[1],
121*da0073e9SAndroid Build Coastguard Worker            bracket_gtd[1],
122*da0073e9SAndroid Build Coastguard Worker        )
123*da0073e9SAndroid Build Coastguard Worker
124*da0073e9SAndroid Build Coastguard Worker        # test that we are making sufficient progress:
125*da0073e9SAndroid Build Coastguard Worker        # in case `t` is so close to boundary, we mark that we are making
126*da0073e9SAndroid Build Coastguard Worker        # insufficient progress, and if
127*da0073e9SAndroid Build Coastguard Worker        #   + we have made insufficient progress in the last step, or
128*da0073e9SAndroid Build Coastguard Worker        #   + `t` is at one of the boundary,
129*da0073e9SAndroid Build Coastguard Worker        # we will move `t` to a position which is `0.1 * len(bracket)`
130*da0073e9SAndroid Build Coastguard Worker        # away from the nearest boundary point.
131*da0073e9SAndroid Build Coastguard Worker        eps = 0.1 * (max(bracket) - min(bracket))
132*da0073e9SAndroid Build Coastguard Worker        if min(max(bracket) - t, t - min(bracket)) < eps:
133*da0073e9SAndroid Build Coastguard Worker            # interpolation close to boundary
134*da0073e9SAndroid Build Coastguard Worker            if insuf_progress or t >= max(bracket) or t <= min(bracket):
135*da0073e9SAndroid Build Coastguard Worker                # evaluate at 0.1 away from boundary
136*da0073e9SAndroid Build Coastguard Worker                if abs(t - max(bracket)) < abs(t - min(bracket)):
137*da0073e9SAndroid Build Coastguard Worker                    t = max(bracket) - eps
138*da0073e9SAndroid Build Coastguard Worker                else:
139*da0073e9SAndroid Build Coastguard Worker                    t = min(bracket) + eps
140*da0073e9SAndroid Build Coastguard Worker                insuf_progress = False
141*da0073e9SAndroid Build Coastguard Worker            else:
142*da0073e9SAndroid Build Coastguard Worker                insuf_progress = True
143*da0073e9SAndroid Build Coastguard Worker        else:
144*da0073e9SAndroid Build Coastguard Worker            insuf_progress = False
145*da0073e9SAndroid Build Coastguard Worker
146*da0073e9SAndroid Build Coastguard Worker        # Evaluate new point
147*da0073e9SAndroid Build Coastguard Worker        f_new, g_new = obj_func(x, t, d)
148*da0073e9SAndroid Build Coastguard Worker        ls_func_evals += 1
149*da0073e9SAndroid Build Coastguard Worker        gtd_new = g_new.dot(d)
150*da0073e9SAndroid Build Coastguard Worker        ls_iter += 1
151*da0073e9SAndroid Build Coastguard Worker
152*da0073e9SAndroid Build Coastguard Worker        if f_new > (f + c1 * t * gtd) or f_new >= bracket_f[low_pos]:
153*da0073e9SAndroid Build Coastguard Worker            # Armijo condition not satisfied or not lower than lowest point
154*da0073e9SAndroid Build Coastguard Worker            bracket[high_pos] = t
155*da0073e9SAndroid Build Coastguard Worker            bracket_f[high_pos] = f_new
156*da0073e9SAndroid Build Coastguard Worker            bracket_g[high_pos] = g_new.clone(memory_format=torch.contiguous_format)  # type: ignore[possibly-undefined]
157*da0073e9SAndroid Build Coastguard Worker            bracket_gtd[high_pos] = gtd_new
158*da0073e9SAndroid Build Coastguard Worker            low_pos, high_pos = (0, 1) if bracket_f[0] <= bracket_f[1] else (1, 0)
159*da0073e9SAndroid Build Coastguard Worker        else:
160*da0073e9SAndroid Build Coastguard Worker            if abs(gtd_new) <= -c2 * gtd:
161*da0073e9SAndroid Build Coastguard Worker                # Wolfe conditions satisfied
162*da0073e9SAndroid Build Coastguard Worker                done = True
163*da0073e9SAndroid Build Coastguard Worker            elif gtd_new * (bracket[high_pos] - bracket[low_pos]) >= 0:
164*da0073e9SAndroid Build Coastguard Worker                # old high becomes new low
165*da0073e9SAndroid Build Coastguard Worker                bracket[high_pos] = bracket[low_pos]
166*da0073e9SAndroid Build Coastguard Worker                bracket_f[high_pos] = bracket_f[low_pos]
167*da0073e9SAndroid Build Coastguard Worker                bracket_g[high_pos] = bracket_g[low_pos]  # type: ignore[possibly-undefined]
168*da0073e9SAndroid Build Coastguard Worker                bracket_gtd[high_pos] = bracket_gtd[low_pos]
169*da0073e9SAndroid Build Coastguard Worker
170*da0073e9SAndroid Build Coastguard Worker            # new point becomes new low
171*da0073e9SAndroid Build Coastguard Worker            bracket[low_pos] = t
172*da0073e9SAndroid Build Coastguard Worker            bracket_f[low_pos] = f_new
173*da0073e9SAndroid Build Coastguard Worker            bracket_g[low_pos] = g_new.clone(memory_format=torch.contiguous_format)  # type: ignore[possibly-undefined]
174*da0073e9SAndroid Build Coastguard Worker            bracket_gtd[low_pos] = gtd_new
175*da0073e9SAndroid Build Coastguard Worker
176*da0073e9SAndroid Build Coastguard Worker    # return stuff
177*da0073e9SAndroid Build Coastguard Worker    t = bracket[low_pos]  # type: ignore[possibly-undefined]
178*da0073e9SAndroid Build Coastguard Worker    f_new = bracket_f[low_pos]
179*da0073e9SAndroid Build Coastguard Worker    g_new = bracket_g[low_pos]  # type: ignore[possibly-undefined]
180*da0073e9SAndroid Build Coastguard Worker    return f_new, g_new, t, ls_func_evals
181*da0073e9SAndroid Build Coastguard Worker
182*da0073e9SAndroid Build Coastguard Worker
183*da0073e9SAndroid Build Coastguard Workerclass LBFGS(Optimizer):
184*da0073e9SAndroid Build Coastguard Worker    """Implements L-BFGS algorithm.
185*da0073e9SAndroid Build Coastguard Worker
186*da0073e9SAndroid Build Coastguard Worker    Heavily inspired by `minFunc
187*da0073e9SAndroid Build Coastguard Worker    <https://www.cs.ubc.ca/~schmidtm/Software/minFunc.html>`_.
188*da0073e9SAndroid Build Coastguard Worker
189*da0073e9SAndroid Build Coastguard Worker    .. warning::
190*da0073e9SAndroid Build Coastguard Worker        This optimizer doesn't support per-parameter options and parameter
191*da0073e9SAndroid Build Coastguard Worker        groups (there can be only one).
192*da0073e9SAndroid Build Coastguard Worker
193*da0073e9SAndroid Build Coastguard Worker    .. warning::
194*da0073e9SAndroid Build Coastguard Worker        Right now all parameters have to be on a single device. This will be
195*da0073e9SAndroid Build Coastguard Worker        improved in the future.
196*da0073e9SAndroid Build Coastguard Worker
197*da0073e9SAndroid Build Coastguard Worker    .. note::
198*da0073e9SAndroid Build Coastguard Worker        This is a very memory intensive optimizer (it requires additional
199*da0073e9SAndroid Build Coastguard Worker        ``param_bytes * (history_size + 1)`` bytes). If it doesn't fit in memory
200*da0073e9SAndroid Build Coastguard Worker        try reducing the history size, or use a different algorithm.
201*da0073e9SAndroid Build Coastguard Worker
202*da0073e9SAndroid Build Coastguard Worker    Args:
203*da0073e9SAndroid Build Coastguard Worker        params (iterable): iterable of parameters to optimize. Parameters must be real.
204*da0073e9SAndroid Build Coastguard Worker        lr (float): learning rate (default: 1)
205*da0073e9SAndroid Build Coastguard Worker        max_iter (int): maximal number of iterations per optimization step
206*da0073e9SAndroid Build Coastguard Worker            (default: 20)
207*da0073e9SAndroid Build Coastguard Worker        max_eval (int): maximal number of function evaluations per optimization
208*da0073e9SAndroid Build Coastguard Worker            step (default: max_iter * 1.25).
209*da0073e9SAndroid Build Coastguard Worker        tolerance_grad (float): termination tolerance on first order optimality
210*da0073e9SAndroid Build Coastguard Worker            (default: 1e-7).
211*da0073e9SAndroid Build Coastguard Worker        tolerance_change (float): termination tolerance on function
212*da0073e9SAndroid Build Coastguard Worker            value/parameter changes (default: 1e-9).
213*da0073e9SAndroid Build Coastguard Worker        history_size (int): update history size (default: 100).
214*da0073e9SAndroid Build Coastguard Worker        line_search_fn (str): either 'strong_wolfe' or None (default: None).
215*da0073e9SAndroid Build Coastguard Worker    """
216*da0073e9SAndroid Build Coastguard Worker
217*da0073e9SAndroid Build Coastguard Worker    def __init__(
218*da0073e9SAndroid Build Coastguard Worker        self,
219*da0073e9SAndroid Build Coastguard Worker        params: ParamsT,
220*da0073e9SAndroid Build Coastguard Worker        lr: Union[float, Tensor] = 1,
221*da0073e9SAndroid Build Coastguard Worker        max_iter: int = 20,
222*da0073e9SAndroid Build Coastguard Worker        max_eval: Optional[int] = None,
223*da0073e9SAndroid Build Coastguard Worker        tolerance_grad: float = 1e-7,
224*da0073e9SAndroid Build Coastguard Worker        tolerance_change: float = 1e-9,
225*da0073e9SAndroid Build Coastguard Worker        history_size: int = 100,
226*da0073e9SAndroid Build Coastguard Worker        line_search_fn: Optional[str] = None,
227*da0073e9SAndroid Build Coastguard Worker    ):
228*da0073e9SAndroid Build Coastguard Worker        if isinstance(lr, Tensor) and lr.numel() != 1:
229*da0073e9SAndroid Build Coastguard Worker            raise ValueError("Tensor lr must be 1-element")
230*da0073e9SAndroid Build Coastguard Worker        if not 0.0 <= lr:
231*da0073e9SAndroid Build Coastguard Worker            raise ValueError(f"Invalid learning rate: {lr}")
232*da0073e9SAndroid Build Coastguard Worker        if max_eval is None:
233*da0073e9SAndroid Build Coastguard Worker            max_eval = max_iter * 5 // 4
234*da0073e9SAndroid Build Coastguard Worker        defaults = dict(
235*da0073e9SAndroid Build Coastguard Worker            lr=lr,
236*da0073e9SAndroid Build Coastguard Worker            max_iter=max_iter,
237*da0073e9SAndroid Build Coastguard Worker            max_eval=max_eval,
238*da0073e9SAndroid Build Coastguard Worker            tolerance_grad=tolerance_grad,
239*da0073e9SAndroid Build Coastguard Worker            tolerance_change=tolerance_change,
240*da0073e9SAndroid Build Coastguard Worker            history_size=history_size,
241*da0073e9SAndroid Build Coastguard Worker            line_search_fn=line_search_fn,
242*da0073e9SAndroid Build Coastguard Worker        )
243*da0073e9SAndroid Build Coastguard Worker        super().__init__(params, defaults)
244*da0073e9SAndroid Build Coastguard Worker
245*da0073e9SAndroid Build Coastguard Worker        if len(self.param_groups) != 1:
246*da0073e9SAndroid Build Coastguard Worker            raise ValueError(
247*da0073e9SAndroid Build Coastguard Worker                "LBFGS doesn't support per-parameter options " "(parameter groups)"
248*da0073e9SAndroid Build Coastguard Worker            )
249*da0073e9SAndroid Build Coastguard Worker
250*da0073e9SAndroid Build Coastguard Worker        self._params = self.param_groups[0]["params"]
251*da0073e9SAndroid Build Coastguard Worker        self._numel_cache = None
252*da0073e9SAndroid Build Coastguard Worker
253*da0073e9SAndroid Build Coastguard Worker    def _numel(self):
254*da0073e9SAndroid Build Coastguard Worker        if self._numel_cache is None:
255*da0073e9SAndroid Build Coastguard Worker            self._numel_cache = sum(
256*da0073e9SAndroid Build Coastguard Worker                2 * p.numel() if torch.is_complex(p) else p.numel()
257*da0073e9SAndroid Build Coastguard Worker                for p in self._params
258*da0073e9SAndroid Build Coastguard Worker            )
259*da0073e9SAndroid Build Coastguard Worker
260*da0073e9SAndroid Build Coastguard Worker        return self._numel_cache
261*da0073e9SAndroid Build Coastguard Worker
262*da0073e9SAndroid Build Coastguard Worker    def _gather_flat_grad(self):
263*da0073e9SAndroid Build Coastguard Worker        views = []
264*da0073e9SAndroid Build Coastguard Worker        for p in self._params:
265*da0073e9SAndroid Build Coastguard Worker            if p.grad is None:
266*da0073e9SAndroid Build Coastguard Worker                view = p.new(p.numel()).zero_()
267*da0073e9SAndroid Build Coastguard Worker            elif p.grad.is_sparse:
268*da0073e9SAndroid Build Coastguard Worker                view = p.grad.to_dense().view(-1)
269*da0073e9SAndroid Build Coastguard Worker            else:
270*da0073e9SAndroid Build Coastguard Worker                view = p.grad.view(-1)
271*da0073e9SAndroid Build Coastguard Worker            if torch.is_complex(view):
272*da0073e9SAndroid Build Coastguard Worker                view = torch.view_as_real(view).view(-1)
273*da0073e9SAndroid Build Coastguard Worker            views.append(view)
274*da0073e9SAndroid Build Coastguard Worker        return torch.cat(views, 0)
275*da0073e9SAndroid Build Coastguard Worker
276*da0073e9SAndroid Build Coastguard Worker    def _add_grad(self, step_size, update):
277*da0073e9SAndroid Build Coastguard Worker        offset = 0
278*da0073e9SAndroid Build Coastguard Worker        for p in self._params:
279*da0073e9SAndroid Build Coastguard Worker            if torch.is_complex(p):
280*da0073e9SAndroid Build Coastguard Worker                p = torch.view_as_real(p)
281*da0073e9SAndroid Build Coastguard Worker            numel = p.numel()
282*da0073e9SAndroid Build Coastguard Worker            # view as to avoid deprecated pointwise semantics
283*da0073e9SAndroid Build Coastguard Worker            p.add_(update[offset : offset + numel].view_as(p), alpha=step_size)
284*da0073e9SAndroid Build Coastguard Worker            offset += numel
285*da0073e9SAndroid Build Coastguard Worker        assert offset == self._numel()
286*da0073e9SAndroid Build Coastguard Worker
287*da0073e9SAndroid Build Coastguard Worker    def _clone_param(self):
288*da0073e9SAndroid Build Coastguard Worker        return [p.clone(memory_format=torch.contiguous_format) for p in self._params]
289*da0073e9SAndroid Build Coastguard Worker
290*da0073e9SAndroid Build Coastguard Worker    def _set_param(self, params_data):
291*da0073e9SAndroid Build Coastguard Worker        for p, pdata in zip(self._params, params_data):
292*da0073e9SAndroid Build Coastguard Worker            p.copy_(pdata)
293*da0073e9SAndroid Build Coastguard Worker
294*da0073e9SAndroid Build Coastguard Worker    def _directional_evaluate(self, closure, x, t, d):
295*da0073e9SAndroid Build Coastguard Worker        self._add_grad(t, d)
296*da0073e9SAndroid Build Coastguard Worker        loss = float(closure())
297*da0073e9SAndroid Build Coastguard Worker        flat_grad = self._gather_flat_grad()
298*da0073e9SAndroid Build Coastguard Worker        self._set_param(x)
299*da0073e9SAndroid Build Coastguard Worker        return loss, flat_grad
300*da0073e9SAndroid Build Coastguard Worker
301*da0073e9SAndroid Build Coastguard Worker    @torch.no_grad()
302*da0073e9SAndroid Build Coastguard Worker    def step(self, closure):
303*da0073e9SAndroid Build Coastguard Worker        """Perform a single optimization step.
304*da0073e9SAndroid Build Coastguard Worker
305*da0073e9SAndroid Build Coastguard Worker        Args:
306*da0073e9SAndroid Build Coastguard Worker            closure (Callable): A closure that reevaluates the model
307*da0073e9SAndroid Build Coastguard Worker                and returns the loss.
308*da0073e9SAndroid Build Coastguard Worker        """
309*da0073e9SAndroid Build Coastguard Worker        assert len(self.param_groups) == 1
310*da0073e9SAndroid Build Coastguard Worker
311*da0073e9SAndroid Build Coastguard Worker        # Make sure the closure is always called with grad enabled
312*da0073e9SAndroid Build Coastguard Worker        closure = torch.enable_grad()(closure)
313*da0073e9SAndroid Build Coastguard Worker
314*da0073e9SAndroid Build Coastguard Worker        group = self.param_groups[0]
315*da0073e9SAndroid Build Coastguard Worker        lr = group["lr"]
316*da0073e9SAndroid Build Coastguard Worker        max_iter = group["max_iter"]
317*da0073e9SAndroid Build Coastguard Worker        max_eval = group["max_eval"]
318*da0073e9SAndroid Build Coastguard Worker        tolerance_grad = group["tolerance_grad"]
319*da0073e9SAndroid Build Coastguard Worker        tolerance_change = group["tolerance_change"]
320*da0073e9SAndroid Build Coastguard Worker        line_search_fn = group["line_search_fn"]
321*da0073e9SAndroid Build Coastguard Worker        history_size = group["history_size"]
322*da0073e9SAndroid Build Coastguard Worker
323*da0073e9SAndroid Build Coastguard Worker        # NOTE: LBFGS has only global state, but we register it as state for
324*da0073e9SAndroid Build Coastguard Worker        # the first param, because this helps with casting in load_state_dict
325*da0073e9SAndroid Build Coastguard Worker        state = self.state[self._params[0]]
326*da0073e9SAndroid Build Coastguard Worker        state.setdefault("func_evals", 0)
327*da0073e9SAndroid Build Coastguard Worker        state.setdefault("n_iter", 0)
328*da0073e9SAndroid Build Coastguard Worker
329*da0073e9SAndroid Build Coastguard Worker        # evaluate initial f(x) and df/dx
330*da0073e9SAndroid Build Coastguard Worker        orig_loss = closure()
331*da0073e9SAndroid Build Coastguard Worker        loss = float(orig_loss)
332*da0073e9SAndroid Build Coastguard Worker        current_evals = 1
333*da0073e9SAndroid Build Coastguard Worker        state["func_evals"] += 1
334*da0073e9SAndroid Build Coastguard Worker
335*da0073e9SAndroid Build Coastguard Worker        flat_grad = self._gather_flat_grad()
336*da0073e9SAndroid Build Coastguard Worker        opt_cond = flat_grad.abs().max() <= tolerance_grad
337*da0073e9SAndroid Build Coastguard Worker
338*da0073e9SAndroid Build Coastguard Worker        # optimal condition
339*da0073e9SAndroid Build Coastguard Worker        if opt_cond:
340*da0073e9SAndroid Build Coastguard Worker            return orig_loss
341*da0073e9SAndroid Build Coastguard Worker
342*da0073e9SAndroid Build Coastguard Worker        # tensors cached in state (for tracing)
343*da0073e9SAndroid Build Coastguard Worker        d = state.get("d")
344*da0073e9SAndroid Build Coastguard Worker        t = state.get("t")
345*da0073e9SAndroid Build Coastguard Worker        old_dirs = state.get("old_dirs")
346*da0073e9SAndroid Build Coastguard Worker        old_stps = state.get("old_stps")
347*da0073e9SAndroid Build Coastguard Worker        ro = state.get("ro")
348*da0073e9SAndroid Build Coastguard Worker        H_diag = state.get("H_diag")
349*da0073e9SAndroid Build Coastguard Worker        prev_flat_grad = state.get("prev_flat_grad")
350*da0073e9SAndroid Build Coastguard Worker        prev_loss = state.get("prev_loss")
351*da0073e9SAndroid Build Coastguard Worker
352*da0073e9SAndroid Build Coastguard Worker        n_iter = 0
353*da0073e9SAndroid Build Coastguard Worker        # optimize for a max of max_iter iterations
354*da0073e9SAndroid Build Coastguard Worker        while n_iter < max_iter:
355*da0073e9SAndroid Build Coastguard Worker            # keep track of nb of iterations
356*da0073e9SAndroid Build Coastguard Worker            n_iter += 1
357*da0073e9SAndroid Build Coastguard Worker            state["n_iter"] += 1
358*da0073e9SAndroid Build Coastguard Worker
359*da0073e9SAndroid Build Coastguard Worker            ############################################################
360*da0073e9SAndroid Build Coastguard Worker            # compute gradient descent direction
361*da0073e9SAndroid Build Coastguard Worker            ############################################################
362*da0073e9SAndroid Build Coastguard Worker            if state["n_iter"] == 1:
363*da0073e9SAndroid Build Coastguard Worker                d = flat_grad.neg()
364*da0073e9SAndroid Build Coastguard Worker                old_dirs = []
365*da0073e9SAndroid Build Coastguard Worker                old_stps = []
366*da0073e9SAndroid Build Coastguard Worker                ro = []
367*da0073e9SAndroid Build Coastguard Worker                H_diag = 1
368*da0073e9SAndroid Build Coastguard Worker            else:
369*da0073e9SAndroid Build Coastguard Worker                # do lbfgs update (update memory)
370*da0073e9SAndroid Build Coastguard Worker                y = flat_grad.sub(prev_flat_grad)
371*da0073e9SAndroid Build Coastguard Worker                s = d.mul(t)
372*da0073e9SAndroid Build Coastguard Worker                ys = y.dot(s)  # y*s
373*da0073e9SAndroid Build Coastguard Worker                if ys > 1e-10:
374*da0073e9SAndroid Build Coastguard Worker                    # updating memory
375*da0073e9SAndroid Build Coastguard Worker                    if len(old_dirs) == history_size:
376*da0073e9SAndroid Build Coastguard Worker                        # shift history by one (limited-memory)
377*da0073e9SAndroid Build Coastguard Worker                        old_dirs.pop(0)
378*da0073e9SAndroid Build Coastguard Worker                        old_stps.pop(0)
379*da0073e9SAndroid Build Coastguard Worker                        ro.pop(0)
380*da0073e9SAndroid Build Coastguard Worker
381*da0073e9SAndroid Build Coastguard Worker                    # store new direction/step
382*da0073e9SAndroid Build Coastguard Worker                    old_dirs.append(y)
383*da0073e9SAndroid Build Coastguard Worker                    old_stps.append(s)
384*da0073e9SAndroid Build Coastguard Worker                    ro.append(1.0 / ys)
385*da0073e9SAndroid Build Coastguard Worker
386*da0073e9SAndroid Build Coastguard Worker                    # update scale of initial Hessian approximation
387*da0073e9SAndroid Build Coastguard Worker                    H_diag = ys / y.dot(y)  # (y*y)
388*da0073e9SAndroid Build Coastguard Worker
389*da0073e9SAndroid Build Coastguard Worker                # compute the approximate (L-BFGS) inverse Hessian
390*da0073e9SAndroid Build Coastguard Worker                # multiplied by the gradient
391*da0073e9SAndroid Build Coastguard Worker                num_old = len(old_dirs)
392*da0073e9SAndroid Build Coastguard Worker
393*da0073e9SAndroid Build Coastguard Worker                if "al" not in state:
394*da0073e9SAndroid Build Coastguard Worker                    state["al"] = [None] * history_size
395*da0073e9SAndroid Build Coastguard Worker                al = state["al"]
396*da0073e9SAndroid Build Coastguard Worker
397*da0073e9SAndroid Build Coastguard Worker                # iteration in L-BFGS loop collapsed to use just one buffer
398*da0073e9SAndroid Build Coastguard Worker                q = flat_grad.neg()
399*da0073e9SAndroid Build Coastguard Worker                for i in range(num_old - 1, -1, -1):
400*da0073e9SAndroid Build Coastguard Worker                    al[i] = old_stps[i].dot(q) * ro[i]
401*da0073e9SAndroid Build Coastguard Worker                    q.add_(old_dirs[i], alpha=-al[i])
402*da0073e9SAndroid Build Coastguard Worker
403*da0073e9SAndroid Build Coastguard Worker                # multiply by initial Hessian
404*da0073e9SAndroid Build Coastguard Worker                # r/d is the final direction
405*da0073e9SAndroid Build Coastguard Worker                d = r = torch.mul(q, H_diag)
406*da0073e9SAndroid Build Coastguard Worker                for i in range(num_old):
407*da0073e9SAndroid Build Coastguard Worker                    be_i = old_dirs[i].dot(r) * ro[i]
408*da0073e9SAndroid Build Coastguard Worker                    r.add_(old_stps[i], alpha=al[i] - be_i)
409*da0073e9SAndroid Build Coastguard Worker
410*da0073e9SAndroid Build Coastguard Worker            if prev_flat_grad is None:
411*da0073e9SAndroid Build Coastguard Worker                prev_flat_grad = flat_grad.clone(memory_format=torch.contiguous_format)
412*da0073e9SAndroid Build Coastguard Worker            else:
413*da0073e9SAndroid Build Coastguard Worker                prev_flat_grad.copy_(flat_grad)
414*da0073e9SAndroid Build Coastguard Worker            prev_loss = loss
415*da0073e9SAndroid Build Coastguard Worker
416*da0073e9SAndroid Build Coastguard Worker            ############################################################
417*da0073e9SAndroid Build Coastguard Worker            # compute step length
418*da0073e9SAndroid Build Coastguard Worker            ############################################################
419*da0073e9SAndroid Build Coastguard Worker            # reset initial guess for step size
420*da0073e9SAndroid Build Coastguard Worker            if state["n_iter"] == 1:
421*da0073e9SAndroid Build Coastguard Worker                t = min(1.0, 1.0 / flat_grad.abs().sum()) * lr
422*da0073e9SAndroid Build Coastguard Worker            else:
423*da0073e9SAndroid Build Coastguard Worker                t = lr
424*da0073e9SAndroid Build Coastguard Worker
425*da0073e9SAndroid Build Coastguard Worker            # directional derivative
426*da0073e9SAndroid Build Coastguard Worker            gtd = flat_grad.dot(d)  # g * d
427*da0073e9SAndroid Build Coastguard Worker
428*da0073e9SAndroid Build Coastguard Worker            # directional derivative is below tolerance
429*da0073e9SAndroid Build Coastguard Worker            if gtd > -tolerance_change:
430*da0073e9SAndroid Build Coastguard Worker                break
431*da0073e9SAndroid Build Coastguard Worker
432*da0073e9SAndroid Build Coastguard Worker            # optional line search: user function
433*da0073e9SAndroid Build Coastguard Worker            ls_func_evals = 0
434*da0073e9SAndroid Build Coastguard Worker            if line_search_fn is not None:
435*da0073e9SAndroid Build Coastguard Worker                # perform line search, using user function
436*da0073e9SAndroid Build Coastguard Worker                if line_search_fn != "strong_wolfe":
437*da0073e9SAndroid Build Coastguard Worker                    raise RuntimeError("only 'strong_wolfe' is supported")
438*da0073e9SAndroid Build Coastguard Worker                else:
439*da0073e9SAndroid Build Coastguard Worker                    x_init = self._clone_param()
440*da0073e9SAndroid Build Coastguard Worker
441*da0073e9SAndroid Build Coastguard Worker                    def obj_func(x, t, d):
442*da0073e9SAndroid Build Coastguard Worker                        return self._directional_evaluate(closure, x, t, d)
443*da0073e9SAndroid Build Coastguard Worker
444*da0073e9SAndroid Build Coastguard Worker                    loss, flat_grad, t, ls_func_evals = _strong_wolfe(
445*da0073e9SAndroid Build Coastguard Worker                        obj_func, x_init, t, d, loss, flat_grad, gtd
446*da0073e9SAndroid Build Coastguard Worker                    )
447*da0073e9SAndroid Build Coastguard Worker                self._add_grad(t, d)
448*da0073e9SAndroid Build Coastguard Worker                opt_cond = flat_grad.abs().max() <= tolerance_grad
449*da0073e9SAndroid Build Coastguard Worker            else:
450*da0073e9SAndroid Build Coastguard Worker                # no line search, simply move with fixed-step
451*da0073e9SAndroid Build Coastguard Worker                self._add_grad(t, d)
452*da0073e9SAndroid Build Coastguard Worker                if n_iter != max_iter:
453*da0073e9SAndroid Build Coastguard Worker                    # re-evaluate function only if not in last iteration
454*da0073e9SAndroid Build Coastguard Worker                    # the reason we do this: in a stochastic setting,
455*da0073e9SAndroid Build Coastguard Worker                    # no use to re-evaluate that function here
456*da0073e9SAndroid Build Coastguard Worker                    with torch.enable_grad():
457*da0073e9SAndroid Build Coastguard Worker                        loss = float(closure())
458*da0073e9SAndroid Build Coastguard Worker                    flat_grad = self._gather_flat_grad()
459*da0073e9SAndroid Build Coastguard Worker                    opt_cond = flat_grad.abs().max() <= tolerance_grad
460*da0073e9SAndroid Build Coastguard Worker                    ls_func_evals = 1
461*da0073e9SAndroid Build Coastguard Worker
462*da0073e9SAndroid Build Coastguard Worker            # update func eval
463*da0073e9SAndroid Build Coastguard Worker            current_evals += ls_func_evals
464*da0073e9SAndroid Build Coastguard Worker            state["func_evals"] += ls_func_evals
465*da0073e9SAndroid Build Coastguard Worker
466*da0073e9SAndroid Build Coastguard Worker            ############################################################
467*da0073e9SAndroid Build Coastguard Worker            # check conditions
468*da0073e9SAndroid Build Coastguard Worker            ############################################################
469*da0073e9SAndroid Build Coastguard Worker            if n_iter == max_iter:
470*da0073e9SAndroid Build Coastguard Worker                break
471*da0073e9SAndroid Build Coastguard Worker
472*da0073e9SAndroid Build Coastguard Worker            if current_evals >= max_eval:
473*da0073e9SAndroid Build Coastguard Worker                break
474*da0073e9SAndroid Build Coastguard Worker
475*da0073e9SAndroid Build Coastguard Worker            # optimal condition
476*da0073e9SAndroid Build Coastguard Worker            if opt_cond:
477*da0073e9SAndroid Build Coastguard Worker                break
478*da0073e9SAndroid Build Coastguard Worker
479*da0073e9SAndroid Build Coastguard Worker            # lack of progress
480*da0073e9SAndroid Build Coastguard Worker            if d.mul(t).abs().max() <= tolerance_change:
481*da0073e9SAndroid Build Coastguard Worker                break
482*da0073e9SAndroid Build Coastguard Worker
483*da0073e9SAndroid Build Coastguard Worker            if abs(loss - prev_loss) < tolerance_change:
484*da0073e9SAndroid Build Coastguard Worker                break
485*da0073e9SAndroid Build Coastguard Worker
486*da0073e9SAndroid Build Coastguard Worker        state["d"] = d
487*da0073e9SAndroid Build Coastguard Worker        state["t"] = t
488*da0073e9SAndroid Build Coastguard Worker        state["old_dirs"] = old_dirs
489*da0073e9SAndroid Build Coastguard Worker        state["old_stps"] = old_stps
490*da0073e9SAndroid Build Coastguard Worker        state["ro"] = ro
491*da0073e9SAndroid Build Coastguard Worker        state["H_diag"] = H_diag
492*da0073e9SAndroid Build Coastguard Worker        state["prev_flat_grad"] = prev_flat_grad
493*da0073e9SAndroid Build Coastguard Worker        state["prev_loss"] = prev_loss
494*da0073e9SAndroid Build Coastguard Worker
495*da0073e9SAndroid Build Coastguard Worker        return orig_loss
496