xref: /aosp_15_r20/external/pytorch/torch/_library/autograd.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import dataclasses
3from dataclasses import dataclass
4from typing import Any, Callable, Dict, Optional, Protocol
5
6from .. import _C, _ops, autograd, Tensor
7
8from ..utils import _pytree
9from . import utils
10
11
12class InfoProtocol(Protocol):
13    _backward_fn: Optional[Callable]
14    _setup_context_fn: Optional[Callable]
15
16
17@dataclasses.dataclass
18class Info:
19    _backward_fn: Optional[Callable]
20    _setup_context_fn: Optional[Callable]
21
22
23def make_autograd_impl(op: _ops.OpOverload, info: InfoProtocol) -> Callable:
24    name: str = f"GeneratedBackwardFor_{op._namespace}_{op._opname}_{op._overloadname}"
25
26    has_kwarg_only_args = utils.has_kwarg_only_args(op._schema)
27
28    @dataclass
29    class Metadata:
30        keyset: _C.DispatchKeySet
31        keyword_only_args: Dict[str, Any]
32
33    def forward(ctx, *args):
34        metadata = args[-1]
35        args = args[:-1]
36
37        with _C._AutoDispatchBelowAutograd():
38            keyset = metadata.keyset
39            kwargs = metadata.keyword_only_args
40            result = op.redispatch(keyset & _C._after_autograd_keyset, *args, **kwargs)
41            if info._setup_context_fn:
42                # The Dispatcher will remove args that are equal to their default
43                # values from (args, kwargs). We're going to add it back so that
44                # the user can access them.
45                #
46                # This is OK to do: The Dispatcher removed the args for serialization
47                # FC/BC reasons (that is, a graph will not store args that are equal
48                # to their default values), but that doesn't matter here. If the user
49                # adds a new default arg, then they must update
50                # their setup_context (along with the rest of their operator
51                # registrations)
52                args, kwargs = utils.fill_defaults(op._schema, args, kwargs)
53
54                if has_kwarg_only_args:
55                    info._setup_context_fn(
56                        ctx=ctx, inputs=args, keyword_only_inputs=kwargs, output=result
57                    )
58                else:
59                    info._setup_context_fn(ctx=ctx, inputs=args, output=result)
60            return result
61
62    def backward(ctx, *grads):
63        if info._backward_fn:
64            try:
65                prev_needs_input_grad = ctx.needs_input_grad
66                ctx.needs_input_grad = ctx.needs_input_grad[:-1]
67                result = info._backward_fn(ctx, *grads)
68            finally:
69                ctx.needs_input_grad = prev_needs_input_grad
70            if isinstance(result, tuple):
71                return (*result, None)
72            return result, None
73        raise RuntimeError(
74            f"Trying to backward through {op} but no autograd "
75            f"formula was registered. "
76            f"Please use register_autograd to add one."
77        )
78
79    Generated = type(
80        name,
81        (autograd.Function,),
82        {
83            "forward": staticmethod(forward),
84            "backward": staticmethod(backward),
85        },
86    )
87
88    schema = op._schema
89    if any(
90        utils.is_tensorlist_like_type(a.type)
91        for a in (*schema.arguments, *schema.returns)
92    ):
93        Generated = supports_tensorlist(Generated)
94
95    # The dispatcher passes any keyword-only-args as kwargs and the
96    # rest of the args (even if specified as kwargs) as args.
97    def autograd_impl(keyset, *args, **keyword_only_args):
98        result = Generated.apply(*args, Metadata(keyset, keyword_only_args))  # type: ignore[attr-defined]
99        return result
100
101    return autograd_impl
102
103
104def supports_tensorlist(cls: Any) -> Any:
105    """Allows a given autograd.Function class to support List[Tensor] inputs/outputs.
106
107    Regular autograd.Function has a constraint that it only directly supports autograd for
108    Tensors. Applying @supports_tensorlist enables an autograd.Function to support
109    autograd for List[Tensor] inputs and outputs.
110    """
111    orig_forward = cls.forward
112    orig_backward = cls.backward
113    orig_apply = cls.apply
114
115    @dataclass
116    class Metadata:
117        input_spec: spec_t
118        output_spec: Optional[spec_t] = None
119        result_is_tuple: Optional[bool] = None
120
121    def new_forward(ctx, *args):
122        metadata = args[-1]
123        args = args[:-1]
124        if not isinstance(metadata, Metadata):
125            raise NotImplementedError(
126                "NYI: calling supports_tensorlist autograd.Function.forward directly. "
127                "You should probably be calling .apply instead. "
128                "Please file an issue if not."
129            )
130        args = unflatten(list(args), metadata.input_spec)
131        result = orig_forward(ctx, *args)
132        metadata.result_is_tuple = isinstance(result, tuple)
133        if not metadata.result_is_tuple:
134            result = (result,)
135        flat_result, output_spec = flatten(result, not_list_of_tensor)
136        metadata.output_spec = output_spec
137
138        if hasattr(ctx, "_pt_metadata"):
139            raise RuntimeError(
140                "Please don't set ctx._pt_metadata; PyTorch uses it to store info"
141            )
142        ctx._pt_metadata = metadata
143
144        return tuple(flat_result)
145
146    def new_backward(ctx, *grads):
147        if not hasattr(ctx, "_pt_metadata"):
148            raise NotImplementedError(
149                "NYI: calling supports_tensorlist autograd.Function.backward directly. "
150                "This will automatically get called by PyTorch autograd. "
151                "Please file an issue if you need this."
152            )
153
154        metadata = ctx._pt_metadata
155        grads = unflatten(list(grads), metadata.output_spec)
156
157        # If the user's input is ([x, y, z], w),
158        # then needs_input_grad is (bool, bool, bool, bool, bool).
159        # We need to
160        # 1. get rid of the additional bool (which comes from the extra
161        # `metadata input`)
162        # 2. unflatten to get the right structure.
163        prev_needs_input_grad = ctx.needs_input_grad
164        try:
165            ctx.needs_input_grad = unflatten(
166                list(ctx.needs_input_grad[:-1]), metadata.input_spec
167            )
168            grad_inputs = orig_backward(ctx, *grads)
169        finally:
170            ctx.needs_input_grad = prev_needs_input_grad
171
172        if not isinstance(grad_inputs, tuple):
173            grad_inputs = (grad_inputs,)
174        # Assume that any Nones in the backward are Tensors.
175        # If the forward has an arg that is [1, 2, 3], the backward should
176        # return None as the grad.
177        # If the forward has an arg that is [tensor, tensor], the backward
178        # may return [None, None], [grad, None], [None, grad], or [grad, grad].
179        flat_grad_inputs, grad_inputs_spec = flatten(
180            grad_inputs, not_list_of_optional_tensor
181        )
182        if grad_inputs_spec != metadata.input_spec:
183            raise RuntimeError(
184                f"Expected the return from backward to be of the same structure "
185                f"as the inputs. Got: {grad_inputs_spec} (return from backward), "
186                f"{metadata.input_spec} (inputs)"
187            )
188        return tuple(flat_grad_inputs + [None])
189
190    def new_apply(*args):
191        flat_args, input_spec = flatten(args, is_leaf=not_list_of_tensor)
192        metadata = Metadata(input_spec)
193        result = orig_apply(*flat_args, metadata)  # type: ignore[misc]
194        assert metadata.output_spec is not None
195        result = unflatten(list(result), metadata.output_spec)
196        if not metadata.result_is_tuple:
197            assert isinstance(result, tuple)
198            assert len(result) == 1
199            return result[0]
200        return result
201
202    cls.forward = new_forward
203    cls.backward = new_backward
204    cls.apply = new_apply
205    return cls
206
207
208def not_list_of_tensor(tree):
209    if isinstance(tree, tuple):
210        return False
211    if isinstance(tree, list):
212        return any(not isinstance(l, Tensor) for l in tree)
213    return True
214
215
216def not_list_of_optional_tensor(tree):
217    if isinstance(tree, tuple):
218        return False
219    if isinstance(tree, list):
220        return any(l is not None and not isinstance(l, Tensor) for l in tree)
221    return True
222
223
224flatten = _pytree.tree_flatten
225unflatten = _pytree.tree_unflatten
226spec_t = _pytree.TreeSpec
227