xref: /aosp_15_r20/external/pytorch/torch/autograd/grad_mode.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from typing import Any
3
4import torch
5from torch.utils._contextlib import (
6    _DecoratorContextManager,
7    _NoParamDecoratorContextManager,
8    F,
9)
10
11
12__all__ = [
13    "no_grad",
14    "enable_grad",
15    "set_grad_enabled",
16    "inference_mode",
17    "set_multithreading_enabled",
18]
19
20
21class no_grad(_NoParamDecoratorContextManager):
22    r"""Context-manager that disables gradient calculation.
23
24    Disabling gradient calculation is useful for inference, when you are sure
25    that you will not call :meth:`Tensor.backward()`. It will reduce memory
26    consumption for computations that would otherwise have `requires_grad=True`.
27
28    In this mode, the result of every computation will have
29    `requires_grad=False`, even when the inputs have `requires_grad=True`.
30    There is an exception! All factory functions, or functions that create
31    a new Tensor and take a requires_grad kwarg, will NOT be affected by
32    this mode.
33
34    This context manager is thread local; it will not affect computation
35    in other threads.
36
37    Also functions as a decorator.
38
39    .. note::
40        No-grad is one of several mechanisms that can enable or
41        disable gradients locally see :ref:`locally-disable-grad-doc` for
42        more information on how they compare.
43
44    .. note::
45        This API does not apply to :ref:`forward-mode AD <forward-mode-ad>`.
46        If you want to disable forward AD for a computation, you can unpack
47        your dual tensors.
48
49    Example::
50        >>> # xdoctest: +SKIP
51        >>> x = torch.tensor([1.], requires_grad=True)
52        >>> with torch.no_grad():
53        ...     y = x * 2
54        >>> y.requires_grad
55        False
56        >>> @torch.no_grad()
57        ... def doubler(x):
58        ...     return x * 2
59        >>> z = doubler(x)
60        >>> z.requires_grad
61        False
62        >>> @torch.no_grad()
63        ... def tripler(x):
64        ...     return x * 3
65        >>> z = tripler(x)
66        >>> z.requires_grad
67        False
68        >>> # factory function exception
69        >>> with torch.no_grad():
70        ...     a = torch.nn.Parameter(torch.rand(10))
71        >>> a.requires_grad
72        True
73    """
74
75    def __init__(self) -> None:
76        if not torch._jit_internal.is_scripting():
77            super().__init__()
78        self.prev = False
79
80    def __enter__(self) -> None:
81        self.prev = torch.is_grad_enabled()
82        torch.set_grad_enabled(False)
83
84    def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
85        torch.set_grad_enabled(self.prev)
86
87
88class enable_grad(_NoParamDecoratorContextManager):
89    r"""Context-manager that enables gradient calculation.
90
91    Enables gradient calculation, if it has been disabled via :class:`~no_grad`
92    or :class:`~set_grad_enabled`.
93
94    This context manager is thread local; it will not affect computation
95    in other threads.
96
97    Also functions as a decorator.
98
99    .. note::
100        enable_grad is one of several mechanisms that can enable or
101        disable gradients locally see :ref:`locally-disable-grad-doc` for
102        more information on how they compare.
103
104    .. note::
105        This API does not apply to :ref:`forward-mode AD <forward-mode-ad>`.
106
107    Example::
108        >>> # xdoctest: +SKIP
109        >>> x = torch.tensor([1.], requires_grad=True)
110        >>> with torch.no_grad():
111        ...     with torch.enable_grad():
112        ...         y = x * 2
113        >>> y.requires_grad
114        True
115        >>> y.backward()
116        >>> x.grad
117        tensor([2.])
118        >>> @torch.enable_grad()
119        ... def doubler(x):
120        ...     return x * 2
121        >>> with torch.no_grad():
122        ...     z = doubler(x)
123        >>> z.requires_grad
124        True
125        >>> @torch.enable_grad()
126        ... def tripler(x):
127        ...     return x * 3
128        >>> with torch.no_grad():
129        ...     z = tripler(x)
130        >>> z.requires_grad
131        True
132
133    """
134
135    def __enter__(self) -> None:
136        self.prev = torch.is_grad_enabled()
137        torch._C._set_grad_enabled(True)
138
139    def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
140        torch._C._set_grad_enabled(self.prev)
141
142
143class set_grad_enabled(_DecoratorContextManager):
144    r"""Context-manager that sets gradient calculation on or off.
145
146    ``set_grad_enabled`` will enable or disable grads based on its argument :attr:`mode`.
147    It can be used as a context-manager or as a function.
148
149    This context manager is thread local; it will not affect computation
150    in other threads.
151
152    Args:
153        mode (bool): Flag whether to enable grad (``True``), or disable
154                     (``False``). This can be used to conditionally enable
155                     gradients.
156
157    .. note::
158        set_grad_enabled is one of several mechanisms that can enable or
159        disable gradients locally see :ref:`locally-disable-grad-doc` for
160        more information on how they compare.
161
162    .. note::
163        This API does not apply to :ref:`forward-mode AD <forward-mode-ad>`.
164
165    Example::
166        >>> # xdoctest: +SKIP
167        >>> x = torch.tensor([1.], requires_grad=True)
168        >>> is_train = False
169        >>> with torch.set_grad_enabled(is_train):
170        ...     y = x * 2
171        >>> y.requires_grad
172        False
173        >>> _ = torch.set_grad_enabled(True)
174        >>> y = x * 2
175        >>> y.requires_grad
176        True
177        >>> _ = torch.set_grad_enabled(False)
178        >>> y = x * 2
179        >>> y.requires_grad
180        False
181
182    """
183
184    def __init__(self, mode: bool) -> None:
185        self.prev = torch.is_grad_enabled()
186        self.mode = mode
187        torch._C._set_grad_enabled(mode)
188
189    def __call__(self, orig_func: F) -> F:
190        torch._C._set_grad_enabled(self.prev)
191        return super().__call__(orig_func)
192
193    def __enter__(self) -> None:
194        torch._C._set_grad_enabled(self.mode)
195
196    def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
197        torch._C._set_grad_enabled(self.prev)
198
199    def clone(self) -> "set_grad_enabled":
200        r"""
201        Create a copy of this class
202        """
203        return self.__class__(self.mode)
204
205
206class inference_mode(_DecoratorContextManager):
207    r"""Context-manager that enables or disables inference mode.
208
209    InferenceMode is a context manager analogous to :class:`~no_grad`
210    to be used when you are certain your operations will have no interactions
211    with autograd (e.g., model training). Code run under this mode gets better
212    performance by disabling view tracking and version counter bumps. Note that
213    unlike some other mechanisms that locally enable or disable grad,
214    entering inference_mode also disables to :ref:`forward-mode AD <forward-mode-ad>`.
215
216    This context manager is thread local; it will not affect computation
217    in other threads.
218
219    Also functions as a decorator.
220
221    .. note::
222        Inference mode is one of several mechanisms that can enable or
223        disable gradients locally see :ref:`locally-disable-grad-doc` for
224        more information on how they compare.
225
226    Args:
227        mode (bool or function): Either a boolean flag whether to enable or
228            disable inference mode or a Python function to decorate with
229            inference mode enabled
230
231    Example::
232        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
233        >>> import torch
234        >>> x = torch.ones(1, 2, 3, requires_grad=True)
235        >>> with torch.inference_mode():
236        ...     y = x * x
237        >>> y.requires_grad
238        False
239        >>> # xdoctest: +SKIP("want string isnt quite right")
240        >>> y._version
241        Traceback (most recent call last):
242        File "<stdin>", line 1, in <module>
243        RuntimeError: Inference tensors do not track version counter.
244        >>> @torch.inference_mode()
245        ... def func(x):
246        ...     return x * x
247        >>> out = func(x)
248        >>> out.requires_grad
249        False
250        >>> @torch.inference_mode()
251        ... def doubler(x):
252        ...     return x * 2
253        >>> out = doubler(x)
254        >>> out.requires_grad
255        False
256
257    """
258
259    def __init__(self, mode: bool = True) -> None:
260        if not torch._jit_internal.is_scripting():
261            super().__init__()
262        self.mode = mode
263
264    def __new__(cls, mode=True):
265        if isinstance(mode, bool):
266            return super().__new__(cls)
267        return cls()(mode)
268
269    def __enter__(self) -> None:
270        self._inference_mode_context = torch._C._InferenceMode(self.mode)
271        self._inference_mode_context.__enter__()
272
273    def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
274        self._inference_mode_context.__exit__(exc_type, exc_value, traceback)
275
276    def clone(self) -> "inference_mode":
277        r"""
278        Create a copy of this class
279        """
280        return self.__class__(self.mode)
281
282
283def _enter_inference_mode(mode):
284    mode_context = torch._C._InferenceMode(mode)
285    mode_context.__enter__()
286    return mode_context
287
288
289def _exit_inference_mode(mode):
290    mode.__exit__(None, None, None)
291
292
293class set_multithreading_enabled(_DecoratorContextManager):
294    r"""Context-manager that sets multithreaded backwards on or off.
295
296    ``set_multithreading_enabled`` will enable or disable multithreaded backwards based on its argument :attr:`mode`.
297    It can be used as a context-manager or as a function.
298
299    This context manager is thread local; it will not affect computation
300    in other threads.
301
302    Args:
303        mode (bool): Flag whether to enable multithreaded backwards (``True``), or disable
304                     (``False``).
305
306    .. note::
307        This API does not apply to :ref:`forward-mode AD <forward-mode-ad>`.
308
309    """
310
311    def __init__(self, mode: bool) -> None:
312        self.prev = torch._C._is_multithreading_enabled()
313        torch._C._set_multithreading_enabled(mode)
314        self.mode = mode
315
316    def __enter__(self) -> None:
317        pass
318
319    def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
320        torch._C._set_multithreading_enabled(self.prev)
321
322    def clone(self) -> "set_multithreading_enabled":
323        r"""
324        Create a copy of this class
325        """
326        return self.__class__(self.mode)
327
328
329class _force_original_view_tracking(_DecoratorContextManager):
330    r"""Context-manager that sets whether or not to always enable view-replay in autograd.
331
332    ``set_view_replay_enabled`` will enable or disable view-replay based on its argument :attr:`mode`.
333    It can be used as a context-manager or as a function.
334
335    This context manager is thread local; it will not affect computation
336    in other threads.
337
338    When a tensor view is mutated, the autograd engine needs to decide whether or not
339    to regenerate the "updated view" by either replaying the chain of views from the updated base,
340    or with a single call to as_strided.
341
342    If set_view_replay_enabled is set to True, then autograd will always use view replay.
343    Otherwise, it will fall back to its existing logic.
344
345    Args:
346        mode (bool): Flag whether to enable view-replay (``True``), or disable
347                     (``False``).
348
349    """
350
351    def __init__(self, mode: bool) -> None:
352        self.prev = torch._C._is_view_replay_enabled()
353        torch._C._set_view_replay_enabled(mode)
354        self.mode = mode
355
356    def __enter__(self) -> None:
357        pass
358
359    def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
360        torch._C._set_view_replay_enabled(self.prev)
361
362    def clone(self):
363        return self.__class__(self.mode)
364
365
366class _unsafe_preserve_version_counter(_DecoratorContextManager):
367    r"""DO NOT USE THIS UNLESS YOU KNOW EXACTLY WHAT YOU'RE DOING.
368
369    This context manager can lead to arbitrary silent-correctness issues in any other part of your code
370    (even the ones not touched directly by the context manager)!
371
372    Ordinarily, autograd will track mutations to tensors by incrementing it's `._version` attribute.
373    This is generally important for correctness, as for example, mutating a tensor that autograd has saved
374    for the backwards pass can result in incorrect gradients, and autograd uses the version counter to detect
375    and error out in this situation.
376
377    However, there are rare instances where it might be useful to hide mutations from autograd. For example:
378    if a tensor is very large, and you'd like to free its memory by storing it elsewhere, and re-populate
379    the tensor right before it is needed by autograd.
380
381    Args:
382        tensor (torch.Tensor): the tensor in question, that you would like to preserve the version counter of.
383
384    .. note::
385        This API does not apply to :ref:`forward-mode AD <forward-mode-ad>`.
386
387    """
388
389    def __init__(self, tensor: torch.Tensor) -> None:
390        self.tensor = tensor
391        self.prev_version = tensor._version
392
393    def __enter__(self) -> None:
394        pass
395
396    def __exit__(self, *args) -> None:
397        torch._C._autograd._unsafe_set_version_counter(self.tensor, self.prev_version)
398