1# mypy: allow-untyped-defs 2""" 3The various dataclasses, Enums, namedtuples etc used in AOTAutograd. This includes 4input/output types, metadata, config, function signatures etc. 5""" 6 7import collections 8import dataclasses 9import functools 10from dataclasses import dataclass, field 11from enum import Enum 12from typing import Any, Callable, Dict, List, NewType, Optional, Set, Union 13 14import torch 15import torch.utils._pytree as pytree 16from torch._guards import Source 17from torch._ops import OpOverload 18from torch._subclasses import FakeTensor 19from torch._subclasses.fake_tensor import is_fake 20from torch.utils._python_dispatch import is_traceable_wrapper_subclass 21 22from .. import config 23from .functional_utils import ( 24 _check_if_mutation_can_be_in_graph, 25 FunctionalTensorMetadataEq, 26) 27from .utils import strict_zip 28 29 30zip = strict_zip 31 32OutputType = Enum( 33 "OutputType", 34 ( 35 # output is not an alias 36 "non_alias", 37 # output aliases an input 38 "alias_of_input", 39 # output **is** an input tensor 40 "is_input", 41 # output has a ._base tensor, which is a graph intermediate. 42 # We need to return its ._base as a graph output, 43 # so its requires_grad info is populated correctly. 44 # Instructs the runtime code to regenerate the current output 45 # from a base tensor, graph_intermediates[base_idx] 46 "alias_of_intermediate_save_as_output", 47 # Same as above; but we don't need to explicitly add its ._base 48 # as a graph output, because it already **is** a graph output. 49 "alias_of_intermediate", 50 # Same as above; but the output's ._base is **already** a user output. 51 # Instructs the runtime code to regenerate the current output from 52 # a base tensor, user_outputs[base_idx] 53 "alias_of_intermediate_base_is_user_output", 54 # See Note [Intermediate Bases Optimization] 55 "unsafe_view_alias", 56 # output is an alias, but has a custom autograd.Function backward. 57 # In this case, we don't want to do view-replay, since we won't be able to replay the custom function. 58 # Instead, we'll treat this output "normally", and trace its backward into the graph. 59 "custom_function_view", 60 ), 61) 62 63 64# This class stores info about every user output. 65@dataclass(frozen=True) 66class OutputAliasInfo: 67 # Tells us if this output is: 68 # (1) a regular (non-aliased) output 69 # (2) an alias of a forward input 70 # (3) **is** a forward input (special case of "alias_of_input") 71 # (4) an alias of an intermediate (aka an alias of an output of the inner traced forward) 72 # (5) an alias of an intermediate, that explicitly requires returning the intermediate 73 # as a graph output 74 # (6) an alias of an intermediate, where that intermediate is also a user output 75 output_type: OutputType 76 # The raw type of the output (torch.Tensor, SymInt, etc) 77 raw_type: type 78 # If (1) above, then 79 # - base_idx is None 80 # If (2) or (3) above, then 81 # - Tells us that the base of this alias is user_fwd_input[base_idx] 82 # (This is an index into the inputs *before* we make synthetic bases) 83 # If (4) or (5) above, then 84 # - Tells us that the base of this alias is output_graph_intermediates[base_idx] 85 # here, this refers to the index of the *direct* traced 86 # If (6) above, then: 87 # - Tells us that the base of this alias is output_user_fwds[base_idx] 88 # here, this refers to the index of the *direct* traced 89 base_idx: Optional[int] 90 # If it is a Tensor, what the dynamic dims are (otherwise is None) 91 dynamic_dims: Optional[Set[int]] 92 # requires_grad 93 requires_grad: bool 94 # FunctionalTensorWrapper that represents this output. 95 # 96 # Provides us the means to replay views from it. 97 # 98 # We need to wrap the actual FunctionalTensorWrapper with this class so that 99 # we only compare the tensor's metadata. That's because with the transformations 100 # of the model throughout AOTAutograd, the sequence of ViewMeta and the base 101 # tensor might change. 102 functional_tensor: Optional[FunctionalTensorMetadataEq] = None 103 104 105class MutationType(Enum): 106 NOT_MUTATED = 1 107 MUTATED_IN_GRAPH = 2 108 MUTATED_OUT_GRAPH = 3 109 110 111# This class tells us info about user inputs. 112@dataclass(frozen=True) 113class InputAliasInfo: 114 is_leaf: bool 115 mutates_data: bool 116 mutates_metadata: bool 117 mutations_hidden_from_autograd: bool 118 mutations_under_no_grad_or_inference_mode: bool 119 mutation_inductor_storage_resize: bool 120 mutates_storage_metadata: bool 121 requires_grad: bool 122 keep_input_mutations: bool 123 124 def __post_init__(self): 125 if self.mutates_storage_metadata: 126 # For convenience, we guarantee that this is always true. 127 # In practice, If we call .set_(), then at runtime there is no need 128 # to additionally fix up the tensor metadata, since our runtime 129 # call to inp.set_(updated_inp) will already have the right metadata 130 assert self.mutates_metadata 131 132 @functools.cached_property 133 def mutation_type(self) -> MutationType: 134 if ( 135 (not self.mutates_data) 136 and (not self.mutates_metadata) 137 and not (self.mutation_inductor_storage_resize) 138 ): 139 return MutationType.NOT_MUTATED 140 141 if _check_if_mutation_can_be_in_graph( 142 self.keep_input_mutations, 143 self.mutates_data, 144 self.mutates_metadata, 145 self.mutations_hidden_from_autograd, 146 self.mutations_under_no_grad_or_inference_mode, 147 self.mutates_storage_metadata, 148 self.mutation_inductor_storage_resize, 149 self.requires_grad, 150 ): 151 return MutationType.MUTATED_IN_GRAPH 152 153 return MutationType.MUTATED_OUT_GRAPH 154 155 156@dataclass 157class SubclassCreationMeta: 158 """ 159 Used for AOTDispatch. 160 This dataclass gives us the information we need to reconstruct a tensor subclass 161 from our flat inputs. 162 Why is this important? The graph that we'd like to trace out contains flat tensor inputs, 163 But the user's original model may have subclass inputs and outputs. 164 So we need to wrap/unwrap subclasses as necessary to translate between the user's 165 view (subclass inps/outs), and the backend compiler's view (graph with no subclass args). 166 167 Complications arise mostly from the fact that a subclass can hold more than one inner tensor; 168 So for a given subclass input/output, we need to carefully track which indices map 169 to the subclass tensor in the corresponding "dense-tensor-only" graph. 170 """ 171 172 # In the inner graph that only takes in dense tensor inputs, 173 # this maps to the first index of "tensors that should go in this subclass wrapper" 174 flat_tensor_start_idx: int 175 # arg_count is inclusive of the arg_counts of any 176 # inner tensor subclasses: If I have a TwoTensor and 177 # both of its inner elements are TwoTensors, then the 178 # arg_count of the outer-most sublass will be 4 179 arg_count: int 180 # meta and attrs are produced by the subclass's __tensor_flatten__. 181 # We need to keep them around along with outer_size / outer_stride to plumb them 182 # into __tensor_unflatten__ 183 attrs: Dict[str, Union["SubclassCreationMeta", None]] 184 outer_size: List[int] 185 outer_stride: List[int] 186 meta: Any 187 # Stores the original subclass itself. 188 # This is needed because we need the autograd metadata on the original subclass 189 # (this is guaranteed to be a wrapper subclass that holds a fake tensor, 190 # so holding onto this at runtime shouldn't leak memory) 191 # This field is nulled out after calling make_runtime_safe() 192 original_subclass: Optional[torch.Tensor] 193 194 # Used at runtime to determine the subclass type, so we don't need to save the original subclass 195 original_subclass_type: Optional[type] = None 196 197 def creation_fn(self, all_args, *, is_runtime: bool): 198 inner_tensors = {} 199 200 curr_start_idx = self.flat_tensor_start_idx 201 for attr, creation_meta in self.attrs.items(): 202 if creation_meta is None: 203 subclass = all_args[curr_start_idx] 204 curr_start_idx += 1 205 else: 206 subclass = creation_meta.creation_fn(all_args, is_runtime=is_runtime) 207 curr_start_idx += creation_meta.arg_count 208 inner_tensors[attr] = subclass 209 210 if is_runtime: 211 assert self.original_subclass_type is not None 212 original_subclass_type = self.original_subclass_type 213 else: 214 original_subclass_type = type(self.original_subclass) 215 216 rebuilt = original_subclass_type.__tensor_unflatten__( # type: ignore[attr-defined] 217 inner_tensors, self.meta, self.outer_size, self.outer_stride 218 ) 219 220 if not is_runtime: 221 # After wrapping up the inner dense tensors into a subclass, we need to make sure that our new wrapper 222 # has correct autograd metadata, since we'll be tracing through the autograd engine with the subclass. 223 # We don't trace through the autograd engine at runtime though, so no need 224 # to compute this extra metadata then! 225 torch._mirror_autograd_meta_to(self.original_subclass, rebuilt) # type: ignore[attr-defined] 226 227 return rebuilt 228 229 def make_runtime_safe(self): 230 assert self.original_subclass is not None 231 self.original_subclass_type = type(self.original_subclass) 232 self.original_subclass = None 233 # Recurse on nested subclass info 234 for creation_meta in self.attrs.values(): 235 if creation_meta is not None: 236 creation_meta.make_runtime_safe() 237 238 def __post_init__(self): 239 # sanity assert to make sure we don't leak memory 240 assert is_fake(self.original_subclass) 241 242 # This saves the type of subclass nested structure to compare 243 # against runtime tangent inputs. We do wanna compute this at AOT 244 # time as it is invoked in hot-path 245 from .subclass_utils import get_types_for_subclass 246 247 self.subclass_type = get_types_for_subclass(self.original_subclass) 248 249 250# This class encapsulates all aliasing + mutation info we need about the forward graph 251# See a more detailed overview of the edge case handling at 252# https://docs.google.com/document/d/19UoIh_SVrMy_b2Sx5ZaeOJttm6P0Qmyss2rdBuyfoic/edit 253@dataclass(eq=False) 254class ViewAndMutationMeta: 255 # length = # user inputs 256 # This gives us info about every input, and what sort of mutation happened to it (if any) 257 input_info: List[InputAliasInfo] 258 259 # length = # user outputs 260 # This gives us info about every output (mostly around whether it aliases other tensors) 261 output_info: List[OutputAliasInfo] 262 263 # length = the number of intermediate bases appended as outputs to the end of the forward graph. 264 # Note: this is not necessarily the same thing as: 265 # len([x for x in output_info if x.output_type == OutputType.alias_of_intermediate]) 266 # Because outputs might share a ._base, or an output's ._base might itself be 267 # another user output (in both cases, we won't redundantly append bases to the end of the graph) 268 num_intermediate_bases: int 269 270 # For inference only: instructs us to keep data-only input mutations directly in the graph 271 keep_input_mutations: bool 272 273 # length = (# inputs w data mutations) + (# user outputs that are non_aliasing tensors) 274 # + (# intermediate bases) 275 # These are the FakeTensor (or potential SymInt) outputs that we traced from our 276 # metadata pass of the user's forward function. 277 # Their only use today is to pass them as a best-guess for tangents when tracing the joint. 278 # Stashing them as part of our "metadata" makes it simpler if we want to run our analysis 279 # pass once, and re-use the output throughout AOTAutograd 280 traced_tangents: List[Any] 281 282 # Each of these is a list telling us about subclasses for the inputs/outputs/grad_outs 283 # They are used throughout AOTDispatch to tell us how to generate a list of subclass tensors, 284 # Given a (potentially larger) list of plain torch tensors. 285 286 # Taking subclass_inp_meta as an example: 287 # subclass_inp_meta[i] = j (an int) tells us: 288 # "The i'th user input is not a subclass, and corresponds to inputs[j] of the plain-tensor graph." 289 # subclass_inp_meta[i] = SubclassCreationMeta(flat_tensor_start_idx=3, arg_count=2) 290 # "The i'th user input is subclass holding two inner tensors, which are 291 # inputs[3] and inputs[4] of the plain-tensor graph". 292 293 # length = # user inputs 294 subclass_inp_meta: List[Union[int, SubclassCreationMeta]] 295 # So, the full set of outputs to the forward graph looks something like: 296 # (*mutated_inps, *user_outs, *intermediate_bases, *saved_for_bw_tensors) 297 # where the first 3 of those 4 can be subclasses 298 # (but not saved_for_bw tensors, since these are internal to the compiler 299 # and not user visible, so there's no point in wrapping/unwrapping them at runtime). 300 # This list contains subclass information on all of the fw graph outputs 301 # except for saved_for_bw_tensors. 302 subclass_fw_graph_out_meta: List[Union[int, SubclassCreationMeta]] 303 # length = # backward graph inputs 304 subclass_tangent_meta: List[Union[int, SubclassCreationMeta]] 305 # TODO: we should kill this 306 # (need to default it to not break internal) 307 is_train: bool = False 308 309 # length = (# inputs w data mutations) + (# user outputs that are non_aliasing tensors) 310 # + (# intermediate bases) 311 # At runtime, we don't keep the traced_tangents around since they're not serializable. 312 # Instead, we keep any necessary subclass metadata necessary about each traced_tangent. 313 # This list is generated after calling make_runtime_safe(). 314 traced_tangent_metas: Optional[List[Any]] = None 315 316 num_symints_saved_for_bw: Optional[int] = None 317 318 # The grad_enabled mutation that will be emitted in the runtime_wrapper epilogue 319 # NOTE: AOTAutograd will assume that the ambient `is_grad_enabled` is the grad mode 320 # that is intended to be in effect prior to running the graph, in keeping with 321 # equivalence to eager mode. It is the responsibility of upstream graph acquisition 322 # to reset the grad mode to its pre-graph value prior to calling aot_autograd. 323 grad_enabled_mutation: Optional[bool] = None 324 325 # Keeps track of whether `torch.use_deterministic_algorithms` was turned on 326 # when the forward was run. If deterministic mode was turned off during the 327 # forward, but is turned on during the backward call, then an error is 328 # raised 329 deterministic: Optional[bool] = None 330 331 # Keeps track of which input indices store parameters (which we will treat as static) 332 static_input_indices: List[int] = field(default_factory=list) 333 334 # Map of effect type (ex. _EffectType.ORDERED) to token. If there are 335 # side-effectful operators, FunctionalTensorMode will populate this 336 # dictionary telling us how many tokens we will need during tracing. 337 tokens: Dict[Any, torch.Tensor] = field(default_factory=dict) 338 339 # Only filled in if/when we trace the joint function 340 # If an input requires grad and is mutated in the backward, it is only safe to keep the mutation 341 # in the graph if gradients are disabled while the backward runs 342 # (grad mode is disabled by default when users run the backward, but can be turned on with create_graph=True) 343 # At runtime during the backward, we use this list of indices to error properly if we find out 344 # that it was not safe to include a backward mutation in the graph. 345 indices_of_inputs_that_requires_grad_with_mutations_in_bw: List[int] = field( 346 default_factory=list 347 ) 348 349 # Indexes of saved tensors which are donated buffer. 350 # Donated buffer means the tensor is not alias of any forward user input, forward user output, 351 # and backward output. 352 bw_donated_idxs: Optional[List[int]] = None 353 354 # Number of tokens used in backward, appended at the end of backward outputs. 355 # Filled after tracing joint function. 356 num_backward_tokens: int = 0 357 358 def __post_init__(self): 359 # pre-compute the indices of the inputs that are mutated. 360 # When keep_input_mutations is set, we don't need to worry about our epilogue 361 # handling data-only mutations, because we keep them directly in the graph. 362 363 mutated_inp_runtime_indices = [ 364 i 365 for i, m in enumerate(self.input_info) 366 if (m.mutation_type == MutationType.MUTATED_OUT_GRAPH) 367 ] 368 369 mutated_graph_handled_indices = [ 370 i 371 for i, m in enumerate(self.input_info) 372 if m.mutation_type == MutationType.MUTATED_IN_GRAPH 373 ] 374 self.mutated_graph_handled_indices = mutated_graph_handled_indices 375 self.num_mutated_graph_handled_indices = len(self.mutated_graph_handled_indices) 376 377 mutated_graph_handled_indices_seen_by_autograd = [ 378 i 379 for i in mutated_graph_handled_indices 380 if not self.input_info[i].mutations_hidden_from_autograd 381 ] 382 383 self.mutated_graph_handled_indices_seen_by_autograd = ( 384 mutated_graph_handled_indices_seen_by_autograd 385 ) 386 self.num_mutated_graph_handled_indices_seen_by_autograd = len( 387 self.mutated_graph_handled_indices_seen_by_autograd 388 ) 389 390 aliased_out_indices = [ 391 i 392 for i, m in enumerate(self.output_info) 393 if m.output_type 394 not in [ 395 OutputType.non_alias, 396 OutputType.unsafe_view_alias, 397 OutputType.custom_function_view, 398 ] 399 ] 400 unsafe_view_out_indices = [ 401 i 402 for i, m in enumerate(self.output_info) 403 if m.output_type is OutputType.unsafe_view_alias 404 ] 405 406 # This is pre-computed in post_init for perf. 407 # It contains the index of every element 408 # of input_info that corresponds to a mutation (data or metadata or both) 409 self.mutated_inp_runtime_indices = mutated_inp_runtime_indices 410 self.num_mutated_inp_runtime_indices = len(self.mutated_inp_runtime_indices) 411 412 # This is pre-computed for perf. 413 # It contains the index of every element 414 # of output_info that corresponds to an alias (either of an input or intermediate) 415 self.aliased_out_indices = aliased_out_indices 416 self.unsafe_view_out_indices = unsafe_view_out_indices 417 self.num_outputs = len(self.output_info) 418 self.num_outputs_non_aliased = len( 419 [ 420 x 421 for x in self.output_info 422 if x.output_type 423 in [ 424 OutputType.non_alias, 425 OutputType.unsafe_view_alias, 426 OutputType.custom_function_view, 427 ] 428 ] 429 ) 430 self.num_outputs_aliased_to_inputs = len( 431 [ 432 x 433 for x in self.output_info 434 if x.output_type 435 in [ 436 OutputType.alias_of_input, 437 OutputType.is_input, 438 ] 439 ] 440 ) 441 self.num_unsafe_view_outputs = len(self.unsafe_view_out_indices) 442 self.num_outputs_aliased_to_intermediates = len( 443 [ 444 x 445 for x in self.output_info 446 if x.output_type 447 in [ 448 OutputType.alias_of_intermediate, 449 OutputType.alias_of_intermediate_save_as_output, 450 OutputType.alias_of_intermediate_base_is_user_output, 451 ] 452 ] 453 ) 454 self.num_outputs_aliased = ( 455 self.num_outputs_aliased_to_inputs 456 + self.num_outputs_aliased_to_intermediates 457 ) 458 459 self.dynamic_outputs = any(o.dynamic_dims for o in self.output_info) 460 # See Note: [AOTAutograd Backward Guards] 461 # This is pre-computed for fast asserts on the types of our grad_outputs in the backward. 462 # Eventually, we should kill this and replace with real backward guards. 463 # (we want to precompute the "runtime" types, so replace FakeTensor with torch.Tensor) 464 self.output_types = [ 465 torch.Tensor if isinstance(x, FakeTensor) else type(x) 466 for x in self.traced_tangents 467 ] 468 469 self.is_rng_op_functionalized = config.functionalize_rng_ops 470 # All of the above metadata is collected by tracing the fw function. 471 # However, extra outputs for rng offsets behave differently. Both fwd 472 # and bwd graphs have their own outputs for the total consumed offsets. 473 # Unlike mutated inputs, we don't have to worry about sending the right 474 # set of tensors between fwd and bwd. Fwd and bwd offsets are 475 # independent and simpler to handle. Therefore, we track them 476 # separately. 477 self.num_outputs_rng_offset = 1 if self.is_rng_op_functionalized else 0 478 479 # Our forward() returns both (tokens, mutated_inputs, outputs, output_intermediate_bases, saved_tensors, saved_symints) 480 # Tokens will be split out before mutations/view handling and we do not count them here. 481 self.num_forward_returns = ( 482 self.num_mutated_inp_runtime_indices 483 + self.num_outputs 484 + self.num_intermediate_bases 485 ) 486 # In case of functionalization of rng ops, the fw_module returns one 487 # additional output for rng offset. This rng offset is used right 488 # away to advance the rng state, and is not passed on to the raw 489 # outputs. However, we need to know the exact boundary to identify 490 # which tensors to be saved for the bwd graph. num_forward captures 491 # this information. 492 self.num_forward = self.num_forward_returns + self.num_outputs_rng_offset 493 494 def make_runtime_safe(self): 495 """ 496 There are various fields in ViewAndMutationMeta that aren't serializable. This function is called after all tracing 497 is completed to simplify certain fields in the metadata so that they can be safely cached. 498 499 Doing so may lose information (in the case of traced_tangents), but none of the information is needed at runtime. 500 """ 501 # TODO: This function is only a best effort: there are other fields that may not be cache safe 502 # (i.e., there's no guarantee that tensor_flatten() returns a serializable result), or that 503 # SubclassCreationMeta is cache safe. 504 assert self.traced_tangent_metas is None 505 506 def extract_metadata(t): 507 if isinstance(t, torch.Tensor) and is_traceable_wrapper_subclass(t): 508 (inner_tensors, flatten_spec) = t.__tensor_flatten__() # type: ignore[attr-defined] 509 # Technically, we only need the flatten_spec, not the inner tensors. 510 # However, some Tensor subclasses (like TwoTensor) may have flatten_spec = None. 511 # And we want to be able to assert that this metadata is non-None, 512 # to distinguish between "this was a tensor subclass with no metadata" vs. 513 # "this wasn't a tensor subclass at all". 514 return (inner_tensors, flatten_spec) 515 else: 516 return None 517 518 self.traced_tangent_metas = [extract_metadata(t) for t in self.traced_tangents] 519 # Clear traced tangents at runtime 520 self.traced_tangents = [] 521 new_output_info = [] 522 for out in self.output_info: 523 if config.view_replay_for_aliased_outputs: 524 new_out = out 525 else: 526 # If we're not using view_replay, remove the functional tensor. 527 # Functional tensors are unfortunately not serializable, 528 # so doing this is required for AOTAutograd caching. 529 new_out = dataclasses.replace(out, functional_tensor=None) 530 new_output_info.append(new_out) 531 self.output_info = new_output_info 532 for inp_meta in self.subclass_inp_meta: 533 if isinstance(inp_meta, SubclassCreationMeta): 534 inp_meta.make_runtime_safe() 535 for inp_meta in self.subclass_fw_graph_out_meta: 536 if isinstance(inp_meta, SubclassCreationMeta): 537 inp_meta.make_runtime_safe() 538 for inp_meta in self.subclass_tangent_meta: 539 if isinstance(inp_meta, SubclassCreationMeta): 540 inp_meta.make_runtime_safe() 541 542 @property 543 def tensors_saved_for_backwards_slice(self): 544 assert self.num_symints_saved_for_bw is not None 545 if self.num_symints_saved_for_bw > 0: 546 return slice(self.num_forward, -self.num_symints_saved_for_bw) 547 else: 548 return slice(self.num_forward, None) 549 550 @property 551 def symints_saved_for_backwards_slice(self): 552 assert self.num_symints_saved_for_bw is not None 553 if self.num_symints_saved_for_bw > 0: 554 return slice(-self.num_symints_saved_for_bw, None) 555 else: 556 return slice(0, 0) # empty slice 557 558 def __eq__(self, other): 559 if not isinstance(other, ViewAndMutationMeta): 560 return NotImplemented 561 return ( 562 self.input_info == other.input_info 563 and self.output_info == other.output_info 564 and self.num_intermediate_bases == other.num_intermediate_bases 565 and self.keep_input_mutations == other.keep_input_mutations 566 and self.is_rng_op_functionalized == other.is_rng_op_functionalized 567 and self.num_outputs_rng_offset == other.num_outputs_rng_offset 568 and len(self.traced_tangents) == len(other.traced_tangents) 569 and all( 570 x.shape == y.shape and x.dtype == y.dtype 571 for x, y, in zip(self.traced_tangents, other.traced_tangents) 572 ) 573 and self.num_backward_tokens == other.num_backward_tokens 574 ) 575 576 577@dataclass(eq=False) 578class SubclassMeta: 579 # A copy of all forward metadata, but computed on the *dense* tensor forward (after desugaring subclasses) 580 # So for example, if the user had a model containing two `TwoTensor` inputs, 581 # Then `SubclassMeta.fw_metadata.input_infos` would have length 4 here. 582 fw_metadata: ViewAndMutationMeta 583 584 # Note: [Computing Subclass Metadata about grad_inputs] 585 # Given a list of flattened, plain tensor grad_inputs, this tells us how to reconstruct the grad_input subclasses 586 # 587 # You might think: why not just assume that all grad_inputs will have the same subclass-ness as the original inputs? 588 # (AOTAutograd generally assumes other properties, e.g. that grad_outputs are contiguous) 589 # 590 # This doesn't really work though. take this example: 591 # 592 # def f(DoubleTensor, DenseTensor): 593 # return DoubleTensor * DenseTensor 594 # 595 # In the above example, the .grad field of *both* DoubleTensor and DenseTensor will be a DoubleTensor. 596 # When we trace out a joint fw-bw graph, we'll end up returning two subclasses for the two grad_inputs. 597 # This means that our backward graph will return 4 outputs (two dense tensors for each DoubleTensor grad_input) 598 # and we need to properly store the metadata that tells us how to turn these 4 outputs back into DoubleTensors. 599 # 600 # Note that this info **cannot** easily be figured out from ViewAndMutationMeta. 601 # We can only compute this info by tracing the entire joint and examining the grad_inputs that we computed. 602 # 603 # See Note: [AOTAutograd Backward Guards] 604 # This will also eventually require us to install backward guards, 605 # in case we made incorrect assumptions about the subclass-ness of our grad_outputs 606 # 607 # Optional field because we don't compute for inference graphs 608 grad_input_metas: Optional[List[Union[int, SubclassCreationMeta]]] = None 609 610 def __init__(self) -> None: 611 # The fields in this class get set after its construction. 612 pass 613 614 615# This class exists because: 616# - the autograd.Function.forward() in aot autograd returns outputs that might alias inputs 617# - we only care about the metadata on those aliases, so we can regenerate them. 618# We do not want them to participate in the autograd.Function. 619# We do that by wrapping them in an opaque class, so the autograd.Function 620# does not know to treat them as tensors. 621@dataclass(frozen=True) 622class TensorAlias: 623 alias: torch.Tensor 624 625 626@dataclass 627class BackwardSignature: 628 """ 629 Provides information about the backward section of an exported 630 joint forward-backward graph. 631 For a particular fx GraphModule, this class contains information on: 632 (1) A mapping from each gradient (backwards output) to the parameter 633 it corresponds to (forward input) 634 (2) A mapping from each gradient (backwards output) to the user input 635 it corresponds to (forward input) 636 (3) Which of the forward outputs corresponds to the loss, that we backprop on. 637 638 Each string name is the `node.name` of the corresponding node in the fx graph. 639 """ 640 641 gradients_to_parameters: Dict[str, str] 642 gradients_to_user_inputs: Dict[str, str] 643 loss_output: str 644 645 646GraphOutputName = NewType("GraphOutputName", str) 647GraphInputName = NewType("GraphInputName", str) 648FQN = NewType("FQN", str) 649 650 651@dataclass 652class GraphSignature: 653 """ 654 Provides information about an exported module. 655 For a particular fx GraphModule, this class contains information on: 656 (1) Which graph inputs are parameters, buffers, or user inputs 657 (2) (for params/buffers) a mapping from the name of each graph argument 658 to its parameter/buffer FQN in the original nn.Module. 659 (3) If there are input mutations, these are represented as extra outputs 660 in the fx GraphModule. We provide a mapping from these 661 extra output names to the names of the actual inputs. 662 (4) The pytree metadata on how to flatten/unflatten inputs and outputs. 663 The corresponding FX GraphModule only accepts and returns 664 pytree-flattened inputs/outputs. 665 (5) (Optionally) if the FX is a joint forward-backward graph, we provide 666 a signature on the backward section of the joint graph. 667 """ 668 669 parameters: List[FQN] 670 buffers: List[FQN] 671 672 user_inputs: List[GraphInputName] 673 user_outputs: List[GraphOutputName] 674 inputs_to_parameters: Dict[GraphInputName, FQN] 675 inputs_to_buffers: Dict[GraphInputName, FQN] 676 677 # If the user's module mutates a buffer, 678 # it's represented in the graph as an extra graph output. 679 # This dict is a mapping from 680 # "graph outputs that correspond to updated buffers" 681 # to the FQN names of those mutated buffers. 682 buffers_to_mutate: Dict[GraphOutputName, FQN] 683 user_inputs_to_mutate: Dict[GraphOutputName, GraphInputName] 684 685 in_spec: pytree.TreeSpec 686 out_spec: pytree.TreeSpec 687 688 backward_signature: Optional[BackwardSignature] 689 690 input_tokens: List[GraphInputName] 691 output_tokens: List[GraphOutputName] 692 693 @classmethod 694 def from_tracing_metadata( 695 cls, 696 *, 697 in_spec: pytree.TreeSpec, 698 out_spec: pytree.TreeSpec, 699 graph_input_names: List[str], 700 graph_output_names: List[str], 701 view_mutation_metadata: ViewAndMutationMeta, 702 named_parameters: List[str], 703 named_buffers: List[str], 704 num_user_inputs: int, 705 num_user_outputs: int, 706 loss_index: Optional[int], 707 backward_signature: Optional[BackwardSignature], 708 ) -> "GraphSignature": 709 graph_inputs = graph_input_names 710 graph_outputs = graph_output_names 711 parameters = list(named_parameters) 712 buffers = list(named_buffers) 713 num_tokens = len(view_mutation_metadata.tokens) 714 715 # Calling convention assumptions: 716 # (1) graph inputs = (input_tokens, params, buffers, user_inputs) 717 # (2) graph outputs = (output_tokens, mutated_inputs, user_outs, param_gradients) 718 # (If we are capturing an inference graph, this convention is identical 719 # except that param_gradients is empty) 720 # See Note [Side-Effectful Tokens in AOTAutograd] for information on tokens 721 722 # Address input calling conventions: 723 start, stop = 0, num_tokens 724 input_tokens = graph_inputs[start:stop] 725 726 start, stop = stop, stop + len(parameters) 727 inputs_to_parameters = dict(zip(graph_inputs[start:stop], parameters)) 728 729 start, stop = stop, stop + len(buffers) 730 inputs_to_buffers = dict( 731 zip( 732 graph_inputs[start:stop], 733 buffers, 734 ) 735 ) 736 737 start, stop = stop, stop + num_user_inputs 738 user_inputs = graph_inputs[start:stop] 739 740 # We should've gone through all the inputs now 741 assert len(graph_inputs) - stop == 0 742 743 # Address output calling conventions: 744 start, stop = 0, num_tokens 745 output_tokens = graph_outputs[start:stop] 746 747 names = [*input_tokens, *parameters, *buffers, *user_inputs] 748 mutations = [] 749 for idx, input_info in enumerate(view_mutation_metadata.input_info): 750 if input_info.mutates_data: 751 # Only buffers can be mutated, not parameters 752 assert idx >= len(parameters) 753 mutations.append(names[idx + num_tokens]) 754 755 assert len(mutations) == view_mutation_metadata.num_mutated_inp_runtime_indices 756 757 start, stop = ( 758 stop, 759 stop + view_mutation_metadata.num_mutated_inp_runtime_indices, 760 ) 761 outputs_to_mutations = dict(zip(graph_outputs[start:stop], mutations)) 762 763 user_inputs_to_mutate = {} 764 buffers_to_mutate = {} 765 for output_name, mutation_name in outputs_to_mutations.items(): 766 if mutation_name in user_inputs: 767 user_inputs_to_mutate[output_name] = mutation_name 768 else: 769 assert mutation_name in buffers 770 buffers_to_mutate[output_name] = mutation_name 771 772 start, stop = stop, stop + num_user_outputs 773 user_outputs = graph_outputs[start:stop] 774 775 unused_outputs = len(graph_outputs) - stop 776 if backward_signature is not None: 777 unused_outputs -= len(backward_signature.gradients_to_parameters) + len( 778 backward_signature.gradients_to_user_inputs 779 ) 780 assert unused_outputs == 0 781 782 return GraphSignature( 783 parameters=parameters, # type: ignore[arg-type] 784 buffers=buffers, # type: ignore[arg-type] 785 user_inputs=user_inputs, # type: ignore[arg-type] 786 user_outputs=user_outputs, # type: ignore[arg-type] 787 inputs_to_buffers=inputs_to_buffers, # type: ignore[arg-type] 788 inputs_to_parameters=inputs_to_parameters, # type: ignore[arg-type] 789 user_inputs_to_mutate=user_inputs_to_mutate, 790 buffers_to_mutate=buffers_to_mutate, # type: ignore[arg-type] 791 in_spec=in_spec, 792 out_spec=out_spec, 793 backward_signature=backward_signature, 794 input_tokens=input_tokens, # type: ignore[arg-type] 795 output_tokens=output_tokens, # type: ignore[arg-type] 796 ) 797 798 799@dataclass 800class AOTConfig: 801 """ 802 Configuration for AOTDispatcher 803 """ 804 805 fw_compiler: Callable 806 bw_compiler: Callable 807 partition_fn: Callable 808 decompositions: Dict[OpOverload, Callable] 809 num_params_buffers: int 810 aot_id: int 811 keep_inference_input_mutations: bool 812 is_export: bool = False 813 no_tangents: bool = False 814 dynamic_shapes: bool = False 815 aot_autograd_arg_pos_to_source: Optional[List[Source]] = None 816 static_input_indices: Optional[List[int]] = None 817 inference_compiler: Optional[Callable] = None 818 enable_log: bool = True 819 # this is always false outside of export. 820 pre_dispatch: bool = False 821 822 # Key to use for AOTAutogradCache 823 cache_key: Optional[str] = None 824 825 def __post_init__(self): 826 if self.pre_dispatch: 827 assert self.is_export, "Can only have pre_dispatch IR for export." 828 829 830SubclassTracingInfo = collections.namedtuple( 831 "SubclassTracingInfo", 832 ["plain_tensor_trace_fn", "plain_tensor_args", "maybe_subclass_meta"], 833) 834