1# mypy: allow-untyped-defs 2from typing import Dict, List, Optional 3 4import torch 5import torch.optim._functional as F 6from torch import Tensor 7 8 9__all__: List[str] = [] 10 11 12# Define a TorchScript compatible Functional Adagrad Optimizer 13# where we use these optimizer in a functional way. 14# Instead of using the `param.grad` when updating parameters, 15# we explicitly let the user pass gradients to the `step` function 16# this is so that we could separate the gradients and parameters 17# and allow multithreaded trainer to update the parameters 18# without data traces on accumulating to the same .grad. 19# NOTE: This should be only used by distributed optimizer internals 20# and not meant to expose to the user. 21@torch.jit.script 22class _FunctionalAdagrad: 23 def __init__( 24 self, 25 params: List[Tensor], 26 lr: float = 1e-2, 27 lr_decay: float = 0.0, 28 weight_decay: float = 0.0, 29 initial_accumulator_value: float = 0.0, 30 warmup_lr_multiplier: float = 1.0, 31 warmup_num_iters: float = 0.0, 32 eps: float = 1e-10, 33 coalesce_grad: bool = True, 34 foreach: bool = False, 35 fused: bool = False, 36 maximize: bool = False, 37 _allow_empty_param_list: bool = False, 38 ): 39 self.defaults = { 40 "lr": lr, 41 "lr_decay": lr_decay, 42 "eps": eps, 43 "weight_decay": weight_decay, 44 "initial_accumulator_value": initial_accumulator_value, 45 "warmup_lr_multiplier": warmup_lr_multiplier, 46 "warmup_num_iters": warmup_num_iters, 47 } 48 self.coalesce_grad = coalesce_grad 49 self.foreach = foreach 50 self.fused = fused 51 self.maximize = maximize 52 self.state = torch.jit.annotate(Dict[torch.Tensor, Dict[str, torch.Tensor]], {}) 53 54 if len(params) == 0 and not _allow_empty_param_list: 55 raise ValueError("optimizer got an empty parameter list") 56 57 # NOTE: we only have one param_group and don't allow user to add additional 58 # param group as it's not a common use case. 59 self.param_group = {"params": params} 60 61 # TODO: no union or any types in TorchScript, make step a scalar tensor instead 62 # This is also needed by if we want to share_memory on the step across processes 63 for p in self.param_group["params"]: 64 self.state[p] = { 65 "sum": torch.full_like(p.data, initial_accumulator_value), 66 "step": torch.tensor(0.0), 67 } 68 69 def step(self, gradients: List[Optional[Tensor]]): 70 params = self.param_group["params"] 71 params_with_grad = [] 72 grads = [] 73 state_sums = [] 74 state_steps: List[Tensor] = [] 75 76 if len(params) != len(gradients): 77 raise ValueError( 78 "the gradients passed in does not equal to the size of the parameters!" 79 + f"Params length: {len(params)}. " 80 + f"Gradients length: {len(gradients)}" 81 ) 82 83 has_sparse_grad, has_complex = False, False 84 for param, gradient in zip(self.param_group["params"], gradients): 85 if gradient is not None: 86 has_sparse_grad |= gradient.is_sparse 87 has_complex |= torch.is_complex(param) 88 params_with_grad.append(param) 89 grads.append(gradient) 90 state = self.state[param] 91 state_sums.append(state["sum"]) 92 state_steps.append(state["step"]) 93 94 with torch.no_grad(): 95 F.adagrad( 96 params, 97 grads, 98 state_sums, 99 state_steps, 100 lr=self.defaults["lr"], 101 weight_decay=self.defaults["weight_decay"], 102 lr_decay=self.defaults["lr_decay"], 103 eps=self.defaults["eps"], 104 has_sparse_grad=has_sparse_grad, 105 foreach=self.foreach, 106 maximize=self.maximize, 107 has_complex=has_complex, 108 fused=self.fused, 109 grad_scale=None, 110 found_inf=None, 111 ) 112