1 // 2 // Copyright (c) 2023 Apple Inc. All rights reserved. 3 // Provided subject to the LICENSE file in the top level directory. 4 // 5 6 #pragma once 7 8 // Obj-C headers 9 #include <Foundation/Foundation.h> 10 #include <Metal/Metal.h> 11 #include <MetalPerformanceShaders/MetalPerformanceShaders.h> 12 #include <MetalPerformanceShadersGraph/MetalPerformanceShadersGraph.h> 13 14 // Runtime headers 15 #include <executorch/runtime/backend/interface.h> 16 #include <executorch/runtime/core/exec_aten/util/scalar_type_util.h> 17 18 // MPS headers 19 #include <executorch/backends/apple/mps/runtime/operations/MPSGraphSequoiaOps.h> 20 #include <executorch/backends/apple/mps/runtime/operations/MPSGraphVenturaOps.h> 21 #include <executorch/backends/apple/mps/runtime/operations/OperationUtils.h> 22 #include <executorch/backends/apple/mps/schema_generated.h> 23 24 #include <unordered_map> 25 #include <vector> 26 27 namespace executorch { 28 namespace backends { 29 namespace mps { 30 namespace delegate { 31 32 using Error = executorch::runtime::Error; 33 using DataType = mpsgraph::MPSDataType; 34 using TensorPtr = const mpsgraph::MPSTensor *; 35 using NodePtr = const mpsgraph::MPSNode *; 36 37 #define _DEFINE_MPS_OP(name) Error mps##name##Op(NodePtr nodePtr); 38 39 /** 40 * Helper class to construct a MPSGraph object from a serialized MPS FlatBuffer model. 41 * It records all the input placeholders, lifted weights/biases and output feeds. 42 */ 43 class MPSGraphBuilder { 44 public: 45 MPSGraphBuilder(const void *buffer_pointer, size_t num_bytes, 46 std::unordered_map<MPSGraphTensor *, int32_t> &mpsGraphTensorToId); 47 ~MPSGraphBuilder() = default; 48 49 Error compileModel(); 50 MPSGraph *getMPSGraph(); 51 MPSGraphExecutable *getMPSGraphExecutable(); 52 53 private: 54 // Input feeds & constant ops 55 Error mpsGraphRankedPlaceholder(int32_t id); 56 Error mpsConstantOp(int32_t id); 57 // Activation ops 58 _DEFINE_MPS_OP(HardTanh); 59 _DEFINE_MPS_OP(ReLU); 60 _DEFINE_MPS_OP(GELU); 61 _DEFINE_MPS_OP(LeakyReLU); 62 _DEFINE_MPS_OP(Softmax); 63 _DEFINE_MPS_OP(LogSoftmax); 64 // Arithmetic Binary Ops 65 _DEFINE_MPS_OP(Add); 66 _DEFINE_MPS_OP(Sub); 67 _DEFINE_MPS_OP(Mul); 68 _DEFINE_MPS_OP(Div); 69 _DEFINE_MPS_OP(Pow); 70 _DEFINE_MPS_OP(Fmod); 71 _DEFINE_MPS_OP(Remainder); 72 _DEFINE_MPS_OP(BitwiseAnd); 73 _DEFINE_MPS_OP(BitwiseOr); 74 _DEFINE_MPS_OP(BitwiseXor); 75 _DEFINE_MPS_OP(Minimum); 76 // Comparison ops 77 _DEFINE_MPS_OP(Eq); 78 _DEFINE_MPS_OP(Ne); 79 _DEFINE_MPS_OP(Ge); 80 _DEFINE_MPS_OP(Gt); 81 _DEFINE_MPS_OP(Le); 82 _DEFINE_MPS_OP(Lt); 83 // Unary ops 84 _DEFINE_MPS_OP(Exp); 85 _DEFINE_MPS_OP(Exp2); 86 _DEFINE_MPS_OP(Reciprocal); 87 _DEFINE_MPS_OP(Sqrt); 88 _DEFINE_MPS_OP(Neg); 89 _DEFINE_MPS_OP(Log); 90 _DEFINE_MPS_OP(Log10); 91 _DEFINE_MPS_OP(Log2); 92 _DEFINE_MPS_OP(Erf); 93 _DEFINE_MPS_OP(Floor); 94 _DEFINE_MPS_OP(Ceil); 95 _DEFINE_MPS_OP(Rsqrt); 96 _DEFINE_MPS_OP(Sigmoid); 97 _DEFINE_MPS_OP(Sin); 98 _DEFINE_MPS_OP(Sign); 99 _DEFINE_MPS_OP(Cos); 100 _DEFINE_MPS_OP(Tan); 101 _DEFINE_MPS_OP(Abs); 102 _DEFINE_MPS_OP(Asin); 103 _DEFINE_MPS_OP(Acos); 104 _DEFINE_MPS_OP(Atan); 105 _DEFINE_MPS_OP(Sinh); 106 _DEFINE_MPS_OP(Cosh); 107 _DEFINE_MPS_OP(Tanh); 108 _DEFINE_MPS_OP(Asinh); 109 _DEFINE_MPS_OP(Acosh); 110 _DEFINE_MPS_OP(Atanh); 111 _DEFINE_MPS_OP(BitwiseNot); 112 _DEFINE_MPS_OP(Isnan); 113 _DEFINE_MPS_OP(Isinf); 114 _DEFINE_MPS_OP(Round); 115 _DEFINE_MPS_OP(LogicalNot); 116 _DEFINE_MPS_OP(NormCdf); 117 // Clamp ops 118 _DEFINE_MPS_OP(Clamp); 119 _DEFINE_MPS_OP(Where); 120 // BitWise ops 121 // Convolution ops 122 _DEFINE_MPS_OP(Conv2D); 123 _DEFINE_MPS_OP(DepthwiseConv2D); 124 // Indexing ops 125 _DEFINE_MPS_OP(IndexSelect); 126 _DEFINE_MPS_OP(Embedding); 127 _DEFINE_MPS_OP(IndexTensor); 128 _DEFINE_MPS_OP(IndexPut); 129 _DEFINE_MPS_OP(Scatter); 130 // Linear algebra ops 131 _DEFINE_MPS_OP(MatMul); 132 _DEFINE_MPS_OP(Addmm); 133 // Constant ops 134 _DEFINE_MPS_OP(Full); 135 _DEFINE_MPS_OP(FullLike); 136 // Normalization ops 137 _DEFINE_MPS_OP(BatchNorm); 138 _DEFINE_MPS_OP(LayerNorm); 139 // Reduce ops 140 _DEFINE_MPS_OP(Mean); 141 // Shape ops 142 _DEFINE_MPS_OP(Permute); 143 _DEFINE_MPS_OP(View); 144 _DEFINE_MPS_OP(Expand); 145 _DEFINE_MPS_OP(Cat); 146 _DEFINE_MPS_OP(Squeeze); 147 _DEFINE_MPS_OP(Unsqueeze); 148 _DEFINE_MPS_OP(Select); 149 _DEFINE_MPS_OP(Slice); 150 _DEFINE_MPS_OP(PixelShuffle); 151 _DEFINE_MPS_OP(SplitWithSizes); 152 _DEFINE_MPS_OP(Cast); 153 // Pooling ops 154 _DEFINE_MPS_OP(MaxPool2DWithIndices); 155 _DEFINE_MPS_OP(AvgPool2D); 156 // Pad ops 157 _DEFINE_MPS_OP(ConstantPadND); 158 // Range ops 159 _DEFINE_MPS_OP(Arange); 160 // Quant-Dequant ops 161 _DEFINE_MPS_OP(DequantizePerChannelGroup); 162 163 // Helper functions 164 Error addNodeToMPSGraph(NodePtr nodePtr); 165 Error compileMetalKernel(NodePtr nodePtr); 166 MPSShape *getMPSShape(int32_t id); 167 MPSShape *getMPSShape(const flatbuffers::Vector<int32_t> *shape); 168 int64_t numel(const flatbuffers::Vector<int32_t> *shape); 169 MPSDataType getMPSDataType(int32_t id); 170 MPSDataType getMPSDataType(DataType serializedDataType); 171 MPSGraphTensor *getMPSGraphTensor(int32_t id); 172 NSData *getConstantData(int32_t id); 173 std::pair<float, float> getMinMaxValues(NodePtr nodePtr); 174 Error compileMPSGraph(); 175 Error compileMetalKernel(); 176 177 // Each MPSGraph op result in at least MPSGraphTensor being 178 // produced, which will be stored in this structure. Other ops 179 // can reference the saved tensor by the AOT id (1:1 mapping). 180 std::vector<MPSGraphTensor *> _idToMPSGraphTensor; 181 std::unordered_map<MPSGraphTensor *, int32_t> &_mpsGraphTensorToId; 182 // FlatBuffer serialized graph containing the nodes from the original model. 183 const mpsgraph::MPSGraph *_flatBufferGraph; 184 // FlatBuffer raw bytes of the serialized MPS model. 185 const void *_buffer_pointer; 186 size_t _num_bytes; 187 188 bool _metal_kernel; 189 MPSGraph *_mpsGraph; 190 MPSGraphExecutable *_mpsGraphExecutable; 191 NSMutableDictionary<MPSGraphTensor *, MPSGraphShapedType *> *_feeds; 192 NSMutableArray<MPSGraphTensor *> *_targetTensors; 193 194 const uint8_t *_constant_data_ptr; 195 }; 196 197 #undef _DEFINE_MPS_OP 198 199 } // namespace delegate 200 } // namespace mps 201 } // namespace backends 202 } // namespace executorch 203