xref: /aosp_15_r20/external/pytorch/torch/optim/nadam.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-decorators
2*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs
3*da0073e9SAndroid Build Coastguard Workerr"""Implementation for the NAdam algorithm."""
4*da0073e9SAndroid Build Coastguard Workerfrom typing import cast, List, Optional, Tuple, Union
5*da0073e9SAndroid Build Coastguard Worker
6*da0073e9SAndroid Build Coastguard Workerimport torch
7*da0073e9SAndroid Build Coastguard Workerfrom torch import Tensor
8*da0073e9SAndroid Build Coastguard Worker
9*da0073e9SAndroid Build Coastguard Workerfrom .optimizer import (
10*da0073e9SAndroid Build Coastguard Worker    _capturable_doc,
11*da0073e9SAndroid Build Coastguard Worker    _default_to_fused_or_foreach,
12*da0073e9SAndroid Build Coastguard Worker    _differentiable_doc,
13*da0073e9SAndroid Build Coastguard Worker    _disable_dynamo_if_unsupported,
14*da0073e9SAndroid Build Coastguard Worker    _foreach_doc,
15*da0073e9SAndroid Build Coastguard Worker    _get_capturable_supported_devices,
16*da0073e9SAndroid Build Coastguard Worker    _get_scalar_dtype,
17*da0073e9SAndroid Build Coastguard Worker    _get_value,
18*da0073e9SAndroid Build Coastguard Worker    _maximize_doc,
19*da0073e9SAndroid Build Coastguard Worker    _stack_if_compiling,
20*da0073e9SAndroid Build Coastguard Worker    _use_grad_for_differentiable,
21*da0073e9SAndroid Build Coastguard Worker    _view_as_real,
22*da0073e9SAndroid Build Coastguard Worker    Optimizer,
23*da0073e9SAndroid Build Coastguard Worker    ParamsT,
24*da0073e9SAndroid Build Coastguard Worker)
25*da0073e9SAndroid Build Coastguard Worker
26*da0073e9SAndroid Build Coastguard Worker
27*da0073e9SAndroid Build Coastguard Worker__all__ = ["NAdam", "nadam"]
28*da0073e9SAndroid Build Coastguard Worker
29*da0073e9SAndroid Build Coastguard Worker
30*da0073e9SAndroid Build Coastguard Workerclass NAdam(Optimizer):  # noqa: D101
31*da0073e9SAndroid Build Coastguard Worker    def __init__(
32*da0073e9SAndroid Build Coastguard Worker        self,
33*da0073e9SAndroid Build Coastguard Worker        params: ParamsT,
34*da0073e9SAndroid Build Coastguard Worker        lr: Union[float, Tensor] = 2e-3,
35*da0073e9SAndroid Build Coastguard Worker        betas: Tuple[float, float] = (0.9, 0.999),
36*da0073e9SAndroid Build Coastguard Worker        eps: float = 1e-8,
37*da0073e9SAndroid Build Coastguard Worker        weight_decay: float = 0,
38*da0073e9SAndroid Build Coastguard Worker        momentum_decay: float = 4e-3,
39*da0073e9SAndroid Build Coastguard Worker        decoupled_weight_decay: bool = False,
40*da0073e9SAndroid Build Coastguard Worker        *,
41*da0073e9SAndroid Build Coastguard Worker        foreach: Optional[bool] = None,
42*da0073e9SAndroid Build Coastguard Worker        maximize: bool = False,
43*da0073e9SAndroid Build Coastguard Worker        capturable: bool = False,
44*da0073e9SAndroid Build Coastguard Worker        differentiable: bool = False,
45*da0073e9SAndroid Build Coastguard Worker    ):  # noqa: D107
46*da0073e9SAndroid Build Coastguard Worker        if isinstance(lr, Tensor) and lr.numel() != 1:
47*da0073e9SAndroid Build Coastguard Worker            raise ValueError("Tensor lr must be 1-element")
48*da0073e9SAndroid Build Coastguard Worker        if not 0.0 <= lr:
49*da0073e9SAndroid Build Coastguard Worker            raise ValueError(f"Invalid learning rate: {lr}")
50*da0073e9SAndroid Build Coastguard Worker        if not 0.0 <= eps:
51*da0073e9SAndroid Build Coastguard Worker            raise ValueError(f"Invalid epsilon value: {eps}")
52*da0073e9SAndroid Build Coastguard Worker        if not 0.0 <= betas[0] < 1.0:
53*da0073e9SAndroid Build Coastguard Worker            raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
54*da0073e9SAndroid Build Coastguard Worker        if not 0.0 <= betas[1] < 1.0:
55*da0073e9SAndroid Build Coastguard Worker            raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
56*da0073e9SAndroid Build Coastguard Worker        if not 0.0 <= weight_decay:
57*da0073e9SAndroid Build Coastguard Worker            raise ValueError(f"Invalid weight_decay value: {weight_decay}")
58*da0073e9SAndroid Build Coastguard Worker        if not 0.0 <= momentum_decay:
59*da0073e9SAndroid Build Coastguard Worker            raise ValueError(f"Invalid momentum_decay value: {momentum_decay}")
60*da0073e9SAndroid Build Coastguard Worker        defaults = dict(
61*da0073e9SAndroid Build Coastguard Worker            lr=lr,
62*da0073e9SAndroid Build Coastguard Worker            betas=betas,
63*da0073e9SAndroid Build Coastguard Worker            eps=eps,
64*da0073e9SAndroid Build Coastguard Worker            weight_decay=weight_decay,
65*da0073e9SAndroid Build Coastguard Worker            momentum_decay=momentum_decay,
66*da0073e9SAndroid Build Coastguard Worker            decoupled_weight_decay=decoupled_weight_decay,
67*da0073e9SAndroid Build Coastguard Worker            maximize=maximize,
68*da0073e9SAndroid Build Coastguard Worker            foreach=foreach,
69*da0073e9SAndroid Build Coastguard Worker            capturable=capturable,
70*da0073e9SAndroid Build Coastguard Worker            differentiable=differentiable,
71*da0073e9SAndroid Build Coastguard Worker        )
72*da0073e9SAndroid Build Coastguard Worker        super().__init__(params, defaults)
73*da0073e9SAndroid Build Coastguard Worker
74*da0073e9SAndroid Build Coastguard Worker    def __setstate__(self, state):  # noqa: D105
75*da0073e9SAndroid Build Coastguard Worker        super().__setstate__(state)
76*da0073e9SAndroid Build Coastguard Worker        for group in self.param_groups:
77*da0073e9SAndroid Build Coastguard Worker            group.setdefault("maximize", False)
78*da0073e9SAndroid Build Coastguard Worker            group.setdefault("foreach", None)
79*da0073e9SAndroid Build Coastguard Worker            group.setdefault("capturable", False)
80*da0073e9SAndroid Build Coastguard Worker            group.setdefault("differentiable", False)
81*da0073e9SAndroid Build Coastguard Worker            group.setdefault("decoupled_weight_decay", False)
82*da0073e9SAndroid Build Coastguard Worker            for p in group["params"]:
83*da0073e9SAndroid Build Coastguard Worker                p_state = self.state.get(p, [])
84*da0073e9SAndroid Build Coastguard Worker                if len(p_state) != 0:
85*da0073e9SAndroid Build Coastguard Worker                    if not torch.is_tensor(p_state["step"]):
86*da0073e9SAndroid Build Coastguard Worker                        step_val = float(p_state["step"])
87*da0073e9SAndroid Build Coastguard Worker                        p_state["step"] = (
88*da0073e9SAndroid Build Coastguard Worker                            torch.tensor(
89*da0073e9SAndroid Build Coastguard Worker                                step_val, dtype=_get_scalar_dtype(), device=p.device
90*da0073e9SAndroid Build Coastguard Worker                            )
91*da0073e9SAndroid Build Coastguard Worker                            if group["capturable"]
92*da0073e9SAndroid Build Coastguard Worker                            else torch.tensor(step_val, dtype=_get_scalar_dtype())
93*da0073e9SAndroid Build Coastguard Worker                        )
94*da0073e9SAndroid Build Coastguard Worker                    if not torch.is_tensor(p_state["mu_product"]):
95*da0073e9SAndroid Build Coastguard Worker                        mu_prod_val = p_state["mu_product"]
96*da0073e9SAndroid Build Coastguard Worker                        p_state["mu_product"] = (
97*da0073e9SAndroid Build Coastguard Worker                            torch.tensor(
98*da0073e9SAndroid Build Coastguard Worker                                mu_prod_val, dtype=_get_scalar_dtype(), device=p.device
99*da0073e9SAndroid Build Coastguard Worker                            )
100*da0073e9SAndroid Build Coastguard Worker                            if group["capturable"]
101*da0073e9SAndroid Build Coastguard Worker                            else torch.tensor(mu_prod_val, dtype=_get_scalar_dtype())
102*da0073e9SAndroid Build Coastguard Worker                        )
103*da0073e9SAndroid Build Coastguard Worker
104*da0073e9SAndroid Build Coastguard Worker    def _init_group(
105*da0073e9SAndroid Build Coastguard Worker        self,
106*da0073e9SAndroid Build Coastguard Worker        group,
107*da0073e9SAndroid Build Coastguard Worker        params_with_grad,
108*da0073e9SAndroid Build Coastguard Worker        grads,
109*da0073e9SAndroid Build Coastguard Worker        exp_avgs,
110*da0073e9SAndroid Build Coastguard Worker        exp_avg_sqs,
111*da0073e9SAndroid Build Coastguard Worker        mu_products,
112*da0073e9SAndroid Build Coastguard Worker        state_steps,
113*da0073e9SAndroid Build Coastguard Worker    ):
114*da0073e9SAndroid Build Coastguard Worker        has_complex = False
115*da0073e9SAndroid Build Coastguard Worker        for p in group["params"]:
116*da0073e9SAndroid Build Coastguard Worker            if p.grad is not None:
117*da0073e9SAndroid Build Coastguard Worker                has_complex |= torch.is_complex(p)
118*da0073e9SAndroid Build Coastguard Worker                params_with_grad.append(p)
119*da0073e9SAndroid Build Coastguard Worker                if p.grad.is_sparse:
120*da0073e9SAndroid Build Coastguard Worker                    raise RuntimeError("NAdam does not support sparse gradients")
121*da0073e9SAndroid Build Coastguard Worker                grads.append(p.grad)
122*da0073e9SAndroid Build Coastguard Worker
123*da0073e9SAndroid Build Coastguard Worker                state = self.state[p]
124*da0073e9SAndroid Build Coastguard Worker                # Lazy state initialization
125*da0073e9SAndroid Build Coastguard Worker                if len(state) == 0:
126*da0073e9SAndroid Build Coastguard Worker                    # note(crcrpar): [special device hosting for step]
127*da0073e9SAndroid Build Coastguard Worker                    # Deliberately host `step` and `mu_product` on CPU if capturable is False.
128*da0073e9SAndroid Build Coastguard Worker                    # This is because kernel launches are costly on CUDA and XLA.
129*da0073e9SAndroid Build Coastguard Worker                    state["step"] = (
130*da0073e9SAndroid Build Coastguard Worker                        torch.zeros((), dtype=_get_scalar_dtype(), device=p.device)
131*da0073e9SAndroid Build Coastguard Worker                        if group["capturable"]
132*da0073e9SAndroid Build Coastguard Worker                        else torch.tensor(0.0, dtype=_get_scalar_dtype())
133*da0073e9SAndroid Build Coastguard Worker                    )
134*da0073e9SAndroid Build Coastguard Worker                    state["mu_product"] = (
135*da0073e9SAndroid Build Coastguard Worker                        torch.ones((), dtype=_get_scalar_dtype(), device=p.device)
136*da0073e9SAndroid Build Coastguard Worker                        if group["capturable"]
137*da0073e9SAndroid Build Coastguard Worker                        else torch.tensor(1.0, dtype=_get_scalar_dtype())
138*da0073e9SAndroid Build Coastguard Worker                    )
139*da0073e9SAndroid Build Coastguard Worker                    # Exponential moving average of gradient values
140*da0073e9SAndroid Build Coastguard Worker                    state["exp_avg"] = torch.zeros_like(
141*da0073e9SAndroid Build Coastguard Worker                        p, memory_format=torch.preserve_format
142*da0073e9SAndroid Build Coastguard Worker                    )
143*da0073e9SAndroid Build Coastguard Worker                    # Exponential moving average of squared gradient values
144*da0073e9SAndroid Build Coastguard Worker                    state["exp_avg_sq"] = torch.zeros_like(
145*da0073e9SAndroid Build Coastguard Worker                        p, memory_format=torch.preserve_format
146*da0073e9SAndroid Build Coastguard Worker                    )
147*da0073e9SAndroid Build Coastguard Worker
148*da0073e9SAndroid Build Coastguard Worker                exp_avgs.append(state["exp_avg"])
149*da0073e9SAndroid Build Coastguard Worker                exp_avg_sqs.append(state["exp_avg_sq"])
150*da0073e9SAndroid Build Coastguard Worker                mu_products.append(state["mu_product"])
151*da0073e9SAndroid Build Coastguard Worker                state_steps.append(state["step"])
152*da0073e9SAndroid Build Coastguard Worker        return has_complex
153*da0073e9SAndroid Build Coastguard Worker
154*da0073e9SAndroid Build Coastguard Worker    @_use_grad_for_differentiable
155*da0073e9SAndroid Build Coastguard Worker    def step(self, closure=None):
156*da0073e9SAndroid Build Coastguard Worker        """Perform a single optimization step.
157*da0073e9SAndroid Build Coastguard Worker
158*da0073e9SAndroid Build Coastguard Worker        Args:
159*da0073e9SAndroid Build Coastguard Worker            closure (Callable, optional): A closure that reevaluates the model
160*da0073e9SAndroid Build Coastguard Worker                and returns the loss.
161*da0073e9SAndroid Build Coastguard Worker        """
162*da0073e9SAndroid Build Coastguard Worker        self._cuda_graph_capture_health_check()
163*da0073e9SAndroid Build Coastguard Worker
164*da0073e9SAndroid Build Coastguard Worker        loss = None
165*da0073e9SAndroid Build Coastguard Worker        if closure is not None:
166*da0073e9SAndroid Build Coastguard Worker            with torch.enable_grad():
167*da0073e9SAndroid Build Coastguard Worker                loss = closure()
168*da0073e9SAndroid Build Coastguard Worker
169*da0073e9SAndroid Build Coastguard Worker        for group in self.param_groups:
170*da0073e9SAndroid Build Coastguard Worker            params_with_grad: List[Tensor] = []
171*da0073e9SAndroid Build Coastguard Worker            grads: List[Tensor] = []
172*da0073e9SAndroid Build Coastguard Worker            exp_avgs: List[Tensor] = []
173*da0073e9SAndroid Build Coastguard Worker            exp_avg_sqs: List[Tensor] = []
174*da0073e9SAndroid Build Coastguard Worker            mu_products: List[Tensor] = []
175*da0073e9SAndroid Build Coastguard Worker            state_steps: List[Tensor] = []
176*da0073e9SAndroid Build Coastguard Worker            beta1, beta2 = cast(Tuple[float, float], group["betas"])
177*da0073e9SAndroid Build Coastguard Worker
178*da0073e9SAndroid Build Coastguard Worker            has_complex = self._init_group(
179*da0073e9SAndroid Build Coastguard Worker                group,
180*da0073e9SAndroid Build Coastguard Worker                params_with_grad,
181*da0073e9SAndroid Build Coastguard Worker                grads,
182*da0073e9SAndroid Build Coastguard Worker                exp_avgs,
183*da0073e9SAndroid Build Coastguard Worker                exp_avg_sqs,
184*da0073e9SAndroid Build Coastguard Worker                mu_products,
185*da0073e9SAndroid Build Coastguard Worker                state_steps,
186*da0073e9SAndroid Build Coastguard Worker            )
187*da0073e9SAndroid Build Coastguard Worker
188*da0073e9SAndroid Build Coastguard Worker            nadam(
189*da0073e9SAndroid Build Coastguard Worker                params_with_grad,
190*da0073e9SAndroid Build Coastguard Worker                grads,
191*da0073e9SAndroid Build Coastguard Worker                exp_avgs,
192*da0073e9SAndroid Build Coastguard Worker                exp_avg_sqs,
193*da0073e9SAndroid Build Coastguard Worker                mu_products,
194*da0073e9SAndroid Build Coastguard Worker                state_steps,
195*da0073e9SAndroid Build Coastguard Worker                beta1=beta1,
196*da0073e9SAndroid Build Coastguard Worker                beta2=beta2,
197*da0073e9SAndroid Build Coastguard Worker                lr=group["lr"],
198*da0073e9SAndroid Build Coastguard Worker                weight_decay=group["weight_decay"],
199*da0073e9SAndroid Build Coastguard Worker                momentum_decay=group["momentum_decay"],
200*da0073e9SAndroid Build Coastguard Worker                eps=group["eps"],
201*da0073e9SAndroid Build Coastguard Worker                maximize=group["maximize"],
202*da0073e9SAndroid Build Coastguard Worker                decoupled_weight_decay=group["decoupled_weight_decay"],
203*da0073e9SAndroid Build Coastguard Worker                foreach=group["foreach"],
204*da0073e9SAndroid Build Coastguard Worker                capturable=group["capturable"],
205*da0073e9SAndroid Build Coastguard Worker                differentiable=group["differentiable"],
206*da0073e9SAndroid Build Coastguard Worker                has_complex=has_complex,
207*da0073e9SAndroid Build Coastguard Worker            )
208*da0073e9SAndroid Build Coastguard Worker
209*da0073e9SAndroid Build Coastguard Worker        return loss
210*da0073e9SAndroid Build Coastguard Worker
211*da0073e9SAndroid Build Coastguard Worker
212*da0073e9SAndroid Build Coastguard WorkerNAdam.__doc__ = (
213*da0073e9SAndroid Build Coastguard Worker    r"""Implements NAdam algorithm.
214*da0073e9SAndroid Build Coastguard Worker
215*da0073e9SAndroid Build Coastguard Worker    .. math::
216*da0073e9SAndroid Build Coastguard Worker       \begin{aligned}
217*da0073e9SAndroid Build Coastguard Worker            &\rule{110mm}{0.4pt}                                                                 \\
218*da0073e9SAndroid Build Coastguard Worker            &\textbf{input}      : \gamma_t \text{ (lr)}, \: \beta_1,\beta_2 \text{ (betas)},
219*da0073e9SAndroid Build Coastguard Worker                \: \theta_0 \text{ (params)}, \: f(\theta) \text{ (objective)}                   \\
220*da0073e9SAndroid Build Coastguard Worker            &\hspace{13mm} \: \lambda \text{ (weight decay)}, \:\psi \text{ (momentum decay)}    \\
221*da0073e9SAndroid Build Coastguard Worker            &\hspace{13mm} \: \textit{decoupled\_weight\_decay}, \:\textit{maximize}             \\
222*da0073e9SAndroid Build Coastguard Worker            &\textbf{initialize} :  m_0 \leftarrow 0 \text{ ( first moment)},
223*da0073e9SAndroid Build Coastguard Worker                v_0 \leftarrow 0 \text{ ( second moment)}                                 \\[-1.ex]
224*da0073e9SAndroid Build Coastguard Worker            &\rule{110mm}{0.4pt}                                                                 \\
225*da0073e9SAndroid Build Coastguard Worker            &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do}                         \\
226*da0073e9SAndroid Build Coastguard Worker            &\hspace{5mm}\textbf{if} \: \textit{maximize}:                                       \\
227*da0073e9SAndroid Build Coastguard Worker            &\hspace{10mm}g_t           \leftarrow   -\nabla_{\theta} f_t (\theta_{t-1})         \\
228*da0073e9SAndroid Build Coastguard Worker            &\hspace{5mm}\textbf{else}                                                           \\
229*da0073e9SAndroid Build Coastguard Worker            &\hspace{10mm}g_t           \leftarrow   \nabla_{\theta} f_t (\theta_{t-1})          \\
230*da0073e9SAndroid Build Coastguard Worker            &\hspace{5mm} \theta_t \leftarrow \theta_{t-1}                                       \\
231*da0073e9SAndroid Build Coastguard Worker            &\hspace{5mm} \textbf{if} \: \lambda \neq 0                                          \\
232*da0073e9SAndroid Build Coastguard Worker            &\hspace{10mm}\textbf{if} \: \textit{decoupled\_weight\_decay}                       \\
233*da0073e9SAndroid Build Coastguard Worker            &\hspace{15mm} \theta_t \leftarrow \theta_{t-1} - \gamma \lambda \theta_{t-1}                    \\
234*da0073e9SAndroid Build Coastguard Worker            &\hspace{10mm}\textbf{else}                                                          \\
235*da0073e9SAndroid Build Coastguard Worker            &\hspace{15mm} g_t \leftarrow g_t + \lambda \theta_{t-1}                             \\
236*da0073e9SAndroid Build Coastguard Worker            &\hspace{5mm} \mu_t \leftarrow \beta_1 \big(1 - \frac{1}{2}  0.96^{t \psi} \big)     \\
237*da0073e9SAndroid Build Coastguard Worker            &\hspace{5mm} \mu_{t+1} \leftarrow \beta_1 \big(1 - \frac{1}{2} 0.96^{(t+1)\psi}\big)\\
238*da0073e9SAndroid Build Coastguard Worker            &\hspace{5mm}m_t           \leftarrow   \beta_1 m_{t-1} + (1 - \beta_1) g_t          \\
239*da0073e9SAndroid Build Coastguard Worker            &\hspace{5mm}v_t           \leftarrow   \beta_2 v_{t-1} + (1-\beta_2) g^2_t          \\
240*da0073e9SAndroid Build Coastguard Worker            &\hspace{5mm}\widehat{m_t} \leftarrow \mu_{t+1} m_t/(1-\prod_{i=1}^{t+1}\mu_i)\\[-1.ex]
241*da0073e9SAndroid Build Coastguard Worker            & \hspace{11mm} + (1-\mu_t) g_t /(1-\prod_{i=1}^{t} \mu_{i})                         \\
242*da0073e9SAndroid Build Coastguard Worker            &\hspace{5mm}\widehat{v_t} \leftarrow   v_t/\big(1-\beta_2^t \big)                   \\
243*da0073e9SAndroid Build Coastguard Worker            &\hspace{5mm}\theta_t \leftarrow \theta_t - \gamma \widehat{m_t}/
244*da0073e9SAndroid Build Coastguard Worker                \big(\sqrt{\widehat{v_t}} + \epsilon \big)                                       \\
245*da0073e9SAndroid Build Coastguard Worker            &\rule{110mm}{0.4pt}                                                          \\[-1.ex]
246*da0073e9SAndroid Build Coastguard Worker            &\bf{return} \:  \theta_t                                                     \\[-1.ex]
247*da0073e9SAndroid Build Coastguard Worker            &\rule{110mm}{0.4pt}                                                          \\[-1.ex]
248*da0073e9SAndroid Build Coastguard Worker       \end{aligned}
249*da0073e9SAndroid Build Coastguard Worker
250*da0073e9SAndroid Build Coastguard Worker    For further details regarding the algorithm we refer to `Incorporating Nesterov Momentum into Adam`_.
251*da0073e9SAndroid Build Coastguard Worker    """
252*da0073e9SAndroid Build Coastguard Worker    + rf"""
253*da0073e9SAndroid Build Coastguard Worker    Args:
254*da0073e9SAndroid Build Coastguard Worker        params (iterable): iterable of parameters to optimize or dicts defining
255*da0073e9SAndroid Build Coastguard Worker            parameter groups
256*da0073e9SAndroid Build Coastguard Worker        lr (float, Tensor, optional): learning rate (default: 2e-3)
257*da0073e9SAndroid Build Coastguard Worker        betas (Tuple[float, float], optional): coefficients used for computing
258*da0073e9SAndroid Build Coastguard Worker            running averages of gradient and its square (default: (0.9, 0.999))
259*da0073e9SAndroid Build Coastguard Worker        eps (float, optional): term added to the denominator to improve
260*da0073e9SAndroid Build Coastguard Worker            numerical stability (default: 1e-8)
261*da0073e9SAndroid Build Coastguard Worker        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
262*da0073e9SAndroid Build Coastguard Worker        momentum_decay (float, optional): momentum momentum_decay (default: 4e-3)
263*da0073e9SAndroid Build Coastguard Worker        decoupled_weight_decay (bool, optional): whether to use decoupled weight
264*da0073e9SAndroid Build Coastguard Worker            decay as in AdamW to obtain NAdamW (default: False)
265*da0073e9SAndroid Build Coastguard Worker        {_foreach_doc}
266*da0073e9SAndroid Build Coastguard Worker        {_maximize_doc}
267*da0073e9SAndroid Build Coastguard Worker        {_capturable_doc}
268*da0073e9SAndroid Build Coastguard Worker        {_differentiable_doc}
269*da0073e9SAndroid Build Coastguard Worker
270*da0073e9SAndroid Build Coastguard Worker    .. _Incorporating Nesterov Momentum into Adam:
271*da0073e9SAndroid Build Coastguard Worker        https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ
272*da0073e9SAndroid Build Coastguard Worker    .. _Decoupled Weight Decay Regularization:
273*da0073e9SAndroid Build Coastguard Worker        https://arxiv.org/abs/1711.05101
274*da0073e9SAndroid Build Coastguard Worker
275*da0073e9SAndroid Build Coastguard Worker    """
276*da0073e9SAndroid Build Coastguard Worker)
277*da0073e9SAndroid Build Coastguard Worker
278*da0073e9SAndroid Build Coastguard Worker
279*da0073e9SAndroid Build Coastguard Workerdef _single_tensor_nadam(
280*da0073e9SAndroid Build Coastguard Worker    params: List[Tensor],
281*da0073e9SAndroid Build Coastguard Worker    grads: List[Tensor],
282*da0073e9SAndroid Build Coastguard Worker    exp_avgs: List[Tensor],
283*da0073e9SAndroid Build Coastguard Worker    exp_avg_sqs: List[Tensor],
284*da0073e9SAndroid Build Coastguard Worker    mu_products: List[Tensor],
285*da0073e9SAndroid Build Coastguard Worker    state_steps: List[Tensor],
286*da0073e9SAndroid Build Coastguard Worker    *,
287*da0073e9SAndroid Build Coastguard Worker    beta1: float,
288*da0073e9SAndroid Build Coastguard Worker    beta2: float,
289*da0073e9SAndroid Build Coastguard Worker    lr: float,
290*da0073e9SAndroid Build Coastguard Worker    weight_decay: float,
291*da0073e9SAndroid Build Coastguard Worker    momentum_decay: float,
292*da0073e9SAndroid Build Coastguard Worker    eps: float,
293*da0073e9SAndroid Build Coastguard Worker    decoupled_weight_decay: bool,
294*da0073e9SAndroid Build Coastguard Worker    maximize: bool,
295*da0073e9SAndroid Build Coastguard Worker    capturable: bool,
296*da0073e9SAndroid Build Coastguard Worker    differentiable: bool,
297*da0073e9SAndroid Build Coastguard Worker    has_complex: bool,
298*da0073e9SAndroid Build Coastguard Worker):
299*da0073e9SAndroid Build Coastguard Worker    for i, param in enumerate(params):
300*da0073e9SAndroid Build Coastguard Worker        grad = grads[i] if not maximize else -grads[i]
301*da0073e9SAndroid Build Coastguard Worker        exp_avg = exp_avgs[i]
302*da0073e9SAndroid Build Coastguard Worker        exp_avg_sq = exp_avg_sqs[i]
303*da0073e9SAndroid Build Coastguard Worker        mu_product = mu_products[i]
304*da0073e9SAndroid Build Coastguard Worker        step_t = state_steps[i]
305*da0073e9SAndroid Build Coastguard Worker
306*da0073e9SAndroid Build Coastguard Worker        if torch.is_complex(param):
307*da0073e9SAndroid Build Coastguard Worker            param = torch.view_as_real(param)
308*da0073e9SAndroid Build Coastguard Worker            grad = torch.view_as_real(grad)
309*da0073e9SAndroid Build Coastguard Worker            exp_avg = torch.view_as_real(exp_avg)
310*da0073e9SAndroid Build Coastguard Worker            exp_avg_sq = torch.view_as_real(exp_avg_sq)
311*da0073e9SAndroid Build Coastguard Worker
312*da0073e9SAndroid Build Coastguard Worker        # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
313*da0073e9SAndroid Build Coastguard Worker        if not torch._utils.is_compiling() and capturable:
314*da0073e9SAndroid Build Coastguard Worker            capturable_supported_devices = _get_capturable_supported_devices()
315*da0073e9SAndroid Build Coastguard Worker            assert (
316*da0073e9SAndroid Build Coastguard Worker                param.device.type == mu_product.device.type == step_t.device.type
317*da0073e9SAndroid Build Coastguard Worker                and param.device.type in capturable_supported_devices
318*da0073e9SAndroid Build Coastguard Worker            ), (
319*da0073e9SAndroid Build Coastguard Worker                f"If capturable=True, params, mu_products and state_steps must be "
320*da0073e9SAndroid Build Coastguard Worker                f"on supported devices: {capturable_supported_devices}."
321*da0073e9SAndroid Build Coastguard Worker            )
322*da0073e9SAndroid Build Coastguard Worker
323*da0073e9SAndroid Build Coastguard Worker        # update step
324*da0073e9SAndroid Build Coastguard Worker        step_t += 1
325*da0073e9SAndroid Build Coastguard Worker
326*da0073e9SAndroid Build Coastguard Worker        if capturable:
327*da0073e9SAndroid Build Coastguard Worker            step = step_t
328*da0073e9SAndroid Build Coastguard Worker        else:
329*da0073e9SAndroid Build Coastguard Worker            step = _get_value(step_t)
330*da0073e9SAndroid Build Coastguard Worker
331*da0073e9SAndroid Build Coastguard Worker        bias_correction2 = 1 - beta2**step
332*da0073e9SAndroid Build Coastguard Worker
333*da0073e9SAndroid Build Coastguard Worker        if weight_decay != 0:
334*da0073e9SAndroid Build Coastguard Worker            if decoupled_weight_decay:
335*da0073e9SAndroid Build Coastguard Worker                # Perform stepweight decay
336*da0073e9SAndroid Build Coastguard Worker                param.mul_(1 - lr * weight_decay)
337*da0073e9SAndroid Build Coastguard Worker            else:
338*da0073e9SAndroid Build Coastguard Worker                grad = grad.add(param, alpha=weight_decay)
339*da0073e9SAndroid Build Coastguard Worker
340*da0073e9SAndroid Build Coastguard Worker        # calculate the momentum cache \mu^{t} and \mu^{t+1}
341*da0073e9SAndroid Build Coastguard Worker        mu = beta1 * (1.0 - 0.5 * (0.96 ** (step * momentum_decay)))
342*da0073e9SAndroid Build Coastguard Worker        mu_next = beta1 * (1.0 - 0.5 * (0.96 ** ((step + 1) * momentum_decay)))
343*da0073e9SAndroid Build Coastguard Worker
344*da0073e9SAndroid Build Coastguard Worker        # update mu_product
345*da0073e9SAndroid Build Coastguard Worker        mu_product *= mu
346*da0073e9SAndroid Build Coastguard Worker
347*da0073e9SAndroid Build Coastguard Worker        # decay the first and second moment running average coefficient
348*da0073e9SAndroid Build Coastguard Worker        exp_avg.lerp_(grad, 1 - beta1)
349*da0073e9SAndroid Build Coastguard Worker        exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
350*da0073e9SAndroid Build Coastguard Worker        denom = exp_avg_sq.div(bias_correction2).sqrt()
351*da0073e9SAndroid Build Coastguard Worker
352*da0073e9SAndroid Build Coastguard Worker        if differentiable or capturable:
353*da0073e9SAndroid Build Coastguard Worker            denom = denom.add(eps)
354*da0073e9SAndroid Build Coastguard Worker            # Make autograd track the operations
355*da0073e9SAndroid Build Coastguard Worker            # by updating the grad and exp_avg directly and not using the
356*da0073e9SAndroid Build Coastguard Worker            # scalar "value" argument of addcdiv.
357*da0073e9SAndroid Build Coastguard Worker            mu_product_next = mu_product * mu_next
358*da0073e9SAndroid Build Coastguard Worker            grad = grad * (-lr * (1.0 - mu) / (1.0 - mu_product))
359*da0073e9SAndroid Build Coastguard Worker            exp_avg = exp_avg * (-lr * mu_next / (1.0 - mu_product_next))
360*da0073e9SAndroid Build Coastguard Worker            param.addcdiv_(grad, denom)
361*da0073e9SAndroid Build Coastguard Worker            param.addcdiv_(exp_avg, denom)
362*da0073e9SAndroid Build Coastguard Worker        else:
363*da0073e9SAndroid Build Coastguard Worker            mu_product_next = _get_value(mu_product) * mu_next
364*da0073e9SAndroid Build Coastguard Worker            denom.add_(eps)
365*da0073e9SAndroid Build Coastguard Worker            param.addcdiv_(
366*da0073e9SAndroid Build Coastguard Worker                grad, denom, value=(-lr * (1.0 - mu) / (1.0 - _get_value(mu_product)))
367*da0073e9SAndroid Build Coastguard Worker            )
368*da0073e9SAndroid Build Coastguard Worker            param.addcdiv_(
369*da0073e9SAndroid Build Coastguard Worker                exp_avg, denom, value=(-lr * mu_next) / (1.0 - mu_product_next)
370*da0073e9SAndroid Build Coastguard Worker            )
371*da0073e9SAndroid Build Coastguard Worker
372*da0073e9SAndroid Build Coastguard Worker
373*da0073e9SAndroid Build Coastguard Workerdef _multi_tensor_nadam(
374*da0073e9SAndroid Build Coastguard Worker    params: List[Tensor],
375*da0073e9SAndroid Build Coastguard Worker    grads: List[Tensor],
376*da0073e9SAndroid Build Coastguard Worker    exp_avgs: List[Tensor],
377*da0073e9SAndroid Build Coastguard Worker    exp_avg_sqs: List[Tensor],
378*da0073e9SAndroid Build Coastguard Worker    mu_products: List[Tensor],
379*da0073e9SAndroid Build Coastguard Worker    state_steps: List[Tensor],
380*da0073e9SAndroid Build Coastguard Worker    *,
381*da0073e9SAndroid Build Coastguard Worker    beta1: float,
382*da0073e9SAndroid Build Coastguard Worker    beta2: float,
383*da0073e9SAndroid Build Coastguard Worker    lr: float,
384*da0073e9SAndroid Build Coastguard Worker    weight_decay: float,
385*da0073e9SAndroid Build Coastguard Worker    momentum_decay: float,
386*da0073e9SAndroid Build Coastguard Worker    eps: float,
387*da0073e9SAndroid Build Coastguard Worker    decoupled_weight_decay: bool,
388*da0073e9SAndroid Build Coastguard Worker    maximize: bool,
389*da0073e9SAndroid Build Coastguard Worker    capturable: bool,
390*da0073e9SAndroid Build Coastguard Worker    differentiable: bool,
391*da0073e9SAndroid Build Coastguard Worker    has_complex: bool,
392*da0073e9SAndroid Build Coastguard Worker):
393*da0073e9SAndroid Build Coastguard Worker    if len(params) == 0:
394*da0073e9SAndroid Build Coastguard Worker        return
395*da0073e9SAndroid Build Coastguard Worker
396*da0073e9SAndroid Build Coastguard Worker    assert not differentiable, "_foreach ops don't support autograd"
397*da0073e9SAndroid Build Coastguard Worker
398*da0073e9SAndroid Build Coastguard Worker    # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
399*da0073e9SAndroid Build Coastguard Worker    if not torch._utils.is_compiling() and capturable:
400*da0073e9SAndroid Build Coastguard Worker        capturable_supported_devices = _get_capturable_supported_devices(
401*da0073e9SAndroid Build Coastguard Worker            supports_xla=False
402*da0073e9SAndroid Build Coastguard Worker        )
403*da0073e9SAndroid Build Coastguard Worker        assert all(
404*da0073e9SAndroid Build Coastguard Worker            p.device.type == mp.device.type == step.device.type
405*da0073e9SAndroid Build Coastguard Worker            and p.device.type in capturable_supported_devices
406*da0073e9SAndroid Build Coastguard Worker            for p, mp, step in zip(params, mu_products, state_steps)
407*da0073e9SAndroid Build Coastguard Worker        ), f"If capturable=True, params, mu_products, and state_steps must be on supported devices: {capturable_supported_devices}."
408*da0073e9SAndroid Build Coastguard Worker
409*da0073e9SAndroid Build Coastguard Worker    grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
410*da0073e9SAndroid Build Coastguard Worker        [params, grads, exp_avgs, exp_avg_sqs, mu_products, state_steps]  # type: ignore[list-item]
411*da0073e9SAndroid Build Coastguard Worker    )
412*da0073e9SAndroid Build Coastguard Worker    for (
413*da0073e9SAndroid Build Coastguard Worker        grouped_params_,
414*da0073e9SAndroid Build Coastguard Worker        grouped_grads_,
415*da0073e9SAndroid Build Coastguard Worker        grouped_exp_avgs_,
416*da0073e9SAndroid Build Coastguard Worker        grouped_exp_avg_sqs_,
417*da0073e9SAndroid Build Coastguard Worker        grouped_mu_products_,
418*da0073e9SAndroid Build Coastguard Worker        grouped_state_steps_,
419*da0073e9SAndroid Build Coastguard Worker    ), _ in grouped_tensors.values():
420*da0073e9SAndroid Build Coastguard Worker        grouped_params = cast(List[Tensor], grouped_params_)
421*da0073e9SAndroid Build Coastguard Worker        grouped_grads = cast(List[Tensor], grouped_grads_)
422*da0073e9SAndroid Build Coastguard Worker        grouped_exp_avgs = cast(List[Tensor], grouped_exp_avgs_)
423*da0073e9SAndroid Build Coastguard Worker        grouped_exp_avg_sqs = cast(List[Tensor], grouped_exp_avg_sqs_)
424*da0073e9SAndroid Build Coastguard Worker        grouped_mu_products = cast(List[Tensor], grouped_mu_products_)
425*da0073e9SAndroid Build Coastguard Worker        grouped_state_steps = cast(List[Tensor], grouped_state_steps_)
426*da0073e9SAndroid Build Coastguard Worker
427*da0073e9SAndroid Build Coastguard Worker        # handle complex
428*da0073e9SAndroid Build Coastguard Worker        if has_complex:
429*da0073e9SAndroid Build Coastguard Worker            _view_as_real(
430*da0073e9SAndroid Build Coastguard Worker                grouped_params, grouped_grads, grouped_exp_avgs, grouped_exp_avg_sqs
431*da0073e9SAndroid Build Coastguard Worker            )
432*da0073e9SAndroid Build Coastguard Worker
433*da0073e9SAndroid Build Coastguard Worker        if maximize:
434*da0073e9SAndroid Build Coastguard Worker            grouped_grads = torch._foreach_neg(grouped_grads)  # type: ignore[assignment]
435*da0073e9SAndroid Build Coastguard Worker
436*da0073e9SAndroid Build Coastguard Worker        # Update steps
437*da0073e9SAndroid Build Coastguard Worker        # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
438*da0073e9SAndroid Build Coastguard Worker        # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
439*da0073e9SAndroid Build Coastguard Worker        # wrapped it once now. The alpha is required to assure we go to the right overload.
440*da0073e9SAndroid Build Coastguard Worker        if not torch._utils.is_compiling() and grouped_state_steps[0].is_cpu:
441*da0073e9SAndroid Build Coastguard Worker            torch._foreach_add_(
442*da0073e9SAndroid Build Coastguard Worker                grouped_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
443*da0073e9SAndroid Build Coastguard Worker            )
444*da0073e9SAndroid Build Coastguard Worker        else:
445*da0073e9SAndroid Build Coastguard Worker            torch._foreach_add_(grouped_state_steps, 1)
446*da0073e9SAndroid Build Coastguard Worker
447*da0073e9SAndroid Build Coastguard Worker        if weight_decay != 0:
448*da0073e9SAndroid Build Coastguard Worker            if decoupled_weight_decay:
449*da0073e9SAndroid Build Coastguard Worker                # Perform stepweight decay
450*da0073e9SAndroid Build Coastguard Worker                torch._foreach_mul_(grouped_params, 1 - lr * weight_decay)
451*da0073e9SAndroid Build Coastguard Worker            else:
452*da0073e9SAndroid Build Coastguard Worker                # Re-use the intermediate memory (grouped_grads) already allocated for maximize
453*da0073e9SAndroid Build Coastguard Worker                if maximize:
454*da0073e9SAndroid Build Coastguard Worker                    torch._foreach_add_(
455*da0073e9SAndroid Build Coastguard Worker                        grouped_grads, grouped_params, alpha=weight_decay
456*da0073e9SAndroid Build Coastguard Worker                    )
457*da0073e9SAndroid Build Coastguard Worker                else:
458*da0073e9SAndroid Build Coastguard Worker                    grouped_grads = torch._foreach_add(  # type: ignore[assignment]
459*da0073e9SAndroid Build Coastguard Worker                        grouped_grads, grouped_params, alpha=weight_decay
460*da0073e9SAndroid Build Coastguard Worker                    )
461*da0073e9SAndroid Build Coastguard Worker
462*da0073e9SAndroid Build Coastguard Worker        # Decay the first and second moment running average coefficient
463*da0073e9SAndroid Build Coastguard Worker        torch._foreach_lerp_(grouped_exp_avgs, grouped_grads, 1 - beta1)
464*da0073e9SAndroid Build Coastguard Worker
465*da0073e9SAndroid Build Coastguard Worker        torch._foreach_mul_(grouped_exp_avg_sqs, beta2)
466*da0073e9SAndroid Build Coastguard Worker        torch._foreach_addcmul_(
467*da0073e9SAndroid Build Coastguard Worker            grouped_exp_avg_sqs, grouped_grads, grouped_grads, 1 - beta2
468*da0073e9SAndroid Build Coastguard Worker        )
469*da0073e9SAndroid Build Coastguard Worker
470*da0073e9SAndroid Build Coastguard Worker        exp_avg_sq_sqrt = torch._foreach_sqrt(grouped_exp_avg_sqs)
471*da0073e9SAndroid Build Coastguard Worker
472*da0073e9SAndroid Build Coastguard Worker        bias_correction_sqrt: Union[Tuple[Tensor, ...], List[Tensor]]
473*da0073e9SAndroid Build Coastguard Worker        mus: Union[Tuple[Tensor, ...], List[Tensor]]
474*da0073e9SAndroid Build Coastguard Worker        mu_nexts: Union[Tuple[Tensor, ...], List[Tensor]]
475*da0073e9SAndroid Build Coastguard Worker        if capturable:
476*da0073e9SAndroid Build Coastguard Worker            # mus will be beta1 * (1 - 0.5 * 0.96 ** (step * momentum_decay))
477*da0073e9SAndroid Build Coastguard Worker            exponent = torch._foreach_mul(grouped_state_steps, momentum_decay)
478*da0073e9SAndroid Build Coastguard Worker            mus = torch._foreach_pow(0.96, exponent)
479*da0073e9SAndroid Build Coastguard Worker            torch._foreach_mul_(mus, -0.5)
480*da0073e9SAndroid Build Coastguard Worker            torch._foreach_add_(mus, 1.0)
481*da0073e9SAndroid Build Coastguard Worker            torch._foreach_mul_(mus, beta1)
482*da0073e9SAndroid Build Coastguard Worker
483*da0073e9SAndroid Build Coastguard Worker            # mu_nexts will be beta1 * (1 - 0.5 * 0.96 ** ((step + 1) * momentum_decay))
484*da0073e9SAndroid Build Coastguard Worker            torch._foreach_add_(exponent, momentum_decay)
485*da0073e9SAndroid Build Coastguard Worker            mu_nexts = torch._foreach_pow(0.96, exponent)
486*da0073e9SAndroid Build Coastguard Worker            torch._foreach_mul_(mu_nexts, -0.5)
487*da0073e9SAndroid Build Coastguard Worker            torch._foreach_add_(mu_nexts, 1.0)
488*da0073e9SAndroid Build Coastguard Worker            torch._foreach_mul_(mu_nexts, beta1)
489*da0073e9SAndroid Build Coastguard Worker
490*da0073e9SAndroid Build Coastguard Worker            # save peak memory as we don't need exponent anymore
491*da0073e9SAndroid Build Coastguard Worker            del exponent
492*da0073e9SAndroid Build Coastguard Worker
493*da0073e9SAndroid Build Coastguard Worker            bias_correction_sqrt = torch._foreach_pow(beta2, grouped_state_steps)
494*da0073e9SAndroid Build Coastguard Worker            # foreach_sub doesn't allow a scalar as the first arg
495*da0073e9SAndroid Build Coastguard Worker            torch._foreach_sub_(bias_correction_sqrt, 1.0)
496*da0073e9SAndroid Build Coastguard Worker            torch._foreach_neg_(bias_correction_sqrt)
497*da0073e9SAndroid Build Coastguard Worker            torch._foreach_sqrt_(bias_correction_sqrt)
498*da0073e9SAndroid Build Coastguard Worker        else:
499*da0073e9SAndroid Build Coastguard Worker            bias_correction_sqrt = [
500*da0073e9SAndroid Build Coastguard Worker                (1 - beta2 ** _get_value(step)) ** 0.5 for step in grouped_state_steps
501*da0073e9SAndroid Build Coastguard Worker            ]
502*da0073e9SAndroid Build Coastguard Worker            mus = [
503*da0073e9SAndroid Build Coastguard Worker                beta1 * (1.0 - 0.5 * (0.96 ** (_get_value(step) * momentum_decay)))
504*da0073e9SAndroid Build Coastguard Worker                for step in grouped_state_steps
505*da0073e9SAndroid Build Coastguard Worker            ]
506*da0073e9SAndroid Build Coastguard Worker            mu_nexts = [
507*da0073e9SAndroid Build Coastguard Worker                beta1
508*da0073e9SAndroid Build Coastguard Worker                * (1.0 - 0.5 * (0.96 ** ((_get_value(step) + 1) * momentum_decay)))
509*da0073e9SAndroid Build Coastguard Worker                for step in grouped_state_steps
510*da0073e9SAndroid Build Coastguard Worker            ]
511*da0073e9SAndroid Build Coastguard Worker
512*da0073e9SAndroid Build Coastguard Worker        # update mu_products
513*da0073e9SAndroid Build Coastguard Worker        torch._foreach_mul_(grouped_mu_products, mus)
514*da0073e9SAndroid Build Coastguard Worker
515*da0073e9SAndroid Build Coastguard Worker        torch._foreach_div_(exp_avg_sq_sqrt, bias_correction_sqrt)
516*da0073e9SAndroid Build Coastguard Worker        torch._foreach_add_(exp_avg_sq_sqrt, eps)
517*da0073e9SAndroid Build Coastguard Worker
518*da0073e9SAndroid Build Coastguard Worker        # explicitly delete bias_correction refs to save memory
519*da0073e9SAndroid Build Coastguard Worker        del bias_correction_sqrt
520*da0073e9SAndroid Build Coastguard Worker
521*da0073e9SAndroid Build Coastguard Worker        if capturable:
522*da0073e9SAndroid Build Coastguard Worker            # Build up the step_size multiplier for grad, reusing mus' memory
523*da0073e9SAndroid Build Coastguard Worker            torch._foreach_sub_(mus, 1.0)
524*da0073e9SAndroid Build Coastguard Worker            torch._foreach_mul_(mus, lr)
525*da0073e9SAndroid Build Coastguard Worker            # foreach_sub doesn't allow a scalar as the first arg
526*da0073e9SAndroid Build Coastguard Worker            denom = torch._foreach_sub(grouped_mu_products, 1.0)
527*da0073e9SAndroid Build Coastguard Worker            torch._foreach_neg_(denom)
528*da0073e9SAndroid Build Coastguard Worker            torch._foreach_div_(mus, denom)
529*da0073e9SAndroid Build Coastguard Worker            # - lr * (1 - mu) / (1 - mu_product)
530*da0073e9SAndroid Build Coastguard Worker            step_size_grads = mus
531*da0073e9SAndroid Build Coastguard Worker            # explicitly delete denom to save memory
532*da0073e9SAndroid Build Coastguard Worker            del denom
533*da0073e9SAndroid Build Coastguard Worker
534*da0073e9SAndroid Build Coastguard Worker            # Build up the step_size multiplier for exp_avg, reusing mu_nexts' memory
535*da0073e9SAndroid Build Coastguard Worker            denom = torch._foreach_mul(grouped_mu_products, mu_nexts)
536*da0073e9SAndroid Build Coastguard Worker            torch._foreach_mul_(mu_nexts, lr)
537*da0073e9SAndroid Build Coastguard Worker            # foreach_sub doesn't allow a scalar as the first arg, but it's okay because
538*da0073e9SAndroid Build Coastguard Worker            # we need a negative here anyway
539*da0073e9SAndroid Build Coastguard Worker            torch._foreach_sub_(denom, 1.0)
540*da0073e9SAndroid Build Coastguard Worker            torch._foreach_div_(mu_nexts, denom)
541*da0073e9SAndroid Build Coastguard Worker            # - lr * mu_next / (1 - mu_product * mu_next)
542*da0073e9SAndroid Build Coastguard Worker            step_size_expavg = mu_nexts
543*da0073e9SAndroid Build Coastguard Worker            # explicitly delete denom to save memory
544*da0073e9SAndroid Build Coastguard Worker            del denom
545*da0073e9SAndroid Build Coastguard Worker
546*da0073e9SAndroid Build Coastguard Worker            # we cannot inplace into step_size_grads cuz it is a list of ScalarTensors
547*da0073e9SAndroid Build Coastguard Worker            # and mul'ing with grouped_grads will result in a list of bigger Tensors
548*da0073e9SAndroid Build Coastguard Worker            numerator = torch._foreach_mul(step_size_grads, grouped_grads)
549*da0073e9SAndroid Build Coastguard Worker            torch._foreach_addcmul_(numerator, step_size_expavg, grouped_exp_avgs)
550*da0073e9SAndroid Build Coastguard Worker
551*da0073e9SAndroid Build Coastguard Worker            # finally, update params
552*da0073e9SAndroid Build Coastguard Worker            torch._foreach_addcdiv_(grouped_params, numerator, exp_avg_sq_sqrt)
553*da0073e9SAndroid Build Coastguard Worker        else:
554*da0073e9SAndroid Build Coastguard Worker            step_size_grads = _stack_if_compiling(
555*da0073e9SAndroid Build Coastguard Worker                [
556*da0073e9SAndroid Build Coastguard Worker                    (_get_value(lr) * (1.0 - mu) / (1.0 - _get_value(mu_product))) * -1
557*da0073e9SAndroid Build Coastguard Worker                    for mu_product, mu in zip(grouped_mu_products, mus)
558*da0073e9SAndroid Build Coastguard Worker                ]
559*da0073e9SAndroid Build Coastguard Worker            )
560*da0073e9SAndroid Build Coastguard Worker            step_size_expavg = _stack_if_compiling(
561*da0073e9SAndroid Build Coastguard Worker                [
562*da0073e9SAndroid Build Coastguard Worker                    (
563*da0073e9SAndroid Build Coastguard Worker                        _get_value(lr)
564*da0073e9SAndroid Build Coastguard Worker                        * mu_next
565*da0073e9SAndroid Build Coastguard Worker                        / (1.0 - _get_value(mu_product) * mu_next)
566*da0073e9SAndroid Build Coastguard Worker                    )
567*da0073e9SAndroid Build Coastguard Worker                    * -1
568*da0073e9SAndroid Build Coastguard Worker                    for mu_product, mu_next in zip(grouped_mu_products, mu_nexts)
569*da0073e9SAndroid Build Coastguard Worker                ]
570*da0073e9SAndroid Build Coastguard Worker            )
571*da0073e9SAndroid Build Coastguard Worker
572*da0073e9SAndroid Build Coastguard Worker            torch._foreach_addcdiv_(
573*da0073e9SAndroid Build Coastguard Worker                grouped_params, grouped_grads, exp_avg_sq_sqrt, step_size_grads  # type: ignore[arg-type]
574*da0073e9SAndroid Build Coastguard Worker            )
575*da0073e9SAndroid Build Coastguard Worker            torch._foreach_addcdiv_(
576*da0073e9SAndroid Build Coastguard Worker                grouped_params, grouped_exp_avgs, exp_avg_sq_sqrt, step_size_expavg  # type: ignore[arg-type]
577*da0073e9SAndroid Build Coastguard Worker            )
578*da0073e9SAndroid Build Coastguard Worker
579*da0073e9SAndroid Build Coastguard Worker
580*da0073e9SAndroid Build Coastguard Worker@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_nadam)
581*da0073e9SAndroid Build Coastguard Workerdef nadam(
582*da0073e9SAndroid Build Coastguard Worker    params: List[Tensor],
583*da0073e9SAndroid Build Coastguard Worker    grads: List[Tensor],
584*da0073e9SAndroid Build Coastguard Worker    exp_avgs: List[Tensor],
585*da0073e9SAndroid Build Coastguard Worker    exp_avg_sqs: List[Tensor],
586*da0073e9SAndroid Build Coastguard Worker    mu_products: List[Tensor],
587*da0073e9SAndroid Build Coastguard Worker    state_steps: List[Tensor],
588*da0073e9SAndroid Build Coastguard Worker    # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
589*da0073e9SAndroid Build Coastguard Worker    # setting this as kwarg for now as functional API is compiled by torch/distributed/optim
590*da0073e9SAndroid Build Coastguard Worker    decoupled_weight_decay: bool = False,
591*da0073e9SAndroid Build Coastguard Worker    foreach: Optional[bool] = None,
592*da0073e9SAndroid Build Coastguard Worker    capturable: bool = False,
593*da0073e9SAndroid Build Coastguard Worker    differentiable: bool = False,
594*da0073e9SAndroid Build Coastguard Worker    has_complex: bool = False,
595*da0073e9SAndroid Build Coastguard Worker    maximize: bool = False,
596*da0073e9SAndroid Build Coastguard Worker    *,
597*da0073e9SAndroid Build Coastguard Worker    beta1: float,
598*da0073e9SAndroid Build Coastguard Worker    beta2: float,
599*da0073e9SAndroid Build Coastguard Worker    lr: float,
600*da0073e9SAndroid Build Coastguard Worker    weight_decay: float,
601*da0073e9SAndroid Build Coastguard Worker    momentum_decay: float,
602*da0073e9SAndroid Build Coastguard Worker    eps: float,
603*da0073e9SAndroid Build Coastguard Worker):
604*da0073e9SAndroid Build Coastguard Worker    r"""Functional API that performs NAdam algorithm computation.
605*da0073e9SAndroid Build Coastguard Worker
606*da0073e9SAndroid Build Coastguard Worker    See :class:`~torch.optim.NAdam` for details.
607*da0073e9SAndroid Build Coastguard Worker    """
608*da0073e9SAndroid Build Coastguard Worker    if not all(isinstance(t, torch.Tensor) for t in state_steps):
609*da0073e9SAndroid Build Coastguard Worker        raise RuntimeError(
610*da0073e9SAndroid Build Coastguard Worker            "API has changed, `state_steps` argument must contain a list of singleton tensors"
611*da0073e9SAndroid Build Coastguard Worker        )
612*da0073e9SAndroid Build Coastguard Worker
613*da0073e9SAndroid Build Coastguard Worker    if not all(isinstance(t, torch.Tensor) for t in mu_products):
614*da0073e9SAndroid Build Coastguard Worker        raise RuntimeError(
615*da0073e9SAndroid Build Coastguard Worker            "API has changed, `mu_products` argument must contain a list of singleton tensors"
616*da0073e9SAndroid Build Coastguard Worker        )
617*da0073e9SAndroid Build Coastguard Worker
618*da0073e9SAndroid Build Coastguard Worker    if foreach is None:
619*da0073e9SAndroid Build Coastguard Worker        _, foreach = _default_to_fused_or_foreach(
620*da0073e9SAndroid Build Coastguard Worker            params, differentiable, use_fused=False
621*da0073e9SAndroid Build Coastguard Worker        )
622*da0073e9SAndroid Build Coastguard Worker
623*da0073e9SAndroid Build Coastguard Worker    if foreach and torch.jit.is_scripting():
624*da0073e9SAndroid Build Coastguard Worker        raise RuntimeError("torch.jit.script not supported with foreach optimizers")
625*da0073e9SAndroid Build Coastguard Worker
626*da0073e9SAndroid Build Coastguard Worker    if foreach and not torch.jit.is_scripting():
627*da0073e9SAndroid Build Coastguard Worker        func = _multi_tensor_nadam
628*da0073e9SAndroid Build Coastguard Worker    else:
629*da0073e9SAndroid Build Coastguard Worker        func = _single_tensor_nadam
630*da0073e9SAndroid Build Coastguard Worker
631*da0073e9SAndroid Build Coastguard Worker    func(
632*da0073e9SAndroid Build Coastguard Worker        params,
633*da0073e9SAndroid Build Coastguard Worker        grads,
634*da0073e9SAndroid Build Coastguard Worker        exp_avgs,
635*da0073e9SAndroid Build Coastguard Worker        exp_avg_sqs,
636*da0073e9SAndroid Build Coastguard Worker        mu_products,
637*da0073e9SAndroid Build Coastguard Worker        state_steps,
638*da0073e9SAndroid Build Coastguard Worker        beta1=beta1,
639*da0073e9SAndroid Build Coastguard Worker        beta2=beta2,
640*da0073e9SAndroid Build Coastguard Worker        lr=lr,
641*da0073e9SAndroid Build Coastguard Worker        weight_decay=weight_decay,
642*da0073e9SAndroid Build Coastguard Worker        momentum_decay=momentum_decay,
643*da0073e9SAndroid Build Coastguard Worker        maximize=maximize,
644*da0073e9SAndroid Build Coastguard Worker        decoupled_weight_decay=decoupled_weight_decay,
645*da0073e9SAndroid Build Coastguard Worker        eps=eps,
646*da0073e9SAndroid Build Coastguard Worker        capturable=capturable,
647*da0073e9SAndroid Build Coastguard Worker        differentiable=differentiable,
648*da0073e9SAndroid Build Coastguard Worker        has_complex=has_complex,
649*da0073e9SAndroid Build Coastguard Worker    )
650