1from __future__ import annotations 2 3import dataclasses 4import hashlib 5import logging 6import os 7import os.path 8from typing import Dict, List, Optional, Tuple 9from typing_extensions import override 10 11import torch 12from torch.utils._triton import has_triton_package 13 14from ..remote_cache import ( 15 JsonDataTy, 16 RemoteCache, 17 RemoteCacheBackend, 18 RemoteCacheJsonSerde, 19) 20 21 22if has_triton_package(): 23 from triton import Config 24 25log = logging.getLogger(__name__) 26 27 28_InductorMetaTy = Dict[str, object] 29 30 31@dataclasses.dataclass 32class AutotuneCache: 33 configs_hash: str 34 filename: str 35 local_cache: Optional[Tuple[RemoteCache[JsonDataTy], str]] = None 36 remote_cache: Optional[Tuple[RemoteCache[JsonDataTy], str]] = None 37 38 # Create a AutotuneCache. Returns None if none of the caches can be used. 39 @staticmethod 40 def create( 41 inductor_meta: _InductorMetaTy, filename: str, configs_hash: str 42 ) -> Optional[AutotuneCache]: 43 cache = AutotuneCache(configs_hash, filename) 44 cache._setup_local_cache(inductor_meta, filename) 45 cache._setup_remote_autotune_cache(inductor_meta, filename) 46 if cache.local_cache or cache.remote_cache: 47 return cache 48 else: 49 return None 50 51 # Read the best config options from the most local cache and return it. 52 def _read(self, inductor_meta: _InductorMetaTy) -> Optional[Dict[str, JsonDataTy]]: 53 if local_cache := self.local_cache: 54 cache, key = local_cache 55 if best_config := cache.get(key): 56 if isinstance(best_config, dict): 57 return best_config 58 59 if remote_cache := self.remote_cache: 60 cache, key = remote_cache 61 if best_config := cache.get(key): 62 if isinstance(best_config, dict): 63 return best_config 64 65 return None 66 67 # Read the best config options from the most local cache and figure out 68 # which `configs` represents that option. 69 def read_best( 70 self, inductor_meta: _InductorMetaTy, configs: List[Config] 71 ) -> Optional[Config]: 72 if best := self._read(inductor_meta): 73 return _load_cached_autotuning( 74 best, self.configs_hash, configs, inductor_meta 75 ) 76 return None 77 78 # Set up local filesystem caching information 79 def _setup_local_cache(self, inductor_meta: _InductorMetaTy, filename: str) -> None: 80 if not inductor_meta.get("autotune_local_cache", True): 81 return 82 83 cache_filename = os.path.splitext(filename)[0] + ".best_config" 84 local_cache = RemoteCache(_LocalAutotuneCacheBackend(), RemoteCacheJsonSerde()) 85 self.local_cache = (local_cache, cache_filename) 86 87 # Set up remote caching information 88 def _setup_remote_autotune_cache( 89 self, inductor_meta: _InductorMetaTy, filename: str 90 ) -> None: 91 if not _should_use_remote_autotune_cache(inductor_meta): 92 return 93 94 remote_cache = _create_cache( 95 inductor_meta, 96 self.configs_hash, 97 "FbRemoteAutotuneCache", 98 "RemoteAutotuneCache", 99 "autotune-best-config-v2", 100 ) 101 if not remote_cache: 102 return 103 104 # we already sha256 hash the source contents 105 remote_cache_key = os.path.basename(filename) 106 self.remote_cache = (remote_cache, remote_cache_key) 107 108 # Save the config in the caches 109 def save( 110 self, config: Config, time_taken_ns: int, found_by_coordesc: bool = False 111 ) -> None: 112 data = { 113 **config.kwargs, 114 "num_warps": config.num_warps, 115 "num_stages": config.num_stages, 116 "configs_hash": self.configs_hash, 117 "found_by_coordesc": found_by_coordesc, 118 "time_taken_ms": time_taken_ns // 1000000, # Convert from NS to MS 119 } 120 121 if local_cache := self.local_cache: 122 cache, key = local_cache 123 cache.put(key, data) 124 125 if log.isEnabledFor(logging.DEBUG): 126 type_str = "coordesc" if found_by_coordesc else "heuristic" 127 log.debug("Save %s tuning result to %s", type_str, key) 128 129 if remote_cache := self.remote_cache: 130 cache, key = remote_cache 131 cache.put(key, data) 132 133 134def _should_use_remote_autotune_cache(inductor_meta: Dict[str, object]) -> bool: 135 if (config := inductor_meta.get("autotune_remote_cache")) is not None: 136 return bool(config) 137 if not inductor_meta.get("is_fbcode"): 138 return False 139 if torch._utils_internal.is_fb_unit_test(): 140 return False 141 if inductor_meta.get("is_hip"): 142 return False 143 144 try: 145 from torch._inductor.fb.remote_cache import REMOTE_CACHE_VERSION 146 except ModuleNotFoundError: 147 return False 148 149 return REMOTE_CACHE_VERSION >= torch._utils_internal.justknobs_getval_int( 150 "pytorch/remote_cache:autotune_memcache_version" 151 ) 152 153 154def _load_cached_autotuning( 155 best_config: Dict[str, JsonDataTy], 156 configs_hash: str, 157 configs: List[Config], 158 inductor_meta: Dict[str, object], 159) -> Optional[Config]: 160 if best_config is None: 161 return None 162 if best_config.pop("configs_hash", None) != configs_hash: 163 return None 164 165 # Remove time taken for comparison 166 best_config.pop("time_taken_ms", None) 167 168 if inductor_meta.get("coordinate_descent_tuning") and best_config.pop( 169 "found_by_coordesc", False 170 ): 171 num_warps = best_config.pop("num_warps") 172 num_stages = best_config.pop("num_stages") 173 triton_config = Config(best_config, num_warps=num_warps, num_stages=num_stages) 174 triton_config.found_by_coordesc = True 175 return triton_config 176 177 matching_configs = [ 178 cfg 179 for cfg in configs 180 if all(val == best_config.get(key) for key, val in cfg.kwargs.items()) 181 and cfg.num_warps == best_config.get("num_warps") 182 and cfg.num_stages == best_config.get("num_stages") 183 ] 184 if len(matching_configs) != 1: 185 return None 186 187 return matching_configs[0] 188 189 190def _create_cache( 191 inductor_meta: Dict[str, object], 192 configs_hash: str, 193 fb_cache_cls: str, 194 oss_cache_cls: str, 195 salt: str, 196) -> Optional[RemoteCache[JsonDataTy]]: 197 backend_hash = inductor_meta.get("backend_hash", None) 198 if backend_hash is None: 199 log.debug( 200 "backend_hash is not passed on the inductor_meta, unable to use autotune remote cache" 201 ) 202 return None 203 204 assert isinstance(backend_hash, str) 205 206 key = backend_hash + configs_hash + salt 207 key = hashlib.sha256(key.encode("utf-8")).hexdigest() 208 209 try: 210 if inductor_meta.get("is_fbcode"): 211 import torch._inductor.fb.remote_cache 212 213 cache_cls = getattr(torch._inductor.fb.remote_cache, fb_cache_cls) 214 return cache_cls(key) 215 else: 216 import torch._inductor.remote_cache 217 218 cache_cls = getattr(torch._inductor.remote_cache, oss_cache_cls) 219 return cache_cls(key) 220 except Exception: 221 log.warning("Unable to create a remote cache", exc_info=True) 222 return None 223 224 225class _LocalAutotuneCacheBackend(RemoteCacheBackend[bytes]): 226 @override 227 def get(self, key: str) -> Optional[bytes]: 228 try: 229 with open(key, "rb") as fd: 230 return fd.read() 231 except FileNotFoundError: 232 return None 233 234 @override 235 def put(self, key: str, data: bytes) -> None: 236 with open(key, "wb") as fd: 237 fd.write(data) 238