1 // 2 // Copyright (c) 2023 Apple Inc. All rights reserved. 3 // Provided subject to the LICENSE file in the top level directory. 4 // 5 6 #pragma once 7 8 // Obj-C headers 9 #include <Foundation/Foundation.h> 10 #include <Metal/Metal.h> 11 #include <MetalPerformanceShaders/MetalPerformanceShaders.h> 12 #include <MetalPerformanceShadersGraph/MetalPerformanceShadersGraph.h> 13 // MPS Headers 14 #include <executorch/backends/apple/mps/runtime/MPSDevice.h> 15 // Runtime headers 16 #include <executorch/runtime/core/error.h> 17 18 #include <unordered_map> 19 20 namespace executorch { 21 namespace backends { 22 namespace mps { 23 namespace delegate { 24 25 enum class SyncType { 26 NONE, // no commit to command buffer 27 COMMIT, // commit and flush the command buffer 28 COMMIT_AND_WAIT, // flush and wait for command buffer execution to finish 29 COMMIT_AND_CONTINUE, // commit and continue with a new underlying command 30 // buffer 31 COMMIT_ADAPTIVE, // commit adaptively based on available memory 32 }; 33 34 // Helper structure to copy data between CPU <-> GPU 35 struct CPUBufferWrapper { 36 void* srcBuffer; 37 void* dstBuffer; 38 size_t length; 39 size_t srcOffset; 40 size_t dstOffset; 41 union { 42 struct { 43 unsigned int srcCpu : 1; 44 unsigned int dstCpu : 1; 45 }; 46 uint16_t flags; 47 }; 48 }; 49 50 class MPSStream { 51 public: 52 MPSStream(); 53 54 ~MPSStream(); commandQueue()55 id<MTLCommandQueue> commandQueue() const { 56 return _commandQueue; 57 }; queue()58 dispatch_queue_t queue() const { 59 return _serialQueue; 60 } 61 62 bool hasLiveCommandBuffer(); 63 MPSCommandBuffer* commandBuffer(); 64 id<MTLComputeCommandEncoder> commandEncoder(); 65 void endKernelCoalescing(); 66 ET_NODISCARD executorch::runtime::Error synchronize(SyncType syncType); 67 bool commitAndContinueEnabled(); 68 void copy( 69 id<MTLBuffer> srcBuffer, 70 id<MTLBuffer> dstBuffer, 71 size_t length, 72 size_t srcOffset, 73 size_t dstOffset, 74 SyncType syncType = SyncType::NONE); 75 void copy( 76 std::vector<CPUBufferWrapper>& dataBuffers, 77 SyncType syncType = SyncType::NONE); 78 void copy_and_sync( 79 id<MTLBuffer> srcBuffer, 80 id<MTLBuffer> dstBuffer, 81 size_t length, 82 size_t srcOffset, 83 size_t dstOffset, 84 bool non_blocking); 85 void copy_and_sync( 86 std::vector<CPUBufferWrapper>& dataBuffers, 87 bool non_blocking); 88 89 private: 90 id<MTLCommandQueue> _commandQueue = nil; 91 MPSCommandBuffer* _commandBuffer = nil; 92 MPSCommandBuffer* _prevCommandBuffer = nil; 93 id<MTLComputeCommandEncoder> _commandEncoder = nil; 94 dispatch_queue_t _serialQueue = nullptr; 95 // CommitAndContinue is disabled by default 96 bool _enableCommitAndContinue = false; 97 // accumulated sizes of resources encoded on command buffer 98 size_t _commandBufferResourceSize = 0; 99 // unfortunately, there's no way to get the underlying buffer from 100 // an MPSGraphTensorData. so we need to keep a mapping of them here 101 std::unordered_map<MPSGraphTensorData*, void*> _activeResources{}; 102 103 // use synchronize() to access any of these commit functions outside MPSStream 104 void commit(); 105 void commitAndWait(); 106 void commitAndContinue(); 107 void flush(); 108 }; 109 110 /** 111 * Get the current MPS stream 112 */ 113 MPSStream* getCurrentMPSStream(); 114 115 /** 116 * Get the default MPS stream 117 */ 118 MPSStream* getDefaultMPSStream(); 119 120 //----------------------------------------------------------------- 121 // MPSStreamImpl 122 //----------------------------------------------------------------- 123 124 class MPSStreamImpl { 125 public: 126 /** 127 * Gets single instance of the MPSStream. 128 */ 129 static MPSStream* getInstance(); 130 131 private: 132 static MPSStream* _stream; 133 MPSStreamImpl(); 134 }; 135 136 } // namespace delegate 137 } // namespace mps 138 } // namespace backends 139 } // namespace executorch 140