1 // Copyright © 2022 Apple Inc. 2 3 #pragma once 4 5 #include <cstdint> 6 #include <utility> 7 8 #include <c10/core/DeviceGuard.h> 9 #include <c10/util/Exception.h> 10 #include <c10/core/Stream.h> 11 #include <ATen/mps/MPSDevice.h> 12 13 #ifdef __OBJC__ 14 #include <Foundation/Foundation.h> 15 #include <Metal/Metal.h> 16 #include <MetalPerformanceShaders/MetalPerformanceShaders.h> 17 #include <MetalPerformanceShadersGraph/MetalPerformanceShadersGraph.h> 18 typedef id<MTLCommandQueue> MTLCommandQueue_t; 19 typedef id<MTLCommandBuffer> MTLCommandBuffer_t; 20 typedef id<MTLComputeCommandEncoder> MTLComputeCommandEncoder_t; 21 typedef id<MTLSharedEvent> MTLSharedEvent_t; 22 typedef id<MTLDevice> MTLDevice_t; 23 #else 24 typedef void* MTLCommandQueue_t; 25 typedef void* MTLCommandQueue; 26 typedef void* MTLCommandBuffer_t; 27 typedef void* MTLCommandBuffer; 28 typedef void* MTLComputeCommandEncoder_t; 29 typedef void* MTLSharedEvent_t; 30 typedef void* dispatch_queue_t; 31 typedef void* MTLDevice_t; 32 #define nil NULL; 33 #endif 34 35 36 namespace at::mps { 37 38 //----------------------------------------------------------------- 39 // MPSStream 40 //----------------------------------------------------------------- 41 42 enum class SyncType { 43 NONE, // no commit to command buffer 44 COMMIT, // commit and flush the command buffer 45 COMMIT_AND_WAIT, // flush and wait for command buffer execution to finish 46 COMMIT_AND_CONTINUE,// commit and continue with a new underlying command buffer 47 COMMIT_ADAPTIVE, // commit adaptively based on available memory 48 }; 49 50 class TORCH_API MPSStream 51 { 52 public: 53 enum Unchecked { UNCHECKED }; 54 55 /// Construct a MPSStream from a Stream. This construction is checked, 56 /// and will raise an error if the Stream is not, in fact, a MPS stream. 57 explicit MPSStream(Stream stream); 58 59 ~MPSStream(); commandQueue()60 MTLCommandQueue_t commandQueue() const { return _commandQueue; }; queue()61 dispatch_queue_t queue() const { return _serialQueue; } 62 63 MPSCommandBuffer* commandBuffer(); 64 MTLComputeCommandEncoder_t commandEncoder(); 65 void endKernelCoalescing(); 66 void synchronize(SyncType syncType); 67 void fill(id<MTLBuffer> buffer, uint8_t value, size_t length, size_t offset, SyncType syncType = SyncType::NONE); 68 void copy(id<MTLBuffer> srcBuffer, id<MTLBuffer> dstBuffer, 69 size_t length, size_t srcOffset, size_t dstOffset, 70 uint64_t profileId, SyncType syncType = SyncType::NONE); 71 void copy_and_sync(id<MTLBuffer> srcBuffer, id<MTLBuffer> dstBuffer, 72 size_t length, size_t srcOffset, size_t dstOffset, 73 bool non_blocking, uint64_t profileId); 74 void executeMPSGraph(MPSGraph* mpsGraph, NSDictionary* feeds, NSDictionary* results, SyncType syncType = SyncType::NONE); 75 void addCompletedHandler(MTLCommandBufferHandler block); 76 77 /// Get the MPS device index that this stream is associated with. device_index()78 c10::DeviceIndex device_index() const { return _stream.device_index(); } 79 stream()80 MTLCommandQueue_t stream() const { return _commandQueue; }; 81 device()82 MTLDevice_t device() const { return [_commandQueue device];} 83 84 /// Explicit conversion to Stream. unwrap()85 Stream unwrap() const { return _stream; } 86 87 private: 88 Stream _stream; 89 MTLCommandQueue_t _commandQueue = nil; 90 MPSCommandBuffer* _commandBuffer = nil; 91 MPSCommandBuffer* _prevCommandBuffer = nil; 92 MTLComputeCommandEncoder_t _commandEncoder = nil; 93 MPSGraphExecutionDescriptor *_executionDescriptor = nil; 94 MPSGraphCompilationDescriptor *_compilationDescriptor = nil; 95 dispatch_queue_t _serialQueue = nullptr; 96 // CommitAndContinue is enabled by default 97 bool _enableCommitAndContinue = true; 98 99 // use synchronize() to access any of these commit functions outside MPSStream 100 void commit(); 101 void commitAndWait(); 102 void commitAndContinue(); 103 void flush(); 104 }; 105 106 /** 107 * Get the current MPS stream 108 */ 109 TORCH_API MPSStream* getCurrentMPSStream(); 110 111 /** 112 * Get the default MPS stream 113 */ 114 TORCH_API MPSStream* getDefaultMPSStream(); 115 116 //----------------------------------------------------------------- 117 // MPSStreamImpl 118 //----------------------------------------------------------------- 119 120 class TORCH_API MPSStreamImpl 121 { 122 public: 123 /** 124 * Gets single instance of the MPSStream. 125 */ 126 static MPSStream* getInstance(); 127 128 private: 129 static MPSStream* _stream; 130 MPSStreamImpl(); 131 }; 132 133 } // namespace at::mps 134