xref: /aosp_15_r20/external/pytorch/torch/optim/sgd.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs
2*da0073e9SAndroid Build Coastguard Workerr"""Implementation for Stochastic Gradient Descent optimizer."""
3*da0073e9SAndroid Build Coastguard Workerfrom typing import cast, List, Optional, Union
4*da0073e9SAndroid Build Coastguard Worker
5*da0073e9SAndroid Build Coastguard Workerimport torch
6*da0073e9SAndroid Build Coastguard Workerfrom torch import Tensor
7*da0073e9SAndroid Build Coastguard Worker
8*da0073e9SAndroid Build Coastguard Workerfrom .optimizer import (
9*da0073e9SAndroid Build Coastguard Worker    _default_to_fused_or_foreach,
10*da0073e9SAndroid Build Coastguard Worker    _device_dtype_check_for_fused,
11*da0073e9SAndroid Build Coastguard Worker    _differentiable_doc,
12*da0073e9SAndroid Build Coastguard Worker    _foreach_doc,
13*da0073e9SAndroid Build Coastguard Worker    _fused_doc,
14*da0073e9SAndroid Build Coastguard Worker    _maximize_doc,
15*da0073e9SAndroid Build Coastguard Worker    _use_grad_for_differentiable,
16*da0073e9SAndroid Build Coastguard Worker    DeviceDict,
17*da0073e9SAndroid Build Coastguard Worker    Optimizer,
18*da0073e9SAndroid Build Coastguard Worker)
19*da0073e9SAndroid Build Coastguard Worker
20*da0073e9SAndroid Build Coastguard Worker
21*da0073e9SAndroid Build Coastguard Worker__all__ = ["SGD", "sgd"]
22*da0073e9SAndroid Build Coastguard Worker
23*da0073e9SAndroid Build Coastguard Worker
24*da0073e9SAndroid Build Coastguard Workerclass SGD(Optimizer):  # noqa: D101
25*da0073e9SAndroid Build Coastguard Worker    def __init__(
26*da0073e9SAndroid Build Coastguard Worker        self,
27*da0073e9SAndroid Build Coastguard Worker        params,
28*da0073e9SAndroid Build Coastguard Worker        lr: Union[float, Tensor] = 1e-3,
29*da0073e9SAndroid Build Coastguard Worker        momentum: float = 0,
30*da0073e9SAndroid Build Coastguard Worker        dampening: float = 0,
31*da0073e9SAndroid Build Coastguard Worker        weight_decay: float = 0,
32*da0073e9SAndroid Build Coastguard Worker        nesterov=False,
33*da0073e9SAndroid Build Coastguard Worker        *,
34*da0073e9SAndroid Build Coastguard Worker        maximize: bool = False,
35*da0073e9SAndroid Build Coastguard Worker        foreach: Optional[bool] = None,
36*da0073e9SAndroid Build Coastguard Worker        differentiable: bool = False,
37*da0073e9SAndroid Build Coastguard Worker        fused: Optional[bool] = None,
38*da0073e9SAndroid Build Coastguard Worker    ):  # noqa: D107
39*da0073e9SAndroid Build Coastguard Worker        if isinstance(lr, Tensor) and lr.numel() != 1:
40*da0073e9SAndroid Build Coastguard Worker            raise ValueError("Tensor lr must be 1-element")
41*da0073e9SAndroid Build Coastguard Worker        if lr < 0.0:
42*da0073e9SAndroid Build Coastguard Worker            raise ValueError(f"Invalid learning rate: {lr}")
43*da0073e9SAndroid Build Coastguard Worker        if momentum < 0.0:
44*da0073e9SAndroid Build Coastguard Worker            raise ValueError(f"Invalid momentum value: {momentum}")
45*da0073e9SAndroid Build Coastguard Worker        if weight_decay < 0.0:
46*da0073e9SAndroid Build Coastguard Worker            raise ValueError(f"Invalid weight_decay value: {weight_decay}")
47*da0073e9SAndroid Build Coastguard Worker
48*da0073e9SAndroid Build Coastguard Worker        defaults = dict(
49*da0073e9SAndroid Build Coastguard Worker            lr=lr,
50*da0073e9SAndroid Build Coastguard Worker            momentum=momentum,
51*da0073e9SAndroid Build Coastguard Worker            dampening=dampening,
52*da0073e9SAndroid Build Coastguard Worker            weight_decay=weight_decay,
53*da0073e9SAndroid Build Coastguard Worker            nesterov=nesterov,
54*da0073e9SAndroid Build Coastguard Worker            maximize=maximize,
55*da0073e9SAndroid Build Coastguard Worker            foreach=foreach,
56*da0073e9SAndroid Build Coastguard Worker            differentiable=differentiable,
57*da0073e9SAndroid Build Coastguard Worker            fused=fused,
58*da0073e9SAndroid Build Coastguard Worker        )
59*da0073e9SAndroid Build Coastguard Worker        if nesterov and (momentum <= 0 or dampening != 0):
60*da0073e9SAndroid Build Coastguard Worker            raise ValueError("Nesterov momentum requires a momentum and zero dampening")
61*da0073e9SAndroid Build Coastguard Worker        super().__init__(params, defaults)
62*da0073e9SAndroid Build Coastguard Worker
63*da0073e9SAndroid Build Coastguard Worker        if fused:
64*da0073e9SAndroid Build Coastguard Worker            self._step_supports_amp_scaling = True
65*da0073e9SAndroid Build Coastguard Worker            self._need_device_dtype_check_for_fused = True
66*da0073e9SAndroid Build Coastguard Worker            if differentiable:
67*da0073e9SAndroid Build Coastguard Worker                raise RuntimeError("`fused` does not support `differentiable`")
68*da0073e9SAndroid Build Coastguard Worker            if foreach:
69*da0073e9SAndroid Build Coastguard Worker                raise RuntimeError("`fused` and `foreach` cannot be `True` together.")
70*da0073e9SAndroid Build Coastguard Worker
71*da0073e9SAndroid Build Coastguard Worker    def __setstate__(self, state):  # noqa: D105
72*da0073e9SAndroid Build Coastguard Worker        super().__setstate__(state)
73*da0073e9SAndroid Build Coastguard Worker        for group in self.param_groups:
74*da0073e9SAndroid Build Coastguard Worker            group.setdefault("nesterov", False)
75*da0073e9SAndroid Build Coastguard Worker            group.setdefault("maximize", False)
76*da0073e9SAndroid Build Coastguard Worker            group.setdefault("foreach", None)
77*da0073e9SAndroid Build Coastguard Worker            group.setdefault("differentiable", False)
78*da0073e9SAndroid Build Coastguard Worker            group.setdefault("fused", False)
79*da0073e9SAndroid Build Coastguard Worker
80*da0073e9SAndroid Build Coastguard Worker    def _init_group(self, group, params, grads, momentum_buffer_list):
81*da0073e9SAndroid Build Coastguard Worker        has_sparse_grad = False
82*da0073e9SAndroid Build Coastguard Worker
83*da0073e9SAndroid Build Coastguard Worker        for p in group["params"]:
84*da0073e9SAndroid Build Coastguard Worker            if p.grad is not None:
85*da0073e9SAndroid Build Coastguard Worker                if group["fused"] and getattr(
86*da0073e9SAndroid Build Coastguard Worker                    self, "_need_device_dtype_check_for_fused", True
87*da0073e9SAndroid Build Coastguard Worker                ):
88*da0073e9SAndroid Build Coastguard Worker                    _device_dtype_check_for_fused(p)
89*da0073e9SAndroid Build Coastguard Worker                    self._need_device_dtype_check_for_fused = False
90*da0073e9SAndroid Build Coastguard Worker                params.append(p)
91*da0073e9SAndroid Build Coastguard Worker                grads.append(p.grad)
92*da0073e9SAndroid Build Coastguard Worker                if p.grad.is_sparse:
93*da0073e9SAndroid Build Coastguard Worker                    has_sparse_grad = True
94*da0073e9SAndroid Build Coastguard Worker
95*da0073e9SAndroid Build Coastguard Worker                if group["momentum"] != 0:
96*da0073e9SAndroid Build Coastguard Worker                    state = self.state[p]
97*da0073e9SAndroid Build Coastguard Worker                    momentum_buffer_list.append(state.get("momentum_buffer"))
98*da0073e9SAndroid Build Coastguard Worker
99*da0073e9SAndroid Build Coastguard Worker        return has_sparse_grad
100*da0073e9SAndroid Build Coastguard Worker
101*da0073e9SAndroid Build Coastguard Worker    @_use_grad_for_differentiable
102*da0073e9SAndroid Build Coastguard Worker    def step(self, closure=None):
103*da0073e9SAndroid Build Coastguard Worker        """Perform a single optimization step.
104*da0073e9SAndroid Build Coastguard Worker
105*da0073e9SAndroid Build Coastguard Worker        Args:
106*da0073e9SAndroid Build Coastguard Worker            closure (Callable, optional): A closure that reevaluates the model
107*da0073e9SAndroid Build Coastguard Worker                and returns the loss.
108*da0073e9SAndroid Build Coastguard Worker        """
109*da0073e9SAndroid Build Coastguard Worker        loss = None
110*da0073e9SAndroid Build Coastguard Worker        if closure is not None:
111*da0073e9SAndroid Build Coastguard Worker            with torch.enable_grad():
112*da0073e9SAndroid Build Coastguard Worker                loss = closure()
113*da0073e9SAndroid Build Coastguard Worker
114*da0073e9SAndroid Build Coastguard Worker        for group in self.param_groups:
115*da0073e9SAndroid Build Coastguard Worker            params: List[Tensor] = []
116*da0073e9SAndroid Build Coastguard Worker            grads: List[Tensor] = []
117*da0073e9SAndroid Build Coastguard Worker            momentum_buffer_list: List[Optional[Tensor]] = []
118*da0073e9SAndroid Build Coastguard Worker
119*da0073e9SAndroid Build Coastguard Worker            has_sparse_grad = self._init_group(
120*da0073e9SAndroid Build Coastguard Worker                group, params, grads, momentum_buffer_list
121*da0073e9SAndroid Build Coastguard Worker            )
122*da0073e9SAndroid Build Coastguard Worker
123*da0073e9SAndroid Build Coastguard Worker            sgd(
124*da0073e9SAndroid Build Coastguard Worker                params,
125*da0073e9SAndroid Build Coastguard Worker                grads,
126*da0073e9SAndroid Build Coastguard Worker                momentum_buffer_list,
127*da0073e9SAndroid Build Coastguard Worker                weight_decay=group["weight_decay"],
128*da0073e9SAndroid Build Coastguard Worker                momentum=group["momentum"],
129*da0073e9SAndroid Build Coastguard Worker                lr=group["lr"],
130*da0073e9SAndroid Build Coastguard Worker                dampening=group["dampening"],
131*da0073e9SAndroid Build Coastguard Worker                nesterov=group["nesterov"],
132*da0073e9SAndroid Build Coastguard Worker                maximize=group["maximize"],
133*da0073e9SAndroid Build Coastguard Worker                has_sparse_grad=has_sparse_grad,
134*da0073e9SAndroid Build Coastguard Worker                foreach=group["foreach"],
135*da0073e9SAndroid Build Coastguard Worker                fused=group["fused"],
136*da0073e9SAndroid Build Coastguard Worker                grad_scale=getattr(self, "grad_scale", None),
137*da0073e9SAndroid Build Coastguard Worker                found_inf=getattr(self, "found_inf", None),
138*da0073e9SAndroid Build Coastguard Worker            )
139*da0073e9SAndroid Build Coastguard Worker
140*da0073e9SAndroid Build Coastguard Worker            if group["momentum"] != 0:
141*da0073e9SAndroid Build Coastguard Worker                # update momentum_buffers in state
142*da0073e9SAndroid Build Coastguard Worker                for p, momentum_buffer in zip(params, momentum_buffer_list):
143*da0073e9SAndroid Build Coastguard Worker                    state = self.state[p]
144*da0073e9SAndroid Build Coastguard Worker                    state["momentum_buffer"] = momentum_buffer
145*da0073e9SAndroid Build Coastguard Worker
146*da0073e9SAndroid Build Coastguard Worker        return loss
147*da0073e9SAndroid Build Coastguard Worker
148*da0073e9SAndroid Build Coastguard Worker
149*da0073e9SAndroid Build Coastguard WorkerSGD.__doc__ = (
150*da0073e9SAndroid Build Coastguard Worker    r"""Implements stochastic gradient descent (optionally with momentum).
151*da0073e9SAndroid Build Coastguard Worker
152*da0073e9SAndroid Build Coastguard Worker    .. math::
153*da0073e9SAndroid Build Coastguard Worker       \begin{aligned}
154*da0073e9SAndroid Build Coastguard Worker            &\rule{110mm}{0.4pt}                                                                 \\
155*da0073e9SAndroid Build Coastguard Worker            &\textbf{input}      : \gamma \text{ (lr)}, \: \theta_0 \text{ (params)}, \: f(\theta)
156*da0073e9SAndroid Build Coastguard Worker                \text{ (objective)}, \: \lambda \text{ (weight decay)},                          \\
157*da0073e9SAndroid Build Coastguard Worker            &\hspace{13mm} \:\mu \text{ (momentum)}, \:\tau \text{ (dampening)},
158*da0073e9SAndroid Build Coastguard Worker            \:\textit{ nesterov,}\:\textit{ maximize}                                     \\[-1.ex]
159*da0073e9SAndroid Build Coastguard Worker            &\rule{110mm}{0.4pt}                                                                 \\
160*da0073e9SAndroid Build Coastguard Worker            &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do}                         \\
161*da0073e9SAndroid Build Coastguard Worker            &\hspace{5mm}g_t           \leftarrow   \nabla_{\theta} f_t (\theta_{t-1})           \\
162*da0073e9SAndroid Build Coastguard Worker            &\hspace{5mm}\textbf{if} \: \lambda \neq 0                                           \\
163*da0073e9SAndroid Build Coastguard Worker            &\hspace{10mm} g_t \leftarrow g_t + \lambda  \theta_{t-1}                            \\
164*da0073e9SAndroid Build Coastguard Worker            &\hspace{5mm}\textbf{if} \: \mu \neq 0                                               \\
165*da0073e9SAndroid Build Coastguard Worker            &\hspace{10mm}\textbf{if} \: t > 1                                                   \\
166*da0073e9SAndroid Build Coastguard Worker            &\hspace{15mm} \textbf{b}_t \leftarrow \mu \textbf{b}_{t-1} + (1-\tau) g_t           \\
167*da0073e9SAndroid Build Coastguard Worker            &\hspace{10mm}\textbf{else}                                                          \\
168*da0073e9SAndroid Build Coastguard Worker            &\hspace{15mm} \textbf{b}_t \leftarrow g_t                                           \\
169*da0073e9SAndroid Build Coastguard Worker            &\hspace{10mm}\textbf{if} \: \textit{nesterov}                                       \\
170*da0073e9SAndroid Build Coastguard Worker            &\hspace{15mm} g_t \leftarrow g_{t} + \mu \textbf{b}_t                             \\
171*da0073e9SAndroid Build Coastguard Worker            &\hspace{10mm}\textbf{else}                                                   \\[-1.ex]
172*da0073e9SAndroid Build Coastguard Worker            &\hspace{15mm} g_t  \leftarrow  \textbf{b}_t                                         \\
173*da0073e9SAndroid Build Coastguard Worker            &\hspace{5mm}\textbf{if} \: \textit{maximize}                                          \\
174*da0073e9SAndroid Build Coastguard Worker            &\hspace{10mm}\theta_t \leftarrow \theta_{t-1} + \gamma g_t                   \\[-1.ex]
175*da0073e9SAndroid Build Coastguard Worker            &\hspace{5mm}\textbf{else}                                                    \\[-1.ex]
176*da0073e9SAndroid Build Coastguard Worker            &\hspace{10mm}\theta_t \leftarrow \theta_{t-1} - \gamma g_t                   \\[-1.ex]
177*da0073e9SAndroid Build Coastguard Worker            &\rule{110mm}{0.4pt}                                                          \\[-1.ex]
178*da0073e9SAndroid Build Coastguard Worker            &\bf{return} \:  \theta_t                                                     \\[-1.ex]
179*da0073e9SAndroid Build Coastguard Worker            &\rule{110mm}{0.4pt}                                                          \\[-1.ex]
180*da0073e9SAndroid Build Coastguard Worker       \end{aligned}
181*da0073e9SAndroid Build Coastguard Worker
182*da0073e9SAndroid Build Coastguard Worker    Nesterov momentum is based on the formula from
183*da0073e9SAndroid Build Coastguard Worker    `On the importance of initialization and momentum in deep learning`__.
184*da0073e9SAndroid Build Coastguard Worker    """
185*da0073e9SAndroid Build Coastguard Worker    + rf"""
186*da0073e9SAndroid Build Coastguard Worker    Args:
187*da0073e9SAndroid Build Coastguard Worker        params (iterable): iterable of parameters to optimize or dicts defining
188*da0073e9SAndroid Build Coastguard Worker            parameter groups
189*da0073e9SAndroid Build Coastguard Worker        lr (float, Tensor, optional): learning rate (default: 1e-3)
190*da0073e9SAndroid Build Coastguard Worker        momentum (float, optional): momentum factor (default: 0)
191*da0073e9SAndroid Build Coastguard Worker        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
192*da0073e9SAndroid Build Coastguard Worker        dampening (float, optional): dampening for momentum (default: 0)
193*da0073e9SAndroid Build Coastguard Worker        nesterov (bool, optional): enables Nesterov momentum (default: False)
194*da0073e9SAndroid Build Coastguard Worker        {_maximize_doc}
195*da0073e9SAndroid Build Coastguard Worker        {_foreach_doc}
196*da0073e9SAndroid Build Coastguard Worker        {_differentiable_doc}
197*da0073e9SAndroid Build Coastguard Worker        {_fused_doc}
198*da0073e9SAndroid Build Coastguard Worker    """
199*da0073e9SAndroid Build Coastguard Worker    + r"""
200*da0073e9SAndroid Build Coastguard Worker
201*da0073e9SAndroid Build Coastguard Worker    Example:
202*da0073e9SAndroid Build Coastguard Worker        >>> # xdoctest: +SKIP
203*da0073e9SAndroid Build Coastguard Worker        >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
204*da0073e9SAndroid Build Coastguard Worker        >>> optimizer.zero_grad()
205*da0073e9SAndroid Build Coastguard Worker        >>> loss_fn(model(input), target).backward()
206*da0073e9SAndroid Build Coastguard Worker        >>> optimizer.step()
207*da0073e9SAndroid Build Coastguard Worker
208*da0073e9SAndroid Build Coastguard Worker    __ http://www.cs.toronto.edu/%7Ehinton/absps/momentum.pdf
209*da0073e9SAndroid Build Coastguard Worker
210*da0073e9SAndroid Build Coastguard Worker    .. note::
211*da0073e9SAndroid Build Coastguard Worker        The implementation of SGD with Momentum/Nesterov subtly differs from
212*da0073e9SAndroid Build Coastguard Worker        Sutskever et al. and implementations in some other frameworks.
213*da0073e9SAndroid Build Coastguard Worker
214*da0073e9SAndroid Build Coastguard Worker        Considering the specific case of Momentum, the update can be written as
215*da0073e9SAndroid Build Coastguard Worker
216*da0073e9SAndroid Build Coastguard Worker        .. math::
217*da0073e9SAndroid Build Coastguard Worker            \begin{aligned}
218*da0073e9SAndroid Build Coastguard Worker                v_{t+1} & = \mu * v_{t} + g_{t+1}, \\
219*da0073e9SAndroid Build Coastguard Worker                p_{t+1} & = p_{t} - \text{lr} * v_{t+1},
220*da0073e9SAndroid Build Coastguard Worker            \end{aligned}
221*da0073e9SAndroid Build Coastguard Worker
222*da0073e9SAndroid Build Coastguard Worker        where :math:`p`, :math:`g`, :math:`v` and :math:`\mu` denote the
223*da0073e9SAndroid Build Coastguard Worker        parameters, gradient, velocity, and momentum respectively.
224*da0073e9SAndroid Build Coastguard Worker
225*da0073e9SAndroid Build Coastguard Worker        This is in contrast to Sutskever et al. and
226*da0073e9SAndroid Build Coastguard Worker        other frameworks which employ an update of the form
227*da0073e9SAndroid Build Coastguard Worker
228*da0073e9SAndroid Build Coastguard Worker        .. math::
229*da0073e9SAndroid Build Coastguard Worker            \begin{aligned}
230*da0073e9SAndroid Build Coastguard Worker                v_{t+1} & = \mu * v_{t} + \text{lr} * g_{t+1}, \\
231*da0073e9SAndroid Build Coastguard Worker                p_{t+1} & = p_{t} - v_{t+1}.
232*da0073e9SAndroid Build Coastguard Worker            \end{aligned}
233*da0073e9SAndroid Build Coastguard Worker
234*da0073e9SAndroid Build Coastguard Worker        The Nesterov version is analogously modified.
235*da0073e9SAndroid Build Coastguard Worker
236*da0073e9SAndroid Build Coastguard Worker        Moreover, the initial value of the momentum buffer is set to the
237*da0073e9SAndroid Build Coastguard Worker        gradient value at the first step. This is in contrast to some other
238*da0073e9SAndroid Build Coastguard Worker        frameworks that initialize it to all zeros.
239*da0073e9SAndroid Build Coastguard Worker
240*da0073e9SAndroid Build Coastguard Worker    """
241*da0073e9SAndroid Build Coastguard Worker)
242*da0073e9SAndroid Build Coastguard Worker
243*da0073e9SAndroid Build Coastguard Worker
244*da0073e9SAndroid Build Coastguard Workerdef sgd(
245*da0073e9SAndroid Build Coastguard Worker    params: List[Tensor],
246*da0073e9SAndroid Build Coastguard Worker    d_p_list: List[Tensor],
247*da0073e9SAndroid Build Coastguard Worker    momentum_buffer_list: List[Optional[Tensor]],
248*da0073e9SAndroid Build Coastguard Worker    # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
249*da0073e9SAndroid Build Coastguard Worker    # setting this as kwarg for now as functional API is compiled by torch/distributed/optim
250*da0073e9SAndroid Build Coastguard Worker    has_sparse_grad: bool = False,
251*da0073e9SAndroid Build Coastguard Worker    foreach: Optional[bool] = None,
252*da0073e9SAndroid Build Coastguard Worker    fused: Optional[bool] = None,
253*da0073e9SAndroid Build Coastguard Worker    grad_scale: Optional[Tensor] = None,
254*da0073e9SAndroid Build Coastguard Worker    found_inf: Optional[Tensor] = None,
255*da0073e9SAndroid Build Coastguard Worker    *,
256*da0073e9SAndroid Build Coastguard Worker    weight_decay: float,
257*da0073e9SAndroid Build Coastguard Worker    momentum: float,
258*da0073e9SAndroid Build Coastguard Worker    lr: float,
259*da0073e9SAndroid Build Coastguard Worker    dampening: float,
260*da0073e9SAndroid Build Coastguard Worker    nesterov: bool,
261*da0073e9SAndroid Build Coastguard Worker    maximize: bool,
262*da0073e9SAndroid Build Coastguard Worker):
263*da0073e9SAndroid Build Coastguard Worker    r"""Functional API that performs SGD algorithm computation.
264*da0073e9SAndroid Build Coastguard Worker
265*da0073e9SAndroid Build Coastguard Worker    See :class:`~torch.optim.SGD` for details.
266*da0073e9SAndroid Build Coastguard Worker    """
267*da0073e9SAndroid Build Coastguard Worker    # Respect when the user inputs False/True for foreach or fused. We only want to change
268*da0073e9SAndroid Build Coastguard Worker    # the default when neither have been user-specified. Note that we default to foreach
269*da0073e9SAndroid Build Coastguard Worker    # and pass False to use_fused. This is not a mistake--we want to give the fused impl
270*da0073e9SAndroid Build Coastguard Worker    # bake-in time before making it the default, even if it is typically faster.
271*da0073e9SAndroid Build Coastguard Worker    if foreach is None and fused is None:
272*da0073e9SAndroid Build Coastguard Worker        # why must we be explicit about an if statement for torch.jit.is_scripting here?
273*da0073e9SAndroid Build Coastguard Worker        # because JIT can't handle Optionals nor fancy conditionals when scripting
274*da0073e9SAndroid Build Coastguard Worker        if not torch.jit.is_scripting():
275*da0073e9SAndroid Build Coastguard Worker            fused, foreach = _default_to_fused_or_foreach(
276*da0073e9SAndroid Build Coastguard Worker                params, differentiable=False, use_fused=False
277*da0073e9SAndroid Build Coastguard Worker            )
278*da0073e9SAndroid Build Coastguard Worker        else:
279*da0073e9SAndroid Build Coastguard Worker            foreach = False
280*da0073e9SAndroid Build Coastguard Worker            fused = False
281*da0073e9SAndroid Build Coastguard Worker    if foreach is None:
282*da0073e9SAndroid Build Coastguard Worker        foreach = False
283*da0073e9SAndroid Build Coastguard Worker    if fused is None:
284*da0073e9SAndroid Build Coastguard Worker        fused = False
285*da0073e9SAndroid Build Coastguard Worker
286*da0073e9SAndroid Build Coastguard Worker    if foreach and torch.jit.is_scripting():
287*da0073e9SAndroid Build Coastguard Worker        raise RuntimeError("torch.jit.script not supported with foreach optimizers")
288*da0073e9SAndroid Build Coastguard Worker    if fused and torch.jit.is_scripting():
289*da0073e9SAndroid Build Coastguard Worker        raise RuntimeError("torch.jit.script not supported with fused optimizers")
290*da0073e9SAndroid Build Coastguard Worker
291*da0073e9SAndroid Build Coastguard Worker    if foreach and not torch.jit.is_scripting():
292*da0073e9SAndroid Build Coastguard Worker        func = _multi_tensor_sgd
293*da0073e9SAndroid Build Coastguard Worker    elif fused and not torch.jit.is_scripting():
294*da0073e9SAndroid Build Coastguard Worker        func = _fused_sgd
295*da0073e9SAndroid Build Coastguard Worker    else:
296*da0073e9SAndroid Build Coastguard Worker        func = _single_tensor_sgd
297*da0073e9SAndroid Build Coastguard Worker
298*da0073e9SAndroid Build Coastguard Worker    func(
299*da0073e9SAndroid Build Coastguard Worker        params,
300*da0073e9SAndroid Build Coastguard Worker        d_p_list,
301*da0073e9SAndroid Build Coastguard Worker        momentum_buffer_list,
302*da0073e9SAndroid Build Coastguard Worker        weight_decay=weight_decay,
303*da0073e9SAndroid Build Coastguard Worker        momentum=momentum,
304*da0073e9SAndroid Build Coastguard Worker        lr=lr,
305*da0073e9SAndroid Build Coastguard Worker        dampening=dampening,
306*da0073e9SAndroid Build Coastguard Worker        nesterov=nesterov,
307*da0073e9SAndroid Build Coastguard Worker        has_sparse_grad=has_sparse_grad,
308*da0073e9SAndroid Build Coastguard Worker        maximize=maximize,
309*da0073e9SAndroid Build Coastguard Worker        grad_scale=grad_scale,
310*da0073e9SAndroid Build Coastguard Worker        found_inf=found_inf,
311*da0073e9SAndroid Build Coastguard Worker    )
312*da0073e9SAndroid Build Coastguard Worker
313*da0073e9SAndroid Build Coastguard Worker
314*da0073e9SAndroid Build Coastguard Workerdef _single_tensor_sgd(
315*da0073e9SAndroid Build Coastguard Worker    params: List[Tensor],
316*da0073e9SAndroid Build Coastguard Worker    grads: List[Tensor],
317*da0073e9SAndroid Build Coastguard Worker    momentum_buffer_list: List[Optional[Tensor]],
318*da0073e9SAndroid Build Coastguard Worker    grad_scale: Optional[Tensor],
319*da0073e9SAndroid Build Coastguard Worker    found_inf: Optional[Tensor],
320*da0073e9SAndroid Build Coastguard Worker    *,
321*da0073e9SAndroid Build Coastguard Worker    weight_decay: float,
322*da0073e9SAndroid Build Coastguard Worker    momentum: float,
323*da0073e9SAndroid Build Coastguard Worker    lr: float,
324*da0073e9SAndroid Build Coastguard Worker    dampening: float,
325*da0073e9SAndroid Build Coastguard Worker    nesterov: bool,
326*da0073e9SAndroid Build Coastguard Worker    maximize: bool,
327*da0073e9SAndroid Build Coastguard Worker    has_sparse_grad: bool,
328*da0073e9SAndroid Build Coastguard Worker):
329*da0073e9SAndroid Build Coastguard Worker    assert grad_scale is None and found_inf is None
330*da0073e9SAndroid Build Coastguard Worker
331*da0073e9SAndroid Build Coastguard Worker    for i, param in enumerate(params):
332*da0073e9SAndroid Build Coastguard Worker        grad = grads[i] if not maximize else -grads[i]
333*da0073e9SAndroid Build Coastguard Worker
334*da0073e9SAndroid Build Coastguard Worker        if weight_decay != 0:
335*da0073e9SAndroid Build Coastguard Worker            grad = grad.add(param, alpha=weight_decay)
336*da0073e9SAndroid Build Coastguard Worker
337*da0073e9SAndroid Build Coastguard Worker        if momentum != 0:
338*da0073e9SAndroid Build Coastguard Worker            buf = momentum_buffer_list[i]
339*da0073e9SAndroid Build Coastguard Worker
340*da0073e9SAndroid Build Coastguard Worker            if buf is None:
341*da0073e9SAndroid Build Coastguard Worker                buf = torch.clone(grad).detach()
342*da0073e9SAndroid Build Coastguard Worker                momentum_buffer_list[i] = buf
343*da0073e9SAndroid Build Coastguard Worker            else:
344*da0073e9SAndroid Build Coastguard Worker                buf.mul_(momentum).add_(grad, alpha=1 - dampening)
345*da0073e9SAndroid Build Coastguard Worker
346*da0073e9SAndroid Build Coastguard Worker            if nesterov:
347*da0073e9SAndroid Build Coastguard Worker                grad = grad.add(buf, alpha=momentum)
348*da0073e9SAndroid Build Coastguard Worker            else:
349*da0073e9SAndroid Build Coastguard Worker                grad = buf
350*da0073e9SAndroid Build Coastguard Worker
351*da0073e9SAndroid Build Coastguard Worker        param.add_(grad, alpha=-lr)
352*da0073e9SAndroid Build Coastguard Worker
353*da0073e9SAndroid Build Coastguard Worker
354*da0073e9SAndroid Build Coastguard Workerdef _multi_tensor_sgd(
355*da0073e9SAndroid Build Coastguard Worker    params: List[Tensor],
356*da0073e9SAndroid Build Coastguard Worker    grads: List[Tensor],
357*da0073e9SAndroid Build Coastguard Worker    momentum_buffer_list: List[Optional[Tensor]],
358*da0073e9SAndroid Build Coastguard Worker    grad_scale: Optional[Tensor],
359*da0073e9SAndroid Build Coastguard Worker    found_inf: Optional[Tensor],
360*da0073e9SAndroid Build Coastguard Worker    *,
361*da0073e9SAndroid Build Coastguard Worker    weight_decay: float,
362*da0073e9SAndroid Build Coastguard Worker    momentum: float,
363*da0073e9SAndroid Build Coastguard Worker    lr: float,
364*da0073e9SAndroid Build Coastguard Worker    dampening: float,
365*da0073e9SAndroid Build Coastguard Worker    nesterov: bool,
366*da0073e9SAndroid Build Coastguard Worker    maximize: bool,
367*da0073e9SAndroid Build Coastguard Worker    has_sparse_grad: bool,
368*da0073e9SAndroid Build Coastguard Worker):
369*da0073e9SAndroid Build Coastguard Worker    assert grad_scale is None and found_inf is None
370*da0073e9SAndroid Build Coastguard Worker
371*da0073e9SAndroid Build Coastguard Worker    if len(params) == 0:
372*da0073e9SAndroid Build Coastguard Worker        return
373*da0073e9SAndroid Build Coastguard Worker
374*da0073e9SAndroid Build Coastguard Worker    grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
375*da0073e9SAndroid Build Coastguard Worker        [params, grads, momentum_buffer_list], with_indices=True  # type: ignore[list-item]
376*da0073e9SAndroid Build Coastguard Worker    )
377*da0073e9SAndroid Build Coastguard Worker
378*da0073e9SAndroid Build Coastguard Worker    for (
379*da0073e9SAndroid Build Coastguard Worker        device_params_,
380*da0073e9SAndroid Build Coastguard Worker        device_grads_,
381*da0073e9SAndroid Build Coastguard Worker        device_momentum_buffer_list,
382*da0073e9SAndroid Build Coastguard Worker    ), indices in grouped_tensors.values():
383*da0073e9SAndroid Build Coastguard Worker        device_params: List[Tensor] = cast(List[Tensor], device_params_)
384*da0073e9SAndroid Build Coastguard Worker        device_grads: List[Tensor] = cast(List[Tensor], device_grads_)
385*da0073e9SAndroid Build Coastguard Worker
386*da0073e9SAndroid Build Coastguard Worker        device_has_sparse_grad = has_sparse_grad and any(
387*da0073e9SAndroid Build Coastguard Worker            grad.is_sparse for grad in device_grads
388*da0073e9SAndroid Build Coastguard Worker        )
389*da0073e9SAndroid Build Coastguard Worker
390*da0073e9SAndroid Build Coastguard Worker        if maximize:
391*da0073e9SAndroid Build Coastguard Worker            device_grads = torch._foreach_neg(device_grads)  # type: ignore[assignment]
392*da0073e9SAndroid Build Coastguard Worker
393*da0073e9SAndroid Build Coastguard Worker        if weight_decay != 0:
394*da0073e9SAndroid Build Coastguard Worker            # Re-use the intermediate memory (device_grads) already allocated for maximize
395*da0073e9SAndroid Build Coastguard Worker            if maximize:
396*da0073e9SAndroid Build Coastguard Worker                torch._foreach_add_(device_grads, device_params, alpha=weight_decay)
397*da0073e9SAndroid Build Coastguard Worker            else:
398*da0073e9SAndroid Build Coastguard Worker                device_grads = torch._foreach_add(  # type: ignore[assignment]
399*da0073e9SAndroid Build Coastguard Worker                    device_grads, device_params, alpha=weight_decay
400*da0073e9SAndroid Build Coastguard Worker                )
401*da0073e9SAndroid Build Coastguard Worker
402*da0073e9SAndroid Build Coastguard Worker        if momentum != 0:
403*da0073e9SAndroid Build Coastguard Worker            bufs: List[Tensor] = []
404*da0073e9SAndroid Build Coastguard Worker
405*da0073e9SAndroid Build Coastguard Worker            all_states_with_momentum_buffer = True
406*da0073e9SAndroid Build Coastguard Worker            for i in range(len(device_momentum_buffer_list)):
407*da0073e9SAndroid Build Coastguard Worker                if device_momentum_buffer_list[i] is None:
408*da0073e9SAndroid Build Coastguard Worker                    all_states_with_momentum_buffer = False
409*da0073e9SAndroid Build Coastguard Worker                    break
410*da0073e9SAndroid Build Coastguard Worker                else:
411*da0073e9SAndroid Build Coastguard Worker                    bufs.append(cast(Tensor, device_momentum_buffer_list[i]))
412*da0073e9SAndroid Build Coastguard Worker
413*da0073e9SAndroid Build Coastguard Worker            if all_states_with_momentum_buffer:
414*da0073e9SAndroid Build Coastguard Worker                torch._foreach_mul_(bufs, momentum)
415*da0073e9SAndroid Build Coastguard Worker                torch._foreach_add_(bufs, device_grads, alpha=1 - dampening)
416*da0073e9SAndroid Build Coastguard Worker            else:
417*da0073e9SAndroid Build Coastguard Worker                bufs = []
418*da0073e9SAndroid Build Coastguard Worker                for i in range(len(device_momentum_buffer_list)):
419*da0073e9SAndroid Build Coastguard Worker                    if device_momentum_buffer_list[i] is None:
420*da0073e9SAndroid Build Coastguard Worker                        buf = device_momentum_buffer_list[i] = momentum_buffer_list[
421*da0073e9SAndroid Build Coastguard Worker                            indices[i]
422*da0073e9SAndroid Build Coastguard Worker                        ] = torch.clone(device_grads[i]).detach()
423*da0073e9SAndroid Build Coastguard Worker                    else:
424*da0073e9SAndroid Build Coastguard Worker                        buf = cast(Tensor, device_momentum_buffer_list[i])
425*da0073e9SAndroid Build Coastguard Worker                        buf.mul_(momentum).add_(device_grads[i], alpha=1 - dampening)
426*da0073e9SAndroid Build Coastguard Worker
427*da0073e9SAndroid Build Coastguard Worker                    bufs.append(buf)
428*da0073e9SAndroid Build Coastguard Worker
429*da0073e9SAndroid Build Coastguard Worker            if nesterov:
430*da0073e9SAndroid Build Coastguard Worker                torch._foreach_add_(device_grads, bufs, alpha=momentum)
431*da0073e9SAndroid Build Coastguard Worker            else:
432*da0073e9SAndroid Build Coastguard Worker                device_grads = bufs
433*da0073e9SAndroid Build Coastguard Worker
434*da0073e9SAndroid Build Coastguard Worker        if not device_has_sparse_grad:
435*da0073e9SAndroid Build Coastguard Worker            # handle internal item() call if lr is a tensor
436*da0073e9SAndroid Build Coastguard Worker            if isinstance(lr, torch.Tensor) and torch._utils.is_compiling():
437*da0073e9SAndroid Build Coastguard Worker                grads_x_lr = torch._foreach_mul(device_grads, -lr)
438*da0073e9SAndroid Build Coastguard Worker                torch._foreach_add_(device_params, grads_x_lr)
439*da0073e9SAndroid Build Coastguard Worker            else:
440*da0073e9SAndroid Build Coastguard Worker                torch._foreach_add_(device_params, device_grads, alpha=-lr)
441*da0073e9SAndroid Build Coastguard Worker        else:
442*da0073e9SAndroid Build Coastguard Worker            # foreach APIs don't support sparse
443*da0073e9SAndroid Build Coastguard Worker            for i in range(len(device_params)):
444*da0073e9SAndroid Build Coastguard Worker                device_params[i].add_(device_grads[i], alpha=-lr)
445*da0073e9SAndroid Build Coastguard Worker
446*da0073e9SAndroid Build Coastguard Worker
447*da0073e9SAndroid Build Coastguard Workerdef _fused_sgd(
448*da0073e9SAndroid Build Coastguard Worker    params: List[Tensor],
449*da0073e9SAndroid Build Coastguard Worker    grads: List[Tensor],
450*da0073e9SAndroid Build Coastguard Worker    momentum_buffer_list: List[Optional[Tensor]],
451*da0073e9SAndroid Build Coastguard Worker    grad_scale: Optional[Tensor],
452*da0073e9SAndroid Build Coastguard Worker    found_inf: Optional[Tensor],
453*da0073e9SAndroid Build Coastguard Worker    *,
454*da0073e9SAndroid Build Coastguard Worker    weight_decay: float,
455*da0073e9SAndroid Build Coastguard Worker    momentum: float,
456*da0073e9SAndroid Build Coastguard Worker    lr: float,
457*da0073e9SAndroid Build Coastguard Worker    dampening: float,
458*da0073e9SAndroid Build Coastguard Worker    nesterov: bool,
459*da0073e9SAndroid Build Coastguard Worker    maximize: bool,
460*da0073e9SAndroid Build Coastguard Worker    has_sparse_grad: bool,
461*da0073e9SAndroid Build Coastguard Worker) -> None:
462*da0073e9SAndroid Build Coastguard Worker    if not params:
463*da0073e9SAndroid Build Coastguard Worker        return
464*da0073e9SAndroid Build Coastguard Worker    if has_sparse_grad:
465*da0073e9SAndroid Build Coastguard Worker        raise RuntimeError("`_fused_sgd` does not support sparse gradients")
466*da0073e9SAndroid Build Coastguard Worker    grad_scale_dict: DeviceDict = (
467*da0073e9SAndroid Build Coastguard Worker        {grad_scale.device: grad_scale} if grad_scale is not None else {}
468*da0073e9SAndroid Build Coastguard Worker    )
469*da0073e9SAndroid Build Coastguard Worker    found_inf_dict: DeviceDict = (
470*da0073e9SAndroid Build Coastguard Worker        {found_inf.device: found_inf} if found_inf is not None else {}
471*da0073e9SAndroid Build Coastguard Worker    )
472*da0073e9SAndroid Build Coastguard Worker
473*da0073e9SAndroid Build Coastguard Worker    no_momentum_buffer = momentum == 0
474*da0073e9SAndroid Build Coastguard Worker    is_first_step = (
475*da0073e9SAndroid Build Coastguard Worker        all(t is None for t in momentum_buffer_list) and not no_momentum_buffer
476*da0073e9SAndroid Build Coastguard Worker    )
477*da0073e9SAndroid Build Coastguard Worker    if is_first_step:
478*da0073e9SAndroid Build Coastguard Worker        for i, g in enumerate(grads):
479*da0073e9SAndroid Build Coastguard Worker            momentum_buffer_list[i] = torch.empty_like(g)
480*da0073e9SAndroid Build Coastguard Worker    grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
481*da0073e9SAndroid Build Coastguard Worker        [params, grads, momentum_buffer_list], with_indices=False  # type: ignore[list-item]
482*da0073e9SAndroid Build Coastguard Worker    )
483*da0073e9SAndroid Build Coastguard Worker    for (device, _), (
484*da0073e9SAndroid Build Coastguard Worker        (device_params_, device_grads_, device_momentum_buffer_list),
485*da0073e9SAndroid Build Coastguard Worker        _,
486*da0073e9SAndroid Build Coastguard Worker    ) in grouped_tensors.items():
487*da0073e9SAndroid Build Coastguard Worker        device_params: List[Tensor] = cast(List[Tensor], device_params_)
488*da0073e9SAndroid Build Coastguard Worker        device_grads: List[Tensor] = cast(List[Tensor], device_grads_)
489*da0073e9SAndroid Build Coastguard Worker        device_grad_scale, device_found_inf = None, None
490*da0073e9SAndroid Build Coastguard Worker        if grad_scale is not None:
491*da0073e9SAndroid Build Coastguard Worker            device_grad_scale = grad_scale_dict.setdefault(
492*da0073e9SAndroid Build Coastguard Worker                device, grad_scale.to(device)
493*da0073e9SAndroid Build Coastguard Worker            )
494*da0073e9SAndroid Build Coastguard Worker        if found_inf_dict is not None and found_inf is not None:
495*da0073e9SAndroid Build Coastguard Worker            device_found_inf = found_inf_dict.setdefault(device, found_inf.to(device))
496*da0073e9SAndroid Build Coastguard Worker        torch._fused_sgd_(
497*da0073e9SAndroid Build Coastguard Worker            device_params,
498*da0073e9SAndroid Build Coastguard Worker            device_grads,
499*da0073e9SAndroid Build Coastguard Worker            []
500*da0073e9SAndroid Build Coastguard Worker            if no_momentum_buffer
501*da0073e9SAndroid Build Coastguard Worker            else cast(List[Tensor], device_momentum_buffer_list),
502*da0073e9SAndroid Build Coastguard Worker            weight_decay=weight_decay,
503*da0073e9SAndroid Build Coastguard Worker            momentum=momentum,
504*da0073e9SAndroid Build Coastguard Worker            lr=lr,
505*da0073e9SAndroid Build Coastguard Worker            dampening=dampening,
506*da0073e9SAndroid Build Coastguard Worker            nesterov=nesterov,
507*da0073e9SAndroid Build Coastguard Worker            maximize=maximize,
508*da0073e9SAndroid Build Coastguard Worker            is_first_step=is_first_step,
509*da0073e9SAndroid Build Coastguard Worker            grad_scale=device_grad_scale,
510*da0073e9SAndroid Build Coastguard Worker            found_inf=device_found_inf,
511*da0073e9SAndroid Build Coastguard Worker        )
512