1# 2# Copyright (c) 2023 Apple Inc. All rights reserved. 3# Provided subject to the LICENSE file in the top level directory. 4# 5 6""" 7Please refer to executorch/backends/apple/mps/serialization/schema.fbs for the schema definitions 8""" 9 10from dataclasses import dataclass, field 11from enum import IntEnum 12from typing import List, Optional, Union 13 14 15class MPSDataType(IntEnum): 16 mps_data_type_invalid = 0 17 mps_data_type_float16 = 1 18 mps_data_type_float32 = 2 19 mps_data_type_float64 = 3 20 mps_data_type_bfloat16 = 4 21 22 # Signed integers. 23 mps_data_type_int4 = 5 24 mps_data_type_int8 = 6 25 mps_data_type_int16 = 7 26 mps_data_type_int32 = 8 27 mps_data_type_int64 = 9 28 29 # Unsigned integers. range: [0, UTYPE_MAX] 30 mps_data_type_uint4 = 10 31 mps_data_type_uint8 = 11 32 mps_data_type_uint16 = 12 33 mps_data_type_uint32 = 13 34 mps_data_type_uint64 = 14 35 36 mps_data_type_bool = 15 37 38 mps_data_type_complex_float16 = 16 39 mps_data_type_complex_float32 = 17 40 41 42class OpType(IntEnum): 43 mps_graph = 0 44 metal_kernel = 1 45 46 47@dataclass 48class MPSNode1x1: 49 input1_id: int 50 output_id: int 51 52 53@dataclass 54class MPSNode2x1: 55 input1_id: int 56 input2_id: int 57 output_id: int 58 59 60@dataclass 61class MPSDivNode2x1(MPSNode2x1): 62 rounding_mode: str = None 63 64 65@dataclass 66class MPSNode3x1: 67 input1_id: int 68 input2_id: int 69 input3_id: int 70 output_id: int 71 72 73@dataclass 74class MPSDequantizeNode(MPSNode1x1): 75 scales_id: int 76 zero_points_id: int 77 78 79@dataclass 80class MPSConv(MPSNode3x1): 81 stride_x: int = 0 82 stride_y: int = 0 83 dilation_x: int = 0 84 dilation_y: int = 0 85 groups: int = 0 86 padding_left: int = 0 87 padding_right: int = 0 88 padding_top: int = 0 89 padding_bottom: int = 0 90 91 92@dataclass 93class MPSPooling2D: 94 input1_id: int 95 kernel_height: int 96 kernel_width: int 97 stride_height: int 98 stride_width: int 99 padding_left: int 100 padding_right: int 101 padding_top: int 102 padding_bottom: int 103 dilation_height: int 104 dilation_width: int 105 ceil_mode: bool 106 output1_id: int 107 output2_id: int = -1 108 count_include_pad: bool = True 109 divisor_override: int = 0 110 111 112@dataclass 113class MPSMinMax: 114 min_value: Union[float, str] = "-inf" 115 max_value: Union[float, str] = "inf" 116 117 118## 119## Activation ops 120## 121@dataclass 122class MPSHardTanh(MPSNode1x1): 123 min_value: float = 0.0 124 max_value: float = 0.0 125 126 127@dataclass 128class MPSReLU(MPSNode1x1): 129 pass 130 131 132@dataclass 133class MPSGELU(MPSNode1x1): 134 approximate: str = "none" 135 136 137@dataclass 138class MPSLeakyReLU(MPSNode1x1): 139 negative_slope: float = 0.01 140 141 142@dataclass 143class MPSSoftmax(MPSNode1x1): 144 dim: int = 0 145 half_to_float: bool = False 146 147 148@dataclass 149class MPSLogSoftmax(MPSNode1x1): 150 dim: int = 0 151 half_to_float: bool = False 152 153 154## 155## Binary ops 156## 157@dataclass 158class MPSAdd(MPSNode2x1): 159 alpha: float = 1.0 160 161 162@dataclass 163class MPSSub(MPSNode2x1): 164 alpha: float = 1.0 165 166 167@dataclass 168class MPSMul(MPSNode2x1): 169 pass 170 171 172@dataclass 173class MPSDiv(MPSDivNode2x1): 174 pass 175 176 177@dataclass 178class MPSFmod(MPSDivNode2x1): 179 pass 180 181 182@dataclass 183class MPSRemainder(MPSNode2x1): 184 pass 185 186 187@dataclass 188class MPSMin(MPSNode2x1): 189 pass 190 191 192@dataclass 193class MPSMax(MPSNode2x1): 194 pass 195 196 197@dataclass 198class MPSPow(MPSNode2x1): 199 pass 200 201 202@dataclass 203class MPSAtan2(MPSNode2x1): 204 pass 205 206 207@dataclass 208class MPSBitwiseAnd(MPSNode2x1): 209 pass 210 211 212@dataclass 213class MPSBitwiseOr(MPSNode2x1): 214 pass 215 216 217@dataclass 218class MPSBitwiseXor(MPSNode2x1): 219 pass 220 221 222@dataclass 223class MPSMinimum(MPSNode2x1): 224 pass 225 226 227## 228## Unary ops 229## 230@dataclass 231class MPSExp(MPSNode1x1): 232 pass 233 234 235@dataclass 236class MPSExp2(MPSNode1x1): 237 pass 238 239 240@dataclass 241class MPSReciprocal(MPSNode1x1): 242 pass 243 244 245@dataclass 246class MPSSqrt(MPSNode1x1): 247 pass 248 249 250@dataclass 251class MPSNeg(MPSNode1x1): 252 pass 253 254 255@dataclass 256class MPSLog(MPSNode1x1): 257 pass 258 259 260@dataclass 261class MPSLog10(MPSNode1x1): 262 pass 263 264 265@dataclass 266class MPSLog2(MPSNode1x1): 267 pass 268 269 270@dataclass 271class MPSErf(MPSNode1x1): 272 pass 273 274 275@dataclass 276class MPSFloor(MPSNode1x1): 277 pass 278 279 280@dataclass 281class MPSCeil(MPSNode1x1): 282 pass 283 284 285@dataclass 286class MPSRsqrt(MPSNode1x1): 287 pass 288 289 290@dataclass 291class MPSSigmoid(MPSNode1x1): 292 pass 293 294 295@dataclass 296class MPSSin(MPSNode1x1): 297 pass 298 299 300@dataclass 301class MPSSign(MPSNode1x1): 302 pass 303 304 305@dataclass 306class MPSCos(MPSNode1x1): 307 pass 308 309 310@dataclass 311class MPSTan(MPSNode1x1): 312 pass 313 314 315@dataclass 316class MPSAbs(MPSNode1x1): 317 pass 318 319 320@dataclass 321class MPSAsin(MPSNode1x1): 322 pass 323 324 325@dataclass 326class MPSAcos(MPSNode1x1): 327 pass 328 329 330@dataclass 331class MPSAtan(MPSNode1x1): 332 pass 333 334 335@dataclass 336class MPSSinh(MPSNode1x1): 337 pass 338 339 340@dataclass 341class MPSCosh(MPSNode1x1): 342 pass 343 344 345@dataclass 346class MPSTanh(MPSNode1x1): 347 pass 348 349 350@dataclass 351class MPSAsinh(MPSNode1x1): 352 pass 353 354 355@dataclass 356class MPSAcosh(MPSNode1x1): 357 pass 358 359 360@dataclass 361class MPSAtanh(MPSNode1x1): 362 pass 363 364 365@dataclass 366class MPSBitwiseNot(MPSNode1x1): 367 pass 368 369 370@dataclass 371class MPSIsnan(MPSNode1x1): 372 pass 373 374 375@dataclass 376class MPSIsinf(MPSNode1x1): 377 pass 378 379 380@dataclass 381class MPSRound(MPSNode1x1): 382 pass 383 384 385@dataclass 386class MPSLogicalNot(MPSNode1x1): 387 pass 388 389 390@dataclass 391class MPSBitwise(MPSNode1x1): 392 pass 393 394 395## 396## Linear algebra ops 397## 398@dataclass 399class MPSMatMul(MPSNode2x1): 400 pass 401 402 403@dataclass 404class MPSAddmm(MPSNode3x1): 405 beta: float = 1.0 406 alpha: float = 1.0 407 408 409## 410## Constant ops 411## 412@dataclass 413class MPSFull: 414 output_id: int 415 shape: List[int] 416 fill_value: float 417 dtype: MPSDataType 418 419 420@dataclass 421class MPSFullLike(MPSNode1x1): 422 fill_value: Union[float, str] = 0.0 423 dtype: MPSDataType = MPSDataType.mps_data_type_float32 424 425 426## 427## Clamp ops 428## 429@dataclass 430class MPSClamp(MPSNode1x1): 431 pass 432 433 434@dataclass 435class MPSWhere(MPSNode3x1): 436 pass 437 438 439## 440## Reduce ops 441## 442@dataclass 443class MPSMean(MPSNode1x1): 444 num_dims: int = 0 445 dims: List[int] = field(default_factory=list) 446 keep_dims: bool = False 447 448 449## 450## Indexing ops 451## 452@dataclass 453class MPSIndexSelect(MPSNode1x1): 454 dim: int = 0 455 index_id: int = -1 456 457 458@dataclass 459class MPSEmbedding(MPSNode2x1): 460 padding_idx: int = -1 461 scale_grad_by_freq: bool = False 462 sparse: bool = False 463 464 465@dataclass 466class MPSIndexTensor(MPSNode1x1): 467 indices_id: List[int] = field(default_factory=list) 468 469 470@dataclass 471class MPSIndexPut(MPSNode1x1): 472 indices_id: List[int] = field(default_factory=list) 473 values_shape: List[int] = field(default_factory=list) 474 values_id: int = -1 475 476 477@dataclass 478class MPSScatter(MPSNode1x1): 479 dim: int = 0 480 idx_id: int = -1 481 src_id: int = -1 482 483 484## 485## Shape ops 486## 487@dataclass 488class MPSPermute(MPSNode1x1): 489 num_dims: int = 0 490 perm: List[int] = field(default_factory=list) 491 492 493@dataclass 494class MPSView(MPSNode1x1): 495 num_dims: int = 0 496 shape: List[int] = field(default_factory=list) 497 498 499@dataclass 500class MPSExpand(MPSNode1x1): 501 num_dims: int = 0 502 shape: List[int] = field(default_factory=list) 503 504 505@dataclass 506class MPSCat: 507 input_ids: List[int] 508 output_id: int 509 dim: int 510 511 512@dataclass 513class MPSSqueeze(MPSNode1x1): 514 dims: List[int] = field(default_factory=list) 515 516 517@dataclass 518class MPSUnsqueeze(MPSNode1x1): 519 dim: int = 0 520 521 522@dataclass 523class MPSSelect(MPSNode1x1): 524 dim: int = 0 525 index: int = 0 526 527 528@dataclass 529class MPSSlice(MPSNode1x1): 530 dim: int = 0 531 start: int = -1 532 end: int = -1 533 step: int = 1 534 535 536@dataclass 537class MPSPixelShuffle(MPSNode1x1): 538 upscale_factor: int = 1 539 540 541@dataclass 542class MPSSplitWithSizes: 543 input1_id: int 544 output_ids: List[int] 545 split_sizes: List[int] 546 dim: int 547 548 549@dataclass 550class MPSCast(MPSNode1x1): 551 dtype: MPSDataType 552 553 554## 555## Convolution ops 556## 557 558 559@dataclass 560class MPSConv2D(MPSConv): 561 pass 562 563 564@dataclass 565class MPSDepthwiseConv2D(MPSConv): 566 pass 567 568 569## 570## Comparison Ops 571## 572class MPSEq(MPSNode2x1): 573 pass 574 575 576class MPSNe(MPSNode2x1): 577 pass 578 579 580class MPSGe(MPSNode2x1): 581 pass 582 583 584class MPSGt(MPSNode2x1): 585 pass 586 587 588class MPSLe(MPSNode2x1): 589 pass 590 591 592class MPSLt(MPSNode2x1): 593 pass 594 595 596## 597## Normalization op 598## 599@dataclass 600class MPSBatchNorm: 601 input_id: int 602 mean_id: int 603 var_id: int 604 weight_id: int 605 bias_id: int 606 momentum: float 607 epsilon: float 608 output1_id: int 609 output2_id: int 610 output3_id: int 611 612 613@dataclass 614class MPSLayerNorm: 615 input1_id: int 616 normalized_shape: List[int] 617 weight_id: int 618 bias_id: int 619 eps: float 620 output1_id: int 621 output2_id: int 622 output3_id: int 623 624 625## 626## Pooling ops 627## 628 629 630@dataclass 631class MPSMaxPool2DWithIndices(MPSPooling2D): 632 pass 633 634 635@dataclass 636class MPSAvgPool2D(MPSPooling2D): 637 pass 638 639 640## 641## Pad ops 642## 643@dataclass 644class MPSConstantPadND(MPSNode1x1): 645 pad: List[int] = field(default_factory=list) 646 value: float = 0.0 647 648 649## 650## Range ops 651## 652@dataclass 653class MPSArange: 654 output_id: int 655 start: float 656 end: float 657 step: float 658 dtype: MPSDataType 659 660 661## 662## Quant - Dequant ops 663## 664@dataclass 665class MPSDequantizePerChannelGroup(MPSDequantizeNode): 666 quant_min: int 667 quant_max: int 668 dtype: MPSDataType 669 group_size: int 670 output_dtype: MPSDataType 671 672 673MPSNodeUnion = Union[ 674 # Activation ops 675 MPSHardTanh, 676 MPSReLU, 677 MPSGELU, 678 MPSLeakyReLU, 679 MPSSoftmax, 680 # Binary ops 681 MPSAdd, 682 MPSSub, 683 MPSMul, 684 MPSDiv, 685 MPSMin, 686 MPSMax, 687 MPSPow, 688 MPSRemainder, 689 MPSAtan2, 690 MPSBitwiseAnd, 691 MPSBitwiseOr, 692 MPSBitwiseXor, 693 MPSMinimum, 694 # Unary ops 695 MPSExp, 696 MPSExp2, 697 MPSReciprocal, 698 MPSSqrt, 699 MPSNeg, 700 MPSLog, 701 MPSLog10, 702 MPSLog2, 703 MPSErf, 704 MPSFloor, 705 MPSCeil, 706 MPSRsqrt, 707 MPSSigmoid, 708 MPSSin, 709 MPSSign, 710 MPSCos, 711 MPSTan, 712 MPSAbs, 713 MPSAsin, 714 MPSAcos, 715 MPSAtan, 716 MPSSinh, 717 MPSCosh, 718 MPSTanh, 719 MPSAsinh, 720 MPSAcosh, 721 MPSAtanh, 722 MPSBitwiseNot, 723 MPSIsnan, 724 MPSIsinf, 725 MPSRound, 726 MPSLogicalNot, 727 # Linear algebra ops 728 MPSMatMul, 729 MPSAddmm, 730 # Constant ops 731 MPSFull, 732 MPSFullLike, 733 # Clamp ops 734 MPSClamp, 735 MPSWhere, 736 # Reduce ops 737 MPSMean, 738 # Indexing ops 739 MPSIndexSelect, 740 MPSEmbedding, 741 MPSIndexTensor, 742 MPSIndexPut, 743 MPSScatter, 744 # Shape ops 745 MPSPermute, 746 MPSView, 747 MPSExpand, 748 MPSCat, 749 MPSSqueeze, 750 MPSUnsqueeze, 751 MPSSelect, 752 MPSSlice, 753 MPSPixelShuffle, 754 MPSSplitWithSizes, 755 MPSCast, 756 # Convolution ops 757 MPSConv2D, 758 MPSDepthwiseConv2D, 759 # Comparison ops 760 MPSEq, 761 MPSNe, 762 MPSGe, 763 MPSGt, 764 MPSLe, 765 MPSLt, 766 # Normalization ops 767 MPSBatchNorm, 768 MPSLayerNorm, 769 # Pooling ops 770 MPSMaxPool2DWithIndices, 771 MPSAvgPool2D, 772 # Pad ops 773 MPSConstantPadND, 774 # Range ops 775 MPSArange, 776 # Quant-Dequant ops 777 MPSDequantizePerChannelGroup, 778] 779 780 781@dataclass 782class MPSNode: 783 mpsnode_union: "MPSNodeUnion" 784 min_max: Optional[MPSMinMax] = None 785 786 787@dataclass 788class Buffer: 789 storage: bytes 790 791 792@dataclass 793class MPSTensor: 794 datatype: MPSDataType 795 num_dims: int 796 dims: List[int] 797 constant_buffer_size: int 798 constant_buffer: Buffer # deprecated 799 segment_offset: int = 0 800 801 802@dataclass 803class DataSegment: 804 offset: int 805 size: int 806 807 808@dataclass 809class MPSGraph: 810 version: str 811 mps_nodes: List[MPSNode] 812 mps_values: List[MPSTensor] 813 input_ids: List[int] 814 output_ids: List[int] 815 constant_ids: List[int] 816 graph_type: OpType 817 constant_segment: DataSegment 818