1# mypy: ignore-errors 2import copy 3import itertools 4from torch.fx.experimental.migrate_gradual_types.constraint_generator import BinConstraintT, MAX_TENSOR_RANK 5from torch.fx.experimental.migrate_gradual_types.constraint import T, BinConstraintD, Conj, Constraint, DVar, TVar, \ 6 Transpose 7from torch.fx.experimental.migrate_gradual_types.constraint import Disj, TGreatestUpperBound 8from torch.fx.experimental.migrate_gradual_types.constraint import DGreatestUpperBound 9from torch.fx.experimental.migrate_gradual_types.constraint import CalcConv, CalcMaxPool 10from torch.fx.experimental.migrate_gradual_types.constraint import CalcProduct, CanReshape 11from torch.fx.experimental.migrate_gradual_types.constraint import ApplyBroadcasting, Prod, F, GetItem, GetItemTensor, IndexSelect 12from torch.fx.experimental.migrate_gradual_types.operation import op_eq, op_precision, op_leq, op_matching 13from torch.fx.experimental.migrate_gradual_types.operation import op_consistency, op_neq 14from torch.fx.experimental.migrate_gradual_types.operation import op_mul, op_add, op_sub, op_div, op_mod 15from torch.fx.experimental.migrate_gradual_types.util import gen_tensor_dims, gen_nat_constraints, gen_dvar 16from torch.fx.tensor_type import TensorType, Dyn 17from typing import Callable, Dict, List 18 19_TRANSFORMATION_RULES: Dict[Constraint, Callable] = {} 20 21 22def register_transformation_rule(call_target): 23 def register(fn): 24 if call_target in _TRANSFORMATION_RULES: 25 raise RuntimeError(f'Transformation rule already registered for {call_target}!') 26 _TRANSFORMATION_RULES[call_target] = fn 27 return fn 28 return register 29 30 31def valid_index(index, dims): 32 """ 33 Given a list of dimensions, checks if an index is valid in the list 34 """ 35 try: 36 dims[index] 37 return T() 38 except IndexError: 39 return F() 40 41 42@register_transformation_rule(Transpose) 43def transform_transpose(constraint, counter): 44 """ 45 Similar to a sequence of two index-selects 46 """ 47 dims, counter = gen_tensor_dims(constraint.tensor_size, counter) 48 is_valid_index1 = valid_index(constraint.index1, dims) 49 is_valid_index2 = valid_index(constraint.index2, dims) 50 new_dims = copy.deepcopy(dims) 51 nat_constraints = gen_nat_constraints(dims) 52 53 if is_valid_index1 == T() and is_valid_index2 == T(): 54 new_dims[constraint.index1] = dims[constraint.index2] 55 new_dims[constraint.index2] = dims[constraint.index1] 56 57 transformed_constraint = Conj([BinConstraintT(constraint.input_var, TensorType(dims), op_eq), 58 *nat_constraints, 59 is_valid_index1, is_valid_index2, 60 BinConstraintT(constraint.output, TensorType(new_dims), op_eq)]) 61 return transformed_constraint, counter 62 63 64@register_transformation_rule(IndexSelect) 65def transform_index_select(constraint, counter): 66 """ 67 The constraints consider the given tensor size, checks if the index is valid 68 and if so, generates a constraint for replacing the input dimension 69 with the required dimension 70 """ 71 dims, counter = gen_tensor_dims(constraint.tensor_size, counter) 72 is_valid_index = valid_index(constraint.index, dims) 73 nat_constraints = gen_nat_constraints(dims) 74 75 # if the index is valid then replace the input dimension with the new dimension 76 # otherwise the dimension will not be replaced and the clause will contain False 77 if is_valid_index == T(): 78 new_dims = copy.deepcopy(dims) 79 new_dims[constraint.index] = constraint.dim_replace 80 81 transformed_constraint = Conj([BinConstraintT(constraint.input_var, TensorType(dims), op_eq), 82 *nat_constraints, 83 is_valid_index, 84 BinConstraintT(constraint.output, TensorType(new_dims), op_eq)]) 85 86 # print(constraints) 87 return transformed_constraint, counter 88 89 90@register_transformation_rule(GetItem) 91def transform_get_item(constraint, counter): 92 """ 93 generate an equality of the form: 94 t = [a1, ..., an] 95 then generate constraints that check if the given index is valid 96 given this particular tensor size. 97 If the index is valid, generate a constraint to get the item 98 Note that we already handled the Dyn input case in the previous 99 step. 100 Args: 101 constraint: GetItem which assumes we are getting an item from a tensor (not Dyn) 102 counter: variable tracking 103 Returns: simplified constraints for GetItem 104 105 """ 106 dims, counter = gen_tensor_dims(constraint.tensor_size, counter) 107 nat_constraints = gen_nat_constraints(dims) 108 109 110 is_valid_index = valid_index(constraint.index, dims) 111 112 all_constraints = [BinConstraintT(constraint.input_var, TensorType(dims), op_eq), 113 *nat_constraints, 114 is_valid_index] 115 116 # if the index is valid, we generate a constraint for getting an item 117 # otherwise this clause will have been UNSAT due to the wrong index 118 if is_valid_index == T(): 119 all_constraints.append(BinConstraintD(constraint.res, dims[constraint.index], op_eq)) 120 121 return Conj(all_constraints), counter 122 123def valid_index_tensor(index, dims): 124 """ 125 if the slice instances exceed the length of the dimensions 126 then this is a type error so we return False 127 """ 128 slice_count = 0 129 for s in index: 130 if isinstance(s, slice): 131 slice_count += 1 132 if slice_count > len(dims): 133 return F() 134 else: 135 return T() 136 137@register_transformation_rule(GetItemTensor) 138def transform_get_item_tensor(constraint, counter): 139 """ 140 When the index is a tuple, then the output will be a tensor 141 TODO: we have to check if this is the case for all HF models 142 143 The cases we are covering here are a tuple with one of: 144 - slice with default argument 145 - None 146 147 None appends 1 to the input tensor dimensions 148 so each occurrence of 'None' increases the rank by 1 149 150 slice with default arguments does not change the rank 151 """ 152 assert isinstance(constraint.index_tuple, tuple) 153 154 155 # generate a result tensor of the expected size 156 dims, counter = gen_tensor_dims(constraint.tensor_size, counter) 157 nat_constraints = gen_nat_constraints(dims) 158 159 # generate a place-holder list of the right rank 160 # where "slice" does not contribute to the rank and "None" does 161 none_c = constraint.index_tuple.count(None) 162 resulting_tensor_dims = (none_c + len(dims)) * [None] 163 164 dim_index = 0 165 for i in range(len(constraint.index_tuple)): 166 167 # append 1 to the right location of the resulting tensor 168 if constraint.index_tuple[i] is None: 169 resulting_tensor_dims[i] = 1 170 171 elif constraint.index_tuple[i] == slice(None, None, None): 172 pass 173 174 else: 175 raise NotImplementedError('Method not yet implemented') 176 177 # append the remaining dimensions to the right location 178 dim_index = 0 179 for i in range(len(resulting_tensor_dims)): 180 if resulting_tensor_dims[i] is None: 181 resulting_tensor_dims[i] = dims[dim_index] 182 dim_index += 1 183 184 # check if the index is valid 185 is_valid_index = valid_index_tensor(constraint.index_tuple, dims) 186 187 # check if the resulting tensor is within bounds 188 if len(resulting_tensor_dims) > 4: 189 return F(), counter 190 191 else: 192 constraints = [BinConstraintT(constraint.input_var, TensorType(dims), op_eq), 193 BinConstraintT(constraint.res, TensorType(resulting_tensor_dims), op_eq), 194 *nat_constraints, 195 is_valid_index] 196 return Conj(constraints), counter 197 198 199@register_transformation_rule(BinConstraintT) 200def generate_binconstraint_t(constraint, counter): 201 """ 202 Transform binary constraints for tensors 203 """ 204 205 # precision constraints 206 if constraint.op == op_precision: 207 if constraint.lhs == Dyn: 208 return T(), counter 209 elif isinstance(constraint.lhs, TensorType): 210 is_fully_static = all(d != Dyn for d in constraint.lhs.__args__) 211 if is_fully_static: 212 return BinConstraintT(constraint.lhs, constraint.rhs, op_eq), counter 213 else: 214 new_dims = [] 215 216 for _ in range(len(constraint.lhs.__args__)): 217 dim, counter = gen_dvar(counter) 218 new_dims.append(dim) 219 220 new_dim_constraints = [BinConstraintD(old_dim, new_dim, op_precision) for 221 new_dim, old_dim in zip(new_dims, constraint.lhs.__args__)] + \ 222 [BinConstraintT(constraint.rhs, TensorType(new_dims), op_eq)] + \ 223 [BinConstraintD(1, new_dim, op_leq) for 224 new_dim in new_dims] 225 return Conj(new_dim_constraints), counter 226 227 # matching 228 elif constraint.op == op_matching: 229 assert isinstance(constraint.rhs, TensorType) 230 d1 = constraint.rhs.__args__[0] 231 d2 = constraint.rhs.__args__[1] 232 d3 = constraint.rhs.__args__[2] 233 d4 = constraint.rhs.__args__[3] 234 235 conj = [BinConstraintT(constraint.lhs, Dyn, op_eq), 236 BinConstraintD(d1, Dyn, op_eq), 237 BinConstraintD(d2, Dyn, op_eq), 238 BinConstraintD(d3, Dyn, op_eq), 239 BinConstraintD(d4, Dyn, op_eq)] 240 return Disj([Conj(conj), 241 BinConstraintT(constraint.lhs, TensorType([d1, d2, d3, d4]), op_eq)]), counter 242 243 elif constraint.op == op_consistency: 244 c_dyn = Disj([BinConstraintT(constraint.lhs, Dyn, op_eq), BinConstraintT(constraint.rhs, Dyn, op_eq)]) 245 [c_tensor_1, c_tensor_2, c_tensor_3, c_tensor_4], counter = gen_consistency_constraints(constraint, counter) 246 247 return Disj([c_dyn, c_tensor_1, c_tensor_2, c_tensor_3, c_tensor_4]), counter 248 249 elif constraint.op == op_leq: 250 assert isinstance(constraint.rhs, int) 251 disj = [BinConstraintT(constraint.lhs, Dyn, op_eq)] 252 for i in range(1, constraint.rhs + 1): 253 dims = [] 254 for j in range(1, i + 1): 255 dim_var, counter = gen_dvar(counter) 256 dims.append(dim_var) 257 disj.append(BinConstraintT(constraint.lhs, TensorType(dims), op_eq)) 258 return Disj(disj), counter 259 else: 260 return constraint, counter 261 262 263@register_transformation_rule(BinConstraintD) 264def generate_binconstraint_d(constraint, counter): 265 """ 266 Transform binary constraints for dimensions 267 """ 268 if constraint.op == op_precision: 269 if isinstance(constraint.lhs, int): 270 return BinConstraintD(constraint.lhs, constraint.rhs, op_eq), counter 271 elif constraint.lhs == Dyn: 272 return T(), counter 273 274 elif constraint.op == op_consistency: 275 return Disj([BinConstraintD(constraint.lhs, constraint.rhs, op_eq), 276 BinConstraintD(constraint.rhs, Dyn, op_eq), BinConstraintD(constraint.lhs, Dyn, op_eq)]), counter 277 278 else: 279 return constraint, counter 280 281 282@register_transformation_rule(Conj) 283def generate_conj(constraint, counter): 284 """ 285 Transform conjunctions 286 """ 287 new = [] 288 for c in constraint.conjucts: 289 new_c, counter = transform_constraint(c, counter) 290 new.append(new_c) 291 return Conj(new), counter 292 293 294@register_transformation_rule(Disj) 295def generate_disj(constraint, counter): 296 """ 297 Transform disjunctions 298 """ 299 new = [] 300 for c in constraint.disjuncts: 301 new_c, counter = transform_constraint(c, counter) 302 new.append(new_c) 303 return Disj(new), counter 304 305 306@register_transformation_rule(TGreatestUpperBound) 307def generate_gub(constraint, counter): 308 """ 309 Transform greatest upper bound for tensors. Results in equality and Greatest Upper Bound 310 on dimensions 311 """ 312 c1 = Conj([Disj([BinConstraintT(constraint.rhs1, Dyn, op_eq), 313 BinConstraintT(constraint.rhs2, Dyn, op_eq)]), BinConstraintT(constraint.res, Dyn, op_eq)]) 314 315 [c2, c3, c4, c5], counter = gen_greatest_upper_bound(constraint, counter) 316 317 return Disj([c1, c2, c3, c4, c5]), counter 318 319 320@register_transformation_rule(DGreatestUpperBound) 321def generate_d_gub(constraint, counter): 322 """ 323 Transform greatest upper bound for dimensions into equality constraints 324 """ 325 c1 = Conj([BinConstraintD(constraint.rhs1, Dyn, op_eq), BinConstraintD(constraint.res, constraint.rhs2, op_eq)]) 326 c2 = Conj([BinConstraintD(constraint.rhs2, Dyn, op_eq), BinConstraintD(constraint.res, constraint.rhs1, op_eq)]) 327 c3 = Conj([BinConstraintD(constraint.rhs2, constraint.rhs1, op_eq), BinConstraintD(constraint.res, constraint.rhs1, op_eq)]) 328 return Disj([c1, c2, c3]), counter 329 330 331@register_transformation_rule(CalcConv) 332def generate_calc_conv(constraint, counter): 333 d, counter = gen_tensor_dims(4, counter) 334 conv_result = TensorType([d[0], d[1], d[2], d[3]]) 335 336 # the convolution result is a tensor of size 4 337 c1 = BinConstraintT(constraint.conv_result, conv_result, op_eq) 338 339 # the second dimension of the output is equal to the output channels 340 c2 = Conj([BinConstraintD(d[1], constraint.c_out, op_eq), BinConstraintD(d[1], Dyn, op_neq)]) 341 342 # the input corresponds to the output in the first dimension of the convolution 343 c3 = BinConstraintD(constraint.matching_constraint[0], d[0], op_eq) 344 345 c4, c5 = calc_last_two_dims(constraint, d) 346 347 leq_constraints = Conj([BinConstraintD(0, d[0], op_leq), 348 BinConstraintD(0, d[1], op_leq), 349 BinConstraintD(0, d[2], op_leq), 350 BinConstraintD(0, d[3], op_leq)]) 351 352 return Conj([c1, c2, c3, c4, c5, leq_constraints]), counter 353 354 355@register_transformation_rule(CalcMaxPool) 356def generate_calc_maxpool(constraint, counter): 357 """ 358 Transform maxpool constraints 359 """ 360 d, counter = gen_tensor_dims(4, counter) 361 maxpool_result = TensorType([d[0], d[1], d[2], d[3]]) 362 363 # the maxpool result is a tensor of size 4 364 c1 = BinConstraintT(constraint.maxpool_result, maxpool_result, op_eq) 365 366 # the input corresponds to the output in the first and second dimension of maxpool 367 c2 = BinConstraintD(constraint.matching_constraint[1], d[1], op_eq) 368 c3 = BinConstraintD(constraint.matching_constraint[0], d[0], op_eq) 369 c4, c5 = calc_last_two_dims(constraint, d) 370 371 leq_constraints = Conj([BinConstraintD(0, d[0], op_leq), 372 BinConstraintD(0, d[1], op_leq), 373 BinConstraintD(0, d[2], op_leq), 374 BinConstraintD(0, d[3], op_leq)]) 375 376 return Conj([c1, c2, c3, c4, c5, leq_constraints]), counter 377 378 379@register_transformation_rule(CalcProduct) 380def generate_calc_product(constraint, counter): 381 """ 382 Transform flatten constraints 383 """ 384 start = constraint.start 385 end = constraint.end 386 dims = constraint.dims_to_flatten 387 flattened = constraint.flattened 388 n = len(constraint.dims_to_flatten) 389 390 # this will be evaluated right here 391 boundary_check = (0 <= start and start < end and end <= n) 392 393 c_boundary = T() if boundary_check else F() 394 395 lhs = dims[0:start] 396 rhs = dims[end:] 397 mid = dims[start:end] 398 399 all_possibilities = generate_all_int_dyn_dim_possibilities(mid) 400 401 all_constraints = [] 402 403 for p in all_possibilities: 404 p = list(p) 405 # this tells us there is a dynamic variable 406 contains_dyn = not all(constraint.op == op_neq for constraint in p) 407 if contains_dyn: 408 mid_var = [Dyn] 409 total_constraints = lhs + mid_var + rhs 410 if len(total_constraints) > 4: 411 all_constraints.append(F()) 412 else: 413 all_constraints.append(Conj([BinConstraintT(flattened, TensorType(lhs + mid_var + rhs), op_eq)] + p)) 414 else: 415 new_var, counter = gen_dvar(counter) 416 mid_eq_prod = Conj([BinConstraintD(new_var, Prod(mid), op_eq), BinConstraintD(new_var, Dyn, op_neq)]) 417 mid_var = [new_var] 418 total_constraints = lhs + mid_var + rhs 419 if len(total_constraints) > 4: 420 all_constraints.append(F()) 421 else: 422 all_constraints.append(Conj([BinConstraintT(flattened, TensorType(lhs + mid_var + rhs), op_eq), mid_eq_prod] + p)) 423 424 return Conj([Disj(all_constraints), c_boundary]), counter 425 426 427@register_transformation_rule(CanReshape) 428def generate_reshape(constraint, counter): 429 """ 430 Transform reshape constraints 431 """ 432 d, counter = gen_tensor_dims(4, counter) 433 434 d1 = d[0] 435 d2 = d[1] 436 d3 = d[2] 437 d4 = d[3] 438 439 target = constraint.target.__args__ 440 441 is_fully_static = all(d != Dyn for d in target) 442 443 # dynamic tensor 444 c1_dyn = BinConstraintT(constraint.src, Dyn, op_eq) 445 c2_tensor1 = BinConstraintT(constraint.src, TensorType([d1]), op_eq) 446 c2_tensor2 = BinConstraintT(constraint.src, TensorType([d1, d2]), op_eq) 447 c2_tensor3 = BinConstraintT(constraint.src, TensorType([d1, d2, d3]), op_eq) 448 c2_tensor4 = BinConstraintT(constraint.src, TensorType([d1, d2, d3, d4]), op_eq) 449 450 d1_eq_dyn = BinConstraintD(d1, Dyn, op_eq) 451 d1_neq_dyn = BinConstraintD(d1, Dyn, op_neq) 452 453 d2_eq_dyn = BinConstraintD(d2, Dyn, op_eq) 454 d2_neq_dyn = BinConstraintD(d2, Dyn, op_neq) 455 456 d3_eq_dyn = BinConstraintD(d3, Dyn, op_eq) 457 d3_neq_dyn = BinConstraintD(d3, Dyn, op_neq) 458 459 d4_eq_dyn = BinConstraintD(d3, Dyn, op_eq) 460 d4_neq_dyn = BinConstraintD(d3, Dyn, op_neq) 461 462 nat_d1 = BinConstraintD(0, d1, op_leq) 463 nat_d2 = BinConstraintD(0, d2, op_leq) 464 nat_d3 = BinConstraintD(0, d3, op_leq) 465 nat_d4 = BinConstraintD(0, d4, op_leq) 466 467 if is_fully_static: 468 # size 1 tensor 469 c3_tensor1 = Disj([d1_eq_dyn, 470 (Conj([d1_neq_dyn, 471 BinConstraintD(d1, Prod(target), op_eq)]))]) 472 all_tensor_1 = Conj([c2_tensor1, c3_tensor1]) 473 474 # size 2 tensor 475 all_tensor_2 = Conj([c2_tensor2, gen_all_reshape_possibilities([d1, d2], target)]) 476 477 # size 3 tensor 478 all_tensor_3 = Conj([c2_tensor3, gen_all_reshape_possibilities([d1, d2, d3], target)]) 479 480 # size 4 tensor 481 all_tensor_4 = Conj([c2_tensor4, gen_all_reshape_possibilities([d1, d2, d3, d4], target)]) 482 483 return Conj([Disj([c1_dyn, all_tensor_1, all_tensor_2, all_tensor_3, all_tensor_4]), 484 nat_d1, nat_d2, nat_d3, nat_d4]), counter 485 486 # then there must be exactly one occurrence of dyn 487 else: 488 new_target = [] 489 490 for n in target: 491 if n != Dyn: 492 new_target.append(n) 493 494 # tensor 1 495 c3_tensor1 = Disj([d1_eq_dyn, 496 (Conj([d1_neq_dyn, 497 is_dim_div_by_target(new_target, d1)]))]) 498 all_tensor_1 = Conj([c2_tensor1, c3_tensor1]) 499 500 # tensor 2 501 c21 = Disj([d1_eq_dyn, d2_eq_dyn]) 502 c22 = Conj([d1_neq_dyn, d2_neq_dyn, is_dim_div_by_target(new_target, Prod([d1, d2]))]) 503 all_tensor_2 = Conj([c2_tensor2, Disj([c21, c22])]) 504 505 # tensor 3 506 c31 = Disj([d1_eq_dyn, d2_eq_dyn, d3_eq_dyn]) 507 c32 = Conj([d1_neq_dyn, d2_neq_dyn, d3_neq_dyn, is_dim_div_by_target(new_target, Prod([d1, d2, d3]))]) 508 all_tensor_3 = Conj([c2_tensor3, Disj([c31, c32])]) 509 510 # tensor 4 511 c41 = Disj([d1_eq_dyn, d2_eq_dyn, d3_eq_dyn, d4_eq_dyn]) 512 c42 = Conj([d1_neq_dyn, d2_neq_dyn, d3_neq_dyn, d4_neq_dyn, is_dim_div_by_target(new_target, Prod([d1, d2, d3, d4]))]) 513 all_tensor_4 = Conj([c2_tensor4, Disj([c41, c42])]) 514 515 return Conj([Disj([c1_dyn, all_tensor_1, all_tensor_2, all_tensor_3, all_tensor_4]), 516 nat_d1, nat_d2, nat_d3, nat_d4]), counter 517 518 519@register_transformation_rule(ApplyBroadcasting) 520def generate_broadcasting(constraint, counter): 521 """ 522 Transform broadcasting constraints 523 """ 524 e11, e12 = constraint.res1, constraint.res2 525 e1, e2 = constraint.input1, constraint.input2 526 527 e1_dyn = BinConstraintT(e1, Dyn, op_eq) 528 e2_dyn = BinConstraintT(e2, Dyn, op_eq) 529 530 # Introduce dimensions 531 e1_equal_e11 = BinConstraintT(e1, e11, op_eq) 532 e2_equal_e12 = BinConstraintT(e2, e12, op_eq) 533 534 # dyn possibility 535 e1_dyn_constraint = Conj([e1_dyn, e1_equal_e11, e2_equal_e12]) 536 e2_dyn_constraint = Conj([e2_dyn, e1_equal_e11, e2_equal_e12]) 537 538 # tensor possibility 539 # generate dimensions to create tensors of size 1 540 final_tensor_1_constraint, _, _, nat_dims_1, counter = \ 541 gen_broadcasting_constraints(e1, e2, e11, e12, 1, counter) 542 543 # generate dimensions to create tensors of size 2 544 final_tensor_2_constraint_no_padding, final_tensor_2_constraint_padding_arg1, \ 545 final_tensor_2_constraint_padding_arg2, nat_dims_2, counter = \ 546 gen_broadcasting_constraints(e1, e2, e11, e12, 2, counter) 547 548 # generate dimensions to create tensors of size 3 549 final_tensor_3_constraint_no_padding, final_tensor_3_constraint_padding_arg1, \ 550 final_tensor_3_constraint_padding_arg2, nat_dims_3, counter = \ 551 gen_broadcasting_constraints(e1, e2, e11, e12, 3, counter) 552 553 # generate dimensions to create tensors of size 4 554 final_tensor_4_constraint_no_padding, final_tensor_4_constraint_padding_arg1, \ 555 final_tensor_4_constraint_padding_arg2, nat_dims_4, counter = \ 556 gen_broadcasting_constraints(e1, e2, e11, e12, 4, counter) 557 558 final_result = Disj([ 559 e1_dyn_constraint, 560 e2_dyn_constraint, 561 final_tensor_1_constraint, 562 final_tensor_2_constraint_no_padding, 563 final_tensor_2_constraint_padding_arg1, 564 final_tensor_2_constraint_padding_arg2, 565 final_tensor_3_constraint_no_padding, 566 final_tensor_3_constraint_padding_arg1, 567 final_tensor_3_constraint_padding_arg2, 568 final_tensor_4_constraint_no_padding, 569 final_tensor_4_constraint_padding_arg1, 570 final_tensor_4_constraint_padding_arg2 571 ]) 572 573 return Conj([final_result, *nat_dims_1, *nat_dims_2, *nat_dims_3, *nat_dims_4]), counter 574 575 576def transform_constraint(constraint: Constraint, counter: int): 577 """ 578 Transforms a constraint into a simpler constraint. 579 Ex: precision and consistency are transformed to equality 580 Args: 581 constraint: constraint to be transformed 582 counter: for variable tracking 583 584 Returns: Constraint 585 586 """ 587 if type(constraint) in _TRANSFORMATION_RULES: 588 return _TRANSFORMATION_RULES[type(constraint)](constraint, counter) 589 590 else: 591 return constraint, counter 592 593 594 595 596def calc_last_two_dims(constraint, d: List[DVar]): 597 """ 598 Generates constraints for the last two dimensions of a convolution or a maxpool output 599 Args: 600 constraint: CalcConv or CalcMaxPool 601 d: The list of output dimensions 602 603 Returns: Constraints for calculating the last two dimensions of the output 604 605 """ 606 607 assert isinstance(constraint, (CalcConv, CalcMaxPool)) 608 609 b3 = constraint.matching_constraint[2] 610 b4 = constraint.matching_constraint[3] 611 612 b3_dyn = Conj([BinConstraintD(d[2], Dyn, op_eq), BinConstraintD(b3, Dyn, op_eq)]) 613 b4_dyn = Conj([BinConstraintD(d[3], Dyn, op_eq), BinConstraintD(b4, Dyn, op_eq)]) 614 615 d3_not_dyn = Conj([BinConstraintD(d[2], Dyn, op_neq), BinConstraintD(b3, Dyn, op_neq)]) 616 d4_not_dyn = Conj([BinConstraintD(d[3], Dyn, op_neq), BinConstraintD(b4, Dyn, op_neq)]) 617 618 # transform parameters into tuples incase they are not already 619 padding = (constraint.padding, constraint.padding) \ 620 if isinstance(constraint.padding, int) else constraint.padding 621 kernel = (constraint.kernel, constraint.kernel) \ 622 if isinstance(constraint.kernel, int) else constraint.kernel 623 stride = (constraint.stride, constraint.stride) \ 624 if isinstance(constraint.stride, int) else constraint.stride 625 dilation = (constraint.dilation, constraint.dilation) \ 626 if isinstance(constraint.dilation, int) else constraint.dilation 627 628 f1 = BinConstraintD(b3, BinConstraintD(2, padding[0], op_mul), op_add) 629 f2 = BinConstraintD(dilation[0], BinConstraintD(kernel[0], 1, op_sub), op_mul) 630 f3 = BinConstraintD(BinConstraintD(BinConstraintD(f1, f2, op_sub), 1, op_sub), stride[0], op_div) 631 f4 = BinConstraintD(f3, 1, op_add) 632 633 c4 = Disj([b3_dyn, Conj([d3_not_dyn, BinConstraintD(d[2], f4, op_eq)])]) 634 635 f11 = BinConstraintD(b4, BinConstraintD(2, padding[1], op_mul), op_add) 636 f22 = BinConstraintD(dilation[1], BinConstraintD(kernel[1], 1, op_sub), op_mul) 637 f33 = BinConstraintD(BinConstraintD(BinConstraintD(f11, f22, op_sub), 1, op_sub), stride[1], op_div) 638 f44 = BinConstraintD(f33, 1, op_add) 639 640 c5 = Disj([b4_dyn, Conj([d4_not_dyn, BinConstraintD(d[3], f44, op_eq)])]) 641 642 return c4, c5 643 644 645def generate_all_int_dyn_dim_possibilities(my_list: List[DVar]): 646 """ 647 Generate all possibilities of being equal or not equal to dyn for my_list 648 Args: 649 my_list: List of tensor dimensions 650 651 Returns: A list of a list of constraints. Each list of constraints corresponds to 652 one possibility about the values of the dimension variables 653 """ 654 # generate all possibilities of being equal or not equal to dyn for my_list 655 eq_possibilities = [BinConstraintD(my_list[i], Dyn, op_eq) for i in range(len(my_list))] 656 neq_possibilities = [BinConstraintD(my_list[i], Dyn, op_neq) for i in range(len(my_list))] 657 d_possibilities = [] 658 659 for i in zip(eq_possibilities, neq_possibilities): 660 d_possibilities.append(list(i)) 661 all_possibilities = list(itertools.product(*d_possibilities)) 662 return all_possibilities 663 664 665def is_target_div_by_dim(target: List[int], dim: List[DVar]): 666 """ 667 Generate constraints to check if the target dimensions are divisible by the input dimensions 668 Args: 669 target: Target dimensions 670 dim: Input dimensions 671 672 Returns: Constraints to check divisibility 673 674 """ 675 return BinConstraintD(BinConstraintD(Prod(target), dim, op_mod), 0, op_eq) 676 677 678def is_dim_div_by_target(target: List[int], dim: List[DVar]): 679 """ 680 Generate constraints to check if the input dimensions is divisible by the target dimensions 681 Args: 682 target: Target dimensions 683 dim: Input dimensions 684 685 Returns: Constraints to check divisibility 686 687 """ 688 return BinConstraintD(BinConstraintD(dim, Prod(target), op_mod), 0, op_eq) 689 690 691def gen_all_reshape_possibilities(list_of_dims, target): 692 """ 693 Consider all possibilities what the input dimensions could be (number or dynamic) 694 Then generate the appropriate constraints using multiplication or mod depending on the possibility 695 The possibilities we consider here are the cross product of being equal to dyn or not equal to dyn 696 for the input. Target is fixed because at most one dimension could be dyn. 697 We have different cases for this. 698 699 Args: 700 list_of_dims: The input list of dimensions 701 target: The tensor we want to reshape to 702 703 Returns: A disjunction of transformed reshape constraints 704 705 """ 706 all_possibilities = generate_all_int_dyn_dim_possibilities(list_of_dims) 707 708 all_constraints = [] 709 710 for p in all_possibilities: 711 to_multiply = [] 712 713 p = list(p) 714 715 for constraint in p: 716 assert isinstance(constraint, BinConstraintD) 717 if constraint.op == op_neq: 718 to_multiply.append(constraint.lhs) 719 720 if not to_multiply: 721 all_constraints.append(Conj(p)) 722 723 elif len(to_multiply) < len(list_of_dims): 724 all_constraints.append(Conj(p + [is_target_div_by_dim(target, Prod(to_multiply))])) 725 else: 726 all_constraints.append(Conj(p + [BinConstraintD(Prod(list_of_dims), 727 Prod(target), op_eq)])) 728 729 return Disj(all_constraints) 730 731 732def broadcast_dim(tensor_input1, tensor_input2, res1, res2, index, padding=False): 733 """ 734 Apply broadcasting to the 'index' dimension of tensor_input1. 735 Args: 736 tensor_input1: should represent [d1, ..., d_index, ...] where d_index = 1 737 tensor_input2: represents the second input 738 res1: broadcasted result 1 739 res2: broadcasted result 2 740 index: the index to broadcast 741 padding: If padding was used, then tensor_input1[index] does not exist 742 743 Returns: 744 745 """ 746 if tensor_input1[index] is None: 747 assert padding 748 749 750 if not padding: 751 # then the inputs are the same length so they all have dimensions at "index" 752 return Conj([BinConstraintD(tensor_input1[index], 1, op_eq), 753 BinConstraintD(res1[index], res2[index], op_eq), 754 BinConstraintD(res2[index], tensor_input2[index], op_eq)]) 755 756 else: 757 # we don't set the input dimension to 1, since it doesn't exist. 758 return Conj([BinConstraintD(res1[index], res2[index], op_eq), 759 BinConstraintD(res2[index], tensor_input2[index], op_eq)]) 760 761 762def apply_padding(e1_var: TVar, 763 e11: BinConstraintT, 764 e2: BinConstraintT, 765 e12: BinConstraintT, 766 d2: List[DVar], 767 d11: List[DVar], 768 d12: List[DVar], 769 counter: int): 770 """ 771 We are considering the possibility where one input has less dimensions than 772 another input, so we apply padding to the broadcasted results 773 774 Args: 775 e1_var: Variable representing the first input where padding will be 776 e11: constraint of the form e11 = Tensortype[d1, ..., dn] 777 e2: constraint of the form e2 = Tensortype[d1, ..., dn] 778 e12: constraint of the form e11 = Tensortype[d1, ..., dn] 779 d2: Tensor variables for the second input 780 d11: Tensor variables for the broadcasted first input 781 d12: Tensor variables for the broadcasted second input 782 counter: variable tracking 783 784 Returns: A new constraint whose goal is to apply padding to the broadcasted result 785 786 """ 787 788 res = [] 789 790 # pad the shorter input with None so we can pass it to the broadcasting helper function 791 for i in range(1, len(d2)): 792 793 d1, counter = gen_tensor_dims(i, counter) 794 795 nat_constraints = gen_nat_constraints(d1 + d2 + d11 + d12) 796 797 e1 = BinConstraintT(e1_var, TensorType(d1), op_eq) 798 799 simulate_padding = [None] * (len(d2) - i) 800 801 assert len(simulate_padding + d1) == len(d2) 802 803 broadcast_padding = [] 804 805 # for every padding size, we also consider broadcasting 806 for j in range(len(d2) - i): 807 broadcast_padding.append(broadcast_dim(simulate_padding, d2, d11, d12, j, True)) 808 809 # we consider the possibilities for broadcasting for every dimension. Since we already 810 # padded d1, we do not consider it while broadcasting 811 all_broadcasting_possibilities = generate_all_broadcasting_possibilities_no_padding(d1, 812 d2[(len(d2) - i):], 813 d11[(len(d2) - i):], 814 d12[(len(d2) - i):]) 815 # combine all constraints into a conjunction 816 c = Conj([e1, e11, e2, e12, 817 *broadcast_padding, 818 all_broadcasting_possibilities, 819 *nat_constraints 820 ]) 821 res.append(c) 822 823 return Disj(res), counter 824 825 826def no_broadcast_dim_with_index(d1: List[DVar], 827 d2: List[DVar], 828 d3: List[DVar], 829 d4: List[DVar], 830 i: int): 831 """ 832 Args: 833 d1: input 1 834 d2: input 2 835 d3: simulated broadcasting for input 1 836 d4: simulated broadcasting for input 2 837 i: the rank of the resulting tensor addition 838 839 Returns: Constraints for when no broadcasting occurs 840 """ 841 return Conj([ 842 Disj([ 843 Conj([BinConstraintD(d1[i], 1, op_eq), 844 BinConstraintD(d2[i], 1, op_eq)]), 845 846 Conj([BinConstraintD(d1[i], 1, op_neq), 847 BinConstraintD(d2[i], 1, op_neq)])]), 848 849 BinConstraintD(d1[i], d3[i], op_eq), 850 BinConstraintD(d2[i], d4[i], op_eq)]) 851 852 853 854def gen_lists_of_dims(num_tensors: int, dim_size: int, counter: int): 855 """ 856 Generate lists of DVar to represent tensor dimensions 857 Args: 858 num_tensors: the required number of tensors 859 dim_size: the number of dimensions for each tensor 860 counter: variable tracking 861 862 Returns: A list of a list of tensor dimensions 863 864 """ 865 res = [] 866 867 for _ in range(num_tensors): 868 dims, counter = gen_tensor_dims(dim_size, counter) 869 res.append(dims) 870 871 return res, counter 872 873 874def create_equality_constraints_for_broadcasting(e1: TVar, 875 e2: TVar, 876 e11: TVar, 877 e12: TVar, 878 d1: List[DVar], 879 d2: List[DVar], 880 d11: List[DVar], 881 d12: List[DVar]): 882 """ 883 Create equality constraints for when no broadcasting occurs 884 Args: 885 e1: Input 1 886 e2: Input 2 887 e11: Broadcasted input 1 888 e12: Broadcasted input 2 889 d1: Variables that store dimensions for e1 890 d2: Variables that store dimensions for e2 891 d11: Variables that store dimensions for e11 892 d12: Variables that store dimensions for e22 893 894 Returns: Four equality constraints 895 896 """ 897 898 e1_tensor = BinConstraintT(e1, TensorType(d1), op_eq) 899 e11_tensor = BinConstraintT(e11, TensorType(d11), op_eq) 900 e2_tensor = BinConstraintT(e2, TensorType(d2), op_eq) 901 e12_tensor = BinConstraintT(e12, TensorType(d12), op_eq) 902 return [e1_tensor, e11_tensor, e2_tensor, e12_tensor] 903 904 905def gen_consistency_constraints(constraint: Constraint, counter: int): 906 """ 907 Args: 908 constraint: Consistency constraint on tensors 909 counter: for variable tracking 910 911 Returns: Equality and consistency constraints on dimensions 912 913 """ 914 915 all_constraints = [] 916 917 for i in range(1, MAX_TENSOR_RANK + 1): 918 new_dims_rhs_1, counter = gen_tensor_dims(i, counter) 919 new_dims_rhs_2, counter = gen_tensor_dims(i, counter) 920 921 nat_constraints = gen_nat_constraints(new_dims_rhs_1 + new_dims_rhs_2) 922 923 c_tensor_i = Conj([BinConstraintT(constraint.lhs, TensorType(new_dims_rhs_1), op_eq), 924 BinConstraintT(constraint.rhs, TensorType(new_dims_rhs_2), op_eq)] + 925 [BinConstraintD(d1, d2, op_consistency) for 926 d1, d2 in zip(new_dims_rhs_1, new_dims_rhs_2)] + nat_constraints) 927 928 all_constraints.append(c_tensor_i) 929 930 return all_constraints, counter 931 932 933def gen_greatest_upper_bound(constraint: TGreatestUpperBound, counter: int): 934 """ 935 Args: 936 constraint: Greatest upper bound on tensors 937 counter: variable tracking 938 939 Returns: A set of equality constraints and DGreatestUpperBound constraints 940 941 """ 942 943 all_constraints = [] 944 945 for i in range(1, MAX_TENSOR_RANK + 1): 946 c = [] 947 dims1, counter = gen_tensor_dims(i, counter) 948 c1tensor = TensorType(dims1) 949 950 dims2, counter = gen_tensor_dims(i, counter) 951 c2tensor = TensorType(dims2) 952 953 dims3, counter = gen_tensor_dims(i, counter) 954 c3tensor = TensorType(dims3) 955 956 c += [BinConstraintT(constraint.rhs1, c1tensor, op_eq), 957 BinConstraintT(constraint.rhs2, c2tensor, op_eq), 958 BinConstraintT(constraint.res, c3tensor, op_eq)] + \ 959 gen_nat_constraints(dims1 + dims2 + dims3) 960 961 assert len(c3tensor.__args__) == len(c1tensor.__args__) == len(c2tensor.__args__) 962 for i in range(len(c3tensor.__args__)): 963 c.append(DGreatestUpperBound(c3tensor.__args__[i], 964 c1tensor.__args__[i], 965 c2tensor.__args__[i])) 966 967 all_constraints.append(Conj(c)) 968 return all_constraints, counter 969 970 971def generate_all_broadcasting_possibilities_no_padding(d1: List[DVar], d2: List[DVar], d11: List[DVar], d12: List[DVar]): 972 """ 973 Generate broadcasting constraints assuming no padding. Broadcasting can happen at any dimension. 974 We look at all combinations for all dimensions in d1 and d2 975 Args: 976 d1: input1 dimensions 977 d2: input2 dimensions 978 d11: broadcasted input1 dimensions 979 d12: broadcasted input2 dimensions 980 981 Returns: broadcasting constraints relating the input dimensions to the broadcasted dimensions 982 983 """ 984 985 size = len(d1) 986 987 res2 = [] 988 989 for i in range(size): 990 t1 = broadcast_dim(d1, d2, d11, d12, i) 991 t2 = broadcast_dim(d2, d1, d12, d11, i) 992 t3 = no_broadcast_dim_with_index(d1, d2, d11, d12, i) 993 994 res2.append(Disj([t1, t2, t3])) 995 996 return Conj(res2) 997 998 999def gen_broadcasting_constraints(e1: TVar, e2: TVar, e11: TVar, e12: TVar, i: int, counter: int): 1000 """ 1001 Simulates broadcasting on e1 and e2 and returns the results 1002 respectively in e11 and e12. Because of gradual types, 1003 e1 and e2 may not be equal. Similarly, e11 and e12 may not 1004 be equal. e11 and e12 should be guaranteed to be consistent 1005 as they represent the shapes of the tensors to be added after 1006 broadcasting. 1007 Args: 1008 e1: TVar representing the type of input 1 1009 e2: TVar representing the type of input 2 1010 e11: TVar representing the representing broadcasted input 1 1011 e12: TVar representing the representing broadcasted input 2 1012 i: The rank of the resulting type of addition 1013 counter: for variable tracking 1014 1015 Returns: Simplified broadcasting constraints 1016 1017 """ 1018 dims, counter = gen_lists_of_dims(4, i, counter) 1019 [d1, d2, d3, d4] = dims 1020 nat_dims_i = gen_nat_constraints(list(itertools.chain.from_iterable(dims))) 1021 1022 initialize_tensors_constraints = create_equality_constraints_for_broadcasting(e1, e2, e11, e12, 1023 d1, d2, d3, d4) 1024 1025 [e1_tensor, e11_tensor, e2_tensor, e12_tensor] = initialize_tensors_constraints 1026 1027 # without padding, broadcast all possibilities for tensors of size i 1028 final_tensor_constraint_no_padding = Conj([*initialize_tensors_constraints, 1029 generate_all_broadcasting_possibilities_no_padding(d1, d2, d3, d4)]) 1030 1031 # with padding, broadcast all possibilities for tensors of size i 1032 final_tensor_constraint_padding_arg1, counter = \ 1033 apply_padding(e1, e11_tensor, e2_tensor, e12_tensor, d2, d3, d4, counter) 1034 1035 final_tensor_constraint_padding_arg2, counter = \ 1036 apply_padding(e2, e12_tensor, e1_tensor, e11_tensor, d1, d4, d3, counter) 1037 1038 return final_tensor_constraint_no_padding, \ 1039 final_tensor_constraint_padding_arg1, \ 1040 final_tensor_constraint_padding_arg2, nat_dims_i, counter 1041