1# mypy: allow-untyped-defs 2import functools 3from typing import Any 4from typing_extensions import deprecated 5 6import torch 7 8 9__all__ = ["autocast", "custom_fwd", "custom_bwd"] 10 11 12class autocast(torch.amp.autocast_mode.autocast): 13 r"""See :class:`torch.autocast`. 14 15 ``torch.cuda.amp.autocast(args...)`` is deprecated. Please use ``torch.amp.autocast("cuda", args...)`` instead. 16 """ 17 18 @deprecated( 19 "`torch.cuda.amp.autocast(args...)` is deprecated. " 20 "Please use `torch.amp.autocast('cuda', args...)` instead.", 21 category=FutureWarning, 22 ) 23 def __init__( 24 self, 25 enabled: bool = True, 26 dtype: torch.dtype = torch.float16, 27 cache_enabled: bool = True, 28 ): 29 if torch._jit_internal.is_scripting(): 30 self._enabled = enabled 31 self.device = "cuda" 32 self.fast_dtype = dtype 33 return 34 super().__init__( 35 "cuda", enabled=enabled, dtype=dtype, cache_enabled=cache_enabled 36 ) 37 38 def __enter__(self): 39 if torch._jit_internal.is_scripting(): 40 return self 41 return super().__enter__() 42 43 # TODO: discuss a unified TorchScript-friendly API for autocast 44 def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any): # type: ignore[override] 45 if torch._jit_internal.is_scripting(): 46 return 47 return super().__exit__(exc_type, exc_val, exc_tb) 48 49 def __call__(self, func): 50 if torch._jit_internal.is_scripting(): 51 return func 52 return super().__call__(func) 53 54 55# Preserved only for BC reasons 56@deprecated( 57 "`torch.cuda.amp.autocast_mode._cast(value, dtype)` is deprecated. " 58 "Please use `torch.amp.autocast_mode._cast(value, 'cuda', dtype)` instead.", 59 category=FutureWarning, 60) 61def _cast(value, dtype): 62 return torch.amp.autocast_mode._cast(value, "cuda", dtype) 63 64 65@deprecated( 66 "`torch.cuda.amp.custom_fwd(args...)` is deprecated. " 67 "Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.", 68 category=FutureWarning, 69) 70def custom_fwd(fwd=None, *, cast_inputs=None): 71 """ 72 ``torch.cuda.amp.custom_fwd(args...)`` is deprecated. Please use 73 ``torch.amp.custom_fwd(args..., device_type='cuda')`` instead. 74 """ 75 return functools.partial(torch.amp.custom_fwd, device_type="cuda")( 76 fwd=fwd, cast_inputs=cast_inputs 77 ) 78 79 80@deprecated( 81 "`torch.cuda.amp.custom_bwd(args...)` is deprecated. " 82 "Please use `torch.amp.custom_bwd(args..., device_type='cuda')` instead.", 83 category=FutureWarning, 84) 85def custom_bwd(bwd): 86 """ 87 ``torch.cuda.amp.custom_bwd(args...)`` is deprecated. Please use 88 ``torch.amp.custom_bwd(args..., device_type='cuda')`` instead. 89 """ 90 return functools.partial(torch.amp.custom_bwd, device_type="cuda")(bwd) 91