1# Owner(s): ["module: inductor"] 2from __future__ import annotations 3 4import contextlib 5import dataclasses 6import sys 7import threading 8from typing import Any, Callable, Dict, Generator, Optional, Type, TYPE_CHECKING 9from typing_extensions import override, Self 10from unittest.mock import patch 11 12import torch 13from torch._inductor import config 14from torch._inductor.remote_cache import RemoteCacheBackend 15 16 17if TYPE_CHECKING: 18 from types import TracebackType 19 20 21@dataclasses.dataclass 22class Stats: 23 num_put: int = 0 24 num_get_hit: int = 0 25 num_get_miss: int = 0 26 27 def __iadd__(self, other: Stats) -> Self: 28 self.num_put += other.num_put 29 self.num_get_hit += other.num_get_hit 30 self.num_get_miss += other.num_get_miss 31 return self 32 33 def reset(self) -> None: 34 self.num_put = 0 35 self.num_get_hit = 0 36 self.num_get_miss = 0 37 38 def __str__(self) -> str: 39 return "".join( 40 ( 41 f"puts: {self.num_put}, ", 42 f"misses: {self.num_get_miss}, ", 43 f"hits: {self.num_get_hit}, ", 44 ) 45 ) 46 47 48# The cache states are thread-local so if we're running multiple tests at once 49# they won't cross contaminate. However - it needs to be "global" because we 50# allow code to create new cache clients which refer to the same cache (because 51# it's a remote cache). 52 53 54class _GlobalStats(Stats, threading.local): 55 def __init__(self) -> None: 56 self.autotune = Stats() 57 self.fx_graph = Stats() 58 self.triton = Stats() 59 60 def reset(self) -> None: 61 self.autotune.reset() 62 self.fx_graph.reset() 63 self.triton.reset() 64 65 def update(self, name: str, delta: Stats) -> None: 66 stat = getattr(self, name) 67 stat += delta 68 69 def report(self): 70 print("Cache Stats:", file=sys.stderr) 71 print(f" autotune: {self.autotune}", file=sys.stderr) 72 print(f" fx_graph: {self.fx_graph}", file=sys.stderr) 73 print(f" triton: {self.triton}", file=sys.stderr) 74 75 76global_stats = _GlobalStats() 77 78 79class MockBackend(RemoteCacheBackend[Any]): 80 def __init__(self, name: str, cache: Dict[str, object]) -> None: 81 self._cache = cache 82 self._name = name 83 84 @staticmethod 85 def with_name(name: str) -> Callable[[], MockBackend]: 86 cache = {} 87 88 def wrapper() -> MockBackend: 89 return MockBackend(name, cache) 90 91 return wrapper 92 93 @override 94 def get(self, key: str) -> Optional[Any]: 95 if key in self._cache: 96 global_stats.update(self._name, Stats(num_get_hit=1)) 97 return self._cache.get(key) 98 else: 99 global_stats.update(self._name, Stats(num_get_miss=1)) 100 return None 101 102 @override 103 def put(self, key: str, data: Any) -> None: 104 global_stats.update(self._name, Stats(num_put=1)) 105 self._cache[key] = data 106 107 108# List of configs for each cache 109_CACHE_CONFIG_EN = ( 110 "fx_graph_cache", 111 "fx_graph_remote_cache", 112 "autotune_local_cache", 113 "autotune_remote_cache", 114 # "bundled_autotune_cache", 115) 116 117 118class PatchCaches(contextlib.AbstractContextManager): 119 @classmethod 120 def setUp(cls): 121 # If this test is using PatchCaches then disable all the caches by 122 # default, letting the tests turn them on explicitly. This is because 123 # tests using PatchCaches will often want to check stats explicitly. 124 cls._savedCacheState = {} 125 for name in _CACHE_CONFIG_EN: 126 if hasattr(config, name): 127 cls._savedCacheState[name] = getattr(config, name) 128 setattr(config, name, False) 129 130 @classmethod 131 def tearDown(cls): 132 # Restore cache defaults 133 for name in _CACHE_CONFIG_EN: 134 delattr(config, name) 135 if name in cls._savedCacheState: 136 setattr(config, name, cls._savedCacheState[name]) 137 138 def __init__(self) -> None: 139 self._stack = contextlib.ExitStack() 140 141 def __enter__(self) -> Self: 142 global_stats.reset() 143 self._stack.__enter__() 144 145 ctx = patch( 146 "torch._inductor.remote_cache.RemoteAutotuneCache.backend_override_cls", 147 MockBackend.with_name("autotune"), 148 ) 149 self._stack.enter_context(ctx) 150 151 ctx = patch( 152 "torch._inductor.remote_cache.RemoteFxGraphCache.backend_override_cls", 153 MockBackend.with_name("fx_graph"), 154 ) 155 self._stack.enter_context(ctx) 156 157 if config.is_fbcode(): 158 ctx = patch( 159 "torch._inductor.fb.remote_cache.FbRemoteAutotuneCache.backend_override_cls", 160 MockBackend.with_name("autotune"), 161 ) 162 self._stack.enter_context(ctx) 163 164 ctx = patch( 165 "torch._inductor.fb.remote_cache.FbRemoteFxGraphCache.backend_override_cls", 166 MockBackend.with_name("fx_graph"), 167 ) 168 self._stack.enter_context(ctx) 169 170 ctx = patch( 171 "triton.fb.fb_memcache.FbMemcacheRemoteKernelCache.backend_override_cls", 172 MockBackend.with_name("triton"), 173 ) 174 self._stack.enter_context(ctx) 175 176 return self 177 178 def __exit__( 179 self, 180 exc_type: Optional[Type[BaseException]], 181 exc_value: Optional[BaseException], 182 traceback: Optional[TracebackType], 183 ) -> None: 184 self._stack.__exit__(exc_type, exc_value, traceback) 185 186 187@contextlib.contextmanager 188def patch_fbcode(state: bool) -> Generator[None, None, None]: 189 if hasattr(torch.version, "git_version"): 190 # Currently non-fbcode 191 if state: 192 old = torch.version.git_version 193 delattr(torch.version, "git_version") 194 try: 195 yield 196 finally: 197 torch.version.git_version = old 198 else: 199 yield 200 else: 201 # Currently fbcode 202 if state: 203 yield 204 else: 205 torch.version.git_version = "12345+" 206 try: 207 yield 208 finally: 209 delattr(torch.version, "git_version") 210