1# mypy: allow-untyped-defs 2from torch.fx.experimental.graph_gradual_typechecker import Refine 3from torch.fx.tensor_type import TensorType 4from torch.fx.experimental.unification import Var, unify # type: ignore[attr-defined] 5 6 7def infer_symbolic_types_single_pass(traced): 8 """ 9 Calls our symbolic inferencer once. 10 """ 11 r = Refine(traced) 12 r.refine() 13 mgu = unify_eq(r.constraints) 14 substitute_all_types(traced.graph, mgu) 15 16def infer_symbolic_types(traced): 17 """ 18 Calls our symbolic inferencer twice. 19 This is useful when one pass is not enough 20 to infer all the information such as the case 21 for braodcasting. 22 """ 23 r = Refine(traced) 24 r.refine() 25 mgu = unify_eq(r.constraints) 26 substitute_all_types(traced.graph, mgu) 27 28 r = Refine(traced) 29 r.refine() 30 mgu = unify_eq(r.constraints) 31 substitute_all_types(traced.graph, mgu) 32 33 r.symbolic_relations() 34 35def convert_eq(list_of_eq): 36 """ 37 Convert equality constraints in the right format 38 to be used by unification library. 39 """ 40 lhs = [] 41 rhs = [] 42 for eq in list_of_eq: 43 lhs.append(eq.lhs) 44 rhs.append(eq.rhs) 45 return tuple(lhs), tuple(rhs) 46 47 48def unify_eq(list_of_eq): 49 """ 50 Apply unification to a set of 51 equality constraints 52 """ 53 lhs, rhs = convert_eq(list_of_eq) 54 return unify(lhs, rhs) 55 56 57def substitute_solution_one_type(mapping, t): 58 """ 59 Apply the most general unifier to a type 60 """ 61 if isinstance(t, Var): 62 if t in mapping.keys(): 63 return mapping[t] 64 else: 65 return t 66 67 elif isinstance(t, TensorType): 68 new_type = [] 69 for typ in t.__args__: 70 if typ in mapping.keys(): 71 new_type.append(mapping[typ]) 72 else: 73 new_type.append(typ) 74 return TensorType(tuple(new_type)) 75 76 elif isinstance(t, list): 77 new_type = [] 78 for typ in t: 79 new_type.append(substitute_solution_one_type(mapping, typ)) 80 return new_type 81 82 elif isinstance(t, tuple): 83 new_type = [] 84 for typ in t: 85 new_type.append(substitute_solution_one_type(mapping, typ)) 86 return tuple(new_type) 87 88 else: 89 return t 90 91 92def substitute_all_types(graph, mapping): 93 """ 94 Apply the most general unifier to all types in a graph 95 till reaching a fixed point. If the input and output graph 96 are the same, we converge. 97 """ 98 flag = True 99 while flag: 100 flag = False 101 for k in mapping: 102 old_mapping_val = mapping[k] 103 if mapping[k] in mapping.keys(): 104 new_key = mapping[k] 105 mapping[k] = mapping[new_key] 106 if old_mapping_val != mapping[k]: 107 flag = True 108 109 for n in graph.nodes: 110 n.type = substitute_solution_one_type(mapping, n.type) 111 112def check_for_type_equality(g1, g2): 113 """ 114 A check equality to be used in fixed points. 115 We do not use graph equality but instead type 116 equality. 117 """ 118 for n, m in zip(g1.nodes, g2.nodes): 119 if n.type != m.type: 120 return False 121 return True 122