1from typing import Any, Iterable 2 3from torch._vendor.packaging.version import InvalidVersion, Version 4from torch.version import __version__ as internal_version 5 6 7__all__ = ["TorchVersion"] 8 9 10class TorchVersion(str): 11 """A string with magic powers to compare to both Version and iterables! 12 Prior to 1.10.0 torch.__version__ was stored as a str and so many did 13 comparisons against torch.__version__ as if it were a str. In order to not 14 break them we have TorchVersion which masquerades as a str while also 15 having the ability to compare against both packaging.version.Version as 16 well as tuples of values, eg. (1, 2, 1) 17 Examples: 18 Comparing a TorchVersion object to a Version object 19 TorchVersion('1.10.0a') > Version('1.10.0a') 20 Comparing a TorchVersion object to a Tuple object 21 TorchVersion('1.10.0a') > (1, 2) # 1.2 22 TorchVersion('1.10.0a') > (1, 2, 1) # 1.2.1 23 Comparing a TorchVersion object against a string 24 TorchVersion('1.10.0a') > '1.2' 25 TorchVersion('1.10.0a') > '1.2.1' 26 """ 27 28 # fully qualified type names here to appease mypy 29 def _convert_to_version(self, inp: Any) -> Any: 30 if isinstance(inp, Version): 31 return inp 32 elif isinstance(inp, str): 33 return Version(inp) 34 elif isinstance(inp, Iterable): 35 # Ideally this should work for most cases by attempting to group 36 # the version tuple, assuming the tuple looks (MAJOR, MINOR, ?PATCH) 37 # Examples: 38 # * (1) -> Version("1") 39 # * (1, 20) -> Version("1.20") 40 # * (1, 20, 1) -> Version("1.20.1") 41 return Version(".".join(str(item) for item in inp)) 42 else: 43 raise InvalidVersion(inp) 44 45 def _cmp_wrapper(self, cmp: Any, method: str) -> bool: 46 try: 47 return getattr(Version(self), method)(self._convert_to_version(cmp)) 48 except BaseException as e: 49 if not isinstance(e, InvalidVersion): 50 raise 51 # Fall back to regular string comparison if dealing with an invalid 52 # version like 'parrot' 53 return getattr(super(), method)(cmp) 54 55 56for cmp_method in ["__gt__", "__lt__", "__eq__", "__ge__", "__le__"]: 57 setattr( 58 TorchVersion, 59 cmp_method, 60 lambda x, y, method=cmp_method: x._cmp_wrapper(y, method), 61 ) 62 63__version__ = TorchVersion(internal_version) 64