1# mypy: allow-untyped-defs 2import collections 3import functools 4import warnings 5from typing import Any, Optional 6 7import torch 8from torch.types import _dtype 9 10 11try: 12 import numpy as np 13 14 HAS_NUMPY = True 15except ModuleNotFoundError: 16 HAS_NUMPY = False 17 np = None # type: ignore[assignment] 18 19__all__ = [ 20 "autocast_decorator", 21 "autocast", 22 "is_autocast_available", 23 "custom_fwd", 24 "custom_bwd", 25] 26 27 28def is_autocast_available(device_type: str) -> bool: 29 r""" 30 Return a bool indicating if autocast is available on :attr:`device_type`. 31 32 Args: 33 device_type(str): Device type to use. Possible values are: 'cuda', 'cpu', 'xpu' and so on. 34 The type is the same as the `type` attribute of a :class:`torch.device`. 35 Thus, you may obtain the device type of a tensor using `Tensor.device.type`. 36 """ 37 return torch._C._is_autocast_available(device_type) 38 39 40def autocast_decorator(autocast_instance, func): 41 @functools.wraps(func) 42 def decorate_autocast(*args, **kwargs): 43 with autocast_instance: 44 return func(*args, **kwargs) 45 46 decorate_autocast.__script_unsupported = "@autocast() decorator is not supported in script mode" # type: ignore[attr-defined] 47 return decorate_autocast 48 49 50class autocast: 51 r""" 52 Instances of :class:`autocast` serve as context managers or decorators that 53 allow regions of your script to run in mixed precision. 54 55 In these regions, ops run in an op-specific dtype chosen by autocast 56 to improve performance while maintaining accuracy. 57 See the :ref:`Autocast Op Reference<autocast-op-reference>` for details. 58 59 When entering an autocast-enabled region, Tensors may be any type. 60 You should not call ``half()`` or ``bfloat16()`` on your model(s) or inputs when using autocasting. 61 62 :class:`autocast` should wrap only the forward pass(es) of your network, including the loss 63 computation(s). Backward passes under autocast are not recommended. 64 Backward ops run in the same type that autocast used for corresponding forward ops. 65 66 Example for CUDA Devices:: 67 68 # Creates model and optimizer in default precision 69 model = Net().cuda() 70 optimizer = optim.SGD(model.parameters(), ...) 71 72 for input, target in data: 73 optimizer.zero_grad() 74 75 # Enables autocasting for the forward pass (model + loss) 76 with torch.autocast(device_type="cuda"): 77 output = model(input) 78 loss = loss_fn(output, target) 79 80 # Exits the context manager before backward() 81 loss.backward() 82 optimizer.step() 83 84 See the :ref:`Automatic Mixed Precision examples<amp-examples>` for usage (along with gradient scaling) 85 in more complex scenarios (e.g., gradient penalty, multiple models/losses, custom autograd functions). 86 87 :class:`autocast` can also be used as a decorator, e.g., on the ``forward`` method of your model:: 88 89 class AutocastModel(nn.Module): 90 ... 91 @torch.autocast(device_type="cuda") 92 def forward(self, input): 93 ... 94 95 Floating-point Tensors produced in an autocast-enabled region may be ``float16``. 96 After returning to an autocast-disabled region, using them with floating-point 97 Tensors of different dtypes may cause type mismatch errors. If so, cast the Tensor(s) 98 produced in the autocast region back to ``float32`` (or other dtype if desired). 99 If a Tensor from the autocast region is already ``float32``, the cast is a no-op, 100 and incurs no additional overhead. 101 CUDA Example:: 102 103 # Creates some tensors in default dtype (here assumed to be float32) 104 a_float32 = torch.rand((8, 8), device="cuda") 105 b_float32 = torch.rand((8, 8), device="cuda") 106 c_float32 = torch.rand((8, 8), device="cuda") 107 d_float32 = torch.rand((8, 8), device="cuda") 108 109 with torch.autocast(device_type="cuda"): 110 # torch.mm is on autocast's list of ops that should run in float16. 111 # Inputs are float32, but the op runs in float16 and produces float16 output. 112 # No manual casts are required. 113 e_float16 = torch.mm(a_float32, b_float32) 114 # Also handles mixed input types 115 f_float16 = torch.mm(d_float32, e_float16) 116 117 # After exiting autocast, calls f_float16.float() to use with d_float32 118 g_float32 = torch.mm(d_float32, f_float16.float()) 119 120 CPU Training Example:: 121 122 # Creates model and optimizer in default precision 123 model = Net() 124 optimizer = optim.SGD(model.parameters(), ...) 125 126 for epoch in epochs: 127 for input, target in data: 128 optimizer.zero_grad() 129 130 # Runs the forward pass with autocasting. 131 with torch.autocast(device_type="cpu", dtype=torch.bfloat16): 132 output = model(input) 133 loss = loss_fn(output, target) 134 135 loss.backward() 136 optimizer.step() 137 138 139 CPU Inference Example:: 140 141 # Creates model in default precision 142 model = Net().eval() 143 144 with torch.autocast(device_type="cpu", dtype=torch.bfloat16): 145 for input in data: 146 # Runs the forward pass with autocasting. 147 output = model(input) 148 149 CPU Inference Example with Jit Trace:: 150 151 class TestModel(nn.Module): 152 def __init__(self, input_size, num_classes): 153 super().__init__() 154 self.fc1 = nn.Linear(input_size, num_classes) 155 def forward(self, x): 156 return self.fc1(x) 157 158 input_size = 2 159 num_classes = 2 160 model = TestModel(input_size, num_classes).eval() 161 162 # For now, we suggest to disable the Jit Autocast Pass, 163 # As the issue: https://github.com/pytorch/pytorch/issues/75956 164 torch._C._jit_set_autocast_mode(False) 165 166 with torch.cpu.amp.autocast(cache_enabled=False): 167 model = torch.jit.trace(model, torch.randn(1, input_size)) 168 model = torch.jit.freeze(model) 169 # Models Run 170 for _ in range(3): 171 model(torch.randn(1, input_size)) 172 173 Type mismatch errors *in* an autocast-enabled region are a bug; if this is what you observe, 174 please file an issue. 175 176 ``autocast(enabled=False)`` subregions can be nested in autocast-enabled regions. 177 Locally disabling autocast can be useful, for example, if you want to force a subregion 178 to run in a particular ``dtype``. Disabling autocast gives you explicit control over 179 the execution type. In the subregion, inputs from the surrounding region 180 should be cast to ``dtype`` before use:: 181 182 # Creates some tensors in default dtype (here assumed to be float32) 183 a_float32 = torch.rand((8, 8), device="cuda") 184 b_float32 = torch.rand((8, 8), device="cuda") 185 c_float32 = torch.rand((8, 8), device="cuda") 186 d_float32 = torch.rand((8, 8), device="cuda") 187 188 with torch.autocast(device_type="cuda"): 189 e_float16 = torch.mm(a_float32, b_float32) 190 with torch.autocast(device_type="cuda", enabled=False): 191 # Calls e_float16.float() to ensure float32 execution 192 # (necessary because e_float16 was created in an autocasted region) 193 f_float32 = torch.mm(c_float32, e_float16.float()) 194 195 # No manual casts are required when re-entering the autocast-enabled region. 196 # torch.mm again runs in float16 and produces float16 output, regardless of input types. 197 g_float16 = torch.mm(d_float32, f_float32) 198 199 The autocast state is thread-local. If you want it enabled in a new thread, the context manager or decorator 200 must be invoked in that thread. This affects :class:`torch.nn.DataParallel` and 201 :class:`torch.nn.parallel.DistributedDataParallel` when used with more than one GPU per process 202 (see :ref:`Working with Multiple GPUs<amp-multigpu>`). 203 204 Args: 205 device_type(str, required): Device type to use. Possible values are: 'cuda', 'cpu', 'xpu' and 'hpu'. 206 The type is the same as the `type` attribute of a :class:`torch.device`. 207 Thus, you may obtain the device type of a tensor using `Tensor.device.type`. 208 enabled(bool, optional): Whether autocasting should be enabled in the region. 209 Default: ``True`` 210 dtype(torch_dtype, optional): Data type for ops run in autocast. It uses the default value 211 (``torch.float16`` for CUDA and ``torch.bfloat16`` for CPU), given by 212 :func:`~torch.get_autocast_dtype`, if :attr:`dtype` is ``None``. 213 Default: ``None`` 214 cache_enabled(bool, optional): Whether the weight cache inside autocast should be enabled. 215 Default: ``True`` 216 """ 217 218 def __init__( 219 self, 220 device_type: str, 221 dtype: Optional[_dtype] = None, 222 enabled: bool = True, 223 cache_enabled: Optional[bool] = None, 224 ): 225 if not isinstance(device_type, str): 226 raise ValueError( 227 f"Expected `device_type` of type `str`, got: `{type(device_type)}`" 228 ) 229 if dtype is None: 230 dtype = torch.get_autocast_dtype(device_type) 231 if torch._jit_internal.is_scripting(): 232 self._enabled = enabled 233 self.device = device_type 234 self.fast_dtype = dtype 235 assert dtype is not None 236 return 237 self.device = device_type 238 if not is_autocast_available(self.device): 239 raise RuntimeError( 240 f"User specified an unsupported autocast device_type '{self.device}'" 241 ) 242 self.custom_backend_name = torch._C._get_privateuse1_backend_name() 243 self.fast_dtype = torch.get_autocast_dtype(self.device) 244 if self.device == self.custom_backend_name: 245 necessary_funcs = [ 246 "get_amp_supported_dtype", 247 ] 248 message = f"Tried to use AMP with the `{self.custom_backend_name}` backend, but the backend has not " 249 message += "registered a module or the module miss some necessary funcs. The backend should register " 250 message += "a module by `torch._register_device_module`, and the module must have these funcs: \n" 251 message += "`get_amp_supported_dtype() -> List[torch.dtype]`. \n" 252 253 assert hasattr(torch, self.custom_backend_name), message 254 self.custom_device_mod = getattr(torch, self.custom_backend_name) 255 for func in necessary_funcs: 256 assert hasattr(self.custom_device_mod, func), ( 257 message + f"But the func `{func}` is missing. \n" 258 ) 259 260 self._cache_enabled = torch.is_autocast_cache_enabled() 261 if ( 262 enabled 263 and torch.cuda.amp.common.amp_definitely_not_available() 264 and self.device == "cuda" 265 ): 266 warnings.warn( 267 "User provided device_type of 'cuda', but CUDA is not available. Disabling" 268 ) 269 enabled = False 270 if dtype is not None: 271 self.fast_dtype = dtype 272 if cache_enabled is not None: 273 self._cache_enabled = cache_enabled 274 275 if self.device == "cpu": 276 supported_dtype = [torch.bfloat16, torch.float16] 277 if self.fast_dtype not in supported_dtype and enabled: 278 error_message = "In CPU autocast, but the target dtype is not supported. Disabling autocast.\n" 279 error_message += "CPU Autocast only supports dtype of " 280 error_message += ( 281 ", ".join(str(dtype) for dtype in supported_dtype) + " currently." 282 ) 283 warnings.warn(error_message) 284 enabled = False 285 elif self.device == "xpu": 286 supported_dtype = [torch.bfloat16, torch.float16] 287 if self.fast_dtype not in supported_dtype: 288 error_message = "In XPU autocast, but the target dtype is not supported. Disabling autocast.\n" 289 error_message += "XPU Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently." 290 warnings.warn(error_message) 291 enabled = False 292 elif self.device == "ipu": 293 supported_dtypes = [torch.bfloat16, torch.float16] 294 if self.fast_dtype not in supported_dtypes: 295 error_message = "In IPU autocast, but the target dtype is not supported. Disabling autocast.\n" 296 error_message += "IPU Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently." 297 warnings.warn(error_message) 298 enabled = False 299 elif self.device == "hpu": 300 supported_dtype = [torch.bfloat16, torch.float16] 301 if self.fast_dtype not in supported_dtype: 302 error_message = "In HPU autocast, but the target dtype is not supported. Disabling autocast.\n" 303 error_message += "HPU Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently." 304 warnings.warn(error_message) 305 enabled = False 306 elif self.device == self.custom_backend_name: 307 supported_dtype = self.custom_device_mod.get_amp_supported_dtype() 308 if self.fast_dtype not in supported_dtype: 309 error_message = f"In {self.custom_backend_name} autocast, but the target dtype is not supported. " 310 error_message += f"Disabling autocast.\n {self.custom_backend_name} Autocast only supports dtypes of " 311 error_message += ( 312 ", ".join(str(dtype) for dtype in supported_dtype) + " currently." 313 ) 314 warnings.warn(error_message) 315 enabled = False 316 elif self.device == "cuda": 317 if ( 318 enabled 319 and self.fast_dtype == torch.bfloat16 320 and not torch.cuda.is_bf16_supported() 321 ): 322 raise RuntimeError( 323 "Current CUDA Device does not support bfloat16. Please switch dtype to float16." 324 ) 325 elif self.device == "mps": 326 supported_dtype = [torch.float16] 327 if self.fast_dtype not in supported_dtype: 328 error_message = "In MPS autocast, but the target dtype is not supported. Disabling autocast.\n" 329 error_message += ( 330 "MPS Autocast only supports dtype of torch.bfloat16 currently." 331 ) 332 warnings.warn(error_message) 333 enabled = False 334 elif self.device == "xla": 335 supported_dtype = [torch.float16, torch.bfloat16] 336 if self.fast_dtype not in supported_dtype: 337 error_message = "In XLA autocast, but the target dtype is not supported. Disabling autocast.\n" 338 error_message += ( 339 "XLA Autocast only supports dtype of torch.bfloat16 currently." 340 ) 341 warnings.warn(error_message) 342 enabled = False 343 self._enabled = enabled 344 345 def __enter__(self): 346 if torch._jit_internal.is_scripting(): 347 assert self.fast_dtype is not None 348 return self 349 350 self.prev_cache_enabled = torch.is_autocast_cache_enabled() 351 self.prev = torch.is_autocast_enabled(self.device) 352 self.prev_fastdtype = torch.get_autocast_dtype(self.device) 353 torch.set_autocast_enabled(self.device, self._enabled) 354 torch.set_autocast_dtype(self.device, self.fast_dtype) # type: ignore[arg-type] 355 torch.autocast_increment_nesting() 356 torch.set_autocast_cache_enabled(self._cache_enabled) 357 358 def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any): # type: ignore[override] 359 if torch._jit_internal.is_scripting(): 360 return 361 362 # Drop the cache when we exit to a nesting level that's outside any instance of autocast. 363 if torch.autocast_decrement_nesting() == 0: 364 torch.clear_autocast_cache() 365 torch.set_autocast_enabled(self.device, self.prev) 366 torch.set_autocast_dtype(self.device, self.prev_fastdtype) 367 torch.set_autocast_cache_enabled(self.prev_cache_enabled) 368 return False 369 370 def __call__(self, func): 371 if torch._jit_internal.is_scripting(): 372 return func 373 return autocast_decorator(self, func) 374 375 376# These functions aren't meant for public usage. 377# They are what we trace into a graph during pre_dispatch tracing 378# when we encounter an autocast context manager. 379def _enter_autocast(*vals): 380 # For pre-dispatch tracing, if a TorchFunction mode is active, we'll want to trace this into a graph. 381 if torch._C._is_torch_function_mode_enabled(): 382 return torch.overrides.handle_torch_function( 383 torch.amp._enter_autocast, [], *vals 384 ) 385 mode = torch.amp.autocast(*vals) 386 mode.__enter__() 387 return mode 388 389 390def _exit_autocast(mode): 391 if torch._C._is_torch_function_mode_enabled(): 392 return torch.overrides.handle_torch_function(torch.amp._exit_autocast, [], mode) 393 mode.__exit__(None, None, None) 394 395 396# Casts Tensors and containers of Tensors. Special-cases passthroughs for strings and np.ndarrays, which 397# may be falsely detected as "Iterables." 398def _cast(value, device_type: str, dtype: _dtype): 399 if isinstance(value, torch.Tensor): 400 is_eligible = ( 401 value.is_floating_point() 402 and value.device.type == device_type 403 and (value.dtype is not torch.float64) 404 ) 405 return value.to(dtype) if is_eligible else value 406 elif isinstance(value, (str, bytes)): 407 return value 408 elif HAS_NUMPY and isinstance(value, np.ndarray): 409 return value 410 elif isinstance(value, collections.abc.Mapping): 411 return { 412 _cast(k, device_type, dtype): _cast(v, device_type, dtype) 413 for k, v in value.items() 414 } 415 elif isinstance(value, collections.abc.Iterable): 416 iterable = (_cast(v, device_type, dtype) for v in value) 417 if isinstance(value, (list, tuple)): 418 return type(value)(iterable) 419 else: 420 return iterable 421 else: 422 return value 423 424 425def custom_fwd( 426 fwd=None, 427 *, 428 device_type: str, 429 cast_inputs: Optional[_dtype] = None, 430): 431 """ 432 Create a helper decorator for ``forward`` methods of custom autograd functions. 433 434 Autograd functions are subclasses of :class:`torch.autograd.Function`. 435 See the :ref:`example page<amp-custom-examples>` for more detail. 436 437 Args: 438 device_type(str): Device type to use. 'cuda', 'cpu', 'xpu' and so on. 439 The type is the same as the `type` attribute of a :class:`torch.device`. 440 Thus, you may obtain the device type of a tensor using `Tensor.device.type`. 441 cast_inputs (:class:`torch.dtype` or None, optional, default=None): If not ``None``, 442 when ``forward`` runs in an autocast-enabled region, casts incoming 443 floating-point Tensors to the target dtype (non-floating-point Tensors are not affected), 444 then executes ``forward`` with autocast disabled. 445 If ``None``, ``forward``'s internal ops execute with the current autocast state. 446 447 .. note:: 448 If the decorated ``forward`` is called outside an autocast-enabled region, 449 :func:`custom_fwd<custom_fwd>` is a no-op and ``cast_inputs`` has no effect. 450 """ 451 if not isinstance(device_type, str): 452 raise ValueError( 453 f"Expected `device_type` of type `str`, got: `{type(device_type)}`" 454 ) 455 if fwd is None: 456 return functools.partial( 457 custom_fwd, device_type=device_type, cast_inputs=cast_inputs 458 ) 459 460 @functools.wraps(fwd) 461 def decorate_fwd(*args, **kwargs): 462 args[0]._dtype = torch.get_autocast_dtype(device_type) 463 if cast_inputs is None: 464 args[0]._fwd_used_autocast = torch.is_autocast_enabled(device_type) 465 return fwd(*args, **kwargs) 466 else: 467 autocast_context = torch.is_autocast_enabled(device_type) 468 args[0]._fwd_used_autocast = False 469 if autocast_context: 470 with autocast(device_type=device_type, enabled=False): 471 return fwd( 472 *_cast(args, device_type, cast_inputs), 473 **_cast(kwargs, device_type, cast_inputs), 474 ) 475 else: 476 return fwd(*args, **kwargs) 477 478 return decorate_fwd 479 480 481# Autograd ensures incoming gradients are the same type as forward outputs. Allowing a separate 482# cast_inputs argument on custom_bwd is unnecessary and could cause errors if it doesn't match 483# cast_inputs supplied to custom_fwd. 484def custom_bwd(bwd=None, *, device_type: str): 485 """Create a helper decorator for backward methods of custom autograd functions. 486 487 Autograd functions are subclasses of :class:`torch.autograd.Function`. 488 Ensures that ``backward`` executes with the same autocast state as ``forward``. 489 See the :ref:`example page<amp-custom-examples>` for more detail. 490 491 Args: 492 device_type(str): Device type to use. 'cuda', 'cpu', 'xpu' and so on. 493 The type is the same as the `type` attribute of a :class:`torch.device`. 494 Thus, you may obtain the device type of a tensor using `Tensor.device.type`. 495 """ 496 497 if not isinstance(device_type, str): 498 raise ValueError( 499 f"Expected `device_type` of type `str`, got: `{type(device_type)}`" 500 ) 501 if bwd is None: 502 return functools.partial(custom_bwd, device_type=device_type) 503 504 @functools.wraps(bwd) 505 def decorate_bwd(*args, **kwargs): 506 with autocast( 507 device_type=device_type, 508 enabled=args[0]._fwd_used_autocast, 509 dtype=args[0]._dtype, 510 ): 511 return bwd(*args, **kwargs) 512 513 return decorate_bwd 514