xref: /aosp_15_r20/external/executorch/backends/apple/mps/runtime/MPSGraphBuilder.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 //
2 //  Copyright (c) 2023 Apple Inc. All rights reserved.
3 //  Provided subject to the LICENSE file in the top level directory.
4 //
5 
6 #pragma once
7 
8 // Obj-C headers
9 #include <Foundation/Foundation.h>
10 #include <Metal/Metal.h>
11 #include <MetalPerformanceShaders/MetalPerformanceShaders.h>
12 #include <MetalPerformanceShadersGraph/MetalPerformanceShadersGraph.h>
13 
14 // Runtime headers
15 #include <executorch/runtime/backend/interface.h>
16 #include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
17 
18 // MPS headers
19 #include <executorch/backends/apple/mps/runtime/operations/MPSGraphSequoiaOps.h>
20 #include <executorch/backends/apple/mps/runtime/operations/MPSGraphVenturaOps.h>
21 #include <executorch/backends/apple/mps/runtime/operations/OperationUtils.h>
22 #include <executorch/backends/apple/mps/schema_generated.h>
23 
24 #include <unordered_map>
25 #include <vector>
26 
27 namespace executorch {
28 namespace backends {
29 namespace mps {
30 namespace delegate {
31 
32 using Error = executorch::runtime::Error;
33 using DataType = mpsgraph::MPSDataType;
34 using TensorPtr = const mpsgraph::MPSTensor *;
35 using NodePtr = const mpsgraph::MPSNode *;
36 
37 #define _DEFINE_MPS_OP(name) Error mps##name##Op(NodePtr nodePtr);
38 
39 /**
40  * Helper class to construct a MPSGraph object from a serialized MPS FlatBuffer model.
41  * It records all the input placeholders, lifted weights/biases and output feeds.
42  */
43 class MPSGraphBuilder {
44 public:
45   MPSGraphBuilder(const void *buffer_pointer, size_t num_bytes,
46                   std::unordered_map<MPSGraphTensor *, int32_t> &mpsGraphTensorToId);
47   ~MPSGraphBuilder() = default;
48 
49   Error compileModel();
50   MPSGraph *getMPSGraph();
51   MPSGraphExecutable *getMPSGraphExecutable();
52 
53 private:
54   // Input feeds & constant ops
55   Error mpsGraphRankedPlaceholder(int32_t id);
56   Error mpsConstantOp(int32_t id);
57   // Activation ops
58   _DEFINE_MPS_OP(HardTanh);
59   _DEFINE_MPS_OP(ReLU);
60   _DEFINE_MPS_OP(GELU);
61   _DEFINE_MPS_OP(LeakyReLU);
62   _DEFINE_MPS_OP(Softmax);
63   _DEFINE_MPS_OP(LogSoftmax);
64   // Arithmetic Binary Ops
65   _DEFINE_MPS_OP(Add);
66   _DEFINE_MPS_OP(Sub);
67   _DEFINE_MPS_OP(Mul);
68   _DEFINE_MPS_OP(Div);
69   _DEFINE_MPS_OP(Pow);
70   _DEFINE_MPS_OP(Fmod);
71   _DEFINE_MPS_OP(Remainder);
72   _DEFINE_MPS_OP(BitwiseAnd);
73   _DEFINE_MPS_OP(BitwiseOr);
74   _DEFINE_MPS_OP(BitwiseXor);
75   _DEFINE_MPS_OP(Minimum);
76   // Comparison ops
77   _DEFINE_MPS_OP(Eq);
78   _DEFINE_MPS_OP(Ne);
79   _DEFINE_MPS_OP(Ge);
80   _DEFINE_MPS_OP(Gt);
81   _DEFINE_MPS_OP(Le);
82   _DEFINE_MPS_OP(Lt);
83   // Unary ops
84   _DEFINE_MPS_OP(Exp);
85   _DEFINE_MPS_OP(Exp2);
86   _DEFINE_MPS_OP(Reciprocal);
87   _DEFINE_MPS_OP(Sqrt);
88   _DEFINE_MPS_OP(Neg);
89   _DEFINE_MPS_OP(Log);
90   _DEFINE_MPS_OP(Log10);
91   _DEFINE_MPS_OP(Log2);
92   _DEFINE_MPS_OP(Erf);
93   _DEFINE_MPS_OP(Floor);
94   _DEFINE_MPS_OP(Ceil);
95   _DEFINE_MPS_OP(Rsqrt);
96   _DEFINE_MPS_OP(Sigmoid);
97   _DEFINE_MPS_OP(Sin);
98   _DEFINE_MPS_OP(Sign);
99   _DEFINE_MPS_OP(Cos);
100   _DEFINE_MPS_OP(Tan);
101   _DEFINE_MPS_OP(Abs);
102   _DEFINE_MPS_OP(Asin);
103   _DEFINE_MPS_OP(Acos);
104   _DEFINE_MPS_OP(Atan);
105   _DEFINE_MPS_OP(Sinh);
106   _DEFINE_MPS_OP(Cosh);
107   _DEFINE_MPS_OP(Tanh);
108   _DEFINE_MPS_OP(Asinh);
109   _DEFINE_MPS_OP(Acosh);
110   _DEFINE_MPS_OP(Atanh);
111   _DEFINE_MPS_OP(BitwiseNot);
112   _DEFINE_MPS_OP(Isnan);
113   _DEFINE_MPS_OP(Isinf);
114   _DEFINE_MPS_OP(Round);
115   _DEFINE_MPS_OP(LogicalNot);
116   _DEFINE_MPS_OP(NormCdf);
117   // Clamp ops
118   _DEFINE_MPS_OP(Clamp);
119   _DEFINE_MPS_OP(Where);
120   // BitWise ops
121   // Convolution ops
122   _DEFINE_MPS_OP(Conv2D);
123   _DEFINE_MPS_OP(DepthwiseConv2D);
124   // Indexing ops
125   _DEFINE_MPS_OP(IndexSelect);
126   _DEFINE_MPS_OP(Embedding);
127   _DEFINE_MPS_OP(IndexTensor);
128   _DEFINE_MPS_OP(IndexPut);
129   _DEFINE_MPS_OP(Scatter);
130   // Linear algebra ops
131   _DEFINE_MPS_OP(MatMul);
132   _DEFINE_MPS_OP(Addmm);
133   // Constant ops
134   _DEFINE_MPS_OP(Full);
135   _DEFINE_MPS_OP(FullLike);
136   // Normalization ops
137   _DEFINE_MPS_OP(BatchNorm);
138   _DEFINE_MPS_OP(LayerNorm);
139   // Reduce ops
140   _DEFINE_MPS_OP(Mean);
141   // Shape ops
142   _DEFINE_MPS_OP(Permute);
143   _DEFINE_MPS_OP(View);
144   _DEFINE_MPS_OP(Expand);
145   _DEFINE_MPS_OP(Cat);
146   _DEFINE_MPS_OP(Squeeze);
147   _DEFINE_MPS_OP(Unsqueeze);
148   _DEFINE_MPS_OP(Select);
149   _DEFINE_MPS_OP(Slice);
150   _DEFINE_MPS_OP(PixelShuffle);
151   _DEFINE_MPS_OP(SplitWithSizes);
152   _DEFINE_MPS_OP(Cast);
153   // Pooling ops
154   _DEFINE_MPS_OP(MaxPool2DWithIndices);
155   _DEFINE_MPS_OP(AvgPool2D);
156   // Pad ops
157   _DEFINE_MPS_OP(ConstantPadND);
158   // Range ops
159   _DEFINE_MPS_OP(Arange);
160   // Quant-Dequant ops
161   _DEFINE_MPS_OP(DequantizePerChannelGroup);
162 
163   // Helper functions
164   Error addNodeToMPSGraph(NodePtr nodePtr);
165   Error compileMetalKernel(NodePtr nodePtr);
166   MPSShape *getMPSShape(int32_t id);
167   MPSShape *getMPSShape(const flatbuffers::Vector<int32_t> *shape);
168   int64_t numel(const flatbuffers::Vector<int32_t> *shape);
169   MPSDataType getMPSDataType(int32_t id);
170   MPSDataType getMPSDataType(DataType serializedDataType);
171   MPSGraphTensor *getMPSGraphTensor(int32_t id);
172   NSData *getConstantData(int32_t id);
173   std::pair<float, float> getMinMaxValues(NodePtr nodePtr);
174   Error compileMPSGraph();
175   Error compileMetalKernel();
176 
177   // Each MPSGraph op result in at least MPSGraphTensor being
178   // produced, which will be stored in this structure. Other ops
179   // can reference the saved tensor by the AOT id (1:1 mapping).
180   std::vector<MPSGraphTensor *> _idToMPSGraphTensor;
181   std::unordered_map<MPSGraphTensor *, int32_t> &_mpsGraphTensorToId;
182   // FlatBuffer serialized graph containing the nodes from the original model.
183   const mpsgraph::MPSGraph *_flatBufferGraph;
184   // FlatBuffer raw bytes of the serialized MPS model.
185   const void *_buffer_pointer;
186   size_t _num_bytes;
187 
188   bool _metal_kernel;
189   MPSGraph *_mpsGraph;
190   MPSGraphExecutable *_mpsGraphExecutable;
191   NSMutableDictionary<MPSGraphTensor *, MPSGraphShapedType *> *_feeds;
192   NSMutableArray<MPSGraphTensor *> *_targetTensors;
193 
194   const uint8_t *_constant_data_ptr;
195 };
196 
197 #undef _DEFINE_MPS_OP
198 
199 } // namespace delegate
200 } // namespace mps
201 } // namespace backends
202 } // namespace executorch
203