xref: /aosp_15_r20/external/pytorch/torch/amp/grad_scaler.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from __future__ import annotations
3
4import inspect
5import warnings
6from collections import abc, defaultdict
7from enum import Enum
8from typing import Any, cast, Dict, Iterable, List, Optional, overload, Tuple, Union
9
10import torch
11
12
13__all__ = ["OptState", "GradScaler"]
14
15
16class _MultiDeviceReplicator:
17    """Lazily serves copies of a tensor to requested devices.
18
19    Copies are cached per-device.
20    """
21
22    def __init__(self, master_tensor: torch.Tensor) -> None:
23        self.master = master_tensor
24        self._per_device_tensors: Dict[torch.device, torch.Tensor] = {}
25
26    def get(self, device: torch.device) -> torch.Tensor:
27        retval = self._per_device_tensors.get(device, None)
28        if retval is None:
29            retval = self.master.to(device=device, non_blocking=True, copy=True)
30            self._per_device_tensors[device] = retval
31        return retval
32
33
34# Defines default_factory for GradScaler's _per_optimizer_states defaultdict,
35# as well as associated "enum" values.  Prefers defining these at top level because
36# - Lambdas can't be pickled, so we don't want to supply a lambda as the factory.
37# - Defining READY, UNSCALED, STEPPED and _refresh_per_optimizer_state within GradScaler
38#   causes a circular reference, which we'd rather avoid.
39class OptState(Enum):
40    READY = 0
41    UNSCALED = 1
42    STEPPED = 2
43
44
45def _refresh_per_optimizer_state() -> Dict[str, Any]:
46    return {"stage": OptState.READY, "found_inf_per_device": {}}
47
48
49class GradScaler:
50    """An instance ``scaler`` of :class:`GradScaler`.
51
52    Helps perform the steps of gradient scaling
53    conveniently.
54
55    * ``scaler.scale(loss)`` multiplies a given loss by ``scaler``'s current scale factor.
56    * ``scaler.step(optimizer)`` safely unscales gradients and calls ``optimizer.step()``.
57    * ``scaler.update()`` updates ``scaler``'s scale factor.
58
59    Example::
60
61        # Creates a GradScaler once at the beginning of training.
62        scaler = GradScaler()
63
64        for epoch in epochs:
65            for input, target in data:
66                optimizer.zero_grad()
67                output = model(input)
68                loss = loss_fn(output, target)
69
70                # Scales loss.  Calls backward() on scaled loss to create scaled gradients.
71                scaler.scale(loss).backward()
72
73                # scaler.step() first unscales gradients of the optimizer's params.
74                # If gradients don't contain infs/NaNs, optimizer.step() is then called,
75                # otherwise, optimizer.step() is skipped.
76                scaler.step(optimizer)
77
78                # Updates the scale for next iteration.
79                scaler.update()
80
81    See the :ref:`Automatic Mixed Precision examples<amp-examples>` for usage
82    (along with autocasting) in more complex cases like gradient clipping, gradient accumulation, gradient penalty,
83    and multiple losses/optimizers.
84
85    ``scaler`` dynamically estimates the scale factor each iteration.  To minimize gradient underflow,
86    a large scale factor should be used.  However, ``float16`` values can "overflow" (become inf or NaN) if
87    the scale factor is too large.  Therefore, the optimal scale factor is the largest factor that can be used
88    without incurring inf or NaN gradient values.
89    ``scaler`` approximates the optimal scale factor over time by checking the gradients for infs and NaNs during every
90    ``scaler.step(optimizer)`` (or optional separate ``scaler.unscale_(optimizer)``, see :meth:`unscale_`).
91
92    * If infs/NaNs are found, ``scaler.step(optimizer)`` skips the underlying ``optimizer.step()`` (so the params
93      themselves remain uncorrupted) and ``update()`` multiplies the scale by ``backoff_factor``.
94
95    * If no infs/NaNs are found, ``scaler.step(optimizer)`` runs the underlying ``optimizer.step()`` as usual.
96      If ``growth_interval`` unskipped iterations occur consecutively, ``update()`` multiplies the scale by
97      ``growth_factor``.
98
99    The scale factor often causes infs/NaNs to appear in gradients for the first few iterations as its
100    value calibrates.  ``scaler.step`` will skip the underlying ``optimizer.step()`` for these
101    iterations.  After that, step skipping should occur rarely (once every few hundred or thousand iterations).
102
103    Args:
104        device (str, optional, default="cuda"): Device type to use. Possible values are: 'cuda' and 'cpu'.
105            The type is the same as the `type` attribute of a :class:`torch.device`.
106            Thus, you may obtain the device type of a tensor using `Tensor.device.type`.
107        init_scale (float, optional, default=2.**16):  Initial scale factor.
108        growth_factor (float, optional, default=2.0):  Factor by which the scale is multiplied during
109            :meth:`update` if no inf/NaN gradients occur for ``growth_interval`` consecutive iterations.
110        backoff_factor (float, optional, default=0.5):  Factor by which the scale is multiplied during
111            :meth:`update` if inf/NaN gradients occur in an iteration.
112        growth_interval (int, optional, default=2000):  Number of consecutive iterations without inf/NaN gradients
113            that must occur for the scale to be multiplied by ``growth_factor``.
114        enabled (bool, optional):  If ``False``, disables gradient scaling. :meth:`step` simply
115            invokes the underlying ``optimizer.step()``, and other methods become no-ops.
116            Default: ``True``
117    """
118
119    def __init__(
120        self,
121        device: str = "cuda",
122        init_scale: float = 2.0**16,
123        growth_factor: float = 2.0,
124        backoff_factor: float = 0.5,
125        growth_interval: int = 2000,
126        enabled: bool = True,
127    ) -> None:
128        self._device = device
129        self._enabled = enabled
130        if self._device == "cuda":
131            if enabled and torch.cuda.amp.common.amp_definitely_not_available():
132                warnings.warn(
133                    "torch.cuda.amp.GradScaler is enabled, but CUDA is not available.  Disabling."
134                )
135                self._enabled = False
136
137        if self._enabled:
138            assert growth_factor > 1.0, "The growth factor must be > 1.0."
139            assert backoff_factor < 1.0, "The backoff factor must be < 1.0."
140
141            self._init_scale = init_scale
142            # self._scale will be lazily initialized during the first call to scale()
143            self._scale: Optional[torch.Tensor] = None
144            self._growth_factor = growth_factor
145            self._backoff_factor = backoff_factor
146            self._growth_interval = growth_interval
147            self._init_growth_tracker = 0
148            # self._growth_tracker will be lazily initialized during the first call to scale()
149            self._growth_tracker: Optional[torch.Tensor] = None
150            self._per_optimizer_states: Dict[int, Dict[str, Any]] = defaultdict(
151                _refresh_per_optimizer_state
152            )
153
154    def _check_scale_growth_tracker(
155        self, funcname: str
156    ) -> Tuple[torch.Tensor, torch.Tensor]:
157        fix = "This may indicate your script did not use scaler.scale(loss or outputs) earlier in the iteration."
158        assert self._scale is not None, (
159            f"Attempted {funcname} but _scale is None.  " + fix
160        )
161        assert self._growth_tracker is not None, (
162            f"Attempted {funcname} but _growth_tracker is None.  " + fix
163        )
164        return (self._scale, self._growth_tracker)
165
166    def _lazy_init_scale_growth_tracker(self, dev: torch.device) -> None:
167        assert self._growth_tracker is None, "_growth_tracker initialized before _scale"
168        self._scale = torch.full((), self._init_scale, dtype=torch.float32, device=dev)
169        self._growth_tracker = torch.full(
170            (), self._init_growth_tracker, dtype=torch.int32, device=dev
171        )
172
173    @overload
174    def scale(self, outputs: torch.Tensor) -> torch.Tensor:
175        ...
176
177    @overload
178    def scale(self, outputs: List[torch.Tensor]) -> List[torch.Tensor]:
179        ...
180
181    @overload
182    def scale(self, outputs: Tuple[torch.Tensor, ...]) -> Tuple[torch.Tensor, ...]:
183        ...
184
185    @overload
186    def scale(self, outputs: Iterable[torch.Tensor]) -> Iterable[torch.Tensor]:
187        ...
188
189    def scale(
190        self,
191        outputs: Union[torch.Tensor, Iterable[torch.Tensor]],
192    ) -> Union[torch.Tensor, Iterable[torch.Tensor]]:
193        """
194        Multiplies ('scales') a tensor or list of tensors by the scale factor.
195
196        Returns scaled outputs.  If this instance of :class:`GradScaler` is not enabled, outputs are returned
197        unmodified.
198
199        Args:
200            outputs (Tensor or iterable of Tensors):  Outputs to scale.
201        """
202        if not self._enabled:
203            return outputs
204
205        # Short-circuit for the common case.
206        if isinstance(outputs, torch.Tensor):
207            if self._scale is None:
208                self._lazy_init_scale_growth_tracker(outputs.device)
209            assert self._scale is not None
210            return outputs * self._scale.to(device=outputs.device, non_blocking=True)
211
212        # Invoke the more complex machinery only if we're treating multiple outputs.
213        stash: List[
214            _MultiDeviceReplicator
215        ] = []  # holds a reference that can be overwritten by apply_scale
216
217        def apply_scale(val: Union[torch.Tensor, Iterable[torch.Tensor]]):
218            if isinstance(val, torch.Tensor):
219                if len(stash) == 0:
220                    if self._scale is None:
221                        self._lazy_init_scale_growth_tracker(val.device)
222                    assert self._scale is not None
223                    stash.append(_MultiDeviceReplicator(self._scale))
224                return val * stash[0].get(val.device)
225            if isinstance(val, abc.Iterable):
226                iterable = map(apply_scale, val)
227                if isinstance(val, (list, tuple)):
228                    return type(val)(iterable)
229                return iterable
230            raise ValueError("outputs must be a Tensor or an iterable of Tensors")
231
232        return apply_scale(outputs)
233
234    def _unscale_grads_(
235        self,
236        optimizer: torch.optim.Optimizer,
237        inv_scale: torch.Tensor,
238        found_inf: torch.Tensor,
239        allow_fp16: bool,
240    ) -> Dict[torch.device, torch.Tensor]:
241        per_device_inv_scale = _MultiDeviceReplicator(inv_scale)
242        per_device_found_inf = _MultiDeviceReplicator(found_inf)
243
244        # To set up _amp_foreach_non_finite_check_and_unscale_, split grads by device and dtype.
245        # There could be hundreds of grads, so we'd like to iterate through them just once.
246        # However, we don't know their devices or dtypes in advance.
247
248        # https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict
249        # Google says mypy struggles with defaultdicts type annotations.
250        per_device_and_dtype_grads: Dict[
251            torch.device, Dict[torch.dtype, List[torch.Tensor]]
252        ] = defaultdict(lambda: defaultdict(list))
253        with torch.no_grad():
254            for group in optimizer.param_groups:
255                for param in group["params"]:
256                    assert isinstance(param, torch.Tensor)
257                    if param.grad is None:
258                        continue
259                    if (not allow_fp16) and param.grad.dtype == torch.float16:
260                        raise ValueError("Attempting to unscale FP16 gradients.")
261                    if param.grad.is_sparse:
262                        # is_coalesced() == False means the sparse grad has values with duplicate indices.
263                        # coalesce() deduplicates indices and adds all values that have the same index.
264                        # For scaled fp16 values, there's a good chance coalescing will cause overflow,
265                        # so we should check the coalesced _values().
266                        if param.grad.dtype is torch.float16:
267                            param.grad = param.grad.coalesce()
268                        to_unscale = param.grad._values()
269                    else:
270                        to_unscale = param.grad
271
272                    # TODO: is there a way to split by device and dtype without appending in the inner loop?
273                    per_device_and_dtype_grads[to_unscale.device][
274                        to_unscale.dtype
275                    ].append(to_unscale)
276
277            for device, per_dtype_grads in per_device_and_dtype_grads.items():
278                for grads in per_dtype_grads.values():
279                    torch._amp_foreach_non_finite_check_and_unscale_(
280                        grads,
281                        per_device_found_inf.get(device),
282                        per_device_inv_scale.get(device),
283                    )
284
285        return per_device_found_inf._per_device_tensors
286
287    def unscale_(self, optimizer: torch.optim.Optimizer) -> None:
288        """
289        Divides ("unscales") the optimizer's gradient tensors by the scale factor.
290
291        :meth:`unscale_` is optional, serving cases where you need to
292        :ref:`modify or inspect gradients<working-with-unscaled-gradients>`
293        between the backward pass(es) and :meth:`step`.
294        If :meth:`unscale_` is not called explicitly,  gradients will be unscaled  automatically during :meth:`step`.
295
296        Simple example, using :meth:`unscale_` to enable clipping of unscaled gradients::
297
298            ...
299            scaler.scale(loss).backward()
300            scaler.unscale_(optimizer)
301            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
302            scaler.step(optimizer)
303            scaler.update()
304
305        Args:
306            optimizer (torch.optim.Optimizer):  Optimizer that owns the gradients to be unscaled.
307
308        .. note::
309            :meth:`unscale_` does not incur a CPU-GPU sync.
310
311        .. warning::
312            :meth:`unscale_` should only be called once per optimizer per :meth:`step` call,
313            and only after all gradients for that optimizer's assigned parameters have been accumulated.
314            Calling :meth:`unscale_` twice for a given optimizer between each :meth:`step` triggers a RuntimeError.
315
316        .. warning::
317            :meth:`unscale_` may unscale sparse gradients out of place, replacing the ``.grad`` attribute.
318        """
319        if not self._enabled:
320            return
321
322        self._check_scale_growth_tracker("unscale_")
323
324        optimizer_state = self._per_optimizer_states[id(optimizer)]
325
326        if optimizer_state["stage"] is OptState.UNSCALED:
327            raise RuntimeError(
328                "unscale_() has already been called on this optimizer since the last update()."
329            )
330        elif optimizer_state["stage"] is OptState.STEPPED:
331            raise RuntimeError("unscale_() is being called after step().")
332
333        # FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64.
334        assert self._scale is not None
335        inv_scale = self._scale.double().reciprocal().float()
336        found_inf = torch.full((), 0.0, dtype=torch.float32, device=self._scale.device)
337
338        optimizer_state["found_inf_per_device"] = self._unscale_grads_(
339            optimizer, inv_scale, found_inf, False
340        )
341        optimizer_state["stage"] = OptState.UNSCALED
342
343    def _maybe_opt_step(
344        self,
345        optimizer: torch.optim.Optimizer,
346        optimizer_state: Dict[str, Any],
347        *args: Any,
348        **kwargs: Any,
349    ) -> Optional[float]:
350        retval: Optional[float] = None
351        if not sum(v.item() for v in optimizer_state["found_inf_per_device"].values()):
352            retval = optimizer.step(*args, **kwargs)
353        return retval
354
355    def step(
356        self, optimizer: torch.optim.Optimizer, *args: Any, **kwargs: Any
357    ) -> Optional[float]:
358        """Invoke ``unscale_(optimizer)`` followed by parameter update, if gradients are not infs/NaN.
359
360        :meth:`step` carries out the following two operations:
361
362        1.  Internally invokes ``unscale_(optimizer)`` (unless :meth:`unscale_` was explicitly called for ``optimizer``
363            earlier in the iteration).  As part of the :meth:`unscale_`, gradients are checked for infs/NaNs.
364        2.  If no inf/NaN gradients are found, invokes ``optimizer.step()`` using the unscaled
365            gradients.  Otherwise, ``optimizer.step()`` is skipped to avoid corrupting the params.
366
367        ``*args`` and ``**kwargs`` are forwarded to ``optimizer.step()``.
368
369        Returns the return value of ``optimizer.step(*args, **kwargs)``.
370
371        Args:
372            optimizer (torch.optim.Optimizer):  Optimizer that applies the gradients.
373            args:  Any arguments.
374            kwargs:  Any keyword arguments.
375
376        .. warning::
377            Closure use is not currently supported.
378        """
379        if not self._enabled:
380            return optimizer.step(*args, **kwargs)
381
382        if "closure" in kwargs:
383            raise RuntimeError(
384                "Closure use is not currently supported if GradScaler is enabled."
385            )
386
387        self._check_scale_growth_tracker("step")
388
389        optimizer_state = self._per_optimizer_states[id(optimizer)]
390
391        if optimizer_state["stage"] is OptState.STEPPED:
392            raise RuntimeError(
393                "step() has already been called since the last update()."
394            )
395
396        retval: Optional[float] = None
397
398        if getattr(optimizer, "_step_supports_amp_scaling", False):
399            # This optimizer has customized scale-handling logic, so we can call optimizer.step() directly.
400            # The contract with custom optimizers is that their step() should accept an additional,
401            # optional grad_scaler kwarg.  We append self to the kwargs so the custom optimizer has full information:
402            # it can query its own state, invoke unscale_ on itself, etc
403            # The contract above is being deprecated to avoid introducing `grad_scaler: GradScaler` argument
404            # to `Optimizer.step`. The new behavior is going to add two Tensor attributes of `grad_scale`
405            # and `found_inf` to the passed optimizer so that the optimizer can utilize those
406            # to skip the parameter updates or unscale gradients before updating parameters in
407            # the fused kernel, e.g. `FusedAdamMathFunctor`.
408            # In this behavior, `GradScaler._check_inf_per_device` is called if `OptState.READY`,
409            # while the method is expected to be called by users side, i.e. their optimizers.
410            kwargs_ = kwargs
411            has_grad_scaler_kwarg = (
412                "grad_scaler" in inspect.signature(optimizer.step).parameters
413            )
414            if has_grad_scaler_kwarg:
415                warnings.warn(
416                    "GradScaler is going to stop passing itself as a keyword argument to the passed "
417                    "optimizer. In the near future GradScaler registers `grad_scale: Tensor` and "
418                    "`found_inf: Tensor` to the passed optimizer and let the optimizer use them directly.",
419                    FutureWarning,
420                )
421                kwargs_.update({"grad_scaler": self})
422            else:
423                if optimizer_state["stage"] is OptState.READY:
424                    self._check_inf_per_device(optimizer)
425                scaler = self._get_scale_async()
426                assert scaler is not None
427                found_inf = cast(
428                    torch.Tensor,
429                    sum(
430                        [  # noqa: C419
431                            t.to(scaler.device, non_blocking=True)
432                            for t in optimizer_state["found_inf_per_device"].values()
433                        ]
434                    ),
435                )
436                # Take the product of the scales, if the user has already set `optimizer.grad_scale`.
437                optimizer.grad_scale = (  # type: ignore[attr-defined]
438                    getattr(optimizer, "grad_scale", None)
439                    if optimizer_state["stage"] == OptState.UNSCALED
440                    else scaler * getattr(optimizer, "grad_scale", 1)
441                )
442                optimizer.found_inf = found_inf  # type: ignore[attr-defined]
443            retval = optimizer.step(*args, **kwargs_)
444            optimizer_state["stage"] = OptState.STEPPED
445            if not has_grad_scaler_kwarg:
446                del optimizer.grad_scale  # type: ignore[attr-defined]
447                del optimizer.found_inf  # type: ignore[attr-defined]
448            return retval
449
450        if optimizer_state["stage"] is OptState.READY:
451            self.unscale_(optimizer)
452
453        assert (
454            len(optimizer_state["found_inf_per_device"]) > 0
455        ), "No inf checks were recorded for this optimizer."
456
457        retval = self._maybe_opt_step(optimizer, optimizer_state, *args, **kwargs)
458
459        optimizer_state["stage"] = OptState.STEPPED
460
461        return retval
462
463    def update(self, new_scale: Optional[Union[float, torch.Tensor]] = None) -> None:
464        """Update the scale factor.
465
466        If any optimizer steps were skipped the scale is multiplied by ``backoff_factor``
467        to reduce it. If ``growth_interval`` unskipped iterations occurred consecutively,
468        the scale is multiplied by ``growth_factor`` to increase it.
469
470        Passing ``new_scale`` sets the new scale value manually. (``new_scale`` is not
471        used directly, it's used to fill GradScaler's internal scale tensor. So if
472        ``new_scale`` was a tensor, later in-place changes to that tensor will not further
473        affect the scale GradScaler uses internally.)
474
475        Args:
476            new_scale (float or :class:`torch.Tensor`, optional, default=None):  New scale factor.
477
478        .. warning::
479            :meth:`update` should only be called at the end of the iteration, after ``scaler.step(optimizer)`` has
480            been invoked for all optimizers used this iteration.
481
482        .. warning::
483            For performance reasons, we do not check the scale factor value to avoid synchronizations,
484            so the scale factor is not guaranteed to be above 1. If the scale falls below 1 and/or
485            you are seeing NaNs in your gradients or loss, something is likely wrong. For example,
486            bf16-pretrained models are often incompatible with AMP/fp16 due to differing dynamic ranges.
487        """
488        if not self._enabled:
489            return
490
491        _scale, _growth_tracker = self._check_scale_growth_tracker("update")
492
493        if new_scale is not None:
494            assert self._scale is not None
495            # Accept a new user-defined scale.
496            if isinstance(new_scale, float):
497                self._scale.fill_(new_scale)
498            else:
499                reason = "new_scale should be a float or a 1-element torch.cuda.FloatTensor or \
500                    torch.FloatTensor with requires_grad=False."
501                assert new_scale.device.type == self._device, reason
502                assert new_scale.numel() == 1, reason
503                assert new_scale.requires_grad is False, reason
504                self._scale.copy_(new_scale)
505        else:
506            # Consume shared inf/nan data collected from optimizers to update the scale.
507            # If all found_inf tensors are on the same device as self._scale, this operation is asynchronous.
508            found_infs = [
509                found_inf.to(device=_scale.device, non_blocking=True)
510                for state in self._per_optimizer_states.values()
511                for found_inf in state["found_inf_per_device"].values()
512            ]
513
514            assert len(found_infs) > 0, "No inf checks were recorded prior to update."
515
516            found_inf_combined = found_infs[0]
517            if len(found_infs) > 1:
518                for i in range(1, len(found_infs)):
519                    found_inf_combined += found_infs[i]
520
521            torch._amp_update_scale_(
522                _scale,
523                _growth_tracker,
524                found_inf_combined,
525                self._growth_factor,
526                self._backoff_factor,
527                self._growth_interval,
528            )
529
530        # To prepare for next iteration, clear the data collected from optimizers this iteration.
531        self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
532
533    def _get_scale_async(self) -> Optional[torch.Tensor]:
534        return self._scale
535
536    def get_scale(self) -> float:
537        """Return a Python float containing the current scale, or 1.0 if scaling is disabled.
538
539        .. warning::
540            :meth:`get_scale` incurs a CPU-GPU sync.
541        """
542        if self._enabled:
543            return (
544                self._init_scale
545                if (scale := self._get_scale_async()) is None
546                else cast(float, scale.item())
547            )
548        return 1.0
549
550    def get_growth_factor(self) -> float:
551        r"""Return a Python float containing the scale growth factor."""
552        return self._growth_factor
553
554    def set_growth_factor(self, new_factor: float) -> None:
555        r"""Set a new scale growth factor.
556
557        Args:
558            new_scale (float):  Value to use as the new scale growth factor.
559        """
560        self._growth_factor = new_factor
561
562    def get_backoff_factor(self) -> float:
563        r"""Return a Python float containing the scale backoff factor."""
564        return self._backoff_factor
565
566    def set_backoff_factor(self, new_factor: float) -> None:
567        r"""Set a new scale backoff factor.
568
569        Args:
570            new_scale (float):  Value to use as the new scale backoff factor.
571        """
572        self._backoff_factor = new_factor
573
574    def get_growth_interval(self) -> int:
575        r"""Return a Python int containing the growth interval."""
576        return self._growth_interval
577
578    def set_growth_interval(self, new_interval: int) -> None:
579        r"""Set a new growth interval.
580
581        Args:
582            new_interval (int):  Value to use as the new growth interval.
583        """
584        self._growth_interval = new_interval
585
586    def _get_growth_tracker(self) -> int:
587        if self._enabled:
588            return (
589                self._init_growth_tracker
590                if self._growth_tracker is None
591                else cast(int, self._growth_tracker.item())
592            )
593        return 0
594
595    def is_enabled(self) -> bool:
596        r"""Return a bool indicating whether this instance is enabled."""
597        return self._enabled
598
599    def state_dict(self) -> Dict[str, Any]:
600        r"""Return the state of the scaler as a :class:`dict`.
601
602        It contains five entries:
603
604        * ``"scale"`` - a Python float containing the current scale
605        * ``"growth_factor"`` - a Python float containing the current growth factor
606        * ``"backoff_factor"`` - a Python float containing the current backoff factor
607        * ``"growth_interval"`` - a Python int containing the current growth interval
608        * ``"_growth_tracker"`` - a Python int containing the number of recent consecutive unskipped steps.
609
610        If this instance is not enabled, returns an empty dict.
611
612        .. note::
613           If you wish to checkpoint the scaler's state after a particular iteration, :meth:`state_dict`
614           should be called after :meth:`update`.
615        """
616        if self._enabled:
617            return {
618                "scale": self.get_scale(),
619                "growth_factor": self._growth_factor,
620                "backoff_factor": self._backoff_factor,
621                "growth_interval": self._growth_interval,
622                "_growth_tracker": self._get_growth_tracker(),
623            }
624        return {}
625
626    def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
627        r"""Load the scaler state.
628
629        If this instance is disabled, :meth:`load_state_dict` is a no-op.
630
631        Args:
632           state_dict(dict): scaler state.  Should be an object returned from a call to :meth:`state_dict`.
633        """
634        if not self._enabled:
635            return
636
637        if len(state_dict) == 0:
638            raise RuntimeError(
639                "The source state dict is empty, possibly because it was saved "
640                "from a disabled instance of GradScaler."
641            )
642
643        self._init_scale = cast(float, state_dict["scale"])
644        if self._scale is not None:
645            self._scale.fill_(state_dict["scale"])
646        self._growth_factor = cast(float, state_dict["growth_factor"])
647        self._backoff_factor = cast(float, state_dict["backoff_factor"])
648        self._growth_interval = cast(int, state_dict["growth_interval"])
649        self._init_growth_tracker = cast(int, state_dict["_growth_tracker"])
650        if self._growth_tracker is not None:
651            self._growth_tracker.fill_(state_dict["_growth_tracker"])
652
653    def __getstate__(self) -> Dict[str, Any]:
654        state = self.__dict__.copy()
655        if self._enabled:
656            assert len(self._per_optimizer_states) == 0, (
657                "A GradScaler instance may only be pickled at the beginning "
658                "of an iteration, or at the end after scaler.update()."
659            )
660            # Pickling _scale and _growth_tracker Tensors directly triggers
661            # "warnings.warn("pickle support for Storage will be removed in 1.5..."
662            # so instead, we set the unpickled instance up to reinitialize them lazily.
663            state["_init_scale"] = self.get_scale()
664            state["_init_growth_tracker"] = self._get_growth_tracker()
665            state["_scale"] = None
666            state["_growth_tracker"] = None
667        return state
668
669    def __setstate__(self, state: Dict[str, Any]) -> None:
670        self.__dict__.update(state)
671
672    def _check_inf_per_device(self, optimizer: torch.optim.Optimizer) -> Dict[str, Any]:
673        _scale, _ = self._check_scale_growth_tracker("_check_inf_per_device")
674
675        dummy_inv_scale = torch.full((), 1.0, dtype=torch.float32, device=_scale.device)
676        found_inf = torch.full((), 0.0, dtype=torch.float32, device=_scale.device)
677
678        self._per_optimizer_states[id(optimizer)][
679            "found_inf_per_device"
680        ] = self._unscale_grads_(optimizer, dummy_inv_scale, found_inf, True)
681
682        return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"]
683
684    def _found_inf_per_device(self, optimizer: torch.optim.Optimizer) -> Dict[str, Any]:
685        return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"]
686