xref: /aosp_15_r20/external/pytorch/torch/jit/_shape_functions.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import math
3from typing import Any, Callable, Dict, List, Optional, Tuple, Union
4
5
6number = Union[int, float]
7# flake8: noqa
8
9###
10# There are generated files that depend on this file
11# To re-generate, please run from the root of the repo:
12# python torchgen/shape_functions/gen_jit_shape_functions.py
13
14# How to test:
15# After regenerating files, compile PyTorch.
16# Then run: ./build/bin/test_jit --gtest_filter=TestShapeGraphLinting.Basic
17# If you have enabled opinfo testing for the op, also run:
18# python test/test_ops_jit.py TestJitCPU.test_variant_consistency_jit_[FAILING_OP]_cpu_float32
19# to reproduce errors from opinfo tests.
20
21# Example PR: https://github.com/pytorch/pytorch/pull/80860/files
22####
23
24import torch
25
26
27def broadcast(a: List[int], b: List[int]):
28    dimsA = len(a)
29    dimsB = len(b)
30    ndim = max(dimsA, dimsB)
31    expandedSizes: List[int] = []
32
33    for i in range(ndim):
34        offset = ndim - 1 - i
35        dimA = dimsA - 1 - offset
36        dimB = dimsB - 1 - offset
37        sizeA = a[dimA] if (dimA >= 0) else 1
38        sizeB = b[dimB] if (dimB >= 0) else 1
39
40        if sizeA != sizeB and sizeA != 1 and sizeB != 1:
41            # TODO: only assertion error is bound in C++ compilation right now
42            raise AssertionError(
43                f"The size of tensor a {sizeA} must match the size of tensor b ({sizeB}) at non-singleton dimension {i}"
44            )
45
46        expandedSizes.append(sizeB if sizeA == 1 else sizeA)
47
48    return expandedSizes
49
50
51def broadcast_three(a: List[int], b: List[int], c: List[int]):
52    return broadcast(broadcast(a, b), c)
53
54
55def broadcast_one_three(a: List[int], b: Any, c: List[int]):
56    return broadcast(a, c)
57
58
59def adaptive_avg_pool2d(self: List[int], out: List[int]):
60    assert len(out) == 2
61    assert len(self) == 3 or len(self) == 4
62    for i in range(1, len(self)):
63        assert self[i] != 0
64
65    shape: List[int] = []
66    for i in range(0, len(self) - 2):
67        shape.append(self[i])
68    for elem in out:
69        shape.append(elem)
70    return shape
71
72
73def _copy(self: List[int]):
74    out: List[int] = []
75    for elem in self:
76        out.append(elem)
77    return out
78
79
80def unary(self: List[int]):
81    return _copy(self)
82
83
84def broadcast_inplace(a: List[int], b: List[int]):
85    dimsA = len(a)
86    dimsB = len(b)
87    if dimsB > dimsA:
88        raise AssertionError(
89            f"The dims of tensor b ({dimsB}) must be less than or equal tothe dims of tensor a ({dimsA}) "
90        )
91    for dimA in range(dimsA):
92        dimB = dimsB - dimsA + dimA
93        sizeA = a[dimA]
94        sizeB = b[dimB] if (dimB >= 0) else 1
95        if sizeA != sizeB and sizeB != 1:
96            # TODO: only assertion error is bound in C++ compilation right now
97            raise AssertionError(
98                "The size of tensor a {} must match the size of tensor b ("
99                "{}) at non-singleton dimension {}".format(sizeA, sizeB, dimA)
100            )
101    return _copy(a)
102
103
104def expand(self: List[int], sizes: List[int]):
105    assert len(sizes) >= len(self)
106    ndim = len(sizes)
107    tensor_dim = len(self)
108    if ndim == 0:
109        return _copy(sizes)
110    out: List[int] = []
111    for i in range(ndim):
112        offset = ndim - 1 - i
113        dim = tensor_dim - 1 - offset
114        size = self[dim] if dim >= 0 else 1
115        targetSize = sizes[i]
116        if targetSize == -1:
117            assert dim >= 0
118            targetSize = size
119        if size != targetSize:
120            assert size == 1
121            size = targetSize
122        out.append(size)
123    return out
124
125
126def expand_one_unused(self: List[int], sizes: List[int], inp0: Any):
127    return expand(self, sizes)
128
129
130def infer_size_impl(shape: List[int], numel: int) -> List[int]:
131    newsize = 1
132    infer_dim: Optional[int] = None
133    for dim in range(len(shape)):
134        if shape[dim] == -1:
135            if infer_dim is not None:
136                raise AssertionError("only one dimension can be inferred")
137            infer_dim = dim
138        elif shape[dim] >= 0:
139            newsize *= shape[dim]
140        else:
141            raise AssertionError("invalid shape dimensions")
142    if not (
143        numel == newsize
144        or (infer_dim is not None and newsize > 0 and numel % newsize == 0)
145    ):
146        raise AssertionError("invalid shape")
147    out = _copy(shape)
148    if infer_dim is not None:
149        out[infer_dim] = numel // newsize
150    return out
151
152
153def numel(sizes: List[int]):
154    numel = 1
155    for elem in sizes:
156        numel *= elem
157    return numel
158
159
160def view(self: List[int], sizes: List[int]):
161    return infer_size_impl(sizes, numel(self))
162
163
164def view_one_unused(self: List[int], sizes: List[int], *, implicit: bool = False):
165    return view(self, sizes)
166
167
168def sum_mean_dim(
169    self: List[int], opt_dims: Optional[List[int]], keep_dim: bool, dt: Any
170):
171    out: List[int] = []
172    if opt_dims is None or len(opt_dims) == 0:
173        dims: List[int] = list(range(len(self)))
174    else:
175        dims = opt_dims
176
177    for idx in range(len(self)):
178        is_mean_dim: bool = False
179        for reduce_dim in dims:
180            if idx == maybe_wrap_dim(reduce_dim, len(self)):
181                is_mean_dim = True
182        if is_mean_dim:
183            if keep_dim:
184                out.append(1)
185        else:
186            out.append(self[idx])
187    return out
188
189
190def max_dim(self: List[int], dim: int, keep_dim: bool):
191    out = sum_mean_dim(self, [dim], keep_dim, None)
192    return out, out
193
194
195# note: python already rounds down towards negative infinity on integer division, special arithmetic not needed
196def div_rtn(x: int, y: int):
197    return x // y
198
199
200def pooling_output_shape_pad_lr(
201    inputSize: int,
202    kernelSize: int,
203    pad_l: int,
204    pad_r: int,
205    stride: int,
206    dilation: int,
207    ceil_mode: bool,
208):
209    outputSize = (
210        div_rtn(
211            inputSize
212            + pad_l
213            + pad_r
214            - dilation * (kernelSize - 1)
215            - 1
216            + (stride - 1 if ceil_mode else 0),
217            stride,
218        )
219        + 1
220    )
221    if ceil_mode:
222        if (outputSize - 1) * stride >= inputSize + pad_l:
223            outputSize = outputSize - 1
224    return outputSize
225
226
227def pooling_output_shape(
228    inputSize: int,
229    kernelSize: int,
230    pad_l: int,
231    stride: int,
232    dilation: int,
233    ceil_mode: bool,
234):
235    assert stride != 0, "stride should not be zeero"
236    return pooling_output_shape_pad_lr(
237        inputSize, kernelSize, pad_l, pad_l, stride, dilation, ceil_mode
238    )
239
240
241def pool2d_shape_check(
242    input: List[int],
243    kH: int,
244    kW: int,
245    dH: int,
246    dW: int,
247    padH: int,
248    padW: int,
249    dilationH: int,
250    dilationW: int,
251    nInputPlane: int,
252    inputHeight: int,
253    inputWidth: int,
254    outputHeight: int,
255    outputWidth: int,
256):
257    ndim = len(input)
258    nOutputPlane = nInputPlane
259
260    assert kW > 0 and kH > 0
261    assert dW > 0 and dH > 0
262    assert dilationH > 0 and dilationW > 0
263
264    valid_dims = input[1] != 0 and input[2] != 0
265    assert (
266        ndim == 3
267        and input[0] != 0
268        and valid_dims
269        or (ndim == 4 and valid_dims and input[3] != 0)
270    )
271
272    assert kW // 2 >= padW and kH // 2 >= padH
273    assert outputWidth >= 1 and outputHeight >= 1
274
275
276def max_pool2d(
277    input: List[int],
278    kernel_size: List[int],
279    stride: List[int],
280    padding: List[int],
281    dilation: List[int],
282    ceil_mode: bool,
283):
284    assert (
285        len(kernel_size) == 1 or len(kernel_size) == 2
286    ), "max_pool2d: kernel_size must either be a single int, or a tuple of two ints"
287    kH = kernel_size[0]
288    kW = kH if len(kernel_size) == 1 else kernel_size[1]
289
290    assert (
291        len(stride) == 0 or len(stride) == 1 or len(stride) == 2
292    ), "max_pool2d: stride must either be omitted, a single int, or a tuple of two ints"
293    dH = kH if len(stride) == 0 else stride[0]
294    if len(stride) == 0:
295        dW = kW
296    elif len(stride) == 1:
297        dW = dH
298    else:
299        dW = stride[1]
300
301    assert (
302        len(padding) == 1 or len(padding) == 2
303    ), "max_pool2d: padding must either be a single int, or a tuple of two ints"
304    padH = padding[0]
305    padW = padH if len(padding) == 1 else padding[1]
306
307    assert (
308        len(dilation) == 1 or len(dilation) == 2
309    ), "max_pool2d: dilation must be either a single int, or a tuple of two ints"
310    dilationH = dilation[0]
311    dilationW = dilationH if len(dilation) == 1 else dilation[1]
312
313    assert len(input) == 3 or len(input) == 4
314
315    nbatch = input[-4] if len(input) == 4 else 1
316    nInputPlane = input[-3]
317    inputHeight = input[-2]
318    inputWidth = input[-1]
319
320    outputHeight = pooling_output_shape(inputHeight, kH, padH, dH, dilationH, ceil_mode)
321    outputWidth = pooling_output_shape(inputWidth, kW, padW, dW, dilationW, ceil_mode)
322
323    pool2d_shape_check(
324        input,
325        kH,
326        kW,
327        dH,
328        dW,
329        padH,
330        padW,
331        dilationH,
332        dilationW,
333        nInputPlane,
334        inputHeight,
335        inputWidth,
336        outputHeight,
337        outputWidth,
338    )
339
340    if len(input) == 3:
341        return [nInputPlane, outputHeight, outputWidth]
342    else:
343        return [nbatch, nInputPlane, outputHeight, outputWidth]
344
345
346def max_pool2d_with_indices(
347    input: List[int],
348    kernel_size: List[int],
349    stride: List[int],
350    padding: List[int],
351    dilation: List[int],
352    ceil_mode: bool,
353):
354    out = max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
355    return (out, out)
356
357
358def upsample_nearest2d(
359    input: List[int],
360    output_size: Optional[List[int]],
361    scale_factors: Optional[List[float]],
362):
363    out: List[int] = []
364    out.append(input[0])
365    out.append(input[1])
366
367    if scale_factors is None and output_size is None:
368        assert 0, "Either output_size or scale_factors must be presented"
369
370    if output_size is not None:
371        assert (
372            scale_factors is None
373        ), "Must specify exactly one of output_size and scale_factors"
374        assert len(output_size) == 2
375        out.append(output_size[0])
376        out.append(output_size[1])
377
378    if scale_factors is not None:
379        assert (
380            output_size is None
381        ), "Must specify exactly one of output_size and scale_factors"
382        assert len(scale_factors) == 2
383        out.append(int(input[2] * scale_factors[0]))
384        out.append(int(input[3] * scale_factors[1]))
385
386    return out
387
388
389def mm(self: List[int], mat2: List[int]):
390    assert len(self) == 2, "self must be a matrix"
391    assert len(mat2) == 2, "mat2 must be a matrix"
392
393    assert self[1] == mat2[0]
394    return [self[0], mat2[1]]
395
396
397def dot(self: List[int], tensor: List[int]):
398    assert len(self) == 1 and len(tensor) == 1
399    assert self[0] == tensor[0]
400    out: List[int] = []
401    return out
402
403
404def mv(self: List[int], vec: List[int]):
405    assert len(self) == 2 and len(vec) == 1
406    assert self[1] == vec[0]
407    # TODO: return self
408    return [self[0]]
409
410
411def unsqueeze(li: List[int], dim: int):
412    dim = maybe_wrap_dim(dim, len(li) + 1)
413    out = _copy(li)
414    out.insert(dim, 1)
415    return out
416
417
418def squeeze_nodim(li: List[int]):
419    out: List[int] = []
420    for i in range(len(li)):
421        if li[i] != 1:
422            out.append(li[i])
423    return out
424
425
426def squeeze(li: List[int], dim: int):
427    out: List[int] = []
428    wrapped_dim = maybe_wrap_dim(dim, len(li))
429    for i in range(len(li)):
430        if i == wrapped_dim:
431            if li[i] != 1:
432                out.append(li[i])
433        else:
434            out.append(li[i])
435    return out
436
437
438def squeeze_dims(li: List[int], dims: List[int]):
439    if len(dims) == 0:
440        return li
441    wrapped_dims = _copy(dims)
442    for i in range(len(dims)):
443        wrapped_dims[i] = maybe_wrap_dim(wrapped_dims[i], len(li))
444    result: List[int] = []
445    for i in range(len(li)):
446        if li[i] == 1:
447            if i not in wrapped_dims:
448                result.append(li[i])
449        else:
450            result.append(li[i])
451    return result
452
453
454def index_select(self: List[int], dim: int, index: List[int]):
455    dim = maybe_wrap_dim(dim, len(self))
456    numel = multiply_integers(index)
457    assert len(index) <= 1
458    assert dim == 0 or dim < len(self)
459    result_size: List[int] = []
460    for i in range(len(self)):
461        if dim == i:
462            result_size.append(numel)
463        else:
464            result_size.append(self[i])
465    return result_size
466
467
468def embedding(
469    weight: List[int],
470    indices: List[int],
471    padding_idx: int = -1,
472    scale_grad_by_freq: bool = False,
473    sparse: bool = False,
474):
475    assert len(weight) == 2
476    if len(indices) == 1:
477        return index_select(weight, 0, indices)
478    size = _copy(indices)
479    size.append(weight[1])
480    return size
481
482
483def max_int():
484    return 9223372036854775807
485
486
487def slice(
488    self: List[int], dim: int, start: Optional[int], end: Optional[int], step: int
489):
490    ndim = len(self)
491    assert ndim != 0
492    dim = maybe_wrap_dim(dim, ndim)
493    start_val = start if start is not None else 0
494    end_val = end if end is not None else max_int()
495    assert step > 0
496    if start_val == max_int():
497        start_val = 0
498    if start_val < 0:
499        start_val += self[dim]
500    if end_val < 0:
501        end_val += self[dim]
502    if start_val < 0:
503        start_val = 0
504    elif start_val > self[dim]:
505        start_val = self[dim]
506    if end_val < start_val:
507        end_val = start_val
508    elif end_val >= self[dim]:
509        end_val = self[dim]
510    slice_len = end_val - start_val
511    out = _copy(self)
512    out[dim] = (slice_len + step - 1) // step
513    return out
514
515
516def check_cat_no_zero_dim(tensors: List[List[int]]):
517    for tensor in tensors:
518        assert len(tensor) > 0
519
520
521def legacy_cat_wrap_dim(dim: int, tensor_sizes: List[List[int]]):
522    out_dim: Optional[int] = None
523    for size in tensor_sizes:
524        if not (len(size) == 1 and size[0] == 0):
525            if out_dim is None:
526                out_dim = maybe_wrap_dim(dim, len(size))
527    if out_dim is None:
528        out_dim = dim
529    return out_dim
530
531
532def should_skip(tensor: List[int]):
533    return numel(tensor) == 0 and len(tensor) == 1
534
535
536def check_cat_shape_except_dim(
537    first: List[int], second: List[int], dimension: int, index: int
538):
539    first_dims = len(first)
540    second_dims = len(second)
541    assert first_dims == second_dims, "Tensors must have same number of dimensions"
542    for dim in range(0, first_dims):
543        if dim != dimension:
544            assert (
545                first[dim] == second[dim]
546            ), "Sizes of tensors must match except in dimension"
547
548
549def cat(tensors: List[List[int]], dim: int):
550    check_cat_no_zero_dim(tensors)
551    dim = legacy_cat_wrap_dim(dim, tensors)
552    assert len(tensors) > 0
553    not_skipped_tensor: Optional[List[int]] = None
554    for tensor in tensors:
555        if not should_skip(tensor):
556            not_skipped_tensor = tensor
557    if not_skipped_tensor is None:
558        return [0]
559
560    cat_dim_size = 0
561
562    for i in range(len(tensors)):
563        tensor = tensors[i]
564        if not should_skip(tensor):
565            check_cat_shape_except_dim(not_skipped_tensor, tensor, dim, i)
566            cat_dim_size = cat_dim_size + tensor[dim]
567
568    result_size = _copy(not_skipped_tensor)
569    result_size[dim] = cat_dim_size
570    return result_size
571
572
573def stack(tensors: List[List[int]], dim: int):
574    unsqueezed_tensors: List[List[int]] = []
575    for tensor in tensors:
576        unsqueezed = unsqueeze(tensor, dim)
577        unsqueezed_tensors.append(unsqueezed)
578    return cat(unsqueezed_tensors, dim)
579
580
581def select(self: List[int], dim: int, index: int):
582    ndim = len(self)
583    assert ndim != 0
584    dim = maybe_wrap_dim(dim, ndim)
585    size = self[dim]
586    assert not (index < -size or index >= size)
587    if index < 0:
588        index += size
589    out: List[int] = []
590    for i in range(ndim):
591        if i != dim:
592            out.append(self[i])
593    return out
594
595
596def matmul(tensor1: List[int], tensor2: List[int]):
597    dim_tensor1 = len(tensor1)
598    dim_tensor2 = len(tensor2)
599    if dim_tensor1 == 1 and dim_tensor2 == 1:
600        return dot(tensor1, tensor2)
601    elif dim_tensor1 == 2 and dim_tensor2 == 1:
602        return mv(tensor1, tensor2)
603    elif dim_tensor1 == 1 and dim_tensor2 == 2:
604        return squeeze(mm(unsqueeze(tensor1, 0), tensor2), 0)
605    elif dim_tensor1 == 2 and dim_tensor2 == 2:
606        return mm(tensor1, tensor2)
607    elif dim_tensor1 >= 1 and dim_tensor2 >= 1:
608        # We are multiplying b1 x n x m1 by x2 x m2 x p (where b1 can be a list);
609        # we track m1 vs m2 separately even though they must match for nicer error messages
610        n = tensor1[-2] if dim_tensor1 > 1 else 1
611        m1 = tensor1[-1]
612        batch_tensor1: List[int] = []
613        # TODO: handling of slice
614        for i in range(dim_tensor1 - 2):
615            batch_tensor1.append(tensor1[i])
616        m2 = tensor2[-1] if dim_tensor2 > 1 else 1
617        p = tensor2[-1]
618        batch_tensor2: List[int] = []
619        # TODO: handling of slice
620        for i in range(dim_tensor2 - 2):
621            batch_tensor2.append(tensor2[i])
622
623        # expand the batch portion (i.e. cut off matrix dimensions and expand rest)
624        expand_batch_portion = broadcast(batch_tensor1, batch_tensor2)
625
626        # todo: copy ?
627        output_shape = expand_batch_portion
628        if dim_tensor1 > 1:
629            output_shape.append(n)
630
631        if dim_tensor2 > 1:
632            output_shape.append(p)
633
634        return output_shape
635    else:
636        assert False, "both  arguments to matmul need to be at least 1D"
637
638
639def t(self: List[int]):
640    assert len(self) <= 2
641    self_len = len(self)
642    if self_len == 0:
643        out: List[int] = []
644        return out
645    elif self_len == 1:
646        return [self[0]]
647    else:
648        return [self[1], self[0]]
649
650
651def transpose(self: List[int], dim0: int, dim1: int):
652    ndims = len(self)
653    dim0 = maybe_wrap_dim(dim0, ndims)
654    dim1 = maybe_wrap_dim(dim1, ndims)
655    if dim0 == dim1:
656        return _copy(self)
657    out: List[int] = []
658    for i in range(ndims):
659        if i == dim0:
660            out.append(self[dim1])
661        elif i == dim1:
662            out.append(self[dim0])
663        else:
664            out.append(self[i])
665    return out
666
667
668def linear(input: List[int], weight: List[int], bias: Optional[List[int]]):
669    out = matmul(input, t(weight))
670    if bias is not None:
671        assert broadcast(bias, out) == out
672    return out
673
674
675def addmm(self: List[int], mat1: List[int], mat2: List[int], beta: Any, alpha: Any):
676    return broadcast(self, mm(mat1, mat2))
677
678
679def check_non_negative(array: List[int]) -> bool:
680    # TODO: look into rewriting with early return and getting loop unrolling to fire
681    non_negative = False
682    for val in array:
683        if val < 0:
684            non_negative = True
685    return non_negative
686
687
688def check_shape_forward(
689    input: List[int],
690    weight_sizes: List[int],
691    bias: Optional[List[int]],
692    stride: List[int],
693    padding: List[int],
694    dilation: List[int],
695    groups: int,
696):
697    k = len(input)
698    weight_dim = len(weight_sizes)
699
700    # TODO: assertions could be expanded with the error messages
701    assert not check_non_negative(padding)
702    assert not check_non_negative(stride)
703
704    assert weight_dim == k
705    assert weight_sizes[0] >= groups
706    assert (weight_sizes[0] % groups) == 0
707    # only handling not transposed
708    assert input[1] == weight_sizes[1] * groups
709    assert bias is None or (len(bias) == 1 and bias[0] == weight_sizes[0])
710
711    for i in range(2, k):
712        assert (input[i] + 2 * padding[i - 2]) >= (
713            dilation[i - 2] * (weight_sizes[i] - 1) + 1
714        )
715
716    # this is not handling transposed convolution yet
717
718
719def conv_output_size(
720    input_size: List[int],
721    weight_size: List[int],
722    bias: Optional[List[int]],
723    stride: List[int],
724    padding: List[int],
725    dilation: List[int],
726    groups: int,
727):
728    check_shape_forward(
729        input_size, weight_size, bias, stride, padding, dilation, groups
730    )
731
732    has_dilation = len(dilation) > 0
733    dim = len(input_size)
734    output_size: List[int] = []
735    input_batch_size_dim = 0
736    weight_output_channels_dim = 0
737    output_size.append(input_size[input_batch_size_dim])
738    output_size.append(weight_size[weight_output_channels_dim])
739
740    for d in range(2, dim):
741        dilation_ = dilation[d - 2] if has_dilation else 1
742        kernel = dilation_ * (weight_size[d] - 1) + 1
743        output_size.append(
744            (input_size[d] + (2 * padding[d - 2]) - kernel) // stride[d - 2] + 1
745        )
746    return output_size
747
748
749def conv1d(
750    input: List[int],
751    weight: List[int],
752    bias: Optional[List[int]],
753    stride: List[int],
754    padding: List[int],
755    dilation: List[int],
756    groups: int,
757):
758    assert len(weight) == 3
759    assert len(input) == 3
760    return conv_output_size(input, weight, bias, stride, padding, dilation, groups)
761
762
763def conv2d(
764    input: List[int],
765    weight: List[int],
766    bias: Optional[List[int]],
767    stride: List[int],
768    padding: List[int],
769    dilation: List[int],
770    groups: int,
771):
772    assert len(weight) == 4
773    assert len(input) == 4
774    return conv_output_size(input, weight, bias, stride, padding, dilation, groups)
775
776
777def conv_backwards(
778    grad_output: List[int],
779    input: List[int],
780    weight: List[int],
781    biases: Optional[List[int]],
782):
783    # Bias gradient is always generated regardess of if biases is supplied
784    return _copy(input), _copy(weight), [grad_output[1]]
785
786
787def conv_transpose2d_input(
788    input: List[int],
789    weight: List[int],
790    bias: Optional[List[int]] = None,
791    stride: Optional[List[int]] = None,
792    padding: Optional[List[int]] = None,
793    output_padding: Optional[List[int]] = None,
794    groups: int = 1,
795    dilation: Optional[List[int]] = None,
796) -> List[int]:
797    if stride is None:
798        stride = [1, 1]
799    if padding is None:
800        padding = [0, 0]
801    if output_padding is None:
802        output_padding = [0, 0]
803    if dilation is None:
804        dilation = [1, 1]
805    has_dilation = len(dilation) > 0
806    dim = len(input)
807    output_size: List[int] = []
808    input_batch_size_dim = 0
809    weight_output_channels_dim = 1
810    output_size.append(input[input_batch_size_dim])
811    output_size.append(weight[weight_output_channels_dim] * groups)
812
813    for d in range(2, dim):
814        dilation_ = dilation[d - 2] if has_dilation else 1
815        kernel = dilation_ * (weight[d] - 1)
816        output_size.append(
817            (input[d] - 1) * stride[d - 2]
818            - 2 * padding[d - 2]
819            + kernel
820            + output_padding[d - 2]
821            + 1
822        )
823    return output_size
824
825
826def conv_forwards(
827    input: List[int],
828    weight: List[int],
829    bias: Optional[List[int]],
830    stride: List[int],
831    padding: List[int],
832    dilation: List[int],
833    transposed: bool,
834    output_padding: List[int],
835    groups: int,
836) -> List[int]:
837    has_dilation = len(dilation) > 0
838    has_output_padding = len(output_padding) > 0
839    dim = len(input)
840    output_size: List[int] = []
841    input_batch_size_dim = 0
842    weight_output_channels_dim = 1 if transposed else 0
843    output_size.append(input[input_batch_size_dim])
844    if transposed:
845        output_size.append(weight[weight_output_channels_dim] * groups)
846    else:
847        output_size.append(weight[weight_output_channels_dim])
848
849    for d in range(2, dim):
850        dilation_ = dilation[d - 2] if has_dilation else 1
851        output_padding_ = output_padding[d - 2] if has_output_padding else 0
852        if transposed:
853            kernel = dilation_ * (weight[d] - 1)
854            output_size.append(
855                (input[d] - 1) * stride[d - 2]
856                - 2 * padding[d - 2]
857                + kernel
858                + output_padding_
859                + 1
860            )
861        else:
862            kernel = dilation_ * (weight[d] - 1) + 1
863            output_size.append(
864                (input[d] + (2 * padding[d - 2]) - kernel) // stride[d - 2] + 1
865            )
866    return output_size
867
868
869def _conv_forwards(
870    input: List[int],
871    weight: List[int],
872    bias: Optional[List[int]],
873    stride: List[int],
874    padding: List[int],
875    dilation: List[int],
876    transposed: bool,
877    output_padding: List[int],
878    groups: int,
879    benchmark: bool,
880    deterministic: bool,
881    cudnn_enabled: bool,
882    allow_tf32: bool,
883) -> List[int]:
884    return conv_forwards(
885        input,
886        weight,
887        bias,
888        stride,
889        padding,
890        dilation,
891        transposed,
892        output_padding,
893        groups,
894    )
895
896
897def batch_norm(
898    input: List[int],
899    weight: Optional[List[int]],
900    bias: Optional[List[int]],
901    running_mean: Optional[List[int]],
902    running_var: Optional[List[int]],
903    training: bool,
904    momentum: float,
905    eps: float,
906    cudnn_enabled: bool,
907):
908    out: List[int] = []
909    for elem in input:
910        out.append(elem)
911    return out
912
913
914def conv3d(
915    input: List[int],
916    weight: List[int],
917    bias: Optional[List[int]],
918    stride: List[int],
919    padding: List[int],
920    dilation: List[int],
921    groups: int,
922):
923    assert len(weight) == 5
924    assert len(input) == 5
925    return conv_output_size(input, weight, bias, stride, padding, dilation, groups)
926
927
928def maybe_wrap_dim(dim: int, dim_post_expr: int, wrap_scalar: bool = True):
929    if dim_post_expr <= 0:
930        assert wrap_scalar
931        dim_post_expr = 1
932    min = -dim_post_expr
933    max = dim_post_expr - 1
934    assert not (dim < min or dim > max)
935    if dim < 0:
936        dim += dim_post_expr
937    return dim
938
939
940def zero_dim_tensor(input: Any):
941    out: List[int] = []
942    return out
943
944
945def multiply_integers(li: List[int]):
946    out = 1
947    for elem in li:
948        out = out * elem
949    return out
950
951
952def arange_end(end: number, inp0: Any, inp1: Any, inp2: Any, inp3: Any):
953    assert end >= 0
954    return [int(math.ceil(end))]
955
956
957def arange_start(
958    start: number, end: number, inp0: Any, inp1: Any, inp2: Any, inp3: Any
959):
960    assert end >= 0
961    assert end >= start
962    return [int(math.ceil(end - start))]
963
964
965def arange_start_step(
966    start: number, end: number, step: number, inp0: Any, inp1: Any, inp2: Any, inp3: Any
967):
968    assert step != 0
969    if step < 0:
970        assert start >= end
971    else:
972        assert end >= start
973    return [int(math.ceil((end - start) / step))]
974
975
976def permute(input: List[int], dims: List[int]):
977    assert len(input) == len(dims)
978    ndim = len(dims)
979    seen_dims: List[int] = []
980    newSizes: List[int] = []
981    for i in range(ndim):
982        dim = maybe_wrap_dim(dims[i], ndim)
983        seen_dims.append(dim)
984        newSizes.append(input[dim])
985    for i in range(1, ndim):
986        for j in range(i):
987            assert seen_dims[i] != seen_dims[j]
988    return newSizes
989
990
991def movedim(self: List[int], source: List[int], destination: List[int]) -> List[int]:
992    self_dim = len(self)
993    if self_dim <= 1:
994        return self
995    normalized_src: List[int] = []
996    normalized_dst: List[int] = []
997    for i in range(len(source)):
998        normalized_src.append(maybe_wrap_dim(source[i], self_dim))
999        normalized_dst.append(maybe_wrap_dim(destination[i], self_dim))
1000    order = [-1 for i in range(self_dim)]
1001    src_dims = [i for i in range(self_dim)]
1002    dst_dims = [i for i in range(self_dim)]
1003
1004    for i in range(len(source)):
1005        order[normalized_dst[i]] = normalized_src[i]
1006        src_dims[normalized_src[i]] = -1
1007        dst_dims[normalized_dst[i]] = -1
1008
1009    source_dims: List[int] = []
1010    destination_dims: List[int] = []
1011    for ele in src_dims:
1012        if ele != -1:
1013            source_dims.append(ele)
1014    for ele in dst_dims:
1015        if ele != -1:
1016            destination_dims.append(ele)
1017
1018    rest_dim = self_dim - len(source)
1019    for i in range(rest_dim):
1020        order[destination_dims[i]] = source_dims[i]
1021    return permute(self, order)
1022
1023
1024def flatten(input: List[int], start_dim: int, end_dim: int):
1025    start_dim = maybe_wrap_dim(start_dim, len(input))
1026    end_dim = maybe_wrap_dim(end_dim, len(input))
1027    assert start_dim <= end_dim
1028    if len(input) == 0:
1029        return [1]
1030    if start_dim == end_dim:
1031        # TODO: return self
1032        out: List[int] = []
1033        for elem in input:
1034            out.append(elem)
1035        return out
1036    slice_numel = 1
1037    for i in range(start_dim, end_dim + 1):
1038        slice_numel *= input[i]
1039    # TODO: use slicing when slice optimization has landed
1040    # slice_numel = multiply_integers(input[start_dim:end_dim - start_dim + 1])
1041    shape: List[int] = []
1042    for i in range(start_dim):
1043        shape.append(input[i])
1044    shape.append(slice_numel)
1045    for i in range(end_dim + 1, len(input)):
1046        shape.append(input[i])
1047    return shape
1048
1049
1050def nonzero_lower_bound(input: List[int]):
1051    return [0, len(input)]
1052
1053
1054def nonzero_upper_bound(input: List[int]):
1055    return [numel(input), len(input)]
1056
1057
1058def _reduce_along_dim(self: List[int], dim: int, keepdim: bool):
1059    dim = maybe_wrap_dim(dim, len(self))
1060    out: List[int] = []
1061    for i, self_dim in enumerate(self):
1062        if i == dim:
1063            if keepdim:
1064                out.append(1)
1065        else:
1066            out.append(self_dim)
1067    return out
1068
1069
1070def argmax(
1071    self: List[int], dim: Optional[int] = None, keepdim: bool = False
1072) -> List[int]:
1073    if dim is None:
1074        return []
1075    return _reduce_along_dim(self, dim, keepdim)
1076
1077
1078def bmm(self: List[int], mat2: List[int]) -> List[int]:
1079    assert len(self) == 3, "bmm only supports 3D tensors"
1080    assert len(mat2) == 3, "bmm only supports 3D tensors"
1081    assert self[0] == mat2[0], "mismatching batch dimension"
1082    assert self[2] == mat2[1], "mismatching contracting dimension"
1083    return [self[0], self[1], mat2[2]]
1084
1085
1086def _shape_as_tensor(self: List[int]) -> List[int]:
1087    return [len(self)]
1088
1089
1090def topk(self: List[int], k: int, dim: int = -1) -> Tuple[List[int], List[int]]:
1091    if len(self) == 0:
1092        result: List[int] = []
1093    else:
1094        assert (
1095            k <= self[dim]
1096        ), f"k ({k}) is too big for dimension {dim} of size {self[dim]}"
1097        result = _copy(self)
1098        result[dim] = k
1099    return result, result
1100
1101
1102def nll_loss_forward(
1103    self: List[int], target: List[int], weight: Optional[List[int]], reduction: int
1104) -> Tuple[List[int], List[int]]:
1105    # This is taken shamelessly from the meta function in LossNLL.cpp
1106    self_dim = len(self)
1107    target_dim = len(target)
1108    assert 0 < self_dim <= 2
1109    assert target_dim <= 1
1110    no_batch_dim = self_dim == 1 and target_dim == 0
1111    assert no_batch_dim or (self[0] == target[0])
1112    n_classes = self[-1]
1113    scalar_shape: List[int] = []
1114    assert weight is None or (len(weight) == 1 and weight[0] == n_classes)
1115    if reduction == 0 and self_dim == 2:
1116        reduction_shape = [self[0]]
1117    else:
1118        reduction_shape = scalar_shape
1119    return reduction_shape, scalar_shape
1120
1121
1122def native_layer_norm(
1123    input: List[int], normalized_shape: List[int]
1124) -> Tuple[List[int], List[int], List[int]]:
1125    reduction_shape: List[int] = []
1126    num_unreduced_dimensions = len(input) - len(normalized_shape)
1127    assert num_unreduced_dimensions >= 0
1128    for i in range(num_unreduced_dimensions):
1129        reduction_shape.append(input[i])
1130    for i in range(num_unreduced_dimensions, len(input)):
1131        reduction_shape.append(1)
1132    return _copy(input), reduction_shape, reduction_shape
1133
1134
1135def native_batch_norm(
1136    input: List[int],
1137    weight: Optional[List[int]],
1138    bias: Optional[List[int]],
1139    running_mean: Optional[List[int]],
1140    running_var: Optional[List[int]],
1141    training: bool,
1142) -> Tuple[List[int], List[int], List[int]]:
1143    if training:
1144        _size = [input[1]]
1145    else:
1146        _size = [0]
1147    return _copy(input), _size, _size
1148
1149
1150def _batch_norm_with_update(
1151    input: List[int],
1152    weight: Optional[List[int]],
1153    bias: Optional[List[int]],
1154    running_mean: Optional[List[int]],
1155    running_var: Optional[List[int]],
1156) -> Tuple[List[int], List[int], List[int], List[int]]:
1157    _size = [input[1]]
1158    return _copy(input), _size, _size, [0]
1159
1160
1161def cross_entropy_loss(
1162    self: List[int],
1163    target: List[int],
1164    weight: Optional[List[int]] = None,
1165    reduction: int = 1,
1166    ignore_index: int = -100,
1167    label_smoothing: float = 0.0,
1168) -> List[int]:
1169    result_shape = nll_loss_forward(self, target, weight, reduction)[0]
1170    return result_shape
1171
1172
1173"""
1174Currently deferring the enabling of this, as part of the propoasal to suspend
1175adding ops.
1176There are currently cases in the test case where this is being called
1177in the SSA opinfo tests with with unexpected values (eg list of two ints, see the first
1178opinfo test). The behavoir of index is significantly dependent on the inputs.
1179
1180This could be an error with how we are matching up shape functions, or that this
1181function needs to just implement everything.
1182
1183def index_Tensor(self: List[int], indices: List[Optional[List[int]]]) -> List[int]:
1184    assert len(indices) <= len(self), "More indices than dimensions to index"
1185    broadcasted_shape: List[int] = []
1186    for index_tensor_shape in indices:
1187        if index_tensor_shape is not None:
1188            broadcasted_shape = broadcast(broadcasted_shape, index_tensor_shape)
1189    return broadcasted_shape
1190"""
1191
1192ScriptFn = torch._C.ScriptFunction
1193shape_compute_graph_mapping: Dict[str, ScriptFn] = {}
1194bounded_compute_graph_mapping: Dict[str, Tuple[ScriptFn, ScriptFn]] = {}
1195script_func_map: Dict[Callable, ScriptFn] = {}
1196
1197
1198def process_func(func: Callable):
1199    if func not in script_func_map:
1200        scripted_func = torch.jit.script(func)
1201
1202        torch._C._jit_pass_inline(scripted_func.graph)
1203
1204        for _ in range(2):
1205            torch._C._jit_pass_peephole(scripted_func.graph)
1206            torch._C._jit_pass_constant_propagation(scripted_func.graph)
1207
1208        script_func_map[func] = scripted_func
1209    return script_func_map[func]
1210
1211
1212def add_shape_compute_mapping(operator_schema: str, func: Callable):
1213    global shape_compute_graph_mapping
1214
1215    shape_compute_graph_mapping[operator_schema] = process_func(func)
1216
1217
1218def add_bounded_compute_mapping(
1219    operator_schema: str, lower_bound_func: Callable, upper_bound_func: Callable
1220):
1221    # Adds a shape compute function for both upper and lower bounds
1222    fns = (process_func(lower_bound_func), process_func(upper_bound_func))
1223    bounded_compute_graph_mapping[operator_schema] = fns
1224
1225
1226add_shape_compute_mapping(
1227    "aten::contiguous(Tensor(a) self, *, MemoryFormat memory_format=contiguous_format) -> Tensor(a)",
1228    unary,
1229)
1230add_shape_compute_mapping(
1231    "aten::rsub.Tensor(Tensor self, Scalar other, Scalar alpha=1) -> Tensor", unary
1232)
1233add_shape_compute_mapping(
1234    "aten::dropout(Tensor input, float p, bool train) -> Tensor", unary
1235)
1236add_shape_compute_mapping(
1237    "aten::adaptive_avg_pool2d(Tensor self, int[2] output_size) -> Tensor",
1238    adaptive_avg_pool2d,
1239)
1240add_shape_compute_mapping(
1241    "prim::NumToTensor.Scalar(Scalar a) -> Tensor", zero_dim_tensor
1242)
1243add_shape_compute_mapping("prim::NumToTensor.bool(bool a) -> Tensor", zero_dim_tensor)
1244add_shape_compute_mapping(
1245    "aten::zeros(int[] size, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None) -> (Tensor)",
1246    unary,
1247)
1248add_shape_compute_mapping(
1249    "aten::to.dtype(Tensor(a) self, int dtype, bool non_blocking=False, bool copy=False, int? memory_format=None) -> (Tensor(a))",
1250    unary,
1251)
1252add_shape_compute_mapping(
1253    "aten::arange(Scalar end, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None) -> (Tensor)",
1254    arange_end,
1255)
1256add_shape_compute_mapping(
1257    "aten::arange.start(Scalar start, Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor",
1258    arange_start,
1259)
1260add_shape_compute_mapping(
1261    "aten::arange.start_step(Scalar start, Scalar end, Scalar step, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor",
1262    arange_start_step,
1263)
1264add_shape_compute_mapping("aten::squeeze(Tensor(a) self) -> Tensor(a)", squeeze_nodim)
1265add_shape_compute_mapping(
1266    "aten::squeeze.dim(Tensor(a) self, int dim) -> Tensor(a)", squeeze
1267)
1268add_shape_compute_mapping(
1269    "aten::squeeze.dims(Tensor(a) self, int[] dim) -> Tensor(a)", squeeze_dims
1270)
1271add_shape_compute_mapping(
1272    "aten::unsqueeze(Tensor(a) self, int dim) -> Tensor(a)", unsqueeze
1273)
1274add_shape_compute_mapping(
1275    "aten::slice.Tensor(Tensor(a) self, int dim=0, int? start=None, int? end=None, int step=1) -> Tensor(a)",
1276    slice,
1277)
1278add_shape_compute_mapping(
1279    "aten::select.int(Tensor(a) self, int dim, int index) -> Tensor(a)", select
1280)
1281add_shape_compute_mapping(
1282    "aten::index_select(Tensor self, int dim, Tensor index) -> Tensor", index_select
1283)
1284add_shape_compute_mapping(
1285    "aten::layer_norm(Tensor input, int[] normalized_shape, Tensor? weight=None, Tensor? bias=None, "
1286    "float eps=1e-05, bool cudnn_enable=True) -> Tensor",
1287    unary,
1288)
1289add_shape_compute_mapping(
1290    "aten::softmax.int(Tensor self, int dim, ScalarType? dtype=None) -> Tensor", unary
1291)
1292add_shape_compute_mapping(
1293    "aten::_no_grad_embedding_renorm_(Tensor weight, Tensor input, float max_norm, float norm_type) -> Tensor",
1294    unary,
1295)
1296add_shape_compute_mapping(
1297    "aten::embedding_renorm_(Tensor(a!) self, Tensor indices, float max_norm, float norm_type) -> Tensor(a!)",
1298    unary,
1299)
1300add_shape_compute_mapping(
1301    "aten::embedding(Tensor weight, Tensor indices, int padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False) -> Tensor",
1302    embedding,
1303)
1304add_shape_compute_mapping("aten::mm(Tensor self, Tensor mat2) -> Tensor", mm)
1305add_shape_compute_mapping("aten::dot(Tensor self, Tensor tensor) -> Tensor", dot)
1306add_shape_compute_mapping("aten::mv(Tensor self, Tensor vec) -> Tensor", mv)
1307add_shape_compute_mapping("aten::matmul(Tensor self, Tensor other) -> Tensor", matmul)
1308add_shape_compute_mapping(
1309    "aten::linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor", linear
1310)
1311add_shape_compute_mapping(
1312    "aten::max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor",
1313    max_pool2d,
1314)
1315add_shape_compute_mapping(
1316    "aten::max_pool2d_with_indices(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor)",
1317    max_pool2d_with_indices,
1318)
1319add_shape_compute_mapping("aten::t(Tensor(a) self) -> Tensor(a)", t)
1320add_shape_compute_mapping(
1321    "aten::transpose.int(Tensor(a) self, int dim0, int dim1) -> Tensor(a)", transpose
1322)
1323add_shape_compute_mapping(
1324    "aten::conv1d(Tensor input, Tensor weight, Tensor? bias=None, int[1] stride=1, int[1] padding=0, int[1] dilation=1, int groups=1) -> Tensor",
1325    conv1d,
1326)
1327add_shape_compute_mapping(
1328    "aten::conv2d(Tensor input, Tensor weight, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] dilation=1, int groups=1) -> Tensor",
1329    conv2d,
1330)
1331add_shape_compute_mapping(
1332    "aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor",
1333    batch_norm,
1334)
1335add_shape_compute_mapping(
1336    "aten::conv3d(Tensor input, Tensor weight, Tensor? bias=None, int[3] stride=1, int[3] padding=0, int[3] dilation=1, int groups=1) -> Tensor",
1337    conv3d,
1338)
1339add_shape_compute_mapping(
1340    "aten::convolution_backward(Tensor grad_output, Tensor input, Tensor weight, int[]? bias_sizes, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool[3] output_mask) -> (Tensor, Tensor, Tensor)",
1341    conv_backwards,
1342)
1343add_shape_compute_mapping(
1344    "aten::convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups) -> Tensor",
1345    conv_forwards,
1346)
1347add_shape_compute_mapping(
1348    "aten::_convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32) -> Tensor",
1349    _conv_forwards,
1350)
1351add_shape_compute_mapping(
1352    "aten::conv_transpose2d.input(Tensor input, Tensor weight, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] output_padding=0, int groups=1, int[2] dilation=1) -> Tensor",
1353    conv_transpose2d_input,
1354)
1355add_shape_compute_mapping(
1356    "aten::flatten.using_ints(Tensor(a) self, int start_dim=0, int end_dim=-1) -> Tensor(a)",
1357    flatten,
1358)
1359add_shape_compute_mapping("aten::cat(Tensor[] tensors, int dim=0) -> Tensor", cat)
1360add_shape_compute_mapping("aten::stack(Tensor[] tensors, int dim=0) -> Tensor", stack)
1361add_shape_compute_mapping(
1362    "aten::permute(Tensor(a) self, int[] dims) -> Tensor(a)", permute
1363)
1364add_shape_compute_mapping(
1365    "aten::movedim.intlist(Tensor(a) self, int[] source, int[] destination) -> Tensor(a)",
1366    movedim,
1367)
1368add_shape_compute_mapping("aten::view(Tensor(a) self, int[] size) -> Tensor(a)", view)
1369add_shape_compute_mapping(
1370    "aten::expand_as(Tensor(a) self, Tensor other) -> Tensor(a)", expand
1371)
1372add_shape_compute_mapping(
1373    "aten::expand(Tensor(a) self, int[] size, *, bool implicit=False) -> Tensor(a)",
1374    expand_one_unused,
1375)
1376add_shape_compute_mapping(
1377    "aten::mean.dim(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor",
1378    sum_mean_dim,
1379)
1380add_shape_compute_mapping(
1381    "aten::sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor",
1382    sum_mean_dim,
1383)
1384add_shape_compute_mapping(
1385    "aten::max.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)",
1386    max_dim,
1387)
1388add_shape_compute_mapping(
1389    "aten::mean(Tensor self, *, ScalarType? dtype=None) -> Tensor", zero_dim_tensor
1390)
1391add_shape_compute_mapping(
1392    "aten::sum(Tensor self, *, ScalarType? dtype=None) -> Tensor", zero_dim_tensor
1393)
1394add_shape_compute_mapping(
1395    "aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor",
1396    addmm,
1397)
1398add_shape_compute_mapping(
1399    "aten::upsample_nearest2d.vec(Tensor input, int[]? output_size, float[]? scale_factors) -> (Tensor)",
1400    upsample_nearest2d,
1401)
1402add_shape_compute_mapping(
1403    "aten::quantize_per_tensor(Tensor self, float scale, int zero_point, ScalarType dtype) -> Tensor",
1404    unary,
1405)
1406add_shape_compute_mapping(
1407    "aten::quantize_per_tensor.tensor_qparams(Tensor self, Tensor scale, Tensor zero_point, ScalarType dtype) -> Tensor",
1408    unary,
1409)
1410add_shape_compute_mapping("aten::dequantize(Tensor self) -> Tensor", unary)
1411add_shape_compute_mapping(
1412    "quantized::add(Tensor qa, Tensor qb, float scale, int zero_point) -> Tensor qc",
1413    broadcast,
1414)
1415add_shape_compute_mapping(
1416    "aten::argmax(Tensor self, int? dim=None, bool keepdim=False) -> Tensor", argmax
1417)
1418add_shape_compute_mapping("aten::bmm(Tensor self, Tensor mat2) -> Tensor", bmm)
1419add_shape_compute_mapping(
1420    "aten::_shape_as_tensor(Tensor self) -> Tensor", _shape_as_tensor
1421)
1422add_shape_compute_mapping(
1423    "aten::topk(Tensor self, int k, int dim=-1, bool largest=True, bool sorted=True) -> (Tensor values, Tensor indices)",
1424    topk,
1425)
1426add_shape_compute_mapping(
1427    "aten::nll_loss_forward(Tensor self, Tensor target, Tensor? weight, int reduction, int ignore_index) -> (Tensor output, Tensor total_weight)",
1428    nll_loss_forward,
1429)
1430add_shape_compute_mapping(
1431    "aten::native_layer_norm(Tensor input, int[] normalized_shape, Tensor? weight, Tensor? bias, float eps) -> (Tensor, Tensor, Tensor)",
1432    native_layer_norm,
1433)
1434add_shape_compute_mapping(
1435    "aten::native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)",
1436    native_batch_norm,
1437)
1438add_shape_compute_mapping(
1439    "aten::_native_batch_norm_legit(Tensor input, Tensor? weight, Tensor? bias, Tensor running_mean, Tensor running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)",
1440    native_batch_norm,
1441)
1442add_shape_compute_mapping(
1443    "aten::_native_batch_norm_legit.no_stats(Tensor input, Tensor? weight, Tensor? bias, Tensor running_mean, Tensor running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)",
1444    native_batch_norm,
1445)
1446add_shape_compute_mapping(
1447    "_batch_norm_with_update(Tensor input, Tensor? weight, Tensor? bias, Tensor(a!) running_mean, Tensor(b!) running_var, float momentum, float eps) -> (Tensor, Tensor, Tensor, Tensor)",
1448    _batch_norm_with_update,
1449)
1450
1451add_shape_compute_mapping(
1452    "aten::cross_entropy_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, float label_smoothing=0.0) -> Tensor",
1453    cross_entropy_loss,
1454)
1455# add_shape_compute_mapping("aten::index.Tensor(Tensor self, Tensor?[] indices) -> Tensor", index_Tensor)
1456
1457# TODO: migrate over all of symbolic_shape_registry_util.cpp
1458# These are duplicated here so that the functions will be serialiazed
1459add_shape_compute_mapping(
1460    "aten::lerp.Tensor(Tensor self, Tensor end, Tensor weight) -> Tensor",
1461    broadcast_three,
1462)
1463add_shape_compute_mapping(
1464    "aten::where.ScalarSelf(Tensor condition, Scalar self, Tensor other) -> Tensor",
1465    broadcast_one_three,
1466)
1467add_shape_compute_mapping(
1468    "aten::add_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!)",
1469    broadcast_inplace,
1470)
1471
1472# quantized_conv_prepack TODO
1473
1474# Shape Compute Fn with upper and lower bounds
1475add_bounded_compute_mapping(
1476    "aten::nonzero(Tensor self) -> (Tensor)", nonzero_lower_bound, nonzero_upper_bound
1477)
1478