1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-decorators 2*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs 3*da0073e9SAndroid Build Coastguard Workerfrom typing import cast, List, Optional, Tuple, Union 4*da0073e9SAndroid Build Coastguard Worker 5*da0073e9SAndroid Build Coastguard Workerimport torch 6*da0073e9SAndroid Build Coastguard Workerfrom torch import Tensor 7*da0073e9SAndroid Build Coastguard Worker 8*da0073e9SAndroid Build Coastguard Workerfrom .optimizer import ( 9*da0073e9SAndroid Build Coastguard Worker _capturable_doc, 10*da0073e9SAndroid Build Coastguard Worker _default_to_fused_or_foreach, 11*da0073e9SAndroid Build Coastguard Worker _differentiable_doc, 12*da0073e9SAndroid Build Coastguard Worker _disable_dynamo_if_unsupported, 13*da0073e9SAndroid Build Coastguard Worker _foreach_doc, 14*da0073e9SAndroid Build Coastguard Worker _get_capturable_supported_devices, 15*da0073e9SAndroid Build Coastguard Worker _get_scalar_dtype, 16*da0073e9SAndroid Build Coastguard Worker _get_value, 17*da0073e9SAndroid Build Coastguard Worker _maximize_doc, 18*da0073e9SAndroid Build Coastguard Worker _use_grad_for_differentiable, 19*da0073e9SAndroid Build Coastguard Worker _view_as_real, 20*da0073e9SAndroid Build Coastguard Worker Optimizer, 21*da0073e9SAndroid Build Coastguard Worker ParamsT, 22*da0073e9SAndroid Build Coastguard Worker) 23*da0073e9SAndroid Build Coastguard Worker 24*da0073e9SAndroid Build Coastguard Worker 25*da0073e9SAndroid Build Coastguard Worker__all__ = ["Adamax", "adamax"] 26*da0073e9SAndroid Build Coastguard Worker 27*da0073e9SAndroid Build Coastguard Worker 28*da0073e9SAndroid Build Coastguard Workerclass Adamax(Optimizer): 29*da0073e9SAndroid Build Coastguard Worker def __init__( 30*da0073e9SAndroid Build Coastguard Worker self, 31*da0073e9SAndroid Build Coastguard Worker params: ParamsT, 32*da0073e9SAndroid Build Coastguard Worker lr: Union[float, Tensor] = 2e-3, 33*da0073e9SAndroid Build Coastguard Worker betas: Tuple[float, float] = (0.9, 0.999), 34*da0073e9SAndroid Build Coastguard Worker eps: float = 1e-8, 35*da0073e9SAndroid Build Coastguard Worker weight_decay: float = 0, 36*da0073e9SAndroid Build Coastguard Worker foreach: Optional[bool] = None, 37*da0073e9SAndroid Build Coastguard Worker *, 38*da0073e9SAndroid Build Coastguard Worker maximize: bool = False, 39*da0073e9SAndroid Build Coastguard Worker differentiable: bool = False, 40*da0073e9SAndroid Build Coastguard Worker capturable: bool = False, 41*da0073e9SAndroid Build Coastguard Worker ): 42*da0073e9SAndroid Build Coastguard Worker if isinstance(lr, Tensor) and lr.numel() != 1: 43*da0073e9SAndroid Build Coastguard Worker raise ValueError("Tensor lr must be 1-element") 44*da0073e9SAndroid Build Coastguard Worker if not 0.0 <= lr: 45*da0073e9SAndroid Build Coastguard Worker raise ValueError(f"Invalid learning rate: {lr}") 46*da0073e9SAndroid Build Coastguard Worker if not 0.0 <= eps: 47*da0073e9SAndroid Build Coastguard Worker raise ValueError(f"Invalid epsilon value: {eps}") 48*da0073e9SAndroid Build Coastguard Worker if not 0.0 <= betas[0] < 1.0: 49*da0073e9SAndroid Build Coastguard Worker raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}") 50*da0073e9SAndroid Build Coastguard Worker if not 0.0 <= betas[1] < 1.0: 51*da0073e9SAndroid Build Coastguard Worker raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}") 52*da0073e9SAndroid Build Coastguard Worker if not 0.0 <= weight_decay: 53*da0073e9SAndroid Build Coastguard Worker raise ValueError(f"Invalid weight_decay value: {weight_decay}") 54*da0073e9SAndroid Build Coastguard Worker 55*da0073e9SAndroid Build Coastguard Worker defaults = dict( 56*da0073e9SAndroid Build Coastguard Worker lr=lr, 57*da0073e9SAndroid Build Coastguard Worker betas=betas, 58*da0073e9SAndroid Build Coastguard Worker eps=eps, 59*da0073e9SAndroid Build Coastguard Worker weight_decay=weight_decay, 60*da0073e9SAndroid Build Coastguard Worker foreach=foreach, 61*da0073e9SAndroid Build Coastguard Worker maximize=maximize, 62*da0073e9SAndroid Build Coastguard Worker differentiable=differentiable, 63*da0073e9SAndroid Build Coastguard Worker capturable=capturable, 64*da0073e9SAndroid Build Coastguard Worker ) 65*da0073e9SAndroid Build Coastguard Worker super().__init__(params, defaults) 66*da0073e9SAndroid Build Coastguard Worker 67*da0073e9SAndroid Build Coastguard Worker def __setstate__(self, state): 68*da0073e9SAndroid Build Coastguard Worker super().__setstate__(state) 69*da0073e9SAndroid Build Coastguard Worker for group in self.param_groups: 70*da0073e9SAndroid Build Coastguard Worker group.setdefault("foreach", None) 71*da0073e9SAndroid Build Coastguard Worker group.setdefault("maximize", False) 72*da0073e9SAndroid Build Coastguard Worker group.setdefault("differentiable", False) 73*da0073e9SAndroid Build Coastguard Worker group.setdefault("capturable", False) 74*da0073e9SAndroid Build Coastguard Worker for p in group["params"]: 75*da0073e9SAndroid Build Coastguard Worker p_state = self.state.get(p, []) 76*da0073e9SAndroid Build Coastguard Worker if len(p_state) != 0 and not torch.is_tensor(p_state["step"]): 77*da0073e9SAndroid Build Coastguard Worker step_val = float(p_state["step"]) 78*da0073e9SAndroid Build Coastguard Worker p_state["step"] = ( 79*da0073e9SAndroid Build Coastguard Worker torch.tensor( 80*da0073e9SAndroid Build Coastguard Worker step_val, dtype=_get_scalar_dtype(), device=p.device 81*da0073e9SAndroid Build Coastguard Worker ) 82*da0073e9SAndroid Build Coastguard Worker if group["capturable"] 83*da0073e9SAndroid Build Coastguard Worker else torch.tensor(step_val, dtype=_get_scalar_dtype()) 84*da0073e9SAndroid Build Coastguard Worker ) 85*da0073e9SAndroid Build Coastguard Worker 86*da0073e9SAndroid Build Coastguard Worker def _init_group( 87*da0073e9SAndroid Build Coastguard Worker self, group, params_with_grad, grads, exp_avgs, exp_infs, state_steps 88*da0073e9SAndroid Build Coastguard Worker ): 89*da0073e9SAndroid Build Coastguard Worker has_complex = False 90*da0073e9SAndroid Build Coastguard Worker for p in group["params"]: 91*da0073e9SAndroid Build Coastguard Worker if p.grad is None: 92*da0073e9SAndroid Build Coastguard Worker continue 93*da0073e9SAndroid Build Coastguard Worker has_complex |= torch.is_complex(p) 94*da0073e9SAndroid Build Coastguard Worker params_with_grad.append(p) 95*da0073e9SAndroid Build Coastguard Worker if p.grad.is_sparse: 96*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("Adamax does not support sparse gradients") 97*da0073e9SAndroid Build Coastguard Worker grads.append(p.grad) 98*da0073e9SAndroid Build Coastguard Worker 99*da0073e9SAndroid Build Coastguard Worker state = self.state[p] 100*da0073e9SAndroid Build Coastguard Worker 101*da0073e9SAndroid Build Coastguard Worker # State initialization 102*da0073e9SAndroid Build Coastguard Worker if len(state) == 0: 103*da0073e9SAndroid Build Coastguard Worker state["step"] = ( 104*da0073e9SAndroid Build Coastguard Worker torch.zeros((), dtype=_get_scalar_dtype(), device=p.device) 105*da0073e9SAndroid Build Coastguard Worker if group["capturable"] 106*da0073e9SAndroid Build Coastguard Worker else torch.tensor(0.0, dtype=_get_scalar_dtype()) 107*da0073e9SAndroid Build Coastguard Worker ) 108*da0073e9SAndroid Build Coastguard Worker state["exp_avg"] = torch.zeros_like( 109*da0073e9SAndroid Build Coastguard Worker p, memory_format=torch.preserve_format 110*da0073e9SAndroid Build Coastguard Worker ) 111*da0073e9SAndroid Build Coastguard Worker state["exp_inf"] = torch.zeros_like( 112*da0073e9SAndroid Build Coastguard Worker p, memory_format=torch.preserve_format 113*da0073e9SAndroid Build Coastguard Worker ) 114*da0073e9SAndroid Build Coastguard Worker 115*da0073e9SAndroid Build Coastguard Worker exp_avgs.append(state["exp_avg"]) 116*da0073e9SAndroid Build Coastguard Worker exp_infs.append(state["exp_inf"]) 117*da0073e9SAndroid Build Coastguard Worker state_steps.append(state["step"]) 118*da0073e9SAndroid Build Coastguard Worker 119*da0073e9SAndroid Build Coastguard Worker return has_complex 120*da0073e9SAndroid Build Coastguard Worker 121*da0073e9SAndroid Build Coastguard Worker @_use_grad_for_differentiable 122*da0073e9SAndroid Build Coastguard Worker def step(self, closure=None): 123*da0073e9SAndroid Build Coastguard Worker """Performs a single optimization step. 124*da0073e9SAndroid Build Coastguard Worker 125*da0073e9SAndroid Build Coastguard Worker Args: 126*da0073e9SAndroid Build Coastguard Worker closure (Callable, optional): A closure that reevaluates the model 127*da0073e9SAndroid Build Coastguard Worker and returns the loss. 128*da0073e9SAndroid Build Coastguard Worker """ 129*da0073e9SAndroid Build Coastguard Worker self._cuda_graph_capture_health_check() 130*da0073e9SAndroid Build Coastguard Worker 131*da0073e9SAndroid Build Coastguard Worker loss = None 132*da0073e9SAndroid Build Coastguard Worker if closure is not None: 133*da0073e9SAndroid Build Coastguard Worker with torch.enable_grad(): 134*da0073e9SAndroid Build Coastguard Worker loss = closure() 135*da0073e9SAndroid Build Coastguard Worker 136*da0073e9SAndroid Build Coastguard Worker for group in self.param_groups: 137*da0073e9SAndroid Build Coastguard Worker params_with_grad: List[Tensor] = [] 138*da0073e9SAndroid Build Coastguard Worker grads: List[Tensor] = [] 139*da0073e9SAndroid Build Coastguard Worker exp_avgs: List[Tensor] = [] 140*da0073e9SAndroid Build Coastguard Worker exp_infs: List[Tensor] = [] 141*da0073e9SAndroid Build Coastguard Worker state_steps: List[Tensor] = [] 142*da0073e9SAndroid Build Coastguard Worker 143*da0073e9SAndroid Build Coastguard Worker beta1, beta2 = group["betas"] 144*da0073e9SAndroid Build Coastguard Worker eps = group["eps"] 145*da0073e9SAndroid Build Coastguard Worker lr = group["lr"] 146*da0073e9SAndroid Build Coastguard Worker weight_decay = group["weight_decay"] 147*da0073e9SAndroid Build Coastguard Worker foreach = group["foreach"] 148*da0073e9SAndroid Build Coastguard Worker maximize = group["maximize"] 149*da0073e9SAndroid Build Coastguard Worker differentiable = group["differentiable"] 150*da0073e9SAndroid Build Coastguard Worker capturable = group["capturable"] 151*da0073e9SAndroid Build Coastguard Worker 152*da0073e9SAndroid Build Coastguard Worker has_complex = self._init_group( 153*da0073e9SAndroid Build Coastguard Worker group, params_with_grad, grads, exp_avgs, exp_infs, state_steps 154*da0073e9SAndroid Build Coastguard Worker ) 155*da0073e9SAndroid Build Coastguard Worker 156*da0073e9SAndroid Build Coastguard Worker adamax( 157*da0073e9SAndroid Build Coastguard Worker params_with_grad, 158*da0073e9SAndroid Build Coastguard Worker grads, 159*da0073e9SAndroid Build Coastguard Worker exp_avgs, 160*da0073e9SAndroid Build Coastguard Worker exp_infs, 161*da0073e9SAndroid Build Coastguard Worker state_steps, 162*da0073e9SAndroid Build Coastguard Worker eps=eps, 163*da0073e9SAndroid Build Coastguard Worker beta1=beta1, 164*da0073e9SAndroid Build Coastguard Worker beta2=beta2, 165*da0073e9SAndroid Build Coastguard Worker lr=lr, 166*da0073e9SAndroid Build Coastguard Worker weight_decay=weight_decay, 167*da0073e9SAndroid Build Coastguard Worker foreach=foreach, 168*da0073e9SAndroid Build Coastguard Worker maximize=maximize, 169*da0073e9SAndroid Build Coastguard Worker differentiable=differentiable, 170*da0073e9SAndroid Build Coastguard Worker capturable=capturable, 171*da0073e9SAndroid Build Coastguard Worker has_complex=has_complex, 172*da0073e9SAndroid Build Coastguard Worker ) 173*da0073e9SAndroid Build Coastguard Worker 174*da0073e9SAndroid Build Coastguard Worker return loss 175*da0073e9SAndroid Build Coastguard Worker 176*da0073e9SAndroid Build Coastguard Worker 177*da0073e9SAndroid Build Coastguard WorkerAdamax.__doc__ = ( 178*da0073e9SAndroid Build Coastguard Worker r"""Implements Adamax algorithm (a variant of Adam based on infinity norm). 179*da0073e9SAndroid Build Coastguard Worker 180*da0073e9SAndroid Build Coastguard Worker .. math:: 181*da0073e9SAndroid Build Coastguard Worker \begin{aligned} 182*da0073e9SAndroid Build Coastguard Worker &\rule{110mm}{0.4pt} \\ 183*da0073e9SAndroid Build Coastguard Worker &\textbf{input} : \gamma \text{ (lr)}, \beta_1, \beta_2 184*da0073e9SAndroid Build Coastguard Worker \text{ (betas)},\theta_0 \text{ (params)},f(\theta) \text{ (objective)}, 185*da0073e9SAndroid Build Coastguard Worker \: \lambda \text{ (weight decay)}, \\ 186*da0073e9SAndroid Build Coastguard Worker &\hspace{13mm} \epsilon \text{ (epsilon)} \\ 187*da0073e9SAndroid Build Coastguard Worker &\textbf{initialize} : m_0 \leftarrow 0 \text{ ( first moment)}, 188*da0073e9SAndroid Build Coastguard Worker u_0 \leftarrow 0 \text{ ( infinity norm)} \\[-1.ex] 189*da0073e9SAndroid Build Coastguard Worker &\rule{110mm}{0.4pt} \\ 190*da0073e9SAndroid Build Coastguard Worker &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\ 191*da0073e9SAndroid Build Coastguard Worker &\hspace{5mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\ 192*da0073e9SAndroid Build Coastguard Worker &\hspace{5mm}if \: \lambda \neq 0 \\ 193*da0073e9SAndroid Build Coastguard Worker &\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\ 194*da0073e9SAndroid Build Coastguard Worker &\hspace{5mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\ 195*da0073e9SAndroid Build Coastguard Worker &\hspace{5mm}u_t \leftarrow \mathrm{max}(\beta_2 u_{t-1}, |g_{t}|+\epsilon) \\ 196*da0073e9SAndroid Build Coastguard Worker &\hspace{5mm}\theta_t \leftarrow \theta_{t-1} - \frac{\gamma m_t}{(1-\beta^t_1) u_t} \\ 197*da0073e9SAndroid Build Coastguard Worker &\rule{110mm}{0.4pt} \\[-1.ex] 198*da0073e9SAndroid Build Coastguard Worker &\bf{return} \: \theta_t \\[-1.ex] 199*da0073e9SAndroid Build Coastguard Worker &\rule{110mm}{0.4pt} \\[-1.ex] 200*da0073e9SAndroid Build Coastguard Worker \end{aligned} 201*da0073e9SAndroid Build Coastguard Worker 202*da0073e9SAndroid Build Coastguard Worker For further details regarding the algorithm we refer to `Adam: A Method for Stochastic Optimization`_. 203*da0073e9SAndroid Build Coastguard Worker """ 204*da0073e9SAndroid Build Coastguard Worker + rf""" 205*da0073e9SAndroid Build Coastguard Worker Args: 206*da0073e9SAndroid Build Coastguard Worker params (iterable): iterable of parameters to optimize or dicts defining 207*da0073e9SAndroid Build Coastguard Worker parameter groups 208*da0073e9SAndroid Build Coastguard Worker lr (float, Tensor, optional): learning rate (default: 2e-3) 209*da0073e9SAndroid Build Coastguard Worker betas (Tuple[float, float], optional): coefficients used for computing 210*da0073e9SAndroid Build Coastguard Worker running averages of gradient and its square 211*da0073e9SAndroid Build Coastguard Worker eps (float, optional): term added to the denominator to improve 212*da0073e9SAndroid Build Coastguard Worker numerical stability (default: 1e-8) 213*da0073e9SAndroid Build Coastguard Worker weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 214*da0073e9SAndroid Build Coastguard Worker {_foreach_doc} 215*da0073e9SAndroid Build Coastguard Worker {_maximize_doc} 216*da0073e9SAndroid Build Coastguard Worker {_differentiable_doc} 217*da0073e9SAndroid Build Coastguard Worker {_capturable_doc} 218*da0073e9SAndroid Build Coastguard Worker 219*da0073e9SAndroid Build Coastguard Worker .. _Adam\: A Method for Stochastic Optimization: 220*da0073e9SAndroid Build Coastguard Worker https://arxiv.org/abs/1412.6980 221*da0073e9SAndroid Build Coastguard Worker 222*da0073e9SAndroid Build Coastguard Worker """ 223*da0073e9SAndroid Build Coastguard Worker) 224*da0073e9SAndroid Build Coastguard Worker 225*da0073e9SAndroid Build Coastguard Worker 226*da0073e9SAndroid Build Coastguard Workerdef _single_tensor_adamax( 227*da0073e9SAndroid Build Coastguard Worker params: List[Tensor], 228*da0073e9SAndroid Build Coastguard Worker grads: List[Tensor], 229*da0073e9SAndroid Build Coastguard Worker exp_avgs: List[Tensor], 230*da0073e9SAndroid Build Coastguard Worker exp_infs: List[Tensor], 231*da0073e9SAndroid Build Coastguard Worker state_steps: List[Tensor], 232*da0073e9SAndroid Build Coastguard Worker *, 233*da0073e9SAndroid Build Coastguard Worker eps: float, 234*da0073e9SAndroid Build Coastguard Worker beta1: float, 235*da0073e9SAndroid Build Coastguard Worker beta2: float, 236*da0073e9SAndroid Build Coastguard Worker lr: float, 237*da0073e9SAndroid Build Coastguard Worker weight_decay: float, 238*da0073e9SAndroid Build Coastguard Worker maximize: bool, 239*da0073e9SAndroid Build Coastguard Worker differentiable: bool, 240*da0073e9SAndroid Build Coastguard Worker capturable: bool, 241*da0073e9SAndroid Build Coastguard Worker has_complex: bool, 242*da0073e9SAndroid Build Coastguard Worker): 243*da0073e9SAndroid Build Coastguard Worker for i, param in enumerate(params): 244*da0073e9SAndroid Build Coastguard Worker grad = grads[i] 245*da0073e9SAndroid Build Coastguard Worker grad = grad if not maximize else -grad 246*da0073e9SAndroid Build Coastguard Worker exp_avg = exp_avgs[i] 247*da0073e9SAndroid Build Coastguard Worker exp_inf = exp_infs[i] 248*da0073e9SAndroid Build Coastguard Worker step_t = state_steps[i] 249*da0073e9SAndroid Build Coastguard Worker 250*da0073e9SAndroid Build Coastguard Worker # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] 251*da0073e9SAndroid Build Coastguard Worker if not torch._utils.is_compiling() and capturable: 252*da0073e9SAndroid Build Coastguard Worker capturable_supported_devices = _get_capturable_supported_devices() 253*da0073e9SAndroid Build Coastguard Worker assert ( 254*da0073e9SAndroid Build Coastguard Worker param.device.type == step_t.device.type 255*da0073e9SAndroid Build Coastguard Worker and param.device.type in capturable_supported_devices 256*da0073e9SAndroid Build Coastguard Worker ), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}." 257*da0073e9SAndroid Build Coastguard Worker 258*da0073e9SAndroid Build Coastguard Worker # update step 259*da0073e9SAndroid Build Coastguard Worker step_t += 1 260*da0073e9SAndroid Build Coastguard Worker 261*da0073e9SAndroid Build Coastguard Worker if weight_decay != 0: 262*da0073e9SAndroid Build Coastguard Worker grad = grad.add(param, alpha=weight_decay) 263*da0073e9SAndroid Build Coastguard Worker 264*da0073e9SAndroid Build Coastguard Worker if torch.is_complex(param): 265*da0073e9SAndroid Build Coastguard Worker param = torch.view_as_real(param) 266*da0073e9SAndroid Build Coastguard Worker grad = torch.view_as_real(grad) 267*da0073e9SAndroid Build Coastguard Worker exp_avg = torch.view_as_real(exp_avg) 268*da0073e9SAndroid Build Coastguard Worker exp_inf = torch.view_as_real(exp_inf) 269*da0073e9SAndroid Build Coastguard Worker 270*da0073e9SAndroid Build Coastguard Worker # Update biased first moment estimate. 271*da0073e9SAndroid Build Coastguard Worker exp_avg.lerp_(grad, 1 - beta1) 272*da0073e9SAndroid Build Coastguard Worker # Update the exponentially weighted infinity norm. 273*da0073e9SAndroid Build Coastguard Worker if not differentiable: 274*da0073e9SAndroid Build Coastguard Worker torch.maximum( 275*da0073e9SAndroid Build Coastguard Worker exp_inf.mul_(beta2), 276*da0073e9SAndroid Build Coastguard Worker grad.abs().add_(eps), 277*da0073e9SAndroid Build Coastguard Worker out=exp_inf, 278*da0073e9SAndroid Build Coastguard Worker ) 279*da0073e9SAndroid Build Coastguard Worker else: 280*da0073e9SAndroid Build Coastguard Worker norm_buf = torch.cat( 281*da0073e9SAndroid Build Coastguard Worker [exp_inf.mul_(beta2).unsqueeze(0), grad.abs().add_(eps).unsqueeze_(0)], 282*da0073e9SAndroid Build Coastguard Worker 0, 283*da0073e9SAndroid Build Coastguard Worker ) 284*da0073e9SAndroid Build Coastguard Worker exp_inf.copy_(torch.amax(norm_buf, 0, keepdim=False)) 285*da0073e9SAndroid Build Coastguard Worker 286*da0073e9SAndroid Build Coastguard Worker if capturable: 287*da0073e9SAndroid Build Coastguard Worker # why jump through extra hoops and negate bias_correction? check out #121238 288*da0073e9SAndroid Build Coastguard Worker # once fixed, we should use bias_correction with addcdiv value=-1 for readability 289*da0073e9SAndroid Build Coastguard Worker neg_bias_correction = beta1**step_t - 1 290*da0073e9SAndroid Build Coastguard Worker neg_bias_correction.div_(lr) 291*da0073e9SAndroid Build Coastguard Worker denom = exp_inf * neg_bias_correction 292*da0073e9SAndroid Build Coastguard Worker param.addcdiv_(exp_avg, denom) 293*da0073e9SAndroid Build Coastguard Worker else: 294*da0073e9SAndroid Build Coastguard Worker bias_correction = 1 - beta1 ** _get_value(step_t) 295*da0073e9SAndroid Build Coastguard Worker clr = lr / bias_correction 296*da0073e9SAndroid Build Coastguard Worker 297*da0073e9SAndroid Build Coastguard Worker param.addcdiv_(exp_avg, exp_inf, value=-clr) 298*da0073e9SAndroid Build Coastguard Worker 299*da0073e9SAndroid Build Coastguard Worker 300*da0073e9SAndroid Build Coastguard Workerdef _multi_tensor_adamax( 301*da0073e9SAndroid Build Coastguard Worker params: List[Tensor], 302*da0073e9SAndroid Build Coastguard Worker grads: List[Tensor], 303*da0073e9SAndroid Build Coastguard Worker exp_avgs: List[Tensor], 304*da0073e9SAndroid Build Coastguard Worker exp_infs: List[Tensor], 305*da0073e9SAndroid Build Coastguard Worker state_steps: List[Tensor], 306*da0073e9SAndroid Build Coastguard Worker *, 307*da0073e9SAndroid Build Coastguard Worker eps: float, 308*da0073e9SAndroid Build Coastguard Worker beta1: float, 309*da0073e9SAndroid Build Coastguard Worker beta2: float, 310*da0073e9SAndroid Build Coastguard Worker lr: float, 311*da0073e9SAndroid Build Coastguard Worker weight_decay: float, 312*da0073e9SAndroid Build Coastguard Worker maximize: bool, 313*da0073e9SAndroid Build Coastguard Worker differentiable: bool, 314*da0073e9SAndroid Build Coastguard Worker capturable: bool, 315*da0073e9SAndroid Build Coastguard Worker has_complex: bool, 316*da0073e9SAndroid Build Coastguard Worker): 317*da0073e9SAndroid Build Coastguard Worker assert not differentiable, "_foreach ops don't support autograd" 318*da0073e9SAndroid Build Coastguard Worker 319*da0073e9SAndroid Build Coastguard Worker if len(params) == 0: 320*da0073e9SAndroid Build Coastguard Worker return 321*da0073e9SAndroid Build Coastguard Worker 322*da0073e9SAndroid Build Coastguard Worker # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] 323*da0073e9SAndroid Build Coastguard Worker if not torch._utils.is_compiling() and capturable: 324*da0073e9SAndroid Build Coastguard Worker capturable_supported_devices = _get_capturable_supported_devices( 325*da0073e9SAndroid Build Coastguard Worker supports_xla=False 326*da0073e9SAndroid Build Coastguard Worker ) 327*da0073e9SAndroid Build Coastguard Worker assert all( 328*da0073e9SAndroid Build Coastguard Worker p.device.type == step.device.type 329*da0073e9SAndroid Build Coastguard Worker and p.device.type in capturable_supported_devices 330*da0073e9SAndroid Build Coastguard Worker for p, step in zip(params, state_steps) 331*da0073e9SAndroid Build Coastguard Worker ), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}." 332*da0073e9SAndroid Build Coastguard Worker 333*da0073e9SAndroid Build Coastguard Worker grouped_tensors = Optimizer._group_tensors_by_device_and_dtype( 334*da0073e9SAndroid Build Coastguard Worker [params, grads, exp_avgs, exp_infs, state_steps] # type: ignore[list-item] 335*da0073e9SAndroid Build Coastguard Worker ) 336*da0073e9SAndroid Build Coastguard Worker for ( 337*da0073e9SAndroid Build Coastguard Worker grouped_params_, 338*da0073e9SAndroid Build Coastguard Worker grouped_grads_, 339*da0073e9SAndroid Build Coastguard Worker grouped_exp_avgs_, 340*da0073e9SAndroid Build Coastguard Worker grouped_exp_infs_, 341*da0073e9SAndroid Build Coastguard Worker grouped_state_steps_, 342*da0073e9SAndroid Build Coastguard Worker ), _ in grouped_tensors.values(): 343*da0073e9SAndroid Build Coastguard Worker grouped_params = cast(List[Tensor], grouped_params_) 344*da0073e9SAndroid Build Coastguard Worker grouped_grads = cast(List[Tensor], grouped_grads_) 345*da0073e9SAndroid Build Coastguard Worker grouped_exp_avgs = cast(List[Tensor], grouped_exp_avgs_) 346*da0073e9SAndroid Build Coastguard Worker grouped_exp_infs = cast(List[Tensor], grouped_exp_infs_) 347*da0073e9SAndroid Build Coastguard Worker grouped_state_steps = cast(List[Tensor], grouped_state_steps_) 348*da0073e9SAndroid Build Coastguard Worker 349*da0073e9SAndroid Build Coastguard Worker if has_complex: 350*da0073e9SAndroid Build Coastguard Worker _view_as_real( 351*da0073e9SAndroid Build Coastguard Worker grouped_params, grouped_grads, grouped_exp_avgs, grouped_exp_infs 352*da0073e9SAndroid Build Coastguard Worker ) 353*da0073e9SAndroid Build Coastguard Worker 354*da0073e9SAndroid Build Coastguard Worker if maximize: 355*da0073e9SAndroid Build Coastguard Worker grouped_grads = torch._foreach_neg(grouped_grads) # type: ignore[assignment] 356*da0073e9SAndroid Build Coastguard Worker 357*da0073e9SAndroid Build Coastguard Worker # Update steps 358*da0073e9SAndroid Build Coastguard Worker # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over 359*da0073e9SAndroid Build Coastguard Worker # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just 360*da0073e9SAndroid Build Coastguard Worker # wrapped it once now. The alpha is required to assure we go to the right overload. 361*da0073e9SAndroid Build Coastguard Worker if not torch._utils.is_compiling() and grouped_state_steps[0].is_cpu: 362*da0073e9SAndroid Build Coastguard Worker torch._foreach_add_( 363*da0073e9SAndroid Build Coastguard Worker grouped_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0 364*da0073e9SAndroid Build Coastguard Worker ) 365*da0073e9SAndroid Build Coastguard Worker else: 366*da0073e9SAndroid Build Coastguard Worker torch._foreach_add_(grouped_state_steps, 1) 367*da0073e9SAndroid Build Coastguard Worker 368*da0073e9SAndroid Build Coastguard Worker if weight_decay != 0: 369*da0073e9SAndroid Build Coastguard Worker if maximize: 370*da0073e9SAndroid Build Coastguard Worker # Re-use the intermediate memory (grouped_grads) already allocated for maximize 371*da0073e9SAndroid Build Coastguard Worker torch._foreach_add_(grouped_grads, grouped_params, alpha=weight_decay) 372*da0073e9SAndroid Build Coastguard Worker else: 373*da0073e9SAndroid Build Coastguard Worker grouped_grads = torch._foreach_add( # type: ignore[assignment] 374*da0073e9SAndroid Build Coastguard Worker grouped_grads, grouped_params, alpha=weight_decay 375*da0073e9SAndroid Build Coastguard Worker ) 376*da0073e9SAndroid Build Coastguard Worker 377*da0073e9SAndroid Build Coastguard Worker # Update biased first moment estimate. 378*da0073e9SAndroid Build Coastguard Worker torch._foreach_lerp_(grouped_exp_avgs, grouped_grads, 1 - beta1) 379*da0073e9SAndroid Build Coastguard Worker 380*da0073e9SAndroid Build Coastguard Worker # Update the exponentially weighted infinity norm. 381*da0073e9SAndroid Build Coastguard Worker torch._foreach_mul_(grouped_exp_infs, beta2) 382*da0073e9SAndroid Build Coastguard Worker 383*da0073e9SAndroid Build Coastguard Worker # in this case, we need to introduce a copy of the grads 384*da0073e9SAndroid Build Coastguard Worker # since one has not been introduced previously 385*da0073e9SAndroid Build Coastguard Worker if not maximize and weight_decay == 0: 386*da0073e9SAndroid Build Coastguard Worker grouped_grads = torch._foreach_abs(grouped_grads) # type: ignore[assignment] 387*da0073e9SAndroid Build Coastguard Worker else: 388*da0073e9SAndroid Build Coastguard Worker torch._foreach_abs_(grouped_grads) 389*da0073e9SAndroid Build Coastguard Worker 390*da0073e9SAndroid Build Coastguard Worker torch._foreach_add_(grouped_grads, eps) 391*da0073e9SAndroid Build Coastguard Worker torch._foreach_maximum_(grouped_exp_infs, grouped_grads) 392*da0073e9SAndroid Build Coastguard Worker 393*da0073e9SAndroid Build Coastguard Worker bias_corrections: Union[Tuple[Tensor, ...], List[Tensor]] 394*da0073e9SAndroid Build Coastguard Worker if capturable: 395*da0073e9SAndroid Build Coastguard Worker bias_corrections = torch._foreach_pow(beta1, grouped_state_steps) 396*da0073e9SAndroid Build Coastguard Worker # foreach_sub doesn't allow a scalar as the first arg 397*da0073e9SAndroid Build Coastguard Worker torch._foreach_sub_(bias_corrections, 1) 398*da0073e9SAndroid Build Coastguard Worker torch._foreach_div_(bias_corrections, lr) 399*da0073e9SAndroid Build Coastguard Worker 400*da0073e9SAndroid Build Coastguard Worker denom = torch._foreach_mul(grouped_exp_infs, bias_corrections) 401*da0073e9SAndroid Build Coastguard Worker torch._foreach_addcdiv_(grouped_params, grouped_exp_avgs, denom) 402*da0073e9SAndroid Build Coastguard Worker else: 403*da0073e9SAndroid Build Coastguard Worker bias_corrections = [ 404*da0073e9SAndroid Build Coastguard Worker 1 - beta1 ** _get_value(step) for step in grouped_state_steps 405*da0073e9SAndroid Build Coastguard Worker ] 406*da0073e9SAndroid Build Coastguard Worker step_size = [(_get_value(lr) / bc) * -1 for bc in bias_corrections] 407*da0073e9SAndroid Build Coastguard Worker torch._foreach_addcdiv_( 408*da0073e9SAndroid Build Coastguard Worker grouped_params, grouped_exp_avgs, grouped_exp_infs, step_size 409*da0073e9SAndroid Build Coastguard Worker ) 410*da0073e9SAndroid Build Coastguard Worker 411*da0073e9SAndroid Build Coastguard Worker 412*da0073e9SAndroid Build Coastguard Worker@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_adamax) 413*da0073e9SAndroid Build Coastguard Workerdef adamax( 414*da0073e9SAndroid Build Coastguard Worker params: List[Tensor], 415*da0073e9SAndroid Build Coastguard Worker grads: List[Tensor], 416*da0073e9SAndroid Build Coastguard Worker exp_avgs: List[Tensor], 417*da0073e9SAndroid Build Coastguard Worker exp_infs: List[Tensor], 418*da0073e9SAndroid Build Coastguard Worker state_steps: List[Tensor], 419*da0073e9SAndroid Build Coastguard Worker # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 420*da0073e9SAndroid Build Coastguard Worker # setting this as kwarg for now as functional API is compiled by torch/distributed/optim 421*da0073e9SAndroid Build Coastguard Worker foreach: Optional[bool] = None, 422*da0073e9SAndroid Build Coastguard Worker maximize: bool = False, 423*da0073e9SAndroid Build Coastguard Worker differentiable: bool = False, 424*da0073e9SAndroid Build Coastguard Worker capturable: bool = False, 425*da0073e9SAndroid Build Coastguard Worker has_complex: bool = False, 426*da0073e9SAndroid Build Coastguard Worker *, 427*da0073e9SAndroid Build Coastguard Worker eps: float, 428*da0073e9SAndroid Build Coastguard Worker beta1: float, 429*da0073e9SAndroid Build Coastguard Worker beta2: float, 430*da0073e9SAndroid Build Coastguard Worker lr: float, 431*da0073e9SAndroid Build Coastguard Worker weight_decay: float, 432*da0073e9SAndroid Build Coastguard Worker): 433*da0073e9SAndroid Build Coastguard Worker r"""Functional API that performs adamax algorithm computation. 434*da0073e9SAndroid Build Coastguard Worker 435*da0073e9SAndroid Build Coastguard Worker See :class:`~torch.optim.Adamax` for details. 436*da0073e9SAndroid Build Coastguard Worker """ 437*da0073e9SAndroid Build Coastguard Worker 438*da0073e9SAndroid Build Coastguard Worker if not torch._utils.is_compiling() and not all( 439*da0073e9SAndroid Build Coastguard Worker isinstance(t, torch.Tensor) for t in state_steps 440*da0073e9SAndroid Build Coastguard Worker ): 441*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 442*da0073e9SAndroid Build Coastguard Worker "API has changed, `state_steps` argument must contain a list of singleton tensors" 443*da0073e9SAndroid Build Coastguard Worker ) 444*da0073e9SAndroid Build Coastguard Worker 445*da0073e9SAndroid Build Coastguard Worker if foreach is None: 446*da0073e9SAndroid Build Coastguard Worker _, foreach = _default_to_fused_or_foreach( 447*da0073e9SAndroid Build Coastguard Worker params, differentiable, use_fused=False 448*da0073e9SAndroid Build Coastguard Worker ) 449*da0073e9SAndroid Build Coastguard Worker 450*da0073e9SAndroid Build Coastguard Worker if foreach and torch.jit.is_scripting(): 451*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("torch.jit.script not supported with foreach optimizers") 452*da0073e9SAndroid Build Coastguard Worker 453*da0073e9SAndroid Build Coastguard Worker if foreach and not torch.jit.is_scripting(): 454*da0073e9SAndroid Build Coastguard Worker func = _multi_tensor_adamax 455*da0073e9SAndroid Build Coastguard Worker else: 456*da0073e9SAndroid Build Coastguard Worker func = _single_tensor_adamax 457*da0073e9SAndroid Build Coastguard Worker 458*da0073e9SAndroid Build Coastguard Worker func( 459*da0073e9SAndroid Build Coastguard Worker params, 460*da0073e9SAndroid Build Coastguard Worker grads, 461*da0073e9SAndroid Build Coastguard Worker exp_avgs, 462*da0073e9SAndroid Build Coastguard Worker exp_infs, 463*da0073e9SAndroid Build Coastguard Worker state_steps, 464*da0073e9SAndroid Build Coastguard Worker eps=eps, 465*da0073e9SAndroid Build Coastguard Worker beta1=beta1, 466*da0073e9SAndroid Build Coastguard Worker beta2=beta2, 467*da0073e9SAndroid Build Coastguard Worker lr=lr, 468*da0073e9SAndroid Build Coastguard Worker weight_decay=weight_decay, 469*da0073e9SAndroid Build Coastguard Worker maximize=maximize, 470*da0073e9SAndroid Build Coastguard Worker differentiable=differentiable, 471*da0073e9SAndroid Build Coastguard Worker has_complex=has_complex, 472*da0073e9SAndroid Build Coastguard Worker capturable=capturable, 473*da0073e9SAndroid Build Coastguard Worker ) 474