xref: /aosp_15_r20/external/pytorch/torch/optim/_adafactor.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-decorators
2# mypy: allow-untyped-defs
3from typing import cast, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
4
5import torch
6from torch import Tensor
7
8from .optimizer import (
9    _disable_dynamo_if_unsupported,
10    _get_scalar_dtype,
11    _maximize_doc,
12    Optimizer,
13    ParamsT,
14    TensorListList,
15)
16
17
18__all__ = ["Adafactor", "adafactor"]
19
20
21class Adafactor(Optimizer):
22    def __init__(
23        self,
24        params: ParamsT,
25        lr: Union[float, Tensor] = 1e-2,
26        beta2_decay: float = -0.8,
27        eps: Tuple[Optional[float], float] = (None, 1e-3),
28        d: float = 1.0,
29        weight_decay: float = 0.0,
30        *,
31        foreach: Optional[bool] = None,
32        maximize: bool = False,
33    ):
34        if isinstance(lr, Tensor) and lr.numel() != 1:
35            raise ValueError("Tensor lr must be 1-element")
36        if not 0.0 <= lr:
37            raise ValueError(f"Learning rate should be >= 0 but is: {lr}")
38        if not 0.0 >= beta2_decay:
39            raise ValueError(f"beta2_decay should be <= 0 but is: {beta2_decay}")
40        if eps[0] is not None and not 0.0 <= eps[0]:
41            raise ValueError(f"epsilon1 should be >= 0 but is: {eps[0]}")
42        if not 0.0 <= eps[1]:
43            raise ValueError(f"epsilon2 should be >= 0 but is: {eps[1]}")
44        if not 1.0 <= d:
45            raise ValueError(f"Clipping threshold d should be >= 1 but is: {d}")
46        if not 0.0 <= weight_decay:
47            raise ValueError(f"weight_decay should be >= 0 but is: {weight_decay}")
48        defaults = dict(
49            lr=lr,
50            beta2_decay=beta2_decay,
51            eps=eps,
52            d=d,
53            weight_decay=weight_decay,
54            foreach=foreach,
55            maximize=maximize,
56        )
57        super().__init__(params, defaults)
58
59    def __setstate__(self, state):
60        super().__setstate__(state)
61        for group in self.param_groups:
62            group.setdefault("foreach", None)
63            for p in group["params"]:
64                p_state = self.state.get(p, [])
65                if len(p_state) != 0 and not torch.is_tensor(p_state["step"]):
66                    step_val = float(p_state["step"])
67                    p_state["step"] = torch.tensor(step_val, dtype=_get_scalar_dtype())
68
69    def _init_group(
70        self,
71        group,
72        params_with_grad,
73        grads,
74        row_vars,
75        col_vars,
76        variances,
77        state_steps,
78    ):
79        for p in group["params"]:
80            if p.grad is None:
81                continue
82            if torch.is_complex(p):
83                raise RuntimeError("Adafactor does not support complex parameters")
84            if p.grad.is_sparse:
85                raise RuntimeError("Adafactor does not support sparse gradients")
86
87            params_with_grad.append(p)
88            grads.append(p.grad)
89
90            state = self.state[p]
91
92            # State initialization
93            if len(state) == 0:
94                # note(crcrpar): Deliberately host `step` on CPU if both capturable and fused are off.
95                # This is because kernel launches are costly on CUDA and XLA.
96                state["step"] = torch.tensor(0.0, dtype=_get_scalar_dtype())
97
98                if p.grad.dim() > 1:
99                    row_shape = list(p.grad.shape)
100                    row_shape[-1] = 1
101                    # Row factor of variance, NOT the same shape as grads (will be reduced along last dim)
102                    state["row_var"] = p.grad.new_zeros(row_shape)
103
104                    col_shape = list(p.grad.shape)
105                    col_shape[-2] = 1
106                    # Col factor of variance, NOT the same shape as grads (will be reduced along penultimate dim)
107                    state["col_var"] = p.grad.new_zeros(col_shape)
108                else:
109                    state["variance"] = torch.zeros_like(
110                        p.grad, memory_format=torch.preserve_format
111                    )
112
113            row_vars.append(state.get("row_var", None))
114            col_vars.append(state.get("col_var", None))
115            variances.append(state.get("variance", None))
116            state_steps.append(state["step"])
117        return False  # has_complex
118
119    @torch.no_grad()
120    def step(self, closure=None):
121        r"""Perform a single optimization step.
122
123        Args:
124            closure (Callable, optional): A closure that reevaluates the model
125                and returns the loss.
126        """
127        self._cuda_graph_capture_health_check()
128
129        loss = None
130        if closure is not None:
131            with torch.enable_grad():
132                loss = closure()
133
134        for group in self.param_groups:
135            params_with_grad: List[Tensor] = []
136            grads: List[Tensor] = []
137            row_vars: List[Optional[Tensor]] = []
138            col_vars: List[Optional[Tensor]] = []
139            variances: List[Optional[Tensor]] = []
140            state_steps: List[Tensor] = []
141            eps1, eps2 = group["eps"]
142
143            has_complex = self._init_group(
144                group,
145                params_with_grad,
146                grads,
147                row_vars,
148                col_vars,
149                variances,
150                state_steps,
151            )
152
153            adafactor(
154                params_with_grad,
155                grads,
156                row_vars,
157                col_vars,
158                variances,
159                state_steps,
160                d=group["d"],
161                lr=group["lr"],
162                beta2_decay=group["beta2_decay"],
163                weight_decay=group["weight_decay"],
164                eps1=eps1,
165                eps2=eps2,
166                foreach=group["foreach"],
167                maximize=group["maximize"],
168                grad_scale=getattr(self, "grad_scale", None),
169                found_inf=getattr(self, "found_inf", None),
170                has_complex=has_complex,
171            )
172
173        return loss
174
175
176Adafactor.__doc__ = (
177    r"""Implements Adafactor algorithm.
178
179    .. math::
180        \begin{aligned}
181            &\rule{110mm}{0.4pt}                                                                 \\
182            &\textbf{input}      : \gamma \text{(lr)}, \: \tau
183                \text{(}\beta_2\text{ decay)}, \: \theta_0 \text{(params)}, \: f(\theta) \text{(objective)},    \\
184            &\hspace{15mm}      \: \epsilon_1, \epsilon_2 \text{ (epsilons)}, \: d \text{(clipping threshold)}, \\
185            &\hspace{15mm}      \: \lambda \text{(weight decay)},
186                \: \textit{maximize}                                                             \\
187            &\textbf{initialize} : \: R_0 \leftarrow 0 \text{ (second moment row factor)},       \\
188            &\hspace{23mm} \: C_0 \leftarrow 0 \text{ (second moment col factor)},               \\
189            &\hspace{23mm} \: \widehat{V}_0 \leftarrow 0 \text{ (second moment for vectors)}     \\[-1.ex]
190            &\rule{110mm}{0.4pt}                                                                 \\
191            &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do}                         \\
192
193            &\hspace{5mm}\textbf{if} \: \textit{maximize}:                                       \\
194            &\hspace{10mm}G_t           \leftarrow   -\nabla_{\theta} f_t (\theta_{t-1})         \\
195            &\hspace{5mm}\textbf{else}                                                           \\
196            &\hspace{10mm}G_t           \leftarrow   \nabla_{\theta} f_t (\theta_{t-1})          \\
197            &\hspace{5mm}\widehat{\beta}_{2_t} \leftarrow 1 - t^{\tau}                           \\
198            &\hspace{5mm}\rho_t         \leftarrow min(lr, \frac{1}{\sqrt{t}})                   \\
199            &\hspace{5mm}\alpha_t       \leftarrow max(\epsilon_2,
200                \text{RMS}(\theta_{t-1}))\rho_t                                                  \\
201            &\hspace{5mm}\theta_t       \leftarrow \theta_{t-1} - \gamma \lambda \theta_{t-1}    \\
202            &\hspace{5mm}\textbf{if} \: \text{dim}(G_t) > 1:                                     \\
203            &\hspace{10mm}R_t           \leftarrow \widehat{\beta}_{2_t}R_{t-1}+
204                (1-\widehat{\beta}_{2_t})(G_t \odot G_t) \cdot 1_m                               \\
205            &\hspace{10mm}C_t           \leftarrow \widehat{\beta}_{2_t}C_{t-1}+
206                (1-\widehat{\beta}_{2_t}) 1^\top_n \cdot (G_t \odot G_t)                         \\
207            &\hspace{10mm}\widehat{V}_t \leftarrow
208                \frac{R_t \cdot C_t}{max(1^\top_n \cdot R_t, \epsilon_1)}                        \\
209            &\hspace{5mm}\textbf{else}                                                           \\
210            &\hspace{10mm}\widehat{V}_t \leftarrow \widehat{\beta}_{2_t}\widehat{V}_{t-1}+
211                (1-\widehat{\beta}_{2_t}) \cdot (G_t \odot G_t)                                  \\
212            &\hspace{5mm}U_t            \leftarrow
213                \frac{G_t}{max(\sqrt{\widehat{V}_t}, \epsilon_1)}                                \\
214            &\hspace{5mm}\widehat{U}_t  \leftarrow \frac{U_t}{max(1, \frac{\text{RMS}(U_t)}{d})} \\
215            &\hspace{5mm}\theta_t       \leftarrow \theta_{t-1} - \alpha_t \widehat{U}_t         \\
216
217            &\rule{110mm}{0.4pt}                                                          \\[-1.ex]
218            &\bf{return} \:  \theta_t                                                     \\[-1.ex]
219            &\rule{110mm}{0.4pt}                                                          \\[-1.ex]
220        \end{aligned}
221
222    For further details regarding the algorithm we refer to `Adafactor: Adaptive Learning Rates with Sublinear Memory Cost`_.
223    """
224    + rf"""
225    Args:
226        params (iterable): iterable of parameters to optimize or dicts defining
227            parameter groups
228        lr (float, Tensor, optional): unlike other optimizers, Adafactor does not require a
229            learning rate, and Shazeer, Noam, and Mitchell Stern do not use lr at all.
230            Deviating from the paper, this implementation uses lr for applying weight
231            decay and as the maximum value for relative step size rho_t. Note that in
232            the paper, a constant of 0.01 is used as the maximum value for relative
233            step size, and so we set 0.01 as the default value. (default: 1e-2)
234        beta2_decay (float, optional): the decay rate of beta2. beta2 standardly refers
235            to the coefficient used for computing the running average of the gradient
236            squared. (default: -0.8)
237        eps (Tuple[float, float], optional): epsilon1 is the term added to the denominator
238            of the update calculation to improve numerical stability. This use of epsilon1
239            deviates from the algorithm written in the paper! See note below for more details.
240            epsilon2 is the term used to avoid having too small a weight update when applying
241            parameter scaling. (default: (None, 1e-3))
242        d (float, optional): the clipping threshold, used to avoid larger-than-desired
243            updates.
244        weight_decay (float, optional): weight decay coefficient (default: 1e-2)
245        foreach (bool, optional): whether foreach implementation of optimizer is used. Note
246            that the foreach implementation uses ~ sizeof(params) more peak memory than the
247            for-loop version due to the intermediates being a tensorlist vs just one tensor.
248            As Adafactor is commonly used when memory is prohibitive, Adafactor will default
249            to the slower single tensor for-loop implementation unless this flag is explicitly
250            True. This behavior is contrary to other optimizers, which will attempt defaulting
251            to foreach on CUDA for faster runtime. (default: None)
252        {_maximize_doc}"""
253    + r"""
254    .. Note::
255        The implementation of Adafactor subtly differs from Shazeer, Noam, and Mitchell Stern
256        and implementations in some other frameworks with its use of learning rate and
257        :math:`\epsilon_1`.
258
259        Regarding the learning rate hyperparameter: Shazeer, Noam, and Mitchell Stern do not
260        use lr at all, as the stated algorithm uses :math:`\rho_t` and update clipping to
261        affect the step size.
262
263        This implementation allows `lr` to influence the maximum value for :math:`\rho_t`:
264
265        .. math::
266            \begin{aligned}
267                &\hspace{5mm}\rho_t \leftarrow min(lr, \frac{1}{\sqrt{t}})
268            \end{aligned}
269
270        This differs from Shazeer, Noam, and Mitchell Stern, who use a constant of 0.01 as
271        the maximum value of :math:`\rho_t`
272
273        .. math::
274            \begin{aligned}
275                &\hspace{5mm}\rho_t \leftarrow min(0.01, \frac{1}{\sqrt{t}})
276            \end{aligned}
277
278        Shazeer, Noam, and Mitchell Stern do not enforce an opinion on how weight decay should
279        be computed, and so we use the learning rate as a coefficient for decoupled weight
280        decay, similar to what is suggested in `Decoupled Weight Decay Regularization`_.
281
282        Regarding the use of :math:`\epsilon_1`: The implementation attempts to replicate the
283        presumed intention of Shazeer, Noam, and Mitchell Stern to use :math:`\epsilon_1` as
284        a stabilizing term when the squared gradient becomes small.
285
286        This stabilization can be written as
287
288        .. math::
289            \begin{aligned}
290                &\hspace{5mm}R_t \leftarrow \widehat{\beta}_{2_t}R_{t-1}+
291                    (1-\widehat{\beta}_{2_t})(G_t \odot G_t + 1_n \cdot 1^\top_m) \cdot 1_m          \\
292                &\hspace{5mm}C_t \leftarrow \widehat{\beta}_{2_t}C_{t-1}+
293                    (1-\widehat{\beta}_{2_t}) 1^\top_n \cdot (G_t \odot G_t + 1_n \cdot 1^\top_m)    \\
294                &\hspace{5mm}\widehat{V}_t \leftarrow
295                    \frac{R_t \cdot C_t}{max(1^\top_n \cdot R_t, \epsilon_1)}                        \\
296                &\hspace{5mm}U_t \leftarrow \frac{G_t}{max(\sqrt{\widehat{V}_t}, \epsilon_1)}        \\
297            \end{aligned}
298
299        where the row and column factors of gradient squared :math:`R_t` and :math:`C_t`
300        are left alone, and we apply :math:`\epsilon_1` at the final calculation of
301        the variance estimate :math:`\widehat{V}_t` and for the update :math:`U_t`.
302
303        This is in contrast to Shazeer, Noam, and Mitchell Stern and other frameworks which
304        apply :math:`\epsilon_1` to both row and column factors of the squared gradient, but
305        not in the calculations after:
306
307        .. math::
308            \begin{aligned}
309                &\hspace{5mm}R_t \leftarrow \widehat{\beta}_{2_t}R_{t-1}+
310                            (1-\widehat{\beta}_{2_t})(G_t \odot G_t + \epsilon_1 1_n \cdot 1^\top_m) \cdot 1_m          \\
311                &\hspace{5mm}C_t \leftarrow \widehat{\beta}_{2_t}C_{t-1}+
312                            (1-\widehat{\beta}_{2_t}) 1^\top_n \cdot (G_t \odot G_t + \epsilon_1 1_n \cdot 1^\top_m)    \\
313                &\hspace{5mm}\widehat{V}_t \leftarrow \frac{R_t \cdot C_t}{1^\top_n \cdot R_t}                          \\
314                &\hspace{5mm}U_t \leftarrow \frac{G_t}{\sqrt{\widehat{V}_t}}                                            \\
315            \end{aligned}
316
317
318    .. _Adafactor\: Adaptive Learning Rates with Sublinear Memory Cost:
319        https://arxiv.org/pdf/1804.04235
320    .. _Decoupled Weight Decay Regularization:
321        https://arxiv.org/abs/1711.05101
322    """
323)
324
325
326def _single_tensor_adafactor(
327    params: List[Tensor],
328    grads: List[Tensor],
329    # If grad is 1-dimensional (aka a vector), there is no factorization necessary
330    # so row_var and col_var will be None while variance will be filled.
331    # Contrarily, for a grad with multiple dimensions, we will factor along the last
332    # 2 dimensions, and so row_var and col_var will be filled and variance will be None.
333    row_vars: List[Optional[Tensor]],
334    col_vars: List[Optional[Tensor]],
335    variances: List[Optional[Tensor]],
336    state_steps: List[Tensor],
337    grad_scale: Optional[Tensor],
338    found_inf: Optional[Tensor],
339    *,
340    d: float,
341    lr: Union[Tensor, float],
342    beta2_decay: float,
343    weight_decay: float,
344    eps1: Optional[float],
345    eps2: float,
346    maximize: bool,
347    has_complex: bool,
348):
349    assert (
350        grad_scale is None and found_inf is None
351    ), "Grad scaling should occur outside of optimizer.step()"
352
353    if torch.jit.is_scripting():
354        # this assert is due to JIT being dumb and not realizing that the ops below
355        # have overloads to handle both float and Tensor lrs, so we just assert it's
356        # a float since most people using JIT are using floats
357        assert isinstance(lr, float)
358
359    for i, param in enumerate(params):
360        grad = grads[i] if not maximize else -grads[i]
361        step_t = state_steps[i]
362        row_var = row_vars[i]
363        col_var = col_vars[i]
364        variance = variances[i]
365        if eps1 is None:
366            eps1 = torch.finfo(param.dtype).eps
367
368        # update step
369        step_t += 1
370        step_float = step_t.item()
371
372        one_minus_beta2_t = step_float**beta2_decay
373        rho_t = min(lr, 1 / (step_float**0.5))
374        alpha = max(eps2, param.norm(2).item() / (param.numel() ** 0.5)) * rho_t
375
376        # Perform stepweight decay
377        if weight_decay != 0:
378            param.mul_(1 - lr * weight_decay)
379
380        if grad.dim() > 1:
381            assert (
382                row_var is not None and col_var is not None
383            ), "row_var and col_var should be defined when grad is multidimensional"
384            # same as (g * g).mean(dim=-1) w/o materializing an intermediate size g
385            row_mean = (
386                torch.norm(grad, dim=-1, keepdim=True).square_().div_(grad.size(-1))
387            )
388            row_var.lerp_(row_mean, one_minus_beta2_t)
389            # same as (g * g).mean(dim=-2) w/o materializing an intermediate size g
390            col_mean = (
391                torch.norm(grad, dim=-2, keepdim=True).square_().div_(grad.size(-2))
392            )
393            col_var.lerp_(col_mean, one_minus_beta2_t)
394            var_estimate = row_var @ col_var
395            var_estimate.div_(row_var.mean(dim=-2, keepdim=True).clamp_(min=eps1))
396        else:
397            assert (
398                variance is not None
399            ), "variance should be defined when grad is a vector"
400            grad_squared = grad * grad
401            variance.lerp_(grad_squared, one_minus_beta2_t)
402            # avoid writing into variance during update
403            var_estimate = variance.clone()
404
405        # square the eps1 as we sqrt after to keep eps1's magnitude
406        update = var_estimate.clamp_(min=eps1 * eps1).rsqrt_()
407        update.mul_(grad)
408        denom = max(1.0, update.norm(2).item() / ((update.numel() ** 0.5) * d))
409        param.add_(update, alpha=-alpha / denom)
410
411
412def _group_tensors_by_device_dtype_and_is_multidim(
413    tensorlists: TensorListList,
414) -> Dict[
415    Tuple[Optional[torch.device], Optional[torch.dtype], bool],
416    List[List[Optional[Tensor]]],
417]:
418    """Groups tensors by device, dtype, AND multidimensionality -- whether the tensor
419    has multiple dims or just one dim (is a vector). This allows the foreach impl of
420    Adafactor to assume that every group of params will either be factored or not."""
421    grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(tensorlists)
422    ultra_grouped_tensors: Dict[
423        Tuple[Optional[torch.device], Optional[torch.dtype], bool],
424        List[List[Optional[Tensor]]],
425    ] = {}
426    for (device, dtype), (tensorlists, _) in grouped_tensors.items():
427        matrix_key = (device, dtype, True)
428        vector_key = (device, dtype, False)
429
430        # assumes grad is the second tensorlist
431        for j, tensor in enumerate(tensorlists[1]):
432            assert tensor is not None, "grad should not be None"
433            if tensor.dim() > 1:
434                if matrix_key not in ultra_grouped_tensors:
435                    ultra_grouped_tensors[matrix_key] = [[] for _ in tensorlists]
436                for i in range(len(tensorlists)):
437                    ultra_grouped_tensors[matrix_key][i].append(tensorlists[i][j])
438            else:
439                if vector_key not in ultra_grouped_tensors:
440                    ultra_grouped_tensors[vector_key] = [[] for _ in tensorlists]
441                for i in range(len(tensorlists)):
442                    ultra_grouped_tensors[vector_key][i].append(tensorlists[i][j])
443    return ultra_grouped_tensors
444
445
446def _multi_tensor_adafactor(
447    params: List[Tensor],
448    grads: List[Tensor],
449    # If grad is 1-dimensional (aka a vector), there is no factorization necessary
450    # so row_var and col_var will be None while variance will be filled.
451    # Contrarily, for a grad with multiple dimensions, we will factor along the last
452    # 2 dimensions, and so row_var and col_var will be filled and variance will be None.
453    row_vars: List[Optional[Tensor]],
454    col_vars: List[Optional[Tensor]],
455    variances: List[Optional[Tensor]],
456    state_steps: List[Tensor],
457    grad_scale: Optional[Tensor],
458    found_inf: Optional[Tensor],
459    *,
460    d: float,
461    lr: Union[Tensor, float],
462    beta2_decay: float,
463    weight_decay: float,
464    eps1: Optional[float],
465    eps2: float,
466    maximize: bool,
467    has_complex: bool,
468):
469    if len(params) == 0:
470        return
471
472    assert (
473        grad_scale is None and found_inf is None
474    ), "Grad scaling should occur outside of optimizer.step()"
475
476    grouped_tensors = _group_tensors_by_device_dtype_and_is_multidim(
477        [params, grads, row_vars, col_vars, variances, state_steps]  # type: ignore[list-item]
478    )
479    for (_, dtype, is_multidim), (
480        (
481            device_params_,
482            device_grads_,
483            device_row_vars_,
484            device_col_vars_,
485            device_variances_,
486            device_state_steps_,
487        )
488    ) in grouped_tensors.items():
489        device_params = cast(List[Tensor], device_params_)
490        device_grads = cast(List[Tensor], device_grads_)
491        device_state_steps = cast(List[Tensor], device_state_steps_)
492        if eps1 is None:
493            assert (
494                dtype is not None
495            ), "dtype is needed to compute eps1 when eps1 is unset"
496            eps1 = torch.finfo(dtype).eps
497
498        if TYPE_CHECKING:
499            assert device_state_steps[0] is not None
500
501        if maximize:
502            device_grads = torch._foreach_neg(device_grads)  # type: ignore[assignment]
503
504        # Update steps
505        # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
506        # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
507        # wrapped it once now. The alpha is required to assure we go to the right overload.
508        if not torch._utils.is_compiling() and device_state_steps[0].is_cpu:
509            torch._foreach_add_(
510                device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
511            )
512        else:
513            torch._foreach_add_(device_state_steps, 1.0)
514
515        one_minus_beta2_ts = []
516        beta2_ts = []
517        rho_ts = []
518        for s in device_state_steps:
519            one_minus_beta2_ts.append(s.item() ** beta2_decay)
520            beta2_ts.append(1 - s.item() ** beta2_decay)
521            rho_ts.append(min(lr, 1 / (s.item() ** 0.5)))
522
523        alphas = [
524            max(eps2, p.norm(2).item() / (p.numel() ** 0.5)) * r
525            for p, r in zip(device_params, rho_ts)
526        ]
527
528        # Perform stepweight decay
529        if weight_decay != 0:
530            torch._foreach_mul_(device_params, 1 - lr * weight_decay)
531
532        if is_multidim:
533            device_row_vars = cast(List[Tensor], device_row_vars_)
534            device_col_vars = cast(List[Tensor], device_col_vars_)
535            assert (
536                device_row_vars[0] is not None and device_col_vars[0] is not None
537            ), "row_var and col_var should be defined when grad is multidimensional"
538            # same as (g * g).mean(dim=-1) w/o materializing an intermediate size g
539            row_means = [
540                torch.norm(grad, dim=-1, keepdim=True) for grad in device_grads
541            ]
542            torch._foreach_mul_(row_means, row_means)
543            torch._foreach_div_(row_means, [grad.size(-1) for grad in device_grads])
544            torch._foreach_mul_(device_row_vars, beta2_ts)
545            torch._foreach_mul_(row_means, one_minus_beta2_ts)
546            torch._foreach_add_(device_row_vars, row_means)
547            del row_means
548
549            # same as (g * g).mean(dim=-2) w/o materializing an intermediate size g
550            col_means = [
551                torch.norm(grad, dim=-2, keepdim=True) for grad in device_grads
552            ]
553            torch._foreach_mul_(col_means, col_means)
554            torch._foreach_div_(col_means, [grad.size(-2) for grad in device_grads])
555            torch._foreach_mul_(device_col_vars, beta2_ts)
556            torch._foreach_mul_(col_means, one_minus_beta2_ts)
557            torch._foreach_add_(device_col_vars, col_means)
558            del col_means
559
560            var_estimates = [
561                row_var @ col_var
562                for row_var, col_var in zip(device_row_vars, device_col_vars)
563            ]
564            row_var_means = [
565                row_var.mean(dim=-2, keepdim=True) for row_var in device_row_vars
566            ]
567            torch._foreach_clamp_min_(row_var_means, eps1)
568            torch._foreach_div_(var_estimates, row_var_means)
569            del row_var_means
570        else:
571            device_variances = cast(List[Tensor], device_variances_)
572            assert (
573                device_variances[0] is not None
574            ), "variance should be defined when grad is a vector"
575
576            grads_squared = torch._foreach_mul(device_grads, device_grads)
577            torch._foreach_mul_(device_variances, beta2_ts)
578            torch._foreach_mul_(grads_squared, one_minus_beta2_ts)
579            torch._foreach_add_(device_variances, grads_squared)
580            del grads_squared
581
582            # avoid writing into variance during update
583            var_estimates = [v.clone() for v in device_variances]
584
585        # square the eps1 as we sqrt after to keep eps1's magnitude
586        torch._foreach_clamp_min_(var_estimates, eps1 * eps1)
587        torch._foreach_sqrt_(var_estimates)
588        torch._foreach_reciprocal_(var_estimates)
589        torch._foreach_mul_(var_estimates, device_grads)
590        updates = var_estimates
591
592        alphas = [
593            -a / (max(1.0, update.norm(2).item() / ((update.numel() ** 0.5) * d)))
594            for a, update in zip(alphas, updates)
595        ]
596        torch._foreach_mul_(updates, alphas)
597        torch._foreach_add_(device_params, updates)
598
599
600@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_adafactor)
601def adafactor(
602    params: List[Tensor],
603    grads: List[Tensor],
604    row_vars: List[Optional[Tensor]],
605    col_vars: List[Optional[Tensor]],
606    variances: List[Optional[Tensor]],
607    state_steps: List[Tensor],
608    # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
609    # setting this as kwarg for now as functional API is compiled by torch/distributed/optim
610    foreach: Optional[bool] = None,
611    grad_scale: Optional[Tensor] = None,
612    found_inf: Optional[Tensor] = None,
613    has_complex: bool = False,
614    *,
615    d: float,
616    lr: Union[float, Tensor],
617    beta2_decay: float,
618    weight_decay: float,
619    eps1: float,
620    eps2: float,
621    maximize: bool,
622):
623    r"""Functional API that performs Adafactor algorithm computation.
624
625    See :class:`~torch.optim.Adafactor` for details.
626    """
627    if not torch._utils.is_compiling() and not all(
628        isinstance(t, torch.Tensor) for t in state_steps
629    ):
630        raise RuntimeError(
631            "`state_steps` argument must contain a list of singleton tensors"
632        )
633
634    if foreach:
635        func = _multi_tensor_adafactor
636    else:
637        func = _single_tensor_adafactor
638
639    func(
640        params,
641        grads,
642        row_vars,
643        col_vars,
644        variances,
645        state_steps,
646        d=d,
647        lr=lr,
648        beta2_decay=beta2_decay,
649        weight_decay=weight_decay,
650        eps1=eps1,
651        eps2=eps2,
652        maximize=maximize,
653        grad_scale=grad_scale,
654        found_inf=found_inf,
655        has_complex=has_complex,
656    )
657