xref: /aosp_15_r20/external/executorch/backends/apple/mps/serialization/mps_graph_schema.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1#
2#  Copyright (c) 2023 Apple Inc. All rights reserved.
3#  Provided subject to the LICENSE file in the top level directory.
4#
5
6"""
7Please refer to executorch/backends/apple/mps/serialization/schema.fbs for the schema definitions
8"""
9
10from dataclasses import dataclass, field
11from enum import IntEnum
12from typing import List, Optional, Union
13
14
15class MPSDataType(IntEnum):
16    mps_data_type_invalid = 0
17    mps_data_type_float16 = 1
18    mps_data_type_float32 = 2
19    mps_data_type_float64 = 3
20    mps_data_type_bfloat16 = 4
21
22    # Signed integers.
23    mps_data_type_int4 = 5
24    mps_data_type_int8 = 6
25    mps_data_type_int16 = 7
26    mps_data_type_int32 = 8
27    mps_data_type_int64 = 9
28
29    # Unsigned integers. range: [0, UTYPE_MAX]
30    mps_data_type_uint4 = 10
31    mps_data_type_uint8 = 11
32    mps_data_type_uint16 = 12
33    mps_data_type_uint32 = 13
34    mps_data_type_uint64 = 14
35
36    mps_data_type_bool = 15
37
38    mps_data_type_complex_float16 = 16
39    mps_data_type_complex_float32 = 17
40
41
42class OpType(IntEnum):
43    mps_graph = 0
44    metal_kernel = 1
45
46
47@dataclass
48class MPSNode1x1:
49    input1_id: int
50    output_id: int
51
52
53@dataclass
54class MPSNode2x1:
55    input1_id: int
56    input2_id: int
57    output_id: int
58
59
60@dataclass
61class MPSDivNode2x1(MPSNode2x1):
62    rounding_mode: str = None
63
64
65@dataclass
66class MPSNode3x1:
67    input1_id: int
68    input2_id: int
69    input3_id: int
70    output_id: int
71
72
73@dataclass
74class MPSDequantizeNode(MPSNode1x1):
75    scales_id: int
76    zero_points_id: int
77
78
79@dataclass
80class MPSConv(MPSNode3x1):
81    stride_x: int = 0
82    stride_y: int = 0
83    dilation_x: int = 0
84    dilation_y: int = 0
85    groups: int = 0
86    padding_left: int = 0
87    padding_right: int = 0
88    padding_top: int = 0
89    padding_bottom: int = 0
90
91
92@dataclass
93class MPSPooling2D:
94    input1_id: int
95    kernel_height: int
96    kernel_width: int
97    stride_height: int
98    stride_width: int
99    padding_left: int
100    padding_right: int
101    padding_top: int
102    padding_bottom: int
103    dilation_height: int
104    dilation_width: int
105    ceil_mode: bool
106    output1_id: int
107    output2_id: int = -1
108    count_include_pad: bool = True
109    divisor_override: int = 0
110
111
112@dataclass
113class MPSMinMax:
114    min_value: Union[float, str] = "-inf"
115    max_value: Union[float, str] = "inf"
116
117
118##
119## Activation ops
120##
121@dataclass
122class MPSHardTanh(MPSNode1x1):
123    min_value: float = 0.0
124    max_value: float = 0.0
125
126
127@dataclass
128class MPSReLU(MPSNode1x1):
129    pass
130
131
132@dataclass
133class MPSGELU(MPSNode1x1):
134    approximate: str = "none"
135
136
137@dataclass
138class MPSLeakyReLU(MPSNode1x1):
139    negative_slope: float = 0.01
140
141
142@dataclass
143class MPSSoftmax(MPSNode1x1):
144    dim: int = 0
145    half_to_float: bool = False
146
147
148@dataclass
149class MPSLogSoftmax(MPSNode1x1):
150    dim: int = 0
151    half_to_float: bool = False
152
153
154##
155## Binary ops
156##
157@dataclass
158class MPSAdd(MPSNode2x1):
159    alpha: float = 1.0
160
161
162@dataclass
163class MPSSub(MPSNode2x1):
164    alpha: float = 1.0
165
166
167@dataclass
168class MPSMul(MPSNode2x1):
169    pass
170
171
172@dataclass
173class MPSDiv(MPSDivNode2x1):
174    pass
175
176
177@dataclass
178class MPSFmod(MPSDivNode2x1):
179    pass
180
181
182@dataclass
183class MPSRemainder(MPSNode2x1):
184    pass
185
186
187@dataclass
188class MPSMin(MPSNode2x1):
189    pass
190
191
192@dataclass
193class MPSMax(MPSNode2x1):
194    pass
195
196
197@dataclass
198class MPSPow(MPSNode2x1):
199    pass
200
201
202@dataclass
203class MPSAtan2(MPSNode2x1):
204    pass
205
206
207@dataclass
208class MPSBitwiseAnd(MPSNode2x1):
209    pass
210
211
212@dataclass
213class MPSBitwiseOr(MPSNode2x1):
214    pass
215
216
217@dataclass
218class MPSBitwiseXor(MPSNode2x1):
219    pass
220
221
222@dataclass
223class MPSMinimum(MPSNode2x1):
224    pass
225
226
227##
228## Unary ops
229##
230@dataclass
231class MPSExp(MPSNode1x1):
232    pass
233
234
235@dataclass
236class MPSExp2(MPSNode1x1):
237    pass
238
239
240@dataclass
241class MPSReciprocal(MPSNode1x1):
242    pass
243
244
245@dataclass
246class MPSSqrt(MPSNode1x1):
247    pass
248
249
250@dataclass
251class MPSNeg(MPSNode1x1):
252    pass
253
254
255@dataclass
256class MPSLog(MPSNode1x1):
257    pass
258
259
260@dataclass
261class MPSLog10(MPSNode1x1):
262    pass
263
264
265@dataclass
266class MPSLog2(MPSNode1x1):
267    pass
268
269
270@dataclass
271class MPSErf(MPSNode1x1):
272    pass
273
274
275@dataclass
276class MPSFloor(MPSNode1x1):
277    pass
278
279
280@dataclass
281class MPSCeil(MPSNode1x1):
282    pass
283
284
285@dataclass
286class MPSRsqrt(MPSNode1x1):
287    pass
288
289
290@dataclass
291class MPSSigmoid(MPSNode1x1):
292    pass
293
294
295@dataclass
296class MPSSin(MPSNode1x1):
297    pass
298
299
300@dataclass
301class MPSSign(MPSNode1x1):
302    pass
303
304
305@dataclass
306class MPSCos(MPSNode1x1):
307    pass
308
309
310@dataclass
311class MPSTan(MPSNode1x1):
312    pass
313
314
315@dataclass
316class MPSAbs(MPSNode1x1):
317    pass
318
319
320@dataclass
321class MPSAsin(MPSNode1x1):
322    pass
323
324
325@dataclass
326class MPSAcos(MPSNode1x1):
327    pass
328
329
330@dataclass
331class MPSAtan(MPSNode1x1):
332    pass
333
334
335@dataclass
336class MPSSinh(MPSNode1x1):
337    pass
338
339
340@dataclass
341class MPSCosh(MPSNode1x1):
342    pass
343
344
345@dataclass
346class MPSTanh(MPSNode1x1):
347    pass
348
349
350@dataclass
351class MPSAsinh(MPSNode1x1):
352    pass
353
354
355@dataclass
356class MPSAcosh(MPSNode1x1):
357    pass
358
359
360@dataclass
361class MPSAtanh(MPSNode1x1):
362    pass
363
364
365@dataclass
366class MPSBitwiseNot(MPSNode1x1):
367    pass
368
369
370@dataclass
371class MPSIsnan(MPSNode1x1):
372    pass
373
374
375@dataclass
376class MPSIsinf(MPSNode1x1):
377    pass
378
379
380@dataclass
381class MPSRound(MPSNode1x1):
382    pass
383
384
385@dataclass
386class MPSLogicalNot(MPSNode1x1):
387    pass
388
389
390@dataclass
391class MPSBitwise(MPSNode1x1):
392    pass
393
394
395##
396## Linear algebra ops
397##
398@dataclass
399class MPSMatMul(MPSNode2x1):
400    pass
401
402
403@dataclass
404class MPSAddmm(MPSNode3x1):
405    beta: float = 1.0
406    alpha: float = 1.0
407
408
409##
410## Constant ops
411##
412@dataclass
413class MPSFull:
414    output_id: int
415    shape: List[int]
416    fill_value: float
417    dtype: MPSDataType
418
419
420@dataclass
421class MPSFullLike(MPSNode1x1):
422    fill_value: Union[float, str] = 0.0
423    dtype: MPSDataType = MPSDataType.mps_data_type_float32
424
425
426##
427## Clamp ops
428##
429@dataclass
430class MPSClamp(MPSNode1x1):
431    pass
432
433
434@dataclass
435class MPSWhere(MPSNode3x1):
436    pass
437
438
439##
440## Reduce ops
441##
442@dataclass
443class MPSMean(MPSNode1x1):
444    num_dims: int = 0
445    dims: List[int] = field(default_factory=list)
446    keep_dims: bool = False
447
448
449##
450## Indexing ops
451##
452@dataclass
453class MPSIndexSelect(MPSNode1x1):
454    dim: int = 0
455    index_id: int = -1
456
457
458@dataclass
459class MPSEmbedding(MPSNode2x1):
460    padding_idx: int = -1
461    scale_grad_by_freq: bool = False
462    sparse: bool = False
463
464
465@dataclass
466class MPSIndexTensor(MPSNode1x1):
467    indices_id: List[int] = field(default_factory=list)
468
469
470@dataclass
471class MPSIndexPut(MPSNode1x1):
472    indices_id: List[int] = field(default_factory=list)
473    values_shape: List[int] = field(default_factory=list)
474    values_id: int = -1
475
476
477@dataclass
478class MPSScatter(MPSNode1x1):
479    dim: int = 0
480    idx_id: int = -1
481    src_id: int = -1
482
483
484##
485## Shape ops
486##
487@dataclass
488class MPSPermute(MPSNode1x1):
489    num_dims: int = 0
490    perm: List[int] = field(default_factory=list)
491
492
493@dataclass
494class MPSView(MPSNode1x1):
495    num_dims: int = 0
496    shape: List[int] = field(default_factory=list)
497
498
499@dataclass
500class MPSExpand(MPSNode1x1):
501    num_dims: int = 0
502    shape: List[int] = field(default_factory=list)
503
504
505@dataclass
506class MPSCat:
507    input_ids: List[int]
508    output_id: int
509    dim: int
510
511
512@dataclass
513class MPSSqueeze(MPSNode1x1):
514    dims: List[int] = field(default_factory=list)
515
516
517@dataclass
518class MPSUnsqueeze(MPSNode1x1):
519    dim: int = 0
520
521
522@dataclass
523class MPSSelect(MPSNode1x1):
524    dim: int = 0
525    index: int = 0
526
527
528@dataclass
529class MPSSlice(MPSNode1x1):
530    dim: int = 0
531    start: int = -1
532    end: int = -1
533    step: int = 1
534
535
536@dataclass
537class MPSPixelShuffle(MPSNode1x1):
538    upscale_factor: int = 1
539
540
541@dataclass
542class MPSSplitWithSizes:
543    input1_id: int
544    output_ids: List[int]
545    split_sizes: List[int]
546    dim: int
547
548
549@dataclass
550class MPSCast(MPSNode1x1):
551    dtype: MPSDataType
552
553
554##
555## Convolution ops
556##
557
558
559@dataclass
560class MPSConv2D(MPSConv):
561    pass
562
563
564@dataclass
565class MPSDepthwiseConv2D(MPSConv):
566    pass
567
568
569##
570## Comparison Ops
571##
572class MPSEq(MPSNode2x1):
573    pass
574
575
576class MPSNe(MPSNode2x1):
577    pass
578
579
580class MPSGe(MPSNode2x1):
581    pass
582
583
584class MPSGt(MPSNode2x1):
585    pass
586
587
588class MPSLe(MPSNode2x1):
589    pass
590
591
592class MPSLt(MPSNode2x1):
593    pass
594
595
596##
597## Normalization op
598##
599@dataclass
600class MPSBatchNorm:
601    input_id: int
602    mean_id: int
603    var_id: int
604    weight_id: int
605    bias_id: int
606    momentum: float
607    epsilon: float
608    output1_id: int
609    output2_id: int
610    output3_id: int
611
612
613@dataclass
614class MPSLayerNorm:
615    input1_id: int
616    normalized_shape: List[int]
617    weight_id: int
618    bias_id: int
619    eps: float
620    output1_id: int
621    output2_id: int
622    output3_id: int
623
624
625##
626## Pooling ops
627##
628
629
630@dataclass
631class MPSMaxPool2DWithIndices(MPSPooling2D):
632    pass
633
634
635@dataclass
636class MPSAvgPool2D(MPSPooling2D):
637    pass
638
639
640##
641## Pad ops
642##
643@dataclass
644class MPSConstantPadND(MPSNode1x1):
645    pad: List[int] = field(default_factory=list)
646    value: float = 0.0
647
648
649##
650## Range ops
651##
652@dataclass
653class MPSArange:
654    output_id: int
655    start: float
656    end: float
657    step: float
658    dtype: MPSDataType
659
660
661##
662## Quant - Dequant ops
663##
664@dataclass
665class MPSDequantizePerChannelGroup(MPSDequantizeNode):
666    quant_min: int
667    quant_max: int
668    dtype: MPSDataType
669    group_size: int
670    output_dtype: MPSDataType
671
672
673MPSNodeUnion = Union[
674    # Activation ops
675    MPSHardTanh,
676    MPSReLU,
677    MPSGELU,
678    MPSLeakyReLU,
679    MPSSoftmax,
680    # Binary ops
681    MPSAdd,
682    MPSSub,
683    MPSMul,
684    MPSDiv,
685    MPSMin,
686    MPSMax,
687    MPSPow,
688    MPSRemainder,
689    MPSAtan2,
690    MPSBitwiseAnd,
691    MPSBitwiseOr,
692    MPSBitwiseXor,
693    MPSMinimum,
694    # Unary ops
695    MPSExp,
696    MPSExp2,
697    MPSReciprocal,
698    MPSSqrt,
699    MPSNeg,
700    MPSLog,
701    MPSLog10,
702    MPSLog2,
703    MPSErf,
704    MPSFloor,
705    MPSCeil,
706    MPSRsqrt,
707    MPSSigmoid,
708    MPSSin,
709    MPSSign,
710    MPSCos,
711    MPSTan,
712    MPSAbs,
713    MPSAsin,
714    MPSAcos,
715    MPSAtan,
716    MPSSinh,
717    MPSCosh,
718    MPSTanh,
719    MPSAsinh,
720    MPSAcosh,
721    MPSAtanh,
722    MPSBitwiseNot,
723    MPSIsnan,
724    MPSIsinf,
725    MPSRound,
726    MPSLogicalNot,
727    # Linear algebra ops
728    MPSMatMul,
729    MPSAddmm,
730    # Constant ops
731    MPSFull,
732    MPSFullLike,
733    # Clamp ops
734    MPSClamp,
735    MPSWhere,
736    # Reduce ops
737    MPSMean,
738    # Indexing ops
739    MPSIndexSelect,
740    MPSEmbedding,
741    MPSIndexTensor,
742    MPSIndexPut,
743    MPSScatter,
744    # Shape ops
745    MPSPermute,
746    MPSView,
747    MPSExpand,
748    MPSCat,
749    MPSSqueeze,
750    MPSUnsqueeze,
751    MPSSelect,
752    MPSSlice,
753    MPSPixelShuffle,
754    MPSSplitWithSizes,
755    MPSCast,
756    # Convolution ops
757    MPSConv2D,
758    MPSDepthwiseConv2D,
759    # Comparison ops
760    MPSEq,
761    MPSNe,
762    MPSGe,
763    MPSGt,
764    MPSLe,
765    MPSLt,
766    # Normalization ops
767    MPSBatchNorm,
768    MPSLayerNorm,
769    # Pooling ops
770    MPSMaxPool2DWithIndices,
771    MPSAvgPool2D,
772    # Pad ops
773    MPSConstantPadND,
774    # Range ops
775    MPSArange,
776    # Quant-Dequant ops
777    MPSDequantizePerChannelGroup,
778]
779
780
781@dataclass
782class MPSNode:
783    mpsnode_union: "MPSNodeUnion"
784    min_max: Optional[MPSMinMax] = None
785
786
787@dataclass
788class Buffer:
789    storage: bytes
790
791
792@dataclass
793class MPSTensor:
794    datatype: MPSDataType
795    num_dims: int
796    dims: List[int]
797    constant_buffer_size: int
798    constant_buffer: Buffer  # deprecated
799    segment_offset: int = 0
800
801
802@dataclass
803class DataSegment:
804    offset: int
805    size: int
806
807
808@dataclass
809class MPSGraph:
810    version: str
811    mps_nodes: List[MPSNode]
812    mps_values: List[MPSTensor]
813    input_ids: List[int]
814    output_ids: List[int]
815    constant_ids: List[int]
816    graph_type: OpType
817    constant_segment: DataSegment
818