1# mypy: allow-untyped-defs 2from typing import Any 3 4import torch 5from torch.utils._contextlib import ( 6 _DecoratorContextManager, 7 _NoParamDecoratorContextManager, 8 F, 9) 10 11 12__all__ = [ 13 "no_grad", 14 "enable_grad", 15 "set_grad_enabled", 16 "inference_mode", 17 "set_multithreading_enabled", 18] 19 20 21class no_grad(_NoParamDecoratorContextManager): 22 r"""Context-manager that disables gradient calculation. 23 24 Disabling gradient calculation is useful for inference, when you are sure 25 that you will not call :meth:`Tensor.backward()`. It will reduce memory 26 consumption for computations that would otherwise have `requires_grad=True`. 27 28 In this mode, the result of every computation will have 29 `requires_grad=False`, even when the inputs have `requires_grad=True`. 30 There is an exception! All factory functions, or functions that create 31 a new Tensor and take a requires_grad kwarg, will NOT be affected by 32 this mode. 33 34 This context manager is thread local; it will not affect computation 35 in other threads. 36 37 Also functions as a decorator. 38 39 .. note:: 40 No-grad is one of several mechanisms that can enable or 41 disable gradients locally see :ref:`locally-disable-grad-doc` for 42 more information on how they compare. 43 44 .. note:: 45 This API does not apply to :ref:`forward-mode AD <forward-mode-ad>`. 46 If you want to disable forward AD for a computation, you can unpack 47 your dual tensors. 48 49 Example:: 50 >>> # xdoctest: +SKIP 51 >>> x = torch.tensor([1.], requires_grad=True) 52 >>> with torch.no_grad(): 53 ... y = x * 2 54 >>> y.requires_grad 55 False 56 >>> @torch.no_grad() 57 ... def doubler(x): 58 ... return x * 2 59 >>> z = doubler(x) 60 >>> z.requires_grad 61 False 62 >>> @torch.no_grad() 63 ... def tripler(x): 64 ... return x * 3 65 >>> z = tripler(x) 66 >>> z.requires_grad 67 False 68 >>> # factory function exception 69 >>> with torch.no_grad(): 70 ... a = torch.nn.Parameter(torch.rand(10)) 71 >>> a.requires_grad 72 True 73 """ 74 75 def __init__(self) -> None: 76 if not torch._jit_internal.is_scripting(): 77 super().__init__() 78 self.prev = False 79 80 def __enter__(self) -> None: 81 self.prev = torch.is_grad_enabled() 82 torch.set_grad_enabled(False) 83 84 def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: 85 torch.set_grad_enabled(self.prev) 86 87 88class enable_grad(_NoParamDecoratorContextManager): 89 r"""Context-manager that enables gradient calculation. 90 91 Enables gradient calculation, if it has been disabled via :class:`~no_grad` 92 or :class:`~set_grad_enabled`. 93 94 This context manager is thread local; it will not affect computation 95 in other threads. 96 97 Also functions as a decorator. 98 99 .. note:: 100 enable_grad is one of several mechanisms that can enable or 101 disable gradients locally see :ref:`locally-disable-grad-doc` for 102 more information on how they compare. 103 104 .. note:: 105 This API does not apply to :ref:`forward-mode AD <forward-mode-ad>`. 106 107 Example:: 108 >>> # xdoctest: +SKIP 109 >>> x = torch.tensor([1.], requires_grad=True) 110 >>> with torch.no_grad(): 111 ... with torch.enable_grad(): 112 ... y = x * 2 113 >>> y.requires_grad 114 True 115 >>> y.backward() 116 >>> x.grad 117 tensor([2.]) 118 >>> @torch.enable_grad() 119 ... def doubler(x): 120 ... return x * 2 121 >>> with torch.no_grad(): 122 ... z = doubler(x) 123 >>> z.requires_grad 124 True 125 >>> @torch.enable_grad() 126 ... def tripler(x): 127 ... return x * 3 128 >>> with torch.no_grad(): 129 ... z = tripler(x) 130 >>> z.requires_grad 131 True 132 133 """ 134 135 def __enter__(self) -> None: 136 self.prev = torch.is_grad_enabled() 137 torch._C._set_grad_enabled(True) 138 139 def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: 140 torch._C._set_grad_enabled(self.prev) 141 142 143class set_grad_enabled(_DecoratorContextManager): 144 r"""Context-manager that sets gradient calculation on or off. 145 146 ``set_grad_enabled`` will enable or disable grads based on its argument :attr:`mode`. 147 It can be used as a context-manager or as a function. 148 149 This context manager is thread local; it will not affect computation 150 in other threads. 151 152 Args: 153 mode (bool): Flag whether to enable grad (``True``), or disable 154 (``False``). This can be used to conditionally enable 155 gradients. 156 157 .. note:: 158 set_grad_enabled is one of several mechanisms that can enable or 159 disable gradients locally see :ref:`locally-disable-grad-doc` for 160 more information on how they compare. 161 162 .. note:: 163 This API does not apply to :ref:`forward-mode AD <forward-mode-ad>`. 164 165 Example:: 166 >>> # xdoctest: +SKIP 167 >>> x = torch.tensor([1.], requires_grad=True) 168 >>> is_train = False 169 >>> with torch.set_grad_enabled(is_train): 170 ... y = x * 2 171 >>> y.requires_grad 172 False 173 >>> _ = torch.set_grad_enabled(True) 174 >>> y = x * 2 175 >>> y.requires_grad 176 True 177 >>> _ = torch.set_grad_enabled(False) 178 >>> y = x * 2 179 >>> y.requires_grad 180 False 181 182 """ 183 184 def __init__(self, mode: bool) -> None: 185 self.prev = torch.is_grad_enabled() 186 self.mode = mode 187 torch._C._set_grad_enabled(mode) 188 189 def __call__(self, orig_func: F) -> F: 190 torch._C._set_grad_enabled(self.prev) 191 return super().__call__(orig_func) 192 193 def __enter__(self) -> None: 194 torch._C._set_grad_enabled(self.mode) 195 196 def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: 197 torch._C._set_grad_enabled(self.prev) 198 199 def clone(self) -> "set_grad_enabled": 200 r""" 201 Create a copy of this class 202 """ 203 return self.__class__(self.mode) 204 205 206class inference_mode(_DecoratorContextManager): 207 r"""Context-manager that enables or disables inference mode. 208 209 InferenceMode is a context manager analogous to :class:`~no_grad` 210 to be used when you are certain your operations will have no interactions 211 with autograd (e.g., model training). Code run under this mode gets better 212 performance by disabling view tracking and version counter bumps. Note that 213 unlike some other mechanisms that locally enable or disable grad, 214 entering inference_mode also disables to :ref:`forward-mode AD <forward-mode-ad>`. 215 216 This context manager is thread local; it will not affect computation 217 in other threads. 218 219 Also functions as a decorator. 220 221 .. note:: 222 Inference mode is one of several mechanisms that can enable or 223 disable gradients locally see :ref:`locally-disable-grad-doc` for 224 more information on how they compare. 225 226 Args: 227 mode (bool or function): Either a boolean flag whether to enable or 228 disable inference mode or a Python function to decorate with 229 inference mode enabled 230 231 Example:: 232 >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD) 233 >>> import torch 234 >>> x = torch.ones(1, 2, 3, requires_grad=True) 235 >>> with torch.inference_mode(): 236 ... y = x * x 237 >>> y.requires_grad 238 False 239 >>> # xdoctest: +SKIP("want string isnt quite right") 240 >>> y._version 241 Traceback (most recent call last): 242 File "<stdin>", line 1, in <module> 243 RuntimeError: Inference tensors do not track version counter. 244 >>> @torch.inference_mode() 245 ... def func(x): 246 ... return x * x 247 >>> out = func(x) 248 >>> out.requires_grad 249 False 250 >>> @torch.inference_mode() 251 ... def doubler(x): 252 ... return x * 2 253 >>> out = doubler(x) 254 >>> out.requires_grad 255 False 256 257 """ 258 259 def __init__(self, mode: bool = True) -> None: 260 if not torch._jit_internal.is_scripting(): 261 super().__init__() 262 self.mode = mode 263 264 def __new__(cls, mode=True): 265 if isinstance(mode, bool): 266 return super().__new__(cls) 267 return cls()(mode) 268 269 def __enter__(self) -> None: 270 self._inference_mode_context = torch._C._InferenceMode(self.mode) 271 self._inference_mode_context.__enter__() 272 273 def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: 274 self._inference_mode_context.__exit__(exc_type, exc_value, traceback) 275 276 def clone(self) -> "inference_mode": 277 r""" 278 Create a copy of this class 279 """ 280 return self.__class__(self.mode) 281 282 283def _enter_inference_mode(mode): 284 mode_context = torch._C._InferenceMode(mode) 285 mode_context.__enter__() 286 return mode_context 287 288 289def _exit_inference_mode(mode): 290 mode.__exit__(None, None, None) 291 292 293class set_multithreading_enabled(_DecoratorContextManager): 294 r"""Context-manager that sets multithreaded backwards on or off. 295 296 ``set_multithreading_enabled`` will enable or disable multithreaded backwards based on its argument :attr:`mode`. 297 It can be used as a context-manager or as a function. 298 299 This context manager is thread local; it will not affect computation 300 in other threads. 301 302 Args: 303 mode (bool): Flag whether to enable multithreaded backwards (``True``), or disable 304 (``False``). 305 306 .. note:: 307 This API does not apply to :ref:`forward-mode AD <forward-mode-ad>`. 308 309 """ 310 311 def __init__(self, mode: bool) -> None: 312 self.prev = torch._C._is_multithreading_enabled() 313 torch._C._set_multithreading_enabled(mode) 314 self.mode = mode 315 316 def __enter__(self) -> None: 317 pass 318 319 def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: 320 torch._C._set_multithreading_enabled(self.prev) 321 322 def clone(self) -> "set_multithreading_enabled": 323 r""" 324 Create a copy of this class 325 """ 326 return self.__class__(self.mode) 327 328 329class _force_original_view_tracking(_DecoratorContextManager): 330 r"""Context-manager that sets whether or not to always enable view-replay in autograd. 331 332 ``set_view_replay_enabled`` will enable or disable view-replay based on its argument :attr:`mode`. 333 It can be used as a context-manager or as a function. 334 335 This context manager is thread local; it will not affect computation 336 in other threads. 337 338 When a tensor view is mutated, the autograd engine needs to decide whether or not 339 to regenerate the "updated view" by either replaying the chain of views from the updated base, 340 or with a single call to as_strided. 341 342 If set_view_replay_enabled is set to True, then autograd will always use view replay. 343 Otherwise, it will fall back to its existing logic. 344 345 Args: 346 mode (bool): Flag whether to enable view-replay (``True``), or disable 347 (``False``). 348 349 """ 350 351 def __init__(self, mode: bool) -> None: 352 self.prev = torch._C._is_view_replay_enabled() 353 torch._C._set_view_replay_enabled(mode) 354 self.mode = mode 355 356 def __enter__(self) -> None: 357 pass 358 359 def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: 360 torch._C._set_view_replay_enabled(self.prev) 361 362 def clone(self): 363 return self.__class__(self.mode) 364 365 366class _unsafe_preserve_version_counter(_DecoratorContextManager): 367 r"""DO NOT USE THIS UNLESS YOU KNOW EXACTLY WHAT YOU'RE DOING. 368 369 This context manager can lead to arbitrary silent-correctness issues in any other part of your code 370 (even the ones not touched directly by the context manager)! 371 372 Ordinarily, autograd will track mutations to tensors by incrementing it's `._version` attribute. 373 This is generally important for correctness, as for example, mutating a tensor that autograd has saved 374 for the backwards pass can result in incorrect gradients, and autograd uses the version counter to detect 375 and error out in this situation. 376 377 However, there are rare instances where it might be useful to hide mutations from autograd. For example: 378 if a tensor is very large, and you'd like to free its memory by storing it elsewhere, and re-populate 379 the tensor right before it is needed by autograd. 380 381 Args: 382 tensor (torch.Tensor): the tensor in question, that you would like to preserve the version counter of. 383 384 .. note:: 385 This API does not apply to :ref:`forward-mode AD <forward-mode-ad>`. 386 387 """ 388 389 def __init__(self, tensor: torch.Tensor) -> None: 390 self.tensor = tensor 391 self.prev_version = tensor._version 392 393 def __enter__(self) -> None: 394 pass 395 396 def __exit__(self, *args) -> None: 397 torch._C._autograd._unsafe_set_version_counter(self.tensor, self.prev_version) 398