1# mypy: allow-untyped-defs 2""" 3This file contains utilities for tracing through __torch_dispatch__ based tensor subclasses and modes. 4AOTAutograd's responsibility is to trace through all pytorch capabilities that live in the pytorch dispatcher, 5and this includes tensor subclasses that implement __torch_dispatch__. 6""" 7 8import typing 9from typing import Any, List, Optional, Tuple, Union 10 11import torch.utils._pytree as pytree 12from torch import Tensor 13from torch._subclasses.fake_tensor import get_plain_tensors 14from torch.utils._python_dispatch import is_traceable_wrapper_subclass 15 16from .schemas import MutationType, SubclassCreationMeta, ViewAndMutationMeta 17from .utils import strict_zip 18 19 20zip = strict_zip 21 22 23def requires_subclass_dispatch(args, fw_metadata: ViewAndMutationMeta) -> bool: 24 args_flattened = pytree.arg_tree_leaves(*args) 25 any_subclass_args = any( 26 is_traceable_wrapper_subclass(x) 27 for x in args_flattened 28 if isinstance(x, Tensor) 29 ) 30 from torch._functorch._aot_autograd.schemas import SubclassCreationMeta 31 32 any_subclass_outputs = any( 33 type(x) is SubclassCreationMeta for x in fw_metadata.subclass_fw_graph_out_meta 34 ) 35 # This tells us whether or not we need to perform any unwrapping/wrapping of tensor subclasses at runtime. 36 return any_subclass_args or any_subclass_outputs 37 38 39def create_subclass_metadata(a, start_idx): 40 if not is_traceable_wrapper_subclass(a): 41 return None, start_idx + 1 42 43 inner_keys, metadata = a.__tensor_flatten__() 44 new_start_idx = start_idx 45 attrs = {} 46 for key in inner_keys: 47 new_subclass_meta, new_start_idx = create_subclass_metadata( 48 getattr(a, key), new_start_idx 49 ) 50 attrs[key] = new_subclass_meta 51 52 # It *must* be because is_traceable_wrapper_subclass() - but mypy is not smart. 53 assert isinstance(a, Tensor) 54 55 return ( 56 SubclassCreationMeta( 57 flat_tensor_start_idx=start_idx, 58 arg_count=new_start_idx - start_idx, 59 attrs=attrs, 60 meta=metadata, 61 outer_size=a.size(), # type: ignore[attr-defined, arg-type] 62 outer_stride=a.stride(), # type: ignore[arg-type] 63 original_subclass=a, 64 ), 65 new_start_idx, 66 ) 67 68 69# Given a real tensor subclass, returns a nested list of Plain tensor types 70def get_types_for_subclass(tensor_subclass): 71 if not is_traceable_wrapper_subclass(tensor_subclass): 72 return ["Tensor"] 73 inner_keys, _ = tensor_subclass.__tensor_flatten__() 74 result = [] 75 for key in inner_keys: 76 inner_tensor = getattr(tensor_subclass, key) 77 result.extend(get_types_for_subclass(inner_tensor)) 78 return result 79 80 81# Given a flat list of arguments, some of which may be tensor subclasses, 82# computes metadata about "how to reconstruct the current list of subclasses, 83# if we were given their flattened dense tensors instead" 84def create_subclass_meta( 85 curr_args: Union[List[Any], Tuple[Any, ...]] 86) -> List[Union[int, SubclassCreationMeta]]: 87 idx = 0 88 infos: List[Union[int, SubclassCreationMeta]] = [] 89 for a in curr_args: 90 if is_traceable_wrapper_subclass(a): 91 assert isinstance(a, Tensor) 92 start_idx = idx 93 subclass_meta, _ = create_subclass_metadata(a, start_idx) 94 infos.append(subclass_meta) 95 cnt = subclass_meta.arg_count 96 else: 97 infos.append(idx) 98 cnt = 1 99 idx += cnt 100 return infos 101 102 103# Output structure: 104# - List[Tensor] if tracing an inference graph 105# - Tuple[List[Tensor], List[Tensor]] if tracing a joint graph. 106# This function effectively concats each inner list of subclass tensors 107# into a (potentially longer) list of inner tensors. 108# 109# This function takes in a pytree of arguments and unwraps any tensor subclasses. 110# Annoyingly, we can't use pytrees to perform the unwrapping, because unwrapping returns 111# a list of tensors that we would then need to concat together. 112# Instead, we specialize the logic for the inference vs. joint graph case. 113# NOTE: this function is hot, since we unwrap tensor subclass inputs at runtime 114def unwrap_tensor_subclasses(wrapped_args, *, is_joint_structure: bool): 115 def concat_inner_tensors_from_subclasses(xs): 116 xs_inner = [] 117 for x in xs: 118 if is_traceable_wrapper_subclass(x): 119 xs_inner.extend(get_plain_tensors(typing.cast(Tensor, x))) 120 else: 121 xs_inner.append(x) 122 return xs_inner 123 124 if is_joint_structure: 125 assert isinstance(wrapped_args, tuple) and len(wrapped_args) == 2 126 assert isinstance(wrapped_args[0], (tuple, list)) and isinstance( 127 wrapped_args[1], (tuple, list) 128 ) 129 unwrapped_args_fw = concat_inner_tensors_from_subclasses(wrapped_args[0]) 130 unwrapped_args_tangents = concat_inner_tensors_from_subclasses(wrapped_args[1]) 131 unwrapped_args = (unwrapped_args_fw, unwrapped_args_tangents) 132 else: 133 assert isinstance(wrapped_args, (list, tuple)) 134 unwrapped_args_fw = concat_inner_tensors_from_subclasses(wrapped_args) 135 unwrapped_args = unwrapped_args_fw 136 return unwrapped_args 137 138 139def remap_unwrapped_subclass_arg_indices(wrapped_args, static_input_indices): 140 static_input_indices = set(static_input_indices) 141 new_ind = 0 142 remapped_static_indices = [] 143 for i, arg in enumerate(wrapped_args): 144 num_indices = 1 145 if is_traceable_wrapper_subclass(arg): 146 num_indices = len(get_plain_tensors(typing.cast(Tensor, arg))) 147 148 for _ in range(num_indices): 149 if i in static_input_indices: 150 remapped_static_indices.append(new_ind) 151 152 new_ind += 1 153 154 return remapped_static_indices 155 156 157# Turns a flattened list of tensor arguments into (maybe) subclass tensors. 158# This function is used both at trace time and runtime, so we have an is_runtime flag telling us which context we're in. 159def wrap_tensor_subclasses( 160 unwrapped_args: Union[Tuple[Any, ...], List[Any]], 161 *, 162 subclass_metas: List[Union[int, SubclassCreationMeta]], 163 num_fw_outs_saved_for_bw: Optional[int] = None, 164 is_runtime: bool = False, 165) -> Tuple[Any, ...]: 166 wrapped_args = [] 167 num_args_tallied = 0 168 for subclass_meta in subclass_metas: 169 if isinstance(subclass_meta, int): 170 wrapped_args.append(unwrapped_args[subclass_meta]) 171 num_args_tallied += 1 172 else: 173 assert isinstance(subclass_meta, SubclassCreationMeta) 174 wrapped_args.append( 175 subclass_meta.creation_fn(unwrapped_args, is_runtime=is_runtime) 176 ) 177 num_args_tallied += subclass_meta.arg_count 178 179 # Note: [Partitioner handling for Subclasses, Part 2] 180 # At the beginning of AOTAutograd, we collect metadata on the inputs and outputs of the user fw, 181 # to figure out which inputs/outputs are subclasses, and how to reconstruct the subclasses after flattening them. 182 # 183 # When this function is called at runtime in the forward, 184 # we have been passed a list of (flattened) dense-tensor fw-outs, and need to reconstruct any subclass fw outs. 185 # 186 # One reasonable question that you should ask: when should the dense_tensor -> subclass_tensor wrapping happen? 187 # Answer: we do it **inside of our compiled autograd.Function**. 188 # This seems like morally the right place: autograd happens above subclass desugaring, 189 # so autograd should see actual tensor subclasses at runtime, and not flattened dense tensors. 190 # 191 # This causes a tricky interaction though: when we run the min-cut partitioner to divvy up the joint graph 192 # into a forward and backward graph, we end up with some activations that show up as extra outputs 193 # in the compiled forward graph, that are **not** user outputs. 194 # These activations are not visible to the user, and so there's no need for us to wrap them back into subclasses. 195 # 196 # On top of that, when we first computed subclass metadata (in `run_functionalized_fw_and_collect_metadata`), 197 # we computed subclass metadata on every forward output, but this did **not** include activations 198 # created by the partitioner. 199 # as a result, `unwrapped_args` here will correspond to (*unwrapped_user_fw_outs, *activations), 200 # but `subclass_metas` will only correspond to subclass metatadata on `user_fw_outs`. 201 # We then need to make sure that we return (*wrapped_user_fw_outs, *activations). 202 if num_fw_outs_saved_for_bw is not None: 203 assert len(unwrapped_args) == num_args_tallied + num_fw_outs_saved_for_bw, ( 204 f"Expected the number actual unwrapped-subclass outputs {len(unwrapped_args)} to equal " 205 f"the number of args calculated from subclasses ({num_args_tallied}) plus the number of " 206 f"additional activations saved for the backward pass ({num_fw_outs_saved_for_bw})" 207 ) 208 activations = unwrapped_args[num_args_tallied:] 209 if isinstance(wrapped_args, tuple) and isinstance(activations, tuple): 210 return wrapped_args + activations 211 return tuple(list(wrapped_args) + list(activations)) 212 else: 213 assert len(unwrapped_args) == num_args_tallied 214 return tuple(wrapped_args) 215 216 217# Given a bunch of "dense" tensor arguments, this function (potentially) wraps them into tensor subclasses. 218# This function carefully handles the inference vs. joint cases: 219# - when is_joint_structure is True, args is (primals, tangents) 220# - when is_joint_structure is False, args is [*primals] 221def wrap_tensor_subclasses_maybe_joint( 222 unwrapped_args, *, is_joint_structure: bool, meta: ViewAndMutationMeta 223) -> Union[Tuple[Any, ...], List[Any]]: 224 # Since this function is re-used for both inference and joint graphs, 225 if is_joint_structure: 226 assert isinstance(unwrapped_args, tuple) and len(unwrapped_args) == 2 227 assert isinstance(unwrapped_args[0], (tuple, list)) and isinstance( 228 unwrapped_args[1], (tuple, list) 229 ) 230 primals, tangents = unwrapped_args[0], unwrapped_args[1] 231 wrapped_primals = wrap_tensor_subclasses( 232 primals, subclass_metas=meta.subclass_inp_meta 233 ) 234 wrapped_tangents = wrap_tensor_subclasses( 235 tangents, subclass_metas=meta.subclass_tangent_meta 236 ) 237 return (wrapped_primals, wrapped_tangents) 238 else: 239 wrapped_args = wrap_tensor_subclasses( 240 unwrapped_args, subclass_metas=meta.subclass_inp_meta 241 ) 242 return wrapped_args 243 244 245# TODO: UNUSED. delete? 246def create_metadata_for_subclass(meta: ViewAndMutationMeta) -> ViewAndMutationMeta: 247 # input infos 248 input_info = [] 249 for inp, subclass_meta in zip(meta.input_info, meta.subclass_inp_meta): 250 num_inps = 1 if isinstance(subclass_meta, int) else subclass_meta.arg_count 251 for _ in range(num_inps): 252 input_info.append(inp) 253 254 # output infos 255 output_info = [] 256 subclass_out_meta_user_outs_only = meta.subclass_fw_graph_out_meta[ 257 meta.num_mutated_inp_runtime_indices : 258 ] 259 if meta.num_intermediate_bases > 0: 260 subclass_out_meta_user_outs_only = subclass_out_meta_user_outs_only[ 261 : -meta.num_intermediate_bases 262 ] 263 # sanity assert 264 assert len(meta.output_info) == len(subclass_out_meta_user_outs_only) 265 # Assume that the information on the output is shared by all of its inner tensors. 266 for out, subclass_meta in zip(meta.output_info, subclass_out_meta_user_outs_only): 267 num_outs = 1 if isinstance(subclass_meta, int) else subclass_meta.arg_count 268 for _ in range(num_outs): 269 output_info.append(out) 270 271 # A bit hacky, but we don't actually care about all of the metadata here. 272 # This metadata is used **underneath** both autograd and subclass de-sugaring, 273 # So all we really care about is stuff like: 274 # - num inputs/outputs (needed by the partitioner) 275 # - input mutations (**not** used today, since we don't handle input mutations inside the subclass, 276 # although we should handle this eventually) 277 # TODO: add a test case to assert we error when this happens, instead of getting silent correctness 278 num_intermediate_bases = None 279 keep_input_mutations = meta.keep_input_mutations 280 traced_tangents = None 281 subclass_inp_meta = None 282 subclass_fw_graph_out_meta = None 283 subclass_tangent_meta = None 284 285 metadata = ViewAndMutationMeta( 286 input_info=input_info, # type: ignore[arg-type] 287 output_info=output_info, # type: ignore[arg-type] 288 num_intermediate_bases=num_intermediate_bases, # type: ignore[arg-type] 289 keep_input_mutations=keep_input_mutations, # type: ignore[arg-type] 290 traced_tangents=traced_tangents, # type: ignore[arg-type] 291 subclass_inp_meta=subclass_inp_meta, # type: ignore[arg-type] 292 subclass_fw_graph_out_meta=subclass_fw_graph_out_meta, # type: ignore[arg-type] 293 subclass_tangent_meta=subclass_tangent_meta, # type: ignore[arg-type] 294 ) 295 return metadata 296 297 298def compute_inner_mutated_inp_indices_from_subclass_meta( 299 fw_metadata: ViewAndMutationMeta, 300 inner_metadata: ViewAndMutationMeta, 301) -> List[int]: 302 # Note: [Recomputing subclass mutation handling] 303 # 304 # Generally, if a subclass requires grad, its components will not require grad. 305 # But for the purposes of tracking returned tensors, we should treat those component 306 # tensors as if they require grad. 307 # 308 # For example, if the subclass tensor requires grad and will be mutated in a way that 309 # requires us to handle the mutation outside of the graph, we need to return it 310 # from the forward graph. The inner_meta data won't consider the component tensors 311 # as if they need to be returned, because they don't require grad; but really, we 312 # should handle those tensors the same way we handle the subclass tensor itself; i.e. 313 # if we'd include the subclass tensor as part of the outputs, then we should also 314 # include the component tensors. 315 # 316 # To do this, we patch num_mutated_inp_runtime_indices below by expanding the inputs 317 # from the outer subclass tensors and propagating 318 319 updated_input_info = [] 320 inner_idx = 0 321 if not fw_metadata.subclass_inp_meta: 322 # Sometimes we don't have subclass info, e.g. synthetic_base codepaths 323 return inner_metadata.mutated_inp_runtime_indices 324 assert len(fw_metadata.subclass_inp_meta) == len(fw_metadata.input_info) 325 for outer_idx, inp_meta in enumerate(fw_metadata.subclass_inp_meta): 326 if isinstance(inp_meta, int): 327 assert outer_idx < len(fw_metadata.input_info) 328 if inner_metadata is not None: 329 assert inner_idx < len(inner_metadata.input_info) 330 assert ( 331 inner_metadata.input_info[inner_idx] 332 == fw_metadata.input_info[outer_idx] 333 ) 334 updated_input_info.append(fw_metadata.input_info[outer_idx]) 335 inner_idx += 1 336 else: 337 for _ in range(inp_meta.arg_count): 338 updated_input_info.append(fw_metadata.input_info[outer_idx]) 339 inner_idx += 1 340 if inner_metadata is not None: 341 assert len(inner_metadata.input_info) == len(updated_input_info) 342 343 return [ 344 i 345 for i, inp in enumerate(updated_input_info) 346 if inp.mutation_type == MutationType.MUTATED_OUT_GRAPH 347 ] 348