xref: /aosp_15_r20/external/pytorch/torch/optim/_functional.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2r"""Functional interface."""
3import math
4from typing import List
5
6from torch import Tensor
7
8from .adadelta import adadelta  # type: ignore[attr-defined]  # noqa: F401
9from .adagrad import _make_sparse, adagrad  # type: ignore[attr-defined]  # noqa: F401
10from .adam import adam  # type: ignore[attr-defined]  # noqa: F401
11from .adamax import adamax  # type: ignore[attr-defined]  # noqa: F401
12from .adamw import adamw  # type: ignore[attr-defined]  # noqa: F401
13from .asgd import asgd  # type: ignore[attr-defined]  # noqa: F401
14from .nadam import nadam  # type: ignore[attr-defined]  # noqa: F401
15from .radam import radam  # type: ignore[attr-defined]  # noqa: F401
16from .rmsprop import rmsprop  # type: ignore[attr-defined]  # noqa: F401
17from .rprop import rprop  # type: ignore[attr-defined]  # noqa: F401
18from .sgd import sgd  # type: ignore[attr-defined]  # noqa: F401
19
20
21# TODO: use foreach API in optim._functional to do all the computation
22
23
24def sparse_adam(
25    params: List[Tensor],
26    grads: List[Tensor],
27    exp_avgs: List[Tensor],
28    exp_avg_sqs: List[Tensor],
29    state_steps: List[int],
30    *,
31    eps: float,
32    beta1: float,
33    beta2: float,
34    lr: float,
35    maximize: bool,
36):
37    r"""Functional API that performs Sparse Adam algorithm computation.
38
39    See :class:`~torch.optim.SparseAdam` for details.
40    """
41    for i, param in enumerate(params):
42        grad = grads[i]
43        grad = grad if not maximize else -grad
44        grad = grad.coalesce()  # the update is non-linear so indices must be unique
45        grad_indices = grad._indices()
46        grad_values = grad._values()
47        if grad_values.numel() == 0:
48            # Skip update for empty grad
49            continue
50        size = grad.size()
51
52        exp_avg = exp_avgs[i]
53        exp_avg_sq = exp_avg_sqs[i]
54        step = state_steps[i]
55
56        def make_sparse(values):
57            constructor = grad.new
58            if grad_indices.dim() == 0 or values.dim() == 0:
59                return constructor().resize_as_(grad)
60            return constructor(grad_indices, values, size)
61
62        # Decay the first and second moment running average coefficient
63        #      old <- b * old + (1 - b) * new
64        # <==> old += (1 - b) * (new - old)
65        old_exp_avg_values = exp_avg.sparse_mask(grad)._values()
66        exp_avg_update_values = grad_values.sub(old_exp_avg_values).mul_(1 - beta1)
67        exp_avg.add_(make_sparse(exp_avg_update_values))
68        old_exp_avg_sq_values = exp_avg_sq.sparse_mask(grad)._values()
69        exp_avg_sq_update_values = (
70            grad_values.pow(2).sub_(old_exp_avg_sq_values).mul_(1 - beta2)
71        )
72        exp_avg_sq.add_(make_sparse(exp_avg_sq_update_values))
73
74        # Dense addition again is intended, avoiding another sparse_mask
75        numer = exp_avg_update_values.add_(old_exp_avg_values)
76        exp_avg_sq_update_values.add_(old_exp_avg_sq_values)
77        denom = exp_avg_sq_update_values.sqrt_().add_(eps)
78        del exp_avg_update_values, exp_avg_sq_update_values
79
80        bias_correction1 = 1 - beta1**step
81        bias_correction2 = 1 - beta2**step
82        step_size = lr * math.sqrt(bias_correction2) / bias_correction1
83
84        param.add_(make_sparse(-step_size * numer.div_(denom)))
85