1# mypy: allow-untyped-defs 2""" 3Utils for caching the outputs of AOTAutograd 4""" 5from __future__ import annotations 6 7import json 8import logging 9import os 10import pickle 11import shutil 12import time 13from dataclasses import dataclass 14from typing import Callable, Dict, List, Optional, Tuple, TYPE_CHECKING, Union 15 16import torch 17from torch._dynamo.utils import counters, get_chromium_event_logger 18from torch._functorch import config 19from torch._inductor.codecache import ( 20 _ident, 21 BypassFxGraphCache, 22 CompiledFxGraph, 23 extract_tensor_metadata_for_cache_key, 24 FxGraphCache, 25 FxGraphCachePickler, 26 FxGraphHashDetails, 27 write_atomic, 28) 29from torch._inductor.runtime.runtime_utils import cache_dir 30from torch._logging import LazyString 31 32from .runtime_wrappers import ( 33 AOTDispatchAutograd, 34 AOTDispatchSubclassWrapper, 35 CompilerWrapper, 36 FunctionalizedRngRuntimeWrapper, 37 post_compile, 38 RuntimeWrapper, 39 SubclassMeta, 40) 41from .schemas import AOTConfig, ViewAndMutationMeta # noqa: F401 42 43 44if TYPE_CHECKING: 45 from torch._inductor.utils import BoxedBool 46 from torch.fx.node import Node 47log = logging.getLogger(__name__) 48 49 50class BypassAOTAutogradCache(Exception): 51 pass 52 53 54# Used to signify when FXGraphCache missed when AOTAutogradCache uses it 55class FXGraphCacheMiss(BypassAOTAutogradCache): 56 pass 57 58 59def check_node_safe(node: Node): 60 """ 61 Checks that the node only uses supported operators. We are starting with very 62 conservative cacheability constraints, and incrementally adding more support as we expand. 63 64 [Note: AOTAutograd Cacheability checks] 65 - Our cache key is computed from the FX graph produced by Dynamo and the input example values 66 - A node is "safe" if the same cache key results in a compiled artifact that has the same behavior 67 (i.e, the set of inputs that go into our cache key is sufficient to distinguish its behavior) 68 69 To accomplish this safety check, we consider the following functions to be safe: 70 - Public functions under modules torch, torch.functional, and torch.nn.functional: these are 71 allowed in the graph by dynamo, so we can assume they are safe to cache. 72 - method calls on base tensor types 73 - Any call_module that dynamo deemed safe to allow AOTAutograd to trace 74 - Non callable nodes, such as placeholder, output, get_attr 75 76 The test suite test_aot_autograd_cache.py::AOTAutogradCachePicklerTests tries its best to fully cover/specify this behavior. 77 """ 78 SAFE_TORCH_MODULES = ("torch.functional", "torch.nn.functional") 79 80 def is_public_torch_api(target): 81 # Don't blindly allow private functions in the torch namespace 82 is_private = target.__name__.startswith("_") 83 return ( 84 getattr(target, "__module__", None) in SAFE_TORCH_MODULES and not is_private 85 ) 86 87 def is_torch_function(target): 88 if isinstance(target, torch._ops.OpOverload): 89 return True 90 if is_public_torch_api(target): 91 return True 92 is_builtin_fun_or_type = type(target).__name__ == "builtin_function_or_method" 93 return is_builtin_fun_or_type 94 95 def is_tensor(target): 96 # Tensors always have example values in meta field 97 return "example_value" in target.meta 98 99 # I'd love to use a match statement here, but it wasn't introduced until py3.10 100 if node.op == "call_function": 101 # We support only torch.* functions for now 102 # We can probably add an allowlist of safe non-torch implementations as well 103 if not is_torch_function(node.target): 104 raise BypassAOTAutogradCache( 105 f"Unsupported call_function target {node.target}" 106 ) 107 elif node.op == "call_method": 108 method_name = node.target 109 method_target = node.args[0] 110 # Only support method calls on base tensors 111 if not is_tensor(method_target): 112 raise BypassAOTAutogradCache( 113 f"Unsupported call_method target {method_target}" 114 ) 115 if ( 116 type(method_name) != str 117 and type(method_name).__name__ != "method_descriptor" 118 ): 119 raise BypassAOTAutogradCache( 120 f"Unsupported call_method method {node.target}: {method_name}" 121 ) 122 # Cache safe 123 elif node.op in ("placeholder", "get_attr", "call_module", "output"): 124 # Assumption today for call_module being a safe op: 125 # (1) today the only call_module ops that can show up in a graph come from "built-in-nn-modules" 126 # that dynamo assumes are safe to trace. If dynamo assumes they are safely to blindly trace, then 127 # they should be safe to cache as well. 128 # (2) in the steady-state (some time in H2?) we shouldn't see these anymore, once inline builtin nn modules by default 129 # (3) We do not allow user made nn modules in the graph today, only function calls. 130 pass 131 else: 132 raise BypassAOTAutogradCache(f"Unsupported node op {node.op}") 133 134 135def check_cacheable(gm: torch.fx.GraphModule): 136 """ 137 Checks that the graph module only uses supported operators 138 """ 139 nodes = gm.graph.nodes 140 if torch._dynamo.compiled_autograd.in_compiled_autograd_region: 141 raise BypassAOTAutogradCache( 142 "Cannot cache a graph with compiled autograd enabled" 143 ) 144 145 if not torch._inductor.config.fx_graph_cache: 146 raise BypassAOTAutogradCache("FX graph cache is not enabled") 147 148 tracing_context = torch._guards.TracingContext.try_get() 149 if tracing_context and tracing_context.fakify_first_call: 150 raise BypassAOTAutogradCache( 151 "Won't cache a graph with fakify_first_call enabled" 152 ) 153 for node in nodes: 154 check_node_safe(node) 155 156 157class AOTAutogradCacheDetails(FxGraphHashDetails): 158 """ 159 Object to capture all the details for a dynamo graph module relevant to computing 160 a safe and stable cache key for AOTAutograd. 161 """ 162 163 def __init__( 164 self, 165 gm: torch.fx.GraphModule, 166 example_inputs, 167 aot_config: AOTConfig, 168 fx_config: Dict[str, BoxedBool], 169 ): 170 # FxGraphHashDetails contains all the keys related to inductor. Also includes some system info 171 self.aot_config = aot_config 172 self.grad_enabled = torch.is_grad_enabled() 173 self.disable_amp = torch._C._is_any_autocast_enabled() 174 self.deterministic_algorithms = torch.are_deterministic_algorithms_enabled() 175 self.autograd_config = config.save_config() 176 try: 177 # TODO: example_inputs causes more cache misses than necessary 178 # with dynamic shapes, because this is before we add 179 # symints to tensor metadata. Improve this later. 180 super().__init__(gm, example_inputs, fx_config, []) 181 except BypassFxGraphCache as e: 182 # Sometimes inductor configs are unpickleable and can fail 183 raise BypassAOTAutogradCache from e 184 185 def debug_lines(self) -> List[str]: 186 return AOTAutogradCachePickler.debug_lines(self) 187 188 189def _reduce_aot_config(aot_config: AOTConfig): 190 """ 191 Reduce the config to a stable key for caching. 192 """ 193 return ( 194 _ident, 195 ( 196 aot_config.num_params_buffers, 197 aot_config.keep_inference_input_mutations, 198 aot_config.is_export, 199 aot_config.no_tangents, 200 aot_config.dynamic_shapes, 201 aot_config.aot_autograd_arg_pos_to_source, 202 aot_config.enable_log, 203 aot_config.pre_dispatch, 204 ), 205 ) 206 207 208def _reduce_tensor(tensor): 209 """ 210 Reduce the tensor to a stable key for caching. 211 """ 212 return ( 213 _ident, 214 ( 215 extract_tensor_metadata_for_cache_key( 216 FxGraphCachePickler._device_map, tensor 217 ), 218 ), 219 ) 220 221 222class AOTAutogradCachePickler(FxGraphCachePickler): 223 dispatch_table = FxGraphCachePickler.dispatch_table.copy() 224 dispatch_table[AOTConfig] = _reduce_aot_config 225 dispatch_table[torch.Tensor] = _reduce_tensor 226 227 228def autograd_cache_key( 229 gm: torch.fx.GraphModule, 230 example_inputs, 231 config: AOTConfig, 232 fx_config: Dict[str, BoxedBool], 233 # TODO: add args and parameters 234) -> Tuple[str, List[str]]: 235 """ 236 Generate a unique hash of the FX graph for caching. 237 """ 238 check_cacheable(gm) 239 details = AOTAutogradCacheDetails(gm, example_inputs, config, fx_config) 240 # The prefix distinguishes among the other kinds of objects we cache 241 key = "a" + AOTAutogradCachePickler.get_hash(details) 242 debug_lines = details.debug_lines() 243 log.debug( 244 "Autograd graph cache hash details for key %s:\n%s", 245 key, 246 LazyString(lambda: "\n".join(debug_lines)), 247 ) 248 return key, debug_lines 249 250 251@dataclass 252class FXGraphCacheLoadable: 253 fx_graph_cache_key: str 254 255 def load(self, example_inputs, fx_config: Dict[str, BoxedBool]) -> CompiledFxGraph: 256 # [Note: AOTAutogradCache and FXGraphCache Guard interactions] 257 # As mentioned, AOTAutograd takes in the symint inputs from dynamo's list of arguments. 258 # FXGraphCache serializes guards that are needed in the shape_env based on these symint inputs to the graph. 259 # The invariant that AOTAutograd uses here is that the sources for symints given to it by dynamo are exactly 260 # the same as the ones it passes to inductor, for both the forward and backward passes. 261 # (This does not mean that the tensor values passed in are the same: only that their symints are). 262 # That is, AOTAutograd and Inductor never create new guards based on symints with different sources 263 # than those passed to it by inductor. 264 result = FxGraphCache._lookup_graph( 265 self.fx_graph_cache_key, example_inputs, local=True, remote_cache=None 266 ) 267 if result is None: 268 log.info("FXGraphCache cache miss for key %s", self.fx_graph_cache_key) 269 counters["inductor"]["fxgraph_cache_miss"] += 1 270 raise FXGraphCacheMiss 271 FxGraphCache.post_compile(result, example_inputs, fx_config["cudagraphs"]) 272 counters["inductor"]["fxgraph_cache_hit"] += 1 273 result._boxed_call = True 274 return result 275 276 277@dataclass 278class CompiledForward(FXGraphCacheLoadable): 279 """ 280 Cacheable entry for a forward function 281 """ 282 283 284@dataclass 285class CompiledBackward(FXGraphCacheLoadable): 286 """ 287 Cacheable entry for a forward function 288 """ 289 290 # Used by AOTDispatchAutograd.post_compile 291 backward_state_indices: List[int] 292 num_symints_saved_for_bw_: int 293 294 295@dataclass 296class AOTAutogradCacheEntry: 297 """A single entry into the cache.""" 298 299 # Forward and Backward info 300 compiled_fw: CompiledForward 301 compiled_bw: Optional[CompiledBackward] 302 303 # Runtime_metadata saved right before compilation 304 runtime_metadata: ViewAndMutationMeta 305 306 # Wrappers that run after each aot_dispatch_* function 307 dispatch_wrappers: List[CompilerWrapper] 308 309 # Used by AOTSubclassWrapper 310 maybe_subclass_meta: Optional[SubclassMeta] 311 num_fw_outs_saved_for_bw: Optional[int] 312 313 # Used by RuntimeWrapepr 314 indices_of_inps_to_detach: List[int] 315 316 # Turn cache entry into the original callable 317 def wrap_post_compile( 318 self, 319 args: List[torch.Tensor], 320 aot_config: AOTConfig, 321 fx_config: Dict[str, BoxedBool], 322 ) -> Callable: 323 """ 324 This function takes a cache entry and carefully reconstructs the original callable 325 that AOTAutograd returned the first time it was run. It does this by running the various 326 post compile steps that AOTAutograd runs on its compiled artifact after running the fw/bw compilers. 327 328 In the inference path, this consists of the Subclass, FunctionalzedRngRuntime, and RuntimeWrappers. 329 In the autograd path, this consists of AOTAutogradDispatch.post_compile. 330 331 The steps here should match exactly the steps that are run in aot_dispatch_base and aot_dispatch_autograd. 332 333 Notably absent from the cached path are: 334 - DebugAssertWrapper 335 - FakifiedOutWrapper 336 337 Which we'll handle separately later on, if necessary. 338 """ 339 compiled_fw_func = self.compiled_fw.load(args, fx_config) 340 compiled_bw_func = None 341 if self.compiled_bw is not None: 342 compiled_bw_func = self.compiled_bw.load(args, fx_config) 343 needs_autograd = True 344 else: 345 needs_autograd = False 346 347 # Wrap the forward function in post compile wrappers 348 compiled_fw_func = AOTDispatchSubclassWrapper( 349 trace_joint=needs_autograd, 350 fw_only=None, 351 maybe_subclass_meta=self.maybe_subclass_meta, 352 num_fw_outs_saved_for_bw=self.num_fw_outs_saved_for_bw, 353 ).post_compile( 354 compiled_fw_func, aot_config, runtime_metadata=self.runtime_metadata 355 ) 356 357 # In autograd case, functionalizedRngWrapper should not modify outs 358 return_new_outs = not needs_autograd 359 compiled_fw_func = FunctionalizedRngRuntimeWrapper( 360 return_new_outs=return_new_outs 361 ).post_compile( 362 compiled_fw_func, aot_config, runtime_metadata=self.runtime_metadata 363 ) 364 disable_amp = torch._C._is_any_autocast_enabled() 365 366 if needs_autograd: 367 assert self.compiled_bw is not None 368 # This function is run on both cache miss and cache hit, either here 369 # or in aot_dispatch_autograd. On a cache hit, 370 # 1. the bw is already compiled 371 # 2. we don't need to save to the cache again 372 # so those corresponding arguments are set to None. 373 compiled_function = AOTDispatchAutograd.post_compile( 374 compiled_fw_func, 375 compiled_bw_func, 376 self.maybe_subclass_meta, 377 self.compiled_bw.num_symints_saved_for_bw_, 378 self.compiled_bw.backward_state_indices, 379 disable_amp, 380 self.indices_of_inps_to_detach, 381 None, # lazy_backward_info 382 aot_config, 383 fw_metadata=self.runtime_metadata, 384 try_save_cache_entry=None, 385 ) 386 else: 387 compiled_function = RuntimeWrapper( 388 indices_of_inps_to_detach=self.indices_of_inps_to_detach, 389 trace_joint=False, 390 disable_amp=disable_amp, 391 ).post_compile( 392 compiled_fw_func, aot_config, runtime_metadata=self.runtime_metadata 393 ) 394 395 compiled_function, _ = post_compile( 396 self.dispatch_wrappers, 397 compiled_function, 398 aot_config, 399 runtime_metadata=self.runtime_metadata, 400 ) 401 402 return compiled_function 403 404 405class AOTAutogradCache: 406 """ 407 Caches the results of running AOTAutograd. This class mostly handles the save and load logic, whereas 408 AOTAutogradCacheEntry handles the wrapping/unwrapping logic. 409 410 Cache Inputs (AOTAutogradCacheDetails) 411 - AOTAutogradCache takes in the following inputs, which are analogous to inputs given 412 to AOTAutograd by dynamo: 413 - A fx graph module generated by dynamo 414 - A list of args, which consists of: 415 - Symint inputs to the graph, generated by dynamo 416 - The **real tensor** inputs, which inductor uses for cudagraphs 417 - Notably, the real tensor inputs don't have symints in their metadata. 418 AOTAutograd then retraces those real tensor arguments into FakeTensors later during execution. 419 - A set of global configurations that affect AOTAutograd or Inductor behavior. 420 421 It then generates a cache key given these values. Notably, this means AOTAutogradCache currently 422 specializes on the sizes and strides of the real tensor inputs when dynamic shapes are turned on. 423 In a later PR, we'll likely generate the cache key based on the FakeTensors AOTAutograd generates 424 based on the real tensor inputs, which can contain symints. 425 426 # Cache Outputs (AOTAutogradCacheEntry) 427 - AOTAutogradCache caches the following values: 428 - The compiled forward and backward functions from inductor, via keys to the FXGraphCache 429 - Metadata to reconstruct the AOTModule from the compiled inductor artifacts 430 - See AOTAutogradCacheEntry for more info 431 432 [Note: Caching guards generated by AOTAutograd and Inductor] 433 AOTAutograd and inductor both can introduce new guards to the shape environment. FXGraphCache saves guards with each 434 compiled graph inductor generates. On a cache hit, AOTAutograd reloads the compiled forward and backward functions 435 from FXGraphCache, giving it new symint arguments from the input args. 436 FXGraphCache uses those symints and its saved guards to repopulate the ShapeEnv with guards. 437 **No new guards are generated into the shape env after inductor finishes compiling**, so the guards 438 saved by inductor are sufficient for correctness for both AOTAutograd and Inductor's caches. 439 """ 440 441 @staticmethod 442 def clear(): 443 """Clear the cache""" 444 try: 445 shutil.rmtree(AOTAutogradCache._get_tmp_dir()) 446 except FileNotFoundError: 447 pass 448 449 @staticmethod 450 def load( 451 dispatch_and_compile: Callable, 452 mod: Union[torch.fx.GraphModule, torch._dynamo.utils.GmWrapper], 453 args, 454 aot_config: AOTConfig, 455 cudagraphs: BoxedBool, 456 ) -> Callable: 457 """ 458 Load a result from the cache, and reconstruct a runtime wrapper around the object 459 """ 460 gm = mod.gm if isinstance(mod, torch._dynamo.utils.GmWrapper) else mod 461 compiled_fn = None 462 cache_key = None 463 debug_lines: List[str] = [] 464 cache_event_time = time.time_ns() 465 cache_state = None 466 fx_config = {"cudagraphs": cudagraphs} 467 try: 468 cache_key, debug_lines = autograd_cache_key(gm, args, aot_config, fx_config) 469 entry: Optional[AOTAutogradCacheEntry] = AOTAutogradCache._lookup(cache_key) 470 if entry is not None: 471 compiled_fn = entry.wrap_post_compile(args, aot_config, fx_config) 472 log.info("AOTAutograd cache hit for key %s", cache_key) 473 counters["aot_autograd"]["autograd_cache_hit"] += 1 474 cache_state = "hit" 475 cache_event_time = time.time_ns() 476 if compiled_fn is None: 477 log.info("AOTAutograd cache miss for key %s", cache_key) 478 counters["aot_autograd"]["autograd_cache_miss"] += 1 479 cache_state = "miss" 480 cache_event_time = time.time_ns() 481 # Count missing the FXGraphCache as a miss not a bypass 482 except FXGraphCacheMiss as e: 483 counters["aot_autograd"]["autograd_cache_miss"] += 1 484 # Special counter when we pass autograd cache but 485 # fail when on inductor guards 486 counters["aot_autograd"]["autograd_cache_guard_miss"] += 1 487 if config.strict_autograd_cache: 488 raise e 489 except BypassAOTAutogradCache as e: 490 cache_key = None 491 counters["aot_autograd"]["autograd_cache_bypass"] += 1 492 cache_state = "bypass" 493 cache_event_time = time.time_ns() 494 if config.strict_autograd_cache: 495 raise e 496 if compiled_fn is None: 497 # Set the cache key so we can save a cache result later 498 aot_config.cache_key = cache_key 499 compiled_fn = dispatch_and_compile() 500 cache_args = { 501 "key": cache_key, 502 "cache_state": cache_state, 503 "components": debug_lines, 504 } 505 chromium_log = get_chromium_event_logger() 506 chromium_log.log_instant_event( 507 f"autograd_cache_{cache_state}", cache_event_time, metadata=cache_args 508 ) 509 torch._logging.trace_structured( 510 "artifact", 511 metadata_fn=lambda: { 512 "name": "aotautograd_cache_hash", 513 "encoding": "json", 514 }, 515 payload_fn=lambda: json.dumps(cache_args), 516 ) 517 return compiled_fn 518 519 @staticmethod 520 def _get_tmp_dir() -> str: 521 """ 522 Get the toplevel temporary directory for storing compiled graphs. 523 """ 524 return os.path.join(cache_dir(), "aotautograd") 525 526 @staticmethod 527 def _lookup(key: str) -> Optional[AOTAutogradCacheEntry]: 528 """Given a key generated by AOTAutogradCachePickler, look up its location in the cache.""" 529 subdir = os.path.join(AOTAutogradCache._get_tmp_dir(), key) 530 if not os.path.exists(subdir): 531 return None 532 path = os.path.join(subdir, "entry") 533 try: 534 with open(path, "rb") as f: 535 entry: AOTAutogradCacheEntry = pickle.load(f) 536 return entry 537 except Exception as e: 538 log.warning("AOTAutograd cache unable to load compiled graph: %s", e) 539 if config.strict_autograd_cache: 540 raise e 541 return None 542 543 @staticmethod 544 def save(key: str, entry: AOTAutogradCacheEntry): 545 """Save a single entry into the cache.""" 546 try: 547 content = pickle.dumps(entry) 548 except Exception as e: 549 log.warning("AOTAutograd cache unable to serialize compiled graph: %s", e) 550 if config.strict_autograd_cache: 551 raise e 552 return None 553 subdir = os.path.join(AOTAutogradCache._get_tmp_dir(), key) 554 if not os.path.exists(subdir): 555 os.makedirs(subdir, exist_ok=True) 556 path = os.path.join(subdir, "entry") 557 log.info("Writing AOTAutograd cache entry to %s", path) 558 write_atomic(path, content) 559 counters["aot_autograd"]["autograd_cache_saved"] += 1 560