xref: /aosp_15_r20/external/pytorch/torch/torch_version.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from typing import Any, Iterable
2
3from torch._vendor.packaging.version import InvalidVersion, Version
4from torch.version import __version__ as internal_version
5
6
7__all__ = ["TorchVersion"]
8
9
10class TorchVersion(str):
11    """A string with magic powers to compare to both Version and iterables!
12    Prior to 1.10.0 torch.__version__ was stored as a str and so many did
13    comparisons against torch.__version__ as if it were a str. In order to not
14    break them we have TorchVersion which masquerades as a str while also
15    having the ability to compare against both packaging.version.Version as
16    well as tuples of values, eg. (1, 2, 1)
17    Examples:
18        Comparing a TorchVersion object to a Version object
19            TorchVersion('1.10.0a') > Version('1.10.0a')
20        Comparing a TorchVersion object to a Tuple object
21            TorchVersion('1.10.0a') > (1, 2)    # 1.2
22            TorchVersion('1.10.0a') > (1, 2, 1) # 1.2.1
23        Comparing a TorchVersion object against a string
24            TorchVersion('1.10.0a') > '1.2'
25            TorchVersion('1.10.0a') > '1.2.1'
26    """
27
28    # fully qualified type names here to appease mypy
29    def _convert_to_version(self, inp: Any) -> Any:
30        if isinstance(inp, Version):
31            return inp
32        elif isinstance(inp, str):
33            return Version(inp)
34        elif isinstance(inp, Iterable):
35            # Ideally this should work for most cases by attempting to group
36            # the version tuple, assuming the tuple looks (MAJOR, MINOR, ?PATCH)
37            # Examples:
38            #   * (1)         -> Version("1")
39            #   * (1, 20)     -> Version("1.20")
40            #   * (1, 20, 1)  -> Version("1.20.1")
41            return Version(".".join(str(item) for item in inp))
42        else:
43            raise InvalidVersion(inp)
44
45    def _cmp_wrapper(self, cmp: Any, method: str) -> bool:
46        try:
47            return getattr(Version(self), method)(self._convert_to_version(cmp))
48        except BaseException as e:
49            if not isinstance(e, InvalidVersion):
50                raise
51            # Fall back to regular string comparison if dealing with an invalid
52            # version like 'parrot'
53            return getattr(super(), method)(cmp)
54
55
56for cmp_method in ["__gt__", "__lt__", "__eq__", "__ge__", "__le__"]:
57    setattr(
58        TorchVersion,
59        cmp_method,
60        lambda x, y, method=cmp_method: x._cmp_wrapper(y, method),
61    )
62
63__version__ = TorchVersion(internal_version)
64