xref: /aosp_15_r20/external/pytorch/torch/optim/adam.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-decorators
2*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs
3*da0073e9SAndroid Build Coastguard Workerfrom typing import cast, List, Optional, Tuple, Union
4*da0073e9SAndroid Build Coastguard Worker
5*da0073e9SAndroid Build Coastguard Workerimport torch
6*da0073e9SAndroid Build Coastguard Workerfrom torch import Tensor
7*da0073e9SAndroid Build Coastguard Worker
8*da0073e9SAndroid Build Coastguard Workerfrom .optimizer import (
9*da0073e9SAndroid Build Coastguard Worker    _capturable_doc,
10*da0073e9SAndroid Build Coastguard Worker    _default_to_fused_or_foreach,
11*da0073e9SAndroid Build Coastguard Worker    _device_dtype_check_for_fused,
12*da0073e9SAndroid Build Coastguard Worker    _differentiable_doc,
13*da0073e9SAndroid Build Coastguard Worker    _disable_dynamo_if_unsupported,
14*da0073e9SAndroid Build Coastguard Worker    _foreach_doc,
15*da0073e9SAndroid Build Coastguard Worker    _fused_doc,
16*da0073e9SAndroid Build Coastguard Worker    _get_capturable_supported_devices,
17*da0073e9SAndroid Build Coastguard Worker    _get_scalar_dtype,
18*da0073e9SAndroid Build Coastguard Worker    _get_value,
19*da0073e9SAndroid Build Coastguard Worker    _maximize_doc,
20*da0073e9SAndroid Build Coastguard Worker    _stack_if_compiling,
21*da0073e9SAndroid Build Coastguard Worker    _use_grad_for_differentiable,
22*da0073e9SAndroid Build Coastguard Worker    _view_as_real,
23*da0073e9SAndroid Build Coastguard Worker    DeviceDict,
24*da0073e9SAndroid Build Coastguard Worker    Optimizer,
25*da0073e9SAndroid Build Coastguard Worker    ParamsT,
26*da0073e9SAndroid Build Coastguard Worker)
27*da0073e9SAndroid Build Coastguard Worker
28*da0073e9SAndroid Build Coastguard Worker
29*da0073e9SAndroid Build Coastguard Worker__all__ = ["Adam", "adam"]
30*da0073e9SAndroid Build Coastguard Worker
31*da0073e9SAndroid Build Coastguard Worker
32*da0073e9SAndroid Build Coastguard Workerclass Adam(Optimizer):
33*da0073e9SAndroid Build Coastguard Worker    def __init__(
34*da0073e9SAndroid Build Coastguard Worker        self,
35*da0073e9SAndroid Build Coastguard Worker        params: ParamsT,
36*da0073e9SAndroid Build Coastguard Worker        lr: Union[float, Tensor] = 1e-3,
37*da0073e9SAndroid Build Coastguard Worker        betas: Tuple[float, float] = (0.9, 0.999),
38*da0073e9SAndroid Build Coastguard Worker        eps: float = 1e-8,
39*da0073e9SAndroid Build Coastguard Worker        weight_decay: float = 0,
40*da0073e9SAndroid Build Coastguard Worker        amsgrad: bool = False,
41*da0073e9SAndroid Build Coastguard Worker        *,
42*da0073e9SAndroid Build Coastguard Worker        foreach: Optional[bool] = None,
43*da0073e9SAndroid Build Coastguard Worker        maximize: bool = False,
44*da0073e9SAndroid Build Coastguard Worker        capturable: bool = False,
45*da0073e9SAndroid Build Coastguard Worker        differentiable: bool = False,
46*da0073e9SAndroid Build Coastguard Worker        fused: Optional[bool] = None,
47*da0073e9SAndroid Build Coastguard Worker    ):
48*da0073e9SAndroid Build Coastguard Worker        if isinstance(lr, Tensor):
49*da0073e9SAndroid Build Coastguard Worker            if foreach and not capturable:
50*da0073e9SAndroid Build Coastguard Worker                raise ValueError(
51*da0073e9SAndroid Build Coastguard Worker                    "lr as a Tensor is not supported for capturable=False and foreach=True"
52*da0073e9SAndroid Build Coastguard Worker                )
53*da0073e9SAndroid Build Coastguard Worker            if lr.numel() != 1:
54*da0073e9SAndroid Build Coastguard Worker                raise ValueError("Tensor lr must be 1-element")
55*da0073e9SAndroid Build Coastguard Worker        if not 0.0 <= lr:
56*da0073e9SAndroid Build Coastguard Worker            raise ValueError(f"Invalid learning rate: {lr}")
57*da0073e9SAndroid Build Coastguard Worker        if not 0.0 <= eps:
58*da0073e9SAndroid Build Coastguard Worker            raise ValueError(f"Invalid epsilon value: {eps}")
59*da0073e9SAndroid Build Coastguard Worker        if not 0.0 <= betas[0] < 1.0:
60*da0073e9SAndroid Build Coastguard Worker            raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
61*da0073e9SAndroid Build Coastguard Worker        if not 0.0 <= betas[1] < 1.0:
62*da0073e9SAndroid Build Coastguard Worker            raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
63*da0073e9SAndroid Build Coastguard Worker        if not 0.0 <= weight_decay:
64*da0073e9SAndroid Build Coastguard Worker            raise ValueError(f"Invalid weight_decay value: {weight_decay}")
65*da0073e9SAndroid Build Coastguard Worker
66*da0073e9SAndroid Build Coastguard Worker        defaults = dict(
67*da0073e9SAndroid Build Coastguard Worker            lr=lr,
68*da0073e9SAndroid Build Coastguard Worker            betas=betas,
69*da0073e9SAndroid Build Coastguard Worker            eps=eps,
70*da0073e9SAndroid Build Coastguard Worker            weight_decay=weight_decay,
71*da0073e9SAndroid Build Coastguard Worker            amsgrad=amsgrad,
72*da0073e9SAndroid Build Coastguard Worker            maximize=maximize,
73*da0073e9SAndroid Build Coastguard Worker            foreach=foreach,
74*da0073e9SAndroid Build Coastguard Worker            capturable=capturable,
75*da0073e9SAndroid Build Coastguard Worker            differentiable=differentiable,
76*da0073e9SAndroid Build Coastguard Worker            fused=fused,
77*da0073e9SAndroid Build Coastguard Worker        )
78*da0073e9SAndroid Build Coastguard Worker        super().__init__(params, defaults)
79*da0073e9SAndroid Build Coastguard Worker
80*da0073e9SAndroid Build Coastguard Worker        if fused:
81*da0073e9SAndroid Build Coastguard Worker            if differentiable:
82*da0073e9SAndroid Build Coastguard Worker                raise RuntimeError("`fused` does not support `differentiable`")
83*da0073e9SAndroid Build Coastguard Worker            self._step_supports_amp_scaling = True
84*da0073e9SAndroid Build Coastguard Worker            # TODO(crcrpar): [low prec params & their higher prec copy]
85*da0073e9SAndroid Build Coastguard Worker            # Support AMP with FP16/BF16 model params which would need
86*da0073e9SAndroid Build Coastguard Worker            # higher prec copy of params to do update math in higher prec to
87*da0073e9SAndroid Build Coastguard Worker            # alleviate the loss of information.
88*da0073e9SAndroid Build Coastguard Worker            if foreach:
89*da0073e9SAndroid Build Coastguard Worker                raise RuntimeError("`fused` and `foreach` cannot be `True` together.")
90*da0073e9SAndroid Build Coastguard Worker
91*da0073e9SAndroid Build Coastguard Worker    def __setstate__(self, state):
92*da0073e9SAndroid Build Coastguard Worker        super().__setstate__(state)
93*da0073e9SAndroid Build Coastguard Worker        for group in self.param_groups:
94*da0073e9SAndroid Build Coastguard Worker            group.setdefault("amsgrad", False)
95*da0073e9SAndroid Build Coastguard Worker            group.setdefault("maximize", False)
96*da0073e9SAndroid Build Coastguard Worker            group.setdefault("foreach", None)
97*da0073e9SAndroid Build Coastguard Worker            group.setdefault("capturable", False)
98*da0073e9SAndroid Build Coastguard Worker            group.setdefault("differentiable", False)
99*da0073e9SAndroid Build Coastguard Worker            fused = group.setdefault("fused", None)
100*da0073e9SAndroid Build Coastguard Worker            for p in group["params"]:
101*da0073e9SAndroid Build Coastguard Worker                p_state = self.state.get(p, [])
102*da0073e9SAndroid Build Coastguard Worker                if len(p_state) != 0 and not torch.is_tensor(p_state["step"]):
103*da0073e9SAndroid Build Coastguard Worker                    step_val = float(p_state["step"])
104*da0073e9SAndroid Build Coastguard Worker                    p_state["step"] = (
105*da0073e9SAndroid Build Coastguard Worker                        torch.tensor(
106*da0073e9SAndroid Build Coastguard Worker                            step_val,
107*da0073e9SAndroid Build Coastguard Worker                            dtype=_get_scalar_dtype(is_fused=fused),
108*da0073e9SAndroid Build Coastguard Worker                            device=p.device,
109*da0073e9SAndroid Build Coastguard Worker                        )
110*da0073e9SAndroid Build Coastguard Worker                        if group["capturable"] or group["fused"]
111*da0073e9SAndroid Build Coastguard Worker                        else torch.tensor(step_val, dtype=_get_scalar_dtype())
112*da0073e9SAndroid Build Coastguard Worker                    )
113*da0073e9SAndroid Build Coastguard Worker
114*da0073e9SAndroid Build Coastguard Worker    def _init_group(
115*da0073e9SAndroid Build Coastguard Worker        self,
116*da0073e9SAndroid Build Coastguard Worker        group,
117*da0073e9SAndroid Build Coastguard Worker        params_with_grad,
118*da0073e9SAndroid Build Coastguard Worker        grads,
119*da0073e9SAndroid Build Coastguard Worker        exp_avgs,
120*da0073e9SAndroid Build Coastguard Worker        exp_avg_sqs,
121*da0073e9SAndroid Build Coastguard Worker        max_exp_avg_sqs,
122*da0073e9SAndroid Build Coastguard Worker        state_steps,
123*da0073e9SAndroid Build Coastguard Worker    ):
124*da0073e9SAndroid Build Coastguard Worker        has_complex = False
125*da0073e9SAndroid Build Coastguard Worker        for p in group["params"]:
126*da0073e9SAndroid Build Coastguard Worker            if p.grad is not None:
127*da0073e9SAndroid Build Coastguard Worker                has_complex |= torch.is_complex(p)
128*da0073e9SAndroid Build Coastguard Worker                params_with_grad.append(p)
129*da0073e9SAndroid Build Coastguard Worker                if p.grad.is_sparse:
130*da0073e9SAndroid Build Coastguard Worker                    raise RuntimeError(
131*da0073e9SAndroid Build Coastguard Worker                        "Adam does not support sparse gradients, please consider SparseAdam instead"
132*da0073e9SAndroid Build Coastguard Worker                    )
133*da0073e9SAndroid Build Coastguard Worker                grads.append(p.grad)
134*da0073e9SAndroid Build Coastguard Worker
135*da0073e9SAndroid Build Coastguard Worker                state = self.state[p]
136*da0073e9SAndroid Build Coastguard Worker                # Lazy state initialization
137*da0073e9SAndroid Build Coastguard Worker                if len(state) == 0:
138*da0073e9SAndroid Build Coastguard Worker                    if group["fused"]:
139*da0073e9SAndroid Build Coastguard Worker                        _device_dtype_check_for_fused(p)
140*da0073e9SAndroid Build Coastguard Worker                    # note(crcrpar): [special device hosting for step]
141*da0073e9SAndroid Build Coastguard Worker                    # Deliberately host `step` on CPU if both capturable and fused are off.
142*da0073e9SAndroid Build Coastguard Worker                    # This is because kernel launches are costly on CUDA and XLA.
143*da0073e9SAndroid Build Coastguard Worker                    state["step"] = (
144*da0073e9SAndroid Build Coastguard Worker                        torch.zeros(
145*da0073e9SAndroid Build Coastguard Worker                            (),
146*da0073e9SAndroid Build Coastguard Worker                            dtype=_get_scalar_dtype(is_fused=group["fused"]),
147*da0073e9SAndroid Build Coastguard Worker                            device=p.device,
148*da0073e9SAndroid Build Coastguard Worker                        )
149*da0073e9SAndroid Build Coastguard Worker                        if group["capturable"] or group["fused"]
150*da0073e9SAndroid Build Coastguard Worker                        else torch.tensor(0.0, dtype=_get_scalar_dtype())
151*da0073e9SAndroid Build Coastguard Worker                    )
152*da0073e9SAndroid Build Coastguard Worker                    # Exponential moving average of gradient values
153*da0073e9SAndroid Build Coastguard Worker                    state["exp_avg"] = torch.zeros_like(
154*da0073e9SAndroid Build Coastguard Worker                        p, memory_format=torch.preserve_format
155*da0073e9SAndroid Build Coastguard Worker                    )
156*da0073e9SAndroid Build Coastguard Worker                    # Exponential moving average of squared gradient values
157*da0073e9SAndroid Build Coastguard Worker                    state["exp_avg_sq"] = torch.zeros_like(
158*da0073e9SAndroid Build Coastguard Worker                        p, memory_format=torch.preserve_format
159*da0073e9SAndroid Build Coastguard Worker                    )
160*da0073e9SAndroid Build Coastguard Worker                    if group["amsgrad"]:
161*da0073e9SAndroid Build Coastguard Worker                        # Maintains max of all exp. moving avg. of sq. grad. values
162*da0073e9SAndroid Build Coastguard Worker                        state["max_exp_avg_sq"] = torch.zeros_like(
163*da0073e9SAndroid Build Coastguard Worker                            p, memory_format=torch.preserve_format
164*da0073e9SAndroid Build Coastguard Worker                        )
165*da0073e9SAndroid Build Coastguard Worker
166*da0073e9SAndroid Build Coastguard Worker                exp_avgs.append(state["exp_avg"])
167*da0073e9SAndroid Build Coastguard Worker                exp_avg_sqs.append(state["exp_avg_sq"])
168*da0073e9SAndroid Build Coastguard Worker
169*da0073e9SAndroid Build Coastguard Worker                if group["amsgrad"]:
170*da0073e9SAndroid Build Coastguard Worker                    max_exp_avg_sqs.append(state["max_exp_avg_sq"])
171*da0073e9SAndroid Build Coastguard Worker                if group["differentiable"] and state["step"].requires_grad:
172*da0073e9SAndroid Build Coastguard Worker                    raise RuntimeError(
173*da0073e9SAndroid Build Coastguard Worker                        "`requires_grad` is not supported for `step` in differentiable mode"
174*da0073e9SAndroid Build Coastguard Worker                    )
175*da0073e9SAndroid Build Coastguard Worker
176*da0073e9SAndroid Build Coastguard Worker                # Foreach without capturable does not support a tensor lr
177*da0073e9SAndroid Build Coastguard Worker                if (
178*da0073e9SAndroid Build Coastguard Worker                    group["foreach"]
179*da0073e9SAndroid Build Coastguard Worker                    and torch.is_tensor(group["lr"])
180*da0073e9SAndroid Build Coastguard Worker                    and not group["capturable"]
181*da0073e9SAndroid Build Coastguard Worker                ):
182*da0073e9SAndroid Build Coastguard Worker                    raise RuntimeError(
183*da0073e9SAndroid Build Coastguard Worker                        "lr as a Tensor is not supported for capturable=False and foreach=True"
184*da0073e9SAndroid Build Coastguard Worker                    )
185*da0073e9SAndroid Build Coastguard Worker
186*da0073e9SAndroid Build Coastguard Worker                state_steps.append(state["step"])
187*da0073e9SAndroid Build Coastguard Worker        return has_complex
188*da0073e9SAndroid Build Coastguard Worker
189*da0073e9SAndroid Build Coastguard Worker    @_use_grad_for_differentiable
190*da0073e9SAndroid Build Coastguard Worker    def step(self, closure=None):
191*da0073e9SAndroid Build Coastguard Worker        """Perform a single optimization step.
192*da0073e9SAndroid Build Coastguard Worker
193*da0073e9SAndroid Build Coastguard Worker        Args:
194*da0073e9SAndroid Build Coastguard Worker            closure (Callable, optional): A closure that reevaluates the model
195*da0073e9SAndroid Build Coastguard Worker                and returns the loss.
196*da0073e9SAndroid Build Coastguard Worker        """
197*da0073e9SAndroid Build Coastguard Worker        self._cuda_graph_capture_health_check()
198*da0073e9SAndroid Build Coastguard Worker
199*da0073e9SAndroid Build Coastguard Worker        loss = None
200*da0073e9SAndroid Build Coastguard Worker        if closure is not None:
201*da0073e9SAndroid Build Coastguard Worker            with torch.enable_grad():
202*da0073e9SAndroid Build Coastguard Worker                loss = closure()
203*da0073e9SAndroid Build Coastguard Worker
204*da0073e9SAndroid Build Coastguard Worker        for group in self.param_groups:
205*da0073e9SAndroid Build Coastguard Worker            params_with_grad: List[Tensor] = []
206*da0073e9SAndroid Build Coastguard Worker            grads: List[Tensor] = []
207*da0073e9SAndroid Build Coastguard Worker            exp_avgs: List[Tensor] = []
208*da0073e9SAndroid Build Coastguard Worker            exp_avg_sqs: List[Tensor] = []
209*da0073e9SAndroid Build Coastguard Worker            max_exp_avg_sqs: List[Tensor] = []
210*da0073e9SAndroid Build Coastguard Worker            state_steps: List[Tensor] = []
211*da0073e9SAndroid Build Coastguard Worker            beta1, beta2 = group["betas"]
212*da0073e9SAndroid Build Coastguard Worker
213*da0073e9SAndroid Build Coastguard Worker            has_complex = self._init_group(
214*da0073e9SAndroid Build Coastguard Worker                group,
215*da0073e9SAndroid Build Coastguard Worker                params_with_grad,
216*da0073e9SAndroid Build Coastguard Worker                grads,
217*da0073e9SAndroid Build Coastguard Worker                exp_avgs,
218*da0073e9SAndroid Build Coastguard Worker                exp_avg_sqs,
219*da0073e9SAndroid Build Coastguard Worker                max_exp_avg_sqs,
220*da0073e9SAndroid Build Coastguard Worker                state_steps,
221*da0073e9SAndroid Build Coastguard Worker            )
222*da0073e9SAndroid Build Coastguard Worker
223*da0073e9SAndroid Build Coastguard Worker            adam(
224*da0073e9SAndroid Build Coastguard Worker                params_with_grad,
225*da0073e9SAndroid Build Coastguard Worker                grads,
226*da0073e9SAndroid Build Coastguard Worker                exp_avgs,
227*da0073e9SAndroid Build Coastguard Worker                exp_avg_sqs,
228*da0073e9SAndroid Build Coastguard Worker                max_exp_avg_sqs,
229*da0073e9SAndroid Build Coastguard Worker                state_steps,
230*da0073e9SAndroid Build Coastguard Worker                amsgrad=group["amsgrad"],
231*da0073e9SAndroid Build Coastguard Worker                has_complex=has_complex,
232*da0073e9SAndroid Build Coastguard Worker                beta1=beta1,
233*da0073e9SAndroid Build Coastguard Worker                beta2=beta2,
234*da0073e9SAndroid Build Coastguard Worker                lr=group["lr"],
235*da0073e9SAndroid Build Coastguard Worker                weight_decay=group["weight_decay"],
236*da0073e9SAndroid Build Coastguard Worker                eps=group["eps"],
237*da0073e9SAndroid Build Coastguard Worker                maximize=group["maximize"],
238*da0073e9SAndroid Build Coastguard Worker                foreach=group["foreach"],
239*da0073e9SAndroid Build Coastguard Worker                capturable=group["capturable"],
240*da0073e9SAndroid Build Coastguard Worker                differentiable=group["differentiable"],
241*da0073e9SAndroid Build Coastguard Worker                fused=group["fused"],
242*da0073e9SAndroid Build Coastguard Worker                grad_scale=getattr(self, "grad_scale", None),
243*da0073e9SAndroid Build Coastguard Worker                found_inf=getattr(self, "found_inf", None),
244*da0073e9SAndroid Build Coastguard Worker            )
245*da0073e9SAndroid Build Coastguard Worker
246*da0073e9SAndroid Build Coastguard Worker        return loss
247*da0073e9SAndroid Build Coastguard Worker
248*da0073e9SAndroid Build Coastguard Worker
249*da0073e9SAndroid Build Coastguard WorkerAdam.__doc__ = (
250*da0073e9SAndroid Build Coastguard Worker    r"""Implements Adam algorithm.
251*da0073e9SAndroid Build Coastguard Worker
252*da0073e9SAndroid Build Coastguard Worker    .. math::
253*da0073e9SAndroid Build Coastguard Worker       \begin{aligned}
254*da0073e9SAndroid Build Coastguard Worker            &\rule{110mm}{0.4pt}                                                                 \\
255*da0073e9SAndroid Build Coastguard Worker            &\textbf{input}      : \gamma \text{ (lr)}, \beta_1, \beta_2
256*da0073e9SAndroid Build Coastguard Worker                \text{ (betas)},\theta_0 \text{ (params)},f(\theta) \text{ (objective)}          \\
257*da0073e9SAndroid Build Coastguard Worker            &\hspace{13mm}      \lambda \text{ (weight decay)},  \: \textit{amsgrad},
258*da0073e9SAndroid Build Coastguard Worker                \:\textit{maximize}                                                              \\
259*da0073e9SAndroid Build Coastguard Worker            &\textbf{initialize} :  m_0 \leftarrow 0 \text{ ( first moment)},
260*da0073e9SAndroid Build Coastguard Worker                v_0\leftarrow 0 \text{ (second moment)},\: \widehat{v_0}^{max}\leftarrow 0\\[-1.ex]
261*da0073e9SAndroid Build Coastguard Worker            &\rule{110mm}{0.4pt}                                                                 \\
262*da0073e9SAndroid Build Coastguard Worker            &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do}                         \\
263*da0073e9SAndroid Build Coastguard Worker
264*da0073e9SAndroid Build Coastguard Worker            &\hspace{5mm}\textbf{if} \: \textit{maximize}:                                       \\
265*da0073e9SAndroid Build Coastguard Worker            &\hspace{10mm}g_t           \leftarrow   -\nabla_{\theta} f_t (\theta_{t-1})         \\
266*da0073e9SAndroid Build Coastguard Worker            &\hspace{5mm}\textbf{else}                                                           \\
267*da0073e9SAndroid Build Coastguard Worker            &\hspace{10mm}g_t           \leftarrow   \nabla_{\theta} f_t (\theta_{t-1})          \\
268*da0073e9SAndroid Build Coastguard Worker            &\hspace{5mm}\textbf{if} \: \lambda \neq 0                                           \\
269*da0073e9SAndroid Build Coastguard Worker            &\hspace{10mm} g_t \leftarrow g_t + \lambda  \theta_{t-1}                            \\
270*da0073e9SAndroid Build Coastguard Worker            &\hspace{5mm}m_t           \leftarrow   \beta_1 m_{t-1} + (1 - \beta_1) g_t          \\
271*da0073e9SAndroid Build Coastguard Worker            &\hspace{5mm}v_t           \leftarrow   \beta_2 v_{t-1} + (1-\beta_2) g^2_t          \\
272*da0073e9SAndroid Build Coastguard Worker            &\hspace{5mm}\widehat{m_t} \leftarrow   m_t/\big(1-\beta_1^t \big)                   \\
273*da0073e9SAndroid Build Coastguard Worker            &\hspace{5mm}\widehat{v_t} \leftarrow   v_t/\big(1-\beta_2^t \big)                   \\
274*da0073e9SAndroid Build Coastguard Worker            &\hspace{5mm}\textbf{if} \: amsgrad                                                  \\
275*da0073e9SAndroid Build Coastguard Worker            &\hspace{10mm}\widehat{v_t}^{max} \leftarrow \mathrm{max}(\widehat{v_t}^{max},
276*da0073e9SAndroid Build Coastguard Worker                \widehat{v_t})                                                                   \\
277*da0073e9SAndroid Build Coastguard Worker            &\hspace{10mm}\theta_t \leftarrow \theta_{t-1} - \gamma \widehat{m_t}/
278*da0073e9SAndroid Build Coastguard Worker                \big(\sqrt{\widehat{v_t}^{max}} + \epsilon \big)                                 \\
279*da0073e9SAndroid Build Coastguard Worker            &\hspace{5mm}\textbf{else}                                                           \\
280*da0073e9SAndroid Build Coastguard Worker            &\hspace{10mm}\theta_t \leftarrow \theta_{t-1} - \gamma \widehat{m_t}/
281*da0073e9SAndroid Build Coastguard Worker                \big(\sqrt{\widehat{v_t}} + \epsilon \big)                                       \\
282*da0073e9SAndroid Build Coastguard Worker            &\rule{110mm}{0.4pt}                                                          \\[-1.ex]
283*da0073e9SAndroid Build Coastguard Worker            &\bf{return} \:  \theta_t                                                     \\[-1.ex]
284*da0073e9SAndroid Build Coastguard Worker            &\rule{110mm}{0.4pt}                                                          \\[-1.ex]
285*da0073e9SAndroid Build Coastguard Worker       \end{aligned}
286*da0073e9SAndroid Build Coastguard Worker
287*da0073e9SAndroid Build Coastguard Worker    For further details regarding the algorithm we refer to `Adam: A Method for Stochastic Optimization`_.
288*da0073e9SAndroid Build Coastguard Worker    """
289*da0073e9SAndroid Build Coastguard Worker    + rf"""
290*da0073e9SAndroid Build Coastguard Worker    Args:
291*da0073e9SAndroid Build Coastguard Worker        params (iterable): iterable of parameters to optimize or dicts defining
292*da0073e9SAndroid Build Coastguard Worker            parameter groups
293*da0073e9SAndroid Build Coastguard Worker        lr (float, Tensor, optional): learning rate (default: 1e-3). A tensor LR
294*da0073e9SAndroid Build Coastguard Worker            is not yet supported for all our implementations. Please use a float
295*da0073e9SAndroid Build Coastguard Worker            LR if you are not also specifying fused=True or capturable=True.
296*da0073e9SAndroid Build Coastguard Worker        betas (Tuple[float, float], optional): coefficients used for computing
297*da0073e9SAndroid Build Coastguard Worker            running averages of gradient and its square (default: (0.9, 0.999))
298*da0073e9SAndroid Build Coastguard Worker        eps (float, optional): term added to the denominator to improve
299*da0073e9SAndroid Build Coastguard Worker            numerical stability (default: 1e-8)
300*da0073e9SAndroid Build Coastguard Worker        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
301*da0073e9SAndroid Build Coastguard Worker        amsgrad (bool, optional): whether to use the AMSGrad variant of this
302*da0073e9SAndroid Build Coastguard Worker            algorithm from the paper `On the Convergence of Adam and Beyond`_
303*da0073e9SAndroid Build Coastguard Worker            (default: False)
304*da0073e9SAndroid Build Coastguard Worker        {_foreach_doc}
305*da0073e9SAndroid Build Coastguard Worker        {_maximize_doc}
306*da0073e9SAndroid Build Coastguard Worker        {_capturable_doc}
307*da0073e9SAndroid Build Coastguard Worker        {_differentiable_doc}
308*da0073e9SAndroid Build Coastguard Worker        {_fused_doc}
309*da0073e9SAndroid Build Coastguard Worker    .. Note::
310*da0073e9SAndroid Build Coastguard Worker        A prototype implementation of Adam and AdamW for MPS supports `torch.float32` and `torch.float16`.
311*da0073e9SAndroid Build Coastguard Worker    .. _Adam\: A Method for Stochastic Optimization:
312*da0073e9SAndroid Build Coastguard Worker        https://arxiv.org/abs/1412.6980
313*da0073e9SAndroid Build Coastguard Worker    .. _On the Convergence of Adam and Beyond:
314*da0073e9SAndroid Build Coastguard Worker        https://openreview.net/forum?id=ryQu7f-RZ
315*da0073e9SAndroid Build Coastguard Worker
316*da0073e9SAndroid Build Coastguard Worker    """
317*da0073e9SAndroid Build Coastguard Worker)
318*da0073e9SAndroid Build Coastguard Worker
319*da0073e9SAndroid Build Coastguard Worker
320*da0073e9SAndroid Build Coastguard Workerdef _single_tensor_adam(
321*da0073e9SAndroid Build Coastguard Worker    params: List[Tensor],
322*da0073e9SAndroid Build Coastguard Worker    grads: List[Tensor],
323*da0073e9SAndroid Build Coastguard Worker    exp_avgs: List[Tensor],
324*da0073e9SAndroid Build Coastguard Worker    exp_avg_sqs: List[Tensor],
325*da0073e9SAndroid Build Coastguard Worker    max_exp_avg_sqs: List[Tensor],
326*da0073e9SAndroid Build Coastguard Worker    state_steps: List[Tensor],
327*da0073e9SAndroid Build Coastguard Worker    grad_scale: Optional[Tensor],
328*da0073e9SAndroid Build Coastguard Worker    found_inf: Optional[Tensor],
329*da0073e9SAndroid Build Coastguard Worker    *,
330*da0073e9SAndroid Build Coastguard Worker    amsgrad: bool,
331*da0073e9SAndroid Build Coastguard Worker    has_complex: bool,
332*da0073e9SAndroid Build Coastguard Worker    beta1: float,
333*da0073e9SAndroid Build Coastguard Worker    beta2: float,
334*da0073e9SAndroid Build Coastguard Worker    lr: Union[float, Tensor],
335*da0073e9SAndroid Build Coastguard Worker    weight_decay: float,
336*da0073e9SAndroid Build Coastguard Worker    eps: float,
337*da0073e9SAndroid Build Coastguard Worker    maximize: bool,
338*da0073e9SAndroid Build Coastguard Worker    capturable: bool,
339*da0073e9SAndroid Build Coastguard Worker    differentiable: bool,
340*da0073e9SAndroid Build Coastguard Worker):
341*da0073e9SAndroid Build Coastguard Worker    assert grad_scale is None and found_inf is None
342*da0073e9SAndroid Build Coastguard Worker
343*da0073e9SAndroid Build Coastguard Worker    if torch.jit.is_scripting():
344*da0073e9SAndroid Build Coastguard Worker        # this assert is due to JIT being dumb and not realizing that the ops below
345*da0073e9SAndroid Build Coastguard Worker        # have overloads to handle both float and Tensor lrs, so we just assert it's
346*da0073e9SAndroid Build Coastguard Worker        # a float since most people using JIT are using floats
347*da0073e9SAndroid Build Coastguard Worker        assert isinstance(lr, float)
348*da0073e9SAndroid Build Coastguard Worker
349*da0073e9SAndroid Build Coastguard Worker    for i, param in enumerate(params):
350*da0073e9SAndroid Build Coastguard Worker        grad = grads[i] if not maximize else -grads[i]
351*da0073e9SAndroid Build Coastguard Worker        exp_avg = exp_avgs[i]
352*da0073e9SAndroid Build Coastguard Worker        exp_avg_sq = exp_avg_sqs[i]
353*da0073e9SAndroid Build Coastguard Worker        step_t = state_steps[i]
354*da0073e9SAndroid Build Coastguard Worker
355*da0073e9SAndroid Build Coastguard Worker        # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
356*da0073e9SAndroid Build Coastguard Worker        if not torch._utils.is_compiling() and capturable:
357*da0073e9SAndroid Build Coastguard Worker            capturable_supported_devices = _get_capturable_supported_devices()
358*da0073e9SAndroid Build Coastguard Worker            assert (
359*da0073e9SAndroid Build Coastguard Worker                param.device.type == step_t.device.type
360*da0073e9SAndroid Build Coastguard Worker                and param.device.type in capturable_supported_devices
361*da0073e9SAndroid Build Coastguard Worker            ), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
362*da0073e9SAndroid Build Coastguard Worker
363*da0073e9SAndroid Build Coastguard Worker        # update step
364*da0073e9SAndroid Build Coastguard Worker        step_t += 1
365*da0073e9SAndroid Build Coastguard Worker
366*da0073e9SAndroid Build Coastguard Worker        if weight_decay != 0:
367*da0073e9SAndroid Build Coastguard Worker            grad = grad.add(param, alpha=weight_decay)
368*da0073e9SAndroid Build Coastguard Worker
369*da0073e9SAndroid Build Coastguard Worker        if torch.is_complex(param):
370*da0073e9SAndroid Build Coastguard Worker            grad = torch.view_as_real(grad)
371*da0073e9SAndroid Build Coastguard Worker            exp_avg = torch.view_as_real(exp_avg)
372*da0073e9SAndroid Build Coastguard Worker            exp_avg_sq = torch.view_as_real(exp_avg_sq)
373*da0073e9SAndroid Build Coastguard Worker            if amsgrad:
374*da0073e9SAndroid Build Coastguard Worker                max_exp_avg_sqs[i] = torch.view_as_real(max_exp_avg_sqs[i])
375*da0073e9SAndroid Build Coastguard Worker            param = torch.view_as_real(param)
376*da0073e9SAndroid Build Coastguard Worker
377*da0073e9SAndroid Build Coastguard Worker        # Decay the first and second moment running average coefficient
378*da0073e9SAndroid Build Coastguard Worker        exp_avg.lerp_(grad, 1 - beta1)
379*da0073e9SAndroid Build Coastguard Worker        exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2)
380*da0073e9SAndroid Build Coastguard Worker
381*da0073e9SAndroid Build Coastguard Worker        if capturable or differentiable:
382*da0073e9SAndroid Build Coastguard Worker            step = step_t
383*da0073e9SAndroid Build Coastguard Worker
384*da0073e9SAndroid Build Coastguard Worker            bias_correction1 = 1 - beta1**step
385*da0073e9SAndroid Build Coastguard Worker            bias_correction2 = 1 - beta2**step
386*da0073e9SAndroid Build Coastguard Worker
387*da0073e9SAndroid Build Coastguard Worker            step_size = lr / bias_correction1
388*da0073e9SAndroid Build Coastguard Worker            step_size_neg = step_size.neg()
389*da0073e9SAndroid Build Coastguard Worker
390*da0073e9SAndroid Build Coastguard Worker            bias_correction2_sqrt = bias_correction2.sqrt()
391*da0073e9SAndroid Build Coastguard Worker
392*da0073e9SAndroid Build Coastguard Worker            if amsgrad:
393*da0073e9SAndroid Build Coastguard Worker                # Maintains the maximum of all 2nd moment running avg. till now
394*da0073e9SAndroid Build Coastguard Worker                if differentiable:
395*da0073e9SAndroid Build Coastguard Worker                    max_exp_avg_sq = max_exp_avg_sqs[i].clone()
396*da0073e9SAndroid Build Coastguard Worker                else:
397*da0073e9SAndroid Build Coastguard Worker                    max_exp_avg_sq = max_exp_avg_sqs[i]
398*da0073e9SAndroid Build Coastguard Worker
399*da0073e9SAndroid Build Coastguard Worker                max_exp_avg_sqs[i].copy_(torch.maximum(max_exp_avg_sq, exp_avg_sq))
400*da0073e9SAndroid Build Coastguard Worker
401*da0073e9SAndroid Build Coastguard Worker                # Uses the max. for normalizing running avg. of gradient
402*da0073e9SAndroid Build Coastguard Worker                # Folds in (admittedly ugly) 1-elem step_size math here to avoid extra param-set-sized read+write
403*da0073e9SAndroid Build Coastguard Worker                # (can't fold it into addcdiv_ below because addcdiv_ requires value is a Number, not a Tensor)
404*da0073e9SAndroid Build Coastguard Worker                denom = (
405*da0073e9SAndroid Build Coastguard Worker                    max_exp_avg_sqs[i].sqrt() / (bias_correction2_sqrt * step_size_neg)
406*da0073e9SAndroid Build Coastguard Worker                ).add_(eps / step_size_neg)
407*da0073e9SAndroid Build Coastguard Worker            else:
408*da0073e9SAndroid Build Coastguard Worker                denom = (
409*da0073e9SAndroid Build Coastguard Worker                    exp_avg_sq.sqrt() / (bias_correction2_sqrt * step_size_neg)
410*da0073e9SAndroid Build Coastguard Worker                ).add_(eps / step_size_neg)
411*da0073e9SAndroid Build Coastguard Worker
412*da0073e9SAndroid Build Coastguard Worker            param.addcdiv_(exp_avg, denom)
413*da0073e9SAndroid Build Coastguard Worker        else:
414*da0073e9SAndroid Build Coastguard Worker            step = _get_value(step_t)
415*da0073e9SAndroid Build Coastguard Worker
416*da0073e9SAndroid Build Coastguard Worker            bias_correction1 = 1 - beta1**step
417*da0073e9SAndroid Build Coastguard Worker            bias_correction2 = 1 - beta2**step
418*da0073e9SAndroid Build Coastguard Worker
419*da0073e9SAndroid Build Coastguard Worker            step_size = lr / bias_correction1
420*da0073e9SAndroid Build Coastguard Worker
421*da0073e9SAndroid Build Coastguard Worker            bias_correction2_sqrt = bias_correction2**0.5
422*da0073e9SAndroid Build Coastguard Worker
423*da0073e9SAndroid Build Coastguard Worker            if amsgrad:
424*da0073e9SAndroid Build Coastguard Worker                # Maintains the maximum of all 2nd moment running avg. till now
425*da0073e9SAndroid Build Coastguard Worker                torch.maximum(max_exp_avg_sqs[i], exp_avg_sq, out=max_exp_avg_sqs[i])
426*da0073e9SAndroid Build Coastguard Worker
427*da0073e9SAndroid Build Coastguard Worker                # Use the max. for normalizing running avg. of gradient
428*da0073e9SAndroid Build Coastguard Worker                denom = (max_exp_avg_sqs[i].sqrt() / bias_correction2_sqrt).add_(eps)
429*da0073e9SAndroid Build Coastguard Worker            else:
430*da0073e9SAndroid Build Coastguard Worker                denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps)
431*da0073e9SAndroid Build Coastguard Worker
432*da0073e9SAndroid Build Coastguard Worker            param.addcdiv_(exp_avg, denom, value=-step_size)
433*da0073e9SAndroid Build Coastguard Worker
434*da0073e9SAndroid Build Coastguard Worker        # Lastly, switch back to complex view
435*da0073e9SAndroid Build Coastguard Worker        if amsgrad and torch.is_complex(params[i]):
436*da0073e9SAndroid Build Coastguard Worker            max_exp_avg_sqs[i] = torch.view_as_complex(max_exp_avg_sqs[i])
437*da0073e9SAndroid Build Coastguard Worker
438*da0073e9SAndroid Build Coastguard Worker
439*da0073e9SAndroid Build Coastguard Workerdef _multi_tensor_adam(
440*da0073e9SAndroid Build Coastguard Worker    params: List[Tensor],
441*da0073e9SAndroid Build Coastguard Worker    grads: List[Tensor],
442*da0073e9SAndroid Build Coastguard Worker    exp_avgs: List[Tensor],
443*da0073e9SAndroid Build Coastguard Worker    exp_avg_sqs: List[Tensor],
444*da0073e9SAndroid Build Coastguard Worker    max_exp_avg_sqs: List[Tensor],
445*da0073e9SAndroid Build Coastguard Worker    state_steps: List[Tensor],
446*da0073e9SAndroid Build Coastguard Worker    grad_scale: Optional[Tensor],
447*da0073e9SAndroid Build Coastguard Worker    found_inf: Optional[Tensor],
448*da0073e9SAndroid Build Coastguard Worker    *,
449*da0073e9SAndroid Build Coastguard Worker    amsgrad: bool,
450*da0073e9SAndroid Build Coastguard Worker    has_complex: bool,
451*da0073e9SAndroid Build Coastguard Worker    beta1: float,
452*da0073e9SAndroid Build Coastguard Worker    beta2: float,
453*da0073e9SAndroid Build Coastguard Worker    lr: Union[float, Tensor],
454*da0073e9SAndroid Build Coastguard Worker    weight_decay: float,
455*da0073e9SAndroid Build Coastguard Worker    eps: float,
456*da0073e9SAndroid Build Coastguard Worker    maximize: bool,
457*da0073e9SAndroid Build Coastguard Worker    capturable: bool,
458*da0073e9SAndroid Build Coastguard Worker    differentiable: bool,
459*da0073e9SAndroid Build Coastguard Worker):
460*da0073e9SAndroid Build Coastguard Worker    if len(params) == 0:
461*da0073e9SAndroid Build Coastguard Worker        return
462*da0073e9SAndroid Build Coastguard Worker
463*da0073e9SAndroid Build Coastguard Worker    if isinstance(lr, Tensor) and not capturable:
464*da0073e9SAndroid Build Coastguard Worker        raise RuntimeError(
465*da0073e9SAndroid Build Coastguard Worker            "lr as a Tensor is not supported for capturable=False and foreach=True"
466*da0073e9SAndroid Build Coastguard Worker        )
467*da0073e9SAndroid Build Coastguard Worker
468*da0073e9SAndroid Build Coastguard Worker    # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
469*da0073e9SAndroid Build Coastguard Worker    if not torch._utils.is_compiling() and capturable:
470*da0073e9SAndroid Build Coastguard Worker        capturable_supported_devices = _get_capturable_supported_devices(
471*da0073e9SAndroid Build Coastguard Worker            supports_xla=False
472*da0073e9SAndroid Build Coastguard Worker        )
473*da0073e9SAndroid Build Coastguard Worker        assert all(
474*da0073e9SAndroid Build Coastguard Worker            p.device.type == step.device.type
475*da0073e9SAndroid Build Coastguard Worker            and p.device.type in capturable_supported_devices
476*da0073e9SAndroid Build Coastguard Worker            for p, step in zip(params, state_steps)
477*da0073e9SAndroid Build Coastguard Worker        ), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
478*da0073e9SAndroid Build Coastguard Worker
479*da0073e9SAndroid Build Coastguard Worker    assert grad_scale is None and found_inf is None
480*da0073e9SAndroid Build Coastguard Worker
481*da0073e9SAndroid Build Coastguard Worker    assert not differentiable, "_foreach ops don't support autograd"
482*da0073e9SAndroid Build Coastguard Worker
483*da0073e9SAndroid Build Coastguard Worker    grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
484*da0073e9SAndroid Build Coastguard Worker        [params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps]  # type: ignore[list-item]
485*da0073e9SAndroid Build Coastguard Worker    )
486*da0073e9SAndroid Build Coastguard Worker    for (
487*da0073e9SAndroid Build Coastguard Worker        device_params_,
488*da0073e9SAndroid Build Coastguard Worker        device_grads_,
489*da0073e9SAndroid Build Coastguard Worker        device_exp_avgs_,
490*da0073e9SAndroid Build Coastguard Worker        device_exp_avg_sqs_,
491*da0073e9SAndroid Build Coastguard Worker        device_max_exp_avg_sqs_,
492*da0073e9SAndroid Build Coastguard Worker        device_state_steps_,
493*da0073e9SAndroid Build Coastguard Worker    ), _ in grouped_tensors.values():
494*da0073e9SAndroid Build Coastguard Worker        device_params = cast(List[Tensor], device_params_)
495*da0073e9SAndroid Build Coastguard Worker        device_grads = cast(List[Tensor], device_grads_)
496*da0073e9SAndroid Build Coastguard Worker        device_exp_avgs = cast(List[Tensor], device_exp_avgs_)
497*da0073e9SAndroid Build Coastguard Worker        device_exp_avg_sqs = cast(List[Tensor], device_exp_avg_sqs_)
498*da0073e9SAndroid Build Coastguard Worker        device_state_steps = cast(List[Tensor], device_state_steps_)
499*da0073e9SAndroid Build Coastguard Worker
500*da0073e9SAndroid Build Coastguard Worker        # Handle complex parameters
501*da0073e9SAndroid Build Coastguard Worker        if has_complex:
502*da0073e9SAndroid Build Coastguard Worker            if amsgrad:
503*da0073e9SAndroid Build Coastguard Worker                device_max_exp_avg_sqs = cast(List[Tensor], device_max_exp_avg_sqs_)
504*da0073e9SAndroid Build Coastguard Worker                _view_as_real(
505*da0073e9SAndroid Build Coastguard Worker                    device_params,
506*da0073e9SAndroid Build Coastguard Worker                    device_grads,
507*da0073e9SAndroid Build Coastguard Worker                    device_exp_avgs,
508*da0073e9SAndroid Build Coastguard Worker                    device_exp_avg_sqs,
509*da0073e9SAndroid Build Coastguard Worker                    device_max_exp_avg_sqs,
510*da0073e9SAndroid Build Coastguard Worker                )
511*da0073e9SAndroid Build Coastguard Worker            else:
512*da0073e9SAndroid Build Coastguard Worker                _view_as_real(
513*da0073e9SAndroid Build Coastguard Worker                    device_params, device_grads, device_exp_avgs, device_exp_avg_sqs
514*da0073e9SAndroid Build Coastguard Worker                )
515*da0073e9SAndroid Build Coastguard Worker
516*da0073e9SAndroid Build Coastguard Worker        if maximize:
517*da0073e9SAndroid Build Coastguard Worker            device_grads = torch._foreach_neg(device_grads)  # type: ignore[assignment]
518*da0073e9SAndroid Build Coastguard Worker
519*da0073e9SAndroid Build Coastguard Worker        # Update steps
520*da0073e9SAndroid Build Coastguard Worker        # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
521*da0073e9SAndroid Build Coastguard Worker        # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
522*da0073e9SAndroid Build Coastguard Worker        # wrapped it once now. The alpha is required to assure we go to the right overload.
523*da0073e9SAndroid Build Coastguard Worker        if not torch._utils.is_compiling() and device_state_steps[0].is_cpu:
524*da0073e9SAndroid Build Coastguard Worker            torch._foreach_add_(
525*da0073e9SAndroid Build Coastguard Worker                device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
526*da0073e9SAndroid Build Coastguard Worker            )
527*da0073e9SAndroid Build Coastguard Worker        else:
528*da0073e9SAndroid Build Coastguard Worker            torch._foreach_add_(device_state_steps, 1)
529*da0073e9SAndroid Build Coastguard Worker
530*da0073e9SAndroid Build Coastguard Worker        if weight_decay != 0:
531*da0073e9SAndroid Build Coastguard Worker            # Re-use the intermediate memory (device_grads) already allocated for maximize
532*da0073e9SAndroid Build Coastguard Worker            if maximize:
533*da0073e9SAndroid Build Coastguard Worker                torch._foreach_add_(device_grads, device_params, alpha=weight_decay)
534*da0073e9SAndroid Build Coastguard Worker            else:
535*da0073e9SAndroid Build Coastguard Worker                device_grads = torch._foreach_add(  # type: ignore[assignment]
536*da0073e9SAndroid Build Coastguard Worker                    device_grads, device_params, alpha=weight_decay
537*da0073e9SAndroid Build Coastguard Worker                )
538*da0073e9SAndroid Build Coastguard Worker
539*da0073e9SAndroid Build Coastguard Worker        # Decay the first and second moment running average coefficient
540*da0073e9SAndroid Build Coastguard Worker        torch._foreach_lerp_(device_exp_avgs, device_grads, 1 - beta1)
541*da0073e9SAndroid Build Coastguard Worker
542*da0073e9SAndroid Build Coastguard Worker        torch._foreach_mul_(device_exp_avg_sqs, beta2)
543*da0073e9SAndroid Build Coastguard Worker        torch._foreach_addcmul_(
544*da0073e9SAndroid Build Coastguard Worker            device_exp_avg_sqs, device_grads, device_grads, 1 - beta2
545*da0073e9SAndroid Build Coastguard Worker        )
546*da0073e9SAndroid Build Coastguard Worker
547*da0073e9SAndroid Build Coastguard Worker        # Delete the local intermediate since it won't be used anymore to save on peak memory
548*da0073e9SAndroid Build Coastguard Worker        del device_grads
549*da0073e9SAndroid Build Coastguard Worker
550*da0073e9SAndroid Build Coastguard Worker        bias_correction1: Union[Tuple[Tensor, ...], List[Tensor]]
551*da0073e9SAndroid Build Coastguard Worker        bias_correction2: Union[Tuple[Tensor, ...], List[Tensor]]
552*da0073e9SAndroid Build Coastguard Worker        bias_correction2_sqrt: Union[Tuple[Tensor, ...], List[Tensor]]
553*da0073e9SAndroid Build Coastguard Worker
554*da0073e9SAndroid Build Coastguard Worker        if capturable:
555*da0073e9SAndroid Build Coastguard Worker            bias_correction1 = torch._foreach_pow(beta1, device_state_steps)
556*da0073e9SAndroid Build Coastguard Worker            bias_correction2 = torch._foreach_pow(beta2, device_state_steps)
557*da0073e9SAndroid Build Coastguard Worker            # foreach_sub doesn't allow a scalar as the first arg
558*da0073e9SAndroid Build Coastguard Worker            torch._foreach_sub_(bias_correction1, 1)
559*da0073e9SAndroid Build Coastguard Worker            torch._foreach_sub_(bias_correction2, 1)
560*da0073e9SAndroid Build Coastguard Worker            # we do not negate bias_correction1 as it'll need to be negated later anyway
561*da0073e9SAndroid Build Coastguard Worker            torch._foreach_neg_(bias_correction2)
562*da0073e9SAndroid Build Coastguard Worker
563*da0073e9SAndroid Build Coastguard Worker            # foreach_div doesn't allow a scalar as the first arg
564*da0073e9SAndroid Build Coastguard Worker            torch._foreach_div_(bias_correction1, lr)
565*da0073e9SAndroid Build Coastguard Worker            torch._foreach_reciprocal_(bias_correction1)
566*da0073e9SAndroid Build Coastguard Worker
567*da0073e9SAndroid Build Coastguard Worker            torch._foreach_sqrt_(bias_correction2)
568*da0073e9SAndroid Build Coastguard Worker
569*da0073e9SAndroid Build Coastguard Worker            # Re-assign for clarity as we maintain minimal intermediates: we'll have
570*da0073e9SAndroid Build Coastguard Worker            # step_size = - lr / (1 - beta1 ^ t) where t = num_steps
571*da0073e9SAndroid Build Coastguard Worker            # bias_correction2_sqrt = sqrt(1 - beta2 ^ t)
572*da0073e9SAndroid Build Coastguard Worker            step_size = bias_correction1
573*da0073e9SAndroid Build Coastguard Worker            bias_correction2_sqrt = bias_correction2
574*da0073e9SAndroid Build Coastguard Worker
575*da0073e9SAndroid Build Coastguard Worker            if amsgrad:
576*da0073e9SAndroid Build Coastguard Worker                device_max_exp_avg_sqs = cast(List[Tensor], device_max_exp_avg_sqs_)
577*da0073e9SAndroid Build Coastguard Worker                # Maintains the maximum of all 2nd moment running avg. till now
578*da0073e9SAndroid Build Coastguard Worker                torch._foreach_maximum_(device_max_exp_avg_sqs, device_exp_avg_sqs)  # type: ignore[assignment]
579*da0073e9SAndroid Build Coastguard Worker
580*da0073e9SAndroid Build Coastguard Worker                # Set intermediate to the max. for normalizing running avg. of gradient when amsgrad
581*da0073e9SAndroid Build Coastguard Worker                exp_avg_sq_sqrt = torch._foreach_sqrt(device_max_exp_avg_sqs)
582*da0073e9SAndroid Build Coastguard Worker            else:
583*da0073e9SAndroid Build Coastguard Worker                exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs)
584*da0073e9SAndroid Build Coastguard Worker
585*da0073e9SAndroid Build Coastguard Worker            torch._foreach_div_(exp_avg_sq_sqrt, bias_correction2_sqrt)
586*da0073e9SAndroid Build Coastguard Worker            torch._foreach_add_(exp_avg_sq_sqrt, eps)
587*da0073e9SAndroid Build Coastguard Worker            torch._foreach_div_(exp_avg_sq_sqrt, step_size)
588*da0073e9SAndroid Build Coastguard Worker
589*da0073e9SAndroid Build Coastguard Worker            # at this point, exp_avg_sq_sqrt = - (1 - beta^t) * [sqrt(exp_avg_sq / (1 - beta2^t)) + eps] / lr
590*da0073e9SAndroid Build Coastguard Worker            torch._foreach_addcdiv_(device_params, device_exp_avgs, exp_avg_sq_sqrt)
591*da0073e9SAndroid Build Coastguard Worker        else:
592*da0073e9SAndroid Build Coastguard Worker            bias_correction1 = [
593*da0073e9SAndroid Build Coastguard Worker                1 - beta1 ** _get_value(step) for step in device_state_steps
594*da0073e9SAndroid Build Coastguard Worker            ]
595*da0073e9SAndroid Build Coastguard Worker            bias_correction2 = [
596*da0073e9SAndroid Build Coastguard Worker                1 - beta2 ** _get_value(step) for step in device_state_steps
597*da0073e9SAndroid Build Coastguard Worker            ]
598*da0073e9SAndroid Build Coastguard Worker
599*da0073e9SAndroid Build Coastguard Worker            step_size = _stack_if_compiling([(lr / bc) * -1 for bc in bias_correction1])
600*da0073e9SAndroid Build Coastguard Worker
601*da0073e9SAndroid Build Coastguard Worker            bias_correction2_sqrt = [bc**0.5 for bc in bias_correction2]  # type: ignore[arg-type]
602*da0073e9SAndroid Build Coastguard Worker
603*da0073e9SAndroid Build Coastguard Worker            if amsgrad:
604*da0073e9SAndroid Build Coastguard Worker                device_max_exp_avg_sqs = cast(List[Tensor], device_max_exp_avg_sqs_)
605*da0073e9SAndroid Build Coastguard Worker                # Maintains the maximum of all 2nd moment running avg. till now
606*da0073e9SAndroid Build Coastguard Worker                torch._foreach_maximum_(device_max_exp_avg_sqs, device_exp_avg_sqs)
607*da0073e9SAndroid Build Coastguard Worker
608*da0073e9SAndroid Build Coastguard Worker                # Use the max. for normalizing running avg. of gradient
609*da0073e9SAndroid Build Coastguard Worker                exp_avg_sq_sqrt = torch._foreach_sqrt(device_max_exp_avg_sqs)
610*da0073e9SAndroid Build Coastguard Worker            else:
611*da0073e9SAndroid Build Coastguard Worker                exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs)
612*da0073e9SAndroid Build Coastguard Worker
613*da0073e9SAndroid Build Coastguard Worker            torch._foreach_div_(exp_avg_sq_sqrt, bias_correction2_sqrt)
614*da0073e9SAndroid Build Coastguard Worker            torch._foreach_add_(exp_avg_sq_sqrt, eps)
615*da0073e9SAndroid Build Coastguard Worker            torch._foreach_addcdiv_(
616*da0073e9SAndroid Build Coastguard Worker                device_params, device_exp_avgs, exp_avg_sq_sqrt, step_size  # type: ignore[arg-type]
617*da0073e9SAndroid Build Coastguard Worker            )
618*da0073e9SAndroid Build Coastguard Worker
619*da0073e9SAndroid Build Coastguard Worker
620*da0073e9SAndroid Build Coastguard Workerdef _fused_adam(
621*da0073e9SAndroid Build Coastguard Worker    params: List[Tensor],
622*da0073e9SAndroid Build Coastguard Worker    grads: List[Tensor],
623*da0073e9SAndroid Build Coastguard Worker    exp_avgs: List[Tensor],
624*da0073e9SAndroid Build Coastguard Worker    exp_avg_sqs: List[Tensor],
625*da0073e9SAndroid Build Coastguard Worker    max_exp_avg_sqs: List[Tensor],
626*da0073e9SAndroid Build Coastguard Worker    state_steps: List[Tensor],
627*da0073e9SAndroid Build Coastguard Worker    grad_scale: Optional[Tensor],
628*da0073e9SAndroid Build Coastguard Worker    found_inf: Optional[Tensor],
629*da0073e9SAndroid Build Coastguard Worker    *,
630*da0073e9SAndroid Build Coastguard Worker    amsgrad: bool,
631*da0073e9SAndroid Build Coastguard Worker    has_complex: bool,  # Needed for consistency.
632*da0073e9SAndroid Build Coastguard Worker    beta1: float,
633*da0073e9SAndroid Build Coastguard Worker    beta2: float,
634*da0073e9SAndroid Build Coastguard Worker    lr: Union[float, Tensor],
635*da0073e9SAndroid Build Coastguard Worker    weight_decay: float,
636*da0073e9SAndroid Build Coastguard Worker    eps: float,
637*da0073e9SAndroid Build Coastguard Worker    maximize: bool,
638*da0073e9SAndroid Build Coastguard Worker    capturable: bool,  # Needed for consistency.
639*da0073e9SAndroid Build Coastguard Worker    differentiable: bool,
640*da0073e9SAndroid Build Coastguard Worker) -> None:
641*da0073e9SAndroid Build Coastguard Worker    if not params:
642*da0073e9SAndroid Build Coastguard Worker        return
643*da0073e9SAndroid Build Coastguard Worker    if differentiable:
644*da0073e9SAndroid Build Coastguard Worker        raise RuntimeError("Adam with fused=True does not support differentiable=True")
645*da0073e9SAndroid Build Coastguard Worker
646*da0073e9SAndroid Build Coastguard Worker    grad_scale_dict: DeviceDict = (
647*da0073e9SAndroid Build Coastguard Worker        {grad_scale.device: grad_scale} if grad_scale is not None else {}
648*da0073e9SAndroid Build Coastguard Worker    )
649*da0073e9SAndroid Build Coastguard Worker    found_inf_dict: DeviceDict = (
650*da0073e9SAndroid Build Coastguard Worker        {found_inf.device: found_inf} if found_inf is not None else {}
651*da0073e9SAndroid Build Coastguard Worker    )
652*da0073e9SAndroid Build Coastguard Worker
653*da0073e9SAndroid Build Coastguard Worker    # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer
654*da0073e9SAndroid Build Coastguard Worker    # treating it as a scalar.
655*da0073e9SAndroid Build Coastguard Worker    lr_dict: Optional[DeviceDict] = (
656*da0073e9SAndroid Build Coastguard Worker        {lr.device: lr} if isinstance(lr, Tensor) and str(lr.device) != "cpu" else None
657*da0073e9SAndroid Build Coastguard Worker    )
658*da0073e9SAndroid Build Coastguard Worker    grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
659*da0073e9SAndroid Build Coastguard Worker        [params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps]  # type: ignore[list-item]
660*da0073e9SAndroid Build Coastguard Worker    )
661*da0073e9SAndroid Build Coastguard Worker    for (device, _), (
662*da0073e9SAndroid Build Coastguard Worker        (
663*da0073e9SAndroid Build Coastguard Worker            device_params_,
664*da0073e9SAndroid Build Coastguard Worker            device_grads_,
665*da0073e9SAndroid Build Coastguard Worker            device_exp_avgs_,
666*da0073e9SAndroid Build Coastguard Worker            device_exp_avg_sqs_,
667*da0073e9SAndroid Build Coastguard Worker            device_max_exp_avg_sqs,
668*da0073e9SAndroid Build Coastguard Worker            device_state_steps_,
669*da0073e9SAndroid Build Coastguard Worker        ),
670*da0073e9SAndroid Build Coastguard Worker        _,
671*da0073e9SAndroid Build Coastguard Worker    ) in grouped_tensors.items():
672*da0073e9SAndroid Build Coastguard Worker        device_params = cast(List[Tensor], device_params_)
673*da0073e9SAndroid Build Coastguard Worker        device_grads = cast(List[Tensor], device_grads_)
674*da0073e9SAndroid Build Coastguard Worker        device_exp_avgs = cast(List[Tensor], device_exp_avgs_)
675*da0073e9SAndroid Build Coastguard Worker        device_exp_avg_sqs = cast(List[Tensor], device_exp_avg_sqs_)
676*da0073e9SAndroid Build Coastguard Worker        device_state_steps = cast(List[Tensor], device_state_steps_)
677*da0073e9SAndroid Build Coastguard Worker
678*da0073e9SAndroid Build Coastguard Worker        if device.type == "mps":  # type: ignore[union-attr]
679*da0073e9SAndroid Build Coastguard Worker            assert found_inf is None and grad_scale is None
680*da0073e9SAndroid Build Coastguard Worker
681*da0073e9SAndroid Build Coastguard Worker        device_grad_scale, device_found_inf = None, None
682*da0073e9SAndroid Build Coastguard Worker        if grad_scale is not None:
683*da0073e9SAndroid Build Coastguard Worker            device_grad_scale = grad_scale_dict.setdefault(
684*da0073e9SAndroid Build Coastguard Worker                device, grad_scale.to(device, non_blocking=True)
685*da0073e9SAndroid Build Coastguard Worker            )
686*da0073e9SAndroid Build Coastguard Worker        if found_inf is not None:
687*da0073e9SAndroid Build Coastguard Worker            device_found_inf = found_inf_dict.setdefault(
688*da0073e9SAndroid Build Coastguard Worker                device, found_inf.to(device, non_blocking=True)
689*da0073e9SAndroid Build Coastguard Worker            )
690*da0073e9SAndroid Build Coastguard Worker        if lr_dict is not None and device not in lr_dict:
691*da0073e9SAndroid Build Coastguard Worker            lr_dict[device] = lr.to(device=device, non_blocking=True)  # type: ignore[union-attr]
692*da0073e9SAndroid Build Coastguard Worker            lr = lr_dict[device]
693*da0073e9SAndroid Build Coastguard Worker        torch._foreach_add_(device_state_steps, 1)
694*da0073e9SAndroid Build Coastguard Worker        torch._fused_adam_(
695*da0073e9SAndroid Build Coastguard Worker            device_params,
696*da0073e9SAndroid Build Coastguard Worker            device_grads,
697*da0073e9SAndroid Build Coastguard Worker            device_exp_avgs,
698*da0073e9SAndroid Build Coastguard Worker            device_exp_avg_sqs,
699*da0073e9SAndroid Build Coastguard Worker            device_max_exp_avg_sqs,  # type: ignore[arg-type]
700*da0073e9SAndroid Build Coastguard Worker            device_state_steps,
701*da0073e9SAndroid Build Coastguard Worker            amsgrad=amsgrad,
702*da0073e9SAndroid Build Coastguard Worker            lr=lr,  # type: ignore[arg-type]
703*da0073e9SAndroid Build Coastguard Worker            beta1=beta1,
704*da0073e9SAndroid Build Coastguard Worker            beta2=beta2,
705*da0073e9SAndroid Build Coastguard Worker            weight_decay=weight_decay,
706*da0073e9SAndroid Build Coastguard Worker            eps=eps,
707*da0073e9SAndroid Build Coastguard Worker            maximize=maximize,
708*da0073e9SAndroid Build Coastguard Worker            grad_scale=device_grad_scale,
709*da0073e9SAndroid Build Coastguard Worker            found_inf=device_found_inf,
710*da0073e9SAndroid Build Coastguard Worker        )
711*da0073e9SAndroid Build Coastguard Worker        if device_found_inf is not None:
712*da0073e9SAndroid Build Coastguard Worker            torch._foreach_sub_(
713*da0073e9SAndroid Build Coastguard Worker                device_state_steps, [device_found_inf] * len(device_state_steps)
714*da0073e9SAndroid Build Coastguard Worker            )
715*da0073e9SAndroid Build Coastguard Worker
716*da0073e9SAndroid Build Coastguard Worker
717*da0073e9SAndroid Build Coastguard Worker@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_adam)
718*da0073e9SAndroid Build Coastguard Workerdef adam(
719*da0073e9SAndroid Build Coastguard Worker    params: List[Tensor],
720*da0073e9SAndroid Build Coastguard Worker    grads: List[Tensor],
721*da0073e9SAndroid Build Coastguard Worker    exp_avgs: List[Tensor],
722*da0073e9SAndroid Build Coastguard Worker    exp_avg_sqs: List[Tensor],
723*da0073e9SAndroid Build Coastguard Worker    max_exp_avg_sqs: List[Tensor],
724*da0073e9SAndroid Build Coastguard Worker    state_steps: List[Tensor],
725*da0073e9SAndroid Build Coastguard Worker    # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
726*da0073e9SAndroid Build Coastguard Worker    # setting this as kwarg for now as functional API is compiled by torch/distributed/optim
727*da0073e9SAndroid Build Coastguard Worker    foreach: Optional[bool] = None,
728*da0073e9SAndroid Build Coastguard Worker    capturable: bool = False,
729*da0073e9SAndroid Build Coastguard Worker    differentiable: bool = False,
730*da0073e9SAndroid Build Coastguard Worker    fused: Optional[bool] = None,
731*da0073e9SAndroid Build Coastguard Worker    grad_scale: Optional[Tensor] = None,
732*da0073e9SAndroid Build Coastguard Worker    found_inf: Optional[Tensor] = None,
733*da0073e9SAndroid Build Coastguard Worker    has_complex: bool = False,
734*da0073e9SAndroid Build Coastguard Worker    *,
735*da0073e9SAndroid Build Coastguard Worker    amsgrad: bool,
736*da0073e9SAndroid Build Coastguard Worker    beta1: float,
737*da0073e9SAndroid Build Coastguard Worker    beta2: float,
738*da0073e9SAndroid Build Coastguard Worker    lr: Union[float, Tensor],
739*da0073e9SAndroid Build Coastguard Worker    weight_decay: float,
740*da0073e9SAndroid Build Coastguard Worker    eps: float,
741*da0073e9SAndroid Build Coastguard Worker    maximize: bool,
742*da0073e9SAndroid Build Coastguard Worker):
743*da0073e9SAndroid Build Coastguard Worker    r"""Functional API that performs Adam algorithm computation.
744*da0073e9SAndroid Build Coastguard Worker
745*da0073e9SAndroid Build Coastguard Worker    See :class:`~torch.optim.Adam` for details.
746*da0073e9SAndroid Build Coastguard Worker    """
747*da0073e9SAndroid Build Coastguard Worker    # Respect when the user inputs False/True for foreach or fused. We only want to change
748*da0073e9SAndroid Build Coastguard Worker    # the default when neither have been user-specified. Note that we default to foreach
749*da0073e9SAndroid Build Coastguard Worker    # and pass False to use_fused. This is not a mistake--we want to give the fused impl
750*da0073e9SAndroid Build Coastguard Worker    # bake-in time before making it the default, even if it is typically faster.
751*da0073e9SAndroid Build Coastguard Worker    if fused is None and foreach is None:
752*da0073e9SAndroid Build Coastguard Worker        _, foreach = _default_to_fused_or_foreach(
753*da0073e9SAndroid Build Coastguard Worker            params, differentiable, use_fused=False
754*da0073e9SAndroid Build Coastguard Worker        )
755*da0073e9SAndroid Build Coastguard Worker        # Do not flip on foreach for the unsupported case where lr is a Tensor and capturable=False.
756*da0073e9SAndroid Build Coastguard Worker        if foreach and isinstance(lr, Tensor) and not capturable:
757*da0073e9SAndroid Build Coastguard Worker            foreach = False
758*da0073e9SAndroid Build Coastguard Worker    if fused is None:
759*da0073e9SAndroid Build Coastguard Worker        fused = False
760*da0073e9SAndroid Build Coastguard Worker    if foreach is None:
761*da0073e9SAndroid Build Coastguard Worker        foreach = False
762*da0073e9SAndroid Build Coastguard Worker
763*da0073e9SAndroid Build Coastguard Worker    # this check is slow during compilation, so we skip it
764*da0073e9SAndroid Build Coastguard Worker    # if it's strictly needed we can add this check back in dynamo
765*da0073e9SAndroid Build Coastguard Worker    if not torch._utils.is_compiling() and not all(
766*da0073e9SAndroid Build Coastguard Worker        isinstance(t, torch.Tensor) for t in state_steps
767*da0073e9SAndroid Build Coastguard Worker    ):
768*da0073e9SAndroid Build Coastguard Worker        raise RuntimeError(
769*da0073e9SAndroid Build Coastguard Worker            "API has changed, `state_steps` argument must contain a list of singleton tensors"
770*da0073e9SAndroid Build Coastguard Worker        )
771*da0073e9SAndroid Build Coastguard Worker
772*da0073e9SAndroid Build Coastguard Worker    if foreach and torch.jit.is_scripting():
773*da0073e9SAndroid Build Coastguard Worker        raise RuntimeError("torch.jit.script not supported with foreach optimizers")
774*da0073e9SAndroid Build Coastguard Worker    if fused and torch.jit.is_scripting():
775*da0073e9SAndroid Build Coastguard Worker        raise RuntimeError("torch.jit.script not supported with fused optimizers")
776*da0073e9SAndroid Build Coastguard Worker
777*da0073e9SAndroid Build Coastguard Worker    if fused and not torch.jit.is_scripting():
778*da0073e9SAndroid Build Coastguard Worker        func = _fused_adam
779*da0073e9SAndroid Build Coastguard Worker    elif foreach and not torch.jit.is_scripting():
780*da0073e9SAndroid Build Coastguard Worker        func = _multi_tensor_adam
781*da0073e9SAndroid Build Coastguard Worker    else:
782*da0073e9SAndroid Build Coastguard Worker        func = _single_tensor_adam
783*da0073e9SAndroid Build Coastguard Worker
784*da0073e9SAndroid Build Coastguard Worker    func(
785*da0073e9SAndroid Build Coastguard Worker        params,
786*da0073e9SAndroid Build Coastguard Worker        grads,
787*da0073e9SAndroid Build Coastguard Worker        exp_avgs,
788*da0073e9SAndroid Build Coastguard Worker        exp_avg_sqs,
789*da0073e9SAndroid Build Coastguard Worker        max_exp_avg_sqs,
790*da0073e9SAndroid Build Coastguard Worker        state_steps,
791*da0073e9SAndroid Build Coastguard Worker        amsgrad=amsgrad,
792*da0073e9SAndroid Build Coastguard Worker        has_complex=has_complex,
793*da0073e9SAndroid Build Coastguard Worker        beta1=beta1,
794*da0073e9SAndroid Build Coastguard Worker        beta2=beta2,
795*da0073e9SAndroid Build Coastguard Worker        lr=lr,
796*da0073e9SAndroid Build Coastguard Worker        weight_decay=weight_decay,
797*da0073e9SAndroid Build Coastguard Worker        eps=eps,
798*da0073e9SAndroid Build Coastguard Worker        maximize=maximize,
799*da0073e9SAndroid Build Coastguard Worker        capturable=capturable,
800*da0073e9SAndroid Build Coastguard Worker        differentiable=differentiable,
801*da0073e9SAndroid Build Coastguard Worker        grad_scale=grad_scale,
802*da0073e9SAndroid Build Coastguard Worker        found_inf=found_inf,
803*da0073e9SAndroid Build Coastguard Worker    )
804