xref: /aosp_15_r20/external/pytorch/aten/src/ATen/mps/MPSStream.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 //  Copyright © 2022 Apple Inc.
2 
3 #pragma once
4 
5 #include <cstdint>
6 #include <utility>
7 
8 #include <c10/core/DeviceGuard.h>
9 #include <c10/util/Exception.h>
10 #include <c10/core/Stream.h>
11 #include <ATen/mps/MPSDevice.h>
12 
13 #ifdef __OBJC__
14 #include <Foundation/Foundation.h>
15 #include <Metal/Metal.h>
16 #include <MetalPerformanceShaders/MetalPerformanceShaders.h>
17 #include <MetalPerformanceShadersGraph/MetalPerformanceShadersGraph.h>
18 typedef id<MTLCommandQueue> MTLCommandQueue_t;
19 typedef id<MTLCommandBuffer> MTLCommandBuffer_t;
20 typedef id<MTLComputeCommandEncoder> MTLComputeCommandEncoder_t;
21 typedef id<MTLSharedEvent> MTLSharedEvent_t;
22 typedef id<MTLDevice> MTLDevice_t;
23 #else
24 typedef void* MTLCommandQueue_t;
25 typedef void* MTLCommandQueue;
26 typedef void* MTLCommandBuffer_t;
27 typedef void* MTLCommandBuffer;
28 typedef void* MTLComputeCommandEncoder_t;
29 typedef void* MTLSharedEvent_t;
30 typedef void* dispatch_queue_t;
31 typedef void* MTLDevice_t;
32 #define nil NULL;
33 #endif
34 
35 
36 namespace at::mps {
37 
38 //-----------------------------------------------------------------
39 //  MPSStream
40 //-----------------------------------------------------------------
41 
42 enum class SyncType {
43   NONE,               // no commit to command buffer
44   COMMIT,             // commit and flush the command buffer
45   COMMIT_AND_WAIT,    // flush and wait for command buffer execution to finish
46   COMMIT_AND_CONTINUE,// commit and continue with a new underlying command buffer
47   COMMIT_ADAPTIVE,    // commit adaptively based on available memory
48 };
49 
50 class TORCH_API MPSStream
51 {
52 public:
53   enum Unchecked { UNCHECKED };
54 
55   /// Construct a MPSStream from a Stream.  This construction is checked,
56   /// and will raise an error if the Stream is not, in fact, a MPS stream.
57   explicit MPSStream(Stream stream);
58 
59   ~MPSStream();
commandQueue()60   MTLCommandQueue_t commandQueue() const { return _commandQueue; };
queue()61   dispatch_queue_t queue() const { return _serialQueue; }
62 
63   MPSCommandBuffer* commandBuffer();
64   MTLComputeCommandEncoder_t commandEncoder();
65   void endKernelCoalescing();
66   void synchronize(SyncType syncType);
67   void fill(id<MTLBuffer> buffer, uint8_t value, size_t length, size_t offset, SyncType syncType = SyncType::NONE);
68   void copy(id<MTLBuffer> srcBuffer, id<MTLBuffer> dstBuffer,
69             size_t length, size_t srcOffset, size_t dstOffset,
70             uint64_t profileId, SyncType syncType = SyncType::NONE);
71   void copy_and_sync(id<MTLBuffer> srcBuffer, id<MTLBuffer> dstBuffer,
72                      size_t length, size_t srcOffset, size_t dstOffset,
73                      bool non_blocking, uint64_t profileId);
74   void executeMPSGraph(MPSGraph* mpsGraph, NSDictionary* feeds, NSDictionary* results, SyncType syncType = SyncType::NONE);
75   void addCompletedHandler(MTLCommandBufferHandler block);
76 
77   /// Get the MPS device index that this stream is associated with.
device_index()78   c10::DeviceIndex device_index() const { return _stream.device_index(); }
79 
stream()80   MTLCommandQueue_t stream() const { return _commandQueue; };
81 
device()82   MTLDevice_t device() const { return [_commandQueue device];}
83 
84   /// Explicit conversion to Stream.
unwrap()85   Stream unwrap() const { return _stream; }
86 
87 private:
88   Stream _stream;
89   MTLCommandQueue_t _commandQueue = nil;
90   MPSCommandBuffer* _commandBuffer = nil;
91   MPSCommandBuffer* _prevCommandBuffer = nil;
92   MTLComputeCommandEncoder_t _commandEncoder = nil;
93   MPSGraphExecutionDescriptor *_executionDescriptor = nil;
94   MPSGraphCompilationDescriptor *_compilationDescriptor = nil;
95   dispatch_queue_t _serialQueue = nullptr;
96   // CommitAndContinue is enabled by default
97   bool _enableCommitAndContinue = true;
98 
99   // use synchronize() to access any of these commit functions outside MPSStream
100   void commit();
101   void commitAndWait();
102   void commitAndContinue();
103   void flush();
104 };
105 
106 /**
107  * Get the current MPS stream
108  */
109 TORCH_API MPSStream* getCurrentMPSStream();
110 
111 /**
112  * Get the default MPS stream
113  */
114 TORCH_API MPSStream* getDefaultMPSStream();
115 
116 //-----------------------------------------------------------------
117 //  MPSStreamImpl
118 //-----------------------------------------------------------------
119 
120 class TORCH_API MPSStreamImpl
121 {
122  public:
123   /**
124    * Gets single instance of the MPSStream.
125    */
126   static MPSStream* getInstance();
127 
128  private:
129   static MPSStream* _stream;
130   MPSStreamImpl();
131 };
132 
133 } // namespace at::mps
134