xref: /aosp_15_r20/external/pytorch/torch/autograd/graph.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import abc
3import collections
4import contextlib
5import functools
6import logging
7import threading
8import weakref
9from collections import defaultdict, namedtuple
10from typing import (
11    Any,
12    Callable,
13    cast,
14    Deque,
15    Dict,
16    List,
17    Optional,
18    Sequence,
19    Set,
20    Tuple,
21    Union,
22)
23
24import torch
25from torch.autograd.variable import Variable
26from torch.utils._python_dispatch import TorchDispatchMode
27from torch.utils.hooks import RemovableHandle
28
29log = logging.getLogger(__name__)
30
31
32__all__ = [
33    "saved_tensors_hooks",
34    "save_on_cpu",
35    "disable_saved_tensors_hooks",
36    "register_multi_grad_hook",
37    "allow_mutation_on_saved_tensors",
38    "Node",
39    "GradientEdge",
40    "get_gradient_edge",
41    "increment_version",
42]
43
44
45class Node(abc.ABC):
46    @abc.abstractmethod
47    def name(self) -> str:
48        r"""Return the name.
49
50        Example::
51
52            >>> import torch
53            >>> a = torch.tensor([0., 0., 0.], requires_grad=True)
54            >>> b = a.clone()
55            >>> assert isinstance(b.grad_fn, torch.autograd.graph.Node)
56            >>> print(b.grad_fn.name())
57            CloneBackward0
58        """
59        ...
60
61    @property
62    @abc.abstractmethod
63    def next_functions(self) -> Tuple[Tuple[Optional["Node"], int], ...]:
64        ...
65
66    @abc.abstractmethod
67    def metadata(self) -> dict:
68        r"""Return the metadata."""
69        ...
70
71    @abc.abstractmethod
72    def _register_hook_dict(self, tensor: torch.Tensor) -> None:
73        ...
74
75    @abc.abstractmethod
76    def register_hook(self, fn: Callable[..., Any]) -> RemovableHandle:
77        r"""Register a backward hook.
78
79        The hook will be called every time a gradient with respect to the
80        Node is computed. The hook should have the following signature::
81
82            hook(grad_inputs: Tuple[Tensor], grad_outputs: Tuple[Tensor]) -> Tuple[Tensor] or None
83
84
85        The hook should not modify its argument, but it can optionally return
86        a new gradient which will be used in place of :attr:`grad_inputs`.
87
88        This function returns a handle with a method ``handle.remove()``
89        that removes the hook from the module.
90
91        .. note::
92            See :ref:`backward-hooks-execution` for more information on how when this hook
93            is executed, and how its execution is ordered relative to other hooks.
94
95        Example::
96
97            >>> import torch
98            >>> a = torch.tensor([0., 0., 0.], requires_grad=True)
99            >>> b = a.clone()
100            >>> assert isinstance(b.grad_fn, torch.autograd.graph.Node)
101            >>> handle = b.grad_fn.register_hook(lambda gI, gO: (gO[0] * 2,))
102            >>> b.sum().backward(retain_graph=True)
103            >>> print(a.grad)
104            tensor([2., 2., 2.])
105            >>> handle.remove() # Removes the hook
106            >>> a.grad = None
107            >>> b.sum().backward(retain_graph=True)
108            >>> print(a.grad)
109            tensor([1., 1., 1.])
110        """
111        ...
112
113    @abc.abstractmethod
114    def register_prehook(self, fn: Callable[..., Any]) -> RemovableHandle:
115        r"""Register a backward pre-hook.
116
117        The hook will be called every time a gradient with respect to the
118        Node is computed. The hook should have the following signature::
119
120            hook(grad_outputs: Tuple[Tensor]) -> Tuple[Tensor] or None
121
122        The hook should not modify its argument, but it can optionally return
123        a new gradient which will be used in place of :attr:`grad_outputs`.
124
125        This function returns a handle with a method ``handle.remove()``
126        that removes the hook from the module.
127
128        .. note::
129            See :ref:`backward-hooks-execution` for more information on how when this hook
130            is executed, and how its execution is ordered relative to other hooks.
131
132        Example::
133
134            >>> a = torch.tensor([0., 0., 0.], requires_grad=True)
135            >>> b = a.clone()
136            >>> assert isinstance(b.grad_fn, torch.autograd.graph.Node)
137            >>> handle = b.grad_fn.register_prehook(lambda gI: (gI[0] * 2,))
138            >>> b.sum().backward(retain_graph=True)
139            >>> print(a.grad)
140            tensor([2., 2., 2.])
141            >>> handle.remove()
142            >>> a.grad = None
143            >>> b.sum().backward(retain_graph=True)
144            >>> print(a.grad)
145            tensor([1., 1., 1.])
146        """
147        ...
148
149    @classmethod
150    def __subclasshook__(cls, C):
151        if cls is Node:
152            if (
153                C is not None and C is getattr(torch._C._functions, C.__name__, None)
154            ) or issubclass(C, torch.autograd.function.BackwardCFunction):
155                return True
156        return NotImplemented
157
158
159def _get_grad_fn_or_grad_acc(t):
160    if t.requires_grad and t.grad_fn is None:
161        with torch.enable_grad():
162            return t.view_as(t).grad_fn.next_functions[0][0]
163    else:
164        return t.grad_fn
165
166
167GradientEdge = namedtuple("GradientEdge", ("node output_nr"))
168GradientEdge.__doc__ = """\
169Object representing a given gradient edge within the autograd graph.
170To get the gradient edge where a given Tensor gradient will be computed,
171you can do ``edge = autograd.graph.get_gradient_edge(tensor)``.
172"""
173
174
175def get_gradient_edge(tensor):
176    """Get the gradient edge for computing the gradient of the given Tensor.
177
178    In particular, it is equivalent to call
179    ``g = autograd.grad(loss, input)`` and ``g = autograd.grad(loss, get_gradient_edge(input))``.
180    """
181    if not tensor.requires_grad:
182        raise RuntimeError(
183            "It is not possible to get the gradient edge for a Tensor that does not require gradients"
184        )
185    grad_fn = _get_grad_fn_or_grad_acc(tensor)
186
187    # Note that output_nr default to 0 which is the right value
188    # for the AccumulateGrad node.
189    return GradientEdge(grad_fn, tensor.output_nr)
190
191
192def increment_version(tensor):
193    """Update autograd metadata tracking whether the given Tensor was modified in place.
194
195    This is to enable more accurate error checking within the autograd engine.
196    It is already done automatically by PyTorch functions and within custom Function
197    when mark_dirty() is called appropriately so you only need to call this explicitly
198    if you are doing inplace operation on the Tensor data in a way that Pytorch doesn't
199    know about. For example a custom kernel that reads the Tensor data_ptr and modifies
200    the memory inplace based on this pointer.
201
202    Note that incrementing the version counter multiple times for a single inplace operation
203    is not problematic.
204    """
205    torch._C._increment_version(tensor)
206
207
208class saved_tensors_hooks:
209    """Context-manager that sets a pair of pack / unpack hooks for saved tensors.
210
211    Use this context-manager to define how intermediary results of an operation
212    should be packed before saving, and unpacked on retrieval.
213
214    In that context, the ``pack_hook`` function will be called everytime an
215    operation saves a tensor for backward (this includes intermediary results
216    saved using
217    :func:`~torch.autograd.function._ContextMethodMixin.save_for_backward` but
218    also those recorded by a PyTorch-defined operation). The output of
219    ``pack_hook`` is then stored in the computation graph instead of the
220    original tensor.
221
222    The ``unpack_hook`` is called when the saved tensor needs to be accessed,
223    namely when executing :func:`torch.Tensor.backward()` or
224    :func:`torch.autograd.grad()`. It takes as argument the *packed* object
225    returned by ``pack_hook`` and should return a tensor which has the same
226    content as the original tensor (passed as input to the corresponding
227    ``pack_hook``).
228
229    The hooks should have the following signatures:
230
231        pack_hook(tensor: Tensor) -> Any
232
233        unpack_hook(Any) -> Tensor
234
235    where the return value of ``pack_hook`` is a valid input to ``unpack_hook``.
236
237    In general, you want ``unpack_hook(pack_hook(t))`` to be equal to ``t`` in terms
238    of value, size, dtype and device.
239
240    Example::
241
242        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
243        >>> def pack_hook(x):
244        ...     print("Packing", x)
245        ...     return x
246        >>>
247        >>> def unpack_hook(x):
248        ...     print("Unpacking", x)
249        ...     return x
250        >>>
251        >>> a = torch.ones(5, requires_grad=True)
252        >>> b = torch.ones(5, requires_grad=True) * 2
253        >>> with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook):
254        ...     y = a * b
255        Packing tensor([1., 1., 1., 1., 1.], requires_grad=True)
256        Packing tensor([2., 2., 2., 2., 2.], grad_fn=<MulBackward0>)
257        >>> y.sum().backward()
258        Unpacking tensor([1., 1., 1., 1., 1.], requires_grad=True)
259        Unpacking tensor([2., 2., 2., 2., 2.], grad_fn=<MulBackward0>)
260
261    .. warning ::
262        Performing an inplace operation on the input to either hooks may lead
263        to undefined behavior.
264
265    .. warning ::
266        Only one pair of hooks is allowed at a time. When recursively nesting this
267        context-manager, only the inner-most pair of hooks will be applied.
268    """
269
270    def __init__(
271        self,
272        pack_hook: Callable[[torch.Tensor], Any],
273        unpack_hook: Callable[[Any], torch.Tensor],
274    ):
275        self.pack_hook = pack_hook
276        self.unpack_hook = unpack_hook
277
278    def __enter__(self):
279        torch._C._autograd._push_saved_tensors_default_hooks(
280            self.pack_hook, self.unpack_hook
281        )
282
283    def __exit__(self, *args: object):
284        torch._C._autograd._pop_saved_tensors_default_hooks()
285
286
287class save_on_cpu(saved_tensors_hooks):
288    """Context manager under which tensors saved by the forward pass will be stored on cpu, then retrieved for backward.
289
290    When performing operations within this context manager, intermediary
291    results saved in the graph during the forward pass will be moved to CPU,
292    then copied back to the original device when needed for the backward pass.
293    If the graph was already on CPU, no tensor copy is performed.
294
295    Use this context-manager to trade compute for GPU memory usage (e.g.
296    when your model doesn't fit in GPU memory during training).
297
298    Args:
299        pin_memory (bool): If ``True`` tensors will be saved to CPU pinned memory
300                           during packing and copied to GPU asynchronously during unpacking.
301                           Defaults to ``False``.
302                           Also see :ref:`cuda-memory-pinning`.
303
304
305    Example::
306
307        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
308        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
309        >>> a = torch.randn(5, requires_grad=True, device="cuda")
310        >>> b = torch.randn(5, requires_grad=True, device="cuda")
311        >>> c = torch.randn(5, requires_grad=True, device="cuda")
312        >>>
313        >>> def f(a, b, c):
314        ...     prod_1 = a * b           # a and b are saved on GPU
315        ...     with torch.autograd.graph.save_on_cpu():
316        ...         prod_2 = prod_1 * c  # prod_1 and c are saved on CPU
317        ...     y = prod_2 * a           # prod_2 and a are saved on GPU
318        ...     return y
319        >>>
320        >>> y = f(a, b, c)
321        >>> del a, b, c  # for illustration only
322        >>> # the content of a, b, and prod_2 are still alive on GPU
323        >>> # the content of prod_1 and c only live on CPU
324        >>> y.sum().backward()  # all CPU tensors are moved back to GPU, for backward
325        >>> # all intermediary tensors are released (deleted) after the call to backward
326
327    """
328
329    def __init__(self, pin_memory=False, device_type="cuda"):
330        device_module = getattr(torch, device_type, torch.cuda)
331
332        def pack_to_cpu(tensor):
333            if not pin_memory:
334                return (tensor.device, tensor.cpu())
335            packed = torch.empty(
336                tensor.size(),
337                dtype=tensor.dtype,
338                layout=tensor.layout,
339                pin_memory=(device_module.is_available() and not tensor.is_sparse),
340            )
341            packed.copy_(tensor)
342            return (tensor.device, packed)
343
344        def unpack_from_cpu(packed):
345            device, tensor = packed
346            return tensor.to(device, non_blocking=pin_memory)
347
348        super().__init__(pack_to_cpu, unpack_from_cpu)
349
350
351@contextlib.contextmanager
352def disable_saved_tensors_hooks(error_message):
353    """Context-manager that disables the saved tensors default hooks feature.
354
355    Useful for if you are creating a feature that does not work with saved
356    tensors default hooks.
357
358    Args:
359        error_message (str): When saved tensors default hooks are used when they
360                             have been are disabled, a RuntimeError with this
361                             error message gets raised.
362
363    Example::
364
365        >>> # xdoctest: +SKIP(failing)
366        >>> message = "saved tensors default hooks are disabled"
367        >>> with torch.autograd.graph.disable_saved_tensors_hooks(message):
368        ...     # Raises RuntimeError: saved tensors default hooks are disabled
369        ...     with torch.autograd.graph.save_on_cpu():
370        ...         pass
371
372    """
373    try:
374        maybe_prev_message = (
375            torch._C._autograd._saved_tensors_hooks_get_disabled_error_message()
376        )
377        torch._C._autograd._saved_tensors_hooks_disable(error_message)
378        yield
379    finally:
380        # See NOTE: [disabled_error_message invariant]
381        if maybe_prev_message is None:
382            torch._C._autograd._saved_tensors_hooks_enable()
383        else:
384            torch._C._autograd._saved_tensors_hooks_disable(maybe_prev_message)
385
386
387class _MultiHandle(RemovableHandle):
388    handles: Tuple[RemovableHandle, ...]
389
390    def __init__(self, handles: Tuple[RemovableHandle, ...]):
391        self.handles = handles
392
393    def remove(self):
394        for handle in self.handles:
395            handle.remove()
396
397    def __getstate__(self):
398        return self.handles
399
400    def __setstate__(self, state):
401        self.handles = state
402
403
404def register_multi_grad_hook(
405    tensors: Sequence[torch.Tensor],
406    fn: Union[
407        Callable[[Sequence[Optional[torch.Tensor]]], None],
408        Callable[[torch.Tensor], None],
409    ],
410    *,
411    mode: str = "all",
412):
413    r"""Register a multi-grad backward hook.
414
415    There are two supported modes: ``"all"`` and ``"any"``.
416
417    Under the ``"all"`` mode, the hook will be called after gradients with respect to every tensor in
418    :attr:`tensors` have been computed. If a tensor is in :attr:`tensors` but
419    is not part of the graph, or if a tensor is not needed to compute the gradients
420    for any ``inputs`` specified for the current ``.backward()`` or ``.grad()`` call,
421    this tensor will be ignored and the hook will not wait for its gradient to be
422    computed.
423
424    After every non-ignored tensor's gradient has been computed, :attr:`fn` will be
425    called with those gradients. ``None`` will be passed for tensors that did not
426    have their gradients computed.
427
428    Under the ``"any"`` mode, the hook will be called after the first gradient
429    with respect to a tensor in :attr:`tensors` has been computed. The hook
430    will be called with that gradient as its argument.
431
432    The hook should not modify its arguments.
433
434    This function returns a handle with a method ``handle.remove()`` that removes the hook.
435
436    .. note::
437        See :ref:`backward-hooks-execution` for more information on how when this hook
438        is executed, and how its execution is ordered relative to other hooks.
439
440    Example::
441
442        >>> import torch
443        >>>
444        >>> a = torch.rand(2, 3, requires_grad=True)
445        >>> b = torch.rand(2, 3, requires_grad=True)
446        >>> c = a * b
447        >>> d = a * b
448        >>>
449        >>> def fn(grads):
450        ...     print([g is not None for g in grads])
451        ...
452        >>> torch.autograd.graph.register_multi_grad_hook((a, b, c, d), fn)
453        >>>
454        >>> c.sum().backward(retain_graph=True)
455        [True, True, True, False]
456        >>> c.sum().backward(inputs=(a,), retain_graph=True)
457        [True, False, True, False]
458        >>>
459    """
460    supported_modes = ("all", "any")
461    if mode not in supported_modes:
462        raise ValueError(f"Expects mode to be one of {supported_modes} but got {mode}")
463
464    if mode == "all":
465        count: Dict[int, int] = dict()
466        nb_calls = None
467        buffer: Dict[int, List[Optional[torch.Tensor]]] = dict()
468
469        grad_fns = list(map(_get_grad_fn_or_grad_acc, tensors))
470        len_tensors = len(tensors)
471
472        def get_inner_hook(idx):
473            def inner_hook(grad: torch.Tensor):
474                nonlocal count, nb_calls, buffer, fn
475                id = torch._C._current_graph_task_id()
476                assert (
477                    id != -1
478                ), "expected this hook to be called inside a backward call"
479                count[id] = count.get(id, 0)
480                buffer[id] = buffer.get(id, [None] * len_tensors)
481
482                if count[id] == 0:
483                    # On the first call, compute the actual nb_calls and buffer
484                    nb_calls = sum(torch._C._will_engine_execute_node(g) for g in grad_fns)  # type: ignore[attr-defined]
485
486                buffer[id][idx] = grad
487                count[id] += 1
488
489                if count[id] == nb_calls:
490                    fn = cast(Callable[[Sequence[Optional[torch.Tensor]]], None], fn)
491                    fn(buffer[id])
492                    del count[id]
493                    del buffer[id]
494
495            return inner_hook
496
497        handles: Tuple[RemovableHandle] = tuple(
498            t.register_hook(get_inner_hook(i)) for i, t in enumerate(tensors)
499        )
500    elif mode == "any":
501        fn = cast(Callable[[torch.Tensor], None], fn)
502        lock = threading.Lock()
503        ran_hook: Dict[int, bool] = defaultdict(bool)
504
505        @functools.wraps(fn)
506        def wrapped_fn(grad: torch.Tensor):
507            nonlocal ran_hook
508            id = torch._C._current_graph_task_id()
509            assert id != -1, "expected this hook to be called inside a backward call"
510            with lock:
511                prev, ran_hook[id] = ran_hook[id], True
512            if prev:
513                return
514            fn(grad)
515
516        handles = tuple(
517            tensor.register_hook(wrapped_fn)
518            for tensor in tensors
519            if tensor.requires_grad
520        )
521
522    return _MultiHandle(handles)  # type: ignore[possibly-undefined]
523
524
525# NOTE [Allow mutation on tensors saved for backward]
526#
527# 1. Tensor gets saved for backward
528#    - remember the python object id and the version of the tensor
529#    - remember aliasing information (data_ptr of base + version)
530#    - save the original so we control its lifetime
531# 2. Any time a tensor gets in-placed
532#    - for each tensor aliased to it:
533#      - check using its object id and version to see if it has been saved
534#      - if it has been saved, clone it
535#      - delete the reference to the original
536# 3. during backward
537#    - if the clone exists, the tensor must've been modified in-place
538_allow_mutation_on_saved_tensors_enabled = False
539
540
541def _get_tid(t) -> Tuple[int, int, int]:
542    # FIXME: This is almost definitely a bug.
543    if isinstance(
544        t,
545        (
546            torch._subclasses.fake_tensor.FakeTensor,
547            torch._subclasses.functional_tensor.FunctionalTensor,
548        ),
549    ):
550        data_ptr = 0
551    else:
552        data_ptr = t.data_ptr()
553    return (id(t), data_ptr, t._version)
554
555
556def _get_sid(t) -> Tuple[int, int]:
557    # FIXME: This is almost definitely a bug.
558    if isinstance(
559        t,
560        (
561            torch._subclasses.fake_tensor.FakeTensor,
562            torch._subclasses.functional_tensor.FunctionalTensor,
563        ),
564    ):
565        data_ptr = 0
566    else:
567        data_ptr = t.data_ptr()
568    return (data_ptr, t._version)
569
570
571class _Handle:
572    pass
573
574
575class _swap_with_cloned(saved_tensors_hooks):
576    def __init__(self, ctx):
577        def pack_hook(t):
578            tid = _get_tid(t)
579            sid = _get_sid(t)
580            # Tensors saved for backward have an entry in _tid_to_weakhandle
581            handle: Optional[_Handle] = None
582
583            # Save aliasing information
584            ctx.sid_to_tid[sid].add(tid)
585
586            # NB: The same tensor (of the same version) can be saved multiple times
587            if tid not in ctx.tid_to_weakhandle:
588                handle = _Handle()
589                ctx.tid_to_weakhandle[tid] = handle
590                ctx.original[handle] = t
591            else:
592                # Store an additional strong reference to the handle
593                handle = ctx.tid_to_weakhandle[tid]
594            return handle
595
596        def unpack_hook(tup):
597            handle = tup
598            error_msg = (
599                "Trying to backward outside of the 'allow_mutation_on_saved_tensors' context"
600                "in which the graph was originally recorded."
601            )
602            assert _allow_mutation_on_saved_tensors_enabled, error_msg
603            if handle in ctx.cloned:
604                res = ctx.cloned[handle]
605            else:
606                assert handle in ctx.original, error_msg
607                res = ctx.original[handle]
608            return res
609
610        super().__init__(pack_hook, unpack_hook)
611
612
613class _CloneArgBeforeMutateMode(TorchDispatchMode):
614    def __init__(self, ctx):
615        self.ctx = ctx
616
617    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
618        kwargs = kwargs or {}
619
620        for idx, arg in enumerate(func._schema.arguments):
621            if arg.alias_info is not None and arg.alias_info.is_write:
622                t = kwargs["out"] if arg.is_out else args[idx]
623                tid = _get_tid(t)
624                sid = _get_sid(t)
625                ctx = self.ctx
626                if sid in ctx.sid_to_tid:
627                    for tid in ctx.sid_to_tid[sid]:
628                        if tid not in ctx.tid_to_weakhandle:
629                            # We know that if tid is in sid_to_tid, then it must also be in
630                            # tid_to_weakhandle. However, it is possible for the tensor to be
631                            # saved at one point, but cleared by backward before it is modified
632                            # in-place. Consider the following example:
633                            #
634                            # >>> a = torch.randn(2, 3, requires_grad=True).clone()
635                            # >>> out = (a**2).sum()
636                            # >>> out.backward()
637                            # >>> a.sin_()
638                            continue
639                        handle = ctx.tid_to_weakhandle[tid]
640                        if handle in ctx.cloned:
641                            # The same exact tensor has been cloned already
642                            continue
643                        ctx.cloned[handle] = ctx.original[handle].clone()
644                        del ctx.original[handle]
645
646        rs = func(*args, **kwargs)
647        return rs
648
649
650class _AllowMutationOnSavedContext:
651    def __init__(self):
652        self.cloned: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary()
653        self.original: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary()
654        self.tid_to_weakhandle: weakref.WeakValueDictionary = (
655            weakref.WeakValueDictionary()
656        )
657        self.sid_to_tid: Dict[Tuple[int, int], Set[Tuple[int, int, int]]] = defaultdict(
658            set
659        )
660
661    def clear(self):
662        self.cloned.clear()
663        self.original.clear()
664        self.tid_to_weakhandle.clear()
665        self.sid_to_tid.clear()
666
667
668@contextlib.contextmanager
669def allow_mutation_on_saved_tensors():
670    """Context manager under which mutating tensors saved for backward is allowed.
671
672    Under this context manager, tensors saved for backward are cloned on mutation,
673    so the original version can still be used during backward. Normally, mutating a tensor
674    saved for backward will result in an error raised when it's used during backward.
675
676    To ensure the correct behavior, both the forward and backward should be run under
677    the same context manager.
678
679    returns:
680        An _AllowMutationOnSavedContext object storing the state managed by this
681        context manager. This object can be useful for debugging purposes. The state
682        managed by the context manager is automatically cleared upon exiting.
683
684    Example::
685
686        >>> import torch
687        >>> with torch.autograd.graph.allow_mutation_on_saved_tensors():
688        ...     # forward
689        ...     a = torch.ones(2, 3, requires_grad=True)
690        ...     b = a.clone()
691        ...     out = (b**2).sum()
692        ...     b.sin_()
693        ...     # backward
694        ...     out.sum().backward()
695        ...
696        tensor([[0.8415, 0.8415, 0.8415],
697                [0.8415, 0.8415, 0.8415]], grad_fn=<SinBackward0>)
698    """
699    global _allow_mutation_on_saved_tensors_enabled
700
701    ctx = _AllowMutationOnSavedContext()
702
703    with _swap_with_cloned(ctx), _CloneArgBeforeMutateMode(ctx):
704        try:
705            if _allow_mutation_on_saved_tensors_enabled:
706                raise RuntimeError(
707                    "allow_mutation_on_saved_tensors contexts cannot be nested"
708                )
709            _allow_mutation_on_saved_tensors_enabled = True
710            yield ctx
711        finally:
712            ctx.clear()
713            _allow_mutation_on_saved_tensors_enabled = False
714
715
716def _register_logging_hooks_on_whole_graph(t_outputs: List[torch.Tensor]):
717    grad_fns = list(map(_get_grad_fn_or_grad_acc, t_outputs))
718
719    def iter_graph(roots):
720        if not roots:
721            return
722        seen = set()
723        q: Deque = collections.deque()
724        for node in roots:
725            if node is not None:
726                seen.add(node)
727                q.append(node)
728
729        while q:
730            node = q.popleft()
731            for fn, _idx in node.next_functions:
732                if fn in seen or fn is None:
733                    continue
734                seen.add(fn)
735                q.append(fn)
736
737            yield node
738
739    def fmt(t):
740        # Avoid circular import
741        from torch.testing._internal.common_utils import dtype_abbrs
742
743        if t is None:
744            return "None"
745        return f"{dtype_abbrs[t.dtype]}[{', '.join(map(str, t.shape))}]"
746
747    def prehook(grad_outputs):
748        node = torch._C._current_autograd_node()
749        grad_outputs_str = f"[{','.join(fmt(t) for t in grad_outputs)}]"
750        log_str = f"Executing: {node} with grad_outputs: {grad_outputs_str}"
751        log.debug(log_str)
752
753    handles = []
754    for node in iter_graph(grad_fns):
755        handles.append(node.register_prehook(prehook))
756
757    def unregister_hooks():
758        for handle in handles:
759            handle.remove()
760
761    return unregister_hooks
762
763
764def _engine_run_backward(t_outputs, *args, **kwargs):
765    attach_logging_hooks = log.getEffectiveLevel() <= logging.DEBUG
766    if attach_logging_hooks:
767        unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs)
768    try:
769        return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
770            t_outputs, *args, **kwargs
771        )  # Calls into the C++ engine to run the backward pass
772    finally:
773        if attach_logging_hooks:
774            unregister_hooks()  # type: ignore[possibly-undefined]
775