xref: /aosp_15_r20/external/pytorch/torch/fx/experimental/migrate_gradual_types/constraint.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from torch.fx.experimental.migrate_gradual_types.operation import op_add, op_sub, op_mul, op_div, \
3    op_mod, op_gt, op_lt, op_neq, op_eq
4from torch.fx.tensor_type import TensorType, Dyn
5
6
7class Constraint:
8    pass
9
10
11class Conj(Constraint):
12    def __init__(self, conjuncts):
13        """
14        :param conjuncts: Conjunction of constraints
15        """
16        self.conjucts = conjuncts
17
18    def __eq__(self, other):
19        if isinstance(other, Conj):
20            return self.conjucts == other.conjucts and self.conjucts == other.conjucts
21        else:
22            return False
23
24    def __repr__(self):
25        return f'And({self.conjucts})'
26
27
28class Disj(Constraint):
29    def __init__(self, disjuncts):
30        """
31        :param disjuncts: Disjunction of constraints
32        """
33        self.disjuncts = disjuncts
34
35    def __eq__(self, other):
36        if isinstance(other, Disj):
37            return self.disjuncts == other.disjuncts and self.disjuncts == other.disjuncts
38        else:
39            return False
40
41    def __repr__(self):
42        return f'Or({self.disjuncts})'
43
44
45class Prod(Constraint):
46    def __init__(self, products):
47        """
48        :param products: lists of dimensions to multiply
49        """
50        self.products = products
51
52    def __eq__(self, other):
53        if isinstance(other, Prod):
54            return self.products == other.products and self.products == other.products
55        else:
56            return False
57
58    def __repr__(self):
59        return f'Product({self.products})'
60
61
62class T(Constraint):
63    """
64    True
65    """
66    def __init__(self) -> None:
67        pass
68
69    def __eq__(self, other):
70        return isinstance(other, T)
71
72    def __repr__(self):
73        return 'True'
74
75class F(Constraint):
76    """
77    False
78    """
79    def __init__(self) -> None:
80        pass
81
82    def __eq__(self, other):
83        return isinstance(other, F)
84
85    def __repr__(self):
86        return 'False'
87
88
89class BinaryConstraint(Constraint):
90    """
91    Represents all binary operations
92    """
93    def __init__(self, lhs, rhs, op):
94        """
95        :param lhs: lhs of the constraint
96        :param rhs: rhs of the constraint
97        :param op: string representing the operation
98        """
99        self.lhs = lhs
100        self.rhs = rhs
101        self.op = op
102
103    def __eq__(self, other):
104        if isinstance(other, BinaryConstraint):
105            return self.lhs == other.lhs and self.rhs == other.rhs and self.op == other.op
106        else:
107            return False
108
109    def __repr__(self):
110        return f'({self.lhs} {self.op} {self.rhs})'
111
112
113class BinConstraintT(BinaryConstraint):
114    """
115    Binary constraints about tensors
116    """
117    def __init__(self, lhs, rhs, op):
118        assert (isinstance(lhs, (TVar, TensorType, int)) or lhs == Dyn) and \
119               (isinstance(rhs, (TVar, TensorType, int)) or rhs == Dyn)
120        super().__init__(lhs, rhs, op)
121
122    def __eq__(self, other):
123        return super().__eq__(other)
124
125
126class BinConstraintD(BinaryConstraint):
127    """
128    Binary constraints about dimensions
129    """
130    def __init__(self, lhs, rhs, op):
131        assert is_algebraic_expression(lhs) or is_dim(lhs) or is_bool_expr(lhs)
132        assert is_algebraic_expression(rhs) or is_dim(rhs) or is_bool_expr(rhs)
133
134        super().__init__(lhs, rhs, op)
135
136    def __eq__(self, other):
137        return super().__eq__(other)
138
139
140
141class TGreatestUpperBound(Constraint):
142    """
143    Greatest Upper bound for tensors with dynamic type
144    """
145    def __init__(self, res, rhs1, rhs2):
146        """
147        :param res: tensor variable that stores the result of the outout
148        :param rhs1: tensor or tensor variable
149        :param rhs2: tensor or tensor variabke
150        """
151        self.res = res
152        self.rhs1 = rhs1
153        self.rhs2 = rhs2
154
155    def __repr__(self):
156        return f'{self.res} = {self.rhs1}\u2294*{self.rhs2}'
157
158    def __eq__(self, other):
159        if isinstance(other, TGreatestUpperBound):
160            return self.res == other.res and self.rhs1 == other.rhs1 and self.rhs2 == other.rhs2
161        else:
162            return False
163
164
165class DGreatestUpperBound(Constraint):
166    """
167    Greatest Upper bound for dimensions
168    """
169    def __init__(self, res, rhs1, rhs2):
170        """
171        :param res: Dimension variable to store the result
172        :param rhs1: dimension variable 1
173        :param rhs2: dimension variable 2
174        """
175        assert is_dim(res)
176        assert is_dim(rhs1)
177        assert is_dim(rhs2)
178
179        self.res = res
180        self.rhs1 = rhs1
181        self.rhs2 = rhs2
182
183    def __repr__(self):
184        return f'{self.res} = {self.rhs1}\u2294{self.rhs2}'
185
186    def __eq__(self, other):
187        if isinstance(other, DGreatestUpperBound):
188            return self.res == other.res and self.rhs1 == other.rhs1 and self.rhs2 == other.rhs2
189        else:
190            return False
191
192
193class CanReshape(Constraint):
194    """
195    can_reshape constraint
196    """
197    def __init__(self, src, target):
198        """
199        :param src: tensor variable
200        :param target: tensor
201        """
202        self.src = src
203        self.target = target
204
205    def __repr__(self):
206        return f'can-reshape({self.src}, {self.target})'
207
208    def __eq__(self, other):
209        if isinstance(other, CanReshape):
210            return self.src == other.src and self.target == other.target
211        else:
212            return False
213
214
215class IndexSelect(Constraint):
216
217    def __init__(self, tensor_size, input_var, dim_replace, index, output):
218        """
219        Args:
220            input_var: input to index_select
221            tensor_size: tensor size we are considering
222            dim_replace: the dimension of the output at "index"
223            index: location of the dimensions to replace in the input
224            output: variable to store the result
225        """
226        assert isinstance(input_var, TVar)
227        assert isinstance(output, TVar)
228        assert isinstance(dim_replace, DVar) or dim_replace == Dyn
229        assert isinstance(index, int)
230
231        self.input_var = input_var
232        self.tensor_size = tensor_size
233        self.dim_replace = dim_replace
234        self.index = index
235        self.output = output
236
237    def __repr__(self):
238
239        return f' {self.output} = ' \
240               f'IndexSelect({self.input_var}, ' \
241               f'tensor_size: {self.tensor_size}, ' \
242               f'{self.dim_replace}, ' \
243               f'{self.index})'
244
245    def __eq__(self, other):
246        if isinstance(other, IndexSelect):
247            return self.tensor_size == other.tensor_size and \
248                self.dim_replace == other.dim_replace and \
249                self.index == other.index and \
250                self.output == other.output and \
251                self.input_var == other.input_var
252        else:
253            return False
254
255
256class Transpose(Constraint):
257
258    def __init__(self, tensor_size, input_var, index1, index2, output):
259        """
260        Args:
261            tensor_size: current tensor size
262            input_var: variable to hold input
263            index1: dimension 1
264            index2: dimension 2
265            output: output that stores result
266        """
267        assert isinstance(input_var, TVar)
268        assert isinstance(output, TVar)
269        assert isinstance(index1, int)
270        assert isinstance(index2, int)
271
272        self.input_var = input_var
273        self.tensor_size = tensor_size
274        self.index1 = index1
275        self.index2 = index2
276        self.output = output
277
278    def __repr__(self):
279
280        return f' {self.output} = ' \
281               f'Transpose({self.input_var}, ' \
282               f'tensor_size: {self.tensor_size}, ' \
283               f'{self.index1}, ' \
284               f'{self.index2})'
285
286    def __eq__(self, other):
287        if isinstance(other, Transpose):
288            return self.tensor_size == other.tensor_size and \
289                self.index1 == other.index1 and \
290                self.index2 == other.index2 and \
291                self.output == other.output and \
292                self.input_var == other.input_var
293        else:
294            return False
295
296
297class GetItem(Constraint):
298
299    def __init__(self, tensor_size, index, res, input_var):
300        """
301        Constraint for getting item given a tensor size
302        :param tensor_size: actual number
303        :param index: actual number representing the index
304        :param res: dimension variable to carry the item we get
305        :param input_var: a tensor variable from which we will get item
306        """
307        assert isinstance(res, DVar)
308
309        self.res = res
310        self.tensor_size = tensor_size
311        self.index = index
312        self.input_var = input_var
313
314    def __repr__(self):
315        return f' {self.res} = GetItem({self.input_var}, tensor_size: {self.tensor_size}, {self.index})'
316
317    def __eq__(self, other):
318        if isinstance(other, GetItem):
319            return self.res == other.res and \
320                self.tensor_size == other.tensor_size and \
321                self.index == other.index and \
322                self.input_var == other.input_var
323        else:
324            return False
325
326class GetItemTensor(Constraint):
327
328    def __init__(self, tensor_size, index_tuple, res, input_var):
329        """
330        Constraint for getting item given a tensor size
331        However, when the argument is a tuple, we will
332        expect a tensor
333        :param tensor_size: actual number representing the rank
334        :param index_tuple: tuple for indexing
335        :param res: tensor variable to carry the item we get
336        :param input_var: a tensor variable from which we will get item
337        """
338        assert isinstance(res, TVar)
339
340        self.res = res
341        self.tensor_size = tensor_size
342        self.index_tuple = index_tuple
343        self.input_var = input_var
344
345    def __repr__(self):
346        return f' {self.res} = GetItemT({self.input_var}, tensor_size: {self.tensor_size}, {self.index_tuple})'
347
348    def __eq__(self, other):
349        if isinstance(other, GetItemTensor):
350            return self.res == other.res and \
351                self.tensor_size == other.tensor_size and \
352                self.index_tuple == other.index_tuple and \
353                self.input_var == other.input_var
354        else:
355            return False
356
357class CalcConv(Constraint):
358
359    def __init__(self, conv_result, input_var, c_out, kernel, padding, stride, dilation, matching_constraint_vars):
360        """
361        :param conv_result: the convolution result
362        :param input_var: input to convolution
363        :param c_out: output chanel type
364        :param kernel: kernel tuple
365        """
366        self.conv_result = conv_result
367        self.input_var = input_var
368        self.c_out = c_out
369        self.kernel = kernel
370        self.padding = padding
371        self.stride = stride
372        self.dilation = dilation
373        self.matching_constraint = matching_constraint_vars
374
375    def __repr__(self):
376        return f'{self.conv_result} =' \
377               f' calc-conv({self.input_var},' \
378               f' {self.c_out}, {self.kernel}, ' \
379               f'{self.padding}, {self.stride},' \
380               f' {self.dilation})'
381
382    def __eq__(self, other):
383        if isinstance(other, CalcConv):
384            return self.conv_result == other.conv_result and self.input_var == other.input_var and \
385                self.c_out == other.c_out and self.kernel == other.kernel and self.padding == other.padding \
386                and self.stride == other.stride and self.dilation == other.dilation \
387                and self.matching_constraint == other.matching_constraint
388        else:
389            return False
390
391
392class CalcMaxPool(Constraint):
393
394    def __init__(self, maxpool_result, input_var, kernel, padding, stride, dilation, matching_constraint_vars):
395        """
396        :param maxpool_result: the result of maxpool
397        :param input_var: input to convolution
398        :param kernel: kernel tuple
399        """
400        self.maxpool_result = maxpool_result
401        self.input_var = input_var
402        self.kernel = kernel
403        self.padding = padding
404        self.stride = stride
405        self.dilation = dilation
406        self.matching_constraint = matching_constraint_vars
407
408    def __repr__(self):
409        return f'{self.maxpool_result} =' \
410               f' calc-maxpool({self.input_var},' \
411               f'  {self.kernel}, ' \
412               f'{self.padding}, {self.stride},' \
413               f' {self.dilation})'
414
415    def __eq__(self, other):
416        if isinstance(other, CalcMaxPool):
417            return self.maxpool_result == other.maxpool_result and self.input_var == other.input_var \
418                and self.kernel == other.kernel and self.padding == other.padding \
419                and self.stride == other.stride and self.dilation == other.dilation \
420                and self.matching_constraint == other.matching_constraint
421        else:
422            return False
423
424
425class ApplyBroadcasting(Constraint):
426    def __init__(self, res1, res2, input1, input2):
427        """
428        :param res1: resulting tensor 1
429        :param res2: resulting tensor 2
430        :param input1: tensor variable 1
431        :param input2: tensor variable 2
432        """
433        self.res1 = res1
434        self.res2 = res2
435        self.input1 = input1
436        self.input2 = input2
437
438    def __eq__(self, other):
439        if isinstance(other, ApplyBroadcasting):
440            return self.res1 == other.res1 \
441                and self.res2 == other.res2 \
442                and self.input1 == other.input1 \
443                and self.input2 == other.input2
444        else:
445            return False
446
447    def __repr__(self):
448        return f'{self.res1}, {self.res2} ='f' apply-broadcasting({self.input1},' f' {self.input2})'
449
450
451class CalcProduct(Constraint):
452    """
453    Given correct dimensions, calculate the product for flatten accounting for Dyn
454    """
455    def __init__(self, start, end, flattened, dims_to_flatten):
456        """
457        :param start: start index
458        :param end: end index
459        :param flattened: variable to store the product
460        :param dims_to_flatten: the type which we will flatten
461        """
462        assert isinstance(dims_to_flatten, list)
463        assert isinstance(flattened, TVar)
464        assert isinstance(start, int)
465        assert isinstance(end, int)
466
467        self.start = start
468        self.end = end
469        self.dims_to_flatten = dims_to_flatten
470        self.flattened = flattened
471
472    def __eq__(self, other):
473        if isinstance(other, CalcProduct):
474            return self.start == other.start and self.end == other.end and \
475                self.dims_to_flatten == other.dims_to_flatten and self.flattened == other.flattened
476
477        else:
478            return False
479
480    def __repr__(self):
481        return f'{self.flattened} = CalcProduct({self.start}, {self.end}, {self.dims_to_flatten})'
482
483
484class TVar:
485    """
486    Tensor variable with no tensor constructor
487    """
488    def __init__(self, tvar):
489        """
490        :param tvar: tensor variable
491        """
492        self.tvar = tvar
493
494    def __repr__(self):
495        return f'TV({self.tvar})'
496
497    def __eq__(self, other):
498        if isinstance(other, TVar):
499            return self.tvar == other.tvar
500        else:
501            return False
502
503
504class DVar:
505    """
506    Dimension variable
507    """
508    def __init__(self, c):
509        """
510        :param c: character or number
511        """
512        self.c = c
513
514    def __repr__(self):
515        return f'DV({self.c})'
516
517    def __eq__(self, other):
518        if isinstance(other, DVar):
519            return self.c == other.c
520        else:
521            return False
522
523
524class BVar:
525    """
526    Boolean variable
527    """
528    def __init__(self, c):
529        """
530        :param c: character or number
531        """
532        self.c = c
533
534    def __repr__(self):
535        return f'BV({self.c})'
536
537    def __eq__(self, other):
538        if isinstance(other, BVar):
539            return self.c == other.c
540        else:
541            return False
542
543
544def is_algebraic_expression(constraint):
545    if isinstance(constraint, BinConstraintD):
546        return constraint.op in [op_add, op_sub, op_div, op_mul, op_mod]
547    else:
548        return isinstance(constraint, Prod)
549
550
551def is_bool_expr(constraint):
552    if isinstance(constraint, BinConstraintD):
553        return constraint.op in [op_gt, op_lt, op_neq, op_eq]
554    else:
555        return isinstance(constraint, (BVar, Conj, Disj))
556
557def is_dim(d):
558    return isinstance(d, (DVar, int)) or d == Dyn
559