1# mypy: ignore-errors 2 3import itertools 4from contextlib import contextmanager, nullcontext 5from functools import partial, wraps 6from typing import Any, Callable, Dict, List, NewType, Optional, Tuple 7from unittest.mock import patch 8 9import torch 10import torch._dynamo.logging 11import torch.nn as nn 12import torch.utils._pytree as pytree 13import torch.utils.dlpack 14from torch import Tensor 15from torch._decomp.decompositions_for_rng import PhiloxStateTracker, rng_decompositions 16from torch._dispatch.python import enable_python_dispatcher 17from torch._dynamo import compiled_autograd 18from torch._dynamo.utils import dynamo_timed, preserve_rng_state 19from torch._guards import detect_fake_mode 20from torch._inductor.utils import BoxedBool 21from torch._subclasses import FakeTensor, FakeTensorMode 22from torch.fx.experimental.proxy_tensor import make_fx 23from torch.fx.experimental.symbolic_shapes import ShapeEnv 24from torch.utils._python_dispatch import is_traceable_wrapper_subclass 25 26 27static_inputs_log = torch._logging.getArtifactLogger( 28 __name__, "cudagraph_static_inputs" 29) 30 31from . import config 32from ._aot_autograd.autograd_cache import ( # noqa: F401 33 AOTAutogradCache, 34 autograd_cache_key, 35) 36from ._aot_autograd.collect_metadata_analysis import ( # noqa: F401 37 run_functionalized_fw_and_collect_metadata, 38) 39from ._aot_autograd.functional_utils import ( # noqa: F401 40 _check_if_mutation_can_be_in_graph, 41 are_all_mutations_hidden_from_autograd, 42 are_all_mutations_under_no_grad_or_inference_mode, 43 assert_functional_graph, 44 from_fun, 45 gen_alias_from_base, 46 has_data_mutation, 47 has_metadata_mutation, 48 is_fun, 49 sync_functional_tensor, 50 to_fun, 51) 52from ._aot_autograd.input_output_analysis import ( # noqa: F401 53 _tensors_definitely_do_not_overlap, 54 compute_overlapping_inputs, 55 create_graph_signature, 56 create_synthetic_base_metadata, 57 remove_dupe_metadata, 58) 59from ._aot_autograd.jit_compile_runtime_wrappers import ( # noqa: F401 60 aot_dispatch_autograd, 61 aot_dispatch_base, 62 aot_dispatch_export, 63) 64from ._aot_autograd.logging_utils import ( # noqa: F401 65 callback_set, 66 describe_input, 67 format_guard_bug_msg, 68 get_aot_compilation_context, 69 get_aot_graph_name, 70 get_graph_being_compiled, 71 graph_being_compiled, 72 model_name, 73 nth_graph, 74 set_model_name, 75 setup_stacktrace_preservation_hooks, 76 track_graph_compiling, 77) 78from ._aot_autograd.runtime_wrappers import ( # noqa: F401 79 AOTDedupeWrapper, 80 AOTSyntheticBaseWrapper, 81) 82from ._aot_autograd.schemas import ( # noqa: F401 83 AOTConfig, 84 BackwardSignature, 85 FQN, 86 GraphInputName, 87 GraphOutputName, 88 GraphSignature, 89 InputAliasInfo, 90 MutationType, 91 OutputAliasInfo, 92 OutputType, 93 SubclassCreationMeta, 94 SubclassMeta, 95 TensorAlias, 96 ViewAndMutationMeta, 97) 98from ._aot_autograd.subclass_utils import ( # noqa: F401 99 create_metadata_for_subclass, 100 requires_subclass_dispatch, 101 unwrap_tensor_subclasses, 102 wrap_tensor_subclasses, 103 wrap_tensor_subclasses_maybe_joint, 104) 105from ._aot_autograd.traced_function_transforms import ( # noqa: F401 106 aot_dispatch_subclass, 107 create_functional_call, 108 create_functionalized_fn, 109 create_functionalized_rng_ops_wrapper, 110 create_joint, 111 fn_input_mutations_to_outputs, 112 fn_prepped_for_autograd, 113) 114from ._aot_autograd.utils import ( # noqa: F401 115 _get_autocast_states, 116 _get_symint_hints, 117 call_func_at_runtime_with_args, 118 create_tree_flattened_fn, 119 KNOWN_TYPES, 120 make_boxed_compiler, 121 make_boxed_func, 122 maybe_to_fresh_input, 123 normalize_as_list, 124 partial_flatten_asdict, 125 root_module_when_exporting_non_strict, 126 strict_zip, 127) 128from .partitioners import default_partition 129 130 131zip = strict_zip 132 133# This global counter increments every time we compile a graph with 134# AOTAutograd. You can use this to correlate runtime error messages 135# with compile time (e.g., if you get an error at runtime saying 136# compiled graph 3 failed, you can set a breakpoint at compile time 137# for this graph number to investigate further at compile time.) 138# 139# NB: this is different from get_aot_compilation_context, which tracks 140# each underlying graph that is compiled. In contrast, AOT_COUNTER 141# corresponds to top-level invocations of aot_module/aot_function; 142# one counter is allocated per entire compiled block (but this block 143# may involve compiling multiple subgraphs; e.g., for forwards/backwards) 144AOT_COUNTER = itertools.count() 145 146# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 147# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 148# 149# AOT Autograd contains a pretty non-trivial amount of logic to handle edge cases around aliasing and mutation 150# that are external to the graph (they show up as side effects in some way when you run the graph). 151# 152# Take a look at `test_aotdispatch.py TestAOTAutograd.test_input_mutation*` tests for some examples functions 153# and what they're compiled graphs looks like. 154# Below is a very long comment detailing several edge cases, and showing how AOT Autograd handles them. 155# 156# Note [AOT Autograd: input data mutations] 157# 158# If we compile a function that mutates inputs, then those input mutations are real side effects 159# that a user expects to see after running the compiled graph. 160# However, the graph that we want to send to a backend needs to be *entirely* functional. 161# The way we reconcile this difference is that we remove the mutations completely from the graph that we compile 162# but we update the graph to return (updated_inputs, user_outputs). 163# In the epilogue that runs after the compiled graph is executed, we copy the updated inputs back to the originals. 164# 165# Example: original user code: 166# def f(x): 167# x.mul_(2) 168# out = x.mul(3) 169# return out 170# 171# After AOT Autograd compiles, we end up with a: 172# (a) compiled graph 173# (b) autograd.Function.forward() method, that executes the compiled graph 174# (c) wrapper function, that calls the autograd.Function.forward() and performs the epilogue 175# 176# The output of (a, b, c) are all written below. 177# 178# def compiled_forward_graph(x): 179# x_updated = x.mul(2) 180# out = x_updated.mul(3) 181# return x_updated, out 182# 183# # x_updated gets a gradient in the compiled backward 184# def compiled_backward_graph(grad_x_updated, grad_out): 185# grad_x = ... 186# return grad_x 187# 188# def autograd.Function.forward(x): 189# x_updated, out = compiled_forward_graph(x) 190# return x_updated, out 191# 192# def compiled_wrapper(x): 193# x_updated, out = autograd.Function.apply(x) 194# x.copy_(x_updated) 195# return out 196# 197# Another important thing to note is that updated inputs (due to data mutations) *do* participate 198# in the compiled backward graph! Since the compiled forward graph gets N extra outputs 199# (due to updated inputs showing up as graph outputs), 200# The compiled backward gets an additional N inputs. 201# That way, during the x.copy_(x_updated) bit in the epilogue, gradients will flow from the updated input 202# back to the original input. 203 204 205# Note [AOT Autograd: input metadata mutations] 206# 207# For the same reason as input mutations, we also don't put input metadata mutations in the graph. 208# Instead, we return the updated version of the input (a view), and mutate the input's metadata outside of the graph 209# 210# Example: original user code: 211# def f(x): 212# x.t_() 213# out = x.mul(3) 214# return out 215# 216# AOT Autograd output (compiled graph, autograd.Function.forward(), wrapper function): 217# def compiled_forward_graph(x): 218# x_updated = x.t() 219# out = x_updated.mul(3) 220# return x_updated, out 221# 222# # x_updated does *not* get a gradient in the compiled backward 223# def compiled_backward_graph(grad_out): 224# grad_x = ... 225# return grad_x 226# 227# def autograd.Function.forward(x): 228# x_updated, out = compiled_forward_graph(x) 229# return x_updated, out 230# 231# def compiled_wrapper(x): 232# x_updated, out = autograd.Function.apply(x) 233# x.as_strided_(x_updated) 234# return out 235 236 237# Note [AOT Autograd: outputs aliasing inputs or intermediates!] 238# 239# AOT Autograd needs special handling for outputs that alias graph inputs or intermediates! 240# Why? 241# (1) autograd.Function.forward() has a limitation, where views that returned in the forward cannot later be mutated. 242# (2) views don't need to be compiled in the graph anyway - it's cheap to generate them outside of the compiled graph, 243# in an epilogue. 244# For outputs that alias inputs, we do the following: 245# (a) *still* return the aliased output as a graph output 246# (b) In the AOT Autograd wrapper/epilogue, we don't return that aliased output. Instead, we use it to regenerate the output. 247# 248# For outputs that alias *intermediates*, we do the following: 249# (a) Return the output in the compiled forward, **and** return it's ._base (a graph intermediates) as an output in the forward 250# (b) Use (output, graph_intermediate) to regenerate the alias, and return that to the user (instead of the compiled fw output). 251# You might wonder why we return the aliased output directly in the graph (and making the graph compute it), 252# only to not return it and instead generate a fresh alias off of the intermediate, 253# instead of (say) just storing metadata about the size/stride of the output somewhere to generate the alias. There are two reasons: 254# (1) Getting the actual alias tensor allows us to use view-replay to generate the alias, instead of an as_strided() call 255# (2) Inductor (and other backends) are free to change the memory format of graph outputs, if it results in better performance. 256# This can result in problems if a user later tries to .view() that output expecting it to have one set of strides, 257# when it has a different set of strides. 258# By including the view op directly in the graph, inductor takes that into account when deciding what memory format 259# the graph intermediate should be. 260# 261# Another important thing to note is how our traced backward() graph handles aliases. 262# (this applies to outputs aliasing inputs, outputs aliasing intermediates, 263# *and* updated inputs returned in the compiled forward due to metadata-only mutations). 264# Any outputs that alias (either inputs or intermediates) do NOT participate in the compiled backward graph 265# It would be wasteful to include them in the compiled backward(), because we regenerate them eagerly 266# at the end of the forward. 267# 268# Example: original user code: 269# def f(x): 270# out1 = x.t() 271# intermediate = x.mul(2) 272# out2 = intermediate.view(-1) 273# return out1, out2 274# 275# AOT Autograd output (compiled graph, autograd.Function.forward(), wrapper function): 276# def compiled_forward_graph(x): 277# out1 = x.t() 278# intermediate = x.mul(2) 279# out2 = intermediate.view(-1) 280# # the compiled graph also returns the intermediate 281# return out1, out2, intermediate 282# 283# # intermediate gets a gradient in the compiled backward. 284# # both output aliases (out1 and out2) do not. 285# def compiled_backward_graph(grad_intermediate): 286# grad_x = ... 287# return grad_x 288# 289# def autograd.Function.forward(x): 290# out1, out2, intermediate = compiled_forward_graph(x) 291# return out1, out2, intermediate 292# 293# def compiled_wrapper(x): 294# out1, out2, intermediate = autograd.Function.apply(x) 295# # regenerate out1 from the input 296# out1_regenerated = out1._view_func(x) 297# # regenerate out1 from the intermediate 298# out2_regenerated = out2._view_func(intermediate) 299# return out1_regenerated, out2_regenerated 300 301 302# Note [AOT Autograd: mutations to inputs that alias other inputs] 303# 304# Another edge case that is (only partially) handled today is when an input is mutated, but itself aliases another input. 305# AOT Autograd needs to **ensure** that functionalization knows that the two inputs are aliased to each other. 306# That way, when the aliased input is accessed later in the graph, functionalization knows to "update" the alias 307# given the mutation that occurred. 308# 309# This is handled by updating the calling convention: we create a "synthetic base" that becomes a new input 310# in the compiled function, and we regenerate the original (aliased) inputs directly off of the base 311# inside of the compiled function. 312# 313# This logic is fully encapsulated in aot_wrapper_synthetic_base() 314# 315# Example: original user code: 316# def f(x, x_view): 317# x.mul_(2) 318# out = x * x_view 319# return out 320# f(x, x.view(-1)) 321# 322# AOT Autograd output (compiled graph, autograd.Function.forward(), wrapper function): 323# def compiled_forward_graph(base) 324# x = generate_x(base) 325# x_view = generate_x_view(base) 326# x_updated = x.mul(2) 327# x_view_updated = x_updated.view(-1) 328# out = x_updated * x_view_updated 329# return x_updated, out 330# 331# # The calling convention change from (aliases) -> (base) happens 332# # *outside* of the autograd.Function.forward(). 333# # That means the forward() only has 1 input (base), 334# # and the backward() only has 1 output (grad_base) 335# def compiled_backward_graph(grad_out): 336# grad_base = ... 337# return grad_base 338# 339# def autograd.Function.forward(base): 340# x_updated, out = compiled_forward_graph(base) 341# return x_updated, out 342# 343# # The compiled wrapper is where we create synthetic bases. 344# # The info on which inputs are mutated is also tracked *before* synthetic base creation. 345# def compiled_wrapper(x, x_view): 346# base = merge_view_inputs(x, x_view) 347# x_updated, out = autograd.Function.apply(base) 348# # x and x_view are aliased in eager mode, so this mutation to x will automatically affect x_view. 349# x.copy_(x_updated) 350# return out 351 352 353# Note [AOT Autograd: Views to avoid tangents aliasing inputs] 354# 355# We view every forward output when creating out tangent tensors to handle the problematic 356# case in which a subclass does extra aliasing between graph outputs/inputs in a way that 357# is not visible above the sublass. 358# 359# Ordinarily, when constructing the joint function that we want to trace in AOTAutograd, 360# we're guaranteed that the tangent tensors that we pass 361# into the joint are distinct tensors from the primals. This is because when 362# decide which forward outputs to create tangents for, we only create tangents 363# for forward outputs that are not aliases of inputs (See Note 364# [AOT Autograd: outputs aliasing inputs or intermediates!]). 365# 366# However, when wrapper tensor subclasses enter the picture, it is possible 367# to have an output of the forward that is a subclass that is not an 368# input / alias of an input, but one of its inner tensors is an alias! 369# NestedTensor is an example: Performing an out-of-place pointwise op on a 370# NestedTensor constructs a fresh NestedTensor that holds onto the input's 371# offsets tensor directly. 372# 373# Having tangent tensors that are the same as the (primal) forward inputs, 374# can cause problems during tracing as make_fx() will specialize on our 375# duplicate inputs: If we passed in the same tensor for primals_1 and 376# tangents_1 during tracing, make_fx() will happily sub out all usages of 377# tangents_1 with primals_1 in the graph, which is not what we want. 378# 379# To work around this, we view every forward output when creating out tangent 380# tensors so that tangents can never be the same as forward inputs even if 381# forward inputs alias forward outputs. 382 383# Note [Side-Effectful Tokens in AOTAutograd] 384# 385# We allow some some side-effectful operators in 386# the post-AOTAutograd (functional) graph, such as prints and torchbind operations. 387# To ensure that these side-effects are compatible to future graph passes that 388# assume that the graph is functional, we will thread "effect tokens" to show 389# data dependence between these side-effectful operators. Practically speaking, 390# effect tokens are just dummy values (torch.tensor([])). The graph would look 391# like the following: 392# 393# def gm(self, token0, reader): 394# token1, frame = with_token(ordered_effect_op, (reader,), token0) 395# frame = frame * 2 396# token2, frame2 = with_token(ordered_effect_op, (reader,), token1) 397# frame2 = frame2 * 2 398# return token2, frame, frame2 399# 400# We will pass the token as an input to the graph, thread it through 401# side-effectful operators using the `with_effects` high order operator, and then 402# return the updated token as an output. 403# So the signature of the graph input would look something like 404# (*tokens, *params_buffers, *user_inputs), and the signature of the graph 405# output would look something like (*tokens, *outputs). 406# 407# However, Inductor does not want the concept of tokens in the final generated 408# code's input and output. Since changing the graph signature inside of inductor 409# is difficult, after generating the forward graph, we will run a pass to 410# remove the tokens from the inputgenerate the following graph for Inductor, where 411# the tokens are created and sunk within the graph, rather than as inputs and 412# outputs: 413# 414# def gm(self, reader): 415# token0 = torch.ops.prims._make_token() 416# token1, frame = with_token(ordered_effect_op, (reader,), token0) 417# frame = frame * 2 418# token2, frame2 = with_token(ordered_effect_op, (reader,), token1) 419# frame2 = frame2 * 2 420# sink_token = torch.ops.prims._sink_tokens([token2]) 421# return frame, frame2 422 423# 424# 425# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 426# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 427 428 429aot_autograd_decompositions = {} 430 431FakifiedFlatArgs = NewType("FakifiedFlatArgs", List[Any]) 432 433 434def process_inputs( 435 flat_args: List[Any], 436 aot_config: AOTConfig, 437 fake_mode: FakeTensorMode, 438 shape_env: Optional[ShapeEnv], 439) -> FakifiedFlatArgs: 440 with fake_mode: 441 442 def convert(idx, x): 443 if shape_env is not None: 444 from torch._dynamo.source import ConstantSource 445 446 if isinstance(x, int): 447 # We always specialize on scalar values in export. 448 if aot_config.is_export: 449 return x 450 source = ConstantSource(f"sym_{idx}") 451 return shape_env.create_symintnode( 452 shape_env.create_symbol(x, source), hint=x, source=source 453 ) 454 if isinstance(x, torch.ScriptObject): 455 return torch._library.fake_class_registry.maybe_to_fake_obj( 456 fake_mode, x 457 ) 458 if not isinstance(x, torch.Tensor): 459 return x 460 if isinstance(x, FakeTensor): 461 assert x.fake_mode is fake_mode 462 return x 463 if is_traceable_wrapper_subclass(x): 464 attrs, _ = x.__tensor_flatten__() 465 if all(isinstance(getattr(x, attr), FakeTensor) for attr in attrs): 466 assert all( 467 getattr(x, attr).fake_mode is fake_mode for attr in attrs 468 ) 469 return x 470 471 # see note [Tensor Fakification and Symbol Caching] 472 symbolic_context = None 473 source = None 474 trace = True 475 if tracing_context := torch._guards.TracingContext.try_get(): 476 if x in tracing_context.tensor_to_context: 477 symbolic_context = tracing_context.tensor_to_context[x] 478 source = symbolic_context.tensor_source 479 # We already fakeified this tensor in Dynamo, don't 480 # dump the trace for it again 481 trace = False 482 if ( 483 idx < aot_config.num_params_buffers 484 and config.static_weight_shapes 485 and not symbolic_context 486 ): 487 # TODO: Ensure that this codepath is never exercised from 488 # Dynamo 489 return fake_mode.from_tensor(x, static_shapes=True) 490 491 return fake_mode.from_tensor( 492 x, 493 static_shapes=False, 494 symbolic_context=symbolic_context, 495 source=source, 496 trace=trace, 497 ) 498 499 return FakifiedFlatArgs([convert(idx, x) for idx, x in enumerate(flat_args)]) 500 501 502def construct_fake_mode( 503 flat_args: List[Any], aot_config: AOTConfig 504) -> Tuple[FakeTensorMode, Optional[ShapeEnv]]: 505 fake_mode = detect_fake_mode(flat_args) 506 if fake_mode is None: 507 shape_env = ShapeEnv() if aot_config.dynamic_shapes else None 508 fake_mode = FakeTensorMode(shape_env=shape_env) 509 else: 510 shape_env = fake_mode.shape_env 511 return (fake_mode, shape_env) 512 513 514def create_aot_dispatcher_function( 515 flat_fn, 516 fake_flat_args: FakifiedFlatArgs, 517 aot_config: AOTConfig, 518 fake_mode: FakeTensorMode, 519 shape_env: Optional[ShapeEnv], 520) -> Tuple[Callable, ViewAndMutationMeta]: 521 with dynamo_timed("create_aot_dispatcher_function"): 522 return _create_aot_dispatcher_function( 523 flat_fn, fake_flat_args, aot_config, fake_mode, shape_env 524 ) 525 526 527def _create_aot_dispatcher_function( 528 flat_fn, 529 fake_flat_args: FakifiedFlatArgs, 530 aot_config: AOTConfig, 531 fake_mode: FakeTensorMode, 532 shape_env: Optional[ShapeEnv], 533) -> Tuple[Callable, ViewAndMutationMeta]: 534 """ 535 Traces the forward and backward graphs of the attr:`flat_fn` to generate a 536 joint graph. The joint graph is an Fx graph with Aten ops. Please refer to 537 the tracing mechanism to understand the graph capturing details. 538 539 The joint graph is then passed through attr:`partition_fn` to isolate the 540 forward and backward portions, which are then respectively compiled via the 541 provided attr:`fw_compiler` and attr:`bw_compiler`. 542 543 The resulting compiled forward and backward graphs are then wrapped up in a 544 ``torch.autograd.Function`` object. 545 546 The calling convention here is that the first aot_config.num_params_buffers 547 inputs in flat_args are parameters and buffers, and the rest are inputs. 548 549 We use this to assume that parameters/buffer's shapes don't change. 550 551 Note: this function is used both by aot_function and aot_export (controlled by aot_config.is_export) 552 When aot_config.is_export is True, we return an FX graph + metadata 553 When aot_config.is_export is False, we return an ordinary runtime function 554 """ 555 556 # This is the main entry point. 557 # TODO: Chillee argues that dynamo itself should pass in fake tensors to 558 # the list of arguments when compiling; at the moment we do not do this 559 560 if aot_config.decompositions is None: 561 aot_config.decompositions = {} 562 563 aot_config.decompositions = { 564 **aot_autograd_decompositions, 565 **aot_config.decompositions, 566 } 567 568 if config.functionalize_rng_ops: 569 # Update the decompositions with functionalized random decompositions 570 aot_config.decompositions = { 571 **rng_decompositions, 572 **aot_config.decompositions, 573 } 574 575 # Check flat_args to see if they're already fake. If so, use that fake 576 # mode instead. 577 578 python_dispatcher_mode = ( 579 enable_python_dispatcher() if shape_env is not None else nullcontext() 580 ) 581 582 # See NOTE: [Deferring tensor pack/unpack hooks until runtime] 583 # If any saved tensor hooks are active, we **don't** want to trace them. 584 # Instead, we'll let them run at runtime, around the custom autograd.Function 585 # that we generate in torch.compile. 586 with torch.autograd.set_multithreading_enabled( 587 False 588 ), preserve_rng_state(), ( 589 fake_mode 590 ), ( 591 python_dispatcher_mode 592 ), PhiloxStateTracker(), torch._dynamo.utils._disable_saved_tensors_hooks_during_tracing(): 593 from torch._library.fake_class_registry import ( 594 FakeScriptObject, 595 maybe_to_fake_obj, 596 ) 597 598 # Tracing may mutate the states the fake script object, 599 # so we need to duplicate the fake script objects so that subsequent tracing 600 # won't be affected. 601 def _dup_fake_script_obj(fake_flat_args): 602 return [ 603 maybe_to_fake_obj(detect_fake_mode(fake_flat_args), arg.real_obj) 604 if isinstance(arg, FakeScriptObject) 605 else arg 606 for arg in fake_flat_args 607 ] 608 609 needs_autograd = any( 610 x.requires_grad for x in fake_flat_args if isinstance(x, Tensor) 611 ) 612 613 with enable_python_dispatcher(): 614 # Patch set_rng_state as set_rng_state with fake tensors is 615 # nonsensical. This does not affect the collection of metadata. 616 with patch("torch.cuda.set_rng_state", lambda *args: None): 617 mod = root_module_when_exporting_non_strict(flat_fn) 618 if mod is not None: 619 ctx = _detect_attribute_assignment(mod) 620 else: 621 ctx = nullcontext() 622 with ctx: 623 fw_metadata = run_functionalized_fw_and_collect_metadata( 624 flat_fn, 625 static_input_indices=aot_config.static_input_indices, 626 keep_input_mutations=aot_config.keep_inference_input_mutations, 627 is_train=needs_autograd, 628 pre_dispatch=aot_config.pre_dispatch, 629 )(*_dup_fake_script_obj(fake_flat_args)) 630 631 req_subclass_dispatch = requires_subclass_dispatch( 632 fake_flat_args, fw_metadata 633 ) 634 635 output_and_mutation_safe = not any( 636 x.requires_grad 637 # view-type operations preserve requires_grad even in no_grad. 638 # Do not count aliases of inputs with requires_grad as reason to make a training graph, 639 # as AOTAutograd will perform view-replay to regenerate the view outputs at runtime, 640 # setting their grad_fn properly. 641 and not ( 642 x.output_type 643 in (OutputType.alias_of_input, OutputType.is_input) 644 and fw_metadata.input_info[x.base_idx].requires_grad 645 ) 646 for x in fw_metadata.output_info 647 ) and not any( 648 x.requires_grad 649 and x.mutates_data 650 and not x.mutations_under_no_grad_or_inference_mode 651 and not x.mutations_hidden_from_autograd 652 for x in fw_metadata.input_info 653 ) 654 655 if needs_autograd and output_and_mutation_safe: 656 # We realized that none of the outputs require grad, 657 # and none of the inputs that require grad are mutated. 658 # so we actually have an inference graph. 659 needs_autograd = False 660 # A bit silly: right now in the subclass codepath, our ViewAndMutationMeta 661 # changes depending on whether we pass in is_train / keep_input_mutations, 662 # so we're forced to recompute the metadata. 663 # TODO: refactor the subclass path of run_functionalized_fw_and_collect_metadata 664 # so that this is unnecessary. 665 if req_subclass_dispatch: 666 fw_metadata = run_functionalized_fw_and_collect_metadata( 667 flat_fn, 668 keep_input_mutations=aot_config.keep_inference_input_mutations, 669 is_train=False, 670 pre_dispatch=aot_config.pre_dispatch, 671 static_input_indices=aot_config.static_input_indices, 672 )(*fake_flat_args) 673 else: 674 fw_metadata = ViewAndMutationMeta( 675 input_info=fw_metadata.input_info, 676 output_info=fw_metadata.output_info, 677 num_intermediate_bases=fw_metadata.num_intermediate_bases, 678 keep_input_mutations=aot_config.keep_inference_input_mutations, 679 traced_tangents=fw_metadata.traced_tangents, 680 subclass_inp_meta=fw_metadata.subclass_inp_meta, 681 subclass_fw_graph_out_meta=fw_metadata.subclass_fw_graph_out_meta, 682 subclass_tangent_meta=fw_metadata.subclass_tangent_meta, 683 is_train=False, 684 tokens=fw_metadata.tokens, 685 static_input_indices=fw_metadata.static_input_indices, 686 ) 687 688 if fw_metadata.num_intermediate_bases > 0: 689 assert not req_subclass_dispatch, f"""\ 690torch.compile is currently being used with tensor subclass inputs: 691{','.join([str(type(x)) for x in fake_flat_args])}. We are attempting to a compile a graph with two graph outputs 692that alias one another, which is currently unsupported in the subclass use case. If you run into this, 693please file a github issue""" 694 695 if aot_config.is_export: 696 # aot_export: ban input metadata mutations for now to keep shared code paths simpler. 697 # Keeping .resize_() in the graph will require some work 698 # Allowing it but keeping the graph functional will require some calling convention changes. 699 if len([x for x in fw_metadata.input_info if x.mutates_metadata]) != 0: 700 raise RuntimeError( 701 f"""\ 702Found an input that received a metadata mutation, through e.g. a call to `.resize_()` or `.transpose_()`. 703This is currently banned in the aot_export workflow. If you need this functionality, please file a github issue. 704 705fw_metadata={str(fw_metadata)}""" 706 ) 707 # In export, banning data mutations on inputs that require grad for now. 708 # This should be rare, and is tricky to get right. When we trace the backward, 709 # we currently trace with autograd.grad instead of .backward(), which makes it difficult 710 # to ensure that we run autograd all the way through the input **before** it saw the mutation. 711 if ( 712 len( 713 [ 714 x 715 for x in fw_metadata.input_info 716 if x.requires_grad and x.mutates_data 717 ] 718 ) 719 != 0 720 ): 721 raise RuntimeError( 722 f"""\ 723Found a graph input that requires gradients, and received a mutation. 724This is currently banned in the aot_export workflow. If you need this functionality, please file a github issue. 725 726fw_metadata={str(fw_metadata)}""" 727 ) 728 if req_subclass_dispatch: 729 raise RuntimeError( 730 """\ 731aot_export is not currently supported with traceable tensor subclass. 732If you need this feature, please comment on <CREATE_ISSUE_LINK>""" 733 ) 734 735 # Need to decide on a strategy for functionalized RNG: toggling via global config seems bad, 736 # and turning it on will require a non-trivial calling convention change for any export runtime. 737 if config.functionalize_rng_ops: 738 raise RuntimeError( 739 """\ 740Functionalized RNG is not currently supported in the aot_export workflow. Please file a github issue, 741or otherwise set torch._functorch.config.functionalize_rng_ops = False.""" 742 ) 743 744 def choose_dispatcher(needs_autograd, aot_config): 745 """ 746 Pick a dispatcher based on the config rules. 747 """ 748 if aot_config.is_export: 749 # export uses just the "graph bits", whereas the other 750 # two dispatchers include some extra work around handling a runtime epilogue 751 return partial(aot_dispatch_export, needs_autograd=needs_autograd) 752 elif needs_autograd and not aot_config.pre_dispatch: 753 return aot_dispatch_autograd 754 else: 755 return aot_dispatch_base 756 757 compiler_fn = choose_dispatcher(needs_autograd, aot_config) 758 759 compiled_fn, fw_metadata = compiler_fn( 760 flat_fn, 761 _dup_fake_script_obj(fake_flat_args), 762 aot_config, 763 fw_metadata=fw_metadata, 764 ) 765 return compiled_fn, fw_metadata 766 767 768def aot_function( 769 fn: Callable, 770 fw_compiler: Callable, 771 bw_compiler: Optional[Callable] = None, 772 partition_fn: Callable = default_partition, 773 decompositions: Optional[Dict] = None, 774 num_params_buffers: int = 0, 775 keep_inference_input_mutations: bool = False, 776 inference_compiler: Optional[Callable] = None, 777 *, 778 # Whether or not to trace with dynamic shapes 779 dynamic=False, 780 enable_log=True, 781) -> Callable: 782 """ 783 Traces the forward and backward graph of :attr:`fn` using torch dispatch 784 mechanism, and then compiles the generated forward and backward graphs 785 through :attr:`fw_compiler` and :attr:`bw_compiler`. 786 787 :func:`aot_function` traces the forward and backward graph ahead of time, 788 and generates a joint forward and backward graph. :attr:`partition_fn` is 789 then used to separate out forward and backward graphs. The partitioner 790 function can be used to perform optimizations such as recomputation. One can 791 set `decompositions` dictionary to decompose the operators into a sequence 792 of core or simpler operators supported by the backend compilers. 793 794 .. warning:: 795 This API is experimental and likely to change. 796 797 Args: 798 fn (Callable): A Python function that takes one ore more arguments. Must 799 return one or more Tensors. 800 fw_compiler (Callable): A Python function that accepts an Fx graph with 801 Aten ops and input args, and returns a Callable that semantically is 802 equivalent to the input Fx graph. 803 bw_compiler (Optional[Callable]): A Python function that accepts an 804 Fx graph with Aten ops and input args, and returns a Callable that 805 semantically is equivalent to the input Fx graph. Default: None 806 (when None, it defaults to the :attr:`fw_compiler`) 807 partition_fn (Callable): A Python function that takes a joint forward 808 and backward graph, and partitions it into separate forward and 809 backward graphs. 810 decompositions (Dict): A dictionary to define the decomposition of 811 larger Aten ops into simpler or core Aten ops. 812 inference_compiler (Optional[Callable]): A Python function that accepts an 813 Fx graph with Aten ops and input args, and returns a Callable that 814 semantically is equivalent to the input Fx graph. inference_compiler is invoked 815 if no autograd is needed. Default: None 816 (when None, it defaults to the :attr:`fw_compiler`) 817 Returns: 818 Returns a ``Callable`` that retains the eager behavior of the original 819 :attr:`fn`, but with forward and backward graph compiled via 820 :attr:`fw_compile` and :attr:`bw_compile`. 821 822 A simple example usage of :func:`aot_function` is as follows. This example 823 will print the forward and backward graphs of the function ``fn`` 824 825 >>> fn = lambda x : x.sin().cos() 826 >>> def print_compile_fn(fx_module, args): 827 >>> print(fx_module) 828 >>> return fx_module 829 >>> aot_fn = aot_function(fn, print_compile_fn) 830 >>> x = torch.randn(4, 5, requires_grad=True) 831 >>> aot_fn(x) 832 """ 833 834 if bw_compiler is None: 835 bw_compiler = fw_compiler 836 if inference_compiler is None: 837 inference_compiler = fw_compiler 838 aot_config = AOTConfig( 839 fw_compiler=fw_compiler, 840 bw_compiler=bw_compiler, 841 inference_compiler=inference_compiler, 842 partition_fn=partition_fn, 843 decompositions=decompositions, 844 num_params_buffers=num_params_buffers, 845 aot_id=next(AOT_COUNTER), 846 keep_inference_input_mutations=keep_inference_input_mutations, 847 dynamic_shapes=dynamic, 848 aot_autograd_arg_pos_to_source=None, 849 is_export=False, 850 no_tangents=False, 851 enable_log=enable_log, 852 ) 853 cached_res = None 854 855 @wraps(fn) 856 def returned_function(*args, **kwargs): 857 nonlocal cached_res 858 # Now flatten the tensor args 859 flat_args = pytree.arg_tree_leaves(*args, **kwargs) 860 861 # Compile the function and save it in the cache 862 if cached_res is None: 863 flat_fn, out_spec = create_tree_flattened_fn(fn, args, kwargs) 864 (fake_mode, shape_env) = construct_fake_mode(flat_args, aot_config) 865 fake_flat_args: FakifiedFlatArgs = process_inputs( 866 flat_args, aot_config, fake_mode, shape_env 867 ) 868 compiled_fn, _ = create_aot_dispatcher_function( 869 flat_fn, 870 fake_flat_args, 871 aot_config, 872 fake_mode, 873 shape_env, 874 ) 875 cached_res = (compiled_fn, out_spec) 876 877 cached_fn, out_spec = cached_res 878 out = cached_fn(flat_args) 879 return out_spec.unflatten(out) 880 881 return returned_function 882 883 884def aot_module(mod: nn.Module, *args, **kwargs) -> nn.Module: 885 """ 886 Traces the forward and backward graph of :attr:`mod` using torch dispatch 887 tracing mechanism. It is wrapper function, that underneath uses 888 :func:`aot_function` to perform tracing and compilation. 889 890 :func:`aot_module` lifts the parameters and buffers of ``nn.Module`` as inputs 891 to a new callable which is then compiled through :func:`aot_function`. 892 893 .. warning:: 894 This API is experimental and likely to change. 895 896 Args: 897 mod (Callable): A ``nn.Module`` module. 898 args : args to be passed to :func:`aot_function` 899 kwargs : kwargs to be passed to :func:`aot_function` 900 901 Returns: 902 Returns a ``nn.Module`` that retains the eager behavior of the original 903 :attr:`mod`, but with forward and backward graph compiled. 904 905 """ 906 # See Note: [Fake Modules and AOTAutograd] 907 torch._dynamo.utils.assert_no_fake_params_or_buffers(mod) 908 909 def functional_call(named_params, named_buffers, *args, **kwargs): 910 params_and_buffers = {**named_params, **named_buffers} 911 return torch.func.functional_call(mod, params_and_buffers, args, kwargs) 912 913 named_params = dict(mod.named_parameters(remove_duplicate=False)) 914 named_buffers = dict(mod.named_buffers(remove_duplicate=False)) 915 num_params_buffers = len(named_params) + len(named_buffers) 916 compiled_f = aot_function( 917 functional_call, *args, num_params_buffers=num_params_buffers, **kwargs 918 ) 919 920 class AOTModule(nn.Module): 921 def __init__(self) -> None: 922 super().__init__() 923 self.orig_module = mod 924 925 def forward(self, *args, **kwargs): 926 return compiled_f( 927 named_params, 928 named_buffers, 929 *args, 930 **kwargs, 931 ) 932 933 return AOTModule() 934 935 936def aot_module_simplified( 937 mod: nn.Module, 938 args, 939 fw_compiler: Callable, 940 bw_compiler: Optional[Callable] = None, 941 partition_fn: Callable = default_partition, 942 decompositions: Optional[Dict] = None, 943 keep_inference_input_mutations=False, 944 inference_compiler: Optional[Callable] = None, 945 cudagraphs: Optional[BoxedBool] = None, 946) -> nn.Module: 947 """ 948 This is the simplified or low overhead version of aot_module. For frontends 949 like TorchDynamo, the input functions/modules to AOT are static and have 950 unpacked inputs/outputs. This gives us an opportunity to remove the 951 (1) pytree overhead to parse inputs/outputs, 952 (2) AOT Autograd cache, 953 (3) Reading of params/buffers in every forward call 954 955 :func:`aot_module_simplified` removes these overheads. 956 """ 957 params = { 958 **dict(mod.named_parameters(remove_duplicate=False)), 959 **dict(mod.named_buffers(remove_duplicate=False)), 960 } 961 params_flat, params_spec = pytree.tree_flatten(params) 962 params_flat = list(params_flat) 963 params_len = len(params_flat) 964 965 if cudagraphs is None: 966 cudagraphs = BoxedBool(torch._inductor.config.triton.cudagraphs) 967 968 if bw_compiler is None: 969 bw_compiler = fw_compiler 970 if inference_compiler is None: 971 inference_compiler = fw_compiler 972 973 seen_sources = set() 974 975 full_args = [] 976 # First, the params 977 full_args.extend(params_flat) 978 979 if tracing_context := torch._guards.TracingContext.try_get(): 980 tracing_context.params_flat = params_flat 981 982 aot_autograd_arg_pos_to_source = None 983 # Then, the params 1:1 mapped sources, if relevant. 984 if hasattr(mod, "_param_name_to_source"): 985 aot_autograd_arg_pos_to_source = [] 986 # We now know this came from dynamo, and (1) we care about guards, 987 # so setting up aot_autograd_arg_pos_to_source for downstream dedup guards 988 # can now be done safely. (2) Dynamo logic protects the 1:1 sizing below. 989 for name in params.keys(): 990 assert name in mod._param_name_to_source, f"{name} not found." 991 source = mod._param_name_to_source[name] 992 assert source not in seen_sources, source 993 seen_sources.add(source) 994 aot_autograd_arg_pos_to_source.append(source) 995 996 # Next, the input args 997 full_args.extend(args) 998 999 static_input_indices = [] 1000 if hasattr(mod, "graph"): 1001 # Non dynamo entrypoints can get to here... 1002 for pos, node in enumerate(mod.graph.find_nodes(op="placeholder")): 1003 if hasattr(node, "_dynamo_source"): 1004 # ... but not here! 1005 if aot_autograd_arg_pos_to_source is None: 1006 aot_autograd_arg_pos_to_source = [] 1007 source = node._dynamo_source 1008 assert source not in seen_sources, source 1009 seen_sources.add(source) 1010 aot_autograd_arg_pos_to_source.append(source) 1011 source_name = source.name() if source else str(source) 1012 1013 if "tensor_dict" in node.meta and node.meta["tensor_dict"].get( 1014 "_dynamo_static_input_type", None 1015 ): 1016 static_inputs_log.debug( 1017 "Adding static input pos %s for source %s", pos, source_name 1018 ) 1019 static_input_indices.append(pos) 1020 else: 1021 static_inputs_log.debug( 1022 "Non-static input pos %s for source %s", pos, source_name 1023 ) 1024 1025 if aot_autograd_arg_pos_to_source is not None: 1026 assert len(full_args) == len(aot_autograd_arg_pos_to_source) 1027 1028 dynamic_shapes = False 1029 for x in full_args: 1030 if isinstance(x, FakeTensor): 1031 dynamic_shapes = x.fake_mode.shape_env is not None 1032 break 1033 1034 aot_config = AOTConfig( 1035 fw_compiler=fw_compiler, 1036 bw_compiler=bw_compiler, 1037 inference_compiler=inference_compiler, 1038 partition_fn=partition_fn, 1039 decompositions=decompositions, 1040 num_params_buffers=params_len, 1041 aot_id=next(AOT_COUNTER), 1042 keep_inference_input_mutations=keep_inference_input_mutations, 1043 dynamic_shapes=dynamic_shapes, 1044 aot_autograd_arg_pos_to_source=aot_autograd_arg_pos_to_source, 1045 static_input_indices=static_input_indices, 1046 is_export=False, 1047 no_tangents=False, 1048 cache_key=None, 1049 ) 1050 fake_mode, shape_env = construct_fake_mode(full_args, aot_config) 1051 fake_flat_args = process_inputs(full_args, aot_config, fake_mode, shape_env) 1052 1053 def dispatch_and_compile(): 1054 functional_call = create_functional_call(mod, params_spec, params_len) 1055 with compiled_autograd.disable(): 1056 compiled_fn, _ = create_aot_dispatcher_function( 1057 functional_call, 1058 fake_flat_args, 1059 aot_config, 1060 fake_mode, 1061 shape_env, 1062 ) 1063 return compiled_fn 1064 1065 # Autograd cache stuff 1066 if config.enable_autograd_cache: 1067 compiled_fn = AOTAutogradCache.load( 1068 dispatch_and_compile, mod, fake_flat_args, aot_config, cudagraphs 1069 ) 1070 else: 1071 compiled_fn = dispatch_and_compile() 1072 1073 if isinstance(mod, torch._dynamo.utils.GmWrapper): 1074 # This function is called by the flatten_graph_inputs wrapper, which boxes 1075 # the inputs so that they can be freed before the end of this scope. 1076 # For overhead reasons, this is not the default wrapper, see comment: 1077 # https://github.com/pytorch/pytorch/pull/122535/files#r1560096481 1078 def boxed_forward(runtime_args: List[Any]): 1079 flat_args = [] 1080 flat_args.extend(params_flat) 1081 flat_args.extend(runtime_args) 1082 runtime_args.clear() 1083 return compiled_fn(flat_args) 1084 1085 # Just for convenience 1086 boxed_forward.zero_grad = mod.zero_grad 1087 boxed_forward.named_parameters = mod.named_parameters 1088 boxed_forward.named_buffers = mod.named_buffers 1089 return boxed_forward 1090 1091 # TODO: There is something deeply wrong here; compiled_fn running with 1092 # the boxed calling convention, but aot_module_simplified somehow 1093 # historically returned a function that was not the boxed calling 1094 # convention. This should get fixed... 1095 # NB: GraphModule/nn.Module rely on the non-boxed calling convention here 1096 def forward(*runtime_args: Tuple[Any]): 1097 full_args = [] 1098 full_args.extend(params_flat) 1099 full_args.extend(runtime_args) 1100 return compiled_fn(full_args) 1101 1102 # Just for convenience 1103 forward.zero_grad = mod.zero_grad 1104 forward.named_parameters = mod.named_parameters 1105 forward.named_buffers = mod.named_buffers 1106 1107 return forward 1108 1109 1110def aot_export_module( 1111 mod: nn.Module, 1112 args, 1113 *, 1114 decompositions: Optional[Dict] = None, 1115 # If true, we'll return a joint forward-backward graph, 1116 # As well as metadata on the loss + gradients in the backward. 1117 trace_joint: bool, 1118 # If trace_joint is True, we expect your module to return a scalar loss. 1119 # Your module can return multiple outputs, so you must specify which output the loss is. 1120 output_loss_index: Optional[int] = None, 1121 pre_dispatch: bool = False, 1122 # If None, will be infered from inputs and mod.graph.nodes if mod is a graph module, but the inferred result might be wrong. 1123 dynamic_shapes: Optional[bool] = None, 1124 kwargs=None, 1125) -> Tuple[torch.fx.GraphModule, GraphSignature]: 1126 """ 1127 This function takes in a module, and returns: 1128 (1) an FX graph that can be exported 1129 (2) some metadata about the graph 1130 1131 If `trace_joint=True` we will return a joint graph of the forward + backward. 1132 1133 The traced FX graph will have the following properties compared to the original module: 1134 (1) Inputs and outputs to the module will be pytree-flattened 1135 (2) Parameters and buffers on the module will be lifted into graph inputs, 1136 graph_inputs = (*parameters, *buffers, *user_inputs) 1137 (3) The graph will be fully functionalized 1138 (4) Any input mutations will be converted into additional outputs in the graph, 1139 meaning whoever calls this graph is responsible for applying the mutations 1140 back to the original inputs. 1141 (5) If is_joint is provided the graph will return parameter gradients in addition to user outputs. 1142 The graph output will look like: 1143 graph_outputs = (*updated_inputs, *user_outputs, *param_gradients) 1144 1145 There are also several restrictions on what modules can use this API. In particular: 1146 (1) If trace_joint is specified, we expect the loss function to be **fused** 1147 into the module forward. One of the outputs to the forward must be a scalar loss, 1148 which is specified with `output_loss_index`. 1149 All other outputs to the forward are presumed to not require gradients. 1150 (2) This API cannot capture optimizers (although in theory we could build an API for this). 1151 (3) Metadata mutations on params/buffers/inputs are banned. 1152 (4) Data mutations on anything that requires gradients are banned (parameters) 1153 (5) If an input is mutated, it is not allowed to alias any other inputs. 1154 (6) Parameters must not be duplicated. 1155 """ 1156 if pre_dispatch and trace_joint: 1157 raise RuntimeError("pre_dispatch is not supported when trace_joint is True.") 1158 named_parameters = dict(mod.named_parameters(remove_duplicate=False)) 1159 named_buffers = dict(mod.named_buffers(remove_duplicate=False)) 1160 1161 params_and_buffers = { 1162 **dict(named_parameters), 1163 **dict(named_buffers), 1164 } 1165 params_and_buffers_flat, params_spec = pytree.tree_flatten(params_and_buffers) 1166 params_and_buffers_flat = tuple(params_and_buffers_flat) 1167 params_len = len(params_and_buffers_flat) 1168 1169 kwargs = kwargs or {} 1170 1171 functional_call = create_functional_call( 1172 mod, params_spec, params_len, store_orig_mod=True 1173 ) 1174 1175 num_fw_outs = None 1176 1177 if trace_joint: 1178 # This helper effectively just adds some extra asserts about what the backward will look like: 1179 # Outputs must include a scalar loss, that we compute gradients w.r.t. 1180 # We don't compute gradients w.r.t. anything else: so just in case we detach() 1181 # and other output tensors. 1182 def fn_to_trace(*args): 1183 nonlocal num_fw_outs 1184 out = functional_call(*args) 1185 if output_loss_index is None: 1186 raise RuntimeError( 1187 """\ 1188If trace_joint=Trueit is required that one of your forward outputs must be a scalar loss. 1189You must specify the which (index) output is the loss with output_loss_index.""" 1190 ) 1191 if isinstance(out, (torch.Tensor)): 1192 out = (out,) 1193 if not isinstance(out, (tuple, list)): 1194 raise RuntimeError( 1195 f"Expected forward output to be either a tensor or a list/tuple of tensors. found {type(out)}" 1196 ) 1197 1198 for i, o in enumerate(out): 1199 # We only want to create a backward graph w.r.t. the loss that the user passed in. 1200 # This implies that every other output should not require gradients. 1201 # Instead of making this an error (and forcing the user to detach all other outputs 1202 # of their forward), 1203 # we'll automatically detach them here. 1204 if o.requires_grad and i != output_loss_index: 1205 raise RuntimeError( 1206 f"""\ 1207Found an output of the forward that requires gradients, that was not the scalar loss. 1208We require all outputs to the forward that are not the scalar loss to not require gradient, 1209because we will only compute a backward graph against the scalar loss. 1210You can fix this by calling .detach() on each of your forward outputs that is not the loss. 1211You specified that output index {output_loss_index} is the loss, but we found that 1212the output at index {i} requires gradients.""" 1213 ) 1214 out_loss = out[output_loss_index] 1215 num_fw_outs = len(out) 1216 if not out_loss.requires_grad: 1217 raise RuntimeError( 1218 f"""\ 1219The output at index {output_loss_index} was marked as the loss, but it does not require gradients""" 1220 ) 1221 if out_loss.numel() != 1: 1222 raise RuntimeError( 1223 f"""\ 1224We require the output marked as the loss (at index {output_loss_index}) to be a scalar, but it has shape {out_loss.shape}""" 1225 ) 1226 return out 1227 1228 ctx = nullcontext 1229 else: 1230 # Run under no_grad, so our tracing machinery only traces an inference graph. 1231 # However if pre_dispatch=True, we want to correctly trace set_grad_enabled calls for training. 1232 ctx = nullcontext if pre_dispatch else torch.no_grad 1233 fn_to_trace = functional_call 1234 1235 full_args = [] 1236 # First, the params 1237 # NB: It is REQUIRED that parameters come first, Inductor infers "fixed" 1238 # parameters by looking at the difference in parameter count outside 1239 # and inside AOTAutograd, and assumes the prefix of arguments are fixed 1240 # arguments 1241 full_args.extend(params_and_buffers_flat) 1242 # Next, the input args 1243 full_args.extend(args) 1244 1245 with ctx(): 1246 fx_g, metadata, in_spec, out_spec = _aot_export_function( 1247 fn_to_trace, 1248 full_args, 1249 decompositions=decompositions, 1250 num_params_buffers=params_len, 1251 no_tangents=True, 1252 pre_dispatch=pre_dispatch, 1253 dynamic_shapes=dynamic_shapes, 1254 kwargs=kwargs, 1255 ) 1256 if trace_joint: 1257 1258 def flattened_joint(*args): 1259 # The idea here is that the joint graph that AOTAutograd creates has some strict properties: 1260 # (1) It accepts two arguments (primals, tangents), and pytree_flattens them 1261 # (2) It returns a tuple of (fw_outs, gradients) 1262 # This is a very useful convention for anyone who wants to partition the joint graph 1263 # into a separate forward and backward graph. 1264 # However, 1265 # (1) for people exporting a single joint graph, it would be preferable not to have 1266 # any pytrees in the graph. 1267 # (2) We are guaranteed in the aot_export_module case that the forward outputs a loss, 1268 # and there are therefore no tangents that are needed to run the joint graph. 1269 # (3) AOTAutograd creates a grad_input for every input in the forward, 1270 # including None's for inputs that are not grad-requiring tensors. 1271 # we don't want these in our export graph. 1272 # and there are therefore no tangents that are needed to run the joint graph. 1273 # This function "fixes" both of the above by removing any tangent inputs, 1274 # and removing pytrees from the original FX graph. 1275 fake_tangents = [ 1276 None 1277 for _ in range( 1278 metadata.num_outputs + metadata.num_mutated_inp_runtime_indices 1279 ) 1280 ] 1281 fw_outs, gradients = fx_g(args, fake_tangents) 1282 assert len(gradients) == len(args) 1283 output_gradients = [] 1284 for i, (a, grad) in enumerate(zip(args, gradients)): 1285 if isinstance(a, torch.Tensor) and a.requires_grad: 1286 assert ( 1287 grad is not None 1288 ), """\ 1289Found a parameter that did not receive a gradient. 1290"This is most likely a bug, but if this needs to be supported please comment on this Github issue: 1291https://github.com/pytorch/pytorch/issues/101192 1292""" 1293 output_gradients.append(grad) 1294 else: 1295 assert grad is None 1296 return *fw_outs, *output_gradients 1297 1298 fx_g = make_fx(flattened_joint)(*full_args) 1299 1300 user_args_flat = pytree.arg_tree_leaves(*args, **kwargs) 1301 return fx_g, create_graph_signature( 1302 fx_g, 1303 metadata, 1304 in_spec, 1305 out_spec, 1306 user_args_flat=user_args_flat, 1307 params_and_buffers_flat=params_and_buffers_flat, 1308 param_names=list(named_parameters.keys()), 1309 buffer_names=list(named_buffers.keys()), 1310 trace_joint=trace_joint, 1311 num_user_fw_outs=num_fw_outs, 1312 loss_index=output_loss_index, 1313 ) 1314 1315 1316def aot_export_joint_simple( 1317 func: Callable, 1318 args, 1319 *, 1320 trace_joint: bool, 1321 # It looks like the main consequence of this API is that for dynamic shapes, 1322 # it will assume that parms/buffers are static. 1323 # With the new inferred dynamic shapes API, maybe this doesn't matter? 1324 num_params_buffers: int = 0, 1325 decompositions: Optional[Dict] = None, 1326) -> torch.fx.GraphModule: 1327 """ 1328 A simplified version of export. Used by higher order operators. 1329 1330 This function makes a high-level "no calling convention changes" guarantee: 1331 - If no inputs require grad (so we export an inference graph), 1332 there are *no* calling convention change between the exported graph, and "func". 1333 - If at least one input requires grad (so we trace out and export a joint fw-bw graph), 1334 Then if you were partition the graph into a separate forward and backward graph, 1335 The forward graph will have no calling convention changes compared to "func". 1336 1337 The above also relies on some strong restrictions around which functions this API accepts: 1338 (1) `args` cannot contain any pytrees (they must have been pytree_flattened already) 1339 (2) `func` cannot mutate any inputs 1340 (3) The outputs of `func` cannot alias any inputs. 1341 1342 Note: this function is only lightly tested today. It will probably be tested more heavily by higher order ops. 1343 """ 1344 if trace_joint: 1345 ctx = nullcontext 1346 else: 1347 # Run under no_grad, so our tracing machinery only traces an inference graph. 1348 ctx = torch.no_grad 1349 1350 with ctx(): 1351 fx_g, metadata, in_spec, out_spec = _aot_export_function( 1352 func, 1353 args, 1354 decompositions=decompositions, 1355 ) 1356 in_spec, _kw_in_spec = in_spec.children_specs 1357 # At this point, we can just directly return the (joint or inference graph) that we traced. 1358 # First though: a bunch of assertions to make sure that our graph doesn't require 1359 # any calling convention changes compared to the original function. 1360 # These restrictions are *in addition to* the general restrictions on export. 1361 1362 # No input mutations 1363 if ( 1364 len([x for x in metadata.input_info if x.mutates_data or x.mutates_metadata]) 1365 != 0 1366 ): 1367 raise RuntimeError( 1368 f"aot_export_joint_simple does not support input mutations. {str(metadata)}" 1369 ) 1370 # No output aliasing 1371 if ( 1372 len([x for x in metadata.output_info if x.output_type != OutputType.non_alias]) 1373 != 0 1374 ): 1375 raise RuntimeError( 1376 f"aot_export_joint_simple does not support outputs that alias inputs. {str(metadata)}" 1377 ) 1378 # No pytrees 1379 if in_spec.is_leaf(): 1380 raise RuntimeError( 1381 f"aot_export_joint_simple requires inputs to be a single list/tuple. in_spec={str(in_spec)}" 1382 ) 1383 if not all(child.is_leaf() for child in in_spec.children_specs): 1384 raise RuntimeError( 1385 f"aot_export_joint_simple requires individual inputs not to be pytrees. in_spec={str(in_spec)}" 1386 ) 1387 if out_spec.is_leaf(): 1388 raise RuntimeError( 1389 f"aot_export_joint_simple requires outputs to be a single list/tuple. out_spec={str(out_spec)}" 1390 ) 1391 if not all(child.is_leaf() for child in out_spec.children_specs): 1392 raise RuntimeError( 1393 f"aot_export_joint_simple requires individual outputs not to be pytrees. out_spec={str(out_spec)}" 1394 ) 1395 # TODO: we might have to temporarily patch config.functionalize_rng 1396 # so that it doesn't run when we're exporting a higher order op. 1397 1398 if config.debug_assert: 1399 # Smoke test that after partitioning, we can run the forward without any calling convention changes. 1400 fw_module, bw_module = aot_config.default_partition( # noqa: F821 1401 fx_g, args, num_fwd_outputs=len(fw_metadata.output_infos) # noqa: F821 1402 ) 1403 # Attempt to run the fw_module with the original user inputs 1404 fake_mode = detect_fake_mode(args) 1405 if fake_mode is None: 1406 fake_mode = FakeTensorMode() 1407 with fake_mode: 1408 fw_module(*args) 1409 return fx_g 1410 1411 1412# Private for now because we aren't providing a contract on what to return 1413# for joint graphs (we could when there's a clearer use case) 1414# In the future, we may need to add more export API's that provide their own strong guarantees. 1415# This is meant as a general helper function for handling various export-y use cases. 1416def _aot_export_function( 1417 func: Callable, 1418 args, 1419 *, 1420 num_params_buffers: int = 0, 1421 decompositions: Optional[Dict] = None, 1422 # If we're exporting a joint graph and we don't want any tangent inputs in the graph 1423 # (because we are backpropping through a scalar 1 loss), 1424 # we need to explicitly specify not to include tangents in the graph. 1425 # It's not enough just to check that our tangent is a scalar, since we also 1426 # need to know if it is a 1 (no need to make it a graph input), or something else 1427 # (requiring it to be a graph input). 1428 # We don't know this info at trace time though, so we need to make it an explicit config. 1429 no_tangents: bool = False, 1430 pre_dispatch: bool = False, 1431 # If None, `dynamic_shapes` will be infered from inputs, but the inferred result might be wrong. 1432 dynamic_shapes: Optional[bool] = None, 1433 kwargs=None, 1434) -> Tuple[torch.fx.GraphModule, ViewAndMutationMeta, pytree.TreeSpec, pytree.TreeSpec]: 1435 kwargs = kwargs or {} 1436 1437 flat_fn, out_spec = create_tree_flattened_fn(func, args, kwargs) 1438 flat_args, in_spec = pytree.tree_flatten((args, kwargs)) 1439 1440 if dynamic_shapes is None: 1441 # Try to infer `dynamic_shapes from inputs and graph nodes 1442 fake_mode = detect_fake_mode(flat_args) 1443 if ( 1444 fake_mode is None 1445 and hasattr(func, "_orig_mod") 1446 and isinstance(func._orig_mod, torch.fx.GraphModule) 1447 ): 1448 vals = [ 1449 node.meta["val"] 1450 for node in func._orig_mod.graph.nodes 1451 if "val" in node.meta 1452 ] 1453 fake_mode = detect_fake_mode(vals) 1454 dynamic_shapes = fake_mode is not None and fake_mode.shape_env is not None 1455 1456 # The export use case doesn't care about several bits of AOTConfig 1457 # (1) compilers (we just export the graph) 1458 # (2) partitioners (export is only full graph, user can partition themselves) 1459 aot_config = AOTConfig( 1460 fw_compiler=None, 1461 bw_compiler=None, 1462 inference_compiler=None, 1463 partition_fn=None, 1464 decompositions=decompositions, 1465 num_params_buffers=num_params_buffers, 1466 aot_id=next(AOT_COUNTER), 1467 # For now there's no use case involving keeping input mutations in the graph 1468 # (which we can only do in the inference case anyway). 1469 # We can add this later if we need to. 1470 keep_inference_input_mutations=False, 1471 dynamic_shapes=dynamic_shapes, 1472 aot_autograd_arg_pos_to_source=None, 1473 is_export=True, 1474 no_tangents=no_tangents, 1475 pre_dispatch=pre_dispatch, 1476 ) 1477 fake_mode, shape_env = construct_fake_mode(flat_args, aot_config) 1478 fake_flat_args = process_inputs(flat_args, aot_config, fake_mode, shape_env) 1479 1480 fx_g, meta = create_aot_dispatcher_function( 1481 flat_fn, 1482 fake_flat_args, 1483 aot_config, 1484 fake_mode, 1485 shape_env, 1486 ) 1487 return fx_g, meta, in_spec, out_spec.spec 1488 1489 1490@contextmanager 1491def _detect_attribute_assignment(mod: torch.nn.Module): 1492 # Do not allow assignment of tensor attributes during export unless 1493 # the attribute is registered as a buffer. 1494 1495 STD_ATTRS = { 1496 "_backward_hooks", 1497 "_backward_pre_hooks", 1498 "_buffers", 1499 "_forward_hooks", 1500 "_forward_hooks_always_called", 1501 "_forward_hooks_with_kwargs", 1502 "_forward_pre_hooks", 1503 "_forward_pre_hooks_with_kwargs", 1504 "_is_full_backward_hook", 1505 "_load_state_dict_post_hooks", 1506 "_load_state_dict_pre_hooks", 1507 "_modules", 1508 "_non_persistent_buffers_set", 1509 "_parameters", 1510 "_state_dict_hooks", 1511 "_state_dict_pre_hooks", 1512 "training", 1513 } 1514 1515 def _get_attributes(mod): 1516 # return any attributes of a module that are not standard attributes 1517 return {k: v for k, v in mod.__dict__.items() if k not in STD_ATTRS} 1518 1519 # save state of attributes before enter 1520 snapshot = pytree.tree_map(lambda x: x, _get_attributes(mod)) 1521 try: 1522 yield 1523 finally: 1524 # after exit, compare state of attributes with snapshot 1525 # to detect which tensor attributes were assigned 1526 assigned_tensor_attributes = [] 1527 1528 def _collect_assigned_tensor_attributes(kp, v, _v): 1529 if _v is not v: 1530 attr, *rest = kp 1531 if isinstance(v, torch.Tensor): 1532 assigned_tensor_attributes.append( 1533 f"self.{attr.key}{pytree.keystr(rest)}" 1534 ) 1535 # TODO(avik): Assigning all other types are allowed right now. 1536 # Maybe in the future we want to limit this to primitive types? 1537 1538 pytree.tree_map_with_path( 1539 _collect_assigned_tensor_attributes, snapshot, _get_attributes(mod) 1540 ) 1541 # restore state of all attributes (including, e.g., of primitive types) 1542 mod.__dict__.update(snapshot) 1543 1544 if assigned_tensor_attributes: 1545 if len(assigned_tensor_attributes) > 1: 1546 noun, verb = "attributes", "were" 1547 else: 1548 noun, verb = "attribute", "was" 1549 raise ValueError( 1550 f"The tensor {noun} {', '.join(assigned_tensor_attributes)} {verb} assigned during export. " 1551 "Such attributes must be registered as buffers using the `register_buffer` API " 1552 "(https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_buffer)." 1553 ) 1554 1555 1556compiled_function = aot_function 1557compiled_module = aot_module 1558