xref: /aosp_15_r20/external/executorch/backends/apple/mps/serialization/schema.fbs (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1//
2//  Copyright (c) 2023 Apple Inc. All rights reserved.
3//  Provided subject to the LICENSE file in the top level directory.
4//
5
6namespace mpsgraph;
7
8// Update after any BC breaking changes
9file_identifier "MP00";
10
11// datatype for mps-values
12enum MPSDataType : short {
13  mps_data_type_invalid = 0,
14  mps_data_type_float16 = 1,
15  mps_data_type_float32 = 2,
16  mps_data_type_float64 = 3,
17  mps_data_type_bfloat16 = 4,
18
19  // Signed integers.
20  mps_data_type_int4 = 5,
21  mps_data_type_int8 = 6,
22  mps_data_type_int16 = 7,
23  mps_data_type_int32 = 8,
24  mps_data_type_int64 = 9,
25
26
27  // Unsigned integers. range: [0, UTYPE_MAX]
28  mps_data_type_uint4 = 10,
29  mps_data_type_uint8 = 11,
30  mps_data_type_uint16 = 12,
31  mps_data_type_uint32 = 13,
32  mps_data_type_uint64 = 14,
33
34  mps_data_type_bool = 15,
35
36  mps_data_type_complex_float16 = 16,
37  mps_data_type_complex_float32 = 17,
38}
39
40// ops like index.Tensor and index.put are currentely implemented as
41// Metal kernels for unsupported MPSGraph cases.
42enum OpType : short {
43  mps_graph,
44  metal_kernel
45}
46
47// Helper classes to define the number of input and output tensors for a node.
48// Not meant to be used directly.
49
50// A node with one input and one output.
51table _MPSNode1x1 {
52  input1_id:int;
53  output_id:int;
54}
55
56// A node with two inputs and one output.
57table _MPSNode2x1 {
58  input1_id:int;
59  input2_id:int;
60  output_id:int;
61}
62
63table _MPSDivNode2x1 {
64  input1_id:int;
65  input2_id:int;
66  output_id:int;
67  rounding_mode:string;
68}
69
70table _MPSNodeWithAlpha2x1 {
71  input1_id:int;
72  input2_id:int;
73  output_id:int;
74  alpha:float;
75}
76
77// A node with three inputs and one output.
78table _MPSNode3x1 {
79  input1_id:int;
80  input2_id:int;
81  input3_id:int;
82  output_id:int;
83}
84
85table MPSMinMax {
86  min_value:float;
87  max_value:float;
88}
89
90table MPSPooling2D {
91  input1_id:int;
92  kernel_height:int;
93  kernel_width:int;
94  stride_height:int;
95  stride_width:int;
96  padding_left:int;
97  padding_right:int;
98  padding_top:int;
99  padding_bottom:int;
100  dilation_height:int;
101  dilation_width:int;
102  ceil_mode:bool;
103  count_include_pad:bool;
104  divisor_override:int;
105  output1_id:int;
106  output2_id:int;
107}
108
109// Activation ops.
110table MPSHardTanh {
111  input1_id:int;
112  output_id:int;
113  min_value:float;
114  max_value:float;
115}
116
117table MPSGELU {
118  input1_id:int;
119  output_id:int;
120  approximate:string;
121}
122
123table MPSLeakyReLU {
124  input1_id:int;
125  output_id:int;
126  negative_slope:float;
127}
128
129table MPSSoftmax {
130  input1_id:int;
131  output_id:int;
132  dim:int;
133  half_to_float:bool;
134}
135
136// Clamp ops
137table MPSClamp {
138  input1_id:int;
139  output_id:int;
140}
141
142// Reduce ops
143table MPSMean {
144  input1_id:int;
145  output_id:int;
146  num_dims:int;
147  dims:[int];
148  keep_dims:bool;
149}
150
151// Indexing ops
152table MPSIndexSelect {
153  input1_id:int;
154  output_id:int;
155  dim:int;
156  index_id:int;
157}
158
159table MPSEmbedding {
160  input1_id:int;
161  input2_id:int;
162  output_id:int;
163  padding_idx:int;
164  scale_grad_by_freq:bool;
165  sparse:bool;
166}
167
168table MPSIndexTensor {
169  input1_id:int;
170  indices_id:[int];
171  output_id:int;
172}
173
174table MPSIndexPut {
175  input1_id:int;
176  indices_id:[int];
177  values_shape:[int];
178  values_id:int;
179  output_id:int;
180}
181
182table MPSScatter {
183  input1_id:int;
184  output_id:int;
185  dim:long;
186  idx_id:int;
187  src_id:int;
188}
189
190// Shape ops.
191table MPSPermute {
192  input1_id:int;
193  output_id:int;
194  num_dims:int;
195  perm:[int];
196}
197
198table MPSView {
199  input1_id:int;
200  output_id:int;
201  num_dims:int;
202  shape:[int];
203}
204
205table MPSCat {
206  input_ids:[int];
207  output_id:int;
208  dim:int;
209}
210
211table MPSSqueeze {
212  input1_id:int;
213  output_id:int;
214  dims:[int];
215}
216
217table MPSUnsqueeze {
218  input1_id:int;
219  output_id:int;
220  dim:int;
221}
222
223table MPSSelect {
224  input1_id:int;
225  output_id:int;
226  dim:int;
227  index:int;
228}
229
230table MPSSlice {
231  input1_id:int;
232  output_id:int;
233  dim:long;
234  start:long;
235  end:long;
236  step:long;
237}
238
239table MPSPixelShuffle {
240  input1_id:int;
241  output_id:int;
242  upscale_factor:int;
243}
244
245table MPSSplitWithSizes {
246  input1_id:int;
247  output_ids:[int];
248  split_sizes:[int];
249  dim:int;
250}
251
252table MPSCast {
253  input1_id:int;
254  output_id:int;
255  dtype:MPSDataType;
256}
257
258// Linear algebra ops.
259table MPSAddmm {
260  input1_id:int;
261  input2_id:int;
262  input3_id:int;
263  output_id:int;
264  beta:float;
265  alpha:float;
266}
267
268// Constant ops
269table _MPSFull {
270  input1_id:int;
271  output_id:int;
272  shape:[int];
273  fill_value: float;
274  dtype:MPSDataType;
275}
276
277// Convolution ops.
278table MPSConv {
279  input1_id:int;
280  input2_id:int;
281  input3_id:int;
282  output_id:int;
283  stride_x:int;
284  stride_y:int;
285  dilation_x:int;
286  dilation_y:int;
287  groups:int;
288  padding_left:int;
289  padding_right:int;
290  padding_top:int;
291  padding_bottom:int;
292}
293
294// Normalization ops.
295table MPSBatchNorm {
296  input_id:int;
297  mean_id:int;
298  var_id:int;
299  weight_id:int;
300  bias_id:int;
301  momentum:float;
302  epsilon:float;
303  output2_id:int;
304  output1_id:int;
305  output3_id:int;
306}
307
308table MPSLayerNorm {
309  input1_id:int;
310  normalized_shape:[int];
311  weight_id:int;
312  bias_id:int;
313  eps:float;
314  output2_id:int;
315  output1_id:int;
316  output3_id:int;
317}
318
319// Pooling ops
320
321// Pad ops
322table MPSConstantPadND {
323  input1_id:int;
324  output_id:int;
325  pad:[int];
326  value:float;
327}
328
329// Range ops
330table MPSArange {
331  output_id:int;
332  start:float;
333  end:float;
334  step:float;
335  dtype:MPSDataType;
336}
337
338// Quant - Dequant ops
339table MPSDequantizePerChannelGroup {
340  input1_id:int;
341  output_id:int;
342  scales_id:int;
343  zero_points_id:int;
344  quant_min:int;
345  quant_max:int;
346  dtype:MPSDataType;
347  group_size:int;
348  output_dtype:MPSDataType;
349}
350
351union MPSNodeUnion {
352    // Activation ops
353    MPSHardTanh,
354    MPSReLU: _MPSNode2x1,
355    MPSGELU,
356    MPSLeakyReLU,
357    MPSSoftmax,
358    MPSLogSoftmax: MPSSoftmax,
359
360    // Binary ops
361    MPSAdd: _MPSNodeWithAlpha2x1,
362    MPSSub: _MPSNodeWithAlpha2x1,
363    MPSMul: _MPSNode2x1,
364    MPSDiv: _MPSDivNode2x1,
365    MPSFmod: _MPSDivNode2x1,
366    MPSRemainder: _MPSDivNode2x1,
367    MPSMin: _MPSNode2x1,
368    MPSMax: _MPSNode2x1,
369    MPSPow: _MPSNode2x1,
370    MPSAtan2: _MPSNode2x1,
371    MPSBitwiseAnd: _MPSNode2x1,
372    MPSBitwiseOr: _MPSNode2x1,
373    MPSBitwiseXor: _MPSNode2x1,
374    MPSMinimum: _MPSNode2x1,
375
376    // Unary ops
377    MPSExp: _MPSNode1x1,
378    MPSExp2: _MPSNode1x1,
379    MPSReciprocal: _MPSNode1x1,
380    MPSSqrt: _MPSNode1x1,
381    MPSNeg: _MPSNode1x1,
382    MPSLog: _MPSNode1x1,
383    MPSLog10: _MPSNode1x1,
384    MPSLog2: _MPSNode1x1,
385    MPSErf: _MPSNode1x1,
386    MPSFloor: _MPSNode1x1,
387    MPSCeil: _MPSNode1x1,
388    MPSRsqrt: _MPSNode1x1,
389    MPSSigmoid: _MPSNode1x1,
390    MPSSin: _MPSNode1x1,
391    MPSSign: _MPSNode1x1,
392    MPSCos: _MPSNode1x1,
393    MPSTan: _MPSNode1x1,
394    MPSAbs: _MPSNode1x1,
395    MPSAsin: _MPSNode1x1,
396    MPSAcos: _MPSNode1x1,
397    MPSAtan: _MPSNode1x1,
398    MPSSinh: _MPSNode1x1,
399    MPSCosh: _MPSNode1x1,
400    MPSTanh: _MPSNode1x1,
401    MPSAsinh: _MPSNode1x1,
402    MPSAcosh: _MPSNode1x1,
403    MPSAtanh: _MPSNode1x1,
404    MPSBitwiseNot: _MPSNode1x1,
405    MPSIsnan: _MPSNode1x1,
406    MPSIsinf: _MPSNode1x1,
407    MPSRound: _MPSNode1x1,
408    MPSLogicalNot: _MPSNode1x1,
409
410    // Linear algebra ops
411    MPSMatMul: _MPSNode2x1,
412    MPSAddmm,
413
414    // Constant ops
415    MPSFull: _MPSFull,
416    MPSFullLike: _MPSFull,
417
418    // Clamp ops,
419    MPSClamp,
420    MPSWhere: _MPSNode3x1,
421
422    // Indexing ops
423    MPSIndexSelect,
424    MPSEmbedding,
425    MPSIndexTensor,
426    MPSIndexPut,
427    MPSScatter,
428
429    // Reduce ops
430    MPSMean,
431
432    // Shape ops
433    MPSPermute,
434    MPSView,
435    MPSExpand: MPSView,
436    MPSCat,
437    MPSSqueeze,
438    MPSUnsqueeze,
439    MPSSelect,
440    MPSSlice,
441    MPSPixelShuffle,
442    MPSSplitWithSizes,
443    MPSCast,
444
445    // Convolution ops
446    MPSConv2D: MPSConv,
447    MPSDepthwiseConv2D: MPSConv,
448
449    // Comparasion ops
450    MPSEq: _MPSNode2x1,
451    MPSNe: _MPSNode2x1,
452    MPSGe: _MPSNode2x1,
453    MPSGt: _MPSNode2x1,
454    MPSLe: _MPSNode2x1,
455    MPSLt: _MPSNode2x1,
456
457    // Normalization ops
458    MPSBatchNorm,
459    MPSLayerNorm,
460
461    // Pooling ops
462    MPSMaxPool2DWithIndices: MPSPooling2D,
463    MPSAvgPool2D: MPSPooling2D,
464
465    // Pad ops
466    MPSConstantPadND,
467
468    // Range ops
469    MPSArange,
470
471    // Quant-Dequant ops
472    MPSDequantizePerChannelGroup,
473}
474
475table MPSNode {
476  mpsnode_union:MPSNodeUnion;
477  min_max:MPSMinMax;
478}
479
480// taken from executorch
481// Data buffer abstraction.
482// Deprecated
483table Buffer {
484  storage:[ubyte] (force_align: 16);
485}
486
487table MPSTensor {
488  datatype:MPSDataType;
489  num_dims:int;
490  dims:[int];
491  constant_buffer_size:uint64;
492  constant_buffer:Buffer; // deprecated
493  segment_offset:uint64;
494}
495
496table DataSegment {
497  // Segment offsets are relative to the segment base offset provided in
498  // the extended file header. Segments will typically be aligned in a
499  // way to make it possible to use mmap() to load them.
500  offset: uint64;
501
502  // The size in bytes of valid data starting at the offset. The segment
503  // data may be followed by padding before the segment that follows it,
504  // to make it easier to use mmap().
505  size: uint64;
506}
507
508table MPSGraph {
509  // Schema version.
510  version:string;
511  mps_nodes:[MPSNode];
512  mps_values:[MPSTensor];
513
514  input_ids:[int];
515  output_ids:[int];
516  constant_ids:[int];
517
518  graph_type:OpType;
519
520  constant_segment:DataSegment;
521}
522
523root_type MPSGraph;
524