1# mypy: allow-untyped-defs 2import sys 3import warnings 4from contextlib import contextmanager 5from functools import lru_cache as _lru_cache 6from typing import Any 7 8from torch.backends import __allow_nonbracketed_mutation, ContextProp, PropModule 9 10 11try: 12 import opt_einsum as _opt_einsum # type: ignore[import] 13except ImportError: 14 _opt_einsum = None 15 16 17@_lru_cache 18def is_available() -> bool: 19 r"""Return a bool indicating if opt_einsum is currently available.""" 20 return _opt_einsum is not None 21 22 23def get_opt_einsum() -> Any: 24 r"""Return the opt_einsum package if opt_einsum is currently available, else None.""" 25 return _opt_einsum 26 27 28def _set_enabled(_enabled: bool) -> None: 29 if not is_available() and _enabled: 30 raise ValueError( 31 f"opt_einsum is not available, so setting `enabled` to {_enabled} will not reap " 32 "the benefits of calculating an optimal path for einsum. torch.einsum will " 33 "fall back to contracting from left to right. To enable this optimal path " 34 "calculation, please install opt-einsum." 35 ) 36 global enabled 37 enabled = _enabled 38 39 40def _get_enabled() -> bool: 41 return enabled 42 43 44def _set_strategy(_strategy: str) -> None: 45 if not is_available(): 46 raise ValueError( 47 f"opt_einsum is not available, so setting `strategy` to {_strategy} will not be meaningful. " 48 "torch.einsum will bypass path calculation and simply contract from left to right. " 49 "Please install opt_einsum or unset `strategy`." 50 ) 51 if not enabled: 52 raise ValueError( 53 f"opt_einsum is not enabled, so setting a `strategy` to {_strategy} will not be meaningful. " 54 "torch.einsum will bypass path calculation and simply contract from left to right. " 55 "Please set `enabled` to `True` as well or unset `strategy`." 56 ) 57 if _strategy not in ["auto", "greedy", "optimal"]: 58 raise ValueError( 59 f"`strategy` must be one of the following: [auto, greedy, optimal] but is {_strategy}" 60 ) 61 global strategy 62 strategy = _strategy 63 64 65def _get_strategy() -> str: 66 return strategy 67 68 69def set_flags(_enabled=None, _strategy=None): 70 orig_flags = (enabled, None if not is_available() else strategy) 71 if _enabled is not None: 72 _set_enabled(_enabled) 73 if _strategy is not None: 74 _set_strategy(_strategy) 75 return orig_flags 76 77 78@contextmanager 79def flags(enabled=None, strategy=None): 80 with __allow_nonbracketed_mutation(): 81 orig_flags = set_flags(enabled, strategy) 82 try: 83 yield 84 finally: 85 # recover the previous values 86 with __allow_nonbracketed_mutation(): 87 set_flags(*orig_flags) 88 89 90# The magic here is to allow us to intercept code like this: 91# 92# torch.backends.opt_einsum.enabled = True 93 94 95class OptEinsumModule(PropModule): 96 def __init__(self, m, name): 97 super().__init__(m, name) 98 99 global enabled 100 enabled = ContextProp(_get_enabled, _set_enabled) 101 global strategy 102 strategy = None 103 if is_available(): 104 strategy = ContextProp(_get_strategy, _set_strategy) 105 106 107# This is the sys.modules replacement trick, see 108# https://stackoverflow.com/questions/2447353/getattr-on-a-module/7668273#7668273 109sys.modules[__name__] = OptEinsumModule(sys.modules[__name__], __name__) 110 111enabled = True if is_available() else False 112strategy = "auto" if is_available() else None 113