1# mypy: allow-untyped-defs 2from torch.fx.experimental.migrate_gradual_types.constraint import TVar, DVar, BinConstraintD, \ 3 BVar 4from torch.fx.experimental.migrate_gradual_types.operation import op_leq 5 6 7def gen_tvar(curr): 8 """ 9 Generate a tensor variable 10 :param curr: The current counter 11 :return: a tensor variable and the updated counter 12 """ 13 curr += 1 14 return TVar(curr), curr 15 16 17def gen_dvar(curr): 18 """ 19 Generate a dimension variable 20 :param curr: the current counter 21 :return: a dimension variable and an updated counter 22 """ 23 curr += 1 24 return DVar(curr), curr 25 26def gen_bvar(curr): 27 """ 28 Generate a boolean variable 29 :param curr: the current counter 30 :return: a boolean variable and an updated counter 31 """ 32 curr += 1 33 return BVar(curr), curr 34 35def gen_tensor_dims(n, curr): 36 """ 37 Generate a list of tensor dimensions 38 :param n: the number of dimensions 39 :param curr: the current counter 40 :return: a list of dimension variables and an updated counter 41 """ 42 dims = [] 43 for _ in range(n): 44 dvar, curr = gen_dvar(curr) 45 dims.append(dvar) 46 return dims, curr 47 48 49def gen_nat_constraints(list_of_dims): 50 """ 51 Generate natural number constraints for dimensions 52 """ 53 return [BinConstraintD(0, d, op_leq) for d in list_of_dims] 54