1# mypy: allow-untyped-defs 2from torch.fx.experimental.unification import Var # type: ignore[attr-defined] 3 4from ._compatibility import compatibility 5 6 7@compatibility(is_backward_compatible=False) 8class TensorType: 9 """ 10 TensorType defines a type for tensors, which consists of a list of dimensions. 11 Example: 12 class M(torch.nn.Module): 13 def forward(self, x:TensorType((1,2,3, Dyn)), y:TensorType((1,2,3, Dyn))): 14 return torch.add(x, y) 15 """ 16 17 def __init__(self, dim): 18 self.__origin__ = TensorType 19 self.__args__ = dim 20 21 def __repr__(self): 22 return f'TensorType[{self.__args__}]' 23 24 def __eq__(self, other): 25 if isinstance(other, self.__class__): 26 return list(self.__args__) == list(other.__args__) 27 else: 28 return False 29 30 @staticmethod 31 def __class_getitem__(*args): 32 if len(args) == 1 and isinstance(args[0], tuple): 33 args = args[0] 34 return TensorType(tuple(args)) 35 36 37class _DynType: 38 """ 39 _DynType defines a type which stands for the absence of type information. 40 """ 41 def __init__(self) -> None: 42 self.__name__ = '_DynType' 43 44 def __eq__(self, other): 45 return isinstance(other, self.__class__) 46 47 def __str__(self): 48 return "Dyn" 49 50 def __repr__(self): 51 return "Dyn" 52 53 54Dyn = _DynType() 55 56@compatibility(is_backward_compatible=False) 57def is_consistent(t1, t2): 58 """ 59 A binary relation denoted by ~ that determines if t1 is consistent with t2. 60 The relation is reflexive, symmetric but not transitive. 61 returns True if t1 and t2 are consistent and False otherwise. 62 Example: 63 Dyn ~ TensorType((1,2,3)) 64 int ~ Dyn 65 int ~ int 66 TensorType((1,Dyn,3)) ~ TensorType((1,2,3)) 67 """ 68 69 if t1 == t2: 70 return True 71 72 if t1 == Dyn or t2 == Dyn or isinstance(t1, Var) or isinstance(t2, Var): 73 return True 74 75 if isinstance(t1, TensorType) and isinstance(t2, TensorType): 76 return len(t1.__args__) == len(t2.__args__) and \ 77 all(is_consistent(elem1, elem2) for elem1, elem2 in zip(t1.__args__, t2.__args__)) 78 else: 79 return False 80 81 82@compatibility(is_backward_compatible=False) 83def is_more_precise(t1, t2): 84 """ 85 A binary relation denoted by <= that determines if t1 is more precise than t2. 86 The relation is reflexive and transitive. 87 returns True if t1 is more precise than t2 and False otherwise. 88 Example: 89 Dyn >= TensorType((1,2,3)) 90 int >= Dyn 91 int >= int 92 TensorType((1,Dyn,3)) <= TensorType((1,2,3)) 93 """ 94 if t1 == t2: 95 return True 96 97 if isinstance(t2, _DynType): 98 return True 99 100 if isinstance(t1, TensorType) and isinstance(t2, TensorType): 101 return len(t1.__args__) == len(t2.__args__) and \ 102 all(is_more_precise(elem1, elem2) for elem1, elem2 in zip(t1.__args__, t2.__args__)) 103 104 else: 105 return False 106