1# mypy: allow-untyped-defs 2r"""Functional interface.""" 3import math 4from typing import List 5 6from torch import Tensor 7 8from .adadelta import adadelta # type: ignore[attr-defined] # noqa: F401 9from .adagrad import _make_sparse, adagrad # type: ignore[attr-defined] # noqa: F401 10from .adam import adam # type: ignore[attr-defined] # noqa: F401 11from .adamax import adamax # type: ignore[attr-defined] # noqa: F401 12from .adamw import adamw # type: ignore[attr-defined] # noqa: F401 13from .asgd import asgd # type: ignore[attr-defined] # noqa: F401 14from .nadam import nadam # type: ignore[attr-defined] # noqa: F401 15from .radam import radam # type: ignore[attr-defined] # noqa: F401 16from .rmsprop import rmsprop # type: ignore[attr-defined] # noqa: F401 17from .rprop import rprop # type: ignore[attr-defined] # noqa: F401 18from .sgd import sgd # type: ignore[attr-defined] # noqa: F401 19 20 21# TODO: use foreach API in optim._functional to do all the computation 22 23 24def sparse_adam( 25 params: List[Tensor], 26 grads: List[Tensor], 27 exp_avgs: List[Tensor], 28 exp_avg_sqs: List[Tensor], 29 state_steps: List[int], 30 *, 31 eps: float, 32 beta1: float, 33 beta2: float, 34 lr: float, 35 maximize: bool, 36): 37 r"""Functional API that performs Sparse Adam algorithm computation. 38 39 See :class:`~torch.optim.SparseAdam` for details. 40 """ 41 for i, param in enumerate(params): 42 grad = grads[i] 43 grad = grad if not maximize else -grad 44 grad = grad.coalesce() # the update is non-linear so indices must be unique 45 grad_indices = grad._indices() 46 grad_values = grad._values() 47 if grad_values.numel() == 0: 48 # Skip update for empty grad 49 continue 50 size = grad.size() 51 52 exp_avg = exp_avgs[i] 53 exp_avg_sq = exp_avg_sqs[i] 54 step = state_steps[i] 55 56 def make_sparse(values): 57 constructor = grad.new 58 if grad_indices.dim() == 0 or values.dim() == 0: 59 return constructor().resize_as_(grad) 60 return constructor(grad_indices, values, size) 61 62 # Decay the first and second moment running average coefficient 63 # old <- b * old + (1 - b) * new 64 # <==> old += (1 - b) * (new - old) 65 old_exp_avg_values = exp_avg.sparse_mask(grad)._values() 66 exp_avg_update_values = grad_values.sub(old_exp_avg_values).mul_(1 - beta1) 67 exp_avg.add_(make_sparse(exp_avg_update_values)) 68 old_exp_avg_sq_values = exp_avg_sq.sparse_mask(grad)._values() 69 exp_avg_sq_update_values = ( 70 grad_values.pow(2).sub_(old_exp_avg_sq_values).mul_(1 - beta2) 71 ) 72 exp_avg_sq.add_(make_sparse(exp_avg_sq_update_values)) 73 74 # Dense addition again is intended, avoiding another sparse_mask 75 numer = exp_avg_update_values.add_(old_exp_avg_values) 76 exp_avg_sq_update_values.add_(old_exp_avg_sq_values) 77 denom = exp_avg_sq_update_values.sqrt_().add_(eps) 78 del exp_avg_update_values, exp_avg_sq_update_values 79 80 bias_correction1 = 1 - beta1**step 81 bias_correction2 = 1 - beta2**step 82 step_size = lr * math.sqrt(bias_correction2) / bias_correction1 83 84 param.add_(make_sparse(-step_size * numer.div_(denom))) 85