1from typing_extensions import deprecated 2 3import torch 4 5# We need to keep this unused import for BC reasons 6from torch.amp.grad_scaler import OptState # noqa: F401 7 8 9__all__ = ["GradScaler"] 10 11 12class GradScaler(torch.amp.GradScaler): 13 r""" 14 See :class:`torch.amp.GradScaler`. 15 ``torch.cuda.amp.GradScaler(args...)`` is deprecated. Please use ``torch.amp.GradScaler("cuda", args...)`` instead. 16 """ 17 18 @deprecated( 19 "`torch.cuda.amp.GradScaler(args...)` is deprecated. " 20 "Please use `torch.amp.GradScaler('cuda', args...)` instead.", 21 category=FutureWarning, 22 ) 23 def __init__( 24 self, 25 init_scale: float = 2.0**16, 26 growth_factor: float = 2.0, 27 backoff_factor: float = 0.5, 28 growth_interval: int = 2000, 29 enabled: bool = True, 30 ) -> None: 31 super().__init__( 32 "cuda", 33 init_scale=init_scale, 34 growth_factor=growth_factor, 35 backoff_factor=backoff_factor, 36 growth_interval=growth_interval, 37 enabled=enabled, 38 ) 39