xref: /aosp_15_r20/external/pytorch/torch/optim/adadelta.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-decorators
2# mypy: allow-untyped-defs
3from typing import Any, cast, Dict, List, Optional, Union
4
5import torch
6from torch import Tensor
7
8from .optimizer import (
9    _capturable_doc,
10    _default_to_fused_or_foreach,
11    _differentiable_doc,
12    _disable_dynamo_if_unsupported,
13    _foreach_doc,
14    _get_capturable_supported_devices,
15    _get_scalar_dtype,
16    _maximize_doc,
17    _use_grad_for_differentiable,
18    _view_as_real,
19    Optimizer,
20    ParamsT,
21)
22
23
24__all__ = ["Adadelta", "adadelta"]
25
26
27class Adadelta(Optimizer):
28    def __init__(
29        self,
30        params: ParamsT,
31        lr: Union[float, Tensor] = 1.0,
32        rho: float = 0.9,
33        eps: float = 1e-6,
34        weight_decay: float = 0,
35        foreach: Optional[bool] = None,
36        *,
37        capturable: bool = False,
38        maximize: bool = False,
39        differentiable: bool = False,
40    ):
41        if isinstance(lr, Tensor) and lr.numel() != 1:
42            raise ValueError("Tensor lr must be 1-element")
43        if not 0.0 <= lr:
44            raise ValueError(f"Invalid learning rate: {lr}")
45        if not 0.0 <= rho <= 1.0:
46            raise ValueError(f"Invalid rho value: {rho}")
47        if not 0.0 <= eps:
48            raise ValueError(f"Invalid epsilon value: {eps}")
49        if not 0.0 <= weight_decay:
50            raise ValueError(f"Invalid weight_decay value: {weight_decay}")
51
52        defaults = dict(
53            lr=lr,
54            rho=rho,
55            eps=eps,
56            weight_decay=weight_decay,
57            maximize=maximize,
58            capturable=capturable,
59            foreach=foreach,
60            differentiable=differentiable,
61        )
62        super().__init__(params, defaults)
63
64    def __setstate__(self, state):
65        super().__setstate__(state)
66        for group in self.param_groups:
67            group.setdefault("foreach", None)
68            group.setdefault("maximize", False)
69            group.setdefault("differentiable", False)
70            group.setdefault("capturable", False)
71            for p in group["params"]:
72                p_state = self.state.get(p, [])
73                if len(p_state) != 0 and not torch.is_tensor(p_state["step"]):
74                    step_val = float(p_state["step"])
75                    p_state["step"] = (
76                        torch.tensor(
77                            step_val, dtype=_get_scalar_dtype(), device=p.device
78                        )
79                        if group["capturable"]
80                        else torch.tensor(step_val, dtype=_get_scalar_dtype())
81                    )
82
83    def _init_group(
84        self,
85        group: Dict[str, Any],
86        params_with_grad: List[Tensor],
87        grads: List[Tensor],
88        square_avgs: List[Tensor],
89        acc_deltas: List[Tensor],
90        state_steps: List[Tensor],
91    ):
92        has_complex = False
93        p: Tensor
94        for p in group["params"]:
95            if p.grad is None:
96                continue
97            has_complex |= torch.is_complex(p)
98            params_with_grad.append(p)
99            if p.grad.is_sparse:
100                raise RuntimeError("Adadelta does not support sparse gradients")
101            grads.append(p.grad)
102
103            state = self.state[p]
104
105            # Lazy state initialization
106            if len(state) == 0:
107                state["step"] = (
108                    torch.zeros((), dtype=_get_scalar_dtype(), device=p.device)
109                    if group["capturable"]
110                    else torch.zeros((), dtype=_get_scalar_dtype())
111                )
112
113                state["square_avg"] = torch.zeros_like(
114                    p, memory_format=torch.preserve_format
115                )
116                state["acc_delta"] = torch.zeros_like(
117                    p, memory_format=torch.preserve_format
118                )
119
120            square_avgs.append(state["square_avg"])
121            acc_deltas.append(state["acc_delta"])
122            state_steps.append(state["step"])
123
124        return has_complex
125
126    @_use_grad_for_differentiable
127    def step(self, closure=None):
128        """Perform a single optimization step.
129
130        Args:
131            closure (Callable, optional): A closure that reevaluates the model
132                and returns the loss.
133        """
134        self._cuda_graph_capture_health_check()
135
136        loss = None
137        if closure is not None:
138            with torch.enable_grad():
139                loss = closure()
140
141        for group in self.param_groups:
142            params_with_grad: List[Tensor] = []
143            grads: List[Tensor] = []
144            square_avgs: List[Tensor] = []
145            acc_deltas: List[Tensor] = []
146            state_steps: List[Tensor] = []
147            (
148                lr,
149                rho,
150                eps,
151                weight_decay,
152                foreach,
153                maximize,
154                differentiable,
155                capturable,
156            ) = (
157                group["lr"],
158                group["rho"],
159                group["eps"],
160                group["weight_decay"],
161                group["foreach"],
162                group["maximize"],
163                group["differentiable"],
164                group["capturable"],
165            )
166
167            has_complex = self._init_group(
168                group, params_with_grad, grads, square_avgs, acc_deltas, state_steps
169            )
170
171            adadelta(
172                params_with_grad,
173                grads,
174                square_avgs,
175                acc_deltas,
176                state_steps,
177                lr=lr,
178                rho=rho,
179                eps=eps,
180                weight_decay=weight_decay,
181                foreach=foreach,
182                maximize=maximize,
183                differentiable=differentiable,
184                capturable=capturable,
185                has_complex=has_complex,
186            )
187
188        return loss
189
190
191Adadelta.__doc__ = (
192    r"""Implements Adadelta algorithm.
193
194    .. math::
195       \begin{aligned}
196            &\rule{110mm}{0.4pt}                                                                 \\
197            &\textbf{input}      : \gamma \text{ (lr)}, \: \theta_0 \text{ (params)},
198                \: f(\theta) \text{ (objective)}, \: \rho \text{ (decay)},
199                \: \lambda \text{ (weight decay)}                                                \\
200            &\textbf{initialize} :  v_0  \leftarrow 0 \: \text{ (square avg)},
201                \: u_0 \leftarrow 0 \: \text{ (accumulate variables)}                     \\[-1.ex]
202            &\rule{110mm}{0.4pt}                                                                 \\
203            &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do}                         \\
204            &\hspace{5mm}g_t           \leftarrow   \nabla_{\theta} f_t (\theta_{t-1})           \\
205            &\hspace{5mm}if \: \lambda \neq 0                                                    \\
206            &\hspace{10mm} g_t \leftarrow g_t + \lambda  \theta_{t-1}                            \\
207            &\hspace{5mm} v_t      \leftarrow v_{t-1} \rho + g^2_t (1 - \rho)                    \\
208            &\hspace{5mm}\Delta x_t    \leftarrow   \frac{\sqrt{u_{t-1} +
209                \epsilon }}{ \sqrt{v_t + \epsilon}  }g_t \hspace{21mm}                           \\
210            &\hspace{5mm} u_t  \leftarrow   u_{t-1}  \rho +
211                 \Delta x^2_t  (1 - \rho)                                                        \\
212            &\hspace{5mm}\theta_t      \leftarrow   \theta_{t-1} - \gamma  \Delta x_t            \\
213            &\rule{110mm}{0.4pt}                                                          \\[-1.ex]
214            &\bf{return} \:  \theta_t                                                     \\[-1.ex]
215            &\rule{110mm}{0.4pt}                                                          \\[-1.ex]
216       \end{aligned}
217
218    For further details regarding the algorithm we refer to `ADADELTA: An Adaptive Learning Rate Method`_.
219    """
220    + rf"""
221    Args:
222        params (iterable): iterable of parameters to optimize or dicts defining
223            parameter groups
224        rho (float, optional): coefficient used for computing a running average
225            of squared gradients (default: 0.9). A higher value of `rho` will
226            result in a slower average, which can be helpful for preventing
227            oscillations in the learning process.
228        eps (float, optional): term added to the denominator to improve
229            numerical stability (default: 1e-6).
230        lr (float, Tensor, optional): coefficient that scale delta before it is applied
231            to the parameters (default: 1.0)
232        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
233        {_foreach_doc}
234        {_capturable_doc}
235        {_maximize_doc}
236        {_differentiable_doc}
237
238    .. _ADADELTA\: An Adaptive Learning Rate Method:
239        https://arxiv.org/abs/1212.5701
240
241    """
242)
243
244
245def _single_tensor_adadelta(
246    params: List[Tensor],
247    grads: List[Tensor],
248    square_avgs: List[Tensor],
249    acc_deltas: List[Tensor],
250    state_steps: List[Tensor],
251    *,
252    lr: float,
253    rho: float,
254    eps: float,
255    weight_decay: float,
256    maximize: bool,
257    differentiable: bool,
258    capturable: bool,
259    has_complex: bool,
260):
261    # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
262    if not torch._utils.is_compiling() and capturable:
263        capturable_supported_devices = _get_capturable_supported_devices(
264            supports_xla=False
265        )
266        assert all(
267            p.device.type == step.device.type
268            and p.device.type in capturable_supported_devices
269            for p, step in zip(params, state_steps)
270        ), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
271
272    for param, grad, square_avg, acc_delta, step in zip(
273        params, grads, square_avgs, acc_deltas, state_steps
274    ):
275        step += 1
276        grad = grad if not maximize else -grad
277
278        if weight_decay != 0:
279            grad = grad.add(param, alpha=weight_decay)
280
281        if torch.is_complex(param):
282            square_avg = torch.view_as_real(square_avg)
283            acc_delta = torch.view_as_real(acc_delta)
284            grad = torch.view_as_real(grad)
285
286        square_avg.mul_(rho).addcmul_(grad, grad, value=1 - rho)
287        std = square_avg.add(eps).sqrt_()
288        delta = acc_delta.add(eps).sqrt_()
289        if differentiable:
290            delta = delta.clone()
291        delta.div_(std).mul_(grad)
292        acc_delta.mul_(rho).addcmul_(delta, delta, value=1 - rho)
293
294        if torch.is_complex(param):
295            delta = torch.view_as_complex(delta)
296        param.add_(delta, alpha=-lr)
297
298
299def _multi_tensor_adadelta(
300    params: List[Tensor],
301    grads: List[Tensor],
302    square_avgs: List[Tensor],
303    acc_deltas: List[Tensor],
304    state_steps: List[Tensor],
305    *,
306    lr: float,
307    rho: float,
308    eps: float,
309    weight_decay: float,
310    maximize: bool,
311    differentiable: bool,
312    capturable: bool,
313    has_complex: bool,
314):
315    assert not differentiable, "_foreach ops don't support autograd"
316
317    # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
318    if not torch._utils.is_compiling() and capturable:
319        capturable_supported_devices = _get_capturable_supported_devices(
320            supports_xla=False
321        )
322        assert all(
323            p.device.type == step.device.type
324            and p.device.type in capturable_supported_devices
325            for p, step in zip(params, state_steps)
326        ), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
327
328    if len(params) == 0:
329        return
330
331    grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
332        [params, grads, square_avgs, acc_deltas, state_steps]  # type: ignore[list-item]
333    )
334    for (
335        device_params_,
336        device_grads_,
337        device_square_avgs_,
338        device_acc_deltas_,
339        device_state_steps_,
340    ), _ in grouped_tensors.values():
341        device_params = cast(List[Tensor], device_params_)
342        device_grads = cast(List[Tensor], device_grads_)
343        device_square_avgs = cast(List[Tensor], device_square_avgs_)
344        device_acc_deltas = cast(List[Tensor], device_acc_deltas_)
345        device_state_steps = cast(List[Tensor], device_state_steps_)
346        if has_complex:
347            _view_as_real(
348                device_params, device_grads, device_square_avgs, device_acc_deltas
349            )
350
351        # Update steps
352        # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
353        # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
354        # wrapped it once now. The alpha is required to assure we go to the right overload.
355        if not torch._utils.is_compiling() and device_state_steps[0].is_cpu:
356            torch._foreach_add_(
357                device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
358            )
359        else:
360            torch._foreach_add_(device_state_steps, 1)
361
362        if maximize:
363            device_grads = torch._foreach_neg(device_grads)  # type: ignore[assignment]
364
365        if weight_decay != 0:
366            # Re-use the intermediate memory (device_grads) already allocated for maximize
367            if maximize:
368                torch._foreach_add_(device_grads, device_params, alpha=weight_decay)
369            else:
370                device_grads = torch._foreach_add(  # type: ignore[assignment]
371                    device_grads, device_params, alpha=weight_decay
372                )
373
374        torch._foreach_mul_(device_square_avgs, rho)
375        torch._foreach_addcmul_(
376            device_square_avgs, device_grads, device_grads, value=1 - rho
377        )
378
379        std = torch._foreach_add(device_square_avgs, eps)
380        torch._foreach_sqrt_(std)
381
382        deltas = torch._foreach_add(device_acc_deltas, eps)
383        torch._foreach_sqrt_(deltas)
384        torch._foreach_div_(deltas, std)
385        torch._foreach_mul_(deltas, device_grads)
386
387        torch._foreach_mul_(device_acc_deltas, rho)
388        torch._foreach_addcmul_(device_acc_deltas, deltas, deltas, value=1 - rho)
389
390        # If LR is a tensor, the else branch will internally call item()
391        # which will cause silent incorrectness if we are capturing
392        if capturable and isinstance(lr, torch.Tensor):
393            torch._foreach_mul_(deltas, -lr)
394            torch._foreach_add_(device_params, deltas)
395        else:
396            torch._foreach_add_(device_params, deltas, alpha=-lr)
397
398
399@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_adadelta)
400def adadelta(
401    params: List[Tensor],
402    grads: List[Tensor],
403    square_avgs: List[Tensor],
404    acc_deltas: List[Tensor],
405    state_steps: List[Tensor],
406    # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
407    # setting this as kwarg for now as functional API is compiled by torch/distributed/optim
408    capturable: bool = False,
409    foreach: Optional[bool] = None,
410    differentiable: bool = False,
411    has_complex: bool = False,
412    *,
413    lr: float,
414    rho: float,
415    eps: float,
416    weight_decay: float,
417    maximize: bool,
418):
419    r"""Functional API that performs Adadelta algorithm computation.
420
421    See :class:`~torch.optim.Adadelta` for details.
422    """
423
424    # this check is slow during compilation, so we skip it
425    # if it's strictly needed we can add this check back in dynamo
426    if not torch._utils.is_compiling() and not all(
427        isinstance(t, torch.Tensor) for t in state_steps
428    ):
429        raise RuntimeError(
430            "API has changed, `state_steps` argument must contain a list of singleton tensors"
431        )
432
433    # We still respect when the user inputs False for foreach.
434    if foreach is None:
435        _, foreach = _default_to_fused_or_foreach(
436            params, differentiable, use_fused=False
437        )
438
439    if foreach and torch.jit.is_scripting():
440        raise RuntimeError("torch.jit.script not supported with foreach optimizers")
441
442    if foreach and not torch.jit.is_scripting():
443        func = _multi_tensor_adadelta
444    else:
445        func = _single_tensor_adadelta
446
447    func(
448        params,
449        grads,
450        square_avgs,
451        acc_deltas,
452        state_steps,
453        lr=lr,
454        rho=rho,
455        eps=eps,
456        weight_decay=weight_decay,
457        maximize=maximize,
458        differentiable=differentiable,
459        capturable=capturable,
460        has_complex=has_complex,
461    )
462