xref: /aosp_15_r20/external/pytorch/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2"""
3Functions in this module do most of the "work" of AOTAutograd.
4An aot_dispatch_* function:
5- Takes in the input flat_fn, flat_args, and some metadata
6- Runs a set of pre compile wrappers (e.g. argument deduping)
7- Runs the actual compiler
8- Wraps the returned callable in a set of post compile wrappers
9- Returns the wrapped callable and metadata.
10"""
11
12import itertools
13import logging
14import traceback
15from contextlib import nullcontext
16from typing import Any, Callable, List, Optional, Sequence, Tuple
17
18import torch
19import torch.utils.dlpack
20from torch import Tensor
21from torch._dynamo.utils import lazy_format_graph_code
22from torch._guards import CompileContext, TracingContext
23from torch._logging import getArtifactLogger, trace_structured
24from torch._subclasses import FakeTensor
25from torch.fx.experimental._backward_state import BackwardState
26from torch.fx.experimental.proxy_tensor import is_sym_node
27from torch.fx.experimental.symbolic_shapes import fx_placeholder_vals
28from torch.multiprocessing.reductions import StorageWeakRef
29
30from .. import config
31from .autograd_cache import (
32    AOTAutogradCache,
33    AOTAutogradCacheEntry,
34    CompiledBackward,
35    CompiledForward,
36)
37from .dispatch_and_compile_graph import (
38    aot_dispatch_autograd_graph,
39    aot_dispatch_base_graph,
40)
41from .logging_utils import track_graph_compiling
42from .runtime_wrappers import (
43    AOTDedupeWrapper,
44    AOTDispatchAutograd,
45    AOTDispatchSubclassWrapper,
46    AOTSyntheticBaseWrapper,
47    AutogradLazyBackwardCompileInfo,
48    CompilerWrapper,
49    DebugAssertWrapper,
50    EffectTokensWrapper,
51    FakifiedOutWrapper,
52    FunctionalizedRngRuntimeWrapper,
53    make_runtime_safe,
54    post_compile,
55    pre_compile,
56    RuntimeWrapper,
57)
58from .schemas import AOTConfig, MutationType, ViewAndMutationMeta
59from .subclass_utils import compute_inner_mutated_inp_indices_from_subclass_meta
60from .utils import _get_symint_hints, make_boxed_func, strict_zip, unlift_tokens
61
62
63zip = strict_zip
64
65log = logging.getLogger(__name__)
66aot_joint_log = getArtifactLogger(__name__, "aot_joint_graph")
67aot_graphs_log = getArtifactLogger(__name__, "aot_graphs")
68
69aten = torch.ops.aten
70
71# Returns a Callable and a ViewAndMutationMeta.
72# Currently, only export needs the ViewAndMutationMeta after this function.
73DispatchReturn = Tuple[Callable, ViewAndMutationMeta]
74
75
76def _create_wrappers_for_dispatch(needs_autograd: bool) -> List[CompilerWrapper]:
77    """
78    Wrappers that run on every dispatch function
79    """
80    return [AOTDedupeWrapper(), AOTSyntheticBaseWrapper(trace_joint=needs_autograd)]
81
82
83# Export's dispatching logic is unique in a few ways: it only needs the "graph"
84# bits of aot_autograd, and doesn't need to do any specific wrapping.
85def aot_dispatch_export(
86    flat_fn: Callable,
87    flat_args: List[Any],
88    aot_config: AOTConfig,
89    *,
90    fw_metadata: ViewAndMutationMeta,
91    needs_autograd: bool,
92) -> DispatchReturn:
93    wrappers = _create_wrappers_for_dispatch(needs_autograd)
94    flat_fn, flat_args, fw_metadata = pre_compile(
95        wrappers,
96        flat_fn,
97        flat_args,
98        aot_config,
99        fw_metadata=fw_metadata,
100    )
101    if needs_autograd and not aot_config.pre_dispatch:
102        graph, _, _ = aot_dispatch_autograd_graph(
103            flat_fn, flat_args, aot_config, fw_metadata=fw_metadata
104        )
105    else:
106        graph, _, _ = aot_dispatch_base_graph(
107            flat_fn, flat_args, aot_config, fw_metadata=fw_metadata
108        )
109
110    # NB: the wrappers that run in pre_compile for export are
111    # either a no-op, because they're not needed, or will raise a runtime error,
112    # since they don't support export.
113    # We still run these wrappers to make sure that they're not needed pre compile,
114    # but we technically don't need to run them post compile at all here.
115    compiled_fn, fw_metadata = post_compile(
116        wrappers, graph, aot_config, runtime_metadata=fw_metadata
117    )
118
119    # Therefore, since no wrapperes run, we don't get back a callable - we get back the raw fx graph
120    # (either a joint or an inference-only graph)
121    assert isinstance(compiled_fn, torch.fx.GraphModule)
122    return compiled_fn, fw_metadata
123
124
125def aot_dispatch_base(
126    flat_fn,
127    flat_args: List[Any],
128    aot_config: AOTConfig,
129    *,
130    fw_metadata: ViewAndMutationMeta,
131) -> DispatchReturn:
132    """
133    Handles functions that don't need autograd. Runs wrappers and compiles with fw_compiler.
134    """
135    wrappers = _create_wrappers_for_dispatch(needs_autograd=False)
136    flat_fn, flat_args, fw_metadata = pre_compile(
137        wrappers, flat_fn, flat_args, aot_config, fw_metadata=fw_metadata
138    )
139
140    fw_module, updated_flat_args, maybe_subclass_meta = aot_dispatch_base_graph(  # type: ignore[misc]
141        flat_fn, flat_args, aot_config, fw_metadata=fw_metadata
142    )
143
144    fakified_out_wrapper = FakifiedOutWrapper()
145    (
146        fw_module,
147        updated_flat_args,
148        fw_metadata,
149    ) = fakified_out_wrapper.pre_compile(
150        fw_module, updated_flat_args, aot_config, fw_metadata=fw_metadata
151    )
152    functionalized_rng_wrapper = FunctionalizedRngRuntimeWrapper()
153    (
154        fw_module,
155        updated_flat_args,
156        fw_metadata,
157    ) = functionalized_rng_wrapper.pre_compile(
158        fw_module, updated_flat_args, aot_config, fw_metadata=fw_metadata
159    )
160
161    disable_amp = torch._C._is_any_autocast_enabled()
162    context = torch._C._DisableAutocast if disable_amp else nullcontext
163
164    with context(), track_graph_compiling(aot_config, "inference"):
165        compiler = (
166            aot_config.inference_compiler
167            if aot_config.inference_compiler is not None
168            else aot_config.fw_compiler
169        )
170
171        if tracing_context := torch._guards.TracingContext.try_get():
172            tracing_context.fw_metadata = (
173                fw_metadata
174                if maybe_subclass_meta is None
175                else maybe_subclass_meta.fw_metadata
176            )
177
178        with TracingContext.report_output_strides() as fwd_output_strides:
179            compiled_fw = compiler(fw_module, updated_flat_args)
180
181        if fakified_out_wrapper.needs_post_compile:
182            fakified_out_wrapper.set_fwd_output_strides(fwd_output_strides)
183
184    make_runtime_safe(fw_metadata, maybe_subclass_meta)
185
186    # However, RuntimeWrapper does not expect the rng offsets in the
187    # output. So, we have to create another wrapper and take out the offset. As
188    # a result, we have to account for not boxed_call compilers as well.
189    if not hasattr(compiled_fw, "_boxed_call"):
190        compiled_fw = make_boxed_func(compiled_fw)
191
192    # Create a wrapper to set up the rng functionalize and fakified out bits
193    compiled_fw = functionalized_rng_wrapper.post_compile(
194        compiled_fw, aot_config, runtime_metadata=fw_metadata
195    )
196
197    if config.enable_autograd_cache and aot_config.cache_key:
198        if fw_key := getattr(compiled_fw, "_fx_graph_cache_key", None):
199            entry = AOTAutogradCacheEntry(
200                compiled_fw=CompiledForward(fw_key),
201                compiled_bw=None,
202                runtime_metadata=fw_metadata,
203                dispatch_wrappers=wrappers,
204                maybe_subclass_meta=maybe_subclass_meta,
205                num_fw_outs_saved_for_bw=None,
206                indices_of_inps_to_detach=[],
207            )
208            AOTAutogradCache.save(aot_config.cache_key, entry)
209
210    compiled_fw = fakified_out_wrapper.post_compile(
211        compiled_fw,
212        aot_config,
213        runtime_metadata=fw_metadata,
214    )
215
216    compiled_fw = EffectTokensWrapper().post_compile(
217        compiled_fw,
218        aot_config,
219        runtime_metadata=fw_metadata,
220    )
221
222    # Why do we need to pass in num_fw_outs_saved_for_bw?
223    # See Note: [Partitioner handling for Subclasses, Part 2]
224    compiled_fw = AOTDispatchSubclassWrapper(
225        trace_joint=False,
226        # TODO: once we use pre_compile this will be flat_fn at the top of this function
227        fw_only=None,
228        maybe_subclass_meta=maybe_subclass_meta,
229        num_fw_outs_saved_for_bw=None,
230    ).post_compile(
231        compiled_fw,
232        aot_config,  # not used
233        runtime_metadata=fw_metadata,
234    )
235
236    if not hasattr(compiled_fw, "_boxed_call"):
237        compiled_fw = make_boxed_func(compiled_fw)
238
239    compiled_fn = RuntimeWrapper(
240        indices_of_inps_to_detach=[],
241        trace_joint=False,
242        disable_amp=disable_amp,
243    ).post_compile(
244        compiled_fw,
245        aot_config,
246        runtime_metadata=fw_metadata,
247    )
248
249    compiled_fn = post_compile(
250        wrappers, compiled_fn, aot_config, runtime_metadata=fw_metadata
251    )
252    return compiled_fn
253
254
255def collect_fw_donated_buffer_idxs(
256    fw_ins: List[Optional[FakeTensor]],
257    user_fw_outs: List[Optional[FakeTensor]],
258    bw_outs: List[Optional[FakeTensor]],
259    saved_tensors: List[FakeTensor],
260) -> List[int]:
261    """
262    Checks if the saved tensors are donated buffers, which means a saved tensor is not
263    an alias of any tensors in fw_ins, user_fw_outs, and bw_outs.
264    """
265
266    storage_refs = set()
267    for t in itertools.chain(fw_ins, user_fw_outs, bw_outs):
268        if isinstance(t, FakeTensor):
269            storage_refs.add(StorageWeakRef(t.untyped_storage()))
270
271    num_saved_tensor = len(saved_tensors)
272    donated_buffer_idxs = []
273    for i in range(num_saved_tensor):
274        t = saved_tensors[i]
275        if StorageWeakRef(t.untyped_storage()) not in storage_refs:
276            donated_buffer_idxs.append(i)
277
278    return donated_buffer_idxs
279
280
281def collect_bw_donated_buffer_idxs(
282    fw_module: torch.fx.GraphModule,
283    bw_module: torch.fx.GraphModule,
284    fw_metadata: ViewAndMutationMeta,
285) -> List[int]:
286    """
287    Collects backward donated buffer indexes from fw_module and bw_module.
288    """
289
290    fw_ins = fw_module.graph.find_nodes(op="placeholder")
291    bw_outs = next(reversed(bw_module.graph.find_nodes(op="output"))).args[0]
292    fw_outs = next(reversed(fw_module.graph.find_nodes(op="output"))).args[0]
293
294    fw_ins = [n.meta["val"] if hasattr(n, "meta") else None for n in fw_ins]
295    fw_outs = [n.meta["val"] if hasattr(n, "meta") else None for n in fw_outs]
296    bw_outs = [n.meta["val"] if hasattr(n, "meta") else None for n in bw_outs]
297
298    user_fw_outs = fw_outs[: fw_metadata.num_forward]
299    saved_tensors = fw_outs[fw_metadata.tensors_saved_for_backwards_slice]
300
301    fw_donated_buffer = collect_fw_donated_buffer_idxs(
302        fw_ins,
303        user_fw_outs,
304        bw_outs,
305        saved_tensors,
306    )
307
308    assert fw_metadata.num_symints_saved_for_bw is not None
309    return [fw_metadata.num_symints_saved_for_bw + i for i in fw_donated_buffer]
310
311
312def aot_dispatch_autograd(
313    flat_fn,
314    flat_args: List[Any],
315    aot_config: AOTConfig,
316    *,
317    fw_metadata: ViewAndMutationMeta,
318) -> DispatchReturn:
319    """
320    Autograd logic. Generates a joint graph, partitions it, manipulates the input with various wrappers,
321    and returns a wrapped torch.autograd.Function with a forward and backward.
322    """
323    wrappers = _create_wrappers_for_dispatch(needs_autograd=True)
324    flat_fn, flat_args, fw_metadata = pre_compile(
325        wrappers,
326        flat_fn,
327        flat_args,
328        aot_config,
329        fw_metadata=fw_metadata,
330    )
331
332    fw_metadata.deterministic = torch.are_deterministic_algorithms_enabled()
333    fx_g, joint_inputs, maybe_subclass_meta = aot_dispatch_autograd_graph(
334        flat_fn, flat_args, aot_config, fw_metadata=fw_metadata
335    )
336
337    # Copied from aot_dispatch_autograd_graph.
338    disable_amp = torch._C._is_any_autocast_enabled()
339
340    if aot_config.enable_log:
341        aot_joint_log.info(
342            "%s",
343            lazy_format_graph_code(
344                "Joint graph",
345                fx_g,
346                aot_config.aot_id,
347                include_stride=True,
348                include_device=True,
349                colored=True,
350            ),
351        )
352        trace_structured(
353            "aot_joint_graph",
354            payload_fn=lambda: fx_g.print_readable(
355                print_output=False, include_stride=True, include_device=True
356            ),
357        )
358
359    with torch.no_grad():
360        inner_meta = (
361            fw_metadata
362            if maybe_subclass_meta is None
363            else maybe_subclass_meta.fw_metadata
364        )
365        with track_graph_compiling(aot_config, "joint"):
366            # See Note: [Partitioner handling for Subclasses, Part 1]
367            # See Note: [Recomputing subclass mutation handling]
368            mutated_inp_runtime_indices = (
369                compute_inner_mutated_inp_indices_from_subclass_meta(
370                    fw_metadata, inner_meta
371                )
372            )
373            num_tokens = len(fw_metadata.tokens)
374            num_mutated_inp_runtime_indices = len(mutated_inp_runtime_indices)
375            num_inner_fwd_outputs = (
376                num_mutated_inp_runtime_indices
377                + inner_meta.num_outputs
378                + inner_meta.num_intermediate_bases
379                + inner_meta.num_outputs_rng_offset
380                + num_tokens  # See Note [Side-Effectful Tokens in AOTAutograd]
381            )
382            fw_module, bw_module = aot_config.partition_fn(
383                fx_g, joint_inputs, num_fwd_outputs=num_inner_fwd_outputs
384            )
385
386            # See Note [Side-Effectful Tokens in AOTAutograd]
387            if config.unlift_effect_tokens and (
388                num_tokens > 0 or fw_metadata.num_backward_tokens > 0
389            ):
390                unlift_tokens(fw_module, fw_metadata, aot_config, bw_module)
391
392                num_inner_fwd_outputs -= num_tokens
393                joint_inputs = (
394                    joint_inputs[0][num_tokens:],
395                    joint_inputs[1],
396                )
397
398            fw_outs = next(iter(fw_module.graph.find_nodes(op="output"))).args[0]
399            # we only need to bookkeep the symints that are saved for bw, not any symints
400            # the user forward might have returned in its own output
401            fw_outs_saved_for_bw = fw_outs[num_inner_fwd_outputs:]
402            num_fw_outs_saved_for_bw = len(fw_outs_saved_for_bw)
403            symint_outs_saved_for_bw = [
404                n for n in fw_outs_saved_for_bw if is_sym_node(n)
405            ]
406            fw_metadata.num_symints_saved_for_bw = len(symint_outs_saved_for_bw)
407            inner_meta.num_symints_saved_for_bw = len(symint_outs_saved_for_bw)
408            num_symints_saved_for_bw = len(symint_outs_saved_for_bw)
409
410            if torch._functorch.config.donated_buffer:
411                fw_metadata.bw_donated_idxs = collect_bw_donated_buffer_idxs(
412                    fw_module,
413                    bw_module,
414                    inner_meta,
415                )
416                inner_meta.bw_donated_idxs = fw_metadata.bw_donated_idxs
417
418        if aot_config.enable_log:
419            aot_graphs_log.info(
420                "aot_config id: %s, fw_metadata=%s, inner_meta=%s",
421                str(aot_config.aot_id),
422                str(fw_metadata),
423                str(inner_meta),
424            )
425
426        # Note [Detaching inputs that never need gradients]
427        # See https://github.com/pytorch/pytorch/issues/97745
428        # Suppose we have a function like this that we want to compile:
429        #
430        # def f(x, y):
431        #     return torch.mul(x, y.detach())
432        #
433        # What gradients should we compute for x and y?
434        # By default, AOTAutograd will compute a gradient for **every** input that requires gradients,
435        # and so we'll compute:
436        #    x_grad_input = y
437        #    y_grad_input = None
438        # Does this preserve the semantics of eager mode?
439        # Unfortunately, no.
440        # Doing the above will cause autograd to **continue** to backprop the autograd tape
441        # that was generated from constructing y.
442        #
443        # This is **different** from what would have happened in eager mode.
444        # In eager mode, if we backprop through the output of this function, autograd will only traverse
445        # the bit of the autograd tape corresponding to "x".
446        # In particular, if a user had previously backpropped through y's autograd tape,
447        # And then they try to backprop through the output of the above function,
448        # then we'll hit the dreaded "Trying to backward through the graph a second time" error.
449        #
450        # You might think: If autograd sees that a gradient is None, shouldn't it stop early,
451        # instead of continuing the backprop through the ancestors of that node in the graph?
452        #
453        # Autograd has two passes:
454        # (1) a first pass that traverses the autograd graph and figures out which nodes need to be executed
455        # (2) a second pass that actually goes ahead and executes each node when it becomes ready,
456        #     propagating gradients
457        # By the time we're executing a node and we see that it produces a None, the set of nodes to execute
458        # is already locked-in.
459        #
460        # The fix: instead, we can recognize statically that the graph we're compiling will never contribute
461        # gradients to y, and prevent autograd from trying to traverse y's autograd tape at all.
462        # We can do this by manually detach'ing y before sending it through the `CompiledFunction`.
463        #
464        # Note that this solution is not bulletproof.
465        # It's possible to construct a case where eager may or may not have have tried to autograd through y,
466        # depending on the actual grad_outputs that were passed in during the backward.
467        # There is no easy fix for this: the simplest fix would be to run with `retain_graph=True`,
468        # allowing autograd to re-use the graph.
469        #
470        # An example of this case is:
471        # def f(x):
472        #     return x.detach() * 2, x * 3
473        # If we were to only backprop through outs[0], in eager, we would stop
474        # If we backward only on the first output, we shouldn't send a grad through x.
475        # But the custom autograd function doesn't know that: it will materialize zero grads for x * 3
476        # and we will end up with a zero grad at x.
477        # If we later backprop through the second output, this will also require backprop'ing through x.
478        # Meaning we'll need to use `retain_graph=True` to be able to backprop through x the second time.
479        _indices_of_inps_to_detach: List[int] = []
480
481        # reversed() since we expect output at end of graph
482        bw_output = next(reversed(bw_module.graph.find_nodes(op="output")))
483        bw_outs: Sequence[torch.fx.Node] = bw_output.args[0]  # type: ignore[assignment]
484
485        # TODO: we should apply the below "detach inputs if their gradients are statically known to be None"
486        # optimization even if we have subclass inputs/outputs (we do not handle this today).
487        # Computing which our our inputs get None gradients is a bit more complicated,
488        # if any of our inputs are subclasses. Why?
489        # (a) we need to make sure that we call .detach() on the input subclasses, since autograd sees subclasses.
490        # (b) The grad_outputs that we AOT computed in our backward graph are the desugared tensor tensors,
491        #     so we need to figure out which subclass fw inputs they map to.
492        if maybe_subclass_meta is None:
493            num_backward_tokens: int = inner_meta.num_backward_tokens
494            assert (
495                len(bw_outs)
496                == len(fw_metadata.input_info)
497                + inner_meta.num_outputs_rng_offset
498                + num_backward_tokens
499            )
500            bw_outs_no_rng_no_tokens = bw_outs
501            if (inner_meta.num_outputs_rng_offset + num_backward_tokens) > 0:
502                bw_outs_no_rng_no_tokens = bw_outs[
503                    : -(inner_meta.num_outputs_rng_offset + num_backward_tokens)
504                ]
505            assert len(bw_outs_no_rng_no_tokens) == len(fw_metadata.input_info)
506
507            for i, (bw_out) in enumerate(bw_outs_no_rng_no_tokens):
508                # If our input experiences a metadata mutation inside the graph (e.g. set_()),
509                # we *must* not detach, otherwise it will be the detach'd input that gets the metadata mutation
510                metadata_mutation_in_graph = (
511                    fw_metadata.input_info[i].mutation_type
512                    == MutationType.MUTATED_IN_GRAPH
513                    and fw_metadata.input_info[i].mutates_storage_metadata
514                )
515                is_non_leaf = (
516                    fw_metadata.input_info[i].requires_grad
517                    and not fw_metadata.input_info[i].is_leaf
518                )
519                if bw_out is None and not metadata_mutation_in_graph and is_non_leaf:
520                    _indices_of_inps_to_detach.append(i)
521
522        if aot_config.enable_log:
523            aot_graphs_log.info(
524                "%s",
525                lazy_format_graph_code(
526                    "Forward graph",
527                    fw_module,
528                    aot_config.aot_id,
529                    include_stride=True,
530                    include_device=True,
531                    colored=True,
532                ),
533            )
534            aot_graphs_log.info(
535                "%s",
536                lazy_format_graph_code(
537                    "Backward graph",
538                    bw_module,
539                    aot_config.aot_id,
540                    include_stride=True,
541                    include_device=True,
542                    colored=True,
543                ),
544            )
545            trace_structured(
546                "aot_forward_graph",
547                payload_fn=lambda: fw_module.print_readable(
548                    print_output=False, include_stride=True, include_device=True
549                ),
550            )
551            trace_structured(
552                "aot_backward_graph",
553                payload_fn=lambda: bw_module.print_readable(
554                    print_output=False, include_stride=True, include_device=True
555                ),
556            )
557
558        # AMP is already traced out in joint graph. we do not wish to reapply it accidentally
559        # in the compiler.
560        with track_graph_compiling(aot_config, "forward"), torch._C._DisableAutocast():
561            # flat_args at this point might still be subclasses-
562            # make sure to pass the unwrapped fake tensors into the compiler!
563            adjusted_flat_args = joint_inputs[0]
564
565            fakified_out_wrapper = FakifiedOutWrapper()
566            (
567                fw_module,
568                adjusted_flat_args,
569                fw_metadata,
570            ) = fakified_out_wrapper.pre_compile(
571                fw_module, adjusted_flat_args, aot_config, fw_metadata=fw_metadata
572            )
573
574            functionalized_rng_wrapper = FunctionalizedRngRuntimeWrapper(
575                return_new_outs=False
576            )
577            (
578                fw_module,
579                adjusted_flat_args,
580                fw_metadata,
581            ) = functionalized_rng_wrapper.pre_compile(
582                fw_module, adjusted_flat_args, aot_config, fw_metadata=fw_metadata
583            )
584            if tracing_context := torch._guards.TracingContext.try_get():
585                tracing_context.fw_metadata = inner_meta
586
587            with TracingContext.report_output_strides() as fwd_output_strides:
588                compiled_fw_func = aot_config.fw_compiler(fw_module, adjusted_flat_args)
589
590            if not hasattr(compiled_fw_func, "_boxed_call"):
591                compiled_fw_func = make_boxed_func(compiled_fw_func)
592
593            if fakified_out_wrapper.needs_post_compile:
594                fakified_out_wrapper.set_fwd_output_strides(fwd_output_strides)
595
596            compiled_fw_func = EffectTokensWrapper().post_compile(
597                compiled_fw_func,
598                aot_config,
599                runtime_metadata=fw_metadata,
600            )
601
602            compiled_fw_func = AOTDispatchSubclassWrapper(
603                fw_only=None,
604                trace_joint=False,
605                maybe_subclass_meta=maybe_subclass_meta,
606                num_fw_outs_saved_for_bw=num_fw_outs_saved_for_bw,
607            ).post_compile(
608                compiled_fw_func,
609                aot_config,  # not used
610                runtime_metadata=fw_metadata,
611            )
612
613            compiled_fw_func = functionalized_rng_wrapper.post_compile(
614                compiled_fw_func, aot_config, runtime_metadata=fw_metadata
615            )
616            compiled_fw_func = fakified_out_wrapper.post_compile(
617                compiled_fw_func,
618                aot_config,
619                runtime_metadata=fw_metadata,
620            )
621
622        # NB: It's important to compile backwards ahead of time, as this may
623        # add extra guards which we need to apply to the Dynamo cache at
624        # forwards
625        with track_graph_compiling(aot_config, "backward"), torch._C._DisableAutocast():
626            placeholder_list = fx_placeholder_vals(bw_module)
627
628            forward_saved_for_backwards_strides = None
629            if fwd_output_strides is not None:
630                forward_saved_for_backwards_strides = fwd_output_strides[
631                    inner_meta.tensors_saved_for_backwards_slice
632                ]
633
634            # saved activations can have different stride to eager if
635            # the compiler does layout optimization. We should restride the
636            # tensor passed in for compiling the backward graph using the
637            # saved tensor's stride.
638            for i in range(len(placeholder_list)):
639                ph_arg = placeholder_list[i]
640                if not isinstance(ph_arg, torch.Tensor):
641                    continue
642
643                if forward_saved_for_backwards_strides is None:
644                    continue
645
646                real_stride = None
647                # Per all_args calling convention
648                j = i - num_symints_saved_for_bw
649                if 0 <= j < len(forward_saved_for_backwards_strides):
650                    real_stride = forward_saved_for_backwards_strides[j]
651                if real_stride is None:
652                    continue
653
654                # Comparing ph_arg.stride() with real_stride directly may
655                # cause dynamic dimensions in ph_arg being specialized to static
656                # value. Using the hints to avoid that.
657                if _get_symint_hints(ph_arg.stride()) != real_stride:
658                    # Note that here we use the stride of the real tensor to
659                    # restride a FakeTensor. This does not cause trouble
660                    # for dynamic shape since this code path only get
661                    # executed if layout optimization is enabled. And we
662                    # disable layout optimization for dynamic shape right
663                    # now.
664                    #
665                    # A solution that decide stride order based on real
666                    # tensor's stride and then apply that stride order to
667                    # the FakeTensor does not work smoothly since some
668                    # tensor's layout is not 'dense'. E.g. mixnet_l has a
669                    # tensor with size [8, 64, 112, 112] and strides
670                    # (2408448, 1, 21504, 192). The solution mentioned will
671                    # decide a stride of (802816, 1, 7168, 64) for this
672                    # tensor which is wrong.
673                    placeholder_list[i] = ph_arg.as_strided(ph_arg.size(), real_stride)
674
675            compiled_bw_func = None
676            if num_symints_saved_for_bw > 0:
677                try:
678                    compiled_bw_func = aot_config.bw_compiler(
679                        bw_module, placeholder_list
680                    )
681                except Exception as e:
682                    exc = e
683                    trace_structured(
684                        "artifact",
685                        metadata_fn=lambda: {
686                            "name": "eager_compile_backwards_failure",
687                            "encoding": "string",
688                        },
689                        payload_fn=lambda: "\n".join(traceback.format_exception(exc)),
690                    )
691                    log.warning(
692                        "failed to eagerly compile backwards for dynamic, suppressing in case backwards not needed",
693                        exc_info=True,
694                    )
695            # Compiled autograd will run the bw_module in the backward pass,
696            # so recompilation need happen anyway if the backward pass is ever
697            # called.
698            #
699            # The reason we do the GraphModule recompilation here is because
700            # the lazy recompilation will cause issue in the backward pass
701            # with compiled autograd.
702            #
703            # Do the _LazyGraphModule.force_recompile here rather than when
704            # bw_module is first generated by the partitioner because the bw_module.recompile
705            # may be called in some code path later and cause the _LazyGraphModule.forward
706            # becomes the lazy version again. One example is when dynamic shape is enabled
707            # upfront, the bw_compiler will be called above which can cause extra
708            # graph module recompilation on bw_module.
709            if torch._dynamo.compiled_autograd.in_compiled_autograd_region:
710                from torch.fx._lazy_graph_module import _LazyGraphModule
711
712                _LazyGraphModule.force_recompile(bw_module)
713
714    saved_context = TracingContext.try_get()
715    saved_compile_context = CompileContext.try_get()
716
717    backward_state_indices = [
718        idx for idx, x in enumerate(flat_args) if isinstance(x, BackwardState)
719    ]
720    assert len(backward_state_indices) <= 1
721
722    lazy_backward_info = AutogradLazyBackwardCompileInfo(
723        bw_module,
724        placeholder_list,
725        saved_context,
726        saved_compile_context,
727    )
728
729    make_runtime_safe(fw_metadata, maybe_subclass_meta)
730
731    try_save_cache_entry: Optional[Callable] = None
732    if config.enable_autograd_cache:
733
734        def try_save_cache_entry(compiled_bw_func, _fw_metadata):  # noqa: F811
735            fw_key = getattr(compiled_fw_func, "_fx_graph_cache_key", None)
736            bw_key = getattr(compiled_bw_func, "_fx_graph_cache_key", None)
737            if aot_config.cache_key and fw_key and bw_key:
738                entry = AOTAutogradCacheEntry(
739                    CompiledForward(fw_key),
740                    CompiledBackward(
741                        bw_key, backward_state_indices, num_symints_saved_for_bw
742                    ),
743                    _fw_metadata,
744                    wrappers,
745                    maybe_subclass_meta,
746                    num_fw_outs_saved_for_bw,
747                    _indices_of_inps_to_detach,
748                )
749                AOTAutogradCache.save(aot_config.cache_key, entry)
750
751        if compiled_bw_func is not None:
752            # If we already compiled it we can just run it right now without waiting
753            try_save_cache_entry(compiled_bw_func, fw_metadata)
754            try_save_cache_entry = None
755
756    compiled_fn = AOTDispatchAutograd.post_compile(
757        compiled_fw_func,
758        compiled_bw_func,
759        maybe_subclass_meta,
760        num_symints_saved_for_bw,
761        backward_state_indices,
762        disable_amp,
763        _indices_of_inps_to_detach,
764        lazy_backward_info,
765        aot_config,
766        fw_metadata=fw_metadata,
767        try_save_cache_entry=try_save_cache_entry,
768    )
769
770    if config.debug_assert:
771        flat_requires_grad: List[Optional[bool]] = [
772            a.requires_grad if isinstance(a, Tensor) else None for a in flat_args
773        ]
774        compiled_fn = DebugAssertWrapper(
775            flat_requires_grad=flat_requires_grad
776        ).post_compile(compiled_fn, aot_config, runtime_metadata=fw_metadata)
777
778    compiled_fn = post_compile(
779        wrappers,
780        compiled_fn,
781        aot_config,
782        runtime_metadata=fw_metadata,
783    )
784    return compiled_fn
785