1# mypy: allow-untyped-decorators 2# mypy: allow-untyped-defs 3import functools 4from typing import cast, Dict, Iterable, List, Optional, Tuple, Union 5from typing_extensions import deprecated 6 7import torch 8from torch import Tensor 9from torch.utils._foreach_utils import ( 10 _device_has_foreach_support, 11 _group_tensors_by_device_and_dtype, 12 _has_foreach_support, 13) 14 15 16__all__ = ["clip_grad_norm_", "clip_grad_norm", "clip_grad_value_"] 17 18 19_tensor_or_tensors = Union[torch.Tensor, Iterable[torch.Tensor]] 20 21 22def _no_grad(func): 23 """ 24 This wrapper is needed to avoid a circular import when using @torch.no_grad on the exposed functions 25 clip_grad_norm_ and clip_grad_value_ themselves. 26 """ 27 28 def _no_grad_wrapper(*args, **kwargs): 29 with torch.no_grad(): 30 return func(*args, **kwargs) 31 32 functools.update_wrapper(_no_grad_wrapper, func) 33 return _no_grad_wrapper 34 35 36@_no_grad 37def clip_grad_norm_( 38 parameters: _tensor_or_tensors, 39 max_norm: float, 40 norm_type: float = 2.0, 41 error_if_nonfinite: bool = False, 42 foreach: Optional[bool] = None, 43) -> torch.Tensor: 44 r"""Clip the gradient norm of an iterable of parameters. 45 46 The norm is computed over the norms of the individual gradients of all parameters, 47 as if the norms of the individual gradients were concatenated into a single vector. 48 Gradients are modified in-place. 49 50 Args: 51 parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a 52 single Tensor that will have gradients normalized 53 max_norm (float): max norm of the gradients 54 norm_type (float): type of the used p-norm. Can be ``'inf'`` for 55 infinity norm. 56 error_if_nonfinite (bool): if True, an error is thrown if the total 57 norm of the gradients from :attr:`parameters` is ``nan``, 58 ``inf``, or ``-inf``. Default: False (will switch to True in the future) 59 foreach (bool): use the faster foreach-based implementation. 60 If ``None``, use the foreach implementation for CUDA and CPU native tensors and silently 61 fall back to the slow implementation for other device types. 62 Default: ``None`` 63 64 Returns: 65 Total norm of the parameter gradients (viewed as a single vector). 66 """ 67 if isinstance(parameters, torch.Tensor): 68 parameters = [parameters] 69 grads = [p.grad for p in parameters if p.grad is not None] 70 max_norm = float(max_norm) 71 norm_type = float(norm_type) 72 if len(grads) == 0: 73 return torch.tensor(0.0) 74 first_device = grads[0].device 75 grouped_grads: Dict[ 76 Tuple[torch.device, torch.dtype], Tuple[List[List[Tensor]], List[int]] 77 ] = _group_tensors_by_device_and_dtype( 78 [grads] 79 ) # type: ignore[assignment] 80 81 norms: List[Tensor] = [] 82 for (device, _), ([device_grads], _) in grouped_grads.items(): # type: ignore[assignment] 83 if (foreach is None and _has_foreach_support(device_grads, device)) or ( 84 foreach and _device_has_foreach_support(device) 85 ): 86 norms.extend(torch._foreach_norm(device_grads, norm_type)) 87 elif foreach: 88 raise RuntimeError( 89 f"foreach=True was passed, but can't use the foreach API on {device.type} tensors" 90 ) 91 else: 92 norms.extend([torch.linalg.vector_norm(g, norm_type) for g in device_grads]) 93 94 total_norm = torch.linalg.vector_norm( 95 torch.stack([norm.to(first_device) for norm in norms]), norm_type 96 ) 97 98 if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()): 99 raise RuntimeError( 100 f"The total norm of order {norm_type} for gradients from " 101 "`parameters` is non-finite, so it cannot be clipped. To disable " 102 "this error and scale the gradients by the non-finite norm anyway, " 103 "set `error_if_nonfinite=False`" 104 ) 105 clip_coef = max_norm / (total_norm + 1e-6) 106 # Note: multiplying by the clamped coef is redundant when the coef is clamped to 1, but doing so 107 # avoids a `if clip_coef < 1:` conditional which can require a CPU <=> device synchronization 108 # when the gradients do not reside in CPU memory. 109 clip_coef_clamped = torch.clamp(clip_coef, max=1.0) 110 for (device, _), ([device_grads], _) in grouped_grads.items(): # type: ignore[assignment] 111 if (foreach is None and _has_foreach_support(device_grads, device)) or ( 112 foreach and _device_has_foreach_support(device) 113 ): 114 torch._foreach_mul_(device_grads, clip_coef_clamped.to(device)) 115 elif foreach: 116 raise RuntimeError( 117 f"foreach=True was passed, but can't use the foreach API on {device.type} tensors" 118 ) 119 else: 120 clip_coef_clamped_device = clip_coef_clamped.to(device) 121 for g in device_grads: 122 g.mul_(clip_coef_clamped_device) 123 124 return total_norm 125 126 127@deprecated( 128 "`torch.nn.utils.clip_grad_norm` is now deprecated " 129 "in favor of `torch.nn.utils.clip_grad_norm_`.", 130 category=FutureWarning, 131) 132def clip_grad_norm( 133 parameters: _tensor_or_tensors, 134 max_norm: float, 135 norm_type: float = 2.0, 136 error_if_nonfinite: bool = False, 137 foreach: Optional[bool] = None, 138) -> torch.Tensor: 139 r"""Clip the gradient norm of an iterable of parameters. 140 141 .. warning:: 142 This method is now deprecated in favor of 143 :func:`torch.nn.utils.clip_grad_norm_`. 144 """ 145 return clip_grad_norm_(parameters, max_norm, norm_type, error_if_nonfinite, foreach) 146 147 148@_no_grad 149def clip_grad_value_( 150 parameters: _tensor_or_tensors, 151 clip_value: float, 152 foreach: Optional[bool] = None, 153) -> None: 154 r"""Clip the gradients of an iterable of parameters at specified value. 155 156 Gradients are modified in-place. 157 158 Args: 159 parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a 160 single Tensor that will have gradients normalized 161 clip_value (float): maximum allowed value of the gradients. 162 The gradients are clipped in the range 163 :math:`\left[\text{-clip\_value}, \text{clip\_value}\right]` 164 foreach (bool): use the faster foreach-based implementation 165 If ``None``, use the foreach implementation for CUDA and CPU native tensors and 166 silently fall back to the slow implementation for other device types. 167 Default: ``None`` 168 """ 169 if isinstance(parameters, torch.Tensor): 170 parameters = [parameters] 171 clip_value = float(clip_value) 172 173 grads = [p.grad for p in parameters if p.grad is not None] 174 grouped_grads = _group_tensors_by_device_and_dtype([grads]) 175 176 for (device, _), ([grads], _) in grouped_grads.items(): # type: ignore[assignment] 177 if ( 178 foreach is None 179 and _has_foreach_support(cast(List[Tensor], grads), device=device) 180 ) or (foreach and _device_has_foreach_support(device)): 181 torch._foreach_clamp_min_(cast(List[Tensor], grads), -clip_value) 182 torch._foreach_clamp_max_(cast(List[Tensor], grads), clip_value) 183 elif foreach: 184 raise RuntimeError( 185 f"foreach=True was passed, but can't use the foreach API on {device.type} tensors" 186 ) 187 else: 188 for grad in grads: 189 cast(Tensor, grad).clamp_(min=-clip_value, max=clip_value) 190