1# mypy: allow-untyped-defs 2import math 3from typing import Any, Callable, Dict, List, Optional, Tuple, Union 4 5 6number = Union[int, float] 7# flake8: noqa 8 9### 10# There are generated files that depend on this file 11# To re-generate, please run from the root of the repo: 12# python torchgen/shape_functions/gen_jit_shape_functions.py 13 14# How to test: 15# After regenerating files, compile PyTorch. 16# Then run: ./build/bin/test_jit --gtest_filter=TestShapeGraphLinting.Basic 17# If you have enabled opinfo testing for the op, also run: 18# python test/test_ops_jit.py TestJitCPU.test_variant_consistency_jit_[FAILING_OP]_cpu_float32 19# to reproduce errors from opinfo tests. 20 21# Example PR: https://github.com/pytorch/pytorch/pull/80860/files 22#### 23 24import torch 25 26 27def broadcast(a: List[int], b: List[int]): 28 dimsA = len(a) 29 dimsB = len(b) 30 ndim = max(dimsA, dimsB) 31 expandedSizes: List[int] = [] 32 33 for i in range(ndim): 34 offset = ndim - 1 - i 35 dimA = dimsA - 1 - offset 36 dimB = dimsB - 1 - offset 37 sizeA = a[dimA] if (dimA >= 0) else 1 38 sizeB = b[dimB] if (dimB >= 0) else 1 39 40 if sizeA != sizeB and sizeA != 1 and sizeB != 1: 41 # TODO: only assertion error is bound in C++ compilation right now 42 raise AssertionError( 43 f"The size of tensor a {sizeA} must match the size of tensor b ({sizeB}) at non-singleton dimension {i}" 44 ) 45 46 expandedSizes.append(sizeB if sizeA == 1 else sizeA) 47 48 return expandedSizes 49 50 51def broadcast_three(a: List[int], b: List[int], c: List[int]): 52 return broadcast(broadcast(a, b), c) 53 54 55def broadcast_one_three(a: List[int], b: Any, c: List[int]): 56 return broadcast(a, c) 57 58 59def adaptive_avg_pool2d(self: List[int], out: List[int]): 60 assert len(out) == 2 61 assert len(self) == 3 or len(self) == 4 62 for i in range(1, len(self)): 63 assert self[i] != 0 64 65 shape: List[int] = [] 66 for i in range(0, len(self) - 2): 67 shape.append(self[i]) 68 for elem in out: 69 shape.append(elem) 70 return shape 71 72 73def _copy(self: List[int]): 74 out: List[int] = [] 75 for elem in self: 76 out.append(elem) 77 return out 78 79 80def unary(self: List[int]): 81 return _copy(self) 82 83 84def broadcast_inplace(a: List[int], b: List[int]): 85 dimsA = len(a) 86 dimsB = len(b) 87 if dimsB > dimsA: 88 raise AssertionError( 89 f"The dims of tensor b ({dimsB}) must be less than or equal tothe dims of tensor a ({dimsA}) " 90 ) 91 for dimA in range(dimsA): 92 dimB = dimsB - dimsA + dimA 93 sizeA = a[dimA] 94 sizeB = b[dimB] if (dimB >= 0) else 1 95 if sizeA != sizeB and sizeB != 1: 96 # TODO: only assertion error is bound in C++ compilation right now 97 raise AssertionError( 98 "The size of tensor a {} must match the size of tensor b (" 99 "{}) at non-singleton dimension {}".format(sizeA, sizeB, dimA) 100 ) 101 return _copy(a) 102 103 104def expand(self: List[int], sizes: List[int]): 105 assert len(sizes) >= len(self) 106 ndim = len(sizes) 107 tensor_dim = len(self) 108 if ndim == 0: 109 return _copy(sizes) 110 out: List[int] = [] 111 for i in range(ndim): 112 offset = ndim - 1 - i 113 dim = tensor_dim - 1 - offset 114 size = self[dim] if dim >= 0 else 1 115 targetSize = sizes[i] 116 if targetSize == -1: 117 assert dim >= 0 118 targetSize = size 119 if size != targetSize: 120 assert size == 1 121 size = targetSize 122 out.append(size) 123 return out 124 125 126def expand_one_unused(self: List[int], sizes: List[int], inp0: Any): 127 return expand(self, sizes) 128 129 130def infer_size_impl(shape: List[int], numel: int) -> List[int]: 131 newsize = 1 132 infer_dim: Optional[int] = None 133 for dim in range(len(shape)): 134 if shape[dim] == -1: 135 if infer_dim is not None: 136 raise AssertionError("only one dimension can be inferred") 137 infer_dim = dim 138 elif shape[dim] >= 0: 139 newsize *= shape[dim] 140 else: 141 raise AssertionError("invalid shape dimensions") 142 if not ( 143 numel == newsize 144 or (infer_dim is not None and newsize > 0 and numel % newsize == 0) 145 ): 146 raise AssertionError("invalid shape") 147 out = _copy(shape) 148 if infer_dim is not None: 149 out[infer_dim] = numel // newsize 150 return out 151 152 153def numel(sizes: List[int]): 154 numel = 1 155 for elem in sizes: 156 numel *= elem 157 return numel 158 159 160def view(self: List[int], sizes: List[int]): 161 return infer_size_impl(sizes, numel(self)) 162 163 164def view_one_unused(self: List[int], sizes: List[int], *, implicit: bool = False): 165 return view(self, sizes) 166 167 168def sum_mean_dim( 169 self: List[int], opt_dims: Optional[List[int]], keep_dim: bool, dt: Any 170): 171 out: List[int] = [] 172 if opt_dims is None or len(opt_dims) == 0: 173 dims: List[int] = list(range(len(self))) 174 else: 175 dims = opt_dims 176 177 for idx in range(len(self)): 178 is_mean_dim: bool = False 179 for reduce_dim in dims: 180 if idx == maybe_wrap_dim(reduce_dim, len(self)): 181 is_mean_dim = True 182 if is_mean_dim: 183 if keep_dim: 184 out.append(1) 185 else: 186 out.append(self[idx]) 187 return out 188 189 190def max_dim(self: List[int], dim: int, keep_dim: bool): 191 out = sum_mean_dim(self, [dim], keep_dim, None) 192 return out, out 193 194 195# note: python already rounds down towards negative infinity on integer division, special arithmetic not needed 196def div_rtn(x: int, y: int): 197 return x // y 198 199 200def pooling_output_shape_pad_lr( 201 inputSize: int, 202 kernelSize: int, 203 pad_l: int, 204 pad_r: int, 205 stride: int, 206 dilation: int, 207 ceil_mode: bool, 208): 209 outputSize = ( 210 div_rtn( 211 inputSize 212 + pad_l 213 + pad_r 214 - dilation * (kernelSize - 1) 215 - 1 216 + (stride - 1 if ceil_mode else 0), 217 stride, 218 ) 219 + 1 220 ) 221 if ceil_mode: 222 if (outputSize - 1) * stride >= inputSize + pad_l: 223 outputSize = outputSize - 1 224 return outputSize 225 226 227def pooling_output_shape( 228 inputSize: int, 229 kernelSize: int, 230 pad_l: int, 231 stride: int, 232 dilation: int, 233 ceil_mode: bool, 234): 235 assert stride != 0, "stride should not be zeero" 236 return pooling_output_shape_pad_lr( 237 inputSize, kernelSize, pad_l, pad_l, stride, dilation, ceil_mode 238 ) 239 240 241def pool2d_shape_check( 242 input: List[int], 243 kH: int, 244 kW: int, 245 dH: int, 246 dW: int, 247 padH: int, 248 padW: int, 249 dilationH: int, 250 dilationW: int, 251 nInputPlane: int, 252 inputHeight: int, 253 inputWidth: int, 254 outputHeight: int, 255 outputWidth: int, 256): 257 ndim = len(input) 258 nOutputPlane = nInputPlane 259 260 assert kW > 0 and kH > 0 261 assert dW > 0 and dH > 0 262 assert dilationH > 0 and dilationW > 0 263 264 valid_dims = input[1] != 0 and input[2] != 0 265 assert ( 266 ndim == 3 267 and input[0] != 0 268 and valid_dims 269 or (ndim == 4 and valid_dims and input[3] != 0) 270 ) 271 272 assert kW // 2 >= padW and kH // 2 >= padH 273 assert outputWidth >= 1 and outputHeight >= 1 274 275 276def max_pool2d( 277 input: List[int], 278 kernel_size: List[int], 279 stride: List[int], 280 padding: List[int], 281 dilation: List[int], 282 ceil_mode: bool, 283): 284 assert ( 285 len(kernel_size) == 1 or len(kernel_size) == 2 286 ), "max_pool2d: kernel_size must either be a single int, or a tuple of two ints" 287 kH = kernel_size[0] 288 kW = kH if len(kernel_size) == 1 else kernel_size[1] 289 290 assert ( 291 len(stride) == 0 or len(stride) == 1 or len(stride) == 2 292 ), "max_pool2d: stride must either be omitted, a single int, or a tuple of two ints" 293 dH = kH if len(stride) == 0 else stride[0] 294 if len(stride) == 0: 295 dW = kW 296 elif len(stride) == 1: 297 dW = dH 298 else: 299 dW = stride[1] 300 301 assert ( 302 len(padding) == 1 or len(padding) == 2 303 ), "max_pool2d: padding must either be a single int, or a tuple of two ints" 304 padH = padding[0] 305 padW = padH if len(padding) == 1 else padding[1] 306 307 assert ( 308 len(dilation) == 1 or len(dilation) == 2 309 ), "max_pool2d: dilation must be either a single int, or a tuple of two ints" 310 dilationH = dilation[0] 311 dilationW = dilationH if len(dilation) == 1 else dilation[1] 312 313 assert len(input) == 3 or len(input) == 4 314 315 nbatch = input[-4] if len(input) == 4 else 1 316 nInputPlane = input[-3] 317 inputHeight = input[-2] 318 inputWidth = input[-1] 319 320 outputHeight = pooling_output_shape(inputHeight, kH, padH, dH, dilationH, ceil_mode) 321 outputWidth = pooling_output_shape(inputWidth, kW, padW, dW, dilationW, ceil_mode) 322 323 pool2d_shape_check( 324 input, 325 kH, 326 kW, 327 dH, 328 dW, 329 padH, 330 padW, 331 dilationH, 332 dilationW, 333 nInputPlane, 334 inputHeight, 335 inputWidth, 336 outputHeight, 337 outputWidth, 338 ) 339 340 if len(input) == 3: 341 return [nInputPlane, outputHeight, outputWidth] 342 else: 343 return [nbatch, nInputPlane, outputHeight, outputWidth] 344 345 346def max_pool2d_with_indices( 347 input: List[int], 348 kernel_size: List[int], 349 stride: List[int], 350 padding: List[int], 351 dilation: List[int], 352 ceil_mode: bool, 353): 354 out = max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode) 355 return (out, out) 356 357 358def upsample_nearest2d( 359 input: List[int], 360 output_size: Optional[List[int]], 361 scale_factors: Optional[List[float]], 362): 363 out: List[int] = [] 364 out.append(input[0]) 365 out.append(input[1]) 366 367 if scale_factors is None and output_size is None: 368 assert 0, "Either output_size or scale_factors must be presented" 369 370 if output_size is not None: 371 assert ( 372 scale_factors is None 373 ), "Must specify exactly one of output_size and scale_factors" 374 assert len(output_size) == 2 375 out.append(output_size[0]) 376 out.append(output_size[1]) 377 378 if scale_factors is not None: 379 assert ( 380 output_size is None 381 ), "Must specify exactly one of output_size and scale_factors" 382 assert len(scale_factors) == 2 383 out.append(int(input[2] * scale_factors[0])) 384 out.append(int(input[3] * scale_factors[1])) 385 386 return out 387 388 389def mm(self: List[int], mat2: List[int]): 390 assert len(self) == 2, "self must be a matrix" 391 assert len(mat2) == 2, "mat2 must be a matrix" 392 393 assert self[1] == mat2[0] 394 return [self[0], mat2[1]] 395 396 397def dot(self: List[int], tensor: List[int]): 398 assert len(self) == 1 and len(tensor) == 1 399 assert self[0] == tensor[0] 400 out: List[int] = [] 401 return out 402 403 404def mv(self: List[int], vec: List[int]): 405 assert len(self) == 2 and len(vec) == 1 406 assert self[1] == vec[0] 407 # TODO: return self 408 return [self[0]] 409 410 411def unsqueeze(li: List[int], dim: int): 412 dim = maybe_wrap_dim(dim, len(li) + 1) 413 out = _copy(li) 414 out.insert(dim, 1) 415 return out 416 417 418def squeeze_nodim(li: List[int]): 419 out: List[int] = [] 420 for i in range(len(li)): 421 if li[i] != 1: 422 out.append(li[i]) 423 return out 424 425 426def squeeze(li: List[int], dim: int): 427 out: List[int] = [] 428 wrapped_dim = maybe_wrap_dim(dim, len(li)) 429 for i in range(len(li)): 430 if i == wrapped_dim: 431 if li[i] != 1: 432 out.append(li[i]) 433 else: 434 out.append(li[i]) 435 return out 436 437 438def squeeze_dims(li: List[int], dims: List[int]): 439 if len(dims) == 0: 440 return li 441 wrapped_dims = _copy(dims) 442 for i in range(len(dims)): 443 wrapped_dims[i] = maybe_wrap_dim(wrapped_dims[i], len(li)) 444 result: List[int] = [] 445 for i in range(len(li)): 446 if li[i] == 1: 447 if i not in wrapped_dims: 448 result.append(li[i]) 449 else: 450 result.append(li[i]) 451 return result 452 453 454def index_select(self: List[int], dim: int, index: List[int]): 455 dim = maybe_wrap_dim(dim, len(self)) 456 numel = multiply_integers(index) 457 assert len(index) <= 1 458 assert dim == 0 or dim < len(self) 459 result_size: List[int] = [] 460 for i in range(len(self)): 461 if dim == i: 462 result_size.append(numel) 463 else: 464 result_size.append(self[i]) 465 return result_size 466 467 468def embedding( 469 weight: List[int], 470 indices: List[int], 471 padding_idx: int = -1, 472 scale_grad_by_freq: bool = False, 473 sparse: bool = False, 474): 475 assert len(weight) == 2 476 if len(indices) == 1: 477 return index_select(weight, 0, indices) 478 size = _copy(indices) 479 size.append(weight[1]) 480 return size 481 482 483def max_int(): 484 return 9223372036854775807 485 486 487def slice( 488 self: List[int], dim: int, start: Optional[int], end: Optional[int], step: int 489): 490 ndim = len(self) 491 assert ndim != 0 492 dim = maybe_wrap_dim(dim, ndim) 493 start_val = start if start is not None else 0 494 end_val = end if end is not None else max_int() 495 assert step > 0 496 if start_val == max_int(): 497 start_val = 0 498 if start_val < 0: 499 start_val += self[dim] 500 if end_val < 0: 501 end_val += self[dim] 502 if start_val < 0: 503 start_val = 0 504 elif start_val > self[dim]: 505 start_val = self[dim] 506 if end_val < start_val: 507 end_val = start_val 508 elif end_val >= self[dim]: 509 end_val = self[dim] 510 slice_len = end_val - start_val 511 out = _copy(self) 512 out[dim] = (slice_len + step - 1) // step 513 return out 514 515 516def check_cat_no_zero_dim(tensors: List[List[int]]): 517 for tensor in tensors: 518 assert len(tensor) > 0 519 520 521def legacy_cat_wrap_dim(dim: int, tensor_sizes: List[List[int]]): 522 out_dim: Optional[int] = None 523 for size in tensor_sizes: 524 if not (len(size) == 1 and size[0] == 0): 525 if out_dim is None: 526 out_dim = maybe_wrap_dim(dim, len(size)) 527 if out_dim is None: 528 out_dim = dim 529 return out_dim 530 531 532def should_skip(tensor: List[int]): 533 return numel(tensor) == 0 and len(tensor) == 1 534 535 536def check_cat_shape_except_dim( 537 first: List[int], second: List[int], dimension: int, index: int 538): 539 first_dims = len(first) 540 second_dims = len(second) 541 assert first_dims == second_dims, "Tensors must have same number of dimensions" 542 for dim in range(0, first_dims): 543 if dim != dimension: 544 assert ( 545 first[dim] == second[dim] 546 ), "Sizes of tensors must match except in dimension" 547 548 549def cat(tensors: List[List[int]], dim: int): 550 check_cat_no_zero_dim(tensors) 551 dim = legacy_cat_wrap_dim(dim, tensors) 552 assert len(tensors) > 0 553 not_skipped_tensor: Optional[List[int]] = None 554 for tensor in tensors: 555 if not should_skip(tensor): 556 not_skipped_tensor = tensor 557 if not_skipped_tensor is None: 558 return [0] 559 560 cat_dim_size = 0 561 562 for i in range(len(tensors)): 563 tensor = tensors[i] 564 if not should_skip(tensor): 565 check_cat_shape_except_dim(not_skipped_tensor, tensor, dim, i) 566 cat_dim_size = cat_dim_size + tensor[dim] 567 568 result_size = _copy(not_skipped_tensor) 569 result_size[dim] = cat_dim_size 570 return result_size 571 572 573def stack(tensors: List[List[int]], dim: int): 574 unsqueezed_tensors: List[List[int]] = [] 575 for tensor in tensors: 576 unsqueezed = unsqueeze(tensor, dim) 577 unsqueezed_tensors.append(unsqueezed) 578 return cat(unsqueezed_tensors, dim) 579 580 581def select(self: List[int], dim: int, index: int): 582 ndim = len(self) 583 assert ndim != 0 584 dim = maybe_wrap_dim(dim, ndim) 585 size = self[dim] 586 assert not (index < -size or index >= size) 587 if index < 0: 588 index += size 589 out: List[int] = [] 590 for i in range(ndim): 591 if i != dim: 592 out.append(self[i]) 593 return out 594 595 596def matmul(tensor1: List[int], tensor2: List[int]): 597 dim_tensor1 = len(tensor1) 598 dim_tensor2 = len(tensor2) 599 if dim_tensor1 == 1 and dim_tensor2 == 1: 600 return dot(tensor1, tensor2) 601 elif dim_tensor1 == 2 and dim_tensor2 == 1: 602 return mv(tensor1, tensor2) 603 elif dim_tensor1 == 1 and dim_tensor2 == 2: 604 return squeeze(mm(unsqueeze(tensor1, 0), tensor2), 0) 605 elif dim_tensor1 == 2 and dim_tensor2 == 2: 606 return mm(tensor1, tensor2) 607 elif dim_tensor1 >= 1 and dim_tensor2 >= 1: 608 # We are multiplying b1 x n x m1 by x2 x m2 x p (where b1 can be a list); 609 # we track m1 vs m2 separately even though they must match for nicer error messages 610 n = tensor1[-2] if dim_tensor1 > 1 else 1 611 m1 = tensor1[-1] 612 batch_tensor1: List[int] = [] 613 # TODO: handling of slice 614 for i in range(dim_tensor1 - 2): 615 batch_tensor1.append(tensor1[i]) 616 m2 = tensor2[-1] if dim_tensor2 > 1 else 1 617 p = tensor2[-1] 618 batch_tensor2: List[int] = [] 619 # TODO: handling of slice 620 for i in range(dim_tensor2 - 2): 621 batch_tensor2.append(tensor2[i]) 622 623 # expand the batch portion (i.e. cut off matrix dimensions and expand rest) 624 expand_batch_portion = broadcast(batch_tensor1, batch_tensor2) 625 626 # todo: copy ? 627 output_shape = expand_batch_portion 628 if dim_tensor1 > 1: 629 output_shape.append(n) 630 631 if dim_tensor2 > 1: 632 output_shape.append(p) 633 634 return output_shape 635 else: 636 assert False, "both arguments to matmul need to be at least 1D" 637 638 639def t(self: List[int]): 640 assert len(self) <= 2 641 self_len = len(self) 642 if self_len == 0: 643 out: List[int] = [] 644 return out 645 elif self_len == 1: 646 return [self[0]] 647 else: 648 return [self[1], self[0]] 649 650 651def transpose(self: List[int], dim0: int, dim1: int): 652 ndims = len(self) 653 dim0 = maybe_wrap_dim(dim0, ndims) 654 dim1 = maybe_wrap_dim(dim1, ndims) 655 if dim0 == dim1: 656 return _copy(self) 657 out: List[int] = [] 658 for i in range(ndims): 659 if i == dim0: 660 out.append(self[dim1]) 661 elif i == dim1: 662 out.append(self[dim0]) 663 else: 664 out.append(self[i]) 665 return out 666 667 668def linear(input: List[int], weight: List[int], bias: Optional[List[int]]): 669 out = matmul(input, t(weight)) 670 if bias is not None: 671 assert broadcast(bias, out) == out 672 return out 673 674 675def addmm(self: List[int], mat1: List[int], mat2: List[int], beta: Any, alpha: Any): 676 return broadcast(self, mm(mat1, mat2)) 677 678 679def check_non_negative(array: List[int]) -> bool: 680 # TODO: look into rewriting with early return and getting loop unrolling to fire 681 non_negative = False 682 for val in array: 683 if val < 0: 684 non_negative = True 685 return non_negative 686 687 688def check_shape_forward( 689 input: List[int], 690 weight_sizes: List[int], 691 bias: Optional[List[int]], 692 stride: List[int], 693 padding: List[int], 694 dilation: List[int], 695 groups: int, 696): 697 k = len(input) 698 weight_dim = len(weight_sizes) 699 700 # TODO: assertions could be expanded with the error messages 701 assert not check_non_negative(padding) 702 assert not check_non_negative(stride) 703 704 assert weight_dim == k 705 assert weight_sizes[0] >= groups 706 assert (weight_sizes[0] % groups) == 0 707 # only handling not transposed 708 assert input[1] == weight_sizes[1] * groups 709 assert bias is None or (len(bias) == 1 and bias[0] == weight_sizes[0]) 710 711 for i in range(2, k): 712 assert (input[i] + 2 * padding[i - 2]) >= ( 713 dilation[i - 2] * (weight_sizes[i] - 1) + 1 714 ) 715 716 # this is not handling transposed convolution yet 717 718 719def conv_output_size( 720 input_size: List[int], 721 weight_size: List[int], 722 bias: Optional[List[int]], 723 stride: List[int], 724 padding: List[int], 725 dilation: List[int], 726 groups: int, 727): 728 check_shape_forward( 729 input_size, weight_size, bias, stride, padding, dilation, groups 730 ) 731 732 has_dilation = len(dilation) > 0 733 dim = len(input_size) 734 output_size: List[int] = [] 735 input_batch_size_dim = 0 736 weight_output_channels_dim = 0 737 output_size.append(input_size[input_batch_size_dim]) 738 output_size.append(weight_size[weight_output_channels_dim]) 739 740 for d in range(2, dim): 741 dilation_ = dilation[d - 2] if has_dilation else 1 742 kernel = dilation_ * (weight_size[d] - 1) + 1 743 output_size.append( 744 (input_size[d] + (2 * padding[d - 2]) - kernel) // stride[d - 2] + 1 745 ) 746 return output_size 747 748 749def conv1d( 750 input: List[int], 751 weight: List[int], 752 bias: Optional[List[int]], 753 stride: List[int], 754 padding: List[int], 755 dilation: List[int], 756 groups: int, 757): 758 assert len(weight) == 3 759 assert len(input) == 3 760 return conv_output_size(input, weight, bias, stride, padding, dilation, groups) 761 762 763def conv2d( 764 input: List[int], 765 weight: List[int], 766 bias: Optional[List[int]], 767 stride: List[int], 768 padding: List[int], 769 dilation: List[int], 770 groups: int, 771): 772 assert len(weight) == 4 773 assert len(input) == 4 774 return conv_output_size(input, weight, bias, stride, padding, dilation, groups) 775 776 777def conv_backwards( 778 grad_output: List[int], 779 input: List[int], 780 weight: List[int], 781 biases: Optional[List[int]], 782): 783 # Bias gradient is always generated regardess of if biases is supplied 784 return _copy(input), _copy(weight), [grad_output[1]] 785 786 787def conv_transpose2d_input( 788 input: List[int], 789 weight: List[int], 790 bias: Optional[List[int]] = None, 791 stride: Optional[List[int]] = None, 792 padding: Optional[List[int]] = None, 793 output_padding: Optional[List[int]] = None, 794 groups: int = 1, 795 dilation: Optional[List[int]] = None, 796) -> List[int]: 797 if stride is None: 798 stride = [1, 1] 799 if padding is None: 800 padding = [0, 0] 801 if output_padding is None: 802 output_padding = [0, 0] 803 if dilation is None: 804 dilation = [1, 1] 805 has_dilation = len(dilation) > 0 806 dim = len(input) 807 output_size: List[int] = [] 808 input_batch_size_dim = 0 809 weight_output_channels_dim = 1 810 output_size.append(input[input_batch_size_dim]) 811 output_size.append(weight[weight_output_channels_dim] * groups) 812 813 for d in range(2, dim): 814 dilation_ = dilation[d - 2] if has_dilation else 1 815 kernel = dilation_ * (weight[d] - 1) 816 output_size.append( 817 (input[d] - 1) * stride[d - 2] 818 - 2 * padding[d - 2] 819 + kernel 820 + output_padding[d - 2] 821 + 1 822 ) 823 return output_size 824 825 826def conv_forwards( 827 input: List[int], 828 weight: List[int], 829 bias: Optional[List[int]], 830 stride: List[int], 831 padding: List[int], 832 dilation: List[int], 833 transposed: bool, 834 output_padding: List[int], 835 groups: int, 836) -> List[int]: 837 has_dilation = len(dilation) > 0 838 has_output_padding = len(output_padding) > 0 839 dim = len(input) 840 output_size: List[int] = [] 841 input_batch_size_dim = 0 842 weight_output_channels_dim = 1 if transposed else 0 843 output_size.append(input[input_batch_size_dim]) 844 if transposed: 845 output_size.append(weight[weight_output_channels_dim] * groups) 846 else: 847 output_size.append(weight[weight_output_channels_dim]) 848 849 for d in range(2, dim): 850 dilation_ = dilation[d - 2] if has_dilation else 1 851 output_padding_ = output_padding[d - 2] if has_output_padding else 0 852 if transposed: 853 kernel = dilation_ * (weight[d] - 1) 854 output_size.append( 855 (input[d] - 1) * stride[d - 2] 856 - 2 * padding[d - 2] 857 + kernel 858 + output_padding_ 859 + 1 860 ) 861 else: 862 kernel = dilation_ * (weight[d] - 1) + 1 863 output_size.append( 864 (input[d] + (2 * padding[d - 2]) - kernel) // stride[d - 2] + 1 865 ) 866 return output_size 867 868 869def _conv_forwards( 870 input: List[int], 871 weight: List[int], 872 bias: Optional[List[int]], 873 stride: List[int], 874 padding: List[int], 875 dilation: List[int], 876 transposed: bool, 877 output_padding: List[int], 878 groups: int, 879 benchmark: bool, 880 deterministic: bool, 881 cudnn_enabled: bool, 882 allow_tf32: bool, 883) -> List[int]: 884 return conv_forwards( 885 input, 886 weight, 887 bias, 888 stride, 889 padding, 890 dilation, 891 transposed, 892 output_padding, 893 groups, 894 ) 895 896 897def batch_norm( 898 input: List[int], 899 weight: Optional[List[int]], 900 bias: Optional[List[int]], 901 running_mean: Optional[List[int]], 902 running_var: Optional[List[int]], 903 training: bool, 904 momentum: float, 905 eps: float, 906 cudnn_enabled: bool, 907): 908 out: List[int] = [] 909 for elem in input: 910 out.append(elem) 911 return out 912 913 914def conv3d( 915 input: List[int], 916 weight: List[int], 917 bias: Optional[List[int]], 918 stride: List[int], 919 padding: List[int], 920 dilation: List[int], 921 groups: int, 922): 923 assert len(weight) == 5 924 assert len(input) == 5 925 return conv_output_size(input, weight, bias, stride, padding, dilation, groups) 926 927 928def maybe_wrap_dim(dim: int, dim_post_expr: int, wrap_scalar: bool = True): 929 if dim_post_expr <= 0: 930 assert wrap_scalar 931 dim_post_expr = 1 932 min = -dim_post_expr 933 max = dim_post_expr - 1 934 assert not (dim < min or dim > max) 935 if dim < 0: 936 dim += dim_post_expr 937 return dim 938 939 940def zero_dim_tensor(input: Any): 941 out: List[int] = [] 942 return out 943 944 945def multiply_integers(li: List[int]): 946 out = 1 947 for elem in li: 948 out = out * elem 949 return out 950 951 952def arange_end(end: number, inp0: Any, inp1: Any, inp2: Any, inp3: Any): 953 assert end >= 0 954 return [int(math.ceil(end))] 955 956 957def arange_start( 958 start: number, end: number, inp0: Any, inp1: Any, inp2: Any, inp3: Any 959): 960 assert end >= 0 961 assert end >= start 962 return [int(math.ceil(end - start))] 963 964 965def arange_start_step( 966 start: number, end: number, step: number, inp0: Any, inp1: Any, inp2: Any, inp3: Any 967): 968 assert step != 0 969 if step < 0: 970 assert start >= end 971 else: 972 assert end >= start 973 return [int(math.ceil((end - start) / step))] 974 975 976def permute(input: List[int], dims: List[int]): 977 assert len(input) == len(dims) 978 ndim = len(dims) 979 seen_dims: List[int] = [] 980 newSizes: List[int] = [] 981 for i in range(ndim): 982 dim = maybe_wrap_dim(dims[i], ndim) 983 seen_dims.append(dim) 984 newSizes.append(input[dim]) 985 for i in range(1, ndim): 986 for j in range(i): 987 assert seen_dims[i] != seen_dims[j] 988 return newSizes 989 990 991def movedim(self: List[int], source: List[int], destination: List[int]) -> List[int]: 992 self_dim = len(self) 993 if self_dim <= 1: 994 return self 995 normalized_src: List[int] = [] 996 normalized_dst: List[int] = [] 997 for i in range(len(source)): 998 normalized_src.append(maybe_wrap_dim(source[i], self_dim)) 999 normalized_dst.append(maybe_wrap_dim(destination[i], self_dim)) 1000 order = [-1 for i in range(self_dim)] 1001 src_dims = [i for i in range(self_dim)] 1002 dst_dims = [i for i in range(self_dim)] 1003 1004 for i in range(len(source)): 1005 order[normalized_dst[i]] = normalized_src[i] 1006 src_dims[normalized_src[i]] = -1 1007 dst_dims[normalized_dst[i]] = -1 1008 1009 source_dims: List[int] = [] 1010 destination_dims: List[int] = [] 1011 for ele in src_dims: 1012 if ele != -1: 1013 source_dims.append(ele) 1014 for ele in dst_dims: 1015 if ele != -1: 1016 destination_dims.append(ele) 1017 1018 rest_dim = self_dim - len(source) 1019 for i in range(rest_dim): 1020 order[destination_dims[i]] = source_dims[i] 1021 return permute(self, order) 1022 1023 1024def flatten(input: List[int], start_dim: int, end_dim: int): 1025 start_dim = maybe_wrap_dim(start_dim, len(input)) 1026 end_dim = maybe_wrap_dim(end_dim, len(input)) 1027 assert start_dim <= end_dim 1028 if len(input) == 0: 1029 return [1] 1030 if start_dim == end_dim: 1031 # TODO: return self 1032 out: List[int] = [] 1033 for elem in input: 1034 out.append(elem) 1035 return out 1036 slice_numel = 1 1037 for i in range(start_dim, end_dim + 1): 1038 slice_numel *= input[i] 1039 # TODO: use slicing when slice optimization has landed 1040 # slice_numel = multiply_integers(input[start_dim:end_dim - start_dim + 1]) 1041 shape: List[int] = [] 1042 for i in range(start_dim): 1043 shape.append(input[i]) 1044 shape.append(slice_numel) 1045 for i in range(end_dim + 1, len(input)): 1046 shape.append(input[i]) 1047 return shape 1048 1049 1050def nonzero_lower_bound(input: List[int]): 1051 return [0, len(input)] 1052 1053 1054def nonzero_upper_bound(input: List[int]): 1055 return [numel(input), len(input)] 1056 1057 1058def _reduce_along_dim(self: List[int], dim: int, keepdim: bool): 1059 dim = maybe_wrap_dim(dim, len(self)) 1060 out: List[int] = [] 1061 for i, self_dim in enumerate(self): 1062 if i == dim: 1063 if keepdim: 1064 out.append(1) 1065 else: 1066 out.append(self_dim) 1067 return out 1068 1069 1070def argmax( 1071 self: List[int], dim: Optional[int] = None, keepdim: bool = False 1072) -> List[int]: 1073 if dim is None: 1074 return [] 1075 return _reduce_along_dim(self, dim, keepdim) 1076 1077 1078def bmm(self: List[int], mat2: List[int]) -> List[int]: 1079 assert len(self) == 3, "bmm only supports 3D tensors" 1080 assert len(mat2) == 3, "bmm only supports 3D tensors" 1081 assert self[0] == mat2[0], "mismatching batch dimension" 1082 assert self[2] == mat2[1], "mismatching contracting dimension" 1083 return [self[0], self[1], mat2[2]] 1084 1085 1086def _shape_as_tensor(self: List[int]) -> List[int]: 1087 return [len(self)] 1088 1089 1090def topk(self: List[int], k: int, dim: int = -1) -> Tuple[List[int], List[int]]: 1091 if len(self) == 0: 1092 result: List[int] = [] 1093 else: 1094 assert ( 1095 k <= self[dim] 1096 ), f"k ({k}) is too big for dimension {dim} of size {self[dim]}" 1097 result = _copy(self) 1098 result[dim] = k 1099 return result, result 1100 1101 1102def nll_loss_forward( 1103 self: List[int], target: List[int], weight: Optional[List[int]], reduction: int 1104) -> Tuple[List[int], List[int]]: 1105 # This is taken shamelessly from the meta function in LossNLL.cpp 1106 self_dim = len(self) 1107 target_dim = len(target) 1108 assert 0 < self_dim <= 2 1109 assert target_dim <= 1 1110 no_batch_dim = self_dim == 1 and target_dim == 0 1111 assert no_batch_dim or (self[0] == target[0]) 1112 n_classes = self[-1] 1113 scalar_shape: List[int] = [] 1114 assert weight is None or (len(weight) == 1 and weight[0] == n_classes) 1115 if reduction == 0 and self_dim == 2: 1116 reduction_shape = [self[0]] 1117 else: 1118 reduction_shape = scalar_shape 1119 return reduction_shape, scalar_shape 1120 1121 1122def native_layer_norm( 1123 input: List[int], normalized_shape: List[int] 1124) -> Tuple[List[int], List[int], List[int]]: 1125 reduction_shape: List[int] = [] 1126 num_unreduced_dimensions = len(input) - len(normalized_shape) 1127 assert num_unreduced_dimensions >= 0 1128 for i in range(num_unreduced_dimensions): 1129 reduction_shape.append(input[i]) 1130 for i in range(num_unreduced_dimensions, len(input)): 1131 reduction_shape.append(1) 1132 return _copy(input), reduction_shape, reduction_shape 1133 1134 1135def native_batch_norm( 1136 input: List[int], 1137 weight: Optional[List[int]], 1138 bias: Optional[List[int]], 1139 running_mean: Optional[List[int]], 1140 running_var: Optional[List[int]], 1141 training: bool, 1142) -> Tuple[List[int], List[int], List[int]]: 1143 if training: 1144 _size = [input[1]] 1145 else: 1146 _size = [0] 1147 return _copy(input), _size, _size 1148 1149 1150def _batch_norm_with_update( 1151 input: List[int], 1152 weight: Optional[List[int]], 1153 bias: Optional[List[int]], 1154 running_mean: Optional[List[int]], 1155 running_var: Optional[List[int]], 1156) -> Tuple[List[int], List[int], List[int], List[int]]: 1157 _size = [input[1]] 1158 return _copy(input), _size, _size, [0] 1159 1160 1161def cross_entropy_loss( 1162 self: List[int], 1163 target: List[int], 1164 weight: Optional[List[int]] = None, 1165 reduction: int = 1, 1166 ignore_index: int = -100, 1167 label_smoothing: float = 0.0, 1168) -> List[int]: 1169 result_shape = nll_loss_forward(self, target, weight, reduction)[0] 1170 return result_shape 1171 1172 1173""" 1174Currently deferring the enabling of this, as part of the propoasal to suspend 1175adding ops. 1176There are currently cases in the test case where this is being called 1177in the SSA opinfo tests with with unexpected values (eg list of two ints, see the first 1178opinfo test). The behavoir of index is significantly dependent on the inputs. 1179 1180This could be an error with how we are matching up shape functions, or that this 1181function needs to just implement everything. 1182 1183def index_Tensor(self: List[int], indices: List[Optional[List[int]]]) -> List[int]: 1184 assert len(indices) <= len(self), "More indices than dimensions to index" 1185 broadcasted_shape: List[int] = [] 1186 for index_tensor_shape in indices: 1187 if index_tensor_shape is not None: 1188 broadcasted_shape = broadcast(broadcasted_shape, index_tensor_shape) 1189 return broadcasted_shape 1190""" 1191 1192ScriptFn = torch._C.ScriptFunction 1193shape_compute_graph_mapping: Dict[str, ScriptFn] = {} 1194bounded_compute_graph_mapping: Dict[str, Tuple[ScriptFn, ScriptFn]] = {} 1195script_func_map: Dict[Callable, ScriptFn] = {} 1196 1197 1198def process_func(func: Callable): 1199 if func not in script_func_map: 1200 scripted_func = torch.jit.script(func) 1201 1202 torch._C._jit_pass_inline(scripted_func.graph) 1203 1204 for _ in range(2): 1205 torch._C._jit_pass_peephole(scripted_func.graph) 1206 torch._C._jit_pass_constant_propagation(scripted_func.graph) 1207 1208 script_func_map[func] = scripted_func 1209 return script_func_map[func] 1210 1211 1212def add_shape_compute_mapping(operator_schema: str, func: Callable): 1213 global shape_compute_graph_mapping 1214 1215 shape_compute_graph_mapping[operator_schema] = process_func(func) 1216 1217 1218def add_bounded_compute_mapping( 1219 operator_schema: str, lower_bound_func: Callable, upper_bound_func: Callable 1220): 1221 # Adds a shape compute function for both upper and lower bounds 1222 fns = (process_func(lower_bound_func), process_func(upper_bound_func)) 1223 bounded_compute_graph_mapping[operator_schema] = fns 1224 1225 1226add_shape_compute_mapping( 1227 "aten::contiguous(Tensor(a) self, *, MemoryFormat memory_format=contiguous_format) -> Tensor(a)", 1228 unary, 1229) 1230add_shape_compute_mapping( 1231 "aten::rsub.Tensor(Tensor self, Scalar other, Scalar alpha=1) -> Tensor", unary 1232) 1233add_shape_compute_mapping( 1234 "aten::dropout(Tensor input, float p, bool train) -> Tensor", unary 1235) 1236add_shape_compute_mapping( 1237 "aten::adaptive_avg_pool2d(Tensor self, int[2] output_size) -> Tensor", 1238 adaptive_avg_pool2d, 1239) 1240add_shape_compute_mapping( 1241 "prim::NumToTensor.Scalar(Scalar a) -> Tensor", zero_dim_tensor 1242) 1243add_shape_compute_mapping("prim::NumToTensor.bool(bool a) -> Tensor", zero_dim_tensor) 1244add_shape_compute_mapping( 1245 "aten::zeros(int[] size, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None) -> (Tensor)", 1246 unary, 1247) 1248add_shape_compute_mapping( 1249 "aten::to.dtype(Tensor(a) self, int dtype, bool non_blocking=False, bool copy=False, int? memory_format=None) -> (Tensor(a))", 1250 unary, 1251) 1252add_shape_compute_mapping( 1253 "aten::arange(Scalar end, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None) -> (Tensor)", 1254 arange_end, 1255) 1256add_shape_compute_mapping( 1257 "aten::arange.start(Scalar start, Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", 1258 arange_start, 1259) 1260add_shape_compute_mapping( 1261 "aten::arange.start_step(Scalar start, Scalar end, Scalar step, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", 1262 arange_start_step, 1263) 1264add_shape_compute_mapping("aten::squeeze(Tensor(a) self) -> Tensor(a)", squeeze_nodim) 1265add_shape_compute_mapping( 1266 "aten::squeeze.dim(Tensor(a) self, int dim) -> Tensor(a)", squeeze 1267) 1268add_shape_compute_mapping( 1269 "aten::squeeze.dims(Tensor(a) self, int[] dim) -> Tensor(a)", squeeze_dims 1270) 1271add_shape_compute_mapping( 1272 "aten::unsqueeze(Tensor(a) self, int dim) -> Tensor(a)", unsqueeze 1273) 1274add_shape_compute_mapping( 1275 "aten::slice.Tensor(Tensor(a) self, int dim=0, int? start=None, int? end=None, int step=1) -> Tensor(a)", 1276 slice, 1277) 1278add_shape_compute_mapping( 1279 "aten::select.int(Tensor(a) self, int dim, int index) -> Tensor(a)", select 1280) 1281add_shape_compute_mapping( 1282 "aten::index_select(Tensor self, int dim, Tensor index) -> Tensor", index_select 1283) 1284add_shape_compute_mapping( 1285 "aten::layer_norm(Tensor input, int[] normalized_shape, Tensor? weight=None, Tensor? bias=None, " 1286 "float eps=1e-05, bool cudnn_enable=True) -> Tensor", 1287 unary, 1288) 1289add_shape_compute_mapping( 1290 "aten::softmax.int(Tensor self, int dim, ScalarType? dtype=None) -> Tensor", unary 1291) 1292add_shape_compute_mapping( 1293 "aten::_no_grad_embedding_renorm_(Tensor weight, Tensor input, float max_norm, float norm_type) -> Tensor", 1294 unary, 1295) 1296add_shape_compute_mapping( 1297 "aten::embedding_renorm_(Tensor(a!) self, Tensor indices, float max_norm, float norm_type) -> Tensor(a!)", 1298 unary, 1299) 1300add_shape_compute_mapping( 1301 "aten::embedding(Tensor weight, Tensor indices, int padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False) -> Tensor", 1302 embedding, 1303) 1304add_shape_compute_mapping("aten::mm(Tensor self, Tensor mat2) -> Tensor", mm) 1305add_shape_compute_mapping("aten::dot(Tensor self, Tensor tensor) -> Tensor", dot) 1306add_shape_compute_mapping("aten::mv(Tensor self, Tensor vec) -> Tensor", mv) 1307add_shape_compute_mapping("aten::matmul(Tensor self, Tensor other) -> Tensor", matmul) 1308add_shape_compute_mapping( 1309 "aten::linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor", linear 1310) 1311add_shape_compute_mapping( 1312 "aten::max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor", 1313 max_pool2d, 1314) 1315add_shape_compute_mapping( 1316 "aten::max_pool2d_with_indices(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor)", 1317 max_pool2d_with_indices, 1318) 1319add_shape_compute_mapping("aten::t(Tensor(a) self) -> Tensor(a)", t) 1320add_shape_compute_mapping( 1321 "aten::transpose.int(Tensor(a) self, int dim0, int dim1) -> Tensor(a)", transpose 1322) 1323add_shape_compute_mapping( 1324 "aten::conv1d(Tensor input, Tensor weight, Tensor? bias=None, int[1] stride=1, int[1] padding=0, int[1] dilation=1, int groups=1) -> Tensor", 1325 conv1d, 1326) 1327add_shape_compute_mapping( 1328 "aten::conv2d(Tensor input, Tensor weight, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] dilation=1, int groups=1) -> Tensor", 1329 conv2d, 1330) 1331add_shape_compute_mapping( 1332 "aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor", 1333 batch_norm, 1334) 1335add_shape_compute_mapping( 1336 "aten::conv3d(Tensor input, Tensor weight, Tensor? bias=None, int[3] stride=1, int[3] padding=0, int[3] dilation=1, int groups=1) -> Tensor", 1337 conv3d, 1338) 1339add_shape_compute_mapping( 1340 "aten::convolution_backward(Tensor grad_output, Tensor input, Tensor weight, int[]? bias_sizes, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool[3] output_mask) -> (Tensor, Tensor, Tensor)", 1341 conv_backwards, 1342) 1343add_shape_compute_mapping( 1344 "aten::convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups) -> Tensor", 1345 conv_forwards, 1346) 1347add_shape_compute_mapping( 1348 "aten::_convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32) -> Tensor", 1349 _conv_forwards, 1350) 1351add_shape_compute_mapping( 1352 "aten::conv_transpose2d.input(Tensor input, Tensor weight, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] output_padding=0, int groups=1, int[2] dilation=1) -> Tensor", 1353 conv_transpose2d_input, 1354) 1355add_shape_compute_mapping( 1356 "aten::flatten.using_ints(Tensor(a) self, int start_dim=0, int end_dim=-1) -> Tensor(a)", 1357 flatten, 1358) 1359add_shape_compute_mapping("aten::cat(Tensor[] tensors, int dim=0) -> Tensor", cat) 1360add_shape_compute_mapping("aten::stack(Tensor[] tensors, int dim=0) -> Tensor", stack) 1361add_shape_compute_mapping( 1362 "aten::permute(Tensor(a) self, int[] dims) -> Tensor(a)", permute 1363) 1364add_shape_compute_mapping( 1365 "aten::movedim.intlist(Tensor(a) self, int[] source, int[] destination) -> Tensor(a)", 1366 movedim, 1367) 1368add_shape_compute_mapping("aten::view(Tensor(a) self, int[] size) -> Tensor(a)", view) 1369add_shape_compute_mapping( 1370 "aten::expand_as(Tensor(a) self, Tensor other) -> Tensor(a)", expand 1371) 1372add_shape_compute_mapping( 1373 "aten::expand(Tensor(a) self, int[] size, *, bool implicit=False) -> Tensor(a)", 1374 expand_one_unused, 1375) 1376add_shape_compute_mapping( 1377 "aten::mean.dim(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", 1378 sum_mean_dim, 1379) 1380add_shape_compute_mapping( 1381 "aten::sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", 1382 sum_mean_dim, 1383) 1384add_shape_compute_mapping( 1385 "aten::max.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)", 1386 max_dim, 1387) 1388add_shape_compute_mapping( 1389 "aten::mean(Tensor self, *, ScalarType? dtype=None) -> Tensor", zero_dim_tensor 1390) 1391add_shape_compute_mapping( 1392 "aten::sum(Tensor self, *, ScalarType? dtype=None) -> Tensor", zero_dim_tensor 1393) 1394add_shape_compute_mapping( 1395 "aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor", 1396 addmm, 1397) 1398add_shape_compute_mapping( 1399 "aten::upsample_nearest2d.vec(Tensor input, int[]? output_size, float[]? scale_factors) -> (Tensor)", 1400 upsample_nearest2d, 1401) 1402add_shape_compute_mapping( 1403 "aten::quantize_per_tensor(Tensor self, float scale, int zero_point, ScalarType dtype) -> Tensor", 1404 unary, 1405) 1406add_shape_compute_mapping( 1407 "aten::quantize_per_tensor.tensor_qparams(Tensor self, Tensor scale, Tensor zero_point, ScalarType dtype) -> Tensor", 1408 unary, 1409) 1410add_shape_compute_mapping("aten::dequantize(Tensor self) -> Tensor", unary) 1411add_shape_compute_mapping( 1412 "quantized::add(Tensor qa, Tensor qb, float scale, int zero_point) -> Tensor qc", 1413 broadcast, 1414) 1415add_shape_compute_mapping( 1416 "aten::argmax(Tensor self, int? dim=None, bool keepdim=False) -> Tensor", argmax 1417) 1418add_shape_compute_mapping("aten::bmm(Tensor self, Tensor mat2) -> Tensor", bmm) 1419add_shape_compute_mapping( 1420 "aten::_shape_as_tensor(Tensor self) -> Tensor", _shape_as_tensor 1421) 1422add_shape_compute_mapping( 1423 "aten::topk(Tensor self, int k, int dim=-1, bool largest=True, bool sorted=True) -> (Tensor values, Tensor indices)", 1424 topk, 1425) 1426add_shape_compute_mapping( 1427 "aten::nll_loss_forward(Tensor self, Tensor target, Tensor? weight, int reduction, int ignore_index) -> (Tensor output, Tensor total_weight)", 1428 nll_loss_forward, 1429) 1430add_shape_compute_mapping( 1431 "aten::native_layer_norm(Tensor input, int[] normalized_shape, Tensor? weight, Tensor? bias, float eps) -> (Tensor, Tensor, Tensor)", 1432 native_layer_norm, 1433) 1434add_shape_compute_mapping( 1435 "aten::native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)", 1436 native_batch_norm, 1437) 1438add_shape_compute_mapping( 1439 "aten::_native_batch_norm_legit(Tensor input, Tensor? weight, Tensor? bias, Tensor running_mean, Tensor running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)", 1440 native_batch_norm, 1441) 1442add_shape_compute_mapping( 1443 "aten::_native_batch_norm_legit.no_stats(Tensor input, Tensor? weight, Tensor? bias, Tensor running_mean, Tensor running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)", 1444 native_batch_norm, 1445) 1446add_shape_compute_mapping( 1447 "_batch_norm_with_update(Tensor input, Tensor? weight, Tensor? bias, Tensor(a!) running_mean, Tensor(b!) running_var, float momentum, float eps) -> (Tensor, Tensor, Tensor, Tensor)", 1448 _batch_norm_with_update, 1449) 1450 1451add_shape_compute_mapping( 1452 "aten::cross_entropy_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, float label_smoothing=0.0) -> Tensor", 1453 cross_entropy_loss, 1454) 1455# add_shape_compute_mapping("aten::index.Tensor(Tensor self, Tensor?[] indices) -> Tensor", index_Tensor) 1456 1457# TODO: migrate over all of symbolic_shape_registry_util.cpp 1458# These are duplicated here so that the functions will be serialiazed 1459add_shape_compute_mapping( 1460 "aten::lerp.Tensor(Tensor self, Tensor end, Tensor weight) -> Tensor", 1461 broadcast_three, 1462) 1463add_shape_compute_mapping( 1464 "aten::where.ScalarSelf(Tensor condition, Scalar self, Tensor other) -> Tensor", 1465 broadcast_one_three, 1466) 1467add_shape_compute_mapping( 1468 "aten::add_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!)", 1469 broadcast_inplace, 1470) 1471 1472# quantized_conv_prepack TODO 1473 1474# Shape Compute Fn with upper and lower bounds 1475add_bounded_compute_mapping( 1476 "aten::nonzero(Tensor self) -> (Tensor)", nonzero_lower_bound, nonzero_upper_bound 1477) 1478