1 #ifndef MPSImageWrapper_h 2 #define MPSImageWrapper_h 3 4 #import <ATen/native/metal/MetalCommandBuffer.h> 5 #import <MetalPerformanceShaders/MetalPerformanceShaders.h> 6 #include <c10/util/ArrayRef.h> 7 8 namespace at { 9 namespace native { 10 namespace metal { 11 12 class API_AVAILABLE(ios(11.0), macos(10.13)) MPSImageWrapper { 13 public: 14 MPSImageWrapper(IntArrayRef sizes); 15 ~MPSImageWrapper(); 16 void copyDataFromHost(const float* inputData); 17 void copyDataToHost(float* hostData); 18 void allocateStorage(IntArrayRef sizes); 19 void allocateTemporaryStorage( 20 IntArrayRef sizes, 21 MetalCommandBuffer* commandBuffer); 22 void setCommandBuffer(MetalCommandBuffer* buffer); 23 MetalCommandBuffer* commandBuffer() const; 24 void setImage(MPSImage* image); 25 MPSImage* image() const; 26 id<MTLBuffer> buffer() const; 27 void synchronize(); 28 void prepare(); 29 void release(); 30 31 private: 32 std::vector<int64_t> _imageSizes; 33 MPSImage* _image = nil; 34 id<MTLBuffer> _buffer = nil; 35 MetalCommandBuffer* _commandBuffer = nil; 36 id<PTMetalCommandBuffer> _delegate = nil; 37 }; 38 39 } // namespace metal 40 } // namespace native 41 } // namespace at 42 43 #endif /* MPSImageWrapper_h */ 44