xref: /aosp_15_r20/external/pytorch/torch/autograd/anomaly_mode.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2r"""Autograd anomaly mode."""
3import warnings
4
5import torch
6
7
8__all__ = ["detect_anomaly", "set_detect_anomaly"]
9
10
11class detect_anomaly:
12    r"""Context-manager that enable anomaly detection for the autograd engine.
13
14    This does two things:
15
16    - Running the forward pass with detection enabled will allow the backward
17      pass to print the traceback of the forward operation that created the failing
18      backward function.
19    - If ``check_nan`` is ``True``, any backward computation that generate "nan"
20      value will raise an error. Default ``True``.
21
22    .. warning::
23        This mode should be enabled only for debugging as the different tests
24        will slow down your program execution.
25
26    Example:
27        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_ANOMALY)
28        >>> import torch
29        >>> from torch import autograd
30        >>> class MyFunc(autograd.Function):
31        ...     @staticmethod
32        ...     def forward(ctx, inp):
33        ...         return inp.clone()
34        ...     @staticmethod
35        ...     def backward(ctx, gO):
36        ...         # Error during the backward pass
37        ...         raise RuntimeError("Some error in backward")
38        ...         return gO.clone()
39        >>> def run_fn(a):
40        ...     out = MyFunc.apply(a)
41        ...     return out.sum()
42        >>> inp = torch.rand(10, 10, requires_grad=True)
43        >>> out = run_fn(inp)
44        >>> out.backward()
45            Traceback (most recent call last):
46              File "<stdin>", line 1, in <module>
47              File "/your/pytorch/install/torch/_tensor.py", line 93, in backward
48                torch.autograd.backward(self, gradient, retain_graph, create_graph)
49              File "/your/pytorch/install/torch/autograd/__init__.py", line 90, in backward
50                allow_unreachable=True)  # allow_unreachable flag
51              File "/your/pytorch/install/torch/autograd/function.py", line 76, in apply
52                return self._forward_cls.backward(self, *args)
53              File "<stdin>", line 8, in backward
54            RuntimeError: Some error in backward
55        >>> with autograd.detect_anomaly():
56        ...     inp = torch.rand(10, 10, requires_grad=True)
57        ...     out = run_fn(inp)
58        ...     out.backward()
59            Traceback of forward call that caused the error:
60              File "tmp.py", line 53, in <module>
61                out = run_fn(inp)
62              File "tmp.py", line 44, in run_fn
63                out = MyFunc.apply(a)
64            Traceback (most recent call last):
65              File "<stdin>", line 4, in <module>
66              File "/your/pytorch/install/torch/_tensor.py", line 93, in backward
67                torch.autograd.backward(self, gradient, retain_graph, create_graph)
68              File "/your/pytorch/install/torch/autograd/__init__.py", line 90, in backward
69                allow_unreachable=True)  # allow_unreachable flag
70              File "/your/pytorch/install/torch/autograd/function.py", line 76, in apply
71                return self._forward_cls.backward(self, *args)
72              File "<stdin>", line 8, in backward
73            RuntimeError: Some error in backward
74
75    """
76
77    def __init__(self, check_nan=True) -> None:  # noqa: D107
78        self.prev = torch.is_anomaly_enabled()
79        self.check_nan = check_nan
80        self.prev_check_nan = torch.is_anomaly_check_nan_enabled()
81        warnings.warn(
82            "Anomaly Detection has been enabled. "
83            "This mode will increase the runtime "
84            "and should only be enabled for debugging.",
85            stacklevel=2,
86        )
87
88    def __enter__(self) -> None:  # noqa: D105
89        torch.set_anomaly_enabled(True, self.check_nan)
90
91    def __exit__(self, *args: object) -> None:  # noqa: D105
92        torch.set_anomaly_enabled(self.prev, self.prev_check_nan)
93
94
95class set_detect_anomaly:
96    r"""Context-manager that sets the anomaly detection for the autograd engine on or off.
97
98    ``set_detect_anomaly`` will enable or disable the autograd anomaly detection
99    based on its argument :attr:`mode`.
100    It can be used as a context-manager or as a function.
101
102    See ``detect_anomaly`` above for details of the anomaly detection behaviour.
103
104    Args:
105        mode (bool): Flag whether to enable anomaly detection (``True``),
106                     or disable (``False``).
107        check_nan (bool): Flag whether to raise an error when the backward
108                          generate "nan"
109
110    """
111
112    def __init__(self, mode: bool, check_nan: bool = True) -> None:  # noqa: D107
113        self.prev = torch.is_anomaly_enabled()
114        self.prev_check_nan = torch.is_anomaly_check_nan_enabled()
115        torch.set_anomaly_enabled(mode, check_nan)
116
117    def __enter__(self) -> None:  # noqa: D105
118        pass
119
120    def __exit__(self, *args: object) -> None:  # noqa: D105
121        torch.set_anomaly_enabled(self.prev, self.prev_check_nan)
122