1import contextlib 2import copy 3import hashlib 4import inspect 5import io 6import pickle 7import tokenize 8import unittest 9import warnings 10from types import FunctionType, ModuleType 11from typing import Any, Callable, Dict, NoReturn, Optional, Set, Union 12from typing_extensions import deprecated 13from unittest import mock 14 15 16# Types saved/loaded in configs 17CONFIG_TYPES = (int, float, bool, type(None), str, list, set, tuple, dict) 18 19 20def install_config_module(module: ModuleType) -> None: 21 """ 22 Converts a module-level config into a `ConfigModule()`. 23 24 See _config_typing.pyi for instructions on how to get the converted module to typecheck. 25 """ 26 27 class ConfigModuleInstance(ConfigModule): 28 _bypass_keys = set({"_is_dirty", "_hash_digest"}) 29 30 def visit( 31 source: Union[ModuleType, type], 32 dest: Union[ModuleType, SubConfigProxy], 33 prefix: str, 34 ) -> None: 35 """Walk the module structure and move everything to module._config""" 36 for key, value in list(source.__dict__.items()): 37 if ( 38 key.startswith("__") 39 or isinstance(value, (ModuleType, FunctionType)) 40 or (hasattr(value, "__module__") and value.__module__ == "typing") 41 ): 42 continue 43 44 name = f"{prefix}{key}" 45 if isinstance(value, CONFIG_TYPES): 46 config[name] = value 47 default[name] = value 48 if dest is module: 49 delattr(module, key) 50 elif isinstance(value, type): 51 assert value.__module__ == module.__name__ 52 # a subconfig with `class Blah:` syntax 53 proxy = SubConfigProxy(module, f"{name}.") 54 visit(value, proxy, f"{name}.") 55 if dest is module: 56 setattr(dest, key, proxy) 57 else: 58 dest.__dict__[key] = proxy 59 else: 60 raise AssertionError(f"Unhandled config {key}={value} ({type(value)})") 61 62 config: Dict[str, Any] = {} 63 default: Dict[str, Any] = {} 64 65 compile_ignored_keys = get_assignments_with_compile_ignored_comments(module) 66 67 visit(module, module, "") 68 module._config = config # type: ignore[attr-defined] 69 module._default = default # type: ignore[attr-defined] 70 module._allowed_keys = set(config.keys()) # type: ignore[attr-defined] 71 module._compile_ignored_keys = compile_ignored_keys # type: ignore[attr-defined] 72 module.__class__ = ConfigModuleInstance 73 module._is_dirty = True # type: ignore[attr-defined] 74 module._hash_digest = None # type: ignore[attr-defined] 75 76 77COMPILE_IGNORED_MARKER = "@compile_ignored" 78 79 80# Gets all the keys (i.e. assignments) with a @compile_ignored comment 81def get_assignments_with_compile_ignored_comments(module: ModuleType) -> Set[str]: 82 source_code = inspect.getsource(module) 83 assignments = set() 84 85 # Tokenize the source code to retrieve comments 86 tokens = tokenize.tokenize(io.BytesIO(source_code.encode("utf-8")).readline) 87 current_comment = "", -1 88 prev_name = "" 89 90 for token in tokens: 91 if token.type == tokenize.COMMENT: 92 prev_name = "" 93 maybe_current = token.string.strip() 94 if COMPILE_IGNORED_MARKER in maybe_current: 95 assert current_comment == ( 96 "", 97 -1, 98 ), f"unconsumed {COMPILE_IGNORED_MARKER}" 99 current_comment = maybe_current, token.start[0] 100 elif token.type == tokenize.NAME: 101 # Only accept the first name token, to handle if you have 102 # something like foo: Bar = ... 103 if not prev_name: 104 prev_name = token.string 105 elif token.type == tokenize.OP and token.string == "=": 106 # Check if the current assignment follows a comment 107 # with COMPILE_IGNORED_MARKER 108 if ( 109 COMPILE_IGNORED_MARKER in current_comment[0] 110 and current_comment[1] == token.start[0] - 1 111 ): 112 assignments.add(prev_name) 113 current_comment = "", -1 # reset 114 prev_name = "" 115 assert current_comment == ("", -1), f"unconsumed {COMPILE_IGNORED_MARKER}" 116 return assignments 117 118 119class ConfigModule(ModuleType): 120 # NOTE: This should be kept in sync with _config_typing.pyi. 121 122 # The default values of the configuration settings. This can be used to 123 # determine if the config has been changed or not. 124 _default: Dict[str, Any] 125 # The actual configuration settings. E.g., torch._dynamo.config.debug 126 # would live as "debug" in the key, and torch._inductor.config.triton.cudagraphs 127 # maps as "triton.cudagraphs" 128 _config: Dict[str, Any] 129 _allowed_keys: Set[str] 130 _bypass_keys: Set[str] 131 _compile_ignored_keys: Set[str] 132 _is_dirty: bool 133 _hash_digest: Optional[bytes] 134 135 def __init__(self) -> None: 136 raise NotImplementedError( 137 f"use {__name__}.install_config_module(sys.modules[__name__])" 138 ) 139 140 def __setattr__(self, name: str, value: object) -> None: 141 if name in self._bypass_keys: 142 super().__setattr__(name, value) 143 elif name not in self._allowed_keys: 144 raise AttributeError(f"{self.__name__}.{name} does not exist") 145 else: 146 self._config[name] = value 147 148 def __getattr__(self, name: str) -> Any: 149 try: 150 return self._config[name] 151 except KeyError as e: 152 # make hasattr() work properly 153 raise AttributeError(f"{self.__name__}.{name} does not exist") from e 154 155 def __delattr__(self, name: str) -> None: 156 # must support delete because unittest.mock.patch deletes 157 # then recreate things 158 del self._config[name] 159 160 def save_config(self) -> bytes: 161 """Convert config to a pickled blob""" 162 config = dict(self._config) 163 for key in config.get("_save_config_ignore", ()): 164 config.pop(key) 165 return pickle.dumps(config, protocol=2) 166 167 def save_config_portable(self) -> Dict[str, Any]: 168 """Convert config to portable format""" 169 config: Dict[str, Any] = {} 170 for key in sorted(self._config): 171 if key.startswith("_"): 172 continue 173 if any( 174 key.startswith(e) for e in self._config["_cache_config_ignore_prefix"] 175 ): 176 continue 177 config[key] = self._config[key] 178 return config 179 180 def codegen_config(self) -> str: 181 """Convert config to Python statements that replicate current config. 182 This does NOT include config settings that are at default values. 183 """ 184 lines = [] 185 mod = self.__name__ 186 for k, v in self._config.items(): 187 if k in self._config.get("_save_config_ignore", ()): 188 if v != self._default[k]: 189 warnings.warn(f"Skipping serialization of {k} value {v}") 190 continue 191 if v == self._default[k]: 192 continue 193 lines.append(f"{mod}.{k} = {v!r}") 194 return "\n".join(lines) 195 196 def get_hash(self) -> bytes: 197 """Hashes the configs that are not compile_ignored""" 198 if self._is_dirty or self._hash_digest is None: 199 dict_to_hash = { 200 k: v 201 for k, v in self._config.items() 202 if k not in self._compile_ignored_keys 203 } 204 string_to_hash = repr(sorted(dict_to_hash.items())) 205 self._hash_digest = hashlib.md5(string_to_hash.encode("utf-8")).digest() 206 self._is_dirty = False 207 return self._hash_digest 208 209 @deprecated( 210 "`config.to_dict()` has been deprecated. It may no longer change the underlying config." 211 " use `config.shallow_copy_dict()` or `config.get_config_copy()` instead", 212 category=FutureWarning, 213 ) 214 def to_dict(self) -> Dict[str, Any]: 215 return self.shallow_copy_dict() 216 217 def shallow_copy_dict(self) -> Dict[str, Any]: 218 return {**self._config} 219 220 def load_config(self, maybe_pickled_config: Union[bytes, Dict[str, Any]]) -> None: 221 """Restore from a prior call to save_config() or shallow_copy_dict()""" 222 if not isinstance(maybe_pickled_config, dict): 223 config = pickle.loads(maybe_pickled_config) 224 else: 225 config = maybe_pickled_config 226 self._config.update(config) 227 228 def get_config_copy(self) -> Dict[str, Any]: 229 return copy.deepcopy(self._config) 230 231 def patch( 232 self, 233 arg1: Optional[Union[str, Dict[str, Any]]] = None, 234 arg2: Any = None, 235 **kwargs: Dict[str, Any], 236 ) -> "ContextDecorator": 237 """ 238 Decorator and/or context manager to make temporary changes to a config. 239 240 As a decorator: 241 242 @config.patch("name", val) 243 @config.patch(name1=val1, name2=val2) 244 @config.patch({"name1": val1, "name2", val2}) 245 def foo(...): 246 ... 247 248 As a context manager: 249 250 with config.patch("name", val): 251 ... 252 """ 253 changes: Dict[str, Any] 254 if arg1 is not None: 255 if arg2 is not None: 256 assert isinstance(arg1, str) 257 # patch("key", True) syntax 258 changes = {arg1: arg2} 259 else: 260 assert isinstance(arg1, dict) 261 # patch({"key": True}) syntax 262 changes = arg1 263 assert not kwargs 264 else: 265 # patch(key=True) syntax 266 changes = kwargs 267 assert arg2 is None 268 assert isinstance(changes, dict), f"expected `dict` got {type(changes)}" 269 prior: Dict[str, Any] = {} 270 config = self 271 dirty = False 272 273 class ConfigPatch(ContextDecorator): 274 def __enter__(self) -> None: 275 assert not prior 276 nonlocal dirty 277 for key in changes.keys(): 278 # KeyError on invalid entry 279 prior[key] = config._config[key] 280 dirty = key not in config._compile_ignored_keys 281 config._config.update(changes) 282 config._is_dirty = dirty 283 284 def __exit__(self, exc_type, exc_val, exc_tb): # type: ignore[no-untyped-def] 285 nonlocal dirty 286 config._config.update(prior) 287 config._is_dirty = dirty 288 prior.clear() 289 290 return ConfigPatch() 291 292 def _make_closure_patcher(self, **changes: Dict[str, Any]) -> Any: 293 """ 294 A lower-overhead version of patch() for things on the critical path. 295 296 Usage: 297 298 # do this off the critical path 299 change_fn = config.make_closure_patcher(foo=True) 300 301 ... 302 303 revert = change_fn() 304 try: 305 ... 306 finally: 307 revert() 308 309 """ 310 config = self._config 311 312 def change() -> Callable[[], None]: 313 prior = {k: config[k] for k in changes} 314 config.update(changes) 315 316 def revert() -> None: 317 config.update(prior) 318 319 return revert 320 321 return change 322 323 324class ContextDecorator(contextlib.ContextDecorator): 325 """ 326 Same as contextlib.ContextDecorator, but with support for 327 `unittest.TestCase` 328 """ 329 330 def __enter__(self) -> None: 331 raise NotImplementedError("NYI") 332 333 def __exit__(self, exc_type, exc_val, exc_tb) -> NoReturn: # type: ignore[no-untyped-def] 334 raise NotImplementedError("NYI") 335 336 def __call__(self, func: Callable[[Any], Any]) -> Any: 337 if isinstance(func, type) and issubclass(func, unittest.TestCase): 338 339 class _TestCase(func): # type: ignore[valid-type, misc] 340 @classmethod 341 def setUpClass(cls) -> None: 342 self.__enter__() 343 try: 344 super().setUpClass() 345 except Exception: 346 self.__exit__(None, None, None) 347 raise 348 349 @classmethod 350 def tearDownClass(cls) -> None: 351 try: 352 super().tearDownClass() 353 finally: 354 self.__exit__(None, None, None) 355 356 _TestCase.__name__ = func.__name__ 357 _TestCase.__qualname__ = func.__qualname__ 358 _TestCase.__module__ = func.__module__ 359 360 return _TestCase 361 362 return super().__call__(func) 363 364 365class SubConfigProxy: 366 """ 367 Shim to redirect to main config. 368 `config.triton.cudagraphs` maps to _config["triton.cudagraphs"] 369 """ 370 371 def __init__(self, config: object, prefix: str): 372 # `super().__setattr__` to bypass custom `__setattr__` 373 super().__setattr__("_config", config) 374 super().__setattr__("_prefix", prefix) 375 376 def __setattr__(self, name: str, value: object) -> None: 377 return self._config.__setattr__(self._prefix + name, value) 378 379 def __getattr__(self, name: str) -> Any: 380 return self._config.__getattr__(self._prefix + name) 381 382 def __delattr__(self, name: str) -> None: 383 return self._config.__delattr__(self._prefix + name) 384 385 386def patch_object(obj: object, name: str, value: object) -> object: 387 """ 388 Workaround `mock.patch.object` issue with ConfigModule 389 """ 390 if isinstance(obj, ConfigModule): 391 return obj.patch(name, value) 392 return mock.patch.object(obj, name, value) 393