xref: /aosp_15_r20/external/pytorch/test/inductor/mock_cache.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: inductor"]
2from __future__ import annotations
3
4import contextlib
5import dataclasses
6import sys
7import threading
8from typing import Any, Callable, Dict, Generator, Optional, Type, TYPE_CHECKING
9from typing_extensions import override, Self
10from unittest.mock import patch
11
12import torch
13from torch._inductor import config
14from torch._inductor.remote_cache import RemoteCacheBackend
15
16
17if TYPE_CHECKING:
18    from types import TracebackType
19
20
21@dataclasses.dataclass
22class Stats:
23    num_put: int = 0
24    num_get_hit: int = 0
25    num_get_miss: int = 0
26
27    def __iadd__(self, other: Stats) -> Self:
28        self.num_put += other.num_put
29        self.num_get_hit += other.num_get_hit
30        self.num_get_miss += other.num_get_miss
31        return self
32
33    def reset(self) -> None:
34        self.num_put = 0
35        self.num_get_hit = 0
36        self.num_get_miss = 0
37
38    def __str__(self) -> str:
39        return "".join(
40            (
41                f"puts: {self.num_put}, ",
42                f"misses: {self.num_get_miss}, ",
43                f"hits: {self.num_get_hit}, ",
44            )
45        )
46
47
48# The cache states are thread-local so if we're running multiple tests at once
49# they won't cross contaminate. However - it needs to be "global" because we
50# allow code to create new cache clients which refer to the same cache (because
51# it's a remote cache).
52
53
54class _GlobalStats(Stats, threading.local):
55    def __init__(self) -> None:
56        self.autotune = Stats()
57        self.fx_graph = Stats()
58        self.triton = Stats()
59
60    def reset(self) -> None:
61        self.autotune.reset()
62        self.fx_graph.reset()
63        self.triton.reset()
64
65    def update(self, name: str, delta: Stats) -> None:
66        stat = getattr(self, name)
67        stat += delta
68
69    def report(self):
70        print("Cache Stats:", file=sys.stderr)
71        print(f"  autotune: {self.autotune}", file=sys.stderr)
72        print(f"  fx_graph: {self.fx_graph}", file=sys.stderr)
73        print(f"  triton:   {self.triton}", file=sys.stderr)
74
75
76global_stats = _GlobalStats()
77
78
79class MockBackend(RemoteCacheBackend[Any]):
80    def __init__(self, name: str, cache: Dict[str, object]) -> None:
81        self._cache = cache
82        self._name = name
83
84    @staticmethod
85    def with_name(name: str) -> Callable[[], MockBackend]:
86        cache = {}
87
88        def wrapper() -> MockBackend:
89            return MockBackend(name, cache)
90
91        return wrapper
92
93    @override
94    def get(self, key: str) -> Optional[Any]:
95        if key in self._cache:
96            global_stats.update(self._name, Stats(num_get_hit=1))
97            return self._cache.get(key)
98        else:
99            global_stats.update(self._name, Stats(num_get_miss=1))
100            return None
101
102    @override
103    def put(self, key: str, data: Any) -> None:
104        global_stats.update(self._name, Stats(num_put=1))
105        self._cache[key] = data
106
107
108# List of configs for each cache
109_CACHE_CONFIG_EN = (
110    "fx_graph_cache",
111    "fx_graph_remote_cache",
112    "autotune_local_cache",
113    "autotune_remote_cache",
114    # "bundled_autotune_cache",
115)
116
117
118class PatchCaches(contextlib.AbstractContextManager):
119    @classmethod
120    def setUp(cls):
121        # If this test is using PatchCaches then disable all the caches by
122        # default, letting the tests turn them on explicitly. This is because
123        # tests using PatchCaches will often want to check stats explicitly.
124        cls._savedCacheState = {}
125        for name in _CACHE_CONFIG_EN:
126            if hasattr(config, name):
127                cls._savedCacheState[name] = getattr(config, name)
128            setattr(config, name, False)
129
130    @classmethod
131    def tearDown(cls):
132        # Restore cache defaults
133        for name in _CACHE_CONFIG_EN:
134            delattr(config, name)
135            if name in cls._savedCacheState:
136                setattr(config, name, cls._savedCacheState[name])
137
138    def __init__(self) -> None:
139        self._stack = contextlib.ExitStack()
140
141    def __enter__(self) -> Self:
142        global_stats.reset()
143        self._stack.__enter__()
144
145        ctx = patch(
146            "torch._inductor.remote_cache.RemoteAutotuneCache.backend_override_cls",
147            MockBackend.with_name("autotune"),
148        )
149        self._stack.enter_context(ctx)
150
151        ctx = patch(
152            "torch._inductor.remote_cache.RemoteFxGraphCache.backend_override_cls",
153            MockBackend.with_name("fx_graph"),
154        )
155        self._stack.enter_context(ctx)
156
157        if config.is_fbcode():
158            ctx = patch(
159                "torch._inductor.fb.remote_cache.FbRemoteAutotuneCache.backend_override_cls",
160                MockBackend.with_name("autotune"),
161            )
162            self._stack.enter_context(ctx)
163
164            ctx = patch(
165                "torch._inductor.fb.remote_cache.FbRemoteFxGraphCache.backend_override_cls",
166                MockBackend.with_name("fx_graph"),
167            )
168            self._stack.enter_context(ctx)
169
170            ctx = patch(
171                "triton.fb.fb_memcache.FbMemcacheRemoteKernelCache.backend_override_cls",
172                MockBackend.with_name("triton"),
173            )
174            self._stack.enter_context(ctx)
175
176        return self
177
178    def __exit__(
179        self,
180        exc_type: Optional[Type[BaseException]],
181        exc_value: Optional[BaseException],
182        traceback: Optional[TracebackType],
183    ) -> None:
184        self._stack.__exit__(exc_type, exc_value, traceback)
185
186
187@contextlib.contextmanager
188def patch_fbcode(state: bool) -> Generator[None, None, None]:
189    if hasattr(torch.version, "git_version"):
190        # Currently non-fbcode
191        if state:
192            old = torch.version.git_version
193            delattr(torch.version, "git_version")
194            try:
195                yield
196            finally:
197                torch.version.git_version = old
198        else:
199            yield
200    else:
201        # Currently fbcode
202        if state:
203            yield
204        else:
205            torch.version.git_version = "12345+"
206            try:
207                yield
208            finally:
209                delattr(torch.version, "git_version")
210