1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs 2*da0073e9SAndroid Build Coastguard Workerr"""Implementation for Stochastic Gradient Descent optimizer.""" 3*da0073e9SAndroid Build Coastguard Workerfrom typing import cast, List, Optional, Union 4*da0073e9SAndroid Build Coastguard Worker 5*da0073e9SAndroid Build Coastguard Workerimport torch 6*da0073e9SAndroid Build Coastguard Workerfrom torch import Tensor 7*da0073e9SAndroid Build Coastguard Worker 8*da0073e9SAndroid Build Coastguard Workerfrom .optimizer import ( 9*da0073e9SAndroid Build Coastguard Worker _default_to_fused_or_foreach, 10*da0073e9SAndroid Build Coastguard Worker _device_dtype_check_for_fused, 11*da0073e9SAndroid Build Coastguard Worker _differentiable_doc, 12*da0073e9SAndroid Build Coastguard Worker _foreach_doc, 13*da0073e9SAndroid Build Coastguard Worker _fused_doc, 14*da0073e9SAndroid Build Coastguard Worker _maximize_doc, 15*da0073e9SAndroid Build Coastguard Worker _use_grad_for_differentiable, 16*da0073e9SAndroid Build Coastguard Worker DeviceDict, 17*da0073e9SAndroid Build Coastguard Worker Optimizer, 18*da0073e9SAndroid Build Coastguard Worker) 19*da0073e9SAndroid Build Coastguard Worker 20*da0073e9SAndroid Build Coastguard Worker 21*da0073e9SAndroid Build Coastguard Worker__all__ = ["SGD", "sgd"] 22*da0073e9SAndroid Build Coastguard Worker 23*da0073e9SAndroid Build Coastguard Worker 24*da0073e9SAndroid Build Coastguard Workerclass SGD(Optimizer): # noqa: D101 25*da0073e9SAndroid Build Coastguard Worker def __init__( 26*da0073e9SAndroid Build Coastguard Worker self, 27*da0073e9SAndroid Build Coastguard Worker params, 28*da0073e9SAndroid Build Coastguard Worker lr: Union[float, Tensor] = 1e-3, 29*da0073e9SAndroid Build Coastguard Worker momentum: float = 0, 30*da0073e9SAndroid Build Coastguard Worker dampening: float = 0, 31*da0073e9SAndroid Build Coastguard Worker weight_decay: float = 0, 32*da0073e9SAndroid Build Coastguard Worker nesterov=False, 33*da0073e9SAndroid Build Coastguard Worker *, 34*da0073e9SAndroid Build Coastguard Worker maximize: bool = False, 35*da0073e9SAndroid Build Coastguard Worker foreach: Optional[bool] = None, 36*da0073e9SAndroid Build Coastguard Worker differentiable: bool = False, 37*da0073e9SAndroid Build Coastguard Worker fused: Optional[bool] = None, 38*da0073e9SAndroid Build Coastguard Worker ): # noqa: D107 39*da0073e9SAndroid Build Coastguard Worker if isinstance(lr, Tensor) and lr.numel() != 1: 40*da0073e9SAndroid Build Coastguard Worker raise ValueError("Tensor lr must be 1-element") 41*da0073e9SAndroid Build Coastguard Worker if lr < 0.0: 42*da0073e9SAndroid Build Coastguard Worker raise ValueError(f"Invalid learning rate: {lr}") 43*da0073e9SAndroid Build Coastguard Worker if momentum < 0.0: 44*da0073e9SAndroid Build Coastguard Worker raise ValueError(f"Invalid momentum value: {momentum}") 45*da0073e9SAndroid Build Coastguard Worker if weight_decay < 0.0: 46*da0073e9SAndroid Build Coastguard Worker raise ValueError(f"Invalid weight_decay value: {weight_decay}") 47*da0073e9SAndroid Build Coastguard Worker 48*da0073e9SAndroid Build Coastguard Worker defaults = dict( 49*da0073e9SAndroid Build Coastguard Worker lr=lr, 50*da0073e9SAndroid Build Coastguard Worker momentum=momentum, 51*da0073e9SAndroid Build Coastguard Worker dampening=dampening, 52*da0073e9SAndroid Build Coastguard Worker weight_decay=weight_decay, 53*da0073e9SAndroid Build Coastguard Worker nesterov=nesterov, 54*da0073e9SAndroid Build Coastguard Worker maximize=maximize, 55*da0073e9SAndroid Build Coastguard Worker foreach=foreach, 56*da0073e9SAndroid Build Coastguard Worker differentiable=differentiable, 57*da0073e9SAndroid Build Coastguard Worker fused=fused, 58*da0073e9SAndroid Build Coastguard Worker ) 59*da0073e9SAndroid Build Coastguard Worker if nesterov and (momentum <= 0 or dampening != 0): 60*da0073e9SAndroid Build Coastguard Worker raise ValueError("Nesterov momentum requires a momentum and zero dampening") 61*da0073e9SAndroid Build Coastguard Worker super().__init__(params, defaults) 62*da0073e9SAndroid Build Coastguard Worker 63*da0073e9SAndroid Build Coastguard Worker if fused: 64*da0073e9SAndroid Build Coastguard Worker self._step_supports_amp_scaling = True 65*da0073e9SAndroid Build Coastguard Worker self._need_device_dtype_check_for_fused = True 66*da0073e9SAndroid Build Coastguard Worker if differentiable: 67*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("`fused` does not support `differentiable`") 68*da0073e9SAndroid Build Coastguard Worker if foreach: 69*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("`fused` and `foreach` cannot be `True` together.") 70*da0073e9SAndroid Build Coastguard Worker 71*da0073e9SAndroid Build Coastguard Worker def __setstate__(self, state): # noqa: D105 72*da0073e9SAndroid Build Coastguard Worker super().__setstate__(state) 73*da0073e9SAndroid Build Coastguard Worker for group in self.param_groups: 74*da0073e9SAndroid Build Coastguard Worker group.setdefault("nesterov", False) 75*da0073e9SAndroid Build Coastguard Worker group.setdefault("maximize", False) 76*da0073e9SAndroid Build Coastguard Worker group.setdefault("foreach", None) 77*da0073e9SAndroid Build Coastguard Worker group.setdefault("differentiable", False) 78*da0073e9SAndroid Build Coastguard Worker group.setdefault("fused", False) 79*da0073e9SAndroid Build Coastguard Worker 80*da0073e9SAndroid Build Coastguard Worker def _init_group(self, group, params, grads, momentum_buffer_list): 81*da0073e9SAndroid Build Coastguard Worker has_sparse_grad = False 82*da0073e9SAndroid Build Coastguard Worker 83*da0073e9SAndroid Build Coastguard Worker for p in group["params"]: 84*da0073e9SAndroid Build Coastguard Worker if p.grad is not None: 85*da0073e9SAndroid Build Coastguard Worker if group["fused"] and getattr( 86*da0073e9SAndroid Build Coastguard Worker self, "_need_device_dtype_check_for_fused", True 87*da0073e9SAndroid Build Coastguard Worker ): 88*da0073e9SAndroid Build Coastguard Worker _device_dtype_check_for_fused(p) 89*da0073e9SAndroid Build Coastguard Worker self._need_device_dtype_check_for_fused = False 90*da0073e9SAndroid Build Coastguard Worker params.append(p) 91*da0073e9SAndroid Build Coastguard Worker grads.append(p.grad) 92*da0073e9SAndroid Build Coastguard Worker if p.grad.is_sparse: 93*da0073e9SAndroid Build Coastguard Worker has_sparse_grad = True 94*da0073e9SAndroid Build Coastguard Worker 95*da0073e9SAndroid Build Coastguard Worker if group["momentum"] != 0: 96*da0073e9SAndroid Build Coastguard Worker state = self.state[p] 97*da0073e9SAndroid Build Coastguard Worker momentum_buffer_list.append(state.get("momentum_buffer")) 98*da0073e9SAndroid Build Coastguard Worker 99*da0073e9SAndroid Build Coastguard Worker return has_sparse_grad 100*da0073e9SAndroid Build Coastguard Worker 101*da0073e9SAndroid Build Coastguard Worker @_use_grad_for_differentiable 102*da0073e9SAndroid Build Coastguard Worker def step(self, closure=None): 103*da0073e9SAndroid Build Coastguard Worker """Perform a single optimization step. 104*da0073e9SAndroid Build Coastguard Worker 105*da0073e9SAndroid Build Coastguard Worker Args: 106*da0073e9SAndroid Build Coastguard Worker closure (Callable, optional): A closure that reevaluates the model 107*da0073e9SAndroid Build Coastguard Worker and returns the loss. 108*da0073e9SAndroid Build Coastguard Worker """ 109*da0073e9SAndroid Build Coastguard Worker loss = None 110*da0073e9SAndroid Build Coastguard Worker if closure is not None: 111*da0073e9SAndroid Build Coastguard Worker with torch.enable_grad(): 112*da0073e9SAndroid Build Coastguard Worker loss = closure() 113*da0073e9SAndroid Build Coastguard Worker 114*da0073e9SAndroid Build Coastguard Worker for group in self.param_groups: 115*da0073e9SAndroid Build Coastguard Worker params: List[Tensor] = [] 116*da0073e9SAndroid Build Coastguard Worker grads: List[Tensor] = [] 117*da0073e9SAndroid Build Coastguard Worker momentum_buffer_list: List[Optional[Tensor]] = [] 118*da0073e9SAndroid Build Coastguard Worker 119*da0073e9SAndroid Build Coastguard Worker has_sparse_grad = self._init_group( 120*da0073e9SAndroid Build Coastguard Worker group, params, grads, momentum_buffer_list 121*da0073e9SAndroid Build Coastguard Worker ) 122*da0073e9SAndroid Build Coastguard Worker 123*da0073e9SAndroid Build Coastguard Worker sgd( 124*da0073e9SAndroid Build Coastguard Worker params, 125*da0073e9SAndroid Build Coastguard Worker grads, 126*da0073e9SAndroid Build Coastguard Worker momentum_buffer_list, 127*da0073e9SAndroid Build Coastguard Worker weight_decay=group["weight_decay"], 128*da0073e9SAndroid Build Coastguard Worker momentum=group["momentum"], 129*da0073e9SAndroid Build Coastguard Worker lr=group["lr"], 130*da0073e9SAndroid Build Coastguard Worker dampening=group["dampening"], 131*da0073e9SAndroid Build Coastguard Worker nesterov=group["nesterov"], 132*da0073e9SAndroid Build Coastguard Worker maximize=group["maximize"], 133*da0073e9SAndroid Build Coastguard Worker has_sparse_grad=has_sparse_grad, 134*da0073e9SAndroid Build Coastguard Worker foreach=group["foreach"], 135*da0073e9SAndroid Build Coastguard Worker fused=group["fused"], 136*da0073e9SAndroid Build Coastguard Worker grad_scale=getattr(self, "grad_scale", None), 137*da0073e9SAndroid Build Coastguard Worker found_inf=getattr(self, "found_inf", None), 138*da0073e9SAndroid Build Coastguard Worker ) 139*da0073e9SAndroid Build Coastguard Worker 140*da0073e9SAndroid Build Coastguard Worker if group["momentum"] != 0: 141*da0073e9SAndroid Build Coastguard Worker # update momentum_buffers in state 142*da0073e9SAndroid Build Coastguard Worker for p, momentum_buffer in zip(params, momentum_buffer_list): 143*da0073e9SAndroid Build Coastguard Worker state = self.state[p] 144*da0073e9SAndroid Build Coastguard Worker state["momentum_buffer"] = momentum_buffer 145*da0073e9SAndroid Build Coastguard Worker 146*da0073e9SAndroid Build Coastguard Worker return loss 147*da0073e9SAndroid Build Coastguard Worker 148*da0073e9SAndroid Build Coastguard Worker 149*da0073e9SAndroid Build Coastguard WorkerSGD.__doc__ = ( 150*da0073e9SAndroid Build Coastguard Worker r"""Implements stochastic gradient descent (optionally with momentum). 151*da0073e9SAndroid Build Coastguard Worker 152*da0073e9SAndroid Build Coastguard Worker .. math:: 153*da0073e9SAndroid Build Coastguard Worker \begin{aligned} 154*da0073e9SAndroid Build Coastguard Worker &\rule{110mm}{0.4pt} \\ 155*da0073e9SAndroid Build Coastguard Worker &\textbf{input} : \gamma \text{ (lr)}, \: \theta_0 \text{ (params)}, \: f(\theta) 156*da0073e9SAndroid Build Coastguard Worker \text{ (objective)}, \: \lambda \text{ (weight decay)}, \\ 157*da0073e9SAndroid Build Coastguard Worker &\hspace{13mm} \:\mu \text{ (momentum)}, \:\tau \text{ (dampening)}, 158*da0073e9SAndroid Build Coastguard Worker \:\textit{ nesterov,}\:\textit{ maximize} \\[-1.ex] 159*da0073e9SAndroid Build Coastguard Worker &\rule{110mm}{0.4pt} \\ 160*da0073e9SAndroid Build Coastguard Worker &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\ 161*da0073e9SAndroid Build Coastguard Worker &\hspace{5mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\ 162*da0073e9SAndroid Build Coastguard Worker &\hspace{5mm}\textbf{if} \: \lambda \neq 0 \\ 163*da0073e9SAndroid Build Coastguard Worker &\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\ 164*da0073e9SAndroid Build Coastguard Worker &\hspace{5mm}\textbf{if} \: \mu \neq 0 \\ 165*da0073e9SAndroid Build Coastguard Worker &\hspace{10mm}\textbf{if} \: t > 1 \\ 166*da0073e9SAndroid Build Coastguard Worker &\hspace{15mm} \textbf{b}_t \leftarrow \mu \textbf{b}_{t-1} + (1-\tau) g_t \\ 167*da0073e9SAndroid Build Coastguard Worker &\hspace{10mm}\textbf{else} \\ 168*da0073e9SAndroid Build Coastguard Worker &\hspace{15mm} \textbf{b}_t \leftarrow g_t \\ 169*da0073e9SAndroid Build Coastguard Worker &\hspace{10mm}\textbf{if} \: \textit{nesterov} \\ 170*da0073e9SAndroid Build Coastguard Worker &\hspace{15mm} g_t \leftarrow g_{t} + \mu \textbf{b}_t \\ 171*da0073e9SAndroid Build Coastguard Worker &\hspace{10mm}\textbf{else} \\[-1.ex] 172*da0073e9SAndroid Build Coastguard Worker &\hspace{15mm} g_t \leftarrow \textbf{b}_t \\ 173*da0073e9SAndroid Build Coastguard Worker &\hspace{5mm}\textbf{if} \: \textit{maximize} \\ 174*da0073e9SAndroid Build Coastguard Worker &\hspace{10mm}\theta_t \leftarrow \theta_{t-1} + \gamma g_t \\[-1.ex] 175*da0073e9SAndroid Build Coastguard Worker &\hspace{5mm}\textbf{else} \\[-1.ex] 176*da0073e9SAndroid Build Coastguard Worker &\hspace{10mm}\theta_t \leftarrow \theta_{t-1} - \gamma g_t \\[-1.ex] 177*da0073e9SAndroid Build Coastguard Worker &\rule{110mm}{0.4pt} \\[-1.ex] 178*da0073e9SAndroid Build Coastguard Worker &\bf{return} \: \theta_t \\[-1.ex] 179*da0073e9SAndroid Build Coastguard Worker &\rule{110mm}{0.4pt} \\[-1.ex] 180*da0073e9SAndroid Build Coastguard Worker \end{aligned} 181*da0073e9SAndroid Build Coastguard Worker 182*da0073e9SAndroid Build Coastguard Worker Nesterov momentum is based on the formula from 183*da0073e9SAndroid Build Coastguard Worker `On the importance of initialization and momentum in deep learning`__. 184*da0073e9SAndroid Build Coastguard Worker """ 185*da0073e9SAndroid Build Coastguard Worker + rf""" 186*da0073e9SAndroid Build Coastguard Worker Args: 187*da0073e9SAndroid Build Coastguard Worker params (iterable): iterable of parameters to optimize or dicts defining 188*da0073e9SAndroid Build Coastguard Worker parameter groups 189*da0073e9SAndroid Build Coastguard Worker lr (float, Tensor, optional): learning rate (default: 1e-3) 190*da0073e9SAndroid Build Coastguard Worker momentum (float, optional): momentum factor (default: 0) 191*da0073e9SAndroid Build Coastguard Worker weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 192*da0073e9SAndroid Build Coastguard Worker dampening (float, optional): dampening for momentum (default: 0) 193*da0073e9SAndroid Build Coastguard Worker nesterov (bool, optional): enables Nesterov momentum (default: False) 194*da0073e9SAndroid Build Coastguard Worker {_maximize_doc} 195*da0073e9SAndroid Build Coastguard Worker {_foreach_doc} 196*da0073e9SAndroid Build Coastguard Worker {_differentiable_doc} 197*da0073e9SAndroid Build Coastguard Worker {_fused_doc} 198*da0073e9SAndroid Build Coastguard Worker """ 199*da0073e9SAndroid Build Coastguard Worker + r""" 200*da0073e9SAndroid Build Coastguard Worker 201*da0073e9SAndroid Build Coastguard Worker Example: 202*da0073e9SAndroid Build Coastguard Worker >>> # xdoctest: +SKIP 203*da0073e9SAndroid Build Coastguard Worker >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) 204*da0073e9SAndroid Build Coastguard Worker >>> optimizer.zero_grad() 205*da0073e9SAndroid Build Coastguard Worker >>> loss_fn(model(input), target).backward() 206*da0073e9SAndroid Build Coastguard Worker >>> optimizer.step() 207*da0073e9SAndroid Build Coastguard Worker 208*da0073e9SAndroid Build Coastguard Worker __ http://www.cs.toronto.edu/%7Ehinton/absps/momentum.pdf 209*da0073e9SAndroid Build Coastguard Worker 210*da0073e9SAndroid Build Coastguard Worker .. note:: 211*da0073e9SAndroid Build Coastguard Worker The implementation of SGD with Momentum/Nesterov subtly differs from 212*da0073e9SAndroid Build Coastguard Worker Sutskever et al. and implementations in some other frameworks. 213*da0073e9SAndroid Build Coastguard Worker 214*da0073e9SAndroid Build Coastguard Worker Considering the specific case of Momentum, the update can be written as 215*da0073e9SAndroid Build Coastguard Worker 216*da0073e9SAndroid Build Coastguard Worker .. math:: 217*da0073e9SAndroid Build Coastguard Worker \begin{aligned} 218*da0073e9SAndroid Build Coastguard Worker v_{t+1} & = \mu * v_{t} + g_{t+1}, \\ 219*da0073e9SAndroid Build Coastguard Worker p_{t+1} & = p_{t} - \text{lr} * v_{t+1}, 220*da0073e9SAndroid Build Coastguard Worker \end{aligned} 221*da0073e9SAndroid Build Coastguard Worker 222*da0073e9SAndroid Build Coastguard Worker where :math:`p`, :math:`g`, :math:`v` and :math:`\mu` denote the 223*da0073e9SAndroid Build Coastguard Worker parameters, gradient, velocity, and momentum respectively. 224*da0073e9SAndroid Build Coastguard Worker 225*da0073e9SAndroid Build Coastguard Worker This is in contrast to Sutskever et al. and 226*da0073e9SAndroid Build Coastguard Worker other frameworks which employ an update of the form 227*da0073e9SAndroid Build Coastguard Worker 228*da0073e9SAndroid Build Coastguard Worker .. math:: 229*da0073e9SAndroid Build Coastguard Worker \begin{aligned} 230*da0073e9SAndroid Build Coastguard Worker v_{t+1} & = \mu * v_{t} + \text{lr} * g_{t+1}, \\ 231*da0073e9SAndroid Build Coastguard Worker p_{t+1} & = p_{t} - v_{t+1}. 232*da0073e9SAndroid Build Coastguard Worker \end{aligned} 233*da0073e9SAndroid Build Coastguard Worker 234*da0073e9SAndroid Build Coastguard Worker The Nesterov version is analogously modified. 235*da0073e9SAndroid Build Coastguard Worker 236*da0073e9SAndroid Build Coastguard Worker Moreover, the initial value of the momentum buffer is set to the 237*da0073e9SAndroid Build Coastguard Worker gradient value at the first step. This is in contrast to some other 238*da0073e9SAndroid Build Coastguard Worker frameworks that initialize it to all zeros. 239*da0073e9SAndroid Build Coastguard Worker 240*da0073e9SAndroid Build Coastguard Worker """ 241*da0073e9SAndroid Build Coastguard Worker) 242*da0073e9SAndroid Build Coastguard Worker 243*da0073e9SAndroid Build Coastguard Worker 244*da0073e9SAndroid Build Coastguard Workerdef sgd( 245*da0073e9SAndroid Build Coastguard Worker params: List[Tensor], 246*da0073e9SAndroid Build Coastguard Worker d_p_list: List[Tensor], 247*da0073e9SAndroid Build Coastguard Worker momentum_buffer_list: List[Optional[Tensor]], 248*da0073e9SAndroid Build Coastguard Worker # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 249*da0073e9SAndroid Build Coastguard Worker # setting this as kwarg for now as functional API is compiled by torch/distributed/optim 250*da0073e9SAndroid Build Coastguard Worker has_sparse_grad: bool = False, 251*da0073e9SAndroid Build Coastguard Worker foreach: Optional[bool] = None, 252*da0073e9SAndroid Build Coastguard Worker fused: Optional[bool] = None, 253*da0073e9SAndroid Build Coastguard Worker grad_scale: Optional[Tensor] = None, 254*da0073e9SAndroid Build Coastguard Worker found_inf: Optional[Tensor] = None, 255*da0073e9SAndroid Build Coastguard Worker *, 256*da0073e9SAndroid Build Coastguard Worker weight_decay: float, 257*da0073e9SAndroid Build Coastguard Worker momentum: float, 258*da0073e9SAndroid Build Coastguard Worker lr: float, 259*da0073e9SAndroid Build Coastguard Worker dampening: float, 260*da0073e9SAndroid Build Coastguard Worker nesterov: bool, 261*da0073e9SAndroid Build Coastguard Worker maximize: bool, 262*da0073e9SAndroid Build Coastguard Worker): 263*da0073e9SAndroid Build Coastguard Worker r"""Functional API that performs SGD algorithm computation. 264*da0073e9SAndroid Build Coastguard Worker 265*da0073e9SAndroid Build Coastguard Worker See :class:`~torch.optim.SGD` for details. 266*da0073e9SAndroid Build Coastguard Worker """ 267*da0073e9SAndroid Build Coastguard Worker # Respect when the user inputs False/True for foreach or fused. We only want to change 268*da0073e9SAndroid Build Coastguard Worker # the default when neither have been user-specified. Note that we default to foreach 269*da0073e9SAndroid Build Coastguard Worker # and pass False to use_fused. This is not a mistake--we want to give the fused impl 270*da0073e9SAndroid Build Coastguard Worker # bake-in time before making it the default, even if it is typically faster. 271*da0073e9SAndroid Build Coastguard Worker if foreach is None and fused is None: 272*da0073e9SAndroid Build Coastguard Worker # why must we be explicit about an if statement for torch.jit.is_scripting here? 273*da0073e9SAndroid Build Coastguard Worker # because JIT can't handle Optionals nor fancy conditionals when scripting 274*da0073e9SAndroid Build Coastguard Worker if not torch.jit.is_scripting(): 275*da0073e9SAndroid Build Coastguard Worker fused, foreach = _default_to_fused_or_foreach( 276*da0073e9SAndroid Build Coastguard Worker params, differentiable=False, use_fused=False 277*da0073e9SAndroid Build Coastguard Worker ) 278*da0073e9SAndroid Build Coastguard Worker else: 279*da0073e9SAndroid Build Coastguard Worker foreach = False 280*da0073e9SAndroid Build Coastguard Worker fused = False 281*da0073e9SAndroid Build Coastguard Worker if foreach is None: 282*da0073e9SAndroid Build Coastguard Worker foreach = False 283*da0073e9SAndroid Build Coastguard Worker if fused is None: 284*da0073e9SAndroid Build Coastguard Worker fused = False 285*da0073e9SAndroid Build Coastguard Worker 286*da0073e9SAndroid Build Coastguard Worker if foreach and torch.jit.is_scripting(): 287*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("torch.jit.script not supported with foreach optimizers") 288*da0073e9SAndroid Build Coastguard Worker if fused and torch.jit.is_scripting(): 289*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("torch.jit.script not supported with fused optimizers") 290*da0073e9SAndroid Build Coastguard Worker 291*da0073e9SAndroid Build Coastguard Worker if foreach and not torch.jit.is_scripting(): 292*da0073e9SAndroid Build Coastguard Worker func = _multi_tensor_sgd 293*da0073e9SAndroid Build Coastguard Worker elif fused and not torch.jit.is_scripting(): 294*da0073e9SAndroid Build Coastguard Worker func = _fused_sgd 295*da0073e9SAndroid Build Coastguard Worker else: 296*da0073e9SAndroid Build Coastguard Worker func = _single_tensor_sgd 297*da0073e9SAndroid Build Coastguard Worker 298*da0073e9SAndroid Build Coastguard Worker func( 299*da0073e9SAndroid Build Coastguard Worker params, 300*da0073e9SAndroid Build Coastguard Worker d_p_list, 301*da0073e9SAndroid Build Coastguard Worker momentum_buffer_list, 302*da0073e9SAndroid Build Coastguard Worker weight_decay=weight_decay, 303*da0073e9SAndroid Build Coastguard Worker momentum=momentum, 304*da0073e9SAndroid Build Coastguard Worker lr=lr, 305*da0073e9SAndroid Build Coastguard Worker dampening=dampening, 306*da0073e9SAndroid Build Coastguard Worker nesterov=nesterov, 307*da0073e9SAndroid Build Coastguard Worker has_sparse_grad=has_sparse_grad, 308*da0073e9SAndroid Build Coastguard Worker maximize=maximize, 309*da0073e9SAndroid Build Coastguard Worker grad_scale=grad_scale, 310*da0073e9SAndroid Build Coastguard Worker found_inf=found_inf, 311*da0073e9SAndroid Build Coastguard Worker ) 312*da0073e9SAndroid Build Coastguard Worker 313*da0073e9SAndroid Build Coastguard Worker 314*da0073e9SAndroid Build Coastguard Workerdef _single_tensor_sgd( 315*da0073e9SAndroid Build Coastguard Worker params: List[Tensor], 316*da0073e9SAndroid Build Coastguard Worker grads: List[Tensor], 317*da0073e9SAndroid Build Coastguard Worker momentum_buffer_list: List[Optional[Tensor]], 318*da0073e9SAndroid Build Coastguard Worker grad_scale: Optional[Tensor], 319*da0073e9SAndroid Build Coastguard Worker found_inf: Optional[Tensor], 320*da0073e9SAndroid Build Coastguard Worker *, 321*da0073e9SAndroid Build Coastguard Worker weight_decay: float, 322*da0073e9SAndroid Build Coastguard Worker momentum: float, 323*da0073e9SAndroid Build Coastguard Worker lr: float, 324*da0073e9SAndroid Build Coastguard Worker dampening: float, 325*da0073e9SAndroid Build Coastguard Worker nesterov: bool, 326*da0073e9SAndroid Build Coastguard Worker maximize: bool, 327*da0073e9SAndroid Build Coastguard Worker has_sparse_grad: bool, 328*da0073e9SAndroid Build Coastguard Worker): 329*da0073e9SAndroid Build Coastguard Worker assert grad_scale is None and found_inf is None 330*da0073e9SAndroid Build Coastguard Worker 331*da0073e9SAndroid Build Coastguard Worker for i, param in enumerate(params): 332*da0073e9SAndroid Build Coastguard Worker grad = grads[i] if not maximize else -grads[i] 333*da0073e9SAndroid Build Coastguard Worker 334*da0073e9SAndroid Build Coastguard Worker if weight_decay != 0: 335*da0073e9SAndroid Build Coastguard Worker grad = grad.add(param, alpha=weight_decay) 336*da0073e9SAndroid Build Coastguard Worker 337*da0073e9SAndroid Build Coastguard Worker if momentum != 0: 338*da0073e9SAndroid Build Coastguard Worker buf = momentum_buffer_list[i] 339*da0073e9SAndroid Build Coastguard Worker 340*da0073e9SAndroid Build Coastguard Worker if buf is None: 341*da0073e9SAndroid Build Coastguard Worker buf = torch.clone(grad).detach() 342*da0073e9SAndroid Build Coastguard Worker momentum_buffer_list[i] = buf 343*da0073e9SAndroid Build Coastguard Worker else: 344*da0073e9SAndroid Build Coastguard Worker buf.mul_(momentum).add_(grad, alpha=1 - dampening) 345*da0073e9SAndroid Build Coastguard Worker 346*da0073e9SAndroid Build Coastguard Worker if nesterov: 347*da0073e9SAndroid Build Coastguard Worker grad = grad.add(buf, alpha=momentum) 348*da0073e9SAndroid Build Coastguard Worker else: 349*da0073e9SAndroid Build Coastguard Worker grad = buf 350*da0073e9SAndroid Build Coastguard Worker 351*da0073e9SAndroid Build Coastguard Worker param.add_(grad, alpha=-lr) 352*da0073e9SAndroid Build Coastguard Worker 353*da0073e9SAndroid Build Coastguard Worker 354*da0073e9SAndroid Build Coastguard Workerdef _multi_tensor_sgd( 355*da0073e9SAndroid Build Coastguard Worker params: List[Tensor], 356*da0073e9SAndroid Build Coastguard Worker grads: List[Tensor], 357*da0073e9SAndroid Build Coastguard Worker momentum_buffer_list: List[Optional[Tensor]], 358*da0073e9SAndroid Build Coastguard Worker grad_scale: Optional[Tensor], 359*da0073e9SAndroid Build Coastguard Worker found_inf: Optional[Tensor], 360*da0073e9SAndroid Build Coastguard Worker *, 361*da0073e9SAndroid Build Coastguard Worker weight_decay: float, 362*da0073e9SAndroid Build Coastguard Worker momentum: float, 363*da0073e9SAndroid Build Coastguard Worker lr: float, 364*da0073e9SAndroid Build Coastguard Worker dampening: float, 365*da0073e9SAndroid Build Coastguard Worker nesterov: bool, 366*da0073e9SAndroid Build Coastguard Worker maximize: bool, 367*da0073e9SAndroid Build Coastguard Worker has_sparse_grad: bool, 368*da0073e9SAndroid Build Coastguard Worker): 369*da0073e9SAndroid Build Coastguard Worker assert grad_scale is None and found_inf is None 370*da0073e9SAndroid Build Coastguard Worker 371*da0073e9SAndroid Build Coastguard Worker if len(params) == 0: 372*da0073e9SAndroid Build Coastguard Worker return 373*da0073e9SAndroid Build Coastguard Worker 374*da0073e9SAndroid Build Coastguard Worker grouped_tensors = Optimizer._group_tensors_by_device_and_dtype( 375*da0073e9SAndroid Build Coastguard Worker [params, grads, momentum_buffer_list], with_indices=True # type: ignore[list-item] 376*da0073e9SAndroid Build Coastguard Worker ) 377*da0073e9SAndroid Build Coastguard Worker 378*da0073e9SAndroid Build Coastguard Worker for ( 379*da0073e9SAndroid Build Coastguard Worker device_params_, 380*da0073e9SAndroid Build Coastguard Worker device_grads_, 381*da0073e9SAndroid Build Coastguard Worker device_momentum_buffer_list, 382*da0073e9SAndroid Build Coastguard Worker ), indices in grouped_tensors.values(): 383*da0073e9SAndroid Build Coastguard Worker device_params: List[Tensor] = cast(List[Tensor], device_params_) 384*da0073e9SAndroid Build Coastguard Worker device_grads: List[Tensor] = cast(List[Tensor], device_grads_) 385*da0073e9SAndroid Build Coastguard Worker 386*da0073e9SAndroid Build Coastguard Worker device_has_sparse_grad = has_sparse_grad and any( 387*da0073e9SAndroid Build Coastguard Worker grad.is_sparse for grad in device_grads 388*da0073e9SAndroid Build Coastguard Worker ) 389*da0073e9SAndroid Build Coastguard Worker 390*da0073e9SAndroid Build Coastguard Worker if maximize: 391*da0073e9SAndroid Build Coastguard Worker device_grads = torch._foreach_neg(device_grads) # type: ignore[assignment] 392*da0073e9SAndroid Build Coastguard Worker 393*da0073e9SAndroid Build Coastguard Worker if weight_decay != 0: 394*da0073e9SAndroid Build Coastguard Worker # Re-use the intermediate memory (device_grads) already allocated for maximize 395*da0073e9SAndroid Build Coastguard Worker if maximize: 396*da0073e9SAndroid Build Coastguard Worker torch._foreach_add_(device_grads, device_params, alpha=weight_decay) 397*da0073e9SAndroid Build Coastguard Worker else: 398*da0073e9SAndroid Build Coastguard Worker device_grads = torch._foreach_add( # type: ignore[assignment] 399*da0073e9SAndroid Build Coastguard Worker device_grads, device_params, alpha=weight_decay 400*da0073e9SAndroid Build Coastguard Worker ) 401*da0073e9SAndroid Build Coastguard Worker 402*da0073e9SAndroid Build Coastguard Worker if momentum != 0: 403*da0073e9SAndroid Build Coastguard Worker bufs: List[Tensor] = [] 404*da0073e9SAndroid Build Coastguard Worker 405*da0073e9SAndroid Build Coastguard Worker all_states_with_momentum_buffer = True 406*da0073e9SAndroid Build Coastguard Worker for i in range(len(device_momentum_buffer_list)): 407*da0073e9SAndroid Build Coastguard Worker if device_momentum_buffer_list[i] is None: 408*da0073e9SAndroid Build Coastguard Worker all_states_with_momentum_buffer = False 409*da0073e9SAndroid Build Coastguard Worker break 410*da0073e9SAndroid Build Coastguard Worker else: 411*da0073e9SAndroid Build Coastguard Worker bufs.append(cast(Tensor, device_momentum_buffer_list[i])) 412*da0073e9SAndroid Build Coastguard Worker 413*da0073e9SAndroid Build Coastguard Worker if all_states_with_momentum_buffer: 414*da0073e9SAndroid Build Coastguard Worker torch._foreach_mul_(bufs, momentum) 415*da0073e9SAndroid Build Coastguard Worker torch._foreach_add_(bufs, device_grads, alpha=1 - dampening) 416*da0073e9SAndroid Build Coastguard Worker else: 417*da0073e9SAndroid Build Coastguard Worker bufs = [] 418*da0073e9SAndroid Build Coastguard Worker for i in range(len(device_momentum_buffer_list)): 419*da0073e9SAndroid Build Coastguard Worker if device_momentum_buffer_list[i] is None: 420*da0073e9SAndroid Build Coastguard Worker buf = device_momentum_buffer_list[i] = momentum_buffer_list[ 421*da0073e9SAndroid Build Coastguard Worker indices[i] 422*da0073e9SAndroid Build Coastguard Worker ] = torch.clone(device_grads[i]).detach() 423*da0073e9SAndroid Build Coastguard Worker else: 424*da0073e9SAndroid Build Coastguard Worker buf = cast(Tensor, device_momentum_buffer_list[i]) 425*da0073e9SAndroid Build Coastguard Worker buf.mul_(momentum).add_(device_grads[i], alpha=1 - dampening) 426*da0073e9SAndroid Build Coastguard Worker 427*da0073e9SAndroid Build Coastguard Worker bufs.append(buf) 428*da0073e9SAndroid Build Coastguard Worker 429*da0073e9SAndroid Build Coastguard Worker if nesterov: 430*da0073e9SAndroid Build Coastguard Worker torch._foreach_add_(device_grads, bufs, alpha=momentum) 431*da0073e9SAndroid Build Coastguard Worker else: 432*da0073e9SAndroid Build Coastguard Worker device_grads = bufs 433*da0073e9SAndroid Build Coastguard Worker 434*da0073e9SAndroid Build Coastguard Worker if not device_has_sparse_grad: 435*da0073e9SAndroid Build Coastguard Worker # handle internal item() call if lr is a tensor 436*da0073e9SAndroid Build Coastguard Worker if isinstance(lr, torch.Tensor) and torch._utils.is_compiling(): 437*da0073e9SAndroid Build Coastguard Worker grads_x_lr = torch._foreach_mul(device_grads, -lr) 438*da0073e9SAndroid Build Coastguard Worker torch._foreach_add_(device_params, grads_x_lr) 439*da0073e9SAndroid Build Coastguard Worker else: 440*da0073e9SAndroid Build Coastguard Worker torch._foreach_add_(device_params, device_grads, alpha=-lr) 441*da0073e9SAndroid Build Coastguard Worker else: 442*da0073e9SAndroid Build Coastguard Worker # foreach APIs don't support sparse 443*da0073e9SAndroid Build Coastguard Worker for i in range(len(device_params)): 444*da0073e9SAndroid Build Coastguard Worker device_params[i].add_(device_grads[i], alpha=-lr) 445*da0073e9SAndroid Build Coastguard Worker 446*da0073e9SAndroid Build Coastguard Worker 447*da0073e9SAndroid Build Coastguard Workerdef _fused_sgd( 448*da0073e9SAndroid Build Coastguard Worker params: List[Tensor], 449*da0073e9SAndroid Build Coastguard Worker grads: List[Tensor], 450*da0073e9SAndroid Build Coastguard Worker momentum_buffer_list: List[Optional[Tensor]], 451*da0073e9SAndroid Build Coastguard Worker grad_scale: Optional[Tensor], 452*da0073e9SAndroid Build Coastguard Worker found_inf: Optional[Tensor], 453*da0073e9SAndroid Build Coastguard Worker *, 454*da0073e9SAndroid Build Coastguard Worker weight_decay: float, 455*da0073e9SAndroid Build Coastguard Worker momentum: float, 456*da0073e9SAndroid Build Coastguard Worker lr: float, 457*da0073e9SAndroid Build Coastguard Worker dampening: float, 458*da0073e9SAndroid Build Coastguard Worker nesterov: bool, 459*da0073e9SAndroid Build Coastguard Worker maximize: bool, 460*da0073e9SAndroid Build Coastguard Worker has_sparse_grad: bool, 461*da0073e9SAndroid Build Coastguard Worker) -> None: 462*da0073e9SAndroid Build Coastguard Worker if not params: 463*da0073e9SAndroid Build Coastguard Worker return 464*da0073e9SAndroid Build Coastguard Worker if has_sparse_grad: 465*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("`_fused_sgd` does not support sparse gradients") 466*da0073e9SAndroid Build Coastguard Worker grad_scale_dict: DeviceDict = ( 467*da0073e9SAndroid Build Coastguard Worker {grad_scale.device: grad_scale} if grad_scale is not None else {} 468*da0073e9SAndroid Build Coastguard Worker ) 469*da0073e9SAndroid Build Coastguard Worker found_inf_dict: DeviceDict = ( 470*da0073e9SAndroid Build Coastguard Worker {found_inf.device: found_inf} if found_inf is not None else {} 471*da0073e9SAndroid Build Coastguard Worker ) 472*da0073e9SAndroid Build Coastguard Worker 473*da0073e9SAndroid Build Coastguard Worker no_momentum_buffer = momentum == 0 474*da0073e9SAndroid Build Coastguard Worker is_first_step = ( 475*da0073e9SAndroid Build Coastguard Worker all(t is None for t in momentum_buffer_list) and not no_momentum_buffer 476*da0073e9SAndroid Build Coastguard Worker ) 477*da0073e9SAndroid Build Coastguard Worker if is_first_step: 478*da0073e9SAndroid Build Coastguard Worker for i, g in enumerate(grads): 479*da0073e9SAndroid Build Coastguard Worker momentum_buffer_list[i] = torch.empty_like(g) 480*da0073e9SAndroid Build Coastguard Worker grouped_tensors = Optimizer._group_tensors_by_device_and_dtype( 481*da0073e9SAndroid Build Coastguard Worker [params, grads, momentum_buffer_list], with_indices=False # type: ignore[list-item] 482*da0073e9SAndroid Build Coastguard Worker ) 483*da0073e9SAndroid Build Coastguard Worker for (device, _), ( 484*da0073e9SAndroid Build Coastguard Worker (device_params_, device_grads_, device_momentum_buffer_list), 485*da0073e9SAndroid Build Coastguard Worker _, 486*da0073e9SAndroid Build Coastguard Worker ) in grouped_tensors.items(): 487*da0073e9SAndroid Build Coastguard Worker device_params: List[Tensor] = cast(List[Tensor], device_params_) 488*da0073e9SAndroid Build Coastguard Worker device_grads: List[Tensor] = cast(List[Tensor], device_grads_) 489*da0073e9SAndroid Build Coastguard Worker device_grad_scale, device_found_inf = None, None 490*da0073e9SAndroid Build Coastguard Worker if grad_scale is not None: 491*da0073e9SAndroid Build Coastguard Worker device_grad_scale = grad_scale_dict.setdefault( 492*da0073e9SAndroid Build Coastguard Worker device, grad_scale.to(device) 493*da0073e9SAndroid Build Coastguard Worker ) 494*da0073e9SAndroid Build Coastguard Worker if found_inf_dict is not None and found_inf is not None: 495*da0073e9SAndroid Build Coastguard Worker device_found_inf = found_inf_dict.setdefault(device, found_inf.to(device)) 496*da0073e9SAndroid Build Coastguard Worker torch._fused_sgd_( 497*da0073e9SAndroid Build Coastguard Worker device_params, 498*da0073e9SAndroid Build Coastguard Worker device_grads, 499*da0073e9SAndroid Build Coastguard Worker [] 500*da0073e9SAndroid Build Coastguard Worker if no_momentum_buffer 501*da0073e9SAndroid Build Coastguard Worker else cast(List[Tensor], device_momentum_buffer_list), 502*da0073e9SAndroid Build Coastguard Worker weight_decay=weight_decay, 503*da0073e9SAndroid Build Coastguard Worker momentum=momentum, 504*da0073e9SAndroid Build Coastguard Worker lr=lr, 505*da0073e9SAndroid Build Coastguard Worker dampening=dampening, 506*da0073e9SAndroid Build Coastguard Worker nesterov=nesterov, 507*da0073e9SAndroid Build Coastguard Worker maximize=maximize, 508*da0073e9SAndroid Build Coastguard Worker is_first_step=is_first_step, 509*da0073e9SAndroid Build Coastguard Worker grad_scale=device_grad_scale, 510*da0073e9SAndroid Build Coastguard Worker found_inf=device_found_inf, 511*da0073e9SAndroid Build Coastguard Worker ) 512