xref: /aosp_15_r20/external/pytorch/torch/_custom_op/autograd.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import torch
3import torch.utils._pytree as pytree
4from collections import namedtuple
5import functools
6
7
8# NOTE [CustomOp autograd kernel indirection]
9# We register `inner` as the autograd kernel for this custom_op.
10# `inner` either calls the autograd formula registered by the user,
11# or goes into an `autograd_not_implemented` kernel.
12#
13# The reason why this indirection exists is
14# so that we can swap out the autograd kernel (the PyTorch dispatcher
15# doesn't actually allow us to do this). By default, we want
16# the `autograd_not_implemented` behavior, but then the user may come
17# and register something that is actually a backward formula
18def autograd_kernel_indirection(custom_op):
19    autograd_fallback = autograd_not_implemented(custom_op)
20
21    def inner(*args, **kwargs):
22        if custom_op._has_impl('autograd'):
23            kernel = custom_op._get_impl('autograd').func
24            return kernel(*args, **kwargs)
25        # As explained in NOTE ["backward", "save_for_backward", and "autograd"],
26        # after the user gives us "backward" and "save_for_backward", we generate
27        # the "autograd" impl. If the user only provided one, then we tell
28        # the user they've done something wrong.
29        if custom_op._has_impl('save_for_backward') or custom_op._has_impl('backward'):
30            missing = (
31                'save_for_backward' if custom_op._has_impl('backward')
32                else 'backward'
33            )
34            found = 'save_for_backward' if missing == 'backward' else 'backward'
35            loc = custom_op._get_impl(found).location
36            raise RuntimeError(
37                f"We found a '{found}' registration for {custom_op} at "
38                f"{loc} but were unable to find a '{missing}' registration. "
39                f"To use the CustomOp API to register a backward formula, "
40                f"please provide us both a backward function and a "
41                f"'save for backward' function via `impl_backward` and "
42                f"`impl_save_for_backward` respectively.")
43        return autograd_fallback(*args, **kwargs)
44    return inner
45
46
47# TODO(#101191): Use the actual C++ autograd not implemented fallback,
48# or change the default autograd fallback to the autograd not implemented fallback.
49def autograd_not_implemented(custom_op):
50    def kernel(*args, **kwargs):
51        if torch.is_grad_enabled() and pytree.tree_any(
52            lambda x: isinstance(x, torch.Tensor) and x.requires_grad, (args, kwargs)
53        ):
54            raise RuntimeError("Autograd has not been implemented for operator")
55        with torch._C._AutoDispatchBelowAutograd():
56            return custom_op(*args, **kwargs)
57    return kernel
58
59
60def mark_non_differentiable(ctx, output, output_differentiability):
61    # Output types are restricted to be:
62    # - Tensor
63    # - Tensor[]
64    # - int, bool, Scalar, float
65    # See _check_can_register_backward
66    if output_differentiability is not None:
67        if not isinstance(output, tuple):
68            tuple_output = (output,)
69        else:
70            tuple_output = output  # type: ignore[assignment]
71        assert len(output_differentiability) == len(tuple_output)
72        non_differentiable_tensors = []
73        for idx, (differentiable, out) in enumerate(zip(output_differentiability, tuple_output)):
74            if isinstance(out, torch.Tensor):
75                if not differentiable:
76                    non_differentiable_tensors.append(out)
77                continue
78            if isinstance(out, list):
79                if not differentiable:
80                    non_differentiable_tensors.extend(out)
81                continue
82            if differentiable:
83                raise RuntimeError(
84                    f"With output_differentiability={output_differentiability}. "
85                    f"At idx {idx}, we received an object of type {type(out)} that "
86                    f"is not a Tensor, so it cannot have be marked as differentiable in "
87                    f"output_differentiability.")
88        if non_differentiable_tensors:
89            ctx.mark_non_differentiable(*non_differentiable_tensors)
90
91
92def construct_autograd_kernel(
93        schema,
94        output_differentiability,
95        custom_op,
96        op_overload,
97        save_for_backward_fn,
98        backward_fn):
99
100    def apply(*args):
101        flat_args, spec = pytree.tree_flatten(args)
102        out_spec = None
103
104        def forward(ctx, *flat_args):
105            ctx.set_materialize_grads(True)
106            args = pytree.tree_unflatten(list(flat_args), spec)
107            with torch._C._AutoDispatchBelowAutograd():
108                output = op_overload(*args)
109
110            # We use the info about args to give better error messages in backward
111            args_info = namedtuple_args(
112                schema, pytree.tree_map(type, args))
113
114            save_for_backward_fn_inputs = namedtuple_args(schema, args)
115            to_save = save_for_backward_fn(save_for_backward_fn_inputs, output)
116
117            save_pytree_for_backward(ctx, (to_save, args_info))
118            mark_non_differentiable(ctx, output, output_differentiability)
119
120            nonlocal out_spec
121            flat_output, out_spec = pytree.tree_flatten(output)
122            return tuple(flat_output)
123
124        def backward(ctx, *flat_grad_output):
125            assert out_spec is not None
126            grads = pytree.tree_unflatten(list(flat_grad_output), out_spec)
127            saved, args_info = unpack_saved(ctx)
128            # There is nothing on the ctx object for now, it is just there so
129            # that we can add additional things in the future.
130            inner_ctx = object()
131            if not isinstance(grads, tuple):
132                grads = (grads,)
133            grad_inputs_dict = backward_fn(inner_ctx, saved, *grads)
134
135            # Massage the grad_inputs_dict to a form acceptable by
136            # autograd.Function.
137            validate_grad_inputs_dict(grad_inputs_dict, custom_op, args_info)
138            return grad_inputs_dict_to_flat_tuple(grad_inputs_dict, args_info)
139
140        generated_cls = gen_autograd_function(
141            custom_op._opname + '_customop', forward, backward)
142
143        flat_output = generated_cls.apply(*flat_args)
144        assert out_spec is not None
145        return pytree.tree_unflatten(list(flat_output), out_spec)
146    return apply
147
148
149def gen_autograd_function(name, forward, backward):
150    generated_cls = type(
151        name,
152        (torch.autograd.Function,),
153        {
154            'forward': staticmethod(forward),
155            'backward': staticmethod(backward),
156        }
157    )
158    return generated_cls
159
160
161@functools.lru_cache
162def namedtuple_args_cls(schema):
163    attribs = [arg.name for arg in schema.arguments.flat_all]
164    name = str(schema.name) + "_args"
165    # mypy doesn't support dynamic namedtuple name
166    tuple_cls = namedtuple(name, attribs)  # type: ignore[misc]
167    return tuple_cls
168
169
170def namedtuple_args(schema, args):
171    assert isinstance(args, tuple)
172    tuple_cls = namedtuple_args_cls(schema)
173    return tuple_cls(*args)
174
175
176def validate_grad_inputs_dict(grad_inputs_dict, forward_op, args_info):
177    def error(what):
178        backward = forward_op._get_impl('backward')
179        raise RuntimeError(
180            f"In the backward function defined for {forward_op} at "
181            f"{backward.location} using the CustomOp API, {what}")
182
183    if not isinstance(grad_inputs_dict, dict):
184        error(f"expected the output of the backward function to be a dict but "
185              f"got {type(grad_inputs_dict)}")
186
187    expected_keys = {arg.name for arg in forward_op._schema.arguments.flat_all
188                     if arg.type.is_tensor_like()}
189    actual_keys = grad_inputs_dict.keys()
190    if expected_keys != actual_keys:
191        error(f"expected the returned grad_input dict to have keys "
192              f"{expected_keys} but got {actual_keys}. The backward "
193              f"function must return a gradient (can be None) for each arg "
194              f"to the CustomOp that may be a Tensor or Sequence[Tensor]. "
195              f"Args declared to be non-Tensor-like types should not appear "
196              f"in the grad_input dict")
197
198    for name, grad in grad_inputs_dict.items():
199        arg_info = getattr(args_info, name)
200
201        if isinstance(arg_info, list):
202            if not isinstance(grad, (tuple, list)):
203                error(f"for input '{name}' expected the grad_input dict to "
204                      f"hold a list of gradients but got object of type "
205                      f"{type(grad)}.")
206            if not len(grad) == len(arg_info):
207                error(f"for input '{name}' expected the grad_input dict to "
208                      f"hold a list of {len(arg_info)} gradients but got "
209                      f"{len(grad)}")
210            for idx, (g, info) in enumerate(zip(grad, arg_info)):
211                if g is None:
212                    continue
213                if not isinstance(g, torch.Tensor):
214                    error(f"for input '{name}' expected the grad_input dict to "
215                          f"hold a list of None or Tensor gradients but got "
216                          f"object of {type(g)} at index {idx}")
217                if not issubclass(info, torch.Tensor):
218                    error(f"for input '{name}', got a Tensor as the gradient "
219                          f"for the {idx}-th value but expected None because "
220                          f"the {idx}-th value was not a Tensor (it was "
221                          f"type {arg_info}")
222            continue
223
224        if grad is None:
225            continue
226        if not isinstance(grad, torch.Tensor):
227            error(f"got object of type {type(grad)} as the gradient for input "
228                  f"'{name}', "
229                  f"but expected the gradient to be either None or a Tensor")
230        if not issubclass(arg_info, torch.Tensor):
231            error(f"got a Tensor as the gradient for input '{name}' but "
232                  f"expected None as the gradient because input '{name}' "
233                  f"was not a Tensor (it was type {arg_info}).")
234
235
236def grad_inputs_dict_to_flat_tuple(grad_inputs_dict, args_info):
237    result = []
238    for name, arg_info in args_info._asdict().items():
239        if name not in grad_inputs_dict:
240            result.append(pytree.tree_map(lambda x: None, arg_info))
241            continue
242        result.append(grad_inputs_dict[name])
243    return tuple(pytree.tree_leaves(result))
244
245# Saves "stuff" (a pytree) onto the ctx object. Use unpack_saved to unpack it.
246# autograd.Function prefers that users use ctx.save_for_backward to
247# save Tensors (to avoid reference cycles) and for non-Tensors to go onto the
248# ctx object.
249def save_pytree_for_backward(ctx, stuff):
250    flat_stuff, spec = pytree.tree_flatten(stuff)
251    num_elts = len(flat_stuff)
252    tensor_idxs = [idx for idx, thing in enumerate(flat_stuff)
253                   if isinstance(thing, torch.Tensor)]
254    non_tensor_idxs = [idx for idx, thing in enumerate(flat_stuff)
255                       if not isinstance(thing, torch.Tensor)]
256    tensors = [thing for thing in flat_stuff if isinstance(thing, torch.Tensor)]
257    non_tensors = [thing for thing in flat_stuff if not isinstance(thing, torch.Tensor)]
258
259    ctx.spec = spec
260    ctx.num_elts = num_elts
261    ctx.save_for_backward(*tensors)
262    ctx.tensor_idxs = tensor_idxs
263    ctx.saved_non_tensors = non_tensors
264    ctx.non_tensor_idxs = non_tensor_idxs
265
266
267# Inverse operation to save_pytree_for_backward
268def unpack_saved(ctx):
269    flat_stuff = [None] * ctx.num_elts
270    for tensor, idx in zip(ctx.saved_tensors, ctx.tensor_idxs):
271        flat_stuff[idx] = tensor
272    for non_tensor, idx in zip(ctx.saved_non_tensors, ctx.non_tensor_idxs):
273        flat_stuff[idx] = non_tensor
274    stuff = pytree.tree_unflatten(flat_stuff, ctx.spec)
275    return stuff
276