xref: /aosp_15_r20/external/pytorch/torch/distributed/optim/functional_adagrad.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from typing import Dict, List, Optional
3
4import torch
5import torch.optim._functional as F
6from torch import Tensor
7
8
9__all__: List[str] = []
10
11
12# Define a TorchScript compatible Functional Adagrad Optimizer
13# where we use these optimizer in a functional way.
14# Instead of using the `param.grad` when updating parameters,
15# we explicitly let the user pass gradients to the `step` function
16# this is so that we could separate the gradients and parameters
17# and allow multithreaded trainer to update the parameters
18# without data traces on accumulating to the same .grad.
19# NOTE: This should be only used by distributed optimizer internals
20# and not meant to expose to the user.
21@torch.jit.script
22class _FunctionalAdagrad:
23    def __init__(
24        self,
25        params: List[Tensor],
26        lr: float = 1e-2,
27        lr_decay: float = 0.0,
28        weight_decay: float = 0.0,
29        initial_accumulator_value: float = 0.0,
30        warmup_lr_multiplier: float = 1.0,
31        warmup_num_iters: float = 0.0,
32        eps: float = 1e-10,
33        coalesce_grad: bool = True,
34        foreach: bool = False,
35        fused: bool = False,
36        maximize: bool = False,
37        _allow_empty_param_list: bool = False,
38    ):
39        self.defaults = {
40            "lr": lr,
41            "lr_decay": lr_decay,
42            "eps": eps,
43            "weight_decay": weight_decay,
44            "initial_accumulator_value": initial_accumulator_value,
45            "warmup_lr_multiplier": warmup_lr_multiplier,
46            "warmup_num_iters": warmup_num_iters,
47        }
48        self.coalesce_grad = coalesce_grad
49        self.foreach = foreach
50        self.fused = fused
51        self.maximize = maximize
52        self.state = torch.jit.annotate(Dict[torch.Tensor, Dict[str, torch.Tensor]], {})
53
54        if len(params) == 0 and not _allow_empty_param_list:
55            raise ValueError("optimizer got an empty parameter list")
56
57        # NOTE: we only have one param_group and don't allow user to add additional
58        # param group as it's not a common use case.
59        self.param_group = {"params": params}
60
61        # TODO: no union or any types in TorchScript, make step a scalar tensor instead
62        # This is also needed by if we want to share_memory on the step across processes
63        for p in self.param_group["params"]:
64            self.state[p] = {
65                "sum": torch.full_like(p.data, initial_accumulator_value),
66                "step": torch.tensor(0.0),
67            }
68
69    def step(self, gradients: List[Optional[Tensor]]):
70        params = self.param_group["params"]
71        params_with_grad = []
72        grads = []
73        state_sums = []
74        state_steps: List[Tensor] = []
75
76        if len(params) != len(gradients):
77            raise ValueError(
78                "the gradients passed in does not equal to the size of the parameters!"
79                + f"Params length: {len(params)}. "
80                + f"Gradients length: {len(gradients)}"
81            )
82
83        has_sparse_grad, has_complex = False, False
84        for param, gradient in zip(self.param_group["params"], gradients):
85            if gradient is not None:
86                has_sparse_grad |= gradient.is_sparse
87                has_complex |= torch.is_complex(param)
88                params_with_grad.append(param)
89                grads.append(gradient)
90                state = self.state[param]
91                state_sums.append(state["sum"])
92                state_steps.append(state["step"])
93
94        with torch.no_grad():
95            F.adagrad(
96                params,
97                grads,
98                state_sums,
99                state_steps,
100                lr=self.defaults["lr"],
101                weight_decay=self.defaults["weight_decay"],
102                lr_decay=self.defaults["lr_decay"],
103                eps=self.defaults["eps"],
104                has_sparse_grad=has_sparse_grad,
105                foreach=self.foreach,
106                maximize=self.maximize,
107                has_complex=has_complex,
108                fused=self.fused,
109                grad_scale=None,
110                found_inf=None,
111            )
112