1// 2// Copyright (c) 2023 Apple Inc. All rights reserved. 3// Provided subject to the LICENSE file in the top level directory. 4// 5 6namespace mpsgraph; 7 8// Update after any BC breaking changes 9file_identifier "MP00"; 10 11// datatype for mps-values 12enum MPSDataType : short { 13 mps_data_type_invalid = 0, 14 mps_data_type_float16 = 1, 15 mps_data_type_float32 = 2, 16 mps_data_type_float64 = 3, 17 mps_data_type_bfloat16 = 4, 18 19 // Signed integers. 20 mps_data_type_int4 = 5, 21 mps_data_type_int8 = 6, 22 mps_data_type_int16 = 7, 23 mps_data_type_int32 = 8, 24 mps_data_type_int64 = 9, 25 26 27 // Unsigned integers. range: [0, UTYPE_MAX] 28 mps_data_type_uint4 = 10, 29 mps_data_type_uint8 = 11, 30 mps_data_type_uint16 = 12, 31 mps_data_type_uint32 = 13, 32 mps_data_type_uint64 = 14, 33 34 mps_data_type_bool = 15, 35 36 mps_data_type_complex_float16 = 16, 37 mps_data_type_complex_float32 = 17, 38} 39 40// ops like index.Tensor and index.put are currentely implemented as 41// Metal kernels for unsupported MPSGraph cases. 42enum OpType : short { 43 mps_graph, 44 metal_kernel 45} 46 47// Helper classes to define the number of input and output tensors for a node. 48// Not meant to be used directly. 49 50// A node with one input and one output. 51table _MPSNode1x1 { 52 input1_id:int; 53 output_id:int; 54} 55 56// A node with two inputs and one output. 57table _MPSNode2x1 { 58 input1_id:int; 59 input2_id:int; 60 output_id:int; 61} 62 63table _MPSDivNode2x1 { 64 input1_id:int; 65 input2_id:int; 66 output_id:int; 67 rounding_mode:string; 68} 69 70table _MPSNodeWithAlpha2x1 { 71 input1_id:int; 72 input2_id:int; 73 output_id:int; 74 alpha:float; 75} 76 77// A node with three inputs and one output. 78table _MPSNode3x1 { 79 input1_id:int; 80 input2_id:int; 81 input3_id:int; 82 output_id:int; 83} 84 85table MPSMinMax { 86 min_value:float; 87 max_value:float; 88} 89 90table MPSPooling2D { 91 input1_id:int; 92 kernel_height:int; 93 kernel_width:int; 94 stride_height:int; 95 stride_width:int; 96 padding_left:int; 97 padding_right:int; 98 padding_top:int; 99 padding_bottom:int; 100 dilation_height:int; 101 dilation_width:int; 102 ceil_mode:bool; 103 count_include_pad:bool; 104 divisor_override:int; 105 output1_id:int; 106 output2_id:int; 107} 108 109// Activation ops. 110table MPSHardTanh { 111 input1_id:int; 112 output_id:int; 113 min_value:float; 114 max_value:float; 115} 116 117table MPSGELU { 118 input1_id:int; 119 output_id:int; 120 approximate:string; 121} 122 123table MPSLeakyReLU { 124 input1_id:int; 125 output_id:int; 126 negative_slope:float; 127} 128 129table MPSSoftmax { 130 input1_id:int; 131 output_id:int; 132 dim:int; 133 half_to_float:bool; 134} 135 136// Clamp ops 137table MPSClamp { 138 input1_id:int; 139 output_id:int; 140} 141 142// Reduce ops 143table MPSMean { 144 input1_id:int; 145 output_id:int; 146 num_dims:int; 147 dims:[int]; 148 keep_dims:bool; 149} 150 151// Indexing ops 152table MPSIndexSelect { 153 input1_id:int; 154 output_id:int; 155 dim:int; 156 index_id:int; 157} 158 159table MPSEmbedding { 160 input1_id:int; 161 input2_id:int; 162 output_id:int; 163 padding_idx:int; 164 scale_grad_by_freq:bool; 165 sparse:bool; 166} 167 168table MPSIndexTensor { 169 input1_id:int; 170 indices_id:[int]; 171 output_id:int; 172} 173 174table MPSIndexPut { 175 input1_id:int; 176 indices_id:[int]; 177 values_shape:[int]; 178 values_id:int; 179 output_id:int; 180} 181 182table MPSScatter { 183 input1_id:int; 184 output_id:int; 185 dim:long; 186 idx_id:int; 187 src_id:int; 188} 189 190// Shape ops. 191table MPSPermute { 192 input1_id:int; 193 output_id:int; 194 num_dims:int; 195 perm:[int]; 196} 197 198table MPSView { 199 input1_id:int; 200 output_id:int; 201 num_dims:int; 202 shape:[int]; 203} 204 205table MPSCat { 206 input_ids:[int]; 207 output_id:int; 208 dim:int; 209} 210 211table MPSSqueeze { 212 input1_id:int; 213 output_id:int; 214 dims:[int]; 215} 216 217table MPSUnsqueeze { 218 input1_id:int; 219 output_id:int; 220 dim:int; 221} 222 223table MPSSelect { 224 input1_id:int; 225 output_id:int; 226 dim:int; 227 index:int; 228} 229 230table MPSSlice { 231 input1_id:int; 232 output_id:int; 233 dim:long; 234 start:long; 235 end:long; 236 step:long; 237} 238 239table MPSPixelShuffle { 240 input1_id:int; 241 output_id:int; 242 upscale_factor:int; 243} 244 245table MPSSplitWithSizes { 246 input1_id:int; 247 output_ids:[int]; 248 split_sizes:[int]; 249 dim:int; 250} 251 252table MPSCast { 253 input1_id:int; 254 output_id:int; 255 dtype:MPSDataType; 256} 257 258// Linear algebra ops. 259table MPSAddmm { 260 input1_id:int; 261 input2_id:int; 262 input3_id:int; 263 output_id:int; 264 beta:float; 265 alpha:float; 266} 267 268// Constant ops 269table _MPSFull { 270 input1_id:int; 271 output_id:int; 272 shape:[int]; 273 fill_value: float; 274 dtype:MPSDataType; 275} 276 277// Convolution ops. 278table MPSConv { 279 input1_id:int; 280 input2_id:int; 281 input3_id:int; 282 output_id:int; 283 stride_x:int; 284 stride_y:int; 285 dilation_x:int; 286 dilation_y:int; 287 groups:int; 288 padding_left:int; 289 padding_right:int; 290 padding_top:int; 291 padding_bottom:int; 292} 293 294// Normalization ops. 295table MPSBatchNorm { 296 input_id:int; 297 mean_id:int; 298 var_id:int; 299 weight_id:int; 300 bias_id:int; 301 momentum:float; 302 epsilon:float; 303 output2_id:int; 304 output1_id:int; 305 output3_id:int; 306} 307 308table MPSLayerNorm { 309 input1_id:int; 310 normalized_shape:[int]; 311 weight_id:int; 312 bias_id:int; 313 eps:float; 314 output2_id:int; 315 output1_id:int; 316 output3_id:int; 317} 318 319// Pooling ops 320 321// Pad ops 322table MPSConstantPadND { 323 input1_id:int; 324 output_id:int; 325 pad:[int]; 326 value:float; 327} 328 329// Range ops 330table MPSArange { 331 output_id:int; 332 start:float; 333 end:float; 334 step:float; 335 dtype:MPSDataType; 336} 337 338// Quant - Dequant ops 339table MPSDequantizePerChannelGroup { 340 input1_id:int; 341 output_id:int; 342 scales_id:int; 343 zero_points_id:int; 344 quant_min:int; 345 quant_max:int; 346 dtype:MPSDataType; 347 group_size:int; 348 output_dtype:MPSDataType; 349} 350 351union MPSNodeUnion { 352 // Activation ops 353 MPSHardTanh, 354 MPSReLU: _MPSNode2x1, 355 MPSGELU, 356 MPSLeakyReLU, 357 MPSSoftmax, 358 MPSLogSoftmax: MPSSoftmax, 359 360 // Binary ops 361 MPSAdd: _MPSNodeWithAlpha2x1, 362 MPSSub: _MPSNodeWithAlpha2x1, 363 MPSMul: _MPSNode2x1, 364 MPSDiv: _MPSDivNode2x1, 365 MPSFmod: _MPSDivNode2x1, 366 MPSRemainder: _MPSDivNode2x1, 367 MPSMin: _MPSNode2x1, 368 MPSMax: _MPSNode2x1, 369 MPSPow: _MPSNode2x1, 370 MPSAtan2: _MPSNode2x1, 371 MPSBitwiseAnd: _MPSNode2x1, 372 MPSBitwiseOr: _MPSNode2x1, 373 MPSBitwiseXor: _MPSNode2x1, 374 MPSMinimum: _MPSNode2x1, 375 376 // Unary ops 377 MPSExp: _MPSNode1x1, 378 MPSExp2: _MPSNode1x1, 379 MPSReciprocal: _MPSNode1x1, 380 MPSSqrt: _MPSNode1x1, 381 MPSNeg: _MPSNode1x1, 382 MPSLog: _MPSNode1x1, 383 MPSLog10: _MPSNode1x1, 384 MPSLog2: _MPSNode1x1, 385 MPSErf: _MPSNode1x1, 386 MPSFloor: _MPSNode1x1, 387 MPSCeil: _MPSNode1x1, 388 MPSRsqrt: _MPSNode1x1, 389 MPSSigmoid: _MPSNode1x1, 390 MPSSin: _MPSNode1x1, 391 MPSSign: _MPSNode1x1, 392 MPSCos: _MPSNode1x1, 393 MPSTan: _MPSNode1x1, 394 MPSAbs: _MPSNode1x1, 395 MPSAsin: _MPSNode1x1, 396 MPSAcos: _MPSNode1x1, 397 MPSAtan: _MPSNode1x1, 398 MPSSinh: _MPSNode1x1, 399 MPSCosh: _MPSNode1x1, 400 MPSTanh: _MPSNode1x1, 401 MPSAsinh: _MPSNode1x1, 402 MPSAcosh: _MPSNode1x1, 403 MPSAtanh: _MPSNode1x1, 404 MPSBitwiseNot: _MPSNode1x1, 405 MPSIsnan: _MPSNode1x1, 406 MPSIsinf: _MPSNode1x1, 407 MPSRound: _MPSNode1x1, 408 MPSLogicalNot: _MPSNode1x1, 409 410 // Linear algebra ops 411 MPSMatMul: _MPSNode2x1, 412 MPSAddmm, 413 414 // Constant ops 415 MPSFull: _MPSFull, 416 MPSFullLike: _MPSFull, 417 418 // Clamp ops, 419 MPSClamp, 420 MPSWhere: _MPSNode3x1, 421 422 // Indexing ops 423 MPSIndexSelect, 424 MPSEmbedding, 425 MPSIndexTensor, 426 MPSIndexPut, 427 MPSScatter, 428 429 // Reduce ops 430 MPSMean, 431 432 // Shape ops 433 MPSPermute, 434 MPSView, 435 MPSExpand: MPSView, 436 MPSCat, 437 MPSSqueeze, 438 MPSUnsqueeze, 439 MPSSelect, 440 MPSSlice, 441 MPSPixelShuffle, 442 MPSSplitWithSizes, 443 MPSCast, 444 445 // Convolution ops 446 MPSConv2D: MPSConv, 447 MPSDepthwiseConv2D: MPSConv, 448 449 // Comparasion ops 450 MPSEq: _MPSNode2x1, 451 MPSNe: _MPSNode2x1, 452 MPSGe: _MPSNode2x1, 453 MPSGt: _MPSNode2x1, 454 MPSLe: _MPSNode2x1, 455 MPSLt: _MPSNode2x1, 456 457 // Normalization ops 458 MPSBatchNorm, 459 MPSLayerNorm, 460 461 // Pooling ops 462 MPSMaxPool2DWithIndices: MPSPooling2D, 463 MPSAvgPool2D: MPSPooling2D, 464 465 // Pad ops 466 MPSConstantPadND, 467 468 // Range ops 469 MPSArange, 470 471 // Quant-Dequant ops 472 MPSDequantizePerChannelGroup, 473} 474 475table MPSNode { 476 mpsnode_union:MPSNodeUnion; 477 min_max:MPSMinMax; 478} 479 480// taken from executorch 481// Data buffer abstraction. 482// Deprecated 483table Buffer { 484 storage:[ubyte] (force_align: 16); 485} 486 487table MPSTensor { 488 datatype:MPSDataType; 489 num_dims:int; 490 dims:[int]; 491 constant_buffer_size:uint64; 492 constant_buffer:Buffer; // deprecated 493 segment_offset:uint64; 494} 495 496table DataSegment { 497 // Segment offsets are relative to the segment base offset provided in 498 // the extended file header. Segments will typically be aligned in a 499 // way to make it possible to use mmap() to load them. 500 offset: uint64; 501 502 // The size in bytes of valid data starting at the offset. The segment 503 // data may be followed by padding before the segment that follows it, 504 // to make it easier to use mmap(). 505 size: uint64; 506} 507 508table MPSGraph { 509 // Schema version. 510 version:string; 511 mps_nodes:[MPSNode]; 512 mps_values:[MPSTensor]; 513 514 input_ids:[int]; 515 output_ids:[int]; 516 constant_ids:[int]; 517 518 graph_type:OpType; 519 520 constant_segment:DataSegment; 521} 522 523root_type MPSGraph; 524