1# mypy: allow-untyped-defs 2import dataclasses 3from dataclasses import dataclass 4from typing import Any, Callable, Dict, Optional, Protocol 5 6from .. import _C, _ops, autograd, Tensor 7 8from ..utils import _pytree 9from . import utils 10 11 12class InfoProtocol(Protocol): 13 _backward_fn: Optional[Callable] 14 _setup_context_fn: Optional[Callable] 15 16 17@dataclasses.dataclass 18class Info: 19 _backward_fn: Optional[Callable] 20 _setup_context_fn: Optional[Callable] 21 22 23def make_autograd_impl(op: _ops.OpOverload, info: InfoProtocol) -> Callable: 24 name: str = f"GeneratedBackwardFor_{op._namespace}_{op._opname}_{op._overloadname}" 25 26 has_kwarg_only_args = utils.has_kwarg_only_args(op._schema) 27 28 @dataclass 29 class Metadata: 30 keyset: _C.DispatchKeySet 31 keyword_only_args: Dict[str, Any] 32 33 def forward(ctx, *args): 34 metadata = args[-1] 35 args = args[:-1] 36 37 with _C._AutoDispatchBelowAutograd(): 38 keyset = metadata.keyset 39 kwargs = metadata.keyword_only_args 40 result = op.redispatch(keyset & _C._after_autograd_keyset, *args, **kwargs) 41 if info._setup_context_fn: 42 # The Dispatcher will remove args that are equal to their default 43 # values from (args, kwargs). We're going to add it back so that 44 # the user can access them. 45 # 46 # This is OK to do: The Dispatcher removed the args for serialization 47 # FC/BC reasons (that is, a graph will not store args that are equal 48 # to their default values), but that doesn't matter here. If the user 49 # adds a new default arg, then they must update 50 # their setup_context (along with the rest of their operator 51 # registrations) 52 args, kwargs = utils.fill_defaults(op._schema, args, kwargs) 53 54 if has_kwarg_only_args: 55 info._setup_context_fn( 56 ctx=ctx, inputs=args, keyword_only_inputs=kwargs, output=result 57 ) 58 else: 59 info._setup_context_fn(ctx=ctx, inputs=args, output=result) 60 return result 61 62 def backward(ctx, *grads): 63 if info._backward_fn: 64 try: 65 prev_needs_input_grad = ctx.needs_input_grad 66 ctx.needs_input_grad = ctx.needs_input_grad[:-1] 67 result = info._backward_fn(ctx, *grads) 68 finally: 69 ctx.needs_input_grad = prev_needs_input_grad 70 if isinstance(result, tuple): 71 return (*result, None) 72 return result, None 73 raise RuntimeError( 74 f"Trying to backward through {op} but no autograd " 75 f"formula was registered. " 76 f"Please use register_autograd to add one." 77 ) 78 79 Generated = type( 80 name, 81 (autograd.Function,), 82 { 83 "forward": staticmethod(forward), 84 "backward": staticmethod(backward), 85 }, 86 ) 87 88 schema = op._schema 89 if any( 90 utils.is_tensorlist_like_type(a.type) 91 for a in (*schema.arguments, *schema.returns) 92 ): 93 Generated = supports_tensorlist(Generated) 94 95 # The dispatcher passes any keyword-only-args as kwargs and the 96 # rest of the args (even if specified as kwargs) as args. 97 def autograd_impl(keyset, *args, **keyword_only_args): 98 result = Generated.apply(*args, Metadata(keyset, keyword_only_args)) # type: ignore[attr-defined] 99 return result 100 101 return autograd_impl 102 103 104def supports_tensorlist(cls: Any) -> Any: 105 """Allows a given autograd.Function class to support List[Tensor] inputs/outputs. 106 107 Regular autograd.Function has a constraint that it only directly supports autograd for 108 Tensors. Applying @supports_tensorlist enables an autograd.Function to support 109 autograd for List[Tensor] inputs and outputs. 110 """ 111 orig_forward = cls.forward 112 orig_backward = cls.backward 113 orig_apply = cls.apply 114 115 @dataclass 116 class Metadata: 117 input_spec: spec_t 118 output_spec: Optional[spec_t] = None 119 result_is_tuple: Optional[bool] = None 120 121 def new_forward(ctx, *args): 122 metadata = args[-1] 123 args = args[:-1] 124 if not isinstance(metadata, Metadata): 125 raise NotImplementedError( 126 "NYI: calling supports_tensorlist autograd.Function.forward directly. " 127 "You should probably be calling .apply instead. " 128 "Please file an issue if not." 129 ) 130 args = unflatten(list(args), metadata.input_spec) 131 result = orig_forward(ctx, *args) 132 metadata.result_is_tuple = isinstance(result, tuple) 133 if not metadata.result_is_tuple: 134 result = (result,) 135 flat_result, output_spec = flatten(result, not_list_of_tensor) 136 metadata.output_spec = output_spec 137 138 if hasattr(ctx, "_pt_metadata"): 139 raise RuntimeError( 140 "Please don't set ctx._pt_metadata; PyTorch uses it to store info" 141 ) 142 ctx._pt_metadata = metadata 143 144 return tuple(flat_result) 145 146 def new_backward(ctx, *grads): 147 if not hasattr(ctx, "_pt_metadata"): 148 raise NotImplementedError( 149 "NYI: calling supports_tensorlist autograd.Function.backward directly. " 150 "This will automatically get called by PyTorch autograd. " 151 "Please file an issue if you need this." 152 ) 153 154 metadata = ctx._pt_metadata 155 grads = unflatten(list(grads), metadata.output_spec) 156 157 # If the user's input is ([x, y, z], w), 158 # then needs_input_grad is (bool, bool, bool, bool, bool). 159 # We need to 160 # 1. get rid of the additional bool (which comes from the extra 161 # `metadata input`) 162 # 2. unflatten to get the right structure. 163 prev_needs_input_grad = ctx.needs_input_grad 164 try: 165 ctx.needs_input_grad = unflatten( 166 list(ctx.needs_input_grad[:-1]), metadata.input_spec 167 ) 168 grad_inputs = orig_backward(ctx, *grads) 169 finally: 170 ctx.needs_input_grad = prev_needs_input_grad 171 172 if not isinstance(grad_inputs, tuple): 173 grad_inputs = (grad_inputs,) 174 # Assume that any Nones in the backward are Tensors. 175 # If the forward has an arg that is [1, 2, 3], the backward should 176 # return None as the grad. 177 # If the forward has an arg that is [tensor, tensor], the backward 178 # may return [None, None], [grad, None], [None, grad], or [grad, grad]. 179 flat_grad_inputs, grad_inputs_spec = flatten( 180 grad_inputs, not_list_of_optional_tensor 181 ) 182 if grad_inputs_spec != metadata.input_spec: 183 raise RuntimeError( 184 f"Expected the return from backward to be of the same structure " 185 f"as the inputs. Got: {grad_inputs_spec} (return from backward), " 186 f"{metadata.input_spec} (inputs)" 187 ) 188 return tuple(flat_grad_inputs + [None]) 189 190 def new_apply(*args): 191 flat_args, input_spec = flatten(args, is_leaf=not_list_of_tensor) 192 metadata = Metadata(input_spec) 193 result = orig_apply(*flat_args, metadata) # type: ignore[misc] 194 assert metadata.output_spec is not None 195 result = unflatten(list(result), metadata.output_spec) 196 if not metadata.result_is_tuple: 197 assert isinstance(result, tuple) 198 assert len(result) == 1 199 return result[0] 200 return result 201 202 cls.forward = new_forward 203 cls.backward = new_backward 204 cls.apply = new_apply 205 return cls 206 207 208def not_list_of_tensor(tree): 209 if isinstance(tree, tuple): 210 return False 211 if isinstance(tree, list): 212 return any(not isinstance(l, Tensor) for l in tree) 213 return True 214 215 216def not_list_of_optional_tensor(tree): 217 if isinstance(tree, tuple): 218 return False 219 if isinstance(tree, list): 220 return any(l is not None and not isinstance(l, Tensor) for l in tree) 221 return True 222 223 224flatten = _pytree.tree_flatten 225unflatten = _pytree.tree_unflatten 226spec_t = _pytree.TreeSpec 227