1# mypy: allow-untyped-defs 2""" 3This module defines runtime wrappers, which, based on previous analysis attempts to: 41. process the inputs and outputs 52. apply mutations 63. handle functionalized randomness 74. deduplicate inputs and consolidate views into their bases (see input_output_analysis) 8""" 9import builtins 10import collections 11import pprint 12from contextlib import nullcontext 13from dataclasses import dataclass, field 14from functools import wraps 15from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union 16 17import torch 18import torch.utils.dlpack 19from torch import Tensor 20from torch._guards import ( 21 compile_context, 22 CompileContext, 23 detect_fake_mode, 24 DuplicateInputs, 25 tracing, 26 TracingContext, 27) 28from torch._prims_common import CUDARngStateHelper 29from torch._subclasses import FakeTensor 30from torch.fx.experimental._backward_state import BackwardState 31from torch.multiprocessing.reductions import StorageWeakRef 32from torch.utils._python_dispatch import is_traceable_wrapper_subclass 33 34from .. import config 35from .collect_metadata_analysis import run_functionalized_fw_and_collect_metadata 36from .functional_utils import gen_alias_from_base 37from .input_output_analysis import ( 38 compute_overlapping_inputs, 39 create_synthetic_base_metadata, 40 remove_dupe_metadata, 41) 42from .logging_utils import describe_input, format_guard_bug_msg, track_graph_compiling 43from .schemas import ( 44 AOTConfig, 45 InputAliasInfo, 46 MutationType, 47 OutputType, 48 SubclassCreationMeta, 49 SubclassMeta, 50 TensorAlias, 51 ViewAndMutationMeta, 52) 53from .subclass_utils import ( 54 get_types_for_subclass, 55 requires_subclass_dispatch, 56 unwrap_tensor_subclasses, 57 wrap_tensor_subclasses, 58) 59from .traced_function_transforms import aot_dispatch_subclass 60from .utils import ( 61 call_func_at_runtime_with_args, 62 make_boxed_func, 63 normalize_as_list, 64 partial_flatten_asdict, 65 strict_zip, 66) 67 68 69zip = strict_zip 70 71 72class CompilerWrapper: 73 """ 74 A wrapper around the inputs and outputs to the compiler_fn. We separate these into two parts: 75 76 1. The prologue, which edits the input to the compiler_fn(flat_fn, flat_args, etc) 77 2. The epilogue, which edits the outputs of the compiler_fn (compiled_fn, real arguments) 78 79 Each wrapper below should be implemented as a CompilerWrapper, so that we can facilitate 80 caching on the compiled output, and re-wrapping the output via epilogues. 81 Extra metadata that is needed to compute pre or post compile can be passed in via attributes. 82 """ 83 84 def pre_compile( 85 self, 86 flat_fn, 87 flat_args: List[Tensor], 88 aot_config: AOTConfig, 89 *, 90 fw_metadata: ViewAndMutationMeta, 91 ) -> Tuple[Callable, List[Tensor], ViewAndMutationMeta]: 92 """ 93 Process the inputs to the compiler_fn. You can pass in extra metadata via kwargs. 94 Args: 95 flat_fn: The function to compile 96 flat_args: Metadata from example inputs of the function to compile 97 aot_config: AOTConfig passed in at compile time 98 fw_metadata: ViewAndMutationMeta generated from flat_fn and flat_args 99 """ 100 return flat_fn, flat_args, fw_metadata 101 102 def post_compile(self, compiled_fn, aot_config, *, runtime_metadata) -> Callable: 103 """ 104 Given an output of the compiler, wrap it with information received from prologue. 105 Args: 106 compiled_fn: Callable after calling compiler_fn 107 aot_config: AOTConfig after calling prologue 108 runtime_metadata: ViewAndMutationMeta after calling all wrappers's pre_compile steps. 109 Example: 110 111 def wrapped_compiled_fn(args): 112 # do something with args, aot_config, fw_metadata 113 return compiled_fn(args) 114 115 return wrapped_compiled_fn 116 """ 117 return compiled_fn 118 119 120# The wrapper created by this function handles all of the runtime aliasing and mutation "epilogue" logic 121# that needs to run after the compiled function. 122# 123# This function accepts a trace_joint flag, indicating whether or not we're generating the runtime 124# epilogue for a forward-only inference graph, or for an autograd.Function.apply function. 125# This is because there are some minor differences in how we treat these cases at runtime: 126# - resize_() is currently handled in the inference case, but not fully handled in the autograd case. 127# - the autograd cases inserts TensorAlias wrapper objects for outputs that alias inputs 128@dataclass 129class RuntimeWrapper(CompilerWrapper): 130 indices_of_inps_to_detach: List[int] 131 trace_joint: bool 132 disable_amp: bool 133 134 def post_compile( 135 self, 136 compiled_fn, 137 aot_config: AOTConfig, 138 *, 139 runtime_metadata: ViewAndMutationMeta, 140 ): 141 return _create_runtime_wrapper( 142 compiled_fn, 143 runtime_metadata=runtime_metadata, 144 indices_of_inps_to_detach=self.indices_of_inps_to_detach, 145 trace_joint=self.trace_joint, 146 keep_input_mutations=aot_config.keep_inference_input_mutations, 147 disable_amp=self.disable_amp, 148 ) 149 150 151class NoopAliasHandler: 152 def __init__(self, info, runtime_metadata, trace_joint): 153 pass 154 155 def __call__(self, orig_inputs, fw_outs, out): 156 return out 157 158 159def _unwrap_tensoralias(x): 160 assert isinstance(x, TensorAlias) 161 return x.alias 162 163 164def _identity(x): 165 return x 166 167 168class AliasOfInputHandler: 169 def __init__(self, info, runtime_metadata, trace_joint): 170 self.base_idx = info.base_idx 171 self.unwrap_out = _unwrap_tensoralias if trace_joint else _identity 172 self.requires_grad = info.requires_grad 173 self.functional_tensor = info.functional_tensor 174 self.replay_views = config.view_replay_for_aliased_outputs 175 176 def __call__(self, orig_inputs, fw_outs, out): 177 aliased_base_tensor = orig_inputs[self.base_idx] 178 return gen_alias_from_base( 179 aliased_base_tensor, 180 self.unwrap_out(out), 181 self.requires_grad, 182 self.functional_tensor, 183 replay_views=self.replay_views, 184 ) 185 186 187class IsInputHandler: 188 def __init__(self, info, runtime_metadata, trace_joint): 189 self.base_idx = info.base_idx 190 self.unwrap_out = _unwrap_tensoralias if trace_joint else _identity 191 192 def __call__(self, orig_inputs, fw_outs, out): 193 aliased_base_tensor = orig_inputs[self.base_idx] 194 return aliased_base_tensor 195 196 197class AliasOfIntermediateHandler: 198 def __init__(self, info, runtime_metadata, trace_joint): 199 if info.output_type in ( 200 OutputType.alias_of_intermediate, 201 OutputType.alias_of_intermediate_save_as_output, 202 ): 203 num_user_outputs = len(runtime_metadata.output_info) 204 self.base_idx = info.base_idx + num_user_outputs 205 else: 206 self.base_idx = info.base_idx 207 208 self.unwrap_out = _unwrap_tensoralias if trace_joint else _identity 209 self.requires_grad = info.requires_grad 210 self.functional_tensor = info.functional_tensor 211 self.replay_views = config.view_replay_for_aliased_outputs 212 213 def __call__(self, orig_inputs, fw_outs, out): 214 aliased_base_tensor = fw_outs[self.base_idx] 215 return gen_alias_from_base( 216 aliased_base_tensor, 217 self.unwrap_out(out), 218 self.requires_grad, 219 self.functional_tensor, 220 replay_views=self.replay_views, 221 ) 222 223 224_HANDLER_MAP = { 225 OutputType.non_alias: NoopAliasHandler, 226 OutputType.unsafe_view_alias: NoopAliasHandler, 227 OutputType.custom_function_view: NoopAliasHandler, 228 OutputType.alias_of_input: AliasOfInputHandler, 229 OutputType.is_input: IsInputHandler, 230 OutputType.alias_of_intermediate: AliasOfIntermediateHandler, 231 OutputType.alias_of_intermediate_save_as_output: AliasOfIntermediateHandler, 232 OutputType.alias_of_intermediate_base_is_user_output: AliasOfIntermediateHandler, 233} 234 235 236def make_output_handler(info, runtime_metadata, trace_joint): 237 handler_type = _HANDLER_MAP[info.output_type] 238 return handler_type(info, runtime_metadata, trace_joint) 239 240 241def _create_runtime_wrapper( 242 compiled_fn, 243 *, 244 runtime_metadata: ViewAndMutationMeta, 245 indices_of_inps_to_detach: List[int], 246 trace_joint: bool, 247 keep_input_mutations: bool, 248 disable_amp: bool, 249): 250 if not hasattr(compiled_fn, "_boxed_call"): 251 compiled_fn = make_boxed_func(compiled_fn) 252 253 # Note [Inputs needed in runtime epilogue after list clearing] 254 # In Python functions, you can't free the input arguments of a function within the scope of that function. A workaround is to 255 # wrap the input arguments in a list, and clear the list from within the function. 256 # Here, this is implemented as `call_func_at_runtime_with_args(..., steal_args=True)`. 257 # 258 # This is needed for Compiled Autograd since some of the inputs (activations) should be freed early. 259 # However, we cannot blindly clear the entire list, because AOTAutograd may need access to some of the graph inputs 260 # **after** the compiled function has finished running. There are two main cases: 261 # (1) Input mutations: If there are an input mutations that we must run outside of the graph, we need access to the input. 262 # (2) Output aliasing: Outputs that aliases graph inputs generally must be regenerated outside of the `autograd.Function`, 263 # and doing so requires us accessing the corresponding input after the compiled artifact has run. 264 epilogue_args_idx = [] 265 epilogue_args_idx.extend(runtime_metadata.mutated_inp_runtime_indices) 266 for info in runtime_metadata.output_info: 267 if ( 268 info.output_type == OutputType.alias_of_input 269 or info.output_type == OutputType.is_input 270 ): 271 assert isinstance(info.base_idx, int) 272 epilogue_args_idx.append(info.base_idx) 273 274 if config.unlift_effect_tokens: 275 assert len(runtime_metadata.tokens) == 0 276 277 replay_views = config.view_replay_for_aliased_outputs 278 if runtime_metadata.num_outputs_aliased > 0: 279 output_handlers = tuple( 280 make_output_handler(info, runtime_metadata, trace_joint) 281 for info in runtime_metadata.output_info 282 ) 283 284 def runtime_wrapper(args: List[Any]): 285 # stash a ref to each input tensor we plan to use after the compiled function 286 orig_inputs = {i: args[i] for i in epilogue_args_idx} 287 288 if keep_input_mutations: 289 mutated_args = ( 290 args[i] 291 for i in runtime_metadata.mutated_graph_handled_indices_seen_by_autograd 292 ) 293 torch.autograd.graph.increment_version(mutated_args) 294 295 if trace_joint: 296 args_ = list(args) 297 # See Note [Detaching inputs that never need gradients] 298 for idx in indices_of_inps_to_detach: 299 if isinstance(args_[idx], torch.Tensor): 300 args_[idx] = args_[idx].detach() 301 302 # It's possible to have trace_joint inside user specified with no_grad() region, 303 # if there is a nested with enable_grad(), that forces some outputs to require gradients. 304 # Therefore, we unconditionally turn on enable_grad() for compiled_fn execution. 305 with torch.autograd._force_original_view_tracking( 306 True 307 ), torch.enable_grad(): 308 all_outs = call_func_at_runtime_with_args( 309 compiled_fn, args_, disable_amp=disable_amp, steal_args=True 310 ) 311 else: 312 # When we have an inference graph, we run with grad disabled. 313 # It's possible to get an inference graph with inputs that require grad, 314 # in which case we want to make sure autograd is disabled 315 # (since e.g., inductor will generate aten.addmm.out calls which autograd will complain on) 316 # NOTE: We use _set_grad_enabled directly to reduce runtime overhead 317 grad_enabled = torch.is_grad_enabled() 318 try: 319 if grad_enabled: 320 torch._C._set_grad_enabled(False) 321 all_outs = call_func_at_runtime_with_args( 322 compiled_fn, args, disable_amp=disable_amp, steal_args=True 323 ) 324 finally: 325 if grad_enabled: 326 torch._C._set_grad_enabled(True) 327 del args 328 329 num_mutated_runtime_inps = runtime_metadata.num_mutated_inp_runtime_indices 330 num_intermediate_bases = runtime_metadata.num_intermediate_bases 331 332 assert ( 333 len(all_outs) 334 == num_mutated_runtime_inps 335 + runtime_metadata.num_outputs 336 + num_intermediate_bases 337 ) 338 339 # Step 3: After running the compiled fw, apply updates to mutated inputs 340 num_mutations_to_apply = runtime_metadata.num_mutated_inp_runtime_indices 341 if num_mutations_to_apply > 0: 342 updated_inputs = all_outs[:num_mutations_to_apply] 343 fw_outs = all_outs[num_mutations_to_apply:] 344 345 for i, inpt_idx in enumerate(runtime_metadata.mutated_inp_runtime_indices): 346 meta = runtime_metadata.input_info[inpt_idx] 347 if not meta.mutates_data and not meta.mutates_metadata: 348 continue 349 original_inpt = orig_inputs[inpt_idx] 350 updated_inpt = updated_inputs[i] 351 if meta.mutates_storage_metadata: 352 # See Note [set_() Input Mutations in AOTAutograd] 353 # mutates_storage_metadata means our input saw a x.set_(y) call. 354 # What if x **also** saw a data and/or a metadata mutation? 355 # (1) If the [meta]data mutation occurred after the set_(), 356 # then there is no need to copy_() the data. 357 # When we perform x.set_(x_updated), we are guaranteed that 358 # x_updated already has the final version of the data/metadata 359 # (2) If a data mutation occurred before the set_(). 360 # This case seems very difficult to support. 361 # TODO: discuss on the PR and decide if we want to tr to 362 # either support it, or detect and ban it. 363 if trace_joint: 364 assert isinstance(updated_inpt, TensorAlias) 365 updated_inpt = updated_inpt.alias 366 with torch.no_grad(): 367 original_inpt.set_(updated_inpt) 368 continue 369 if meta.mutates_metadata and not meta.mutates_data: 370 if trace_joint: 371 assert isinstance(updated_inpt, TensorAlias) 372 updated_inpt = updated_inpt.alias 373 # We need to grab the size/stride/storage_offset from the compiled forward, 374 # and use that to mutate the metadata of the input 375 original_inpt.as_strided_( 376 updated_inpt.size(), 377 updated_inpt.stride(), 378 updated_inpt.storage_offset(), 379 ) 380 else: 381 if meta.mutates_data and meta.mutates_metadata: 382 original_inpt.as_strided_( 383 updated_inpt.size(), 384 updated_inpt.stride(), 385 updated_inpt.storage_offset(), 386 ) 387 else: 388 assert meta.mutates_data 389 if meta.is_leaf and original_inpt.requires_grad: 390 # We can hit this situation in this case: 391 # def f(x): 392 # x.detach().mul_(2) 393 # return x + 1 394 # AOTAutograd will see a mutation in the above case, and try to 395 # apply a copy_() here, in the epilogue. 396 # But if x required gradients, and is a leaf, then autograd 397 # will yell at us for trying to mutate it. 398 # However, it's only possible to end up in this scenario (like the above) 399 # if all of the mutations to the leaf input were non-autograd-tracking mutations 400 # (aka mutations under no_grad(), or on detached views). 401 # In that case, we fully want to hide the mutation from autograd, so detaching is ok. 402 original_inpt.detach().copy_(updated_inpt) 403 else: 404 original_inpt.copy_(updated_inpt) 405 else: 406 fw_outs = all_outs 407 408 # Step 4: Manually regenerate any outputs that are aliased to inputs, instead of 409 # compiling them. 410 if runtime_metadata.num_outputs_aliased > 0: 411 # The compiled forward also returned intermediate bases. We don't want to return them to the user. 412 expect_num_outputs = ( 413 len(output_handlers) + runtime_metadata.num_intermediate_bases 414 ) 415 assert len(fw_outs) == expect_num_outputs 416 ret_outs = [ 417 handler(orig_inputs, fw_outs, out) 418 for out, handler in builtins.zip(fw_outs, output_handlers) 419 ] 420 else: 421 ret_outs = fw_outs 422 423 if runtime_metadata.dynamic_outputs: 424 for t, o in zip(ret_outs, runtime_metadata.output_info): 425 if o.dynamic_dims is None: 426 continue 427 if hasattr(t, "_dynamo_weak_dynamic_indices"): 428 t._dynamo_weak_dynamic_indices |= o.dynamic_dims 429 else: 430 t._dynamo_weak_dynamic_indices = o.dynamic_dims.copy() 431 if runtime_metadata.grad_enabled_mutation is not None: 432 torch._C._set_grad_enabled(runtime_metadata.grad_enabled_mutation) 433 return ret_outs 434 435 return runtime_wrapper 436 437 438@dataclass 439class FunctionalizedRngRuntimeWrapper(CompilerWrapper): 440 # TODO: I would love to get rid of this argument, but it's 441 # Wrapped pretty tightly around our aot_dispatch_autograd logic. 442 # Specifically, tensors_saved_for_backwards_slice's value is both used for calculating indices 443 # for setting placeholder strides(which is done before runtime, before this wrapper runs) 444 # and for saving tensors for backward (which is done during runtime, after this wrapper runs) 445 # So in aot_dispatch_autograd, this wrapper can't edit the set of outs without making one 446 # of those two indices incorrect. 447 return_new_outs: bool = True 448 449 def pre_compile( 450 self, 451 flat_fn, 452 flat_args, 453 aot_config, 454 *, 455 fw_metadata, 456 ) -> Tuple[Callable, List[Tensor], ViewAndMutationMeta]: 457 if config.functionalize_rng_ops: 458 # Update example inputs for the fw_compiler 459 fake_mode = detect_fake_mode() 460 seed, offset = CUDARngStateHelper.get_torch_state_as_tuple(fake_mode) 461 flat_args.extend([seed, offset]) 462 # We are not clearing flat_args here because 463 # 1) There is a check in the debug compiler at the end 464 # 2) It does not matter as these are fake tensors 465 return flat_fn, flat_args, fw_metadata 466 467 def post_compile( 468 self, 469 compiled_fn, 470 aot_config: AOTConfig, 471 *, 472 runtime_metadata: ViewAndMutationMeta, 473 ): 474 @wraps(compiled_fn) 475 def wrapper(runtime_args: List[Any]): 476 if runtime_metadata.is_rng_op_functionalized: 477 # Add the seed and offset to args 478 seed, offset = CUDARngStateHelper.get_torch_state_as_tuple() 479 runtime_args.extend([seed, offset]) 480 out = compiled_fn(runtime_args) 481 out = self._functionalized_rng_runtime_epilogue( 482 runtime_metadata, 483 out, 484 # TODO: this won't be right for the backward when we convert the call_compiled_backward to use the wrapper 485 runtime_metadata.num_forward_returns, 486 ) 487 return out 488 return compiled_fn(runtime_args) 489 490 return wrapper 491 492 # Calling convention: If we are running functionalized RNG, then outs consists 493 # of (user_outs, rng_offset) 494 def _functionalized_rng_runtime_epilogue( 495 self, 496 metadata: ViewAndMutationMeta, 497 outs, 498 offset_index, 499 ): 500 if metadata.is_rng_op_functionalized: 501 assert metadata.num_outputs_rng_offset == 1 502 new_rng_offset = outs[offset_index] 503 CUDARngStateHelper.set_new_offset(new_rng_offset) 504 if self.return_new_outs: 505 user_outs = outs[:offset_index] + outs[offset_index + 1 :] 506 return user_outs 507 else: 508 return outs 509 510 return outs 511 512 513@dataclass 514class FakifiedOutWrapper(CompilerWrapper): 515 out_metas: List[torch.Tensor] = field(default_factory=list) 516 # TracingContext.fwd_output_strides 517 # Generated from actually doing compile 518 fwd_output_strides: Optional[List[List[int]]] = None 519 needs_post_compile: bool = True 520 521 def pre_compile( 522 self, 523 fw_module, # Must be fw_module from aot_dispatch_*_graph 524 flat_args, 525 aot_config, 526 *, 527 fw_metadata, 528 ) -> Tuple[Callable, List[Tensor], ViewAndMutationMeta]: 529 tracing_context = torch._guards.TracingContext.try_get() 530 if tracing_context and tracing_context.fakify_first_call: 531 self.out_metas = [ 532 n.meta["val"] for n in (list(fw_module.graph.nodes)[-1].args[0]) 533 ] 534 else: 535 self.needs_post_compile = False 536 return fw_module, flat_args, fw_metadata 537 538 def _compute_output_meta_with_inductor_strides(self): 539 out = self.out_metas 540 fwd_output_strides = self.fwd_output_strides 541 if not fwd_output_strides: 542 return out 543 544 from torch.fx.experimental.symbolic_shapes import statically_known_true 545 546 for i in range(len(out)): 547 if not isinstance(out[i], Tensor): 548 continue 549 if all( 550 statically_known_true(s1 == s2) 551 for s1, s2 in zip(out[i].stride(), fwd_output_strides[i]) 552 ): 553 continue 554 out[i] = out[i].as_strided(out[i].shape, fwd_output_strides[i]) 555 return out 556 557 # To be called post compile 558 def set_fwd_output_strides(self, fwd_output_strides): 559 self.fwd_output_strides = fwd_output_strides 560 561 def post_compile( 562 self, 563 compiled_fn, 564 aot_config: AOTConfig, 565 *, 566 runtime_metadata: ViewAndMutationMeta, 567 ): 568 if self.needs_post_compile: 569 assert self.fwd_output_strides is not None 570 fakified_out = self._compute_output_meta_with_inductor_strides() 571 572 @wraps(compiled_fn) 573 def wrapper(runtime_args): 574 nonlocal fakified_out 575 if fakified_out is not None: 576 out = fakified_out 577 fakified_out = None 578 return out 579 return compiled_fn(runtime_args) 580 581 return wrapper 582 # If we don't need to fakify, we can just return the original compiled function 583 return compiled_fn 584 585 586# This wrapper handles the AOTDispatch runtime logic for tensor subclasses. 587# At runtime, we have a compiled function that knows how to operate on the domain of DenseTensor -> DenseTensor, 588# But the user might have passed us some tensor subclass inputs (or expect some subclass tensor outputs). 589# This function handles the wrapping and unwrapping of tensor subclasses at runtime. 590@dataclass 591class AOTDispatchSubclassWrapper(CompilerWrapper): 592 trace_joint: bool 593 fw_only: Optional[Callable] # Not cached, only used in pre_compile 594 maybe_subclass_meta: Optional[SubclassMeta] 595 num_fw_outs_saved_for_bw: Optional[int] 596 597 def pre_compile( 598 self, 599 flat_fn, 600 flat_args: List[Tensor], 601 aot_config: AOTConfig, 602 *, 603 fw_metadata: ViewAndMutationMeta, 604 ): 605 (new_flat_fn, new_flat_args, subclass_meta) = aot_dispatch_subclass( 606 flat_fn, 607 flat_args, 608 is_joint_structure=self.trace_joint, 609 meta=fw_metadata, 610 fw_only=self.fw_only, # type: ignore[arg-type] 611 ) 612 self.maybe_subclass_meta = subclass_meta 613 return new_flat_fn, new_flat_args, fw_metadata 614 615 def post_compile( 616 self, 617 compiled_fn, 618 _aot_config: AOTConfig, 619 *, 620 runtime_metadata: ViewAndMutationMeta, 621 ): 622 if self.maybe_subclass_meta is None: 623 return compiled_fn 624 625 subclass_metas = runtime_metadata.subclass_fw_graph_out_meta 626 627 @wraps(compiled_fn) 628 def inner_fn(args: List[Any]): 629 unwrapped_args = unwrap_tensor_subclasses( 630 args, is_joint_structure=self.trace_joint 631 ) 632 args.clear() 633 # expectation: runtime_fn is a boxed fn 634 unwrapped_outs = compiled_fn(unwrapped_args) 635 wrapped_outs = wrap_tensor_subclasses( 636 unwrapped_outs, 637 subclass_metas=subclass_metas, 638 num_fw_outs_saved_for_bw=self.num_fw_outs_saved_for_bw, 639 is_runtime=True, 640 ) 641 return wrapped_outs 642 643 # box it 644 inner_fn._boxed_call = True # type: ignore[attr-defined] 645 return inner_fn 646 647 648@dataclass 649class EffectTokensWrapper(CompilerWrapper): 650 def post_compile( 651 self, 652 compiled_fn, 653 _aot_config, 654 *, 655 runtime_metadata: ViewAndMutationMeta, 656 ): 657 num_tokens = len(runtime_metadata.tokens) 658 659 @wraps(compiled_fn) 660 def inner_fn(args: List[Any]): 661 if num_tokens > 0: 662 # Pass in forward effect tokens (See Note [Side-Effectful Tokens in AOTAutograd]) 663 old_args = args 664 args = [*([None] * num_tokens), *args] 665 old_args.clear() 666 667 outs = compiled_fn(args) 668 669 # Inductor cache DummyModule can return None 670 if outs is None: 671 return None 672 # Toss out the effect tokens (See Note [Side-Effectful Tokens in AOTAutograd]) 673 return outs[num_tokens:] if num_tokens != 0 else outs 674 675 # box it 676 inner_fn._boxed_call = True # type: ignore[attr-defined] 677 return inner_fn 678 679 680# MOTIVATION: 681# 682# When tracing functions for future execution, one must be careful not to pass 683# in the same input tensor multiple times (e.g., f(x, x), as this can result 684# in graphs that are ONLY valid if you later pass a new tensor in exactly the 685# same way (e.g., f(y, y)). (NB: we really mean duplicate; two distinct 686# tensors that alias each other is a different situation that is covered by 687# aot_dispatch_deduplicated_autograd). Here are two examples: 688# 689# (1) Suppose you have a function: 690# 691# def f(x, y): 692# return x + y 693# 694# If you make_fx(f)(x, x), you will trace out: 695# 696# def f(x, y): 697# return y + y 698# 699# Oops! 700# 701# (2) For most tensors x and y, you can compute f's gradient with respect to 702# these to inputs by saying torch.autograd.grad(f(x, y), (x, y)). However, 703# if x is y, you will trace out a program that gets incorrect gradients: 704# 705# >>> x = torch.randn(1, requires_grad=True) 706# >>> torch.autograd.grad(x + x, (x, x)) 707# (tensor([2.]), tensor([2.])) 708# 709# In other words, the gradient is double-counted. Deduplicating the arguments 710# gives you an appropriate gradient: 711# 712# >>> y = torch.randn(1, requires_grad=True) 713# >>> torch.autograd.grad(x + y, (x, y)) 714# (tensor([1.]), tensor([1.])) 715# 716# HOW TO DEDUPLICATE: 717# 718# There are a few strategies, in order of preference: 719# 720# 1. For every duplicate argument to the function, detach it into 721# a separate leaf tensor, so that it is no longer duplicated. 722# 723# PRO: The resulting compiled graph works for any configuration 724# of duplicated arguments. 725# 726# CON: It does not (naively) work if you mutate the metadata of inputs: 727# 728# def f(x, y): 729# x.transpose_(0, 1) 730# y.transpose_(0, 2) 731# 732# x = torch.randn(2, 3, 4) 733# f(x, x) 734# 735# The ordering of the transposes inside f dictates whether or not 736# you get [4, 2, 3] or [3, 4, 2]. This means that you cannot precompute 737# what metadata mutations should get applied to each input; you need to 738# assume they aren't duplicates (what we do today) or preserve 739# the original metadata mutations exactly in order, so that they work 740# for any duplicate configuration. 741# 742# CON: It does not (naively) work if you mutate the data of inputs. 743# In particular, leaf tensors that require grad cannot be mutated, 744# this makes it impossible to differentiate with respect to the original 745# base. 746# 747# 2. For every duplicate argument to the function, remove it, so it is 748# no longer part of the "true" signature: 749# 750# PRO: Implemented naively, it still works for metadata/data mutation. 751# 752# CON: The resulting compiled graph is duplicate-specialized: it only 753# works if future calls duplicate arguments in exactly the same way. 754# Horribly, Dynamo doesn't guard on this at the moment. But even if 755# it did, you could still end up recompiling a bunch of each duplicate. 756# 757# Our strategy is to do (1) if we can, and do (2) otherwise, erroring if 758# Dynamo's guards are not enough. In practice, this seems to cover 759# everything. 760# 761@dataclass 762class AOTDedupeWrapper(CompilerWrapper): 763 keep_arg_mask: List[bool] = field(default_factory=list) 764 add_dupe_map: List[int] = field(default_factory=list) 765 old_input_metadata: List[InputAliasInfo] = field(default_factory=list) 766 needs_post_compile: bool = True 767 768 # NB: Hot path, avoid set lookups here 769 # TODO: Can avoid the zip here too, probably 770 def remove_dupe_args(self, args): 771 return [t for t, keep in zip(args, self.keep_arg_mask) if keep] 772 773 def add_dupe_args(self, args): 774 return [args[i] for i in self.add_dupe_map] 775 776 def pre_compile( 777 self, 778 flat_fn, 779 flat_args: List[Tensor], 780 aot_config: AOTConfig, 781 *, 782 fw_metadata: ViewAndMutationMeta, 783 ) -> Tuple[Callable, List[Tensor], ViewAndMutationMeta]: 784 # Use information about whether or not flat_fn mutates its arguments 785 # or not to handle dupe args 786 787 # Strategy 1: For any input that is not mutated, we can leafify it if we 788 # need to remove a duplicate. 789 leaf_flat_args = [] 790 args_set = set() 791 ok = True 792 793 for i, a in enumerate(flat_args): 794 if not isinstance(a, torch.Tensor): 795 leaf_flat_args.append(a) 796 elif a not in args_set: 797 args_set.add(a) 798 leaf_flat_args.append(a) 799 elif ( 800 not fw_metadata.input_info[i].mutates_data 801 and not fw_metadata.input_info[i].mutates_metadata 802 ): 803 leaf_flat_args.append(a.detach().requires_grad_(a.requires_grad)) 804 else: 805 ok = False 806 break 807 808 if ok: 809 self.needs_post_compile = False 810 return flat_fn, leaf_flat_args, fw_metadata 811 812 if requires_subclass_dispatch(leaf_flat_args, fw_metadata): 813 raise RuntimeError( 814 """\ 815 Encountered duplicate inputs that are mutated in the graph, but at least one input/output 816 to the graph is a tensor subclass. This is not supported today. You can try to 817 remove the aliasing yourself as a workaround, or otherwise file an issue on github.""" 818 ) 819 820 # export path: ban duplicate inputs for now, add later if requested. 821 if aot_config.is_export: 822 raise RuntimeError( 823 f"""\ 824 Encountered duplicated inputs that are mutated in the graph you are trying to export. 825 This functionality is currently not supported. If needed, please file a github issue. 826 827 fw_metadata={str(fw_metadata)} 828 """ 829 ) 830 831 # Strategy 2: Duplicate specialize. 832 # 833 # In Haskell types, suppose you have: 834 # 835 # add_dupe_args :: DedupedArgs -> Args 836 # remove_dupe_args :: Args -> DedupedArgs 837 # 838 # compiler_fn 839 # :: (DedupedArgs -> R) -> DedupedArgs -> AOTConfig -> (DedupedArgs -> R) 840 # deped_compiler_fn 841 # :: (Args -> R) -> Args -> AOTConfig -> (Args -> R) 842 # 843 # Then the code below can be written in point-free style as: 844 # 845 # deduped_compiler_fn f a c = 846 # compiler_fn (f . add_dupe_args) (remove_dupe_args a) c . remove_dupe_args 847 # 848 # Suppose you have: 849 # 850 # [a, b, a, c] 851 # 852 # We want: 853 # 854 # remove_dupe_args([a, b, a, c]) == [a, b, c] 855 # add_dupe_args([a, b, c]) == [a, b, a, c] 856 # 857 # This is done via (respectively): 858 # 859 # seen_args = {a: 0, b: 1, c: 2} 860 # enumerate(add_dupe_map) = [ # how to get args from the deduped list 861 # (0, 0), 862 # (1, 1), 863 # (2, 0), 864 # (3, 2), 865 # ] 866 # keep_arg_mask = [True, True, False, True] 867 868 seen_args: Dict[Tensor, int] = {} 869 # Implicitly map duped arg position (list index) to de-duped arg position 870 keep_arg_mask: List[bool] = [] 871 add_dupe_map: List[int] = [] 872 duped_arg_len = len(flat_args) 873 874 j = 0 # index into deduped_flat_args 875 for t in flat_args: 876 if isinstance(t, torch.Tensor): 877 if t in seen_args: 878 keep_arg_mask.append(False) 879 add_dupe_map.append(seen_args[t]) 880 continue 881 seen_args[t] = j 882 883 keep_arg_mask.append(True) 884 add_dupe_map.append(j) 885 j += 1 886 assert ( 887 len(add_dupe_map) == duped_arg_len 888 ), f"Expects add_dupe_map to have length {duped_arg_len} but got {len(add_dupe_map)}" 889 890 self.keep_arg_mask = keep_arg_mask 891 self.add_dupe_map = add_dupe_map 892 893 deduped_flat_args = self.remove_dupe_args(flat_args) 894 895 # Update our input metadata to remove duped input metadata. 896 updated_fw_metadata = remove_dupe_metadata( 897 fw_metadata, keep_arg_mask, add_dupe_map 898 ) 899 900 if ( 901 tracing_context := TracingContext.try_get() 902 and aot_config.aot_autograd_arg_pos_to_source 903 ): 904 # TODO(voz): This structure is 1:1, we could consider an alternate structure like 905 # kept_pos:[dupe_arg_pos], however, add_dupe_map is 1:1 so we would need a new structure there, 906 # which feels like needless complexity for a tiny bit of efficiency at this point. 907 for dupe_arg_pos, (kept_pos, keep_arg) in enumerate( 908 zip(add_dupe_map, keep_arg_mask) 909 ): 910 if not keep_arg: 911 dupe_arg_source = aot_config.aot_autograd_arg_pos_to_source[ 912 dupe_arg_pos 913 ] 914 kept_arg_source = aot_config.aot_autograd_arg_pos_to_source[ 915 kept_pos 916 ] 917 tracing_context.guards_context.aotautograd_guards.append( # type: ignore[attr-defined] 918 DuplicateInputs(kept_arg_source, dupe_arg_source) 919 ) 920 921 @wraps(flat_fn) 922 def wrapped_flat_fn(*args): 923 return flat_fn(*self.add_dupe_args(args)) 924 925 if config.debug_assert: 926 ref_fw_metadata = run_functionalized_fw_and_collect_metadata( 927 wrapped_flat_fn, 928 static_input_indices=aot_config.static_input_indices, 929 keep_input_mutations=fw_metadata.keep_input_mutations, 930 is_train=fw_metadata.is_train, 931 )(*deduped_flat_args) 932 assert ( 933 ref_fw_metadata == updated_fw_metadata 934 ), f"ref_metadata={str(ref_fw_metadata)}, actual_metadata={str(updated_fw_metadata)}" 935 936 return wrapped_flat_fn, deduped_flat_args, updated_fw_metadata 937 938 def post_compile( 939 self, 940 compiled_fn, 941 aot_config: AOTConfig, 942 *, 943 runtime_metadata: ViewAndMutationMeta, 944 ): 945 if not self.needs_post_compile: 946 return compiled_fn 947 948 @wraps(compiled_fn) 949 def wrapped_compiled_fn(args: List[Any]): 950 deduped_args = self.remove_dupe_args(args) 951 args.clear() 952 return compiled_fn(deduped_args) 953 954 wrapped_compiled_fn._boxed_call = True # type: ignore[attr-defined] 955 956 # This can be uncommented when we properly guard for duplicates, 957 # but right now we must not do it. 958 # if not config.debug_assert: 959 # return wrapped_compiled_fn 960 961 @wraps(wrapped_compiled_fn) 962 def debugged_compiled_fn(args): 963 # Test that the computed remove/add arg functions are an inverse 964 new_args = self.add_dupe_args(self.remove_dupe_args(args)) 965 seen: Dict[Any, None] = {} 966 for i, (x, y) in enumerate(zip(new_args, args)): 967 seen[y] = None 968 assert x is y, format_guard_bug_msg( 969 aot_config, 970 f"{describe_input(i, aot_config)} would be a duplicate of " 971 f"{describe_input(self.add_dupe_map[i], aot_config)}", 972 ) 973 # This is only an error if there is metadata mutation on both of 974 # the duped arguments; in this case, we need to know what order 975 # the metadata mutation applies in. You'll get the correct result 976 # otherwise, because a graph that assumes distinct inputs works if 977 # you dupe the inputs (the gradient contributions from each input 978 # will get summed up appropriately.) 979 # 980 # TODO: work out how to setup this assert correctly 981 """ 982 assert len(seen) == unique_args, format_guard_bug_msg(aot_config, 983 f"there would be {unique_args} distinct arguments" 984 ) 985 """ 986 return wrapped_compiled_fn(args) 987 988 debugged_compiled_fn._boxed_call = True # type: ignore[attr-defined] 989 990 return debugged_compiled_fn 991 992 993# This layer handles the situation where you have two inputs that alias each other, 994# and one of the inputs is mutated. 995# We need to take special care to ensure that the mutation is applied to the other aliases in the graph. 996# 997# pre-condition: AOTDedupWrapper has already run. 998# (This function will in theory work if there are duplicate args. 999# However, the synthetic base code path is a bit sub-optimal, and running with dupe'd inputs 1000# would cause us to hit that path more frequently). 1001@dataclass 1002class AOTSyntheticBaseWrapper(CompilerWrapper): 1003 # Currently, the only reason we need to plumb this bool is because 1004 # the synthetic base code prohibits more cases in the autograd case than the inference case. 1005 trace_joint: bool # TODO: refactor trace_joint 1006 needs_post_compile: bool = True 1007 aliased_arg_idx_with_metadata_mutations: List[int] = field(default_factory=list) 1008 1009 def pre_compile( 1010 self, 1011 flat_fn, 1012 flat_args: List[Any], 1013 aot_config: AOTConfig, 1014 *, 1015 fw_metadata: ViewAndMutationMeta, 1016 ) -> Tuple[Callable, List[Tensor], ViewAndMutationMeta]: 1017 is_inference = not self.trace_joint 1018 flat_args_with_synthetic_bases, synthetic_base_info = merge_view_inputs( 1019 flat_args, 1020 fw_metadata.input_info, 1021 is_inference=is_inference, 1022 ) 1023 1024 # Happy path: we don't need synthetic bases 1025 if synthetic_base_info is None: 1026 self.needs_post_compile = False 1027 return flat_fn, flat_args, fw_metadata 1028 1029 # export path: ban synthetic bases for now, add later if requested. 1030 if requires_subclass_dispatch(flat_args, fw_metadata): 1031 raise RuntimeError( 1032 """\ 1033 Encountered aliased inputs that are mutated in the graph, but at least one input/output 1034 to the graph is a tensor subclass. This is not supported today. You can try to 1035 remove the aliasing yourself as a workaround, or otherwise file an issue on github.""" 1036 ) 1037 1038 if aot_config.is_export: 1039 raise RuntimeError( 1040 f"""\ 1041 Encountered aliased inputs that are mutated in the graph you are trying to export. 1042 This functionality is currently not supported. If needed, please file a github issue. 1043 1044 synthetic_base_info={str(synthetic_base_info)} 1045 1046 fw_metadata={str(fw_metadata)} 1047 """ 1048 ) 1049 1050 assert len(fw_metadata.input_info) == len(synthetic_base_info) 1051 1052 # Update our forward metadata to take synthetic bases into account 1053 ( 1054 fw_metadata_updated, 1055 aliased_arg_idx_with_metadata_mutations, 1056 ) = create_synthetic_base_metadata( 1057 fw_metadata, synthetic_base_info, flat_args, flat_args_with_synthetic_bases 1058 ) 1059 # Save old input args for post-compile 1060 self.old_input_info = fw_metadata.input_info 1061 1062 self.aliased_arg_idx_with_metadata_mutations = ( 1063 aliased_arg_idx_with_metadata_mutations 1064 ) 1065 1066 num_aliased_args_with_metadata_mutations = len( 1067 aliased_arg_idx_with_metadata_mutations 1068 ) 1069 1070 replay_views = config.view_replay_for_aliased_outputs 1071 1072 def _unpack_synthetic_bases(primals: Tuple[Any, ...]) -> List[Any]: 1073 f_args_inner = [] 1074 for inner_idx_or_tuple in synthetic_base_info: 1075 if isinstance(inner_idx_or_tuple, int): 1076 f_args_inner.append(primals[inner_idx_or_tuple]) 1077 else: 1078 inner_base_idx, view_tensor = inner_idx_or_tuple 1079 base = primals[inner_base_idx] 1080 view_arg = gen_alias_from_base( 1081 base, 1082 view_tensor, 1083 view_tensor.requires_grad, 1084 replay_views=replay_views, 1085 ) 1086 f_args_inner.append(view_arg) 1087 return f_args_inner 1088 1089 @wraps(flat_fn) 1090 def wrapped_flat_fn(*args): 1091 unpacked_args = _unpack_synthetic_bases(args) 1092 # This is a bit subtle. The goal of this entire function (aot_dispatch_synthetic_bases) 1093 # is to relieve the downstream logic from having to reason about mutations on inputs that alias 1094 # each other, by replacing aliased inputs with a synthetic base. 1095 # One area where this breaks down a bit however is if one of those aliased inputs 1096 # experienced a metadata mutation. 1097 # We are now obligated to reapply the metadata mutation directly to the user's input; 1098 # it isn't enough to apply mutations back to the synthetic base in the downstream logic. 1099 # 1100 # The way we handle this is by pretending that those aliased inputs that experience metadata mutations 1101 # are additional outputs in the user's forward function. 1102 # The downstream logic will just treat these as "user outputs that alias inputs". 1103 # However, we will manually grab them at runtime here, use them to reapply the metadata mutation 1104 # to the user inputs, and not return them to the user. 1105 aliased_args_with_metadata_mutations = [ 1106 x 1107 for i, x in enumerate(unpacked_args) 1108 if i in self.aliased_arg_idx_with_metadata_mutations 1109 ] 1110 if len(aliased_args_with_metadata_mutations) > 0: 1111 return *(flat_fn(*unpacked_args)), *aliased_args_with_metadata_mutations 1112 else: 1113 return flat_fn(*unpacked_args) 1114 1115 if config.debug_assert: 1116 ref_fw_metadata = run_functionalized_fw_and_collect_metadata( 1117 wrapped_flat_fn, 1118 static_input_indices=aot_config.static_input_indices, 1119 keep_input_mutations=fw_metadata.keep_input_mutations, 1120 is_train=fw_metadata.is_train, 1121 )(*flat_args_with_synthetic_bases) 1122 assert ref_fw_metadata == fw_metadata_updated, ( 1123 f"ref_metadata={pprint.pformat(partial_flatten_asdict(ref_fw_metadata))}, " 1124 f"\nactual_metadata={pprint.pformat(partial_flatten_asdict(fw_metadata_updated))}" 1125 ) 1126 return ( 1127 wrapped_flat_fn, 1128 flat_args_with_synthetic_bases, 1129 fw_metadata_updated, 1130 ) 1131 1132 def post_compile( 1133 self, 1134 compiled_fn, 1135 aot_config: AOTConfig, 1136 *, 1137 runtime_metadata: ViewAndMutationMeta, 1138 ): 1139 if not self.needs_post_compile: 1140 return compiled_fn 1141 1142 is_inference = not self.trace_joint 1143 1144 @wraps(compiled_fn) 1145 def wrapped_compiled_fn(args): 1146 args_with_synthetic_bases, synthetic_base_info = merge_view_inputs( 1147 args, self.old_input_info, is_inference=is_inference 1148 ) 1149 assert synthetic_base_info is not None 1150 aliased_args_w_metadata_mutations = [ 1151 args[i] for i in self.aliased_arg_idx_with_metadata_mutations 1152 ] 1153 num_aliased_args_with_metadata_mutations = len( 1154 aliased_args_w_metadata_mutations 1155 ) 1156 args.clear() 1157 outs = compiled_fn(args_with_synthetic_bases) 1158 if num_aliased_args_with_metadata_mutations > 0: 1159 # This code does not handle **all** input metadata mutations. 1160 # Instead, it only handles metadata mutations on inputs that were converted into synthetic bases 1161 # (which only happens if at least one aliased input experienced a data mutation). 1162 # e.g: 1163 # def f(a, b): 1164 # a.mul_(2) 1165 # b.t_(1, 0) 1166 # f(x.view(2, 2), x.view(2, 2)) 1167 mutated_metadata_inps = outs[-num_aliased_args_with_metadata_mutations:] 1168 user_outs = outs[:-num_aliased_args_with_metadata_mutations] 1169 for inp, mutated_inp in zip( 1170 aliased_args_w_metadata_mutations, mutated_metadata_inps 1171 ): 1172 inp.as_strided_( 1173 mutated_inp.size(), 1174 mutated_inp.stride(), 1175 mutated_inp.storage_offset(), 1176 ) 1177 return user_outs 1178 return outs 1179 1180 return wrapped_compiled_fn 1181 1182 1183# Note [Handling mutations on an input that aliases other inputs] 1184# The easiest example to show-case this edge case is here: 1185# 1186# def f(a, b): 1187# a.mul_(2) 1188# out = a + b 1189# return out 1190# b = torch.ones(...) 1191# a = b.view(-1) 1192# f(a, b) 1193# 1194# In this situation, if a and b happened to be aliased, we need to trace something different! 1195# Suppose we had b = a.view(-1) 1196# (In this case, that means that `a._base is b`) 1197# 1198# We need to ensure that the aliasing relationship between a and b is preserved. 1199# We do that detecting the specific situation above (mutate an input that aliases another input), 1200# and when we do that, we create a synthetic base argument. Then inside of the traced forward, 1201# we regenerate a and b off of that base. 1202# The complete example of the transformed function looks like this: 1203# 1204# // The traced forward takes in a synthetic base, and regenerates the aliased inputs as views 1205# // We could consider getting view-replay support here to minimize as_strided_scatter ops in the graph 1206# def traced_forward(base): 1207# a = base.as_strided(...) 1208# b = base.as_strided(...) 1209# a_updated = a.mul(2) 1210# base_updated = torch.as_strided_scatter(base, a_updated, ...) 1211# b_updated = base_updated.as_strided(...) 1212# out = a_updated + b_updated 1213# return a_updated, out 1214# 1215# def compiled_fn(a, b): 1216# // we detect that a is the "differentiable base" here 1217# base = a 1218# // In other situations, we might do either: 1219# // (1) a and b are both views off of some larger differentiable base 1220# // assert a._base is b._base and a._base is not None 1221# // base = a._base 1222# // (2) a and b both don't require gradients. Create a base from the storage 1223# // assert a._base is None and b._base is None 1224# // base = torch.Tensor(a.storage()) 1225# a_updated, out = traced_forward(base) 1226# a.copy_(a_updated) 1227# return out 1228# 1229# This function: 1230# (1) Merges input views into a synthetic base argument, when any of those input views are mutated 1231# (2) Returns metadata telling the autograd.Function how to modify their arguments properly, 1232# to respect the new calling convention. 1233# 1234# The calling convention is as follows. 1235# Any inputs that were originally views of one another get yanked, and replaced with a synthetic base. 1236# The argument list ordering goes [base1, ..., baseN], [arg1, ..., argN], 1237# Where the ordering of the bases is determined from the ordering of the original view args. 1238# baseA will come before baseB if the earliest original argument coming from baseA 1239# showed up earlier in the argument list than the earliest original argument coming from baseB. 1240# 1241# Example, given some tensors a, b, c, d 1242# call site: 1243# f(a, c.view(-1), b.view(-1), b, c, d) 1244# Modified argument list: 1245# c_base comes first because the first c view came earlier in arg list than the first b view 1246# a and d still show up in the modified arg list, but b and c don't- they're regenerated from their bases 1247# b_base = torch.Tensor(b.storage()) 1248# c_base = torch.Tensor(c.storage()) 1249# f(c_base, b_base, a, d) 1250def merge_view_inputs( 1251 fwd_inputs: List[Any], 1252 mutated_input_info: List[InputAliasInfo], 1253 *, 1254 # The autograd case currently has more restrictions than the inference case. 1255 is_inference: bool, 1256) -> Tuple[List[Any], Optional[List[Union[int, Tuple[int, torch.Tensor]]]]]: 1257 def _are_differentiable_views(view1, view2): 1258 if view1 is view2: 1259 return True 1260 if view1._base is None and view2._base is None: 1261 return False 1262 if view1._base is view2._base or view1._base is view2 or view1 is view2._base: 1263 return True 1264 return False 1265 1266 def _same_dtype_views(view1, view2): 1267 if view1.dtype != view2.dtype: 1268 return False 1269 if view1._base is not None and view1.dtype != view1._base.dtype: 1270 return False 1271 if view2._base is not None and view2.dtype != view2._base.dtype: 1272 return False 1273 return True 1274 1275 assert len(fwd_inputs) == len(mutated_input_info) 1276 if not [info for info in mutated_input_info if info.mutates_data]: 1277 # Return early when there are no mutations. 1278 return fwd_inputs, None 1279 1280 storage_ref_to_idx: Dict[StorageWeakRef, List[int]] = collections.defaultdict(list) 1281 base_args = [] 1282 other_args = [] 1283 for i, inpt in enumerate(fwd_inputs): 1284 if isinstance(inpt, Tensor): 1285 storage_ref = StorageWeakRef(inpt.untyped_storage()) 1286 storage_ref_to_idx[storage_ref].append(i) 1287 else: 1288 other_args.append(inpt) 1289 # Note [Synthetic Base Info Metadata] 1290 # This list contains metadata that tells you what the i'th argument in the inner calling convention should be. 1291 # It's either: 1292 # - another int (corresponding to the index in the argument list of the element from the outer calling convention) 1293 # - idx, view_tensor, where we can generate the new output with view_tensor._view_func(old_args[idx]) 1294 # idx corresponds to which synthetic base from the outer calling context to view 1295 inner_calling_convention_meta: Dict[int, Union[int, Tuple[int, torch.Tensor]]] = {} 1296 for aliased_input_indices in storage_ref_to_idx.values(): 1297 if len(aliased_input_indices) <= 1 or not any( 1298 # We only care about mutations that affect all aliases, 1299 # so metadata mutations on an input doesn't require us to do synthetic base handling. 1300 mutated_input_info[inpt_idx].mutates_data 1301 for inpt_idx in aliased_input_indices 1302 ): 1303 for curr_idx in aliased_input_indices: 1304 other_args.append(fwd_inputs[curr_idx]) 1305 continue 1306 1307 # Here, we attempt to do a more complicated check to detect false aliasing 1308 # (e.g. if all the tensors have the same storage, but don't actually overlap) 1309 # In theory, we could have a large group of tensors that all share storages, where only *some* of them 1310 # have overlapping memory. 1311 # I don't bother with that case for now: here, we only bail out earlier if we detect that **every** pair 1312 # of tensors in the current group that shares a storage is non-overlapping. 1313 aliased_input_indices_no_false_sharing = compute_overlapping_inputs( 1314 fwd_inputs, aliased_input_indices 1315 ) 1316 if len(aliased_input_indices_no_false_sharing) <= 1: 1317 for curr_idx in aliased_input_indices: 1318 other_args.append(fwd_inputs[curr_idx]) 1319 continue 1320 1321 # We detected an input that was mutated, AND aliases with another input. 1322 # we need to replace this set of aliased inputs with a single synthetic base. 1323 # For now, I'm banning a bunch of cases. We expect dynamo to properly detect these cases 1324 # and error out. We can fix them later. 1325 # These checks are transitive, so we don't need to check every pair. 1326 for idx1, idx2 in zip( 1327 aliased_input_indices, aliased_input_indices[1:], strict=False 1328 ): 1329 view1 = fwd_inputs[idx1] 1330 view2 = fwd_inputs[idx2] 1331 # The "inputs that are aliased but have different differentiable bases" case 1332 # is more complicated and hopefully pretty rare. Not currently handled. 1333 if not is_inference: 1334 assert _are_differentiable_views( 1335 view1, view2 1336 ), "aot_autograd() does not yet handle non-differentiable view input mutations." 1337 # Regenerating views when reinterpreting complex / real tensors seems non-trivial, 1338 # not handling for now 1339 assert _same_dtype_views( 1340 view1, view2 1341 ), "aot_autograd() does not yet handle input mutations on views with different dtypes." 1342 non_none_bases = [ 1343 fwd_inputs[i]._base 1344 for i in aliased_input_indices 1345 if fwd_inputs[i]._base is not None 1346 ] 1347 aliases_with_none_bases = [ 1348 fwd_inputs[i] for i in aliased_input_indices if fwd_inputs[i]._base is None 1349 ] 1350 if len(non_none_bases) == 0: 1351 # Case where none of the aliases have a ._base 1352 # we generate a synthetic base without gradients, and generate views off of it 1353 # We hit this case when we have input tensors to the graph that share a storage, 1354 # but do not have a ._base field. 1355 # Wondering when we hit this case? 1356 # The _base field simply says that autograd knows about the aliasing relationship, 1357 # but sometimes we create tensors which are aliased out of the same storage but guaranteed 1358 # to be disjoint. In these cases, we will skip setting up the _base relationship 1359 # for performance reasons (because the fact that the tensors share the same storage 1360 # is unobservable unless you (1) do naughty things with resize_/as_strided 1361 # or (2) look at the storage--as we are doing here.) 1362 # One particular example of this is optimizer steps on the LSTM module: 1363 # LSTM parameters are packed into a contiguous storage for efficiency reasons when 1364 # calling cuDNN kernels, so when these parameters get passed to the optimizer we will 1365 # find they share the same storage, but do not have _base set since they are all disjoint. 1366 # 1367 # NOTE: There is one case where this is unsafe: 1368 # torch.Tensor(storage) will ALWAYS create a 1D tensor, which is not necessarily 1369 # the same shape as the "actual" base that the tensor came from. 1370 # For the most part this is fine, because we always use as_strided() 1371 # to generate the original aliased inputs again. 1372 # If we were to use view-replay though, this could cause the aliased views 1373 # to have incorrect sizes. 1374 example_idx = aliased_input_indices[0] 1375 example_alias = fwd_inputs[example_idx] 1376 # Note that this function is re-used at both trace time and runtime. 1377 # At trace time, we're under a FakeMode so synthetic_base becomes a FakeTensor. 1378 synthetic_base = torch.empty( 1379 (0,), dtype=example_alias.dtype, device=example_alias.device 1380 ) 1381 # We don't actually have a convenient way of going from storage -> tensor, 1382 # So using set_() here (we suffer some minor overhead, but this case is rare). 1383 synthetic_base.set_(example_alias.untyped_storage()) 1384 else: 1385 # Case where all of the aliases require gradients, and have the same _base. 1386 synthetic_base = non_none_bases[0] 1387 for other_base in non_none_bases[1:]: 1388 assert ( 1389 other_base is synthetic_base 1390 ), "aot_autograd() does not yet handle non-differentiable view input mutations." 1391 for alias in aliases_with_none_bases: 1392 assert ( 1393 alias is synthetic_base 1394 ), "aot_autograd() does not yet handle non-differentiable view input mutations." 1395 base_args.append(synthetic_base) 1396 for curr_view_idx in aliased_input_indices: 1397 curr_view = fwd_inputs[curr_view_idx] 1398 base_idx = len(base_args) - 1 1399 # We store just enough info here so that we can regenerate the view later. 1400 # Regeneration: curr_view._view_func(args[base_idx]) 1401 inner_calling_convention_meta[curr_view_idx] = (base_idx, curr_view) 1402 if len(base_args) == 0: 1403 assert len(other_args) == len(fwd_inputs) 1404 # If no synthetic bases are necessary, just return the original inputs. 1405 return fwd_inputs, None 1406 else: 1407 # Otherwise, return: 1408 # (1) The new args according to the updated calling convention: (synthetic_bases, other_args) 1409 # (2) Metadata telling functionalization how to generate the inner argument list given the outer calling convention. 1410 # We post-process it into a list, where meta[i] tells you info about the i'th argument in the inner calling convention. 1411 args_to_functionalization = base_args + other_args 1412 arg_to_old_idx_map = {arg: i for (i, arg) in enumerate(fwd_inputs)} 1413 for i, other_arg in enumerate(other_args): 1414 new_idx = len(base_args) + i 1415 old_idx = arg_to_old_idx_map[other_arg] 1416 inner_calling_convention_meta[old_idx] = new_idx 1417 # post process into a list 1418 post_processed_calling_convention_meta: List[ 1419 Union[int, Tuple[int, torch.Tensor]] 1420 ] = [-1 for _ in range(len(inner_calling_convention_meta))] 1421 for k, v in inner_calling_convention_meta.items(): 1422 post_processed_calling_convention_meta[k] = v 1423 # Quick assert: every argument in the inner calling convention should be accounted for. 1424 for x in post_processed_calling_convention_meta: 1425 assert x != -1 1426 return args_to_functionalization, post_processed_calling_convention_meta 1427 1428 1429@dataclass 1430class AutogradLazyBackwardCompileInfo: 1431 bw_module: Callable 1432 placeholder_list: List[Any] 1433 saved_context: Optional[TracingContext] 1434 saved_compile_context: Optional[CompileContext] 1435 1436 1437# This is wrapped in a class just for namespacing purposes 1438# No need to make it into an actual CompilerWrapper because it doesn't fit the abstract as cleanly 1439class AOTDispatchAutograd: 1440 @staticmethod 1441 def _force_contiguous(x): 1442 if not isinstance(x, torch.Tensor): 1443 return x 1444 x = x.contiguous() 1445 if not is_traceable_wrapper_subclass(x): 1446 return x 1447 for attr in x.__tensor_flatten__()[0]: # type: ignore[attr-defined] 1448 elem = getattr(x, attr) 1449 if not elem.is_contiguous(): 1450 setattr(x, attr, elem.contiguous()) 1451 return x 1452 1453 # See Note [Tangents must be contiguous, Part 2] 1454 @staticmethod 1455 def coerce_runtime_tangent(x, metadata): 1456 if not isinstance(x, torch.Tensor): 1457 return x 1458 if not is_traceable_wrapper_subclass(x): 1459 return x 1460 assert metadata is not None 1461 (_, expected_tangent_metadata) = metadata 1462 _, runtime_tangent_metadata = x.__tensor_flatten__() # type: ignore[attr-defined] 1463 if runtime_tangent_metadata == expected_tangent_metadata: 1464 return x 1465 if not hasattr(x, "__coerce_same_metadata_as_tangent__"): 1466 raise RuntimeError( 1467 f""" 1468During the backward, we encountered a tensor subclass where we guessed its 1469metadata incorrectly. 1470 1471Expected metadata: {str(expected_tangent_metadata)} 1472 1473Runtime metadata: {str(runtime_tangent_metadata)} 1474 1475shape: {str(cast(torch.Tensor, x).shape)} 1476To fix this, your tensor subclass must implement the dunder method __force_to_same_metadata__. 1477""" 1478 ) 1479 return x.__coerce_same_metadata_as_tangent__(expected_tangent_metadata) # type: ignore[attr-defined] 1480 1481 @staticmethod 1482 def post_compile( 1483 compiled_fw_func, # fw_module after compilation + wrappers 1484 compiled_bw_func, # bw_module after compilation + wrappers 1485 maybe_subclass_meta: Optional[SubclassMeta], 1486 num_symints_saved_for_bw_: int, 1487 backward_state_indices: List[int], 1488 disable_amp: bool, 1489 indices_of_inps_to_detach: List[int], 1490 lazy_backward_info: Optional[AutogradLazyBackwardCompileInfo], 1491 aot_config: AOTConfig, 1492 *, 1493 fw_metadata: ViewAndMutationMeta, # runtime metadata 1494 try_save_cache_entry: Optional[Callable], # Save cache entry after compilation 1495 ): 1496 class CompiledFunction(torch.autograd.Function): 1497 compiled_fw = compiled_fw_func 1498 compiled_bw = compiled_bw_func 1499 metadata: ViewAndMutationMeta = fw_metadata # type: ignore[assignment] 1500 maybe_subclass_metadata: Optional[SubclassMeta] = maybe_subclass_meta 1501 num_symints_saved_for_bw = num_symints_saved_for_bw_ 1502 _compiled_autograd_should_lift = False 1503 _aot_id = aot_config.aot_id 1504 _lazy_backward_info = lazy_backward_info 1505 1506 @staticmethod 1507 def _compiled_autograd_key(ctx): 1508 return (ctx._autograd_function_id, *ctx.symints) 1509 1510 @staticmethod 1511 def forward(ctx, *deduped_flat_tensor_args): 1512 args = deduped_flat_tensor_args 1513 if backward_state_indices: 1514 bw_state = args[backward_state_indices[0]] 1515 assert isinstance(bw_state, BackwardState) 1516 ctx._compiled_autograd_backward_state = bw_state 1517 1518 # There is a pretty complicated calling convention around what the compiled fw returns. 1519 # The full list of outputs and their relative order is: 1520 # (*tokens, *mutated_inputs, *fw_outs, *fw_intermediate_bases, *saved_tensors, *saved_symints) 1521 # - Note that in the synthetic bases case, mutated_inputs will correspond to an updated version 1522 # of the original view, and not the synthetic base 1523 # - Note that donated buffer logic requires (*saved_tensors, *saved_symints) showing up last 1524 # in the fw output order. 1525 fw_outs = call_func_at_runtime_with_args( 1526 CompiledFunction.compiled_fw, 1527 args, 1528 disable_amp=disable_amp, 1529 ) 1530 1531 num_outputs = CompiledFunction.metadata.num_outputs 1532 num_outputs_aliased = CompiledFunction.metadata.num_outputs_aliased 1533 num_mutated_runtime_inps = ( 1534 CompiledFunction.metadata.num_mutated_inp_runtime_indices 1535 ) 1536 num_forward_returns = CompiledFunction.metadata.num_forward_returns 1537 1538 # Partitioners must put symint arguments at the end separate from tensor arguments 1539 tensors_saved_for_backwards = fw_outs[ 1540 CompiledFunction.metadata.tensors_saved_for_backwards_slice 1541 ] 1542 assert all( 1543 isinstance(x, torch.Tensor) for x in tensors_saved_for_backwards 1544 ) 1545 # See Note [Detaching saved tensors in AOTAutograd] 1546 ctx.save_for_backward( 1547 *( 1548 x.detach() if x._is_view() else x 1549 for x in tensors_saved_for_backwards 1550 ) 1551 ) 1552 symint_outs = fw_outs[ 1553 CompiledFunction.metadata.symints_saved_for_backwards_slice 1554 ] 1555 assert all( 1556 isinstance(x, (int, float, torch.SymInt, torch.SymFloat)) 1557 for x in symint_outs 1558 ), str([type(x) for x in symint_outs]) 1559 ctx.symints = symint_outs 1560 1561 raw_returns = fw_outs[0:num_forward_returns] 1562 1563 # Wrap all autograd.Function.forward() outputs that are aliases 1564 # so that autograd.Function doesn't treat them as tensors 1565 if num_mutated_runtime_inps > 0: 1566 for i, idx in enumerate( 1567 CompiledFunction.metadata.mutated_inp_runtime_indices 1568 ): 1569 # We could make this faster by only looping over inputs with metadata-only mutations 1570 # (instead of looping over inputs with either data or metadata mutations), but there shouldn't be many. 1571 info = CompiledFunction.metadata.input_info[idx] 1572 if info.mutates_metadata and not info.mutates_data: 1573 raw_return_idx = i 1574 raw_returns[raw_return_idx] = TensorAlias( 1575 raw_returns[raw_return_idx] 1576 ) 1577 1578 if config.debug_assert: 1579 user_mutated_inputs_raw = raw_returns[ 1580 0:num_mutated_runtime_inps 1581 ] 1582 mut_inp_infos = [ 1583 x 1584 for x in CompiledFunction.metadata.input_info 1585 if x.mutates_data or x.mutates_metadata 1586 ] 1587 assert len(user_mutated_inputs_raw) == len(mut_inp_infos) 1588 1589 if CompiledFunction.metadata.num_unsafe_view_outputs > 0: 1590 for idx in CompiledFunction.metadata.unsafe_view_out_indices: 1591 raw_return_idx = num_mutated_runtime_inps + idx 1592 o = raw_returns[raw_return_idx] 1593 raw_returns[raw_return_idx] = torch.ops.aten._unsafe_view( 1594 o, o.shape 1595 ) 1596 1597 if num_outputs_aliased > 0: 1598 for idx in CompiledFunction.metadata.aliased_out_indices: 1599 raw_return_idx = num_mutated_runtime_inps + idx 1600 raw_returns[raw_return_idx] = TensorAlias( 1601 raw_returns[raw_return_idx] 1602 ) 1603 1604 if config.debug_assert: 1605 intermediates_raw = raw_returns[ 1606 num_mutated_runtime_inps + num_outputs : 1607 ] 1608 assert not any( 1609 isinstance(x, TensorAlias) for x in intermediates_raw 1610 ) 1611 1612 # invariant: intermediate bases always require gradients, so we don't have to 1613 # consider marking them as non-differentiable. 1614 raw_returns_not_including_intermediate_bases = raw_returns[ 1615 : num_mutated_runtime_inps + num_outputs 1616 ] 1617 raw_returns_meta = [ 1618 x 1619 for x in CompiledFunction.metadata.input_info 1620 if x.mutation_type == MutationType.MUTATED_OUT_GRAPH 1621 ] + CompiledFunction.metadata.output_info 1622 1623 fw_outs_not_requiring_grad = [ 1624 x 1625 for (i, x) in enumerate( 1626 raw_returns_not_including_intermediate_bases 1627 ) 1628 if isinstance(x, torch.Tensor) 1629 and not raw_returns_meta[i].requires_grad 1630 ] 1631 ctx.mark_non_differentiable(*fw_outs_not_requiring_grad) 1632 ctx._materialize_non_diff_grads = False 1633 return tuple(raw_returns) 1634 1635 @staticmethod 1636 def backward(ctx, *flat_args): 1637 # Calling convention: we expect a grad_out passed to the backward: 1638 # - for every output of the fw that does *not* alias an input or graph intermediate 1639 # - for every updated_input generated by the fw that does *not* alias an input (aka only data-mutations) 1640 # - for every graph intermediate that we need to use to generate an output later. 1641 # The other outputs in the autograd.Function.forward that do *not* show up in the backward include: 1642 # - outputs that alias inputs or graph intermediates 1643 # - updated inputs due to metadata-only mutations. 1644 # We need to return them in the forward, but ensure that they all do not get gradients in the backward, 1645 # and we filter them out here before passing the remaining grad_outputs into the compiled backward. 1646 num_intermediate_bases = ( 1647 CompiledFunction.metadata.num_intermediate_bases 1648 ) 1649 num_mutated_runtime_inps = ( 1650 CompiledFunction.metadata.num_mutated_inp_runtime_indices 1651 ) 1652 expected_grad_outs = ( 1653 CompiledFunction.metadata.num_outputs 1654 + num_mutated_runtime_inps 1655 + num_intermediate_bases 1656 ) 1657 deterministic = CompiledFunction.metadata.deterministic 1658 global_deterministic = torch.are_deterministic_algorithms_enabled() 1659 if deterministic is not None: 1660 torch._check( 1661 not (not deterministic and global_deterministic), 1662 lambda: ( 1663 "This compiled backward function is being run with " 1664 "torch.use_deterministic_algorithms(True), " 1665 "but it was previously generated during the forward function while " 1666 "torch.use_deterministic_algorithms(False) was set." 1667 ), 1668 ) 1669 1670 assert len(flat_args) == expected_grad_outs 1671 out_info = CompiledFunction.metadata.output_info 1672 1673 inp_tangents, out_tangents, intermediate_base_tangents = ( 1674 flat_args[:num_mutated_runtime_inps], 1675 flat_args[ 1676 num_mutated_runtime_inps : num_mutated_runtime_inps 1677 + CompiledFunction.metadata.num_outputs 1678 ], 1679 flat_args[ 1680 num_mutated_runtime_inps 1681 + CompiledFunction.metadata.num_outputs : 1682 ], 1683 ) 1684 # input_info contains info on *every* input, 1685 # But in the backward(), we are only given grad outputs for every mutated input 1686 # We then need to filter out the grad outputs that correspond to metadata-only mutations or don't require grad 1687 input_info = CompiledFunction.metadata.input_info 1688 inp_tangents_filtered = [ 1689 x 1690 for x, info_idx in zip( 1691 inp_tangents, 1692 CompiledFunction.metadata.mutated_inp_runtime_indices, 1693 ) 1694 if input_info[info_idx].mutates_data 1695 and input_info[info_idx].requires_grad 1696 ] 1697 # We also need to filter out grad outputs that correspond to outputs aliasing inputs/intermediates 1698 out_tangents_filtered = [ 1699 x 1700 for x, info in zip(out_tangents, out_info) 1701 if info.output_type 1702 in [ 1703 OutputType.non_alias, 1704 OutputType.unsafe_view_alias, 1705 OutputType.custom_function_view, 1706 ] 1707 and issubclass(info.raw_type, torch.Tensor) 1708 and info.requires_grad 1709 ] 1710 # intermediate bases always require gradients, and always participate in the backward graph. 1711 flat_bw_args_with_grads = [ 1712 *inp_tangents_filtered, 1713 *out_tangents_filtered, 1714 *intermediate_base_tangents, 1715 ] 1716 num_flat_bw_args_with_grads = len(flat_bw_args_with_grads) 1717 1718 # sanity asserts 1719 # metadata_only_inps = [ 1720 # x for x, info_idx in zip(inp_tangents, mutated_inp_indices) 1721 # if not input_info[info_idx].mutates_data 1722 # ] 1723 # aliased_outputs = [ 1724 # x for x, info in zip(out_tangents, out_info) if info.output_type != OutputType.non_alias] 1725 # assert all(x is None for x in metadata_only_inps) 1726 # assert all(x is None for x in aliased_outputs) 1727 # TODO: replace this with FunctionalizedRngRuntimeWrapper 1728 rng_args = [] 1729 if CompiledFunction.metadata.is_rng_op_functionalized: 1730 # Add the seed and offset to args 1731 rng_args = CUDARngStateHelper.get_torch_state_as_tuple() 1732 1733 bw_tokens = [None] * CompiledFunction.metadata.num_backward_tokens 1734 1735 # - note: donated buffer logic requires (*ctx.symints, *ctx.saved_tensors) showing up first 1736 # in the bw output order. 1737 1738 # Every dereference of ctx.saved_tensors incurs saved_tensors_hooks calls 1739 # There are tests that count these calls, saving to var. 1740 ctx_saved_tensors = ctx.saved_tensors 1741 num_ctx_saved_tensors = len(ctx_saved_tensors) 1742 all_args = [ 1743 *ctx.symints, 1744 *ctx_saved_tensors, 1745 *flat_bw_args_with_grads, 1746 *bw_tokens, 1747 *rng_args, 1748 ] 1749 del ctx_saved_tensors 1750 1751 # Note: [AOTAutograd Backward Guards] 1752 # During AOTDispatch, we eagerly create and trace out a joint fw-bw graph. 1753 # Doing so requires us to "guess" about some of the metadata of our grad_outputs. 1754 # 1755 # In particular: if an output to the forward is a plain tensor or a subclass, 1756 # its corresponding grad_output in the backward **may or may not** be 1757 # a plain tensor or a subclass. The main cases are: 1758 # (1) If an output is a plain tensor, its grad_out will also be a plain tensor, 1759 # *unless* the output is used in some subclass compute later in the forward graph, 1760 # which will cause its grad_output to become a subclass 1761 # (2) If an output is a subclass, its grad_out will also be a subclass, 1762 # *unless* the output of the forward did not actually participate in the gradient computation, 1763 # in which case autograd will insert a plain tensor of zeros for the grad_output. 1764 # We could avoid this case with `torch.autograd.Function.set_materialize_grads`, 1765 # although this is not turned on today in AOTAutgrad and would require more work. 1766 # 1767 # Today, we make a guess on subclass-ness based on the above examples, 1768 # and hard-error in the backward if we guessed wrong. 1769 # 1770 # In the future, we should add backward guards that would allow us to 1771 # properly handle this case instead of erroring: we would need to retrace the backward graph, 1772 # since we might produce an entirely different trace if our grad_outputs are subclass or not. 1773 assert ( 1774 len(CompiledFunction.metadata.output_types) 1775 == num_flat_bw_args_with_grads 1776 ) 1777 1778 grad_output_types = [type(x) for x in flat_bw_args_with_grads] 1779 # In general, we can add more asserts/guards here for when we partitioned 1780 # with incorrect assumptions about the grad_outputs. 1781 # Normalize FakeTensor -> torch.Tensor 1782 # - during tracing our types are FakeTensor 1783 # - at runtime in the backward our types are torch.Tensor... 1784 # - unless we're running compiled backward, in which case they are also FakeTensor 1785 grad_output_types_ = [ 1786 torch.Tensor if x is FakeTensor else x for x in grad_output_types 1787 ] 1788 assert ( 1789 grad_output_types_ == CompiledFunction.metadata.output_types 1790 ), f"""\ 1791 We incorrectly attempted to compile the backward with incorrect subclass metadata. 1792 If you run into this error, please file an issue. 1793 Expected grad_output types: {str(CompiledFunction.metadata.output_types)} 1794 Got grad_output types: {str(grad_output_types)}""" 1795 1796 del flat_bw_args_with_grads 1797 1798 tangents_start_idx = ( 1799 len(all_args) 1800 - num_flat_bw_args_with_grads 1801 - len(rng_args) 1802 - len(bw_tokens) 1803 ) 1804 assert tangents_start_idx == len(ctx.symints) + num_ctx_saved_tensors 1805 tangents_end_idx = len(all_args) - len(rng_args) - len(bw_tokens) 1806 1807 # TODO: figure out how to refactor the backward properly 1808 # so I can use aot_dispatch_subclass_wrapper() here. 1809 if CompiledFunction.maybe_subclass_metadata is not None: 1810 tangents = all_args[tangents_start_idx:tangents_end_idx] 1811 1812 def get_types_for_tangents(tangents): 1813 infos = [] 1814 idx = 0 1815 for a in tangents: 1816 if isinstance(a, Tensor) and is_traceable_wrapper_subclass( 1817 a 1818 ): 1819 infos.append(get_types_for_subclass(a)) 1820 else: 1821 infos.append(idx) 1822 idx += 1 1823 return infos 1824 1825 runtime_subclass_info = get_types_for_tangents(tangents) 1826 1827 if len(runtime_subclass_info) != len( 1828 CompiledFunction.metadata.subclass_tangent_meta 1829 ): 1830 raise RuntimeError( 1831 "The grad inputs should be same number as forward output tangents" 1832 ) 1833 for a, b in zip( 1834 runtime_subclass_info, 1835 CompiledFunction.metadata.subclass_tangent_meta, 1836 ): 1837 # Types should match between runtime and traced tangents. 1838 # TODO (tmanlaibaatar) Should actually call coerce_runtime_tangent 1839 if isinstance(a, List) and ( 1840 isinstance(b, SubclassCreationMeta) and b.subclass_type 1841 ): 1842 if not a == b.subclass_type: 1843 raise RuntimeError( 1844 "The grad inputs should be same tensor subclass type as forward output" 1845 ) 1846 1847 # Get the number of tangents after unwrapping 1848 len_tangents = len( 1849 unwrap_tensor_subclasses( 1850 tangents, 1851 is_joint_structure=False, 1852 ) 1853 ) 1854 assert CompiledFunction.metadata.traced_tangent_metas is not None 1855 all_args = [ 1856 ( 1857 AOTDispatchAutograd.coerce_runtime_tangent( 1858 t, 1859 CompiledFunction.metadata.traced_tangent_metas[ 1860 i - tangents_start_idx 1861 ], 1862 ) 1863 if tangents_start_idx <= i < tangents_end_idx 1864 else t 1865 ) 1866 for i, t in enumerate(all_args) 1867 ] 1868 all_args = unwrap_tensor_subclasses( 1869 all_args, is_joint_structure=False 1870 ) 1871 tangents_start_idx = ( 1872 len(all_args) - len_tangents - len(rng_args) - len(bw_tokens) 1873 ) 1874 tangents_end_idx = tangents_start_idx + len_tangents 1875 1876 # Make the tangents contiguous. Note that we must do this after subclass desugaring 1877 # because inputs to inductor have to be contiguous 1878 all_args = [ 1879 ( 1880 AOTDispatchAutograd._force_contiguous(t) 1881 if (tangents_start_idx <= i < tangents_end_idx) 1882 else t 1883 ) 1884 for i, t in enumerate(all_args) 1885 ] 1886 1887 def call_compiled_backward(): 1888 if ctx._is_compiled_autograd_tracing(): 1889 if lazy_backward_info is None: 1890 raise RuntimeError( 1891 """This compiled backward function was saved by AOTAutogradCache, which does not support 1892 compiled autograd. Please turn off AOTAutogradCache using `ENABLE_AOT_AUTOGRAD_CACHE=0` to continue.""" 1893 ) 1894 bw_module = lazy_backward_info.bw_module 1895 # For compiled autograd, run raw FX graph so that it can be inlined into the larger graph 1896 symints = ctx._get_compiled_autograd_symints() 1897 assert len(symints) == len(ctx.symints) 1898 all_args[: len(symints)] = symints 1899 if backward_state_indices: 1900 assert ( 1901 ctx._compiled_autograd_backward_state.proxy is not None 1902 ) 1903 all_args.append(ctx._compiled_autograd_backward_state) 1904 context = ( 1905 torch._C._DisableAutocast if disable_amp else nullcontext 1906 ) 1907 with context(): 1908 out = normalize_as_list(bw_module(*all_args)) 1909 # TODO: replace with post_compile wrapper 1910 out = FunctionalizedRngRuntimeWrapper()._functionalized_rng_runtime_epilogue( 1911 CompiledFunction.metadata, out, offset_index=len(out) - 1 1912 ) 1913 return tuple(out) 1914 assert ( 1915 not backward_state_indices 1916 ), "BackwardState requires CompiledAutograd" 1917 ctx.maybe_clear_saved_tensors() 1918 1919 saved_tensors_use_once = ( 1920 not torch._C._autograd._get_current_graph_task_keep_graph() 1921 ) 1922 1923 if CompiledFunction.compiled_bw is None: 1924 assert lazy_backward_info is not None 1925 1926 if not saved_tensors_use_once: 1927 fw_metadata.bw_donated_idxs = [] 1928 # Update bw_donated_idxs if using lazy_backward_info from `aot_dispatch_autograd` 1929 if ( 1930 hasattr(lazy_backward_info, "saved_context") 1931 and hasattr( 1932 lazy_backward_info.saved_context, "fw_metadata" 1933 ) 1934 and hasattr( 1935 lazy_backward_info.saved_context.fw_metadata, # type: ignore[union-attr] 1936 "bw_donated_idxs", 1937 ) 1938 ): 1939 lazy_backward_info.saved_context.fw_metadata.bw_donated_idxs = ( # type: ignore[union-attr] 1940 [] 1941 ) 1942 1943 bw_module = lazy_backward_info.bw_module 1944 placeholder_list = lazy_backward_info.placeholder_list 1945 saved_context = lazy_backward_info.saved_context 1946 saved_compile_context = lazy_backward_info.saved_compile_context 1947 1948 context = ( 1949 torch._C._DisableAutocast if disable_amp else nullcontext 1950 ) 1951 with tracing(saved_context), compile_context( 1952 saved_compile_context 1953 ), context(), track_graph_compiling(aot_config, "backward"): 1954 CompiledFunction.compiled_bw = aot_config.bw_compiler( 1955 bw_module, placeholder_list 1956 ) 1957 # Maybe save cache entry 1958 if try_save_cache_entry is not None: 1959 try_save_cache_entry( 1960 CompiledFunction.compiled_bw, fw_metadata 1961 ) 1962 1963 if ( 1964 torch._functorch.config.donated_buffer 1965 and not saved_tensors_use_once 1966 and fw_metadata.bw_donated_idxs != [] 1967 ): 1968 torch._check( 1969 False, 1970 lambda: ( 1971 "This backward function was compiled with non-empty donated " 1972 "buffers which requires create_graph=False and retain_graph=False. " 1973 "Please keep backward(create_graph=False, retain_graph=False) " 1974 "across all backward() function calls, or set " 1975 "torch._functorch.config.donated_buffer=False to disable " 1976 "donated buffer." 1977 ), 1978 ) 1979 1980 out = call_func_at_runtime_with_args( 1981 CompiledFunction.compiled_bw, 1982 all_args, 1983 steal_args=True, 1984 disable_amp=disable_amp, 1985 ) 1986 1987 # Toss out the backward output tokens 1988 num_bw_tokens = CompiledFunction.metadata.num_backward_tokens 1989 if num_bw_tokens > 0: 1990 out = out[:-num_bw_tokens] 1991 1992 # TODO: replace this with FunctionalizedRngRuntimeWrapper.post_compile 1993 out = FunctionalizedRngRuntimeWrapper()._functionalized_rng_runtime_epilogue( 1994 CompiledFunction.metadata, out, offset_index=len(out) - 1 1995 ) 1996 return tuple(out) 1997 1998 # Backward with forward inputs mutations is not supported in double backward. 1999 if ( 2000 torch.is_grad_enabled() 2001 and CompiledFunction.metadata.indices_of_inputs_that_requires_grad_with_mutations_in_bw 2002 ): 2003 raise RuntimeError( 2004 "aot_autograd does not support input mutations with requires_grad in backward for create_graph=True" 2005 ) 2006 2007 if torch.is_grad_enabled() and any( 2008 t.requires_grad for t in all_args if isinstance(t, torch.Tensor) 2009 ): 2010 # Ensure that the graph is connected, and error if double backward is performed. 2011 # See comment for why once_differentiable is not sufficient: 2012 # https://github.com/pytorch/pytorch/pull/92348/files#r1072962107 2013 class CompiledFunctionBackward(torch.autograd.Function): 2014 # CompiledFunctionBackward is not yet supported in dynamo skipfiles 2015 _compiled_autograd_should_lift = False 2016 _aot_id = aot_config.aot_id 2017 2018 @staticmethod 2019 def forward(ctx, *unused_args): 2020 outs = call_compiled_backward() 2021 # TODO: figure out how to refactor the backward properly 2022 # so I can use aot_dispatch_subclass_wrapper() here. 2023 if CompiledFunction.maybe_subclass_metadata is not None: 2024 assert ( 2025 CompiledFunction.maybe_subclass_metadata.grad_input_metas 2026 is not None 2027 ) 2028 outs_wrapped = wrap_tensor_subclasses( 2029 outs, 2030 subclass_metas=CompiledFunction.maybe_subclass_metadata.grad_input_metas, 2031 ) 2032 return outs_wrapped 2033 return outs 2034 2035 @staticmethod 2036 def backward(ctx, *args): 2037 raise RuntimeError( 2038 "torch.compile with aot_autograd does not currently support double backward" 2039 ) 2040 2041 CompiledFunctionBackward._compiled_autograd_key = ( # type: ignore[method-assign] 2042 CompiledFunction._compiled_autograd_key 2043 ) 2044 2045 # Pass args even though they're unused, so that the graph is built 2046 out = CompiledFunctionBackward.apply(*all_args) 2047 else: 2048 out = call_compiled_backward() 2049 2050 # TODO: figure out how to refactor the backward properly so I can use aot_dispatch_subclass_wrapper() here. 2051 if CompiledFunction.maybe_subclass_metadata is not None: 2052 assert ( 2053 CompiledFunction.maybe_subclass_metadata.grad_input_metas 2054 is not None 2055 ) 2056 outs_wrapped = wrap_tensor_subclasses( 2057 out, 2058 subclass_metas=CompiledFunction.maybe_subclass_metadata.grad_input_metas, 2059 ) 2060 return outs_wrapped 2061 return out 2062 2063 compiled_function = RuntimeWrapper( 2064 indices_of_inps_to_detach=indices_of_inps_to_detach, 2065 trace_joint=True, 2066 disable_amp=disable_amp, 2067 ).post_compile( 2068 CompiledFunction.apply, 2069 aot_config, 2070 runtime_metadata=fw_metadata, 2071 ) 2072 2073 return compiled_function 2074 2075 2076@dataclass 2077class DebugAssertWrapper(CompilerWrapper): 2078 flat_requires_grad: List[Optional[bool]] = field(default_factory=list) 2079 2080 def post_compile( 2081 self, 2082 compiled_fn, 2083 aot_config: AOTConfig, 2084 *, 2085 runtime_metadata: ViewAndMutationMeta, 2086 ): 2087 @wraps(compiled_fn) 2088 def debug_compiled_function(args: List[Any]): 2089 # TODO: Check aliasing relationships 2090 # TODO: Check strides for metadata mutation 2091 # (NB: ideally, this logic is factored out of this function and 2092 # you move these debug checks there) 2093 2094 # Check requires grad. Bad case is when we compiled with 2095 # requires_grad = False, but input requires_grad = True 2096 # (vice versa is OK; we compute a gradient and then throw 2097 # it away when it hits the input.) 2098 for i, a in enumerate(args): 2099 can_require_grad = self.flat_requires_grad[i] 2100 if can_require_grad is None: 2101 assert not isinstance(a, Tensor) 2102 elif not can_require_grad: 2103 assert not a.requires_grad, format_guard_bug_msg( 2104 aot_config, 2105 f"{describe_input(i, aot_config)} would not require grad", 2106 ) 2107 2108 return compiled_fn(args) 2109 2110 return debug_compiled_function 2111 2112 2113def pre_compile( 2114 wrappers: List[CompilerWrapper], 2115 flat_fn: Callable, 2116 flat_args: List[Any], 2117 aot_config: AOTConfig, 2118 *, 2119 fw_metadata: ViewAndMutationMeta, 2120) -> Tuple[Callable, List[Tensor], ViewAndMutationMeta]: 2121 """ 2122 Runs a sequence of wrappers on the given function and arguments. 2123 Mutates wrappers in place. 2124 """ 2125 for wrapper in wrappers: 2126 flat_fn, flat_args, fw_metadata = wrapper.pre_compile( 2127 flat_fn, flat_args, aot_config, fw_metadata=fw_metadata 2128 ) 2129 return flat_fn, flat_args, fw_metadata 2130 2131 2132def post_compile( 2133 wrappers: List[CompilerWrapper], 2134 compiled_fn: Callable, 2135 aot_config: AOTConfig, 2136 *, 2137 runtime_metadata: ViewAndMutationMeta, 2138) -> Tuple[Callable, ViewAndMutationMeta]: 2139 """ 2140 Runs a sequence of wrappers on the given function. Should be called after pre_compile() 2141 """ 2142 for wrapper in reversed(wrappers): 2143 compiled_fn = wrapper.post_compile( 2144 compiled_fn, aot_config, runtime_metadata=runtime_metadata 2145 ) 2146 return compiled_fn, runtime_metadata 2147 2148 2149def make_runtime_safe( 2150 fw_metadata: ViewAndMutationMeta, 2151 maybe_subclass_meta: Optional[SubclassMeta], 2152): 2153 """ 2154 Calls make_runtime_safe on all ViewAndMutationMetas. 2155 Modifies both arguments. Allows ViewAndMutationMetas to 2156 be safely cached in AOTAutogradCache. 2157 """ 2158 fw_metadata.make_runtime_safe() 2159 if maybe_subclass_meta is not None: 2160 maybe_subclass_meta.fw_metadata.make_runtime_safe() 2161