xref: /aosp_15_r20/external/pytorch/torch/_inductor/compile_fx.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import contextlib
3import functools
4import itertools
5import logging
6import os
7import sys
8import time
9import warnings
10from itertools import count
11
12from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
13from unittest import mock
14
15import torch._inductor.async_compile  # noqa: F401 required to warm up AsyncCompile pools
16
17import torch.fx
18import torch.utils._pytree as pytree
19
20from functorch.compile import min_cut_rematerialization_partition
21from torch._dynamo import (
22    compiled_autograd,
23    config as dynamo_config,
24    logging as dynamo_logging,
25    utils as dynamo_utils,
26)
27from torch._dynamo.utils import (
28    counters,
29    detect_fake_mode,
30    flatten_graph_inputs,
31    lazy_format_graph_code,
32)
33from torch._functorch import config as functorch_config
34from torch._functorch.aot_autograd import aot_export_module, make_boxed_func
35from torch._inductor.codecache import code_hash, CompiledFxGraph, FxGraphCache
36from torch._inductor.cudagraph_utils import (
37    BoxedDeviceIndex,
38    get_placeholders,
39    log_cudagraph_skip_and_bump_counter,
40)
41
42from torch._inductor.debug import save_args_for_compile_fx_inner
43from torch._inductor.utils import (
44    BoxedBool,
45    count_tangents,
46    fresh_inductor_cache,
47    should_assume_input_aligned,
48    tensor_is_aligned,
49)
50from torch._logging import trace_structured
51from torch._ops import OpOverload
52from torch._subclasses.fake_tensor import FakeTensor
53from torch._utils_internal import compile_time_strobelight_meta
54from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols
55from torch.fx.passes.fake_tensor_prop import FakeTensorProp
56
57from .._dynamo.backends.common import aot_autograd
58from ..fx._lazy_graph_module import _use_lazy_graph_module  # type: ignore[attr-defined]
59from ..fx.graph import _PyTreeCodeGen
60from . import config, metrics
61from .debug import DebugContext
62from .decomposition import select_decomp_table
63from .fx_passes.joint_graph import joint_graph_passes
64from .fx_passes.post_grad import post_grad_passes, view_to_reshape
65from .fx_passes.pre_grad import pre_grad_passes
66from .graph import GraphLowering
67from .ir import ExternKernelNode
68from .utils import (
69    get_cloned_parameter_buffer_name,
70    has_incompatible_cudagraph_ops,
71    maybe_get_suppress_shape_guards_ctx,
72    output_node,
73)
74from .virtualized import V
75
76if config.is_fbcode():
77    from torch._inductor.fb.utils import log_optimus_to_scuba, time_and_log
78else:
79    # no-op decorator
80    def time_and_log(attr: str):
81        return dynamo_utils.identity
82
83
84log = logging.getLogger(__name__)
85perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints")
86post_grad_graphs_log = torch._logging.getArtifactLogger(__name__, "post_grad_graphs")
87ALIGNMENT = 16
88
89
90# copy_ fails when trying to write to tensors with memory overlap,
91# for expanded dimensions (a dimension which used to have size 1 -> ?)
92# we can select one element from that dimension and write to it
93# to achieve writing to all values of that dimension of the input tensor
94def get_expanded_dims(t):
95    if not isinstance(t, torch.Tensor):
96        return None
97    return [i for i in range(t.ndim) if t.stride(i) == 0 and t.size(i) != 1]
98
99
100def index_expanded_dims(t: torch.Tensor, expanded_dims: List[int]) -> torch.Tensor:
101    for expanded_dim in expanded_dims:
102        t = torch.ops.aten.slice(t, expanded_dim, 0, 1)
103    return t
104
105
106def complex_memory_overlap(t: torch.Tensor) -> bool:
107    # if torch._debug_has_internal_overlap thinks this tensor potentially has
108    # memory overlap internally, let's dig deeper to find out whether it's true.
109    #
110    # Call squeeze() so that dimension with size 1 does not cause false positive.
111    t = index_expanded_dims(t, get_expanded_dims(t)).squeeze()
112    if torch._debug_has_internal_overlap(t) != 0:
113        strides = t.stride()
114        sizes = t.shape
115        indices = list(range(len(strides)))
116        indices = [x for _, x in sorted(zip(strides, indices))]
117        for i in range(len(strides)):
118            prev_stride = 1 if i == 0 else strides[indices[i - 1]]
119            prev_size = 1 if i == 0 else sizes[indices[i - 1]]
120            if strides[indices[i]] < prev_stride * prev_size:
121                return True
122    return False
123
124
125def get_static_input_idxs(num_fixed):
126    # If we are inlining NNModules, we treat all torch.nn.Parameters as static for the purposes
127    # of cudagraphs. Rather than copying these into cudagraph-owned memory
128    # like we do for normal inputs on each run, we will re-record a cudagraph if these
129    # parameter locations change.
130    context = torch._guards.TracingContext.try_get()
131    fixed = list(range(num_fixed))
132    if not context or not context.fw_metadata:
133        return fixed
134
135    return fixed + context.fw_metadata.static_parameter_indices
136
137
138@functools.lru_cache(None)
139def _step_logger():
140    return dynamo_logging.get_step_logger(log)
141
142
143@functools.lru_cache(None)
144def _warn_tf32_disabled():
145    if (
146        torch.cuda.is_available()
147        and not torch.backends.cuda.matmul.allow_tf32
148        and torch.cuda.get_device_capability() >= (8, 0)
149    ):
150        warnings.warn(
151            "TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. "
152            "Consider setting `torch.set_float32_matmul_precision('high')` for better performance."
153        )
154
155
156def _unlift_graph(mod, gm, graph_signature):
157    from torch.export.unflatten import _assign_attr, _AttrKind
158
159    state_dict = {}
160    for name, param in mod.named_parameters(remove_duplicate=False):
161        state_dict[name] = param
162        _assign_attr(
163            param,
164            gm,
165            name,
166            attr_kind=_AttrKind.PARAMETER,
167        )
168    for name, buffer in mod.named_buffers(remove_duplicate=False):
169        state_dict[name] = buffer
170        _assign_attr(
171            buffer,
172            gm,
173            name,
174            attr_kind=_AttrKind.BUFFER,
175        )
176
177    placeholder_nodes = gm.graph.find_nodes(op="placeholder")
178    lifted_inputs = []
179
180    # In AOTI, module parameters and buffers are not lifted as graph inputs.
181    # As a result, mutation to buffers has side effect which makes their initial
182    # values different from Eager. So we clone them here as a copy.
183    # We are not cloning for parameters, although it will be needed if we want to
184    # support training.
185    for node in placeholder_nodes:
186        node_name = node.name
187        if node_name in graph_signature.inputs_to_parameters:
188            parameter_name = graph_signature.inputs_to_parameters[node_name]
189            lifted_inputs.append(parameter_name)
190        elif node_name in graph_signature.inputs_to_buffers:
191            buffer_name = graph_signature.inputs_to_buffers[node_name]
192            lifted_inputs.append(buffer_name)
193            gm.meta[
194                get_cloned_parameter_buffer_name(buffer_name)
195            ] = clone_preserve_strides(state_dict[buffer_name])
196        else:
197            assert node_name in graph_signature.user_inputs
198            lifted_inputs.append(None)
199
200    from torch.export._unlift import _unlift
201
202    outputs = list(gm.graph.nodes)[-1].args[0]
203    mutated_outputs = []
204    buffer_mutations = graph_signature.buffers_to_mutate
205    user_input_mutations = graph_signature.user_inputs_to_mutate
206    output_tokens = graph_signature.output_tokens
207    for idx, out in enumerate(outputs):
208        value = None
209
210        if idx < len(buffer_mutations) + len(user_input_mutations) + len(output_tokens):
211            if out.name in buffer_mutations:
212                value = buffer_mutations[out.name]
213            elif out.name in user_input_mutations:
214                value = user_input_mutations[out.name]
215
216        mutated_outputs.append(value)
217
218    unlifted_gm = _unlift(
219        gm,
220        lifted_inputs,
221        mutated_outputs,
222        pytree.LeafSpec(),
223        None,
224        state_dict,
225        {},
226    )
227    return unlifted_gm
228
229
230def _get_subgraph_names(gm):
231    for node in sorted(
232        itertools.chain(
233            gm.graph.find_nodes(op="call_function", target=torch.ops.higher_order.cond),
234            gm.graph.find_nodes(
235                op="call_function", target=torch.ops.higher_order.while_loop
236            ),
237        )
238    ):
239        if node.target == torch.ops.higher_order.cond:
240            true_subgraph_name = node.args[1].name
241            false_subgraph_name = node.args[2].name
242            yield true_subgraph_name
243            yield false_subgraph_name
244        elif node.target == torch.ops.higher_order.while_loop:
245            cond_subgraph_name = node.args[0].name
246            body_subgraph_name = node.args[1].name
247            yield cond_subgraph_name
248            yield body_subgraph_name
249
250
251def _recursive_pre_grad_passes(gm, example_inputs):
252    for subgraph_name in _get_subgraph_names(gm):
253        subgraph = getattr(gm, subgraph_name)
254        # as we don't have recursive example inputs, passing None here
255        new_subgraph = _recursive_pre_grad_passes(subgraph, example_inputs=None)
256        setattr(gm, subgraph_name, new_subgraph)
257    return pre_grad_passes(gm, example_inputs)
258
259
260def _recursive_joint_graph_passes(gm):
261    for subgraph_name in _get_subgraph_names(gm):
262        subgraph = getattr(gm, subgraph_name)
263        _recursive_joint_graph_passes(subgraph)
264    joint_graph_passes(gm)
265
266
267def _recursive_post_grad_passes(gm, is_inference: bool = False):
268    for subgraph_name in _get_subgraph_names(gm):
269        subgraph = getattr(gm, subgraph_name)
270        _recursive_post_grad_passes(subgraph, is_inference)
271    post_grad_passes(gm, is_inference)
272
273
274def split_const_gm(
275    gm: torch.fx.GraphModule,
276) -> Tuple[torch.fx.GraphModule, Dict[str, int]]:
277    """
278    This function takes an GraphModule input "gm".
279    The gm will be split into 2 components,
280      1) const_gm, which consists the subgraph of gm that can be constant folded.
281      2) gm (being inplace modified,) which returns the graph after constant folding.
282
283    const_output_index is a mapping of corresponding node name from gm to the
284    output index of const_gm.
285    Returns (const_gm, const_output_index)
286    """
287    from torch._inductor.constant_folding import (
288        CONST_MODULE_TAG,
289        META_TAG,
290        MODULE_TAG,
291        replace_node_with_constant,
292        run_and_get_constant_graph,
293    )
294
295    const_gm = run_and_get_constant_graph(gm)
296    const_result = const_gm()
297
298    const_outputs = {
299        x.name: idx for idx, x in enumerate(tuple(const_gm.graph.nodes)[-1].args[0])
300    }
301
302    to_erase_node = []
303    to_replace_node = []
304    const_output_index = {}
305    for node in gm.graph.nodes:
306        if node.name in const_outputs:
307            to_replace_node.append(node)
308        elif node.meta[META_TAG] == CONST_MODULE_TAG:
309            to_erase_node.append(node)
310
311    for node in to_replace_node:
312        new_const_name = "_FOLDED_CONST_" + node.name
313        replace_node_with_constant(
314            gm,
315            node,
316            const_result[const_outputs[node.name]],
317            new_const_name,
318        )
319        const_output_index[new_const_name] = const_outputs[node.name]
320    for node in to_erase_node[::-1]:
321        if node.users:
322            for n in node.users:
323                assert n.meta[META_TAG] == MODULE_TAG, f"node: {node} user not empty."
324        else:
325            gm.graph.erase_node(node)
326    gm.recompile()
327
328    return const_gm, const_output_index
329
330
331def is_tf32_warning_applicable(gm: torch.fx.GraphModule):
332    aten = torch.ops.aten
333    tf32_ops = {
334        aten.mm.default,
335        aten.addmm.default,
336        aten.bmm.default,
337        aten.baddbmm.default,
338    }
339    for target in tf32_ops:
340        for node in gm.graph.find_nodes(op="call_function", target=target):
341            if (
342                isinstance(node.meta.get("val", None), torch.Tensor)
343                and node.meta["val"].dtype == torch.float32
344                and node.meta["val"].device.type == "cuda"
345            ):
346                return True
347    return False
348
349
350def maybe_disable_comprehensive_padding(example_inputs: List[torch.Tensor]):
351    """
352    For CPU backend, enable comprehensive padding causes some unit tests
353    fail due to changing number of generated kernels. Skip for now.
354    """
355    has_cuda = any(
356        t.device.type == "cuda" for t in example_inputs if isinstance(t, torch.Tensor)
357    )
358
359    if config.comprehensive_padding and not has_cuda:
360        perf_hint_log.info("Skip comprehensive padding on CPU")
361        return config.patch(comprehensive_padding=False)
362    else:
363        return contextlib.nullcontext()
364
365
366def fake_tensor_prop(
367    gm: torch.fx.GraphModule,
368    example_inputs: List[torch.Tensor],
369    force_allow_non_fake_inputs: bool = False,
370):
371    """
372    If we can not detect fake mode from the context of inputs, create one.
373
374    The created fake mode will be returned.
375    """
376    fake_mode = detect_fake_mode(example_inputs)
377    if not fake_mode:
378        fake_mode = torch._subclasses.FakeTensorMode(allow_non_fake_inputs=True)
379        FakeTensorProp(gm, mode=fake_mode).propagate(*example_inputs)
380    else:
381        ctx = (
382            contextlib.nullcontext()
383            if not force_allow_non_fake_inputs
384            else mock.patch.object(fake_mode, "allow_non_fake_inputs", True)
385        )
386        with ctx:  # type: ignore[attr-defined]
387            FakeTensorProp(gm, mode=fake_mode).propagate_dont_convert_inputs(
388                *example_inputs
389            )
390
391    return fake_mode
392
393
394def should_use_remote_fx_graph_cache():
395    if config.fx_graph_remote_cache:
396        return True
397    if not config.is_fbcode():
398        return False
399    if torch.version.hip is not None:
400        return False
401
402    try:
403        from triton.fb.fb_memcache import MEMCACHE_VERSION
404    except ModuleNotFoundError:
405        return False
406
407    return MEMCACHE_VERSION >= torch._utils_internal.justknobs_getval_int(
408        "pytorch/remote_cache:fx_graph_memcache_version"
409    )
410
411
412# pass config dict back to user
413def get_patched_config_dict(config_patches=None) -> Dict[str, Any]:
414    with config.patch(config_patches):
415        return config.get_config_copy()
416
417
418@functools.wraps
419def with_fresh_cache_if_config(f):
420    if config.force_disable_caches:
421        with fresh_inductor_cache():
422            return f
423    else:
424        return f
425
426
427@DebugContext.wrap
428@torch.utils._python_dispatch._disable_current_modes()
429@time_and_log(attr="compilation time (in seconds)")
430# Need this decorator for compile_fx_inner even if we already have one for
431# compile_fx. The reason is the compilation for backward graph may happen after
432# compile_fx return and we may want to use the _LazyGraphModule for compiling
433# the backward graph as well.
434@_use_lazy_graph_module(dynamo_config.use_lazy_graph_module)
435@with_fresh_cache_if_config
436@dynamo_utils.dynamo_timed(phase_name="inductor_compile", fwd_only=False)
437def compile_fx_inner(
438    gm: torch.fx.GraphModule,
439    example_inputs: List[torch.Tensor],
440    cudagraphs: Optional[BoxedBool] = None,
441    static_input_idxs: Optional[List[int]] = None,
442    is_backward: bool = False,
443    graph_id: Optional[int] = None,
444    cpp_wrapper: bool = False,
445    aot_mode: bool = False,
446    is_inference: bool = False,
447    boxed_forward_device_index: Optional[BoxedDeviceIndex] = None,
448    user_visible_outputs: Optional[Dict[str, None]] = None,
449    layout_opt: Optional[bool] = None,
450    extern_node_serializer: Optional[Callable[[List[ExternKernelNode]], Any]] = None,
451) -> Union[CompiledFxGraph, str]:
452    """
453    Inductor API that compiles a single graph.
454
455    If you change the argument list for this function, make sure you
456    also update the call to save_args_for_compile_fx_inner below accordingly.
457    """
458    if dynamo_utils.count_calls(gm.graph) == 0 and not aot_mode:
459        # trigger the real recompilation for _LazyGraphModule before returning
460        # the forward method.
461        from torch.fx._lazy_graph_module import _LazyGraphModule
462
463        _LazyGraphModule.force_recompile(gm)
464        return make_boxed_func(gm.forward)
465
466    if static_input_idxs is None:
467        static_input_idxs = []
468
469    assert isinstance(
470        next(iter(reversed(gm.graph.nodes))).args[0], (tuple, list)
471    ), f"inductor can only compile FX graphs which return a tuple/list, but got {gm.graph}"
472
473    if config.save_args:
474        save_args_for_compile_fx_inner(
475            gm,
476            example_inputs,
477            cudagraphs=cudagraphs,
478            static_input_idxs=static_input_idxs,
479            is_backward=is_backward,
480            graph_id=graph_id,
481            cpp_wrapper=cpp_wrapper,
482            aot_mode=aot_mode,
483            is_inference=is_inference,
484            boxed_forward_device_index=boxed_forward_device_index,
485            user_visible_outputs=user_visible_outputs,
486            layout_opt=layout_opt,
487        )
488
489    if cudagraphs is None:
490        cudagraphs = BoxedBool(config.triton.cudagraphs)
491
492    # Inputs to fx_codegen_and_compile
493    # Anything that affects codegen should go here, so if the signature
494    # of fx_codegen_and_compile changes, the dict should be updated accordingly
495    graph_kwargs = {
496        "cudagraphs": cudagraphs,
497        "static_input_idxs": static_input_idxs,
498        "is_backward": is_backward,
499        "graph_id": graph_id,
500        "cpp_wrapper": cpp_wrapper,
501        "aot_mode": aot_mode,
502        "is_inference": is_inference,
503        "user_visible_outputs": user_visible_outputs,
504        "layout_opt": layout_opt,
505        "extern_node_serializer": extern_node_serializer,
506    }
507
508    start = time.time()
509
510    fx_graph_remote_cache = should_use_remote_fx_graph_cache()
511    inputs_to_check = get_input_idxs_to_check(example_inputs, static_input_idxs)
512    if (
513        not config.force_disable_caches
514        and (config.fx_graph_cache or fx_graph_remote_cache)
515        and not aot_mode
516    ):
517        for i, input in enumerate(example_inputs):
518            if (
519                isinstance(input, torch.Tensor)
520                and input.device.type == "cuda"
521                and i in static_input_idxs
522            ):
523                input._is_inductor_static = True  # type: ignore[attr-defined]
524
525        compiled_graph = FxGraphCache.load(
526            fx_codegen_and_compile,
527            gm,
528            example_inputs,
529            graph_kwargs,
530            inputs_to_check,
531            local=config.fx_graph_cache,
532            remote=fx_graph_remote_cache,
533        )
534    else:
535        compiled_graph = fx_codegen_and_compile(
536            gm, example_inputs, **graph_kwargs  # type: ignore[arg-type]
537        )
538
539    log.debug("FX codegen and compilation took %.3fs", time.time() - start)
540
541    # check cudagraph disabling reasons from inductor lowering
542    if cudagraphs and compiled_graph.disabled_cudagraphs_reason:
543        if "cuda" in compiled_graph.device_types:
544            log_cudagraph_skip_and_bump_counter(
545                f"skipping cudagraphs due to {compiled_graph.disabled_cudagraphs_reason}"
546            )
547        else:
548            counters["inductor"]["cudagraph_skips"] += 1
549        BoxedBool.disable(cudagraphs)
550
551    # Return the output strides to the caller via TracingContext
552    context = torch._guards.TracingContext.try_get()
553    if context is not None and context.output_strides is not None:
554        assert len(context.output_strides) == 0
555        context.output_strides.extend(compiled_graph.output_strides)
556
557    if aot_mode:
558        return compiled_graph
559
560    if cudagraphs:
561        # output args are tuple of first argument
562        output = output_node(gm)
563        assert len(output.args) == 1
564        stack_traces = [
565            (arg.stack_trace if isinstance(arg, torch.fx.node.Node) else None)
566            for arg in output.args[0]
567        ]
568
569        complex_memory_overlap_inputs = any(
570            complex_memory_overlap(t)
571            for t in example_inputs
572            if isinstance(t, torch.Tensor)
573        )
574
575        if not config.triton.cudagraph_support_input_mutation:
576            # Skip supports for cudagraph-managed tensors
577            from torch._inductor.cudagraph_utils import (
578                check_for_mutation_ignore_cuda_graph_managed_tensor,
579            )
580
581            has_mutation_str = check_for_mutation_ignore_cuda_graph_managed_tensor(
582                gm, compiled_graph, static_input_idxs
583            )
584            has_mutation = has_mutation_str is not None
585
586            if has_mutation:
587                compiled_graph.disabled_cudagraphs_reason = has_mutation_str
588        else:
589            # Check mutation later to support cudagraph-managed tensors
590            has_mutation = None
591
592        cudagraph_tests = [
593            (not has_mutation, "mutated inputs"),
594            (not has_incompatible_cudagraph_ops(gm), "incompatible ops"),
595            (not complex_memory_overlap_inputs, "complex memory overlap"),
596            (
597                all(
598                    isinstance(t, (torch.Tensor, torch.SymInt)) for t in example_inputs
599                ),
600                "non-Tensor inputs",
601            ),
602        ]
603        cudagraph_fail_reasons = [s for b, s in cudagraph_tests if not b]
604
605        if not cudagraph_fail_reasons:
606            if not config.triton.cudagraph_trees:
607                # Force specialize all inputs so that CUDA graphs will work
608                for t in example_inputs:
609                    if isinstance(t, torch.SymInt):
610                        int(t)  # guard
611
612            if (
613                boxed_forward_device_index is not None
614                and not is_inference
615                and not is_backward
616            ):
617                boxed_forward_device_index.set(next(iter(compiled_graph.device_idxs)))
618
619            compiled_graph.current_callable = cudagraphify(
620                compiled_graph.current_callable,
621                example_inputs,
622                static_input_idxs=static_input_idxs,
623                device_index=next(iter(compiled_graph.device_idxs)),
624                stack_traces=stack_traces,
625                is_backward=is_backward,
626                is_inference=is_inference,
627                constants=tuple(compiled_graph.constants.values()),
628                placeholders=tuple(get_placeholders(gm.graph)),
629                mutated_input_idxs=tuple(compiled_graph.mutated_input_idxs),
630            )
631        else:
632            BoxedBool.disable(cudagraphs)
633
634            # See [Backward Generation Handling]
635            # if cudagraph'd the forward and set the device, we need to let the cudagraph manager
636            # know we are we running the backward even if we will not run it in cudagraphs
637            if is_backward and config.triton.cudagraph_trees:
638                assert boxed_forward_device_index is not None
639                assert boxed_forward_device_index.value is not None
640                compiled_graph_callable = compiled_graph.current_callable
641
642                manager = torch._inductor.cudagraph_trees.get_manager(
643                    boxed_forward_device_index.value, create_if_none_exists=False
644                )
645                # should already exist from forward
646                assert manager is not None
647
648                def compiled_artifact(new_inputs):
649                    manager.set_to_running_backward()  # type: ignore[union-attr]
650                    return compiled_graph_callable(new_inputs)
651
652                compiled_graph.current_callable = compiled_artifact
653
654            if "cuda" in compiled_graph.device_types:
655                # prefer better disable_cudagraphs_reason bc stack trace
656                # TODO: migrate all disable reasons to stack trace, refactor
657                if compiled_graph.disabled_cudagraphs_reason:
658                    log_cudagraph_skip_and_bump_counter(
659                        compiled_graph.disabled_cudagraphs_reason
660                    )
661                else:
662                    log_cudagraph_skip_and_bump_counter(
663                        f"skipping cudagraphs due to {cudagraph_fail_reasons}"
664                    )
665
666    # cudagraphs does its own aligning of inputs
667    if not cudagraphs:
668        new_callable = align_inputs_from_check_idxs(
669            compiled_graph.current_callable, inputs_to_check
670        )
671        if new_callable is not compiled_graph.current_callable:
672            compiled_graph.current_callable = new_callable
673
674    _step_logger()(
675        logging.INFO,
676        "torchinductor done compiling "
677        f"{'BACKWARDS' if is_backward else 'FORWARDS'} "
678        f"graph {graph_id}",
679    )
680
681    # aot autograd needs to know to pass in inputs as a list
682    compiled_graph._boxed_call = True
683    return compiled_graph
684
685
686@dynamo_utils.preserve_rng_state()
687def fx_codegen_and_compile(
688    gm: torch.fx.GraphModule,
689    example_inputs: List[torch.Tensor],
690    cudagraphs: Optional[BoxedBool] = None,
691    static_input_idxs: Optional[List[int]] = None,
692    is_backward: bool = False,
693    graph_id: Optional[int] = None,
694    cpp_wrapper: bool = False,
695    aot_mode: bool = False,
696    is_inference: bool = False,
697    # Use a dict with None value rather than a set for deterministic
698    # iteration order just in case.
699    user_visible_outputs: Optional[Dict[str, None]] = None,
700    layout_opt: Optional[bool] = None,
701    extern_node_serializer: Optional[Callable[[List[ExternKernelNode]], Any]] = None,
702) -> Union[CompiledFxGraph, str]:
703    if is_tf32_warning_applicable(gm):
704        _warn_tf32_disabled()
705
706    # lift the maximum depth of the Python interpreter stack
707    # to adapt large/deep models
708    sys.setrecursionlimit(max(sys.getrecursionlimit(), 2000))
709
710    _step_logger()(
711        logging.INFO,
712        "torchinductor compiling "
713        f"{'BACKWARDS' if is_backward else 'FORWARDS'} "
714        f"graph {graph_id}",
715    )
716    V.debug.fx_graph(gm, example_inputs)
717    # TODO: Should we actually dump this?  It should be redundant with the aot
718    # structured logs...
719    # trace_structured("inductor_input_graph", payload_fn=lambda: gm.print_readable(print_output=False))
720
721    shape_env = _shape_env_from_inputs(example_inputs)
722
723    # Convert view to reshape in the graph. This is necessary primarily for
724    # layout optimization. Do it unconditionally for uniformity.
725    #
726    # It's needed because when we do layout optimization, an contiguous tensor
727    # in eager mode may becomes a channels last tensor. A view op previously
728    # can be applied to the contiguous tensor may not be able to be applied
729    # on the channels tensor any more. An error like
730    #   RuntimeError: view size is not compatible with input tensor's size and stride
731    #   (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
732    # will be printed.
733    #
734    # Replace view op to reshape op in this case.
735    # As an example, timm_resnest/botnet26t_256/convnext_base etc. will fail if we don't do this.
736    #
737    # Also this has to be done before FakeTensorProp below to avoid the failed
738    # .view() call.
739    view_to_reshape(gm)
740
741    # It is safe to run FakeTensorProp under no_grad because by the time
742    # we're in inductor, we assume that AOTAutograd has already "taken care"
743    # of autograd, so there should be no more autograd-related API's in the
744    # graph.
745    with torch.no_grad():
746        fake_mode = fake_tensor_prop(gm, example_inputs)
747
748    # pattern matcher passes might not preserve striding information
749    # on node.meta["val"]. if in the future we rely on these being
750    # correct we will need to fix.
751
752    with V.set_fake_mode(fake_mode):
753        # has some issues with memory in training
754        _recursive_post_grad_passes(gm, is_inference=is_inference)
755        V.debug.fx_graph_transformed(gm, example_inputs)
756        post_grad_graphs_log.debug(
757            "%s",
758            lazy_format_graph_code(
759                "AFTER POST GRAD", gm, include_stride=True, include_device=True
760            ),
761        )
762        trace_structured(
763            "inductor_post_grad_graph",
764            payload_fn=lambda: gm.print_readable(
765                print_output=False, include_stride=True, include_device=True
766            ),
767        )
768        if config.is_fbcode():
769            log_optimus_to_scuba(
770                extra_logging={"pt2_configs": str(get_patched_config_dict())}
771            )
772
773    with V.set_fake_mode(fake_mode), maybe_disable_comprehensive_padding(
774        example_inputs
775    ):
776        const_output_index = None
777        const_graph = None
778        const_code = None
779
780        if aot_mode and config.aot_inductor.use_runtime_constant_folding:
781            const_gm, const_output_index = split_const_gm(gm)
782
783            const_graph = GraphLowering(
784                const_gm,
785                example_inputs=[],
786                shape_env=shape_env,
787                graph_id=graph_id,
788                cpp_wrapper=cpp_wrapper,
789                aot_mode=aot_mode,
790                user_visible_outputs=user_visible_outputs,
791                extern_node_serializer=extern_node_serializer,
792                is_inference=is_inference,
793                is_const_graph=True,
794            )
795            with V.set_graph_handler(const_graph):
796                assert cpp_wrapper, "AOT mode only supports C++ wrapper"
797                const_graph.run()
798
799                const_code, _ = const_graph.codegen_with_cpp_wrapper()
800
801        graph = GraphLowering(
802            gm,
803            # example_inputs will be used by AOTInductor to dry-run the generated code for Triton kernel tuning.
804            # For the forward pass, we have the real inputs to be used as example_inputs. For the backward pass,
805            # we currently use fake tensors and defake them later.
806            example_inputs=example_inputs,
807            shape_env=shape_env,
808            graph_id=graph_id,
809            cpp_wrapper=cpp_wrapper,
810            aot_mode=aot_mode,
811            user_visible_outputs=user_visible_outputs,
812            extern_node_serializer=extern_node_serializer,
813            is_inference=is_inference,
814            const_output_index=const_output_index,
815            const_code=const_code,
816            const_module=const_graph,
817        )
818        metrics_helper = metrics.CachedMetricsHelper()
819        with V.set_graph_handler(graph):
820            graph.run(*example_inputs)
821            output_strides: List[Optional[Tuple[int, ...]]] = []
822            if graph.graph_outputs is not None:
823                # We'll put the output strides in the compiled graph so we
824                # can later return them to the caller via TracingContext
825                for out in graph.graph_outputs:
826                    if (
827                        hasattr(out, "layout")
828                        and len(free_unbacked_symbols(out.layout.stride)) == 0
829                    ):
830                        output_strides.append(
831                            tuple(
832                                V.graph.sizevars.size_hint(s) for s in out.layout.stride
833                            )
834                        )
835                    else:
836                        output_strides.append(None)
837
838            _check_triton_bf16_support(graph)
839            compiled_fn = graph.compile_to_fn()
840            num_bytes, nodes_num_elem, node_runtimes = graph.count_bytes()
841            metrics.num_bytes_accessed += num_bytes
842            metrics.node_runtimes += node_runtimes
843            metrics.nodes_num_elem += nodes_num_elem
844
845            if (
846                cudagraphs
847                and config.triton.cudagraph_skip_dynamic_graphs
848                and not V.graph.disable_cudagraphs_reason
849                and torch._inductor.utils.any_is_symbolic(*example_inputs)
850            ):
851                stack_trace = None
852                for node in gm.graph.nodes:
853                    meta_val = node.meta.get("val", None)
854                    if (
855                        node.op == "placeholder"
856                        or not isinstance(meta_val, torch.Tensor)
857                        or not torch._inductor.utils.any_is_symbolic(meta_val)
858                    ):
859                        continue
860
861                    if stack_trace := node.meta.get("stack_trace", None):
862                        break
863                disable = "graph with symbolic shapes inputs and config.triton.cudagraph_skip_dynamic_graphs=True."
864                if stack_trace:
865                    disable = f"{disable} Found from {stack_trace}\n"
866                else:
867                    disable = f"{disable}\n"
868                V.graph.disable_cudagraphs_reason = disable
869
870            if V.aot_compilation is True:
871                return compiled_fn
872
873            if cudagraphs and not V.graph.disable_cudagraphs_reason:
874                from torch._inductor.cudagraph_utils import (
875                    check_lowering_disable_cudagraph,
876                )
877
878                V.graph.disable_cudagraphs_reason = check_lowering_disable_cudagraph(
879                    V.graph.device_node_mapping
880                )
881
882            compiled_graph = CompiledFxGraph(
883                compiled_fn,
884                graph,
885                output_strides,
886                V.graph.disable_cudagraphs_reason,
887                metrics_helper.get_deltas(),
888            )
889
890    return compiled_graph
891
892
893def clone_preserve_strides(x: torch.Tensor):
894    needed_size = (
895        sum((shape - 1) * stride for shape, stride in zip(x.size(), x.stride())) + 1
896    )
897    buffer = torch.as_strided(x, (needed_size,), (1,)).clone()
898    return torch.as_strided(buffer, x.size(), x.stride())
899
900
901def copy_misaligned_inputs(
902    new_inputs: List[torch.Tensor], check_inputs_idxs: Sequence[int]
903) -> None:
904    for i in check_inputs_idxs:
905        if new_inputs[i].data_ptr() % ALIGNMENT:
906            new_inputs[i] = clone_preserve_strides(new_inputs[i])
907
908
909def get_input_idxs_to_check(
910    inputs: Union[List[torch.Tensor], Sequence[int]],
911    static_input_idxs: Sequence[int],
912) -> Sequence[int]:
913    """
914    This function runs at compile time, and generates a list of indices for which we
915    might need to do a copy to preserve alignment requirements.
916    """
917    ids_to_check = []
918
919    for i, input in enumerate(inputs):
920        if not isinstance(input, torch.Tensor):
921            # non-tensors don't need alignment
922            continue
923        if input.device.type != "cuda":
924            # right now we only care for cuda tensors
925            continue
926        with maybe_get_suppress_shape_guards_ctx():
927            # suppress guards so that tensor_is_aligned and should_assume_input_aligned
928            # do not add guards on input's storage offset
929            if i in static_input_idxs and tensor_is_aligned(input):
930                continue
931            if not should_assume_input_aligned(input):
932                continue
933
934        # if we get here, then
935        # (a) our triton code assumes that the input is aligned
936        # (b) we can't be sure ahead of time that the input will actually be aligned.
937        # therefore, at runtime, we'll need to check that the input is aligned
938        # (and if not, clone it to make it aligned.)
939        ids_to_check.append(i)
940
941    return ids_to_check
942
943
944def align_inputs_from_check_idxs(
945    model: Callable[[List[torch.Tensor]], Any], inputs_to_check: Sequence[int]
946):
947    if len(inputs_to_check) == 0:
948        return model
949
950    def run(new_inputs):
951        copy_misaligned_inputs(new_inputs, inputs_to_check)
952        return model(new_inputs)
953
954    return run
955
956
957@dynamo_utils.dynamo_timed
958def cudagraphify(
959    model: torch.fx.GraphModule,
960    inputs: List[torch.Tensor],
961    static_input_idxs: Sequence[int] = (),
962    *,
963    device_index: int,
964    stack_traces: List[Optional[str]],
965    is_backward: bool,
966    is_inference: bool,
967    constants: Tuple[torch.Tensor, ...] = (),
968    placeholders: Tuple[torch.fx.Node, ...] = (),
969    mutated_input_idxs: Tuple[int, ...] = (),
970):
971    from torch._inductor.cudagraph_trees import (
972        cudagraphify_impl as new_cudagraphify_impl,
973    )
974
975    cudagraphify_fn: Callable[..., Any]
976    if config.triton.cudagraph_trees:
977        cudagraphify_fn = functools.partial(
978            new_cudagraphify_impl,
979            device_index=device_index,
980            stack_traces=stack_traces,
981            is_backward=is_backward,
982            is_inference=is_inference,
983            constants=constants,
984            placeholders=placeholders,
985            mutated_input_idxs=mutated_input_idxs,
986        )
987    else:
988        cudagraphify_fn = cudagraphify_impl
989
990    # if using fake tensors, defer cudagraphs until we get real inputs at runtime
991    if not any(isinstance(inp, FakeTensor) for inp in inputs):
992        return cudagraphify_fn(model, inputs, static_input_idxs)
993
994    compiled_fn = None
995
996    def run(new_inputs):
997        nonlocal compiled_fn
998        if compiled_fn is None:
999            with dynamo_utils.preserve_rng_state():
1000                compiled_fn = cudagraphify_fn(model, new_inputs, static_input_idxs)
1001        return compiled_fn(new_inputs)
1002
1003    return run
1004
1005
1006def remove_unaligned_input_idxs(
1007    inputs: Union[List[torch.Tensor], Sequence[int]],
1008    static_input_idxs: Sequence[int],
1009):
1010    """
1011    We require all inputs to be aligned, so introduce a copy for any
1012    that aren't.
1013    """
1014    aligned_static_input_idxs = []
1015    for idx, input in zip(static_input_idxs, inputs):
1016        if isinstance(input, torch.Tensor) and (input.data_ptr() % ALIGNMENT) == 0:
1017            aligned_static_input_idxs.append(idx)
1018    if len(aligned_static_input_idxs) != len(static_input_idxs):
1019        return aligned_static_input_idxs
1020    return static_input_idxs
1021
1022
1023def static_input(x: torch.Tensor):
1024    """
1025    Copy and input while preserving strides
1026    """
1027    # TODO(jansel): figure out why this version doesn't work:
1028    # return torch.empty_strided(x.size(), x.stride(), dtype=x.dtype, device=x.device)
1029    needed_size = (
1030        sum((shape - 1) * stride for shape, stride in zip(x.size(), x.stride())) + 1
1031    )
1032    buffer = torch.empty(needed_size, dtype=x.dtype, device=x.device)
1033    return torch.as_strided(buffer, x.size(), x.stride())
1034
1035
1036def index_expanded_dims_and_copy_(
1037    dst: torch.Tensor,
1038    src: torch.Tensor,
1039    expanded_dims: List[int],
1040):
1041    "Index into expanded dimensions of both dst and src then copy_"
1042    dst = index_expanded_dims(dst, expanded_dims)
1043    src = index_expanded_dims(src, expanded_dims)
1044    dst.copy_(src)
1045
1046
1047def cudagraphify_impl(
1048    model: torch.fx.GraphModule,
1049    inputs: List[torch.Tensor],
1050    static_input_idxs: Sequence[int] = (),
1051):
1052    """
1053    Assumes inputs[static_input_idxs[i]] are always the same memory address
1054    """
1055    check_input_idxs = get_input_idxs_to_check(inputs, static_input_idxs)
1056    static_input_idxs = remove_unaligned_input_idxs(inputs, static_input_idxs)
1057    copy_misaligned_inputs(inputs, check_input_idxs)
1058
1059    assert isinstance(inputs, list)
1060
1061    inps_expanded_dims = [
1062        get_expanded_dims(x) if idx not in static_input_idxs else []
1063        for idx, x in enumerate(inputs)
1064    ]
1065
1066    # allocate static tensor inputs
1067    static_inputs = [
1068        x
1069        if not isinstance(x, torch.Tensor)
1070        else static_input(x)
1071        if idx not in static_input_idxs
1072        else x.detach()
1073        for idx, x in enumerate(inputs)
1074    ]
1075
1076    # copy over input values for fresh allocations
1077    for idx, (x, expanded_dims) in enumerate(zip(inputs, inps_expanded_dims)):
1078        if isinstance(x, torch.Tensor) and idx not in static_input_idxs:
1079            index_expanded_dims_and_copy_(static_inputs[idx], x, expanded_dims)
1080
1081    # warmup
1082    torch.cuda.synchronize()
1083    stream = torch.cuda.Stream()
1084    stream.wait_stream(torch.cuda.current_stream())
1085    # copy static_inputs because it will be cleared in model
1086    with torch.cuda.stream(stream):
1087        model(list(static_inputs))
1088    stream.synchronize()
1089    torch.cuda.current_stream().wait_stream(stream)
1090    torch.cuda.synchronize()
1091
1092    # record
1093    graph = torch.cuda.CUDAGraph()
1094    with torch.cuda.graph(graph, stream=stream, capture_error_mode="thread_local"):
1095        static_outputs = model(list(static_inputs))
1096    if not isinstance(static_outputs, (list, tuple)):
1097        static_outputs = (static_outputs,)
1098
1099    if config.size_asserts:
1100
1101        def run(new_inputs):
1102            assert len(static_inputs) == len(new_inputs)
1103            for idx, (dst, src, expanded_dims) in enumerate(
1104                zip(static_inputs, new_inputs, inps_expanded_dims)
1105            ):
1106                if not isinstance(dst, torch.Tensor):
1107                    pass
1108                elif idx in static_input_idxs:
1109                    assert dst.data_ptr() == src.data_ptr()
1110                else:
1111                    # TODO - could make one single op of multiple slices
1112                    # and avoid dispatch.
1113                    # Could also pre-index the `dst` tensors
1114                    index_expanded_dims_and_copy_(dst, src, expanded_dims)
1115            new_inputs.clear()
1116            graph.replay()
1117            return static_outputs
1118
1119    else:
1120        copy_indices = [
1121            idx for idx in range(len(static_inputs)) if idx not in static_input_idxs
1122        ]
1123
1124        def run(new_inputs):
1125            for idx in copy_indices:
1126                expanded_dims = inps_expanded_dims[idx]
1127                index_expanded_dims_and_copy_(
1128                    static_inputs[idx], new_inputs[idx], expanded_dims
1129                )
1130            new_inputs.clear()
1131            graph.replay()
1132            return static_outputs
1133
1134    return align_inputs_from_check_idxs(run, check_input_idxs)
1135
1136
1137def compile_fx_aot(
1138    model_: torch.fx.GraphModule,
1139    example_inputs_: List[torch.Tensor],
1140    inner_compile: Callable[..., Any] = compile_fx_inner,
1141    config_patches: Optional[Dict[str, Any]] = None,
1142):
1143    config_patches: Dict[str, Any] = (
1144        {"cpp_wrapper": True}
1145        if config_patches is None
1146        else {**config_patches, "cpp_wrapper": True}
1147    )
1148    if (
1149        "aot_inductor.output_path" not in config_patches
1150        and not config.aot_inductor.output_path
1151    ):
1152        config_patches = {
1153            **config_patches,
1154            "aot_inductor.output_path": code_hash(model_.code),
1155        }
1156
1157    extern_node_serializer = config_patches.pop("extern_node_serializer", None)
1158    with V.set_aot_compilation(True):
1159        compiled_lib_path = compile_fx(
1160            model_,
1161            example_inputs_,
1162            inner_compile=functools.partial(
1163                inner_compile,
1164                aot_mode=True,
1165                extern_node_serializer=extern_node_serializer,
1166            ),
1167            config_patches=config_patches,
1168        )
1169        assert os.path.exists(
1170            compiled_lib_path
1171        ), f"AOTInductor compiled library does not exist at {compiled_lib_path}"
1172        return compiled_lib_path
1173
1174
1175_graph_counter = count(0)
1176
1177
1178def fw_compiler_freezing(
1179    aot_autograd_model: torch.fx.GraphModule,
1180    aot_example_inputs: List[torch.Tensor],
1181    dynamo_model: torch.fx.GraphModule,
1182    num_example_inputs: int,
1183    inner_compile: Callable[..., Any],
1184    cudagraphs: BoxedBool,
1185    graph_id: int,
1186    forward_device: BoxedDeviceIndex,
1187):
1188    from torch._inductor.freezing import convert_conv_weights_to_channels_last, freeze
1189
1190    # partition_fn won't be called
1191    _recursive_joint_graph_passes(aot_autograd_model)
1192
1193    layout_opt = GraphLowering.decide_layout_opt(aot_autograd_model, is_inference=True)
1194    if layout_opt:
1195        # make sure meta['val'] is properly setup
1196        fake_tensor_prop(aot_autograd_model, aot_example_inputs, True)
1197        convert_conv_weights_to_channels_last(aot_autograd_model)
1198
1199    opt_model, preserved_arg_indices = freeze(
1200        dynamo_model,
1201        aot_autograd_model,
1202        aot_example_inputs,  # type: ignore[arg-type]
1203    )
1204
1205    aot_example_inputs = [aot_example_inputs[ind] for ind in preserved_arg_indices]
1206    num_fixed = len(preserved_arg_indices) - num_example_inputs
1207
1208    fake_mode = detect_fake_mode(aot_example_inputs)
1209
1210    # for freezing, all graph outputs should be user visible
1211    *_, model_outputs_node = opt_model.graph.nodes
1212    model_outputs = model_outputs_node.args[0]
1213    user_visible_outputs = dict.fromkeys(
1214        n.name for n in model_outputs if isinstance(n, torch.fx.Node)
1215    )
1216
1217    static_input_idxs = list(range(num_fixed))
1218    # constant params will be real tensors, not fake
1219    tracing_context = torch._guards.TracingContext.try_get()
1220    if tracing_context is not None:
1221        params_flat = tracing_context.params_flat
1222        assert params_flat is not None
1223        for i in range(len(params_flat)):
1224            if i not in preserved_arg_indices:
1225                params_flat[i] = None
1226
1227        if tracing_context.fw_metadata:
1228            static_input_idxs += tracing_context.fw_metadata.static_parameter_indices
1229
1230    with mock.patch.object(fake_mode, "allow_non_fake_inputs", True):
1231        optimized_function = inner_compile(
1232            opt_model,
1233            aot_example_inputs,
1234            static_input_idxs=static_input_idxs,
1235            cudagraphs=cudagraphs,
1236            graph_id=graph_id,
1237            is_inference=True,
1238            boxed_forward_device_index=forward_device,
1239            layout_opt=layout_opt,
1240            user_visible_outputs=user_visible_outputs,
1241        )
1242
1243    # aot_inductor codegens a call that takes in just the inputs, so we don't return a wrapper
1244    # that drops constant-ified params
1245    if V.aot_compilation is True:
1246        return optimized_function
1247
1248    def wrapper(args):
1249        args_new = [args[i] for i in preserved_arg_indices]
1250        args.clear()
1251        return optimized_function(args_new)
1252
1253    wrapper._boxed_call = True  # type: ignore[attr-defined]
1254
1255    return wrapper
1256
1257
1258@_use_lazy_graph_module(dynamo_config.use_lazy_graph_module)
1259def compile_fx(
1260    model_: torch.fx.GraphModule,
1261    example_inputs_: List[torch.Tensor],
1262    inner_compile: Callable[..., Any] = compile_fx_inner,
1263    config_patches: Optional[Dict[str, Any]] = None,
1264    decompositions: Optional[Dict[OpOverload, Callable[..., Any]]] = None,
1265):
1266    """Main entrypoint to a compile given FX graph"""
1267    if config_patches:
1268        with config.patch(config_patches):
1269            return compile_fx(
1270                model_,
1271                example_inputs_,
1272                # need extra layer of patching as backwards is compiled out of scope
1273                inner_compile=config.patch(config_patches)(inner_compile),
1274                decompositions=decompositions,
1275            )
1276
1277    if config.cpp_wrapper:
1278        with config.patch(
1279            {
1280                "cpp_wrapper": False,
1281                "triton.autotune_cublasLt": False,
1282                "triton.cudagraphs": False,
1283                "triton.store_cubin": True,
1284            }
1285        ), V.set_real_inputs(example_inputs_):
1286            inputs_ = example_inputs_
1287            if isinstance(model_, torch.fx.GraphModule):
1288                fake_inputs = [
1289                    node.meta.get("val")
1290                    for node in model_.graph.nodes
1291                    if node.op == "placeholder"
1292                ]
1293                if all(v is not None for v in fake_inputs):
1294                    # Validate devices before switching to fake tensors.
1295                    for idx, fi, i in zip(count(), fake_inputs, inputs_):
1296                        if fi.device != i.device:
1297                            raise ValueError(
1298                                f"Device mismatch between fake input and example input at position #{idx}: "
1299                                f"{fi.device} vs {i.device}. If the model was exported via torch.export(), "
1300                                "make sure torch.export() and torch.aot_compile() run on the same device."
1301                            )
1302                    inputs_ = fake_inputs
1303            return compile_fx(
1304                model_,
1305                inputs_,
1306                inner_compile=functools.partial(inner_compile, cpp_wrapper=True),
1307                decompositions=decompositions,
1308            )
1309
1310    recursive_compile_fx = functools.partial(
1311        compile_fx,
1312        inner_compile=inner_compile,
1313        decompositions=decompositions,
1314    )
1315
1316    if not graph_returns_tuple(model_):
1317        return make_graph_return_tuple(
1318            model_,
1319            example_inputs_,
1320            recursive_compile_fx,
1321        )
1322
1323    if isinstance(model_, torch.fx.GraphModule):
1324        if isinstance(model_.graph._codegen, _PyTreeCodeGen):
1325            # this graph is the result of dynamo.export()
1326            return handle_dynamo_export_graph(
1327                model_,
1328                example_inputs_,
1329                recursive_compile_fx,
1330            )
1331
1332        model_ = _recursive_pre_grad_passes(model_, example_inputs_)
1333
1334    if any(isinstance(x, (list, tuple, dict)) for x in example_inputs_):
1335        return flatten_graph_inputs(
1336            model_,
1337            example_inputs_,
1338            recursive_compile_fx,
1339        )
1340
1341    assert not config._raise_error_for_testing
1342    num_example_inputs = len(example_inputs_)
1343    cudagraphs = BoxedBool(config.triton.cudagraphs)
1344    forward_device = BoxedDeviceIndex(None)
1345
1346    graph_id = next(_graph_counter)
1347
1348    decompositions = (
1349        decompositions if decompositions is not None else select_decomp_table()
1350    )
1351
1352    @dynamo_utils.dynamo_timed
1353    def fw_compiler_base(
1354        model: torch.fx.GraphModule,
1355        example_inputs: List[torch.Tensor],
1356        is_inference: bool,
1357    ):
1358        if is_inference:
1359            # partition_fn won't be called
1360            _recursive_joint_graph_passes(model)
1361
1362        fixed = torch._inductor.utils.num_fw_fixed_arguments(
1363            num_example_inputs, len(example_inputs)
1364        )
1365
1366        user_visible_outputs = {}
1367
1368        if config.keep_output_stride:
1369            model_outputs_node = output_node(model)
1370            model_outputs = pytree.arg_tree_leaves(*model_outputs_node.args)
1371            num_model_outputs = len(model_outputs)
1372
1373            context = torch._guards.TracingContext.try_get()
1374            # See Note [User Outputs in the inductor graph]
1375            if context is not None and context.fw_metadata and not is_inference:
1376                original_output_start_index = (
1377                    context.fw_metadata.num_mutated_inp_runtime_indices
1378                )
1379            else:
1380                original_output_start_index = 0
1381
1382            if isinstance(model_, torch.fx.GraphModule):
1383                *_, orig_model_outputs_node = model_.graph.nodes
1384                assert orig_model_outputs_node.op == "output"
1385                orig_model_outputs, _ = pytree.tree_flatten(
1386                    orig_model_outputs_node.args
1387                )
1388                num_orig_model_outputs = len(orig_model_outputs)
1389            else:
1390                num_orig_model_outputs = num_model_outputs
1391
1392            assert num_orig_model_outputs <= num_model_outputs
1393
1394            # Note [User Outputs in the inductor graph]
1395            # We makes the following assumption
1396            # For inference
1397            #   len(orig_model_outputs) == len(model_outputs)
1398            # For training
1399            #   len(orig_model_outputs) <= len(model_outputs)
1400            # During training, most of the time the model_outputs starts with
1401            # original module's outputs followed by saved activations.
1402            # But this can be not true if the model have inplace updated tensors.
1403            # AOTAutograd will make those tensors being returned before the original
1404            # module's output.
1405            # To make things safe, we'll use original_output_start_index field
1406            # set by AOTAutograd to decide where the original module outputs start.
1407            orig_output_end_idx = original_output_start_index + num_orig_model_outputs
1408            # Sanity chec: we are about to splice out the "user" outputs from the full set
1409            # of "graph" outputs. Make sure we're within bounds.
1410            assert orig_output_end_idx <= num_model_outputs
1411
1412            user_visible_outputs = dict.fromkeys(
1413                n.name
1414                for n in model_outputs[original_output_start_index:orig_output_end_idx]
1415                if isinstance(n, torch.fx.Node)
1416            )
1417
1418        return inner_compile(
1419            model,
1420            example_inputs,
1421            static_input_idxs=get_static_input_idxs(fixed),
1422            cudagraphs=cudagraphs,
1423            graph_id=graph_id,
1424            is_inference=is_inference,
1425            boxed_forward_device_index=forward_device,
1426            user_visible_outputs=user_visible_outputs,
1427        )
1428
1429    fw_compiler = functools.partial(fw_compiler_base, is_inference=False)
1430
1431    if config.freezing and not torch.is_grad_enabled():
1432        inference_compiler = functools.partial(
1433            fw_compiler_freezing,
1434            dynamo_model=model_,
1435            num_example_inputs=num_example_inputs,
1436            inner_compile=inner_compile,
1437            cudagraphs=cudagraphs,
1438            graph_id=graph_id,
1439            forward_device=forward_device,
1440        )
1441    else:
1442        inference_compiler = functools.partial(fw_compiler_base, is_inference=True)
1443
1444    def partition_fn(graph, joint_inputs, **kwargs):
1445        _recursive_joint_graph_passes(graph)
1446        return min_cut_rematerialization_partition(
1447            graph, joint_inputs, **kwargs, compiler="inductor"
1448        )
1449
1450    @compile_time_strobelight_meta(phase_name="bw_compiler")
1451    @dynamo_utils.dynamo_timed
1452    def bw_compiler(model: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
1453        user_visible_outputs = {}
1454
1455        if config.bw_outputs_user_visible:
1456            model_outputs_node = output_node(model)
1457            model_outputs = pytree.arg_tree_leaves(*model_outputs_node.args)
1458            user_visible_outputs = dict.fromkeys(
1459                n.name for n in model_outputs if isinstance(n, torch.fx.Node)
1460            )
1461        fixed = count_tangents(model)
1462        return inner_compile(
1463            model,
1464            example_inputs,
1465            static_input_idxs=list(range(fixed)),
1466            cudagraphs=cudagraphs,
1467            is_backward=True,
1468            graph_id=graph_id,
1469            boxed_forward_device_index=forward_device,
1470            user_visible_outputs=user_visible_outputs,
1471        )
1472
1473    # TODO: can add logging before/after the call to create_aot_dispatcher_function
1474    # in torch._functorch/aot_autograd.py::aot_module_simplified::aot_function_simplified::new_func
1475    # once torchdynamo is merged into pytorch
1476
1477    fake_mode = detect_fake_mode(example_inputs_) or torch._subclasses.FakeTensorMode(
1478        allow_non_fake_inputs=True
1479    )
1480    tracing_context = (
1481        torch._guards.TracingContext.try_get()
1482        or torch._guards.TracingContext(fake_mode)
1483    )
1484
1485    if V.aot_compilation is True:
1486        with functorch_config.patch(unlift_effect_tokens=True):
1487            gm, graph_signature = aot_export_module(
1488                model_,
1489                example_inputs_,
1490                trace_joint=False,
1491                decompositions=decompositions,
1492            )
1493        unlifted_gm = _unlift_graph(model_, gm, graph_signature)
1494        if "dynamo_flat_name_to_original_fqn" in model_.meta:
1495            unlifted_gm.meta["dynamo_flat_name_to_original_fqn"] = model_.meta[
1496                "dynamo_flat_name_to_original_fqn"
1497            ]
1498
1499        # Disable amp as in aot_dispatch_autograd (https://github.com/pytorch/pytorch/pull/86515)
1500        # In inference_compiler (fw_compiler_base), _recursive_joint_graph_passes will call into
1501        # _sfdp_init() to register patterns.
1502        # When fallback_random is set to True, the sdpa patterns will be traced during runtime.
1503        # If amp is turned on, the traced FP32 patterns will have prims.convert_element_type which
1504        # will be the same as the generated FP16 patterns.
1505        disable_amp = torch._C._is_any_autocast_enabled()
1506        context = torch._C._DisableAutocast if disable_amp else contextlib.nullcontext
1507        with V.set_fake_mode(fake_mode), compiled_autograd.disable(), context():
1508            return inference_compiler(unlifted_gm, example_inputs_)
1509
1510    with V.set_fake_mode(fake_mode), torch._guards.tracing(
1511        tracing_context
1512    ), compiled_autograd.disable(), functorch_config.patch(unlift_effect_tokens=True):
1513        return aot_autograd(
1514            fw_compiler=fw_compiler,
1515            bw_compiler=bw_compiler,
1516            inference_compiler=inference_compiler,
1517            decompositions=decompositions,
1518            partition_fn=partition_fn,
1519            keep_inference_input_mutations=True,
1520        )(model_, example_inputs_)
1521
1522
1523def _shape_env_from_inputs(inputs: List[torch.Tensor]):
1524    shape_env = None
1525    fake_mode = detect_fake_mode(inputs)
1526
1527    # TODO(voz): It would be nice to enable this assert, but there are lots of tests that
1528    # pass in real inputs for now.
1529    # if len(inputs) > 0:
1530    # assert fake_mode is not None, breakpoint()
1531
1532    if fake_mode is not None:
1533        return fake_mode.shape_env
1534
1535    # When there are no tensor inputs, get shape_env from the first SymInt.
1536    for input in inputs:
1537        if isinstance(input, torch.SymInt):
1538            return input.node.shape_env
1539
1540    # TODO(voz): Should we always have one anyway?
1541    return None
1542
1543
1544def graph_returns_tuple(gm: torch.fx.GraphModule):
1545    """True if a FX graph returns a tuple"""
1546    if not isinstance(gm, torch.fx.GraphModule):
1547        return True  # can't check this, assume true
1548    (rv,) = output_node(gm).args
1549    if isinstance(rv, (list, tuple)):
1550        return True
1551    if (
1552        isinstance(rv, torch.fx.node.Node)
1553        and hasattr(rv.target, "_schema")
1554        and len(rv.target._schema.returns) > 1
1555        and all(str(ret.type) == "Tensor" for ret in rv.target._schema.returns)
1556    ):
1557        # for graphs whose result is one node with multiple outputs
1558        return True
1559    return False
1560
1561
1562def make_graph_return_tuple(
1563    gm: torch.fx.GraphModule,
1564    inputs: List[torch.Tensor],
1565    compile_gm: Callable[..., Any],
1566):
1567    """
1568    Mutate gm so it returns a tuple.  This is only needed for graphs
1569    not created by torchdynamo that return non-tuples.
1570    """
1571    node = output_node(gm)
1572    (rv,) = node.args
1573    rv, spec = pytree.tree_flatten(rv)
1574    with gm.graph.inserting_before(node):
1575        gm.graph.output(rv)
1576    gm.graph.erase_node(node)
1577    assert graph_returns_tuple(gm)
1578
1579    compiled_fn = compile_gm(gm, inputs)
1580
1581    @functools.wraps(compiled_fn)
1582    def wrapper(*args, **kwargs):
1583        return pytree.tree_unflatten(compiled_fn(*args, **kwargs), spec)
1584
1585    return wrapper
1586
1587
1588def handle_dynamo_export_graph(
1589    gm: torch.fx.GraphModule,
1590    inputs: List[torch.Tensor],
1591    compile_gm: Callable[..., Any],
1592):
1593    """
1594    `torch._dynamo.export` embeds pytrees in the FX graph codegen object,
1595    convert that to a normal FX graph so inductor can compile it.
1596    """
1597    codegen = gm.graph._codegen
1598    gm.graph._codegen = torch.fx.graph.CodeGen()
1599    gm.recompile()
1600
1601    compiled_fn = compile_gm(gm, codegen.process_inputs(*inputs))
1602
1603    @functools.wraps(compiled_fn)
1604    def wrapper(*args):
1605        return codegen.process_outputs(compiled_fn(*codegen.process_inputs(*args)))
1606
1607    return wrapper
1608
1609
1610def _check_triton_bf16_support(graph: GraphLowering) -> None:
1611    def warn_and_skip(device) -> None:
1612        from torch._dynamo.exc import SkipFrame
1613
1614        device_props = torch.cuda.get_device_properties(device)
1615        warnings.warn(
1616            f"{device_props.name} does not support bfloat16 compilation natively, skipping"
1617        )
1618        raise SkipFrame("BF16 is not supported")
1619
1620    for inp in graph.graph_inputs.values():
1621        device = getattr(inp, "get_device", lambda: torch.device("meta"))()
1622        if device.type != "cuda" or inp.get_dtype() != torch.bfloat16:
1623            continue
1624        # Print warning and skip frame if attempting to compile for bfloat16
1625        # on device without hardware support for dtype
1626        if torch.cuda.is_bf16_supported(including_emulation=False):
1627            return
1628        warn_and_skip(device)
1629
1630    for out in graph.graph_outputs:
1631        device = getattr(out, "get_device", lambda: torch.device("meta"))()
1632        if device.type != "cuda" or out.get_dtype() != torch.bfloat16:
1633            continue
1634        # Print warning and skip frame if attempting to compile for bfloat16
1635        # on device without hardware support for dtype
1636        if torch.cuda.is_bf16_supported(including_emulation=False):
1637            return
1638        warn_and_skip(device)
1639