xref: /aosp_15_r20/external/pytorch/torch/nn/utils/clip_grad.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-decorators
2# mypy: allow-untyped-defs
3import functools
4from typing import cast, Dict, Iterable, List, Optional, Tuple, Union
5from typing_extensions import deprecated
6
7import torch
8from torch import Tensor
9from torch.utils._foreach_utils import (
10    _device_has_foreach_support,
11    _group_tensors_by_device_and_dtype,
12    _has_foreach_support,
13)
14
15
16__all__ = ["clip_grad_norm_", "clip_grad_norm", "clip_grad_value_"]
17
18
19_tensor_or_tensors = Union[torch.Tensor, Iterable[torch.Tensor]]
20
21
22def _no_grad(func):
23    """
24    This wrapper is needed to avoid a circular import when using @torch.no_grad on the exposed functions
25    clip_grad_norm_ and clip_grad_value_ themselves.
26    """
27
28    def _no_grad_wrapper(*args, **kwargs):
29        with torch.no_grad():
30            return func(*args, **kwargs)
31
32    functools.update_wrapper(_no_grad_wrapper, func)
33    return _no_grad_wrapper
34
35
36@_no_grad
37def clip_grad_norm_(
38    parameters: _tensor_or_tensors,
39    max_norm: float,
40    norm_type: float = 2.0,
41    error_if_nonfinite: bool = False,
42    foreach: Optional[bool] = None,
43) -> torch.Tensor:
44    r"""Clip the gradient norm of an iterable of parameters.
45
46    The norm is computed over the norms of the individual gradients of all parameters,
47    as if the norms of the individual gradients were concatenated into a single vector.
48    Gradients are modified in-place.
49
50    Args:
51        parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
52            single Tensor that will have gradients normalized
53        max_norm (float): max norm of the gradients
54        norm_type (float): type of the used p-norm. Can be ``'inf'`` for
55            infinity norm.
56        error_if_nonfinite (bool): if True, an error is thrown if the total
57            norm of the gradients from :attr:`parameters` is ``nan``,
58            ``inf``, or ``-inf``. Default: False (will switch to True in the future)
59        foreach (bool): use the faster foreach-based implementation.
60            If ``None``, use the foreach implementation for CUDA and CPU native tensors and silently
61            fall back to the slow implementation for other device types.
62            Default: ``None``
63
64    Returns:
65        Total norm of the parameter gradients (viewed as a single vector).
66    """
67    if isinstance(parameters, torch.Tensor):
68        parameters = [parameters]
69    grads = [p.grad for p in parameters if p.grad is not None]
70    max_norm = float(max_norm)
71    norm_type = float(norm_type)
72    if len(grads) == 0:
73        return torch.tensor(0.0)
74    first_device = grads[0].device
75    grouped_grads: Dict[
76        Tuple[torch.device, torch.dtype], Tuple[List[List[Tensor]], List[int]]
77    ] = _group_tensors_by_device_and_dtype(
78        [grads]
79    )  # type: ignore[assignment]
80
81    norms: List[Tensor] = []
82    for (device, _), ([device_grads], _) in grouped_grads.items():  # type: ignore[assignment]
83        if (foreach is None and _has_foreach_support(device_grads, device)) or (
84            foreach and _device_has_foreach_support(device)
85        ):
86            norms.extend(torch._foreach_norm(device_grads, norm_type))
87        elif foreach:
88            raise RuntimeError(
89                f"foreach=True was passed, but can't use the foreach API on {device.type} tensors"
90            )
91        else:
92            norms.extend([torch.linalg.vector_norm(g, norm_type) for g in device_grads])
93
94    total_norm = torch.linalg.vector_norm(
95        torch.stack([norm.to(first_device) for norm in norms]), norm_type
96    )
97
98    if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()):
99        raise RuntimeError(
100            f"The total norm of order {norm_type} for gradients from "
101            "`parameters` is non-finite, so it cannot be clipped. To disable "
102            "this error and scale the gradients by the non-finite norm anyway, "
103            "set `error_if_nonfinite=False`"
104        )
105    clip_coef = max_norm / (total_norm + 1e-6)
106    # Note: multiplying by the clamped coef is redundant when the coef is clamped to 1, but doing so
107    # avoids a `if clip_coef < 1:` conditional which can require a CPU <=> device synchronization
108    # when the gradients do not reside in CPU memory.
109    clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
110    for (device, _), ([device_grads], _) in grouped_grads.items():  # type: ignore[assignment]
111        if (foreach is None and _has_foreach_support(device_grads, device)) or (
112            foreach and _device_has_foreach_support(device)
113        ):
114            torch._foreach_mul_(device_grads, clip_coef_clamped.to(device))
115        elif foreach:
116            raise RuntimeError(
117                f"foreach=True was passed, but can't use the foreach API on {device.type} tensors"
118            )
119        else:
120            clip_coef_clamped_device = clip_coef_clamped.to(device)
121            for g in device_grads:
122                g.mul_(clip_coef_clamped_device)
123
124    return total_norm
125
126
127@deprecated(
128    "`torch.nn.utils.clip_grad_norm` is now deprecated "
129    "in favor of `torch.nn.utils.clip_grad_norm_`.",
130    category=FutureWarning,
131)
132def clip_grad_norm(
133    parameters: _tensor_or_tensors,
134    max_norm: float,
135    norm_type: float = 2.0,
136    error_if_nonfinite: bool = False,
137    foreach: Optional[bool] = None,
138) -> torch.Tensor:
139    r"""Clip the gradient norm of an iterable of parameters.
140
141    .. warning::
142        This method is now deprecated in favor of
143        :func:`torch.nn.utils.clip_grad_norm_`.
144    """
145    return clip_grad_norm_(parameters, max_norm, norm_type, error_if_nonfinite, foreach)
146
147
148@_no_grad
149def clip_grad_value_(
150    parameters: _tensor_or_tensors,
151    clip_value: float,
152    foreach: Optional[bool] = None,
153) -> None:
154    r"""Clip the gradients of an iterable of parameters at specified value.
155
156    Gradients are modified in-place.
157
158    Args:
159        parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
160            single Tensor that will have gradients normalized
161        clip_value (float): maximum allowed value of the gradients.
162            The gradients are clipped in the range
163            :math:`\left[\text{-clip\_value}, \text{clip\_value}\right]`
164        foreach (bool): use the faster foreach-based implementation
165            If ``None``, use the foreach implementation for CUDA and CPU native tensors and
166            silently fall back to the slow implementation for other device types.
167            Default: ``None``
168    """
169    if isinstance(parameters, torch.Tensor):
170        parameters = [parameters]
171    clip_value = float(clip_value)
172
173    grads = [p.grad for p in parameters if p.grad is not None]
174    grouped_grads = _group_tensors_by_device_and_dtype([grads])
175
176    for (device, _), ([grads], _) in grouped_grads.items():  # type: ignore[assignment]
177        if (
178            foreach is None
179            and _has_foreach_support(cast(List[Tensor], grads), device=device)
180        ) or (foreach and _device_has_foreach_support(device)):
181            torch._foreach_clamp_min_(cast(List[Tensor], grads), -clip_value)
182            torch._foreach_clamp_max_(cast(List[Tensor], grads), clip_value)
183        elif foreach:
184            raise RuntimeError(
185                f"foreach=True was passed, but can't use the foreach API on {device.type} tensors"
186            )
187        else:
188            for grad in grads:
189                cast(Tensor, grad).clamp_(min=-clip_value, max=clip_value)
190