xref: /aosp_15_r20/external/pytorch/torch/amp/autocast_mode.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import collections
3import functools
4import warnings
5from typing import Any, Optional
6
7import torch
8from torch.types import _dtype
9
10
11try:
12    import numpy as np
13
14    HAS_NUMPY = True
15except ModuleNotFoundError:
16    HAS_NUMPY = False
17    np = None  # type: ignore[assignment]
18
19__all__ = [
20    "autocast_decorator",
21    "autocast",
22    "is_autocast_available",
23    "custom_fwd",
24    "custom_bwd",
25]
26
27
28def is_autocast_available(device_type: str) -> bool:
29    r"""
30    Return a bool indicating if autocast is available on :attr:`device_type`.
31
32    Args:
33        device_type(str):  Device type to use. Possible values are: 'cuda', 'cpu', 'xpu' and so on.
34            The type is the same as the `type` attribute of a :class:`torch.device`.
35            Thus, you may obtain the device type of a tensor using `Tensor.device.type`.
36    """
37    return torch._C._is_autocast_available(device_type)
38
39
40def autocast_decorator(autocast_instance, func):
41    @functools.wraps(func)
42    def decorate_autocast(*args, **kwargs):
43        with autocast_instance:
44            return func(*args, **kwargs)
45
46    decorate_autocast.__script_unsupported = "@autocast() decorator is not supported in script mode"  # type: ignore[attr-defined]
47    return decorate_autocast
48
49
50class autocast:
51    r"""
52    Instances of :class:`autocast` serve as context managers or decorators that
53    allow regions of your script to run in mixed precision.
54
55    In these regions, ops run in an op-specific dtype chosen by autocast
56    to improve performance while maintaining accuracy.
57    See the :ref:`Autocast Op Reference<autocast-op-reference>` for details.
58
59    When entering an autocast-enabled region, Tensors may be any type.
60    You should not call ``half()`` or ``bfloat16()`` on your model(s) or inputs when using autocasting.
61
62    :class:`autocast` should wrap only the forward pass(es) of your network, including the loss
63    computation(s).  Backward passes under autocast are not recommended.
64    Backward ops run in the same type that autocast used for corresponding forward ops.
65
66    Example for CUDA Devices::
67
68        # Creates model and optimizer in default precision
69        model = Net().cuda()
70        optimizer = optim.SGD(model.parameters(), ...)
71
72        for input, target in data:
73            optimizer.zero_grad()
74
75            # Enables autocasting for the forward pass (model + loss)
76            with torch.autocast(device_type="cuda"):
77                output = model(input)
78                loss = loss_fn(output, target)
79
80            # Exits the context manager before backward()
81            loss.backward()
82            optimizer.step()
83
84    See the :ref:`Automatic Mixed Precision examples<amp-examples>` for usage (along with gradient scaling)
85    in more complex scenarios (e.g., gradient penalty, multiple models/losses, custom autograd functions).
86
87    :class:`autocast` can also be used as a decorator, e.g., on the ``forward`` method of your model::
88
89        class AutocastModel(nn.Module):
90            ...
91            @torch.autocast(device_type="cuda")
92            def forward(self, input):
93                ...
94
95    Floating-point Tensors produced in an autocast-enabled region may be ``float16``.
96    After returning to an autocast-disabled region, using them with floating-point
97    Tensors of different dtypes may cause type mismatch errors.  If so, cast the Tensor(s)
98    produced in the autocast region back to ``float32`` (or other dtype if desired).
99    If a Tensor from the autocast region is already ``float32``, the cast is a no-op,
100    and incurs no additional overhead.
101    CUDA Example::
102
103        # Creates some tensors in default dtype (here assumed to be float32)
104        a_float32 = torch.rand((8, 8), device="cuda")
105        b_float32 = torch.rand((8, 8), device="cuda")
106        c_float32 = torch.rand((8, 8), device="cuda")
107        d_float32 = torch.rand((8, 8), device="cuda")
108
109        with torch.autocast(device_type="cuda"):
110            # torch.mm is on autocast's list of ops that should run in float16.
111            # Inputs are float32, but the op runs in float16 and produces float16 output.
112            # No manual casts are required.
113            e_float16 = torch.mm(a_float32, b_float32)
114            # Also handles mixed input types
115            f_float16 = torch.mm(d_float32, e_float16)
116
117        # After exiting autocast, calls f_float16.float() to use with d_float32
118        g_float32 = torch.mm(d_float32, f_float16.float())
119
120    CPU Training Example::
121
122        # Creates model and optimizer in default precision
123        model = Net()
124        optimizer = optim.SGD(model.parameters(), ...)
125
126        for epoch in epochs:
127            for input, target in data:
128                optimizer.zero_grad()
129
130                # Runs the forward pass with autocasting.
131                with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
132                    output = model(input)
133                    loss = loss_fn(output, target)
134
135                loss.backward()
136                optimizer.step()
137
138
139    CPU Inference Example::
140
141        # Creates model in default precision
142        model = Net().eval()
143
144        with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
145            for input in data:
146                # Runs the forward pass with autocasting.
147                output = model(input)
148
149    CPU Inference Example with Jit Trace::
150
151        class TestModel(nn.Module):
152            def __init__(self, input_size, num_classes):
153                super().__init__()
154                self.fc1 = nn.Linear(input_size, num_classes)
155            def forward(self, x):
156                return self.fc1(x)
157
158        input_size = 2
159        num_classes = 2
160        model = TestModel(input_size, num_classes).eval()
161
162        # For now, we suggest to disable the Jit Autocast Pass,
163        # As the issue: https://github.com/pytorch/pytorch/issues/75956
164        torch._C._jit_set_autocast_mode(False)
165
166        with torch.cpu.amp.autocast(cache_enabled=False):
167            model = torch.jit.trace(model, torch.randn(1, input_size))
168        model = torch.jit.freeze(model)
169        # Models Run
170        for _ in range(3):
171            model(torch.randn(1, input_size))
172
173    Type mismatch errors *in* an autocast-enabled region are a bug; if this is what you observe,
174    please file an issue.
175
176    ``autocast(enabled=False)`` subregions can be nested in autocast-enabled regions.
177    Locally disabling autocast can be useful, for example, if you want to force a subregion
178    to run in a particular ``dtype``.  Disabling autocast gives you explicit control over
179    the execution type.  In the subregion, inputs from the surrounding region
180    should be cast to ``dtype`` before use::
181
182        # Creates some tensors in default dtype (here assumed to be float32)
183        a_float32 = torch.rand((8, 8), device="cuda")
184        b_float32 = torch.rand((8, 8), device="cuda")
185        c_float32 = torch.rand((8, 8), device="cuda")
186        d_float32 = torch.rand((8, 8), device="cuda")
187
188        with torch.autocast(device_type="cuda"):
189            e_float16 = torch.mm(a_float32, b_float32)
190            with torch.autocast(device_type="cuda", enabled=False):
191                # Calls e_float16.float() to ensure float32 execution
192                # (necessary because e_float16 was created in an autocasted region)
193                f_float32 = torch.mm(c_float32, e_float16.float())
194
195            # No manual casts are required when re-entering the autocast-enabled region.
196            # torch.mm again runs in float16 and produces float16 output, regardless of input types.
197            g_float16 = torch.mm(d_float32, f_float32)
198
199    The autocast state is thread-local.  If you want it enabled in a new thread, the context manager or decorator
200    must be invoked in that thread.  This affects :class:`torch.nn.DataParallel` and
201    :class:`torch.nn.parallel.DistributedDataParallel` when used with more than one GPU per process
202    (see :ref:`Working with Multiple GPUs<amp-multigpu>`).
203
204    Args:
205        device_type(str, required):  Device type to use. Possible values are: 'cuda', 'cpu', 'xpu' and 'hpu'.
206                                     The type is the same as the `type` attribute of a :class:`torch.device`.
207                                     Thus, you may obtain the device type of a tensor using `Tensor.device.type`.
208        enabled(bool, optional):  Whether autocasting should be enabled in the region.
209            Default: ``True``
210        dtype(torch_dtype, optional):  Data type for ops run in autocast. It uses the default value
211            (``torch.float16`` for CUDA and ``torch.bfloat16`` for CPU), given by
212            :func:`~torch.get_autocast_dtype`, if :attr:`dtype` is ``None``.
213            Default: ``None``
214        cache_enabled(bool, optional):  Whether the weight cache inside autocast should be enabled.
215            Default: ``True``
216    """
217
218    def __init__(
219        self,
220        device_type: str,
221        dtype: Optional[_dtype] = None,
222        enabled: bool = True,
223        cache_enabled: Optional[bool] = None,
224    ):
225        if not isinstance(device_type, str):
226            raise ValueError(
227                f"Expected `device_type` of type `str`, got: `{type(device_type)}`"
228            )
229        if dtype is None:
230            dtype = torch.get_autocast_dtype(device_type)
231        if torch._jit_internal.is_scripting():
232            self._enabled = enabled
233            self.device = device_type
234            self.fast_dtype = dtype
235            assert dtype is not None
236            return
237        self.device = device_type
238        if not is_autocast_available(self.device):
239            raise RuntimeError(
240                f"User specified an unsupported autocast device_type '{self.device}'"
241            )
242        self.custom_backend_name = torch._C._get_privateuse1_backend_name()
243        self.fast_dtype = torch.get_autocast_dtype(self.device)
244        if self.device == self.custom_backend_name:
245            necessary_funcs = [
246                "get_amp_supported_dtype",
247            ]
248            message = f"Tried to use AMP with the `{self.custom_backend_name}` backend, but the backend has not "
249            message += "registered a module or  the module miss some necessary funcs. The backend should register "
250            message += "a module by `torch._register_device_module`, and the module must have these funcs: \n"
251            message += "`get_amp_supported_dtype() -> List[torch.dtype]`. \n"
252
253            assert hasattr(torch, self.custom_backend_name), message
254            self.custom_device_mod = getattr(torch, self.custom_backend_name)
255            for func in necessary_funcs:
256                assert hasattr(self.custom_device_mod, func), (
257                    message + f"But the func `{func}` is missing. \n"
258                )
259
260        self._cache_enabled = torch.is_autocast_cache_enabled()
261        if (
262            enabled
263            and torch.cuda.amp.common.amp_definitely_not_available()
264            and self.device == "cuda"
265        ):
266            warnings.warn(
267                "User provided device_type of 'cuda', but CUDA is not available. Disabling"
268            )
269            enabled = False
270        if dtype is not None:
271            self.fast_dtype = dtype
272        if cache_enabled is not None:
273            self._cache_enabled = cache_enabled
274
275        if self.device == "cpu":
276            supported_dtype = [torch.bfloat16, torch.float16]
277            if self.fast_dtype not in supported_dtype and enabled:
278                error_message = "In CPU autocast, but the target dtype is not supported. Disabling autocast.\n"
279                error_message += "CPU Autocast only supports dtype of "
280                error_message += (
281                    ", ".join(str(dtype) for dtype in supported_dtype) + " currently."
282                )
283                warnings.warn(error_message)
284                enabled = False
285        elif self.device == "xpu":
286            supported_dtype = [torch.bfloat16, torch.float16]
287            if self.fast_dtype not in supported_dtype:
288                error_message = "In XPU autocast, but the target dtype is not supported. Disabling autocast.\n"
289                error_message += "XPU Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently."
290                warnings.warn(error_message)
291                enabled = False
292        elif self.device == "ipu":
293            supported_dtypes = [torch.bfloat16, torch.float16]
294            if self.fast_dtype not in supported_dtypes:
295                error_message = "In IPU autocast, but the target dtype is not supported. Disabling autocast.\n"
296                error_message += "IPU Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently."
297                warnings.warn(error_message)
298                enabled = False
299        elif self.device == "hpu":
300            supported_dtype = [torch.bfloat16, torch.float16]
301            if self.fast_dtype not in supported_dtype:
302                error_message = "In HPU autocast, but the target dtype is not supported. Disabling autocast.\n"
303                error_message += "HPU Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently."
304                warnings.warn(error_message)
305                enabled = False
306        elif self.device == self.custom_backend_name:
307            supported_dtype = self.custom_device_mod.get_amp_supported_dtype()
308            if self.fast_dtype not in supported_dtype:
309                error_message = f"In {self.custom_backend_name} autocast, but the target dtype is not supported. "
310                error_message += f"Disabling autocast.\n {self.custom_backend_name} Autocast only supports dtypes of "
311                error_message += (
312                    ", ".join(str(dtype) for dtype in supported_dtype) + " currently."
313                )
314                warnings.warn(error_message)
315                enabled = False
316        elif self.device == "cuda":
317            if (
318                enabled
319                and self.fast_dtype == torch.bfloat16
320                and not torch.cuda.is_bf16_supported()
321            ):
322                raise RuntimeError(
323                    "Current CUDA Device does not support bfloat16. Please switch dtype to float16."
324                )
325        elif self.device == "mps":
326            supported_dtype = [torch.float16]
327            if self.fast_dtype not in supported_dtype:
328                error_message = "In MPS autocast, but the target dtype is not supported. Disabling autocast.\n"
329                error_message += (
330                    "MPS Autocast only supports dtype of torch.bfloat16 currently."
331                )
332                warnings.warn(error_message)
333                enabled = False
334        elif self.device == "xla":
335            supported_dtype = [torch.float16, torch.bfloat16]
336            if self.fast_dtype not in supported_dtype:
337                error_message = "In XLA autocast, but the target dtype is not supported. Disabling autocast.\n"
338                error_message += (
339                    "XLA Autocast only supports dtype of torch.bfloat16 currently."
340                )
341                warnings.warn(error_message)
342                enabled = False
343        self._enabled = enabled
344
345    def __enter__(self):
346        if torch._jit_internal.is_scripting():
347            assert self.fast_dtype is not None
348            return self
349
350        self.prev_cache_enabled = torch.is_autocast_cache_enabled()
351        self.prev = torch.is_autocast_enabled(self.device)
352        self.prev_fastdtype = torch.get_autocast_dtype(self.device)
353        torch.set_autocast_enabled(self.device, self._enabled)
354        torch.set_autocast_dtype(self.device, self.fast_dtype)  # type: ignore[arg-type]
355        torch.autocast_increment_nesting()
356        torch.set_autocast_cache_enabled(self._cache_enabled)
357
358    def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any):  # type: ignore[override]
359        if torch._jit_internal.is_scripting():
360            return
361
362        # Drop the cache when we exit to a nesting level that's outside any instance of autocast.
363        if torch.autocast_decrement_nesting() == 0:
364            torch.clear_autocast_cache()
365        torch.set_autocast_enabled(self.device, self.prev)
366        torch.set_autocast_dtype(self.device, self.prev_fastdtype)
367        torch.set_autocast_cache_enabled(self.prev_cache_enabled)
368        return False
369
370    def __call__(self, func):
371        if torch._jit_internal.is_scripting():
372            return func
373        return autocast_decorator(self, func)
374
375
376# These functions aren't meant for public usage.
377# They are what we trace into a graph during pre_dispatch tracing
378# when we encounter an autocast context manager.
379def _enter_autocast(*vals):
380    # For pre-dispatch tracing, if a TorchFunction mode is active, we'll want to trace this into a graph.
381    if torch._C._is_torch_function_mode_enabled():
382        return torch.overrides.handle_torch_function(
383            torch.amp._enter_autocast, [], *vals
384        )
385    mode = torch.amp.autocast(*vals)
386    mode.__enter__()
387    return mode
388
389
390def _exit_autocast(mode):
391    if torch._C._is_torch_function_mode_enabled():
392        return torch.overrides.handle_torch_function(torch.amp._exit_autocast, [], mode)
393    mode.__exit__(None, None, None)
394
395
396# Casts Tensors and containers of Tensors.  Special-cases passthroughs for strings and np.ndarrays, which
397# may be falsely detected as "Iterables."
398def _cast(value, device_type: str, dtype: _dtype):
399    if isinstance(value, torch.Tensor):
400        is_eligible = (
401            value.is_floating_point()
402            and value.device.type == device_type
403            and (value.dtype is not torch.float64)
404        )
405        return value.to(dtype) if is_eligible else value
406    elif isinstance(value, (str, bytes)):
407        return value
408    elif HAS_NUMPY and isinstance(value, np.ndarray):
409        return value
410    elif isinstance(value, collections.abc.Mapping):
411        return {
412            _cast(k, device_type, dtype): _cast(v, device_type, dtype)
413            for k, v in value.items()
414        }
415    elif isinstance(value, collections.abc.Iterable):
416        iterable = (_cast(v, device_type, dtype) for v in value)
417        if isinstance(value, (list, tuple)):
418            return type(value)(iterable)
419        else:
420            return iterable
421    else:
422        return value
423
424
425def custom_fwd(
426    fwd=None,
427    *,
428    device_type: str,
429    cast_inputs: Optional[_dtype] = None,
430):
431    """
432    Create a helper decorator for ``forward`` methods of custom autograd functions.
433
434    Autograd functions are subclasses of :class:`torch.autograd.Function`.
435    See the :ref:`example page<amp-custom-examples>` for more detail.
436
437    Args:
438        device_type(str):  Device type to use. 'cuda', 'cpu', 'xpu' and so on.
439            The type is the same as the `type` attribute of a :class:`torch.device`.
440            Thus, you may obtain the device type of a tensor using `Tensor.device.type`.
441        cast_inputs (:class:`torch.dtype` or None, optional, default=None):  If not ``None``,
442            when ``forward`` runs in an autocast-enabled region, casts incoming
443            floating-point Tensors to the target dtype (non-floating-point Tensors are not affected),
444            then executes ``forward`` with autocast disabled.
445            If ``None``, ``forward``'s internal ops execute with the current autocast state.
446
447    .. note::
448        If the decorated ``forward`` is called outside an autocast-enabled region,
449        :func:`custom_fwd<custom_fwd>` is a no-op and ``cast_inputs`` has no effect.
450    """
451    if not isinstance(device_type, str):
452        raise ValueError(
453            f"Expected `device_type` of type `str`, got: `{type(device_type)}`"
454        )
455    if fwd is None:
456        return functools.partial(
457            custom_fwd, device_type=device_type, cast_inputs=cast_inputs
458        )
459
460    @functools.wraps(fwd)
461    def decorate_fwd(*args, **kwargs):
462        args[0]._dtype = torch.get_autocast_dtype(device_type)
463        if cast_inputs is None:
464            args[0]._fwd_used_autocast = torch.is_autocast_enabled(device_type)
465            return fwd(*args, **kwargs)
466        else:
467            autocast_context = torch.is_autocast_enabled(device_type)
468            args[0]._fwd_used_autocast = False
469            if autocast_context:
470                with autocast(device_type=device_type, enabled=False):
471                    return fwd(
472                        *_cast(args, device_type, cast_inputs),
473                        **_cast(kwargs, device_type, cast_inputs),
474                    )
475            else:
476                return fwd(*args, **kwargs)
477
478    return decorate_fwd
479
480
481# Autograd ensures incoming gradients are the same type as forward outputs.  Allowing a separate
482# cast_inputs argument on custom_bwd is unnecessary and could cause errors if it doesn't match
483# cast_inputs supplied to custom_fwd.
484def custom_bwd(bwd=None, *, device_type: str):
485    """Create a helper decorator for backward methods of custom autograd functions.
486
487    Autograd functions are subclasses of :class:`torch.autograd.Function`.
488    Ensures that ``backward`` executes with the same autocast state as ``forward``.
489    See the :ref:`example page<amp-custom-examples>` for more detail.
490
491    Args:
492        device_type(str):  Device type to use. 'cuda', 'cpu', 'xpu' and so on.
493            The type is the same as the `type` attribute of a :class:`torch.device`.
494            Thus, you may obtain the device type of a tensor using `Tensor.device.type`.
495    """
496
497    if not isinstance(device_type, str):
498        raise ValueError(
499            f"Expected `device_type` of type `str`, got: `{type(device_type)}`"
500        )
501    if bwd is None:
502        return functools.partial(custom_bwd, device_type=device_type)
503
504    @functools.wraps(bwd)
505    def decorate_bwd(*args, **kwargs):
506        with autocast(
507            device_type=device_type,
508            enabled=args[0]._fwd_used_autocast,
509            dtype=args[0]._dtype,
510        ):
511            return bwd(*args, **kwargs)
512
513    return decorate_bwd
514