1 // 2 // Copyright (c) 2023 Apple Inc. All rights reserved. 3 // Provided subject to the LICENSE file in the top level directory. 4 // 5 // clang-format off 6 #pragma once 7 8 9 #include <executorch/runtime/core/error.h> 10 #include <executorch/runtime/core/exec_aten/util/tensor_util.h> 11 12 #include <executorch/backends/apple/mps/runtime/operations/OperationUtils.h> 13 #include <executorch/backends/apple/mps/runtime/MPSStream.h> 14 15 #include <map> 16 #include <memory> 17 #include <vector> 18 19 namespace executorch { 20 namespace backends { 21 namespace mps { 22 namespace delegate { 23 24 class MPSExecutor { 25 private: 26 MPSGraphExecutable* _executable; 27 NSArray<MPSGraphShapedType *> * _inputShapes; 28 NSArray<MPSGraphShapedType *> * _outputShapes; 29 30 NSMutableArray<MPSGraphTensorData *>* _inputsArray; 31 NSMutableArray<MPSGraphTensorData *>* _outputsArray; 32 33 // Flag whatever to use shared memory or not 34 // Shared memory flag will be set as following (based on HW and target config): 35 // - True: Apple Silicon and macOS15+/iOS17+/iPadOS17+ 36 // - False: Simulator or x86 or pre-macOS15/iOS17/iPadOS17 37 bool _use_shared_mem; 38 bool _buffers_initialized; 39 40 // Input/Output GPU buffer pointer 41 std::vector<id<MTLBuffer>> _inputGPUBuffers; 42 std::vector<id<MTLBuffer>> _outputGPUBuffers; 43 44 // Input/Output CPU buffer pointers 45 std::vector<CPUBufferWrapper> _inputCPUBuffers; 46 std::vector<CPUBufferWrapper> _outputCPUBuffers; 47 48 std::unordered_map<MPSGraphTensor*, int32_t> _mpsGraphTensorToId; 49 public: 50 MPSExecutor(); ~MPSExecutor()51 ~MPSExecutor() { 52 if (_inputsArray) { 53 [_inputsArray release]; 54 _inputsArray = nil; 55 } 56 if (_outputsArray) { 57 [_outputsArray release]; 58 } 59 60 _inputsArray = nil; 61 _outputsArray = nil; 62 } 63 getNumInputs()64 inline size_t getNumInputs() { 65 return [_inputShapes count]; 66 } 67 getNumOutputs()68 inline size_t getNumOutputs() { 69 return [_outputShapes count]; 70 } 71 getMPSGraphExecutable()72 inline MPSGraphExecutable* getMPSGraphExecutable() { 73 return _executable; 74 } 75 76 ET_NODISCARD executorch::runtime::Error forward(std::vector<const executorch::aten::Tensor*>& outputs); 77 78 ET_NODISCARD executorch::runtime::Error 79 set_inputs_outputs(std::vector<const executorch::aten::Tensor*>& inputs, std::vector<const executorch::aten::Tensor*>& outputs); 80 81 executorch::runtime::Error initDataBuffers(); 82 executorch::runtime::Error updateDataBuffers(std::vector<const executorch::aten::Tensor*>& inputs, std::vector<const executorch::aten::Tensor*>& outputs); 83 executorch::runtime::Error syncOutputBuffers(std::vector<const executorch::aten::Tensor*>& outputs); 84 85 friend class MPSCompiler; 86 }; 87 88 } // namespace delegate 89 } // namespace mps 90 } // namespace backends 91 } // namespace executorch 92 // clang-format on 93