1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-decorators 2*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs 3*da0073e9SAndroid Build Coastguard Workerr"""Implementation for the NAdam algorithm.""" 4*da0073e9SAndroid Build Coastguard Workerfrom typing import cast, List, Optional, Tuple, Union 5*da0073e9SAndroid Build Coastguard Worker 6*da0073e9SAndroid Build Coastguard Workerimport torch 7*da0073e9SAndroid Build Coastguard Workerfrom torch import Tensor 8*da0073e9SAndroid Build Coastguard Worker 9*da0073e9SAndroid Build Coastguard Workerfrom .optimizer import ( 10*da0073e9SAndroid Build Coastguard Worker _capturable_doc, 11*da0073e9SAndroid Build Coastguard Worker _default_to_fused_or_foreach, 12*da0073e9SAndroid Build Coastguard Worker _differentiable_doc, 13*da0073e9SAndroid Build Coastguard Worker _disable_dynamo_if_unsupported, 14*da0073e9SAndroid Build Coastguard Worker _foreach_doc, 15*da0073e9SAndroid Build Coastguard Worker _get_capturable_supported_devices, 16*da0073e9SAndroid Build Coastguard Worker _get_scalar_dtype, 17*da0073e9SAndroid Build Coastguard Worker _get_value, 18*da0073e9SAndroid Build Coastguard Worker _maximize_doc, 19*da0073e9SAndroid Build Coastguard Worker _stack_if_compiling, 20*da0073e9SAndroid Build Coastguard Worker _use_grad_for_differentiable, 21*da0073e9SAndroid Build Coastguard Worker _view_as_real, 22*da0073e9SAndroid Build Coastguard Worker Optimizer, 23*da0073e9SAndroid Build Coastguard Worker ParamsT, 24*da0073e9SAndroid Build Coastguard Worker) 25*da0073e9SAndroid Build Coastguard Worker 26*da0073e9SAndroid Build Coastguard Worker 27*da0073e9SAndroid Build Coastguard Worker__all__ = ["NAdam", "nadam"] 28*da0073e9SAndroid Build Coastguard Worker 29*da0073e9SAndroid Build Coastguard Worker 30*da0073e9SAndroid Build Coastguard Workerclass NAdam(Optimizer): # noqa: D101 31*da0073e9SAndroid Build Coastguard Worker def __init__( 32*da0073e9SAndroid Build Coastguard Worker self, 33*da0073e9SAndroid Build Coastguard Worker params: ParamsT, 34*da0073e9SAndroid Build Coastguard Worker lr: Union[float, Tensor] = 2e-3, 35*da0073e9SAndroid Build Coastguard Worker betas: Tuple[float, float] = (0.9, 0.999), 36*da0073e9SAndroid Build Coastguard Worker eps: float = 1e-8, 37*da0073e9SAndroid Build Coastguard Worker weight_decay: float = 0, 38*da0073e9SAndroid Build Coastguard Worker momentum_decay: float = 4e-3, 39*da0073e9SAndroid Build Coastguard Worker decoupled_weight_decay: bool = False, 40*da0073e9SAndroid Build Coastguard Worker *, 41*da0073e9SAndroid Build Coastguard Worker foreach: Optional[bool] = None, 42*da0073e9SAndroid Build Coastguard Worker maximize: bool = False, 43*da0073e9SAndroid Build Coastguard Worker capturable: bool = False, 44*da0073e9SAndroid Build Coastguard Worker differentiable: bool = False, 45*da0073e9SAndroid Build Coastguard Worker ): # noqa: D107 46*da0073e9SAndroid Build Coastguard Worker if isinstance(lr, Tensor) and lr.numel() != 1: 47*da0073e9SAndroid Build Coastguard Worker raise ValueError("Tensor lr must be 1-element") 48*da0073e9SAndroid Build Coastguard Worker if not 0.0 <= lr: 49*da0073e9SAndroid Build Coastguard Worker raise ValueError(f"Invalid learning rate: {lr}") 50*da0073e9SAndroid Build Coastguard Worker if not 0.0 <= eps: 51*da0073e9SAndroid Build Coastguard Worker raise ValueError(f"Invalid epsilon value: {eps}") 52*da0073e9SAndroid Build Coastguard Worker if not 0.0 <= betas[0] < 1.0: 53*da0073e9SAndroid Build Coastguard Worker raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}") 54*da0073e9SAndroid Build Coastguard Worker if not 0.0 <= betas[1] < 1.0: 55*da0073e9SAndroid Build Coastguard Worker raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}") 56*da0073e9SAndroid Build Coastguard Worker if not 0.0 <= weight_decay: 57*da0073e9SAndroid Build Coastguard Worker raise ValueError(f"Invalid weight_decay value: {weight_decay}") 58*da0073e9SAndroid Build Coastguard Worker if not 0.0 <= momentum_decay: 59*da0073e9SAndroid Build Coastguard Worker raise ValueError(f"Invalid momentum_decay value: {momentum_decay}") 60*da0073e9SAndroid Build Coastguard Worker defaults = dict( 61*da0073e9SAndroid Build Coastguard Worker lr=lr, 62*da0073e9SAndroid Build Coastguard Worker betas=betas, 63*da0073e9SAndroid Build Coastguard Worker eps=eps, 64*da0073e9SAndroid Build Coastguard Worker weight_decay=weight_decay, 65*da0073e9SAndroid Build Coastguard Worker momentum_decay=momentum_decay, 66*da0073e9SAndroid Build Coastguard Worker decoupled_weight_decay=decoupled_weight_decay, 67*da0073e9SAndroid Build Coastguard Worker maximize=maximize, 68*da0073e9SAndroid Build Coastguard Worker foreach=foreach, 69*da0073e9SAndroid Build Coastguard Worker capturable=capturable, 70*da0073e9SAndroid Build Coastguard Worker differentiable=differentiable, 71*da0073e9SAndroid Build Coastguard Worker ) 72*da0073e9SAndroid Build Coastguard Worker super().__init__(params, defaults) 73*da0073e9SAndroid Build Coastguard Worker 74*da0073e9SAndroid Build Coastguard Worker def __setstate__(self, state): # noqa: D105 75*da0073e9SAndroid Build Coastguard Worker super().__setstate__(state) 76*da0073e9SAndroid Build Coastguard Worker for group in self.param_groups: 77*da0073e9SAndroid Build Coastguard Worker group.setdefault("maximize", False) 78*da0073e9SAndroid Build Coastguard Worker group.setdefault("foreach", None) 79*da0073e9SAndroid Build Coastguard Worker group.setdefault("capturable", False) 80*da0073e9SAndroid Build Coastguard Worker group.setdefault("differentiable", False) 81*da0073e9SAndroid Build Coastguard Worker group.setdefault("decoupled_weight_decay", False) 82*da0073e9SAndroid Build Coastguard Worker for p in group["params"]: 83*da0073e9SAndroid Build Coastguard Worker p_state = self.state.get(p, []) 84*da0073e9SAndroid Build Coastguard Worker if len(p_state) != 0: 85*da0073e9SAndroid Build Coastguard Worker if not torch.is_tensor(p_state["step"]): 86*da0073e9SAndroid Build Coastguard Worker step_val = float(p_state["step"]) 87*da0073e9SAndroid Build Coastguard Worker p_state["step"] = ( 88*da0073e9SAndroid Build Coastguard Worker torch.tensor( 89*da0073e9SAndroid Build Coastguard Worker step_val, dtype=_get_scalar_dtype(), device=p.device 90*da0073e9SAndroid Build Coastguard Worker ) 91*da0073e9SAndroid Build Coastguard Worker if group["capturable"] 92*da0073e9SAndroid Build Coastguard Worker else torch.tensor(step_val, dtype=_get_scalar_dtype()) 93*da0073e9SAndroid Build Coastguard Worker ) 94*da0073e9SAndroid Build Coastguard Worker if not torch.is_tensor(p_state["mu_product"]): 95*da0073e9SAndroid Build Coastguard Worker mu_prod_val = p_state["mu_product"] 96*da0073e9SAndroid Build Coastguard Worker p_state["mu_product"] = ( 97*da0073e9SAndroid Build Coastguard Worker torch.tensor( 98*da0073e9SAndroid Build Coastguard Worker mu_prod_val, dtype=_get_scalar_dtype(), device=p.device 99*da0073e9SAndroid Build Coastguard Worker ) 100*da0073e9SAndroid Build Coastguard Worker if group["capturable"] 101*da0073e9SAndroid Build Coastguard Worker else torch.tensor(mu_prod_val, dtype=_get_scalar_dtype()) 102*da0073e9SAndroid Build Coastguard Worker ) 103*da0073e9SAndroid Build Coastguard Worker 104*da0073e9SAndroid Build Coastguard Worker def _init_group( 105*da0073e9SAndroid Build Coastguard Worker self, 106*da0073e9SAndroid Build Coastguard Worker group, 107*da0073e9SAndroid Build Coastguard Worker params_with_grad, 108*da0073e9SAndroid Build Coastguard Worker grads, 109*da0073e9SAndroid Build Coastguard Worker exp_avgs, 110*da0073e9SAndroid Build Coastguard Worker exp_avg_sqs, 111*da0073e9SAndroid Build Coastguard Worker mu_products, 112*da0073e9SAndroid Build Coastguard Worker state_steps, 113*da0073e9SAndroid Build Coastguard Worker ): 114*da0073e9SAndroid Build Coastguard Worker has_complex = False 115*da0073e9SAndroid Build Coastguard Worker for p in group["params"]: 116*da0073e9SAndroid Build Coastguard Worker if p.grad is not None: 117*da0073e9SAndroid Build Coastguard Worker has_complex |= torch.is_complex(p) 118*da0073e9SAndroid Build Coastguard Worker params_with_grad.append(p) 119*da0073e9SAndroid Build Coastguard Worker if p.grad.is_sparse: 120*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("NAdam does not support sparse gradients") 121*da0073e9SAndroid Build Coastguard Worker grads.append(p.grad) 122*da0073e9SAndroid Build Coastguard Worker 123*da0073e9SAndroid Build Coastguard Worker state = self.state[p] 124*da0073e9SAndroid Build Coastguard Worker # Lazy state initialization 125*da0073e9SAndroid Build Coastguard Worker if len(state) == 0: 126*da0073e9SAndroid Build Coastguard Worker # note(crcrpar): [special device hosting for step] 127*da0073e9SAndroid Build Coastguard Worker # Deliberately host `step` and `mu_product` on CPU if capturable is False. 128*da0073e9SAndroid Build Coastguard Worker # This is because kernel launches are costly on CUDA and XLA. 129*da0073e9SAndroid Build Coastguard Worker state["step"] = ( 130*da0073e9SAndroid Build Coastguard Worker torch.zeros((), dtype=_get_scalar_dtype(), device=p.device) 131*da0073e9SAndroid Build Coastguard Worker if group["capturable"] 132*da0073e9SAndroid Build Coastguard Worker else torch.tensor(0.0, dtype=_get_scalar_dtype()) 133*da0073e9SAndroid Build Coastguard Worker ) 134*da0073e9SAndroid Build Coastguard Worker state["mu_product"] = ( 135*da0073e9SAndroid Build Coastguard Worker torch.ones((), dtype=_get_scalar_dtype(), device=p.device) 136*da0073e9SAndroid Build Coastguard Worker if group["capturable"] 137*da0073e9SAndroid Build Coastguard Worker else torch.tensor(1.0, dtype=_get_scalar_dtype()) 138*da0073e9SAndroid Build Coastguard Worker ) 139*da0073e9SAndroid Build Coastguard Worker # Exponential moving average of gradient values 140*da0073e9SAndroid Build Coastguard Worker state["exp_avg"] = torch.zeros_like( 141*da0073e9SAndroid Build Coastguard Worker p, memory_format=torch.preserve_format 142*da0073e9SAndroid Build Coastguard Worker ) 143*da0073e9SAndroid Build Coastguard Worker # Exponential moving average of squared gradient values 144*da0073e9SAndroid Build Coastguard Worker state["exp_avg_sq"] = torch.zeros_like( 145*da0073e9SAndroid Build Coastguard Worker p, memory_format=torch.preserve_format 146*da0073e9SAndroid Build Coastguard Worker ) 147*da0073e9SAndroid Build Coastguard Worker 148*da0073e9SAndroid Build Coastguard Worker exp_avgs.append(state["exp_avg"]) 149*da0073e9SAndroid Build Coastguard Worker exp_avg_sqs.append(state["exp_avg_sq"]) 150*da0073e9SAndroid Build Coastguard Worker mu_products.append(state["mu_product"]) 151*da0073e9SAndroid Build Coastguard Worker state_steps.append(state["step"]) 152*da0073e9SAndroid Build Coastguard Worker return has_complex 153*da0073e9SAndroid Build Coastguard Worker 154*da0073e9SAndroid Build Coastguard Worker @_use_grad_for_differentiable 155*da0073e9SAndroid Build Coastguard Worker def step(self, closure=None): 156*da0073e9SAndroid Build Coastguard Worker """Perform a single optimization step. 157*da0073e9SAndroid Build Coastguard Worker 158*da0073e9SAndroid Build Coastguard Worker Args: 159*da0073e9SAndroid Build Coastguard Worker closure (Callable, optional): A closure that reevaluates the model 160*da0073e9SAndroid Build Coastguard Worker and returns the loss. 161*da0073e9SAndroid Build Coastguard Worker """ 162*da0073e9SAndroid Build Coastguard Worker self._cuda_graph_capture_health_check() 163*da0073e9SAndroid Build Coastguard Worker 164*da0073e9SAndroid Build Coastguard Worker loss = None 165*da0073e9SAndroid Build Coastguard Worker if closure is not None: 166*da0073e9SAndroid Build Coastguard Worker with torch.enable_grad(): 167*da0073e9SAndroid Build Coastguard Worker loss = closure() 168*da0073e9SAndroid Build Coastguard Worker 169*da0073e9SAndroid Build Coastguard Worker for group in self.param_groups: 170*da0073e9SAndroid Build Coastguard Worker params_with_grad: List[Tensor] = [] 171*da0073e9SAndroid Build Coastguard Worker grads: List[Tensor] = [] 172*da0073e9SAndroid Build Coastguard Worker exp_avgs: List[Tensor] = [] 173*da0073e9SAndroid Build Coastguard Worker exp_avg_sqs: List[Tensor] = [] 174*da0073e9SAndroid Build Coastguard Worker mu_products: List[Tensor] = [] 175*da0073e9SAndroid Build Coastguard Worker state_steps: List[Tensor] = [] 176*da0073e9SAndroid Build Coastguard Worker beta1, beta2 = cast(Tuple[float, float], group["betas"]) 177*da0073e9SAndroid Build Coastguard Worker 178*da0073e9SAndroid Build Coastguard Worker has_complex = self._init_group( 179*da0073e9SAndroid Build Coastguard Worker group, 180*da0073e9SAndroid Build Coastguard Worker params_with_grad, 181*da0073e9SAndroid Build Coastguard Worker grads, 182*da0073e9SAndroid Build Coastguard Worker exp_avgs, 183*da0073e9SAndroid Build Coastguard Worker exp_avg_sqs, 184*da0073e9SAndroid Build Coastguard Worker mu_products, 185*da0073e9SAndroid Build Coastguard Worker state_steps, 186*da0073e9SAndroid Build Coastguard Worker ) 187*da0073e9SAndroid Build Coastguard Worker 188*da0073e9SAndroid Build Coastguard Worker nadam( 189*da0073e9SAndroid Build Coastguard Worker params_with_grad, 190*da0073e9SAndroid Build Coastguard Worker grads, 191*da0073e9SAndroid Build Coastguard Worker exp_avgs, 192*da0073e9SAndroid Build Coastguard Worker exp_avg_sqs, 193*da0073e9SAndroid Build Coastguard Worker mu_products, 194*da0073e9SAndroid Build Coastguard Worker state_steps, 195*da0073e9SAndroid Build Coastguard Worker beta1=beta1, 196*da0073e9SAndroid Build Coastguard Worker beta2=beta2, 197*da0073e9SAndroid Build Coastguard Worker lr=group["lr"], 198*da0073e9SAndroid Build Coastguard Worker weight_decay=group["weight_decay"], 199*da0073e9SAndroid Build Coastguard Worker momentum_decay=group["momentum_decay"], 200*da0073e9SAndroid Build Coastguard Worker eps=group["eps"], 201*da0073e9SAndroid Build Coastguard Worker maximize=group["maximize"], 202*da0073e9SAndroid Build Coastguard Worker decoupled_weight_decay=group["decoupled_weight_decay"], 203*da0073e9SAndroid Build Coastguard Worker foreach=group["foreach"], 204*da0073e9SAndroid Build Coastguard Worker capturable=group["capturable"], 205*da0073e9SAndroid Build Coastguard Worker differentiable=group["differentiable"], 206*da0073e9SAndroid Build Coastguard Worker has_complex=has_complex, 207*da0073e9SAndroid Build Coastguard Worker ) 208*da0073e9SAndroid Build Coastguard Worker 209*da0073e9SAndroid Build Coastguard Worker return loss 210*da0073e9SAndroid Build Coastguard Worker 211*da0073e9SAndroid Build Coastguard Worker 212*da0073e9SAndroid Build Coastguard WorkerNAdam.__doc__ = ( 213*da0073e9SAndroid Build Coastguard Worker r"""Implements NAdam algorithm. 214*da0073e9SAndroid Build Coastguard Worker 215*da0073e9SAndroid Build Coastguard Worker .. math:: 216*da0073e9SAndroid Build Coastguard Worker \begin{aligned} 217*da0073e9SAndroid Build Coastguard Worker &\rule{110mm}{0.4pt} \\ 218*da0073e9SAndroid Build Coastguard Worker &\textbf{input} : \gamma_t \text{ (lr)}, \: \beta_1,\beta_2 \text{ (betas)}, 219*da0073e9SAndroid Build Coastguard Worker \: \theta_0 \text{ (params)}, \: f(\theta) \text{ (objective)} \\ 220*da0073e9SAndroid Build Coastguard Worker &\hspace{13mm} \: \lambda \text{ (weight decay)}, \:\psi \text{ (momentum decay)} \\ 221*da0073e9SAndroid Build Coastguard Worker &\hspace{13mm} \: \textit{decoupled\_weight\_decay}, \:\textit{maximize} \\ 222*da0073e9SAndroid Build Coastguard Worker &\textbf{initialize} : m_0 \leftarrow 0 \text{ ( first moment)}, 223*da0073e9SAndroid Build Coastguard Worker v_0 \leftarrow 0 \text{ ( second moment)} \\[-1.ex] 224*da0073e9SAndroid Build Coastguard Worker &\rule{110mm}{0.4pt} \\ 225*da0073e9SAndroid Build Coastguard Worker &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\ 226*da0073e9SAndroid Build Coastguard Worker &\hspace{5mm}\textbf{if} \: \textit{maximize}: \\ 227*da0073e9SAndroid Build Coastguard Worker &\hspace{10mm}g_t \leftarrow -\nabla_{\theta} f_t (\theta_{t-1}) \\ 228*da0073e9SAndroid Build Coastguard Worker &\hspace{5mm}\textbf{else} \\ 229*da0073e9SAndroid Build Coastguard Worker &\hspace{10mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\ 230*da0073e9SAndroid Build Coastguard Worker &\hspace{5mm} \theta_t \leftarrow \theta_{t-1} \\ 231*da0073e9SAndroid Build Coastguard Worker &\hspace{5mm} \textbf{if} \: \lambda \neq 0 \\ 232*da0073e9SAndroid Build Coastguard Worker &\hspace{10mm}\textbf{if} \: \textit{decoupled\_weight\_decay} \\ 233*da0073e9SAndroid Build Coastguard Worker &\hspace{15mm} \theta_t \leftarrow \theta_{t-1} - \gamma \lambda \theta_{t-1} \\ 234*da0073e9SAndroid Build Coastguard Worker &\hspace{10mm}\textbf{else} \\ 235*da0073e9SAndroid Build Coastguard Worker &\hspace{15mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\ 236*da0073e9SAndroid Build Coastguard Worker &\hspace{5mm} \mu_t \leftarrow \beta_1 \big(1 - \frac{1}{2} 0.96^{t \psi} \big) \\ 237*da0073e9SAndroid Build Coastguard Worker &\hspace{5mm} \mu_{t+1} \leftarrow \beta_1 \big(1 - \frac{1}{2} 0.96^{(t+1)\psi}\big)\\ 238*da0073e9SAndroid Build Coastguard Worker &\hspace{5mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\ 239*da0073e9SAndroid Build Coastguard Worker &\hspace{5mm}v_t \leftarrow \beta_2 v_{t-1} + (1-\beta_2) g^2_t \\ 240*da0073e9SAndroid Build Coastguard Worker &\hspace{5mm}\widehat{m_t} \leftarrow \mu_{t+1} m_t/(1-\prod_{i=1}^{t+1}\mu_i)\\[-1.ex] 241*da0073e9SAndroid Build Coastguard Worker & \hspace{11mm} + (1-\mu_t) g_t /(1-\prod_{i=1}^{t} \mu_{i}) \\ 242*da0073e9SAndroid Build Coastguard Worker &\hspace{5mm}\widehat{v_t} \leftarrow v_t/\big(1-\beta_2^t \big) \\ 243*da0073e9SAndroid Build Coastguard Worker &\hspace{5mm}\theta_t \leftarrow \theta_t - \gamma \widehat{m_t}/ 244*da0073e9SAndroid Build Coastguard Worker \big(\sqrt{\widehat{v_t}} + \epsilon \big) \\ 245*da0073e9SAndroid Build Coastguard Worker &\rule{110mm}{0.4pt} \\[-1.ex] 246*da0073e9SAndroid Build Coastguard Worker &\bf{return} \: \theta_t \\[-1.ex] 247*da0073e9SAndroid Build Coastguard Worker &\rule{110mm}{0.4pt} \\[-1.ex] 248*da0073e9SAndroid Build Coastguard Worker \end{aligned} 249*da0073e9SAndroid Build Coastguard Worker 250*da0073e9SAndroid Build Coastguard Worker For further details regarding the algorithm we refer to `Incorporating Nesterov Momentum into Adam`_. 251*da0073e9SAndroid Build Coastguard Worker """ 252*da0073e9SAndroid Build Coastguard Worker + rf""" 253*da0073e9SAndroid Build Coastguard Worker Args: 254*da0073e9SAndroid Build Coastguard Worker params (iterable): iterable of parameters to optimize or dicts defining 255*da0073e9SAndroid Build Coastguard Worker parameter groups 256*da0073e9SAndroid Build Coastguard Worker lr (float, Tensor, optional): learning rate (default: 2e-3) 257*da0073e9SAndroid Build Coastguard Worker betas (Tuple[float, float], optional): coefficients used for computing 258*da0073e9SAndroid Build Coastguard Worker running averages of gradient and its square (default: (0.9, 0.999)) 259*da0073e9SAndroid Build Coastguard Worker eps (float, optional): term added to the denominator to improve 260*da0073e9SAndroid Build Coastguard Worker numerical stability (default: 1e-8) 261*da0073e9SAndroid Build Coastguard Worker weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 262*da0073e9SAndroid Build Coastguard Worker momentum_decay (float, optional): momentum momentum_decay (default: 4e-3) 263*da0073e9SAndroid Build Coastguard Worker decoupled_weight_decay (bool, optional): whether to use decoupled weight 264*da0073e9SAndroid Build Coastguard Worker decay as in AdamW to obtain NAdamW (default: False) 265*da0073e9SAndroid Build Coastguard Worker {_foreach_doc} 266*da0073e9SAndroid Build Coastguard Worker {_maximize_doc} 267*da0073e9SAndroid Build Coastguard Worker {_capturable_doc} 268*da0073e9SAndroid Build Coastguard Worker {_differentiable_doc} 269*da0073e9SAndroid Build Coastguard Worker 270*da0073e9SAndroid Build Coastguard Worker .. _Incorporating Nesterov Momentum into Adam: 271*da0073e9SAndroid Build Coastguard Worker https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ 272*da0073e9SAndroid Build Coastguard Worker .. _Decoupled Weight Decay Regularization: 273*da0073e9SAndroid Build Coastguard Worker https://arxiv.org/abs/1711.05101 274*da0073e9SAndroid Build Coastguard Worker 275*da0073e9SAndroid Build Coastguard Worker """ 276*da0073e9SAndroid Build Coastguard Worker) 277*da0073e9SAndroid Build Coastguard Worker 278*da0073e9SAndroid Build Coastguard Worker 279*da0073e9SAndroid Build Coastguard Workerdef _single_tensor_nadam( 280*da0073e9SAndroid Build Coastguard Worker params: List[Tensor], 281*da0073e9SAndroid Build Coastguard Worker grads: List[Tensor], 282*da0073e9SAndroid Build Coastguard Worker exp_avgs: List[Tensor], 283*da0073e9SAndroid Build Coastguard Worker exp_avg_sqs: List[Tensor], 284*da0073e9SAndroid Build Coastguard Worker mu_products: List[Tensor], 285*da0073e9SAndroid Build Coastguard Worker state_steps: List[Tensor], 286*da0073e9SAndroid Build Coastguard Worker *, 287*da0073e9SAndroid Build Coastguard Worker beta1: float, 288*da0073e9SAndroid Build Coastguard Worker beta2: float, 289*da0073e9SAndroid Build Coastguard Worker lr: float, 290*da0073e9SAndroid Build Coastguard Worker weight_decay: float, 291*da0073e9SAndroid Build Coastguard Worker momentum_decay: float, 292*da0073e9SAndroid Build Coastguard Worker eps: float, 293*da0073e9SAndroid Build Coastguard Worker decoupled_weight_decay: bool, 294*da0073e9SAndroid Build Coastguard Worker maximize: bool, 295*da0073e9SAndroid Build Coastguard Worker capturable: bool, 296*da0073e9SAndroid Build Coastguard Worker differentiable: bool, 297*da0073e9SAndroid Build Coastguard Worker has_complex: bool, 298*da0073e9SAndroid Build Coastguard Worker): 299*da0073e9SAndroid Build Coastguard Worker for i, param in enumerate(params): 300*da0073e9SAndroid Build Coastguard Worker grad = grads[i] if not maximize else -grads[i] 301*da0073e9SAndroid Build Coastguard Worker exp_avg = exp_avgs[i] 302*da0073e9SAndroid Build Coastguard Worker exp_avg_sq = exp_avg_sqs[i] 303*da0073e9SAndroid Build Coastguard Worker mu_product = mu_products[i] 304*da0073e9SAndroid Build Coastguard Worker step_t = state_steps[i] 305*da0073e9SAndroid Build Coastguard Worker 306*da0073e9SAndroid Build Coastguard Worker if torch.is_complex(param): 307*da0073e9SAndroid Build Coastguard Worker param = torch.view_as_real(param) 308*da0073e9SAndroid Build Coastguard Worker grad = torch.view_as_real(grad) 309*da0073e9SAndroid Build Coastguard Worker exp_avg = torch.view_as_real(exp_avg) 310*da0073e9SAndroid Build Coastguard Worker exp_avg_sq = torch.view_as_real(exp_avg_sq) 311*da0073e9SAndroid Build Coastguard Worker 312*da0073e9SAndroid Build Coastguard Worker # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] 313*da0073e9SAndroid Build Coastguard Worker if not torch._utils.is_compiling() and capturable: 314*da0073e9SAndroid Build Coastguard Worker capturable_supported_devices = _get_capturable_supported_devices() 315*da0073e9SAndroid Build Coastguard Worker assert ( 316*da0073e9SAndroid Build Coastguard Worker param.device.type == mu_product.device.type == step_t.device.type 317*da0073e9SAndroid Build Coastguard Worker and param.device.type in capturable_supported_devices 318*da0073e9SAndroid Build Coastguard Worker ), ( 319*da0073e9SAndroid Build Coastguard Worker f"If capturable=True, params, mu_products and state_steps must be " 320*da0073e9SAndroid Build Coastguard Worker f"on supported devices: {capturable_supported_devices}." 321*da0073e9SAndroid Build Coastguard Worker ) 322*da0073e9SAndroid Build Coastguard Worker 323*da0073e9SAndroid Build Coastguard Worker # update step 324*da0073e9SAndroid Build Coastguard Worker step_t += 1 325*da0073e9SAndroid Build Coastguard Worker 326*da0073e9SAndroid Build Coastguard Worker if capturable: 327*da0073e9SAndroid Build Coastguard Worker step = step_t 328*da0073e9SAndroid Build Coastguard Worker else: 329*da0073e9SAndroid Build Coastguard Worker step = _get_value(step_t) 330*da0073e9SAndroid Build Coastguard Worker 331*da0073e9SAndroid Build Coastguard Worker bias_correction2 = 1 - beta2**step 332*da0073e9SAndroid Build Coastguard Worker 333*da0073e9SAndroid Build Coastguard Worker if weight_decay != 0: 334*da0073e9SAndroid Build Coastguard Worker if decoupled_weight_decay: 335*da0073e9SAndroid Build Coastguard Worker # Perform stepweight decay 336*da0073e9SAndroid Build Coastguard Worker param.mul_(1 - lr * weight_decay) 337*da0073e9SAndroid Build Coastguard Worker else: 338*da0073e9SAndroid Build Coastguard Worker grad = grad.add(param, alpha=weight_decay) 339*da0073e9SAndroid Build Coastguard Worker 340*da0073e9SAndroid Build Coastguard Worker # calculate the momentum cache \mu^{t} and \mu^{t+1} 341*da0073e9SAndroid Build Coastguard Worker mu = beta1 * (1.0 - 0.5 * (0.96 ** (step * momentum_decay))) 342*da0073e9SAndroid Build Coastguard Worker mu_next = beta1 * (1.0 - 0.5 * (0.96 ** ((step + 1) * momentum_decay))) 343*da0073e9SAndroid Build Coastguard Worker 344*da0073e9SAndroid Build Coastguard Worker # update mu_product 345*da0073e9SAndroid Build Coastguard Worker mu_product *= mu 346*da0073e9SAndroid Build Coastguard Worker 347*da0073e9SAndroid Build Coastguard Worker # decay the first and second moment running average coefficient 348*da0073e9SAndroid Build Coastguard Worker exp_avg.lerp_(grad, 1 - beta1) 349*da0073e9SAndroid Build Coastguard Worker exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) 350*da0073e9SAndroid Build Coastguard Worker denom = exp_avg_sq.div(bias_correction2).sqrt() 351*da0073e9SAndroid Build Coastguard Worker 352*da0073e9SAndroid Build Coastguard Worker if differentiable or capturable: 353*da0073e9SAndroid Build Coastguard Worker denom = denom.add(eps) 354*da0073e9SAndroid Build Coastguard Worker # Make autograd track the operations 355*da0073e9SAndroid Build Coastguard Worker # by updating the grad and exp_avg directly and not using the 356*da0073e9SAndroid Build Coastguard Worker # scalar "value" argument of addcdiv. 357*da0073e9SAndroid Build Coastguard Worker mu_product_next = mu_product * mu_next 358*da0073e9SAndroid Build Coastguard Worker grad = grad * (-lr * (1.0 - mu) / (1.0 - mu_product)) 359*da0073e9SAndroid Build Coastguard Worker exp_avg = exp_avg * (-lr * mu_next / (1.0 - mu_product_next)) 360*da0073e9SAndroid Build Coastguard Worker param.addcdiv_(grad, denom) 361*da0073e9SAndroid Build Coastguard Worker param.addcdiv_(exp_avg, denom) 362*da0073e9SAndroid Build Coastguard Worker else: 363*da0073e9SAndroid Build Coastguard Worker mu_product_next = _get_value(mu_product) * mu_next 364*da0073e9SAndroid Build Coastguard Worker denom.add_(eps) 365*da0073e9SAndroid Build Coastguard Worker param.addcdiv_( 366*da0073e9SAndroid Build Coastguard Worker grad, denom, value=(-lr * (1.0 - mu) / (1.0 - _get_value(mu_product))) 367*da0073e9SAndroid Build Coastguard Worker ) 368*da0073e9SAndroid Build Coastguard Worker param.addcdiv_( 369*da0073e9SAndroid Build Coastguard Worker exp_avg, denom, value=(-lr * mu_next) / (1.0 - mu_product_next) 370*da0073e9SAndroid Build Coastguard Worker ) 371*da0073e9SAndroid Build Coastguard Worker 372*da0073e9SAndroid Build Coastguard Worker 373*da0073e9SAndroid Build Coastguard Workerdef _multi_tensor_nadam( 374*da0073e9SAndroid Build Coastguard Worker params: List[Tensor], 375*da0073e9SAndroid Build Coastguard Worker grads: List[Tensor], 376*da0073e9SAndroid Build Coastguard Worker exp_avgs: List[Tensor], 377*da0073e9SAndroid Build Coastguard Worker exp_avg_sqs: List[Tensor], 378*da0073e9SAndroid Build Coastguard Worker mu_products: List[Tensor], 379*da0073e9SAndroid Build Coastguard Worker state_steps: List[Tensor], 380*da0073e9SAndroid Build Coastguard Worker *, 381*da0073e9SAndroid Build Coastguard Worker beta1: float, 382*da0073e9SAndroid Build Coastguard Worker beta2: float, 383*da0073e9SAndroid Build Coastguard Worker lr: float, 384*da0073e9SAndroid Build Coastguard Worker weight_decay: float, 385*da0073e9SAndroid Build Coastguard Worker momentum_decay: float, 386*da0073e9SAndroid Build Coastguard Worker eps: float, 387*da0073e9SAndroid Build Coastguard Worker decoupled_weight_decay: bool, 388*da0073e9SAndroid Build Coastguard Worker maximize: bool, 389*da0073e9SAndroid Build Coastguard Worker capturable: bool, 390*da0073e9SAndroid Build Coastguard Worker differentiable: bool, 391*da0073e9SAndroid Build Coastguard Worker has_complex: bool, 392*da0073e9SAndroid Build Coastguard Worker): 393*da0073e9SAndroid Build Coastguard Worker if len(params) == 0: 394*da0073e9SAndroid Build Coastguard Worker return 395*da0073e9SAndroid Build Coastguard Worker 396*da0073e9SAndroid Build Coastguard Worker assert not differentiable, "_foreach ops don't support autograd" 397*da0073e9SAndroid Build Coastguard Worker 398*da0073e9SAndroid Build Coastguard Worker # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] 399*da0073e9SAndroid Build Coastguard Worker if not torch._utils.is_compiling() and capturable: 400*da0073e9SAndroid Build Coastguard Worker capturable_supported_devices = _get_capturable_supported_devices( 401*da0073e9SAndroid Build Coastguard Worker supports_xla=False 402*da0073e9SAndroid Build Coastguard Worker ) 403*da0073e9SAndroid Build Coastguard Worker assert all( 404*da0073e9SAndroid Build Coastguard Worker p.device.type == mp.device.type == step.device.type 405*da0073e9SAndroid Build Coastguard Worker and p.device.type in capturable_supported_devices 406*da0073e9SAndroid Build Coastguard Worker for p, mp, step in zip(params, mu_products, state_steps) 407*da0073e9SAndroid Build Coastguard Worker ), f"If capturable=True, params, mu_products, and state_steps must be on supported devices: {capturable_supported_devices}." 408*da0073e9SAndroid Build Coastguard Worker 409*da0073e9SAndroid Build Coastguard Worker grouped_tensors = Optimizer._group_tensors_by_device_and_dtype( 410*da0073e9SAndroid Build Coastguard Worker [params, grads, exp_avgs, exp_avg_sqs, mu_products, state_steps] # type: ignore[list-item] 411*da0073e9SAndroid Build Coastguard Worker ) 412*da0073e9SAndroid Build Coastguard Worker for ( 413*da0073e9SAndroid Build Coastguard Worker grouped_params_, 414*da0073e9SAndroid Build Coastguard Worker grouped_grads_, 415*da0073e9SAndroid Build Coastguard Worker grouped_exp_avgs_, 416*da0073e9SAndroid Build Coastguard Worker grouped_exp_avg_sqs_, 417*da0073e9SAndroid Build Coastguard Worker grouped_mu_products_, 418*da0073e9SAndroid Build Coastguard Worker grouped_state_steps_, 419*da0073e9SAndroid Build Coastguard Worker ), _ in grouped_tensors.values(): 420*da0073e9SAndroid Build Coastguard Worker grouped_params = cast(List[Tensor], grouped_params_) 421*da0073e9SAndroid Build Coastguard Worker grouped_grads = cast(List[Tensor], grouped_grads_) 422*da0073e9SAndroid Build Coastguard Worker grouped_exp_avgs = cast(List[Tensor], grouped_exp_avgs_) 423*da0073e9SAndroid Build Coastguard Worker grouped_exp_avg_sqs = cast(List[Tensor], grouped_exp_avg_sqs_) 424*da0073e9SAndroid Build Coastguard Worker grouped_mu_products = cast(List[Tensor], grouped_mu_products_) 425*da0073e9SAndroid Build Coastguard Worker grouped_state_steps = cast(List[Tensor], grouped_state_steps_) 426*da0073e9SAndroid Build Coastguard Worker 427*da0073e9SAndroid Build Coastguard Worker # handle complex 428*da0073e9SAndroid Build Coastguard Worker if has_complex: 429*da0073e9SAndroid Build Coastguard Worker _view_as_real( 430*da0073e9SAndroid Build Coastguard Worker grouped_params, grouped_grads, grouped_exp_avgs, grouped_exp_avg_sqs 431*da0073e9SAndroid Build Coastguard Worker ) 432*da0073e9SAndroid Build Coastguard Worker 433*da0073e9SAndroid Build Coastguard Worker if maximize: 434*da0073e9SAndroid Build Coastguard Worker grouped_grads = torch._foreach_neg(grouped_grads) # type: ignore[assignment] 435*da0073e9SAndroid Build Coastguard Worker 436*da0073e9SAndroid Build Coastguard Worker # Update steps 437*da0073e9SAndroid Build Coastguard Worker # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over 438*da0073e9SAndroid Build Coastguard Worker # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just 439*da0073e9SAndroid Build Coastguard Worker # wrapped it once now. The alpha is required to assure we go to the right overload. 440*da0073e9SAndroid Build Coastguard Worker if not torch._utils.is_compiling() and grouped_state_steps[0].is_cpu: 441*da0073e9SAndroid Build Coastguard Worker torch._foreach_add_( 442*da0073e9SAndroid Build Coastguard Worker grouped_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0 443*da0073e9SAndroid Build Coastguard Worker ) 444*da0073e9SAndroid Build Coastguard Worker else: 445*da0073e9SAndroid Build Coastguard Worker torch._foreach_add_(grouped_state_steps, 1) 446*da0073e9SAndroid Build Coastguard Worker 447*da0073e9SAndroid Build Coastguard Worker if weight_decay != 0: 448*da0073e9SAndroid Build Coastguard Worker if decoupled_weight_decay: 449*da0073e9SAndroid Build Coastguard Worker # Perform stepweight decay 450*da0073e9SAndroid Build Coastguard Worker torch._foreach_mul_(grouped_params, 1 - lr * weight_decay) 451*da0073e9SAndroid Build Coastguard Worker else: 452*da0073e9SAndroid Build Coastguard Worker # Re-use the intermediate memory (grouped_grads) already allocated for maximize 453*da0073e9SAndroid Build Coastguard Worker if maximize: 454*da0073e9SAndroid Build Coastguard Worker torch._foreach_add_( 455*da0073e9SAndroid Build Coastguard Worker grouped_grads, grouped_params, alpha=weight_decay 456*da0073e9SAndroid Build Coastguard Worker ) 457*da0073e9SAndroid Build Coastguard Worker else: 458*da0073e9SAndroid Build Coastguard Worker grouped_grads = torch._foreach_add( # type: ignore[assignment] 459*da0073e9SAndroid Build Coastguard Worker grouped_grads, grouped_params, alpha=weight_decay 460*da0073e9SAndroid Build Coastguard Worker ) 461*da0073e9SAndroid Build Coastguard Worker 462*da0073e9SAndroid Build Coastguard Worker # Decay the first and second moment running average coefficient 463*da0073e9SAndroid Build Coastguard Worker torch._foreach_lerp_(grouped_exp_avgs, grouped_grads, 1 - beta1) 464*da0073e9SAndroid Build Coastguard Worker 465*da0073e9SAndroid Build Coastguard Worker torch._foreach_mul_(grouped_exp_avg_sqs, beta2) 466*da0073e9SAndroid Build Coastguard Worker torch._foreach_addcmul_( 467*da0073e9SAndroid Build Coastguard Worker grouped_exp_avg_sqs, grouped_grads, grouped_grads, 1 - beta2 468*da0073e9SAndroid Build Coastguard Worker ) 469*da0073e9SAndroid Build Coastguard Worker 470*da0073e9SAndroid Build Coastguard Worker exp_avg_sq_sqrt = torch._foreach_sqrt(grouped_exp_avg_sqs) 471*da0073e9SAndroid Build Coastguard Worker 472*da0073e9SAndroid Build Coastguard Worker bias_correction_sqrt: Union[Tuple[Tensor, ...], List[Tensor]] 473*da0073e9SAndroid Build Coastguard Worker mus: Union[Tuple[Tensor, ...], List[Tensor]] 474*da0073e9SAndroid Build Coastguard Worker mu_nexts: Union[Tuple[Tensor, ...], List[Tensor]] 475*da0073e9SAndroid Build Coastguard Worker if capturable: 476*da0073e9SAndroid Build Coastguard Worker # mus will be beta1 * (1 - 0.5 * 0.96 ** (step * momentum_decay)) 477*da0073e9SAndroid Build Coastguard Worker exponent = torch._foreach_mul(grouped_state_steps, momentum_decay) 478*da0073e9SAndroid Build Coastguard Worker mus = torch._foreach_pow(0.96, exponent) 479*da0073e9SAndroid Build Coastguard Worker torch._foreach_mul_(mus, -0.5) 480*da0073e9SAndroid Build Coastguard Worker torch._foreach_add_(mus, 1.0) 481*da0073e9SAndroid Build Coastguard Worker torch._foreach_mul_(mus, beta1) 482*da0073e9SAndroid Build Coastguard Worker 483*da0073e9SAndroid Build Coastguard Worker # mu_nexts will be beta1 * (1 - 0.5 * 0.96 ** ((step + 1) * momentum_decay)) 484*da0073e9SAndroid Build Coastguard Worker torch._foreach_add_(exponent, momentum_decay) 485*da0073e9SAndroid Build Coastguard Worker mu_nexts = torch._foreach_pow(0.96, exponent) 486*da0073e9SAndroid Build Coastguard Worker torch._foreach_mul_(mu_nexts, -0.5) 487*da0073e9SAndroid Build Coastguard Worker torch._foreach_add_(mu_nexts, 1.0) 488*da0073e9SAndroid Build Coastguard Worker torch._foreach_mul_(mu_nexts, beta1) 489*da0073e9SAndroid Build Coastguard Worker 490*da0073e9SAndroid Build Coastguard Worker # save peak memory as we don't need exponent anymore 491*da0073e9SAndroid Build Coastguard Worker del exponent 492*da0073e9SAndroid Build Coastguard Worker 493*da0073e9SAndroid Build Coastguard Worker bias_correction_sqrt = torch._foreach_pow(beta2, grouped_state_steps) 494*da0073e9SAndroid Build Coastguard Worker # foreach_sub doesn't allow a scalar as the first arg 495*da0073e9SAndroid Build Coastguard Worker torch._foreach_sub_(bias_correction_sqrt, 1.0) 496*da0073e9SAndroid Build Coastguard Worker torch._foreach_neg_(bias_correction_sqrt) 497*da0073e9SAndroid Build Coastguard Worker torch._foreach_sqrt_(bias_correction_sqrt) 498*da0073e9SAndroid Build Coastguard Worker else: 499*da0073e9SAndroid Build Coastguard Worker bias_correction_sqrt = [ 500*da0073e9SAndroid Build Coastguard Worker (1 - beta2 ** _get_value(step)) ** 0.5 for step in grouped_state_steps 501*da0073e9SAndroid Build Coastguard Worker ] 502*da0073e9SAndroid Build Coastguard Worker mus = [ 503*da0073e9SAndroid Build Coastguard Worker beta1 * (1.0 - 0.5 * (0.96 ** (_get_value(step) * momentum_decay))) 504*da0073e9SAndroid Build Coastguard Worker for step in grouped_state_steps 505*da0073e9SAndroid Build Coastguard Worker ] 506*da0073e9SAndroid Build Coastguard Worker mu_nexts = [ 507*da0073e9SAndroid Build Coastguard Worker beta1 508*da0073e9SAndroid Build Coastguard Worker * (1.0 - 0.5 * (0.96 ** ((_get_value(step) + 1) * momentum_decay))) 509*da0073e9SAndroid Build Coastguard Worker for step in grouped_state_steps 510*da0073e9SAndroid Build Coastguard Worker ] 511*da0073e9SAndroid Build Coastguard Worker 512*da0073e9SAndroid Build Coastguard Worker # update mu_products 513*da0073e9SAndroid Build Coastguard Worker torch._foreach_mul_(grouped_mu_products, mus) 514*da0073e9SAndroid Build Coastguard Worker 515*da0073e9SAndroid Build Coastguard Worker torch._foreach_div_(exp_avg_sq_sqrt, bias_correction_sqrt) 516*da0073e9SAndroid Build Coastguard Worker torch._foreach_add_(exp_avg_sq_sqrt, eps) 517*da0073e9SAndroid Build Coastguard Worker 518*da0073e9SAndroid Build Coastguard Worker # explicitly delete bias_correction refs to save memory 519*da0073e9SAndroid Build Coastguard Worker del bias_correction_sqrt 520*da0073e9SAndroid Build Coastguard Worker 521*da0073e9SAndroid Build Coastguard Worker if capturable: 522*da0073e9SAndroid Build Coastguard Worker # Build up the step_size multiplier for grad, reusing mus' memory 523*da0073e9SAndroid Build Coastguard Worker torch._foreach_sub_(mus, 1.0) 524*da0073e9SAndroid Build Coastguard Worker torch._foreach_mul_(mus, lr) 525*da0073e9SAndroid Build Coastguard Worker # foreach_sub doesn't allow a scalar as the first arg 526*da0073e9SAndroid Build Coastguard Worker denom = torch._foreach_sub(grouped_mu_products, 1.0) 527*da0073e9SAndroid Build Coastguard Worker torch._foreach_neg_(denom) 528*da0073e9SAndroid Build Coastguard Worker torch._foreach_div_(mus, denom) 529*da0073e9SAndroid Build Coastguard Worker # - lr * (1 - mu) / (1 - mu_product) 530*da0073e9SAndroid Build Coastguard Worker step_size_grads = mus 531*da0073e9SAndroid Build Coastguard Worker # explicitly delete denom to save memory 532*da0073e9SAndroid Build Coastguard Worker del denom 533*da0073e9SAndroid Build Coastguard Worker 534*da0073e9SAndroid Build Coastguard Worker # Build up the step_size multiplier for exp_avg, reusing mu_nexts' memory 535*da0073e9SAndroid Build Coastguard Worker denom = torch._foreach_mul(grouped_mu_products, mu_nexts) 536*da0073e9SAndroid Build Coastguard Worker torch._foreach_mul_(mu_nexts, lr) 537*da0073e9SAndroid Build Coastguard Worker # foreach_sub doesn't allow a scalar as the first arg, but it's okay because 538*da0073e9SAndroid Build Coastguard Worker # we need a negative here anyway 539*da0073e9SAndroid Build Coastguard Worker torch._foreach_sub_(denom, 1.0) 540*da0073e9SAndroid Build Coastguard Worker torch._foreach_div_(mu_nexts, denom) 541*da0073e9SAndroid Build Coastguard Worker # - lr * mu_next / (1 - mu_product * mu_next) 542*da0073e9SAndroid Build Coastguard Worker step_size_expavg = mu_nexts 543*da0073e9SAndroid Build Coastguard Worker # explicitly delete denom to save memory 544*da0073e9SAndroid Build Coastguard Worker del denom 545*da0073e9SAndroid Build Coastguard Worker 546*da0073e9SAndroid Build Coastguard Worker # we cannot inplace into step_size_grads cuz it is a list of ScalarTensors 547*da0073e9SAndroid Build Coastguard Worker # and mul'ing with grouped_grads will result in a list of bigger Tensors 548*da0073e9SAndroid Build Coastguard Worker numerator = torch._foreach_mul(step_size_grads, grouped_grads) 549*da0073e9SAndroid Build Coastguard Worker torch._foreach_addcmul_(numerator, step_size_expavg, grouped_exp_avgs) 550*da0073e9SAndroid Build Coastguard Worker 551*da0073e9SAndroid Build Coastguard Worker # finally, update params 552*da0073e9SAndroid Build Coastguard Worker torch._foreach_addcdiv_(grouped_params, numerator, exp_avg_sq_sqrt) 553*da0073e9SAndroid Build Coastguard Worker else: 554*da0073e9SAndroid Build Coastguard Worker step_size_grads = _stack_if_compiling( 555*da0073e9SAndroid Build Coastguard Worker [ 556*da0073e9SAndroid Build Coastguard Worker (_get_value(lr) * (1.0 - mu) / (1.0 - _get_value(mu_product))) * -1 557*da0073e9SAndroid Build Coastguard Worker for mu_product, mu in zip(grouped_mu_products, mus) 558*da0073e9SAndroid Build Coastguard Worker ] 559*da0073e9SAndroid Build Coastguard Worker ) 560*da0073e9SAndroid Build Coastguard Worker step_size_expavg = _stack_if_compiling( 561*da0073e9SAndroid Build Coastguard Worker [ 562*da0073e9SAndroid Build Coastguard Worker ( 563*da0073e9SAndroid Build Coastguard Worker _get_value(lr) 564*da0073e9SAndroid Build Coastguard Worker * mu_next 565*da0073e9SAndroid Build Coastguard Worker / (1.0 - _get_value(mu_product) * mu_next) 566*da0073e9SAndroid Build Coastguard Worker ) 567*da0073e9SAndroid Build Coastguard Worker * -1 568*da0073e9SAndroid Build Coastguard Worker for mu_product, mu_next in zip(grouped_mu_products, mu_nexts) 569*da0073e9SAndroid Build Coastguard Worker ] 570*da0073e9SAndroid Build Coastguard Worker ) 571*da0073e9SAndroid Build Coastguard Worker 572*da0073e9SAndroid Build Coastguard Worker torch._foreach_addcdiv_( 573*da0073e9SAndroid Build Coastguard Worker grouped_params, grouped_grads, exp_avg_sq_sqrt, step_size_grads # type: ignore[arg-type] 574*da0073e9SAndroid Build Coastguard Worker ) 575*da0073e9SAndroid Build Coastguard Worker torch._foreach_addcdiv_( 576*da0073e9SAndroid Build Coastguard Worker grouped_params, grouped_exp_avgs, exp_avg_sq_sqrt, step_size_expavg # type: ignore[arg-type] 577*da0073e9SAndroid Build Coastguard Worker ) 578*da0073e9SAndroid Build Coastguard Worker 579*da0073e9SAndroid Build Coastguard Worker 580*da0073e9SAndroid Build Coastguard Worker@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_nadam) 581*da0073e9SAndroid Build Coastguard Workerdef nadam( 582*da0073e9SAndroid Build Coastguard Worker params: List[Tensor], 583*da0073e9SAndroid Build Coastguard Worker grads: List[Tensor], 584*da0073e9SAndroid Build Coastguard Worker exp_avgs: List[Tensor], 585*da0073e9SAndroid Build Coastguard Worker exp_avg_sqs: List[Tensor], 586*da0073e9SAndroid Build Coastguard Worker mu_products: List[Tensor], 587*da0073e9SAndroid Build Coastguard Worker state_steps: List[Tensor], 588*da0073e9SAndroid Build Coastguard Worker # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 589*da0073e9SAndroid Build Coastguard Worker # setting this as kwarg for now as functional API is compiled by torch/distributed/optim 590*da0073e9SAndroid Build Coastguard Worker decoupled_weight_decay: bool = False, 591*da0073e9SAndroid Build Coastguard Worker foreach: Optional[bool] = None, 592*da0073e9SAndroid Build Coastguard Worker capturable: bool = False, 593*da0073e9SAndroid Build Coastguard Worker differentiable: bool = False, 594*da0073e9SAndroid Build Coastguard Worker has_complex: bool = False, 595*da0073e9SAndroid Build Coastguard Worker maximize: bool = False, 596*da0073e9SAndroid Build Coastguard Worker *, 597*da0073e9SAndroid Build Coastguard Worker beta1: float, 598*da0073e9SAndroid Build Coastguard Worker beta2: float, 599*da0073e9SAndroid Build Coastguard Worker lr: float, 600*da0073e9SAndroid Build Coastguard Worker weight_decay: float, 601*da0073e9SAndroid Build Coastguard Worker momentum_decay: float, 602*da0073e9SAndroid Build Coastguard Worker eps: float, 603*da0073e9SAndroid Build Coastguard Worker): 604*da0073e9SAndroid Build Coastguard Worker r"""Functional API that performs NAdam algorithm computation. 605*da0073e9SAndroid Build Coastguard Worker 606*da0073e9SAndroid Build Coastguard Worker See :class:`~torch.optim.NAdam` for details. 607*da0073e9SAndroid Build Coastguard Worker """ 608*da0073e9SAndroid Build Coastguard Worker if not all(isinstance(t, torch.Tensor) for t in state_steps): 609*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 610*da0073e9SAndroid Build Coastguard Worker "API has changed, `state_steps` argument must contain a list of singleton tensors" 611*da0073e9SAndroid Build Coastguard Worker ) 612*da0073e9SAndroid Build Coastguard Worker 613*da0073e9SAndroid Build Coastguard Worker if not all(isinstance(t, torch.Tensor) for t in mu_products): 614*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 615*da0073e9SAndroid Build Coastguard Worker "API has changed, `mu_products` argument must contain a list of singleton tensors" 616*da0073e9SAndroid Build Coastguard Worker ) 617*da0073e9SAndroid Build Coastguard Worker 618*da0073e9SAndroid Build Coastguard Worker if foreach is None: 619*da0073e9SAndroid Build Coastguard Worker _, foreach = _default_to_fused_or_foreach( 620*da0073e9SAndroid Build Coastguard Worker params, differentiable, use_fused=False 621*da0073e9SAndroid Build Coastguard Worker ) 622*da0073e9SAndroid Build Coastguard Worker 623*da0073e9SAndroid Build Coastguard Worker if foreach and torch.jit.is_scripting(): 624*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("torch.jit.script not supported with foreach optimizers") 625*da0073e9SAndroid Build Coastguard Worker 626*da0073e9SAndroid Build Coastguard Worker if foreach and not torch.jit.is_scripting(): 627*da0073e9SAndroid Build Coastguard Worker func = _multi_tensor_nadam 628*da0073e9SAndroid Build Coastguard Worker else: 629*da0073e9SAndroid Build Coastguard Worker func = _single_tensor_nadam 630*da0073e9SAndroid Build Coastguard Worker 631*da0073e9SAndroid Build Coastguard Worker func( 632*da0073e9SAndroid Build Coastguard Worker params, 633*da0073e9SAndroid Build Coastguard Worker grads, 634*da0073e9SAndroid Build Coastguard Worker exp_avgs, 635*da0073e9SAndroid Build Coastguard Worker exp_avg_sqs, 636*da0073e9SAndroid Build Coastguard Worker mu_products, 637*da0073e9SAndroid Build Coastguard Worker state_steps, 638*da0073e9SAndroid Build Coastguard Worker beta1=beta1, 639*da0073e9SAndroid Build Coastguard Worker beta2=beta2, 640*da0073e9SAndroid Build Coastguard Worker lr=lr, 641*da0073e9SAndroid Build Coastguard Worker weight_decay=weight_decay, 642*da0073e9SAndroid Build Coastguard Worker momentum_decay=momentum_decay, 643*da0073e9SAndroid Build Coastguard Worker maximize=maximize, 644*da0073e9SAndroid Build Coastguard Worker decoupled_weight_decay=decoupled_weight_decay, 645*da0073e9SAndroid Build Coastguard Worker eps=eps, 646*da0073e9SAndroid Build Coastguard Worker capturable=capturable, 647*da0073e9SAndroid Build Coastguard Worker differentiable=differentiable, 648*da0073e9SAndroid Build Coastguard Worker has_complex=has_complex, 649*da0073e9SAndroid Build Coastguard Worker ) 650