// // Copyright (c) 2023 Apple Inc. All rights reserved. // Provided subject to the LICENSE file in the top level directory. // #pragma once // Obj-C headers #include #include #include #include // Runtime headers #include #include // MPS headers #include #include #include #include #include #include namespace executorch { namespace backends { namespace mps { namespace delegate { using Error = executorch::runtime::Error; using DataType = mpsgraph::MPSDataType; using TensorPtr = const mpsgraph::MPSTensor *; using NodePtr = const mpsgraph::MPSNode *; #define _DEFINE_MPS_OP(name) Error mps##name##Op(NodePtr nodePtr); /** * Helper class to construct a MPSGraph object from a serialized MPS FlatBuffer model. * It records all the input placeholders, lifted weights/biases and output feeds. */ class MPSGraphBuilder { public: MPSGraphBuilder(const void *buffer_pointer, size_t num_bytes, std::unordered_map &mpsGraphTensorToId); ~MPSGraphBuilder() = default; Error compileModel(); MPSGraph *getMPSGraph(); MPSGraphExecutable *getMPSGraphExecutable(); private: // Input feeds & constant ops Error mpsGraphRankedPlaceholder(int32_t id); Error mpsConstantOp(int32_t id); // Activation ops _DEFINE_MPS_OP(HardTanh); _DEFINE_MPS_OP(ReLU); _DEFINE_MPS_OP(GELU); _DEFINE_MPS_OP(LeakyReLU); _DEFINE_MPS_OP(Softmax); _DEFINE_MPS_OP(LogSoftmax); // Arithmetic Binary Ops _DEFINE_MPS_OP(Add); _DEFINE_MPS_OP(Sub); _DEFINE_MPS_OP(Mul); _DEFINE_MPS_OP(Div); _DEFINE_MPS_OP(Pow); _DEFINE_MPS_OP(Fmod); _DEFINE_MPS_OP(Remainder); _DEFINE_MPS_OP(BitwiseAnd); _DEFINE_MPS_OP(BitwiseOr); _DEFINE_MPS_OP(BitwiseXor); _DEFINE_MPS_OP(Minimum); // Comparison ops _DEFINE_MPS_OP(Eq); _DEFINE_MPS_OP(Ne); _DEFINE_MPS_OP(Ge); _DEFINE_MPS_OP(Gt); _DEFINE_MPS_OP(Le); _DEFINE_MPS_OP(Lt); // Unary ops _DEFINE_MPS_OP(Exp); _DEFINE_MPS_OP(Exp2); _DEFINE_MPS_OP(Reciprocal); _DEFINE_MPS_OP(Sqrt); _DEFINE_MPS_OP(Neg); _DEFINE_MPS_OP(Log); _DEFINE_MPS_OP(Log10); _DEFINE_MPS_OP(Log2); _DEFINE_MPS_OP(Erf); _DEFINE_MPS_OP(Floor); _DEFINE_MPS_OP(Ceil); _DEFINE_MPS_OP(Rsqrt); _DEFINE_MPS_OP(Sigmoid); _DEFINE_MPS_OP(Sin); _DEFINE_MPS_OP(Sign); _DEFINE_MPS_OP(Cos); _DEFINE_MPS_OP(Tan); _DEFINE_MPS_OP(Abs); _DEFINE_MPS_OP(Asin); _DEFINE_MPS_OP(Acos); _DEFINE_MPS_OP(Atan); _DEFINE_MPS_OP(Sinh); _DEFINE_MPS_OP(Cosh); _DEFINE_MPS_OP(Tanh); _DEFINE_MPS_OP(Asinh); _DEFINE_MPS_OP(Acosh); _DEFINE_MPS_OP(Atanh); _DEFINE_MPS_OP(BitwiseNot); _DEFINE_MPS_OP(Isnan); _DEFINE_MPS_OP(Isinf); _DEFINE_MPS_OP(Round); _DEFINE_MPS_OP(LogicalNot); _DEFINE_MPS_OP(NormCdf); // Clamp ops _DEFINE_MPS_OP(Clamp); _DEFINE_MPS_OP(Where); // BitWise ops // Convolution ops _DEFINE_MPS_OP(Conv2D); _DEFINE_MPS_OP(DepthwiseConv2D); // Indexing ops _DEFINE_MPS_OP(IndexSelect); _DEFINE_MPS_OP(Embedding); _DEFINE_MPS_OP(IndexTensor); _DEFINE_MPS_OP(IndexPut); _DEFINE_MPS_OP(Scatter); // Linear algebra ops _DEFINE_MPS_OP(MatMul); _DEFINE_MPS_OP(Addmm); // Constant ops _DEFINE_MPS_OP(Full); _DEFINE_MPS_OP(FullLike); // Normalization ops _DEFINE_MPS_OP(BatchNorm); _DEFINE_MPS_OP(LayerNorm); // Reduce ops _DEFINE_MPS_OP(Mean); // Shape ops _DEFINE_MPS_OP(Permute); _DEFINE_MPS_OP(View); _DEFINE_MPS_OP(Expand); _DEFINE_MPS_OP(Cat); _DEFINE_MPS_OP(Squeeze); _DEFINE_MPS_OP(Unsqueeze); _DEFINE_MPS_OP(Select); _DEFINE_MPS_OP(Slice); _DEFINE_MPS_OP(PixelShuffle); _DEFINE_MPS_OP(SplitWithSizes); _DEFINE_MPS_OP(Cast); // Pooling ops _DEFINE_MPS_OP(MaxPool2DWithIndices); _DEFINE_MPS_OP(AvgPool2D); // Pad ops _DEFINE_MPS_OP(ConstantPadND); // Range ops _DEFINE_MPS_OP(Arange); // Quant-Dequant ops _DEFINE_MPS_OP(DequantizePerChannelGroup); // Helper functions Error addNodeToMPSGraph(NodePtr nodePtr); Error compileMetalKernel(NodePtr nodePtr); MPSShape *getMPSShape(int32_t id); MPSShape *getMPSShape(const flatbuffers::Vector *shape); int64_t numel(const flatbuffers::Vector *shape); MPSDataType getMPSDataType(int32_t id); MPSDataType getMPSDataType(DataType serializedDataType); MPSGraphTensor *getMPSGraphTensor(int32_t id); NSData *getConstantData(int32_t id); std::pair getMinMaxValues(NodePtr nodePtr); Error compileMPSGraph(); Error compileMetalKernel(); // Each MPSGraph op result in at least MPSGraphTensor being // produced, which will be stored in this structure. Other ops // can reference the saved tensor by the AOT id (1:1 mapping). std::vector _idToMPSGraphTensor; std::unordered_map &_mpsGraphTensorToId; // FlatBuffer serialized graph containing the nodes from the original model. const mpsgraph::MPSGraph *_flatBufferGraph; // FlatBuffer raw bytes of the serialized MPS model. const void *_buffer_pointer; size_t _num_bytes; bool _metal_kernel; MPSGraph *_mpsGraph; MPSGraphExecutable *_mpsGraphExecutable; NSMutableDictionary *_feeds; NSMutableArray *_targetTensors; const uint8_t *_constant_data_ptr; }; #undef _DEFINE_MPS_OP } // namespace delegate } // namespace mps } // namespace backends } // namespace executorch