xref: /aosp_15_r20/external/executorch/backends/apple/mps/runtime/MPSExecutor.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 //
2 //  Copyright (c) 2023 Apple Inc. All rights reserved.
3 //  Provided subject to the LICENSE file in the top level directory.
4 //
5 // clang-format off
6 #pragma once
7 
8 
9 #include <executorch/runtime/core/error.h>
10 #include <executorch/runtime/core/exec_aten/util/tensor_util.h>
11 
12 #include <executorch/backends/apple/mps/runtime/operations/OperationUtils.h>
13 #include <executorch/backends/apple/mps/runtime/MPSStream.h>
14 
15 #include <map>
16 #include <memory>
17 #include <vector>
18 
19 namespace executorch {
20 namespace backends {
21 namespace mps {
22 namespace delegate {
23 
24 class MPSExecutor {
25  private:
26   MPSGraphExecutable* _executable;
27   NSArray<MPSGraphShapedType *> * _inputShapes;
28   NSArray<MPSGraphShapedType *> * _outputShapes;
29 
30   NSMutableArray<MPSGraphTensorData *>* _inputsArray;
31   NSMutableArray<MPSGraphTensorData *>* _outputsArray;
32 
33   // Flag whatever to use shared memory or not
34   // Shared memory flag will be set as following (based on HW and target config):
35   //   - True: Apple Silicon and macOS15+/iOS17+/iPadOS17+
36   //   - False: Simulator or x86 or pre-macOS15/iOS17/iPadOS17
37   bool _use_shared_mem;
38   bool _buffers_initialized;
39 
40   // Input/Output GPU buffer pointer
41   std::vector<id<MTLBuffer>> _inputGPUBuffers;
42   std::vector<id<MTLBuffer>> _outputGPUBuffers;
43 
44   // Input/Output CPU buffer pointers
45   std::vector<CPUBufferWrapper> _inputCPUBuffers;
46   std::vector<CPUBufferWrapper> _outputCPUBuffers;
47 
48   std::unordered_map<MPSGraphTensor*, int32_t> _mpsGraphTensorToId;
49  public:
50   MPSExecutor();
~MPSExecutor()51   ~MPSExecutor() {
52     if (_inputsArray) {
53       [_inputsArray release];
54       _inputsArray = nil;
55     }
56     if (_outputsArray) {
57       [_outputsArray release];
58     }
59 
60     _inputsArray = nil;
61     _outputsArray = nil;
62   }
63 
getNumInputs()64   inline size_t getNumInputs() {
65     return [_inputShapes count];
66   }
67 
getNumOutputs()68   inline size_t getNumOutputs() {
69     return [_outputShapes count];
70   }
71 
getMPSGraphExecutable()72   inline MPSGraphExecutable* getMPSGraphExecutable() {
73     return _executable;
74   }
75 
76   ET_NODISCARD executorch::runtime::Error forward(std::vector<const executorch::aten::Tensor*>& outputs);
77 
78   ET_NODISCARD executorch::runtime::Error
79   set_inputs_outputs(std::vector<const executorch::aten::Tensor*>& inputs, std::vector<const executorch::aten::Tensor*>& outputs);
80 
81   executorch::runtime::Error initDataBuffers();
82   executorch::runtime::Error updateDataBuffers(std::vector<const executorch::aten::Tensor*>& inputs, std::vector<const executorch::aten::Tensor*>& outputs);
83   executorch::runtime::Error syncOutputBuffers(std::vector<const executorch::aten::Tensor*>& outputs);
84 
85   friend class MPSCompiler;
86 };
87 
88 } // namespace delegate
89 } // namespace mps
90 } // namespace backends
91 } // namespace executorch
92 // clang-format on
93