xref: /aosp_15_r20/external/pytorch/torch/fx/experimental/graph_gradual_typechecker.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-decorators
2# mypy: allow-untyped-defs
3from functools import reduce
4import torch
5import operator
6from torch.fx.tensor_type import Dyn, is_consistent, TensorType, is_more_precise
7from typing import Callable, Dict
8from torch.fx.node import Target, Node
9from torch.nn.modules.batchnorm import BatchNorm2d
10from torch.nn.modules.conv import Conv2d
11from torch.fx.experimental.refinement_types import Equality
12import itertools
13
14from torch.fx.experimental.unification import Var  # type: ignore[attr-defined]
15
16import sympy
17
18_INFERENCE_RULES: Dict[Target, Callable] = {}
19_REFINEMENT_RULES: Dict[Target, Callable] = {}
20_RULES: Dict[Target, Callable] = {}
21
22
23def expand_to_tensor_dim(t, n):
24    """
25    Expand a type to the desired tensor dimension if possible
26    Raise an error otherwise.
27    - t is the given type
28    - n is a number of dimensions to expand to
29    """
30    if t == Dyn:
31        dims = [Dyn] * n
32        return TensorType(tuple(dims))
33    elif isinstance(t, TensorType):
34        if len(t.__args__) != n:
35            raise TypeError(f'Cannot extend tensor. Tensor {t} has rank {len(t.__args__)}. It should have rank {n}')
36        return t
37    else:
38        raise TypeError(f'Cannot match the type {t}')
39
40
41def broadcast_types(t1, t2):
42    """
43    Applies broadcasting to both given types such that they
44    become consistent with eachother and returns two new
45    resulting types
46    """
47
48    # if either type is Dyn, do nothing since the types are already consistent
49    if t1 == Dyn or t2 == Dyn or isinstance(t1, Var) or isinstance(t2, Var):
50        return t1, t2
51
52    if isinstance(t1, TensorType) and isinstance(t2, TensorType):
53        s1 = len(t1.__args__)
54        s2 = len(t2.__args__)
55
56        new_t1 = list(t1.__args__)
57        new_t2 = list(t2.__args__)
58
59        # We make the types the same length which is the first requirement
60        # for consistency
61        if s1 > s2:
62            for i in range(s1 - s2):
63                new_t2.insert(0, 1)
64
65        elif s2 > s1:
66            for i in range(s2 - s1):
67                new_t1.insert(0, 1)
68
69        # we replace occurrences of "1" with each tensor with
70        # the corresponding type from the other tensor
71        for i, (x, y) in enumerate(zip(new_t1, new_t2)):
72            if x == 1:
73                new_t1[i] = y
74            elif y == 1:
75                new_t2[i] = x
76
77        # at this point our tensors should be consistent
78        # and we can apply the element-wise operation and find the right dimension
79        # for the output of the operation
80        (t1, t2) = TensorType(tuple(new_t1)), TensorType(tuple(new_t2))
81        return (t1, t2)
82    else:
83        raise TypeError(f'Cannot broadcast types {t1} and {t2}')
84
85def register_inference_rule(call_target):
86    def register(fn):
87        if call_target in _INFERENCE_RULES:
88            raise RuntimeError(f'Inference rule already registered for {call_target}!')
89        _INFERENCE_RULES[call_target] = fn
90        return fn
91    return register
92
93def register_refinement_rule(call_target):
94    def register(fn):
95        if call_target in _REFINEMENT_RULES:
96            raise RuntimeError(f'Refinement rule already registered for {call_target}!')
97        _REFINEMENT_RULES[call_target] = fn
98        return fn
99    return register
100
101def register_algebraic_expressions_inference_rule(call_target):
102    def register(fn):
103        if call_target in _RULES:
104            raise RuntimeError(f'Rule already registered for {call_target}!')
105        _RULES[call_target] = fn
106        return fn
107    return register
108
109@register_inference_rule(torch.add)
110@register_inference_rule(operator.add)
111def add_inference_rule(n: Node):
112    """
113    Apply the addition inference rule. This includes:
114    - scalar addition
115    - broadcasting semantics
116
117    Note that we always return the least precise type between
118    the operands (after applying broadcasting) to be the final type of the operation
119
120    Note that we do not modify the operand types themselves after applying broadcasting
121    to them. We only use them to calculate the final type
122    """
123    assert isinstance(n.args[0], Node)
124    assert isinstance(n.args[1], Node)
125    t1 = n.args[0].type
126    t2 = n.args[1].type
127
128    # handle scalar addition
129    if t1 == int and isinstance(t2, TensorType):
130        n.type = t2
131        return n.type
132
133    # handle scalar addition
134    elif t2 == int and isinstance(t1, TensorType):
135        n.type = t1
136        return n.type
137
138    # we bring the new types to the point where
139    # we can check for consistency
140    # any inconsistency would not have been caused
141    # by broadcasting at this point
142    (new_t1, new_t2) = broadcast_types(t1, t2)
143
144    if new_t1 != t1 or new_t2 != t2:
145        n.meta['broadcast'] = True
146        n.meta[str(n.args[0])] = new_t1
147        n.meta[str(n.args[1])] = new_t2
148
149    else:
150        n.meta['broadcast'] = False
151
152    new_t1 = t1 if not n.meta['broadcast'] else new_t1
153    new_t2 = t2 if not n.meta['broadcast'] else new_t2
154
155    # we check for consistency between the new types
156    if is_consistent(new_t1, new_t2):
157        # we return the less precise type because
158        # broadcasting may have happened
159        # for operands with shape [1,2,Dyn] and [1,2,1]
160        # we have to assign the node [1,2,Dyn]
161        if is_more_precise(new_t1, new_t2):
162            n.type = new_t2
163        else:
164            n.type = new_t1
165        return n.type
166    else:
167        raise TypeError(f'Cannot add arguments {n.args[0]} ({ n.args[0].type}) and {n.args[1]} ({ n.args[1].type}) in node {n}.'
168                        f' Types should match ')
169
170@register_inference_rule(getattr)
171def get_attr_inference_rule(n: Node, traced):
172    """
173    The current getattr rule only handles the shape attribute
174    Can be extended to other attributes
175    The most representitive type we have is "Dyn" but the system
176    can be extended with more types, such as a type to represent shapes
177    """
178    attr_node = n.args[0]
179    attr_name = n.args[1]
180
181    if attr_name == "shape":
182        n.type = Dyn
183    else:
184        raise TypeError("Not yet implemented")
185
186    # TODO. We leave it like this till we add a type to represent tensor sizes
187    return n.type
188
189@register_inference_rule(torch.transpose)
190def transpose_inference_rule(n: Node):
191    """
192    We check that dimensions for the transpose operations
193    are within range of the tensor type of the node
194    """
195    if n.target == torch.transpose:
196        assert isinstance(n.args[0], Node)
197        t = n.args[0].type
198
199        assert isinstance(n.args[1], int)
200        assert isinstance(n.args[2], int)
201        dim1, dim2 = n.args[1], n.args[2]
202
203        if t == Dyn:
204            n.type = Dyn
205            return n.type
206
207        elif isinstance(t, TensorType):
208            if 0 <= dim1 < len(t.__args__) and 0 <= dim2 < len(t.__args__):
209                new_type = list(t.__args__)
210                new_type[dim1], new_type[dim2] = new_type[dim2], new_type[dim1]
211                final = TensorType(new_type)
212                n.type = get_greatest_upper_bound(n.type, final)
213                return n.type
214            else:
215                raise TypeError(f'Cannot transpose {dim1} and {dim2} in type {t} for node {n}')
216        else:
217            raise TypeError(f'Cannot transpose {dim1} and {dim2} in type {t} for node {n}')
218
219
220@register_inference_rule(torch.reshape)
221def reshape_inference_rule(n: Node):
222    """
223    Without dynamism, the rule checks that the
224    product of the elements of the argument tensor
225    type is equal to the product of the elements
226    of the required shape. We gradualize this rule
227    by adding a case to handle fully dynamic input
228    as well as input where some of the tensor dimensions
229    are unknown. In this case we check for divisibility
230    """
231    assert isinstance(n.args[0], Node)
232    t1 = n.args[0].type
233
234    assert isinstance(n.args[1], list)
235    t2 = n.args[1]
236    t2_type = TensorType([Dyn if elem == -1 else elem for elem in t2])
237
238    # if we do not know the original tensor dimension,
239    # we return the required dimension
240    if t1 == Dyn:
241        n.type = t2_type
242        return t2_type
243
244    # if any of the dimensions are unknown,
245    # we check for divisibility
246    elif isinstance(t1, TensorType):
247        assert isinstance(t1, TensorType)
248        a = [e if e != Dyn else 1 for e in t1.__args__]
249        p1 = reduce(operator.mul, a)
250        p2 = reduce(operator.mul, t2)
251        if p1 % p2 == 0 or p2 % p1 == 0:
252            n.type = t2_type
253            return t2_type
254        else:
255            raise TypeError(f'Cannot reshape in node {n} from {t1} to {t2_type}')
256    else:
257        raise TypeError(f'Cannot reshape in node {n} from {t1} to {t2_type}')
258
259@register_inference_rule(BatchNorm2d)
260def bn2d_inference_rule(n: Node, module_instance):
261    """
262    Given a BatchNorm2D instance and a node check the following conditions:
263    - the input type can be expanded to a size 4 tensor: t =  (x_1, x_2, x_3, x_4)
264    - the current node type can be expanded to a size 4 tensor: t' =  (x_1', x_2', x_3', x_4')
265    - t is consistent with t'
266    - x_2 is consistent with the module's num_features
267    - x_2' is consistent with the module's num_features
268    output type: the more precise type of t and t'
269    """
270    assert isinstance(n.args[0], Node)
271    n.args[0].type = expand_to_tensor_dim(n.args[0].type, 4)
272    arg_type = n.args[0].type
273    n.type = expand_to_tensor_dim(n.type, 4)
274
275    # we check the conditions on the incoming argument
276    # and any existing annotation
277    # we also check for consistency between both annotations
278    if is_consistent(arg_type.__args__[1], module_instance.num_features) and \
279            is_consistent(n.type.__args__[1], module_instance.num_features) and \
280            is_consistent(arg_type, n.type):
281
282        # we choose the more precise type
283        # to be the node type
284        # so if an incoming argument has more type information
285        # we set this node's type to be the argument type
286        n.type = get_greatest_upper_bound(arg_type, n.type)
287        return n.type
288    else:
289        raise TypeError(f'Cannot apply {module_instance} with input type {arg_type} and existing type {n.type} on {n}')
290
291
292def calculate_out_dimension(d_in, module_instance, index):
293    """
294    For calculating h_in and w_out according to the conv2D documentation
295    """
296    padding = (module_instance.padding, module_instance.padding) \
297        if isinstance(module_instance.padding, int) else module_instance.padding
298    kernel_size = (module_instance.kernel_size, module_instance.kernel_size) \
299        if isinstance(module_instance.kernel_size, int) else module_instance.kernel_size
300    stride = (module_instance.stride, module_instance.stride) \
301        if isinstance(module_instance.stride, int) else module_instance.stride
302    dilation = (module_instance.dilation, module_instance.dilation) \
303        if isinstance(module_instance.dilation, int) else module_instance.dilation
304
305    DIMENSION_TYPES = (int, sympy.Symbol)
306
307    if d_in == Dyn:
308        return Dyn
309
310    elif isinstance(d_in, DIMENSION_TYPES):
311        n = d_in + 2 * padding[index] - \
312            dilation[index] * \
313            (kernel_size[index] - 1) - 1
314
315        return (n // stride[0]) + 1
316
317    else:
318        raise TypeError(f'{d_in} in {module_instance} must be a number or Dyn. Received {type(d_in)}')
319
320
321def get_greatest_upper_bound(type1, type2):
322    """
323    Get the most precise type that's consistent with the given types
324    """
325    if type1 == Dyn:
326        return type2
327    elif type2 == Dyn:
328        return type1
329    elif isinstance(type1, TensorType) and isinstance(type2, TensorType):
330        if not is_consistent(type1, type2):
331            raise TypeError(f'Inconsistent types {type1}, {type2}')
332        gub = [t1 if is_more_precise(t1, t2) else t2 for (t1, t2) in zip(type1.__args__, type2.__args__)]
333        return TensorType(tuple(gub))
334
335
336@register_inference_rule(Conv2d)
337def conv2d_inference_rule(n: Node, module_instance):
338    """
339    Given a Conv2D instance and a node check the following conditions:
340    - the input type can be expanded to a size 4 tensor: t =  (x_1, x_2, H, W)
341    - the current node type can be expanded to a size 4 tensor: t' =  (x_1', x_2', x_3', x_4')
342    - x_2 is consistent with the module's in_channels
343    - let o = (x_1, out_channels, H_out, W_out)
344    then the output is the greatest upper bound of o and the existing node type t'.
345    """
346    assert isinstance(n.args[0], Node)
347    n.args[0].type = expand_to_tensor_dim(n.args[0].type, 4)
348    arg_type = n.args[0].type
349    curr_node_type = expand_to_tensor_dim(n.type, 4)
350
351    if is_consistent(arg_type.__args__[1], module_instance.in_channels):
352        w_in = arg_type.__args__[3]
353        h_in = arg_type.__args__[2]
354        h_out = calculate_out_dimension(h_in, module_instance, 0)
355        w_out = calculate_out_dimension(w_in, module_instance, 1)
356        new_type = TensorType((arg_type.__args__[0], module_instance.out_channels, h_out, w_out))
357        gub = get_greatest_upper_bound(new_type, curr_node_type)
358        n.type = gub
359        return n.type
360    else:
361        raise TypeError(f'Cannot apply {module_instance} with input type { arg_type} and existing type {n.type} on {n}')
362
363
364@register_inference_rule(torch.nn.ReLU)
365def relu_inference_rule(n: Node, module_instance):
366    """
367    Input and output shapes should be equal.
368    """
369    assert isinstance(n.args[0], Node)
370
371    if n.args[0].type == Dyn and isinstance(n.type, TensorType):
372        n.args[0].type = expand_to_tensor_dim(n.args[0].type, len(n.type.__args__))
373
374    if isinstance(n.args[0].type, TensorType):
375        n.type = get_greatest_upper_bound(n.args[0].type, n.type)
376    return n.type
377
378
379def maxpool2d_check(typ, module_instance):
380    """
381    Applies the maxpool2d shape information to the input
382    this affects the last two dimensions
383    """
384    new_type_list = list(typ.__args__)
385    if len(new_type_list) == 4 or len(new_type_list) == 3:
386        w_in = new_type_list[-1]
387        h_in = new_type_list[-2]
388
389        h_out = calculate_out_dimension(h_in, module_instance, 0)
390        w_out = calculate_out_dimension(w_in, module_instance, 1)
391
392        new_type_list[-1] = w_out
393        new_type_list[-2] = h_out
394        return TensorType(tuple(new_type_list))
395
396    else:
397        raise TypeError(f'Wrong size {typ} for {module_instance}')
398
399
400@register_inference_rule(torch.nn.MaxPool2d)
401def maxpool2d_inference_rule(n: Node, module_instance):
402    """
403    Given a MaxPool2D instance and a node check the following conditions:
404    - Input size matches size 3 or 4
405    - Current node type is consistent with the output type we will calculate
406    - Input size matches output size and the last two dimensions of the output
407      are w_out and h_out. The remaining dimensions are the same as the input
408    - Our final result is the greatest upper bound of the output we calculate
409      and the current node type.
410    """
411    assert isinstance(n.args[0], Node)
412
413    if n.args[0].type == Dyn and isinstance(n.type, TensorType):
414        n.args[0].type = expand_to_tensor_dim(n.args[0].type, len(n.type.__args__))
415    if isinstance(n.args[0].type, TensorType):
416        output = maxpool2d_check(n.args[0].type, module_instance)
417        n.type = get_greatest_upper_bound(output, n.type)
418    return n.type
419
420
421
422def linear_check(tensor_type, module_instance):
423    """
424    Checks that an input tensor type satisfies the conditions for linear operation
425    and returns the output type based on in and out features given by module_instance
426    """
427    if len(tensor_type.__args__) >= 2:
428        if is_consistent(module_instance.in_features, tensor_type.__args__[-1]):
429            new_type_args = list(tensor_type.__args__)
430            new_type_args[-1] = module_instance.out_features
431            return TensorType(tuple(new_type_args))
432        else:
433            raise TypeError(f'Inconsistent {module_instance.in_features} and {tensor_type.__args__[-1]} in {module_instance}')
434    else:
435        raise TypeError(f'Type {tensor_type} must have rank 2 or more.')
436
437
438@register_inference_rule(torch.nn.Linear)
439def linear_inference_rule(n: Node, module_instance):
440    """
441    Applies the shape information to the input then gets the greatest upper bound
442    of the resulting type and the existing type
443    """
444    assert isinstance(n.args[0], Node)
445    if n.args[0].type == Dyn and isinstance(n.type, TensorType):
446        n.args[0].type = expand_to_tensor_dim(n.args[0].type, len(n.type.__args__))
447    if isinstance(n.args[0].type, TensorType):
448        output_type = linear_check(n.args[0].type, module_instance)
449        n.type = get_greatest_upper_bound(output_type, n.type)
450    return n.type
451
452
453def adaptiveavgpool2d_check(tensor_type, module_instance):
454    output_size = module_instance.output_size
455    if isinstance(output_size, int):
456        output_size = [output_size, output_size]
457    elif isinstance(output_size, tuple):
458        output_size = list(output_size)
459        if output_size[0] is None:
460            output_size[0] = output_size[1]
461        if output_size[1] is None:
462            output_size[1] = output_size[0]
463
464    new_type_list = list(tensor_type.__args__)
465
466    if len(tensor_type.__args__) == 4 or len(tensor_type.__args__) == 3:
467        new_type_list[-1] = output_size[1]
468        new_type_list[-2] = output_size[0]
469
470        return TensorType(tuple(new_type_list))
471
472    else:
473        raise TypeError(f'Tensor ranks must be 3 or 4. Got {tensor_type}')
474
475@register_inference_rule(torch.nn.AdaptiveAvgPool2d)
476def adaptiveavgpool2d_inference_rule(n: Node, module_instance):
477    """
478    The input and output sizes should be the same except for the last
479    two dimensions taken from the input, which represent width and height
480    """
481    assert isinstance(n.args[0], Node)
482    if n.args[0].type == Dyn and isinstance(n.type, TensorType):
483        n.args[0].type = expand_to_tensor_dim(n.args[0].type, len(n.type.__args__))
484    if isinstance(n.args[0].type, TensorType):
485        output_type = adaptiveavgpool2d_check(n.args[0].type, module_instance)
486        n.type = get_greatest_upper_bound(n.type, output_type)
487    return n.type
488
489def flatten_check(tensor_type, start_dim, end_dim):
490    l = len(tensor_type.__args__)
491
492    start_dim = l if start_dim == -1 else abs(start_dim)
493    end_dim = l + end_dim + 1 if end_dim < 0 else end_dim + 1
494
495    if 0 <= start_dim <= (l - 1) and 0 <= end_dim <= l and start_dim < end_dim:
496        my_args = list(tensor_type.__args__)
497        lhs = my_args[0:start_dim]
498        rhs = my_args[end_dim:]
499        mid = my_args[start_dim:end_dim]
500        if Dyn in mid:
501            mid = [Dyn]
502        else:
503            mid = [reduce(operator.mul, my_args[start_dim:end_dim])]
504        new_type_list = lhs + mid + rhs
505        return TensorType(tuple(new_type_list))
506    else:
507        raise TypeError(f'Incompatible dimensions {start_dim}, {end_dim - 1} in type {tensor_type}')
508
509@register_inference_rule(torch.flatten)
510def flatten_inference_rule(n: Node):
511    """
512    Applies the flatten shape information to the input then gets the
513    greatest upper bound of the resulting type and the existing type
514    """
515    assert isinstance(n.args[0], Node)
516
517    # set the default start and end dims
518    start_dim = 1
519    end_dim = -1
520
521    if len(n.args) > 1:
522        assert isinstance(n.args[1], int)
523        start_dim = n.args[1]
524
525    if len(n.args) > 2:
526        assert isinstance(n.args[2], int)
527        end_dim = n.args[2]
528
529    if n.args[0].type == Dyn and isinstance(n.type, TensorType):
530        n.args[0].type = expand_to_tensor_dim(n.args[0].type, len(n.type.__args__))
531
532    if isinstance(n.args[0].type, TensorType):
533        output_type = flatten_check(n.args[0].type, start_dim, end_dim)
534        n.type = get_greatest_upper_bound(output_type , n.type)
535
536    return n.type
537
538class GraphTypeChecker:
539    def __init__(self, env, traced):
540        self.env = env
541        self.traced = traced
542
543    def type_check(self):
544        """
545        A gradual type checker for graphs
546        Effect: every node's field type will be
547        populated with a type after type-checking is done
548        """
549        graph = self.traced.graph
550
551        # type check every node with gradual type rules
552        # if any node does not type check return false
553        for n in graph.nodes:
554            self.type_check_node(n)
555        return True
556
557    def type_check_node(self, n: Node):
558        """
559        Type check a given fx node.
560        Current operations:
561        - Reshape
562        - Transpose
563        - Add
564        - Relu
565        - conv2d
566        - batchnorm2d
567        - flatten
568        - maxpool2d
569        - adaptiveavgpool2d
570        - linear
571        """
572        if n.type is None:
573            n.type = Dyn
574
575        if n.op == 'placeholder':
576            return n.type
577
578        elif n.op == 'get_attr':
579            t = get_parameter(self.traced, n.target)  # type: ignore[arg-type]
580            if isinstance(t.data, torch.Tensor):
581                n.type = TensorType(t.data.shape)
582            return n.type
583
584        elif n.op == 'call_function':
585            if n.target == getattr:
586                assert getattr in _INFERENCE_RULES
587                return _INFERENCE_RULES[n.target](n, self.traced)
588
589            elif n.target in _INFERENCE_RULES:
590                return _INFERENCE_RULES[n.target](n)
591            else:
592                raise RuntimeError(f'No inference rule registered for target {n.target}!')
593
594        elif n.op == 'call_module':
595            module_instance = self.traced.get_submodule(n.target)
596            if type(module_instance) in _INFERENCE_RULES:
597                return _INFERENCE_RULES[type(module_instance)](n, module_instance)
598            else:
599                raise RuntimeError(f'No inference rule registered for class {type(module_instance)}!')
600
601        elif n.op == 'output':
602            def get_node_type(a):
603                return a.type
604            n.type = torch.fx.node.map_arg(n.args[0], get_node_type)
605            return n.type
606
607        else:
608            raise NotImplementedError(f"Method {n.op} not yet implemented")
609
610
611@register_refinement_rule(Conv2d)
612def conv_refinement_rule(n: Node):
613    """
614    The equality constraints are between the first dimension of
615    the input and output
616    """
617    res = []
618    assert isinstance(n.args[0], Node)
619    arg_type = n.args[0].type
620    if isinstance(arg_type, TensorType) and isinstance(n.type, TensorType):
621        res = [Equality(arg_type.__args__[0], n.type.__args__[0])]
622        return res
623
624
625@register_refinement_rule(torch.nn.Linear)
626def linear_refinement_rule(n: Node):
627    """
628    The equality constraints are between the first dimension of
629    the input and output
630    """
631    res = []
632    assert isinstance(n.args[0], Node)
633    arg_type = n.args[0].type
634    if isinstance(arg_type, TensorType) and isinstance(n.type, TensorType):
635        res = [Equality(arg_type.__args__[0], n.type.__args__[0])]
636    return res
637
638@register_refinement_rule(BatchNorm2d)
639@register_refinement_rule(torch.nn.ReLU)
640def all_eq(n: Node):
641    """
642    For operations where the input shape is equal to the output shape
643    """
644    res = []
645    assert isinstance(n.args[0], Node)
646    arg_type = n.args[0].type
647    if isinstance(arg_type, TensorType) and isinstance(n.type, TensorType):
648        args1 = arg_type.__args__
649        args2 = n.type.__args__
650        res = [Equality(args1[i], args2[i]) for i in range(len(args1))]
651    return res
652
653
654@register_refinement_rule(torch.nn.AdaptiveAvgPool2d)
655@register_refinement_rule(torch.nn.MaxPool2d)
656def first_two_eq(n: Node):
657    """
658    For operations where the first two dimensions of the input and output shape
659    are equal
660    """
661    res = []
662    assert isinstance(n.args[0], Node)
663    arg_type = n.args[0].type
664    if isinstance(arg_type, TensorType) and isinstance(n.type, TensorType):
665        args1 = arg_type.__args__
666        args2 = n.type.__args__
667        res = [Equality(args1[0], args2[0]), Equality(args1[1], args2[1])]
668    return res
669
670
671@register_refinement_rule(torch.add)
672@register_refinement_rule(operator.add)
673def element_wise_eq(n: Node):
674    """
675    For element-wise operations and handles broadcasting.
676    Note that after applying broadcasting to the arguments
677    we are able to determine if certain dimensions have not been broadcast
678    if they are symbolicallu equal.
679
680    in this case, we can establish equality between those dimensions and the
681    corresponding output dimensions.
682
683    Note that it takes two iterations for this result. One iteration to establish
684    equality between certain dimensions of the operands (requiring the whole solver
685    including unification) and another iteration to establish equality between the operands
686    and the resulting type, requiring another round of constraint generation and unificaiton.
687    """
688    res = []
689    if isinstance(n.args[0], Node) and isinstance(n.args[1], Node):
690        arg_type1 = n.args[0].type
691        arg_type2 = n.args[1].type
692        if isinstance(arg_type1, TensorType) and isinstance(arg_type2, TensorType) and isinstance(n.type, TensorType):
693            args1, args2 = broadcast_types(arg_type1, arg_type2)
694            # by this point, we know that args1 and args2 are the same size.
695            a1 = args1.__args__
696            a2 = args2.__args__
697            a3 = n.type.__args__
698
699            # we would be here in the second iteration where we establish equality
700            # between operand type dimensions and the resulting type dimensions
701            r = []
702            for x, y, z in zip(a1, a2, a3):
703                if x == y:
704                    r.append(Equality(x, z))
705            res = r
706    return res
707
708
709@register_refinement_rule(torch.flatten)
710def flatten_refinement_rule(n: Node):
711    """
712    Generates equality constraints between the dimensions of the input and output
713    that will not be involved in the flatten operation
714    """
715    assert isinstance(n.args[0], Node)
716
717    eq_const = []
718
719    start_dim = 1
720    end_dim = -1
721
722    if len(n.args) > 1:
723        assert isinstance(n.args[1], int)
724        start_dim = n.args[1]
725
726    if len(n.args) > 2:
727        assert isinstance(n.args[2], int)
728        end_dim = n.args[2]
729
730    if isinstance(n.type, TensorType) and isinstance(n.args[0].type, TensorType):
731        l = len(n.type.__args__)
732        arg_type = n.args[0].type
733        start_dim = l if start_dim == -1 else start_dim
734        end_dim = l + end_dim + 1 if end_dim < 0 else end_dim + 1
735
736        for t1, t2 in zip(n.type.__args__[0:start_dim], arg_type.__args__[0:start_dim]):
737            eq_const.append(Equality(t1, t2))
738
739        for t1, t2 in zip(n.type.__args__[end_dim:], arg_type.__args__[end_dim:]):
740            eq_const.append(Equality(t1, t2))
741    return eq_const
742
743
744@register_algebraic_expressions_inference_rule(Conv2d)
745def conv_rule(n: Node, module_instance):
746    """
747    Represents the outout in terms of an algrbraic expression w.r.t
748    the input when possible
749    """
750    assert isinstance(n.args[0], Node)
751    arg_type = n.args[0].type
752    if isinstance(arg_type, TensorType) and isinstance(n.type, TensorType):
753        w_in = arg_type.__args__[3]
754        h_in = arg_type.__args__[2]
755        h_out = calculate_out_dimension(h_in, module_instance, 0)
756        w_out = calculate_out_dimension(w_in, module_instance, 1)
757        new_type = TensorType((n.type.__args__[0], n.type.__args__[1], h_out, w_out))
758        n.type = new_type
759        return new_type
760
761class Refine:
762    """
763    Symbolic shape inference.
764    Generates constraints over type variables.
765    Currently all constraints are equality constraints.
766    """
767    def __init__(self, traced):
768        self.constraints = []
769        self.traced = traced
770        self.symbol_iter = itertools.count(start=0, step=1)
771
772    def refine(self):
773        """
774        Generates constraints for
775        every node in the graph based on
776        the operation.
777        """
778        graph = self.traced.graph
779        for n in graph.nodes:
780            self.refine_node(n)
781        return True
782
783    def symbolic_relations(self):
784        """
785        Infers algebraic relations
786        """
787        graph = self.traced.graph
788        for n in graph.nodes:
789            self.infer_symbolic_relations(n)
790        return True
791
792    def replace_dyn_with_fresh_var(self, typ):
793        """
794        Replace all unknown types with fresh type variables.
795        """
796        if typ == Dyn:
797            new_symbol = Var(next(self.symbol_iter))
798            return new_symbol
799        elif isinstance(typ, TensorType):
800            new_args = [self.replace_dyn_with_fresh_var(a) for a in typ.__args__]
801            return TensorType(tuple(new_args))
802        elif isinstance(typ, list):
803            return [self.replace_dyn_with_fresh_var(t) for t in typ]
804        elif isinstance(typ, tuple):
805            return (self.replace_dyn_with_fresh_var(t) for t in typ)
806        else:
807            return typ
808
809
810    def convert_to_sympy_symbols(self, typ):
811        """
812        Replace all unknown types with fresh type variables.
813        """
814        if isinstance(typ, Var):
815            return sympy.symbols(str(typ))
816        elif isinstance(typ, TensorType):
817            new_args = [self.convert_to_sympy_symbols(a) for a in typ.__args__]
818            return TensorType(tuple(new_args))
819        elif isinstance(typ, list):
820            return [self.convert_to_sympy_symbols(t) for t in typ]
821        elif isinstance(typ, tuple):
822            return (self.convert_to_sympy_symbols(t) for t in typ)
823        else:
824            return typ
825
826    def refine_node(self, n: Node):
827        """
828        Returns a list of equality constraints for
829        call_module and call_function nodes.
830        Models the relation between input and output dimensions
831        using constraints in case they are both tensors.
832        All operations used in resnet50 are defined.
833        """
834        if n.type is None:
835            n.type = Dyn
836
837        n.type = self.replace_dyn_with_fresh_var(n.type)
838
839        if n.op == 'call_function':
840            if n.target in _REFINEMENT_RULES:
841                self.constraints += _REFINEMENT_RULES[n.target](n)
842            else:
843                pass
844
845        if n.op == 'call_module':
846            module_instance = self.traced.get_submodule(n.target)
847            if type(module_instance) in _REFINEMENT_RULES:
848                self.constraints += _REFINEMENT_RULES[type(module_instance)](n)
849            else:
850                pass
851
852        if n.op == 'output':
853            def get_node_type(a):
854                return a.type
855            n.type = torch.fx.node.map_arg(n.args[0], get_node_type)
856            return n.type
857
858        else:
859            pass
860
861    def infer_symbolic_relations(self, n: Node):
862        n.type = self.convert_to_sympy_symbols(n.type)
863        if n.op == 'call_function':
864            if n.target in _RULES:
865                return _RULES[n.target](n)
866            else:
867                pass
868
869        if n.op == 'call_module':
870            module_instance = self.traced.get_submodule(n.target)
871            if type(module_instance) in _RULES:
872                return _RULES[type(module_instance)](n, module_instance)
873            else:
874                pass
875
876        if n.op == 'output':
877            def get_node_type(a):
878                return a.type
879            n.type = torch.fx.node.map_arg(n.args[0], get_node_type)
880            return n.type
881
882        else:
883            pass
884
885def get_parameter(traced, target: str):
886    """
887    Returns the parameter given by ``target`` if it exists,
888    otherwise throws an error.
889
890    See the docstring for ``get_submodule`` for a more detailed
891    explanation of this method's functionality as well as how to
892    correctly specify ``target``.
893
894    Args:
895        target: The fully-qualified string name of the Parameter
896            to look for. (See ``get_submodule`` for how to specify a
897            fully-qualified string.)
898
899    Returns:
900        torch.nn.Parameter: The Parameter referenced by ``target``
901
902    Raises:
903        AttributeError: If the target string references an invalid
904            path or resolves to something that is not an
905            ``nn.Parameter``
906    """
907    module_path, _, param_name = target.rpartition(".")
908
909    mod: torch.nn.Module = traced.get_submodule(module_path)
910
911    if not hasattr(mod, param_name):
912        raise AttributeError(mod._get_name() + " has no attribute `" + param_name + "`")
913
914    param: torch.nn.Parameter = getattr(mod, param_name)
915
916    return param
917