xref: /aosp_15_r20/external/pytorch/torch/_inductor/runtime/autotune_cache.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from __future__ import annotations
2
3import dataclasses
4import hashlib
5import logging
6import os
7import os.path
8from typing import Dict, List, Optional, Tuple
9from typing_extensions import override
10
11import torch
12from torch.utils._triton import has_triton_package
13
14from ..remote_cache import (
15    JsonDataTy,
16    RemoteCache,
17    RemoteCacheBackend,
18    RemoteCacheJsonSerde,
19)
20
21
22if has_triton_package():
23    from triton import Config
24
25log = logging.getLogger(__name__)
26
27
28_InductorMetaTy = Dict[str, object]
29
30
31@dataclasses.dataclass
32class AutotuneCache:
33    configs_hash: str
34    filename: str
35    local_cache: Optional[Tuple[RemoteCache[JsonDataTy], str]] = None
36    remote_cache: Optional[Tuple[RemoteCache[JsonDataTy], str]] = None
37
38    # Create a AutotuneCache. Returns None if none of the caches can be used.
39    @staticmethod
40    def create(
41        inductor_meta: _InductorMetaTy, filename: str, configs_hash: str
42    ) -> Optional[AutotuneCache]:
43        cache = AutotuneCache(configs_hash, filename)
44        cache._setup_local_cache(inductor_meta, filename)
45        cache._setup_remote_autotune_cache(inductor_meta, filename)
46        if cache.local_cache or cache.remote_cache:
47            return cache
48        else:
49            return None
50
51    # Read the best config options from the most local cache and return it.
52    def _read(self, inductor_meta: _InductorMetaTy) -> Optional[Dict[str, JsonDataTy]]:
53        if local_cache := self.local_cache:
54            cache, key = local_cache
55            if best_config := cache.get(key):
56                if isinstance(best_config, dict):
57                    return best_config
58
59        if remote_cache := self.remote_cache:
60            cache, key = remote_cache
61            if best_config := cache.get(key):
62                if isinstance(best_config, dict):
63                    return best_config
64
65        return None
66
67    # Read the best config options from the most local cache and figure out
68    # which `configs` represents that option.
69    def read_best(
70        self, inductor_meta: _InductorMetaTy, configs: List[Config]
71    ) -> Optional[Config]:
72        if best := self._read(inductor_meta):
73            return _load_cached_autotuning(
74                best, self.configs_hash, configs, inductor_meta
75            )
76        return None
77
78    # Set up local filesystem caching information
79    def _setup_local_cache(self, inductor_meta: _InductorMetaTy, filename: str) -> None:
80        if not inductor_meta.get("autotune_local_cache", True):
81            return
82
83        cache_filename = os.path.splitext(filename)[0] + ".best_config"
84        local_cache = RemoteCache(_LocalAutotuneCacheBackend(), RemoteCacheJsonSerde())
85        self.local_cache = (local_cache, cache_filename)
86
87    # Set up remote caching information
88    def _setup_remote_autotune_cache(
89        self, inductor_meta: _InductorMetaTy, filename: str
90    ) -> None:
91        if not _should_use_remote_autotune_cache(inductor_meta):
92            return
93
94        remote_cache = _create_cache(
95            inductor_meta,
96            self.configs_hash,
97            "FbRemoteAutotuneCache",
98            "RemoteAutotuneCache",
99            "autotune-best-config-v2",
100        )
101        if not remote_cache:
102            return
103
104        # we already sha256 hash the source contents
105        remote_cache_key = os.path.basename(filename)
106        self.remote_cache = (remote_cache, remote_cache_key)
107
108    # Save the config in the caches
109    def save(
110        self, config: Config, time_taken_ns: int, found_by_coordesc: bool = False
111    ) -> None:
112        data = {
113            **config.kwargs,
114            "num_warps": config.num_warps,
115            "num_stages": config.num_stages,
116            "configs_hash": self.configs_hash,
117            "found_by_coordesc": found_by_coordesc,
118            "time_taken_ms": time_taken_ns // 1000000,  # Convert from NS to MS
119        }
120
121        if local_cache := self.local_cache:
122            cache, key = local_cache
123            cache.put(key, data)
124
125            if log.isEnabledFor(logging.DEBUG):
126                type_str = "coordesc" if found_by_coordesc else "heuristic"
127                log.debug("Save %s tuning result to %s", type_str, key)
128
129        if remote_cache := self.remote_cache:
130            cache, key = remote_cache
131            cache.put(key, data)
132
133
134def _should_use_remote_autotune_cache(inductor_meta: Dict[str, object]) -> bool:
135    if (config := inductor_meta.get("autotune_remote_cache")) is not None:
136        return bool(config)
137    if not inductor_meta.get("is_fbcode"):
138        return False
139    if torch._utils_internal.is_fb_unit_test():
140        return False
141    if inductor_meta.get("is_hip"):
142        return False
143
144    try:
145        from torch._inductor.fb.remote_cache import REMOTE_CACHE_VERSION
146    except ModuleNotFoundError:
147        return False
148
149    return REMOTE_CACHE_VERSION >= torch._utils_internal.justknobs_getval_int(
150        "pytorch/remote_cache:autotune_memcache_version"
151    )
152
153
154def _load_cached_autotuning(
155    best_config: Dict[str, JsonDataTy],
156    configs_hash: str,
157    configs: List[Config],
158    inductor_meta: Dict[str, object],
159) -> Optional[Config]:
160    if best_config is None:
161        return None
162    if best_config.pop("configs_hash", None) != configs_hash:
163        return None
164
165    # Remove time taken for comparison
166    best_config.pop("time_taken_ms", None)
167
168    if inductor_meta.get("coordinate_descent_tuning") and best_config.pop(
169        "found_by_coordesc", False
170    ):
171        num_warps = best_config.pop("num_warps")
172        num_stages = best_config.pop("num_stages")
173        triton_config = Config(best_config, num_warps=num_warps, num_stages=num_stages)
174        triton_config.found_by_coordesc = True
175        return triton_config
176
177    matching_configs = [
178        cfg
179        for cfg in configs
180        if all(val == best_config.get(key) for key, val in cfg.kwargs.items())
181        and cfg.num_warps == best_config.get("num_warps")
182        and cfg.num_stages == best_config.get("num_stages")
183    ]
184    if len(matching_configs) != 1:
185        return None
186
187    return matching_configs[0]
188
189
190def _create_cache(
191    inductor_meta: Dict[str, object],
192    configs_hash: str,
193    fb_cache_cls: str,
194    oss_cache_cls: str,
195    salt: str,
196) -> Optional[RemoteCache[JsonDataTy]]:
197    backend_hash = inductor_meta.get("backend_hash", None)
198    if backend_hash is None:
199        log.debug(
200            "backend_hash is not passed on the inductor_meta, unable to use autotune remote cache"
201        )
202        return None
203
204    assert isinstance(backend_hash, str)
205
206    key = backend_hash + configs_hash + salt
207    key = hashlib.sha256(key.encode("utf-8")).hexdigest()
208
209    try:
210        if inductor_meta.get("is_fbcode"):
211            import torch._inductor.fb.remote_cache
212
213            cache_cls = getattr(torch._inductor.fb.remote_cache, fb_cache_cls)
214            return cache_cls(key)
215        else:
216            import torch._inductor.remote_cache
217
218            cache_cls = getattr(torch._inductor.remote_cache, oss_cache_cls)
219            return cache_cls(key)
220    except Exception:
221        log.warning("Unable to create a remote cache", exc_info=True)
222        return None
223
224
225class _LocalAutotuneCacheBackend(RemoteCacheBackend[bytes]):
226    @override
227    def get(self, key: str) -> Optional[bytes]:
228        try:
229            with open(key, "rb") as fd:
230                return fd.read()
231        except FileNotFoundError:
232            return None
233
234    @override
235    def put(self, key: str, data: bytes) -> None:
236        with open(key, "wb") as fd:
237            fd.write(data)
238