xref: /aosp_15_r20/external/executorch/backends/cadence/aot/ops_registrations.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-strict
8
9from math import prod
10from typing import Optional, Tuple
11
12import torch
13from executorch.exir.scalar_type import ScalarType
14from torch.library import Library, register_fake
15
16from .utils import get_conv1d_output_size, get_conv2d_output_size
17
18lib = Library("cadence", "DEF")
19
20lib.define(
21    "quantize_per_tensor(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype) -> (Tensor Z)"
22)
23lib.define(
24    "quantize_per_tensor.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)"
25)
26
27lib.define(
28    "dequantize_per_tensor(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype) -> (Tensor Z)"
29)
30lib.define(
31    "dequantize_per_tensor.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)"
32)
33
34lib.define(
35    "quantized_layer_norm(Tensor X, Tensor X_scale, Tensor X_zero_point, int[] normalized_shape, Tensor weight, Tensor bias, float eps, float output_scale, int output_zero_point) -> (Tensor Y)"
36)
37lib.define(
38    "quantized_layer_norm.out(Tensor X, Tensor X_scale, Tensor X_zero_point, int[] normalized_shape, Tensor weight, Tensor bias, float eps, float output_scale, int output_zero_point, *, Tensor(a!) out) -> Tensor (a!)"
39)
40lib.define(
41    "quantized_layer_norm.per_tensor(Tensor X, float X_scale, int X_zero_point, int[] normalized_shape, Tensor weight, Tensor bias, float eps, float output_scale, int output_zero_point) -> (Tensor Y)"
42)
43lib.define(
44    "quantized_layer_norm.per_tensor_out(Tensor X, float X_scale, int X_zero_point, int[] normalized_shape, Tensor weight, Tensor bias, float eps, float output_scale, int output_zero_point, *, Tensor(a!) out) -> Tensor (a!)"
45)
46
47lib.define(
48    "quantized_linear(Tensor src, Tensor weight, Tensor bias, int src_zero_point, Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, Tensor? offset) -> (Tensor Z)"
49)
50lib.define(
51    "quantized_linear.out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) ->  Tensor(a!)"
52)
53lib.define(
54    "quantized_linear.per_tensor_out(Tensor src, Tensor weight, Tensor bias, SymInt src_zero_point, SymInt weight_zero_point, SymInt out_multiplier, SymInt out_shift, SymInt out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)"
55)
56lib.define(
57    "quantized_linear.per_tensor(Tensor src, Tensor weight, Tensor bias, SymInt src_zero_point, "
58    "SymInt weight_zero_point, SymInt out_multiplier, SymInt out_shift, SymInt out_zero_point, Tensor? offset) -> Tensor"
59)
60
61lib.define(
62    "quantized_relu(Tensor X, Tensor X_zero_point, int out_zero_point, Tensor out_multiplier, Tensor out_shift) -> (Tensor Y)"
63)
64lib.define(
65    "quantized_relu.out(Tensor X, Tensor X_zero_point, int out_zero_point, Tensor out_multiplier, Tensor out_shift, *, Tensor(a!) out) -> Tensor (a!)"
66)
67
68lib.define(
69    "quantized_conv(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift, bool channel_last=False) -> (Tensor Z)"
70)
71lib.define(
72    "quantized_conv.out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift, bool channel_last=False, *, Tensor(a!) out) -> Tensor(a!)"
73)
74lib.define(
75    "quantized_conv.per_tensor(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, bool channel_last=False) -> (Tensor Z)"
76)
77lib.define(
78    "quantized_conv.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, bool channel_last=False, *, Tensor(a!) out) -> Tensor(a!)"
79)
80
81lib.define(
82    "quantized_matmul(Tensor X, int X_zero_point, Tensor Y, int Y_zero_point, Tensor? bias, int out_multiplier, int out_shift, int out_zero_point, bool transposed=False) -> (Tensor Z)"
83)
84lib.define(
85    "quantized_matmul.out(Tensor X, int X_zero_point, Tensor Y, int Y_zero_point, Tensor? bias, int out_multiplier, int out_shift, int out_zero_point, bool transposed=False, *, Tensor(a!) out) -> Tensor(a!)"
86)
87
88lib.define(
89    "convolution(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, "
90    "int[] dilation, int groups, bool channel_last=False) -> (Tensor Y)"
91)
92lib.define(
93    "transposed_convolution(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, "
94    "int[] dilation, SymInt[] output_padding, int groups, bool channel_last=False) -> (Tensor Y)"
95)
96lib.define("dequantize(Tensor X, Tensor X_scale, Tensor X_zero_point) -> (Tensor Y)")
97# cadence::quantized_relu is defined in OSS
98lib.define(
99    "quantized_add(Tensor X, Tensor X_scale, Tensor X_zero_point, Tensor Y, Tensor Y_scale, "
100    "Tensor Y_zero_point, float out_scale, int out_zero_point) -> (Tensor Z)"
101)
102lib.define(
103    "quantized_mul(Tensor X, Tensor X_scale, Tensor X_zero_point, Tensor Y, Tensor Y_scale, "
104    "Tensor Y_zero_point, float out_scale, int out_zero_point) -> (Tensor Z)"
105)
106lib.define(
107    "quantized_add_Scalar(Tensor X, Tensor X_scale, Tensor X_zero_point, Scalar Y, "
108    "float out_scale, int out_zero_point) -> (Tensor Z)"
109)
110lib.define(
111    "quantized_mul_Scalar(Tensor X, Tensor X_scale, Tensor X_zero_point, Scalar Y, "
112    "float out_scale, int out_zero_point) -> (Tensor Z)"
113)
114lib.define(
115    "quantized_embedding_byte(Tensor weight, Tensor weight_scales, Tensor weight_zero_points, "
116    "Tensor indices, bool pruned_weights=False) -> (Tensor X)"
117)
118# cadence::quantized_layer_norm is defined in OSS
119# cadence::quantized_conv is defined is OSS
120lib.define(
121    "quantized_transposed_conv(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, "
122    "int[] dilation, SymInt[] output_padding, int groups, int input_zero_point, Tensor weight_zero_point, "
123    "Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift, bool channel_last=False) -> (Tensor out)"
124)
125lib.define(
126    "avg_pool2d(Tensor input, int[2] kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False, "
127    "bool count_include_pad=True, int? divisor_override=None, Tensor? in_zero_point=None, bool channel_last=False) -> (Tensor out)"
128)
129lib.define(
130    "im2row(Tensor input, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride, "
131    "Tensor in_zero_point, bool channel_last=False) -> (Tensor out)"
132)
133lib.define("linalg_vector_norm(Tensor X) -> (Tensor Y)")
134lib.define(
135    "transposed_im2row(Tensor input, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride, "
136    "int[2] output_padding, Tensor in_zero_point, bool channel_last=False) -> (Tensor out)"
137)
138lib.define(
139    "requantize(Tensor input, Tensor in_scale, Tensor in_zero_point, Tensor out_scale, "
140    "Tensor out_zero_point, ScalarType out_dtype) -> (Tensor Y)"
141)
142lib.define(
143    "fully_connected(Tensor input, Tensor weight, Tensor? bias=None) -> (Tensor out)"
144)
145lib.define(
146    "quantized_fully_connected(Tensor src, Tensor weight, Tensor bias, int src_zero_point, "
147    "Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, Tensor? offset) -> (Tensor Z)"
148)
149
150
151# ------------------------------------ #
152#   Migrated from custom_ops.ymal      #
153# ------------------------------------ #
154# Migrated from the custom_ops.yaml files containing different operator variants (e.g., .out, .tensor_out)
155lib.define(
156    "convolution.out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, "
157    "int groups, bool channel_last=False, *, Tensor(a!) out) -> Tensor(a!)"
158)
159lib.define(
160    "transposed_convolution.out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, "
161    "int[] dilation, SymInt[] output_padding, int groups, bool channel_last=False, *, Tensor(a!) out) -> Tensor(a!)"
162)
163# cadence::quantized_relu.out is defined in OSS
164lib.define(
165    "quantized_relu.per_tensor(Tensor X, int X_zero_point, int out_zero_point, int out_multiplier, int out_shift) -> Tensor"
166)
167lib.define(
168    "quantized_relu.per_tensor_out(Tensor X, int X_zero_point, int out_zero_point, int out_multiplier, "
169    "int out_shift, *, Tensor(a!) out) -> Tensor(a!)"
170)
171lib.define(
172    "quantized_add.out(Tensor X, Tensor X_scale, Tensor X_zero_point, Tensor Y, Tensor Y_scale, "
173    "Tensor Y_zero_point, float out_scale, int out_zero_point, *, Tensor(a!) out) -> Tensor(a!)"
174)
175lib.define(
176    "quantized_mul.out(Tensor X, Tensor X_scale, Tensor X_zero_point, Tensor Y, Tensor Y_scale, "
177    "Tensor Y_zero_point, float out_scale, int out_zero_point, *, Tensor(a!) out) -> Tensor(a!)"
178)
179lib.define(
180    "quantized_add_Scalar.out(Tensor X, Tensor X_scale, Tensor X_zero_point, Scalar Y, "
181    "float out_scale, int out_zero_point, *, Tensor(a!) out) -> Tensor(a!)"
182)
183lib.define(
184    "quantized_mul_Scalar.out(Tensor X, Tensor X_scale, Tensor X_zero_point, Scalar Y, "
185    "float out_scale, int out_zero_point, *, Tensor(a!) out) -> Tensor(a!)"
186)
187lib.define(
188    "fully_connected.out(Tensor input, Tensor weight, Tensor? bias=None, *, Tensor(a!) out) -> Tensor(a!)"
189)
190lib.define("linalg_vector_norm.out(Tensor X, *, Tensor(a!) out) -> Tensor(a!)")
191lib.define(
192    "quantized_fully_connected.out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, "
193    "Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)"
194)
195lib.define(
196    "quantized_embedding_byte.out(Tensor weight, Tensor weight_scales, Tensor weight_zero_points, "
197    "Tensor indices, bool pruned_weights=False, *, Tensor(a!) out) -> Tensor(a!)"
198)
199
200lib.define(
201    "quantized_transposed_conv.out(Tensor input, Tensor weight, Tensor bias, int[] stride, "
202    "SymInt[] padding, int[] dilation, SymInt[] output_padding, int groups, int input_zero_point, "
203    "Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, "
204    "Tensor out_multiplier, Tensor out_shift, bool channel_last=False, *, Tensor(a!) out) -> Tensor(a!)"
205)
206lib.define(
207    "avg_pool2d.out(Tensor input, int[2] kernel_size, int[2] stride=[], int[2] padding=0, "
208    "bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None, "
209    "Tensor? in_zero_point=None, bool channel_last=False, *, Tensor(a!) out) -> Tensor(a!)"
210)
211lib.define(
212    "im2row.out(Tensor input, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride, "
213    "Tensor in_zero_point, bool channel_last=False, *, Tensor(a!) out) -> Tensor(a!)"
214)
215lib.define(
216    "transposed_im2row.out(Tensor input, int[2] kernel_size, int[2] dilation, int[2] padding, "
217    "int[2] stride, int[2] output_padding, Tensor in_zero_point, bool channel_last=False, *, Tensor(a!) out) -> Tensor(a!)"
218)
219lib.define(
220    "requantize.out(Tensor input, Tensor in_scale, Tensor in_zero_point, Tensor out_scale, "
221    "Tensor out_zero_point, ScalarType out_dtype, *, Tensor(a!) out) -> Tensor(a!)"
222)
223
224
225# Custom ops with aten namespace. Need to specify the lib var as FRAGMENT type as aten library is already defined
226aten_lib = Library("aten", "FRAGMENT")
227aten_lib.define(
228    "chunk.out(Tensor self, int chunks, int dim=0, *, Tensor(a!)[] out) -> ()"
229)
230aten_lib.define(
231    "contiguous.out(Tensor self, *, MemoryFormat memory_format=contiguous_format, "
232    "Tensor(a!) out) -> Tensor(a!)"
233)
234aten_lib.define(
235    "tensor_split.sections_out(Tensor self, int sections, int dim=0, *, Tensor(a!)[] out) -> ()"
236)
237aten_lib.define(
238    "_slice_copy_nop(Tensor self, int dim=0, SymInt? start=None, SymInt? end=None, "
239    "SymInt step=1) -> Tensor(a!)"
240)
241aten_lib.define(
242    "_select_copy_nop.int_out(Tensor self, int dim, SymInt index, *, Tensor(a!) out) -> Tensor(a!)"
243)
244aten_lib.define(
245    "_slice_copy_nop.Tensor_out(Tensor self, int dim=0, SymInt? start=None, SymInt? end=None, "
246    "SymInt step=1, *, Tensor(a!) out) -> Tensor(a!)"
247)
248aten_lib.define("_cat_nop(Tensor[] tensors, int dim=0) -> Tensor(a!)")
249aten_lib.define(
250    "_cat_nop.out(Tensor[] tensors, int dim=0, *, Tensor(a!) out) -> Tensor(a!)"
251)
252
253# Custom ops with jarvis_nn_ops namespace
254jarvis_nn_lib = Library("jarvis_nn_ops", "DEF")
255jarvis_nn_lib.define(
256    "attention_mask.out(Tensor input, Tensor start, Tensor stop, *, Tensor(a!) out) -> Tensor(a!)"
257)
258
259m = Library("cadence", "IMPL", "Meta")
260
261
262@register_fake("cadence::quantize_per_tensor")
263def quantize_per_tensor_meta(
264    input: torch.Tensor,
265    scale: float,
266    zero_point: int,
267    quant_min: int,
268    quant_max: int,
269    dtype: torch.dtype,
270) -> torch.Tensor:
271    return input.new_empty(input.size(), dtype=dtype)
272
273
274@register_fake("cadence::dequantize_per_tensor")
275def dequantize_per_tensor_meta(
276    input: torch.Tensor,
277    scale: float,
278    zero_point: int,
279    quant_min: int,
280    quant_max: int,
281    dtype: torch.dtype,
282) -> torch.Tensor:
283    return input.new_empty(input.size(), dtype=torch.float)
284
285
286@register_fake("cadence::quantized_linear")
287def quantized_linear_meta(
288    src: torch.Tensor,
289    weight: torch.Tensor,
290    bias: torch.Tensor,
291    in_zero_point: int,
292    weight_zero_point: torch.Tensor,
293    out_multiplier: torch.Tensor,
294    out_shift: torch.Tensor,
295    out_zero_point: int,
296    offset: Optional[torch.Tensor],
297) -> torch.Tensor:
298    # src comes in shape [leading_dims, in_dim]
299    # weight comes in shape [out_dim, in_dim]
300    # output comes in empty with shape [leading_dims, out_dim]
301    out_size = list(src.size())
302    weight_size = list(weight.size())
303    assert len(weight_size) == 2
304    out_size[-1] = weight_size[0]
305    return src.new_empty(out_size, dtype=src.dtype)
306
307
308@register_fake("cadence::quantized_linear.per_tensor")
309def quantized_linear_per_tensor_meta(
310    src: torch.Tensor,
311    weight: torch.Tensor,
312    bias: torch.Tensor,
313    in_zero_point: torch.SymInt,
314    weight_zero_point: torch.SymInt,
315    out_multiplier: torch.SymInt,
316    out_shift: torch.SymInt,
317    out_zero_point: torch.SymInt,
318    offset: Optional[torch.Tensor],
319) -> torch.Tensor:
320    # src comes in shape [leading_dims, in_dim]
321    # weight comes in shape [out_dim, in_dim]
322    # output comes in empty with shape [leading_dims, out_dim]
323    out_size = list(src.size())
324    weight_size = list(weight.size())
325    assert len(weight_size) == 2
326    out_size[-1] = weight_size[0]
327    return src.new_empty(out_size, dtype=src.dtype)
328
329
330@register_fake("cadence::quantized_conv")
331def quantized_conv_meta(
332    input: torch.Tensor,
333    weight: torch.Tensor,
334    bias: torch.Tensor,
335    stride: Tuple[int],
336    padding: Tuple[int],
337    dilation: Tuple[int],
338    groups: int,
339    in_zero_point: int,
340    weight_zero_point: torch.Tensor,
341    bias_scale: torch.Tensor,
342    output_scale: float,
343    output_zero_point: int,
344    out_multiplier: torch.Tensor,
345    out_shift: torch.Tensor,
346    channel_last: bool = False,
347) -> torch.Tensor:
348    if channel_last:
349        out_channels, *kernel_size, _ = weight.shape
350    else:
351        out_channels, _, *kernel_size = weight.shape
352
353    in_size = input.shape
354    # Assert that the input tensor has at least 3 dimensions, and at most 6
355    assert len(in_size) > 2
356    assert len(in_size) < 6
357
358    # Compute the output tensor size
359    output_size = (
360        get_conv1d_output_size(
361            in_size,
362            out_channels,
363            stride[1],
364            padding[1],
365            dilation[1],
366            kernel_size[0],
367            channel_last,
368        )
369        if len(in_size) == 3
370        else get_conv2d_output_size(
371            in_size, out_channels, stride, padding, dilation, kernel_size, channel_last
372        )
373    )
374
375    return input.new_empty(output_size, dtype=input.dtype)
376
377
378@register_fake("cadence::quantized_conv.per_tensor")
379def quantized_conv_per_tensor_meta(
380    input: torch.Tensor,
381    weight: torch.Tensor,
382    bias: torch.Tensor,
383    stride: Tuple[int],
384    padding: Tuple[int],
385    dilation: Tuple[int],
386    groups: int,
387    in_zero_point: int,
388    weight_zero_point: int,
389    bias_scale: float,
390    output_scale: float,
391    output_zero_point: int,
392    out_multiplier: int,
393    out_shift: int,
394    channel_last: bool = False,
395) -> torch.Tensor:
396    if channel_last:
397        out_channels, *kernel_size, _ = weight.shape
398    else:
399        out_channels, _, *kernel_size = weight.shape
400
401    in_size = input.shape
402    # Assert that the input tensor has at least 3 dimensions, and at most 6
403    assert len(in_size) > 2
404    assert len(in_size) < 6
405
406    # Compute the output tensor size
407    output_size = (
408        get_conv1d_output_size(
409            in_size,
410            out_channels,
411            stride[1],
412            padding[1],
413            dilation[1],
414            kernel_size[0],
415            channel_last,
416        )
417        if len(in_size) == 3
418        else get_conv2d_output_size(
419            in_size, out_channels, stride, padding, dilation, kernel_size, channel_last
420        )
421    )
422
423    return input.new_empty(output_size, dtype=input.dtype)
424
425
426@register_fake("cadence::quantized_layer_norm")
427def quantized_layer_norm_meta(
428    input: torch.Tensor,
429    X_scale: torch.Tensor,
430    X_zero_point: torch.Tensor,
431    normalized_shape: int,
432    weight: torch.Tensor,
433    bias: torch.Tensor,
434    eps: float,
435    output_scale: float,
436    output_zero_point: int,
437) -> torch.Tensor:
438    return input.new_empty(input.size(), dtype=input.dtype)
439
440
441@register_fake("cadence::quantized_layer_norm.per_tensor")
442def quantized_layer_norm_per_tensor_meta(
443    input: torch.Tensor,
444    X_scale: float,
445    X_zero_point: int,
446    normalized_shape: int,
447    weight: torch.Tensor,
448    bias: torch.Tensor,
449    eps: float,
450    output_scale: float,
451    output_zero_point: int,
452) -> torch.Tensor:
453    return input.new_empty(input.size(), dtype=input.dtype)
454
455
456@register_fake("cadence::quantized_relu")
457def quantized_relu_meta(
458    X: torch.Tensor,
459    X_zero_point: torch.Tensor,
460    out_zero_point: int,
461    out_multiplier: torch.Tensor,
462    out_shift: torch.Tensor,
463) -> torch.Tensor:
464    return X.new_empty(X.size(), dtype=X.dtype)
465
466
467@register_fake("cadence::quantized_matmul")
468def quantized_matmul_meta(
469    X: torch.Tensor,
470    X_zero_point: int,
471    Y: torch.Tensor,
472    Y_zero_point: int,
473    bias: Optional[torch.Tensor],
474    out_multiplier: int,
475    out_shift: int,
476    out_zero_point: int,
477    transposed: bool = False,
478) -> torch.Tensor:
479    X_size = list(X.size())
480    Y_size = list(Y.size())
481
482    # Get the batch dimensions for both tensors
483    X_batch_dims = X_size[:-2]
484    Y_batch_dims = Y_size[:-2]
485
486    # If they don't match, check that they're compatible
487    if X_batch_dims != Y_batch_dims:
488        assert prod(X_batch_dims) == prod(
489            Y_batch_dims
490        ), f"Batch dimensions of X and Y do not match: {X_batch_dims} vs {Y_batch_dims}"
491
492    # Get the matmul output size
493    if transposed:
494        assert X_size[-1] == Y_size[-1], "matrices cannot be multiplied"
495        mat_size = [X_size[-2], Y_size[-2]]
496    else:
497        assert X_size[-1] == Y_size[-2], "matrices cannot be multiplied"
498        mat_size = [X_size[-2], Y_size[-1]]
499
500    # Combine the larger batch dimensions with the matmul output size
501    out_size = (
502        X_batch_dims + mat_size
503        if len(X_batch_dims) > len(Y_batch_dims)
504        else Y_batch_dims + mat_size
505    )
506
507    return X.new_empty(out_size, dtype=X.dtype)
508
509
510@register_fake("cadence::im2row")
511def im2row_meta(
512    input: torch.Tensor,
513    kernel_size: Tuple[int],
514    dilation: Tuple[int],
515    padding: Tuple[int],
516    stride: Tuple[int],
517    in_zero_point: torch.Tensor,
518    channel_last: bool = False,
519) -> torch.Tensor:
520    if len(input.shape) == 3:
521        height_dim = 1 if channel_last else 2
522        input = input.unsqueeze(height_dim)
523
524    batch_size = input.shape[0]
525    n_input_plane = input.shape[3] if channel_last else input.shape[1]
526    input_height = input.shape[1] if channel_last else input.shape[2]
527    input_width = input.shape[2] if channel_last else input.shape[3]
528    output_height = (
529        input_height + 2 * padding[0] - (dilation[0] * (kernel_size[0] - 1) + 1)
530    ) // stride[0] + 1
531    output_width = (
532        input_width + 2 * padding[1] - (dilation[1] * (kernel_size[1] - 1) + 1)
533    ) // stride[1] + 1
534    n_output_plane = n_input_plane * kernel_size[0] * kernel_size[1]
535    output_size = torch.Size((batch_size, output_height * output_width, n_output_plane))
536    return input.new_empty(output_size, dtype=input.dtype)
537
538
539# Define the abstract implementations of the operators as required
540@register_fake("cadence::linalg_vector_norm")
541def linalg_vector_norm_meta(
542    X: torch.Tensor,
543) -> torch.Tensor:
544    # Output of norm is a scalar, so we return a [] tensor
545    return X.new_empty([], dtype=X.dtype)
546
547
548@register_fake("cadence::requantize")
549def requantize_meta(
550    input: torch.Tensor,
551    in_scale: torch.Tensor,
552    in_zero_point: torch.Tensor,
553    out_scale: torch.Tensor,
554    out_zero_point: torch.Tensor,
555    dtype: ScalarType,
556) -> torch.Tensor:
557    return input.new_empty(
558        input.size(),
559        # pyre-ignore[6]: Incompatible type
560        dtype=dtype,
561    )
562
563
564@register_fake("cadence::quantized_relu.per_tensor")
565def quantized_relu_per_tensor_meta(
566    input: torch.Tensor,
567    in_zero_point: int,
568    out_zero_point: int,
569    out_multiplier: int,
570    out_shift: int,
571) -> torch.Tensor:
572    return input.new_empty(input.size(), dtype=torch.uint8)
573
574
575@register_fake("cadence::fully_connected")
576def fully_connected_meta(
577    src: torch.Tensor,
578    weight: torch.Tensor,
579    bias: torch.Tensor,
580) -> torch.Tensor:
581    # src comes in shape [leading_dims, in_dim]
582    # weight comes in shape [out_dim, in_dim]
583    # output comes in empty with shape [leading_dims, out_dim]
584    out_size = list(src.size())
585    weight_size = list(weight.size())
586    assert len(weight_size) == 2
587    out_size[-1] = weight_size[0]
588    return src.new_empty(out_size, dtype=src.dtype)
589
590
591@register_fake("cadence::quantized_fully_connected")
592def quantized_fully_connected_meta(
593    src: torch.Tensor,
594    weight: torch.Tensor,
595    bias: torch.Tensor,
596    in_zero_point: int,
597    weight_zero_point: torch.Tensor,
598    out_multiplier: int,
599    out_shift: int,
600    out_zero_point: int,
601    offset: Optional[torch.Tensor],
602) -> torch.Tensor:
603    # src comes in shape [leading_dims, in_dim]
604    # weight comes in shape [out_dim, in_dim]
605    # output comes in empty with shape [leading_dims, out_dim]
606    out_size = list(src.size())
607    weight_size = list(weight.size())
608    assert len(weight_size) == 2
609    out_size[-1] = weight_size[0]
610    return src.new_empty(out_size, dtype=torch.uint8)
611
612
613@register_fake("cadence::convolution")
614def convolution_meta(
615    input: torch.Tensor,
616    weight: torch.Tensor,
617    bias: torch.Tensor,
618    stride: Tuple[int],
619    padding: Tuple[int],
620    dilation: Tuple[int],
621    groups: int,
622    channel_last: bool = False,
623) -> torch.Tensor:
624    if channel_last:
625        out_channels, *kernel_size, _ = weight.shape
626    else:
627        out_channels, _, *kernel_size = weight.shape
628    in_size = input.shape
629    # Assert that the input tensor has at least 3 dimensions, and at most 6
630    assert len(in_size) > 2
631    assert len(in_size) < 6
632
633    # Compute the output tensor size
634    output_size = (
635        get_conv1d_output_size(
636            in_size,
637            out_channels,
638            stride[0],
639            padding[0],
640            dilation[0],
641            kernel_size[0],
642            channel_last,
643        )
644        if len(in_size) == 3
645        else get_conv2d_output_size(
646            in_size, out_channels, stride, padding, dilation, kernel_size, channel_last
647        )
648    )
649
650    return input.new_empty(output_size, dtype=input.dtype)
651
652
653@register_fake("cadence::transposed_convolution")
654def transposed_convolution_meta(
655    input: torch.Tensor,
656    weight: torch.Tensor,
657    bias: torch.Tensor,
658    stride: Tuple[int],
659    padding: Tuple[int],
660    dilation: Tuple[int],
661    output_padding: Tuple[int],
662    groups: int,
663    channel_last: bool = False,
664) -> torch.Tensor:
665    # The native definition of torch transposed conv will have weight shape as
666    # (in_channels, out_channels/groups, *kernel_size).
667    # However, the two channel position is flipped in the Jarvis pass of replacing it
668    # with cadence::transposed_convolution here: https://fburl.com/code/d2s7pkyy
669    out_channels, _input_channels, *kernel_size = weight.shape
670    out_channels *= groups
671    in_size = input.shape
672
673    # Get the output size of a transposed 1D convolution given the input size and parameters
674    def get_conv_transpose1d_output_size(
675        in_size: torch.Size,
676        kernel_size: list[int],
677        out_channels: int,
678        stride: Tuple[int],
679        padding: Tuple[int],
680        dilation: Tuple[int],
681        output_padding: Tuple[int],
682        channel_last: bool = False,
683    ) -> torch.Size:
684        assert len(in_size) == 3
685        if channel_last:
686            N, L, C = in_size
687        else:
688            N, C, L = in_size
689
690        # Reference: https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose1d.html
691        lout = (
692            (L - 1) * stride[0]
693            - 2 * padding[0]
694            + dilation[0] * (kernel_size[0] - 1)
695            + output_padding[0]
696            + 1
697        )
698
699        if channel_last:
700            return torch.Size((in_size[0], lout, out_channels))
701        else:
702            return torch.Size((in_size[0], out_channels, lout))
703
704    def get_conv_transpose2d_output_size(
705        in_size: torch.Size,
706        kernel_size: list[int],
707        out_channels: int,
708        stride: Tuple[int],
709        padding: Tuple[int],
710        dilation: Tuple[int],
711        output_padding: Tuple[int],
712        channel_last: bool = False,
713    ) -> torch.Size:
714        assert len(in_size) == 4
715        if channel_last:
716            N, H, W, C = in_size
717        else:
718            N, C, H, W = in_size
719
720        # Reference: https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html
721        hout = (
722            (H - 1) * stride[0]
723            - 2 * padding[0]
724            + dilation[0] * (kernel_size[0] - 1)
725            + output_padding[0]
726            + 1
727        )
728        wout = (
729            (W - 1) * stride[1]
730            - 2 * padding[1]
731            + dilation[1] * (kernel_size[1] - 1)
732            + output_padding[1]
733            + 1
734        )
735
736        if channel_last:
737            return torch.Size((in_size[0], hout, wout, out_channels))
738        else:
739            return torch.Size((in_size[0], out_channels, hout, wout))
740
741    # Compute the output tensor size
742    if len(in_size) == 3:
743        output_size = get_conv_transpose1d_output_size(
744            in_size,
745            kernel_size,
746            out_channels,
747            stride,
748            padding,
749            dilation,
750            output_padding,
751            channel_last,
752        )
753    elif len(in_size) == 4:
754        output_size = get_conv_transpose2d_output_size(
755            in_size,
756            kernel_size,
757            out_channels,
758            stride,
759            padding,
760            dilation,
761            output_padding,
762            channel_last,
763        )
764    else:
765        raise NotImplementedError(
766            f"transposed_convolution meta is not implemented for input tensor with {len(in_size)} dimensions"
767        )
768
769    return input.new_empty(output_size, dtype=input.dtype)
770
771
772@register_fake("cadence::avg_pool2d")
773def avg_pool2d_meta(
774    input: torch.Tensor,
775    kernel_size: Tuple[int],
776    stride: Tuple[int],
777    padding: Tuple[int],
778    ceil_mode: bool,
779    count_include_pad: Optional[bool] = True,
780    divisor_override: Optional[int] = None,
781    in_zero_point: Optional[int] = None,
782    channel_last: bool = False,
783) -> torch.Tensor:
784    # Use torch native meta kernels when operator semantics are similar
785    return torch._meta_registrations.meta_avg_pool2d(
786        input,
787        kernel_size,
788        stride,
789        padding,
790        ceil_mode,
791        count_include_pad,
792        divisor_override,
793    )
794
795
796@register_fake("cadence::transposed_im2row")
797def transposed_im2row_meta(
798    input: torch.Tensor,
799    kernel_size: Tuple[int],
800    dilation: Tuple[int],
801    padding: Tuple[int],
802    stride: Tuple[int],
803    output_padding: Tuple[int],
804    in_zero_point: torch.Tensor,
805    channel_last: bool = False,
806) -> torch.Tensor:
807    if len(input.shape) == 3:
808        height_dim = 1 if channel_last else 2
809        input = input.unsqueeze(height_dim)
810
811    batch_size = input.shape[0]
812    n_input_plane = input.shape[3] if channel_last else input.shape[1]
813    input_height = input.shape[1] if channel_last else input.shape[2]
814    input_width = input.shape[2] if channel_last else input.shape[3]
815    output_height = (
816        (input_height - 1) * stride[0]
817        - 2 * padding[0]
818        + dilation[0] * (kernel_size[0] - 1)
819        + output_padding[0]
820        + 1
821    )
822    output_width = (
823        (input_width - 1) * stride[1]
824        - 2 * padding[1]
825        + dilation[1] * (kernel_size[1] - 1)
826        + output_padding[1]
827        + 1
828    )
829    n_output_plane = n_input_plane * kernel_size[0] * kernel_size[1]
830    output_length = output_height * output_width
831    output_size = torch.Size((batch_size, output_length, n_output_plane))
832
833    return input.new_empty(output_size, dtype=input.dtype)
834