xref: /aosp_15_r20/external/pytorch/torch/fx/experimental/migrate_gradual_types/util.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from torch.fx.experimental.migrate_gradual_types.constraint import TVar, DVar, BinConstraintD, \
3    BVar
4from torch.fx.experimental.migrate_gradual_types.operation import op_leq
5
6
7def gen_tvar(curr):
8    """
9    Generate a tensor variable
10    :param curr: The current counter
11    :return: a tensor variable and the updated counter
12    """
13    curr += 1
14    return TVar(curr), curr
15
16
17def gen_dvar(curr):
18    """
19    Generate a dimension variable
20    :param curr: the current counter
21    :return: a dimension variable and an updated counter
22    """
23    curr += 1
24    return DVar(curr), curr
25
26def gen_bvar(curr):
27    """
28    Generate a boolean variable
29    :param curr: the current counter
30    :return: a boolean variable and an updated counter
31    """
32    curr += 1
33    return BVar(curr), curr
34
35def gen_tensor_dims(n, curr):
36    """
37    Generate a list of tensor dimensions
38    :param n:  the number of dimensions
39    :param curr: the current counter
40    :return: a list of dimension variables and an updated counter
41    """
42    dims = []
43    for _ in range(n):
44        dvar, curr = gen_dvar(curr)
45        dims.append(dvar)
46    return dims, curr
47
48
49def gen_nat_constraints(list_of_dims):
50    """
51    Generate natural number constraints for dimensions
52    """
53    return [BinConstraintD(0, d, op_leq) for d in list_of_dims]
54