xref: /aosp_15_r20/external/executorch/backends/apple/mps/runtime/MPSStream.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 //
2 //  Copyright (c) 2023 Apple Inc. All rights reserved.
3 //  Provided subject to the LICENSE file in the top level directory.
4 //
5 
6 #pragma once
7 
8 // Obj-C headers
9 #include <Foundation/Foundation.h>
10 #include <Metal/Metal.h>
11 #include <MetalPerformanceShaders/MetalPerformanceShaders.h>
12 #include <MetalPerformanceShadersGraph/MetalPerformanceShadersGraph.h>
13 // MPS Headers
14 #include <executorch/backends/apple/mps/runtime/MPSDevice.h>
15 // Runtime headers
16 #include <executorch/runtime/core/error.h>
17 
18 #include <unordered_map>
19 
20 namespace executorch {
21 namespace backends {
22 namespace mps {
23 namespace delegate {
24 
25 enum class SyncType {
26   NONE, // no commit to command buffer
27   COMMIT, // commit and flush the command buffer
28   COMMIT_AND_WAIT, // flush and wait for command buffer execution to finish
29   COMMIT_AND_CONTINUE, // commit and continue with a new underlying command
30                        // buffer
31   COMMIT_ADAPTIVE, // commit adaptively based on available memory
32 };
33 
34 // Helper structure to copy data between CPU <-> GPU
35 struct CPUBufferWrapper {
36   void* srcBuffer;
37   void* dstBuffer;
38   size_t length;
39   size_t srcOffset;
40   size_t dstOffset;
41   union {
42     struct {
43       unsigned int srcCpu : 1;
44       unsigned int dstCpu : 1;
45     };
46     uint16_t flags;
47   };
48 };
49 
50 class MPSStream {
51  public:
52   MPSStream();
53 
54   ~MPSStream();
commandQueue()55   id<MTLCommandQueue> commandQueue() const {
56     return _commandQueue;
57   };
queue()58   dispatch_queue_t queue() const {
59     return _serialQueue;
60   }
61 
62   bool hasLiveCommandBuffer();
63   MPSCommandBuffer* commandBuffer();
64   id<MTLComputeCommandEncoder> commandEncoder();
65   void endKernelCoalescing();
66   ET_NODISCARD executorch::runtime::Error synchronize(SyncType syncType);
67   bool commitAndContinueEnabled();
68   void copy(
69       id<MTLBuffer> srcBuffer,
70       id<MTLBuffer> dstBuffer,
71       size_t length,
72       size_t srcOffset,
73       size_t dstOffset,
74       SyncType syncType = SyncType::NONE);
75   void copy(
76       std::vector<CPUBufferWrapper>& dataBuffers,
77       SyncType syncType = SyncType::NONE);
78   void copy_and_sync(
79       id<MTLBuffer> srcBuffer,
80       id<MTLBuffer> dstBuffer,
81       size_t length,
82       size_t srcOffset,
83       size_t dstOffset,
84       bool non_blocking);
85   void copy_and_sync(
86       std::vector<CPUBufferWrapper>& dataBuffers,
87       bool non_blocking);
88 
89  private:
90   id<MTLCommandQueue> _commandQueue = nil;
91   MPSCommandBuffer* _commandBuffer = nil;
92   MPSCommandBuffer* _prevCommandBuffer = nil;
93   id<MTLComputeCommandEncoder> _commandEncoder = nil;
94   dispatch_queue_t _serialQueue = nullptr;
95   // CommitAndContinue is disabled by default
96   bool _enableCommitAndContinue = false;
97   // accumulated sizes of resources encoded on command buffer
98   size_t _commandBufferResourceSize = 0;
99   // unfortunately, there's no way to get the underlying buffer from
100   // an MPSGraphTensorData. so we need to keep a mapping of them here
101   std::unordered_map<MPSGraphTensorData*, void*> _activeResources{};
102 
103   // use synchronize() to access any of these commit functions outside MPSStream
104   void commit();
105   void commitAndWait();
106   void commitAndContinue();
107   void flush();
108 };
109 
110 /**
111  * Get the current MPS stream
112  */
113 MPSStream* getCurrentMPSStream();
114 
115 /**
116  * Get the default MPS stream
117  */
118 MPSStream* getDefaultMPSStream();
119 
120 //-----------------------------------------------------------------
121 //  MPSStreamImpl
122 //-----------------------------------------------------------------
123 
124 class MPSStreamImpl {
125  public:
126   /**
127    * Gets single instance of the MPSStream.
128    */
129   static MPSStream* getInstance();
130 
131  private:
132   static MPSStream* _stream;
133   MPSStreamImpl();
134 };
135 
136 } // namespace delegate
137 } // namespace mps
138 } // namespace backends
139 } // namespace executorch
140