xref: /aosp_15_r20/external/pytorch/torch/utils/_config_module.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import contextlib
2import copy
3import hashlib
4import inspect
5import io
6import pickle
7import tokenize
8import unittest
9import warnings
10from types import FunctionType, ModuleType
11from typing import Any, Callable, Dict, NoReturn, Optional, Set, Union
12from typing_extensions import deprecated
13from unittest import mock
14
15
16# Types saved/loaded in configs
17CONFIG_TYPES = (int, float, bool, type(None), str, list, set, tuple, dict)
18
19
20def install_config_module(module: ModuleType) -> None:
21    """
22    Converts a module-level config into a `ConfigModule()`.
23
24    See _config_typing.pyi for instructions on how to get the converted module to typecheck.
25    """
26
27    class ConfigModuleInstance(ConfigModule):
28        _bypass_keys = set({"_is_dirty", "_hash_digest"})
29
30    def visit(
31        source: Union[ModuleType, type],
32        dest: Union[ModuleType, SubConfigProxy],
33        prefix: str,
34    ) -> None:
35        """Walk the module structure and move everything to module._config"""
36        for key, value in list(source.__dict__.items()):
37            if (
38                key.startswith("__")
39                or isinstance(value, (ModuleType, FunctionType))
40                or (hasattr(value, "__module__") and value.__module__ == "typing")
41            ):
42                continue
43
44            name = f"{prefix}{key}"
45            if isinstance(value, CONFIG_TYPES):
46                config[name] = value
47                default[name] = value
48                if dest is module:
49                    delattr(module, key)
50            elif isinstance(value, type):
51                assert value.__module__ == module.__name__
52                # a subconfig with `class Blah:` syntax
53                proxy = SubConfigProxy(module, f"{name}.")
54                visit(value, proxy, f"{name}.")
55                if dest is module:
56                    setattr(dest, key, proxy)
57                else:
58                    dest.__dict__[key] = proxy
59            else:
60                raise AssertionError(f"Unhandled config {key}={value} ({type(value)})")
61
62    config: Dict[str, Any] = {}
63    default: Dict[str, Any] = {}
64
65    compile_ignored_keys = get_assignments_with_compile_ignored_comments(module)
66
67    visit(module, module, "")
68    module._config = config  # type: ignore[attr-defined]
69    module._default = default  # type: ignore[attr-defined]
70    module._allowed_keys = set(config.keys())  # type: ignore[attr-defined]
71    module._compile_ignored_keys = compile_ignored_keys  # type: ignore[attr-defined]
72    module.__class__ = ConfigModuleInstance
73    module._is_dirty = True  # type: ignore[attr-defined]
74    module._hash_digest = None  # type: ignore[attr-defined]
75
76
77COMPILE_IGNORED_MARKER = "@compile_ignored"
78
79
80# Gets all the keys (i.e. assignments) with a @compile_ignored comment
81def get_assignments_with_compile_ignored_comments(module: ModuleType) -> Set[str]:
82    source_code = inspect.getsource(module)
83    assignments = set()
84
85    # Tokenize the source code to retrieve comments
86    tokens = tokenize.tokenize(io.BytesIO(source_code.encode("utf-8")).readline)
87    current_comment = "", -1
88    prev_name = ""
89
90    for token in tokens:
91        if token.type == tokenize.COMMENT:
92            prev_name = ""
93            maybe_current = token.string.strip()
94            if COMPILE_IGNORED_MARKER in maybe_current:
95                assert current_comment == (
96                    "",
97                    -1,
98                ), f"unconsumed {COMPILE_IGNORED_MARKER}"
99                current_comment = maybe_current, token.start[0]
100        elif token.type == tokenize.NAME:
101            # Only accept the first name token, to handle if you have
102            # something like foo: Bar = ...
103            if not prev_name:
104                prev_name = token.string
105        elif token.type == tokenize.OP and token.string == "=":
106            # Check if the current assignment follows a comment
107            # with COMPILE_IGNORED_MARKER
108            if (
109                COMPILE_IGNORED_MARKER in current_comment[0]
110                and current_comment[1] == token.start[0] - 1
111            ):
112                assignments.add(prev_name)
113                current_comment = "", -1  # reset
114            prev_name = ""
115    assert current_comment == ("", -1), f"unconsumed {COMPILE_IGNORED_MARKER}"
116    return assignments
117
118
119class ConfigModule(ModuleType):
120    # NOTE: This should be kept in sync with _config_typing.pyi.
121
122    # The default values of the configuration settings.  This can be used to
123    # determine if the config has been changed or not.
124    _default: Dict[str, Any]
125    # The actual configuration settings.  E.g., torch._dynamo.config.debug
126    # would live as "debug" in the key, and torch._inductor.config.triton.cudagraphs
127    # maps as "triton.cudagraphs"
128    _config: Dict[str, Any]
129    _allowed_keys: Set[str]
130    _bypass_keys: Set[str]
131    _compile_ignored_keys: Set[str]
132    _is_dirty: bool
133    _hash_digest: Optional[bytes]
134
135    def __init__(self) -> None:
136        raise NotImplementedError(
137            f"use {__name__}.install_config_module(sys.modules[__name__])"
138        )
139
140    def __setattr__(self, name: str, value: object) -> None:
141        if name in self._bypass_keys:
142            super().__setattr__(name, value)
143        elif name not in self._allowed_keys:
144            raise AttributeError(f"{self.__name__}.{name} does not exist")
145        else:
146            self._config[name] = value
147
148    def __getattr__(self, name: str) -> Any:
149        try:
150            return self._config[name]
151        except KeyError as e:
152            # make hasattr() work properly
153            raise AttributeError(f"{self.__name__}.{name} does not exist") from e
154
155    def __delattr__(self, name: str) -> None:
156        # must support delete because unittest.mock.patch deletes
157        # then recreate things
158        del self._config[name]
159
160    def save_config(self) -> bytes:
161        """Convert config to a pickled blob"""
162        config = dict(self._config)
163        for key in config.get("_save_config_ignore", ()):
164            config.pop(key)
165        return pickle.dumps(config, protocol=2)
166
167    def save_config_portable(self) -> Dict[str, Any]:
168        """Convert config to portable format"""
169        config: Dict[str, Any] = {}
170        for key in sorted(self._config):
171            if key.startswith("_"):
172                continue
173            if any(
174                key.startswith(e) for e in self._config["_cache_config_ignore_prefix"]
175            ):
176                continue
177            config[key] = self._config[key]
178        return config
179
180    def codegen_config(self) -> str:
181        """Convert config to Python statements that replicate current config.
182        This does NOT include config settings that are at default values.
183        """
184        lines = []
185        mod = self.__name__
186        for k, v in self._config.items():
187            if k in self._config.get("_save_config_ignore", ()):
188                if v != self._default[k]:
189                    warnings.warn(f"Skipping serialization of {k} value {v}")
190                continue
191            if v == self._default[k]:
192                continue
193            lines.append(f"{mod}.{k} = {v!r}")
194        return "\n".join(lines)
195
196    def get_hash(self) -> bytes:
197        """Hashes the configs that are not compile_ignored"""
198        if self._is_dirty or self._hash_digest is None:
199            dict_to_hash = {
200                k: v
201                for k, v in self._config.items()
202                if k not in self._compile_ignored_keys
203            }
204            string_to_hash = repr(sorted(dict_to_hash.items()))
205            self._hash_digest = hashlib.md5(string_to_hash.encode("utf-8")).digest()
206            self._is_dirty = False
207        return self._hash_digest
208
209    @deprecated(
210        "`config.to_dict()` has been deprecated. It may no longer change the underlying config."
211        " use `config.shallow_copy_dict()` or `config.get_config_copy()` instead",
212        category=FutureWarning,
213    )
214    def to_dict(self) -> Dict[str, Any]:
215        return self.shallow_copy_dict()
216
217    def shallow_copy_dict(self) -> Dict[str, Any]:
218        return {**self._config}
219
220    def load_config(self, maybe_pickled_config: Union[bytes, Dict[str, Any]]) -> None:
221        """Restore from a prior call to save_config() or shallow_copy_dict()"""
222        if not isinstance(maybe_pickled_config, dict):
223            config = pickle.loads(maybe_pickled_config)
224        else:
225            config = maybe_pickled_config
226        self._config.update(config)
227
228    def get_config_copy(self) -> Dict[str, Any]:
229        return copy.deepcopy(self._config)
230
231    def patch(
232        self,
233        arg1: Optional[Union[str, Dict[str, Any]]] = None,
234        arg2: Any = None,
235        **kwargs: Dict[str, Any],
236    ) -> "ContextDecorator":
237        """
238        Decorator and/or context manager to make temporary changes to a config.
239
240        As a decorator:
241
242            @config.patch("name", val)
243            @config.patch(name1=val1, name2=val2)
244            @config.patch({"name1": val1, "name2", val2})
245            def foo(...):
246                ...
247
248        As a context manager:
249
250            with config.patch("name", val):
251                ...
252        """
253        changes: Dict[str, Any]
254        if arg1 is not None:
255            if arg2 is not None:
256                assert isinstance(arg1, str)
257                # patch("key", True) syntax
258                changes = {arg1: arg2}
259            else:
260                assert isinstance(arg1, dict)
261                # patch({"key": True}) syntax
262                changes = arg1
263            assert not kwargs
264        else:
265            # patch(key=True) syntax
266            changes = kwargs
267            assert arg2 is None
268        assert isinstance(changes, dict), f"expected `dict` got {type(changes)}"
269        prior: Dict[str, Any] = {}
270        config = self
271        dirty = False
272
273        class ConfigPatch(ContextDecorator):
274            def __enter__(self) -> None:
275                assert not prior
276                nonlocal dirty
277                for key in changes.keys():
278                    # KeyError on invalid entry
279                    prior[key] = config._config[key]
280                    dirty = key not in config._compile_ignored_keys
281                config._config.update(changes)
282                config._is_dirty = dirty
283
284            def __exit__(self, exc_type, exc_val, exc_tb):  # type: ignore[no-untyped-def]
285                nonlocal dirty
286                config._config.update(prior)
287                config._is_dirty = dirty
288                prior.clear()
289
290        return ConfigPatch()
291
292    def _make_closure_patcher(self, **changes: Dict[str, Any]) -> Any:
293        """
294        A lower-overhead version of patch() for things on the critical path.
295
296        Usage:
297
298            # do this off the critical path
299            change_fn = config.make_closure_patcher(foo=True)
300
301            ...
302
303            revert = change_fn()
304            try:
305              ...
306            finally:
307                revert()
308
309        """
310        config = self._config
311
312        def change() -> Callable[[], None]:
313            prior = {k: config[k] for k in changes}
314            config.update(changes)
315
316            def revert() -> None:
317                config.update(prior)
318
319            return revert
320
321        return change
322
323
324class ContextDecorator(contextlib.ContextDecorator):
325    """
326    Same as contextlib.ContextDecorator, but with support for
327    `unittest.TestCase`
328    """
329
330    def __enter__(self) -> None:
331        raise NotImplementedError("NYI")
332
333    def __exit__(self, exc_type, exc_val, exc_tb) -> NoReturn:  # type: ignore[no-untyped-def]
334        raise NotImplementedError("NYI")
335
336    def __call__(self, func: Callable[[Any], Any]) -> Any:
337        if isinstance(func, type) and issubclass(func, unittest.TestCase):
338
339            class _TestCase(func):  # type: ignore[valid-type, misc]
340                @classmethod
341                def setUpClass(cls) -> None:
342                    self.__enter__()
343                    try:
344                        super().setUpClass()
345                    except Exception:
346                        self.__exit__(None, None, None)
347                        raise
348
349                @classmethod
350                def tearDownClass(cls) -> None:
351                    try:
352                        super().tearDownClass()
353                    finally:
354                        self.__exit__(None, None, None)
355
356            _TestCase.__name__ = func.__name__
357            _TestCase.__qualname__ = func.__qualname__
358            _TestCase.__module__ = func.__module__
359
360            return _TestCase
361
362        return super().__call__(func)
363
364
365class SubConfigProxy:
366    """
367    Shim to redirect to main config.
368    `config.triton.cudagraphs` maps to _config["triton.cudagraphs"]
369    """
370
371    def __init__(self, config: object, prefix: str):
372        # `super().__setattr__` to bypass custom `__setattr__`
373        super().__setattr__("_config", config)
374        super().__setattr__("_prefix", prefix)
375
376    def __setattr__(self, name: str, value: object) -> None:
377        return self._config.__setattr__(self._prefix + name, value)
378
379    def __getattr__(self, name: str) -> Any:
380        return self._config.__getattr__(self._prefix + name)
381
382    def __delattr__(self, name: str) -> None:
383        return self._config.__delattr__(self._prefix + name)
384
385
386def patch_object(obj: object, name: str, value: object) -> object:
387    """
388    Workaround `mock.patch.object` issue with ConfigModule
389    """
390    if isinstance(obj, ConfigModule):
391        return obj.patch(name, value)
392    return mock.patch.object(obj, name, value)
393