1# mypy: allow-untyped-defs 2""" 3Functions in this module do most of the "work" of AOTAutograd. 4An aot_dispatch_* function: 5- Takes in the input flat_fn, flat_args, and some metadata 6- Runs a set of pre compile wrappers (e.g. argument deduping) 7- Runs the actual compiler 8- Wraps the returned callable in a set of post compile wrappers 9- Returns the wrapped callable and metadata. 10""" 11 12import itertools 13import logging 14import traceback 15from contextlib import nullcontext 16from typing import Any, Callable, List, Optional, Sequence, Tuple 17 18import torch 19import torch.utils.dlpack 20from torch import Tensor 21from torch._dynamo.utils import lazy_format_graph_code 22from torch._guards import CompileContext, TracingContext 23from torch._logging import getArtifactLogger, trace_structured 24from torch._subclasses import FakeTensor 25from torch.fx.experimental._backward_state import BackwardState 26from torch.fx.experimental.proxy_tensor import is_sym_node 27from torch.fx.experimental.symbolic_shapes import fx_placeholder_vals 28from torch.multiprocessing.reductions import StorageWeakRef 29 30from .. import config 31from .autograd_cache import ( 32 AOTAutogradCache, 33 AOTAutogradCacheEntry, 34 CompiledBackward, 35 CompiledForward, 36) 37from .dispatch_and_compile_graph import ( 38 aot_dispatch_autograd_graph, 39 aot_dispatch_base_graph, 40) 41from .logging_utils import track_graph_compiling 42from .runtime_wrappers import ( 43 AOTDedupeWrapper, 44 AOTDispatchAutograd, 45 AOTDispatchSubclassWrapper, 46 AOTSyntheticBaseWrapper, 47 AutogradLazyBackwardCompileInfo, 48 CompilerWrapper, 49 DebugAssertWrapper, 50 EffectTokensWrapper, 51 FakifiedOutWrapper, 52 FunctionalizedRngRuntimeWrapper, 53 make_runtime_safe, 54 post_compile, 55 pre_compile, 56 RuntimeWrapper, 57) 58from .schemas import AOTConfig, MutationType, ViewAndMutationMeta 59from .subclass_utils import compute_inner_mutated_inp_indices_from_subclass_meta 60from .utils import _get_symint_hints, make_boxed_func, strict_zip, unlift_tokens 61 62 63zip = strict_zip 64 65log = logging.getLogger(__name__) 66aot_joint_log = getArtifactLogger(__name__, "aot_joint_graph") 67aot_graphs_log = getArtifactLogger(__name__, "aot_graphs") 68 69aten = torch.ops.aten 70 71# Returns a Callable and a ViewAndMutationMeta. 72# Currently, only export needs the ViewAndMutationMeta after this function. 73DispatchReturn = Tuple[Callable, ViewAndMutationMeta] 74 75 76def _create_wrappers_for_dispatch(needs_autograd: bool) -> List[CompilerWrapper]: 77 """ 78 Wrappers that run on every dispatch function 79 """ 80 return [AOTDedupeWrapper(), AOTSyntheticBaseWrapper(trace_joint=needs_autograd)] 81 82 83# Export's dispatching logic is unique in a few ways: it only needs the "graph" 84# bits of aot_autograd, and doesn't need to do any specific wrapping. 85def aot_dispatch_export( 86 flat_fn: Callable, 87 flat_args: List[Any], 88 aot_config: AOTConfig, 89 *, 90 fw_metadata: ViewAndMutationMeta, 91 needs_autograd: bool, 92) -> DispatchReturn: 93 wrappers = _create_wrappers_for_dispatch(needs_autograd) 94 flat_fn, flat_args, fw_metadata = pre_compile( 95 wrappers, 96 flat_fn, 97 flat_args, 98 aot_config, 99 fw_metadata=fw_metadata, 100 ) 101 if needs_autograd and not aot_config.pre_dispatch: 102 graph, _, _ = aot_dispatch_autograd_graph( 103 flat_fn, flat_args, aot_config, fw_metadata=fw_metadata 104 ) 105 else: 106 graph, _, _ = aot_dispatch_base_graph( 107 flat_fn, flat_args, aot_config, fw_metadata=fw_metadata 108 ) 109 110 # NB: the wrappers that run in pre_compile for export are 111 # either a no-op, because they're not needed, or will raise a runtime error, 112 # since they don't support export. 113 # We still run these wrappers to make sure that they're not needed pre compile, 114 # but we technically don't need to run them post compile at all here. 115 compiled_fn, fw_metadata = post_compile( 116 wrappers, graph, aot_config, runtime_metadata=fw_metadata 117 ) 118 119 # Therefore, since no wrapperes run, we don't get back a callable - we get back the raw fx graph 120 # (either a joint or an inference-only graph) 121 assert isinstance(compiled_fn, torch.fx.GraphModule) 122 return compiled_fn, fw_metadata 123 124 125def aot_dispatch_base( 126 flat_fn, 127 flat_args: List[Any], 128 aot_config: AOTConfig, 129 *, 130 fw_metadata: ViewAndMutationMeta, 131) -> DispatchReturn: 132 """ 133 Handles functions that don't need autograd. Runs wrappers and compiles with fw_compiler. 134 """ 135 wrappers = _create_wrappers_for_dispatch(needs_autograd=False) 136 flat_fn, flat_args, fw_metadata = pre_compile( 137 wrappers, flat_fn, flat_args, aot_config, fw_metadata=fw_metadata 138 ) 139 140 fw_module, updated_flat_args, maybe_subclass_meta = aot_dispatch_base_graph( # type: ignore[misc] 141 flat_fn, flat_args, aot_config, fw_metadata=fw_metadata 142 ) 143 144 fakified_out_wrapper = FakifiedOutWrapper() 145 ( 146 fw_module, 147 updated_flat_args, 148 fw_metadata, 149 ) = fakified_out_wrapper.pre_compile( 150 fw_module, updated_flat_args, aot_config, fw_metadata=fw_metadata 151 ) 152 functionalized_rng_wrapper = FunctionalizedRngRuntimeWrapper() 153 ( 154 fw_module, 155 updated_flat_args, 156 fw_metadata, 157 ) = functionalized_rng_wrapper.pre_compile( 158 fw_module, updated_flat_args, aot_config, fw_metadata=fw_metadata 159 ) 160 161 disable_amp = torch._C._is_any_autocast_enabled() 162 context = torch._C._DisableAutocast if disable_amp else nullcontext 163 164 with context(), track_graph_compiling(aot_config, "inference"): 165 compiler = ( 166 aot_config.inference_compiler 167 if aot_config.inference_compiler is not None 168 else aot_config.fw_compiler 169 ) 170 171 if tracing_context := torch._guards.TracingContext.try_get(): 172 tracing_context.fw_metadata = ( 173 fw_metadata 174 if maybe_subclass_meta is None 175 else maybe_subclass_meta.fw_metadata 176 ) 177 178 with TracingContext.report_output_strides() as fwd_output_strides: 179 compiled_fw = compiler(fw_module, updated_flat_args) 180 181 if fakified_out_wrapper.needs_post_compile: 182 fakified_out_wrapper.set_fwd_output_strides(fwd_output_strides) 183 184 make_runtime_safe(fw_metadata, maybe_subclass_meta) 185 186 # However, RuntimeWrapper does not expect the rng offsets in the 187 # output. So, we have to create another wrapper and take out the offset. As 188 # a result, we have to account for not boxed_call compilers as well. 189 if not hasattr(compiled_fw, "_boxed_call"): 190 compiled_fw = make_boxed_func(compiled_fw) 191 192 # Create a wrapper to set up the rng functionalize and fakified out bits 193 compiled_fw = functionalized_rng_wrapper.post_compile( 194 compiled_fw, aot_config, runtime_metadata=fw_metadata 195 ) 196 197 if config.enable_autograd_cache and aot_config.cache_key: 198 if fw_key := getattr(compiled_fw, "_fx_graph_cache_key", None): 199 entry = AOTAutogradCacheEntry( 200 compiled_fw=CompiledForward(fw_key), 201 compiled_bw=None, 202 runtime_metadata=fw_metadata, 203 dispatch_wrappers=wrappers, 204 maybe_subclass_meta=maybe_subclass_meta, 205 num_fw_outs_saved_for_bw=None, 206 indices_of_inps_to_detach=[], 207 ) 208 AOTAutogradCache.save(aot_config.cache_key, entry) 209 210 compiled_fw = fakified_out_wrapper.post_compile( 211 compiled_fw, 212 aot_config, 213 runtime_metadata=fw_metadata, 214 ) 215 216 compiled_fw = EffectTokensWrapper().post_compile( 217 compiled_fw, 218 aot_config, 219 runtime_metadata=fw_metadata, 220 ) 221 222 # Why do we need to pass in num_fw_outs_saved_for_bw? 223 # See Note: [Partitioner handling for Subclasses, Part 2] 224 compiled_fw = AOTDispatchSubclassWrapper( 225 trace_joint=False, 226 # TODO: once we use pre_compile this will be flat_fn at the top of this function 227 fw_only=None, 228 maybe_subclass_meta=maybe_subclass_meta, 229 num_fw_outs_saved_for_bw=None, 230 ).post_compile( 231 compiled_fw, 232 aot_config, # not used 233 runtime_metadata=fw_metadata, 234 ) 235 236 if not hasattr(compiled_fw, "_boxed_call"): 237 compiled_fw = make_boxed_func(compiled_fw) 238 239 compiled_fn = RuntimeWrapper( 240 indices_of_inps_to_detach=[], 241 trace_joint=False, 242 disable_amp=disable_amp, 243 ).post_compile( 244 compiled_fw, 245 aot_config, 246 runtime_metadata=fw_metadata, 247 ) 248 249 compiled_fn = post_compile( 250 wrappers, compiled_fn, aot_config, runtime_metadata=fw_metadata 251 ) 252 return compiled_fn 253 254 255def collect_fw_donated_buffer_idxs( 256 fw_ins: List[Optional[FakeTensor]], 257 user_fw_outs: List[Optional[FakeTensor]], 258 bw_outs: List[Optional[FakeTensor]], 259 saved_tensors: List[FakeTensor], 260) -> List[int]: 261 """ 262 Checks if the saved tensors are donated buffers, which means a saved tensor is not 263 an alias of any tensors in fw_ins, user_fw_outs, and bw_outs. 264 """ 265 266 storage_refs = set() 267 for t in itertools.chain(fw_ins, user_fw_outs, bw_outs): 268 if isinstance(t, FakeTensor): 269 storage_refs.add(StorageWeakRef(t.untyped_storage())) 270 271 num_saved_tensor = len(saved_tensors) 272 donated_buffer_idxs = [] 273 for i in range(num_saved_tensor): 274 t = saved_tensors[i] 275 if StorageWeakRef(t.untyped_storage()) not in storage_refs: 276 donated_buffer_idxs.append(i) 277 278 return donated_buffer_idxs 279 280 281def collect_bw_donated_buffer_idxs( 282 fw_module: torch.fx.GraphModule, 283 bw_module: torch.fx.GraphModule, 284 fw_metadata: ViewAndMutationMeta, 285) -> List[int]: 286 """ 287 Collects backward donated buffer indexes from fw_module and bw_module. 288 """ 289 290 fw_ins = fw_module.graph.find_nodes(op="placeholder") 291 bw_outs = next(reversed(bw_module.graph.find_nodes(op="output"))).args[0] 292 fw_outs = next(reversed(fw_module.graph.find_nodes(op="output"))).args[0] 293 294 fw_ins = [n.meta["val"] if hasattr(n, "meta") else None for n in fw_ins] 295 fw_outs = [n.meta["val"] if hasattr(n, "meta") else None for n in fw_outs] 296 bw_outs = [n.meta["val"] if hasattr(n, "meta") else None for n in bw_outs] 297 298 user_fw_outs = fw_outs[: fw_metadata.num_forward] 299 saved_tensors = fw_outs[fw_metadata.tensors_saved_for_backwards_slice] 300 301 fw_donated_buffer = collect_fw_donated_buffer_idxs( 302 fw_ins, 303 user_fw_outs, 304 bw_outs, 305 saved_tensors, 306 ) 307 308 assert fw_metadata.num_symints_saved_for_bw is not None 309 return [fw_metadata.num_symints_saved_for_bw + i for i in fw_donated_buffer] 310 311 312def aot_dispatch_autograd( 313 flat_fn, 314 flat_args: List[Any], 315 aot_config: AOTConfig, 316 *, 317 fw_metadata: ViewAndMutationMeta, 318) -> DispatchReturn: 319 """ 320 Autograd logic. Generates a joint graph, partitions it, manipulates the input with various wrappers, 321 and returns a wrapped torch.autograd.Function with a forward and backward. 322 """ 323 wrappers = _create_wrappers_for_dispatch(needs_autograd=True) 324 flat_fn, flat_args, fw_metadata = pre_compile( 325 wrappers, 326 flat_fn, 327 flat_args, 328 aot_config, 329 fw_metadata=fw_metadata, 330 ) 331 332 fw_metadata.deterministic = torch.are_deterministic_algorithms_enabled() 333 fx_g, joint_inputs, maybe_subclass_meta = aot_dispatch_autograd_graph( 334 flat_fn, flat_args, aot_config, fw_metadata=fw_metadata 335 ) 336 337 # Copied from aot_dispatch_autograd_graph. 338 disable_amp = torch._C._is_any_autocast_enabled() 339 340 if aot_config.enable_log: 341 aot_joint_log.info( 342 "%s", 343 lazy_format_graph_code( 344 "Joint graph", 345 fx_g, 346 aot_config.aot_id, 347 include_stride=True, 348 include_device=True, 349 colored=True, 350 ), 351 ) 352 trace_structured( 353 "aot_joint_graph", 354 payload_fn=lambda: fx_g.print_readable( 355 print_output=False, include_stride=True, include_device=True 356 ), 357 ) 358 359 with torch.no_grad(): 360 inner_meta = ( 361 fw_metadata 362 if maybe_subclass_meta is None 363 else maybe_subclass_meta.fw_metadata 364 ) 365 with track_graph_compiling(aot_config, "joint"): 366 # See Note: [Partitioner handling for Subclasses, Part 1] 367 # See Note: [Recomputing subclass mutation handling] 368 mutated_inp_runtime_indices = ( 369 compute_inner_mutated_inp_indices_from_subclass_meta( 370 fw_metadata, inner_meta 371 ) 372 ) 373 num_tokens = len(fw_metadata.tokens) 374 num_mutated_inp_runtime_indices = len(mutated_inp_runtime_indices) 375 num_inner_fwd_outputs = ( 376 num_mutated_inp_runtime_indices 377 + inner_meta.num_outputs 378 + inner_meta.num_intermediate_bases 379 + inner_meta.num_outputs_rng_offset 380 + num_tokens # See Note [Side-Effectful Tokens in AOTAutograd] 381 ) 382 fw_module, bw_module = aot_config.partition_fn( 383 fx_g, joint_inputs, num_fwd_outputs=num_inner_fwd_outputs 384 ) 385 386 # See Note [Side-Effectful Tokens in AOTAutograd] 387 if config.unlift_effect_tokens and ( 388 num_tokens > 0 or fw_metadata.num_backward_tokens > 0 389 ): 390 unlift_tokens(fw_module, fw_metadata, aot_config, bw_module) 391 392 num_inner_fwd_outputs -= num_tokens 393 joint_inputs = ( 394 joint_inputs[0][num_tokens:], 395 joint_inputs[1], 396 ) 397 398 fw_outs = next(iter(fw_module.graph.find_nodes(op="output"))).args[0] 399 # we only need to bookkeep the symints that are saved for bw, not any symints 400 # the user forward might have returned in its own output 401 fw_outs_saved_for_bw = fw_outs[num_inner_fwd_outputs:] 402 num_fw_outs_saved_for_bw = len(fw_outs_saved_for_bw) 403 symint_outs_saved_for_bw = [ 404 n for n in fw_outs_saved_for_bw if is_sym_node(n) 405 ] 406 fw_metadata.num_symints_saved_for_bw = len(symint_outs_saved_for_bw) 407 inner_meta.num_symints_saved_for_bw = len(symint_outs_saved_for_bw) 408 num_symints_saved_for_bw = len(symint_outs_saved_for_bw) 409 410 if torch._functorch.config.donated_buffer: 411 fw_metadata.bw_donated_idxs = collect_bw_donated_buffer_idxs( 412 fw_module, 413 bw_module, 414 inner_meta, 415 ) 416 inner_meta.bw_donated_idxs = fw_metadata.bw_donated_idxs 417 418 if aot_config.enable_log: 419 aot_graphs_log.info( 420 "aot_config id: %s, fw_metadata=%s, inner_meta=%s", 421 str(aot_config.aot_id), 422 str(fw_metadata), 423 str(inner_meta), 424 ) 425 426 # Note [Detaching inputs that never need gradients] 427 # See https://github.com/pytorch/pytorch/issues/97745 428 # Suppose we have a function like this that we want to compile: 429 # 430 # def f(x, y): 431 # return torch.mul(x, y.detach()) 432 # 433 # What gradients should we compute for x and y? 434 # By default, AOTAutograd will compute a gradient for **every** input that requires gradients, 435 # and so we'll compute: 436 # x_grad_input = y 437 # y_grad_input = None 438 # Does this preserve the semantics of eager mode? 439 # Unfortunately, no. 440 # Doing the above will cause autograd to **continue** to backprop the autograd tape 441 # that was generated from constructing y. 442 # 443 # This is **different** from what would have happened in eager mode. 444 # In eager mode, if we backprop through the output of this function, autograd will only traverse 445 # the bit of the autograd tape corresponding to "x". 446 # In particular, if a user had previously backpropped through y's autograd tape, 447 # And then they try to backprop through the output of the above function, 448 # then we'll hit the dreaded "Trying to backward through the graph a second time" error. 449 # 450 # You might think: If autograd sees that a gradient is None, shouldn't it stop early, 451 # instead of continuing the backprop through the ancestors of that node in the graph? 452 # 453 # Autograd has two passes: 454 # (1) a first pass that traverses the autograd graph and figures out which nodes need to be executed 455 # (2) a second pass that actually goes ahead and executes each node when it becomes ready, 456 # propagating gradients 457 # By the time we're executing a node and we see that it produces a None, the set of nodes to execute 458 # is already locked-in. 459 # 460 # The fix: instead, we can recognize statically that the graph we're compiling will never contribute 461 # gradients to y, and prevent autograd from trying to traverse y's autograd tape at all. 462 # We can do this by manually detach'ing y before sending it through the `CompiledFunction`. 463 # 464 # Note that this solution is not bulletproof. 465 # It's possible to construct a case where eager may or may not have have tried to autograd through y, 466 # depending on the actual grad_outputs that were passed in during the backward. 467 # There is no easy fix for this: the simplest fix would be to run with `retain_graph=True`, 468 # allowing autograd to re-use the graph. 469 # 470 # An example of this case is: 471 # def f(x): 472 # return x.detach() * 2, x * 3 473 # If we were to only backprop through outs[0], in eager, we would stop 474 # If we backward only on the first output, we shouldn't send a grad through x. 475 # But the custom autograd function doesn't know that: it will materialize zero grads for x * 3 476 # and we will end up with a zero grad at x. 477 # If we later backprop through the second output, this will also require backprop'ing through x. 478 # Meaning we'll need to use `retain_graph=True` to be able to backprop through x the second time. 479 _indices_of_inps_to_detach: List[int] = [] 480 481 # reversed() since we expect output at end of graph 482 bw_output = next(reversed(bw_module.graph.find_nodes(op="output"))) 483 bw_outs: Sequence[torch.fx.Node] = bw_output.args[0] # type: ignore[assignment] 484 485 # TODO: we should apply the below "detach inputs if their gradients are statically known to be None" 486 # optimization even if we have subclass inputs/outputs (we do not handle this today). 487 # Computing which our our inputs get None gradients is a bit more complicated, 488 # if any of our inputs are subclasses. Why? 489 # (a) we need to make sure that we call .detach() on the input subclasses, since autograd sees subclasses. 490 # (b) The grad_outputs that we AOT computed in our backward graph are the desugared tensor tensors, 491 # so we need to figure out which subclass fw inputs they map to. 492 if maybe_subclass_meta is None: 493 num_backward_tokens: int = inner_meta.num_backward_tokens 494 assert ( 495 len(bw_outs) 496 == len(fw_metadata.input_info) 497 + inner_meta.num_outputs_rng_offset 498 + num_backward_tokens 499 ) 500 bw_outs_no_rng_no_tokens = bw_outs 501 if (inner_meta.num_outputs_rng_offset + num_backward_tokens) > 0: 502 bw_outs_no_rng_no_tokens = bw_outs[ 503 : -(inner_meta.num_outputs_rng_offset + num_backward_tokens) 504 ] 505 assert len(bw_outs_no_rng_no_tokens) == len(fw_metadata.input_info) 506 507 for i, (bw_out) in enumerate(bw_outs_no_rng_no_tokens): 508 # If our input experiences a metadata mutation inside the graph (e.g. set_()), 509 # we *must* not detach, otherwise it will be the detach'd input that gets the metadata mutation 510 metadata_mutation_in_graph = ( 511 fw_metadata.input_info[i].mutation_type 512 == MutationType.MUTATED_IN_GRAPH 513 and fw_metadata.input_info[i].mutates_storage_metadata 514 ) 515 is_non_leaf = ( 516 fw_metadata.input_info[i].requires_grad 517 and not fw_metadata.input_info[i].is_leaf 518 ) 519 if bw_out is None and not metadata_mutation_in_graph and is_non_leaf: 520 _indices_of_inps_to_detach.append(i) 521 522 if aot_config.enable_log: 523 aot_graphs_log.info( 524 "%s", 525 lazy_format_graph_code( 526 "Forward graph", 527 fw_module, 528 aot_config.aot_id, 529 include_stride=True, 530 include_device=True, 531 colored=True, 532 ), 533 ) 534 aot_graphs_log.info( 535 "%s", 536 lazy_format_graph_code( 537 "Backward graph", 538 bw_module, 539 aot_config.aot_id, 540 include_stride=True, 541 include_device=True, 542 colored=True, 543 ), 544 ) 545 trace_structured( 546 "aot_forward_graph", 547 payload_fn=lambda: fw_module.print_readable( 548 print_output=False, include_stride=True, include_device=True 549 ), 550 ) 551 trace_structured( 552 "aot_backward_graph", 553 payload_fn=lambda: bw_module.print_readable( 554 print_output=False, include_stride=True, include_device=True 555 ), 556 ) 557 558 # AMP is already traced out in joint graph. we do not wish to reapply it accidentally 559 # in the compiler. 560 with track_graph_compiling(aot_config, "forward"), torch._C._DisableAutocast(): 561 # flat_args at this point might still be subclasses- 562 # make sure to pass the unwrapped fake tensors into the compiler! 563 adjusted_flat_args = joint_inputs[0] 564 565 fakified_out_wrapper = FakifiedOutWrapper() 566 ( 567 fw_module, 568 adjusted_flat_args, 569 fw_metadata, 570 ) = fakified_out_wrapper.pre_compile( 571 fw_module, adjusted_flat_args, aot_config, fw_metadata=fw_metadata 572 ) 573 574 functionalized_rng_wrapper = FunctionalizedRngRuntimeWrapper( 575 return_new_outs=False 576 ) 577 ( 578 fw_module, 579 adjusted_flat_args, 580 fw_metadata, 581 ) = functionalized_rng_wrapper.pre_compile( 582 fw_module, adjusted_flat_args, aot_config, fw_metadata=fw_metadata 583 ) 584 if tracing_context := torch._guards.TracingContext.try_get(): 585 tracing_context.fw_metadata = inner_meta 586 587 with TracingContext.report_output_strides() as fwd_output_strides: 588 compiled_fw_func = aot_config.fw_compiler(fw_module, adjusted_flat_args) 589 590 if not hasattr(compiled_fw_func, "_boxed_call"): 591 compiled_fw_func = make_boxed_func(compiled_fw_func) 592 593 if fakified_out_wrapper.needs_post_compile: 594 fakified_out_wrapper.set_fwd_output_strides(fwd_output_strides) 595 596 compiled_fw_func = EffectTokensWrapper().post_compile( 597 compiled_fw_func, 598 aot_config, 599 runtime_metadata=fw_metadata, 600 ) 601 602 compiled_fw_func = AOTDispatchSubclassWrapper( 603 fw_only=None, 604 trace_joint=False, 605 maybe_subclass_meta=maybe_subclass_meta, 606 num_fw_outs_saved_for_bw=num_fw_outs_saved_for_bw, 607 ).post_compile( 608 compiled_fw_func, 609 aot_config, # not used 610 runtime_metadata=fw_metadata, 611 ) 612 613 compiled_fw_func = functionalized_rng_wrapper.post_compile( 614 compiled_fw_func, aot_config, runtime_metadata=fw_metadata 615 ) 616 compiled_fw_func = fakified_out_wrapper.post_compile( 617 compiled_fw_func, 618 aot_config, 619 runtime_metadata=fw_metadata, 620 ) 621 622 # NB: It's important to compile backwards ahead of time, as this may 623 # add extra guards which we need to apply to the Dynamo cache at 624 # forwards 625 with track_graph_compiling(aot_config, "backward"), torch._C._DisableAutocast(): 626 placeholder_list = fx_placeholder_vals(bw_module) 627 628 forward_saved_for_backwards_strides = None 629 if fwd_output_strides is not None: 630 forward_saved_for_backwards_strides = fwd_output_strides[ 631 inner_meta.tensors_saved_for_backwards_slice 632 ] 633 634 # saved activations can have different stride to eager if 635 # the compiler does layout optimization. We should restride the 636 # tensor passed in for compiling the backward graph using the 637 # saved tensor's stride. 638 for i in range(len(placeholder_list)): 639 ph_arg = placeholder_list[i] 640 if not isinstance(ph_arg, torch.Tensor): 641 continue 642 643 if forward_saved_for_backwards_strides is None: 644 continue 645 646 real_stride = None 647 # Per all_args calling convention 648 j = i - num_symints_saved_for_bw 649 if 0 <= j < len(forward_saved_for_backwards_strides): 650 real_stride = forward_saved_for_backwards_strides[j] 651 if real_stride is None: 652 continue 653 654 # Comparing ph_arg.stride() with real_stride directly may 655 # cause dynamic dimensions in ph_arg being specialized to static 656 # value. Using the hints to avoid that. 657 if _get_symint_hints(ph_arg.stride()) != real_stride: 658 # Note that here we use the stride of the real tensor to 659 # restride a FakeTensor. This does not cause trouble 660 # for dynamic shape since this code path only get 661 # executed if layout optimization is enabled. And we 662 # disable layout optimization for dynamic shape right 663 # now. 664 # 665 # A solution that decide stride order based on real 666 # tensor's stride and then apply that stride order to 667 # the FakeTensor does not work smoothly since some 668 # tensor's layout is not 'dense'. E.g. mixnet_l has a 669 # tensor with size [8, 64, 112, 112] and strides 670 # (2408448, 1, 21504, 192). The solution mentioned will 671 # decide a stride of (802816, 1, 7168, 64) for this 672 # tensor which is wrong. 673 placeholder_list[i] = ph_arg.as_strided(ph_arg.size(), real_stride) 674 675 compiled_bw_func = None 676 if num_symints_saved_for_bw > 0: 677 try: 678 compiled_bw_func = aot_config.bw_compiler( 679 bw_module, placeholder_list 680 ) 681 except Exception as e: 682 exc = e 683 trace_structured( 684 "artifact", 685 metadata_fn=lambda: { 686 "name": "eager_compile_backwards_failure", 687 "encoding": "string", 688 }, 689 payload_fn=lambda: "\n".join(traceback.format_exception(exc)), 690 ) 691 log.warning( 692 "failed to eagerly compile backwards for dynamic, suppressing in case backwards not needed", 693 exc_info=True, 694 ) 695 # Compiled autograd will run the bw_module in the backward pass, 696 # so recompilation need happen anyway if the backward pass is ever 697 # called. 698 # 699 # The reason we do the GraphModule recompilation here is because 700 # the lazy recompilation will cause issue in the backward pass 701 # with compiled autograd. 702 # 703 # Do the _LazyGraphModule.force_recompile here rather than when 704 # bw_module is first generated by the partitioner because the bw_module.recompile 705 # may be called in some code path later and cause the _LazyGraphModule.forward 706 # becomes the lazy version again. One example is when dynamic shape is enabled 707 # upfront, the bw_compiler will be called above which can cause extra 708 # graph module recompilation on bw_module. 709 if torch._dynamo.compiled_autograd.in_compiled_autograd_region: 710 from torch.fx._lazy_graph_module import _LazyGraphModule 711 712 _LazyGraphModule.force_recompile(bw_module) 713 714 saved_context = TracingContext.try_get() 715 saved_compile_context = CompileContext.try_get() 716 717 backward_state_indices = [ 718 idx for idx, x in enumerate(flat_args) if isinstance(x, BackwardState) 719 ] 720 assert len(backward_state_indices) <= 1 721 722 lazy_backward_info = AutogradLazyBackwardCompileInfo( 723 bw_module, 724 placeholder_list, 725 saved_context, 726 saved_compile_context, 727 ) 728 729 make_runtime_safe(fw_metadata, maybe_subclass_meta) 730 731 try_save_cache_entry: Optional[Callable] = None 732 if config.enable_autograd_cache: 733 734 def try_save_cache_entry(compiled_bw_func, _fw_metadata): # noqa: F811 735 fw_key = getattr(compiled_fw_func, "_fx_graph_cache_key", None) 736 bw_key = getattr(compiled_bw_func, "_fx_graph_cache_key", None) 737 if aot_config.cache_key and fw_key and bw_key: 738 entry = AOTAutogradCacheEntry( 739 CompiledForward(fw_key), 740 CompiledBackward( 741 bw_key, backward_state_indices, num_symints_saved_for_bw 742 ), 743 _fw_metadata, 744 wrappers, 745 maybe_subclass_meta, 746 num_fw_outs_saved_for_bw, 747 _indices_of_inps_to_detach, 748 ) 749 AOTAutogradCache.save(aot_config.cache_key, entry) 750 751 if compiled_bw_func is not None: 752 # If we already compiled it we can just run it right now without waiting 753 try_save_cache_entry(compiled_bw_func, fw_metadata) 754 try_save_cache_entry = None 755 756 compiled_fn = AOTDispatchAutograd.post_compile( 757 compiled_fw_func, 758 compiled_bw_func, 759 maybe_subclass_meta, 760 num_symints_saved_for_bw, 761 backward_state_indices, 762 disable_amp, 763 _indices_of_inps_to_detach, 764 lazy_backward_info, 765 aot_config, 766 fw_metadata=fw_metadata, 767 try_save_cache_entry=try_save_cache_entry, 768 ) 769 770 if config.debug_assert: 771 flat_requires_grad: List[Optional[bool]] = [ 772 a.requires_grad if isinstance(a, Tensor) else None for a in flat_args 773 ] 774 compiled_fn = DebugAssertWrapper( 775 flat_requires_grad=flat_requires_grad 776 ).post_compile(compiled_fn, aot_config, runtime_metadata=fw_metadata) 777 778 compiled_fn = post_compile( 779 wrappers, 780 compiled_fn, 781 aot_config, 782 runtime_metadata=fw_metadata, 783 ) 784 return compiled_fn 785