xref: /aosp_15_r20/external/pytorch/torch/autograd/function.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import functools
3import inspect
4import itertools
5import warnings
6from collections import OrderedDict
7from typing import Any, List, Optional, Tuple
8from typing_extensions import deprecated
9
10import torch
11import torch._C as _C
12import torch._functorch as _functorch
13import torch.utils.hooks as hooks
14from torch._C import _functions
15from torch._functorch.autograd_function import custom_function_call
16
17
18__all__ = [
19    "FunctionCtx",
20    "BackwardCFunction",
21    "FunctionMeta",
22    "Function",
23    "once_differentiable",
24    "InplaceFunction",
25    "NestedIOFunction",
26]
27
28# Unique id provider for each class inheriting from Function
29# This is incremented in FunctionMeta during class definition
30AUTOGRAD_FUNCTION_COUNTER = itertools.count()
31
32
33# Formerly known as: _ContextMethodMixin
34class FunctionCtx:
35    def save_for_backward(self, *tensors: torch.Tensor):
36        r"""Save given tensors for a future call to :func:`~Function.backward`.
37
38        ``save_for_backward`` should be called at most once, in either the
39        :func:`setup_context` or :func:`forward` methods, and only with tensors.
40
41        All tensors intended to be used in the backward pass should be saved
42        with ``save_for_backward`` (as opposed to directly on ``ctx``) to prevent
43        incorrect gradients and memory leaks, and enable the application of saved
44        tensor hooks. See :class:`torch.autograd.graph.saved_tensors_hooks`.
45
46        Note that if intermediary tensors, tensors that are neither inputs
47        nor outputs of :func:`forward`, are saved for backward, your custom Function
48        may not support double backward.
49        Custom Functions that do not support double backward should decorate their
50        :func:`backward` method with ``@once_differentiable`` so that performing
51        double backward raises an error. If you'd like to support double backward,
52        you can either recompute intermediaries based on the inputs during backward
53        or return the intermediaries as the outputs of the custom Function. See the
54        `double backward tutorial <https://pytorch.org/tutorials/intermediate/custom_function_double_backward_tutorial.html>`_
55        for more details.
56
57        In :func:`backward`, saved tensors can be accessed through the :attr:`saved_tensors`
58        attribute. Before returning them to the user, a check is made to ensure
59        they weren't used in any in-place operation that modified their content.
60
61        Arguments can also be ``None``. This is a no-op.
62
63        See :ref:`extending-autograd` for more details on how to use this method.
64
65        Example::
66            >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
67            >>> class Func(Function):
68            >>>     @staticmethod
69            >>>     def forward(ctx, x: torch.Tensor, y: torch.Tensor, z: int):
70            >>>         w = x * z
71            >>>         out = x * y + y * z + w * y
72            >>>         ctx.save_for_backward(x, y, w, out)
73            >>>         ctx.z = z  # z is not a tensor
74            >>>         return out
75            >>>
76            >>>     @staticmethod
77            >>>     @once_differentiable
78            >>>     def backward(ctx, grad_out):
79            >>>         x, y, w, out = ctx.saved_tensors
80            >>>         z = ctx.z
81            >>>         gx = grad_out * (y + y * z)
82            >>>         gy = grad_out * (x + z + w)
83            >>>         gz = None
84            >>>         return gx, gy, gz
85            >>>
86            >>> a = torch.tensor(1., requires_grad=True, dtype=torch.double)
87            >>> b = torch.tensor(2., requires_grad=True, dtype=torch.double)
88            >>> c = 4
89            >>> d = Func.apply(a, b, c)
90
91        """
92        self.to_save = tensors
93
94    def save_for_forward(self, *tensors: torch.Tensor):
95        r"""Save given tensors for a future call to :func:`~Function.jvp`.
96
97        ``save_for_forward`` should be called at most once, in either the
98        :func:`setup_context` or :func:`forward` methods, and all arguments
99        should be tensors.
100
101        In :func:`jvp`, saved objects can be accessed through the :attr:`saved_tensors`
102        attribute.
103
104        Arguments can also be ``None``. This is a no-op.
105
106        See :ref:`extending-autograd` for more details on how to use this method.
107
108        Example::
109            >>> # xdoctest: +SKIP
110            >>> class Func(torch.autograd.Function):
111            >>>     @staticmethod
112            >>>     def forward(ctx, x: torch.Tensor, y: torch.Tensor, z: int):
113            >>>         ctx.save_for_backward(x, y)
114            >>>         ctx.save_for_forward(x, y)
115            >>>         ctx.z = z
116            >>>         return x * y * z
117            >>>
118            >>>     @staticmethod
119            >>>     def jvp(ctx, x_t, y_t, _):
120            >>>         x, y = ctx.saved_tensors
121            >>>         z = ctx.z
122            >>>         return z * (y * x_t + x * y_t)
123            >>>
124            >>>     @staticmethod
125            >>>     def vjp(ctx, grad_out):
126            >>>         x, y = ctx.saved_tensors
127            >>>         z = ctx.z
128            >>>         return z * grad_out * y, z * grad_out * x, None
129            >>>
130            >>>     a = torch.tensor(1., requires_grad=True, dtype=torch.double)
131            >>>     t = torch.tensor(1., dtype=torch.double)
132            >>>     b = torch.tensor(2., requires_grad=True, dtype=torch.double)
133            >>>     c = 4
134            >>>
135            >>>     with fwAD.dual_level():
136            >>>         a_dual = fwAD.make_dual(a, t)
137            >>>         d = Func.apply(a_dual, b, c)
138
139        """
140        for tensor in tensors:
141            assert isinstance(tensor, torch.Tensor) or tensor is None, (
142                "save_for_forward expects all arguments to be tensors; you should "
143                "save non-tensors as attributes on ctx."
144            )
145
146        self.saved_for_forward = tensors
147
148    def mark_dirty(self, *args: torch.Tensor):
149        r"""Mark given tensors as modified in an in-place operation.
150
151        This should be called at most once, in either the :func:`setup_context`
152        or :func:`forward` methods, and all arguments should be inputs.
153
154        Every tensor that's been modified in-place in a call to :func:`forward`
155        should be given to this function, to ensure correctness of our checks.
156        It doesn't matter whether the function is called before or after
157        modification.
158
159        Examples::
160            >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
161            >>> class Inplace(Function):
162            >>>     @staticmethod
163            >>>     def forward(ctx, x):
164            >>>         x_npy = x.numpy() # x_npy shares storage with x
165            >>>         x_npy += 1
166            >>>         ctx.mark_dirty(x)
167            >>>         return x
168            >>>
169            >>>     @staticmethod
170            >>>     @once_differentiable
171            >>>     def backward(ctx, grad_output):
172            >>>         return grad_output
173            >>>
174            >>> a = torch.tensor(1., requires_grad=True, dtype=torch.double).clone()
175            >>> b = a * a
176            >>> Inplace.apply(a)  # This would lead to wrong gradients!
177            >>>                   # but the engine would not know unless we mark_dirty
178            >>> # xdoctest: +SKIP
179            >>> b.backward() # RuntimeError: one of the variables needed for gradient
180            >>>              # computation has been modified by an inplace operation
181
182        """
183        self.dirty_tensors = args
184
185    @deprecated(
186        "`mark_shared_storage` is deprecated. "
187        "Tensors with shared storages are automatically tracked. "
188        "Note that calls to `set_()` are not tracked",
189        category=FutureWarning,
190    )
191    def mark_shared_storage(self, *pairs):
192        pass
193
194    def mark_non_differentiable(self, *args: torch.Tensor):
195        r"""Mark outputs as non-differentiable.
196
197        This should be called at most once, in either the :func:`setup_context`
198        or :func:`forward` methods, and all arguments should be tensor outputs.
199
200        This will mark outputs as not requiring gradients, increasing the
201        efficiency of backward computation. You still need to accept a gradient
202        for each output in :meth:`~Function.backward`, but it's always going to
203        be a zero tensor with the same shape as the shape of a corresponding
204        output.
205
206        This is used e.g. for indices returned from a sort. See example::
207            >>> class Func(Function):
208            >>>     @staticmethod
209            >>>     def forward(ctx, x):
210            >>>         sorted, idx = x.sort()
211            >>>         ctx.mark_non_differentiable(idx)
212            >>>         ctx.save_for_backward(x, idx)
213            >>>         return sorted, idx
214            >>>
215            >>>     @staticmethod
216            >>>     @once_differentiable
217            >>>     def backward(ctx, g1, g2):  # still need to accept g2
218            >>>         x, idx = ctx.saved_tensors
219            >>>         grad_input = torch.zeros_like(x)
220            >>>         grad_input.index_add_(0, idx, g1)
221            >>>         return grad_input
222
223        """
224        self.non_differentiable = args
225
226    def set_materialize_grads(self, value: bool):
227        r"""Set whether to materialize grad tensors. Default is ``True``.
228
229        This should be called only from either the :func:`setup_context` or
230        :func:`forward` methods.
231
232        If ``True``, undefined grad tensors will be expanded to tensors full of zeros
233        prior to calling the :func:`backward` and :func:`jvp` methods.
234
235        Example::
236            >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
237            >>> class SimpleFunc(Function):
238            >>>     @staticmethod
239            >>>     def forward(ctx, x):
240            >>>         return x.clone(), x.clone()
241            >>>
242            >>>     @staticmethod
243            >>>     @once_differentiable
244            >>>     def backward(ctx, g1, g2):
245            >>>         return g1 + g2  # No check for None necessary
246            >>>
247            >>> # We modify SimpleFunc to handle non-materialized grad outputs
248            >>> class Func(Function):
249            >>>     @staticmethod
250            >>>     def forward(ctx, x):
251            >>>         ctx.set_materialize_grads(False)
252            >>>         ctx.save_for_backward(x)
253            >>>         return x.clone(), x.clone()
254            >>>
255            >>>     @staticmethod
256            >>>     @once_differentiable
257            >>>     def backward(ctx, g1, g2):
258            >>>         x, = ctx.saved_tensors
259            >>>         grad_input = torch.zeros_like(x)
260            >>>         if g1 is not None:  # We must check for None now
261            >>>             grad_input += g1
262            >>>         if g2 is not None:
263            >>>             grad_input += g2
264            >>>         return grad_input
265            >>>
266            >>> a = torch.tensor(1., requires_grad=True)
267            >>> b, _ = Func.apply(a)  # induces g2 to be undefined
268
269        """
270        self.materialize_grads = value
271
272
273# DO NOT USE: This is only defined to be able to load old serialized models
274_ContextMethodMixin = FunctionCtx
275
276
277class _HookMixin:
278    @staticmethod
279    def _register_hook(backward_hooks, hook):
280        if backward_hooks is None:
281            backward_hooks = OrderedDict()
282        handle = hooks.RemovableHandle(backward_hooks)
283        backward_hooks[handle.id] = hook
284        return backward_hooks, handle
285
286
287class BackwardCFunction(_C._FunctionBase, FunctionCtx, _HookMixin):
288    r"""
289    This class is used for internal autograd work. Do not use.
290    """
291
292    def apply(self, *args):
293        r"""
294        Apply method used when executing this Node during the backward
295        """
296        # _forward_cls is defined by derived class
297        # The user should define either backward or vjp but never both.
298        backward_fn = self._forward_cls.backward  # type: ignore[attr-defined]
299        vjp_fn = self._forward_cls.vjp  # type: ignore[attr-defined]
300        if backward_fn is not Function.backward and vjp_fn is not Function.vjp:
301            raise RuntimeError(
302                "Implementing both 'backward' and 'vjp' for a custom "
303                "Function is not allowed. You should only implement one "
304                "of them."
305            )
306        user_fn = vjp_fn if vjp_fn is not Function.vjp else backward_fn
307        return user_fn(self, *args)
308
309    def apply_jvp(self, *args):
310        r"""
311        Apply method used when executing forward mode AD during the forward
312        """
313        # _forward_cls is defined by derived class
314        return self._forward_cls.jvp(self, *args)  # type: ignore[attr-defined]
315
316    def _compiled_autograd_key(self):
317        return self._forward_cls._compiled_autograd_key(self)  # type: ignore[attr-defined]
318
319
320class FunctionMeta(type):
321    """Function metaclass.
322
323    This metaclass sets up the following properties:
324        _backward_cls: The Function class corresponding to the differentiated
325            version of this function (which is generated on the fly by this
326            metaclass).
327    """
328
329    def __init__(cls, name, bases, attrs):
330        backward_fn = type(
331            name + "Backward", (BackwardCFunction,), {"_forward_cls": cls}
332        )
333        backward_fn._autograd_function_id = next(AUTOGRAD_FUNCTION_COUNTER)  # type: ignore[attr-defined]
334        backward_fn._compiled_autograd_should_lift = attrs.get(  # type: ignore[attr-defined]
335            "_compiled_autograd_should_lift", True
336        )
337        cls._backward_cls = backward_fn
338
339        super().__init__(name, bases, attrs)
340
341
342class _SingleLevelFunction(
343    _C._FunctionBase, FunctionCtx, _HookMixin, metaclass=FunctionMeta
344):
345    @staticmethod
346    def forward(*args: Any, **kwargs: Any) -> Any:
347        r"""Define the forward of the custom autograd Function.
348
349        This function is to be overridden by all subclasses.
350        There are two ways to define forward:
351
352        Usage 1 (Combined forward and ctx)::
353
354            @staticmethod
355            def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
356                pass
357
358        - It must accept a context ctx as the first argument, followed by any
359          number of arguments (tensors or other types).
360        - See :ref:`combining-forward-context` for more details
361
362        Usage 2 (Separate forward and ctx)::
363
364            @staticmethod
365            def forward(*args: Any, **kwargs: Any) -> Any:
366                pass
367
368            @staticmethod
369            def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:
370                pass
371
372        - The forward no longer accepts a ctx argument.
373        - Instead, you must also override the :meth:`torch.autograd.Function.setup_context`
374          staticmethod to handle setting up the ``ctx`` object.
375          ``output`` is the output of the forward, ``inputs`` are a Tuple of inputs
376          to the forward.
377        - See :ref:`extending-autograd` for more details
378
379        The context can be used to store arbitrary data that can be then
380        retrieved during the backward pass. Tensors should not be stored
381        directly on `ctx` (though this is not currently enforced for
382        backward compatibility). Instead, tensors should be saved either with
383        :func:`ctx.save_for_backward` if they are intended to be used in
384        ``backward`` (equivalently, ``vjp``) or :func:`ctx.save_for_forward`
385        if they are intended to be used for in ``jvp``.
386        """
387        raise NotImplementedError(
388            "You must implement the forward function for custom autograd.Function."
389        )
390
391    @staticmethod
392    def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> Any:
393        r"""There are two ways to define the forward pass of an autograd.Function.
394
395        Either:
396
397        1. Override forward with the signature ``forward(ctx, *args, **kwargs)``.
398           ``setup_context`` is not overridden. Setting up the ctx for backward
399           happens inside the ``forward``.
400        2. Override forward with the signature ``forward(*args, **kwargs)`` and
401           override ``setup_context``. Setting up the ctx for backward happens
402           inside ``setup_context`` (as opposed to inside the ``forward``)
403
404        See :meth:`torch.autograd.Function.forward` and :ref:`extending-autograd` for more details.
405        """
406        raise NotImplementedError("setup_context is not implemented.")
407
408    @staticmethod
409    def backward(ctx: Any, *grad_outputs: Any) -> Any:
410        r"""Define a formula for differentiating the operation with backward mode automatic differentiation.
411
412        This function is to be overridden by all subclasses.
413        (Defining this function is equivalent to defining the ``vjp`` function.)
414
415        It must accept a context :attr:`ctx` as the first argument, followed by
416        as many outputs as the :func:`forward` returned (None will be passed in
417        for non tensor outputs of the forward function),
418        and it should return as many tensors, as there were inputs to
419        :func:`forward`. Each argument is the gradient w.r.t the given output,
420        and each returned value should be the gradient w.r.t. the
421        corresponding input. If an input is not a Tensor or is a Tensor not
422        requiring grads, you can just pass None as a gradient for that input.
423
424        The context can be used to retrieve tensors saved during the forward
425        pass. It also has an attribute :attr:`ctx.needs_input_grad` as a tuple
426        of booleans representing whether each input needs gradient. E.g.,
427        :func:`backward` will have ``ctx.needs_input_grad[0] = True`` if the
428        first input to :func:`forward` needs gradient computed w.r.t. the
429        output.
430        """
431        raise NotImplementedError(
432            "You must implement either the backward or vjp method for "
433            "your custom autograd.Function to use it with backward "
434            "mode AD."
435        )
436
437    # vjp and backward are alias of each other
438    vjp = backward
439
440    @staticmethod
441    def jvp(ctx: Any, *grad_inputs: Any) -> Any:
442        r"""Define a formula for differentiating the operation with forward mode automatic differentiation.
443
444        This function is to be overridden by all subclasses.
445        It must accept a context :attr:`ctx` as the first argument, followed by
446        as many inputs as the :func:`forward` got (None will be passed in
447        for non tensor inputs of the forward function),
448        and it should return as many tensors as there were outputs to
449        :func:`forward`. Each argument is the gradient w.r.t the given input,
450        and each returned value should be the gradient w.r.t. the
451        corresponding output. If an output is not a Tensor or the function is not
452        differentiable with respect to that output, you can just pass None as a
453        gradient for that input.
454
455        You can use the :attr:`ctx` object to pass any value from the forward to this
456        functions.
457        """
458        raise NotImplementedError(
459            "You must implement the jvp function for custom "
460            "autograd.Function to use it with forward mode AD."
461        )
462
463
464class Function(_SingleLevelFunction):
465    r"""Base class to create custom `autograd.Function`.
466
467    To create a custom `autograd.Function`, subclass this class and implement
468    the :meth:`forward` and :meth:`backward` static methods. Then, to use your custom
469    op in the forward pass, call the class method ``apply``. Do not call
470    :meth:`forward` directly.
471
472    To ensure correctness and best performance, make sure you are calling the
473    correct methods on ``ctx`` and validating your backward function using
474    :func:`torch.autograd.gradcheck`.
475
476    See :ref:`extending-autograd` for more details on how to use this class.
477
478    Examples::
479
480        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
481        >>> class Exp(Function):
482        >>>     @staticmethod
483        >>>     def forward(ctx, i):
484        >>>         result = i.exp()
485        >>>         ctx.save_for_backward(result)
486        >>>         return result
487        >>>
488        >>>     @staticmethod
489        >>>     def backward(ctx, grad_output):
490        >>>         result, = ctx.saved_tensors
491        >>>         return grad_output * result
492        >>>
493        >>> # Use it by calling the apply method:
494        >>> # xdoctest: +SKIP
495        >>> output = Exp.apply(input)
496    """
497
498    def __init__(self, *args, **kwargs):
499        warnings.warn(
500            f"{self.__class__} should not be instantiated. Methods on autograd functions"
501            "are all static, so you should invoke them on the class itself. "
502            "Instantiating an autograd function will raise an "
503            "error in a future version of PyTorch.",
504            DeprecationWarning,
505            stacklevel=2,
506        )
507
508    def __call__(self, *args, **kwargs):
509        raise RuntimeError(
510            "Legacy autograd function with non-static forward method is deprecated. "
511            "Please use new-style autograd function with static forward method. "
512            "(Example: https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function)"
513        )
514
515    """
516    Bool that specifies if PyTorch should attempt to autogenerate
517    :func:`torch.vmap` support for this autograd.Function. You may set this to
518    True only if this autograd.Function's forward, backward, and jvp (if they
519    exist) are written using PyTorch operations; otherwise, please override
520    :meth:`torch.autograd.Function.vmap` to add support for :func:`torch.vmap`.
521
522    Please see :ref:`func-autograd-function` for more details.
523    """
524    generate_vmap_rule = False
525
526    @staticmethod
527    def vmap(info, in_dims, *args):
528        r"""Define the behavior for this autograd.Function underneath :func:`torch.vmap`.
529
530        For a :func:`torch.autograd.Function` to support
531        :func:`torch.vmap`, you must either override this static method, or set
532        ``generate_vmap_rule`` to ``True`` (you may not do both).
533
534        If you choose to override this staticmethod: it must accept
535
536        - an ``info`` object as the first argument. ``info.batch_size``
537          specifies the size of the dimension being vmapped over,
538          while ``info.randomness`` is the randomness option passed to
539          :func:`torch.vmap`.
540        - an ``in_dims`` tuple as the second argument.
541          For each arg in ``args``, ``in_dims`` has a corresponding
542          ``Optional[int]``. It is ``None`` if the arg is not a Tensor or if
543          the arg is not being vmapped over, otherwise, it is an integer
544          specifying what dimension of the Tensor is being vmapped over.
545        - ``*args``, which is the same as the args to :meth:`~Function.forward`.
546
547        The return of the vmap staticmethod is a tuple of ``(output, out_dims)``.
548        Similar to ``in_dims``, ``out_dims`` should be of the same structure as
549        ``output`` and contain one ``out_dim`` per output that specifies if the
550        output has the vmapped dimension and what index it is in.
551
552        Please see :ref:`func-autograd-function` for more details.
553        """
554        raise NotImplementedError(
555            "To use autograd.Function with vmap, you must either override the "
556            "vmap staticmethod or set generate_vmap_rule=True."
557        )
558
559    @classmethod
560    def apply(cls, *args, **kwargs):
561        def bind_default_args(func, *args, **kwargs):
562            signature = inspect.signature(func)
563            bound_args = signature.bind(*args, **kwargs)
564            bound_args.apply_defaults()
565
566            return bound_args.args
567
568        is_setup_ctx_defined = _is_setup_context_defined(cls.setup_context)
569        if is_setup_ctx_defined:
570            args = bind_default_args(cls.forward, *args, **kwargs)
571
572        if not torch._C._are_functorch_transforms_active():
573            # See NOTE: [functorch vjp and autograd interaction]
574            args = _functorch.utils.unwrap_dead_wrappers(args)
575            return super().apply(*args, **kwargs)  # type: ignore[misc]
576
577        if not is_setup_ctx_defined:
578            raise RuntimeError(
579                "In order to use an autograd.Function with functorch transforms "
580                "(vmap, grad, jvp, jacrev, ...), it must override the setup_context "
581                "staticmethod. For more details, please see "
582                "https://pytorch.org/docs/main/notes/extending.func.html"
583            )
584
585        return custom_function_call(cls, *args, **kwargs)
586
587    @staticmethod
588    def _compiled_autograd_key(ctx):
589        return (ctx._autograd_function_id,)
590
591
592def _is_setup_context_defined(fn):
593    return fn != _SingleLevelFunction.setup_context
594
595
596def once_differentiable(fn):
597    @functools.wraps(fn)
598    def wrapper(ctx, *args):
599        with torch.no_grad():
600            outputs = fn(ctx, *args)
601
602        if not torch.is_grad_enabled():
603            return outputs
604
605        # If any of the inputs have requires_grad=True, we force the outputs
606        # to have requires_grad=True but point to a grad_fn which throws an
607        # error message during (double) back-propagation.
608        # XXX: this is only an approximation of requires_grad - there's no way
609        # to figure out if fn didn't use ctx.saved_tensors and as a result
610        # some Tensors might require grad, even if no args do.
611        # Unfortunately, this leads to unexpected error messages ("no nodes
612        # require computing gradients"), but I don't have a better idea.
613        # These functions would raise an error in backward anyway.
614        requires_grad = any(
615            isinstance(arg, torch.Tensor) and arg.requires_grad for arg in args
616        )
617        if not requires_grad:
618            return outputs
619
620        if not isinstance(outputs, tuple):
621            outputs = (outputs,)
622
623        err_fn = _functions.DelayedError(
624            b"trying to differentiate twice a function that was marked "
625            b"with @once_differentiable",
626            len(outputs),
627        )
628
629        # Create aliases of each output that has requires_grad=True. We need
630        # at least one of the inputs to err_fn to require grad so that the
631        # output will have a grad_fn.
632        def fake_requires_grad(var):
633            if var is not None:
634                var = var.detach()
635                var.requires_grad = True
636            return var
637
638        return err_fn(*[fake_requires_grad(v) for v in outputs])
639
640    return wrapper
641
642
643class InplaceFunction(Function):
644    r"""
645    This class is here only for backward compatibility reasons.
646    Use :class:`Function` instead of this for any new use case.
647    """
648
649    def __init__(self, inplace=False):
650        super().__init__()
651        self.inplace = inplace
652
653
654def _nested_map(condition, fn, condition_msg=None):
655    def _map(obj):
656        if condition(obj):
657            return fn(obj)
658        elif obj is None:
659            return None
660        elif isinstance(obj, (list, tuple)):
661            mapped = (_map(x) for x in obj)
662            if hasattr(obj, "_fields"):
663                # obj is namedtuple
664                return type(obj)(*mapped)
665            return type(obj)(mapped)
666        elif isinstance(obj, dict):
667            return {x: _map(obj[x]) for x in obj}
668        else:
669            raise ValueError(
670                "Auto nesting doesn't know how to process "
671                "an input object of type "
672                + torch.typename(obj)
673                + (
674                    ". Accepted types: " + condition_msg + ", or lists/tuples of them"
675                    if condition_msg
676                    else ""
677                )
678            )
679
680    return _map
681
682
683def _jit_unwrap_structured(obj):
684    if hasattr(obj, "_jit_unwrap"):
685        return obj._jit_unwrap()
686    return obj
687
688
689def _iter_filter(condition, allow_unknown=False, condition_msg=None, conversion=None):
690    def _iter(obj):
691        if conversion is not None:
692            obj = conversion(obj)
693        if condition(obj):
694            yield obj
695        elif obj is None:
696            return
697        elif isinstance(obj, (list, tuple)):
698            for o in obj:
699                yield from _iter(o)
700        elif isinstance(obj, dict):
701            # We only accept primitive key types, so we needn't inspect them
702            for o in obj.values():
703                yield from _iter(o)
704        elif allow_unknown:
705            yield obj
706        else:
707            raise ValueError(
708                "Auto nesting doesn't know how to process "
709                "an input object of type "
710                + torch.typename(obj)
711                + (
712                    ". Accepted types: " + condition_msg + ", or lists/tuples of them"
713                    if condition_msg
714                    else ""
715                )
716            )
717
718    return _iter
719
720
721def _unflatten(input, proto):
722    # unflatten a list or tuple input into a nested list/tuple structure
723    # specified by proto
724    def unflatten_helper(input, proto):
725        res: List[Optional[torch.Tensor]] = []
726        if hasattr(proto, "_jit_wrap"):
727            return proto._jit_wrap(input)
728        if not isinstance(proto, (list, tuple)):
729            return input[0], input[1:]
730        for e in proto:
731            if e is None:
732                res.append(e)
733            else:
734                res_e, input = unflatten_helper(input, e)
735                res.append(res_e)
736        return type(proto)(res), input
737
738    return unflatten_helper(input, proto)[0]
739
740
741_iter_jit_values = _iter_filter(
742    lambda o: o is None or isinstance(o, torch._C.Value),
743    condition_msg="jit's Values or None",
744)
745_iter_tensors = _iter_filter(
746    lambda x: isinstance(x, torch.Tensor),
747    condition_msg="Tensors",
748    conversion=_jit_unwrap_structured,
749)
750_iter_tensors_permissive = _iter_filter(
751    lambda x: isinstance(x, torch.Tensor),
752    allow_unknown=True,
753    condition_msg="Tensors (permissive)",
754)
755_iter_None_tensors = _iter_filter(
756    lambda o: o is None or isinstance(o, torch.Tensor), condition_msg="Tensors or None"
757)
758_map_tensor_data = _nested_map(
759    lambda x: isinstance(x, torch.Tensor), lambda o: o.data, condition_msg="Tensors"
760)
761
762
763class NestedIOFunction(Function):
764    r"""
765    This class is here only for backward compatibility reasons.
766    Use :class:`Function` instead of this for any new use case.
767    """
768    # The 'type: ignore' statements are needed here because these functions are declared as '@staticmethod' in the
769    # superclass (Function) but are instance methods here, which mypy reports as incompatible.
770
771    def _do_forward(self, *input):
772        self._nested_input = input
773        flat_input = tuple(_iter_tensors(input))
774        flat_output = super()._do_forward(*flat_input)  # type: ignore[misc]
775        nested_output = self._nested_output
776        nested_tensors = _unflatten(flat_output, self._nested_output)
777        return nested_tensors
778
779    def _do_backward(self, gradients, retain_variables):
780        self.retain_variables = retain_variables
781        result = super()._do_backward(gradients, retain_variables)  # type: ignore[misc]
782        if not retain_variables:
783            del self._nested_output
784            del self._to_save_nested
785        return result
786
787    def backward(self, *gradients: Any) -> Any:  # type: ignore[override]
788        r"""
789        Shared backward utility.
790        """
791        nested_gradients = _unflatten(gradients, self._nested_output)
792        result = self.backward_extended(*nested_gradients)  # type: ignore[func-returns-value]
793        return tuple(_iter_None_tensors(result))
794
795    __call__ = _do_forward
796
797    def forward(self, *args: Any) -> Any:  # type: ignore[override]
798        r"""
799        Shared forward utility.
800        """
801        nested_tensors = _map_tensor_data(self._nested_input)
802        result = self.forward_extended(*nested_tensors)  # type: ignore[func-returns-value]
803        del self._nested_input
804        self._nested_output = result
805        return tuple(_iter_tensors(result))
806
807    def save_for_backward(self, *args: Any) -> None:
808        r"""
809        See :meth:`Function.save_for_backward`.
810        """
811        self.to_save = tuple(_iter_tensors(args))
812        self._to_save_nested = args
813
814    @property
815    def saved_tensors(self):
816        r"""
817        See :meth:`Function.saved_tensors`.
818        """
819        flat_tensors = super().saved_tensors  # type: ignore[misc]
820        return _unflatten(flat_tensors, self._to_save_nested)
821
822    def mark_dirty(self, *args: Any, **kwargs: Any) -> None:
823        r"""
824        See :meth:`Function.mark_dirty`.
825        """
826        self.dirty_tensors = tuple(_iter_tensors((args, kwargs)))
827
828    def mark_non_differentiable(self, *args: Any, **kwargs: Any) -> None:
829        r"""
830        See :meth:`Function.mark_non_differentiable`.
831        """
832        self.non_differentiable = tuple(_iter_tensors((args, kwargs)))
833
834    def forward_extended(self, *input: Any) -> None:
835        r"""
836        User defined forward.
837        """
838        raise NotImplementedError
839
840    def backward_extended(self, *grad_output: Any) -> None:
841        r"""
842        User defined backward.
843        """
844        raise NotImplementedError
845