1# mypy: allow-untyped-decorators 2# mypy: allow-untyped-defs 3from typing import cast, Dict, List, Optional, Tuple, TYPE_CHECKING, Union 4 5import torch 6from torch import Tensor 7 8from .optimizer import ( 9 _disable_dynamo_if_unsupported, 10 _get_scalar_dtype, 11 _maximize_doc, 12 Optimizer, 13 ParamsT, 14 TensorListList, 15) 16 17 18__all__ = ["Adafactor", "adafactor"] 19 20 21class Adafactor(Optimizer): 22 def __init__( 23 self, 24 params: ParamsT, 25 lr: Union[float, Tensor] = 1e-2, 26 beta2_decay: float = -0.8, 27 eps: Tuple[Optional[float], float] = (None, 1e-3), 28 d: float = 1.0, 29 weight_decay: float = 0.0, 30 *, 31 foreach: Optional[bool] = None, 32 maximize: bool = False, 33 ): 34 if isinstance(lr, Tensor) and lr.numel() != 1: 35 raise ValueError("Tensor lr must be 1-element") 36 if not 0.0 <= lr: 37 raise ValueError(f"Learning rate should be >= 0 but is: {lr}") 38 if not 0.0 >= beta2_decay: 39 raise ValueError(f"beta2_decay should be <= 0 but is: {beta2_decay}") 40 if eps[0] is not None and not 0.0 <= eps[0]: 41 raise ValueError(f"epsilon1 should be >= 0 but is: {eps[0]}") 42 if not 0.0 <= eps[1]: 43 raise ValueError(f"epsilon2 should be >= 0 but is: {eps[1]}") 44 if not 1.0 <= d: 45 raise ValueError(f"Clipping threshold d should be >= 1 but is: {d}") 46 if not 0.0 <= weight_decay: 47 raise ValueError(f"weight_decay should be >= 0 but is: {weight_decay}") 48 defaults = dict( 49 lr=lr, 50 beta2_decay=beta2_decay, 51 eps=eps, 52 d=d, 53 weight_decay=weight_decay, 54 foreach=foreach, 55 maximize=maximize, 56 ) 57 super().__init__(params, defaults) 58 59 def __setstate__(self, state): 60 super().__setstate__(state) 61 for group in self.param_groups: 62 group.setdefault("foreach", None) 63 for p in group["params"]: 64 p_state = self.state.get(p, []) 65 if len(p_state) != 0 and not torch.is_tensor(p_state["step"]): 66 step_val = float(p_state["step"]) 67 p_state["step"] = torch.tensor(step_val, dtype=_get_scalar_dtype()) 68 69 def _init_group( 70 self, 71 group, 72 params_with_grad, 73 grads, 74 row_vars, 75 col_vars, 76 variances, 77 state_steps, 78 ): 79 for p in group["params"]: 80 if p.grad is None: 81 continue 82 if torch.is_complex(p): 83 raise RuntimeError("Adafactor does not support complex parameters") 84 if p.grad.is_sparse: 85 raise RuntimeError("Adafactor does not support sparse gradients") 86 87 params_with_grad.append(p) 88 grads.append(p.grad) 89 90 state = self.state[p] 91 92 # State initialization 93 if len(state) == 0: 94 # note(crcrpar): Deliberately host `step` on CPU if both capturable and fused are off. 95 # This is because kernel launches are costly on CUDA and XLA. 96 state["step"] = torch.tensor(0.0, dtype=_get_scalar_dtype()) 97 98 if p.grad.dim() > 1: 99 row_shape = list(p.grad.shape) 100 row_shape[-1] = 1 101 # Row factor of variance, NOT the same shape as grads (will be reduced along last dim) 102 state["row_var"] = p.grad.new_zeros(row_shape) 103 104 col_shape = list(p.grad.shape) 105 col_shape[-2] = 1 106 # Col factor of variance, NOT the same shape as grads (will be reduced along penultimate dim) 107 state["col_var"] = p.grad.new_zeros(col_shape) 108 else: 109 state["variance"] = torch.zeros_like( 110 p.grad, memory_format=torch.preserve_format 111 ) 112 113 row_vars.append(state.get("row_var", None)) 114 col_vars.append(state.get("col_var", None)) 115 variances.append(state.get("variance", None)) 116 state_steps.append(state["step"]) 117 return False # has_complex 118 119 @torch.no_grad() 120 def step(self, closure=None): 121 r"""Perform a single optimization step. 122 123 Args: 124 closure (Callable, optional): A closure that reevaluates the model 125 and returns the loss. 126 """ 127 self._cuda_graph_capture_health_check() 128 129 loss = None 130 if closure is not None: 131 with torch.enable_grad(): 132 loss = closure() 133 134 for group in self.param_groups: 135 params_with_grad: List[Tensor] = [] 136 grads: List[Tensor] = [] 137 row_vars: List[Optional[Tensor]] = [] 138 col_vars: List[Optional[Tensor]] = [] 139 variances: List[Optional[Tensor]] = [] 140 state_steps: List[Tensor] = [] 141 eps1, eps2 = group["eps"] 142 143 has_complex = self._init_group( 144 group, 145 params_with_grad, 146 grads, 147 row_vars, 148 col_vars, 149 variances, 150 state_steps, 151 ) 152 153 adafactor( 154 params_with_grad, 155 grads, 156 row_vars, 157 col_vars, 158 variances, 159 state_steps, 160 d=group["d"], 161 lr=group["lr"], 162 beta2_decay=group["beta2_decay"], 163 weight_decay=group["weight_decay"], 164 eps1=eps1, 165 eps2=eps2, 166 foreach=group["foreach"], 167 maximize=group["maximize"], 168 grad_scale=getattr(self, "grad_scale", None), 169 found_inf=getattr(self, "found_inf", None), 170 has_complex=has_complex, 171 ) 172 173 return loss 174 175 176Adafactor.__doc__ = ( 177 r"""Implements Adafactor algorithm. 178 179 .. math:: 180 \begin{aligned} 181 &\rule{110mm}{0.4pt} \\ 182 &\textbf{input} : \gamma \text{(lr)}, \: \tau 183 \text{(}\beta_2\text{ decay)}, \: \theta_0 \text{(params)}, \: f(\theta) \text{(objective)}, \\ 184 &\hspace{15mm} \: \epsilon_1, \epsilon_2 \text{ (epsilons)}, \: d \text{(clipping threshold)}, \\ 185 &\hspace{15mm} \: \lambda \text{(weight decay)}, 186 \: \textit{maximize} \\ 187 &\textbf{initialize} : \: R_0 \leftarrow 0 \text{ (second moment row factor)}, \\ 188 &\hspace{23mm} \: C_0 \leftarrow 0 \text{ (second moment col factor)}, \\ 189 &\hspace{23mm} \: \widehat{V}_0 \leftarrow 0 \text{ (second moment for vectors)} \\[-1.ex] 190 &\rule{110mm}{0.4pt} \\ 191 &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\ 192 193 &\hspace{5mm}\textbf{if} \: \textit{maximize}: \\ 194 &\hspace{10mm}G_t \leftarrow -\nabla_{\theta} f_t (\theta_{t-1}) \\ 195 &\hspace{5mm}\textbf{else} \\ 196 &\hspace{10mm}G_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\ 197 &\hspace{5mm}\widehat{\beta}_{2_t} \leftarrow 1 - t^{\tau} \\ 198 &\hspace{5mm}\rho_t \leftarrow min(lr, \frac{1}{\sqrt{t}}) \\ 199 &\hspace{5mm}\alpha_t \leftarrow max(\epsilon_2, 200 \text{RMS}(\theta_{t-1}))\rho_t \\ 201 &\hspace{5mm}\theta_t \leftarrow \theta_{t-1} - \gamma \lambda \theta_{t-1} \\ 202 &\hspace{5mm}\textbf{if} \: \text{dim}(G_t) > 1: \\ 203 &\hspace{10mm}R_t \leftarrow \widehat{\beta}_{2_t}R_{t-1}+ 204 (1-\widehat{\beta}_{2_t})(G_t \odot G_t) \cdot 1_m \\ 205 &\hspace{10mm}C_t \leftarrow \widehat{\beta}_{2_t}C_{t-1}+ 206 (1-\widehat{\beta}_{2_t}) 1^\top_n \cdot (G_t \odot G_t) \\ 207 &\hspace{10mm}\widehat{V}_t \leftarrow 208 \frac{R_t \cdot C_t}{max(1^\top_n \cdot R_t, \epsilon_1)} \\ 209 &\hspace{5mm}\textbf{else} \\ 210 &\hspace{10mm}\widehat{V}_t \leftarrow \widehat{\beta}_{2_t}\widehat{V}_{t-1}+ 211 (1-\widehat{\beta}_{2_t}) \cdot (G_t \odot G_t) \\ 212 &\hspace{5mm}U_t \leftarrow 213 \frac{G_t}{max(\sqrt{\widehat{V}_t}, \epsilon_1)} \\ 214 &\hspace{5mm}\widehat{U}_t \leftarrow \frac{U_t}{max(1, \frac{\text{RMS}(U_t)}{d})} \\ 215 &\hspace{5mm}\theta_t \leftarrow \theta_{t-1} - \alpha_t \widehat{U}_t \\ 216 217 &\rule{110mm}{0.4pt} \\[-1.ex] 218 &\bf{return} \: \theta_t \\[-1.ex] 219 &\rule{110mm}{0.4pt} \\[-1.ex] 220 \end{aligned} 221 222 For further details regarding the algorithm we refer to `Adafactor: Adaptive Learning Rates with Sublinear Memory Cost`_. 223 """ 224 + rf""" 225 Args: 226 params (iterable): iterable of parameters to optimize or dicts defining 227 parameter groups 228 lr (float, Tensor, optional): unlike other optimizers, Adafactor does not require a 229 learning rate, and Shazeer, Noam, and Mitchell Stern do not use lr at all. 230 Deviating from the paper, this implementation uses lr for applying weight 231 decay and as the maximum value for relative step size rho_t. Note that in 232 the paper, a constant of 0.01 is used as the maximum value for relative 233 step size, and so we set 0.01 as the default value. (default: 1e-2) 234 beta2_decay (float, optional): the decay rate of beta2. beta2 standardly refers 235 to the coefficient used for computing the running average of the gradient 236 squared. (default: -0.8) 237 eps (Tuple[float, float], optional): epsilon1 is the term added to the denominator 238 of the update calculation to improve numerical stability. This use of epsilon1 239 deviates from the algorithm written in the paper! See note below for more details. 240 epsilon2 is the term used to avoid having too small a weight update when applying 241 parameter scaling. (default: (None, 1e-3)) 242 d (float, optional): the clipping threshold, used to avoid larger-than-desired 243 updates. 244 weight_decay (float, optional): weight decay coefficient (default: 1e-2) 245 foreach (bool, optional): whether foreach implementation of optimizer is used. Note 246 that the foreach implementation uses ~ sizeof(params) more peak memory than the 247 for-loop version due to the intermediates being a tensorlist vs just one tensor. 248 As Adafactor is commonly used when memory is prohibitive, Adafactor will default 249 to the slower single tensor for-loop implementation unless this flag is explicitly 250 True. This behavior is contrary to other optimizers, which will attempt defaulting 251 to foreach on CUDA for faster runtime. (default: None) 252 {_maximize_doc}""" 253 + r""" 254 .. Note:: 255 The implementation of Adafactor subtly differs from Shazeer, Noam, and Mitchell Stern 256 and implementations in some other frameworks with its use of learning rate and 257 :math:`\epsilon_1`. 258 259 Regarding the learning rate hyperparameter: Shazeer, Noam, and Mitchell Stern do not 260 use lr at all, as the stated algorithm uses :math:`\rho_t` and update clipping to 261 affect the step size. 262 263 This implementation allows `lr` to influence the maximum value for :math:`\rho_t`: 264 265 .. math:: 266 \begin{aligned} 267 &\hspace{5mm}\rho_t \leftarrow min(lr, \frac{1}{\sqrt{t}}) 268 \end{aligned} 269 270 This differs from Shazeer, Noam, and Mitchell Stern, who use a constant of 0.01 as 271 the maximum value of :math:`\rho_t` 272 273 .. math:: 274 \begin{aligned} 275 &\hspace{5mm}\rho_t \leftarrow min(0.01, \frac{1}{\sqrt{t}}) 276 \end{aligned} 277 278 Shazeer, Noam, and Mitchell Stern do not enforce an opinion on how weight decay should 279 be computed, and so we use the learning rate as a coefficient for decoupled weight 280 decay, similar to what is suggested in `Decoupled Weight Decay Regularization`_. 281 282 Regarding the use of :math:`\epsilon_1`: The implementation attempts to replicate the 283 presumed intention of Shazeer, Noam, and Mitchell Stern to use :math:`\epsilon_1` as 284 a stabilizing term when the squared gradient becomes small. 285 286 This stabilization can be written as 287 288 .. math:: 289 \begin{aligned} 290 &\hspace{5mm}R_t \leftarrow \widehat{\beta}_{2_t}R_{t-1}+ 291 (1-\widehat{\beta}_{2_t})(G_t \odot G_t + 1_n \cdot 1^\top_m) \cdot 1_m \\ 292 &\hspace{5mm}C_t \leftarrow \widehat{\beta}_{2_t}C_{t-1}+ 293 (1-\widehat{\beta}_{2_t}) 1^\top_n \cdot (G_t \odot G_t + 1_n \cdot 1^\top_m) \\ 294 &\hspace{5mm}\widehat{V}_t \leftarrow 295 \frac{R_t \cdot C_t}{max(1^\top_n \cdot R_t, \epsilon_1)} \\ 296 &\hspace{5mm}U_t \leftarrow \frac{G_t}{max(\sqrt{\widehat{V}_t}, \epsilon_1)} \\ 297 \end{aligned} 298 299 where the row and column factors of gradient squared :math:`R_t` and :math:`C_t` 300 are left alone, and we apply :math:`\epsilon_1` at the final calculation of 301 the variance estimate :math:`\widehat{V}_t` and for the update :math:`U_t`. 302 303 This is in contrast to Shazeer, Noam, and Mitchell Stern and other frameworks which 304 apply :math:`\epsilon_1` to both row and column factors of the squared gradient, but 305 not in the calculations after: 306 307 .. math:: 308 \begin{aligned} 309 &\hspace{5mm}R_t \leftarrow \widehat{\beta}_{2_t}R_{t-1}+ 310 (1-\widehat{\beta}_{2_t})(G_t \odot G_t + \epsilon_1 1_n \cdot 1^\top_m) \cdot 1_m \\ 311 &\hspace{5mm}C_t \leftarrow \widehat{\beta}_{2_t}C_{t-1}+ 312 (1-\widehat{\beta}_{2_t}) 1^\top_n \cdot (G_t \odot G_t + \epsilon_1 1_n \cdot 1^\top_m) \\ 313 &\hspace{5mm}\widehat{V}_t \leftarrow \frac{R_t \cdot C_t}{1^\top_n \cdot R_t} \\ 314 &\hspace{5mm}U_t \leftarrow \frac{G_t}{\sqrt{\widehat{V}_t}} \\ 315 \end{aligned} 316 317 318 .. _Adafactor\: Adaptive Learning Rates with Sublinear Memory Cost: 319 https://arxiv.org/pdf/1804.04235 320 .. _Decoupled Weight Decay Regularization: 321 https://arxiv.org/abs/1711.05101 322 """ 323) 324 325 326def _single_tensor_adafactor( 327 params: List[Tensor], 328 grads: List[Tensor], 329 # If grad is 1-dimensional (aka a vector), there is no factorization necessary 330 # so row_var and col_var will be None while variance will be filled. 331 # Contrarily, for a grad with multiple dimensions, we will factor along the last 332 # 2 dimensions, and so row_var and col_var will be filled and variance will be None. 333 row_vars: List[Optional[Tensor]], 334 col_vars: List[Optional[Tensor]], 335 variances: List[Optional[Tensor]], 336 state_steps: List[Tensor], 337 grad_scale: Optional[Tensor], 338 found_inf: Optional[Tensor], 339 *, 340 d: float, 341 lr: Union[Tensor, float], 342 beta2_decay: float, 343 weight_decay: float, 344 eps1: Optional[float], 345 eps2: float, 346 maximize: bool, 347 has_complex: bool, 348): 349 assert ( 350 grad_scale is None and found_inf is None 351 ), "Grad scaling should occur outside of optimizer.step()" 352 353 if torch.jit.is_scripting(): 354 # this assert is due to JIT being dumb and not realizing that the ops below 355 # have overloads to handle both float and Tensor lrs, so we just assert it's 356 # a float since most people using JIT are using floats 357 assert isinstance(lr, float) 358 359 for i, param in enumerate(params): 360 grad = grads[i] if not maximize else -grads[i] 361 step_t = state_steps[i] 362 row_var = row_vars[i] 363 col_var = col_vars[i] 364 variance = variances[i] 365 if eps1 is None: 366 eps1 = torch.finfo(param.dtype).eps 367 368 # update step 369 step_t += 1 370 step_float = step_t.item() 371 372 one_minus_beta2_t = step_float**beta2_decay 373 rho_t = min(lr, 1 / (step_float**0.5)) 374 alpha = max(eps2, param.norm(2).item() / (param.numel() ** 0.5)) * rho_t 375 376 # Perform stepweight decay 377 if weight_decay != 0: 378 param.mul_(1 - lr * weight_decay) 379 380 if grad.dim() > 1: 381 assert ( 382 row_var is not None and col_var is not None 383 ), "row_var and col_var should be defined when grad is multidimensional" 384 # same as (g * g).mean(dim=-1) w/o materializing an intermediate size g 385 row_mean = ( 386 torch.norm(grad, dim=-1, keepdim=True).square_().div_(grad.size(-1)) 387 ) 388 row_var.lerp_(row_mean, one_minus_beta2_t) 389 # same as (g * g).mean(dim=-2) w/o materializing an intermediate size g 390 col_mean = ( 391 torch.norm(grad, dim=-2, keepdim=True).square_().div_(grad.size(-2)) 392 ) 393 col_var.lerp_(col_mean, one_minus_beta2_t) 394 var_estimate = row_var @ col_var 395 var_estimate.div_(row_var.mean(dim=-2, keepdim=True).clamp_(min=eps1)) 396 else: 397 assert ( 398 variance is not None 399 ), "variance should be defined when grad is a vector" 400 grad_squared = grad * grad 401 variance.lerp_(grad_squared, one_minus_beta2_t) 402 # avoid writing into variance during update 403 var_estimate = variance.clone() 404 405 # square the eps1 as we sqrt after to keep eps1's magnitude 406 update = var_estimate.clamp_(min=eps1 * eps1).rsqrt_() 407 update.mul_(grad) 408 denom = max(1.0, update.norm(2).item() / ((update.numel() ** 0.5) * d)) 409 param.add_(update, alpha=-alpha / denom) 410 411 412def _group_tensors_by_device_dtype_and_is_multidim( 413 tensorlists: TensorListList, 414) -> Dict[ 415 Tuple[Optional[torch.device], Optional[torch.dtype], bool], 416 List[List[Optional[Tensor]]], 417]: 418 """Groups tensors by device, dtype, AND multidimensionality -- whether the tensor 419 has multiple dims or just one dim (is a vector). This allows the foreach impl of 420 Adafactor to assume that every group of params will either be factored or not.""" 421 grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(tensorlists) 422 ultra_grouped_tensors: Dict[ 423 Tuple[Optional[torch.device], Optional[torch.dtype], bool], 424 List[List[Optional[Tensor]]], 425 ] = {} 426 for (device, dtype), (tensorlists, _) in grouped_tensors.items(): 427 matrix_key = (device, dtype, True) 428 vector_key = (device, dtype, False) 429 430 # assumes grad is the second tensorlist 431 for j, tensor in enumerate(tensorlists[1]): 432 assert tensor is not None, "grad should not be None" 433 if tensor.dim() > 1: 434 if matrix_key not in ultra_grouped_tensors: 435 ultra_grouped_tensors[matrix_key] = [[] for _ in tensorlists] 436 for i in range(len(tensorlists)): 437 ultra_grouped_tensors[matrix_key][i].append(tensorlists[i][j]) 438 else: 439 if vector_key not in ultra_grouped_tensors: 440 ultra_grouped_tensors[vector_key] = [[] for _ in tensorlists] 441 for i in range(len(tensorlists)): 442 ultra_grouped_tensors[vector_key][i].append(tensorlists[i][j]) 443 return ultra_grouped_tensors 444 445 446def _multi_tensor_adafactor( 447 params: List[Tensor], 448 grads: List[Tensor], 449 # If grad is 1-dimensional (aka a vector), there is no factorization necessary 450 # so row_var and col_var will be None while variance will be filled. 451 # Contrarily, for a grad with multiple dimensions, we will factor along the last 452 # 2 dimensions, and so row_var and col_var will be filled and variance will be None. 453 row_vars: List[Optional[Tensor]], 454 col_vars: List[Optional[Tensor]], 455 variances: List[Optional[Tensor]], 456 state_steps: List[Tensor], 457 grad_scale: Optional[Tensor], 458 found_inf: Optional[Tensor], 459 *, 460 d: float, 461 lr: Union[Tensor, float], 462 beta2_decay: float, 463 weight_decay: float, 464 eps1: Optional[float], 465 eps2: float, 466 maximize: bool, 467 has_complex: bool, 468): 469 if len(params) == 0: 470 return 471 472 assert ( 473 grad_scale is None and found_inf is None 474 ), "Grad scaling should occur outside of optimizer.step()" 475 476 grouped_tensors = _group_tensors_by_device_dtype_and_is_multidim( 477 [params, grads, row_vars, col_vars, variances, state_steps] # type: ignore[list-item] 478 ) 479 for (_, dtype, is_multidim), ( 480 ( 481 device_params_, 482 device_grads_, 483 device_row_vars_, 484 device_col_vars_, 485 device_variances_, 486 device_state_steps_, 487 ) 488 ) in grouped_tensors.items(): 489 device_params = cast(List[Tensor], device_params_) 490 device_grads = cast(List[Tensor], device_grads_) 491 device_state_steps = cast(List[Tensor], device_state_steps_) 492 if eps1 is None: 493 assert ( 494 dtype is not None 495 ), "dtype is needed to compute eps1 when eps1 is unset" 496 eps1 = torch.finfo(dtype).eps 497 498 if TYPE_CHECKING: 499 assert device_state_steps[0] is not None 500 501 if maximize: 502 device_grads = torch._foreach_neg(device_grads) # type: ignore[assignment] 503 504 # Update steps 505 # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over 506 # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just 507 # wrapped it once now. The alpha is required to assure we go to the right overload. 508 if not torch._utils.is_compiling() and device_state_steps[0].is_cpu: 509 torch._foreach_add_( 510 device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0 511 ) 512 else: 513 torch._foreach_add_(device_state_steps, 1.0) 514 515 one_minus_beta2_ts = [] 516 beta2_ts = [] 517 rho_ts = [] 518 for s in device_state_steps: 519 one_minus_beta2_ts.append(s.item() ** beta2_decay) 520 beta2_ts.append(1 - s.item() ** beta2_decay) 521 rho_ts.append(min(lr, 1 / (s.item() ** 0.5))) 522 523 alphas = [ 524 max(eps2, p.norm(2).item() / (p.numel() ** 0.5)) * r 525 for p, r in zip(device_params, rho_ts) 526 ] 527 528 # Perform stepweight decay 529 if weight_decay != 0: 530 torch._foreach_mul_(device_params, 1 - lr * weight_decay) 531 532 if is_multidim: 533 device_row_vars = cast(List[Tensor], device_row_vars_) 534 device_col_vars = cast(List[Tensor], device_col_vars_) 535 assert ( 536 device_row_vars[0] is not None and device_col_vars[0] is not None 537 ), "row_var and col_var should be defined when grad is multidimensional" 538 # same as (g * g).mean(dim=-1) w/o materializing an intermediate size g 539 row_means = [ 540 torch.norm(grad, dim=-1, keepdim=True) for grad in device_grads 541 ] 542 torch._foreach_mul_(row_means, row_means) 543 torch._foreach_div_(row_means, [grad.size(-1) for grad in device_grads]) 544 torch._foreach_mul_(device_row_vars, beta2_ts) 545 torch._foreach_mul_(row_means, one_minus_beta2_ts) 546 torch._foreach_add_(device_row_vars, row_means) 547 del row_means 548 549 # same as (g * g).mean(dim=-2) w/o materializing an intermediate size g 550 col_means = [ 551 torch.norm(grad, dim=-2, keepdim=True) for grad in device_grads 552 ] 553 torch._foreach_mul_(col_means, col_means) 554 torch._foreach_div_(col_means, [grad.size(-2) for grad in device_grads]) 555 torch._foreach_mul_(device_col_vars, beta2_ts) 556 torch._foreach_mul_(col_means, one_minus_beta2_ts) 557 torch._foreach_add_(device_col_vars, col_means) 558 del col_means 559 560 var_estimates = [ 561 row_var @ col_var 562 for row_var, col_var in zip(device_row_vars, device_col_vars) 563 ] 564 row_var_means = [ 565 row_var.mean(dim=-2, keepdim=True) for row_var in device_row_vars 566 ] 567 torch._foreach_clamp_min_(row_var_means, eps1) 568 torch._foreach_div_(var_estimates, row_var_means) 569 del row_var_means 570 else: 571 device_variances = cast(List[Tensor], device_variances_) 572 assert ( 573 device_variances[0] is not None 574 ), "variance should be defined when grad is a vector" 575 576 grads_squared = torch._foreach_mul(device_grads, device_grads) 577 torch._foreach_mul_(device_variances, beta2_ts) 578 torch._foreach_mul_(grads_squared, one_minus_beta2_ts) 579 torch._foreach_add_(device_variances, grads_squared) 580 del grads_squared 581 582 # avoid writing into variance during update 583 var_estimates = [v.clone() for v in device_variances] 584 585 # square the eps1 as we sqrt after to keep eps1's magnitude 586 torch._foreach_clamp_min_(var_estimates, eps1 * eps1) 587 torch._foreach_sqrt_(var_estimates) 588 torch._foreach_reciprocal_(var_estimates) 589 torch._foreach_mul_(var_estimates, device_grads) 590 updates = var_estimates 591 592 alphas = [ 593 -a / (max(1.0, update.norm(2).item() / ((update.numel() ** 0.5) * d))) 594 for a, update in zip(alphas, updates) 595 ] 596 torch._foreach_mul_(updates, alphas) 597 torch._foreach_add_(device_params, updates) 598 599 600@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_adafactor) 601def adafactor( 602 params: List[Tensor], 603 grads: List[Tensor], 604 row_vars: List[Optional[Tensor]], 605 col_vars: List[Optional[Tensor]], 606 variances: List[Optional[Tensor]], 607 state_steps: List[Tensor], 608 # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 609 # setting this as kwarg for now as functional API is compiled by torch/distributed/optim 610 foreach: Optional[bool] = None, 611 grad_scale: Optional[Tensor] = None, 612 found_inf: Optional[Tensor] = None, 613 has_complex: bool = False, 614 *, 615 d: float, 616 lr: Union[float, Tensor], 617 beta2_decay: float, 618 weight_decay: float, 619 eps1: float, 620 eps2: float, 621 maximize: bool, 622): 623 r"""Functional API that performs Adafactor algorithm computation. 624 625 See :class:`~torch.optim.Adafactor` for details. 626 """ 627 if not torch._utils.is_compiling() and not all( 628 isinstance(t, torch.Tensor) for t in state_steps 629 ): 630 raise RuntimeError( 631 "`state_steps` argument must contain a list of singleton tensors" 632 ) 633 634 if foreach: 635 func = _multi_tensor_adafactor 636 else: 637 func = _single_tensor_adafactor 638 639 func( 640 params, 641 grads, 642 row_vars, 643 col_vars, 644 variances, 645 state_steps, 646 d=d, 647 lr=lr, 648 beta2_decay=beta2_decay, 649 weight_decay=weight_decay, 650 eps1=eps1, 651 eps2=eps2, 652 maximize=maximize, 653 grad_scale=grad_scale, 654 found_inf=found_inf, 655 has_complex=has_complex, 656 ) 657