xref: /aosp_15_r20/external/pytorch/torch/_dynamo/backends/common.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: ignore-errors
2
3import contextlib
4import functools
5import logging
6from unittest.mock import patch
7
8import torch
9from torch._dynamo import disable
10from torch._dynamo.utils import counters, defake, flatten_graph_inputs
11from torch._functorch.aot_autograd import aot_module_simplified
12from torch.utils._python_dispatch import _disable_current_modes
13
14
15log = logging.getLogger(__name__)
16
17
18class AotAutograd:
19    def __init__(self, **kwargs) -> None:
20        self.__name__ = "compiler_fn"
21        self.kwargs = kwargs
22
23    def __call__(self, gm: torch.fx.GraphModule, example_inputs, **kwargs):
24        if kwargs:
25            log.warning("aot_autograd-based backend ignoring extra kwargs %s", kwargs)
26
27        if any(isinstance(x, (list, tuple, dict)) for x in example_inputs):
28            return flatten_graph_inputs(
29                gm,
30                example_inputs,
31                self,
32            )
33
34        # Hack to get around circular import problems with aot_eager_decomp_partition
35        if callable(self.kwargs.get("decompositions")):
36            self.kwargs["decompositions"] = self.kwargs["decompositions"]()
37
38        # NB: dont delete counter increment
39        counters["aot_autograd"]["total"] += 1
40        use_fallback = False
41
42        if use_fallback:
43            log.debug("Unable to use AOT Autograd because graph has mutation")
44            counters["aot_autograd"]["not_ok"] += 1
45            return gm
46
47        # OK attempt to compile
48
49        def _wrapped_bw_compiler(*args, **kwargs):
50            # stop TorchDynamo from trying to compile our generated backwards pass
51            return disable(disable(bw_compiler)(*args, **kwargs))
52
53        bw_compiler = self.kwargs.get("bw_compiler") or self.kwargs["fw_compiler"]
54        self.kwargs["bw_compiler"] = _wrapped_bw_compiler
55        self.kwargs["inference_compiler"] = (
56            self.kwargs.get("inference_compiler") or self.kwargs["fw_compiler"]
57        )
58
59        from functorch.compile import nop
60        from torch._inductor.debug import enable_aot_logging
61
62        # debug asserts slow down compile time noticeably,
63        # So only default them on when the aot_eager backend is used.
64        if self.kwargs.get("fw_compiler", None) == nop:
65            patch_config = patch("functorch.compile.config.debug_assert", True)
66        else:
67            patch_config = contextlib.nullcontext()
68
69        try:
70            # NB: NOT cloned!
71            with enable_aot_logging(), patch_config:
72                cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
73                counters["aot_autograd"]["ok"] += 1
74                return disable(cg)
75        except Exception:
76            counters["aot_autograd"]["not_ok"] += 1
77            raise
78
79
80def aot_autograd(**kwargs):
81    return AotAutograd(**kwargs)
82
83
84def mem_efficient_fusion_kwargs(use_decomps):
85    from functorch.compile import (
86        default_decompositions,
87        min_cut_rematerialization_partition,
88        ts_compile,
89    )
90
91    kwargs = {
92        # these are taken from memory_efficient_fusion()
93        "fw_compiler": ts_compile,
94        "bw_compiler": ts_compile,
95        "partition_fn": min_cut_rematerialization_partition,
96    }
97
98    if use_decomps:
99        kwargs["decompositions"] = default_decompositions
100
101    return kwargs
102
103
104def fake_tensor_unsupported(fn):
105    """
106    Decorator for backends that need real inputs.  We swap out fake
107    tensors for zero tensors.
108    """
109
110    @functools.wraps(fn)
111    def wrapper(model, inputs, **kwargs):
112        with _disable_current_modes():
113            inputs = list(map(defake, inputs))
114            return fn(model, inputs, **kwargs)
115
116    return wrapper
117
118
119def device_from_inputs(example_inputs) -> torch.device:
120    for x in example_inputs:
121        if hasattr(x, "device"):
122            return x.device
123
124
125def dtype_from_inputs(example_inputs) -> torch.dtype:
126    for x in example_inputs:
127        if hasattr(x, "dtype"):
128            return x.dtype
129