1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-strict 8 9from math import prod 10from typing import Optional, Tuple 11 12import torch 13from executorch.exir.scalar_type import ScalarType 14from torch.library import Library, register_fake 15 16from .utils import get_conv1d_output_size, get_conv2d_output_size 17 18lib = Library("cadence", "DEF") 19 20lib.define( 21 "quantize_per_tensor(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype) -> (Tensor Z)" 22) 23lib.define( 24 "quantize_per_tensor.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)" 25) 26 27lib.define( 28 "dequantize_per_tensor(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype) -> (Tensor Z)" 29) 30lib.define( 31 "dequantize_per_tensor.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)" 32) 33 34lib.define( 35 "quantized_layer_norm(Tensor X, Tensor X_scale, Tensor X_zero_point, int[] normalized_shape, Tensor weight, Tensor bias, float eps, float output_scale, int output_zero_point) -> (Tensor Y)" 36) 37lib.define( 38 "quantized_layer_norm.out(Tensor X, Tensor X_scale, Tensor X_zero_point, int[] normalized_shape, Tensor weight, Tensor bias, float eps, float output_scale, int output_zero_point, *, Tensor(a!) out) -> Tensor (a!)" 39) 40lib.define( 41 "quantized_layer_norm.per_tensor(Tensor X, float X_scale, int X_zero_point, int[] normalized_shape, Tensor weight, Tensor bias, float eps, float output_scale, int output_zero_point) -> (Tensor Y)" 42) 43lib.define( 44 "quantized_layer_norm.per_tensor_out(Tensor X, float X_scale, int X_zero_point, int[] normalized_shape, Tensor weight, Tensor bias, float eps, float output_scale, int output_zero_point, *, Tensor(a!) out) -> Tensor (a!)" 45) 46 47lib.define( 48 "quantized_linear(Tensor src, Tensor weight, Tensor bias, int src_zero_point, Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, Tensor? offset) -> (Tensor Z)" 49) 50lib.define( 51 "quantized_linear.out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)" 52) 53lib.define( 54 "quantized_linear.per_tensor_out(Tensor src, Tensor weight, Tensor bias, SymInt src_zero_point, SymInt weight_zero_point, SymInt out_multiplier, SymInt out_shift, SymInt out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)" 55) 56lib.define( 57 "quantized_linear.per_tensor(Tensor src, Tensor weight, Tensor bias, SymInt src_zero_point, " 58 "SymInt weight_zero_point, SymInt out_multiplier, SymInt out_shift, SymInt out_zero_point, Tensor? offset) -> Tensor" 59) 60 61lib.define( 62 "quantized_relu(Tensor X, Tensor X_zero_point, int out_zero_point, Tensor out_multiplier, Tensor out_shift) -> (Tensor Y)" 63) 64lib.define( 65 "quantized_relu.out(Tensor X, Tensor X_zero_point, int out_zero_point, Tensor out_multiplier, Tensor out_shift, *, Tensor(a!) out) -> Tensor (a!)" 66) 67 68lib.define( 69 "quantized_conv(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift, bool channel_last=False) -> (Tensor Z)" 70) 71lib.define( 72 "quantized_conv.out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift, bool channel_last=False, *, Tensor(a!) out) -> Tensor(a!)" 73) 74lib.define( 75 "quantized_conv.per_tensor(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, bool channel_last=False) -> (Tensor Z)" 76) 77lib.define( 78 "quantized_conv.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, bool channel_last=False, *, Tensor(a!) out) -> Tensor(a!)" 79) 80 81lib.define( 82 "quantized_matmul(Tensor X, int X_zero_point, Tensor Y, int Y_zero_point, Tensor? bias, int out_multiplier, int out_shift, int out_zero_point, bool transposed=False) -> (Tensor Z)" 83) 84lib.define( 85 "quantized_matmul.out(Tensor X, int X_zero_point, Tensor Y, int Y_zero_point, Tensor? bias, int out_multiplier, int out_shift, int out_zero_point, bool transposed=False, *, Tensor(a!) out) -> Tensor(a!)" 86) 87 88lib.define( 89 "convolution(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, " 90 "int[] dilation, int groups, bool channel_last=False) -> (Tensor Y)" 91) 92lib.define( 93 "transposed_convolution(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, " 94 "int[] dilation, SymInt[] output_padding, int groups, bool channel_last=False) -> (Tensor Y)" 95) 96lib.define("dequantize(Tensor X, Tensor X_scale, Tensor X_zero_point) -> (Tensor Y)") 97# cadence::quantized_relu is defined in OSS 98lib.define( 99 "quantized_add(Tensor X, Tensor X_scale, Tensor X_zero_point, Tensor Y, Tensor Y_scale, " 100 "Tensor Y_zero_point, float out_scale, int out_zero_point) -> (Tensor Z)" 101) 102lib.define( 103 "quantized_mul(Tensor X, Tensor X_scale, Tensor X_zero_point, Tensor Y, Tensor Y_scale, " 104 "Tensor Y_zero_point, float out_scale, int out_zero_point) -> (Tensor Z)" 105) 106lib.define( 107 "quantized_add_Scalar(Tensor X, Tensor X_scale, Tensor X_zero_point, Scalar Y, " 108 "float out_scale, int out_zero_point) -> (Tensor Z)" 109) 110lib.define( 111 "quantized_mul_Scalar(Tensor X, Tensor X_scale, Tensor X_zero_point, Scalar Y, " 112 "float out_scale, int out_zero_point) -> (Tensor Z)" 113) 114lib.define( 115 "quantized_embedding_byte(Tensor weight, Tensor weight_scales, Tensor weight_zero_points, " 116 "Tensor indices, bool pruned_weights=False) -> (Tensor X)" 117) 118# cadence::quantized_layer_norm is defined in OSS 119# cadence::quantized_conv is defined is OSS 120lib.define( 121 "quantized_transposed_conv(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, " 122 "int[] dilation, SymInt[] output_padding, int groups, int input_zero_point, Tensor weight_zero_point, " 123 "Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift, bool channel_last=False) -> (Tensor out)" 124) 125lib.define( 126 "avg_pool2d(Tensor input, int[2] kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False, " 127 "bool count_include_pad=True, int? divisor_override=None, Tensor? in_zero_point=None, bool channel_last=False) -> (Tensor out)" 128) 129lib.define( 130 "im2row(Tensor input, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride, " 131 "Tensor in_zero_point, bool channel_last=False) -> (Tensor out)" 132) 133lib.define("linalg_vector_norm(Tensor X) -> (Tensor Y)") 134lib.define( 135 "transposed_im2row(Tensor input, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride, " 136 "int[2] output_padding, Tensor in_zero_point, bool channel_last=False) -> (Tensor out)" 137) 138lib.define( 139 "requantize(Tensor input, Tensor in_scale, Tensor in_zero_point, Tensor out_scale, " 140 "Tensor out_zero_point, ScalarType out_dtype) -> (Tensor Y)" 141) 142lib.define( 143 "fully_connected(Tensor input, Tensor weight, Tensor? bias=None) -> (Tensor out)" 144) 145lib.define( 146 "quantized_fully_connected(Tensor src, Tensor weight, Tensor bias, int src_zero_point, " 147 "Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, Tensor? offset) -> (Tensor Z)" 148) 149 150 151# ------------------------------------ # 152# Migrated from custom_ops.ymal # 153# ------------------------------------ # 154# Migrated from the custom_ops.yaml files containing different operator variants (e.g., .out, .tensor_out) 155lib.define( 156 "convolution.out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, " 157 "int groups, bool channel_last=False, *, Tensor(a!) out) -> Tensor(a!)" 158) 159lib.define( 160 "transposed_convolution.out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, " 161 "int[] dilation, SymInt[] output_padding, int groups, bool channel_last=False, *, Tensor(a!) out) -> Tensor(a!)" 162) 163# cadence::quantized_relu.out is defined in OSS 164lib.define( 165 "quantized_relu.per_tensor(Tensor X, int X_zero_point, int out_zero_point, int out_multiplier, int out_shift) -> Tensor" 166) 167lib.define( 168 "quantized_relu.per_tensor_out(Tensor X, int X_zero_point, int out_zero_point, int out_multiplier, " 169 "int out_shift, *, Tensor(a!) out) -> Tensor(a!)" 170) 171lib.define( 172 "quantized_add.out(Tensor X, Tensor X_scale, Tensor X_zero_point, Tensor Y, Tensor Y_scale, " 173 "Tensor Y_zero_point, float out_scale, int out_zero_point, *, Tensor(a!) out) -> Tensor(a!)" 174) 175lib.define( 176 "quantized_mul.out(Tensor X, Tensor X_scale, Tensor X_zero_point, Tensor Y, Tensor Y_scale, " 177 "Tensor Y_zero_point, float out_scale, int out_zero_point, *, Tensor(a!) out) -> Tensor(a!)" 178) 179lib.define( 180 "quantized_add_Scalar.out(Tensor X, Tensor X_scale, Tensor X_zero_point, Scalar Y, " 181 "float out_scale, int out_zero_point, *, Tensor(a!) out) -> Tensor(a!)" 182) 183lib.define( 184 "quantized_mul_Scalar.out(Tensor X, Tensor X_scale, Tensor X_zero_point, Scalar Y, " 185 "float out_scale, int out_zero_point, *, Tensor(a!) out) -> Tensor(a!)" 186) 187lib.define( 188 "fully_connected.out(Tensor input, Tensor weight, Tensor? bias=None, *, Tensor(a!) out) -> Tensor(a!)" 189) 190lib.define("linalg_vector_norm.out(Tensor X, *, Tensor(a!) out) -> Tensor(a!)") 191lib.define( 192 "quantized_fully_connected.out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, " 193 "Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)" 194) 195lib.define( 196 "quantized_embedding_byte.out(Tensor weight, Tensor weight_scales, Tensor weight_zero_points, " 197 "Tensor indices, bool pruned_weights=False, *, Tensor(a!) out) -> Tensor(a!)" 198) 199 200lib.define( 201 "quantized_transposed_conv.out(Tensor input, Tensor weight, Tensor bias, int[] stride, " 202 "SymInt[] padding, int[] dilation, SymInt[] output_padding, int groups, int input_zero_point, " 203 "Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, " 204 "Tensor out_multiplier, Tensor out_shift, bool channel_last=False, *, Tensor(a!) out) -> Tensor(a!)" 205) 206lib.define( 207 "avg_pool2d.out(Tensor input, int[2] kernel_size, int[2] stride=[], int[2] padding=0, " 208 "bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None, " 209 "Tensor? in_zero_point=None, bool channel_last=False, *, Tensor(a!) out) -> Tensor(a!)" 210) 211lib.define( 212 "im2row.out(Tensor input, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride, " 213 "Tensor in_zero_point, bool channel_last=False, *, Tensor(a!) out) -> Tensor(a!)" 214) 215lib.define( 216 "transposed_im2row.out(Tensor input, int[2] kernel_size, int[2] dilation, int[2] padding, " 217 "int[2] stride, int[2] output_padding, Tensor in_zero_point, bool channel_last=False, *, Tensor(a!) out) -> Tensor(a!)" 218) 219lib.define( 220 "requantize.out(Tensor input, Tensor in_scale, Tensor in_zero_point, Tensor out_scale, " 221 "Tensor out_zero_point, ScalarType out_dtype, *, Tensor(a!) out) -> Tensor(a!)" 222) 223 224 225# Custom ops with aten namespace. Need to specify the lib var as FRAGMENT type as aten library is already defined 226aten_lib = Library("aten", "FRAGMENT") 227aten_lib.define( 228 "chunk.out(Tensor self, int chunks, int dim=0, *, Tensor(a!)[] out) -> ()" 229) 230aten_lib.define( 231 "contiguous.out(Tensor self, *, MemoryFormat memory_format=contiguous_format, " 232 "Tensor(a!) out) -> Tensor(a!)" 233) 234aten_lib.define( 235 "tensor_split.sections_out(Tensor self, int sections, int dim=0, *, Tensor(a!)[] out) -> ()" 236) 237aten_lib.define( 238 "_slice_copy_nop(Tensor self, int dim=0, SymInt? start=None, SymInt? end=None, " 239 "SymInt step=1) -> Tensor(a!)" 240) 241aten_lib.define( 242 "_select_copy_nop.int_out(Tensor self, int dim, SymInt index, *, Tensor(a!) out) -> Tensor(a!)" 243) 244aten_lib.define( 245 "_slice_copy_nop.Tensor_out(Tensor self, int dim=0, SymInt? start=None, SymInt? end=None, " 246 "SymInt step=1, *, Tensor(a!) out) -> Tensor(a!)" 247) 248aten_lib.define("_cat_nop(Tensor[] tensors, int dim=0) -> Tensor(a!)") 249aten_lib.define( 250 "_cat_nop.out(Tensor[] tensors, int dim=0, *, Tensor(a!) out) -> Tensor(a!)" 251) 252 253# Custom ops with jarvis_nn_ops namespace 254jarvis_nn_lib = Library("jarvis_nn_ops", "DEF") 255jarvis_nn_lib.define( 256 "attention_mask.out(Tensor input, Tensor start, Tensor stop, *, Tensor(a!) out) -> Tensor(a!)" 257) 258 259m = Library("cadence", "IMPL", "Meta") 260 261 262@register_fake("cadence::quantize_per_tensor") 263def quantize_per_tensor_meta( 264 input: torch.Tensor, 265 scale: float, 266 zero_point: int, 267 quant_min: int, 268 quant_max: int, 269 dtype: torch.dtype, 270) -> torch.Tensor: 271 return input.new_empty(input.size(), dtype=dtype) 272 273 274@register_fake("cadence::dequantize_per_tensor") 275def dequantize_per_tensor_meta( 276 input: torch.Tensor, 277 scale: float, 278 zero_point: int, 279 quant_min: int, 280 quant_max: int, 281 dtype: torch.dtype, 282) -> torch.Tensor: 283 return input.new_empty(input.size(), dtype=torch.float) 284 285 286@register_fake("cadence::quantized_linear") 287def quantized_linear_meta( 288 src: torch.Tensor, 289 weight: torch.Tensor, 290 bias: torch.Tensor, 291 in_zero_point: int, 292 weight_zero_point: torch.Tensor, 293 out_multiplier: torch.Tensor, 294 out_shift: torch.Tensor, 295 out_zero_point: int, 296 offset: Optional[torch.Tensor], 297) -> torch.Tensor: 298 # src comes in shape [leading_dims, in_dim] 299 # weight comes in shape [out_dim, in_dim] 300 # output comes in empty with shape [leading_dims, out_dim] 301 out_size = list(src.size()) 302 weight_size = list(weight.size()) 303 assert len(weight_size) == 2 304 out_size[-1] = weight_size[0] 305 return src.new_empty(out_size, dtype=src.dtype) 306 307 308@register_fake("cadence::quantized_linear.per_tensor") 309def quantized_linear_per_tensor_meta( 310 src: torch.Tensor, 311 weight: torch.Tensor, 312 bias: torch.Tensor, 313 in_zero_point: torch.SymInt, 314 weight_zero_point: torch.SymInt, 315 out_multiplier: torch.SymInt, 316 out_shift: torch.SymInt, 317 out_zero_point: torch.SymInt, 318 offset: Optional[torch.Tensor], 319) -> torch.Tensor: 320 # src comes in shape [leading_dims, in_dim] 321 # weight comes in shape [out_dim, in_dim] 322 # output comes in empty with shape [leading_dims, out_dim] 323 out_size = list(src.size()) 324 weight_size = list(weight.size()) 325 assert len(weight_size) == 2 326 out_size[-1] = weight_size[0] 327 return src.new_empty(out_size, dtype=src.dtype) 328 329 330@register_fake("cadence::quantized_conv") 331def quantized_conv_meta( 332 input: torch.Tensor, 333 weight: torch.Tensor, 334 bias: torch.Tensor, 335 stride: Tuple[int], 336 padding: Tuple[int], 337 dilation: Tuple[int], 338 groups: int, 339 in_zero_point: int, 340 weight_zero_point: torch.Tensor, 341 bias_scale: torch.Tensor, 342 output_scale: float, 343 output_zero_point: int, 344 out_multiplier: torch.Tensor, 345 out_shift: torch.Tensor, 346 channel_last: bool = False, 347) -> torch.Tensor: 348 if channel_last: 349 out_channels, *kernel_size, _ = weight.shape 350 else: 351 out_channels, _, *kernel_size = weight.shape 352 353 in_size = input.shape 354 # Assert that the input tensor has at least 3 dimensions, and at most 6 355 assert len(in_size) > 2 356 assert len(in_size) < 6 357 358 # Compute the output tensor size 359 output_size = ( 360 get_conv1d_output_size( 361 in_size, 362 out_channels, 363 stride[1], 364 padding[1], 365 dilation[1], 366 kernel_size[0], 367 channel_last, 368 ) 369 if len(in_size) == 3 370 else get_conv2d_output_size( 371 in_size, out_channels, stride, padding, dilation, kernel_size, channel_last 372 ) 373 ) 374 375 return input.new_empty(output_size, dtype=input.dtype) 376 377 378@register_fake("cadence::quantized_conv.per_tensor") 379def quantized_conv_per_tensor_meta( 380 input: torch.Tensor, 381 weight: torch.Tensor, 382 bias: torch.Tensor, 383 stride: Tuple[int], 384 padding: Tuple[int], 385 dilation: Tuple[int], 386 groups: int, 387 in_zero_point: int, 388 weight_zero_point: int, 389 bias_scale: float, 390 output_scale: float, 391 output_zero_point: int, 392 out_multiplier: int, 393 out_shift: int, 394 channel_last: bool = False, 395) -> torch.Tensor: 396 if channel_last: 397 out_channels, *kernel_size, _ = weight.shape 398 else: 399 out_channels, _, *kernel_size = weight.shape 400 401 in_size = input.shape 402 # Assert that the input tensor has at least 3 dimensions, and at most 6 403 assert len(in_size) > 2 404 assert len(in_size) < 6 405 406 # Compute the output tensor size 407 output_size = ( 408 get_conv1d_output_size( 409 in_size, 410 out_channels, 411 stride[1], 412 padding[1], 413 dilation[1], 414 kernel_size[0], 415 channel_last, 416 ) 417 if len(in_size) == 3 418 else get_conv2d_output_size( 419 in_size, out_channels, stride, padding, dilation, kernel_size, channel_last 420 ) 421 ) 422 423 return input.new_empty(output_size, dtype=input.dtype) 424 425 426@register_fake("cadence::quantized_layer_norm") 427def quantized_layer_norm_meta( 428 input: torch.Tensor, 429 X_scale: torch.Tensor, 430 X_zero_point: torch.Tensor, 431 normalized_shape: int, 432 weight: torch.Tensor, 433 bias: torch.Tensor, 434 eps: float, 435 output_scale: float, 436 output_zero_point: int, 437) -> torch.Tensor: 438 return input.new_empty(input.size(), dtype=input.dtype) 439 440 441@register_fake("cadence::quantized_layer_norm.per_tensor") 442def quantized_layer_norm_per_tensor_meta( 443 input: torch.Tensor, 444 X_scale: float, 445 X_zero_point: int, 446 normalized_shape: int, 447 weight: torch.Tensor, 448 bias: torch.Tensor, 449 eps: float, 450 output_scale: float, 451 output_zero_point: int, 452) -> torch.Tensor: 453 return input.new_empty(input.size(), dtype=input.dtype) 454 455 456@register_fake("cadence::quantized_relu") 457def quantized_relu_meta( 458 X: torch.Tensor, 459 X_zero_point: torch.Tensor, 460 out_zero_point: int, 461 out_multiplier: torch.Tensor, 462 out_shift: torch.Tensor, 463) -> torch.Tensor: 464 return X.new_empty(X.size(), dtype=X.dtype) 465 466 467@register_fake("cadence::quantized_matmul") 468def quantized_matmul_meta( 469 X: torch.Tensor, 470 X_zero_point: int, 471 Y: torch.Tensor, 472 Y_zero_point: int, 473 bias: Optional[torch.Tensor], 474 out_multiplier: int, 475 out_shift: int, 476 out_zero_point: int, 477 transposed: bool = False, 478) -> torch.Tensor: 479 X_size = list(X.size()) 480 Y_size = list(Y.size()) 481 482 # Get the batch dimensions for both tensors 483 X_batch_dims = X_size[:-2] 484 Y_batch_dims = Y_size[:-2] 485 486 # If they don't match, check that they're compatible 487 if X_batch_dims != Y_batch_dims: 488 assert prod(X_batch_dims) == prod( 489 Y_batch_dims 490 ), f"Batch dimensions of X and Y do not match: {X_batch_dims} vs {Y_batch_dims}" 491 492 # Get the matmul output size 493 if transposed: 494 assert X_size[-1] == Y_size[-1], "matrices cannot be multiplied" 495 mat_size = [X_size[-2], Y_size[-2]] 496 else: 497 assert X_size[-1] == Y_size[-2], "matrices cannot be multiplied" 498 mat_size = [X_size[-2], Y_size[-1]] 499 500 # Combine the larger batch dimensions with the matmul output size 501 out_size = ( 502 X_batch_dims + mat_size 503 if len(X_batch_dims) > len(Y_batch_dims) 504 else Y_batch_dims + mat_size 505 ) 506 507 return X.new_empty(out_size, dtype=X.dtype) 508 509 510@register_fake("cadence::im2row") 511def im2row_meta( 512 input: torch.Tensor, 513 kernel_size: Tuple[int], 514 dilation: Tuple[int], 515 padding: Tuple[int], 516 stride: Tuple[int], 517 in_zero_point: torch.Tensor, 518 channel_last: bool = False, 519) -> torch.Tensor: 520 if len(input.shape) == 3: 521 height_dim = 1 if channel_last else 2 522 input = input.unsqueeze(height_dim) 523 524 batch_size = input.shape[0] 525 n_input_plane = input.shape[3] if channel_last else input.shape[1] 526 input_height = input.shape[1] if channel_last else input.shape[2] 527 input_width = input.shape[2] if channel_last else input.shape[3] 528 output_height = ( 529 input_height + 2 * padding[0] - (dilation[0] * (kernel_size[0] - 1) + 1) 530 ) // stride[0] + 1 531 output_width = ( 532 input_width + 2 * padding[1] - (dilation[1] * (kernel_size[1] - 1) + 1) 533 ) // stride[1] + 1 534 n_output_plane = n_input_plane * kernel_size[0] * kernel_size[1] 535 output_size = torch.Size((batch_size, output_height * output_width, n_output_plane)) 536 return input.new_empty(output_size, dtype=input.dtype) 537 538 539# Define the abstract implementations of the operators as required 540@register_fake("cadence::linalg_vector_norm") 541def linalg_vector_norm_meta( 542 X: torch.Tensor, 543) -> torch.Tensor: 544 # Output of norm is a scalar, so we return a [] tensor 545 return X.new_empty([], dtype=X.dtype) 546 547 548@register_fake("cadence::requantize") 549def requantize_meta( 550 input: torch.Tensor, 551 in_scale: torch.Tensor, 552 in_zero_point: torch.Tensor, 553 out_scale: torch.Tensor, 554 out_zero_point: torch.Tensor, 555 dtype: ScalarType, 556) -> torch.Tensor: 557 return input.new_empty( 558 input.size(), 559 # pyre-ignore[6]: Incompatible type 560 dtype=dtype, 561 ) 562 563 564@register_fake("cadence::quantized_relu.per_tensor") 565def quantized_relu_per_tensor_meta( 566 input: torch.Tensor, 567 in_zero_point: int, 568 out_zero_point: int, 569 out_multiplier: int, 570 out_shift: int, 571) -> torch.Tensor: 572 return input.new_empty(input.size(), dtype=torch.uint8) 573 574 575@register_fake("cadence::fully_connected") 576def fully_connected_meta( 577 src: torch.Tensor, 578 weight: torch.Tensor, 579 bias: torch.Tensor, 580) -> torch.Tensor: 581 # src comes in shape [leading_dims, in_dim] 582 # weight comes in shape [out_dim, in_dim] 583 # output comes in empty with shape [leading_dims, out_dim] 584 out_size = list(src.size()) 585 weight_size = list(weight.size()) 586 assert len(weight_size) == 2 587 out_size[-1] = weight_size[0] 588 return src.new_empty(out_size, dtype=src.dtype) 589 590 591@register_fake("cadence::quantized_fully_connected") 592def quantized_fully_connected_meta( 593 src: torch.Tensor, 594 weight: torch.Tensor, 595 bias: torch.Tensor, 596 in_zero_point: int, 597 weight_zero_point: torch.Tensor, 598 out_multiplier: int, 599 out_shift: int, 600 out_zero_point: int, 601 offset: Optional[torch.Tensor], 602) -> torch.Tensor: 603 # src comes in shape [leading_dims, in_dim] 604 # weight comes in shape [out_dim, in_dim] 605 # output comes in empty with shape [leading_dims, out_dim] 606 out_size = list(src.size()) 607 weight_size = list(weight.size()) 608 assert len(weight_size) == 2 609 out_size[-1] = weight_size[0] 610 return src.new_empty(out_size, dtype=torch.uint8) 611 612 613@register_fake("cadence::convolution") 614def convolution_meta( 615 input: torch.Tensor, 616 weight: torch.Tensor, 617 bias: torch.Tensor, 618 stride: Tuple[int], 619 padding: Tuple[int], 620 dilation: Tuple[int], 621 groups: int, 622 channel_last: bool = False, 623) -> torch.Tensor: 624 if channel_last: 625 out_channels, *kernel_size, _ = weight.shape 626 else: 627 out_channels, _, *kernel_size = weight.shape 628 in_size = input.shape 629 # Assert that the input tensor has at least 3 dimensions, and at most 6 630 assert len(in_size) > 2 631 assert len(in_size) < 6 632 633 # Compute the output tensor size 634 output_size = ( 635 get_conv1d_output_size( 636 in_size, 637 out_channels, 638 stride[0], 639 padding[0], 640 dilation[0], 641 kernel_size[0], 642 channel_last, 643 ) 644 if len(in_size) == 3 645 else get_conv2d_output_size( 646 in_size, out_channels, stride, padding, dilation, kernel_size, channel_last 647 ) 648 ) 649 650 return input.new_empty(output_size, dtype=input.dtype) 651 652 653@register_fake("cadence::transposed_convolution") 654def transposed_convolution_meta( 655 input: torch.Tensor, 656 weight: torch.Tensor, 657 bias: torch.Tensor, 658 stride: Tuple[int], 659 padding: Tuple[int], 660 dilation: Tuple[int], 661 output_padding: Tuple[int], 662 groups: int, 663 channel_last: bool = False, 664) -> torch.Tensor: 665 # The native definition of torch transposed conv will have weight shape as 666 # (in_channels, out_channels/groups, *kernel_size). 667 # However, the two channel position is flipped in the Jarvis pass of replacing it 668 # with cadence::transposed_convolution here: https://fburl.com/code/d2s7pkyy 669 out_channels, _input_channels, *kernel_size = weight.shape 670 out_channels *= groups 671 in_size = input.shape 672 673 # Get the output size of a transposed 1D convolution given the input size and parameters 674 def get_conv_transpose1d_output_size( 675 in_size: torch.Size, 676 kernel_size: list[int], 677 out_channels: int, 678 stride: Tuple[int], 679 padding: Tuple[int], 680 dilation: Tuple[int], 681 output_padding: Tuple[int], 682 channel_last: bool = False, 683 ) -> torch.Size: 684 assert len(in_size) == 3 685 if channel_last: 686 N, L, C = in_size 687 else: 688 N, C, L = in_size 689 690 # Reference: https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose1d.html 691 lout = ( 692 (L - 1) * stride[0] 693 - 2 * padding[0] 694 + dilation[0] * (kernel_size[0] - 1) 695 + output_padding[0] 696 + 1 697 ) 698 699 if channel_last: 700 return torch.Size((in_size[0], lout, out_channels)) 701 else: 702 return torch.Size((in_size[0], out_channels, lout)) 703 704 def get_conv_transpose2d_output_size( 705 in_size: torch.Size, 706 kernel_size: list[int], 707 out_channels: int, 708 stride: Tuple[int], 709 padding: Tuple[int], 710 dilation: Tuple[int], 711 output_padding: Tuple[int], 712 channel_last: bool = False, 713 ) -> torch.Size: 714 assert len(in_size) == 4 715 if channel_last: 716 N, H, W, C = in_size 717 else: 718 N, C, H, W = in_size 719 720 # Reference: https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html 721 hout = ( 722 (H - 1) * stride[0] 723 - 2 * padding[0] 724 + dilation[0] * (kernel_size[0] - 1) 725 + output_padding[0] 726 + 1 727 ) 728 wout = ( 729 (W - 1) * stride[1] 730 - 2 * padding[1] 731 + dilation[1] * (kernel_size[1] - 1) 732 + output_padding[1] 733 + 1 734 ) 735 736 if channel_last: 737 return torch.Size((in_size[0], hout, wout, out_channels)) 738 else: 739 return torch.Size((in_size[0], out_channels, hout, wout)) 740 741 # Compute the output tensor size 742 if len(in_size) == 3: 743 output_size = get_conv_transpose1d_output_size( 744 in_size, 745 kernel_size, 746 out_channels, 747 stride, 748 padding, 749 dilation, 750 output_padding, 751 channel_last, 752 ) 753 elif len(in_size) == 4: 754 output_size = get_conv_transpose2d_output_size( 755 in_size, 756 kernel_size, 757 out_channels, 758 stride, 759 padding, 760 dilation, 761 output_padding, 762 channel_last, 763 ) 764 else: 765 raise NotImplementedError( 766 f"transposed_convolution meta is not implemented for input tensor with {len(in_size)} dimensions" 767 ) 768 769 return input.new_empty(output_size, dtype=input.dtype) 770 771 772@register_fake("cadence::avg_pool2d") 773def avg_pool2d_meta( 774 input: torch.Tensor, 775 kernel_size: Tuple[int], 776 stride: Tuple[int], 777 padding: Tuple[int], 778 ceil_mode: bool, 779 count_include_pad: Optional[bool] = True, 780 divisor_override: Optional[int] = None, 781 in_zero_point: Optional[int] = None, 782 channel_last: bool = False, 783) -> torch.Tensor: 784 # Use torch native meta kernels when operator semantics are similar 785 return torch._meta_registrations.meta_avg_pool2d( 786 input, 787 kernel_size, 788 stride, 789 padding, 790 ceil_mode, 791 count_include_pad, 792 divisor_override, 793 ) 794 795 796@register_fake("cadence::transposed_im2row") 797def transposed_im2row_meta( 798 input: torch.Tensor, 799 kernel_size: Tuple[int], 800 dilation: Tuple[int], 801 padding: Tuple[int], 802 stride: Tuple[int], 803 output_padding: Tuple[int], 804 in_zero_point: torch.Tensor, 805 channel_last: bool = False, 806) -> torch.Tensor: 807 if len(input.shape) == 3: 808 height_dim = 1 if channel_last else 2 809 input = input.unsqueeze(height_dim) 810 811 batch_size = input.shape[0] 812 n_input_plane = input.shape[3] if channel_last else input.shape[1] 813 input_height = input.shape[1] if channel_last else input.shape[2] 814 input_width = input.shape[2] if channel_last else input.shape[3] 815 output_height = ( 816 (input_height - 1) * stride[0] 817 - 2 * padding[0] 818 + dilation[0] * (kernel_size[0] - 1) 819 + output_padding[0] 820 + 1 821 ) 822 output_width = ( 823 (input_width - 1) * stride[1] 824 - 2 * padding[1] 825 + dilation[1] * (kernel_size[1] - 1) 826 + output_padding[1] 827 + 1 828 ) 829 n_output_plane = n_input_plane * kernel_size[0] * kernel_size[1] 830 output_length = output_height * output_width 831 output_size = torch.Size((batch_size, output_length, n_output_plane)) 832 833 return input.new_empty(output_size, dtype=input.dtype) 834