1# mypy: allow-untyped-defs 2from torch.fx.experimental.migrate_gradual_types.operation import op_add, op_sub, op_mul, op_div, \ 3 op_mod, op_gt, op_lt, op_neq, op_eq 4from torch.fx.tensor_type import TensorType, Dyn 5 6 7class Constraint: 8 pass 9 10 11class Conj(Constraint): 12 def __init__(self, conjuncts): 13 """ 14 :param conjuncts: Conjunction of constraints 15 """ 16 self.conjucts = conjuncts 17 18 def __eq__(self, other): 19 if isinstance(other, Conj): 20 return self.conjucts == other.conjucts and self.conjucts == other.conjucts 21 else: 22 return False 23 24 def __repr__(self): 25 return f'And({self.conjucts})' 26 27 28class Disj(Constraint): 29 def __init__(self, disjuncts): 30 """ 31 :param disjuncts: Disjunction of constraints 32 """ 33 self.disjuncts = disjuncts 34 35 def __eq__(self, other): 36 if isinstance(other, Disj): 37 return self.disjuncts == other.disjuncts and self.disjuncts == other.disjuncts 38 else: 39 return False 40 41 def __repr__(self): 42 return f'Or({self.disjuncts})' 43 44 45class Prod(Constraint): 46 def __init__(self, products): 47 """ 48 :param products: lists of dimensions to multiply 49 """ 50 self.products = products 51 52 def __eq__(self, other): 53 if isinstance(other, Prod): 54 return self.products == other.products and self.products == other.products 55 else: 56 return False 57 58 def __repr__(self): 59 return f'Product({self.products})' 60 61 62class T(Constraint): 63 """ 64 True 65 """ 66 def __init__(self) -> None: 67 pass 68 69 def __eq__(self, other): 70 return isinstance(other, T) 71 72 def __repr__(self): 73 return 'True' 74 75class F(Constraint): 76 """ 77 False 78 """ 79 def __init__(self) -> None: 80 pass 81 82 def __eq__(self, other): 83 return isinstance(other, F) 84 85 def __repr__(self): 86 return 'False' 87 88 89class BinaryConstraint(Constraint): 90 """ 91 Represents all binary operations 92 """ 93 def __init__(self, lhs, rhs, op): 94 """ 95 :param lhs: lhs of the constraint 96 :param rhs: rhs of the constraint 97 :param op: string representing the operation 98 """ 99 self.lhs = lhs 100 self.rhs = rhs 101 self.op = op 102 103 def __eq__(self, other): 104 if isinstance(other, BinaryConstraint): 105 return self.lhs == other.lhs and self.rhs == other.rhs and self.op == other.op 106 else: 107 return False 108 109 def __repr__(self): 110 return f'({self.lhs} {self.op} {self.rhs})' 111 112 113class BinConstraintT(BinaryConstraint): 114 """ 115 Binary constraints about tensors 116 """ 117 def __init__(self, lhs, rhs, op): 118 assert (isinstance(lhs, (TVar, TensorType, int)) or lhs == Dyn) and \ 119 (isinstance(rhs, (TVar, TensorType, int)) or rhs == Dyn) 120 super().__init__(lhs, rhs, op) 121 122 def __eq__(self, other): 123 return super().__eq__(other) 124 125 126class BinConstraintD(BinaryConstraint): 127 """ 128 Binary constraints about dimensions 129 """ 130 def __init__(self, lhs, rhs, op): 131 assert is_algebraic_expression(lhs) or is_dim(lhs) or is_bool_expr(lhs) 132 assert is_algebraic_expression(rhs) or is_dim(rhs) or is_bool_expr(rhs) 133 134 super().__init__(lhs, rhs, op) 135 136 def __eq__(self, other): 137 return super().__eq__(other) 138 139 140 141class TGreatestUpperBound(Constraint): 142 """ 143 Greatest Upper bound for tensors with dynamic type 144 """ 145 def __init__(self, res, rhs1, rhs2): 146 """ 147 :param res: tensor variable that stores the result of the outout 148 :param rhs1: tensor or tensor variable 149 :param rhs2: tensor or tensor variabke 150 """ 151 self.res = res 152 self.rhs1 = rhs1 153 self.rhs2 = rhs2 154 155 def __repr__(self): 156 return f'{self.res} = {self.rhs1}\u2294*{self.rhs2}' 157 158 def __eq__(self, other): 159 if isinstance(other, TGreatestUpperBound): 160 return self.res == other.res and self.rhs1 == other.rhs1 and self.rhs2 == other.rhs2 161 else: 162 return False 163 164 165class DGreatestUpperBound(Constraint): 166 """ 167 Greatest Upper bound for dimensions 168 """ 169 def __init__(self, res, rhs1, rhs2): 170 """ 171 :param res: Dimension variable to store the result 172 :param rhs1: dimension variable 1 173 :param rhs2: dimension variable 2 174 """ 175 assert is_dim(res) 176 assert is_dim(rhs1) 177 assert is_dim(rhs2) 178 179 self.res = res 180 self.rhs1 = rhs1 181 self.rhs2 = rhs2 182 183 def __repr__(self): 184 return f'{self.res} = {self.rhs1}\u2294{self.rhs2}' 185 186 def __eq__(self, other): 187 if isinstance(other, DGreatestUpperBound): 188 return self.res == other.res and self.rhs1 == other.rhs1 and self.rhs2 == other.rhs2 189 else: 190 return False 191 192 193class CanReshape(Constraint): 194 """ 195 can_reshape constraint 196 """ 197 def __init__(self, src, target): 198 """ 199 :param src: tensor variable 200 :param target: tensor 201 """ 202 self.src = src 203 self.target = target 204 205 def __repr__(self): 206 return f'can-reshape({self.src}, {self.target})' 207 208 def __eq__(self, other): 209 if isinstance(other, CanReshape): 210 return self.src == other.src and self.target == other.target 211 else: 212 return False 213 214 215class IndexSelect(Constraint): 216 217 def __init__(self, tensor_size, input_var, dim_replace, index, output): 218 """ 219 Args: 220 input_var: input to index_select 221 tensor_size: tensor size we are considering 222 dim_replace: the dimension of the output at "index" 223 index: location of the dimensions to replace in the input 224 output: variable to store the result 225 """ 226 assert isinstance(input_var, TVar) 227 assert isinstance(output, TVar) 228 assert isinstance(dim_replace, DVar) or dim_replace == Dyn 229 assert isinstance(index, int) 230 231 self.input_var = input_var 232 self.tensor_size = tensor_size 233 self.dim_replace = dim_replace 234 self.index = index 235 self.output = output 236 237 def __repr__(self): 238 239 return f' {self.output} = ' \ 240 f'IndexSelect({self.input_var}, ' \ 241 f'tensor_size: {self.tensor_size}, ' \ 242 f'{self.dim_replace}, ' \ 243 f'{self.index})' 244 245 def __eq__(self, other): 246 if isinstance(other, IndexSelect): 247 return self.tensor_size == other.tensor_size and \ 248 self.dim_replace == other.dim_replace and \ 249 self.index == other.index and \ 250 self.output == other.output and \ 251 self.input_var == other.input_var 252 else: 253 return False 254 255 256class Transpose(Constraint): 257 258 def __init__(self, tensor_size, input_var, index1, index2, output): 259 """ 260 Args: 261 tensor_size: current tensor size 262 input_var: variable to hold input 263 index1: dimension 1 264 index2: dimension 2 265 output: output that stores result 266 """ 267 assert isinstance(input_var, TVar) 268 assert isinstance(output, TVar) 269 assert isinstance(index1, int) 270 assert isinstance(index2, int) 271 272 self.input_var = input_var 273 self.tensor_size = tensor_size 274 self.index1 = index1 275 self.index2 = index2 276 self.output = output 277 278 def __repr__(self): 279 280 return f' {self.output} = ' \ 281 f'Transpose({self.input_var}, ' \ 282 f'tensor_size: {self.tensor_size}, ' \ 283 f'{self.index1}, ' \ 284 f'{self.index2})' 285 286 def __eq__(self, other): 287 if isinstance(other, Transpose): 288 return self.tensor_size == other.tensor_size and \ 289 self.index1 == other.index1 and \ 290 self.index2 == other.index2 and \ 291 self.output == other.output and \ 292 self.input_var == other.input_var 293 else: 294 return False 295 296 297class GetItem(Constraint): 298 299 def __init__(self, tensor_size, index, res, input_var): 300 """ 301 Constraint for getting item given a tensor size 302 :param tensor_size: actual number 303 :param index: actual number representing the index 304 :param res: dimension variable to carry the item we get 305 :param input_var: a tensor variable from which we will get item 306 """ 307 assert isinstance(res, DVar) 308 309 self.res = res 310 self.tensor_size = tensor_size 311 self.index = index 312 self.input_var = input_var 313 314 def __repr__(self): 315 return f' {self.res} = GetItem({self.input_var}, tensor_size: {self.tensor_size}, {self.index})' 316 317 def __eq__(self, other): 318 if isinstance(other, GetItem): 319 return self.res == other.res and \ 320 self.tensor_size == other.tensor_size and \ 321 self.index == other.index and \ 322 self.input_var == other.input_var 323 else: 324 return False 325 326class GetItemTensor(Constraint): 327 328 def __init__(self, tensor_size, index_tuple, res, input_var): 329 """ 330 Constraint for getting item given a tensor size 331 However, when the argument is a tuple, we will 332 expect a tensor 333 :param tensor_size: actual number representing the rank 334 :param index_tuple: tuple for indexing 335 :param res: tensor variable to carry the item we get 336 :param input_var: a tensor variable from which we will get item 337 """ 338 assert isinstance(res, TVar) 339 340 self.res = res 341 self.tensor_size = tensor_size 342 self.index_tuple = index_tuple 343 self.input_var = input_var 344 345 def __repr__(self): 346 return f' {self.res} = GetItemT({self.input_var}, tensor_size: {self.tensor_size}, {self.index_tuple})' 347 348 def __eq__(self, other): 349 if isinstance(other, GetItemTensor): 350 return self.res == other.res and \ 351 self.tensor_size == other.tensor_size and \ 352 self.index_tuple == other.index_tuple and \ 353 self.input_var == other.input_var 354 else: 355 return False 356 357class CalcConv(Constraint): 358 359 def __init__(self, conv_result, input_var, c_out, kernel, padding, stride, dilation, matching_constraint_vars): 360 """ 361 :param conv_result: the convolution result 362 :param input_var: input to convolution 363 :param c_out: output chanel type 364 :param kernel: kernel tuple 365 """ 366 self.conv_result = conv_result 367 self.input_var = input_var 368 self.c_out = c_out 369 self.kernel = kernel 370 self.padding = padding 371 self.stride = stride 372 self.dilation = dilation 373 self.matching_constraint = matching_constraint_vars 374 375 def __repr__(self): 376 return f'{self.conv_result} =' \ 377 f' calc-conv({self.input_var},' \ 378 f' {self.c_out}, {self.kernel}, ' \ 379 f'{self.padding}, {self.stride},' \ 380 f' {self.dilation})' 381 382 def __eq__(self, other): 383 if isinstance(other, CalcConv): 384 return self.conv_result == other.conv_result and self.input_var == other.input_var and \ 385 self.c_out == other.c_out and self.kernel == other.kernel and self.padding == other.padding \ 386 and self.stride == other.stride and self.dilation == other.dilation \ 387 and self.matching_constraint == other.matching_constraint 388 else: 389 return False 390 391 392class CalcMaxPool(Constraint): 393 394 def __init__(self, maxpool_result, input_var, kernel, padding, stride, dilation, matching_constraint_vars): 395 """ 396 :param maxpool_result: the result of maxpool 397 :param input_var: input to convolution 398 :param kernel: kernel tuple 399 """ 400 self.maxpool_result = maxpool_result 401 self.input_var = input_var 402 self.kernel = kernel 403 self.padding = padding 404 self.stride = stride 405 self.dilation = dilation 406 self.matching_constraint = matching_constraint_vars 407 408 def __repr__(self): 409 return f'{self.maxpool_result} =' \ 410 f' calc-maxpool({self.input_var},' \ 411 f' {self.kernel}, ' \ 412 f'{self.padding}, {self.stride},' \ 413 f' {self.dilation})' 414 415 def __eq__(self, other): 416 if isinstance(other, CalcMaxPool): 417 return self.maxpool_result == other.maxpool_result and self.input_var == other.input_var \ 418 and self.kernel == other.kernel and self.padding == other.padding \ 419 and self.stride == other.stride and self.dilation == other.dilation \ 420 and self.matching_constraint == other.matching_constraint 421 else: 422 return False 423 424 425class ApplyBroadcasting(Constraint): 426 def __init__(self, res1, res2, input1, input2): 427 """ 428 :param res1: resulting tensor 1 429 :param res2: resulting tensor 2 430 :param input1: tensor variable 1 431 :param input2: tensor variable 2 432 """ 433 self.res1 = res1 434 self.res2 = res2 435 self.input1 = input1 436 self.input2 = input2 437 438 def __eq__(self, other): 439 if isinstance(other, ApplyBroadcasting): 440 return self.res1 == other.res1 \ 441 and self.res2 == other.res2 \ 442 and self.input1 == other.input1 \ 443 and self.input2 == other.input2 444 else: 445 return False 446 447 def __repr__(self): 448 return f'{self.res1}, {self.res2} ='f' apply-broadcasting({self.input1},' f' {self.input2})' 449 450 451class CalcProduct(Constraint): 452 """ 453 Given correct dimensions, calculate the product for flatten accounting for Dyn 454 """ 455 def __init__(self, start, end, flattened, dims_to_flatten): 456 """ 457 :param start: start index 458 :param end: end index 459 :param flattened: variable to store the product 460 :param dims_to_flatten: the type which we will flatten 461 """ 462 assert isinstance(dims_to_flatten, list) 463 assert isinstance(flattened, TVar) 464 assert isinstance(start, int) 465 assert isinstance(end, int) 466 467 self.start = start 468 self.end = end 469 self.dims_to_flatten = dims_to_flatten 470 self.flattened = flattened 471 472 def __eq__(self, other): 473 if isinstance(other, CalcProduct): 474 return self.start == other.start and self.end == other.end and \ 475 self.dims_to_flatten == other.dims_to_flatten and self.flattened == other.flattened 476 477 else: 478 return False 479 480 def __repr__(self): 481 return f'{self.flattened} = CalcProduct({self.start}, {self.end}, {self.dims_to_flatten})' 482 483 484class TVar: 485 """ 486 Tensor variable with no tensor constructor 487 """ 488 def __init__(self, tvar): 489 """ 490 :param tvar: tensor variable 491 """ 492 self.tvar = tvar 493 494 def __repr__(self): 495 return f'TV({self.tvar})' 496 497 def __eq__(self, other): 498 if isinstance(other, TVar): 499 return self.tvar == other.tvar 500 else: 501 return False 502 503 504class DVar: 505 """ 506 Dimension variable 507 """ 508 def __init__(self, c): 509 """ 510 :param c: character or number 511 """ 512 self.c = c 513 514 def __repr__(self): 515 return f'DV({self.c})' 516 517 def __eq__(self, other): 518 if isinstance(other, DVar): 519 return self.c == other.c 520 else: 521 return False 522 523 524class BVar: 525 """ 526 Boolean variable 527 """ 528 def __init__(self, c): 529 """ 530 :param c: character or number 531 """ 532 self.c = c 533 534 def __repr__(self): 535 return f'BV({self.c})' 536 537 def __eq__(self, other): 538 if isinstance(other, BVar): 539 return self.c == other.c 540 else: 541 return False 542 543 544def is_algebraic_expression(constraint): 545 if isinstance(constraint, BinConstraintD): 546 return constraint.op in [op_add, op_sub, op_div, op_mul, op_mod] 547 else: 548 return isinstance(constraint, Prod) 549 550 551def is_bool_expr(constraint): 552 if isinstance(constraint, BinConstraintD): 553 return constraint.op in [op_gt, op_lt, op_neq, op_eq] 554 else: 555 return isinstance(constraint, (BVar, Conj, Disj)) 556 557def is_dim(d): 558 return isinstance(d, (DVar, int)) or d == Dyn 559