xref: /aosp_15_r20/external/pytorch/torch/optim/adamax.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-decorators
2*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs
3*da0073e9SAndroid Build Coastguard Workerfrom typing import cast, List, Optional, Tuple, Union
4*da0073e9SAndroid Build Coastguard Worker
5*da0073e9SAndroid Build Coastguard Workerimport torch
6*da0073e9SAndroid Build Coastguard Workerfrom torch import Tensor
7*da0073e9SAndroid Build Coastguard Worker
8*da0073e9SAndroid Build Coastguard Workerfrom .optimizer import (
9*da0073e9SAndroid Build Coastguard Worker    _capturable_doc,
10*da0073e9SAndroid Build Coastguard Worker    _default_to_fused_or_foreach,
11*da0073e9SAndroid Build Coastguard Worker    _differentiable_doc,
12*da0073e9SAndroid Build Coastguard Worker    _disable_dynamo_if_unsupported,
13*da0073e9SAndroid Build Coastguard Worker    _foreach_doc,
14*da0073e9SAndroid Build Coastguard Worker    _get_capturable_supported_devices,
15*da0073e9SAndroid Build Coastguard Worker    _get_scalar_dtype,
16*da0073e9SAndroid Build Coastguard Worker    _get_value,
17*da0073e9SAndroid Build Coastguard Worker    _maximize_doc,
18*da0073e9SAndroid Build Coastguard Worker    _use_grad_for_differentiable,
19*da0073e9SAndroid Build Coastguard Worker    _view_as_real,
20*da0073e9SAndroid Build Coastguard Worker    Optimizer,
21*da0073e9SAndroid Build Coastguard Worker    ParamsT,
22*da0073e9SAndroid Build Coastguard Worker)
23*da0073e9SAndroid Build Coastguard Worker
24*da0073e9SAndroid Build Coastguard Worker
25*da0073e9SAndroid Build Coastguard Worker__all__ = ["Adamax", "adamax"]
26*da0073e9SAndroid Build Coastguard Worker
27*da0073e9SAndroid Build Coastguard Worker
28*da0073e9SAndroid Build Coastguard Workerclass Adamax(Optimizer):
29*da0073e9SAndroid Build Coastguard Worker    def __init__(
30*da0073e9SAndroid Build Coastguard Worker        self,
31*da0073e9SAndroid Build Coastguard Worker        params: ParamsT,
32*da0073e9SAndroid Build Coastguard Worker        lr: Union[float, Tensor] = 2e-3,
33*da0073e9SAndroid Build Coastguard Worker        betas: Tuple[float, float] = (0.9, 0.999),
34*da0073e9SAndroid Build Coastguard Worker        eps: float = 1e-8,
35*da0073e9SAndroid Build Coastguard Worker        weight_decay: float = 0,
36*da0073e9SAndroid Build Coastguard Worker        foreach: Optional[bool] = None,
37*da0073e9SAndroid Build Coastguard Worker        *,
38*da0073e9SAndroid Build Coastguard Worker        maximize: bool = False,
39*da0073e9SAndroid Build Coastguard Worker        differentiable: bool = False,
40*da0073e9SAndroid Build Coastguard Worker        capturable: bool = False,
41*da0073e9SAndroid Build Coastguard Worker    ):
42*da0073e9SAndroid Build Coastguard Worker        if isinstance(lr, Tensor) and lr.numel() != 1:
43*da0073e9SAndroid Build Coastguard Worker            raise ValueError("Tensor lr must be 1-element")
44*da0073e9SAndroid Build Coastguard Worker        if not 0.0 <= lr:
45*da0073e9SAndroid Build Coastguard Worker            raise ValueError(f"Invalid learning rate: {lr}")
46*da0073e9SAndroid Build Coastguard Worker        if not 0.0 <= eps:
47*da0073e9SAndroid Build Coastguard Worker            raise ValueError(f"Invalid epsilon value: {eps}")
48*da0073e9SAndroid Build Coastguard Worker        if not 0.0 <= betas[0] < 1.0:
49*da0073e9SAndroid Build Coastguard Worker            raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
50*da0073e9SAndroid Build Coastguard Worker        if not 0.0 <= betas[1] < 1.0:
51*da0073e9SAndroid Build Coastguard Worker            raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
52*da0073e9SAndroid Build Coastguard Worker        if not 0.0 <= weight_decay:
53*da0073e9SAndroid Build Coastguard Worker            raise ValueError(f"Invalid weight_decay value: {weight_decay}")
54*da0073e9SAndroid Build Coastguard Worker
55*da0073e9SAndroid Build Coastguard Worker        defaults = dict(
56*da0073e9SAndroid Build Coastguard Worker            lr=lr,
57*da0073e9SAndroid Build Coastguard Worker            betas=betas,
58*da0073e9SAndroid Build Coastguard Worker            eps=eps,
59*da0073e9SAndroid Build Coastguard Worker            weight_decay=weight_decay,
60*da0073e9SAndroid Build Coastguard Worker            foreach=foreach,
61*da0073e9SAndroid Build Coastguard Worker            maximize=maximize,
62*da0073e9SAndroid Build Coastguard Worker            differentiable=differentiable,
63*da0073e9SAndroid Build Coastguard Worker            capturable=capturable,
64*da0073e9SAndroid Build Coastguard Worker        )
65*da0073e9SAndroid Build Coastguard Worker        super().__init__(params, defaults)
66*da0073e9SAndroid Build Coastguard Worker
67*da0073e9SAndroid Build Coastguard Worker    def __setstate__(self, state):
68*da0073e9SAndroid Build Coastguard Worker        super().__setstate__(state)
69*da0073e9SAndroid Build Coastguard Worker        for group in self.param_groups:
70*da0073e9SAndroid Build Coastguard Worker            group.setdefault("foreach", None)
71*da0073e9SAndroid Build Coastguard Worker            group.setdefault("maximize", False)
72*da0073e9SAndroid Build Coastguard Worker            group.setdefault("differentiable", False)
73*da0073e9SAndroid Build Coastguard Worker            group.setdefault("capturable", False)
74*da0073e9SAndroid Build Coastguard Worker            for p in group["params"]:
75*da0073e9SAndroid Build Coastguard Worker                p_state = self.state.get(p, [])
76*da0073e9SAndroid Build Coastguard Worker                if len(p_state) != 0 and not torch.is_tensor(p_state["step"]):
77*da0073e9SAndroid Build Coastguard Worker                    step_val = float(p_state["step"])
78*da0073e9SAndroid Build Coastguard Worker                    p_state["step"] = (
79*da0073e9SAndroid Build Coastguard Worker                        torch.tensor(
80*da0073e9SAndroid Build Coastguard Worker                            step_val, dtype=_get_scalar_dtype(), device=p.device
81*da0073e9SAndroid Build Coastguard Worker                        )
82*da0073e9SAndroid Build Coastguard Worker                        if group["capturable"]
83*da0073e9SAndroid Build Coastguard Worker                        else torch.tensor(step_val, dtype=_get_scalar_dtype())
84*da0073e9SAndroid Build Coastguard Worker                    )
85*da0073e9SAndroid Build Coastguard Worker
86*da0073e9SAndroid Build Coastguard Worker    def _init_group(
87*da0073e9SAndroid Build Coastguard Worker        self, group, params_with_grad, grads, exp_avgs, exp_infs, state_steps
88*da0073e9SAndroid Build Coastguard Worker    ):
89*da0073e9SAndroid Build Coastguard Worker        has_complex = False
90*da0073e9SAndroid Build Coastguard Worker        for p in group["params"]:
91*da0073e9SAndroid Build Coastguard Worker            if p.grad is None:
92*da0073e9SAndroid Build Coastguard Worker                continue
93*da0073e9SAndroid Build Coastguard Worker            has_complex |= torch.is_complex(p)
94*da0073e9SAndroid Build Coastguard Worker            params_with_grad.append(p)
95*da0073e9SAndroid Build Coastguard Worker            if p.grad.is_sparse:
96*da0073e9SAndroid Build Coastguard Worker                raise RuntimeError("Adamax does not support sparse gradients")
97*da0073e9SAndroid Build Coastguard Worker            grads.append(p.grad)
98*da0073e9SAndroid Build Coastguard Worker
99*da0073e9SAndroid Build Coastguard Worker            state = self.state[p]
100*da0073e9SAndroid Build Coastguard Worker
101*da0073e9SAndroid Build Coastguard Worker            # State initialization
102*da0073e9SAndroid Build Coastguard Worker            if len(state) == 0:
103*da0073e9SAndroid Build Coastguard Worker                state["step"] = (
104*da0073e9SAndroid Build Coastguard Worker                    torch.zeros((), dtype=_get_scalar_dtype(), device=p.device)
105*da0073e9SAndroid Build Coastguard Worker                    if group["capturable"]
106*da0073e9SAndroid Build Coastguard Worker                    else torch.tensor(0.0, dtype=_get_scalar_dtype())
107*da0073e9SAndroid Build Coastguard Worker                )
108*da0073e9SAndroid Build Coastguard Worker                state["exp_avg"] = torch.zeros_like(
109*da0073e9SAndroid Build Coastguard Worker                    p, memory_format=torch.preserve_format
110*da0073e9SAndroid Build Coastguard Worker                )
111*da0073e9SAndroid Build Coastguard Worker                state["exp_inf"] = torch.zeros_like(
112*da0073e9SAndroid Build Coastguard Worker                    p, memory_format=torch.preserve_format
113*da0073e9SAndroid Build Coastguard Worker                )
114*da0073e9SAndroid Build Coastguard Worker
115*da0073e9SAndroid Build Coastguard Worker            exp_avgs.append(state["exp_avg"])
116*da0073e9SAndroid Build Coastguard Worker            exp_infs.append(state["exp_inf"])
117*da0073e9SAndroid Build Coastguard Worker            state_steps.append(state["step"])
118*da0073e9SAndroid Build Coastguard Worker
119*da0073e9SAndroid Build Coastguard Worker        return has_complex
120*da0073e9SAndroid Build Coastguard Worker
121*da0073e9SAndroid Build Coastguard Worker    @_use_grad_for_differentiable
122*da0073e9SAndroid Build Coastguard Worker    def step(self, closure=None):
123*da0073e9SAndroid Build Coastguard Worker        """Performs a single optimization step.
124*da0073e9SAndroid Build Coastguard Worker
125*da0073e9SAndroid Build Coastguard Worker        Args:
126*da0073e9SAndroid Build Coastguard Worker            closure (Callable, optional): A closure that reevaluates the model
127*da0073e9SAndroid Build Coastguard Worker                and returns the loss.
128*da0073e9SAndroid Build Coastguard Worker        """
129*da0073e9SAndroid Build Coastguard Worker        self._cuda_graph_capture_health_check()
130*da0073e9SAndroid Build Coastguard Worker
131*da0073e9SAndroid Build Coastguard Worker        loss = None
132*da0073e9SAndroid Build Coastguard Worker        if closure is not None:
133*da0073e9SAndroid Build Coastguard Worker            with torch.enable_grad():
134*da0073e9SAndroid Build Coastguard Worker                loss = closure()
135*da0073e9SAndroid Build Coastguard Worker
136*da0073e9SAndroid Build Coastguard Worker        for group in self.param_groups:
137*da0073e9SAndroid Build Coastguard Worker            params_with_grad: List[Tensor] = []
138*da0073e9SAndroid Build Coastguard Worker            grads: List[Tensor] = []
139*da0073e9SAndroid Build Coastguard Worker            exp_avgs: List[Tensor] = []
140*da0073e9SAndroid Build Coastguard Worker            exp_infs: List[Tensor] = []
141*da0073e9SAndroid Build Coastguard Worker            state_steps: List[Tensor] = []
142*da0073e9SAndroid Build Coastguard Worker
143*da0073e9SAndroid Build Coastguard Worker            beta1, beta2 = group["betas"]
144*da0073e9SAndroid Build Coastguard Worker            eps = group["eps"]
145*da0073e9SAndroid Build Coastguard Worker            lr = group["lr"]
146*da0073e9SAndroid Build Coastguard Worker            weight_decay = group["weight_decay"]
147*da0073e9SAndroid Build Coastguard Worker            foreach = group["foreach"]
148*da0073e9SAndroid Build Coastguard Worker            maximize = group["maximize"]
149*da0073e9SAndroid Build Coastguard Worker            differentiable = group["differentiable"]
150*da0073e9SAndroid Build Coastguard Worker            capturable = group["capturable"]
151*da0073e9SAndroid Build Coastguard Worker
152*da0073e9SAndroid Build Coastguard Worker            has_complex = self._init_group(
153*da0073e9SAndroid Build Coastguard Worker                group, params_with_grad, grads, exp_avgs, exp_infs, state_steps
154*da0073e9SAndroid Build Coastguard Worker            )
155*da0073e9SAndroid Build Coastguard Worker
156*da0073e9SAndroid Build Coastguard Worker            adamax(
157*da0073e9SAndroid Build Coastguard Worker                params_with_grad,
158*da0073e9SAndroid Build Coastguard Worker                grads,
159*da0073e9SAndroid Build Coastguard Worker                exp_avgs,
160*da0073e9SAndroid Build Coastguard Worker                exp_infs,
161*da0073e9SAndroid Build Coastguard Worker                state_steps,
162*da0073e9SAndroid Build Coastguard Worker                eps=eps,
163*da0073e9SAndroid Build Coastguard Worker                beta1=beta1,
164*da0073e9SAndroid Build Coastguard Worker                beta2=beta2,
165*da0073e9SAndroid Build Coastguard Worker                lr=lr,
166*da0073e9SAndroid Build Coastguard Worker                weight_decay=weight_decay,
167*da0073e9SAndroid Build Coastguard Worker                foreach=foreach,
168*da0073e9SAndroid Build Coastguard Worker                maximize=maximize,
169*da0073e9SAndroid Build Coastguard Worker                differentiable=differentiable,
170*da0073e9SAndroid Build Coastguard Worker                capturable=capturable,
171*da0073e9SAndroid Build Coastguard Worker                has_complex=has_complex,
172*da0073e9SAndroid Build Coastguard Worker            )
173*da0073e9SAndroid Build Coastguard Worker
174*da0073e9SAndroid Build Coastguard Worker        return loss
175*da0073e9SAndroid Build Coastguard Worker
176*da0073e9SAndroid Build Coastguard Worker
177*da0073e9SAndroid Build Coastguard WorkerAdamax.__doc__ = (
178*da0073e9SAndroid Build Coastguard Worker    r"""Implements Adamax algorithm (a variant of Adam based on infinity norm).
179*da0073e9SAndroid Build Coastguard Worker
180*da0073e9SAndroid Build Coastguard Worker    .. math::
181*da0073e9SAndroid Build Coastguard Worker       \begin{aligned}
182*da0073e9SAndroid Build Coastguard Worker            &\rule{110mm}{0.4pt}                                                                 \\
183*da0073e9SAndroid Build Coastguard Worker            &\textbf{input}      : \gamma \text{ (lr)}, \beta_1, \beta_2
184*da0073e9SAndroid Build Coastguard Worker                \text{ (betas)},\theta_0 \text{ (params)},f(\theta) \text{ (objective)},
185*da0073e9SAndroid Build Coastguard Worker                \: \lambda \text{ (weight decay)},                                                \\
186*da0073e9SAndroid Build Coastguard Worker            &\hspace{13mm}    \epsilon \text{ (epsilon)}                                          \\
187*da0073e9SAndroid Build Coastguard Worker            &\textbf{initialize} :  m_0 \leftarrow 0 \text{ ( first moment)},
188*da0073e9SAndroid Build Coastguard Worker                u_0 \leftarrow 0 \text{ ( infinity norm)}                                 \\[-1.ex]
189*da0073e9SAndroid Build Coastguard Worker            &\rule{110mm}{0.4pt}                                                                 \\
190*da0073e9SAndroid Build Coastguard Worker            &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do}                         \\
191*da0073e9SAndroid Build Coastguard Worker            &\hspace{5mm}g_t           \leftarrow   \nabla_{\theta} f_t (\theta_{t-1})           \\
192*da0073e9SAndroid Build Coastguard Worker            &\hspace{5mm}if \: \lambda \neq 0                                                    \\
193*da0073e9SAndroid Build Coastguard Worker            &\hspace{10mm} g_t \leftarrow g_t + \lambda  \theta_{t-1}                            \\
194*da0073e9SAndroid Build Coastguard Worker            &\hspace{5mm}m_t      \leftarrow   \beta_1 m_{t-1} + (1 - \beta_1) g_t               \\
195*da0073e9SAndroid Build Coastguard Worker            &\hspace{5mm}u_t      \leftarrow   \mathrm{max}(\beta_2 u_{t-1}, |g_{t}|+\epsilon)   \\
196*da0073e9SAndroid Build Coastguard Worker            &\hspace{5mm}\theta_t \leftarrow \theta_{t-1} - \frac{\gamma m_t}{(1-\beta^t_1) u_t} \\
197*da0073e9SAndroid Build Coastguard Worker            &\rule{110mm}{0.4pt}                                                          \\[-1.ex]
198*da0073e9SAndroid Build Coastguard Worker            &\bf{return} \:  \theta_t                                                     \\[-1.ex]
199*da0073e9SAndroid Build Coastguard Worker            &\rule{110mm}{0.4pt}                                                          \\[-1.ex]
200*da0073e9SAndroid Build Coastguard Worker       \end{aligned}
201*da0073e9SAndroid Build Coastguard Worker
202*da0073e9SAndroid Build Coastguard Worker    For further details regarding the algorithm we refer to `Adam: A Method for Stochastic Optimization`_.
203*da0073e9SAndroid Build Coastguard Worker    """
204*da0073e9SAndroid Build Coastguard Worker    + rf"""
205*da0073e9SAndroid Build Coastguard Worker    Args:
206*da0073e9SAndroid Build Coastguard Worker        params (iterable): iterable of parameters to optimize or dicts defining
207*da0073e9SAndroid Build Coastguard Worker            parameter groups
208*da0073e9SAndroid Build Coastguard Worker        lr (float, Tensor, optional): learning rate (default: 2e-3)
209*da0073e9SAndroid Build Coastguard Worker        betas (Tuple[float, float], optional): coefficients used for computing
210*da0073e9SAndroid Build Coastguard Worker            running averages of gradient and its square
211*da0073e9SAndroid Build Coastguard Worker        eps (float, optional): term added to the denominator to improve
212*da0073e9SAndroid Build Coastguard Worker            numerical stability (default: 1e-8)
213*da0073e9SAndroid Build Coastguard Worker        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
214*da0073e9SAndroid Build Coastguard Worker        {_foreach_doc}
215*da0073e9SAndroid Build Coastguard Worker        {_maximize_doc}
216*da0073e9SAndroid Build Coastguard Worker        {_differentiable_doc}
217*da0073e9SAndroid Build Coastguard Worker        {_capturable_doc}
218*da0073e9SAndroid Build Coastguard Worker
219*da0073e9SAndroid Build Coastguard Worker    .. _Adam\: A Method for Stochastic Optimization:
220*da0073e9SAndroid Build Coastguard Worker        https://arxiv.org/abs/1412.6980
221*da0073e9SAndroid Build Coastguard Worker
222*da0073e9SAndroid Build Coastguard Worker    """
223*da0073e9SAndroid Build Coastguard Worker)
224*da0073e9SAndroid Build Coastguard Worker
225*da0073e9SAndroid Build Coastguard Worker
226*da0073e9SAndroid Build Coastguard Workerdef _single_tensor_adamax(
227*da0073e9SAndroid Build Coastguard Worker    params: List[Tensor],
228*da0073e9SAndroid Build Coastguard Worker    grads: List[Tensor],
229*da0073e9SAndroid Build Coastguard Worker    exp_avgs: List[Tensor],
230*da0073e9SAndroid Build Coastguard Worker    exp_infs: List[Tensor],
231*da0073e9SAndroid Build Coastguard Worker    state_steps: List[Tensor],
232*da0073e9SAndroid Build Coastguard Worker    *,
233*da0073e9SAndroid Build Coastguard Worker    eps: float,
234*da0073e9SAndroid Build Coastguard Worker    beta1: float,
235*da0073e9SAndroid Build Coastguard Worker    beta2: float,
236*da0073e9SAndroid Build Coastguard Worker    lr: float,
237*da0073e9SAndroid Build Coastguard Worker    weight_decay: float,
238*da0073e9SAndroid Build Coastguard Worker    maximize: bool,
239*da0073e9SAndroid Build Coastguard Worker    differentiable: bool,
240*da0073e9SAndroid Build Coastguard Worker    capturable: bool,
241*da0073e9SAndroid Build Coastguard Worker    has_complex: bool,
242*da0073e9SAndroid Build Coastguard Worker):
243*da0073e9SAndroid Build Coastguard Worker    for i, param in enumerate(params):
244*da0073e9SAndroid Build Coastguard Worker        grad = grads[i]
245*da0073e9SAndroid Build Coastguard Worker        grad = grad if not maximize else -grad
246*da0073e9SAndroid Build Coastguard Worker        exp_avg = exp_avgs[i]
247*da0073e9SAndroid Build Coastguard Worker        exp_inf = exp_infs[i]
248*da0073e9SAndroid Build Coastguard Worker        step_t = state_steps[i]
249*da0073e9SAndroid Build Coastguard Worker
250*da0073e9SAndroid Build Coastguard Worker        # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
251*da0073e9SAndroid Build Coastguard Worker        if not torch._utils.is_compiling() and capturable:
252*da0073e9SAndroid Build Coastguard Worker            capturable_supported_devices = _get_capturable_supported_devices()
253*da0073e9SAndroid Build Coastguard Worker            assert (
254*da0073e9SAndroid Build Coastguard Worker                param.device.type == step_t.device.type
255*da0073e9SAndroid Build Coastguard Worker                and param.device.type in capturable_supported_devices
256*da0073e9SAndroid Build Coastguard Worker            ), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
257*da0073e9SAndroid Build Coastguard Worker
258*da0073e9SAndroid Build Coastguard Worker        # update step
259*da0073e9SAndroid Build Coastguard Worker        step_t += 1
260*da0073e9SAndroid Build Coastguard Worker
261*da0073e9SAndroid Build Coastguard Worker        if weight_decay != 0:
262*da0073e9SAndroid Build Coastguard Worker            grad = grad.add(param, alpha=weight_decay)
263*da0073e9SAndroid Build Coastguard Worker
264*da0073e9SAndroid Build Coastguard Worker        if torch.is_complex(param):
265*da0073e9SAndroid Build Coastguard Worker            param = torch.view_as_real(param)
266*da0073e9SAndroid Build Coastguard Worker            grad = torch.view_as_real(grad)
267*da0073e9SAndroid Build Coastguard Worker            exp_avg = torch.view_as_real(exp_avg)
268*da0073e9SAndroid Build Coastguard Worker            exp_inf = torch.view_as_real(exp_inf)
269*da0073e9SAndroid Build Coastguard Worker
270*da0073e9SAndroid Build Coastguard Worker        # Update biased first moment estimate.
271*da0073e9SAndroid Build Coastguard Worker        exp_avg.lerp_(grad, 1 - beta1)
272*da0073e9SAndroid Build Coastguard Worker        # Update the exponentially weighted infinity norm.
273*da0073e9SAndroid Build Coastguard Worker        if not differentiable:
274*da0073e9SAndroid Build Coastguard Worker            torch.maximum(
275*da0073e9SAndroid Build Coastguard Worker                exp_inf.mul_(beta2),
276*da0073e9SAndroid Build Coastguard Worker                grad.abs().add_(eps),
277*da0073e9SAndroid Build Coastguard Worker                out=exp_inf,
278*da0073e9SAndroid Build Coastguard Worker            )
279*da0073e9SAndroid Build Coastguard Worker        else:
280*da0073e9SAndroid Build Coastguard Worker            norm_buf = torch.cat(
281*da0073e9SAndroid Build Coastguard Worker                [exp_inf.mul_(beta2).unsqueeze(0), grad.abs().add_(eps).unsqueeze_(0)],
282*da0073e9SAndroid Build Coastguard Worker                0,
283*da0073e9SAndroid Build Coastguard Worker            )
284*da0073e9SAndroid Build Coastguard Worker            exp_inf.copy_(torch.amax(norm_buf, 0, keepdim=False))
285*da0073e9SAndroid Build Coastguard Worker
286*da0073e9SAndroid Build Coastguard Worker        if capturable:
287*da0073e9SAndroid Build Coastguard Worker            # why jump through extra hoops and negate bias_correction? check out #121238
288*da0073e9SAndroid Build Coastguard Worker            # once fixed, we should use bias_correction with addcdiv value=-1 for readability
289*da0073e9SAndroid Build Coastguard Worker            neg_bias_correction = beta1**step_t - 1
290*da0073e9SAndroid Build Coastguard Worker            neg_bias_correction.div_(lr)
291*da0073e9SAndroid Build Coastguard Worker            denom = exp_inf * neg_bias_correction
292*da0073e9SAndroid Build Coastguard Worker            param.addcdiv_(exp_avg, denom)
293*da0073e9SAndroid Build Coastguard Worker        else:
294*da0073e9SAndroid Build Coastguard Worker            bias_correction = 1 - beta1 ** _get_value(step_t)
295*da0073e9SAndroid Build Coastguard Worker            clr = lr / bias_correction
296*da0073e9SAndroid Build Coastguard Worker
297*da0073e9SAndroid Build Coastguard Worker            param.addcdiv_(exp_avg, exp_inf, value=-clr)
298*da0073e9SAndroid Build Coastguard Worker
299*da0073e9SAndroid Build Coastguard Worker
300*da0073e9SAndroid Build Coastguard Workerdef _multi_tensor_adamax(
301*da0073e9SAndroid Build Coastguard Worker    params: List[Tensor],
302*da0073e9SAndroid Build Coastguard Worker    grads: List[Tensor],
303*da0073e9SAndroid Build Coastguard Worker    exp_avgs: List[Tensor],
304*da0073e9SAndroid Build Coastguard Worker    exp_infs: List[Tensor],
305*da0073e9SAndroid Build Coastguard Worker    state_steps: List[Tensor],
306*da0073e9SAndroid Build Coastguard Worker    *,
307*da0073e9SAndroid Build Coastguard Worker    eps: float,
308*da0073e9SAndroid Build Coastguard Worker    beta1: float,
309*da0073e9SAndroid Build Coastguard Worker    beta2: float,
310*da0073e9SAndroid Build Coastguard Worker    lr: float,
311*da0073e9SAndroid Build Coastguard Worker    weight_decay: float,
312*da0073e9SAndroid Build Coastguard Worker    maximize: bool,
313*da0073e9SAndroid Build Coastguard Worker    differentiable: bool,
314*da0073e9SAndroid Build Coastguard Worker    capturable: bool,
315*da0073e9SAndroid Build Coastguard Worker    has_complex: bool,
316*da0073e9SAndroid Build Coastguard Worker):
317*da0073e9SAndroid Build Coastguard Worker    assert not differentiable, "_foreach ops don't support autograd"
318*da0073e9SAndroid Build Coastguard Worker
319*da0073e9SAndroid Build Coastguard Worker    if len(params) == 0:
320*da0073e9SAndroid Build Coastguard Worker        return
321*da0073e9SAndroid Build Coastguard Worker
322*da0073e9SAndroid Build Coastguard Worker    # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
323*da0073e9SAndroid Build Coastguard Worker    if not torch._utils.is_compiling() and capturable:
324*da0073e9SAndroid Build Coastguard Worker        capturable_supported_devices = _get_capturable_supported_devices(
325*da0073e9SAndroid Build Coastguard Worker            supports_xla=False
326*da0073e9SAndroid Build Coastguard Worker        )
327*da0073e9SAndroid Build Coastguard Worker        assert all(
328*da0073e9SAndroid Build Coastguard Worker            p.device.type == step.device.type
329*da0073e9SAndroid Build Coastguard Worker            and p.device.type in capturable_supported_devices
330*da0073e9SAndroid Build Coastguard Worker            for p, step in zip(params, state_steps)
331*da0073e9SAndroid Build Coastguard Worker        ), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
332*da0073e9SAndroid Build Coastguard Worker
333*da0073e9SAndroid Build Coastguard Worker    grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
334*da0073e9SAndroid Build Coastguard Worker        [params, grads, exp_avgs, exp_infs, state_steps]  # type: ignore[list-item]
335*da0073e9SAndroid Build Coastguard Worker    )
336*da0073e9SAndroid Build Coastguard Worker    for (
337*da0073e9SAndroid Build Coastguard Worker        grouped_params_,
338*da0073e9SAndroid Build Coastguard Worker        grouped_grads_,
339*da0073e9SAndroid Build Coastguard Worker        grouped_exp_avgs_,
340*da0073e9SAndroid Build Coastguard Worker        grouped_exp_infs_,
341*da0073e9SAndroid Build Coastguard Worker        grouped_state_steps_,
342*da0073e9SAndroid Build Coastguard Worker    ), _ in grouped_tensors.values():
343*da0073e9SAndroid Build Coastguard Worker        grouped_params = cast(List[Tensor], grouped_params_)
344*da0073e9SAndroid Build Coastguard Worker        grouped_grads = cast(List[Tensor], grouped_grads_)
345*da0073e9SAndroid Build Coastguard Worker        grouped_exp_avgs = cast(List[Tensor], grouped_exp_avgs_)
346*da0073e9SAndroid Build Coastguard Worker        grouped_exp_infs = cast(List[Tensor], grouped_exp_infs_)
347*da0073e9SAndroid Build Coastguard Worker        grouped_state_steps = cast(List[Tensor], grouped_state_steps_)
348*da0073e9SAndroid Build Coastguard Worker
349*da0073e9SAndroid Build Coastguard Worker        if has_complex:
350*da0073e9SAndroid Build Coastguard Worker            _view_as_real(
351*da0073e9SAndroid Build Coastguard Worker                grouped_params, grouped_grads, grouped_exp_avgs, grouped_exp_infs
352*da0073e9SAndroid Build Coastguard Worker            )
353*da0073e9SAndroid Build Coastguard Worker
354*da0073e9SAndroid Build Coastguard Worker        if maximize:
355*da0073e9SAndroid Build Coastguard Worker            grouped_grads = torch._foreach_neg(grouped_grads)  # type: ignore[assignment]
356*da0073e9SAndroid Build Coastguard Worker
357*da0073e9SAndroid Build Coastguard Worker        # Update steps
358*da0073e9SAndroid Build Coastguard Worker        # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
359*da0073e9SAndroid Build Coastguard Worker        # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
360*da0073e9SAndroid Build Coastguard Worker        # wrapped it once now. The alpha is required to assure we go to the right overload.
361*da0073e9SAndroid Build Coastguard Worker        if not torch._utils.is_compiling() and grouped_state_steps[0].is_cpu:
362*da0073e9SAndroid Build Coastguard Worker            torch._foreach_add_(
363*da0073e9SAndroid Build Coastguard Worker                grouped_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
364*da0073e9SAndroid Build Coastguard Worker            )
365*da0073e9SAndroid Build Coastguard Worker        else:
366*da0073e9SAndroid Build Coastguard Worker            torch._foreach_add_(grouped_state_steps, 1)
367*da0073e9SAndroid Build Coastguard Worker
368*da0073e9SAndroid Build Coastguard Worker        if weight_decay != 0:
369*da0073e9SAndroid Build Coastguard Worker            if maximize:
370*da0073e9SAndroid Build Coastguard Worker                # Re-use the intermediate memory (grouped_grads) already allocated for maximize
371*da0073e9SAndroid Build Coastguard Worker                torch._foreach_add_(grouped_grads, grouped_params, alpha=weight_decay)
372*da0073e9SAndroid Build Coastguard Worker            else:
373*da0073e9SAndroid Build Coastguard Worker                grouped_grads = torch._foreach_add(  # type: ignore[assignment]
374*da0073e9SAndroid Build Coastguard Worker                    grouped_grads, grouped_params, alpha=weight_decay
375*da0073e9SAndroid Build Coastguard Worker                )
376*da0073e9SAndroid Build Coastguard Worker
377*da0073e9SAndroid Build Coastguard Worker        # Update biased first moment estimate.
378*da0073e9SAndroid Build Coastguard Worker        torch._foreach_lerp_(grouped_exp_avgs, grouped_grads, 1 - beta1)
379*da0073e9SAndroid Build Coastguard Worker
380*da0073e9SAndroid Build Coastguard Worker        # Update the exponentially weighted infinity norm.
381*da0073e9SAndroid Build Coastguard Worker        torch._foreach_mul_(grouped_exp_infs, beta2)
382*da0073e9SAndroid Build Coastguard Worker
383*da0073e9SAndroid Build Coastguard Worker        # in this case, we need to introduce a copy of the grads
384*da0073e9SAndroid Build Coastguard Worker        # since one has not been introduced previously
385*da0073e9SAndroid Build Coastguard Worker        if not maximize and weight_decay == 0:
386*da0073e9SAndroid Build Coastguard Worker            grouped_grads = torch._foreach_abs(grouped_grads)  # type: ignore[assignment]
387*da0073e9SAndroid Build Coastguard Worker        else:
388*da0073e9SAndroid Build Coastguard Worker            torch._foreach_abs_(grouped_grads)
389*da0073e9SAndroid Build Coastguard Worker
390*da0073e9SAndroid Build Coastguard Worker        torch._foreach_add_(grouped_grads, eps)
391*da0073e9SAndroid Build Coastguard Worker        torch._foreach_maximum_(grouped_exp_infs, grouped_grads)
392*da0073e9SAndroid Build Coastguard Worker
393*da0073e9SAndroid Build Coastguard Worker        bias_corrections: Union[Tuple[Tensor, ...], List[Tensor]]
394*da0073e9SAndroid Build Coastguard Worker        if capturable:
395*da0073e9SAndroid Build Coastguard Worker            bias_corrections = torch._foreach_pow(beta1, grouped_state_steps)
396*da0073e9SAndroid Build Coastguard Worker            # foreach_sub doesn't allow a scalar as the first arg
397*da0073e9SAndroid Build Coastguard Worker            torch._foreach_sub_(bias_corrections, 1)
398*da0073e9SAndroid Build Coastguard Worker            torch._foreach_div_(bias_corrections, lr)
399*da0073e9SAndroid Build Coastguard Worker
400*da0073e9SAndroid Build Coastguard Worker            denom = torch._foreach_mul(grouped_exp_infs, bias_corrections)
401*da0073e9SAndroid Build Coastguard Worker            torch._foreach_addcdiv_(grouped_params, grouped_exp_avgs, denom)
402*da0073e9SAndroid Build Coastguard Worker        else:
403*da0073e9SAndroid Build Coastguard Worker            bias_corrections = [
404*da0073e9SAndroid Build Coastguard Worker                1 - beta1 ** _get_value(step) for step in grouped_state_steps
405*da0073e9SAndroid Build Coastguard Worker            ]
406*da0073e9SAndroid Build Coastguard Worker            step_size = [(_get_value(lr) / bc) * -1 for bc in bias_corrections]
407*da0073e9SAndroid Build Coastguard Worker            torch._foreach_addcdiv_(
408*da0073e9SAndroid Build Coastguard Worker                grouped_params, grouped_exp_avgs, grouped_exp_infs, step_size
409*da0073e9SAndroid Build Coastguard Worker            )
410*da0073e9SAndroid Build Coastguard Worker
411*da0073e9SAndroid Build Coastguard Worker
412*da0073e9SAndroid Build Coastguard Worker@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_adamax)
413*da0073e9SAndroid Build Coastguard Workerdef adamax(
414*da0073e9SAndroid Build Coastguard Worker    params: List[Tensor],
415*da0073e9SAndroid Build Coastguard Worker    grads: List[Tensor],
416*da0073e9SAndroid Build Coastguard Worker    exp_avgs: List[Tensor],
417*da0073e9SAndroid Build Coastguard Worker    exp_infs: List[Tensor],
418*da0073e9SAndroid Build Coastguard Worker    state_steps: List[Tensor],
419*da0073e9SAndroid Build Coastguard Worker    # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
420*da0073e9SAndroid Build Coastguard Worker    # setting this as kwarg for now as functional API is compiled by torch/distributed/optim
421*da0073e9SAndroid Build Coastguard Worker    foreach: Optional[bool] = None,
422*da0073e9SAndroid Build Coastguard Worker    maximize: bool = False,
423*da0073e9SAndroid Build Coastguard Worker    differentiable: bool = False,
424*da0073e9SAndroid Build Coastguard Worker    capturable: bool = False,
425*da0073e9SAndroid Build Coastguard Worker    has_complex: bool = False,
426*da0073e9SAndroid Build Coastguard Worker    *,
427*da0073e9SAndroid Build Coastguard Worker    eps: float,
428*da0073e9SAndroid Build Coastguard Worker    beta1: float,
429*da0073e9SAndroid Build Coastguard Worker    beta2: float,
430*da0073e9SAndroid Build Coastguard Worker    lr: float,
431*da0073e9SAndroid Build Coastguard Worker    weight_decay: float,
432*da0073e9SAndroid Build Coastguard Worker):
433*da0073e9SAndroid Build Coastguard Worker    r"""Functional API that performs adamax algorithm computation.
434*da0073e9SAndroid Build Coastguard Worker
435*da0073e9SAndroid Build Coastguard Worker    See :class:`~torch.optim.Adamax` for details.
436*da0073e9SAndroid Build Coastguard Worker    """
437*da0073e9SAndroid Build Coastguard Worker
438*da0073e9SAndroid Build Coastguard Worker    if not torch._utils.is_compiling() and not all(
439*da0073e9SAndroid Build Coastguard Worker        isinstance(t, torch.Tensor) for t in state_steps
440*da0073e9SAndroid Build Coastguard Worker    ):
441*da0073e9SAndroid Build Coastguard Worker        raise RuntimeError(
442*da0073e9SAndroid Build Coastguard Worker            "API has changed, `state_steps` argument must contain a list of singleton tensors"
443*da0073e9SAndroid Build Coastguard Worker        )
444*da0073e9SAndroid Build Coastguard Worker
445*da0073e9SAndroid Build Coastguard Worker    if foreach is None:
446*da0073e9SAndroid Build Coastguard Worker        _, foreach = _default_to_fused_or_foreach(
447*da0073e9SAndroid Build Coastguard Worker            params, differentiable, use_fused=False
448*da0073e9SAndroid Build Coastguard Worker        )
449*da0073e9SAndroid Build Coastguard Worker
450*da0073e9SAndroid Build Coastguard Worker    if foreach and torch.jit.is_scripting():
451*da0073e9SAndroid Build Coastguard Worker        raise RuntimeError("torch.jit.script not supported with foreach optimizers")
452*da0073e9SAndroid Build Coastguard Worker
453*da0073e9SAndroid Build Coastguard Worker    if foreach and not torch.jit.is_scripting():
454*da0073e9SAndroid Build Coastguard Worker        func = _multi_tensor_adamax
455*da0073e9SAndroid Build Coastguard Worker    else:
456*da0073e9SAndroid Build Coastguard Worker        func = _single_tensor_adamax
457*da0073e9SAndroid Build Coastguard Worker
458*da0073e9SAndroid Build Coastguard Worker    func(
459*da0073e9SAndroid Build Coastguard Worker        params,
460*da0073e9SAndroid Build Coastguard Worker        grads,
461*da0073e9SAndroid Build Coastguard Worker        exp_avgs,
462*da0073e9SAndroid Build Coastguard Worker        exp_infs,
463*da0073e9SAndroid Build Coastguard Worker        state_steps,
464*da0073e9SAndroid Build Coastguard Worker        eps=eps,
465*da0073e9SAndroid Build Coastguard Worker        beta1=beta1,
466*da0073e9SAndroid Build Coastguard Worker        beta2=beta2,
467*da0073e9SAndroid Build Coastguard Worker        lr=lr,
468*da0073e9SAndroid Build Coastguard Worker        weight_decay=weight_decay,
469*da0073e9SAndroid Build Coastguard Worker        maximize=maximize,
470*da0073e9SAndroid Build Coastguard Worker        differentiable=differentiable,
471*da0073e9SAndroid Build Coastguard Worker        has_complex=has_complex,
472*da0073e9SAndroid Build Coastguard Worker        capturable=capturable,
473*da0073e9SAndroid Build Coastguard Worker    )
474