1# mypy: ignore-errors
2import copy
3import itertools
4from torch.fx.experimental.migrate_gradual_types.constraint_generator import BinConstraintT, MAX_TENSOR_RANK
5from torch.fx.experimental.migrate_gradual_types.constraint import T, BinConstraintD, Conj, Constraint, DVar, TVar, \
6    Transpose
7from torch.fx.experimental.migrate_gradual_types.constraint import Disj, TGreatestUpperBound
8from torch.fx.experimental.migrate_gradual_types.constraint import DGreatestUpperBound
9from torch.fx.experimental.migrate_gradual_types.constraint import CalcConv, CalcMaxPool
10from torch.fx.experimental.migrate_gradual_types.constraint import CalcProduct, CanReshape
11from torch.fx.experimental.migrate_gradual_types.constraint import ApplyBroadcasting, Prod, F, GetItem, GetItemTensor, IndexSelect
12from torch.fx.experimental.migrate_gradual_types.operation import op_eq, op_precision, op_leq, op_matching
13from torch.fx.experimental.migrate_gradual_types.operation import op_consistency, op_neq
14from torch.fx.experimental.migrate_gradual_types.operation import op_mul, op_add, op_sub, op_div, op_mod
15from torch.fx.experimental.migrate_gradual_types.util import gen_tensor_dims, gen_nat_constraints, gen_dvar
16from torch.fx.tensor_type import TensorType, Dyn
17from typing import Callable, Dict, List
18
19_TRANSFORMATION_RULES: Dict[Constraint, Callable] = {}
20
21
22def register_transformation_rule(call_target):
23    def register(fn):
24        if call_target in _TRANSFORMATION_RULES:
25            raise RuntimeError(f'Transformation rule already registered for {call_target}!')
26        _TRANSFORMATION_RULES[call_target] = fn
27        return fn
28    return register
29
30
31def valid_index(index, dims):
32    """
33    Given a list of dimensions, checks if an index is valid in the list
34    """
35    try:
36        dims[index]
37        return T()
38    except IndexError:
39        return F()
40
41
42@register_transformation_rule(Transpose)
43def transform_transpose(constraint, counter):
44    """
45    Similar to a sequence of two index-selects
46    """
47    dims, counter = gen_tensor_dims(constraint.tensor_size, counter)
48    is_valid_index1 = valid_index(constraint.index1, dims)
49    is_valid_index2 = valid_index(constraint.index2, dims)
50    new_dims = copy.deepcopy(dims)
51    nat_constraints = gen_nat_constraints(dims)
52
53    if is_valid_index1 == T() and is_valid_index2 == T():
54        new_dims[constraint.index1] = dims[constraint.index2]
55        new_dims[constraint.index2] = dims[constraint.index1]
56
57    transformed_constraint = Conj([BinConstraintT(constraint.input_var, TensorType(dims), op_eq),
58                                   *nat_constraints,
59                                   is_valid_index1, is_valid_index2,
60                                   BinConstraintT(constraint.output, TensorType(new_dims), op_eq)])
61    return transformed_constraint, counter
62
63
64@register_transformation_rule(IndexSelect)
65def transform_index_select(constraint, counter):
66    """
67    The constraints consider the given tensor size, checks if the index is valid
68    and if so, generates a constraint for replacing the input dimension
69    with the required dimension
70    """
71    dims, counter = gen_tensor_dims(constraint.tensor_size, counter)
72    is_valid_index = valid_index(constraint.index, dims)
73    nat_constraints = gen_nat_constraints(dims)
74
75    # if the index is valid then replace the input dimension with the new dimension
76    # otherwise the dimension will not be replaced and the clause will contain False
77    if is_valid_index == T():
78        new_dims = copy.deepcopy(dims)
79        new_dims[constraint.index] = constraint.dim_replace
80
81    transformed_constraint = Conj([BinConstraintT(constraint.input_var, TensorType(dims), op_eq),
82                                   *nat_constraints,
83                                   is_valid_index,
84                                   BinConstraintT(constraint.output, TensorType(new_dims), op_eq)])
85
86    # print(constraints)
87    return transformed_constraint, counter
88
89
90@register_transformation_rule(GetItem)
91def transform_get_item(constraint, counter):
92    """
93    generate an equality of the form:
94    t = [a1, ..., an]
95    then generate constraints that check if the given index is valid
96    given this particular tensor size.
97    If the index is valid, generate a constraint to get the item
98    Note that we already handled the Dyn input case in the previous
99    step.
100    Args:
101        constraint: GetItem which assumes we are getting an item from a tensor (not Dyn)
102        counter: variable tracking
103    Returns: simplified constraints for GetItem
104
105    """
106    dims, counter = gen_tensor_dims(constraint.tensor_size, counter)
107    nat_constraints = gen_nat_constraints(dims)
108
109
110    is_valid_index = valid_index(constraint.index, dims)
111
112    all_constraints = [BinConstraintT(constraint.input_var, TensorType(dims), op_eq),
113                       *nat_constraints,
114                       is_valid_index]
115
116    # if the index is valid, we generate a constraint for getting an item
117    # otherwise this clause will have been UNSAT due to the wrong index
118    if is_valid_index == T():
119        all_constraints.append(BinConstraintD(constraint.res, dims[constraint.index], op_eq))
120
121    return Conj(all_constraints), counter
122
123def valid_index_tensor(index, dims):
124    """
125    if the slice instances exceed the length of the dimensions
126    then this is a type error so we return False
127    """
128    slice_count = 0
129    for s in index:
130        if isinstance(s, slice):
131            slice_count += 1
132    if slice_count > len(dims):
133        return F()
134    else:
135        return T()
136
137@register_transformation_rule(GetItemTensor)
138def transform_get_item_tensor(constraint, counter):
139    """
140    When the index is a tuple, then the output will be a tensor
141    TODO: we have to check if this is the case for all HF models
142
143    The cases we are covering here are a tuple with one of:
144     - slice with default argument
145     - None
146
147     None appends 1 to the input tensor dimensions
148     so each occurrence of 'None' increases the rank by 1
149
150     slice with default arguments does not change the rank
151    """
152    assert isinstance(constraint.index_tuple, tuple)
153
154
155    # generate a result tensor of the expected size
156    dims, counter = gen_tensor_dims(constraint.tensor_size, counter)
157    nat_constraints = gen_nat_constraints(dims)
158
159    # generate a place-holder list of the right rank
160    # where "slice" does not contribute to the rank and "None" does
161    none_c = constraint.index_tuple.count(None)
162    resulting_tensor_dims = (none_c + len(dims)) * [None]
163
164    dim_index = 0
165    for i in range(len(constraint.index_tuple)):
166
167        # append 1 to the right location of the resulting tensor
168        if constraint.index_tuple[i] is None:
169            resulting_tensor_dims[i] = 1
170
171        elif constraint.index_tuple[i] == slice(None, None, None):
172            pass
173
174        else:
175            raise NotImplementedError('Method not yet implemented')
176
177    # append the remaining dimensions to the right location
178    dim_index = 0
179    for i in range(len(resulting_tensor_dims)):
180        if resulting_tensor_dims[i] is None:
181            resulting_tensor_dims[i] = dims[dim_index]
182            dim_index += 1
183
184    # check if the index is valid
185    is_valid_index = valid_index_tensor(constraint.index_tuple, dims)
186
187    # check if the resulting tensor is within bounds
188    if len(resulting_tensor_dims) > 4:
189        return F(), counter
190
191    else:
192        constraints = [BinConstraintT(constraint.input_var, TensorType(dims), op_eq),
193                       BinConstraintT(constraint.res, TensorType(resulting_tensor_dims), op_eq),
194                       *nat_constraints,
195                       is_valid_index]
196        return Conj(constraints), counter
197
198
199@register_transformation_rule(BinConstraintT)
200def generate_binconstraint_t(constraint, counter):
201    """
202    Transform binary constraints for tensors
203    """
204
205    # precision constraints
206    if constraint.op == op_precision:
207        if constraint.lhs == Dyn:
208            return T(), counter
209        elif isinstance(constraint.lhs, TensorType):
210            is_fully_static = all(d != Dyn for d in constraint.lhs.__args__)
211            if is_fully_static:
212                return BinConstraintT(constraint.lhs, constraint.rhs, op_eq), counter
213            else:
214                new_dims = []
215
216                for _ in range(len(constraint.lhs.__args__)):
217                    dim, counter = gen_dvar(counter)
218                    new_dims.append(dim)
219
220                new_dim_constraints = [BinConstraintD(old_dim, new_dim, op_precision) for
221                                       new_dim, old_dim in zip(new_dims, constraint.lhs.__args__)] + \
222                                      [BinConstraintT(constraint.rhs, TensorType(new_dims), op_eq)] + \
223                                      [BinConstraintD(1, new_dim, op_leq) for
224                                       new_dim in new_dims]
225                return Conj(new_dim_constraints), counter
226
227    # matching
228    elif constraint.op == op_matching:
229        assert isinstance(constraint.rhs, TensorType)
230        d1 = constraint.rhs.__args__[0]
231        d2 = constraint.rhs.__args__[1]
232        d3 = constraint.rhs.__args__[2]
233        d4 = constraint.rhs.__args__[3]
234
235        conj = [BinConstraintT(constraint.lhs, Dyn, op_eq),
236                BinConstraintD(d1, Dyn, op_eq),
237                BinConstraintD(d2, Dyn, op_eq),
238                BinConstraintD(d3, Dyn, op_eq),
239                BinConstraintD(d4, Dyn, op_eq)]
240        return Disj([Conj(conj),
241                     BinConstraintT(constraint.lhs, TensorType([d1, d2, d3, d4]), op_eq)]), counter
242
243    elif constraint.op == op_consistency:
244        c_dyn = Disj([BinConstraintT(constraint.lhs, Dyn, op_eq), BinConstraintT(constraint.rhs, Dyn, op_eq)])
245        [c_tensor_1, c_tensor_2, c_tensor_3, c_tensor_4], counter = gen_consistency_constraints(constraint, counter)
246
247        return Disj([c_dyn, c_tensor_1, c_tensor_2, c_tensor_3, c_tensor_4]), counter
248
249    elif constraint.op == op_leq:
250        assert isinstance(constraint.rhs, int)
251        disj = [BinConstraintT(constraint.lhs, Dyn, op_eq)]
252        for i in range(1, constraint.rhs + 1):
253            dims = []
254            for j in range(1, i + 1):
255                dim_var, counter = gen_dvar(counter)
256                dims.append(dim_var)
257            disj.append(BinConstraintT(constraint.lhs, TensorType(dims), op_eq))
258        return Disj(disj), counter
259    else:
260        return constraint, counter
261
262
263@register_transformation_rule(BinConstraintD)
264def generate_binconstraint_d(constraint, counter):
265    """
266    Transform binary constraints for dimensions
267    """
268    if constraint.op == op_precision:
269        if isinstance(constraint.lhs, int):
270            return BinConstraintD(constraint.lhs, constraint.rhs, op_eq), counter
271        elif constraint.lhs == Dyn:
272            return T(), counter
273
274    elif constraint.op == op_consistency:
275        return Disj([BinConstraintD(constraint.lhs, constraint.rhs, op_eq),
276                     BinConstraintD(constraint.rhs, Dyn, op_eq), BinConstraintD(constraint.lhs, Dyn, op_eq)]), counter
277
278    else:
279        return constraint, counter
280
281
282@register_transformation_rule(Conj)
283def generate_conj(constraint, counter):
284    """
285    Transform conjunctions
286    """
287    new = []
288    for c in constraint.conjucts:
289        new_c, counter = transform_constraint(c, counter)
290        new.append(new_c)
291    return Conj(new), counter
292
293
294@register_transformation_rule(Disj)
295def generate_disj(constraint, counter):
296    """
297    Transform disjunctions
298    """
299    new = []
300    for c in constraint.disjuncts:
301        new_c, counter = transform_constraint(c, counter)
302        new.append(new_c)
303    return Disj(new), counter
304
305
306@register_transformation_rule(TGreatestUpperBound)
307def generate_gub(constraint, counter):
308    """
309    Transform greatest upper bound for tensors. Results in equality and Greatest Upper Bound
310    on dimensions
311    """
312    c1 = Conj([Disj([BinConstraintT(constraint.rhs1, Dyn, op_eq),
313                     BinConstraintT(constraint.rhs2, Dyn, op_eq)]), BinConstraintT(constraint.res, Dyn, op_eq)])
314
315    [c2, c3, c4, c5], counter = gen_greatest_upper_bound(constraint, counter)
316
317    return Disj([c1, c2, c3, c4, c5]), counter
318
319
320@register_transformation_rule(DGreatestUpperBound)
321def generate_d_gub(constraint, counter):
322    """
323    Transform greatest upper bound for dimensions into equality constraints
324    """
325    c1 = Conj([BinConstraintD(constraint.rhs1, Dyn, op_eq), BinConstraintD(constraint.res, constraint.rhs2, op_eq)])
326    c2 = Conj([BinConstraintD(constraint.rhs2, Dyn, op_eq), BinConstraintD(constraint.res, constraint.rhs1, op_eq)])
327    c3 = Conj([BinConstraintD(constraint.rhs2, constraint.rhs1, op_eq), BinConstraintD(constraint.res, constraint.rhs1, op_eq)])
328    return Disj([c1, c2, c3]), counter
329
330
331@register_transformation_rule(CalcConv)
332def generate_calc_conv(constraint, counter):
333    d, counter = gen_tensor_dims(4, counter)
334    conv_result = TensorType([d[0], d[1], d[2], d[3]])
335
336    # the convolution result is a tensor of size 4
337    c1 = BinConstraintT(constraint.conv_result, conv_result, op_eq)
338
339    # the second dimension of the output is equal to the output channels
340    c2 = Conj([BinConstraintD(d[1], constraint.c_out, op_eq), BinConstraintD(d[1], Dyn, op_neq)])
341
342    # the input corresponds to the output in the first dimension of the convolution
343    c3 = BinConstraintD(constraint.matching_constraint[0], d[0], op_eq)
344
345    c4, c5 = calc_last_two_dims(constraint, d)
346
347    leq_constraints = Conj([BinConstraintD(0, d[0], op_leq),
348                            BinConstraintD(0, d[1], op_leq),
349                            BinConstraintD(0, d[2], op_leq),
350                            BinConstraintD(0, d[3], op_leq)])
351
352    return Conj([c1, c2, c3, c4, c5, leq_constraints]), counter
353
354
355@register_transformation_rule(CalcMaxPool)
356def generate_calc_maxpool(constraint, counter):
357    """
358    Transform maxpool constraints
359    """
360    d, counter = gen_tensor_dims(4, counter)
361    maxpool_result = TensorType([d[0], d[1], d[2], d[3]])
362
363    # the maxpool result is a tensor of size 4
364    c1 = BinConstraintT(constraint.maxpool_result, maxpool_result, op_eq)
365
366    # the input corresponds to the output in the first and second dimension of maxpool
367    c2 = BinConstraintD(constraint.matching_constraint[1], d[1], op_eq)
368    c3 = BinConstraintD(constraint.matching_constraint[0], d[0], op_eq)
369    c4, c5 = calc_last_two_dims(constraint, d)
370
371    leq_constraints = Conj([BinConstraintD(0, d[0], op_leq),
372                            BinConstraintD(0, d[1], op_leq),
373                            BinConstraintD(0, d[2], op_leq),
374                            BinConstraintD(0, d[3], op_leq)])
375
376    return Conj([c1, c2, c3, c4, c5, leq_constraints]), counter
377
378
379@register_transformation_rule(CalcProduct)
380def generate_calc_product(constraint, counter):
381    """
382    Transform flatten constraints
383    """
384    start = constraint.start
385    end = constraint.end
386    dims = constraint.dims_to_flatten
387    flattened = constraint.flattened
388    n = len(constraint.dims_to_flatten)
389
390    # this will be evaluated right here
391    boundary_check = (0 <= start and start < end and end <= n)
392
393    c_boundary = T() if boundary_check else F()
394
395    lhs = dims[0:start]
396    rhs = dims[end:]
397    mid = dims[start:end]
398
399    all_possibilities = generate_all_int_dyn_dim_possibilities(mid)
400
401    all_constraints = []
402
403    for p in all_possibilities:
404        p = list(p)
405        # this tells us there is a dynamic variable
406        contains_dyn = not all(constraint.op == op_neq for constraint in p)
407        if contains_dyn:
408            mid_var = [Dyn]
409            total_constraints = lhs + mid_var + rhs
410            if len(total_constraints) > 4:
411                all_constraints.append(F())
412            else:
413                all_constraints.append(Conj([BinConstraintT(flattened, TensorType(lhs + mid_var + rhs), op_eq)] + p))
414        else:
415            new_var, counter = gen_dvar(counter)
416            mid_eq_prod = Conj([BinConstraintD(new_var, Prod(mid), op_eq), BinConstraintD(new_var, Dyn, op_neq)])
417            mid_var = [new_var]
418            total_constraints = lhs + mid_var + rhs
419            if len(total_constraints) > 4:
420                all_constraints.append(F())
421            else:
422                all_constraints.append(Conj([BinConstraintT(flattened, TensorType(lhs + mid_var + rhs), op_eq), mid_eq_prod] + p))
423
424    return Conj([Disj(all_constraints), c_boundary]), counter
425
426
427@register_transformation_rule(CanReshape)
428def generate_reshape(constraint, counter):
429    """
430    Transform reshape constraints
431    """
432    d, counter = gen_tensor_dims(4, counter)
433
434    d1 = d[0]
435    d2 = d[1]
436    d3 = d[2]
437    d4 = d[3]
438
439    target = constraint.target.__args__
440
441    is_fully_static = all(d != Dyn for d in target)
442
443    # dynamic tensor
444    c1_dyn = BinConstraintT(constraint.src, Dyn, op_eq)
445    c2_tensor1 = BinConstraintT(constraint.src, TensorType([d1]), op_eq)
446    c2_tensor2 = BinConstraintT(constraint.src, TensorType([d1, d2]), op_eq)
447    c2_tensor3 = BinConstraintT(constraint.src, TensorType([d1, d2, d3]), op_eq)
448    c2_tensor4 = BinConstraintT(constraint.src, TensorType([d1, d2, d3, d4]), op_eq)
449
450    d1_eq_dyn = BinConstraintD(d1, Dyn, op_eq)
451    d1_neq_dyn = BinConstraintD(d1, Dyn, op_neq)
452
453    d2_eq_dyn = BinConstraintD(d2, Dyn, op_eq)
454    d2_neq_dyn = BinConstraintD(d2, Dyn, op_neq)
455
456    d3_eq_dyn = BinConstraintD(d3, Dyn, op_eq)
457    d3_neq_dyn = BinConstraintD(d3, Dyn, op_neq)
458
459    d4_eq_dyn = BinConstraintD(d3, Dyn, op_eq)
460    d4_neq_dyn = BinConstraintD(d3, Dyn, op_neq)
461
462    nat_d1 = BinConstraintD(0, d1, op_leq)
463    nat_d2 = BinConstraintD(0, d2, op_leq)
464    nat_d3 = BinConstraintD(0, d3, op_leq)
465    nat_d4 = BinConstraintD(0, d4, op_leq)
466
467    if is_fully_static:
468        # size 1 tensor
469        c3_tensor1 = Disj([d1_eq_dyn,
470                           (Conj([d1_neq_dyn,
471                                  BinConstraintD(d1, Prod(target), op_eq)]))])
472        all_tensor_1 = Conj([c2_tensor1, c3_tensor1])
473
474        # size 2 tensor
475        all_tensor_2 = Conj([c2_tensor2, gen_all_reshape_possibilities([d1, d2], target)])
476
477        # size 3 tensor
478        all_tensor_3 = Conj([c2_tensor3, gen_all_reshape_possibilities([d1, d2, d3], target)])
479
480        # size 4 tensor
481        all_tensor_4 = Conj([c2_tensor4, gen_all_reshape_possibilities([d1, d2, d3, d4], target)])
482
483        return Conj([Disj([c1_dyn, all_tensor_1, all_tensor_2, all_tensor_3, all_tensor_4]),
484                     nat_d1, nat_d2, nat_d3, nat_d4]), counter
485
486    # then there must be exactly one occurrence of dyn
487    else:
488        new_target = []
489
490        for n in target:
491            if n != Dyn:
492                new_target.append(n)
493
494        # tensor 1
495        c3_tensor1 = Disj([d1_eq_dyn,
496                           (Conj([d1_neq_dyn,
497                                  is_dim_div_by_target(new_target, d1)]))])
498        all_tensor_1 = Conj([c2_tensor1, c3_tensor1])
499
500        # tensor 2
501        c21 = Disj([d1_eq_dyn, d2_eq_dyn])
502        c22 = Conj([d1_neq_dyn, d2_neq_dyn, is_dim_div_by_target(new_target, Prod([d1, d2]))])
503        all_tensor_2 = Conj([c2_tensor2, Disj([c21, c22])])
504
505        # tensor 3
506        c31 = Disj([d1_eq_dyn, d2_eq_dyn, d3_eq_dyn])
507        c32 = Conj([d1_neq_dyn, d2_neq_dyn, d3_neq_dyn, is_dim_div_by_target(new_target, Prod([d1, d2, d3]))])
508        all_tensor_3 = Conj([c2_tensor3, Disj([c31, c32])])
509
510        # tensor 4
511        c41 = Disj([d1_eq_dyn, d2_eq_dyn, d3_eq_dyn, d4_eq_dyn])
512        c42 = Conj([d1_neq_dyn, d2_neq_dyn, d3_neq_dyn, d4_neq_dyn, is_dim_div_by_target(new_target, Prod([d1, d2, d3, d4]))])
513        all_tensor_4 = Conj([c2_tensor4, Disj([c41, c42])])
514
515        return Conj([Disj([c1_dyn, all_tensor_1, all_tensor_2, all_tensor_3, all_tensor_4]),
516                     nat_d1, nat_d2, nat_d3, nat_d4]), counter
517
518
519@register_transformation_rule(ApplyBroadcasting)
520def generate_broadcasting(constraint, counter):
521    """
522    Transform broadcasting constraints
523    """
524    e11, e12 = constraint.res1, constraint.res2
525    e1, e2 = constraint.input1, constraint.input2
526
527    e1_dyn = BinConstraintT(e1, Dyn, op_eq)
528    e2_dyn = BinConstraintT(e2, Dyn, op_eq)
529
530    # Introduce dimensions
531    e1_equal_e11 = BinConstraintT(e1, e11, op_eq)
532    e2_equal_e12 = BinConstraintT(e2, e12, op_eq)
533
534    # dyn possibility
535    e1_dyn_constraint = Conj([e1_dyn, e1_equal_e11, e2_equal_e12])
536    e2_dyn_constraint = Conj([e2_dyn, e1_equal_e11, e2_equal_e12])
537
538    # tensor possibility
539    # generate dimensions to create tensors of size 1
540    final_tensor_1_constraint, _, _, nat_dims_1, counter = \
541        gen_broadcasting_constraints(e1, e2, e11, e12, 1, counter)
542
543    # generate dimensions to create tensors of size 2
544    final_tensor_2_constraint_no_padding, final_tensor_2_constraint_padding_arg1, \
545        final_tensor_2_constraint_padding_arg2, nat_dims_2, counter = \
546        gen_broadcasting_constraints(e1, e2, e11, e12, 2, counter)
547
548    # generate dimensions to create tensors of size 3
549    final_tensor_3_constraint_no_padding, final_tensor_3_constraint_padding_arg1, \
550        final_tensor_3_constraint_padding_arg2, nat_dims_3, counter = \
551        gen_broadcasting_constraints(e1, e2, e11, e12, 3, counter)
552
553    # generate dimensions to create tensors of size 4
554    final_tensor_4_constraint_no_padding, final_tensor_4_constraint_padding_arg1, \
555        final_tensor_4_constraint_padding_arg2, nat_dims_4, counter = \
556        gen_broadcasting_constraints(e1, e2, e11, e12, 4, counter)
557
558    final_result = Disj([
559        e1_dyn_constraint,
560        e2_dyn_constraint,
561        final_tensor_1_constraint,
562        final_tensor_2_constraint_no_padding,
563        final_tensor_2_constraint_padding_arg1,
564        final_tensor_2_constraint_padding_arg2,
565        final_tensor_3_constraint_no_padding,
566        final_tensor_3_constraint_padding_arg1,
567        final_tensor_3_constraint_padding_arg2,
568        final_tensor_4_constraint_no_padding,
569        final_tensor_4_constraint_padding_arg1,
570        final_tensor_4_constraint_padding_arg2
571    ])
572
573    return Conj([final_result, *nat_dims_1, *nat_dims_2, *nat_dims_3, *nat_dims_4]), counter
574
575
576def transform_constraint(constraint: Constraint, counter: int):
577    """
578    Transforms a constraint into a simpler constraint.
579    Ex: precision and consistency are transformed to equality
580    Args:
581        constraint: constraint to be transformed
582        counter: for variable tracking
583
584    Returns: Constraint
585
586    """
587    if type(constraint) in _TRANSFORMATION_RULES:
588        return _TRANSFORMATION_RULES[type(constraint)](constraint, counter)
589
590    else:
591        return constraint, counter
592
593
594
595
596def calc_last_two_dims(constraint, d: List[DVar]):
597    """
598    Generates constraints for the last two dimensions of a convolution or a maxpool output
599    Args:
600        constraint: CalcConv or CalcMaxPool
601        d: The list of output dimensions
602
603    Returns: Constraints for calculating the last two dimensions of the output
604
605    """
606
607    assert isinstance(constraint, (CalcConv, CalcMaxPool))
608
609    b3 = constraint.matching_constraint[2]
610    b4 = constraint.matching_constraint[3]
611
612    b3_dyn = Conj([BinConstraintD(d[2], Dyn, op_eq), BinConstraintD(b3, Dyn, op_eq)])
613    b4_dyn = Conj([BinConstraintD(d[3], Dyn, op_eq), BinConstraintD(b4, Dyn, op_eq)])
614
615    d3_not_dyn = Conj([BinConstraintD(d[2], Dyn, op_neq), BinConstraintD(b3, Dyn, op_neq)])
616    d4_not_dyn = Conj([BinConstraintD(d[3], Dyn, op_neq), BinConstraintD(b4, Dyn, op_neq)])
617
618    # transform parameters into tuples incase they are not already
619    padding = (constraint.padding, constraint.padding) \
620        if isinstance(constraint.padding, int) else constraint.padding
621    kernel = (constraint.kernel, constraint.kernel) \
622        if isinstance(constraint.kernel, int) else constraint.kernel
623    stride = (constraint.stride, constraint.stride) \
624        if isinstance(constraint.stride, int) else constraint.stride
625    dilation = (constraint.dilation, constraint.dilation) \
626        if isinstance(constraint.dilation, int) else constraint.dilation
627
628    f1 = BinConstraintD(b3, BinConstraintD(2, padding[0], op_mul), op_add)
629    f2 = BinConstraintD(dilation[0], BinConstraintD(kernel[0], 1, op_sub), op_mul)
630    f3 = BinConstraintD(BinConstraintD(BinConstraintD(f1, f2, op_sub), 1, op_sub), stride[0], op_div)
631    f4 = BinConstraintD(f3, 1, op_add)
632
633    c4 = Disj([b3_dyn, Conj([d3_not_dyn, BinConstraintD(d[2], f4, op_eq)])])
634
635    f11 = BinConstraintD(b4, BinConstraintD(2, padding[1], op_mul), op_add)
636    f22 = BinConstraintD(dilation[1], BinConstraintD(kernel[1], 1, op_sub), op_mul)
637    f33 = BinConstraintD(BinConstraintD(BinConstraintD(f11, f22, op_sub), 1, op_sub), stride[1], op_div)
638    f44 = BinConstraintD(f33, 1, op_add)
639
640    c5 = Disj([b4_dyn, Conj([d4_not_dyn, BinConstraintD(d[3], f44, op_eq)])])
641
642    return c4, c5
643
644
645def generate_all_int_dyn_dim_possibilities(my_list: List[DVar]):
646    """
647    Generate all possibilities of being equal or not equal to dyn for my_list
648    Args:
649        my_list: List of tensor dimensions
650
651    Returns: A list of a list of constraints. Each list of constraints corresponds to
652    one possibility about the values of the dimension variables
653    """
654    # generate all possibilities of being equal or not equal to dyn for my_list
655    eq_possibilities = [BinConstraintD(my_list[i], Dyn, op_eq) for i in range(len(my_list))]
656    neq_possibilities = [BinConstraintD(my_list[i], Dyn, op_neq) for i in range(len(my_list))]
657    d_possibilities = []
658
659    for i in zip(eq_possibilities, neq_possibilities):
660        d_possibilities.append(list(i))
661    all_possibilities = list(itertools.product(*d_possibilities))
662    return all_possibilities
663
664
665def is_target_div_by_dim(target: List[int], dim: List[DVar]):
666    """
667    Generate constraints to check if the target dimensions are divisible by the input dimensions
668    Args:
669        target: Target dimensions
670        dim: Input dimensions
671
672    Returns: Constraints to check divisibility
673
674    """
675    return BinConstraintD(BinConstraintD(Prod(target), dim, op_mod), 0, op_eq)
676
677
678def is_dim_div_by_target(target: List[int], dim: List[DVar]):
679    """
680    Generate constraints to check if the input dimensions is divisible by the target dimensions
681    Args:
682        target: Target dimensions
683        dim:  Input dimensions
684
685    Returns: Constraints to check divisibility
686
687    """
688    return BinConstraintD(BinConstraintD(dim, Prod(target), op_mod), 0, op_eq)
689
690
691def gen_all_reshape_possibilities(list_of_dims, target):
692    """
693    Consider all possibilities what the input dimensions could be (number or dynamic)
694    Then generate the appropriate constraints using multiplication or mod depending on the possibility
695    The possibilities we consider here are the cross product of being equal to dyn or not equal to dyn
696    for the input. Target is fixed because at most one dimension could be dyn.
697    We have different cases for this.
698
699    Args:
700        list_of_dims: The input list of dimensions
701        target: The tensor we want to reshape to
702
703    Returns: A disjunction of transformed reshape constraints
704
705    """
706    all_possibilities = generate_all_int_dyn_dim_possibilities(list_of_dims)
707
708    all_constraints = []
709
710    for p in all_possibilities:
711        to_multiply = []
712
713        p = list(p)
714
715        for constraint in p:
716            assert isinstance(constraint, BinConstraintD)
717            if constraint.op == op_neq:
718                to_multiply.append(constraint.lhs)
719
720        if not to_multiply:
721            all_constraints.append(Conj(p))
722
723        elif len(to_multiply) < len(list_of_dims):
724            all_constraints.append(Conj(p + [is_target_div_by_dim(target, Prod(to_multiply))]))
725        else:
726            all_constraints.append(Conj(p + [BinConstraintD(Prod(list_of_dims),
727                                                            Prod(target), op_eq)]))
728
729    return Disj(all_constraints)
730
731
732def broadcast_dim(tensor_input1, tensor_input2, res1, res2, index, padding=False):
733    """
734    Apply broadcasting to the 'index' dimension of tensor_input1.
735    Args:
736        tensor_input1: should represent [d1, ..., d_index, ...] where d_index = 1
737        tensor_input2: represents the second input
738        res1: broadcasted result 1
739        res2: broadcasted result 2
740        index: the index to broadcast
741        padding: If padding was used, then tensor_input1[index] does not exist
742
743    Returns:
744
745    """
746    if tensor_input1[index] is None:
747        assert padding
748
749
750    if not padding:
751        # then the inputs are the same length so they all have dimensions at "index"
752        return Conj([BinConstraintD(tensor_input1[index], 1, op_eq),
753                     BinConstraintD(res1[index], res2[index], op_eq),
754                     BinConstraintD(res2[index], tensor_input2[index], op_eq)])
755
756    else:
757        # we don't set the input dimension to 1, since it doesn't exist.
758        return Conj([BinConstraintD(res1[index], res2[index], op_eq),
759                     BinConstraintD(res2[index], tensor_input2[index], op_eq)])
760
761
762def apply_padding(e1_var: TVar,
763                  e11: BinConstraintT,
764                  e2: BinConstraintT,
765                  e12: BinConstraintT,
766                  d2: List[DVar],
767                  d11: List[DVar],
768                  d12: List[DVar],
769                  counter: int):
770    """
771    We are considering the possibility where one input has less dimensions than
772    another input, so we apply padding to the broadcasted results
773
774    Args:
775        e1_var: Variable representing the first input where padding will be
776        e11: constraint of the form e11 = Tensortype[d1, ..., dn]
777        e2:  constraint of the form e2 = Tensortype[d1, ..., dn]
778        e12: constraint of the form e11 = Tensortype[d1, ..., dn]
779        d2: Tensor variables for the second input
780        d11: Tensor variables for the broadcasted first input
781        d12: Tensor variables for the broadcasted second input
782        counter: variable tracking
783
784    Returns: A new constraint whose goal is to apply padding to the broadcasted result
785
786    """
787
788    res = []
789
790    # pad the shorter input with None so we can pass it to the broadcasting helper function
791    for i in range(1, len(d2)):
792
793        d1, counter = gen_tensor_dims(i, counter)
794
795        nat_constraints = gen_nat_constraints(d1 + d2 + d11 + d12)
796
797        e1 = BinConstraintT(e1_var, TensorType(d1), op_eq)
798
799        simulate_padding = [None] * (len(d2) - i)
800
801        assert len(simulate_padding + d1) == len(d2)
802
803        broadcast_padding = []
804
805        # for every padding size, we also consider broadcasting
806        for j in range(len(d2) - i):
807            broadcast_padding.append(broadcast_dim(simulate_padding, d2, d11, d12, j, True))
808
809        # we consider the possibilities for broadcasting for every dimension. Since we already
810        # padded d1, we do not consider it while broadcasting
811        all_broadcasting_possibilities = generate_all_broadcasting_possibilities_no_padding(d1,
812                                                                                            d2[(len(d2) - i):],
813                                                                                            d11[(len(d2) - i):],
814                                                                                            d12[(len(d2) - i):])
815        # combine all constraints into a conjunction
816        c = Conj([e1, e11, e2, e12,
817                  *broadcast_padding,
818                  all_broadcasting_possibilities,
819                  *nat_constraints
820                  ])
821        res.append(c)
822
823    return Disj(res), counter
824
825
826def no_broadcast_dim_with_index(d1: List[DVar],
827                                d2: List[DVar],
828                                d3: List[DVar],
829                                d4: List[DVar],
830                                i: int):
831    """
832    Args:
833        d1: input 1
834        d2: input 2
835        d3: simulated broadcasting for input 1
836        d4: simulated broadcasting for input 2
837        i: the rank of the resulting tensor addition
838
839    Returns: Constraints for when no broadcasting occurs
840    """
841    return Conj([
842        Disj([
843            Conj([BinConstraintD(d1[i], 1, op_eq),
844                  BinConstraintD(d2[i], 1, op_eq)]),
845
846            Conj([BinConstraintD(d1[i], 1, op_neq),
847                  BinConstraintD(d2[i], 1, op_neq)])]),
848
849        BinConstraintD(d1[i], d3[i], op_eq),
850        BinConstraintD(d2[i], d4[i], op_eq)])
851
852
853
854def gen_lists_of_dims(num_tensors: int, dim_size: int, counter: int):
855    """
856    Generate lists of DVar to represent tensor dimensions
857    Args:
858        num_tensors: the required number of tensors
859        dim_size: the number of dimensions for each tensor
860        counter: variable tracking
861
862    Returns: A list of a list of tensor dimensions
863
864    """
865    res = []
866
867    for _ in range(num_tensors):
868        dims, counter = gen_tensor_dims(dim_size, counter)
869        res.append(dims)
870
871    return res, counter
872
873
874def create_equality_constraints_for_broadcasting(e1: TVar,
875                                                 e2: TVar,
876                                                 e11: TVar,
877                                                 e12: TVar,
878                                                 d1: List[DVar],
879                                                 d2: List[DVar],
880                                                 d11: List[DVar],
881                                                 d12: List[DVar]):
882    """
883    Create equality constraints for when no broadcasting occurs
884    Args:
885        e1: Input 1
886        e2: Input 2
887        e11: Broadcasted input 1
888        e12: Broadcasted input 2
889        d1: Variables that store dimensions for e1
890        d2: Variables that store dimensions for e2
891        d11: Variables that store dimensions for e11
892        d12: Variables that store dimensions for e22
893
894    Returns: Four equality constraints
895
896    """
897
898    e1_tensor = BinConstraintT(e1, TensorType(d1), op_eq)
899    e11_tensor = BinConstraintT(e11, TensorType(d11), op_eq)
900    e2_tensor = BinConstraintT(e2, TensorType(d2), op_eq)
901    e12_tensor = BinConstraintT(e12, TensorType(d12), op_eq)
902    return [e1_tensor, e11_tensor, e2_tensor, e12_tensor]
903
904
905def gen_consistency_constraints(constraint: Constraint, counter: int):
906    """
907    Args:
908        constraint: Consistency constraint on tensors
909        counter: for variable tracking
910
911    Returns: Equality and consistency constraints on dimensions
912
913    """
914
915    all_constraints = []
916
917    for i in range(1, MAX_TENSOR_RANK + 1):
918        new_dims_rhs_1, counter = gen_tensor_dims(i, counter)
919        new_dims_rhs_2, counter = gen_tensor_dims(i, counter)
920
921        nat_constraints = gen_nat_constraints(new_dims_rhs_1 + new_dims_rhs_2)
922
923        c_tensor_i = Conj([BinConstraintT(constraint.lhs, TensorType(new_dims_rhs_1), op_eq),
924                           BinConstraintT(constraint.rhs, TensorType(new_dims_rhs_2), op_eq)] +
925                          [BinConstraintD(d1, d2, op_consistency) for
926                           d1, d2 in zip(new_dims_rhs_1, new_dims_rhs_2)] + nat_constraints)
927
928        all_constraints.append(c_tensor_i)
929
930    return all_constraints, counter
931
932
933def gen_greatest_upper_bound(constraint: TGreatestUpperBound, counter: int):
934    """
935    Args:
936        constraint: Greatest upper bound on tensors
937        counter: variable tracking
938
939    Returns: A set of equality constraints and DGreatestUpperBound constraints
940
941    """
942
943    all_constraints = []
944
945    for i in range(1, MAX_TENSOR_RANK + 1):
946        c = []
947        dims1, counter = gen_tensor_dims(i, counter)
948        c1tensor = TensorType(dims1)
949
950        dims2, counter = gen_tensor_dims(i, counter)
951        c2tensor = TensorType(dims2)
952
953        dims3, counter = gen_tensor_dims(i, counter)
954        c3tensor = TensorType(dims3)
955
956        c += [BinConstraintT(constraint.rhs1, c1tensor, op_eq),
957              BinConstraintT(constraint.rhs2, c2tensor, op_eq),
958              BinConstraintT(constraint.res, c3tensor, op_eq)] + \
959            gen_nat_constraints(dims1 + dims2 + dims3)
960
961        assert len(c3tensor.__args__) == len(c1tensor.__args__) == len(c2tensor.__args__)
962        for i in range(len(c3tensor.__args__)):
963            c.append(DGreatestUpperBound(c3tensor.__args__[i],
964                                         c1tensor.__args__[i],
965                                         c2tensor.__args__[i]))
966
967        all_constraints.append(Conj(c))
968    return all_constraints, counter
969
970
971def generate_all_broadcasting_possibilities_no_padding(d1: List[DVar], d2: List[DVar], d11: List[DVar], d12: List[DVar]):
972    """
973    Generate broadcasting constraints assuming no padding. Broadcasting can happen at any dimension.
974    We look at all combinations for all dimensions in d1 and d2
975    Args:
976        d1: input1 dimensions
977        d2: input2 dimensions
978        d11: broadcasted input1 dimensions
979        d12: broadcasted input2 dimensions
980
981    Returns: broadcasting constraints relating the input dimensions to the broadcasted dimensions
982
983    """
984
985    size = len(d1)
986
987    res2 = []
988
989    for i in range(size):
990        t1 = broadcast_dim(d1, d2, d11, d12, i)
991        t2 = broadcast_dim(d2, d1, d12, d11, i)
992        t3 = no_broadcast_dim_with_index(d1, d2, d11, d12, i)
993
994        res2.append(Disj([t1, t2, t3]))
995
996    return Conj(res2)
997
998
999def gen_broadcasting_constraints(e1: TVar, e2: TVar, e11: TVar, e12: TVar, i: int, counter: int):
1000    """
1001    Simulates broadcasting on e1 and e2 and returns the results
1002    respectively in e11 and e12. Because of gradual types,
1003    e1 and e2 may not be equal. Similarly, e11 and e12 may not
1004    be equal. e11 and e12 should be guaranteed to be consistent
1005    as they represent the shapes of the tensors to be added after
1006    broadcasting.
1007    Args:
1008        e1: TVar representing the type of input 1
1009        e2: TVar representing the type of input 2
1010        e11: TVar representing the representing broadcasted input 1
1011        e12: TVar representing the representing broadcasted input 2
1012        i: The rank of the resulting type of addition
1013        counter: for variable tracking
1014
1015    Returns: Simplified broadcasting constraints
1016
1017    """
1018    dims, counter = gen_lists_of_dims(4, i, counter)
1019    [d1, d2, d3, d4] = dims
1020    nat_dims_i = gen_nat_constraints(list(itertools.chain.from_iterable(dims)))
1021
1022    initialize_tensors_constraints = create_equality_constraints_for_broadcasting(e1, e2, e11, e12,
1023                                                                                  d1, d2, d3, d4)
1024
1025    [e1_tensor, e11_tensor, e2_tensor, e12_tensor] = initialize_tensors_constraints
1026
1027    # without padding, broadcast all possibilities for tensors of size i
1028    final_tensor_constraint_no_padding = Conj([*initialize_tensors_constraints,
1029                                               generate_all_broadcasting_possibilities_no_padding(d1, d2, d3, d4)])
1030
1031    # with padding, broadcast all possibilities for tensors of size i
1032    final_tensor_constraint_padding_arg1, counter = \
1033        apply_padding(e1, e11_tensor, e2_tensor, e12_tensor, d2, d3, d4, counter)
1034
1035    final_tensor_constraint_padding_arg2, counter = \
1036        apply_padding(e2, e12_tensor, e1_tensor, e11_tensor, d1, d4, d3, counter)
1037
1038    return final_tensor_constraint_no_padding, \
1039        final_tensor_constraint_padding_arg1, \
1040        final_tensor_constraint_padding_arg2, nat_dims_i, counter
1041