1# mypy: allow-untyped-defs 2r"""Autograd anomaly mode.""" 3import warnings 4 5import torch 6 7 8__all__ = ["detect_anomaly", "set_detect_anomaly"] 9 10 11class detect_anomaly: 12 r"""Context-manager that enable anomaly detection for the autograd engine. 13 14 This does two things: 15 16 - Running the forward pass with detection enabled will allow the backward 17 pass to print the traceback of the forward operation that created the failing 18 backward function. 19 - If ``check_nan`` is ``True``, any backward computation that generate "nan" 20 value will raise an error. Default ``True``. 21 22 .. warning:: 23 This mode should be enabled only for debugging as the different tests 24 will slow down your program execution. 25 26 Example: 27 >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_ANOMALY) 28 >>> import torch 29 >>> from torch import autograd 30 >>> class MyFunc(autograd.Function): 31 ... @staticmethod 32 ... def forward(ctx, inp): 33 ... return inp.clone() 34 ... @staticmethod 35 ... def backward(ctx, gO): 36 ... # Error during the backward pass 37 ... raise RuntimeError("Some error in backward") 38 ... return gO.clone() 39 >>> def run_fn(a): 40 ... out = MyFunc.apply(a) 41 ... return out.sum() 42 >>> inp = torch.rand(10, 10, requires_grad=True) 43 >>> out = run_fn(inp) 44 >>> out.backward() 45 Traceback (most recent call last): 46 File "<stdin>", line 1, in <module> 47 File "/your/pytorch/install/torch/_tensor.py", line 93, in backward 48 torch.autograd.backward(self, gradient, retain_graph, create_graph) 49 File "/your/pytorch/install/torch/autograd/__init__.py", line 90, in backward 50 allow_unreachable=True) # allow_unreachable flag 51 File "/your/pytorch/install/torch/autograd/function.py", line 76, in apply 52 return self._forward_cls.backward(self, *args) 53 File "<stdin>", line 8, in backward 54 RuntimeError: Some error in backward 55 >>> with autograd.detect_anomaly(): 56 ... inp = torch.rand(10, 10, requires_grad=True) 57 ... out = run_fn(inp) 58 ... out.backward() 59 Traceback of forward call that caused the error: 60 File "tmp.py", line 53, in <module> 61 out = run_fn(inp) 62 File "tmp.py", line 44, in run_fn 63 out = MyFunc.apply(a) 64 Traceback (most recent call last): 65 File "<stdin>", line 4, in <module> 66 File "/your/pytorch/install/torch/_tensor.py", line 93, in backward 67 torch.autograd.backward(self, gradient, retain_graph, create_graph) 68 File "/your/pytorch/install/torch/autograd/__init__.py", line 90, in backward 69 allow_unreachable=True) # allow_unreachable flag 70 File "/your/pytorch/install/torch/autograd/function.py", line 76, in apply 71 return self._forward_cls.backward(self, *args) 72 File "<stdin>", line 8, in backward 73 RuntimeError: Some error in backward 74 75 """ 76 77 def __init__(self, check_nan=True) -> None: # noqa: D107 78 self.prev = torch.is_anomaly_enabled() 79 self.check_nan = check_nan 80 self.prev_check_nan = torch.is_anomaly_check_nan_enabled() 81 warnings.warn( 82 "Anomaly Detection has been enabled. " 83 "This mode will increase the runtime " 84 "and should only be enabled for debugging.", 85 stacklevel=2, 86 ) 87 88 def __enter__(self) -> None: # noqa: D105 89 torch.set_anomaly_enabled(True, self.check_nan) 90 91 def __exit__(self, *args: object) -> None: # noqa: D105 92 torch.set_anomaly_enabled(self.prev, self.prev_check_nan) 93 94 95class set_detect_anomaly: 96 r"""Context-manager that sets the anomaly detection for the autograd engine on or off. 97 98 ``set_detect_anomaly`` will enable or disable the autograd anomaly detection 99 based on its argument :attr:`mode`. 100 It can be used as a context-manager or as a function. 101 102 See ``detect_anomaly`` above for details of the anomaly detection behaviour. 103 104 Args: 105 mode (bool): Flag whether to enable anomaly detection (``True``), 106 or disable (``False``). 107 check_nan (bool): Flag whether to raise an error when the backward 108 generate "nan" 109 110 """ 111 112 def __init__(self, mode: bool, check_nan: bool = True) -> None: # noqa: D107 113 self.prev = torch.is_anomaly_enabled() 114 self.prev_check_nan = torch.is_anomaly_check_nan_enabled() 115 torch.set_anomaly_enabled(mode, check_nan) 116 117 def __enter__(self) -> None: # noqa: D105 118 pass 119 120 def __exit__(self, *args: object) -> None: # noqa: D105 121 torch.set_anomaly_enabled(self.prev, self.prev_check_nan) 122