1# mypy: ignore-errors 2 3import contextlib 4import functools 5import logging 6from unittest.mock import patch 7 8import torch 9from torch._dynamo import disable 10from torch._dynamo.utils import counters, defake, flatten_graph_inputs 11from torch._functorch.aot_autograd import aot_module_simplified 12from torch.utils._python_dispatch import _disable_current_modes 13 14 15log = logging.getLogger(__name__) 16 17 18class AotAutograd: 19 def __init__(self, **kwargs) -> None: 20 self.__name__ = "compiler_fn" 21 self.kwargs = kwargs 22 23 def __call__(self, gm: torch.fx.GraphModule, example_inputs, **kwargs): 24 if kwargs: 25 log.warning("aot_autograd-based backend ignoring extra kwargs %s", kwargs) 26 27 if any(isinstance(x, (list, tuple, dict)) for x in example_inputs): 28 return flatten_graph_inputs( 29 gm, 30 example_inputs, 31 self, 32 ) 33 34 # Hack to get around circular import problems with aot_eager_decomp_partition 35 if callable(self.kwargs.get("decompositions")): 36 self.kwargs["decompositions"] = self.kwargs["decompositions"]() 37 38 # NB: dont delete counter increment 39 counters["aot_autograd"]["total"] += 1 40 use_fallback = False 41 42 if use_fallback: 43 log.debug("Unable to use AOT Autograd because graph has mutation") 44 counters["aot_autograd"]["not_ok"] += 1 45 return gm 46 47 # OK attempt to compile 48 49 def _wrapped_bw_compiler(*args, **kwargs): 50 # stop TorchDynamo from trying to compile our generated backwards pass 51 return disable(disable(bw_compiler)(*args, **kwargs)) 52 53 bw_compiler = self.kwargs.get("bw_compiler") or self.kwargs["fw_compiler"] 54 self.kwargs["bw_compiler"] = _wrapped_bw_compiler 55 self.kwargs["inference_compiler"] = ( 56 self.kwargs.get("inference_compiler") or self.kwargs["fw_compiler"] 57 ) 58 59 from functorch.compile import nop 60 from torch._inductor.debug import enable_aot_logging 61 62 # debug asserts slow down compile time noticeably, 63 # So only default them on when the aot_eager backend is used. 64 if self.kwargs.get("fw_compiler", None) == nop: 65 patch_config = patch("functorch.compile.config.debug_assert", True) 66 else: 67 patch_config = contextlib.nullcontext() 68 69 try: 70 # NB: NOT cloned! 71 with enable_aot_logging(), patch_config: 72 cg = aot_module_simplified(gm, example_inputs, **self.kwargs) 73 counters["aot_autograd"]["ok"] += 1 74 return disable(cg) 75 except Exception: 76 counters["aot_autograd"]["not_ok"] += 1 77 raise 78 79 80def aot_autograd(**kwargs): 81 return AotAutograd(**kwargs) 82 83 84def mem_efficient_fusion_kwargs(use_decomps): 85 from functorch.compile import ( 86 default_decompositions, 87 min_cut_rematerialization_partition, 88 ts_compile, 89 ) 90 91 kwargs = { 92 # these are taken from memory_efficient_fusion() 93 "fw_compiler": ts_compile, 94 "bw_compiler": ts_compile, 95 "partition_fn": min_cut_rematerialization_partition, 96 } 97 98 if use_decomps: 99 kwargs["decompositions"] = default_decompositions 100 101 return kwargs 102 103 104def fake_tensor_unsupported(fn): 105 """ 106 Decorator for backends that need real inputs. We swap out fake 107 tensors for zero tensors. 108 """ 109 110 @functools.wraps(fn) 111 def wrapper(model, inputs, **kwargs): 112 with _disable_current_modes(): 113 inputs = list(map(defake, inputs)) 114 return fn(model, inputs, **kwargs) 115 116 return wrapper 117 118 119def device_from_inputs(example_inputs) -> torch.device: 120 for x in example_inputs: 121 if hasattr(x, "device"): 122 return x.device 123 124 125def dtype_from_inputs(example_inputs) -> torch.dtype: 126 for x in example_inputs: 127 if hasattr(x, "dtype"): 128 return x.dtype 129