xref: /aosp_15_r20/external/pytorch/torch/optim/radam.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-decorators
2*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs
3*da0073e9SAndroid Build Coastguard Workerr"""Implementation for the RAdam algorithm."""
4*da0073e9SAndroid Build Coastguard Workerfrom typing import cast, List, Optional, Tuple, Union
5*da0073e9SAndroid Build Coastguard Worker
6*da0073e9SAndroid Build Coastguard Workerimport torch
7*da0073e9SAndroid Build Coastguard Workerfrom torch import Tensor
8*da0073e9SAndroid Build Coastguard Worker
9*da0073e9SAndroid Build Coastguard Workerfrom .optimizer import (
10*da0073e9SAndroid Build Coastguard Worker    _capturable_doc,
11*da0073e9SAndroid Build Coastguard Worker    _default_to_fused_or_foreach,
12*da0073e9SAndroid Build Coastguard Worker    _differentiable_doc,
13*da0073e9SAndroid Build Coastguard Worker    _disable_dynamo_if_unsupported,
14*da0073e9SAndroid Build Coastguard Worker    _foreach_doc,
15*da0073e9SAndroid Build Coastguard Worker    _get_capturable_supported_devices,
16*da0073e9SAndroid Build Coastguard Worker    _get_scalar_dtype,
17*da0073e9SAndroid Build Coastguard Worker    _get_value,
18*da0073e9SAndroid Build Coastguard Worker    _maximize_doc,
19*da0073e9SAndroid Build Coastguard Worker    _use_grad_for_differentiable,
20*da0073e9SAndroid Build Coastguard Worker    _view_as_real,
21*da0073e9SAndroid Build Coastguard Worker    Optimizer,
22*da0073e9SAndroid Build Coastguard Worker    ParamsT,
23*da0073e9SAndroid Build Coastguard Worker)
24*da0073e9SAndroid Build Coastguard Worker
25*da0073e9SAndroid Build Coastguard Worker
26*da0073e9SAndroid Build Coastguard Worker__all__ = ["RAdam", "radam"]
27*da0073e9SAndroid Build Coastguard Worker
28*da0073e9SAndroid Build Coastguard Worker
29*da0073e9SAndroid Build Coastguard Workerclass RAdam(Optimizer):  # noqa: D101
30*da0073e9SAndroid Build Coastguard Worker    def __init__(
31*da0073e9SAndroid Build Coastguard Worker        self,
32*da0073e9SAndroid Build Coastguard Worker        params: ParamsT,
33*da0073e9SAndroid Build Coastguard Worker        lr: Union[float, Tensor] = 1e-3,
34*da0073e9SAndroid Build Coastguard Worker        betas: Tuple[float, float] = (0.9, 0.999),
35*da0073e9SAndroid Build Coastguard Worker        eps: float = 1e-8,
36*da0073e9SAndroid Build Coastguard Worker        weight_decay: float = 0,
37*da0073e9SAndroid Build Coastguard Worker        decoupled_weight_decay: bool = False,
38*da0073e9SAndroid Build Coastguard Worker        *,
39*da0073e9SAndroid Build Coastguard Worker        foreach: Optional[bool] = None,
40*da0073e9SAndroid Build Coastguard Worker        maximize: bool = False,
41*da0073e9SAndroid Build Coastguard Worker        capturable: bool = False,
42*da0073e9SAndroid Build Coastguard Worker        differentiable: bool = False,
43*da0073e9SAndroid Build Coastguard Worker    ):  # noqa: D107
44*da0073e9SAndroid Build Coastguard Worker        if isinstance(lr, Tensor) and lr.numel() != 1:
45*da0073e9SAndroid Build Coastguard Worker            raise ValueError("Tensor lr must be 1-element")
46*da0073e9SAndroid Build Coastguard Worker        if not 0.0 <= lr:
47*da0073e9SAndroid Build Coastguard Worker            raise ValueError(f"Invalid learning rate: {lr}")
48*da0073e9SAndroid Build Coastguard Worker        if not 0.0 <= eps:
49*da0073e9SAndroid Build Coastguard Worker            raise ValueError(f"Invalid epsilon value: {eps}")
50*da0073e9SAndroid Build Coastguard Worker        if not 0.0 <= betas[0] < 1.0:
51*da0073e9SAndroid Build Coastguard Worker            raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
52*da0073e9SAndroid Build Coastguard Worker        if not 0.0 <= betas[1] < 1.0:
53*da0073e9SAndroid Build Coastguard Worker            raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
54*da0073e9SAndroid Build Coastguard Worker        if not 0.0 <= weight_decay:
55*da0073e9SAndroid Build Coastguard Worker            raise ValueError(f"Invalid weight_decay value: {weight_decay}")
56*da0073e9SAndroid Build Coastguard Worker
57*da0073e9SAndroid Build Coastguard Worker        defaults = dict(
58*da0073e9SAndroid Build Coastguard Worker            lr=lr,
59*da0073e9SAndroid Build Coastguard Worker            betas=betas,
60*da0073e9SAndroid Build Coastguard Worker            eps=eps,
61*da0073e9SAndroid Build Coastguard Worker            weight_decay=weight_decay,
62*da0073e9SAndroid Build Coastguard Worker            maximize=maximize,
63*da0073e9SAndroid Build Coastguard Worker            foreach=foreach,
64*da0073e9SAndroid Build Coastguard Worker            capturable=capturable,
65*da0073e9SAndroid Build Coastguard Worker            decoupled_weight_decay=decoupled_weight_decay,
66*da0073e9SAndroid Build Coastguard Worker            differentiable=differentiable,
67*da0073e9SAndroid Build Coastguard Worker        )
68*da0073e9SAndroid Build Coastguard Worker        super().__init__(params, defaults)
69*da0073e9SAndroid Build Coastguard Worker
70*da0073e9SAndroid Build Coastguard Worker    def __setstate__(self, state):  # noqa: D105
71*da0073e9SAndroid Build Coastguard Worker        super().__setstate__(state)
72*da0073e9SAndroid Build Coastguard Worker        for group in self.param_groups:
73*da0073e9SAndroid Build Coastguard Worker            group.setdefault("foreach", None)
74*da0073e9SAndroid Build Coastguard Worker            group.setdefault("maximize", False)
75*da0073e9SAndroid Build Coastguard Worker            group.setdefault("differentiable", False)
76*da0073e9SAndroid Build Coastguard Worker            group.setdefault("decoupled_weight_decay", False)
77*da0073e9SAndroid Build Coastguard Worker            group.setdefault("capturable", False)
78*da0073e9SAndroid Build Coastguard Worker            for p in group["params"]:
79*da0073e9SAndroid Build Coastguard Worker                p_state = self.state.get(p, [])
80*da0073e9SAndroid Build Coastguard Worker                if len(p_state) != 0 and not torch.is_tensor(p_state["step"]):
81*da0073e9SAndroid Build Coastguard Worker                    step_val = float(p_state["step"])
82*da0073e9SAndroid Build Coastguard Worker                    p_state["step"] = (
83*da0073e9SAndroid Build Coastguard Worker                        torch.tensor(
84*da0073e9SAndroid Build Coastguard Worker                            step_val, dtype=_get_scalar_dtype(), device=p.device
85*da0073e9SAndroid Build Coastguard Worker                        )
86*da0073e9SAndroid Build Coastguard Worker                        if group["capturable"]
87*da0073e9SAndroid Build Coastguard Worker                        else torch.tensor(step_val, dtype=_get_scalar_dtype())
88*da0073e9SAndroid Build Coastguard Worker                    )
89*da0073e9SAndroid Build Coastguard Worker
90*da0073e9SAndroid Build Coastguard Worker    def _init_group(
91*da0073e9SAndroid Build Coastguard Worker        self, group, params_with_grad, grads, exp_avgs, exp_avg_sqs, state_steps
92*da0073e9SAndroid Build Coastguard Worker    ):
93*da0073e9SAndroid Build Coastguard Worker        has_complex = False
94*da0073e9SAndroid Build Coastguard Worker        for p in group["params"]:
95*da0073e9SAndroid Build Coastguard Worker            if p.grad is not None:
96*da0073e9SAndroid Build Coastguard Worker                has_complex |= torch.is_complex(p)
97*da0073e9SAndroid Build Coastguard Worker                params_with_grad.append(p)
98*da0073e9SAndroid Build Coastguard Worker                if p.grad.is_sparse:
99*da0073e9SAndroid Build Coastguard Worker                    raise RuntimeError("RAdam does not support sparse gradients")
100*da0073e9SAndroid Build Coastguard Worker                grads.append(p.grad)
101*da0073e9SAndroid Build Coastguard Worker
102*da0073e9SAndroid Build Coastguard Worker                state = self.state[p]
103*da0073e9SAndroid Build Coastguard Worker                # Lazy state initialization
104*da0073e9SAndroid Build Coastguard Worker                if len(state) == 0:
105*da0073e9SAndroid Build Coastguard Worker                    state["step"] = (
106*da0073e9SAndroid Build Coastguard Worker                        torch.zeros((), dtype=_get_scalar_dtype(), device=p.device)
107*da0073e9SAndroid Build Coastguard Worker                        if group["capturable"]
108*da0073e9SAndroid Build Coastguard Worker                        else torch.tensor(0.0, dtype=_get_scalar_dtype())
109*da0073e9SAndroid Build Coastguard Worker                    )
110*da0073e9SAndroid Build Coastguard Worker                    # Exponential moving average of gradient values
111*da0073e9SAndroid Build Coastguard Worker                    state["exp_avg"] = torch.zeros_like(
112*da0073e9SAndroid Build Coastguard Worker                        p, memory_format=torch.preserve_format
113*da0073e9SAndroid Build Coastguard Worker                    )
114*da0073e9SAndroid Build Coastguard Worker                    # Exponential moving average of squared gradient values
115*da0073e9SAndroid Build Coastguard Worker                    state["exp_avg_sq"] = torch.zeros_like(
116*da0073e9SAndroid Build Coastguard Worker                        p, memory_format=torch.preserve_format
117*da0073e9SAndroid Build Coastguard Worker                    )
118*da0073e9SAndroid Build Coastguard Worker
119*da0073e9SAndroid Build Coastguard Worker                exp_avgs.append(state["exp_avg"])
120*da0073e9SAndroid Build Coastguard Worker                exp_avg_sqs.append(state["exp_avg_sq"])
121*da0073e9SAndroid Build Coastguard Worker                state_steps.append(state["step"])
122*da0073e9SAndroid Build Coastguard Worker
123*da0073e9SAndroid Build Coastguard Worker        return has_complex
124*da0073e9SAndroid Build Coastguard Worker
125*da0073e9SAndroid Build Coastguard Worker    @_use_grad_for_differentiable
126*da0073e9SAndroid Build Coastguard Worker    def step(self, closure=None):
127*da0073e9SAndroid Build Coastguard Worker        """Perform a single optimization step.
128*da0073e9SAndroid Build Coastguard Worker
129*da0073e9SAndroid Build Coastguard Worker        Args:
130*da0073e9SAndroid Build Coastguard Worker            closure (Callable, optional): A closure that reevaluates the model
131*da0073e9SAndroid Build Coastguard Worker                and returns the loss.
132*da0073e9SAndroid Build Coastguard Worker        """
133*da0073e9SAndroid Build Coastguard Worker        self._cuda_graph_capture_health_check()
134*da0073e9SAndroid Build Coastguard Worker
135*da0073e9SAndroid Build Coastguard Worker        loss = None
136*da0073e9SAndroid Build Coastguard Worker        if closure is not None:
137*da0073e9SAndroid Build Coastguard Worker            with torch.enable_grad():
138*da0073e9SAndroid Build Coastguard Worker                loss = closure()
139*da0073e9SAndroid Build Coastguard Worker
140*da0073e9SAndroid Build Coastguard Worker        for group in self.param_groups:
141*da0073e9SAndroid Build Coastguard Worker            params_with_grad: List[Tensor] = []
142*da0073e9SAndroid Build Coastguard Worker            grads: List[Tensor] = []
143*da0073e9SAndroid Build Coastguard Worker            exp_avgs: List[Tensor] = []
144*da0073e9SAndroid Build Coastguard Worker            exp_avg_sqs: List[Tensor] = []
145*da0073e9SAndroid Build Coastguard Worker            state_steps: List[Tensor] = []
146*da0073e9SAndroid Build Coastguard Worker            beta1, beta2 = cast(Tuple[float, float], group["betas"])
147*da0073e9SAndroid Build Coastguard Worker
148*da0073e9SAndroid Build Coastguard Worker            has_complex = self._init_group(
149*da0073e9SAndroid Build Coastguard Worker                group, params_with_grad, grads, exp_avgs, exp_avg_sqs, state_steps
150*da0073e9SAndroid Build Coastguard Worker            )
151*da0073e9SAndroid Build Coastguard Worker
152*da0073e9SAndroid Build Coastguard Worker            radam(
153*da0073e9SAndroid Build Coastguard Worker                params_with_grad,
154*da0073e9SAndroid Build Coastguard Worker                grads,
155*da0073e9SAndroid Build Coastguard Worker                exp_avgs,
156*da0073e9SAndroid Build Coastguard Worker                exp_avg_sqs,
157*da0073e9SAndroid Build Coastguard Worker                state_steps,
158*da0073e9SAndroid Build Coastguard Worker                beta1=beta1,
159*da0073e9SAndroid Build Coastguard Worker                beta2=beta2,
160*da0073e9SAndroid Build Coastguard Worker                lr=group["lr"],
161*da0073e9SAndroid Build Coastguard Worker                weight_decay=group["weight_decay"],
162*da0073e9SAndroid Build Coastguard Worker                eps=group["eps"],
163*da0073e9SAndroid Build Coastguard Worker                maximize=group["maximize"],
164*da0073e9SAndroid Build Coastguard Worker                foreach=group["foreach"],
165*da0073e9SAndroid Build Coastguard Worker                capturable=group["capturable"],
166*da0073e9SAndroid Build Coastguard Worker                differentiable=group["differentiable"],
167*da0073e9SAndroid Build Coastguard Worker                decoupled_weight_decay=group["decoupled_weight_decay"],
168*da0073e9SAndroid Build Coastguard Worker                has_complex=has_complex,
169*da0073e9SAndroid Build Coastguard Worker            )
170*da0073e9SAndroid Build Coastguard Worker
171*da0073e9SAndroid Build Coastguard Worker        return loss
172*da0073e9SAndroid Build Coastguard Worker
173*da0073e9SAndroid Build Coastguard Worker
174*da0073e9SAndroid Build Coastguard WorkerRAdam.__doc__ = (
175*da0073e9SAndroid Build Coastguard Worker    r"""Implements RAdam algorithm.
176*da0073e9SAndroid Build Coastguard Worker
177*da0073e9SAndroid Build Coastguard Worker    .. math::
178*da0073e9SAndroid Build Coastguard Worker       \begin{aligned}
179*da0073e9SAndroid Build Coastguard Worker            &\rule{110mm}{0.4pt}                                                                 \\
180*da0073e9SAndroid Build Coastguard Worker            &\textbf{input}      : \gamma \text{ (lr)}, \: \beta_1, \beta_2
181*da0073e9SAndroid Build Coastguard Worker                \text{ (betas)}, \: \theta_0 \text{ (params)}, \:f(\theta) \text{ (objective)}, \:
182*da0073e9SAndroid Build Coastguard Worker                \lambda \text{ (weightdecay)}, \:\textit{maximize}                               \\
183*da0073e9SAndroid Build Coastguard Worker            &\hspace{13mm} \epsilon \text{ (epsilon)}, \textit{decoupled\_weight\_decay}         \\
184*da0073e9SAndroid Build Coastguard Worker            &\textbf{initialize} :  m_0 \leftarrow 0 \text{ ( first moment)},
185*da0073e9SAndroid Build Coastguard Worker                v_0 \leftarrow 0 \text{ ( second moment)},                                       \\
186*da0073e9SAndroid Build Coastguard Worker            &\hspace{18mm} \rho_{\infty} \leftarrow 2/(1-\beta_2) -1                      \\[-1.ex]
187*da0073e9SAndroid Build Coastguard Worker            &\rule{110mm}{0.4pt}  \\
188*da0073e9SAndroid Build Coastguard Worker            &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do}                         \\
189*da0073e9SAndroid Build Coastguard Worker            &\hspace{6mm}\textbf{if} \: \textit{maximize}:                                       \\
190*da0073e9SAndroid Build Coastguard Worker            &\hspace{12mm}g_t           \leftarrow   -\nabla_{\theta} f_t (\theta_{t-1})         \\
191*da0073e9SAndroid Build Coastguard Worker            &\hspace{6mm}\textbf{else}                                                           \\
192*da0073e9SAndroid Build Coastguard Worker            &\hspace{12mm}g_t           \leftarrow   \nabla_{\theta} f_t (\theta_{t-1})          \\
193*da0073e9SAndroid Build Coastguard Worker            &\hspace{6mm} \theta_t \leftarrow \theta_{t-1}                                       \\
194*da0073e9SAndroid Build Coastguard Worker            &\hspace{6mm} \textbf{if} \: \lambda \neq 0                                          \\
195*da0073e9SAndroid Build Coastguard Worker            &\hspace{12mm}\textbf{if} \: \textit{decoupled\_weight\_decay}                       \\
196*da0073e9SAndroid Build Coastguard Worker            &\hspace{18mm} \theta_t \leftarrow \theta_{t} - \gamma \lambda \theta_{t}            \\
197*da0073e9SAndroid Build Coastguard Worker            &\hspace{12mm}\textbf{else}                                                          \\
198*da0073e9SAndroid Build Coastguard Worker            &\hspace{18mm} g_t \leftarrow g_t + \lambda \theta_{t}                               \\
199*da0073e9SAndroid Build Coastguard Worker            &\hspace{6mm}m_t           \leftarrow   \beta_1 m_{t-1} + (1 - \beta_1) g_t          \\
200*da0073e9SAndroid Build Coastguard Worker            &\hspace{6mm}v_t           \leftarrow   \beta_2 v_{t-1} + (1-\beta_2) g^2_t          \\
201*da0073e9SAndroid Build Coastguard Worker            &\hspace{6mm}\widehat{m_t} \leftarrow   m_t/\big(1-\beta_1^t \big)                   \\
202*da0073e9SAndroid Build Coastguard Worker            &\hspace{6mm}\rho_t \leftarrow \rho_{\infty} -
203*da0073e9SAndroid Build Coastguard Worker                2 t \beta^t_2 /\big(1-\beta_2^t \big)                                    \\[0.1.ex]
204*da0073e9SAndroid Build Coastguard Worker            &\hspace{6mm}\textbf{if} \: \rho_t > 5                                               \\
205*da0073e9SAndroid Build Coastguard Worker            &\hspace{12mm} l_t \leftarrow \frac{\sqrt{ (1-\beta^t_2) }}{ \sqrt{v_t} +\epsilon  } \\
206*da0073e9SAndroid Build Coastguard Worker            &\hspace{12mm} r_t \leftarrow
207*da0073e9SAndroid Build Coastguard Worker      \sqrt{\frac{(\rho_t-4)(\rho_t-2)\rho_{\infty}}{(\rho_{\infty}-4)(\rho_{\infty}-2) \rho_t}} \\
208*da0073e9SAndroid Build Coastguard Worker            &\hspace{12mm}\theta_t \leftarrow \theta_t - \gamma \widehat{m_t} r_t l_t        \\
209*da0073e9SAndroid Build Coastguard Worker            &\hspace{6mm}\textbf{else}                                                           \\
210*da0073e9SAndroid Build Coastguard Worker            &\hspace{12mm}\theta_t \leftarrow \theta_t - \gamma \widehat{m_t}                \\
211*da0073e9SAndroid Build Coastguard Worker            &\rule{110mm}{0.4pt}                                                          \\[-1.ex]
212*da0073e9SAndroid Build Coastguard Worker            &\bf{return} \:  \theta_t                                                     \\[-1.ex]
213*da0073e9SAndroid Build Coastguard Worker            &\rule{110mm}{0.4pt}                                                          \\[-1.ex]
214*da0073e9SAndroid Build Coastguard Worker       \end{aligned}
215*da0073e9SAndroid Build Coastguard Worker
216*da0073e9SAndroid Build Coastguard Worker    For further details regarding the algorithm we refer to `On the variance of the adaptive learning rate and beyond`_.
217*da0073e9SAndroid Build Coastguard Worker
218*da0073e9SAndroid Build Coastguard Worker    This implementation provides an option to use either the original weight_decay implementation as in Adam
219*da0073e9SAndroid Build Coastguard Worker    (where the weight_decay is applied to the gradient) or the one from AdamW (where weight_decay is applied
220*da0073e9SAndroid Build Coastguard Worker    to the weight) through the decoupled_weight_decay option. When decoupled_weight_decay is set to False
221*da0073e9SAndroid Build Coastguard Worker    (default), it uses the original Adam style weight decay, otherwise, it uses the AdamW style which
222*da0073e9SAndroid Build Coastguard Worker    corresponds more closely to the `author's implementation`_ in the RAdam paper. Further information
223*da0073e9SAndroid Build Coastguard Worker    about decoupled weight decay can be found in `Decoupled Weight Decay Regularization`_.
224*da0073e9SAndroid Build Coastguard Worker
225*da0073e9SAndroid Build Coastguard Worker    """
226*da0073e9SAndroid Build Coastguard Worker    + rf"""
227*da0073e9SAndroid Build Coastguard Worker    Args:
228*da0073e9SAndroid Build Coastguard Worker        params (iterable): iterable of parameters to optimize or dicts defining
229*da0073e9SAndroid Build Coastguard Worker            parameter groups
230*da0073e9SAndroid Build Coastguard Worker        lr (float, Tensor, optional): learning rate (default: 1e-3)
231*da0073e9SAndroid Build Coastguard Worker        betas (Tuple[float, float], optional): coefficients used for computing
232*da0073e9SAndroid Build Coastguard Worker            running averages of gradient and its square (default: (0.9, 0.999))
233*da0073e9SAndroid Build Coastguard Worker        eps (float, optional): term added to the denominator to improve
234*da0073e9SAndroid Build Coastguard Worker            numerical stability (default: 1e-8)
235*da0073e9SAndroid Build Coastguard Worker        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
236*da0073e9SAndroid Build Coastguard Worker        decoupled_weight_decay (bool, optional): whether to use decoupled weight
237*da0073e9SAndroid Build Coastguard Worker            decay as in AdamW to obtain RAdamW (default: False)
238*da0073e9SAndroid Build Coastguard Worker        {_foreach_doc}
239*da0073e9SAndroid Build Coastguard Worker        {_maximize_doc}
240*da0073e9SAndroid Build Coastguard Worker        {_differentiable_doc}
241*da0073e9SAndroid Build Coastguard Worker        {_capturable_doc}
242*da0073e9SAndroid Build Coastguard Worker
243*da0073e9SAndroid Build Coastguard Worker    .. _On the variance of the adaptive learning rate and beyond:
244*da0073e9SAndroid Build Coastguard Worker        https://arxiv.org/abs/1908.03265
245*da0073e9SAndroid Build Coastguard Worker    .. _author's implementation:
246*da0073e9SAndroid Build Coastguard Worker        https://github.com/LiyuanLucasLiu/RAdam
247*da0073e9SAndroid Build Coastguard Worker    .. _Decoupled Weight Decay Regularization:
248*da0073e9SAndroid Build Coastguard Worker        https://arxiv.org/abs/1711.05101
249*da0073e9SAndroid Build Coastguard Worker
250*da0073e9SAndroid Build Coastguard Worker    """
251*da0073e9SAndroid Build Coastguard Worker)
252*da0073e9SAndroid Build Coastguard Worker
253*da0073e9SAndroid Build Coastguard Worker
254*da0073e9SAndroid Build Coastguard Workerdef _single_tensor_radam(
255*da0073e9SAndroid Build Coastguard Worker    params: List[Tensor],
256*da0073e9SAndroid Build Coastguard Worker    grads: List[Tensor],
257*da0073e9SAndroid Build Coastguard Worker    exp_avgs: List[Tensor],
258*da0073e9SAndroid Build Coastguard Worker    exp_avg_sqs: List[Tensor],
259*da0073e9SAndroid Build Coastguard Worker    state_steps: List[Tensor],
260*da0073e9SAndroid Build Coastguard Worker    *,
261*da0073e9SAndroid Build Coastguard Worker    beta1: float,
262*da0073e9SAndroid Build Coastguard Worker    beta2: float,
263*da0073e9SAndroid Build Coastguard Worker    lr: float,
264*da0073e9SAndroid Build Coastguard Worker    weight_decay: float,
265*da0073e9SAndroid Build Coastguard Worker    eps: float,
266*da0073e9SAndroid Build Coastguard Worker    decoupled_weight_decay: bool,
267*da0073e9SAndroid Build Coastguard Worker    differentiable: bool,
268*da0073e9SAndroid Build Coastguard Worker    maximize: bool,
269*da0073e9SAndroid Build Coastguard Worker    capturable: bool,
270*da0073e9SAndroid Build Coastguard Worker    has_complex: bool,
271*da0073e9SAndroid Build Coastguard Worker):
272*da0073e9SAndroid Build Coastguard Worker    for i, param in enumerate(params):
273*da0073e9SAndroid Build Coastguard Worker        grad = grads[i] if not maximize else -grads[i]
274*da0073e9SAndroid Build Coastguard Worker        exp_avg = exp_avgs[i]
275*da0073e9SAndroid Build Coastguard Worker        exp_avg_sq = exp_avg_sqs[i]
276*da0073e9SAndroid Build Coastguard Worker        step_t = state_steps[i]
277*da0073e9SAndroid Build Coastguard Worker
278*da0073e9SAndroid Build Coastguard Worker        # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
279*da0073e9SAndroid Build Coastguard Worker        if not torch._utils.is_compiling() and capturable:
280*da0073e9SAndroid Build Coastguard Worker            capturable_supported_devices = _get_capturable_supported_devices()
281*da0073e9SAndroid Build Coastguard Worker            assert (
282*da0073e9SAndroid Build Coastguard Worker                param.device.type == step_t.device.type
283*da0073e9SAndroid Build Coastguard Worker                and param.device.type in capturable_supported_devices
284*da0073e9SAndroid Build Coastguard Worker            ), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
285*da0073e9SAndroid Build Coastguard Worker
286*da0073e9SAndroid Build Coastguard Worker        if torch.is_complex(param):
287*da0073e9SAndroid Build Coastguard Worker            param = torch.view_as_real(param)
288*da0073e9SAndroid Build Coastguard Worker            grad = torch.view_as_real(grad)
289*da0073e9SAndroid Build Coastguard Worker            exp_avg = torch.view_as_real(exp_avg)
290*da0073e9SAndroid Build Coastguard Worker            exp_avg_sq = torch.view_as_real(exp_avg_sq)
291*da0073e9SAndroid Build Coastguard Worker
292*da0073e9SAndroid Build Coastguard Worker        # update step
293*da0073e9SAndroid Build Coastguard Worker        step_t += 1
294*da0073e9SAndroid Build Coastguard Worker        step = step_t if capturable else _get_value(step_t)
295*da0073e9SAndroid Build Coastguard Worker
296*da0073e9SAndroid Build Coastguard Worker        if weight_decay != 0:
297*da0073e9SAndroid Build Coastguard Worker            if decoupled_weight_decay:
298*da0073e9SAndroid Build Coastguard Worker                param.mul_(1 - lr * weight_decay)
299*da0073e9SAndroid Build Coastguard Worker            else:
300*da0073e9SAndroid Build Coastguard Worker                grad = grad.add(param, alpha=weight_decay)
301*da0073e9SAndroid Build Coastguard Worker
302*da0073e9SAndroid Build Coastguard Worker        # Decay the first and second moment running average coefficient
303*da0073e9SAndroid Build Coastguard Worker        exp_avg.lerp_(grad, 1 - beta1)
304*da0073e9SAndroid Build Coastguard Worker        exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
305*da0073e9SAndroid Build Coastguard Worker
306*da0073e9SAndroid Build Coastguard Worker        bias_correction1 = 1 - beta1**step
307*da0073e9SAndroid Build Coastguard Worker        bias_correction2 = 1 - beta2**step
308*da0073e9SAndroid Build Coastguard Worker
309*da0073e9SAndroid Build Coastguard Worker        # correcting bias for the first moving moment
310*da0073e9SAndroid Build Coastguard Worker        bias_corrected_exp_avg = exp_avg / bias_correction1
311*da0073e9SAndroid Build Coastguard Worker
312*da0073e9SAndroid Build Coastguard Worker        # maximum length of the approximated SMA
313*da0073e9SAndroid Build Coastguard Worker        rho_inf = 2 / (1 - beta2) - 1
314*da0073e9SAndroid Build Coastguard Worker        # compute the length of the approximated SMA
315*da0073e9SAndroid Build Coastguard Worker        rho_t = rho_inf - 2 * step * (beta2**step) / bias_correction2
316*da0073e9SAndroid Build Coastguard Worker
317*da0073e9SAndroid Build Coastguard Worker        def _compute_rect():
318*da0073e9SAndroid Build Coastguard Worker            return (
319*da0073e9SAndroid Build Coastguard Worker                (rho_t - 4)
320*da0073e9SAndroid Build Coastguard Worker                * (rho_t - 2)
321*da0073e9SAndroid Build Coastguard Worker                * rho_inf
322*da0073e9SAndroid Build Coastguard Worker                / ((rho_inf - 4) * (rho_inf - 2) * rho_t)
323*da0073e9SAndroid Build Coastguard Worker            ) ** 0.5
324*da0073e9SAndroid Build Coastguard Worker
325*da0073e9SAndroid Build Coastguard Worker        def _compute_adaptive_lr():
326*da0073e9SAndroid Build Coastguard Worker            exp_avg_sq_sqrt = exp_avg_sq.sqrt()
327*da0073e9SAndroid Build Coastguard Worker            if differentiable:
328*da0073e9SAndroid Build Coastguard Worker                exp_avg_sq_sqrt = exp_avg_sq_sqrt.add(eps)
329*da0073e9SAndroid Build Coastguard Worker            else:
330*da0073e9SAndroid Build Coastguard Worker                exp_avg_sq_sqrt = exp_avg_sq_sqrt.add_(eps)
331*da0073e9SAndroid Build Coastguard Worker
332*da0073e9SAndroid Build Coastguard Worker            return (bias_correction2**0.5) / exp_avg_sq_sqrt
333*da0073e9SAndroid Build Coastguard Worker
334*da0073e9SAndroid Build Coastguard Worker        # Compute the variance rectification term and update parameters accordingly
335*da0073e9SAndroid Build Coastguard Worker        if capturable:
336*da0073e9SAndroid Build Coastguard Worker            update = torch.where(
337*da0073e9SAndroid Build Coastguard Worker                rho_t > 5.0, _compute_rect() * _compute_adaptive_lr(), 1.0
338*da0073e9SAndroid Build Coastguard Worker            )
339*da0073e9SAndroid Build Coastguard Worker            param.add_(bias_corrected_exp_avg * lr * update, alpha=-1.0)
340*da0073e9SAndroid Build Coastguard Worker        else:
341*da0073e9SAndroid Build Coastguard Worker            if rho_t > 5.0:
342*da0073e9SAndroid Build Coastguard Worker                param.add_(
343*da0073e9SAndroid Build Coastguard Worker                    bias_corrected_exp_avg
344*da0073e9SAndroid Build Coastguard Worker                    * lr
345*da0073e9SAndroid Build Coastguard Worker                    * _compute_adaptive_lr()
346*da0073e9SAndroid Build Coastguard Worker                    * _compute_rect(),
347*da0073e9SAndroid Build Coastguard Worker                    alpha=-1.0,
348*da0073e9SAndroid Build Coastguard Worker                )
349*da0073e9SAndroid Build Coastguard Worker            else:
350*da0073e9SAndroid Build Coastguard Worker                param.add_(bias_corrected_exp_avg * lr, alpha=-1.0)
351*da0073e9SAndroid Build Coastguard Worker
352*da0073e9SAndroid Build Coastguard Worker
353*da0073e9SAndroid Build Coastguard Workerdef _multi_tensor_radam(
354*da0073e9SAndroid Build Coastguard Worker    params: List[Tensor],
355*da0073e9SAndroid Build Coastguard Worker    grads: List[Tensor],
356*da0073e9SAndroid Build Coastguard Worker    exp_avgs: List[Tensor],
357*da0073e9SAndroid Build Coastguard Worker    exp_avg_sqs: List[Tensor],
358*da0073e9SAndroid Build Coastguard Worker    state_steps: List[Tensor],
359*da0073e9SAndroid Build Coastguard Worker    *,
360*da0073e9SAndroid Build Coastguard Worker    beta1: float,
361*da0073e9SAndroid Build Coastguard Worker    beta2: float,
362*da0073e9SAndroid Build Coastguard Worker    lr: float,
363*da0073e9SAndroid Build Coastguard Worker    weight_decay: float,
364*da0073e9SAndroid Build Coastguard Worker    eps: float,
365*da0073e9SAndroid Build Coastguard Worker    decoupled_weight_decay: bool,
366*da0073e9SAndroid Build Coastguard Worker    differentiable: bool,
367*da0073e9SAndroid Build Coastguard Worker    maximize: bool,
368*da0073e9SAndroid Build Coastguard Worker    capturable: bool,
369*da0073e9SAndroid Build Coastguard Worker    has_complex: bool,
370*da0073e9SAndroid Build Coastguard Worker):
371*da0073e9SAndroid Build Coastguard Worker    if len(params) == 0:
372*da0073e9SAndroid Build Coastguard Worker        return
373*da0073e9SAndroid Build Coastguard Worker
374*da0073e9SAndroid Build Coastguard Worker    assert not differentiable, "_foreach ops don't support autograd"
375*da0073e9SAndroid Build Coastguard Worker
376*da0073e9SAndroid Build Coastguard Worker    # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
377*da0073e9SAndroid Build Coastguard Worker    if not torch._utils.is_compiling() and capturable:
378*da0073e9SAndroid Build Coastguard Worker        capturable_supported_devices = _get_capturable_supported_devices(
379*da0073e9SAndroid Build Coastguard Worker            supports_xla=False
380*da0073e9SAndroid Build Coastguard Worker        )
381*da0073e9SAndroid Build Coastguard Worker        assert all(
382*da0073e9SAndroid Build Coastguard Worker            p.device.type == step.device.type
383*da0073e9SAndroid Build Coastguard Worker            and p.device.type in capturable_supported_devices
384*da0073e9SAndroid Build Coastguard Worker            for p, step in zip(params, state_steps)
385*da0073e9SAndroid Build Coastguard Worker        ), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
386*da0073e9SAndroid Build Coastguard Worker
387*da0073e9SAndroid Build Coastguard Worker    grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
388*da0073e9SAndroid Build Coastguard Worker        [params, grads, exp_avgs, exp_avg_sqs, state_steps]  # type: ignore[list-item]
389*da0073e9SAndroid Build Coastguard Worker    )
390*da0073e9SAndroid Build Coastguard Worker    for (
391*da0073e9SAndroid Build Coastguard Worker        grouped_params_,
392*da0073e9SAndroid Build Coastguard Worker        grouped_grads_,
393*da0073e9SAndroid Build Coastguard Worker        grouped_exp_avgs_,
394*da0073e9SAndroid Build Coastguard Worker        grouped_exp_avg_sqs_,
395*da0073e9SAndroid Build Coastguard Worker        grouped_state_steps_,
396*da0073e9SAndroid Build Coastguard Worker    ), _ in grouped_tensors.values():
397*da0073e9SAndroid Build Coastguard Worker        grouped_params = cast(List[Tensor], grouped_params_)
398*da0073e9SAndroid Build Coastguard Worker        grouped_grads = cast(List[Tensor], grouped_grads_)
399*da0073e9SAndroid Build Coastguard Worker        grouped_exp_avgs = cast(List[Tensor], grouped_exp_avgs_)
400*da0073e9SAndroid Build Coastguard Worker        grouped_exp_avg_sqs = cast(List[Tensor], grouped_exp_avg_sqs_)
401*da0073e9SAndroid Build Coastguard Worker        grouped_state_steps = cast(List[Tensor], grouped_state_steps_)
402*da0073e9SAndroid Build Coastguard Worker
403*da0073e9SAndroid Build Coastguard Worker        # Update steps
404*da0073e9SAndroid Build Coastguard Worker        # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
405*da0073e9SAndroid Build Coastguard Worker        # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
406*da0073e9SAndroid Build Coastguard Worker        # wrapped it once now. The alpha is required to assure we go to the right overload.
407*da0073e9SAndroid Build Coastguard Worker        if not torch._utils.is_compiling() and grouped_state_steps[0].is_cpu:
408*da0073e9SAndroid Build Coastguard Worker            torch._foreach_add_(
409*da0073e9SAndroid Build Coastguard Worker                grouped_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
410*da0073e9SAndroid Build Coastguard Worker            )
411*da0073e9SAndroid Build Coastguard Worker        else:
412*da0073e9SAndroid Build Coastguard Worker            torch._foreach_add_(grouped_state_steps, 1)
413*da0073e9SAndroid Build Coastguard Worker
414*da0073e9SAndroid Build Coastguard Worker        if has_complex:
415*da0073e9SAndroid Build Coastguard Worker            _view_as_real(
416*da0073e9SAndroid Build Coastguard Worker                grouped_params, grouped_grads, grouped_exp_avgs, grouped_exp_avg_sqs
417*da0073e9SAndroid Build Coastguard Worker            )
418*da0073e9SAndroid Build Coastguard Worker
419*da0073e9SAndroid Build Coastguard Worker        if maximize:
420*da0073e9SAndroid Build Coastguard Worker            grouped_grads = torch._foreach_neg(grouped_grads)  # type: ignore[assignment]
421*da0073e9SAndroid Build Coastguard Worker
422*da0073e9SAndroid Build Coastguard Worker        # maximum length of the approximated SMA
423*da0073e9SAndroid Build Coastguard Worker        rho_inf = 2 / (1 - beta2) - 1
424*da0073e9SAndroid Build Coastguard Worker        # compute the length of the approximated SMA
425*da0073e9SAndroid Build Coastguard Worker        bias_correction1: Union[Tuple[Tensor, ...], List[Tensor]]
426*da0073e9SAndroid Build Coastguard Worker        bias_correction2: Union[Tuple[Tensor, ...], List[Tensor]]
427*da0073e9SAndroid Build Coastguard Worker        rho_t_list: Union[Tuple[Tensor, ...], List[Tensor]]
428*da0073e9SAndroid Build Coastguard Worker        if capturable:
429*da0073e9SAndroid Build Coastguard Worker            bias_correction1 = torch._foreach_pow(beta2, grouped_state_steps)
430*da0073e9SAndroid Build Coastguard Worker            torch._foreach_neg_(bias_correction1)
431*da0073e9SAndroid Build Coastguard Worker            torch._foreach_add_(bias_correction1, 1)
432*da0073e9SAndroid Build Coastguard Worker            bias_correction2 = torch._foreach_pow(beta2, grouped_state_steps)
433*da0073e9SAndroid Build Coastguard Worker            torch._foreach_mul_(bias_correction2, grouped_state_steps)
434*da0073e9SAndroid Build Coastguard Worker            torch._foreach_mul_(bias_correction2, 2)
435*da0073e9SAndroid Build Coastguard Worker            torch._foreach_div_(bias_correction2, bias_correction1)
436*da0073e9SAndroid Build Coastguard Worker            torch._foreach_neg_(bias_correction2)
437*da0073e9SAndroid Build Coastguard Worker            torch._foreach_add_(bias_correction2, rho_inf)
438*da0073e9SAndroid Build Coastguard Worker            rho_t_list = bias_correction2
439*da0073e9SAndroid Build Coastguard Worker        else:
440*da0073e9SAndroid Build Coastguard Worker            rho_t_list = [
441*da0073e9SAndroid Build Coastguard Worker                rho_inf
442*da0073e9SAndroid Build Coastguard Worker                - 2
443*da0073e9SAndroid Build Coastguard Worker                * _get_value(step)
444*da0073e9SAndroid Build Coastguard Worker                * (beta2 ** _get_value(step))
445*da0073e9SAndroid Build Coastguard Worker                / (1 - beta2 ** _get_value(step))
446*da0073e9SAndroid Build Coastguard Worker                for step in grouped_state_steps
447*da0073e9SAndroid Build Coastguard Worker            ]
448*da0073e9SAndroid Build Coastguard Worker
449*da0073e9SAndroid Build Coastguard Worker        if weight_decay != 0:
450*da0073e9SAndroid Build Coastguard Worker            if decoupled_weight_decay:
451*da0073e9SAndroid Build Coastguard Worker                torch._foreach_mul_(grouped_params, 1 - lr * weight_decay)
452*da0073e9SAndroid Build Coastguard Worker            else:
453*da0073e9SAndroid Build Coastguard Worker                # Re-use the intermediate memory (grouped_grads) already allocated for maximize
454*da0073e9SAndroid Build Coastguard Worker                if maximize:
455*da0073e9SAndroid Build Coastguard Worker                    torch._foreach_add_(
456*da0073e9SAndroid Build Coastguard Worker                        grouped_grads, grouped_params, alpha=weight_decay
457*da0073e9SAndroid Build Coastguard Worker                    )
458*da0073e9SAndroid Build Coastguard Worker                else:
459*da0073e9SAndroid Build Coastguard Worker                    grouped_grads = torch._foreach_add(  # type: ignore[assignment]
460*da0073e9SAndroid Build Coastguard Worker                        grouped_grads, grouped_params, alpha=weight_decay
461*da0073e9SAndroid Build Coastguard Worker                    )
462*da0073e9SAndroid Build Coastguard Worker
463*da0073e9SAndroid Build Coastguard Worker        # Decay the first and second moment running average coefficient
464*da0073e9SAndroid Build Coastguard Worker        torch._foreach_lerp_(grouped_exp_avgs, grouped_grads, 1 - beta1)
465*da0073e9SAndroid Build Coastguard Worker
466*da0073e9SAndroid Build Coastguard Worker        torch._foreach_mul_(grouped_exp_avg_sqs, beta2)
467*da0073e9SAndroid Build Coastguard Worker        torch._foreach_addcmul_(
468*da0073e9SAndroid Build Coastguard Worker            grouped_exp_avg_sqs, grouped_grads, grouped_grads, 1 - beta2
469*da0073e9SAndroid Build Coastguard Worker        )
470*da0073e9SAndroid Build Coastguard Worker
471*da0073e9SAndroid Build Coastguard Worker        # Delete the local intermediate since it won't be used anymore to save on peak memory
472*da0073e9SAndroid Build Coastguard Worker        del grouped_grads
473*da0073e9SAndroid Build Coastguard Worker
474*da0073e9SAndroid Build Coastguard Worker        if capturable:
475*da0073e9SAndroid Build Coastguard Worker            num = torch._foreach_sub(rho_t_list, 4)
476*da0073e9SAndroid Build Coastguard Worker            sub2 = torch._foreach_sub(rho_t_list, 2)
477*da0073e9SAndroid Build Coastguard Worker            torch._foreach_mul_(num, sub2)
478*da0073e9SAndroid Build Coastguard Worker            del sub2
479*da0073e9SAndroid Build Coastguard Worker            torch._foreach_mul_(num, rho_inf)
480*da0073e9SAndroid Build Coastguard Worker            rho_inf = (rho_inf - 4) * (rho_inf - 2)
481*da0073e9SAndroid Build Coastguard Worker            denom = torch._foreach_mul(rho_t_list, rho_inf)
482*da0073e9SAndroid Build Coastguard Worker            torch._foreach_div_(num, denom)
483*da0073e9SAndroid Build Coastguard Worker            del denom
484*da0073e9SAndroid Build Coastguard Worker            torch._foreach_sqrt_(num)
485*da0073e9SAndroid Build Coastguard Worker
486*da0073e9SAndroid Build Coastguard Worker            # TODO(mlazos): we should try and get a foreach_where op https://github.com/pytorch/pytorch/issues/117884
487*da0073e9SAndroid Build Coastguard Worker            rect = [
488*da0073e9SAndroid Build Coastguard Worker                torch.where(rho_t > 5.0, n, 0.0) for n, rho_t in zip(num, rho_t_list)
489*da0073e9SAndroid Build Coastguard Worker            ]
490*da0073e9SAndroid Build Coastguard Worker            del num
491*da0073e9SAndroid Build Coastguard Worker            del rho_t_list
492*da0073e9SAndroid Build Coastguard Worker            unrect_step_size = [torch.where(rect > 0, 0.0, 1.0) for rect in rect]
493*da0073e9SAndroid Build Coastguard Worker            torch._foreach_mul_(unrect_step_size, lr)
494*da0073e9SAndroid Build Coastguard Worker
495*da0073e9SAndroid Build Coastguard Worker            bias_correction1 = torch._foreach_pow(beta1, grouped_state_steps)
496*da0073e9SAndroid Build Coastguard Worker            torch._foreach_neg_(bias_correction1)
497*da0073e9SAndroid Build Coastguard Worker            torch._foreach_add_(bias_correction1, 1)
498*da0073e9SAndroid Build Coastguard Worker
499*da0073e9SAndroid Build Coastguard Worker            torch._foreach_div_(unrect_step_size, bias_correction1)
500*da0073e9SAndroid Build Coastguard Worker            torch._foreach_neg_(unrect_step_size)
501*da0073e9SAndroid Build Coastguard Worker
502*da0073e9SAndroid Build Coastguard Worker            bias_correction2 = torch._foreach_pow(beta2, grouped_state_steps)
503*da0073e9SAndroid Build Coastguard Worker            torch._foreach_neg_(bias_correction2)
504*da0073e9SAndroid Build Coastguard Worker            torch._foreach_add_(bias_correction2, 1)
505*da0073e9SAndroid Build Coastguard Worker            torch._foreach_sqrt_(bias_correction2)
506*da0073e9SAndroid Build Coastguard Worker            torch._foreach_mul_(bias_correction2, lr)
507*da0073e9SAndroid Build Coastguard Worker            torch._foreach_mul_(bias_correction2, rect)
508*da0073e9SAndroid Build Coastguard Worker            del rect
509*da0073e9SAndroid Build Coastguard Worker            torch._foreach_neg_(bias_correction2)
510*da0073e9SAndroid Build Coastguard Worker            torch._foreach_div_(bias_correction2, bias_correction1)
511*da0073e9SAndroid Build Coastguard Worker            del bias_correction1
512*da0073e9SAndroid Build Coastguard Worker        else:
513*da0073e9SAndroid Build Coastguard Worker            rect = [
514*da0073e9SAndroid Build Coastguard Worker                (
515*da0073e9SAndroid Build Coastguard Worker                    (rho_t - 4)  # type: ignore[arg-type]
516*da0073e9SAndroid Build Coastguard Worker                    * (rho_t - 2)
517*da0073e9SAndroid Build Coastguard Worker                    * rho_inf
518*da0073e9SAndroid Build Coastguard Worker                    / ((rho_inf - 4) * (rho_inf - 2) * rho_t)
519*da0073e9SAndroid Build Coastguard Worker                )
520*da0073e9SAndroid Build Coastguard Worker                ** 0.5
521*da0073e9SAndroid Build Coastguard Worker                if rho_t > 5
522*da0073e9SAndroid Build Coastguard Worker                else 0
523*da0073e9SAndroid Build Coastguard Worker                for rho_t in rho_t_list
524*da0073e9SAndroid Build Coastguard Worker            ]
525*da0073e9SAndroid Build Coastguard Worker            unrectified = [0 if rect > 0 else 1.0 for rect in rect]
526*da0073e9SAndroid Build Coastguard Worker
527*da0073e9SAndroid Build Coastguard Worker            bias_correction1 = [
528*da0073e9SAndroid Build Coastguard Worker                1 - beta1 ** _get_value(step) for step in grouped_state_steps
529*da0073e9SAndroid Build Coastguard Worker            ]
530*da0073e9SAndroid Build Coastguard Worker            unrect_step_size = [
531*da0073e9SAndroid Build Coastguard Worker                (lr * rect / bc) * -1 for rect, bc in zip(unrectified, bias_correction1)
532*da0073e9SAndroid Build Coastguard Worker            ]
533*da0073e9SAndroid Build Coastguard Worker            bias_correction2 = [
534*da0073e9SAndroid Build Coastguard Worker                ((1 - beta2 ** _get_value(step)) ** 0.5) * (lr * rect / bc) * -1
535*da0073e9SAndroid Build Coastguard Worker                for step, rect, bc in zip(grouped_state_steps, rect, bias_correction1)
536*da0073e9SAndroid Build Coastguard Worker            ]
537*da0073e9SAndroid Build Coastguard Worker
538*da0073e9SAndroid Build Coastguard Worker        buffer = torch._foreach_sqrt(grouped_exp_avg_sqs)
539*da0073e9SAndroid Build Coastguard Worker        torch._foreach_add_(buffer, eps)
540*da0073e9SAndroid Build Coastguard Worker        torch._foreach_div_(buffer, bias_correction2)
541*da0073e9SAndroid Build Coastguard Worker        torch._foreach_reciprocal_(buffer)
542*da0073e9SAndroid Build Coastguard Worker        torch._foreach_add_(buffer, unrect_step_size)
543*da0073e9SAndroid Build Coastguard Worker
544*da0073e9SAndroid Build Coastguard Worker        # Here, buffer = sqrt(1 - beta2^t) * rect_step_size / (sqrt(v) + eps) + unrect_step_size
545*da0073e9SAndroid Build Coastguard Worker        torch._foreach_addcmul_(grouped_params, grouped_exp_avgs, buffer)
546*da0073e9SAndroid Build Coastguard Worker
547*da0073e9SAndroid Build Coastguard Worker
548*da0073e9SAndroid Build Coastguard Worker@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_radam)
549*da0073e9SAndroid Build Coastguard Workerdef radam(
550*da0073e9SAndroid Build Coastguard Worker    params: List[Tensor],
551*da0073e9SAndroid Build Coastguard Worker    grads: List[Tensor],
552*da0073e9SAndroid Build Coastguard Worker    exp_avgs: List[Tensor],
553*da0073e9SAndroid Build Coastguard Worker    exp_avg_sqs: List[Tensor],
554*da0073e9SAndroid Build Coastguard Worker    state_steps: List[Tensor],
555*da0073e9SAndroid Build Coastguard Worker    # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
556*da0073e9SAndroid Build Coastguard Worker    # setting this as kwarg for now as functional API is compiled by torch/distributed/optim
557*da0073e9SAndroid Build Coastguard Worker    decoupled_weight_decay: bool = False,
558*da0073e9SAndroid Build Coastguard Worker    foreach: Optional[bool] = None,
559*da0073e9SAndroid Build Coastguard Worker    differentiable: bool = False,
560*da0073e9SAndroid Build Coastguard Worker    capturable: bool = False,
561*da0073e9SAndroid Build Coastguard Worker    has_complex: bool = False,
562*da0073e9SAndroid Build Coastguard Worker    maximize: bool = False,
563*da0073e9SAndroid Build Coastguard Worker    *,
564*da0073e9SAndroid Build Coastguard Worker    beta1: float,
565*da0073e9SAndroid Build Coastguard Worker    beta2: float,
566*da0073e9SAndroid Build Coastguard Worker    lr: float,
567*da0073e9SAndroid Build Coastguard Worker    weight_decay: float,
568*da0073e9SAndroid Build Coastguard Worker    eps: float,
569*da0073e9SAndroid Build Coastguard Worker):
570*da0073e9SAndroid Build Coastguard Worker    r"""Functional API that performs RAdam algorithm computation.
571*da0073e9SAndroid Build Coastguard Worker
572*da0073e9SAndroid Build Coastguard Worker    See :class:`~torch.optim.RAdam` for details.
573*da0073e9SAndroid Build Coastguard Worker    """
574*da0073e9SAndroid Build Coastguard Worker    if not all(isinstance(t, torch.Tensor) for t in state_steps):
575*da0073e9SAndroid Build Coastguard Worker        raise RuntimeError(
576*da0073e9SAndroid Build Coastguard Worker            "API has changed, `state_steps` argument must contain a list of singleton tensors"
577*da0073e9SAndroid Build Coastguard Worker        )
578*da0073e9SAndroid Build Coastguard Worker
579*da0073e9SAndroid Build Coastguard Worker    if foreach is None:
580*da0073e9SAndroid Build Coastguard Worker        _, foreach = _default_to_fused_or_foreach(
581*da0073e9SAndroid Build Coastguard Worker            params, differentiable, use_fused=False
582*da0073e9SAndroid Build Coastguard Worker        )
583*da0073e9SAndroid Build Coastguard Worker
584*da0073e9SAndroid Build Coastguard Worker    if foreach and torch.jit.is_scripting():
585*da0073e9SAndroid Build Coastguard Worker        raise RuntimeError("torch.jit.script not supported with foreach optimizers")
586*da0073e9SAndroid Build Coastguard Worker
587*da0073e9SAndroid Build Coastguard Worker    if foreach and not torch.jit.is_scripting():
588*da0073e9SAndroid Build Coastguard Worker        func = _multi_tensor_radam
589*da0073e9SAndroid Build Coastguard Worker    else:
590*da0073e9SAndroid Build Coastguard Worker        func = _single_tensor_radam
591*da0073e9SAndroid Build Coastguard Worker
592*da0073e9SAndroid Build Coastguard Worker    func(
593*da0073e9SAndroid Build Coastguard Worker        params,
594*da0073e9SAndroid Build Coastguard Worker        grads,
595*da0073e9SAndroid Build Coastguard Worker        exp_avgs,
596*da0073e9SAndroid Build Coastguard Worker        exp_avg_sqs,
597*da0073e9SAndroid Build Coastguard Worker        state_steps,
598*da0073e9SAndroid Build Coastguard Worker        beta1=beta1,
599*da0073e9SAndroid Build Coastguard Worker        beta2=beta2,
600*da0073e9SAndroid Build Coastguard Worker        lr=lr,
601*da0073e9SAndroid Build Coastguard Worker        weight_decay=weight_decay,
602*da0073e9SAndroid Build Coastguard Worker        eps=eps,
603*da0073e9SAndroid Build Coastguard Worker        maximize=maximize,
604*da0073e9SAndroid Build Coastguard Worker        decoupled_weight_decay=decoupled_weight_decay,
605*da0073e9SAndroid Build Coastguard Worker        differentiable=differentiable,
606*da0073e9SAndroid Build Coastguard Worker        capturable=capturable,
607*da0073e9SAndroid Build Coastguard Worker        has_complex=has_complex,
608*da0073e9SAndroid Build Coastguard Worker    )
609