xref: /aosp_15_r20/external/pytorch/torch/autograd/graph.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs
2*da0073e9SAndroid Build Coastguard Workerimport abc
3*da0073e9SAndroid Build Coastguard Workerimport collections
4*da0073e9SAndroid Build Coastguard Workerimport contextlib
5*da0073e9SAndroid Build Coastguard Workerimport functools
6*da0073e9SAndroid Build Coastguard Workerimport logging
7*da0073e9SAndroid Build Coastguard Workerimport threading
8*da0073e9SAndroid Build Coastguard Workerimport weakref
9*da0073e9SAndroid Build Coastguard Workerfrom collections import defaultdict, namedtuple
10*da0073e9SAndroid Build Coastguard Workerfrom typing import (
11*da0073e9SAndroid Build Coastguard Worker    Any,
12*da0073e9SAndroid Build Coastguard Worker    Callable,
13*da0073e9SAndroid Build Coastguard Worker    cast,
14*da0073e9SAndroid Build Coastguard Worker    Deque,
15*da0073e9SAndroid Build Coastguard Worker    Dict,
16*da0073e9SAndroid Build Coastguard Worker    List,
17*da0073e9SAndroid Build Coastguard Worker    Optional,
18*da0073e9SAndroid Build Coastguard Worker    Sequence,
19*da0073e9SAndroid Build Coastguard Worker    Set,
20*da0073e9SAndroid Build Coastguard Worker    Tuple,
21*da0073e9SAndroid Build Coastguard Worker    Union,
22*da0073e9SAndroid Build Coastguard Worker)
23*da0073e9SAndroid Build Coastguard Worker
24*da0073e9SAndroid Build Coastguard Workerimport torch
25*da0073e9SAndroid Build Coastguard Workerfrom torch.autograd.variable import Variable
26*da0073e9SAndroid Build Coastguard Workerfrom torch.utils._python_dispatch import TorchDispatchMode
27*da0073e9SAndroid Build Coastguard Workerfrom torch.utils.hooks import RemovableHandle
28*da0073e9SAndroid Build Coastguard Worker
29*da0073e9SAndroid Build Coastguard Workerlog = logging.getLogger(__name__)
30*da0073e9SAndroid Build Coastguard Worker
31*da0073e9SAndroid Build Coastguard Worker
32*da0073e9SAndroid Build Coastguard Worker__all__ = [
33*da0073e9SAndroid Build Coastguard Worker    "saved_tensors_hooks",
34*da0073e9SAndroid Build Coastguard Worker    "save_on_cpu",
35*da0073e9SAndroid Build Coastguard Worker    "disable_saved_tensors_hooks",
36*da0073e9SAndroid Build Coastguard Worker    "register_multi_grad_hook",
37*da0073e9SAndroid Build Coastguard Worker    "allow_mutation_on_saved_tensors",
38*da0073e9SAndroid Build Coastguard Worker    "Node",
39*da0073e9SAndroid Build Coastguard Worker    "GradientEdge",
40*da0073e9SAndroid Build Coastguard Worker    "get_gradient_edge",
41*da0073e9SAndroid Build Coastguard Worker    "increment_version",
42*da0073e9SAndroid Build Coastguard Worker]
43*da0073e9SAndroid Build Coastguard Worker
44*da0073e9SAndroid Build Coastguard Worker
45*da0073e9SAndroid Build Coastguard Workerclass Node(abc.ABC):
46*da0073e9SAndroid Build Coastguard Worker    @abc.abstractmethod
47*da0073e9SAndroid Build Coastguard Worker    def name(self) -> str:
48*da0073e9SAndroid Build Coastguard Worker        r"""Return the name.
49*da0073e9SAndroid Build Coastguard Worker
50*da0073e9SAndroid Build Coastguard Worker        Example::
51*da0073e9SAndroid Build Coastguard Worker
52*da0073e9SAndroid Build Coastguard Worker            >>> import torch
53*da0073e9SAndroid Build Coastguard Worker            >>> a = torch.tensor([0., 0., 0.], requires_grad=True)
54*da0073e9SAndroid Build Coastguard Worker            >>> b = a.clone()
55*da0073e9SAndroid Build Coastguard Worker            >>> assert isinstance(b.grad_fn, torch.autograd.graph.Node)
56*da0073e9SAndroid Build Coastguard Worker            >>> print(b.grad_fn.name())
57*da0073e9SAndroid Build Coastguard Worker            CloneBackward0
58*da0073e9SAndroid Build Coastguard Worker        """
59*da0073e9SAndroid Build Coastguard Worker        ...
60*da0073e9SAndroid Build Coastguard Worker
61*da0073e9SAndroid Build Coastguard Worker    @property
62*da0073e9SAndroid Build Coastguard Worker    @abc.abstractmethod
63*da0073e9SAndroid Build Coastguard Worker    def next_functions(self) -> Tuple[Tuple[Optional["Node"], int], ...]:
64*da0073e9SAndroid Build Coastguard Worker        ...
65*da0073e9SAndroid Build Coastguard Worker
66*da0073e9SAndroid Build Coastguard Worker    @abc.abstractmethod
67*da0073e9SAndroid Build Coastguard Worker    def metadata(self) -> dict:
68*da0073e9SAndroid Build Coastguard Worker        r"""Return the metadata."""
69*da0073e9SAndroid Build Coastguard Worker        ...
70*da0073e9SAndroid Build Coastguard Worker
71*da0073e9SAndroid Build Coastguard Worker    @abc.abstractmethod
72*da0073e9SAndroid Build Coastguard Worker    def _register_hook_dict(self, tensor: torch.Tensor) -> None:
73*da0073e9SAndroid Build Coastguard Worker        ...
74*da0073e9SAndroid Build Coastguard Worker
75*da0073e9SAndroid Build Coastguard Worker    @abc.abstractmethod
76*da0073e9SAndroid Build Coastguard Worker    def register_hook(self, fn: Callable[..., Any]) -> RemovableHandle:
77*da0073e9SAndroid Build Coastguard Worker        r"""Register a backward hook.
78*da0073e9SAndroid Build Coastguard Worker
79*da0073e9SAndroid Build Coastguard Worker        The hook will be called every time a gradient with respect to the
80*da0073e9SAndroid Build Coastguard Worker        Node is computed. The hook should have the following signature::
81*da0073e9SAndroid Build Coastguard Worker
82*da0073e9SAndroid Build Coastguard Worker            hook(grad_inputs: Tuple[Tensor], grad_outputs: Tuple[Tensor]) -> Tuple[Tensor] or None
83*da0073e9SAndroid Build Coastguard Worker
84*da0073e9SAndroid Build Coastguard Worker
85*da0073e9SAndroid Build Coastguard Worker        The hook should not modify its argument, but it can optionally return
86*da0073e9SAndroid Build Coastguard Worker        a new gradient which will be used in place of :attr:`grad_inputs`.
87*da0073e9SAndroid Build Coastguard Worker
88*da0073e9SAndroid Build Coastguard Worker        This function returns a handle with a method ``handle.remove()``
89*da0073e9SAndroid Build Coastguard Worker        that removes the hook from the module.
90*da0073e9SAndroid Build Coastguard Worker
91*da0073e9SAndroid Build Coastguard Worker        .. note::
92*da0073e9SAndroid Build Coastguard Worker            See :ref:`backward-hooks-execution` for more information on how when this hook
93*da0073e9SAndroid Build Coastguard Worker            is executed, and how its execution is ordered relative to other hooks.
94*da0073e9SAndroid Build Coastguard Worker
95*da0073e9SAndroid Build Coastguard Worker        Example::
96*da0073e9SAndroid Build Coastguard Worker
97*da0073e9SAndroid Build Coastguard Worker            >>> import torch
98*da0073e9SAndroid Build Coastguard Worker            >>> a = torch.tensor([0., 0., 0.], requires_grad=True)
99*da0073e9SAndroid Build Coastguard Worker            >>> b = a.clone()
100*da0073e9SAndroid Build Coastguard Worker            >>> assert isinstance(b.grad_fn, torch.autograd.graph.Node)
101*da0073e9SAndroid Build Coastguard Worker            >>> handle = b.grad_fn.register_hook(lambda gI, gO: (gO[0] * 2,))
102*da0073e9SAndroid Build Coastguard Worker            >>> b.sum().backward(retain_graph=True)
103*da0073e9SAndroid Build Coastguard Worker            >>> print(a.grad)
104*da0073e9SAndroid Build Coastguard Worker            tensor([2., 2., 2.])
105*da0073e9SAndroid Build Coastguard Worker            >>> handle.remove() # Removes the hook
106*da0073e9SAndroid Build Coastguard Worker            >>> a.grad = None
107*da0073e9SAndroid Build Coastguard Worker            >>> b.sum().backward(retain_graph=True)
108*da0073e9SAndroid Build Coastguard Worker            >>> print(a.grad)
109*da0073e9SAndroid Build Coastguard Worker            tensor([1., 1., 1.])
110*da0073e9SAndroid Build Coastguard Worker        """
111*da0073e9SAndroid Build Coastguard Worker        ...
112*da0073e9SAndroid Build Coastguard Worker
113*da0073e9SAndroid Build Coastguard Worker    @abc.abstractmethod
114*da0073e9SAndroid Build Coastguard Worker    def register_prehook(self, fn: Callable[..., Any]) -> RemovableHandle:
115*da0073e9SAndroid Build Coastguard Worker        r"""Register a backward pre-hook.
116*da0073e9SAndroid Build Coastguard Worker
117*da0073e9SAndroid Build Coastguard Worker        The hook will be called every time a gradient with respect to the
118*da0073e9SAndroid Build Coastguard Worker        Node is computed. The hook should have the following signature::
119*da0073e9SAndroid Build Coastguard Worker
120*da0073e9SAndroid Build Coastguard Worker            hook(grad_outputs: Tuple[Tensor]) -> Tuple[Tensor] or None
121*da0073e9SAndroid Build Coastguard Worker
122*da0073e9SAndroid Build Coastguard Worker        The hook should not modify its argument, but it can optionally return
123*da0073e9SAndroid Build Coastguard Worker        a new gradient which will be used in place of :attr:`grad_outputs`.
124*da0073e9SAndroid Build Coastguard Worker
125*da0073e9SAndroid Build Coastguard Worker        This function returns a handle with a method ``handle.remove()``
126*da0073e9SAndroid Build Coastguard Worker        that removes the hook from the module.
127*da0073e9SAndroid Build Coastguard Worker
128*da0073e9SAndroid Build Coastguard Worker        .. note::
129*da0073e9SAndroid Build Coastguard Worker            See :ref:`backward-hooks-execution` for more information on how when this hook
130*da0073e9SAndroid Build Coastguard Worker            is executed, and how its execution is ordered relative to other hooks.
131*da0073e9SAndroid Build Coastguard Worker
132*da0073e9SAndroid Build Coastguard Worker        Example::
133*da0073e9SAndroid Build Coastguard Worker
134*da0073e9SAndroid Build Coastguard Worker            >>> a = torch.tensor([0., 0., 0.], requires_grad=True)
135*da0073e9SAndroid Build Coastguard Worker            >>> b = a.clone()
136*da0073e9SAndroid Build Coastguard Worker            >>> assert isinstance(b.grad_fn, torch.autograd.graph.Node)
137*da0073e9SAndroid Build Coastguard Worker            >>> handle = b.grad_fn.register_prehook(lambda gI: (gI[0] * 2,))
138*da0073e9SAndroid Build Coastguard Worker            >>> b.sum().backward(retain_graph=True)
139*da0073e9SAndroid Build Coastguard Worker            >>> print(a.grad)
140*da0073e9SAndroid Build Coastguard Worker            tensor([2., 2., 2.])
141*da0073e9SAndroid Build Coastguard Worker            >>> handle.remove()
142*da0073e9SAndroid Build Coastguard Worker            >>> a.grad = None
143*da0073e9SAndroid Build Coastguard Worker            >>> b.sum().backward(retain_graph=True)
144*da0073e9SAndroid Build Coastguard Worker            >>> print(a.grad)
145*da0073e9SAndroid Build Coastguard Worker            tensor([1., 1., 1.])
146*da0073e9SAndroid Build Coastguard Worker        """
147*da0073e9SAndroid Build Coastguard Worker        ...
148*da0073e9SAndroid Build Coastguard Worker
149*da0073e9SAndroid Build Coastguard Worker    @classmethod
150*da0073e9SAndroid Build Coastguard Worker    def __subclasshook__(cls, C):
151*da0073e9SAndroid Build Coastguard Worker        if cls is Node:
152*da0073e9SAndroid Build Coastguard Worker            if (
153*da0073e9SAndroid Build Coastguard Worker                C is not None and C is getattr(torch._C._functions, C.__name__, None)
154*da0073e9SAndroid Build Coastguard Worker            ) or issubclass(C, torch.autograd.function.BackwardCFunction):
155*da0073e9SAndroid Build Coastguard Worker                return True
156*da0073e9SAndroid Build Coastguard Worker        return NotImplemented
157*da0073e9SAndroid Build Coastguard Worker
158*da0073e9SAndroid Build Coastguard Worker
159*da0073e9SAndroid Build Coastguard Workerdef _get_grad_fn_or_grad_acc(t):
160*da0073e9SAndroid Build Coastguard Worker    if t.requires_grad and t.grad_fn is None:
161*da0073e9SAndroid Build Coastguard Worker        with torch.enable_grad():
162*da0073e9SAndroid Build Coastguard Worker            return t.view_as(t).grad_fn.next_functions[0][0]
163*da0073e9SAndroid Build Coastguard Worker    else:
164*da0073e9SAndroid Build Coastguard Worker        return t.grad_fn
165*da0073e9SAndroid Build Coastguard Worker
166*da0073e9SAndroid Build Coastguard Worker
167*da0073e9SAndroid Build Coastguard WorkerGradientEdge = namedtuple("GradientEdge", ("node output_nr"))
168*da0073e9SAndroid Build Coastguard WorkerGradientEdge.__doc__ = """\
169*da0073e9SAndroid Build Coastguard WorkerObject representing a given gradient edge within the autograd graph.
170*da0073e9SAndroid Build Coastguard WorkerTo get the gradient edge where a given Tensor gradient will be computed,
171*da0073e9SAndroid Build Coastguard Workeryou can do ``edge = autograd.graph.get_gradient_edge(tensor)``.
172*da0073e9SAndroid Build Coastguard Worker"""
173*da0073e9SAndroid Build Coastguard Worker
174*da0073e9SAndroid Build Coastguard Worker
175*da0073e9SAndroid Build Coastguard Workerdef get_gradient_edge(tensor):
176*da0073e9SAndroid Build Coastguard Worker    """Get the gradient edge for computing the gradient of the given Tensor.
177*da0073e9SAndroid Build Coastguard Worker
178*da0073e9SAndroid Build Coastguard Worker    In particular, it is equivalent to call
179*da0073e9SAndroid Build Coastguard Worker    ``g = autograd.grad(loss, input)`` and ``g = autograd.grad(loss, get_gradient_edge(input))``.
180*da0073e9SAndroid Build Coastguard Worker    """
181*da0073e9SAndroid Build Coastguard Worker    if not tensor.requires_grad:
182*da0073e9SAndroid Build Coastguard Worker        raise RuntimeError(
183*da0073e9SAndroid Build Coastguard Worker            "It is not possible to get the gradient edge for a Tensor that does not require gradients"
184*da0073e9SAndroid Build Coastguard Worker        )
185*da0073e9SAndroid Build Coastguard Worker    grad_fn = _get_grad_fn_or_grad_acc(tensor)
186*da0073e9SAndroid Build Coastguard Worker
187*da0073e9SAndroid Build Coastguard Worker    # Note that output_nr default to 0 which is the right value
188*da0073e9SAndroid Build Coastguard Worker    # for the AccumulateGrad node.
189*da0073e9SAndroid Build Coastguard Worker    return GradientEdge(grad_fn, tensor.output_nr)
190*da0073e9SAndroid Build Coastguard Worker
191*da0073e9SAndroid Build Coastguard Worker
192*da0073e9SAndroid Build Coastguard Workerdef increment_version(tensor):
193*da0073e9SAndroid Build Coastguard Worker    """Update autograd metadata tracking whether the given Tensor was modified in place.
194*da0073e9SAndroid Build Coastguard Worker
195*da0073e9SAndroid Build Coastguard Worker    This is to enable more accurate error checking within the autograd engine.
196*da0073e9SAndroid Build Coastguard Worker    It is already done automatically by PyTorch functions and within custom Function
197*da0073e9SAndroid Build Coastguard Worker    when mark_dirty() is called appropriately so you only need to call this explicitly
198*da0073e9SAndroid Build Coastguard Worker    if you are doing inplace operation on the Tensor data in a way that Pytorch doesn't
199*da0073e9SAndroid Build Coastguard Worker    know about. For example a custom kernel that reads the Tensor data_ptr and modifies
200*da0073e9SAndroid Build Coastguard Worker    the memory inplace based on this pointer.
201*da0073e9SAndroid Build Coastguard Worker
202*da0073e9SAndroid Build Coastguard Worker    Note that incrementing the version counter multiple times for a single inplace operation
203*da0073e9SAndroid Build Coastguard Worker    is not problematic.
204*da0073e9SAndroid Build Coastguard Worker    """
205*da0073e9SAndroid Build Coastguard Worker    torch._C._increment_version(tensor)
206*da0073e9SAndroid Build Coastguard Worker
207*da0073e9SAndroid Build Coastguard Worker
208*da0073e9SAndroid Build Coastguard Workerclass saved_tensors_hooks:
209*da0073e9SAndroid Build Coastguard Worker    """Context-manager that sets a pair of pack / unpack hooks for saved tensors.
210*da0073e9SAndroid Build Coastguard Worker
211*da0073e9SAndroid Build Coastguard Worker    Use this context-manager to define how intermediary results of an operation
212*da0073e9SAndroid Build Coastguard Worker    should be packed before saving, and unpacked on retrieval.
213*da0073e9SAndroid Build Coastguard Worker
214*da0073e9SAndroid Build Coastguard Worker    In that context, the ``pack_hook`` function will be called everytime an
215*da0073e9SAndroid Build Coastguard Worker    operation saves a tensor for backward (this includes intermediary results
216*da0073e9SAndroid Build Coastguard Worker    saved using
217*da0073e9SAndroid Build Coastguard Worker    :func:`~torch.autograd.function._ContextMethodMixin.save_for_backward` but
218*da0073e9SAndroid Build Coastguard Worker    also those recorded by a PyTorch-defined operation). The output of
219*da0073e9SAndroid Build Coastguard Worker    ``pack_hook`` is then stored in the computation graph instead of the
220*da0073e9SAndroid Build Coastguard Worker    original tensor.
221*da0073e9SAndroid Build Coastguard Worker
222*da0073e9SAndroid Build Coastguard Worker    The ``unpack_hook`` is called when the saved tensor needs to be accessed,
223*da0073e9SAndroid Build Coastguard Worker    namely when executing :func:`torch.Tensor.backward()` or
224*da0073e9SAndroid Build Coastguard Worker    :func:`torch.autograd.grad()`. It takes as argument the *packed* object
225*da0073e9SAndroid Build Coastguard Worker    returned by ``pack_hook`` and should return a tensor which has the same
226*da0073e9SAndroid Build Coastguard Worker    content as the original tensor (passed as input to the corresponding
227*da0073e9SAndroid Build Coastguard Worker    ``pack_hook``).
228*da0073e9SAndroid Build Coastguard Worker
229*da0073e9SAndroid Build Coastguard Worker    The hooks should have the following signatures:
230*da0073e9SAndroid Build Coastguard Worker
231*da0073e9SAndroid Build Coastguard Worker        pack_hook(tensor: Tensor) -> Any
232*da0073e9SAndroid Build Coastguard Worker
233*da0073e9SAndroid Build Coastguard Worker        unpack_hook(Any) -> Tensor
234*da0073e9SAndroid Build Coastguard Worker
235*da0073e9SAndroid Build Coastguard Worker    where the return value of ``pack_hook`` is a valid input to ``unpack_hook``.
236*da0073e9SAndroid Build Coastguard Worker
237*da0073e9SAndroid Build Coastguard Worker    In general, you want ``unpack_hook(pack_hook(t))`` to be equal to ``t`` in terms
238*da0073e9SAndroid Build Coastguard Worker    of value, size, dtype and device.
239*da0073e9SAndroid Build Coastguard Worker
240*da0073e9SAndroid Build Coastguard Worker    Example::
241*da0073e9SAndroid Build Coastguard Worker
242*da0073e9SAndroid Build Coastguard Worker        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
243*da0073e9SAndroid Build Coastguard Worker        >>> def pack_hook(x):
244*da0073e9SAndroid Build Coastguard Worker        ...     print("Packing", x)
245*da0073e9SAndroid Build Coastguard Worker        ...     return x
246*da0073e9SAndroid Build Coastguard Worker        >>>
247*da0073e9SAndroid Build Coastguard Worker        >>> def unpack_hook(x):
248*da0073e9SAndroid Build Coastguard Worker        ...     print("Unpacking", x)
249*da0073e9SAndroid Build Coastguard Worker        ...     return x
250*da0073e9SAndroid Build Coastguard Worker        >>>
251*da0073e9SAndroid Build Coastguard Worker        >>> a = torch.ones(5, requires_grad=True)
252*da0073e9SAndroid Build Coastguard Worker        >>> b = torch.ones(5, requires_grad=True) * 2
253*da0073e9SAndroid Build Coastguard Worker        >>> with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook):
254*da0073e9SAndroid Build Coastguard Worker        ...     y = a * b
255*da0073e9SAndroid Build Coastguard Worker        Packing tensor([1., 1., 1., 1., 1.], requires_grad=True)
256*da0073e9SAndroid Build Coastguard Worker        Packing tensor([2., 2., 2., 2., 2.], grad_fn=<MulBackward0>)
257*da0073e9SAndroid Build Coastguard Worker        >>> y.sum().backward()
258*da0073e9SAndroid Build Coastguard Worker        Unpacking tensor([1., 1., 1., 1., 1.], requires_grad=True)
259*da0073e9SAndroid Build Coastguard Worker        Unpacking tensor([2., 2., 2., 2., 2.], grad_fn=<MulBackward0>)
260*da0073e9SAndroid Build Coastguard Worker
261*da0073e9SAndroid Build Coastguard Worker    .. warning ::
262*da0073e9SAndroid Build Coastguard Worker        Performing an inplace operation on the input to either hooks may lead
263*da0073e9SAndroid Build Coastguard Worker        to undefined behavior.
264*da0073e9SAndroid Build Coastguard Worker
265*da0073e9SAndroid Build Coastguard Worker    .. warning ::
266*da0073e9SAndroid Build Coastguard Worker        Only one pair of hooks is allowed at a time. When recursively nesting this
267*da0073e9SAndroid Build Coastguard Worker        context-manager, only the inner-most pair of hooks will be applied.
268*da0073e9SAndroid Build Coastguard Worker    """
269*da0073e9SAndroid Build Coastguard Worker
270*da0073e9SAndroid Build Coastguard Worker    def __init__(
271*da0073e9SAndroid Build Coastguard Worker        self,
272*da0073e9SAndroid Build Coastguard Worker        pack_hook: Callable[[torch.Tensor], Any],
273*da0073e9SAndroid Build Coastguard Worker        unpack_hook: Callable[[Any], torch.Tensor],
274*da0073e9SAndroid Build Coastguard Worker    ):
275*da0073e9SAndroid Build Coastguard Worker        self.pack_hook = pack_hook
276*da0073e9SAndroid Build Coastguard Worker        self.unpack_hook = unpack_hook
277*da0073e9SAndroid Build Coastguard Worker
278*da0073e9SAndroid Build Coastguard Worker    def __enter__(self):
279*da0073e9SAndroid Build Coastguard Worker        torch._C._autograd._push_saved_tensors_default_hooks(
280*da0073e9SAndroid Build Coastguard Worker            self.pack_hook, self.unpack_hook
281*da0073e9SAndroid Build Coastguard Worker        )
282*da0073e9SAndroid Build Coastguard Worker
283*da0073e9SAndroid Build Coastguard Worker    def __exit__(self, *args: object):
284*da0073e9SAndroid Build Coastguard Worker        torch._C._autograd._pop_saved_tensors_default_hooks()
285*da0073e9SAndroid Build Coastguard Worker
286*da0073e9SAndroid Build Coastguard Worker
287*da0073e9SAndroid Build Coastguard Workerclass save_on_cpu(saved_tensors_hooks):
288*da0073e9SAndroid Build Coastguard Worker    """Context manager under which tensors saved by the forward pass will be stored on cpu, then retrieved for backward.
289*da0073e9SAndroid Build Coastguard Worker
290*da0073e9SAndroid Build Coastguard Worker    When performing operations within this context manager, intermediary
291*da0073e9SAndroid Build Coastguard Worker    results saved in the graph during the forward pass will be moved to CPU,
292*da0073e9SAndroid Build Coastguard Worker    then copied back to the original device when needed for the backward pass.
293*da0073e9SAndroid Build Coastguard Worker    If the graph was already on CPU, no tensor copy is performed.
294*da0073e9SAndroid Build Coastguard Worker
295*da0073e9SAndroid Build Coastguard Worker    Use this context-manager to trade compute for GPU memory usage (e.g.
296*da0073e9SAndroid Build Coastguard Worker    when your model doesn't fit in GPU memory during training).
297*da0073e9SAndroid Build Coastguard Worker
298*da0073e9SAndroid Build Coastguard Worker    Args:
299*da0073e9SAndroid Build Coastguard Worker        pin_memory (bool): If ``True`` tensors will be saved to CPU pinned memory
300*da0073e9SAndroid Build Coastguard Worker                           during packing and copied to GPU asynchronously during unpacking.
301*da0073e9SAndroid Build Coastguard Worker                           Defaults to ``False``.
302*da0073e9SAndroid Build Coastguard Worker                           Also see :ref:`cuda-memory-pinning`.
303*da0073e9SAndroid Build Coastguard Worker
304*da0073e9SAndroid Build Coastguard Worker
305*da0073e9SAndroid Build Coastguard Worker    Example::
306*da0073e9SAndroid Build Coastguard Worker
307*da0073e9SAndroid Build Coastguard Worker        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
308*da0073e9SAndroid Build Coastguard Worker        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
309*da0073e9SAndroid Build Coastguard Worker        >>> a = torch.randn(5, requires_grad=True, device="cuda")
310*da0073e9SAndroid Build Coastguard Worker        >>> b = torch.randn(5, requires_grad=True, device="cuda")
311*da0073e9SAndroid Build Coastguard Worker        >>> c = torch.randn(5, requires_grad=True, device="cuda")
312*da0073e9SAndroid Build Coastguard Worker        >>>
313*da0073e9SAndroid Build Coastguard Worker        >>> def f(a, b, c):
314*da0073e9SAndroid Build Coastguard Worker        ...     prod_1 = a * b           # a and b are saved on GPU
315*da0073e9SAndroid Build Coastguard Worker        ...     with torch.autograd.graph.save_on_cpu():
316*da0073e9SAndroid Build Coastguard Worker        ...         prod_2 = prod_1 * c  # prod_1 and c are saved on CPU
317*da0073e9SAndroid Build Coastguard Worker        ...     y = prod_2 * a           # prod_2 and a are saved on GPU
318*da0073e9SAndroid Build Coastguard Worker        ...     return y
319*da0073e9SAndroid Build Coastguard Worker        >>>
320*da0073e9SAndroid Build Coastguard Worker        >>> y = f(a, b, c)
321*da0073e9SAndroid Build Coastguard Worker        >>> del a, b, c  # for illustration only
322*da0073e9SAndroid Build Coastguard Worker        >>> # the content of a, b, and prod_2 are still alive on GPU
323*da0073e9SAndroid Build Coastguard Worker        >>> # the content of prod_1 and c only live on CPU
324*da0073e9SAndroid Build Coastguard Worker        >>> y.sum().backward()  # all CPU tensors are moved back to GPU, for backward
325*da0073e9SAndroid Build Coastguard Worker        >>> # all intermediary tensors are released (deleted) after the call to backward
326*da0073e9SAndroid Build Coastguard Worker
327*da0073e9SAndroid Build Coastguard Worker    """
328*da0073e9SAndroid Build Coastguard Worker
329*da0073e9SAndroid Build Coastguard Worker    def __init__(self, pin_memory=False, device_type="cuda"):
330*da0073e9SAndroid Build Coastguard Worker        device_module = getattr(torch, device_type, torch.cuda)
331*da0073e9SAndroid Build Coastguard Worker
332*da0073e9SAndroid Build Coastguard Worker        def pack_to_cpu(tensor):
333*da0073e9SAndroid Build Coastguard Worker            if not pin_memory:
334*da0073e9SAndroid Build Coastguard Worker                return (tensor.device, tensor.cpu())
335*da0073e9SAndroid Build Coastguard Worker            packed = torch.empty(
336*da0073e9SAndroid Build Coastguard Worker                tensor.size(),
337*da0073e9SAndroid Build Coastguard Worker                dtype=tensor.dtype,
338*da0073e9SAndroid Build Coastguard Worker                layout=tensor.layout,
339*da0073e9SAndroid Build Coastguard Worker                pin_memory=(device_module.is_available() and not tensor.is_sparse),
340*da0073e9SAndroid Build Coastguard Worker            )
341*da0073e9SAndroid Build Coastguard Worker            packed.copy_(tensor)
342*da0073e9SAndroid Build Coastguard Worker            return (tensor.device, packed)
343*da0073e9SAndroid Build Coastguard Worker
344*da0073e9SAndroid Build Coastguard Worker        def unpack_from_cpu(packed):
345*da0073e9SAndroid Build Coastguard Worker            device, tensor = packed
346*da0073e9SAndroid Build Coastguard Worker            return tensor.to(device, non_blocking=pin_memory)
347*da0073e9SAndroid Build Coastguard Worker
348*da0073e9SAndroid Build Coastguard Worker        super().__init__(pack_to_cpu, unpack_from_cpu)
349*da0073e9SAndroid Build Coastguard Worker
350*da0073e9SAndroid Build Coastguard Worker
351*da0073e9SAndroid Build Coastguard Worker@contextlib.contextmanager
352*da0073e9SAndroid Build Coastguard Workerdef disable_saved_tensors_hooks(error_message):
353*da0073e9SAndroid Build Coastguard Worker    """Context-manager that disables the saved tensors default hooks feature.
354*da0073e9SAndroid Build Coastguard Worker
355*da0073e9SAndroid Build Coastguard Worker    Useful for if you are creating a feature that does not work with saved
356*da0073e9SAndroid Build Coastguard Worker    tensors default hooks.
357*da0073e9SAndroid Build Coastguard Worker
358*da0073e9SAndroid Build Coastguard Worker    Args:
359*da0073e9SAndroid Build Coastguard Worker        error_message (str): When saved tensors default hooks are used when they
360*da0073e9SAndroid Build Coastguard Worker                             have been are disabled, a RuntimeError with this
361*da0073e9SAndroid Build Coastguard Worker                             error message gets raised.
362*da0073e9SAndroid Build Coastguard Worker
363*da0073e9SAndroid Build Coastguard Worker    Example::
364*da0073e9SAndroid Build Coastguard Worker
365*da0073e9SAndroid Build Coastguard Worker        >>> # xdoctest: +SKIP(failing)
366*da0073e9SAndroid Build Coastguard Worker        >>> message = "saved tensors default hooks are disabled"
367*da0073e9SAndroid Build Coastguard Worker        >>> with torch.autograd.graph.disable_saved_tensors_hooks(message):
368*da0073e9SAndroid Build Coastguard Worker        ...     # Raises RuntimeError: saved tensors default hooks are disabled
369*da0073e9SAndroid Build Coastguard Worker        ...     with torch.autograd.graph.save_on_cpu():
370*da0073e9SAndroid Build Coastguard Worker        ...         pass
371*da0073e9SAndroid Build Coastguard Worker
372*da0073e9SAndroid Build Coastguard Worker    """
373*da0073e9SAndroid Build Coastguard Worker    try:
374*da0073e9SAndroid Build Coastguard Worker        maybe_prev_message = (
375*da0073e9SAndroid Build Coastguard Worker            torch._C._autograd._saved_tensors_hooks_get_disabled_error_message()
376*da0073e9SAndroid Build Coastguard Worker        )
377*da0073e9SAndroid Build Coastguard Worker        torch._C._autograd._saved_tensors_hooks_disable(error_message)
378*da0073e9SAndroid Build Coastguard Worker        yield
379*da0073e9SAndroid Build Coastguard Worker    finally:
380*da0073e9SAndroid Build Coastguard Worker        # See NOTE: [disabled_error_message invariant]
381*da0073e9SAndroid Build Coastguard Worker        if maybe_prev_message is None:
382*da0073e9SAndroid Build Coastguard Worker            torch._C._autograd._saved_tensors_hooks_enable()
383*da0073e9SAndroid Build Coastguard Worker        else:
384*da0073e9SAndroid Build Coastguard Worker            torch._C._autograd._saved_tensors_hooks_disable(maybe_prev_message)
385*da0073e9SAndroid Build Coastguard Worker
386*da0073e9SAndroid Build Coastguard Worker
387*da0073e9SAndroid Build Coastguard Workerclass _MultiHandle(RemovableHandle):
388*da0073e9SAndroid Build Coastguard Worker    handles: Tuple[RemovableHandle, ...]
389*da0073e9SAndroid Build Coastguard Worker
390*da0073e9SAndroid Build Coastguard Worker    def __init__(self, handles: Tuple[RemovableHandle, ...]):
391*da0073e9SAndroid Build Coastguard Worker        self.handles = handles
392*da0073e9SAndroid Build Coastguard Worker
393*da0073e9SAndroid Build Coastguard Worker    def remove(self):
394*da0073e9SAndroid Build Coastguard Worker        for handle in self.handles:
395*da0073e9SAndroid Build Coastguard Worker            handle.remove()
396*da0073e9SAndroid Build Coastguard Worker
397*da0073e9SAndroid Build Coastguard Worker    def __getstate__(self):
398*da0073e9SAndroid Build Coastguard Worker        return self.handles
399*da0073e9SAndroid Build Coastguard Worker
400*da0073e9SAndroid Build Coastguard Worker    def __setstate__(self, state):
401*da0073e9SAndroid Build Coastguard Worker        self.handles = state
402*da0073e9SAndroid Build Coastguard Worker
403*da0073e9SAndroid Build Coastguard Worker
404*da0073e9SAndroid Build Coastguard Workerdef register_multi_grad_hook(
405*da0073e9SAndroid Build Coastguard Worker    tensors: Sequence[torch.Tensor],
406*da0073e9SAndroid Build Coastguard Worker    fn: Union[
407*da0073e9SAndroid Build Coastguard Worker        Callable[[Sequence[Optional[torch.Tensor]]], None],
408*da0073e9SAndroid Build Coastguard Worker        Callable[[torch.Tensor], None],
409*da0073e9SAndroid Build Coastguard Worker    ],
410*da0073e9SAndroid Build Coastguard Worker    *,
411*da0073e9SAndroid Build Coastguard Worker    mode: str = "all",
412*da0073e9SAndroid Build Coastguard Worker):
413*da0073e9SAndroid Build Coastguard Worker    r"""Register a multi-grad backward hook.
414*da0073e9SAndroid Build Coastguard Worker
415*da0073e9SAndroid Build Coastguard Worker    There are two supported modes: ``"all"`` and ``"any"``.
416*da0073e9SAndroid Build Coastguard Worker
417*da0073e9SAndroid Build Coastguard Worker    Under the ``"all"`` mode, the hook will be called after gradients with respect to every tensor in
418*da0073e9SAndroid Build Coastguard Worker    :attr:`tensors` have been computed. If a tensor is in :attr:`tensors` but
419*da0073e9SAndroid Build Coastguard Worker    is not part of the graph, or if a tensor is not needed to compute the gradients
420*da0073e9SAndroid Build Coastguard Worker    for any ``inputs`` specified for the current ``.backward()`` or ``.grad()`` call,
421*da0073e9SAndroid Build Coastguard Worker    this tensor will be ignored and the hook will not wait for its gradient to be
422*da0073e9SAndroid Build Coastguard Worker    computed.
423*da0073e9SAndroid Build Coastguard Worker
424*da0073e9SAndroid Build Coastguard Worker    After every non-ignored tensor's gradient has been computed, :attr:`fn` will be
425*da0073e9SAndroid Build Coastguard Worker    called with those gradients. ``None`` will be passed for tensors that did not
426*da0073e9SAndroid Build Coastguard Worker    have their gradients computed.
427*da0073e9SAndroid Build Coastguard Worker
428*da0073e9SAndroid Build Coastguard Worker    Under the ``"any"`` mode, the hook will be called after the first gradient
429*da0073e9SAndroid Build Coastguard Worker    with respect to a tensor in :attr:`tensors` has been computed. The hook
430*da0073e9SAndroid Build Coastguard Worker    will be called with that gradient as its argument.
431*da0073e9SAndroid Build Coastguard Worker
432*da0073e9SAndroid Build Coastguard Worker    The hook should not modify its arguments.
433*da0073e9SAndroid Build Coastguard Worker
434*da0073e9SAndroid Build Coastguard Worker    This function returns a handle with a method ``handle.remove()`` that removes the hook.
435*da0073e9SAndroid Build Coastguard Worker
436*da0073e9SAndroid Build Coastguard Worker    .. note::
437*da0073e9SAndroid Build Coastguard Worker        See :ref:`backward-hooks-execution` for more information on how when this hook
438*da0073e9SAndroid Build Coastguard Worker        is executed, and how its execution is ordered relative to other hooks.
439*da0073e9SAndroid Build Coastguard Worker
440*da0073e9SAndroid Build Coastguard Worker    Example::
441*da0073e9SAndroid Build Coastguard Worker
442*da0073e9SAndroid Build Coastguard Worker        >>> import torch
443*da0073e9SAndroid Build Coastguard Worker        >>>
444*da0073e9SAndroid Build Coastguard Worker        >>> a = torch.rand(2, 3, requires_grad=True)
445*da0073e9SAndroid Build Coastguard Worker        >>> b = torch.rand(2, 3, requires_grad=True)
446*da0073e9SAndroid Build Coastguard Worker        >>> c = a * b
447*da0073e9SAndroid Build Coastguard Worker        >>> d = a * b
448*da0073e9SAndroid Build Coastguard Worker        >>>
449*da0073e9SAndroid Build Coastguard Worker        >>> def fn(grads):
450*da0073e9SAndroid Build Coastguard Worker        ...     print([g is not None for g in grads])
451*da0073e9SAndroid Build Coastguard Worker        ...
452*da0073e9SAndroid Build Coastguard Worker        >>> torch.autograd.graph.register_multi_grad_hook((a, b, c, d), fn)
453*da0073e9SAndroid Build Coastguard Worker        >>>
454*da0073e9SAndroid Build Coastguard Worker        >>> c.sum().backward(retain_graph=True)
455*da0073e9SAndroid Build Coastguard Worker        [True, True, True, False]
456*da0073e9SAndroid Build Coastguard Worker        >>> c.sum().backward(inputs=(a,), retain_graph=True)
457*da0073e9SAndroid Build Coastguard Worker        [True, False, True, False]
458*da0073e9SAndroid Build Coastguard Worker        >>>
459*da0073e9SAndroid Build Coastguard Worker    """
460*da0073e9SAndroid Build Coastguard Worker    supported_modes = ("all", "any")
461*da0073e9SAndroid Build Coastguard Worker    if mode not in supported_modes:
462*da0073e9SAndroid Build Coastguard Worker        raise ValueError(f"Expects mode to be one of {supported_modes} but got {mode}")
463*da0073e9SAndroid Build Coastguard Worker
464*da0073e9SAndroid Build Coastguard Worker    if mode == "all":
465*da0073e9SAndroid Build Coastguard Worker        count: Dict[int, int] = dict()
466*da0073e9SAndroid Build Coastguard Worker        nb_calls = None
467*da0073e9SAndroid Build Coastguard Worker        buffer: Dict[int, List[Optional[torch.Tensor]]] = dict()
468*da0073e9SAndroid Build Coastguard Worker
469*da0073e9SAndroid Build Coastguard Worker        grad_fns = list(map(_get_grad_fn_or_grad_acc, tensors))
470*da0073e9SAndroid Build Coastguard Worker        len_tensors = len(tensors)
471*da0073e9SAndroid Build Coastguard Worker
472*da0073e9SAndroid Build Coastguard Worker        def get_inner_hook(idx):
473*da0073e9SAndroid Build Coastguard Worker            def inner_hook(grad: torch.Tensor):
474*da0073e9SAndroid Build Coastguard Worker                nonlocal count, nb_calls, buffer, fn
475*da0073e9SAndroid Build Coastguard Worker                id = torch._C._current_graph_task_id()
476*da0073e9SAndroid Build Coastguard Worker                assert (
477*da0073e9SAndroid Build Coastguard Worker                    id != -1
478*da0073e9SAndroid Build Coastguard Worker                ), "expected this hook to be called inside a backward call"
479*da0073e9SAndroid Build Coastguard Worker                count[id] = count.get(id, 0)
480*da0073e9SAndroid Build Coastguard Worker                buffer[id] = buffer.get(id, [None] * len_tensors)
481*da0073e9SAndroid Build Coastguard Worker
482*da0073e9SAndroid Build Coastguard Worker                if count[id] == 0:
483*da0073e9SAndroid Build Coastguard Worker                    # On the first call, compute the actual nb_calls and buffer
484*da0073e9SAndroid Build Coastguard Worker                    nb_calls = sum(torch._C._will_engine_execute_node(g) for g in grad_fns)  # type: ignore[attr-defined]
485*da0073e9SAndroid Build Coastguard Worker
486*da0073e9SAndroid Build Coastguard Worker                buffer[id][idx] = grad
487*da0073e9SAndroid Build Coastguard Worker                count[id] += 1
488*da0073e9SAndroid Build Coastguard Worker
489*da0073e9SAndroid Build Coastguard Worker                if count[id] == nb_calls:
490*da0073e9SAndroid Build Coastguard Worker                    fn = cast(Callable[[Sequence[Optional[torch.Tensor]]], None], fn)
491*da0073e9SAndroid Build Coastguard Worker                    fn(buffer[id])
492*da0073e9SAndroid Build Coastguard Worker                    del count[id]
493*da0073e9SAndroid Build Coastguard Worker                    del buffer[id]
494*da0073e9SAndroid Build Coastguard Worker
495*da0073e9SAndroid Build Coastguard Worker            return inner_hook
496*da0073e9SAndroid Build Coastguard Worker
497*da0073e9SAndroid Build Coastguard Worker        handles: Tuple[RemovableHandle] = tuple(
498*da0073e9SAndroid Build Coastguard Worker            t.register_hook(get_inner_hook(i)) for i, t in enumerate(tensors)
499*da0073e9SAndroid Build Coastguard Worker        )
500*da0073e9SAndroid Build Coastguard Worker    elif mode == "any":
501*da0073e9SAndroid Build Coastguard Worker        fn = cast(Callable[[torch.Tensor], None], fn)
502*da0073e9SAndroid Build Coastguard Worker        lock = threading.Lock()
503*da0073e9SAndroid Build Coastguard Worker        ran_hook: Dict[int, bool] = defaultdict(bool)
504*da0073e9SAndroid Build Coastguard Worker
505*da0073e9SAndroid Build Coastguard Worker        @functools.wraps(fn)
506*da0073e9SAndroid Build Coastguard Worker        def wrapped_fn(grad: torch.Tensor):
507*da0073e9SAndroid Build Coastguard Worker            nonlocal ran_hook
508*da0073e9SAndroid Build Coastguard Worker            id = torch._C._current_graph_task_id()
509*da0073e9SAndroid Build Coastguard Worker            assert id != -1, "expected this hook to be called inside a backward call"
510*da0073e9SAndroid Build Coastguard Worker            with lock:
511*da0073e9SAndroid Build Coastguard Worker                prev, ran_hook[id] = ran_hook[id], True
512*da0073e9SAndroid Build Coastguard Worker            if prev:
513*da0073e9SAndroid Build Coastguard Worker                return
514*da0073e9SAndroid Build Coastguard Worker            fn(grad)
515*da0073e9SAndroid Build Coastguard Worker
516*da0073e9SAndroid Build Coastguard Worker        handles = tuple(
517*da0073e9SAndroid Build Coastguard Worker            tensor.register_hook(wrapped_fn)
518*da0073e9SAndroid Build Coastguard Worker            for tensor in tensors
519*da0073e9SAndroid Build Coastguard Worker            if tensor.requires_grad
520*da0073e9SAndroid Build Coastguard Worker        )
521*da0073e9SAndroid Build Coastguard Worker
522*da0073e9SAndroid Build Coastguard Worker    return _MultiHandle(handles)  # type: ignore[possibly-undefined]
523*da0073e9SAndroid Build Coastguard Worker
524*da0073e9SAndroid Build Coastguard Worker
525*da0073e9SAndroid Build Coastguard Worker# NOTE [Allow mutation on tensors saved for backward]
526*da0073e9SAndroid Build Coastguard Worker#
527*da0073e9SAndroid Build Coastguard Worker# 1. Tensor gets saved for backward
528*da0073e9SAndroid Build Coastguard Worker#    - remember the python object id and the version of the tensor
529*da0073e9SAndroid Build Coastguard Worker#    - remember aliasing information (data_ptr of base + version)
530*da0073e9SAndroid Build Coastguard Worker#    - save the original so we control its lifetime
531*da0073e9SAndroid Build Coastguard Worker# 2. Any time a tensor gets in-placed
532*da0073e9SAndroid Build Coastguard Worker#    - for each tensor aliased to it:
533*da0073e9SAndroid Build Coastguard Worker#      - check using its object id and version to see if it has been saved
534*da0073e9SAndroid Build Coastguard Worker#      - if it has been saved, clone it
535*da0073e9SAndroid Build Coastguard Worker#      - delete the reference to the original
536*da0073e9SAndroid Build Coastguard Worker# 3. during backward
537*da0073e9SAndroid Build Coastguard Worker#    - if the clone exists, the tensor must've been modified in-place
538*da0073e9SAndroid Build Coastguard Worker_allow_mutation_on_saved_tensors_enabled = False
539*da0073e9SAndroid Build Coastguard Worker
540*da0073e9SAndroid Build Coastguard Worker
541*da0073e9SAndroid Build Coastguard Workerdef _get_tid(t) -> Tuple[int, int, int]:
542*da0073e9SAndroid Build Coastguard Worker    # FIXME: This is almost definitely a bug.
543*da0073e9SAndroid Build Coastguard Worker    if isinstance(
544*da0073e9SAndroid Build Coastguard Worker        t,
545*da0073e9SAndroid Build Coastguard Worker        (
546*da0073e9SAndroid Build Coastguard Worker            torch._subclasses.fake_tensor.FakeTensor,
547*da0073e9SAndroid Build Coastguard Worker            torch._subclasses.functional_tensor.FunctionalTensor,
548*da0073e9SAndroid Build Coastguard Worker        ),
549*da0073e9SAndroid Build Coastguard Worker    ):
550*da0073e9SAndroid Build Coastguard Worker        data_ptr = 0
551*da0073e9SAndroid Build Coastguard Worker    else:
552*da0073e9SAndroid Build Coastguard Worker        data_ptr = t.data_ptr()
553*da0073e9SAndroid Build Coastguard Worker    return (id(t), data_ptr, t._version)
554*da0073e9SAndroid Build Coastguard Worker
555*da0073e9SAndroid Build Coastguard Worker
556*da0073e9SAndroid Build Coastguard Workerdef _get_sid(t) -> Tuple[int, int]:
557*da0073e9SAndroid Build Coastguard Worker    # FIXME: This is almost definitely a bug.
558*da0073e9SAndroid Build Coastguard Worker    if isinstance(
559*da0073e9SAndroid Build Coastguard Worker        t,
560*da0073e9SAndroid Build Coastguard Worker        (
561*da0073e9SAndroid Build Coastguard Worker            torch._subclasses.fake_tensor.FakeTensor,
562*da0073e9SAndroid Build Coastguard Worker            torch._subclasses.functional_tensor.FunctionalTensor,
563*da0073e9SAndroid Build Coastguard Worker        ),
564*da0073e9SAndroid Build Coastguard Worker    ):
565*da0073e9SAndroid Build Coastguard Worker        data_ptr = 0
566*da0073e9SAndroid Build Coastguard Worker    else:
567*da0073e9SAndroid Build Coastguard Worker        data_ptr = t.data_ptr()
568*da0073e9SAndroid Build Coastguard Worker    return (data_ptr, t._version)
569*da0073e9SAndroid Build Coastguard Worker
570*da0073e9SAndroid Build Coastguard Worker
571*da0073e9SAndroid Build Coastguard Workerclass _Handle:
572*da0073e9SAndroid Build Coastguard Worker    pass
573*da0073e9SAndroid Build Coastguard Worker
574*da0073e9SAndroid Build Coastguard Worker
575*da0073e9SAndroid Build Coastguard Workerclass _swap_with_cloned(saved_tensors_hooks):
576*da0073e9SAndroid Build Coastguard Worker    def __init__(self, ctx):
577*da0073e9SAndroid Build Coastguard Worker        def pack_hook(t):
578*da0073e9SAndroid Build Coastguard Worker            tid = _get_tid(t)
579*da0073e9SAndroid Build Coastguard Worker            sid = _get_sid(t)
580*da0073e9SAndroid Build Coastguard Worker            # Tensors saved for backward have an entry in _tid_to_weakhandle
581*da0073e9SAndroid Build Coastguard Worker            handle: Optional[_Handle] = None
582*da0073e9SAndroid Build Coastguard Worker
583*da0073e9SAndroid Build Coastguard Worker            # Save aliasing information
584*da0073e9SAndroid Build Coastguard Worker            ctx.sid_to_tid[sid].add(tid)
585*da0073e9SAndroid Build Coastguard Worker
586*da0073e9SAndroid Build Coastguard Worker            # NB: The same tensor (of the same version) can be saved multiple times
587*da0073e9SAndroid Build Coastguard Worker            if tid not in ctx.tid_to_weakhandle:
588*da0073e9SAndroid Build Coastguard Worker                handle = _Handle()
589*da0073e9SAndroid Build Coastguard Worker                ctx.tid_to_weakhandle[tid] = handle
590*da0073e9SAndroid Build Coastguard Worker                ctx.original[handle] = t
591*da0073e9SAndroid Build Coastguard Worker            else:
592*da0073e9SAndroid Build Coastguard Worker                # Store an additional strong reference to the handle
593*da0073e9SAndroid Build Coastguard Worker                handle = ctx.tid_to_weakhandle[tid]
594*da0073e9SAndroid Build Coastguard Worker            return handle
595*da0073e9SAndroid Build Coastguard Worker
596*da0073e9SAndroid Build Coastguard Worker        def unpack_hook(tup):
597*da0073e9SAndroid Build Coastguard Worker            handle = tup
598*da0073e9SAndroid Build Coastguard Worker            error_msg = (
599*da0073e9SAndroid Build Coastguard Worker                "Trying to backward outside of the 'allow_mutation_on_saved_tensors' context"
600*da0073e9SAndroid Build Coastguard Worker                "in which the graph was originally recorded."
601*da0073e9SAndroid Build Coastguard Worker            )
602*da0073e9SAndroid Build Coastguard Worker            assert _allow_mutation_on_saved_tensors_enabled, error_msg
603*da0073e9SAndroid Build Coastguard Worker            if handle in ctx.cloned:
604*da0073e9SAndroid Build Coastguard Worker                res = ctx.cloned[handle]
605*da0073e9SAndroid Build Coastguard Worker            else:
606*da0073e9SAndroid Build Coastguard Worker                assert handle in ctx.original, error_msg
607*da0073e9SAndroid Build Coastguard Worker                res = ctx.original[handle]
608*da0073e9SAndroid Build Coastguard Worker            return res
609*da0073e9SAndroid Build Coastguard Worker
610*da0073e9SAndroid Build Coastguard Worker        super().__init__(pack_hook, unpack_hook)
611*da0073e9SAndroid Build Coastguard Worker
612*da0073e9SAndroid Build Coastguard Worker
613*da0073e9SAndroid Build Coastguard Workerclass _CloneArgBeforeMutateMode(TorchDispatchMode):
614*da0073e9SAndroid Build Coastguard Worker    def __init__(self, ctx):
615*da0073e9SAndroid Build Coastguard Worker        self.ctx = ctx
616*da0073e9SAndroid Build Coastguard Worker
617*da0073e9SAndroid Build Coastguard Worker    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
618*da0073e9SAndroid Build Coastguard Worker        kwargs = kwargs or {}
619*da0073e9SAndroid Build Coastguard Worker
620*da0073e9SAndroid Build Coastguard Worker        for idx, arg in enumerate(func._schema.arguments):
621*da0073e9SAndroid Build Coastguard Worker            if arg.alias_info is not None and arg.alias_info.is_write:
622*da0073e9SAndroid Build Coastguard Worker                t = kwargs["out"] if arg.is_out else args[idx]
623*da0073e9SAndroid Build Coastguard Worker                tid = _get_tid(t)
624*da0073e9SAndroid Build Coastguard Worker                sid = _get_sid(t)
625*da0073e9SAndroid Build Coastguard Worker                ctx = self.ctx
626*da0073e9SAndroid Build Coastguard Worker                if sid in ctx.sid_to_tid:
627*da0073e9SAndroid Build Coastguard Worker                    for tid in ctx.sid_to_tid[sid]:
628*da0073e9SAndroid Build Coastguard Worker                        if tid not in ctx.tid_to_weakhandle:
629*da0073e9SAndroid Build Coastguard Worker                            # We know that if tid is in sid_to_tid, then it must also be in
630*da0073e9SAndroid Build Coastguard Worker                            # tid_to_weakhandle. However, it is possible for the tensor to be
631*da0073e9SAndroid Build Coastguard Worker                            # saved at one point, but cleared by backward before it is modified
632*da0073e9SAndroid Build Coastguard Worker                            # in-place. Consider the following example:
633*da0073e9SAndroid Build Coastguard Worker                            #
634*da0073e9SAndroid Build Coastguard Worker                            # >>> a = torch.randn(2, 3, requires_grad=True).clone()
635*da0073e9SAndroid Build Coastguard Worker                            # >>> out = (a**2).sum()
636*da0073e9SAndroid Build Coastguard Worker                            # >>> out.backward()
637*da0073e9SAndroid Build Coastguard Worker                            # >>> a.sin_()
638*da0073e9SAndroid Build Coastguard Worker                            continue
639*da0073e9SAndroid Build Coastguard Worker                        handle = ctx.tid_to_weakhandle[tid]
640*da0073e9SAndroid Build Coastguard Worker                        if handle in ctx.cloned:
641*da0073e9SAndroid Build Coastguard Worker                            # The same exact tensor has been cloned already
642*da0073e9SAndroid Build Coastguard Worker                            continue
643*da0073e9SAndroid Build Coastguard Worker                        ctx.cloned[handle] = ctx.original[handle].clone()
644*da0073e9SAndroid Build Coastguard Worker                        del ctx.original[handle]
645*da0073e9SAndroid Build Coastguard Worker
646*da0073e9SAndroid Build Coastguard Worker        rs = func(*args, **kwargs)
647*da0073e9SAndroid Build Coastguard Worker        return rs
648*da0073e9SAndroid Build Coastguard Worker
649*da0073e9SAndroid Build Coastguard Worker
650*da0073e9SAndroid Build Coastguard Workerclass _AllowMutationOnSavedContext:
651*da0073e9SAndroid Build Coastguard Worker    def __init__(self):
652*da0073e9SAndroid Build Coastguard Worker        self.cloned: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary()
653*da0073e9SAndroid Build Coastguard Worker        self.original: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary()
654*da0073e9SAndroid Build Coastguard Worker        self.tid_to_weakhandle: weakref.WeakValueDictionary = (
655*da0073e9SAndroid Build Coastguard Worker            weakref.WeakValueDictionary()
656*da0073e9SAndroid Build Coastguard Worker        )
657*da0073e9SAndroid Build Coastguard Worker        self.sid_to_tid: Dict[Tuple[int, int], Set[Tuple[int, int, int]]] = defaultdict(
658*da0073e9SAndroid Build Coastguard Worker            set
659*da0073e9SAndroid Build Coastguard Worker        )
660*da0073e9SAndroid Build Coastguard Worker
661*da0073e9SAndroid Build Coastguard Worker    def clear(self):
662*da0073e9SAndroid Build Coastguard Worker        self.cloned.clear()
663*da0073e9SAndroid Build Coastguard Worker        self.original.clear()
664*da0073e9SAndroid Build Coastguard Worker        self.tid_to_weakhandle.clear()
665*da0073e9SAndroid Build Coastguard Worker        self.sid_to_tid.clear()
666*da0073e9SAndroid Build Coastguard Worker
667*da0073e9SAndroid Build Coastguard Worker
668*da0073e9SAndroid Build Coastguard Worker@contextlib.contextmanager
669*da0073e9SAndroid Build Coastguard Workerdef allow_mutation_on_saved_tensors():
670*da0073e9SAndroid Build Coastguard Worker    """Context manager under which mutating tensors saved for backward is allowed.
671*da0073e9SAndroid Build Coastguard Worker
672*da0073e9SAndroid Build Coastguard Worker    Under this context manager, tensors saved for backward are cloned on mutation,
673*da0073e9SAndroid Build Coastguard Worker    so the original version can still be used during backward. Normally, mutating a tensor
674*da0073e9SAndroid Build Coastguard Worker    saved for backward will result in an error raised when it's used during backward.
675*da0073e9SAndroid Build Coastguard Worker
676*da0073e9SAndroid Build Coastguard Worker    To ensure the correct behavior, both the forward and backward should be run under
677*da0073e9SAndroid Build Coastguard Worker    the same context manager.
678*da0073e9SAndroid Build Coastguard Worker
679*da0073e9SAndroid Build Coastguard Worker    returns:
680*da0073e9SAndroid Build Coastguard Worker        An _AllowMutationOnSavedContext object storing the state managed by this
681*da0073e9SAndroid Build Coastguard Worker        context manager. This object can be useful for debugging purposes. The state
682*da0073e9SAndroid Build Coastguard Worker        managed by the context manager is automatically cleared upon exiting.
683*da0073e9SAndroid Build Coastguard Worker
684*da0073e9SAndroid Build Coastguard Worker    Example::
685*da0073e9SAndroid Build Coastguard Worker
686*da0073e9SAndroid Build Coastguard Worker        >>> import torch
687*da0073e9SAndroid Build Coastguard Worker        >>> with torch.autograd.graph.allow_mutation_on_saved_tensors():
688*da0073e9SAndroid Build Coastguard Worker        ...     # forward
689*da0073e9SAndroid Build Coastguard Worker        ...     a = torch.ones(2, 3, requires_grad=True)
690*da0073e9SAndroid Build Coastguard Worker        ...     b = a.clone()
691*da0073e9SAndroid Build Coastguard Worker        ...     out = (b**2).sum()
692*da0073e9SAndroid Build Coastguard Worker        ...     b.sin_()
693*da0073e9SAndroid Build Coastguard Worker        ...     # backward
694*da0073e9SAndroid Build Coastguard Worker        ...     out.sum().backward()
695*da0073e9SAndroid Build Coastguard Worker        ...
696*da0073e9SAndroid Build Coastguard Worker        tensor([[0.8415, 0.8415, 0.8415],
697*da0073e9SAndroid Build Coastguard Worker                [0.8415, 0.8415, 0.8415]], grad_fn=<SinBackward0>)
698*da0073e9SAndroid Build Coastguard Worker    """
699*da0073e9SAndroid Build Coastguard Worker    global _allow_mutation_on_saved_tensors_enabled
700*da0073e9SAndroid Build Coastguard Worker
701*da0073e9SAndroid Build Coastguard Worker    ctx = _AllowMutationOnSavedContext()
702*da0073e9SAndroid Build Coastguard Worker
703*da0073e9SAndroid Build Coastguard Worker    with _swap_with_cloned(ctx), _CloneArgBeforeMutateMode(ctx):
704*da0073e9SAndroid Build Coastguard Worker        try:
705*da0073e9SAndroid Build Coastguard Worker            if _allow_mutation_on_saved_tensors_enabled:
706*da0073e9SAndroid Build Coastguard Worker                raise RuntimeError(
707*da0073e9SAndroid Build Coastguard Worker                    "allow_mutation_on_saved_tensors contexts cannot be nested"
708*da0073e9SAndroid Build Coastguard Worker                )
709*da0073e9SAndroid Build Coastguard Worker            _allow_mutation_on_saved_tensors_enabled = True
710*da0073e9SAndroid Build Coastguard Worker            yield ctx
711*da0073e9SAndroid Build Coastguard Worker        finally:
712*da0073e9SAndroid Build Coastguard Worker            ctx.clear()
713*da0073e9SAndroid Build Coastguard Worker            _allow_mutation_on_saved_tensors_enabled = False
714*da0073e9SAndroid Build Coastguard Worker
715*da0073e9SAndroid Build Coastguard Worker
716*da0073e9SAndroid Build Coastguard Workerdef _register_logging_hooks_on_whole_graph(t_outputs: List[torch.Tensor]):
717*da0073e9SAndroid Build Coastguard Worker    grad_fns = list(map(_get_grad_fn_or_grad_acc, t_outputs))
718*da0073e9SAndroid Build Coastguard Worker
719*da0073e9SAndroid Build Coastguard Worker    def iter_graph(roots):
720*da0073e9SAndroid Build Coastguard Worker        if not roots:
721*da0073e9SAndroid Build Coastguard Worker            return
722*da0073e9SAndroid Build Coastguard Worker        seen = set()
723*da0073e9SAndroid Build Coastguard Worker        q: Deque = collections.deque()
724*da0073e9SAndroid Build Coastguard Worker        for node in roots:
725*da0073e9SAndroid Build Coastguard Worker            if node is not None:
726*da0073e9SAndroid Build Coastguard Worker                seen.add(node)
727*da0073e9SAndroid Build Coastguard Worker                q.append(node)
728*da0073e9SAndroid Build Coastguard Worker
729*da0073e9SAndroid Build Coastguard Worker        while q:
730*da0073e9SAndroid Build Coastguard Worker            node = q.popleft()
731*da0073e9SAndroid Build Coastguard Worker            for fn, _idx in node.next_functions:
732*da0073e9SAndroid Build Coastguard Worker                if fn in seen or fn is None:
733*da0073e9SAndroid Build Coastguard Worker                    continue
734*da0073e9SAndroid Build Coastguard Worker                seen.add(fn)
735*da0073e9SAndroid Build Coastguard Worker                q.append(fn)
736*da0073e9SAndroid Build Coastguard Worker
737*da0073e9SAndroid Build Coastguard Worker            yield node
738*da0073e9SAndroid Build Coastguard Worker
739*da0073e9SAndroid Build Coastguard Worker    def fmt(t):
740*da0073e9SAndroid Build Coastguard Worker        # Avoid circular import
741*da0073e9SAndroid Build Coastguard Worker        from torch.testing._internal.common_utils import dtype_abbrs
742*da0073e9SAndroid Build Coastguard Worker
743*da0073e9SAndroid Build Coastguard Worker        if t is None:
744*da0073e9SAndroid Build Coastguard Worker            return "None"
745*da0073e9SAndroid Build Coastguard Worker        return f"{dtype_abbrs[t.dtype]}[{', '.join(map(str, t.shape))}]"
746*da0073e9SAndroid Build Coastguard Worker
747*da0073e9SAndroid Build Coastguard Worker    def prehook(grad_outputs):
748*da0073e9SAndroid Build Coastguard Worker        node = torch._C._current_autograd_node()
749*da0073e9SAndroid Build Coastguard Worker        grad_outputs_str = f"[{','.join(fmt(t) for t in grad_outputs)}]"
750*da0073e9SAndroid Build Coastguard Worker        log_str = f"Executing: {node} with grad_outputs: {grad_outputs_str}"
751*da0073e9SAndroid Build Coastguard Worker        log.debug(log_str)
752*da0073e9SAndroid Build Coastguard Worker
753*da0073e9SAndroid Build Coastguard Worker    handles = []
754*da0073e9SAndroid Build Coastguard Worker    for node in iter_graph(grad_fns):
755*da0073e9SAndroid Build Coastguard Worker        handles.append(node.register_prehook(prehook))
756*da0073e9SAndroid Build Coastguard Worker
757*da0073e9SAndroid Build Coastguard Worker    def unregister_hooks():
758*da0073e9SAndroid Build Coastguard Worker        for handle in handles:
759*da0073e9SAndroid Build Coastguard Worker            handle.remove()
760*da0073e9SAndroid Build Coastguard Worker
761*da0073e9SAndroid Build Coastguard Worker    return unregister_hooks
762*da0073e9SAndroid Build Coastguard Worker
763*da0073e9SAndroid Build Coastguard Worker
764*da0073e9SAndroid Build Coastguard Workerdef _engine_run_backward(t_outputs, *args, **kwargs):
765*da0073e9SAndroid Build Coastguard Worker    attach_logging_hooks = log.getEffectiveLevel() <= logging.DEBUG
766*da0073e9SAndroid Build Coastguard Worker    if attach_logging_hooks:
767*da0073e9SAndroid Build Coastguard Worker        unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs)
768*da0073e9SAndroid Build Coastguard Worker    try:
769*da0073e9SAndroid Build Coastguard Worker        return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
770*da0073e9SAndroid Build Coastguard Worker            t_outputs, *args, **kwargs
771*da0073e9SAndroid Build Coastguard Worker        )  # Calls into the C++ engine to run the backward pass
772*da0073e9SAndroid Build Coastguard Worker    finally:
773*da0073e9SAndroid Build Coastguard Worker        if attach_logging_hooks:
774*da0073e9SAndroid Build Coastguard Worker            unregister_hooks()  # type: ignore[possibly-undefined]
775