xref: /aosp_15_r20/external/pytorch/torch/backends/opt_einsum/__init__.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import sys
3import warnings
4from contextlib import contextmanager
5from functools import lru_cache as _lru_cache
6from typing import Any
7
8from torch.backends import __allow_nonbracketed_mutation, ContextProp, PropModule
9
10
11try:
12    import opt_einsum as _opt_einsum  # type: ignore[import]
13except ImportError:
14    _opt_einsum = None
15
16
17@_lru_cache
18def is_available() -> bool:
19    r"""Return a bool indicating if opt_einsum is currently available."""
20    return _opt_einsum is not None
21
22
23def get_opt_einsum() -> Any:
24    r"""Return the opt_einsum package if opt_einsum is currently available, else None."""
25    return _opt_einsum
26
27
28def _set_enabled(_enabled: bool) -> None:
29    if not is_available() and _enabled:
30        raise ValueError(
31            f"opt_einsum is not available, so setting `enabled` to {_enabled} will not reap "
32            "the benefits of calculating an optimal path for einsum. torch.einsum will "
33            "fall back to contracting from left to right. To enable this optimal path "
34            "calculation, please install opt-einsum."
35        )
36    global enabled
37    enabled = _enabled
38
39
40def _get_enabled() -> bool:
41    return enabled
42
43
44def _set_strategy(_strategy: str) -> None:
45    if not is_available():
46        raise ValueError(
47            f"opt_einsum is not available, so setting `strategy` to {_strategy} will not be meaningful. "
48            "torch.einsum will bypass path calculation and simply contract from left to right. "
49            "Please install opt_einsum or unset `strategy`."
50        )
51    if not enabled:
52        raise ValueError(
53            f"opt_einsum is not enabled, so setting a `strategy` to {_strategy} will not be meaningful. "
54            "torch.einsum will bypass path calculation and simply contract from left to right. "
55            "Please set `enabled` to `True` as well or unset `strategy`."
56        )
57    if _strategy not in ["auto", "greedy", "optimal"]:
58        raise ValueError(
59            f"`strategy` must be one of the following: [auto, greedy, optimal] but is {_strategy}"
60        )
61    global strategy
62    strategy = _strategy
63
64
65def _get_strategy() -> str:
66    return strategy
67
68
69def set_flags(_enabled=None, _strategy=None):
70    orig_flags = (enabled, None if not is_available() else strategy)
71    if _enabled is not None:
72        _set_enabled(_enabled)
73    if _strategy is not None:
74        _set_strategy(_strategy)
75    return orig_flags
76
77
78@contextmanager
79def flags(enabled=None, strategy=None):
80    with __allow_nonbracketed_mutation():
81        orig_flags = set_flags(enabled, strategy)
82    try:
83        yield
84    finally:
85        # recover the previous values
86        with __allow_nonbracketed_mutation():
87            set_flags(*orig_flags)
88
89
90# The magic here is to allow us to intercept code like this:
91#
92#   torch.backends.opt_einsum.enabled = True
93
94
95class OptEinsumModule(PropModule):
96    def __init__(self, m, name):
97        super().__init__(m, name)
98
99    global enabled
100    enabled = ContextProp(_get_enabled, _set_enabled)
101    global strategy
102    strategy = None
103    if is_available():
104        strategy = ContextProp(_get_strategy, _set_strategy)
105
106
107# This is the sys.modules replacement trick, see
108# https://stackoverflow.com/questions/2447353/getattr-on-a-module/7668273#7668273
109sys.modules[__name__] = OptEinsumModule(sys.modules[__name__], __name__)
110
111enabled = True if is_available() else False
112strategy = "auto" if is_available() else None
113