xref: /aosp_15_r20/external/pytorch/torch/fx/tensor_type.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from torch.fx.experimental.unification import Var  # type: ignore[attr-defined]
3
4from ._compatibility import compatibility
5
6
7@compatibility(is_backward_compatible=False)
8class TensorType:
9    """
10    TensorType defines a type for tensors, which consists of a list of dimensions.
11    Example:
12        class M(torch.nn.Module):
13            def forward(self, x:TensorType((1,2,3, Dyn)), y:TensorType((1,2,3, Dyn))):
14                return torch.add(x, y)
15    """
16
17    def __init__(self, dim):
18        self.__origin__ = TensorType
19        self.__args__ = dim
20
21    def __repr__(self):
22        return f'TensorType[{self.__args__}]'
23
24    def __eq__(self, other):
25        if isinstance(other, self.__class__):
26            return list(self.__args__) == list(other.__args__)
27        else:
28            return False
29
30    @staticmethod
31    def __class_getitem__(*args):
32        if len(args) == 1 and isinstance(args[0], tuple):
33            args = args[0]
34        return TensorType(tuple(args))
35
36
37class _DynType:
38    """
39    _DynType defines a type which stands for the absence of type information.
40    """
41    def __init__(self) -> None:
42        self.__name__ = '_DynType'
43
44    def __eq__(self, other):
45        return isinstance(other, self.__class__)
46
47    def __str__(self):
48        return "Dyn"
49
50    def __repr__(self):
51        return "Dyn"
52
53
54Dyn = _DynType()
55
56@compatibility(is_backward_compatible=False)
57def is_consistent(t1, t2):
58    """
59    A binary relation denoted by ~ that determines if t1 is consistent with t2.
60    The relation is reflexive, symmetric but not transitive.
61    returns True if t1 and t2 are consistent and False otherwise.
62    Example:
63        Dyn ~ TensorType((1,2,3))
64        int ~ Dyn
65        int ~ int
66        TensorType((1,Dyn,3)) ~ TensorType((1,2,3))
67    """
68
69    if t1 == t2:
70        return True
71
72    if t1 == Dyn or t2 == Dyn or isinstance(t1, Var) or isinstance(t2, Var):
73        return True
74
75    if isinstance(t1, TensorType) and isinstance(t2, TensorType):
76        return len(t1.__args__) == len(t2.__args__) and \
77            all(is_consistent(elem1, elem2) for elem1, elem2 in zip(t1.__args__, t2.__args__))
78    else:
79        return False
80
81
82@compatibility(is_backward_compatible=False)
83def is_more_precise(t1, t2):
84    """
85    A binary relation denoted by <= that determines if t1 is more precise than t2.
86    The relation is reflexive and transitive.
87    returns True if t1 is more precise than t2 and False otherwise.
88    Example:
89        Dyn >= TensorType((1,2,3))
90        int >= Dyn
91        int >= int
92        TensorType((1,Dyn,3)) <= TensorType((1,2,3))
93    """
94    if t1 == t2:
95        return True
96
97    if isinstance(t2, _DynType):
98        return True
99
100    if isinstance(t1, TensorType) and isinstance(t2, TensorType):
101        return len(t1.__args__) == len(t2.__args__) and \
102            all(is_more_precise(elem1, elem2) for elem1, elem2 in zip(t1.__args__, t2.__args__))
103
104    else:
105        return False
106