1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import copy 8from typing import Callable, List, Optional, Tuple 9 10import torch 11from executorch.exir.dialects._ops import bind_pattern_to_op, ops as exir_ops 12from executorch.exir.passes.replace_aten_with_edge_pass import ( 13 aten_to_edge, 14 should_lower_to_edge, 15) 16from torch import fx 17from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib 18from torch.library import impl, register_fake 19 20 21__all__ = [ 22 "get_quant_patterns_and_replacements", 23] 24 25# TODO: extending an existing library that is defined in OSS might be a bit 26# confusing, we can investigate if it is possible to define a new library 27 28quantized_decomposed_lib.define( 29 "embedding_byte(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, " 30 "int weight_quant_min, int weight_quant_max, Tensor indices) -> Tensor", 31) 32 33quantized_decomposed_lib.define( 34 "embedding_byte.dtype(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, " 35 "int weight_quant_min, int weight_quant_max, Tensor indices, *, ScalarType? dtype=None) -> Tensor", 36) 37 38quantized_decomposed_lib.define( 39 "embedding_byte.out(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, " 40 "int weight_quant_min, int weight_quant_max, Tensor indices, *, Tensor(a!) out) -> Tensor(a!)", 41) 42 43quantized_decomposed_lib.define( 44 "embedding_byte.dtype_out(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, " 45 "int weight_quant_min, int weight_quant_max, Tensor indices, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)", 46) 47 48 49def embedding_weight_checks(weight, weight_scales, weight_zero_points): 50 assert weight.dtype in [ 51 torch.int8, 52 torch.uint8, 53 ], f"Expecting weights to be of dtype in [torch.int8, torch.uint8], but got {weight.dtype}" 54 assert ( 55 weight.dim() == 2 56 ), f"Expecting weight tensor to have dim()==2, but found {weight.dim()}" 57 58 assert weight_scales.dtype in [ 59 torch.float16, 60 torch.float32, 61 ], f"Expecting weight_scales to be of dtype in [torch.float16, torch.float32], but got {weight_scales.dtype}" 62 assert ( 63 weight_scales.dim() == 1 or weight_scales.dim() == 2 64 ), f"Expecting weight_scales tensor to have rank 1 or 2, but found {weight_scales.dim()}" 65 assert weight_scales.size(0) == weight.size( 66 0 67 ), f"Expecting weight and scale tensor to have same number of rows, but found {weight.size()} and {weight_scales.size()}" 68 69 assert ( 70 weight_zero_points is None or weight_zero_points.dtype == weight_scales.dtype 71 ), "Expecting weight_zero_points to be None or have same dtype as weight_scales" 72 assert ( 73 weight_zero_points is None or weight_zero_points.dim() == 1 74 ), f"Expecting weight_zero_points tensor to be None or have dim()==1, but found {weight_zero_points.dim()}" 75 assert weight_zero_points is None or weight_zero_points.size(0) == weight.size( 76 0 77 ), f"Expecting weight_zero_points tensor to be None or have same number of rows as weights, but found {weight.size()} and {weight_zero_points.size()}" 78 79 80@impl(quantized_decomposed_lib, "embedding_byte", "CompositeExplicitAutograd") 81def embedding_byte( 82 weight: torch.Tensor, 83 weight_scales: torch.Tensor, 84 weight_zero_points: Optional[torch.Tensor], 85 weight_quant_min: int, 86 weight_quant_max: int, 87 indices: torch.Tensor, 88) -> torch.Tensor: 89 embedding_weight_checks(weight, weight_scales, weight_zero_points) 90 group_size = weight.size(1) // ( 91 weight_scales.size(1) if weight_scales.dim() == 2 else 1 92 ) 93 weight = torch.ops.quantized_decomposed.dequantize_per_channel_group.default( 94 weight, 95 weight_scales, 96 weight_zero_points, 97 weight_quant_min, 98 weight_quant_max, 99 weight.dtype, 100 group_size, 101 weight_scales.dtype, 102 ) 103 return torch.ops.aten.embedding.default(weight, indices) 104 105 106@register_fake("quantized_decomposed::embedding_byte.out") 107def embedding_byte_out_meta( 108 weight: torch.Tensor, 109 weight_scales: torch.Tensor, 110 weight_zero_points: Optional[torch.Tensor], 111 weight_quant_min: int, 112 weight_quant_max: int, 113 indices: torch.Tensor, 114 out: torch.Tensor, 115) -> torch.Tensor: 116 return embedding_byte( 117 weight, 118 weight_scales, 119 weight_zero_points, 120 weight_quant_min, 121 weight_quant_max, 122 indices, 123 ) 124 125 126@impl(quantized_decomposed_lib, "embedding_byte.dtype", "CompositeExplicitAutograd") 127def embedding_byte_dtype( 128 weight: torch.Tensor, 129 weight_scales: torch.Tensor, 130 weight_zero_points: Optional[torch.Tensor], 131 weight_quant_min: int, 132 weight_quant_max: int, 133 indices: torch.Tensor, 134 dtype: Optional[torch.dtype], 135) -> torch.Tensor: 136 embedding_weight_checks(weight, weight_scales, weight_zero_points) 137 group_size = weight.size(1) // ( 138 weight_scales.size(1) if weight_scales.dim() == 2 else 1 139 ) 140 weight = torch.ops.quantized_decomposed.dequantize_per_channel_group.default( 141 weight, 142 weight_scales, 143 weight_zero_points, 144 weight_quant_min, 145 weight_quant_max, 146 weight.dtype, 147 group_size, 148 dtype, 149 ) 150 return torch.ops.aten.embedding.default(weight, indices) 151 152 153@register_fake("quantized_decomposed::embedding_byte.dtype_out") 154def embedding_byte_dtype_out_meta( 155 weight: torch.Tensor, 156 weight_scales: torch.Tensor, 157 weight_zero_points: Optional[torch.Tensor], 158 weight_quant_min: int, 159 weight_quant_max: int, 160 indices: torch.Tensor, 161 dtype: Optional[torch.dtype], 162 out: torch.Tensor, 163) -> torch.Tensor: 164 return embedding_byte_dtype( 165 weight, 166 weight_scales, 167 weight_zero_points, 168 weight_quant_min, 169 weight_quant_max, 170 indices, 171 dtype, 172 ) 173 174 175quantized_decomposed_lib.define( 176 "embedding_2bit(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, " 177 "int weight_quant_min, int weight_quant_max, Tensor indices) -> Tensor", 178) 179 180quantized_decomposed_lib.define( 181 "embedding_2bit.dtype(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, " 182 "int weight_quant_min, int weight_quant_max, Tensor indices, *, ScalarType? dtype=None) -> Tensor", 183) 184 185quantized_decomposed_lib.define( 186 "embedding_2bit.out(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, " 187 "int weight_quant_min, int weight_quant_max, Tensor indices, *, Tensor(a!) out) -> Tensor(a!)", 188) 189 190quantized_decomposed_lib.define( 191 "embedding_2bit.dtype_out(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, " 192 "int weight_quant_min, int weight_quant_max, Tensor indices, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)", 193) 194 195 196@impl(quantized_decomposed_lib, "embedding_2bit", "CompositeExplicitAutograd") 197def embedding_2bit( 198 weight: torch.Tensor, 199 weight_scales: torch.Tensor, 200 weight_zero_points: Optional[torch.Tensor], 201 weight_quant_min: int, 202 weight_quant_max: int, 203 indices: torch.Tensor, 204) -> torch.Tensor: 205 embedding_weight_checks(weight, weight_scales, weight_zero_points) 206 group_size = (4 * weight.size(1)) // ( 207 weight_scales.size(1) if weight_scales.dim() == 2 else 1 208 ) 209 weight_0 = weight & 3 210 weight_1 = (weight & 12) >> 2 211 weight_2 = (weight & 48) >> 4 212 weight_3 = (weight & 192) >> 6 213 weight_unpacked = torch.stack((weight_0, weight_1, weight_2, weight_3), dim=-1) 214 weight = weight_unpacked.view(weight.shape[0], -1) 215 weight = weight.view(torch.int8).add(-2) 216 217 weight = torch.ops.quantized_decomposed.dequantize_per_channel_group.default( 218 weight, 219 weight_scales, 220 weight_zero_points, 221 weight_quant_min, 222 weight_quant_max, 223 weight.dtype, 224 group_size, 225 weight_scales.dtype, 226 ) 227 return torch.ops.aten.embedding.default(weight, indices) 228 229 230@register_fake("quantized_decomposed::embedding_2bit.out") 231def embedding_2bit_out_meta( 232 weight: torch.Tensor, 233 weight_scales: torch.Tensor, 234 weight_zero_points: Optional[torch.Tensor], 235 weight_quant_min: int, 236 weight_quant_max: int, 237 indices: torch.Tensor, 238 out: torch.Tensor, 239) -> torch.Tensor: 240 return embedding_2bit( 241 weight, 242 weight_scales, 243 weight_zero_points, 244 weight_quant_min, 245 weight_quant_max, 246 indices, 247 ) 248 249 250@impl(quantized_decomposed_lib, "embedding_2bit.dtype", "CompositeExplicitAutograd") 251def embedding_2bit_dtype( 252 weight: torch.Tensor, 253 weight_scales: torch.Tensor, 254 weight_zero_points: Optional[torch.Tensor], 255 weight_quant_min: int, 256 weight_quant_max: int, 257 indices: torch.Tensor, 258 dtype: Optional[torch.dtype], 259) -> torch.Tensor: 260 embedding_weight_checks(weight, weight_scales, weight_zero_points) 261 group_size = (4 * weight.size(1)) // ( 262 weight_scales.size(1) if weight_scales.dim() == 2 else 1 263 ) 264 weight_0 = weight & 3 265 weight_1 = (weight & 12) >> 2 266 weight_2 = (weight & 48) >> 4 267 weight_3 = (weight & 192) >> 6 268 weight_unpacked = torch.stack((weight_0, weight_1, weight_2, weight_3), dim=-1) 269 weight = weight_unpacked.view(weight.shape[0], -1) 270 weight = weight.view(torch.int8).add(-2) 271 272 weight = torch.ops.quantized_decomposed.dequantize_per_channel_group.default( 273 weight, 274 weight_scales, 275 weight_zero_points, 276 weight_quant_min, 277 weight_quant_max, 278 weight.dtype, 279 group_size, 280 dtype, 281 ) 282 return torch.ops.aten.embedding.default(weight, indices) 283 284 285@register_fake("quantized_decomposed::embedding_2bit.dtype_out") 286def embedding_2bit_dtype_out_meta( 287 weight: torch.Tensor, 288 weight_scales: torch.Tensor, 289 weight_zero_points: Optional[torch.Tensor], 290 weight_quant_min: int, 291 weight_quant_max: int, 292 indices: torch.Tensor, 293 dtype: Optional[torch.dtype], 294 out: torch.Tensor, 295) -> torch.Tensor: 296 return embedding_2bit_dtype( 297 weight, 298 weight_scales, 299 weight_zero_points, 300 weight_quant_min, 301 weight_quant_max, 302 indices, 303 dtype, 304 ) 305 306 307quantized_decomposed_lib.define( 308 "embedding_4bit(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, " 309 "int weight_quant_min, int weight_quant_max, Tensor indices) -> Tensor", 310) 311 312quantized_decomposed_lib.define( 313 "embedding_4bit.dtype(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, " 314 "int weight_quant_min, int weight_quant_max, Tensor indices, *, ScalarType? dtype=None) -> Tensor", 315) 316 317quantized_decomposed_lib.define( 318 "embedding_4bit.out(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, " 319 "int weight_quant_min, int weight_quant_max, Tensor indices, *, Tensor(a!) out) -> Tensor(a!)", 320) 321 322quantized_decomposed_lib.define( 323 "embedding_4bit.dtype_out(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, " 324 "int weight_quant_min, int weight_quant_max, Tensor indices, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)", 325) 326 327 328@impl(quantized_decomposed_lib, "embedding_4bit", "CompositeExplicitAutograd") 329def embedding_4bit( 330 weight: torch.Tensor, 331 weight_scales: torch.Tensor, 332 weight_zero_points: Optional[torch.Tensor], 333 weight_quant_min: int, 334 weight_quant_max: int, 335 indices: torch.Tensor, 336) -> torch.Tensor: 337 embedding_weight_checks(weight, weight_scales, weight_zero_points) 338 group_size = (2 * weight.size(1)) // ( 339 weight_scales.size(1) if weight_scales.dim() == 2 else 1 340 ) 341 weight_even = weight.div(16, rounding_mode="trunc") 342 weight_odd = weight.remainder(16) 343 weight_unpacked = torch.stack((weight_even, weight_odd), dim=-1) 344 weight = weight_unpacked.view(weight.shape[0], -1) 345 weight = weight.view(torch.int8).add(-8) 346 347 weight = torch.ops.quantized_decomposed.dequantize_per_channel_group.default( 348 weight, 349 weight_scales, 350 weight_zero_points, 351 weight_quant_min, 352 weight_quant_max, 353 weight.dtype, 354 group_size, 355 weight_scales.dtype, 356 ) 357 return torch.ops.aten.embedding.default(weight, indices) 358 359 360@register_fake("quantized_decomposed::embedding_4bit.out") 361def embedding_4bit_out_meta( 362 weight: torch.Tensor, 363 weight_scales: torch.Tensor, 364 weight_zero_points: Optional[torch.Tensor], 365 weight_quant_min: int, 366 weight_quant_max: int, 367 indices: torch.Tensor, 368 out: torch.Tensor, 369) -> torch.Tensor: 370 return embedding_4bit( 371 weight, 372 weight_scales, 373 weight_zero_points, 374 weight_quant_min, 375 weight_quant_max, 376 indices, 377 ) 378 379 380@impl(quantized_decomposed_lib, "embedding_4bit.dtype", "CompositeExplicitAutograd") 381def embedding_4bit_dtype( 382 weight: torch.Tensor, 383 weight_scales: torch.Tensor, 384 weight_zero_points: Optional[torch.Tensor], 385 weight_quant_min: int, 386 weight_quant_max: int, 387 indices: torch.Tensor, 388 dtype: Optional[torch.dtype], 389) -> torch.Tensor: 390 embedding_weight_checks(weight, weight_scales, weight_zero_points) 391 group_size = (2 * weight.size(1)) // ( 392 weight_scales.size(1) if weight_scales.dim() == 2 else 1 393 ) 394 weight_even = weight.div(16, rounding_mode="trunc") 395 weight_odd = weight.remainder(16) 396 weight_unpacked = torch.stack((weight_even, weight_odd), dim=-1) 397 weight = weight_unpacked.view(weight.shape[0], -1) 398 weight = weight.view(torch.int8).add(-8) 399 400 weight = torch.ops.quantized_decomposed.dequantize_per_channel_group.default( 401 weight, 402 weight_scales, 403 weight_zero_points, 404 weight_quant_min, 405 weight_quant_max, 406 weight.dtype, 407 group_size, 408 dtype, 409 ) 410 return torch.ops.aten.embedding.default(weight, indices) 411 412 413@register_fake("quantized_decomposed::embedding_4bit.dtype_out") 414def embedding_4bit_dtype_out_meta( 415 weight: torch.Tensor, 416 weight_scales: torch.Tensor, 417 weight_zero_points: Optional[torch.Tensor], 418 weight_quant_min: int, 419 weight_quant_max: int, 420 indices: torch.Tensor, 421 dtype: Optional[torch.dtype], 422 out: torch.Tensor, 423) -> torch.Tensor: 424 return embedding_4bit_dtype( 425 weight, 426 weight_scales, 427 weight_zero_points, 428 weight_quant_min, 429 weight_quant_max, 430 indices, 431 dtype, 432 ) 433 434 435quantized_decomposed_lib.define( 436 "mixed_mm(Tensor input, Tensor weight, Tensor weight_scales, Tensor? weight_zero_points) -> Tensor", 437) 438 439quantized_decomposed_lib.define( 440 "mixed_linear(Tensor input, Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, ScalarType? dtype=None) -> Tensor", 441) 442 443quantized_decomposed_lib.define( 444 "add(Tensor a, float a_scale, int a_zero_point, int a_quant_min, int a_quant_max, Tensor b, float b_scale, int b_zero_point, int b_quant_min, int b_quant_max, float out_scale, int out_zero_point, int out_quant_min, int out_quant_max) -> Tensor qc" 445) 446 447quantized_decomposed_lib.define( 448 "add.scalar(Tensor qa, float a_scale, int a_zero_point, int a_quant_min, int a_quant_max, ScalarType a_dtype, Scalar b, float out_scale, int out_zero_point, int out_quant_min, int out_quant_max, ScalarType out_dtype) -> Tensor" 449) 450 451quantized_decomposed_lib.define( 452 "add_relu(Tensor a, float a_scale, int a_zero_point, int a_quant_min, int a_quant_max, Tensor b, float b_scale, int b_zero_point, int b_quant_min, int b_quant_max, float out_scale, int out_zero_point, int out_quant_min, int out_quant_max) -> Tensor qc" 453) 454 455 456def _trace_and_lower_to_edge_ops(f: Callable) -> fx.GraphModule: 457 gm = fx.symbolic_trace(f) 458 for node in gm.graph.nodes: 459 if node.op == "call_function" and should_lower_to_edge(node.target): 460 node.target = aten_to_edge(node.target) 461 gm.recompile() 462 return gm 463 464 465def _sixth_input_is_scalar(match, original_graph, pattern_graph): 466 """check the node that's matched to the sixth input of the pattern graph 467 468 is a scalar number 469 """ 470 input_idx = 0 471 for node in pattern_graph.nodes: 472 if node.op == "placeholder": 473 if input_idx == 5: 474 num_node = node 475 input_idx += 1 476 if not isinstance(match.nodes_map[num_node], (int, float)): 477 return False 478 return True 479 480 481def _get_binary_op_patterns_and_replacements( 482 binary_op: Callable, 483 qbinary_op: Callable, 484 qbinary_scalar_op: Callable, 485 qbinary_relu_op: Callable, 486) -> List[Tuple[Callable, Callable]]: 487 @bind_pattern_to_op(quantized_decomposed_lib, qbinary_op.name()) 488 def binary_op_pattern( 489 x, 490 x_scale, 491 x_zero_point, 492 x_qmin, 493 x_qmax, 494 y, 495 y_scale, 496 y_zero_point, 497 y_qmin, 498 y_qmax, 499 out_scale, 500 out_zero_point, 501 out_qmin, 502 out_qmax, 503 ): 504 x = torch.ops.quantized_decomposed.dequantize_per_tensor.default( 505 x, x_scale, x_zero_point, x_qmin, x_qmax, torch.uint8 506 ) 507 y = torch.ops.quantized_decomposed.dequantize_per_tensor.default( 508 y, y_scale, y_zero_point, y_qmin, y_qmax, torch.uint8 509 ) 510 511 out = binary_op(x, y) 512 out = torch.ops.quantized_decomposed.quantize_per_tensor.default( 513 out, out_scale, out_zero_point, out_qmin, out_qmax, torch.uint8 514 ) 515 516 return out 517 518 def binary_op_replacement( 519 x, 520 x_scale, 521 x_zero_point, 522 x_qmin, 523 x_qmax, 524 y, 525 y_scale, 526 y_zero_point, 527 y_qmin, 528 y_qmax, 529 out_scale, 530 out_zero_point, 531 out_qmin, 532 out_qmax, 533 ): 534 out = qbinary_op( 535 x, 536 x_scale, 537 x_zero_point, 538 x_qmin, 539 x_qmax, 540 y, 541 y_scale, 542 y_zero_point, 543 y_qmin, 544 y_qmax, 545 out_scale, 546 out_zero_point, 547 out_qmin, 548 out_qmax, 549 ) 550 551 return out 552 553 @bind_pattern_to_op(quantized_decomposed_lib, qbinary_scalar_op.name()) 554 def binary_op_scalar_1_pattern( 555 x, 556 x_scale, 557 x_zero_point, 558 x_qmin, 559 x_qmax, 560 num, 561 out_scale, 562 out_zero_point, 563 out_qmin, 564 out_qmax, 565 ): 566 x = torch.ops.quantized_decomposed.dequantize_per_tensor.default( 567 x, x_scale, x_zero_point, x_qmin, x_qmax, torch.uint8 568 ) 569 570 out = binary_op(x, num) 571 out = torch.ops.quantized_decomposed.quantize_per_tensor.default( 572 out, out_scale, out_zero_point, out_qmin, out_qmax, torch.uint8 573 ) 574 575 return out 576 577 def binary_op_scalar_1_replacement( 578 x, 579 x_scale, 580 x_zero_point, 581 x_qmin, 582 x_qmax, 583 num, 584 out_scale, 585 out_zero_point, 586 out_qmin, 587 out_qmax, 588 ): 589 out = qbinary_scalar_op( 590 x, 591 x_scale, 592 x_zero_point, 593 x_qmin, 594 x_qmax, 595 num, 596 out_scale, 597 out_zero_point, 598 out_qmin, 599 out_qmax, 600 ) 601 602 return out 603 604 @bind_pattern_to_op(quantized_decomposed_lib, qbinary_scalar_op.name()) 605 def binary_op_scalar_2_pattern( 606 x, 607 x_scale, 608 x_zero_point, 609 x_qmin, 610 x_qmax, 611 num, 612 out_scale, 613 out_zero_point, 614 out_qmin, 615 out_qmax, 616 ): 617 x = torch.ops.quantized_decomposed.dequantize_per_tensor.default( 618 x, x_scale, x_zero_point, x_qmin, x_qmax, torch.uint8 619 ) 620 621 out = binary_op(num, x) 622 out = torch.ops.quantized_decomposed.quantize_per_tensor.default( 623 out, out_scale, out_zero_point, out_qmin, out_qmax, torch.uint8 624 ) 625 626 return out 627 628 def binary_op_scalar_2_replacement( 629 x, 630 x_scale, 631 x_zero_point, 632 x_qmin, 633 x_qmax, 634 num, 635 out_scale, 636 out_zero_point, 637 out_qmin, 638 out_qmax, 639 ): 640 out = qbinary_scalar_op( 641 x, 642 x_scale, 643 x_zero_point, 644 x_qmin, 645 x_qmax, 646 num, 647 out_scale, 648 out_zero_point, 649 out_qmin, 650 out_qmax, 651 ) 652 653 return out 654 655 @bind_pattern_to_op(quantized_decomposed_lib, qbinary_relu_op.name()) 656 def binary_relu_op_pattern( 657 x, 658 x_scale, 659 x_zero_point, 660 x_qmin, 661 x_qmax, 662 y, 663 y_scale, 664 y_zero_point, 665 y_qmin, 666 y_qmax, 667 out_scale, 668 out_zero_point, 669 out_qmin, 670 out_qmax, 671 ): 672 x = torch.ops.quantized_decomposed.dequantize_per_tensor.default( 673 x, x_scale, x_zero_point, x_qmin, x_qmax, torch.uint8 674 ) 675 y = torch.ops.quantized_decomposed.dequantize_per_tensor.default( 676 y, y_scale, y_zero_point, y_qmin, y_qmax, torch.uint8 677 ) 678 679 out = binary_op(x, y) 680 out = torch.ops.aten.relu.default(out) 681 out = torch.ops.quantized_decomposed.quantize_per_tensor.default( 682 out, out_scale, out_zero_point, out_qmin, out_qmax, torch.uint8 683 ) 684 685 return out 686 687 def binary_relu_op_replacement( 688 x, 689 x_scale, 690 x_zero_point, 691 x_qmin, 692 x_qmax, 693 y, 694 y_scale, 695 y_zero_point, 696 y_qmin, 697 y_qmax, 698 out_scale, 699 out_zero_point, 700 out_qmin, 701 out_qmax, 702 ): 703 out = qbinary_relu_op( 704 x, 705 x_scale, 706 x_zero_point, 707 x_qmin, 708 x_qmax, 709 y, 710 y_scale, 711 y_zero_point, 712 y_qmin, 713 y_qmax, 714 out_scale, 715 out_zero_point, 716 out_qmin, 717 out_qmax, 718 ) 719 720 return out 721 722 return [ 723 ( 724 _trace_and_lower_to_edge_ops(binary_relu_op_pattern), 725 _trace_and_lower_to_edge_ops(binary_relu_op_replacement), 726 [], 727 ), 728 ( 729 _trace_and_lower_to_edge_ops(binary_op_pattern), 730 _trace_and_lower_to_edge_ops(binary_op_replacement), 731 [], 732 ), 733 ( 734 _trace_and_lower_to_edge_ops(binary_op_scalar_1_pattern), 735 _trace_and_lower_to_edge_ops(binary_op_scalar_1_replacement), 736 [_sixth_input_is_scalar], 737 ), 738 ( 739 _trace_and_lower_to_edge_ops(binary_op_scalar_2_pattern), 740 _trace_and_lower_to_edge_ops(binary_op_scalar_2_replacement), 741 [_sixth_input_is_scalar], 742 ), 743 ] 744 745 746def _get_binary_ops_patterns_and_replacements() -> ( 747 List[Tuple[Callable, Callable, List[Callable]]] 748): 749 750 # TODO: replace qbinary op with the ops implemented in lean mode 751 binary_op_to_qbinary_ops = { 752 exir_ops.edge.aten.add.Tensor: ( 753 exir_ops.edge.quantized_decomposed.add.default, 754 exir_ops.edge.quantized_decomposed.add.scalar, 755 exir_ops.edge.quantized_decomposed.add_relu.default, 756 ), 757 } 758 pattern_and_replacements = [] 759 for binary_op, (qbop, qbscalar_op, qbrelu_op) in binary_op_to_qbinary_ops.items(): 760 pattern_and_replacements.extend( 761 _get_binary_op_patterns_and_replacements( 762 binary_op, qbop, qbscalar_op, qbrelu_op 763 ) 764 ) 765 766 return pattern_and_replacements 767 768 769def _get_reshape_patterns_and_replacements() -> ( 770 List[Tuple[Callable, Callable, List[Callable]]] 771): 772 def pattern( 773 x, 774 arg0, 775 arg1, 776 x_scale, 777 x_zero_point, 778 x_qmin, 779 x_qmax, 780 out_scale, 781 out_zero_point, 782 out_qmin, 783 out_qmax, 784 ): 785 x = torch.ops.quantized_decomposed.dequantize_per_tensor.default( 786 x, x_scale, x_zero_point, x_qmin, x_qmax, torch.uint8 787 ) 788 789 x = torch.ops.aten._reshape_alias_copy.default(x, arg0, arg1) 790 x = torch.ops.quantized_decomposed.quantize_per_tensor.default( 791 x, out_scale, out_zero_point, out_qmin, out_qmax, torch.uint8 792 ) 793 794 return x 795 796 def replacement( 797 x, 798 arg0, 799 arg1, 800 x_scale, 801 x_zero_point, 802 x_qmin, 803 x_qmax, 804 out_scale, 805 out_zero_point, 806 out_qmin, 807 out_qmax, 808 ): 809 810 x = torch.ops.aten._reshape_alias_copy.default(x, arg0, arg1) 811 return x 812 813 return [ 814 ( 815 _trace_and_lower_to_edge_ops(pattern), 816 _trace_and_lower_to_edge_ops(replacement), 817 [], 818 ) 819 ] 820 821 822def _get_slice_patterns_and_replacements() -> ( 823 List[Tuple[Callable, Callable, List[Callable]]] 824): 825 def pattern(x, dim, start, end, x_scale, x_zero_point, x_qmin, x_qmax): 826 x = torch.ops.quantized_decomposed.dequantize_per_tensor.default( 827 x, x_scale, x_zero_point, x_qmin, x_qmax, torch.uint8 828 ) 829 x = torch.ops.aten.slice_copy.Tensor(x, dim, start, end) 830 x = torch.ops.quantized_decomposed.dequantize_per_tensor.default( 831 x, x_scale, x_zero_point, x_qmin, x_qmax, torch.uint8 832 ) 833 return x 834 835 def replacement(x, dim, start, end, x_scale, x_zero_point, x_qmin, x_qmax): 836 x = torch.ops.aten.slice_copy.Tensor(x, dim, start, end) 837 return x 838 839 return [ 840 ( 841 _trace_and_lower_to_edge_ops(pattern), 842 _trace_and_lower_to_edge_ops(replacement), 843 [], 844 ) 845 ] 846 847 848def _get_embedding_ops_patterns_and_replacements() -> ( 849 List[Tuple[Callable, Callable, List[Callable]]] 850): 851 def get_pattern_and_replacement(): 852 @bind_pattern_to_op(quantized_decomposed_lib, "embedding_byte") 853 def pattern( 854 weight, 855 weight_scales, 856 weight_zero_points, 857 weight_quant_min, 858 weight_quant_max, 859 indicies, 860 ): 861 weight = torch.ops.quantized_decomposed.dequantize_per_channel.default( 862 weight, 863 weight_scales, 864 weight_zero_points, 865 0, 866 weight_quant_min, 867 weight_quant_max, 868 torch.uint8, 869 ) 870 out = torch.ops.aten.embedding.default(weight, indicies) 871 return out 872 873 def replacement( 874 weight, 875 weight_scales, 876 weight_zero_points, 877 weight_quant_min, 878 weight_quant_max, 879 indicies, 880 ): 881 out = torch.ops.quantized_decomposed.embedding_byte.default( 882 weight, 883 weight_scales, 884 weight_zero_points, 885 weight_quant_min, 886 weight_quant_max, 887 indicies, 888 ) 889 return out 890 891 @bind_pattern_to_op(quantized_decomposed_lib, "embedding_byte") 892 def pattern_groupwise( 893 weight, 894 weight_scales, 895 weight_zero_points, 896 weight_quant_min, 897 weight_quant_max, 898 indices, 899 group_size, 900 ): 901 weight = ( 902 torch.ops.quantized_decomposed.dequantize_per_channel_group.default( 903 weight, 904 weight_scales, 905 weight_zero_points, 906 weight_quant_min, 907 weight_quant_max, 908 weight.dtype, 909 group_size, 910 weight_scales.dtype, 911 ) 912 ) 913 out = torch.ops.aten.embedding.default(weight, indices) 914 return out 915 916 def replacement_groupwise( 917 weight, 918 weight_scales, 919 weight_zero_points, 920 weight_quant_min, 921 weight_quant_max, 922 indices, 923 group_size, 924 ): 925 out = torch.ops.quantized_decomposed.embedding_byte.default( 926 weight, 927 weight_scales, 928 weight_zero_points, 929 weight_quant_min, 930 weight_quant_max, 931 indices, 932 ) 933 return out 934 935 @bind_pattern_to_op(quantized_decomposed_lib, "embedding_byte") 936 def pattern_with_padding_idx( 937 weight, 938 weight_scales, 939 weight_zero_points, 940 weight_quant_min, 941 weight_quant_max, 942 indicies, 943 padding_idx, 944 ): 945 weight = torch.ops.quantized_decomposed.dequantize_per_channel.default( 946 weight, 947 weight_scales, 948 weight_zero_points, 949 0, 950 weight_quant_min, 951 weight_quant_max, 952 torch.uint8, 953 ) 954 out = torch.ops.aten.embedding.default(weight, indicies, padding_idx) 955 return out 956 957 def replacement_with_padding_idx( 958 weight, 959 weight_scales, 960 weight_zero_points, 961 weight_quant_min, 962 weight_quant_max, 963 indicies, 964 _, # padding_idx only matters for training and not when running op for inference 965 ): 966 out = torch.ops.quantized_decomposed.embedding_byte.default( 967 weight, 968 weight_scales, 969 weight_zero_points, 970 weight_quant_min, 971 weight_quant_max, 972 indicies, 973 ) 974 return out 975 976 @bind_pattern_to_op(quantized_decomposed_lib, "embedding_byte") 977 def pattern_with_padding_idx_groupwise( 978 weight, 979 weight_scales, 980 weight_zero_points, 981 weight_quant_min, 982 weight_quant_max, 983 indices, 984 group_size, 985 padding_idx, 986 ): 987 weight = ( 988 torch.ops.quantized_decomposed.dequantize_per_channel_group.default( 989 weight, 990 weight_scales, 991 weight_zero_points, 992 weight_quant_min, 993 weight_quant_max, 994 weight.dtype, 995 group_size, 996 weight_scales.dtype, 997 ) 998 ) 999 out = torch.ops.aten.embedding.default(weight, indices, padding_idx) 1000 return out 1001 1002 def replacement_with_padding_idx_groupwise( 1003 weight, 1004 weight_scales, 1005 weight_zero_points, 1006 weight_quant_min, 1007 weight_quant_max, 1008 indices, 1009 group_size, 1010 _, # padding_idx only matters for training and not when running op for inference 1011 ): 1012 out = torch.ops.quantized_decomposed.embedding_byte.default( 1013 weight, 1014 weight_scales, 1015 weight_zero_points, 1016 weight_quant_min, 1017 weight_quant_max, 1018 indices, 1019 ) 1020 return out 1021 1022 @bind_pattern_to_op(quantized_decomposed_lib, "embedding_byte.dtype") 1023 def pattern_with_dtype_groupwise( 1024 weight, 1025 weight_scales, 1026 weight_zero_points, 1027 weight_quant_min, 1028 weight_quant_max, 1029 indices, 1030 group_size, 1031 dtype, 1032 ): 1033 weight = ( 1034 torch.ops.quantized_decomposed.dequantize_per_channel_group.default( 1035 weight, 1036 weight_scales, 1037 weight_zero_points, 1038 weight_quant_min, 1039 weight_quant_max, 1040 weight.dtype, 1041 group_size, 1042 dtype, 1043 ) 1044 ) 1045 out = torch.ops.aten.embedding.default(weight, indices) 1046 return out 1047 1048 def replacement_with_dtype_groupwise( 1049 weight, 1050 weight_scales, 1051 weight_zero_points, 1052 weight_quant_min, 1053 weight_quant_max, 1054 indices, 1055 group_size, 1056 dtype, 1057 ): 1058 out = torch.ops.quantized_decomposed.embedding_byte.dtype( 1059 weight, 1060 weight_scales, 1061 weight_zero_points, 1062 weight_quant_min, 1063 weight_quant_max, 1064 indices, 1065 dtype=dtype, 1066 ) 1067 return out 1068 1069 return [ 1070 ( 1071 _trace_and_lower_to_edge_ops(pattern), 1072 _trace_and_lower_to_edge_ops(replacement), 1073 [], 1074 ), 1075 ( 1076 _trace_and_lower_to_edge_ops(pattern_groupwise), 1077 _trace_and_lower_to_edge_ops(replacement_groupwise), 1078 [], 1079 ), 1080 ( 1081 _trace_and_lower_to_edge_ops(pattern_with_padding_idx), 1082 _trace_and_lower_to_edge_ops(replacement_with_padding_idx), 1083 [], 1084 ), 1085 ( 1086 _trace_and_lower_to_edge_ops(pattern_with_padding_idx_groupwise), 1087 _trace_and_lower_to_edge_ops(replacement_with_padding_idx_groupwise), 1088 [], 1089 ), 1090 ( 1091 _trace_and_lower_to_edge_ops(pattern_with_dtype_groupwise), 1092 _trace_and_lower_to_edge_ops(replacement_with_dtype_groupwise), 1093 [], 1094 ), 1095 ] 1096 1097 patterns_and_replacements = [] 1098 patterns_and_replacements.extend( 1099 get_pattern_and_replacement(), 1100 ) 1101 return patterns_and_replacements 1102 1103 1104""" 1105def _get_fixed_qparams_ops_patterns_and_replacements() -> List[Tuple[Callable, Callable, List[Callable]]]: 1106 fixed_qparams_op_to_qop = { 1107 torch.ops.aten.softmax: (torch.ops.quantized_decomposed.softmax, 1.0 / 256.0, 0) 1108 } 1109 def get_pattern_and_replacement(fixed_qparams_op, fixed_scale, fixed_zero_point): 1110 def pattern(x, x_scale, x_zero_point, x_qmin, x_qmax): 1111 x = torch.ops.quantized_decomposed.dequantize_per_tensor.default(x, x_scale, x_zero_point, x_qmin, x_qmax, torch.uint8) 1112 x = fixed_qparams_op(x) 1113 x = torch.ops.quantized_decomposed.dequantize_per_tensor.default(x, fixed_scale, fixed_zero_point, 0, 255, torch.uint8) 1114 return x 1115 1116 def replacement(x, x_scale, x_zero_point, x_qmin, x_qmax): 1117 x = fixed_qparams_qop(x, x_scale, x_zero_point, x_qmin, x_qmax, torch.uint8) 1118 return x 1119 1120n return [(pattern, replacement, [])] 1121 1122 patterns_and_replacements = [] 1123 for op, (qop, fixed_scale, fixed_zero_point) in fixed_qparams_op_to_qop.items(): 1124 patterns_and_replacements.extend( 1125 get_pattern_and_replacement(op, qop, fixed_scale, fixed_zero_point) 1126 ) 1127""" 1128 1129 1130def get_quant_patterns_and_replacements() -> ( 1131 List[Tuple[Callable, Callable, List[Callable]]] 1132): 1133 1134 return copy.copy( 1135 [ 1136 *_get_binary_ops_patterns_and_replacements(), 1137 # TODO: enable following after the corresponding ops are implemented 1138 *_get_reshape_patterns_and_replacements(), 1139 *_get_slice_patterns_and_replacements(), 1140 # *_get_fixed_qparams_ops_patterns_and_replacements(), 1141 *_get_embedding_ops_patterns_and_replacements(), 1142 ] 1143 ) 1144