1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-decorators 2*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs 3*da0073e9SAndroid Build Coastguard Workerfrom typing import cast, List, Optional, Tuple, Union 4*da0073e9SAndroid Build Coastguard Worker 5*da0073e9SAndroid Build Coastguard Workerimport torch 6*da0073e9SAndroid Build Coastguard Workerfrom torch import Tensor 7*da0073e9SAndroid Build Coastguard Worker 8*da0073e9SAndroid Build Coastguard Workerfrom .optimizer import ( 9*da0073e9SAndroid Build Coastguard Worker _capturable_doc, 10*da0073e9SAndroid Build Coastguard Worker _default_to_fused_or_foreach, 11*da0073e9SAndroid Build Coastguard Worker _device_dtype_check_for_fused, 12*da0073e9SAndroid Build Coastguard Worker _differentiable_doc, 13*da0073e9SAndroid Build Coastguard Worker _disable_dynamo_if_unsupported, 14*da0073e9SAndroid Build Coastguard Worker _foreach_doc, 15*da0073e9SAndroid Build Coastguard Worker _fused_doc, 16*da0073e9SAndroid Build Coastguard Worker _get_capturable_supported_devices, 17*da0073e9SAndroid Build Coastguard Worker _get_scalar_dtype, 18*da0073e9SAndroid Build Coastguard Worker _get_value, 19*da0073e9SAndroid Build Coastguard Worker _maximize_doc, 20*da0073e9SAndroid Build Coastguard Worker _stack_if_compiling, 21*da0073e9SAndroid Build Coastguard Worker _use_grad_for_differentiable, 22*da0073e9SAndroid Build Coastguard Worker _view_as_real, 23*da0073e9SAndroid Build Coastguard Worker DeviceDict, 24*da0073e9SAndroid Build Coastguard Worker Optimizer, 25*da0073e9SAndroid Build Coastguard Worker ParamsT, 26*da0073e9SAndroid Build Coastguard Worker) 27*da0073e9SAndroid Build Coastguard Worker 28*da0073e9SAndroid Build Coastguard Worker 29*da0073e9SAndroid Build Coastguard Worker__all__ = ["Adam", "adam"] 30*da0073e9SAndroid Build Coastguard Worker 31*da0073e9SAndroid Build Coastguard Worker 32*da0073e9SAndroid Build Coastguard Workerclass Adam(Optimizer): 33*da0073e9SAndroid Build Coastguard Worker def __init__( 34*da0073e9SAndroid Build Coastguard Worker self, 35*da0073e9SAndroid Build Coastguard Worker params: ParamsT, 36*da0073e9SAndroid Build Coastguard Worker lr: Union[float, Tensor] = 1e-3, 37*da0073e9SAndroid Build Coastguard Worker betas: Tuple[float, float] = (0.9, 0.999), 38*da0073e9SAndroid Build Coastguard Worker eps: float = 1e-8, 39*da0073e9SAndroid Build Coastguard Worker weight_decay: float = 0, 40*da0073e9SAndroid Build Coastguard Worker amsgrad: bool = False, 41*da0073e9SAndroid Build Coastguard Worker *, 42*da0073e9SAndroid Build Coastguard Worker foreach: Optional[bool] = None, 43*da0073e9SAndroid Build Coastguard Worker maximize: bool = False, 44*da0073e9SAndroid Build Coastguard Worker capturable: bool = False, 45*da0073e9SAndroid Build Coastguard Worker differentiable: bool = False, 46*da0073e9SAndroid Build Coastguard Worker fused: Optional[bool] = None, 47*da0073e9SAndroid Build Coastguard Worker ): 48*da0073e9SAndroid Build Coastguard Worker if isinstance(lr, Tensor): 49*da0073e9SAndroid Build Coastguard Worker if foreach and not capturable: 50*da0073e9SAndroid Build Coastguard Worker raise ValueError( 51*da0073e9SAndroid Build Coastguard Worker "lr as a Tensor is not supported for capturable=False and foreach=True" 52*da0073e9SAndroid Build Coastguard Worker ) 53*da0073e9SAndroid Build Coastguard Worker if lr.numel() != 1: 54*da0073e9SAndroid Build Coastguard Worker raise ValueError("Tensor lr must be 1-element") 55*da0073e9SAndroid Build Coastguard Worker if not 0.0 <= lr: 56*da0073e9SAndroid Build Coastguard Worker raise ValueError(f"Invalid learning rate: {lr}") 57*da0073e9SAndroid Build Coastguard Worker if not 0.0 <= eps: 58*da0073e9SAndroid Build Coastguard Worker raise ValueError(f"Invalid epsilon value: {eps}") 59*da0073e9SAndroid Build Coastguard Worker if not 0.0 <= betas[0] < 1.0: 60*da0073e9SAndroid Build Coastguard Worker raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}") 61*da0073e9SAndroid Build Coastguard Worker if not 0.0 <= betas[1] < 1.0: 62*da0073e9SAndroid Build Coastguard Worker raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}") 63*da0073e9SAndroid Build Coastguard Worker if not 0.0 <= weight_decay: 64*da0073e9SAndroid Build Coastguard Worker raise ValueError(f"Invalid weight_decay value: {weight_decay}") 65*da0073e9SAndroid Build Coastguard Worker 66*da0073e9SAndroid Build Coastguard Worker defaults = dict( 67*da0073e9SAndroid Build Coastguard Worker lr=lr, 68*da0073e9SAndroid Build Coastguard Worker betas=betas, 69*da0073e9SAndroid Build Coastguard Worker eps=eps, 70*da0073e9SAndroid Build Coastguard Worker weight_decay=weight_decay, 71*da0073e9SAndroid Build Coastguard Worker amsgrad=amsgrad, 72*da0073e9SAndroid Build Coastguard Worker maximize=maximize, 73*da0073e9SAndroid Build Coastguard Worker foreach=foreach, 74*da0073e9SAndroid Build Coastguard Worker capturable=capturable, 75*da0073e9SAndroid Build Coastguard Worker differentiable=differentiable, 76*da0073e9SAndroid Build Coastguard Worker fused=fused, 77*da0073e9SAndroid Build Coastguard Worker ) 78*da0073e9SAndroid Build Coastguard Worker super().__init__(params, defaults) 79*da0073e9SAndroid Build Coastguard Worker 80*da0073e9SAndroid Build Coastguard Worker if fused: 81*da0073e9SAndroid Build Coastguard Worker if differentiable: 82*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("`fused` does not support `differentiable`") 83*da0073e9SAndroid Build Coastguard Worker self._step_supports_amp_scaling = True 84*da0073e9SAndroid Build Coastguard Worker # TODO(crcrpar): [low prec params & their higher prec copy] 85*da0073e9SAndroid Build Coastguard Worker # Support AMP with FP16/BF16 model params which would need 86*da0073e9SAndroid Build Coastguard Worker # higher prec copy of params to do update math in higher prec to 87*da0073e9SAndroid Build Coastguard Worker # alleviate the loss of information. 88*da0073e9SAndroid Build Coastguard Worker if foreach: 89*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("`fused` and `foreach` cannot be `True` together.") 90*da0073e9SAndroid Build Coastguard Worker 91*da0073e9SAndroid Build Coastguard Worker def __setstate__(self, state): 92*da0073e9SAndroid Build Coastguard Worker super().__setstate__(state) 93*da0073e9SAndroid Build Coastguard Worker for group in self.param_groups: 94*da0073e9SAndroid Build Coastguard Worker group.setdefault("amsgrad", False) 95*da0073e9SAndroid Build Coastguard Worker group.setdefault("maximize", False) 96*da0073e9SAndroid Build Coastguard Worker group.setdefault("foreach", None) 97*da0073e9SAndroid Build Coastguard Worker group.setdefault("capturable", False) 98*da0073e9SAndroid Build Coastguard Worker group.setdefault("differentiable", False) 99*da0073e9SAndroid Build Coastguard Worker fused = group.setdefault("fused", None) 100*da0073e9SAndroid Build Coastguard Worker for p in group["params"]: 101*da0073e9SAndroid Build Coastguard Worker p_state = self.state.get(p, []) 102*da0073e9SAndroid Build Coastguard Worker if len(p_state) != 0 and not torch.is_tensor(p_state["step"]): 103*da0073e9SAndroid Build Coastguard Worker step_val = float(p_state["step"]) 104*da0073e9SAndroid Build Coastguard Worker p_state["step"] = ( 105*da0073e9SAndroid Build Coastguard Worker torch.tensor( 106*da0073e9SAndroid Build Coastguard Worker step_val, 107*da0073e9SAndroid Build Coastguard Worker dtype=_get_scalar_dtype(is_fused=fused), 108*da0073e9SAndroid Build Coastguard Worker device=p.device, 109*da0073e9SAndroid Build Coastguard Worker ) 110*da0073e9SAndroid Build Coastguard Worker if group["capturable"] or group["fused"] 111*da0073e9SAndroid Build Coastguard Worker else torch.tensor(step_val, dtype=_get_scalar_dtype()) 112*da0073e9SAndroid Build Coastguard Worker ) 113*da0073e9SAndroid Build Coastguard Worker 114*da0073e9SAndroid Build Coastguard Worker def _init_group( 115*da0073e9SAndroid Build Coastguard Worker self, 116*da0073e9SAndroid Build Coastguard Worker group, 117*da0073e9SAndroid Build Coastguard Worker params_with_grad, 118*da0073e9SAndroid Build Coastguard Worker grads, 119*da0073e9SAndroid Build Coastguard Worker exp_avgs, 120*da0073e9SAndroid Build Coastguard Worker exp_avg_sqs, 121*da0073e9SAndroid Build Coastguard Worker max_exp_avg_sqs, 122*da0073e9SAndroid Build Coastguard Worker state_steps, 123*da0073e9SAndroid Build Coastguard Worker ): 124*da0073e9SAndroid Build Coastguard Worker has_complex = False 125*da0073e9SAndroid Build Coastguard Worker for p in group["params"]: 126*da0073e9SAndroid Build Coastguard Worker if p.grad is not None: 127*da0073e9SAndroid Build Coastguard Worker has_complex |= torch.is_complex(p) 128*da0073e9SAndroid Build Coastguard Worker params_with_grad.append(p) 129*da0073e9SAndroid Build Coastguard Worker if p.grad.is_sparse: 130*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 131*da0073e9SAndroid Build Coastguard Worker "Adam does not support sparse gradients, please consider SparseAdam instead" 132*da0073e9SAndroid Build Coastguard Worker ) 133*da0073e9SAndroid Build Coastguard Worker grads.append(p.grad) 134*da0073e9SAndroid Build Coastguard Worker 135*da0073e9SAndroid Build Coastguard Worker state = self.state[p] 136*da0073e9SAndroid Build Coastguard Worker # Lazy state initialization 137*da0073e9SAndroid Build Coastguard Worker if len(state) == 0: 138*da0073e9SAndroid Build Coastguard Worker if group["fused"]: 139*da0073e9SAndroid Build Coastguard Worker _device_dtype_check_for_fused(p) 140*da0073e9SAndroid Build Coastguard Worker # note(crcrpar): [special device hosting for step] 141*da0073e9SAndroid Build Coastguard Worker # Deliberately host `step` on CPU if both capturable and fused are off. 142*da0073e9SAndroid Build Coastguard Worker # This is because kernel launches are costly on CUDA and XLA. 143*da0073e9SAndroid Build Coastguard Worker state["step"] = ( 144*da0073e9SAndroid Build Coastguard Worker torch.zeros( 145*da0073e9SAndroid Build Coastguard Worker (), 146*da0073e9SAndroid Build Coastguard Worker dtype=_get_scalar_dtype(is_fused=group["fused"]), 147*da0073e9SAndroid Build Coastguard Worker device=p.device, 148*da0073e9SAndroid Build Coastguard Worker ) 149*da0073e9SAndroid Build Coastguard Worker if group["capturable"] or group["fused"] 150*da0073e9SAndroid Build Coastguard Worker else torch.tensor(0.0, dtype=_get_scalar_dtype()) 151*da0073e9SAndroid Build Coastguard Worker ) 152*da0073e9SAndroid Build Coastguard Worker # Exponential moving average of gradient values 153*da0073e9SAndroid Build Coastguard Worker state["exp_avg"] = torch.zeros_like( 154*da0073e9SAndroid Build Coastguard Worker p, memory_format=torch.preserve_format 155*da0073e9SAndroid Build Coastguard Worker ) 156*da0073e9SAndroid Build Coastguard Worker # Exponential moving average of squared gradient values 157*da0073e9SAndroid Build Coastguard Worker state["exp_avg_sq"] = torch.zeros_like( 158*da0073e9SAndroid Build Coastguard Worker p, memory_format=torch.preserve_format 159*da0073e9SAndroid Build Coastguard Worker ) 160*da0073e9SAndroid Build Coastguard Worker if group["amsgrad"]: 161*da0073e9SAndroid Build Coastguard Worker # Maintains max of all exp. moving avg. of sq. grad. values 162*da0073e9SAndroid Build Coastguard Worker state["max_exp_avg_sq"] = torch.zeros_like( 163*da0073e9SAndroid Build Coastguard Worker p, memory_format=torch.preserve_format 164*da0073e9SAndroid Build Coastguard Worker ) 165*da0073e9SAndroid Build Coastguard Worker 166*da0073e9SAndroid Build Coastguard Worker exp_avgs.append(state["exp_avg"]) 167*da0073e9SAndroid Build Coastguard Worker exp_avg_sqs.append(state["exp_avg_sq"]) 168*da0073e9SAndroid Build Coastguard Worker 169*da0073e9SAndroid Build Coastguard Worker if group["amsgrad"]: 170*da0073e9SAndroid Build Coastguard Worker max_exp_avg_sqs.append(state["max_exp_avg_sq"]) 171*da0073e9SAndroid Build Coastguard Worker if group["differentiable"] and state["step"].requires_grad: 172*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 173*da0073e9SAndroid Build Coastguard Worker "`requires_grad` is not supported for `step` in differentiable mode" 174*da0073e9SAndroid Build Coastguard Worker ) 175*da0073e9SAndroid Build Coastguard Worker 176*da0073e9SAndroid Build Coastguard Worker # Foreach without capturable does not support a tensor lr 177*da0073e9SAndroid Build Coastguard Worker if ( 178*da0073e9SAndroid Build Coastguard Worker group["foreach"] 179*da0073e9SAndroid Build Coastguard Worker and torch.is_tensor(group["lr"]) 180*da0073e9SAndroid Build Coastguard Worker and not group["capturable"] 181*da0073e9SAndroid Build Coastguard Worker ): 182*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 183*da0073e9SAndroid Build Coastguard Worker "lr as a Tensor is not supported for capturable=False and foreach=True" 184*da0073e9SAndroid Build Coastguard Worker ) 185*da0073e9SAndroid Build Coastguard Worker 186*da0073e9SAndroid Build Coastguard Worker state_steps.append(state["step"]) 187*da0073e9SAndroid Build Coastguard Worker return has_complex 188*da0073e9SAndroid Build Coastguard Worker 189*da0073e9SAndroid Build Coastguard Worker @_use_grad_for_differentiable 190*da0073e9SAndroid Build Coastguard Worker def step(self, closure=None): 191*da0073e9SAndroid Build Coastguard Worker """Perform a single optimization step. 192*da0073e9SAndroid Build Coastguard Worker 193*da0073e9SAndroid Build Coastguard Worker Args: 194*da0073e9SAndroid Build Coastguard Worker closure (Callable, optional): A closure that reevaluates the model 195*da0073e9SAndroid Build Coastguard Worker and returns the loss. 196*da0073e9SAndroid Build Coastguard Worker """ 197*da0073e9SAndroid Build Coastguard Worker self._cuda_graph_capture_health_check() 198*da0073e9SAndroid Build Coastguard Worker 199*da0073e9SAndroid Build Coastguard Worker loss = None 200*da0073e9SAndroid Build Coastguard Worker if closure is not None: 201*da0073e9SAndroid Build Coastguard Worker with torch.enable_grad(): 202*da0073e9SAndroid Build Coastguard Worker loss = closure() 203*da0073e9SAndroid Build Coastguard Worker 204*da0073e9SAndroid Build Coastguard Worker for group in self.param_groups: 205*da0073e9SAndroid Build Coastguard Worker params_with_grad: List[Tensor] = [] 206*da0073e9SAndroid Build Coastguard Worker grads: List[Tensor] = [] 207*da0073e9SAndroid Build Coastguard Worker exp_avgs: List[Tensor] = [] 208*da0073e9SAndroid Build Coastguard Worker exp_avg_sqs: List[Tensor] = [] 209*da0073e9SAndroid Build Coastguard Worker max_exp_avg_sqs: List[Tensor] = [] 210*da0073e9SAndroid Build Coastguard Worker state_steps: List[Tensor] = [] 211*da0073e9SAndroid Build Coastguard Worker beta1, beta2 = group["betas"] 212*da0073e9SAndroid Build Coastguard Worker 213*da0073e9SAndroid Build Coastguard Worker has_complex = self._init_group( 214*da0073e9SAndroid Build Coastguard Worker group, 215*da0073e9SAndroid Build Coastguard Worker params_with_grad, 216*da0073e9SAndroid Build Coastguard Worker grads, 217*da0073e9SAndroid Build Coastguard Worker exp_avgs, 218*da0073e9SAndroid Build Coastguard Worker exp_avg_sqs, 219*da0073e9SAndroid Build Coastguard Worker max_exp_avg_sqs, 220*da0073e9SAndroid Build Coastguard Worker state_steps, 221*da0073e9SAndroid Build Coastguard Worker ) 222*da0073e9SAndroid Build Coastguard Worker 223*da0073e9SAndroid Build Coastguard Worker adam( 224*da0073e9SAndroid Build Coastguard Worker params_with_grad, 225*da0073e9SAndroid Build Coastguard Worker grads, 226*da0073e9SAndroid Build Coastguard Worker exp_avgs, 227*da0073e9SAndroid Build Coastguard Worker exp_avg_sqs, 228*da0073e9SAndroid Build Coastguard Worker max_exp_avg_sqs, 229*da0073e9SAndroid Build Coastguard Worker state_steps, 230*da0073e9SAndroid Build Coastguard Worker amsgrad=group["amsgrad"], 231*da0073e9SAndroid Build Coastguard Worker has_complex=has_complex, 232*da0073e9SAndroid Build Coastguard Worker beta1=beta1, 233*da0073e9SAndroid Build Coastguard Worker beta2=beta2, 234*da0073e9SAndroid Build Coastguard Worker lr=group["lr"], 235*da0073e9SAndroid Build Coastguard Worker weight_decay=group["weight_decay"], 236*da0073e9SAndroid Build Coastguard Worker eps=group["eps"], 237*da0073e9SAndroid Build Coastguard Worker maximize=group["maximize"], 238*da0073e9SAndroid Build Coastguard Worker foreach=group["foreach"], 239*da0073e9SAndroid Build Coastguard Worker capturable=group["capturable"], 240*da0073e9SAndroid Build Coastguard Worker differentiable=group["differentiable"], 241*da0073e9SAndroid Build Coastguard Worker fused=group["fused"], 242*da0073e9SAndroid Build Coastguard Worker grad_scale=getattr(self, "grad_scale", None), 243*da0073e9SAndroid Build Coastguard Worker found_inf=getattr(self, "found_inf", None), 244*da0073e9SAndroid Build Coastguard Worker ) 245*da0073e9SAndroid Build Coastguard Worker 246*da0073e9SAndroid Build Coastguard Worker return loss 247*da0073e9SAndroid Build Coastguard Worker 248*da0073e9SAndroid Build Coastguard Worker 249*da0073e9SAndroid Build Coastguard WorkerAdam.__doc__ = ( 250*da0073e9SAndroid Build Coastguard Worker r"""Implements Adam algorithm. 251*da0073e9SAndroid Build Coastguard Worker 252*da0073e9SAndroid Build Coastguard Worker .. math:: 253*da0073e9SAndroid Build Coastguard Worker \begin{aligned} 254*da0073e9SAndroid Build Coastguard Worker &\rule{110mm}{0.4pt} \\ 255*da0073e9SAndroid Build Coastguard Worker &\textbf{input} : \gamma \text{ (lr)}, \beta_1, \beta_2 256*da0073e9SAndroid Build Coastguard Worker \text{ (betas)},\theta_0 \text{ (params)},f(\theta) \text{ (objective)} \\ 257*da0073e9SAndroid Build Coastguard Worker &\hspace{13mm} \lambda \text{ (weight decay)}, \: \textit{amsgrad}, 258*da0073e9SAndroid Build Coastguard Worker \:\textit{maximize} \\ 259*da0073e9SAndroid Build Coastguard Worker &\textbf{initialize} : m_0 \leftarrow 0 \text{ ( first moment)}, 260*da0073e9SAndroid Build Coastguard Worker v_0\leftarrow 0 \text{ (second moment)},\: \widehat{v_0}^{max}\leftarrow 0\\[-1.ex] 261*da0073e9SAndroid Build Coastguard Worker &\rule{110mm}{0.4pt} \\ 262*da0073e9SAndroid Build Coastguard Worker &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\ 263*da0073e9SAndroid Build Coastguard Worker 264*da0073e9SAndroid Build Coastguard Worker &\hspace{5mm}\textbf{if} \: \textit{maximize}: \\ 265*da0073e9SAndroid Build Coastguard Worker &\hspace{10mm}g_t \leftarrow -\nabla_{\theta} f_t (\theta_{t-1}) \\ 266*da0073e9SAndroid Build Coastguard Worker &\hspace{5mm}\textbf{else} \\ 267*da0073e9SAndroid Build Coastguard Worker &\hspace{10mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\ 268*da0073e9SAndroid Build Coastguard Worker &\hspace{5mm}\textbf{if} \: \lambda \neq 0 \\ 269*da0073e9SAndroid Build Coastguard Worker &\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\ 270*da0073e9SAndroid Build Coastguard Worker &\hspace{5mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\ 271*da0073e9SAndroid Build Coastguard Worker &\hspace{5mm}v_t \leftarrow \beta_2 v_{t-1} + (1-\beta_2) g^2_t \\ 272*da0073e9SAndroid Build Coastguard Worker &\hspace{5mm}\widehat{m_t} \leftarrow m_t/\big(1-\beta_1^t \big) \\ 273*da0073e9SAndroid Build Coastguard Worker &\hspace{5mm}\widehat{v_t} \leftarrow v_t/\big(1-\beta_2^t \big) \\ 274*da0073e9SAndroid Build Coastguard Worker &\hspace{5mm}\textbf{if} \: amsgrad \\ 275*da0073e9SAndroid Build Coastguard Worker &\hspace{10mm}\widehat{v_t}^{max} \leftarrow \mathrm{max}(\widehat{v_t}^{max}, 276*da0073e9SAndroid Build Coastguard Worker \widehat{v_t}) \\ 277*da0073e9SAndroid Build Coastguard Worker &\hspace{10mm}\theta_t \leftarrow \theta_{t-1} - \gamma \widehat{m_t}/ 278*da0073e9SAndroid Build Coastguard Worker \big(\sqrt{\widehat{v_t}^{max}} + \epsilon \big) \\ 279*da0073e9SAndroid Build Coastguard Worker &\hspace{5mm}\textbf{else} \\ 280*da0073e9SAndroid Build Coastguard Worker &\hspace{10mm}\theta_t \leftarrow \theta_{t-1} - \gamma \widehat{m_t}/ 281*da0073e9SAndroid Build Coastguard Worker \big(\sqrt{\widehat{v_t}} + \epsilon \big) \\ 282*da0073e9SAndroid Build Coastguard Worker &\rule{110mm}{0.4pt} \\[-1.ex] 283*da0073e9SAndroid Build Coastguard Worker &\bf{return} \: \theta_t \\[-1.ex] 284*da0073e9SAndroid Build Coastguard Worker &\rule{110mm}{0.4pt} \\[-1.ex] 285*da0073e9SAndroid Build Coastguard Worker \end{aligned} 286*da0073e9SAndroid Build Coastguard Worker 287*da0073e9SAndroid Build Coastguard Worker For further details regarding the algorithm we refer to `Adam: A Method for Stochastic Optimization`_. 288*da0073e9SAndroid Build Coastguard Worker """ 289*da0073e9SAndroid Build Coastguard Worker + rf""" 290*da0073e9SAndroid Build Coastguard Worker Args: 291*da0073e9SAndroid Build Coastguard Worker params (iterable): iterable of parameters to optimize or dicts defining 292*da0073e9SAndroid Build Coastguard Worker parameter groups 293*da0073e9SAndroid Build Coastguard Worker lr (float, Tensor, optional): learning rate (default: 1e-3). A tensor LR 294*da0073e9SAndroid Build Coastguard Worker is not yet supported for all our implementations. Please use a float 295*da0073e9SAndroid Build Coastguard Worker LR if you are not also specifying fused=True or capturable=True. 296*da0073e9SAndroid Build Coastguard Worker betas (Tuple[float, float], optional): coefficients used for computing 297*da0073e9SAndroid Build Coastguard Worker running averages of gradient and its square (default: (0.9, 0.999)) 298*da0073e9SAndroid Build Coastguard Worker eps (float, optional): term added to the denominator to improve 299*da0073e9SAndroid Build Coastguard Worker numerical stability (default: 1e-8) 300*da0073e9SAndroid Build Coastguard Worker weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 301*da0073e9SAndroid Build Coastguard Worker amsgrad (bool, optional): whether to use the AMSGrad variant of this 302*da0073e9SAndroid Build Coastguard Worker algorithm from the paper `On the Convergence of Adam and Beyond`_ 303*da0073e9SAndroid Build Coastguard Worker (default: False) 304*da0073e9SAndroid Build Coastguard Worker {_foreach_doc} 305*da0073e9SAndroid Build Coastguard Worker {_maximize_doc} 306*da0073e9SAndroid Build Coastguard Worker {_capturable_doc} 307*da0073e9SAndroid Build Coastguard Worker {_differentiable_doc} 308*da0073e9SAndroid Build Coastguard Worker {_fused_doc} 309*da0073e9SAndroid Build Coastguard Worker .. Note:: 310*da0073e9SAndroid Build Coastguard Worker A prototype implementation of Adam and AdamW for MPS supports `torch.float32` and `torch.float16`. 311*da0073e9SAndroid Build Coastguard Worker .. _Adam\: A Method for Stochastic Optimization: 312*da0073e9SAndroid Build Coastguard Worker https://arxiv.org/abs/1412.6980 313*da0073e9SAndroid Build Coastguard Worker .. _On the Convergence of Adam and Beyond: 314*da0073e9SAndroid Build Coastguard Worker https://openreview.net/forum?id=ryQu7f-RZ 315*da0073e9SAndroid Build Coastguard Worker 316*da0073e9SAndroid Build Coastguard Worker """ 317*da0073e9SAndroid Build Coastguard Worker) 318*da0073e9SAndroid Build Coastguard Worker 319*da0073e9SAndroid Build Coastguard Worker 320*da0073e9SAndroid Build Coastguard Workerdef _single_tensor_adam( 321*da0073e9SAndroid Build Coastguard Worker params: List[Tensor], 322*da0073e9SAndroid Build Coastguard Worker grads: List[Tensor], 323*da0073e9SAndroid Build Coastguard Worker exp_avgs: List[Tensor], 324*da0073e9SAndroid Build Coastguard Worker exp_avg_sqs: List[Tensor], 325*da0073e9SAndroid Build Coastguard Worker max_exp_avg_sqs: List[Tensor], 326*da0073e9SAndroid Build Coastguard Worker state_steps: List[Tensor], 327*da0073e9SAndroid Build Coastguard Worker grad_scale: Optional[Tensor], 328*da0073e9SAndroid Build Coastguard Worker found_inf: Optional[Tensor], 329*da0073e9SAndroid Build Coastguard Worker *, 330*da0073e9SAndroid Build Coastguard Worker amsgrad: bool, 331*da0073e9SAndroid Build Coastguard Worker has_complex: bool, 332*da0073e9SAndroid Build Coastguard Worker beta1: float, 333*da0073e9SAndroid Build Coastguard Worker beta2: float, 334*da0073e9SAndroid Build Coastguard Worker lr: Union[float, Tensor], 335*da0073e9SAndroid Build Coastguard Worker weight_decay: float, 336*da0073e9SAndroid Build Coastguard Worker eps: float, 337*da0073e9SAndroid Build Coastguard Worker maximize: bool, 338*da0073e9SAndroid Build Coastguard Worker capturable: bool, 339*da0073e9SAndroid Build Coastguard Worker differentiable: bool, 340*da0073e9SAndroid Build Coastguard Worker): 341*da0073e9SAndroid Build Coastguard Worker assert grad_scale is None and found_inf is None 342*da0073e9SAndroid Build Coastguard Worker 343*da0073e9SAndroid Build Coastguard Worker if torch.jit.is_scripting(): 344*da0073e9SAndroid Build Coastguard Worker # this assert is due to JIT being dumb and not realizing that the ops below 345*da0073e9SAndroid Build Coastguard Worker # have overloads to handle both float and Tensor lrs, so we just assert it's 346*da0073e9SAndroid Build Coastguard Worker # a float since most people using JIT are using floats 347*da0073e9SAndroid Build Coastguard Worker assert isinstance(lr, float) 348*da0073e9SAndroid Build Coastguard Worker 349*da0073e9SAndroid Build Coastguard Worker for i, param in enumerate(params): 350*da0073e9SAndroid Build Coastguard Worker grad = grads[i] if not maximize else -grads[i] 351*da0073e9SAndroid Build Coastguard Worker exp_avg = exp_avgs[i] 352*da0073e9SAndroid Build Coastguard Worker exp_avg_sq = exp_avg_sqs[i] 353*da0073e9SAndroid Build Coastguard Worker step_t = state_steps[i] 354*da0073e9SAndroid Build Coastguard Worker 355*da0073e9SAndroid Build Coastguard Worker # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] 356*da0073e9SAndroid Build Coastguard Worker if not torch._utils.is_compiling() and capturable: 357*da0073e9SAndroid Build Coastguard Worker capturable_supported_devices = _get_capturable_supported_devices() 358*da0073e9SAndroid Build Coastguard Worker assert ( 359*da0073e9SAndroid Build Coastguard Worker param.device.type == step_t.device.type 360*da0073e9SAndroid Build Coastguard Worker and param.device.type in capturable_supported_devices 361*da0073e9SAndroid Build Coastguard Worker ), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}." 362*da0073e9SAndroid Build Coastguard Worker 363*da0073e9SAndroid Build Coastguard Worker # update step 364*da0073e9SAndroid Build Coastguard Worker step_t += 1 365*da0073e9SAndroid Build Coastguard Worker 366*da0073e9SAndroid Build Coastguard Worker if weight_decay != 0: 367*da0073e9SAndroid Build Coastguard Worker grad = grad.add(param, alpha=weight_decay) 368*da0073e9SAndroid Build Coastguard Worker 369*da0073e9SAndroid Build Coastguard Worker if torch.is_complex(param): 370*da0073e9SAndroid Build Coastguard Worker grad = torch.view_as_real(grad) 371*da0073e9SAndroid Build Coastguard Worker exp_avg = torch.view_as_real(exp_avg) 372*da0073e9SAndroid Build Coastguard Worker exp_avg_sq = torch.view_as_real(exp_avg_sq) 373*da0073e9SAndroid Build Coastguard Worker if amsgrad: 374*da0073e9SAndroid Build Coastguard Worker max_exp_avg_sqs[i] = torch.view_as_real(max_exp_avg_sqs[i]) 375*da0073e9SAndroid Build Coastguard Worker param = torch.view_as_real(param) 376*da0073e9SAndroid Build Coastguard Worker 377*da0073e9SAndroid Build Coastguard Worker # Decay the first and second moment running average coefficient 378*da0073e9SAndroid Build Coastguard Worker exp_avg.lerp_(grad, 1 - beta1) 379*da0073e9SAndroid Build Coastguard Worker exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2) 380*da0073e9SAndroid Build Coastguard Worker 381*da0073e9SAndroid Build Coastguard Worker if capturable or differentiable: 382*da0073e9SAndroid Build Coastguard Worker step = step_t 383*da0073e9SAndroid Build Coastguard Worker 384*da0073e9SAndroid Build Coastguard Worker bias_correction1 = 1 - beta1**step 385*da0073e9SAndroid Build Coastguard Worker bias_correction2 = 1 - beta2**step 386*da0073e9SAndroid Build Coastguard Worker 387*da0073e9SAndroid Build Coastguard Worker step_size = lr / bias_correction1 388*da0073e9SAndroid Build Coastguard Worker step_size_neg = step_size.neg() 389*da0073e9SAndroid Build Coastguard Worker 390*da0073e9SAndroid Build Coastguard Worker bias_correction2_sqrt = bias_correction2.sqrt() 391*da0073e9SAndroid Build Coastguard Worker 392*da0073e9SAndroid Build Coastguard Worker if amsgrad: 393*da0073e9SAndroid Build Coastguard Worker # Maintains the maximum of all 2nd moment running avg. till now 394*da0073e9SAndroid Build Coastguard Worker if differentiable: 395*da0073e9SAndroid Build Coastguard Worker max_exp_avg_sq = max_exp_avg_sqs[i].clone() 396*da0073e9SAndroid Build Coastguard Worker else: 397*da0073e9SAndroid Build Coastguard Worker max_exp_avg_sq = max_exp_avg_sqs[i] 398*da0073e9SAndroid Build Coastguard Worker 399*da0073e9SAndroid Build Coastguard Worker max_exp_avg_sqs[i].copy_(torch.maximum(max_exp_avg_sq, exp_avg_sq)) 400*da0073e9SAndroid Build Coastguard Worker 401*da0073e9SAndroid Build Coastguard Worker # Uses the max. for normalizing running avg. of gradient 402*da0073e9SAndroid Build Coastguard Worker # Folds in (admittedly ugly) 1-elem step_size math here to avoid extra param-set-sized read+write 403*da0073e9SAndroid Build Coastguard Worker # (can't fold it into addcdiv_ below because addcdiv_ requires value is a Number, not a Tensor) 404*da0073e9SAndroid Build Coastguard Worker denom = ( 405*da0073e9SAndroid Build Coastguard Worker max_exp_avg_sqs[i].sqrt() / (bias_correction2_sqrt * step_size_neg) 406*da0073e9SAndroid Build Coastguard Worker ).add_(eps / step_size_neg) 407*da0073e9SAndroid Build Coastguard Worker else: 408*da0073e9SAndroid Build Coastguard Worker denom = ( 409*da0073e9SAndroid Build Coastguard Worker exp_avg_sq.sqrt() / (bias_correction2_sqrt * step_size_neg) 410*da0073e9SAndroid Build Coastguard Worker ).add_(eps / step_size_neg) 411*da0073e9SAndroid Build Coastguard Worker 412*da0073e9SAndroid Build Coastguard Worker param.addcdiv_(exp_avg, denom) 413*da0073e9SAndroid Build Coastguard Worker else: 414*da0073e9SAndroid Build Coastguard Worker step = _get_value(step_t) 415*da0073e9SAndroid Build Coastguard Worker 416*da0073e9SAndroid Build Coastguard Worker bias_correction1 = 1 - beta1**step 417*da0073e9SAndroid Build Coastguard Worker bias_correction2 = 1 - beta2**step 418*da0073e9SAndroid Build Coastguard Worker 419*da0073e9SAndroid Build Coastguard Worker step_size = lr / bias_correction1 420*da0073e9SAndroid Build Coastguard Worker 421*da0073e9SAndroid Build Coastguard Worker bias_correction2_sqrt = bias_correction2**0.5 422*da0073e9SAndroid Build Coastguard Worker 423*da0073e9SAndroid Build Coastguard Worker if amsgrad: 424*da0073e9SAndroid Build Coastguard Worker # Maintains the maximum of all 2nd moment running avg. till now 425*da0073e9SAndroid Build Coastguard Worker torch.maximum(max_exp_avg_sqs[i], exp_avg_sq, out=max_exp_avg_sqs[i]) 426*da0073e9SAndroid Build Coastguard Worker 427*da0073e9SAndroid Build Coastguard Worker # Use the max. for normalizing running avg. of gradient 428*da0073e9SAndroid Build Coastguard Worker denom = (max_exp_avg_sqs[i].sqrt() / bias_correction2_sqrt).add_(eps) 429*da0073e9SAndroid Build Coastguard Worker else: 430*da0073e9SAndroid Build Coastguard Worker denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) 431*da0073e9SAndroid Build Coastguard Worker 432*da0073e9SAndroid Build Coastguard Worker param.addcdiv_(exp_avg, denom, value=-step_size) 433*da0073e9SAndroid Build Coastguard Worker 434*da0073e9SAndroid Build Coastguard Worker # Lastly, switch back to complex view 435*da0073e9SAndroid Build Coastguard Worker if amsgrad and torch.is_complex(params[i]): 436*da0073e9SAndroid Build Coastguard Worker max_exp_avg_sqs[i] = torch.view_as_complex(max_exp_avg_sqs[i]) 437*da0073e9SAndroid Build Coastguard Worker 438*da0073e9SAndroid Build Coastguard Worker 439*da0073e9SAndroid Build Coastguard Workerdef _multi_tensor_adam( 440*da0073e9SAndroid Build Coastguard Worker params: List[Tensor], 441*da0073e9SAndroid Build Coastguard Worker grads: List[Tensor], 442*da0073e9SAndroid Build Coastguard Worker exp_avgs: List[Tensor], 443*da0073e9SAndroid Build Coastguard Worker exp_avg_sqs: List[Tensor], 444*da0073e9SAndroid Build Coastguard Worker max_exp_avg_sqs: List[Tensor], 445*da0073e9SAndroid Build Coastguard Worker state_steps: List[Tensor], 446*da0073e9SAndroid Build Coastguard Worker grad_scale: Optional[Tensor], 447*da0073e9SAndroid Build Coastguard Worker found_inf: Optional[Tensor], 448*da0073e9SAndroid Build Coastguard Worker *, 449*da0073e9SAndroid Build Coastguard Worker amsgrad: bool, 450*da0073e9SAndroid Build Coastguard Worker has_complex: bool, 451*da0073e9SAndroid Build Coastguard Worker beta1: float, 452*da0073e9SAndroid Build Coastguard Worker beta2: float, 453*da0073e9SAndroid Build Coastguard Worker lr: Union[float, Tensor], 454*da0073e9SAndroid Build Coastguard Worker weight_decay: float, 455*da0073e9SAndroid Build Coastguard Worker eps: float, 456*da0073e9SAndroid Build Coastguard Worker maximize: bool, 457*da0073e9SAndroid Build Coastguard Worker capturable: bool, 458*da0073e9SAndroid Build Coastguard Worker differentiable: bool, 459*da0073e9SAndroid Build Coastguard Worker): 460*da0073e9SAndroid Build Coastguard Worker if len(params) == 0: 461*da0073e9SAndroid Build Coastguard Worker return 462*da0073e9SAndroid Build Coastguard Worker 463*da0073e9SAndroid Build Coastguard Worker if isinstance(lr, Tensor) and not capturable: 464*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 465*da0073e9SAndroid Build Coastguard Worker "lr as a Tensor is not supported for capturable=False and foreach=True" 466*da0073e9SAndroid Build Coastguard Worker ) 467*da0073e9SAndroid Build Coastguard Worker 468*da0073e9SAndroid Build Coastguard Worker # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] 469*da0073e9SAndroid Build Coastguard Worker if not torch._utils.is_compiling() and capturable: 470*da0073e9SAndroid Build Coastguard Worker capturable_supported_devices = _get_capturable_supported_devices( 471*da0073e9SAndroid Build Coastguard Worker supports_xla=False 472*da0073e9SAndroid Build Coastguard Worker ) 473*da0073e9SAndroid Build Coastguard Worker assert all( 474*da0073e9SAndroid Build Coastguard Worker p.device.type == step.device.type 475*da0073e9SAndroid Build Coastguard Worker and p.device.type in capturable_supported_devices 476*da0073e9SAndroid Build Coastguard Worker for p, step in zip(params, state_steps) 477*da0073e9SAndroid Build Coastguard Worker ), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}." 478*da0073e9SAndroid Build Coastguard Worker 479*da0073e9SAndroid Build Coastguard Worker assert grad_scale is None and found_inf is None 480*da0073e9SAndroid Build Coastguard Worker 481*da0073e9SAndroid Build Coastguard Worker assert not differentiable, "_foreach ops don't support autograd" 482*da0073e9SAndroid Build Coastguard Worker 483*da0073e9SAndroid Build Coastguard Worker grouped_tensors = Optimizer._group_tensors_by_device_and_dtype( 484*da0073e9SAndroid Build Coastguard Worker [params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps] # type: ignore[list-item] 485*da0073e9SAndroid Build Coastguard Worker ) 486*da0073e9SAndroid Build Coastguard Worker for ( 487*da0073e9SAndroid Build Coastguard Worker device_params_, 488*da0073e9SAndroid Build Coastguard Worker device_grads_, 489*da0073e9SAndroid Build Coastguard Worker device_exp_avgs_, 490*da0073e9SAndroid Build Coastguard Worker device_exp_avg_sqs_, 491*da0073e9SAndroid Build Coastguard Worker device_max_exp_avg_sqs_, 492*da0073e9SAndroid Build Coastguard Worker device_state_steps_, 493*da0073e9SAndroid Build Coastguard Worker ), _ in grouped_tensors.values(): 494*da0073e9SAndroid Build Coastguard Worker device_params = cast(List[Tensor], device_params_) 495*da0073e9SAndroid Build Coastguard Worker device_grads = cast(List[Tensor], device_grads_) 496*da0073e9SAndroid Build Coastguard Worker device_exp_avgs = cast(List[Tensor], device_exp_avgs_) 497*da0073e9SAndroid Build Coastguard Worker device_exp_avg_sqs = cast(List[Tensor], device_exp_avg_sqs_) 498*da0073e9SAndroid Build Coastguard Worker device_state_steps = cast(List[Tensor], device_state_steps_) 499*da0073e9SAndroid Build Coastguard Worker 500*da0073e9SAndroid Build Coastguard Worker # Handle complex parameters 501*da0073e9SAndroid Build Coastguard Worker if has_complex: 502*da0073e9SAndroid Build Coastguard Worker if amsgrad: 503*da0073e9SAndroid Build Coastguard Worker device_max_exp_avg_sqs = cast(List[Tensor], device_max_exp_avg_sqs_) 504*da0073e9SAndroid Build Coastguard Worker _view_as_real( 505*da0073e9SAndroid Build Coastguard Worker device_params, 506*da0073e9SAndroid Build Coastguard Worker device_grads, 507*da0073e9SAndroid Build Coastguard Worker device_exp_avgs, 508*da0073e9SAndroid Build Coastguard Worker device_exp_avg_sqs, 509*da0073e9SAndroid Build Coastguard Worker device_max_exp_avg_sqs, 510*da0073e9SAndroid Build Coastguard Worker ) 511*da0073e9SAndroid Build Coastguard Worker else: 512*da0073e9SAndroid Build Coastguard Worker _view_as_real( 513*da0073e9SAndroid Build Coastguard Worker device_params, device_grads, device_exp_avgs, device_exp_avg_sqs 514*da0073e9SAndroid Build Coastguard Worker ) 515*da0073e9SAndroid Build Coastguard Worker 516*da0073e9SAndroid Build Coastguard Worker if maximize: 517*da0073e9SAndroid Build Coastguard Worker device_grads = torch._foreach_neg(device_grads) # type: ignore[assignment] 518*da0073e9SAndroid Build Coastguard Worker 519*da0073e9SAndroid Build Coastguard Worker # Update steps 520*da0073e9SAndroid Build Coastguard Worker # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over 521*da0073e9SAndroid Build Coastguard Worker # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just 522*da0073e9SAndroid Build Coastguard Worker # wrapped it once now. The alpha is required to assure we go to the right overload. 523*da0073e9SAndroid Build Coastguard Worker if not torch._utils.is_compiling() and device_state_steps[0].is_cpu: 524*da0073e9SAndroid Build Coastguard Worker torch._foreach_add_( 525*da0073e9SAndroid Build Coastguard Worker device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0 526*da0073e9SAndroid Build Coastguard Worker ) 527*da0073e9SAndroid Build Coastguard Worker else: 528*da0073e9SAndroid Build Coastguard Worker torch._foreach_add_(device_state_steps, 1) 529*da0073e9SAndroid Build Coastguard Worker 530*da0073e9SAndroid Build Coastguard Worker if weight_decay != 0: 531*da0073e9SAndroid Build Coastguard Worker # Re-use the intermediate memory (device_grads) already allocated for maximize 532*da0073e9SAndroid Build Coastguard Worker if maximize: 533*da0073e9SAndroid Build Coastguard Worker torch._foreach_add_(device_grads, device_params, alpha=weight_decay) 534*da0073e9SAndroid Build Coastguard Worker else: 535*da0073e9SAndroid Build Coastguard Worker device_grads = torch._foreach_add( # type: ignore[assignment] 536*da0073e9SAndroid Build Coastguard Worker device_grads, device_params, alpha=weight_decay 537*da0073e9SAndroid Build Coastguard Worker ) 538*da0073e9SAndroid Build Coastguard Worker 539*da0073e9SAndroid Build Coastguard Worker # Decay the first and second moment running average coefficient 540*da0073e9SAndroid Build Coastguard Worker torch._foreach_lerp_(device_exp_avgs, device_grads, 1 - beta1) 541*da0073e9SAndroid Build Coastguard Worker 542*da0073e9SAndroid Build Coastguard Worker torch._foreach_mul_(device_exp_avg_sqs, beta2) 543*da0073e9SAndroid Build Coastguard Worker torch._foreach_addcmul_( 544*da0073e9SAndroid Build Coastguard Worker device_exp_avg_sqs, device_grads, device_grads, 1 - beta2 545*da0073e9SAndroid Build Coastguard Worker ) 546*da0073e9SAndroid Build Coastguard Worker 547*da0073e9SAndroid Build Coastguard Worker # Delete the local intermediate since it won't be used anymore to save on peak memory 548*da0073e9SAndroid Build Coastguard Worker del device_grads 549*da0073e9SAndroid Build Coastguard Worker 550*da0073e9SAndroid Build Coastguard Worker bias_correction1: Union[Tuple[Tensor, ...], List[Tensor]] 551*da0073e9SAndroid Build Coastguard Worker bias_correction2: Union[Tuple[Tensor, ...], List[Tensor]] 552*da0073e9SAndroid Build Coastguard Worker bias_correction2_sqrt: Union[Tuple[Tensor, ...], List[Tensor]] 553*da0073e9SAndroid Build Coastguard Worker 554*da0073e9SAndroid Build Coastguard Worker if capturable: 555*da0073e9SAndroid Build Coastguard Worker bias_correction1 = torch._foreach_pow(beta1, device_state_steps) 556*da0073e9SAndroid Build Coastguard Worker bias_correction2 = torch._foreach_pow(beta2, device_state_steps) 557*da0073e9SAndroid Build Coastguard Worker # foreach_sub doesn't allow a scalar as the first arg 558*da0073e9SAndroid Build Coastguard Worker torch._foreach_sub_(bias_correction1, 1) 559*da0073e9SAndroid Build Coastguard Worker torch._foreach_sub_(bias_correction2, 1) 560*da0073e9SAndroid Build Coastguard Worker # we do not negate bias_correction1 as it'll need to be negated later anyway 561*da0073e9SAndroid Build Coastguard Worker torch._foreach_neg_(bias_correction2) 562*da0073e9SAndroid Build Coastguard Worker 563*da0073e9SAndroid Build Coastguard Worker # foreach_div doesn't allow a scalar as the first arg 564*da0073e9SAndroid Build Coastguard Worker torch._foreach_div_(bias_correction1, lr) 565*da0073e9SAndroid Build Coastguard Worker torch._foreach_reciprocal_(bias_correction1) 566*da0073e9SAndroid Build Coastguard Worker 567*da0073e9SAndroid Build Coastguard Worker torch._foreach_sqrt_(bias_correction2) 568*da0073e9SAndroid Build Coastguard Worker 569*da0073e9SAndroid Build Coastguard Worker # Re-assign for clarity as we maintain minimal intermediates: we'll have 570*da0073e9SAndroid Build Coastguard Worker # step_size = - lr / (1 - beta1 ^ t) where t = num_steps 571*da0073e9SAndroid Build Coastguard Worker # bias_correction2_sqrt = sqrt(1 - beta2 ^ t) 572*da0073e9SAndroid Build Coastguard Worker step_size = bias_correction1 573*da0073e9SAndroid Build Coastguard Worker bias_correction2_sqrt = bias_correction2 574*da0073e9SAndroid Build Coastguard Worker 575*da0073e9SAndroid Build Coastguard Worker if amsgrad: 576*da0073e9SAndroid Build Coastguard Worker device_max_exp_avg_sqs = cast(List[Tensor], device_max_exp_avg_sqs_) 577*da0073e9SAndroid Build Coastguard Worker # Maintains the maximum of all 2nd moment running avg. till now 578*da0073e9SAndroid Build Coastguard Worker torch._foreach_maximum_(device_max_exp_avg_sqs, device_exp_avg_sqs) # type: ignore[assignment] 579*da0073e9SAndroid Build Coastguard Worker 580*da0073e9SAndroid Build Coastguard Worker # Set intermediate to the max. for normalizing running avg. of gradient when amsgrad 581*da0073e9SAndroid Build Coastguard Worker exp_avg_sq_sqrt = torch._foreach_sqrt(device_max_exp_avg_sqs) 582*da0073e9SAndroid Build Coastguard Worker else: 583*da0073e9SAndroid Build Coastguard Worker exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs) 584*da0073e9SAndroid Build Coastguard Worker 585*da0073e9SAndroid Build Coastguard Worker torch._foreach_div_(exp_avg_sq_sqrt, bias_correction2_sqrt) 586*da0073e9SAndroid Build Coastguard Worker torch._foreach_add_(exp_avg_sq_sqrt, eps) 587*da0073e9SAndroid Build Coastguard Worker torch._foreach_div_(exp_avg_sq_sqrt, step_size) 588*da0073e9SAndroid Build Coastguard Worker 589*da0073e9SAndroid Build Coastguard Worker # at this point, exp_avg_sq_sqrt = - (1 - beta^t) * [sqrt(exp_avg_sq / (1 - beta2^t)) + eps] / lr 590*da0073e9SAndroid Build Coastguard Worker torch._foreach_addcdiv_(device_params, device_exp_avgs, exp_avg_sq_sqrt) 591*da0073e9SAndroid Build Coastguard Worker else: 592*da0073e9SAndroid Build Coastguard Worker bias_correction1 = [ 593*da0073e9SAndroid Build Coastguard Worker 1 - beta1 ** _get_value(step) for step in device_state_steps 594*da0073e9SAndroid Build Coastguard Worker ] 595*da0073e9SAndroid Build Coastguard Worker bias_correction2 = [ 596*da0073e9SAndroid Build Coastguard Worker 1 - beta2 ** _get_value(step) for step in device_state_steps 597*da0073e9SAndroid Build Coastguard Worker ] 598*da0073e9SAndroid Build Coastguard Worker 599*da0073e9SAndroid Build Coastguard Worker step_size = _stack_if_compiling([(lr / bc) * -1 for bc in bias_correction1]) 600*da0073e9SAndroid Build Coastguard Worker 601*da0073e9SAndroid Build Coastguard Worker bias_correction2_sqrt = [bc**0.5 for bc in bias_correction2] # type: ignore[arg-type] 602*da0073e9SAndroid Build Coastguard Worker 603*da0073e9SAndroid Build Coastguard Worker if amsgrad: 604*da0073e9SAndroid Build Coastguard Worker device_max_exp_avg_sqs = cast(List[Tensor], device_max_exp_avg_sqs_) 605*da0073e9SAndroid Build Coastguard Worker # Maintains the maximum of all 2nd moment running avg. till now 606*da0073e9SAndroid Build Coastguard Worker torch._foreach_maximum_(device_max_exp_avg_sqs, device_exp_avg_sqs) 607*da0073e9SAndroid Build Coastguard Worker 608*da0073e9SAndroid Build Coastguard Worker # Use the max. for normalizing running avg. of gradient 609*da0073e9SAndroid Build Coastguard Worker exp_avg_sq_sqrt = torch._foreach_sqrt(device_max_exp_avg_sqs) 610*da0073e9SAndroid Build Coastguard Worker else: 611*da0073e9SAndroid Build Coastguard Worker exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs) 612*da0073e9SAndroid Build Coastguard Worker 613*da0073e9SAndroid Build Coastguard Worker torch._foreach_div_(exp_avg_sq_sqrt, bias_correction2_sqrt) 614*da0073e9SAndroid Build Coastguard Worker torch._foreach_add_(exp_avg_sq_sqrt, eps) 615*da0073e9SAndroid Build Coastguard Worker torch._foreach_addcdiv_( 616*da0073e9SAndroid Build Coastguard Worker device_params, device_exp_avgs, exp_avg_sq_sqrt, step_size # type: ignore[arg-type] 617*da0073e9SAndroid Build Coastguard Worker ) 618*da0073e9SAndroid Build Coastguard Worker 619*da0073e9SAndroid Build Coastguard Worker 620*da0073e9SAndroid Build Coastguard Workerdef _fused_adam( 621*da0073e9SAndroid Build Coastguard Worker params: List[Tensor], 622*da0073e9SAndroid Build Coastguard Worker grads: List[Tensor], 623*da0073e9SAndroid Build Coastguard Worker exp_avgs: List[Tensor], 624*da0073e9SAndroid Build Coastguard Worker exp_avg_sqs: List[Tensor], 625*da0073e9SAndroid Build Coastguard Worker max_exp_avg_sqs: List[Tensor], 626*da0073e9SAndroid Build Coastguard Worker state_steps: List[Tensor], 627*da0073e9SAndroid Build Coastguard Worker grad_scale: Optional[Tensor], 628*da0073e9SAndroid Build Coastguard Worker found_inf: Optional[Tensor], 629*da0073e9SAndroid Build Coastguard Worker *, 630*da0073e9SAndroid Build Coastguard Worker amsgrad: bool, 631*da0073e9SAndroid Build Coastguard Worker has_complex: bool, # Needed for consistency. 632*da0073e9SAndroid Build Coastguard Worker beta1: float, 633*da0073e9SAndroid Build Coastguard Worker beta2: float, 634*da0073e9SAndroid Build Coastguard Worker lr: Union[float, Tensor], 635*da0073e9SAndroid Build Coastguard Worker weight_decay: float, 636*da0073e9SAndroid Build Coastguard Worker eps: float, 637*da0073e9SAndroid Build Coastguard Worker maximize: bool, 638*da0073e9SAndroid Build Coastguard Worker capturable: bool, # Needed for consistency. 639*da0073e9SAndroid Build Coastguard Worker differentiable: bool, 640*da0073e9SAndroid Build Coastguard Worker) -> None: 641*da0073e9SAndroid Build Coastguard Worker if not params: 642*da0073e9SAndroid Build Coastguard Worker return 643*da0073e9SAndroid Build Coastguard Worker if differentiable: 644*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("Adam with fused=True does not support differentiable=True") 645*da0073e9SAndroid Build Coastguard Worker 646*da0073e9SAndroid Build Coastguard Worker grad_scale_dict: DeviceDict = ( 647*da0073e9SAndroid Build Coastguard Worker {grad_scale.device: grad_scale} if grad_scale is not None else {} 648*da0073e9SAndroid Build Coastguard Worker ) 649*da0073e9SAndroid Build Coastguard Worker found_inf_dict: DeviceDict = ( 650*da0073e9SAndroid Build Coastguard Worker {found_inf.device: found_inf} if found_inf is not None else {} 651*da0073e9SAndroid Build Coastguard Worker ) 652*da0073e9SAndroid Build Coastguard Worker 653*da0073e9SAndroid Build Coastguard Worker # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer 654*da0073e9SAndroid Build Coastguard Worker # treating it as a scalar. 655*da0073e9SAndroid Build Coastguard Worker lr_dict: Optional[DeviceDict] = ( 656*da0073e9SAndroid Build Coastguard Worker {lr.device: lr} if isinstance(lr, Tensor) and str(lr.device) != "cpu" else None 657*da0073e9SAndroid Build Coastguard Worker ) 658*da0073e9SAndroid Build Coastguard Worker grouped_tensors = Optimizer._group_tensors_by_device_and_dtype( 659*da0073e9SAndroid Build Coastguard Worker [params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps] # type: ignore[list-item] 660*da0073e9SAndroid Build Coastguard Worker ) 661*da0073e9SAndroid Build Coastguard Worker for (device, _), ( 662*da0073e9SAndroid Build Coastguard Worker ( 663*da0073e9SAndroid Build Coastguard Worker device_params_, 664*da0073e9SAndroid Build Coastguard Worker device_grads_, 665*da0073e9SAndroid Build Coastguard Worker device_exp_avgs_, 666*da0073e9SAndroid Build Coastguard Worker device_exp_avg_sqs_, 667*da0073e9SAndroid Build Coastguard Worker device_max_exp_avg_sqs, 668*da0073e9SAndroid Build Coastguard Worker device_state_steps_, 669*da0073e9SAndroid Build Coastguard Worker ), 670*da0073e9SAndroid Build Coastguard Worker _, 671*da0073e9SAndroid Build Coastguard Worker ) in grouped_tensors.items(): 672*da0073e9SAndroid Build Coastguard Worker device_params = cast(List[Tensor], device_params_) 673*da0073e9SAndroid Build Coastguard Worker device_grads = cast(List[Tensor], device_grads_) 674*da0073e9SAndroid Build Coastguard Worker device_exp_avgs = cast(List[Tensor], device_exp_avgs_) 675*da0073e9SAndroid Build Coastguard Worker device_exp_avg_sqs = cast(List[Tensor], device_exp_avg_sqs_) 676*da0073e9SAndroid Build Coastguard Worker device_state_steps = cast(List[Tensor], device_state_steps_) 677*da0073e9SAndroid Build Coastguard Worker 678*da0073e9SAndroid Build Coastguard Worker if device.type == "mps": # type: ignore[union-attr] 679*da0073e9SAndroid Build Coastguard Worker assert found_inf is None and grad_scale is None 680*da0073e9SAndroid Build Coastguard Worker 681*da0073e9SAndroid Build Coastguard Worker device_grad_scale, device_found_inf = None, None 682*da0073e9SAndroid Build Coastguard Worker if grad_scale is not None: 683*da0073e9SAndroid Build Coastguard Worker device_grad_scale = grad_scale_dict.setdefault( 684*da0073e9SAndroid Build Coastguard Worker device, grad_scale.to(device, non_blocking=True) 685*da0073e9SAndroid Build Coastguard Worker ) 686*da0073e9SAndroid Build Coastguard Worker if found_inf is not None: 687*da0073e9SAndroid Build Coastguard Worker device_found_inf = found_inf_dict.setdefault( 688*da0073e9SAndroid Build Coastguard Worker device, found_inf.to(device, non_blocking=True) 689*da0073e9SAndroid Build Coastguard Worker ) 690*da0073e9SAndroid Build Coastguard Worker if lr_dict is not None and device not in lr_dict: 691*da0073e9SAndroid Build Coastguard Worker lr_dict[device] = lr.to(device=device, non_blocking=True) # type: ignore[union-attr] 692*da0073e9SAndroid Build Coastguard Worker lr = lr_dict[device] 693*da0073e9SAndroid Build Coastguard Worker torch._foreach_add_(device_state_steps, 1) 694*da0073e9SAndroid Build Coastguard Worker torch._fused_adam_( 695*da0073e9SAndroid Build Coastguard Worker device_params, 696*da0073e9SAndroid Build Coastguard Worker device_grads, 697*da0073e9SAndroid Build Coastguard Worker device_exp_avgs, 698*da0073e9SAndroid Build Coastguard Worker device_exp_avg_sqs, 699*da0073e9SAndroid Build Coastguard Worker device_max_exp_avg_sqs, # type: ignore[arg-type] 700*da0073e9SAndroid Build Coastguard Worker device_state_steps, 701*da0073e9SAndroid Build Coastguard Worker amsgrad=amsgrad, 702*da0073e9SAndroid Build Coastguard Worker lr=lr, # type: ignore[arg-type] 703*da0073e9SAndroid Build Coastguard Worker beta1=beta1, 704*da0073e9SAndroid Build Coastguard Worker beta2=beta2, 705*da0073e9SAndroid Build Coastguard Worker weight_decay=weight_decay, 706*da0073e9SAndroid Build Coastguard Worker eps=eps, 707*da0073e9SAndroid Build Coastguard Worker maximize=maximize, 708*da0073e9SAndroid Build Coastguard Worker grad_scale=device_grad_scale, 709*da0073e9SAndroid Build Coastguard Worker found_inf=device_found_inf, 710*da0073e9SAndroid Build Coastguard Worker ) 711*da0073e9SAndroid Build Coastguard Worker if device_found_inf is not None: 712*da0073e9SAndroid Build Coastguard Worker torch._foreach_sub_( 713*da0073e9SAndroid Build Coastguard Worker device_state_steps, [device_found_inf] * len(device_state_steps) 714*da0073e9SAndroid Build Coastguard Worker ) 715*da0073e9SAndroid Build Coastguard Worker 716*da0073e9SAndroid Build Coastguard Worker 717*da0073e9SAndroid Build Coastguard Worker@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_adam) 718*da0073e9SAndroid Build Coastguard Workerdef adam( 719*da0073e9SAndroid Build Coastguard Worker params: List[Tensor], 720*da0073e9SAndroid Build Coastguard Worker grads: List[Tensor], 721*da0073e9SAndroid Build Coastguard Worker exp_avgs: List[Tensor], 722*da0073e9SAndroid Build Coastguard Worker exp_avg_sqs: List[Tensor], 723*da0073e9SAndroid Build Coastguard Worker max_exp_avg_sqs: List[Tensor], 724*da0073e9SAndroid Build Coastguard Worker state_steps: List[Tensor], 725*da0073e9SAndroid Build Coastguard Worker # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 726*da0073e9SAndroid Build Coastguard Worker # setting this as kwarg for now as functional API is compiled by torch/distributed/optim 727*da0073e9SAndroid Build Coastguard Worker foreach: Optional[bool] = None, 728*da0073e9SAndroid Build Coastguard Worker capturable: bool = False, 729*da0073e9SAndroid Build Coastguard Worker differentiable: bool = False, 730*da0073e9SAndroid Build Coastguard Worker fused: Optional[bool] = None, 731*da0073e9SAndroid Build Coastguard Worker grad_scale: Optional[Tensor] = None, 732*da0073e9SAndroid Build Coastguard Worker found_inf: Optional[Tensor] = None, 733*da0073e9SAndroid Build Coastguard Worker has_complex: bool = False, 734*da0073e9SAndroid Build Coastguard Worker *, 735*da0073e9SAndroid Build Coastguard Worker amsgrad: bool, 736*da0073e9SAndroid Build Coastguard Worker beta1: float, 737*da0073e9SAndroid Build Coastguard Worker beta2: float, 738*da0073e9SAndroid Build Coastguard Worker lr: Union[float, Tensor], 739*da0073e9SAndroid Build Coastguard Worker weight_decay: float, 740*da0073e9SAndroid Build Coastguard Worker eps: float, 741*da0073e9SAndroid Build Coastguard Worker maximize: bool, 742*da0073e9SAndroid Build Coastguard Worker): 743*da0073e9SAndroid Build Coastguard Worker r"""Functional API that performs Adam algorithm computation. 744*da0073e9SAndroid Build Coastguard Worker 745*da0073e9SAndroid Build Coastguard Worker See :class:`~torch.optim.Adam` for details. 746*da0073e9SAndroid Build Coastguard Worker """ 747*da0073e9SAndroid Build Coastguard Worker # Respect when the user inputs False/True for foreach or fused. We only want to change 748*da0073e9SAndroid Build Coastguard Worker # the default when neither have been user-specified. Note that we default to foreach 749*da0073e9SAndroid Build Coastguard Worker # and pass False to use_fused. This is not a mistake--we want to give the fused impl 750*da0073e9SAndroid Build Coastguard Worker # bake-in time before making it the default, even if it is typically faster. 751*da0073e9SAndroid Build Coastguard Worker if fused is None and foreach is None: 752*da0073e9SAndroid Build Coastguard Worker _, foreach = _default_to_fused_or_foreach( 753*da0073e9SAndroid Build Coastguard Worker params, differentiable, use_fused=False 754*da0073e9SAndroid Build Coastguard Worker ) 755*da0073e9SAndroid Build Coastguard Worker # Do not flip on foreach for the unsupported case where lr is a Tensor and capturable=False. 756*da0073e9SAndroid Build Coastguard Worker if foreach and isinstance(lr, Tensor) and not capturable: 757*da0073e9SAndroid Build Coastguard Worker foreach = False 758*da0073e9SAndroid Build Coastguard Worker if fused is None: 759*da0073e9SAndroid Build Coastguard Worker fused = False 760*da0073e9SAndroid Build Coastguard Worker if foreach is None: 761*da0073e9SAndroid Build Coastguard Worker foreach = False 762*da0073e9SAndroid Build Coastguard Worker 763*da0073e9SAndroid Build Coastguard Worker # this check is slow during compilation, so we skip it 764*da0073e9SAndroid Build Coastguard Worker # if it's strictly needed we can add this check back in dynamo 765*da0073e9SAndroid Build Coastguard Worker if not torch._utils.is_compiling() and not all( 766*da0073e9SAndroid Build Coastguard Worker isinstance(t, torch.Tensor) for t in state_steps 767*da0073e9SAndroid Build Coastguard Worker ): 768*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 769*da0073e9SAndroid Build Coastguard Worker "API has changed, `state_steps` argument must contain a list of singleton tensors" 770*da0073e9SAndroid Build Coastguard Worker ) 771*da0073e9SAndroid Build Coastguard Worker 772*da0073e9SAndroid Build Coastguard Worker if foreach and torch.jit.is_scripting(): 773*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("torch.jit.script not supported with foreach optimizers") 774*da0073e9SAndroid Build Coastguard Worker if fused and torch.jit.is_scripting(): 775*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("torch.jit.script not supported with fused optimizers") 776*da0073e9SAndroid Build Coastguard Worker 777*da0073e9SAndroid Build Coastguard Worker if fused and not torch.jit.is_scripting(): 778*da0073e9SAndroid Build Coastguard Worker func = _fused_adam 779*da0073e9SAndroid Build Coastguard Worker elif foreach and not torch.jit.is_scripting(): 780*da0073e9SAndroid Build Coastguard Worker func = _multi_tensor_adam 781*da0073e9SAndroid Build Coastguard Worker else: 782*da0073e9SAndroid Build Coastguard Worker func = _single_tensor_adam 783*da0073e9SAndroid Build Coastguard Worker 784*da0073e9SAndroid Build Coastguard Worker func( 785*da0073e9SAndroid Build Coastguard Worker params, 786*da0073e9SAndroid Build Coastguard Worker grads, 787*da0073e9SAndroid Build Coastguard Worker exp_avgs, 788*da0073e9SAndroid Build Coastguard Worker exp_avg_sqs, 789*da0073e9SAndroid Build Coastguard Worker max_exp_avg_sqs, 790*da0073e9SAndroid Build Coastguard Worker state_steps, 791*da0073e9SAndroid Build Coastguard Worker amsgrad=amsgrad, 792*da0073e9SAndroid Build Coastguard Worker has_complex=has_complex, 793*da0073e9SAndroid Build Coastguard Worker beta1=beta1, 794*da0073e9SAndroid Build Coastguard Worker beta2=beta2, 795*da0073e9SAndroid Build Coastguard Worker lr=lr, 796*da0073e9SAndroid Build Coastguard Worker weight_decay=weight_decay, 797*da0073e9SAndroid Build Coastguard Worker eps=eps, 798*da0073e9SAndroid Build Coastguard Worker maximize=maximize, 799*da0073e9SAndroid Build Coastguard Worker capturable=capturable, 800*da0073e9SAndroid Build Coastguard Worker differentiable=differentiable, 801*da0073e9SAndroid Build Coastguard Worker grad_scale=grad_scale, 802*da0073e9SAndroid Build Coastguard Worker found_inf=found_inf, 803*da0073e9SAndroid Build Coastguard Worker ) 804