1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs 2*da0073e9SAndroid Build Coastguard Workerfrom typing import Optional, Union 3*da0073e9SAndroid Build Coastguard Worker 4*da0073e9SAndroid Build Coastguard Workerimport torch 5*da0073e9SAndroid Build Coastguard Workerfrom torch import Tensor 6*da0073e9SAndroid Build Coastguard Worker 7*da0073e9SAndroid Build Coastguard Workerfrom .optimizer import Optimizer, ParamsT 8*da0073e9SAndroid Build Coastguard Worker 9*da0073e9SAndroid Build Coastguard Worker 10*da0073e9SAndroid Build Coastguard Worker__all__ = ["LBFGS"] 11*da0073e9SAndroid Build Coastguard Worker 12*da0073e9SAndroid Build Coastguard Worker 13*da0073e9SAndroid Build Coastguard Workerdef _cubic_interpolate(x1, f1, g1, x2, f2, g2, bounds=None): 14*da0073e9SAndroid Build Coastguard Worker # ported from https://github.com/torch/optim/blob/master/polyinterp.lua 15*da0073e9SAndroid Build Coastguard Worker # Compute bounds of interpolation area 16*da0073e9SAndroid Build Coastguard Worker if bounds is not None: 17*da0073e9SAndroid Build Coastguard Worker xmin_bound, xmax_bound = bounds 18*da0073e9SAndroid Build Coastguard Worker else: 19*da0073e9SAndroid Build Coastguard Worker xmin_bound, xmax_bound = (x1, x2) if x1 <= x2 else (x2, x1) 20*da0073e9SAndroid Build Coastguard Worker 21*da0073e9SAndroid Build Coastguard Worker # Code for most common case: cubic interpolation of 2 points 22*da0073e9SAndroid Build Coastguard Worker # w/ function and derivative values for both 23*da0073e9SAndroid Build Coastguard Worker # Solution in this case (where x2 is the farthest point): 24*da0073e9SAndroid Build Coastguard Worker # d1 = g1 + g2 - 3*(f1-f2)/(x1-x2); 25*da0073e9SAndroid Build Coastguard Worker # d2 = sqrt(d1^2 - g1*g2); 26*da0073e9SAndroid Build Coastguard Worker # min_pos = x2 - (x2 - x1)*((g2 + d2 - d1)/(g2 - g1 + 2*d2)); 27*da0073e9SAndroid Build Coastguard Worker # t_new = min(max(min_pos,xmin_bound),xmax_bound); 28*da0073e9SAndroid Build Coastguard Worker d1 = g1 + g2 - 3 * (f1 - f2) / (x1 - x2) 29*da0073e9SAndroid Build Coastguard Worker d2_square = d1**2 - g1 * g2 30*da0073e9SAndroid Build Coastguard Worker if d2_square >= 0: 31*da0073e9SAndroid Build Coastguard Worker d2 = d2_square.sqrt() 32*da0073e9SAndroid Build Coastguard Worker if x1 <= x2: 33*da0073e9SAndroid Build Coastguard Worker min_pos = x2 - (x2 - x1) * ((g2 + d2 - d1) / (g2 - g1 + 2 * d2)) 34*da0073e9SAndroid Build Coastguard Worker else: 35*da0073e9SAndroid Build Coastguard Worker min_pos = x1 - (x1 - x2) * ((g1 + d2 - d1) / (g1 - g2 + 2 * d2)) 36*da0073e9SAndroid Build Coastguard Worker return min(max(min_pos, xmin_bound), xmax_bound) 37*da0073e9SAndroid Build Coastguard Worker else: 38*da0073e9SAndroid Build Coastguard Worker return (xmin_bound + xmax_bound) / 2.0 39*da0073e9SAndroid Build Coastguard Worker 40*da0073e9SAndroid Build Coastguard Worker 41*da0073e9SAndroid Build Coastguard Workerdef _strong_wolfe( 42*da0073e9SAndroid Build Coastguard Worker obj_func, x, t, d, f, g, gtd, c1=1e-4, c2=0.9, tolerance_change=1e-9, max_ls=25 43*da0073e9SAndroid Build Coastguard Worker): 44*da0073e9SAndroid Build Coastguard Worker # ported from https://github.com/torch/optim/blob/master/lswolfe.lua 45*da0073e9SAndroid Build Coastguard Worker d_norm = d.abs().max() 46*da0073e9SAndroid Build Coastguard Worker g = g.clone(memory_format=torch.contiguous_format) 47*da0073e9SAndroid Build Coastguard Worker # evaluate objective and gradient using initial step 48*da0073e9SAndroid Build Coastguard Worker f_new, g_new = obj_func(x, t, d) 49*da0073e9SAndroid Build Coastguard Worker ls_func_evals = 1 50*da0073e9SAndroid Build Coastguard Worker gtd_new = g_new.dot(d) 51*da0073e9SAndroid Build Coastguard Worker 52*da0073e9SAndroid Build Coastguard Worker # bracket an interval containing a point satisfying the Wolfe criteria 53*da0073e9SAndroid Build Coastguard Worker t_prev, f_prev, g_prev, gtd_prev = 0, f, g, gtd 54*da0073e9SAndroid Build Coastguard Worker done = False 55*da0073e9SAndroid Build Coastguard Worker ls_iter = 0 56*da0073e9SAndroid Build Coastguard Worker while ls_iter < max_ls: 57*da0073e9SAndroid Build Coastguard Worker # check conditions 58*da0073e9SAndroid Build Coastguard Worker if f_new > (f + c1 * t * gtd) or (ls_iter > 1 and f_new >= f_prev): 59*da0073e9SAndroid Build Coastguard Worker bracket = [t_prev, t] 60*da0073e9SAndroid Build Coastguard Worker bracket_f = [f_prev, f_new] 61*da0073e9SAndroid Build Coastguard Worker bracket_g = [g_prev, g_new.clone(memory_format=torch.contiguous_format)] 62*da0073e9SAndroid Build Coastguard Worker bracket_gtd = [gtd_prev, gtd_new] 63*da0073e9SAndroid Build Coastguard Worker break 64*da0073e9SAndroid Build Coastguard Worker 65*da0073e9SAndroid Build Coastguard Worker if abs(gtd_new) <= -c2 * gtd: 66*da0073e9SAndroid Build Coastguard Worker bracket = [t] 67*da0073e9SAndroid Build Coastguard Worker bracket_f = [f_new] 68*da0073e9SAndroid Build Coastguard Worker bracket_g = [g_new] 69*da0073e9SAndroid Build Coastguard Worker done = True 70*da0073e9SAndroid Build Coastguard Worker break 71*da0073e9SAndroid Build Coastguard Worker 72*da0073e9SAndroid Build Coastguard Worker if gtd_new >= 0: 73*da0073e9SAndroid Build Coastguard Worker bracket = [t_prev, t] 74*da0073e9SAndroid Build Coastguard Worker bracket_f = [f_prev, f_new] 75*da0073e9SAndroid Build Coastguard Worker bracket_g = [g_prev, g_new.clone(memory_format=torch.contiguous_format)] 76*da0073e9SAndroid Build Coastguard Worker bracket_gtd = [gtd_prev, gtd_new] 77*da0073e9SAndroid Build Coastguard Worker break 78*da0073e9SAndroid Build Coastguard Worker 79*da0073e9SAndroid Build Coastguard Worker # interpolate 80*da0073e9SAndroid Build Coastguard Worker min_step = t + 0.01 * (t - t_prev) 81*da0073e9SAndroid Build Coastguard Worker max_step = t * 10 82*da0073e9SAndroid Build Coastguard Worker tmp = t 83*da0073e9SAndroid Build Coastguard Worker t = _cubic_interpolate( 84*da0073e9SAndroid Build Coastguard Worker t_prev, f_prev, gtd_prev, t, f_new, gtd_new, bounds=(min_step, max_step) 85*da0073e9SAndroid Build Coastguard Worker ) 86*da0073e9SAndroid Build Coastguard Worker 87*da0073e9SAndroid Build Coastguard Worker # next step 88*da0073e9SAndroid Build Coastguard Worker t_prev = tmp 89*da0073e9SAndroid Build Coastguard Worker f_prev = f_new 90*da0073e9SAndroid Build Coastguard Worker g_prev = g_new.clone(memory_format=torch.contiguous_format) 91*da0073e9SAndroid Build Coastguard Worker gtd_prev = gtd_new 92*da0073e9SAndroid Build Coastguard Worker f_new, g_new = obj_func(x, t, d) 93*da0073e9SAndroid Build Coastguard Worker ls_func_evals += 1 94*da0073e9SAndroid Build Coastguard Worker gtd_new = g_new.dot(d) 95*da0073e9SAndroid Build Coastguard Worker ls_iter += 1 96*da0073e9SAndroid Build Coastguard Worker 97*da0073e9SAndroid Build Coastguard Worker # reached max number of iterations? 98*da0073e9SAndroid Build Coastguard Worker if ls_iter == max_ls: 99*da0073e9SAndroid Build Coastguard Worker bracket = [0, t] 100*da0073e9SAndroid Build Coastguard Worker bracket_f = [f, f_new] 101*da0073e9SAndroid Build Coastguard Worker bracket_g = [g, g_new] 102*da0073e9SAndroid Build Coastguard Worker 103*da0073e9SAndroid Build Coastguard Worker # zoom phase: we now have a point satisfying the criteria, or 104*da0073e9SAndroid Build Coastguard Worker # a bracket around it. We refine the bracket until we find the 105*da0073e9SAndroid Build Coastguard Worker # exact point satisfying the criteria 106*da0073e9SAndroid Build Coastguard Worker insuf_progress = False 107*da0073e9SAndroid Build Coastguard Worker # find high and low points in bracket 108*da0073e9SAndroid Build Coastguard Worker low_pos, high_pos = (0, 1) if bracket_f[0] <= bracket_f[-1] else (1, 0) # type: ignore[possibly-undefined] 109*da0073e9SAndroid Build Coastguard Worker while not done and ls_iter < max_ls: 110*da0073e9SAndroid Build Coastguard Worker # line-search bracket is so small 111*da0073e9SAndroid Build Coastguard Worker if abs(bracket[1] - bracket[0]) * d_norm < tolerance_change: # type: ignore[possibly-undefined] 112*da0073e9SAndroid Build Coastguard Worker break 113*da0073e9SAndroid Build Coastguard Worker 114*da0073e9SAndroid Build Coastguard Worker # compute new trial value 115*da0073e9SAndroid Build Coastguard Worker t = _cubic_interpolate( 116*da0073e9SAndroid Build Coastguard Worker bracket[0], 117*da0073e9SAndroid Build Coastguard Worker bracket_f[0], 118*da0073e9SAndroid Build Coastguard Worker bracket_gtd[0], # type: ignore[possibly-undefined] 119*da0073e9SAndroid Build Coastguard Worker bracket[1], 120*da0073e9SAndroid Build Coastguard Worker bracket_f[1], 121*da0073e9SAndroid Build Coastguard Worker bracket_gtd[1], 122*da0073e9SAndroid Build Coastguard Worker ) 123*da0073e9SAndroid Build Coastguard Worker 124*da0073e9SAndroid Build Coastguard Worker # test that we are making sufficient progress: 125*da0073e9SAndroid Build Coastguard Worker # in case `t` is so close to boundary, we mark that we are making 126*da0073e9SAndroid Build Coastguard Worker # insufficient progress, and if 127*da0073e9SAndroid Build Coastguard Worker # + we have made insufficient progress in the last step, or 128*da0073e9SAndroid Build Coastguard Worker # + `t` is at one of the boundary, 129*da0073e9SAndroid Build Coastguard Worker # we will move `t` to a position which is `0.1 * len(bracket)` 130*da0073e9SAndroid Build Coastguard Worker # away from the nearest boundary point. 131*da0073e9SAndroid Build Coastguard Worker eps = 0.1 * (max(bracket) - min(bracket)) 132*da0073e9SAndroid Build Coastguard Worker if min(max(bracket) - t, t - min(bracket)) < eps: 133*da0073e9SAndroid Build Coastguard Worker # interpolation close to boundary 134*da0073e9SAndroid Build Coastguard Worker if insuf_progress or t >= max(bracket) or t <= min(bracket): 135*da0073e9SAndroid Build Coastguard Worker # evaluate at 0.1 away from boundary 136*da0073e9SAndroid Build Coastguard Worker if abs(t - max(bracket)) < abs(t - min(bracket)): 137*da0073e9SAndroid Build Coastguard Worker t = max(bracket) - eps 138*da0073e9SAndroid Build Coastguard Worker else: 139*da0073e9SAndroid Build Coastguard Worker t = min(bracket) + eps 140*da0073e9SAndroid Build Coastguard Worker insuf_progress = False 141*da0073e9SAndroid Build Coastguard Worker else: 142*da0073e9SAndroid Build Coastguard Worker insuf_progress = True 143*da0073e9SAndroid Build Coastguard Worker else: 144*da0073e9SAndroid Build Coastguard Worker insuf_progress = False 145*da0073e9SAndroid Build Coastguard Worker 146*da0073e9SAndroid Build Coastguard Worker # Evaluate new point 147*da0073e9SAndroid Build Coastguard Worker f_new, g_new = obj_func(x, t, d) 148*da0073e9SAndroid Build Coastguard Worker ls_func_evals += 1 149*da0073e9SAndroid Build Coastguard Worker gtd_new = g_new.dot(d) 150*da0073e9SAndroid Build Coastguard Worker ls_iter += 1 151*da0073e9SAndroid Build Coastguard Worker 152*da0073e9SAndroid Build Coastguard Worker if f_new > (f + c1 * t * gtd) or f_new >= bracket_f[low_pos]: 153*da0073e9SAndroid Build Coastguard Worker # Armijo condition not satisfied or not lower than lowest point 154*da0073e9SAndroid Build Coastguard Worker bracket[high_pos] = t 155*da0073e9SAndroid Build Coastguard Worker bracket_f[high_pos] = f_new 156*da0073e9SAndroid Build Coastguard Worker bracket_g[high_pos] = g_new.clone(memory_format=torch.contiguous_format) # type: ignore[possibly-undefined] 157*da0073e9SAndroid Build Coastguard Worker bracket_gtd[high_pos] = gtd_new 158*da0073e9SAndroid Build Coastguard Worker low_pos, high_pos = (0, 1) if bracket_f[0] <= bracket_f[1] else (1, 0) 159*da0073e9SAndroid Build Coastguard Worker else: 160*da0073e9SAndroid Build Coastguard Worker if abs(gtd_new) <= -c2 * gtd: 161*da0073e9SAndroid Build Coastguard Worker # Wolfe conditions satisfied 162*da0073e9SAndroid Build Coastguard Worker done = True 163*da0073e9SAndroid Build Coastguard Worker elif gtd_new * (bracket[high_pos] - bracket[low_pos]) >= 0: 164*da0073e9SAndroid Build Coastguard Worker # old high becomes new low 165*da0073e9SAndroid Build Coastguard Worker bracket[high_pos] = bracket[low_pos] 166*da0073e9SAndroid Build Coastguard Worker bracket_f[high_pos] = bracket_f[low_pos] 167*da0073e9SAndroid Build Coastguard Worker bracket_g[high_pos] = bracket_g[low_pos] # type: ignore[possibly-undefined] 168*da0073e9SAndroid Build Coastguard Worker bracket_gtd[high_pos] = bracket_gtd[low_pos] 169*da0073e9SAndroid Build Coastguard Worker 170*da0073e9SAndroid Build Coastguard Worker # new point becomes new low 171*da0073e9SAndroid Build Coastguard Worker bracket[low_pos] = t 172*da0073e9SAndroid Build Coastguard Worker bracket_f[low_pos] = f_new 173*da0073e9SAndroid Build Coastguard Worker bracket_g[low_pos] = g_new.clone(memory_format=torch.contiguous_format) # type: ignore[possibly-undefined] 174*da0073e9SAndroid Build Coastguard Worker bracket_gtd[low_pos] = gtd_new 175*da0073e9SAndroid Build Coastguard Worker 176*da0073e9SAndroid Build Coastguard Worker # return stuff 177*da0073e9SAndroid Build Coastguard Worker t = bracket[low_pos] # type: ignore[possibly-undefined] 178*da0073e9SAndroid Build Coastguard Worker f_new = bracket_f[low_pos] 179*da0073e9SAndroid Build Coastguard Worker g_new = bracket_g[low_pos] # type: ignore[possibly-undefined] 180*da0073e9SAndroid Build Coastguard Worker return f_new, g_new, t, ls_func_evals 181*da0073e9SAndroid Build Coastguard Worker 182*da0073e9SAndroid Build Coastguard Worker 183*da0073e9SAndroid Build Coastguard Workerclass LBFGS(Optimizer): 184*da0073e9SAndroid Build Coastguard Worker """Implements L-BFGS algorithm. 185*da0073e9SAndroid Build Coastguard Worker 186*da0073e9SAndroid Build Coastguard Worker Heavily inspired by `minFunc 187*da0073e9SAndroid Build Coastguard Worker <https://www.cs.ubc.ca/~schmidtm/Software/minFunc.html>`_. 188*da0073e9SAndroid Build Coastguard Worker 189*da0073e9SAndroid Build Coastguard Worker .. warning:: 190*da0073e9SAndroid Build Coastguard Worker This optimizer doesn't support per-parameter options and parameter 191*da0073e9SAndroid Build Coastguard Worker groups (there can be only one). 192*da0073e9SAndroid Build Coastguard Worker 193*da0073e9SAndroid Build Coastguard Worker .. warning:: 194*da0073e9SAndroid Build Coastguard Worker Right now all parameters have to be on a single device. This will be 195*da0073e9SAndroid Build Coastguard Worker improved in the future. 196*da0073e9SAndroid Build Coastguard Worker 197*da0073e9SAndroid Build Coastguard Worker .. note:: 198*da0073e9SAndroid Build Coastguard Worker This is a very memory intensive optimizer (it requires additional 199*da0073e9SAndroid Build Coastguard Worker ``param_bytes * (history_size + 1)`` bytes). If it doesn't fit in memory 200*da0073e9SAndroid Build Coastguard Worker try reducing the history size, or use a different algorithm. 201*da0073e9SAndroid Build Coastguard Worker 202*da0073e9SAndroid Build Coastguard Worker Args: 203*da0073e9SAndroid Build Coastguard Worker params (iterable): iterable of parameters to optimize. Parameters must be real. 204*da0073e9SAndroid Build Coastguard Worker lr (float): learning rate (default: 1) 205*da0073e9SAndroid Build Coastguard Worker max_iter (int): maximal number of iterations per optimization step 206*da0073e9SAndroid Build Coastguard Worker (default: 20) 207*da0073e9SAndroid Build Coastguard Worker max_eval (int): maximal number of function evaluations per optimization 208*da0073e9SAndroid Build Coastguard Worker step (default: max_iter * 1.25). 209*da0073e9SAndroid Build Coastguard Worker tolerance_grad (float): termination tolerance on first order optimality 210*da0073e9SAndroid Build Coastguard Worker (default: 1e-7). 211*da0073e9SAndroid Build Coastguard Worker tolerance_change (float): termination tolerance on function 212*da0073e9SAndroid Build Coastguard Worker value/parameter changes (default: 1e-9). 213*da0073e9SAndroid Build Coastguard Worker history_size (int): update history size (default: 100). 214*da0073e9SAndroid Build Coastguard Worker line_search_fn (str): either 'strong_wolfe' or None (default: None). 215*da0073e9SAndroid Build Coastguard Worker """ 216*da0073e9SAndroid Build Coastguard Worker 217*da0073e9SAndroid Build Coastguard Worker def __init__( 218*da0073e9SAndroid Build Coastguard Worker self, 219*da0073e9SAndroid Build Coastguard Worker params: ParamsT, 220*da0073e9SAndroid Build Coastguard Worker lr: Union[float, Tensor] = 1, 221*da0073e9SAndroid Build Coastguard Worker max_iter: int = 20, 222*da0073e9SAndroid Build Coastguard Worker max_eval: Optional[int] = None, 223*da0073e9SAndroid Build Coastguard Worker tolerance_grad: float = 1e-7, 224*da0073e9SAndroid Build Coastguard Worker tolerance_change: float = 1e-9, 225*da0073e9SAndroid Build Coastguard Worker history_size: int = 100, 226*da0073e9SAndroid Build Coastguard Worker line_search_fn: Optional[str] = None, 227*da0073e9SAndroid Build Coastguard Worker ): 228*da0073e9SAndroid Build Coastguard Worker if isinstance(lr, Tensor) and lr.numel() != 1: 229*da0073e9SAndroid Build Coastguard Worker raise ValueError("Tensor lr must be 1-element") 230*da0073e9SAndroid Build Coastguard Worker if not 0.0 <= lr: 231*da0073e9SAndroid Build Coastguard Worker raise ValueError(f"Invalid learning rate: {lr}") 232*da0073e9SAndroid Build Coastguard Worker if max_eval is None: 233*da0073e9SAndroid Build Coastguard Worker max_eval = max_iter * 5 // 4 234*da0073e9SAndroid Build Coastguard Worker defaults = dict( 235*da0073e9SAndroid Build Coastguard Worker lr=lr, 236*da0073e9SAndroid Build Coastguard Worker max_iter=max_iter, 237*da0073e9SAndroid Build Coastguard Worker max_eval=max_eval, 238*da0073e9SAndroid Build Coastguard Worker tolerance_grad=tolerance_grad, 239*da0073e9SAndroid Build Coastguard Worker tolerance_change=tolerance_change, 240*da0073e9SAndroid Build Coastguard Worker history_size=history_size, 241*da0073e9SAndroid Build Coastguard Worker line_search_fn=line_search_fn, 242*da0073e9SAndroid Build Coastguard Worker ) 243*da0073e9SAndroid Build Coastguard Worker super().__init__(params, defaults) 244*da0073e9SAndroid Build Coastguard Worker 245*da0073e9SAndroid Build Coastguard Worker if len(self.param_groups) != 1: 246*da0073e9SAndroid Build Coastguard Worker raise ValueError( 247*da0073e9SAndroid Build Coastguard Worker "LBFGS doesn't support per-parameter options " "(parameter groups)" 248*da0073e9SAndroid Build Coastguard Worker ) 249*da0073e9SAndroid Build Coastguard Worker 250*da0073e9SAndroid Build Coastguard Worker self._params = self.param_groups[0]["params"] 251*da0073e9SAndroid Build Coastguard Worker self._numel_cache = None 252*da0073e9SAndroid Build Coastguard Worker 253*da0073e9SAndroid Build Coastguard Worker def _numel(self): 254*da0073e9SAndroid Build Coastguard Worker if self._numel_cache is None: 255*da0073e9SAndroid Build Coastguard Worker self._numel_cache = sum( 256*da0073e9SAndroid Build Coastguard Worker 2 * p.numel() if torch.is_complex(p) else p.numel() 257*da0073e9SAndroid Build Coastguard Worker for p in self._params 258*da0073e9SAndroid Build Coastguard Worker ) 259*da0073e9SAndroid Build Coastguard Worker 260*da0073e9SAndroid Build Coastguard Worker return self._numel_cache 261*da0073e9SAndroid Build Coastguard Worker 262*da0073e9SAndroid Build Coastguard Worker def _gather_flat_grad(self): 263*da0073e9SAndroid Build Coastguard Worker views = [] 264*da0073e9SAndroid Build Coastguard Worker for p in self._params: 265*da0073e9SAndroid Build Coastguard Worker if p.grad is None: 266*da0073e9SAndroid Build Coastguard Worker view = p.new(p.numel()).zero_() 267*da0073e9SAndroid Build Coastguard Worker elif p.grad.is_sparse: 268*da0073e9SAndroid Build Coastguard Worker view = p.grad.to_dense().view(-1) 269*da0073e9SAndroid Build Coastguard Worker else: 270*da0073e9SAndroid Build Coastguard Worker view = p.grad.view(-1) 271*da0073e9SAndroid Build Coastguard Worker if torch.is_complex(view): 272*da0073e9SAndroid Build Coastguard Worker view = torch.view_as_real(view).view(-1) 273*da0073e9SAndroid Build Coastguard Worker views.append(view) 274*da0073e9SAndroid Build Coastguard Worker return torch.cat(views, 0) 275*da0073e9SAndroid Build Coastguard Worker 276*da0073e9SAndroid Build Coastguard Worker def _add_grad(self, step_size, update): 277*da0073e9SAndroid Build Coastguard Worker offset = 0 278*da0073e9SAndroid Build Coastguard Worker for p in self._params: 279*da0073e9SAndroid Build Coastguard Worker if torch.is_complex(p): 280*da0073e9SAndroid Build Coastguard Worker p = torch.view_as_real(p) 281*da0073e9SAndroid Build Coastguard Worker numel = p.numel() 282*da0073e9SAndroid Build Coastguard Worker # view as to avoid deprecated pointwise semantics 283*da0073e9SAndroid Build Coastguard Worker p.add_(update[offset : offset + numel].view_as(p), alpha=step_size) 284*da0073e9SAndroid Build Coastguard Worker offset += numel 285*da0073e9SAndroid Build Coastguard Worker assert offset == self._numel() 286*da0073e9SAndroid Build Coastguard Worker 287*da0073e9SAndroid Build Coastguard Worker def _clone_param(self): 288*da0073e9SAndroid Build Coastguard Worker return [p.clone(memory_format=torch.contiguous_format) for p in self._params] 289*da0073e9SAndroid Build Coastguard Worker 290*da0073e9SAndroid Build Coastguard Worker def _set_param(self, params_data): 291*da0073e9SAndroid Build Coastguard Worker for p, pdata in zip(self._params, params_data): 292*da0073e9SAndroid Build Coastguard Worker p.copy_(pdata) 293*da0073e9SAndroid Build Coastguard Worker 294*da0073e9SAndroid Build Coastguard Worker def _directional_evaluate(self, closure, x, t, d): 295*da0073e9SAndroid Build Coastguard Worker self._add_grad(t, d) 296*da0073e9SAndroid Build Coastguard Worker loss = float(closure()) 297*da0073e9SAndroid Build Coastguard Worker flat_grad = self._gather_flat_grad() 298*da0073e9SAndroid Build Coastguard Worker self._set_param(x) 299*da0073e9SAndroid Build Coastguard Worker return loss, flat_grad 300*da0073e9SAndroid Build Coastguard Worker 301*da0073e9SAndroid Build Coastguard Worker @torch.no_grad() 302*da0073e9SAndroid Build Coastguard Worker def step(self, closure): 303*da0073e9SAndroid Build Coastguard Worker """Perform a single optimization step. 304*da0073e9SAndroid Build Coastguard Worker 305*da0073e9SAndroid Build Coastguard Worker Args: 306*da0073e9SAndroid Build Coastguard Worker closure (Callable): A closure that reevaluates the model 307*da0073e9SAndroid Build Coastguard Worker and returns the loss. 308*da0073e9SAndroid Build Coastguard Worker """ 309*da0073e9SAndroid Build Coastguard Worker assert len(self.param_groups) == 1 310*da0073e9SAndroid Build Coastguard Worker 311*da0073e9SAndroid Build Coastguard Worker # Make sure the closure is always called with grad enabled 312*da0073e9SAndroid Build Coastguard Worker closure = torch.enable_grad()(closure) 313*da0073e9SAndroid Build Coastguard Worker 314*da0073e9SAndroid Build Coastguard Worker group = self.param_groups[0] 315*da0073e9SAndroid Build Coastguard Worker lr = group["lr"] 316*da0073e9SAndroid Build Coastguard Worker max_iter = group["max_iter"] 317*da0073e9SAndroid Build Coastguard Worker max_eval = group["max_eval"] 318*da0073e9SAndroid Build Coastguard Worker tolerance_grad = group["tolerance_grad"] 319*da0073e9SAndroid Build Coastguard Worker tolerance_change = group["tolerance_change"] 320*da0073e9SAndroid Build Coastguard Worker line_search_fn = group["line_search_fn"] 321*da0073e9SAndroid Build Coastguard Worker history_size = group["history_size"] 322*da0073e9SAndroid Build Coastguard Worker 323*da0073e9SAndroid Build Coastguard Worker # NOTE: LBFGS has only global state, but we register it as state for 324*da0073e9SAndroid Build Coastguard Worker # the first param, because this helps with casting in load_state_dict 325*da0073e9SAndroid Build Coastguard Worker state = self.state[self._params[0]] 326*da0073e9SAndroid Build Coastguard Worker state.setdefault("func_evals", 0) 327*da0073e9SAndroid Build Coastguard Worker state.setdefault("n_iter", 0) 328*da0073e9SAndroid Build Coastguard Worker 329*da0073e9SAndroid Build Coastguard Worker # evaluate initial f(x) and df/dx 330*da0073e9SAndroid Build Coastguard Worker orig_loss = closure() 331*da0073e9SAndroid Build Coastguard Worker loss = float(orig_loss) 332*da0073e9SAndroid Build Coastguard Worker current_evals = 1 333*da0073e9SAndroid Build Coastguard Worker state["func_evals"] += 1 334*da0073e9SAndroid Build Coastguard Worker 335*da0073e9SAndroid Build Coastguard Worker flat_grad = self._gather_flat_grad() 336*da0073e9SAndroid Build Coastguard Worker opt_cond = flat_grad.abs().max() <= tolerance_grad 337*da0073e9SAndroid Build Coastguard Worker 338*da0073e9SAndroid Build Coastguard Worker # optimal condition 339*da0073e9SAndroid Build Coastguard Worker if opt_cond: 340*da0073e9SAndroid Build Coastguard Worker return orig_loss 341*da0073e9SAndroid Build Coastguard Worker 342*da0073e9SAndroid Build Coastguard Worker # tensors cached in state (for tracing) 343*da0073e9SAndroid Build Coastguard Worker d = state.get("d") 344*da0073e9SAndroid Build Coastguard Worker t = state.get("t") 345*da0073e9SAndroid Build Coastguard Worker old_dirs = state.get("old_dirs") 346*da0073e9SAndroid Build Coastguard Worker old_stps = state.get("old_stps") 347*da0073e9SAndroid Build Coastguard Worker ro = state.get("ro") 348*da0073e9SAndroid Build Coastguard Worker H_diag = state.get("H_diag") 349*da0073e9SAndroid Build Coastguard Worker prev_flat_grad = state.get("prev_flat_grad") 350*da0073e9SAndroid Build Coastguard Worker prev_loss = state.get("prev_loss") 351*da0073e9SAndroid Build Coastguard Worker 352*da0073e9SAndroid Build Coastguard Worker n_iter = 0 353*da0073e9SAndroid Build Coastguard Worker # optimize for a max of max_iter iterations 354*da0073e9SAndroid Build Coastguard Worker while n_iter < max_iter: 355*da0073e9SAndroid Build Coastguard Worker # keep track of nb of iterations 356*da0073e9SAndroid Build Coastguard Worker n_iter += 1 357*da0073e9SAndroid Build Coastguard Worker state["n_iter"] += 1 358*da0073e9SAndroid Build Coastguard Worker 359*da0073e9SAndroid Build Coastguard Worker ############################################################ 360*da0073e9SAndroid Build Coastguard Worker # compute gradient descent direction 361*da0073e9SAndroid Build Coastguard Worker ############################################################ 362*da0073e9SAndroid Build Coastguard Worker if state["n_iter"] == 1: 363*da0073e9SAndroid Build Coastguard Worker d = flat_grad.neg() 364*da0073e9SAndroid Build Coastguard Worker old_dirs = [] 365*da0073e9SAndroid Build Coastguard Worker old_stps = [] 366*da0073e9SAndroid Build Coastguard Worker ro = [] 367*da0073e9SAndroid Build Coastguard Worker H_diag = 1 368*da0073e9SAndroid Build Coastguard Worker else: 369*da0073e9SAndroid Build Coastguard Worker # do lbfgs update (update memory) 370*da0073e9SAndroid Build Coastguard Worker y = flat_grad.sub(prev_flat_grad) 371*da0073e9SAndroid Build Coastguard Worker s = d.mul(t) 372*da0073e9SAndroid Build Coastguard Worker ys = y.dot(s) # y*s 373*da0073e9SAndroid Build Coastguard Worker if ys > 1e-10: 374*da0073e9SAndroid Build Coastguard Worker # updating memory 375*da0073e9SAndroid Build Coastguard Worker if len(old_dirs) == history_size: 376*da0073e9SAndroid Build Coastguard Worker # shift history by one (limited-memory) 377*da0073e9SAndroid Build Coastguard Worker old_dirs.pop(0) 378*da0073e9SAndroid Build Coastguard Worker old_stps.pop(0) 379*da0073e9SAndroid Build Coastguard Worker ro.pop(0) 380*da0073e9SAndroid Build Coastguard Worker 381*da0073e9SAndroid Build Coastguard Worker # store new direction/step 382*da0073e9SAndroid Build Coastguard Worker old_dirs.append(y) 383*da0073e9SAndroid Build Coastguard Worker old_stps.append(s) 384*da0073e9SAndroid Build Coastguard Worker ro.append(1.0 / ys) 385*da0073e9SAndroid Build Coastguard Worker 386*da0073e9SAndroid Build Coastguard Worker # update scale of initial Hessian approximation 387*da0073e9SAndroid Build Coastguard Worker H_diag = ys / y.dot(y) # (y*y) 388*da0073e9SAndroid Build Coastguard Worker 389*da0073e9SAndroid Build Coastguard Worker # compute the approximate (L-BFGS) inverse Hessian 390*da0073e9SAndroid Build Coastguard Worker # multiplied by the gradient 391*da0073e9SAndroid Build Coastguard Worker num_old = len(old_dirs) 392*da0073e9SAndroid Build Coastguard Worker 393*da0073e9SAndroid Build Coastguard Worker if "al" not in state: 394*da0073e9SAndroid Build Coastguard Worker state["al"] = [None] * history_size 395*da0073e9SAndroid Build Coastguard Worker al = state["al"] 396*da0073e9SAndroid Build Coastguard Worker 397*da0073e9SAndroid Build Coastguard Worker # iteration in L-BFGS loop collapsed to use just one buffer 398*da0073e9SAndroid Build Coastguard Worker q = flat_grad.neg() 399*da0073e9SAndroid Build Coastguard Worker for i in range(num_old - 1, -1, -1): 400*da0073e9SAndroid Build Coastguard Worker al[i] = old_stps[i].dot(q) * ro[i] 401*da0073e9SAndroid Build Coastguard Worker q.add_(old_dirs[i], alpha=-al[i]) 402*da0073e9SAndroid Build Coastguard Worker 403*da0073e9SAndroid Build Coastguard Worker # multiply by initial Hessian 404*da0073e9SAndroid Build Coastguard Worker # r/d is the final direction 405*da0073e9SAndroid Build Coastguard Worker d = r = torch.mul(q, H_diag) 406*da0073e9SAndroid Build Coastguard Worker for i in range(num_old): 407*da0073e9SAndroid Build Coastguard Worker be_i = old_dirs[i].dot(r) * ro[i] 408*da0073e9SAndroid Build Coastguard Worker r.add_(old_stps[i], alpha=al[i] - be_i) 409*da0073e9SAndroid Build Coastguard Worker 410*da0073e9SAndroid Build Coastguard Worker if prev_flat_grad is None: 411*da0073e9SAndroid Build Coastguard Worker prev_flat_grad = flat_grad.clone(memory_format=torch.contiguous_format) 412*da0073e9SAndroid Build Coastguard Worker else: 413*da0073e9SAndroid Build Coastguard Worker prev_flat_grad.copy_(flat_grad) 414*da0073e9SAndroid Build Coastguard Worker prev_loss = loss 415*da0073e9SAndroid Build Coastguard Worker 416*da0073e9SAndroid Build Coastguard Worker ############################################################ 417*da0073e9SAndroid Build Coastguard Worker # compute step length 418*da0073e9SAndroid Build Coastguard Worker ############################################################ 419*da0073e9SAndroid Build Coastguard Worker # reset initial guess for step size 420*da0073e9SAndroid Build Coastguard Worker if state["n_iter"] == 1: 421*da0073e9SAndroid Build Coastguard Worker t = min(1.0, 1.0 / flat_grad.abs().sum()) * lr 422*da0073e9SAndroid Build Coastguard Worker else: 423*da0073e9SAndroid Build Coastguard Worker t = lr 424*da0073e9SAndroid Build Coastguard Worker 425*da0073e9SAndroid Build Coastguard Worker # directional derivative 426*da0073e9SAndroid Build Coastguard Worker gtd = flat_grad.dot(d) # g * d 427*da0073e9SAndroid Build Coastguard Worker 428*da0073e9SAndroid Build Coastguard Worker # directional derivative is below tolerance 429*da0073e9SAndroid Build Coastguard Worker if gtd > -tolerance_change: 430*da0073e9SAndroid Build Coastguard Worker break 431*da0073e9SAndroid Build Coastguard Worker 432*da0073e9SAndroid Build Coastguard Worker # optional line search: user function 433*da0073e9SAndroid Build Coastguard Worker ls_func_evals = 0 434*da0073e9SAndroid Build Coastguard Worker if line_search_fn is not None: 435*da0073e9SAndroid Build Coastguard Worker # perform line search, using user function 436*da0073e9SAndroid Build Coastguard Worker if line_search_fn != "strong_wolfe": 437*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("only 'strong_wolfe' is supported") 438*da0073e9SAndroid Build Coastguard Worker else: 439*da0073e9SAndroid Build Coastguard Worker x_init = self._clone_param() 440*da0073e9SAndroid Build Coastguard Worker 441*da0073e9SAndroid Build Coastguard Worker def obj_func(x, t, d): 442*da0073e9SAndroid Build Coastguard Worker return self._directional_evaluate(closure, x, t, d) 443*da0073e9SAndroid Build Coastguard Worker 444*da0073e9SAndroid Build Coastguard Worker loss, flat_grad, t, ls_func_evals = _strong_wolfe( 445*da0073e9SAndroid Build Coastguard Worker obj_func, x_init, t, d, loss, flat_grad, gtd 446*da0073e9SAndroid Build Coastguard Worker ) 447*da0073e9SAndroid Build Coastguard Worker self._add_grad(t, d) 448*da0073e9SAndroid Build Coastguard Worker opt_cond = flat_grad.abs().max() <= tolerance_grad 449*da0073e9SAndroid Build Coastguard Worker else: 450*da0073e9SAndroid Build Coastguard Worker # no line search, simply move with fixed-step 451*da0073e9SAndroid Build Coastguard Worker self._add_grad(t, d) 452*da0073e9SAndroid Build Coastguard Worker if n_iter != max_iter: 453*da0073e9SAndroid Build Coastguard Worker # re-evaluate function only if not in last iteration 454*da0073e9SAndroid Build Coastguard Worker # the reason we do this: in a stochastic setting, 455*da0073e9SAndroid Build Coastguard Worker # no use to re-evaluate that function here 456*da0073e9SAndroid Build Coastguard Worker with torch.enable_grad(): 457*da0073e9SAndroid Build Coastguard Worker loss = float(closure()) 458*da0073e9SAndroid Build Coastguard Worker flat_grad = self._gather_flat_grad() 459*da0073e9SAndroid Build Coastguard Worker opt_cond = flat_grad.abs().max() <= tolerance_grad 460*da0073e9SAndroid Build Coastguard Worker ls_func_evals = 1 461*da0073e9SAndroid Build Coastguard Worker 462*da0073e9SAndroid Build Coastguard Worker # update func eval 463*da0073e9SAndroid Build Coastguard Worker current_evals += ls_func_evals 464*da0073e9SAndroid Build Coastguard Worker state["func_evals"] += ls_func_evals 465*da0073e9SAndroid Build Coastguard Worker 466*da0073e9SAndroid Build Coastguard Worker ############################################################ 467*da0073e9SAndroid Build Coastguard Worker # check conditions 468*da0073e9SAndroid Build Coastguard Worker ############################################################ 469*da0073e9SAndroid Build Coastguard Worker if n_iter == max_iter: 470*da0073e9SAndroid Build Coastguard Worker break 471*da0073e9SAndroid Build Coastguard Worker 472*da0073e9SAndroid Build Coastguard Worker if current_evals >= max_eval: 473*da0073e9SAndroid Build Coastguard Worker break 474*da0073e9SAndroid Build Coastguard Worker 475*da0073e9SAndroid Build Coastguard Worker # optimal condition 476*da0073e9SAndroid Build Coastguard Worker if opt_cond: 477*da0073e9SAndroid Build Coastguard Worker break 478*da0073e9SAndroid Build Coastguard Worker 479*da0073e9SAndroid Build Coastguard Worker # lack of progress 480*da0073e9SAndroid Build Coastguard Worker if d.mul(t).abs().max() <= tolerance_change: 481*da0073e9SAndroid Build Coastguard Worker break 482*da0073e9SAndroid Build Coastguard Worker 483*da0073e9SAndroid Build Coastguard Worker if abs(loss - prev_loss) < tolerance_change: 484*da0073e9SAndroid Build Coastguard Worker break 485*da0073e9SAndroid Build Coastguard Worker 486*da0073e9SAndroid Build Coastguard Worker state["d"] = d 487*da0073e9SAndroid Build Coastguard Worker state["t"] = t 488*da0073e9SAndroid Build Coastguard Worker state["old_dirs"] = old_dirs 489*da0073e9SAndroid Build Coastguard Worker state["old_stps"] = old_stps 490*da0073e9SAndroid Build Coastguard Worker state["ro"] = ro 491*da0073e9SAndroid Build Coastguard Worker state["H_diag"] = H_diag 492*da0073e9SAndroid Build Coastguard Worker state["prev_flat_grad"] = prev_flat_grad 493*da0073e9SAndroid Build Coastguard Worker state["prev_loss"] = prev_loss 494*da0073e9SAndroid Build Coastguard Worker 495*da0073e9SAndroid Build Coastguard Worker return orig_loss 496