1# mypy: allow-untyped-decorators 2# mypy: allow-untyped-defs 3from typing import Any, cast, Dict, List, Optional, Union 4 5import torch 6from torch import Tensor 7 8from .optimizer import ( 9 _capturable_doc, 10 _default_to_fused_or_foreach, 11 _differentiable_doc, 12 _disable_dynamo_if_unsupported, 13 _foreach_doc, 14 _get_capturable_supported_devices, 15 _get_scalar_dtype, 16 _maximize_doc, 17 _use_grad_for_differentiable, 18 _view_as_real, 19 Optimizer, 20 ParamsT, 21) 22 23 24__all__ = ["Adadelta", "adadelta"] 25 26 27class Adadelta(Optimizer): 28 def __init__( 29 self, 30 params: ParamsT, 31 lr: Union[float, Tensor] = 1.0, 32 rho: float = 0.9, 33 eps: float = 1e-6, 34 weight_decay: float = 0, 35 foreach: Optional[bool] = None, 36 *, 37 capturable: bool = False, 38 maximize: bool = False, 39 differentiable: bool = False, 40 ): 41 if isinstance(lr, Tensor) and lr.numel() != 1: 42 raise ValueError("Tensor lr must be 1-element") 43 if not 0.0 <= lr: 44 raise ValueError(f"Invalid learning rate: {lr}") 45 if not 0.0 <= rho <= 1.0: 46 raise ValueError(f"Invalid rho value: {rho}") 47 if not 0.0 <= eps: 48 raise ValueError(f"Invalid epsilon value: {eps}") 49 if not 0.0 <= weight_decay: 50 raise ValueError(f"Invalid weight_decay value: {weight_decay}") 51 52 defaults = dict( 53 lr=lr, 54 rho=rho, 55 eps=eps, 56 weight_decay=weight_decay, 57 maximize=maximize, 58 capturable=capturable, 59 foreach=foreach, 60 differentiable=differentiable, 61 ) 62 super().__init__(params, defaults) 63 64 def __setstate__(self, state): 65 super().__setstate__(state) 66 for group in self.param_groups: 67 group.setdefault("foreach", None) 68 group.setdefault("maximize", False) 69 group.setdefault("differentiable", False) 70 group.setdefault("capturable", False) 71 for p in group["params"]: 72 p_state = self.state.get(p, []) 73 if len(p_state) != 0 and not torch.is_tensor(p_state["step"]): 74 step_val = float(p_state["step"]) 75 p_state["step"] = ( 76 torch.tensor( 77 step_val, dtype=_get_scalar_dtype(), device=p.device 78 ) 79 if group["capturable"] 80 else torch.tensor(step_val, dtype=_get_scalar_dtype()) 81 ) 82 83 def _init_group( 84 self, 85 group: Dict[str, Any], 86 params_with_grad: List[Tensor], 87 grads: List[Tensor], 88 square_avgs: List[Tensor], 89 acc_deltas: List[Tensor], 90 state_steps: List[Tensor], 91 ): 92 has_complex = False 93 p: Tensor 94 for p in group["params"]: 95 if p.grad is None: 96 continue 97 has_complex |= torch.is_complex(p) 98 params_with_grad.append(p) 99 if p.grad.is_sparse: 100 raise RuntimeError("Adadelta does not support sparse gradients") 101 grads.append(p.grad) 102 103 state = self.state[p] 104 105 # Lazy state initialization 106 if len(state) == 0: 107 state["step"] = ( 108 torch.zeros((), dtype=_get_scalar_dtype(), device=p.device) 109 if group["capturable"] 110 else torch.zeros((), dtype=_get_scalar_dtype()) 111 ) 112 113 state["square_avg"] = torch.zeros_like( 114 p, memory_format=torch.preserve_format 115 ) 116 state["acc_delta"] = torch.zeros_like( 117 p, memory_format=torch.preserve_format 118 ) 119 120 square_avgs.append(state["square_avg"]) 121 acc_deltas.append(state["acc_delta"]) 122 state_steps.append(state["step"]) 123 124 return has_complex 125 126 @_use_grad_for_differentiable 127 def step(self, closure=None): 128 """Perform a single optimization step. 129 130 Args: 131 closure (Callable, optional): A closure that reevaluates the model 132 and returns the loss. 133 """ 134 self._cuda_graph_capture_health_check() 135 136 loss = None 137 if closure is not None: 138 with torch.enable_grad(): 139 loss = closure() 140 141 for group in self.param_groups: 142 params_with_grad: List[Tensor] = [] 143 grads: List[Tensor] = [] 144 square_avgs: List[Tensor] = [] 145 acc_deltas: List[Tensor] = [] 146 state_steps: List[Tensor] = [] 147 ( 148 lr, 149 rho, 150 eps, 151 weight_decay, 152 foreach, 153 maximize, 154 differentiable, 155 capturable, 156 ) = ( 157 group["lr"], 158 group["rho"], 159 group["eps"], 160 group["weight_decay"], 161 group["foreach"], 162 group["maximize"], 163 group["differentiable"], 164 group["capturable"], 165 ) 166 167 has_complex = self._init_group( 168 group, params_with_grad, grads, square_avgs, acc_deltas, state_steps 169 ) 170 171 adadelta( 172 params_with_grad, 173 grads, 174 square_avgs, 175 acc_deltas, 176 state_steps, 177 lr=lr, 178 rho=rho, 179 eps=eps, 180 weight_decay=weight_decay, 181 foreach=foreach, 182 maximize=maximize, 183 differentiable=differentiable, 184 capturable=capturable, 185 has_complex=has_complex, 186 ) 187 188 return loss 189 190 191Adadelta.__doc__ = ( 192 r"""Implements Adadelta algorithm. 193 194 .. math:: 195 \begin{aligned} 196 &\rule{110mm}{0.4pt} \\ 197 &\textbf{input} : \gamma \text{ (lr)}, \: \theta_0 \text{ (params)}, 198 \: f(\theta) \text{ (objective)}, \: \rho \text{ (decay)}, 199 \: \lambda \text{ (weight decay)} \\ 200 &\textbf{initialize} : v_0 \leftarrow 0 \: \text{ (square avg)}, 201 \: u_0 \leftarrow 0 \: \text{ (accumulate variables)} \\[-1.ex] 202 &\rule{110mm}{0.4pt} \\ 203 &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\ 204 &\hspace{5mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\ 205 &\hspace{5mm}if \: \lambda \neq 0 \\ 206 &\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\ 207 &\hspace{5mm} v_t \leftarrow v_{t-1} \rho + g^2_t (1 - \rho) \\ 208 &\hspace{5mm}\Delta x_t \leftarrow \frac{\sqrt{u_{t-1} + 209 \epsilon }}{ \sqrt{v_t + \epsilon} }g_t \hspace{21mm} \\ 210 &\hspace{5mm} u_t \leftarrow u_{t-1} \rho + 211 \Delta x^2_t (1 - \rho) \\ 212 &\hspace{5mm}\theta_t \leftarrow \theta_{t-1} - \gamma \Delta x_t \\ 213 &\rule{110mm}{0.4pt} \\[-1.ex] 214 &\bf{return} \: \theta_t \\[-1.ex] 215 &\rule{110mm}{0.4pt} \\[-1.ex] 216 \end{aligned} 217 218 For further details regarding the algorithm we refer to `ADADELTA: An Adaptive Learning Rate Method`_. 219 """ 220 + rf""" 221 Args: 222 params (iterable): iterable of parameters to optimize or dicts defining 223 parameter groups 224 rho (float, optional): coefficient used for computing a running average 225 of squared gradients (default: 0.9). A higher value of `rho` will 226 result in a slower average, which can be helpful for preventing 227 oscillations in the learning process. 228 eps (float, optional): term added to the denominator to improve 229 numerical stability (default: 1e-6). 230 lr (float, Tensor, optional): coefficient that scale delta before it is applied 231 to the parameters (default: 1.0) 232 weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 233 {_foreach_doc} 234 {_capturable_doc} 235 {_maximize_doc} 236 {_differentiable_doc} 237 238 .. _ADADELTA\: An Adaptive Learning Rate Method: 239 https://arxiv.org/abs/1212.5701 240 241 """ 242) 243 244 245def _single_tensor_adadelta( 246 params: List[Tensor], 247 grads: List[Tensor], 248 square_avgs: List[Tensor], 249 acc_deltas: List[Tensor], 250 state_steps: List[Tensor], 251 *, 252 lr: float, 253 rho: float, 254 eps: float, 255 weight_decay: float, 256 maximize: bool, 257 differentiable: bool, 258 capturable: bool, 259 has_complex: bool, 260): 261 # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] 262 if not torch._utils.is_compiling() and capturable: 263 capturable_supported_devices = _get_capturable_supported_devices( 264 supports_xla=False 265 ) 266 assert all( 267 p.device.type == step.device.type 268 and p.device.type in capturable_supported_devices 269 for p, step in zip(params, state_steps) 270 ), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}." 271 272 for param, grad, square_avg, acc_delta, step in zip( 273 params, grads, square_avgs, acc_deltas, state_steps 274 ): 275 step += 1 276 grad = grad if not maximize else -grad 277 278 if weight_decay != 0: 279 grad = grad.add(param, alpha=weight_decay) 280 281 if torch.is_complex(param): 282 square_avg = torch.view_as_real(square_avg) 283 acc_delta = torch.view_as_real(acc_delta) 284 grad = torch.view_as_real(grad) 285 286 square_avg.mul_(rho).addcmul_(grad, grad, value=1 - rho) 287 std = square_avg.add(eps).sqrt_() 288 delta = acc_delta.add(eps).sqrt_() 289 if differentiable: 290 delta = delta.clone() 291 delta.div_(std).mul_(grad) 292 acc_delta.mul_(rho).addcmul_(delta, delta, value=1 - rho) 293 294 if torch.is_complex(param): 295 delta = torch.view_as_complex(delta) 296 param.add_(delta, alpha=-lr) 297 298 299def _multi_tensor_adadelta( 300 params: List[Tensor], 301 grads: List[Tensor], 302 square_avgs: List[Tensor], 303 acc_deltas: List[Tensor], 304 state_steps: List[Tensor], 305 *, 306 lr: float, 307 rho: float, 308 eps: float, 309 weight_decay: float, 310 maximize: bool, 311 differentiable: bool, 312 capturable: bool, 313 has_complex: bool, 314): 315 assert not differentiable, "_foreach ops don't support autograd" 316 317 # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] 318 if not torch._utils.is_compiling() and capturable: 319 capturable_supported_devices = _get_capturable_supported_devices( 320 supports_xla=False 321 ) 322 assert all( 323 p.device.type == step.device.type 324 and p.device.type in capturable_supported_devices 325 for p, step in zip(params, state_steps) 326 ), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}." 327 328 if len(params) == 0: 329 return 330 331 grouped_tensors = Optimizer._group_tensors_by_device_and_dtype( 332 [params, grads, square_avgs, acc_deltas, state_steps] # type: ignore[list-item] 333 ) 334 for ( 335 device_params_, 336 device_grads_, 337 device_square_avgs_, 338 device_acc_deltas_, 339 device_state_steps_, 340 ), _ in grouped_tensors.values(): 341 device_params = cast(List[Tensor], device_params_) 342 device_grads = cast(List[Tensor], device_grads_) 343 device_square_avgs = cast(List[Tensor], device_square_avgs_) 344 device_acc_deltas = cast(List[Tensor], device_acc_deltas_) 345 device_state_steps = cast(List[Tensor], device_state_steps_) 346 if has_complex: 347 _view_as_real( 348 device_params, device_grads, device_square_avgs, device_acc_deltas 349 ) 350 351 # Update steps 352 # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over 353 # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just 354 # wrapped it once now. The alpha is required to assure we go to the right overload. 355 if not torch._utils.is_compiling() and device_state_steps[0].is_cpu: 356 torch._foreach_add_( 357 device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0 358 ) 359 else: 360 torch._foreach_add_(device_state_steps, 1) 361 362 if maximize: 363 device_grads = torch._foreach_neg(device_grads) # type: ignore[assignment] 364 365 if weight_decay != 0: 366 # Re-use the intermediate memory (device_grads) already allocated for maximize 367 if maximize: 368 torch._foreach_add_(device_grads, device_params, alpha=weight_decay) 369 else: 370 device_grads = torch._foreach_add( # type: ignore[assignment] 371 device_grads, device_params, alpha=weight_decay 372 ) 373 374 torch._foreach_mul_(device_square_avgs, rho) 375 torch._foreach_addcmul_( 376 device_square_avgs, device_grads, device_grads, value=1 - rho 377 ) 378 379 std = torch._foreach_add(device_square_avgs, eps) 380 torch._foreach_sqrt_(std) 381 382 deltas = torch._foreach_add(device_acc_deltas, eps) 383 torch._foreach_sqrt_(deltas) 384 torch._foreach_div_(deltas, std) 385 torch._foreach_mul_(deltas, device_grads) 386 387 torch._foreach_mul_(device_acc_deltas, rho) 388 torch._foreach_addcmul_(device_acc_deltas, deltas, deltas, value=1 - rho) 389 390 # If LR is a tensor, the else branch will internally call item() 391 # which will cause silent incorrectness if we are capturing 392 if capturable and isinstance(lr, torch.Tensor): 393 torch._foreach_mul_(deltas, -lr) 394 torch._foreach_add_(device_params, deltas) 395 else: 396 torch._foreach_add_(device_params, deltas, alpha=-lr) 397 398 399@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_adadelta) 400def adadelta( 401 params: List[Tensor], 402 grads: List[Tensor], 403 square_avgs: List[Tensor], 404 acc_deltas: List[Tensor], 405 state_steps: List[Tensor], 406 # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 407 # setting this as kwarg for now as functional API is compiled by torch/distributed/optim 408 capturable: bool = False, 409 foreach: Optional[bool] = None, 410 differentiable: bool = False, 411 has_complex: bool = False, 412 *, 413 lr: float, 414 rho: float, 415 eps: float, 416 weight_decay: float, 417 maximize: bool, 418): 419 r"""Functional API that performs Adadelta algorithm computation. 420 421 See :class:`~torch.optim.Adadelta` for details. 422 """ 423 424 # this check is slow during compilation, so we skip it 425 # if it's strictly needed we can add this check back in dynamo 426 if not torch._utils.is_compiling() and not all( 427 isinstance(t, torch.Tensor) for t in state_steps 428 ): 429 raise RuntimeError( 430 "API has changed, `state_steps` argument must contain a list of singleton tensors" 431 ) 432 433 # We still respect when the user inputs False for foreach. 434 if foreach is None: 435 _, foreach = _default_to_fused_or_foreach( 436 params, differentiable, use_fused=False 437 ) 438 439 if foreach and torch.jit.is_scripting(): 440 raise RuntimeError("torch.jit.script not supported with foreach optimizers") 441 442 if foreach and not torch.jit.is_scripting(): 443 func = _multi_tensor_adadelta 444 else: 445 func = _single_tensor_adadelta 446 447 func( 448 params, 449 grads, 450 square_avgs, 451 acc_deltas, 452 state_steps, 453 lr=lr, 454 rho=rho, 455 eps=eps, 456 weight_decay=weight_decay, 457 maximize=maximize, 458 differentiable=differentiable, 459 capturable=capturable, 460 has_complex=has_complex, 461 ) 462