1# mypy: allow-untyped-defs 2from __future__ import annotations 3 4import inspect 5import warnings 6from collections import abc, defaultdict 7from enum import Enum 8from typing import Any, cast, Dict, Iterable, List, Optional, overload, Tuple, Union 9 10import torch 11 12 13__all__ = ["OptState", "GradScaler"] 14 15 16class _MultiDeviceReplicator: 17 """Lazily serves copies of a tensor to requested devices. 18 19 Copies are cached per-device. 20 """ 21 22 def __init__(self, master_tensor: torch.Tensor) -> None: 23 self.master = master_tensor 24 self._per_device_tensors: Dict[torch.device, torch.Tensor] = {} 25 26 def get(self, device: torch.device) -> torch.Tensor: 27 retval = self._per_device_tensors.get(device, None) 28 if retval is None: 29 retval = self.master.to(device=device, non_blocking=True, copy=True) 30 self._per_device_tensors[device] = retval 31 return retval 32 33 34# Defines default_factory for GradScaler's _per_optimizer_states defaultdict, 35# as well as associated "enum" values. Prefers defining these at top level because 36# - Lambdas can't be pickled, so we don't want to supply a lambda as the factory. 37# - Defining READY, UNSCALED, STEPPED and _refresh_per_optimizer_state within GradScaler 38# causes a circular reference, which we'd rather avoid. 39class OptState(Enum): 40 READY = 0 41 UNSCALED = 1 42 STEPPED = 2 43 44 45def _refresh_per_optimizer_state() -> Dict[str, Any]: 46 return {"stage": OptState.READY, "found_inf_per_device": {}} 47 48 49class GradScaler: 50 """An instance ``scaler`` of :class:`GradScaler`. 51 52 Helps perform the steps of gradient scaling 53 conveniently. 54 55 * ``scaler.scale(loss)`` multiplies a given loss by ``scaler``'s current scale factor. 56 * ``scaler.step(optimizer)`` safely unscales gradients and calls ``optimizer.step()``. 57 * ``scaler.update()`` updates ``scaler``'s scale factor. 58 59 Example:: 60 61 # Creates a GradScaler once at the beginning of training. 62 scaler = GradScaler() 63 64 for epoch in epochs: 65 for input, target in data: 66 optimizer.zero_grad() 67 output = model(input) 68 loss = loss_fn(output, target) 69 70 # Scales loss. Calls backward() on scaled loss to create scaled gradients. 71 scaler.scale(loss).backward() 72 73 # scaler.step() first unscales gradients of the optimizer's params. 74 # If gradients don't contain infs/NaNs, optimizer.step() is then called, 75 # otherwise, optimizer.step() is skipped. 76 scaler.step(optimizer) 77 78 # Updates the scale for next iteration. 79 scaler.update() 80 81 See the :ref:`Automatic Mixed Precision examples<amp-examples>` for usage 82 (along with autocasting) in more complex cases like gradient clipping, gradient accumulation, gradient penalty, 83 and multiple losses/optimizers. 84 85 ``scaler`` dynamically estimates the scale factor each iteration. To minimize gradient underflow, 86 a large scale factor should be used. However, ``float16`` values can "overflow" (become inf or NaN) if 87 the scale factor is too large. Therefore, the optimal scale factor is the largest factor that can be used 88 without incurring inf or NaN gradient values. 89 ``scaler`` approximates the optimal scale factor over time by checking the gradients for infs and NaNs during every 90 ``scaler.step(optimizer)`` (or optional separate ``scaler.unscale_(optimizer)``, see :meth:`unscale_`). 91 92 * If infs/NaNs are found, ``scaler.step(optimizer)`` skips the underlying ``optimizer.step()`` (so the params 93 themselves remain uncorrupted) and ``update()`` multiplies the scale by ``backoff_factor``. 94 95 * If no infs/NaNs are found, ``scaler.step(optimizer)`` runs the underlying ``optimizer.step()`` as usual. 96 If ``growth_interval`` unskipped iterations occur consecutively, ``update()`` multiplies the scale by 97 ``growth_factor``. 98 99 The scale factor often causes infs/NaNs to appear in gradients for the first few iterations as its 100 value calibrates. ``scaler.step`` will skip the underlying ``optimizer.step()`` for these 101 iterations. After that, step skipping should occur rarely (once every few hundred or thousand iterations). 102 103 Args: 104 device (str, optional, default="cuda"): Device type to use. Possible values are: 'cuda' and 'cpu'. 105 The type is the same as the `type` attribute of a :class:`torch.device`. 106 Thus, you may obtain the device type of a tensor using `Tensor.device.type`. 107 init_scale (float, optional, default=2.**16): Initial scale factor. 108 growth_factor (float, optional, default=2.0): Factor by which the scale is multiplied during 109 :meth:`update` if no inf/NaN gradients occur for ``growth_interval`` consecutive iterations. 110 backoff_factor (float, optional, default=0.5): Factor by which the scale is multiplied during 111 :meth:`update` if inf/NaN gradients occur in an iteration. 112 growth_interval (int, optional, default=2000): Number of consecutive iterations without inf/NaN gradients 113 that must occur for the scale to be multiplied by ``growth_factor``. 114 enabled (bool, optional): If ``False``, disables gradient scaling. :meth:`step` simply 115 invokes the underlying ``optimizer.step()``, and other methods become no-ops. 116 Default: ``True`` 117 """ 118 119 def __init__( 120 self, 121 device: str = "cuda", 122 init_scale: float = 2.0**16, 123 growth_factor: float = 2.0, 124 backoff_factor: float = 0.5, 125 growth_interval: int = 2000, 126 enabled: bool = True, 127 ) -> None: 128 self._device = device 129 self._enabled = enabled 130 if self._device == "cuda": 131 if enabled and torch.cuda.amp.common.amp_definitely_not_available(): 132 warnings.warn( 133 "torch.cuda.amp.GradScaler is enabled, but CUDA is not available. Disabling." 134 ) 135 self._enabled = False 136 137 if self._enabled: 138 assert growth_factor > 1.0, "The growth factor must be > 1.0." 139 assert backoff_factor < 1.0, "The backoff factor must be < 1.0." 140 141 self._init_scale = init_scale 142 # self._scale will be lazily initialized during the first call to scale() 143 self._scale: Optional[torch.Tensor] = None 144 self._growth_factor = growth_factor 145 self._backoff_factor = backoff_factor 146 self._growth_interval = growth_interval 147 self._init_growth_tracker = 0 148 # self._growth_tracker will be lazily initialized during the first call to scale() 149 self._growth_tracker: Optional[torch.Tensor] = None 150 self._per_optimizer_states: Dict[int, Dict[str, Any]] = defaultdict( 151 _refresh_per_optimizer_state 152 ) 153 154 def _check_scale_growth_tracker( 155 self, funcname: str 156 ) -> Tuple[torch.Tensor, torch.Tensor]: 157 fix = "This may indicate your script did not use scaler.scale(loss or outputs) earlier in the iteration." 158 assert self._scale is not None, ( 159 f"Attempted {funcname} but _scale is None. " + fix 160 ) 161 assert self._growth_tracker is not None, ( 162 f"Attempted {funcname} but _growth_tracker is None. " + fix 163 ) 164 return (self._scale, self._growth_tracker) 165 166 def _lazy_init_scale_growth_tracker(self, dev: torch.device) -> None: 167 assert self._growth_tracker is None, "_growth_tracker initialized before _scale" 168 self._scale = torch.full((), self._init_scale, dtype=torch.float32, device=dev) 169 self._growth_tracker = torch.full( 170 (), self._init_growth_tracker, dtype=torch.int32, device=dev 171 ) 172 173 @overload 174 def scale(self, outputs: torch.Tensor) -> torch.Tensor: 175 ... 176 177 @overload 178 def scale(self, outputs: List[torch.Tensor]) -> List[torch.Tensor]: 179 ... 180 181 @overload 182 def scale(self, outputs: Tuple[torch.Tensor, ...]) -> Tuple[torch.Tensor, ...]: 183 ... 184 185 @overload 186 def scale(self, outputs: Iterable[torch.Tensor]) -> Iterable[torch.Tensor]: 187 ... 188 189 def scale( 190 self, 191 outputs: Union[torch.Tensor, Iterable[torch.Tensor]], 192 ) -> Union[torch.Tensor, Iterable[torch.Tensor]]: 193 """ 194 Multiplies ('scales') a tensor or list of tensors by the scale factor. 195 196 Returns scaled outputs. If this instance of :class:`GradScaler` is not enabled, outputs are returned 197 unmodified. 198 199 Args: 200 outputs (Tensor or iterable of Tensors): Outputs to scale. 201 """ 202 if not self._enabled: 203 return outputs 204 205 # Short-circuit for the common case. 206 if isinstance(outputs, torch.Tensor): 207 if self._scale is None: 208 self._lazy_init_scale_growth_tracker(outputs.device) 209 assert self._scale is not None 210 return outputs * self._scale.to(device=outputs.device, non_blocking=True) 211 212 # Invoke the more complex machinery only if we're treating multiple outputs. 213 stash: List[ 214 _MultiDeviceReplicator 215 ] = [] # holds a reference that can be overwritten by apply_scale 216 217 def apply_scale(val: Union[torch.Tensor, Iterable[torch.Tensor]]): 218 if isinstance(val, torch.Tensor): 219 if len(stash) == 0: 220 if self._scale is None: 221 self._lazy_init_scale_growth_tracker(val.device) 222 assert self._scale is not None 223 stash.append(_MultiDeviceReplicator(self._scale)) 224 return val * stash[0].get(val.device) 225 if isinstance(val, abc.Iterable): 226 iterable = map(apply_scale, val) 227 if isinstance(val, (list, tuple)): 228 return type(val)(iterable) 229 return iterable 230 raise ValueError("outputs must be a Tensor or an iterable of Tensors") 231 232 return apply_scale(outputs) 233 234 def _unscale_grads_( 235 self, 236 optimizer: torch.optim.Optimizer, 237 inv_scale: torch.Tensor, 238 found_inf: torch.Tensor, 239 allow_fp16: bool, 240 ) -> Dict[torch.device, torch.Tensor]: 241 per_device_inv_scale = _MultiDeviceReplicator(inv_scale) 242 per_device_found_inf = _MultiDeviceReplicator(found_inf) 243 244 # To set up _amp_foreach_non_finite_check_and_unscale_, split grads by device and dtype. 245 # There could be hundreds of grads, so we'd like to iterate through them just once. 246 # However, we don't know their devices or dtypes in advance. 247 248 # https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict 249 # Google says mypy struggles with defaultdicts type annotations. 250 per_device_and_dtype_grads: Dict[ 251 torch.device, Dict[torch.dtype, List[torch.Tensor]] 252 ] = defaultdict(lambda: defaultdict(list)) 253 with torch.no_grad(): 254 for group in optimizer.param_groups: 255 for param in group["params"]: 256 assert isinstance(param, torch.Tensor) 257 if param.grad is None: 258 continue 259 if (not allow_fp16) and param.grad.dtype == torch.float16: 260 raise ValueError("Attempting to unscale FP16 gradients.") 261 if param.grad.is_sparse: 262 # is_coalesced() == False means the sparse grad has values with duplicate indices. 263 # coalesce() deduplicates indices and adds all values that have the same index. 264 # For scaled fp16 values, there's a good chance coalescing will cause overflow, 265 # so we should check the coalesced _values(). 266 if param.grad.dtype is torch.float16: 267 param.grad = param.grad.coalesce() 268 to_unscale = param.grad._values() 269 else: 270 to_unscale = param.grad 271 272 # TODO: is there a way to split by device and dtype without appending in the inner loop? 273 per_device_and_dtype_grads[to_unscale.device][ 274 to_unscale.dtype 275 ].append(to_unscale) 276 277 for device, per_dtype_grads in per_device_and_dtype_grads.items(): 278 for grads in per_dtype_grads.values(): 279 torch._amp_foreach_non_finite_check_and_unscale_( 280 grads, 281 per_device_found_inf.get(device), 282 per_device_inv_scale.get(device), 283 ) 284 285 return per_device_found_inf._per_device_tensors 286 287 def unscale_(self, optimizer: torch.optim.Optimizer) -> None: 288 """ 289 Divides ("unscales") the optimizer's gradient tensors by the scale factor. 290 291 :meth:`unscale_` is optional, serving cases where you need to 292 :ref:`modify or inspect gradients<working-with-unscaled-gradients>` 293 between the backward pass(es) and :meth:`step`. 294 If :meth:`unscale_` is not called explicitly, gradients will be unscaled automatically during :meth:`step`. 295 296 Simple example, using :meth:`unscale_` to enable clipping of unscaled gradients:: 297 298 ... 299 scaler.scale(loss).backward() 300 scaler.unscale_(optimizer) 301 torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) 302 scaler.step(optimizer) 303 scaler.update() 304 305 Args: 306 optimizer (torch.optim.Optimizer): Optimizer that owns the gradients to be unscaled. 307 308 .. note:: 309 :meth:`unscale_` does not incur a CPU-GPU sync. 310 311 .. warning:: 312 :meth:`unscale_` should only be called once per optimizer per :meth:`step` call, 313 and only after all gradients for that optimizer's assigned parameters have been accumulated. 314 Calling :meth:`unscale_` twice for a given optimizer between each :meth:`step` triggers a RuntimeError. 315 316 .. warning:: 317 :meth:`unscale_` may unscale sparse gradients out of place, replacing the ``.grad`` attribute. 318 """ 319 if not self._enabled: 320 return 321 322 self._check_scale_growth_tracker("unscale_") 323 324 optimizer_state = self._per_optimizer_states[id(optimizer)] 325 326 if optimizer_state["stage"] is OptState.UNSCALED: 327 raise RuntimeError( 328 "unscale_() has already been called on this optimizer since the last update()." 329 ) 330 elif optimizer_state["stage"] is OptState.STEPPED: 331 raise RuntimeError("unscale_() is being called after step().") 332 333 # FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64. 334 assert self._scale is not None 335 inv_scale = self._scale.double().reciprocal().float() 336 found_inf = torch.full((), 0.0, dtype=torch.float32, device=self._scale.device) 337 338 optimizer_state["found_inf_per_device"] = self._unscale_grads_( 339 optimizer, inv_scale, found_inf, False 340 ) 341 optimizer_state["stage"] = OptState.UNSCALED 342 343 def _maybe_opt_step( 344 self, 345 optimizer: torch.optim.Optimizer, 346 optimizer_state: Dict[str, Any], 347 *args: Any, 348 **kwargs: Any, 349 ) -> Optional[float]: 350 retval: Optional[float] = None 351 if not sum(v.item() for v in optimizer_state["found_inf_per_device"].values()): 352 retval = optimizer.step(*args, **kwargs) 353 return retval 354 355 def step( 356 self, optimizer: torch.optim.Optimizer, *args: Any, **kwargs: Any 357 ) -> Optional[float]: 358 """Invoke ``unscale_(optimizer)`` followed by parameter update, if gradients are not infs/NaN. 359 360 :meth:`step` carries out the following two operations: 361 362 1. Internally invokes ``unscale_(optimizer)`` (unless :meth:`unscale_` was explicitly called for ``optimizer`` 363 earlier in the iteration). As part of the :meth:`unscale_`, gradients are checked for infs/NaNs. 364 2. If no inf/NaN gradients are found, invokes ``optimizer.step()`` using the unscaled 365 gradients. Otherwise, ``optimizer.step()`` is skipped to avoid corrupting the params. 366 367 ``*args`` and ``**kwargs`` are forwarded to ``optimizer.step()``. 368 369 Returns the return value of ``optimizer.step(*args, **kwargs)``. 370 371 Args: 372 optimizer (torch.optim.Optimizer): Optimizer that applies the gradients. 373 args: Any arguments. 374 kwargs: Any keyword arguments. 375 376 .. warning:: 377 Closure use is not currently supported. 378 """ 379 if not self._enabled: 380 return optimizer.step(*args, **kwargs) 381 382 if "closure" in kwargs: 383 raise RuntimeError( 384 "Closure use is not currently supported if GradScaler is enabled." 385 ) 386 387 self._check_scale_growth_tracker("step") 388 389 optimizer_state = self._per_optimizer_states[id(optimizer)] 390 391 if optimizer_state["stage"] is OptState.STEPPED: 392 raise RuntimeError( 393 "step() has already been called since the last update()." 394 ) 395 396 retval: Optional[float] = None 397 398 if getattr(optimizer, "_step_supports_amp_scaling", False): 399 # This optimizer has customized scale-handling logic, so we can call optimizer.step() directly. 400 # The contract with custom optimizers is that their step() should accept an additional, 401 # optional grad_scaler kwarg. We append self to the kwargs so the custom optimizer has full information: 402 # it can query its own state, invoke unscale_ on itself, etc 403 # The contract above is being deprecated to avoid introducing `grad_scaler: GradScaler` argument 404 # to `Optimizer.step`. The new behavior is going to add two Tensor attributes of `grad_scale` 405 # and `found_inf` to the passed optimizer so that the optimizer can utilize those 406 # to skip the parameter updates or unscale gradients before updating parameters in 407 # the fused kernel, e.g. `FusedAdamMathFunctor`. 408 # In this behavior, `GradScaler._check_inf_per_device` is called if `OptState.READY`, 409 # while the method is expected to be called by users side, i.e. their optimizers. 410 kwargs_ = kwargs 411 has_grad_scaler_kwarg = ( 412 "grad_scaler" in inspect.signature(optimizer.step).parameters 413 ) 414 if has_grad_scaler_kwarg: 415 warnings.warn( 416 "GradScaler is going to stop passing itself as a keyword argument to the passed " 417 "optimizer. In the near future GradScaler registers `grad_scale: Tensor` and " 418 "`found_inf: Tensor` to the passed optimizer and let the optimizer use them directly.", 419 FutureWarning, 420 ) 421 kwargs_.update({"grad_scaler": self}) 422 else: 423 if optimizer_state["stage"] is OptState.READY: 424 self._check_inf_per_device(optimizer) 425 scaler = self._get_scale_async() 426 assert scaler is not None 427 found_inf = cast( 428 torch.Tensor, 429 sum( 430 [ # noqa: C419 431 t.to(scaler.device, non_blocking=True) 432 for t in optimizer_state["found_inf_per_device"].values() 433 ] 434 ), 435 ) 436 # Take the product of the scales, if the user has already set `optimizer.grad_scale`. 437 optimizer.grad_scale = ( # type: ignore[attr-defined] 438 getattr(optimizer, "grad_scale", None) 439 if optimizer_state["stage"] == OptState.UNSCALED 440 else scaler * getattr(optimizer, "grad_scale", 1) 441 ) 442 optimizer.found_inf = found_inf # type: ignore[attr-defined] 443 retval = optimizer.step(*args, **kwargs_) 444 optimizer_state["stage"] = OptState.STEPPED 445 if not has_grad_scaler_kwarg: 446 del optimizer.grad_scale # type: ignore[attr-defined] 447 del optimizer.found_inf # type: ignore[attr-defined] 448 return retval 449 450 if optimizer_state["stage"] is OptState.READY: 451 self.unscale_(optimizer) 452 453 assert ( 454 len(optimizer_state["found_inf_per_device"]) > 0 455 ), "No inf checks were recorded for this optimizer." 456 457 retval = self._maybe_opt_step(optimizer, optimizer_state, *args, **kwargs) 458 459 optimizer_state["stage"] = OptState.STEPPED 460 461 return retval 462 463 def update(self, new_scale: Optional[Union[float, torch.Tensor]] = None) -> None: 464 """Update the scale factor. 465 466 If any optimizer steps were skipped the scale is multiplied by ``backoff_factor`` 467 to reduce it. If ``growth_interval`` unskipped iterations occurred consecutively, 468 the scale is multiplied by ``growth_factor`` to increase it. 469 470 Passing ``new_scale`` sets the new scale value manually. (``new_scale`` is not 471 used directly, it's used to fill GradScaler's internal scale tensor. So if 472 ``new_scale`` was a tensor, later in-place changes to that tensor will not further 473 affect the scale GradScaler uses internally.) 474 475 Args: 476 new_scale (float or :class:`torch.Tensor`, optional, default=None): New scale factor. 477 478 .. warning:: 479 :meth:`update` should only be called at the end of the iteration, after ``scaler.step(optimizer)`` has 480 been invoked for all optimizers used this iteration. 481 482 .. warning:: 483 For performance reasons, we do not check the scale factor value to avoid synchronizations, 484 so the scale factor is not guaranteed to be above 1. If the scale falls below 1 and/or 485 you are seeing NaNs in your gradients or loss, something is likely wrong. For example, 486 bf16-pretrained models are often incompatible with AMP/fp16 due to differing dynamic ranges. 487 """ 488 if not self._enabled: 489 return 490 491 _scale, _growth_tracker = self._check_scale_growth_tracker("update") 492 493 if new_scale is not None: 494 assert self._scale is not None 495 # Accept a new user-defined scale. 496 if isinstance(new_scale, float): 497 self._scale.fill_(new_scale) 498 else: 499 reason = "new_scale should be a float or a 1-element torch.cuda.FloatTensor or \ 500 torch.FloatTensor with requires_grad=False." 501 assert new_scale.device.type == self._device, reason 502 assert new_scale.numel() == 1, reason 503 assert new_scale.requires_grad is False, reason 504 self._scale.copy_(new_scale) 505 else: 506 # Consume shared inf/nan data collected from optimizers to update the scale. 507 # If all found_inf tensors are on the same device as self._scale, this operation is asynchronous. 508 found_infs = [ 509 found_inf.to(device=_scale.device, non_blocking=True) 510 for state in self._per_optimizer_states.values() 511 for found_inf in state["found_inf_per_device"].values() 512 ] 513 514 assert len(found_infs) > 0, "No inf checks were recorded prior to update." 515 516 found_inf_combined = found_infs[0] 517 if len(found_infs) > 1: 518 for i in range(1, len(found_infs)): 519 found_inf_combined += found_infs[i] 520 521 torch._amp_update_scale_( 522 _scale, 523 _growth_tracker, 524 found_inf_combined, 525 self._growth_factor, 526 self._backoff_factor, 527 self._growth_interval, 528 ) 529 530 # To prepare for next iteration, clear the data collected from optimizers this iteration. 531 self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state) 532 533 def _get_scale_async(self) -> Optional[torch.Tensor]: 534 return self._scale 535 536 def get_scale(self) -> float: 537 """Return a Python float containing the current scale, or 1.0 if scaling is disabled. 538 539 .. warning:: 540 :meth:`get_scale` incurs a CPU-GPU sync. 541 """ 542 if self._enabled: 543 return ( 544 self._init_scale 545 if (scale := self._get_scale_async()) is None 546 else cast(float, scale.item()) 547 ) 548 return 1.0 549 550 def get_growth_factor(self) -> float: 551 r"""Return a Python float containing the scale growth factor.""" 552 return self._growth_factor 553 554 def set_growth_factor(self, new_factor: float) -> None: 555 r"""Set a new scale growth factor. 556 557 Args: 558 new_scale (float): Value to use as the new scale growth factor. 559 """ 560 self._growth_factor = new_factor 561 562 def get_backoff_factor(self) -> float: 563 r"""Return a Python float containing the scale backoff factor.""" 564 return self._backoff_factor 565 566 def set_backoff_factor(self, new_factor: float) -> None: 567 r"""Set a new scale backoff factor. 568 569 Args: 570 new_scale (float): Value to use as the new scale backoff factor. 571 """ 572 self._backoff_factor = new_factor 573 574 def get_growth_interval(self) -> int: 575 r"""Return a Python int containing the growth interval.""" 576 return self._growth_interval 577 578 def set_growth_interval(self, new_interval: int) -> None: 579 r"""Set a new growth interval. 580 581 Args: 582 new_interval (int): Value to use as the new growth interval. 583 """ 584 self._growth_interval = new_interval 585 586 def _get_growth_tracker(self) -> int: 587 if self._enabled: 588 return ( 589 self._init_growth_tracker 590 if self._growth_tracker is None 591 else cast(int, self._growth_tracker.item()) 592 ) 593 return 0 594 595 def is_enabled(self) -> bool: 596 r"""Return a bool indicating whether this instance is enabled.""" 597 return self._enabled 598 599 def state_dict(self) -> Dict[str, Any]: 600 r"""Return the state of the scaler as a :class:`dict`. 601 602 It contains five entries: 603 604 * ``"scale"`` - a Python float containing the current scale 605 * ``"growth_factor"`` - a Python float containing the current growth factor 606 * ``"backoff_factor"`` - a Python float containing the current backoff factor 607 * ``"growth_interval"`` - a Python int containing the current growth interval 608 * ``"_growth_tracker"`` - a Python int containing the number of recent consecutive unskipped steps. 609 610 If this instance is not enabled, returns an empty dict. 611 612 .. note:: 613 If you wish to checkpoint the scaler's state after a particular iteration, :meth:`state_dict` 614 should be called after :meth:`update`. 615 """ 616 if self._enabled: 617 return { 618 "scale": self.get_scale(), 619 "growth_factor": self._growth_factor, 620 "backoff_factor": self._backoff_factor, 621 "growth_interval": self._growth_interval, 622 "_growth_tracker": self._get_growth_tracker(), 623 } 624 return {} 625 626 def load_state_dict(self, state_dict: Dict[str, Any]) -> None: 627 r"""Load the scaler state. 628 629 If this instance is disabled, :meth:`load_state_dict` is a no-op. 630 631 Args: 632 state_dict(dict): scaler state. Should be an object returned from a call to :meth:`state_dict`. 633 """ 634 if not self._enabled: 635 return 636 637 if len(state_dict) == 0: 638 raise RuntimeError( 639 "The source state dict is empty, possibly because it was saved " 640 "from a disabled instance of GradScaler." 641 ) 642 643 self._init_scale = cast(float, state_dict["scale"]) 644 if self._scale is not None: 645 self._scale.fill_(state_dict["scale"]) 646 self._growth_factor = cast(float, state_dict["growth_factor"]) 647 self._backoff_factor = cast(float, state_dict["backoff_factor"]) 648 self._growth_interval = cast(int, state_dict["growth_interval"]) 649 self._init_growth_tracker = cast(int, state_dict["_growth_tracker"]) 650 if self._growth_tracker is not None: 651 self._growth_tracker.fill_(state_dict["_growth_tracker"]) 652 653 def __getstate__(self) -> Dict[str, Any]: 654 state = self.__dict__.copy() 655 if self._enabled: 656 assert len(self._per_optimizer_states) == 0, ( 657 "A GradScaler instance may only be pickled at the beginning " 658 "of an iteration, or at the end after scaler.update()." 659 ) 660 # Pickling _scale and _growth_tracker Tensors directly triggers 661 # "warnings.warn("pickle support for Storage will be removed in 1.5..." 662 # so instead, we set the unpickled instance up to reinitialize them lazily. 663 state["_init_scale"] = self.get_scale() 664 state["_init_growth_tracker"] = self._get_growth_tracker() 665 state["_scale"] = None 666 state["_growth_tracker"] = None 667 return state 668 669 def __setstate__(self, state: Dict[str, Any]) -> None: 670 self.__dict__.update(state) 671 672 def _check_inf_per_device(self, optimizer: torch.optim.Optimizer) -> Dict[str, Any]: 673 _scale, _ = self._check_scale_growth_tracker("_check_inf_per_device") 674 675 dummy_inv_scale = torch.full((), 1.0, dtype=torch.float32, device=_scale.device) 676 found_inf = torch.full((), 0.0, dtype=torch.float32, device=_scale.device) 677 678 self._per_optimizer_states[id(optimizer)][ 679 "found_inf_per_device" 680 ] = self._unscale_grads_(optimizer, dummy_inv_scale, found_inf, True) 681 682 return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] 683 684 def _found_inf_per_device(self, optimizer: torch.optim.Optimizer) -> Dict[str, Any]: 685 return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] 686