xref: /aosp_15_r20/external/pytorch/torch/cuda/amp/grad_scaler.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from typing_extensions import deprecated
2
3import torch
4
5# We need to keep this unused import for BC reasons
6from torch.amp.grad_scaler import OptState  # noqa: F401
7
8
9__all__ = ["GradScaler"]
10
11
12class GradScaler(torch.amp.GradScaler):
13    r"""
14    See :class:`torch.amp.GradScaler`.
15    ``torch.cuda.amp.GradScaler(args...)`` is deprecated. Please use ``torch.amp.GradScaler("cuda", args...)`` instead.
16    """
17
18    @deprecated(
19        "`torch.cuda.amp.GradScaler(args...)` is deprecated. "
20        "Please use `torch.amp.GradScaler('cuda', args...)` instead.",
21        category=FutureWarning,
22    )
23    def __init__(
24        self,
25        init_scale: float = 2.0**16,
26        growth_factor: float = 2.0,
27        backoff_factor: float = 0.5,
28        growth_interval: int = 2000,
29        enabled: bool = True,
30    ) -> None:
31        super().__init__(
32            "cuda",
33            init_scale=init_scale,
34            growth_factor=growth_factor,
35            backoff_factor=backoff_factor,
36            growth_interval=growth_interval,
37            enabled=enabled,
38        )
39