1# mypy: allow-untyped-defs 2import torch 3import torch.utils._pytree as pytree 4from collections import namedtuple 5import functools 6 7 8# NOTE [CustomOp autograd kernel indirection] 9# We register `inner` as the autograd kernel for this custom_op. 10# `inner` either calls the autograd formula registered by the user, 11# or goes into an `autograd_not_implemented` kernel. 12# 13# The reason why this indirection exists is 14# so that we can swap out the autograd kernel (the PyTorch dispatcher 15# doesn't actually allow us to do this). By default, we want 16# the `autograd_not_implemented` behavior, but then the user may come 17# and register something that is actually a backward formula 18def autograd_kernel_indirection(custom_op): 19 autograd_fallback = autograd_not_implemented(custom_op) 20 21 def inner(*args, **kwargs): 22 if custom_op._has_impl('autograd'): 23 kernel = custom_op._get_impl('autograd').func 24 return kernel(*args, **kwargs) 25 # As explained in NOTE ["backward", "save_for_backward", and "autograd"], 26 # after the user gives us "backward" and "save_for_backward", we generate 27 # the "autograd" impl. If the user only provided one, then we tell 28 # the user they've done something wrong. 29 if custom_op._has_impl('save_for_backward') or custom_op._has_impl('backward'): 30 missing = ( 31 'save_for_backward' if custom_op._has_impl('backward') 32 else 'backward' 33 ) 34 found = 'save_for_backward' if missing == 'backward' else 'backward' 35 loc = custom_op._get_impl(found).location 36 raise RuntimeError( 37 f"We found a '{found}' registration for {custom_op} at " 38 f"{loc} but were unable to find a '{missing}' registration. " 39 f"To use the CustomOp API to register a backward formula, " 40 f"please provide us both a backward function and a " 41 f"'save for backward' function via `impl_backward` and " 42 f"`impl_save_for_backward` respectively.") 43 return autograd_fallback(*args, **kwargs) 44 return inner 45 46 47# TODO(#101191): Use the actual C++ autograd not implemented fallback, 48# or change the default autograd fallback to the autograd not implemented fallback. 49def autograd_not_implemented(custom_op): 50 def kernel(*args, **kwargs): 51 if torch.is_grad_enabled() and pytree.tree_any( 52 lambda x: isinstance(x, torch.Tensor) and x.requires_grad, (args, kwargs) 53 ): 54 raise RuntimeError("Autograd has not been implemented for operator") 55 with torch._C._AutoDispatchBelowAutograd(): 56 return custom_op(*args, **kwargs) 57 return kernel 58 59 60def mark_non_differentiable(ctx, output, output_differentiability): 61 # Output types are restricted to be: 62 # - Tensor 63 # - Tensor[] 64 # - int, bool, Scalar, float 65 # See _check_can_register_backward 66 if output_differentiability is not None: 67 if not isinstance(output, tuple): 68 tuple_output = (output,) 69 else: 70 tuple_output = output # type: ignore[assignment] 71 assert len(output_differentiability) == len(tuple_output) 72 non_differentiable_tensors = [] 73 for idx, (differentiable, out) in enumerate(zip(output_differentiability, tuple_output)): 74 if isinstance(out, torch.Tensor): 75 if not differentiable: 76 non_differentiable_tensors.append(out) 77 continue 78 if isinstance(out, list): 79 if not differentiable: 80 non_differentiable_tensors.extend(out) 81 continue 82 if differentiable: 83 raise RuntimeError( 84 f"With output_differentiability={output_differentiability}. " 85 f"At idx {idx}, we received an object of type {type(out)} that " 86 f"is not a Tensor, so it cannot have be marked as differentiable in " 87 f"output_differentiability.") 88 if non_differentiable_tensors: 89 ctx.mark_non_differentiable(*non_differentiable_tensors) 90 91 92def construct_autograd_kernel( 93 schema, 94 output_differentiability, 95 custom_op, 96 op_overload, 97 save_for_backward_fn, 98 backward_fn): 99 100 def apply(*args): 101 flat_args, spec = pytree.tree_flatten(args) 102 out_spec = None 103 104 def forward(ctx, *flat_args): 105 ctx.set_materialize_grads(True) 106 args = pytree.tree_unflatten(list(flat_args), spec) 107 with torch._C._AutoDispatchBelowAutograd(): 108 output = op_overload(*args) 109 110 # We use the info about args to give better error messages in backward 111 args_info = namedtuple_args( 112 schema, pytree.tree_map(type, args)) 113 114 save_for_backward_fn_inputs = namedtuple_args(schema, args) 115 to_save = save_for_backward_fn(save_for_backward_fn_inputs, output) 116 117 save_pytree_for_backward(ctx, (to_save, args_info)) 118 mark_non_differentiable(ctx, output, output_differentiability) 119 120 nonlocal out_spec 121 flat_output, out_spec = pytree.tree_flatten(output) 122 return tuple(flat_output) 123 124 def backward(ctx, *flat_grad_output): 125 assert out_spec is not None 126 grads = pytree.tree_unflatten(list(flat_grad_output), out_spec) 127 saved, args_info = unpack_saved(ctx) 128 # There is nothing on the ctx object for now, it is just there so 129 # that we can add additional things in the future. 130 inner_ctx = object() 131 if not isinstance(grads, tuple): 132 grads = (grads,) 133 grad_inputs_dict = backward_fn(inner_ctx, saved, *grads) 134 135 # Massage the grad_inputs_dict to a form acceptable by 136 # autograd.Function. 137 validate_grad_inputs_dict(grad_inputs_dict, custom_op, args_info) 138 return grad_inputs_dict_to_flat_tuple(grad_inputs_dict, args_info) 139 140 generated_cls = gen_autograd_function( 141 custom_op._opname + '_customop', forward, backward) 142 143 flat_output = generated_cls.apply(*flat_args) 144 assert out_spec is not None 145 return pytree.tree_unflatten(list(flat_output), out_spec) 146 return apply 147 148 149def gen_autograd_function(name, forward, backward): 150 generated_cls = type( 151 name, 152 (torch.autograd.Function,), 153 { 154 'forward': staticmethod(forward), 155 'backward': staticmethod(backward), 156 } 157 ) 158 return generated_cls 159 160 161@functools.lru_cache 162def namedtuple_args_cls(schema): 163 attribs = [arg.name for arg in schema.arguments.flat_all] 164 name = str(schema.name) + "_args" 165 # mypy doesn't support dynamic namedtuple name 166 tuple_cls = namedtuple(name, attribs) # type: ignore[misc] 167 return tuple_cls 168 169 170def namedtuple_args(schema, args): 171 assert isinstance(args, tuple) 172 tuple_cls = namedtuple_args_cls(schema) 173 return tuple_cls(*args) 174 175 176def validate_grad_inputs_dict(grad_inputs_dict, forward_op, args_info): 177 def error(what): 178 backward = forward_op._get_impl('backward') 179 raise RuntimeError( 180 f"In the backward function defined for {forward_op} at " 181 f"{backward.location} using the CustomOp API, {what}") 182 183 if not isinstance(grad_inputs_dict, dict): 184 error(f"expected the output of the backward function to be a dict but " 185 f"got {type(grad_inputs_dict)}") 186 187 expected_keys = {arg.name for arg in forward_op._schema.arguments.flat_all 188 if arg.type.is_tensor_like()} 189 actual_keys = grad_inputs_dict.keys() 190 if expected_keys != actual_keys: 191 error(f"expected the returned grad_input dict to have keys " 192 f"{expected_keys} but got {actual_keys}. The backward " 193 f"function must return a gradient (can be None) for each arg " 194 f"to the CustomOp that may be a Tensor or Sequence[Tensor]. " 195 f"Args declared to be non-Tensor-like types should not appear " 196 f"in the grad_input dict") 197 198 for name, grad in grad_inputs_dict.items(): 199 arg_info = getattr(args_info, name) 200 201 if isinstance(arg_info, list): 202 if not isinstance(grad, (tuple, list)): 203 error(f"for input '{name}' expected the grad_input dict to " 204 f"hold a list of gradients but got object of type " 205 f"{type(grad)}.") 206 if not len(grad) == len(arg_info): 207 error(f"for input '{name}' expected the grad_input dict to " 208 f"hold a list of {len(arg_info)} gradients but got " 209 f"{len(grad)}") 210 for idx, (g, info) in enumerate(zip(grad, arg_info)): 211 if g is None: 212 continue 213 if not isinstance(g, torch.Tensor): 214 error(f"for input '{name}' expected the grad_input dict to " 215 f"hold a list of None or Tensor gradients but got " 216 f"object of {type(g)} at index {idx}") 217 if not issubclass(info, torch.Tensor): 218 error(f"for input '{name}', got a Tensor as the gradient " 219 f"for the {idx}-th value but expected None because " 220 f"the {idx}-th value was not a Tensor (it was " 221 f"type {arg_info}") 222 continue 223 224 if grad is None: 225 continue 226 if not isinstance(grad, torch.Tensor): 227 error(f"got object of type {type(grad)} as the gradient for input " 228 f"'{name}', " 229 f"but expected the gradient to be either None or a Tensor") 230 if not issubclass(arg_info, torch.Tensor): 231 error(f"got a Tensor as the gradient for input '{name}' but " 232 f"expected None as the gradient because input '{name}' " 233 f"was not a Tensor (it was type {arg_info}).") 234 235 236def grad_inputs_dict_to_flat_tuple(grad_inputs_dict, args_info): 237 result = [] 238 for name, arg_info in args_info._asdict().items(): 239 if name not in grad_inputs_dict: 240 result.append(pytree.tree_map(lambda x: None, arg_info)) 241 continue 242 result.append(grad_inputs_dict[name]) 243 return tuple(pytree.tree_leaves(result)) 244 245# Saves "stuff" (a pytree) onto the ctx object. Use unpack_saved to unpack it. 246# autograd.Function prefers that users use ctx.save_for_backward to 247# save Tensors (to avoid reference cycles) and for non-Tensors to go onto the 248# ctx object. 249def save_pytree_for_backward(ctx, stuff): 250 flat_stuff, spec = pytree.tree_flatten(stuff) 251 num_elts = len(flat_stuff) 252 tensor_idxs = [idx for idx, thing in enumerate(flat_stuff) 253 if isinstance(thing, torch.Tensor)] 254 non_tensor_idxs = [idx for idx, thing in enumerate(flat_stuff) 255 if not isinstance(thing, torch.Tensor)] 256 tensors = [thing for thing in flat_stuff if isinstance(thing, torch.Tensor)] 257 non_tensors = [thing for thing in flat_stuff if not isinstance(thing, torch.Tensor)] 258 259 ctx.spec = spec 260 ctx.num_elts = num_elts 261 ctx.save_for_backward(*tensors) 262 ctx.tensor_idxs = tensor_idxs 263 ctx.saved_non_tensors = non_tensors 264 ctx.non_tensor_idxs = non_tensor_idxs 265 266 267# Inverse operation to save_pytree_for_backward 268def unpack_saved(ctx): 269 flat_stuff = [None] * ctx.num_elts 270 for tensor, idx in zip(ctx.saved_tensors, ctx.tensor_idxs): 271 flat_stuff[idx] = tensor 272 for non_tensor, idx in zip(ctx.saved_non_tensors, ctx.non_tensor_idxs): 273 flat_stuff[idx] = non_tensor 274 stuff = pytree.tree_unflatten(flat_stuff, ctx.spec) 275 return stuff 276