1# mypy: allow-untyped-decorators 2# mypy: allow-untyped-defs 3from functools import reduce 4import torch 5import operator 6from torch.fx.tensor_type import Dyn, is_consistent, TensorType, is_more_precise 7from typing import Callable, Dict 8from torch.fx.node import Target, Node 9from torch.nn.modules.batchnorm import BatchNorm2d 10from torch.nn.modules.conv import Conv2d 11from torch.fx.experimental.refinement_types import Equality 12import itertools 13 14from torch.fx.experimental.unification import Var # type: ignore[attr-defined] 15 16import sympy 17 18_INFERENCE_RULES: Dict[Target, Callable] = {} 19_REFINEMENT_RULES: Dict[Target, Callable] = {} 20_RULES: Dict[Target, Callable] = {} 21 22 23def expand_to_tensor_dim(t, n): 24 """ 25 Expand a type to the desired tensor dimension if possible 26 Raise an error otherwise. 27 - t is the given type 28 - n is a number of dimensions to expand to 29 """ 30 if t == Dyn: 31 dims = [Dyn] * n 32 return TensorType(tuple(dims)) 33 elif isinstance(t, TensorType): 34 if len(t.__args__) != n: 35 raise TypeError(f'Cannot extend tensor. Tensor {t} has rank {len(t.__args__)}. It should have rank {n}') 36 return t 37 else: 38 raise TypeError(f'Cannot match the type {t}') 39 40 41def broadcast_types(t1, t2): 42 """ 43 Applies broadcasting to both given types such that they 44 become consistent with eachother and returns two new 45 resulting types 46 """ 47 48 # if either type is Dyn, do nothing since the types are already consistent 49 if t1 == Dyn or t2 == Dyn or isinstance(t1, Var) or isinstance(t2, Var): 50 return t1, t2 51 52 if isinstance(t1, TensorType) and isinstance(t2, TensorType): 53 s1 = len(t1.__args__) 54 s2 = len(t2.__args__) 55 56 new_t1 = list(t1.__args__) 57 new_t2 = list(t2.__args__) 58 59 # We make the types the same length which is the first requirement 60 # for consistency 61 if s1 > s2: 62 for i in range(s1 - s2): 63 new_t2.insert(0, 1) 64 65 elif s2 > s1: 66 for i in range(s2 - s1): 67 new_t1.insert(0, 1) 68 69 # we replace occurrences of "1" with each tensor with 70 # the corresponding type from the other tensor 71 for i, (x, y) in enumerate(zip(new_t1, new_t2)): 72 if x == 1: 73 new_t1[i] = y 74 elif y == 1: 75 new_t2[i] = x 76 77 # at this point our tensors should be consistent 78 # and we can apply the element-wise operation and find the right dimension 79 # for the output of the operation 80 (t1, t2) = TensorType(tuple(new_t1)), TensorType(tuple(new_t2)) 81 return (t1, t2) 82 else: 83 raise TypeError(f'Cannot broadcast types {t1} and {t2}') 84 85def register_inference_rule(call_target): 86 def register(fn): 87 if call_target in _INFERENCE_RULES: 88 raise RuntimeError(f'Inference rule already registered for {call_target}!') 89 _INFERENCE_RULES[call_target] = fn 90 return fn 91 return register 92 93def register_refinement_rule(call_target): 94 def register(fn): 95 if call_target in _REFINEMENT_RULES: 96 raise RuntimeError(f'Refinement rule already registered for {call_target}!') 97 _REFINEMENT_RULES[call_target] = fn 98 return fn 99 return register 100 101def register_algebraic_expressions_inference_rule(call_target): 102 def register(fn): 103 if call_target in _RULES: 104 raise RuntimeError(f'Rule already registered for {call_target}!') 105 _RULES[call_target] = fn 106 return fn 107 return register 108 109@register_inference_rule(torch.add) 110@register_inference_rule(operator.add) 111def add_inference_rule(n: Node): 112 """ 113 Apply the addition inference rule. This includes: 114 - scalar addition 115 - broadcasting semantics 116 117 Note that we always return the least precise type between 118 the operands (after applying broadcasting) to be the final type of the operation 119 120 Note that we do not modify the operand types themselves after applying broadcasting 121 to them. We only use them to calculate the final type 122 """ 123 assert isinstance(n.args[0], Node) 124 assert isinstance(n.args[1], Node) 125 t1 = n.args[0].type 126 t2 = n.args[1].type 127 128 # handle scalar addition 129 if t1 == int and isinstance(t2, TensorType): 130 n.type = t2 131 return n.type 132 133 # handle scalar addition 134 elif t2 == int and isinstance(t1, TensorType): 135 n.type = t1 136 return n.type 137 138 # we bring the new types to the point where 139 # we can check for consistency 140 # any inconsistency would not have been caused 141 # by broadcasting at this point 142 (new_t1, new_t2) = broadcast_types(t1, t2) 143 144 if new_t1 != t1 or new_t2 != t2: 145 n.meta['broadcast'] = True 146 n.meta[str(n.args[0])] = new_t1 147 n.meta[str(n.args[1])] = new_t2 148 149 else: 150 n.meta['broadcast'] = False 151 152 new_t1 = t1 if not n.meta['broadcast'] else new_t1 153 new_t2 = t2 if not n.meta['broadcast'] else new_t2 154 155 # we check for consistency between the new types 156 if is_consistent(new_t1, new_t2): 157 # we return the less precise type because 158 # broadcasting may have happened 159 # for operands with shape [1,2,Dyn] and [1,2,1] 160 # we have to assign the node [1,2,Dyn] 161 if is_more_precise(new_t1, new_t2): 162 n.type = new_t2 163 else: 164 n.type = new_t1 165 return n.type 166 else: 167 raise TypeError(f'Cannot add arguments {n.args[0]} ({ n.args[0].type}) and {n.args[1]} ({ n.args[1].type}) in node {n}.' 168 f' Types should match ') 169 170@register_inference_rule(getattr) 171def get_attr_inference_rule(n: Node, traced): 172 """ 173 The current getattr rule only handles the shape attribute 174 Can be extended to other attributes 175 The most representitive type we have is "Dyn" but the system 176 can be extended with more types, such as a type to represent shapes 177 """ 178 attr_node = n.args[0] 179 attr_name = n.args[1] 180 181 if attr_name == "shape": 182 n.type = Dyn 183 else: 184 raise TypeError("Not yet implemented") 185 186 # TODO. We leave it like this till we add a type to represent tensor sizes 187 return n.type 188 189@register_inference_rule(torch.transpose) 190def transpose_inference_rule(n: Node): 191 """ 192 We check that dimensions for the transpose operations 193 are within range of the tensor type of the node 194 """ 195 if n.target == torch.transpose: 196 assert isinstance(n.args[0], Node) 197 t = n.args[0].type 198 199 assert isinstance(n.args[1], int) 200 assert isinstance(n.args[2], int) 201 dim1, dim2 = n.args[1], n.args[2] 202 203 if t == Dyn: 204 n.type = Dyn 205 return n.type 206 207 elif isinstance(t, TensorType): 208 if 0 <= dim1 < len(t.__args__) and 0 <= dim2 < len(t.__args__): 209 new_type = list(t.__args__) 210 new_type[dim1], new_type[dim2] = new_type[dim2], new_type[dim1] 211 final = TensorType(new_type) 212 n.type = get_greatest_upper_bound(n.type, final) 213 return n.type 214 else: 215 raise TypeError(f'Cannot transpose {dim1} and {dim2} in type {t} for node {n}') 216 else: 217 raise TypeError(f'Cannot transpose {dim1} and {dim2} in type {t} for node {n}') 218 219 220@register_inference_rule(torch.reshape) 221def reshape_inference_rule(n: Node): 222 """ 223 Without dynamism, the rule checks that the 224 product of the elements of the argument tensor 225 type is equal to the product of the elements 226 of the required shape. We gradualize this rule 227 by adding a case to handle fully dynamic input 228 as well as input where some of the tensor dimensions 229 are unknown. In this case we check for divisibility 230 """ 231 assert isinstance(n.args[0], Node) 232 t1 = n.args[0].type 233 234 assert isinstance(n.args[1], list) 235 t2 = n.args[1] 236 t2_type = TensorType([Dyn if elem == -1 else elem for elem in t2]) 237 238 # if we do not know the original tensor dimension, 239 # we return the required dimension 240 if t1 == Dyn: 241 n.type = t2_type 242 return t2_type 243 244 # if any of the dimensions are unknown, 245 # we check for divisibility 246 elif isinstance(t1, TensorType): 247 assert isinstance(t1, TensorType) 248 a = [e if e != Dyn else 1 for e in t1.__args__] 249 p1 = reduce(operator.mul, a) 250 p2 = reduce(operator.mul, t2) 251 if p1 % p2 == 0 or p2 % p1 == 0: 252 n.type = t2_type 253 return t2_type 254 else: 255 raise TypeError(f'Cannot reshape in node {n} from {t1} to {t2_type}') 256 else: 257 raise TypeError(f'Cannot reshape in node {n} from {t1} to {t2_type}') 258 259@register_inference_rule(BatchNorm2d) 260def bn2d_inference_rule(n: Node, module_instance): 261 """ 262 Given a BatchNorm2D instance and a node check the following conditions: 263 - the input type can be expanded to a size 4 tensor: t = (x_1, x_2, x_3, x_4) 264 - the current node type can be expanded to a size 4 tensor: t' = (x_1', x_2', x_3', x_4') 265 - t is consistent with t' 266 - x_2 is consistent with the module's num_features 267 - x_2' is consistent with the module's num_features 268 output type: the more precise type of t and t' 269 """ 270 assert isinstance(n.args[0], Node) 271 n.args[0].type = expand_to_tensor_dim(n.args[0].type, 4) 272 arg_type = n.args[0].type 273 n.type = expand_to_tensor_dim(n.type, 4) 274 275 # we check the conditions on the incoming argument 276 # and any existing annotation 277 # we also check for consistency between both annotations 278 if is_consistent(arg_type.__args__[1], module_instance.num_features) and \ 279 is_consistent(n.type.__args__[1], module_instance.num_features) and \ 280 is_consistent(arg_type, n.type): 281 282 # we choose the more precise type 283 # to be the node type 284 # so if an incoming argument has more type information 285 # we set this node's type to be the argument type 286 n.type = get_greatest_upper_bound(arg_type, n.type) 287 return n.type 288 else: 289 raise TypeError(f'Cannot apply {module_instance} with input type {arg_type} and existing type {n.type} on {n}') 290 291 292def calculate_out_dimension(d_in, module_instance, index): 293 """ 294 For calculating h_in and w_out according to the conv2D documentation 295 """ 296 padding = (module_instance.padding, module_instance.padding) \ 297 if isinstance(module_instance.padding, int) else module_instance.padding 298 kernel_size = (module_instance.kernel_size, module_instance.kernel_size) \ 299 if isinstance(module_instance.kernel_size, int) else module_instance.kernel_size 300 stride = (module_instance.stride, module_instance.stride) \ 301 if isinstance(module_instance.stride, int) else module_instance.stride 302 dilation = (module_instance.dilation, module_instance.dilation) \ 303 if isinstance(module_instance.dilation, int) else module_instance.dilation 304 305 DIMENSION_TYPES = (int, sympy.Symbol) 306 307 if d_in == Dyn: 308 return Dyn 309 310 elif isinstance(d_in, DIMENSION_TYPES): 311 n = d_in + 2 * padding[index] - \ 312 dilation[index] * \ 313 (kernel_size[index] - 1) - 1 314 315 return (n // stride[0]) + 1 316 317 else: 318 raise TypeError(f'{d_in} in {module_instance} must be a number or Dyn. Received {type(d_in)}') 319 320 321def get_greatest_upper_bound(type1, type2): 322 """ 323 Get the most precise type that's consistent with the given types 324 """ 325 if type1 == Dyn: 326 return type2 327 elif type2 == Dyn: 328 return type1 329 elif isinstance(type1, TensorType) and isinstance(type2, TensorType): 330 if not is_consistent(type1, type2): 331 raise TypeError(f'Inconsistent types {type1}, {type2}') 332 gub = [t1 if is_more_precise(t1, t2) else t2 for (t1, t2) in zip(type1.__args__, type2.__args__)] 333 return TensorType(tuple(gub)) 334 335 336@register_inference_rule(Conv2d) 337def conv2d_inference_rule(n: Node, module_instance): 338 """ 339 Given a Conv2D instance and a node check the following conditions: 340 - the input type can be expanded to a size 4 tensor: t = (x_1, x_2, H, W) 341 - the current node type can be expanded to a size 4 tensor: t' = (x_1', x_2', x_3', x_4') 342 - x_2 is consistent with the module's in_channels 343 - let o = (x_1, out_channels, H_out, W_out) 344 then the output is the greatest upper bound of o and the existing node type t'. 345 """ 346 assert isinstance(n.args[0], Node) 347 n.args[0].type = expand_to_tensor_dim(n.args[0].type, 4) 348 arg_type = n.args[0].type 349 curr_node_type = expand_to_tensor_dim(n.type, 4) 350 351 if is_consistent(arg_type.__args__[1], module_instance.in_channels): 352 w_in = arg_type.__args__[3] 353 h_in = arg_type.__args__[2] 354 h_out = calculate_out_dimension(h_in, module_instance, 0) 355 w_out = calculate_out_dimension(w_in, module_instance, 1) 356 new_type = TensorType((arg_type.__args__[0], module_instance.out_channels, h_out, w_out)) 357 gub = get_greatest_upper_bound(new_type, curr_node_type) 358 n.type = gub 359 return n.type 360 else: 361 raise TypeError(f'Cannot apply {module_instance} with input type { arg_type} and existing type {n.type} on {n}') 362 363 364@register_inference_rule(torch.nn.ReLU) 365def relu_inference_rule(n: Node, module_instance): 366 """ 367 Input and output shapes should be equal. 368 """ 369 assert isinstance(n.args[0], Node) 370 371 if n.args[0].type == Dyn and isinstance(n.type, TensorType): 372 n.args[0].type = expand_to_tensor_dim(n.args[0].type, len(n.type.__args__)) 373 374 if isinstance(n.args[0].type, TensorType): 375 n.type = get_greatest_upper_bound(n.args[0].type, n.type) 376 return n.type 377 378 379def maxpool2d_check(typ, module_instance): 380 """ 381 Applies the maxpool2d shape information to the input 382 this affects the last two dimensions 383 """ 384 new_type_list = list(typ.__args__) 385 if len(new_type_list) == 4 or len(new_type_list) == 3: 386 w_in = new_type_list[-1] 387 h_in = new_type_list[-2] 388 389 h_out = calculate_out_dimension(h_in, module_instance, 0) 390 w_out = calculate_out_dimension(w_in, module_instance, 1) 391 392 new_type_list[-1] = w_out 393 new_type_list[-2] = h_out 394 return TensorType(tuple(new_type_list)) 395 396 else: 397 raise TypeError(f'Wrong size {typ} for {module_instance}') 398 399 400@register_inference_rule(torch.nn.MaxPool2d) 401def maxpool2d_inference_rule(n: Node, module_instance): 402 """ 403 Given a MaxPool2D instance and a node check the following conditions: 404 - Input size matches size 3 or 4 405 - Current node type is consistent with the output type we will calculate 406 - Input size matches output size and the last two dimensions of the output 407 are w_out and h_out. The remaining dimensions are the same as the input 408 - Our final result is the greatest upper bound of the output we calculate 409 and the current node type. 410 """ 411 assert isinstance(n.args[0], Node) 412 413 if n.args[0].type == Dyn and isinstance(n.type, TensorType): 414 n.args[0].type = expand_to_tensor_dim(n.args[0].type, len(n.type.__args__)) 415 if isinstance(n.args[0].type, TensorType): 416 output = maxpool2d_check(n.args[0].type, module_instance) 417 n.type = get_greatest_upper_bound(output, n.type) 418 return n.type 419 420 421 422def linear_check(tensor_type, module_instance): 423 """ 424 Checks that an input tensor type satisfies the conditions for linear operation 425 and returns the output type based on in and out features given by module_instance 426 """ 427 if len(tensor_type.__args__) >= 2: 428 if is_consistent(module_instance.in_features, tensor_type.__args__[-1]): 429 new_type_args = list(tensor_type.__args__) 430 new_type_args[-1] = module_instance.out_features 431 return TensorType(tuple(new_type_args)) 432 else: 433 raise TypeError(f'Inconsistent {module_instance.in_features} and {tensor_type.__args__[-1]} in {module_instance}') 434 else: 435 raise TypeError(f'Type {tensor_type} must have rank 2 or more.') 436 437 438@register_inference_rule(torch.nn.Linear) 439def linear_inference_rule(n: Node, module_instance): 440 """ 441 Applies the shape information to the input then gets the greatest upper bound 442 of the resulting type and the existing type 443 """ 444 assert isinstance(n.args[0], Node) 445 if n.args[0].type == Dyn and isinstance(n.type, TensorType): 446 n.args[0].type = expand_to_tensor_dim(n.args[0].type, len(n.type.__args__)) 447 if isinstance(n.args[0].type, TensorType): 448 output_type = linear_check(n.args[0].type, module_instance) 449 n.type = get_greatest_upper_bound(output_type, n.type) 450 return n.type 451 452 453def adaptiveavgpool2d_check(tensor_type, module_instance): 454 output_size = module_instance.output_size 455 if isinstance(output_size, int): 456 output_size = [output_size, output_size] 457 elif isinstance(output_size, tuple): 458 output_size = list(output_size) 459 if output_size[0] is None: 460 output_size[0] = output_size[1] 461 if output_size[1] is None: 462 output_size[1] = output_size[0] 463 464 new_type_list = list(tensor_type.__args__) 465 466 if len(tensor_type.__args__) == 4 or len(tensor_type.__args__) == 3: 467 new_type_list[-1] = output_size[1] 468 new_type_list[-2] = output_size[0] 469 470 return TensorType(tuple(new_type_list)) 471 472 else: 473 raise TypeError(f'Tensor ranks must be 3 or 4. Got {tensor_type}') 474 475@register_inference_rule(torch.nn.AdaptiveAvgPool2d) 476def adaptiveavgpool2d_inference_rule(n: Node, module_instance): 477 """ 478 The input and output sizes should be the same except for the last 479 two dimensions taken from the input, which represent width and height 480 """ 481 assert isinstance(n.args[0], Node) 482 if n.args[0].type == Dyn and isinstance(n.type, TensorType): 483 n.args[0].type = expand_to_tensor_dim(n.args[0].type, len(n.type.__args__)) 484 if isinstance(n.args[0].type, TensorType): 485 output_type = adaptiveavgpool2d_check(n.args[0].type, module_instance) 486 n.type = get_greatest_upper_bound(n.type, output_type) 487 return n.type 488 489def flatten_check(tensor_type, start_dim, end_dim): 490 l = len(tensor_type.__args__) 491 492 start_dim = l if start_dim == -1 else abs(start_dim) 493 end_dim = l + end_dim + 1 if end_dim < 0 else end_dim + 1 494 495 if 0 <= start_dim <= (l - 1) and 0 <= end_dim <= l and start_dim < end_dim: 496 my_args = list(tensor_type.__args__) 497 lhs = my_args[0:start_dim] 498 rhs = my_args[end_dim:] 499 mid = my_args[start_dim:end_dim] 500 if Dyn in mid: 501 mid = [Dyn] 502 else: 503 mid = [reduce(operator.mul, my_args[start_dim:end_dim])] 504 new_type_list = lhs + mid + rhs 505 return TensorType(tuple(new_type_list)) 506 else: 507 raise TypeError(f'Incompatible dimensions {start_dim}, {end_dim - 1} in type {tensor_type}') 508 509@register_inference_rule(torch.flatten) 510def flatten_inference_rule(n: Node): 511 """ 512 Applies the flatten shape information to the input then gets the 513 greatest upper bound of the resulting type and the existing type 514 """ 515 assert isinstance(n.args[0], Node) 516 517 # set the default start and end dims 518 start_dim = 1 519 end_dim = -1 520 521 if len(n.args) > 1: 522 assert isinstance(n.args[1], int) 523 start_dim = n.args[1] 524 525 if len(n.args) > 2: 526 assert isinstance(n.args[2], int) 527 end_dim = n.args[2] 528 529 if n.args[0].type == Dyn and isinstance(n.type, TensorType): 530 n.args[0].type = expand_to_tensor_dim(n.args[0].type, len(n.type.__args__)) 531 532 if isinstance(n.args[0].type, TensorType): 533 output_type = flatten_check(n.args[0].type, start_dim, end_dim) 534 n.type = get_greatest_upper_bound(output_type , n.type) 535 536 return n.type 537 538class GraphTypeChecker: 539 def __init__(self, env, traced): 540 self.env = env 541 self.traced = traced 542 543 def type_check(self): 544 """ 545 A gradual type checker for graphs 546 Effect: every node's field type will be 547 populated with a type after type-checking is done 548 """ 549 graph = self.traced.graph 550 551 # type check every node with gradual type rules 552 # if any node does not type check return false 553 for n in graph.nodes: 554 self.type_check_node(n) 555 return True 556 557 def type_check_node(self, n: Node): 558 """ 559 Type check a given fx node. 560 Current operations: 561 - Reshape 562 - Transpose 563 - Add 564 - Relu 565 - conv2d 566 - batchnorm2d 567 - flatten 568 - maxpool2d 569 - adaptiveavgpool2d 570 - linear 571 """ 572 if n.type is None: 573 n.type = Dyn 574 575 if n.op == 'placeholder': 576 return n.type 577 578 elif n.op == 'get_attr': 579 t = get_parameter(self.traced, n.target) # type: ignore[arg-type] 580 if isinstance(t.data, torch.Tensor): 581 n.type = TensorType(t.data.shape) 582 return n.type 583 584 elif n.op == 'call_function': 585 if n.target == getattr: 586 assert getattr in _INFERENCE_RULES 587 return _INFERENCE_RULES[n.target](n, self.traced) 588 589 elif n.target in _INFERENCE_RULES: 590 return _INFERENCE_RULES[n.target](n) 591 else: 592 raise RuntimeError(f'No inference rule registered for target {n.target}!') 593 594 elif n.op == 'call_module': 595 module_instance = self.traced.get_submodule(n.target) 596 if type(module_instance) in _INFERENCE_RULES: 597 return _INFERENCE_RULES[type(module_instance)](n, module_instance) 598 else: 599 raise RuntimeError(f'No inference rule registered for class {type(module_instance)}!') 600 601 elif n.op == 'output': 602 def get_node_type(a): 603 return a.type 604 n.type = torch.fx.node.map_arg(n.args[0], get_node_type) 605 return n.type 606 607 else: 608 raise NotImplementedError(f"Method {n.op} not yet implemented") 609 610 611@register_refinement_rule(Conv2d) 612def conv_refinement_rule(n: Node): 613 """ 614 The equality constraints are between the first dimension of 615 the input and output 616 """ 617 res = [] 618 assert isinstance(n.args[0], Node) 619 arg_type = n.args[0].type 620 if isinstance(arg_type, TensorType) and isinstance(n.type, TensorType): 621 res = [Equality(arg_type.__args__[0], n.type.__args__[0])] 622 return res 623 624 625@register_refinement_rule(torch.nn.Linear) 626def linear_refinement_rule(n: Node): 627 """ 628 The equality constraints are between the first dimension of 629 the input and output 630 """ 631 res = [] 632 assert isinstance(n.args[0], Node) 633 arg_type = n.args[0].type 634 if isinstance(arg_type, TensorType) and isinstance(n.type, TensorType): 635 res = [Equality(arg_type.__args__[0], n.type.__args__[0])] 636 return res 637 638@register_refinement_rule(BatchNorm2d) 639@register_refinement_rule(torch.nn.ReLU) 640def all_eq(n: Node): 641 """ 642 For operations where the input shape is equal to the output shape 643 """ 644 res = [] 645 assert isinstance(n.args[0], Node) 646 arg_type = n.args[0].type 647 if isinstance(arg_type, TensorType) and isinstance(n.type, TensorType): 648 args1 = arg_type.__args__ 649 args2 = n.type.__args__ 650 res = [Equality(args1[i], args2[i]) for i in range(len(args1))] 651 return res 652 653 654@register_refinement_rule(torch.nn.AdaptiveAvgPool2d) 655@register_refinement_rule(torch.nn.MaxPool2d) 656def first_two_eq(n: Node): 657 """ 658 For operations where the first two dimensions of the input and output shape 659 are equal 660 """ 661 res = [] 662 assert isinstance(n.args[0], Node) 663 arg_type = n.args[0].type 664 if isinstance(arg_type, TensorType) and isinstance(n.type, TensorType): 665 args1 = arg_type.__args__ 666 args2 = n.type.__args__ 667 res = [Equality(args1[0], args2[0]), Equality(args1[1], args2[1])] 668 return res 669 670 671@register_refinement_rule(torch.add) 672@register_refinement_rule(operator.add) 673def element_wise_eq(n: Node): 674 """ 675 For element-wise operations and handles broadcasting. 676 Note that after applying broadcasting to the arguments 677 we are able to determine if certain dimensions have not been broadcast 678 if they are symbolicallu equal. 679 680 in this case, we can establish equality between those dimensions and the 681 corresponding output dimensions. 682 683 Note that it takes two iterations for this result. One iteration to establish 684 equality between certain dimensions of the operands (requiring the whole solver 685 including unification) and another iteration to establish equality between the operands 686 and the resulting type, requiring another round of constraint generation and unificaiton. 687 """ 688 res = [] 689 if isinstance(n.args[0], Node) and isinstance(n.args[1], Node): 690 arg_type1 = n.args[0].type 691 arg_type2 = n.args[1].type 692 if isinstance(arg_type1, TensorType) and isinstance(arg_type2, TensorType) and isinstance(n.type, TensorType): 693 args1, args2 = broadcast_types(arg_type1, arg_type2) 694 # by this point, we know that args1 and args2 are the same size. 695 a1 = args1.__args__ 696 a2 = args2.__args__ 697 a3 = n.type.__args__ 698 699 # we would be here in the second iteration where we establish equality 700 # between operand type dimensions and the resulting type dimensions 701 r = [] 702 for x, y, z in zip(a1, a2, a3): 703 if x == y: 704 r.append(Equality(x, z)) 705 res = r 706 return res 707 708 709@register_refinement_rule(torch.flatten) 710def flatten_refinement_rule(n: Node): 711 """ 712 Generates equality constraints between the dimensions of the input and output 713 that will not be involved in the flatten operation 714 """ 715 assert isinstance(n.args[0], Node) 716 717 eq_const = [] 718 719 start_dim = 1 720 end_dim = -1 721 722 if len(n.args) > 1: 723 assert isinstance(n.args[1], int) 724 start_dim = n.args[1] 725 726 if len(n.args) > 2: 727 assert isinstance(n.args[2], int) 728 end_dim = n.args[2] 729 730 if isinstance(n.type, TensorType) and isinstance(n.args[0].type, TensorType): 731 l = len(n.type.__args__) 732 arg_type = n.args[0].type 733 start_dim = l if start_dim == -1 else start_dim 734 end_dim = l + end_dim + 1 if end_dim < 0 else end_dim + 1 735 736 for t1, t2 in zip(n.type.__args__[0:start_dim], arg_type.__args__[0:start_dim]): 737 eq_const.append(Equality(t1, t2)) 738 739 for t1, t2 in zip(n.type.__args__[end_dim:], arg_type.__args__[end_dim:]): 740 eq_const.append(Equality(t1, t2)) 741 return eq_const 742 743 744@register_algebraic_expressions_inference_rule(Conv2d) 745def conv_rule(n: Node, module_instance): 746 """ 747 Represents the outout in terms of an algrbraic expression w.r.t 748 the input when possible 749 """ 750 assert isinstance(n.args[0], Node) 751 arg_type = n.args[0].type 752 if isinstance(arg_type, TensorType) and isinstance(n.type, TensorType): 753 w_in = arg_type.__args__[3] 754 h_in = arg_type.__args__[2] 755 h_out = calculate_out_dimension(h_in, module_instance, 0) 756 w_out = calculate_out_dimension(w_in, module_instance, 1) 757 new_type = TensorType((n.type.__args__[0], n.type.__args__[1], h_out, w_out)) 758 n.type = new_type 759 return new_type 760 761class Refine: 762 """ 763 Symbolic shape inference. 764 Generates constraints over type variables. 765 Currently all constraints are equality constraints. 766 """ 767 def __init__(self, traced): 768 self.constraints = [] 769 self.traced = traced 770 self.symbol_iter = itertools.count(start=0, step=1) 771 772 def refine(self): 773 """ 774 Generates constraints for 775 every node in the graph based on 776 the operation. 777 """ 778 graph = self.traced.graph 779 for n in graph.nodes: 780 self.refine_node(n) 781 return True 782 783 def symbolic_relations(self): 784 """ 785 Infers algebraic relations 786 """ 787 graph = self.traced.graph 788 for n in graph.nodes: 789 self.infer_symbolic_relations(n) 790 return True 791 792 def replace_dyn_with_fresh_var(self, typ): 793 """ 794 Replace all unknown types with fresh type variables. 795 """ 796 if typ == Dyn: 797 new_symbol = Var(next(self.symbol_iter)) 798 return new_symbol 799 elif isinstance(typ, TensorType): 800 new_args = [self.replace_dyn_with_fresh_var(a) for a in typ.__args__] 801 return TensorType(tuple(new_args)) 802 elif isinstance(typ, list): 803 return [self.replace_dyn_with_fresh_var(t) for t in typ] 804 elif isinstance(typ, tuple): 805 return (self.replace_dyn_with_fresh_var(t) for t in typ) 806 else: 807 return typ 808 809 810 def convert_to_sympy_symbols(self, typ): 811 """ 812 Replace all unknown types with fresh type variables. 813 """ 814 if isinstance(typ, Var): 815 return sympy.symbols(str(typ)) 816 elif isinstance(typ, TensorType): 817 new_args = [self.convert_to_sympy_symbols(a) for a in typ.__args__] 818 return TensorType(tuple(new_args)) 819 elif isinstance(typ, list): 820 return [self.convert_to_sympy_symbols(t) for t in typ] 821 elif isinstance(typ, tuple): 822 return (self.convert_to_sympy_symbols(t) for t in typ) 823 else: 824 return typ 825 826 def refine_node(self, n: Node): 827 """ 828 Returns a list of equality constraints for 829 call_module and call_function nodes. 830 Models the relation between input and output dimensions 831 using constraints in case they are both tensors. 832 All operations used in resnet50 are defined. 833 """ 834 if n.type is None: 835 n.type = Dyn 836 837 n.type = self.replace_dyn_with_fresh_var(n.type) 838 839 if n.op == 'call_function': 840 if n.target in _REFINEMENT_RULES: 841 self.constraints += _REFINEMENT_RULES[n.target](n) 842 else: 843 pass 844 845 if n.op == 'call_module': 846 module_instance = self.traced.get_submodule(n.target) 847 if type(module_instance) in _REFINEMENT_RULES: 848 self.constraints += _REFINEMENT_RULES[type(module_instance)](n) 849 else: 850 pass 851 852 if n.op == 'output': 853 def get_node_type(a): 854 return a.type 855 n.type = torch.fx.node.map_arg(n.args[0], get_node_type) 856 return n.type 857 858 else: 859 pass 860 861 def infer_symbolic_relations(self, n: Node): 862 n.type = self.convert_to_sympy_symbols(n.type) 863 if n.op == 'call_function': 864 if n.target in _RULES: 865 return _RULES[n.target](n) 866 else: 867 pass 868 869 if n.op == 'call_module': 870 module_instance = self.traced.get_submodule(n.target) 871 if type(module_instance) in _RULES: 872 return _RULES[type(module_instance)](n, module_instance) 873 else: 874 pass 875 876 if n.op == 'output': 877 def get_node_type(a): 878 return a.type 879 n.type = torch.fx.node.map_arg(n.args[0], get_node_type) 880 return n.type 881 882 else: 883 pass 884 885def get_parameter(traced, target: str): 886 """ 887 Returns the parameter given by ``target`` if it exists, 888 otherwise throws an error. 889 890 See the docstring for ``get_submodule`` for a more detailed 891 explanation of this method's functionality as well as how to 892 correctly specify ``target``. 893 894 Args: 895 target: The fully-qualified string name of the Parameter 896 to look for. (See ``get_submodule`` for how to specify a 897 fully-qualified string.) 898 899 Returns: 900 torch.nn.Parameter: The Parameter referenced by ``target`` 901 902 Raises: 903 AttributeError: If the target string references an invalid 904 path or resolves to something that is not an 905 ``nn.Parameter`` 906 """ 907 module_path, _, param_name = target.rpartition(".") 908 909 mod: torch.nn.Module = traced.get_submodule(module_path) 910 911 if not hasattr(mod, param_name): 912 raise AttributeError(mod._get_name() + " has no attribute `" + param_name + "`") 913 914 param: torch.nn.Parameter = getattr(mod, param_name) 915 916 return param 917