xref: /aosp_15_r20/external/pytorch/torch/_functorch/_aot_autograd/autograd_cache.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2"""
3Utils for caching the outputs of AOTAutograd
4"""
5from __future__ import annotations
6
7import json
8import logging
9import os
10import pickle
11import shutil
12import time
13from dataclasses import dataclass
14from typing import Callable, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
15
16import torch
17from torch._dynamo.utils import counters, get_chromium_event_logger
18from torch._functorch import config
19from torch._inductor.codecache import (
20    _ident,
21    BypassFxGraphCache,
22    CompiledFxGraph,
23    extract_tensor_metadata_for_cache_key,
24    FxGraphCache,
25    FxGraphCachePickler,
26    FxGraphHashDetails,
27    write_atomic,
28)
29from torch._inductor.runtime.runtime_utils import cache_dir
30from torch._logging import LazyString
31
32from .runtime_wrappers import (
33    AOTDispatchAutograd,
34    AOTDispatchSubclassWrapper,
35    CompilerWrapper,
36    FunctionalizedRngRuntimeWrapper,
37    post_compile,
38    RuntimeWrapper,
39    SubclassMeta,
40)
41from .schemas import AOTConfig, ViewAndMutationMeta  # noqa: F401
42
43
44if TYPE_CHECKING:
45    from torch._inductor.utils import BoxedBool
46    from torch.fx.node import Node
47log = logging.getLogger(__name__)
48
49
50class BypassAOTAutogradCache(Exception):
51    pass
52
53
54# Used to signify when FXGraphCache missed when AOTAutogradCache uses it
55class FXGraphCacheMiss(BypassAOTAutogradCache):
56    pass
57
58
59def check_node_safe(node: Node):
60    """
61    Checks that the node only uses supported operators. We are starting with very
62    conservative cacheability constraints, and incrementally adding more support as we expand.
63
64    [Note: AOTAutograd Cacheability checks]
65    - Our cache key is computed from the FX graph produced by Dynamo and the input example values
66    - A node is "safe" if the same cache key results in a compiled artifact that has the same behavior
67        (i.e, the set of inputs that go into our cache key is sufficient to distinguish its behavior)
68
69    To accomplish this safety check, we consider the following functions to be safe:
70        - Public functions under modules torch, torch.functional, and torch.nn.functional: these are
71        allowed in the graph by dynamo, so we can assume they are safe to cache.
72        - method calls on base tensor types
73        - Any call_module that dynamo deemed safe to allow AOTAutograd to trace
74        - Non callable nodes, such as placeholder, output, get_attr
75
76    The test suite test_aot_autograd_cache.py::AOTAutogradCachePicklerTests tries its best to fully cover/specify this behavior.
77    """
78    SAFE_TORCH_MODULES = ("torch.functional", "torch.nn.functional")
79
80    def is_public_torch_api(target):
81        # Don't blindly allow private functions in the torch namespace
82        is_private = target.__name__.startswith("_")
83        return (
84            getattr(target, "__module__", None) in SAFE_TORCH_MODULES and not is_private
85        )
86
87    def is_torch_function(target):
88        if isinstance(target, torch._ops.OpOverload):
89            return True
90        if is_public_torch_api(target):
91            return True
92        is_builtin_fun_or_type = type(target).__name__ == "builtin_function_or_method"
93        return is_builtin_fun_or_type
94
95    def is_tensor(target):
96        # Tensors always have example values in meta field
97        return "example_value" in target.meta
98
99    # I'd love to use a match statement here, but it wasn't introduced until py3.10
100    if node.op == "call_function":
101        # We support only torch.* functions for now
102        # We can probably add an allowlist of safe non-torch implementations as well
103        if not is_torch_function(node.target):
104            raise BypassAOTAutogradCache(
105                f"Unsupported call_function target {node.target}"
106            )
107    elif node.op == "call_method":
108        method_name = node.target
109        method_target = node.args[0]
110        # Only support method calls on base tensors
111        if not is_tensor(method_target):
112            raise BypassAOTAutogradCache(
113                f"Unsupported call_method target {method_target}"
114            )
115        if (
116            type(method_name) != str
117            and type(method_name).__name__ != "method_descriptor"
118        ):
119            raise BypassAOTAutogradCache(
120                f"Unsupported call_method method {node.target}: {method_name}"
121            )
122    # Cache safe
123    elif node.op in ("placeholder", "get_attr", "call_module", "output"):
124        # Assumption today for call_module being a safe op:
125        # (1) today the only call_module ops that can show up in a graph come from "built-in-nn-modules"
126        # that dynamo assumes are safe to trace. If dynamo assumes they are safely to blindly trace, then
127        # they should be safe to cache as well.
128        # (2) in the steady-state (some time in H2?) we shouldn't see these anymore, once inline builtin nn modules by default
129        # (3) We do not allow user made nn modules in the graph today, only function calls.
130        pass
131    else:
132        raise BypassAOTAutogradCache(f"Unsupported node op {node.op}")
133
134
135def check_cacheable(gm: torch.fx.GraphModule):
136    """
137    Checks that the graph module only uses supported operators
138    """
139    nodes = gm.graph.nodes
140    if torch._dynamo.compiled_autograd.in_compiled_autograd_region:
141        raise BypassAOTAutogradCache(
142            "Cannot cache a graph with compiled autograd enabled"
143        )
144
145    if not torch._inductor.config.fx_graph_cache:
146        raise BypassAOTAutogradCache("FX graph cache is not enabled")
147
148    tracing_context = torch._guards.TracingContext.try_get()
149    if tracing_context and tracing_context.fakify_first_call:
150        raise BypassAOTAutogradCache(
151            "Won't cache a graph with fakify_first_call enabled"
152        )
153    for node in nodes:
154        check_node_safe(node)
155
156
157class AOTAutogradCacheDetails(FxGraphHashDetails):
158    """
159    Object to capture all the details for a dynamo graph module relevant to computing
160    a safe and stable cache key for AOTAutograd.
161    """
162
163    def __init__(
164        self,
165        gm: torch.fx.GraphModule,
166        example_inputs,
167        aot_config: AOTConfig,
168        fx_config: Dict[str, BoxedBool],
169    ):
170        # FxGraphHashDetails contains all the keys related to inductor. Also includes some system info
171        self.aot_config = aot_config
172        self.grad_enabled = torch.is_grad_enabled()
173        self.disable_amp = torch._C._is_any_autocast_enabled()
174        self.deterministic_algorithms = torch.are_deterministic_algorithms_enabled()
175        self.autograd_config = config.save_config()
176        try:
177            # TODO: example_inputs causes more cache misses than necessary
178            # with dynamic shapes, because this is before we add
179            # symints to tensor metadata. Improve this later.
180            super().__init__(gm, example_inputs, fx_config, [])
181        except BypassFxGraphCache as e:
182            # Sometimes inductor configs are unpickleable and can fail
183            raise BypassAOTAutogradCache from e
184
185    def debug_lines(self) -> List[str]:
186        return AOTAutogradCachePickler.debug_lines(self)
187
188
189def _reduce_aot_config(aot_config: AOTConfig):
190    """
191    Reduce the config to a stable key for caching.
192    """
193    return (
194        _ident,
195        (
196            aot_config.num_params_buffers,
197            aot_config.keep_inference_input_mutations,
198            aot_config.is_export,
199            aot_config.no_tangents,
200            aot_config.dynamic_shapes,
201            aot_config.aot_autograd_arg_pos_to_source,
202            aot_config.enable_log,
203            aot_config.pre_dispatch,
204        ),
205    )
206
207
208def _reduce_tensor(tensor):
209    """
210    Reduce the tensor to a stable key for caching.
211    """
212    return (
213        _ident,
214        (
215            extract_tensor_metadata_for_cache_key(
216                FxGraphCachePickler._device_map, tensor
217            ),
218        ),
219    )
220
221
222class AOTAutogradCachePickler(FxGraphCachePickler):
223    dispatch_table = FxGraphCachePickler.dispatch_table.copy()
224    dispatch_table[AOTConfig] = _reduce_aot_config
225    dispatch_table[torch.Tensor] = _reduce_tensor
226
227
228def autograd_cache_key(
229    gm: torch.fx.GraphModule,
230    example_inputs,
231    config: AOTConfig,
232    fx_config: Dict[str, BoxedBool],
233    # TODO: add args and parameters
234) -> Tuple[str, List[str]]:
235    """
236    Generate a unique hash of the FX graph for caching.
237    """
238    check_cacheable(gm)
239    details = AOTAutogradCacheDetails(gm, example_inputs, config, fx_config)
240    # The prefix distinguishes among the other kinds of objects we cache
241    key = "a" + AOTAutogradCachePickler.get_hash(details)
242    debug_lines = details.debug_lines()
243    log.debug(
244        "Autograd graph cache hash details for key %s:\n%s",
245        key,
246        LazyString(lambda: "\n".join(debug_lines)),
247    )
248    return key, debug_lines
249
250
251@dataclass
252class FXGraphCacheLoadable:
253    fx_graph_cache_key: str
254
255    def load(self, example_inputs, fx_config: Dict[str, BoxedBool]) -> CompiledFxGraph:
256        # [Note: AOTAutogradCache and FXGraphCache Guard interactions]
257        # As mentioned, AOTAutograd takes in the symint inputs from dynamo's list of arguments.
258        # FXGraphCache serializes guards that are needed in the shape_env based on these symint inputs to the graph.
259        # The invariant that AOTAutograd uses here is that the sources for symints given to it by dynamo are exactly
260        # the same as the ones it passes to inductor, for both the forward and backward passes.
261        # (This does not mean that the tensor values passed in are the same: only that their symints are).
262        # That is, AOTAutograd and Inductor never create new guards based on symints with different sources
263        # than those passed to it by inductor.
264        result = FxGraphCache._lookup_graph(
265            self.fx_graph_cache_key, example_inputs, local=True, remote_cache=None
266        )
267        if result is None:
268            log.info("FXGraphCache cache miss for key %s", self.fx_graph_cache_key)
269            counters["inductor"]["fxgraph_cache_miss"] += 1
270            raise FXGraphCacheMiss
271        FxGraphCache.post_compile(result, example_inputs, fx_config["cudagraphs"])
272        counters["inductor"]["fxgraph_cache_hit"] += 1
273        result._boxed_call = True
274        return result
275
276
277@dataclass
278class CompiledForward(FXGraphCacheLoadable):
279    """
280    Cacheable entry for a forward function
281    """
282
283
284@dataclass
285class CompiledBackward(FXGraphCacheLoadable):
286    """
287    Cacheable entry for a forward function
288    """
289
290    # Used by AOTDispatchAutograd.post_compile
291    backward_state_indices: List[int]
292    num_symints_saved_for_bw_: int
293
294
295@dataclass
296class AOTAutogradCacheEntry:
297    """A single entry into the cache."""
298
299    # Forward and Backward info
300    compiled_fw: CompiledForward
301    compiled_bw: Optional[CompiledBackward]
302
303    # Runtime_metadata saved right before compilation
304    runtime_metadata: ViewAndMutationMeta
305
306    # Wrappers that run after each aot_dispatch_* function
307    dispatch_wrappers: List[CompilerWrapper]
308
309    # Used by AOTSubclassWrapper
310    maybe_subclass_meta: Optional[SubclassMeta]
311    num_fw_outs_saved_for_bw: Optional[int]
312
313    # Used by RuntimeWrapepr
314    indices_of_inps_to_detach: List[int]
315
316    # Turn cache entry into the original callable
317    def wrap_post_compile(
318        self,
319        args: List[torch.Tensor],
320        aot_config: AOTConfig,
321        fx_config: Dict[str, BoxedBool],
322    ) -> Callable:
323        """
324        This function takes a cache entry and carefully reconstructs the original callable
325        that AOTAutograd returned the first time it was run. It does this by running the various
326        post compile steps that AOTAutograd runs on its compiled artifact after running the fw/bw compilers.
327
328        In the inference path, this consists of the Subclass, FunctionalzedRngRuntime, and RuntimeWrappers.
329        In the autograd path, this consists of AOTAutogradDispatch.post_compile.
330
331        The steps here should match exactly the steps that are run in aot_dispatch_base and aot_dispatch_autograd.
332
333        Notably absent from the cached path are:
334        - DebugAssertWrapper
335        - FakifiedOutWrapper
336
337        Which we'll handle separately later on, if necessary.
338        """
339        compiled_fw_func = self.compiled_fw.load(args, fx_config)
340        compiled_bw_func = None
341        if self.compiled_bw is not None:
342            compiled_bw_func = self.compiled_bw.load(args, fx_config)
343            needs_autograd = True
344        else:
345            needs_autograd = False
346
347        # Wrap the forward function in post compile wrappers
348        compiled_fw_func = AOTDispatchSubclassWrapper(
349            trace_joint=needs_autograd,
350            fw_only=None,
351            maybe_subclass_meta=self.maybe_subclass_meta,
352            num_fw_outs_saved_for_bw=self.num_fw_outs_saved_for_bw,
353        ).post_compile(
354            compiled_fw_func, aot_config, runtime_metadata=self.runtime_metadata
355        )
356
357        # In autograd case, functionalizedRngWrapper should not modify outs
358        return_new_outs = not needs_autograd
359        compiled_fw_func = FunctionalizedRngRuntimeWrapper(
360            return_new_outs=return_new_outs
361        ).post_compile(
362            compiled_fw_func, aot_config, runtime_metadata=self.runtime_metadata
363        )
364        disable_amp = torch._C._is_any_autocast_enabled()
365
366        if needs_autograd:
367            assert self.compiled_bw is not None
368            # This function is run on both cache miss and cache hit, either here
369            # or in aot_dispatch_autograd. On a cache hit,
370            # 1. the bw is already compiled
371            # 2. we don't need to save to the cache again
372            # so those corresponding arguments are set to None.
373            compiled_function = AOTDispatchAutograd.post_compile(
374                compiled_fw_func,
375                compiled_bw_func,
376                self.maybe_subclass_meta,
377                self.compiled_bw.num_symints_saved_for_bw_,
378                self.compiled_bw.backward_state_indices,
379                disable_amp,
380                self.indices_of_inps_to_detach,
381                None,  # lazy_backward_info
382                aot_config,
383                fw_metadata=self.runtime_metadata,
384                try_save_cache_entry=None,
385            )
386        else:
387            compiled_function = RuntimeWrapper(
388                indices_of_inps_to_detach=self.indices_of_inps_to_detach,
389                trace_joint=False,
390                disable_amp=disable_amp,
391            ).post_compile(
392                compiled_fw_func, aot_config, runtime_metadata=self.runtime_metadata
393            )
394
395        compiled_function, _ = post_compile(
396            self.dispatch_wrappers,
397            compiled_function,
398            aot_config,
399            runtime_metadata=self.runtime_metadata,
400        )
401
402        return compiled_function
403
404
405class AOTAutogradCache:
406    """
407    Caches the results of running AOTAutograd. This class mostly handles the save and load logic, whereas
408    AOTAutogradCacheEntry handles the wrapping/unwrapping logic.
409
410    Cache Inputs (AOTAutogradCacheDetails)
411    - AOTAutogradCache takes in the following inputs, which are analogous to inputs given
412        to AOTAutograd by dynamo:
413        - A fx graph module generated by dynamo
414        - A list of args, which consists of:
415            - Symint inputs to the graph, generated by dynamo
416            - The **real tensor** inputs, which inductor uses for cudagraphs
417            - Notably, the real tensor inputs don't have symints in their metadata.
418        AOTAutograd then retraces those real tensor arguments into FakeTensors later during execution.
419        - A set of global configurations that affect AOTAutograd or Inductor behavior.
420
421    It then generates a cache key given these values. Notably, this means AOTAutogradCache currently
422    specializes on the sizes and strides of the real tensor inputs when dynamic shapes are turned on.
423    In a later PR, we'll likely generate the cache key based on the FakeTensors AOTAutograd generates
424    based on the real tensor inputs, which can contain symints.
425
426    # Cache Outputs (AOTAutogradCacheEntry)
427    - AOTAutogradCache caches the following values:
428        - The compiled forward and backward functions from inductor, via keys to the FXGraphCache
429        - Metadata to reconstruct the AOTModule from the compiled inductor artifacts
430        - See AOTAutogradCacheEntry for more info
431
432    [Note: Caching guards generated by AOTAutograd and Inductor]
433    AOTAutograd and inductor both can introduce new guards to the shape environment. FXGraphCache saves guards with each
434    compiled graph inductor generates. On a cache hit, AOTAutograd reloads the compiled forward and backward functions
435    from FXGraphCache, giving it new symint arguments from the input args.
436    FXGraphCache uses those symints and its saved guards to repopulate the ShapeEnv with guards.
437    **No new guards are generated into the shape env after inductor finishes compiling**, so the guards
438    saved by inductor are sufficient for correctness for both AOTAutograd and Inductor's caches.
439    """
440
441    @staticmethod
442    def clear():
443        """Clear the cache"""
444        try:
445            shutil.rmtree(AOTAutogradCache._get_tmp_dir())
446        except FileNotFoundError:
447            pass
448
449    @staticmethod
450    def load(
451        dispatch_and_compile: Callable,
452        mod: Union[torch.fx.GraphModule, torch._dynamo.utils.GmWrapper],
453        args,
454        aot_config: AOTConfig,
455        cudagraphs: BoxedBool,
456    ) -> Callable:
457        """
458        Load a result from the cache, and reconstruct a runtime wrapper around the object
459        """
460        gm = mod.gm if isinstance(mod, torch._dynamo.utils.GmWrapper) else mod
461        compiled_fn = None
462        cache_key = None
463        debug_lines: List[str] = []
464        cache_event_time = time.time_ns()
465        cache_state = None
466        fx_config = {"cudagraphs": cudagraphs}
467        try:
468            cache_key, debug_lines = autograd_cache_key(gm, args, aot_config, fx_config)
469            entry: Optional[AOTAutogradCacheEntry] = AOTAutogradCache._lookup(cache_key)
470            if entry is not None:
471                compiled_fn = entry.wrap_post_compile(args, aot_config, fx_config)
472                log.info("AOTAutograd cache hit for key %s", cache_key)
473                counters["aot_autograd"]["autograd_cache_hit"] += 1
474                cache_state = "hit"
475                cache_event_time = time.time_ns()
476            if compiled_fn is None:
477                log.info("AOTAutograd cache miss for key %s", cache_key)
478                counters["aot_autograd"]["autograd_cache_miss"] += 1
479                cache_state = "miss"
480                cache_event_time = time.time_ns()
481        # Count missing the FXGraphCache as a miss not a bypass
482        except FXGraphCacheMiss as e:
483            counters["aot_autograd"]["autograd_cache_miss"] += 1
484            # Special counter when we pass autograd cache but
485            # fail when on inductor guards
486            counters["aot_autograd"]["autograd_cache_guard_miss"] += 1
487            if config.strict_autograd_cache:
488                raise e
489        except BypassAOTAutogradCache as e:
490            cache_key = None
491            counters["aot_autograd"]["autograd_cache_bypass"] += 1
492            cache_state = "bypass"
493            cache_event_time = time.time_ns()
494            if config.strict_autograd_cache:
495                raise e
496        if compiled_fn is None:
497            # Set the cache key so we can save a cache result later
498            aot_config.cache_key = cache_key
499            compiled_fn = dispatch_and_compile()
500        cache_args = {
501            "key": cache_key,
502            "cache_state": cache_state,
503            "components": debug_lines,
504        }
505        chromium_log = get_chromium_event_logger()
506        chromium_log.log_instant_event(
507            f"autograd_cache_{cache_state}", cache_event_time, metadata=cache_args
508        )
509        torch._logging.trace_structured(
510            "artifact",
511            metadata_fn=lambda: {
512                "name": "aotautograd_cache_hash",
513                "encoding": "json",
514            },
515            payload_fn=lambda: json.dumps(cache_args),
516        )
517        return compiled_fn
518
519    @staticmethod
520    def _get_tmp_dir() -> str:
521        """
522        Get the toplevel temporary directory for storing compiled graphs.
523        """
524        return os.path.join(cache_dir(), "aotautograd")
525
526    @staticmethod
527    def _lookup(key: str) -> Optional[AOTAutogradCacheEntry]:
528        """Given a key generated by AOTAutogradCachePickler, look up its location in the cache."""
529        subdir = os.path.join(AOTAutogradCache._get_tmp_dir(), key)
530        if not os.path.exists(subdir):
531            return None
532        path = os.path.join(subdir, "entry")
533        try:
534            with open(path, "rb") as f:
535                entry: AOTAutogradCacheEntry = pickle.load(f)
536            return entry
537        except Exception as e:
538            log.warning("AOTAutograd cache unable to load compiled graph: %s", e)
539            if config.strict_autograd_cache:
540                raise e
541            return None
542
543    @staticmethod
544    def save(key: str, entry: AOTAutogradCacheEntry):
545        """Save a single entry into the cache."""
546        try:
547            content = pickle.dumps(entry)
548        except Exception as e:
549            log.warning("AOTAutograd cache unable to serialize compiled graph: %s", e)
550            if config.strict_autograd_cache:
551                raise e
552            return None
553        subdir = os.path.join(AOTAutogradCache._get_tmp_dir(), key)
554        if not os.path.exists(subdir):
555            os.makedirs(subdir, exist_ok=True)
556        path = os.path.join(subdir, "entry")
557        log.info("Writing AOTAutograd cache entry to %s", path)
558        write_atomic(path, content)
559        counters["aot_autograd"]["autograd_cache_saved"] += 1
560