1#import <ATen/native/metal/MetalCommandBuffer.h> 2#import <ATen/native/metal/MetalTensorUtils.h> 3#import <ATen/native/metal/MetalContext.h> 4#import <ATen/native/metal/mpscnn/MPSCNNUtils.h> 5#import <ATen/native/metal/mpscnn/MPSImage+Tensor.h> 6#import <ATen/native/metal/mpscnn/MPSImageUtils.h> 7#import <ATen/native/metal/mpscnn/MPSImageWrapper.h> 8 9using namespace at::native::metal; 10@interface MPSImageWrapperTrampoline : NSObject<PTMetalCommandBuffer> 11+ (instancetype)newWithMPSImageWrapper:(MPSImageWrapper*)wrapper; 12@end 13 14@implementation MPSImageWrapperTrampoline { 15 MPSImageWrapper* _imageWrapper; 16} 17 18+ (instancetype)newWithMPSImageWrapper:(MPSImageWrapper*)wrapper { 19 MPSImageWrapperTrampoline* trampoline = [MPSImageWrapperTrampoline new]; 20 trampoline->_imageWrapper = wrapper; 21 return trampoline; 22} 23 24- (void)dealloc { 25 _imageWrapper = nullptr; 26#if !__has_feature(objc_arc) 27 [super dealloc]; 28#endif 29} 30 31- (void)beginSynchronization { 32 if (_imageWrapper) { 33 _imageWrapper->prepare(); 34 } 35} 36 37- (void)endSynchronization:(NSError*)error { 38 // if something went wrong during command buffer execution 39 if (error) { 40 if (_imageWrapper) { 41 _imageWrapper->release(); 42 } 43 // T159183991: ignore error. We prefer to not crash the app. 44 } 45} 46 47@end 48 49namespace at { 50namespace native { 51namespace metal { 52 53MPSImageWrapper::MPSImageWrapper(IntArrayRef sizes) { 54 _imageSizes = computeImageSize(sizes); 55 _delegate = [MPSImageWrapperTrampoline newWithMPSImageWrapper:this]; 56} 57 58MPSImageWrapper::~MPSImageWrapper() { 59 release(); 60} 61 62void MPSImageWrapper::copyDataFromHost(const float* inputData) { 63 TORCH_CHECK(inputData); 64 _commandBuffer = [MetalCommandBuffer currentBuffer]; 65 [_commandBuffer addSubscriber:_delegate]; 66 _image = createTemporaryImage(_commandBuffer, _imageSizes, inputData); 67} 68 69void MPSImageWrapper::copyDataToHost(float* hostData) { 70 TORCH_CHECK(_image); 71 synchronize(); 72 TORCH_CHECK(_buffer); 73 memcpy(hostData, _buffer.contents, _buffer.length); 74} 75 76MPSImage* MPSImageWrapper::image() const { 77 return _image; 78} 79 80id<MTLBuffer> MPSImageWrapper::buffer() const { 81 return _buffer; 82} 83 84void MPSImageWrapper::setCommandBuffer(MetalCommandBuffer* commandBuffer) { 85 TORCH_CHECK(commandBuffer && commandBuffer.valid); 86 _commandBuffer = commandBuffer; 87 [_commandBuffer addSubscriber:_delegate]; 88} 89 90MetalCommandBuffer* MPSImageWrapper::commandBuffer() const { 91 return _commandBuffer; 92} 93 94void MPSImageWrapper::allocateStorage(IntArrayRef sizes) { 95 _imageSizes = computeImageSize(sizes); 96 _image = createStaticImage(_imageSizes); 97} 98 99void MPSImageWrapper::allocateTemporaryStorage( 100 IntArrayRef sizes, 101 MetalCommandBuffer* commandBuffer) { 102 setCommandBuffer(commandBuffer); 103 _imageSizes = computeImageSize(sizes); 104 _image = createTemporaryImage(commandBuffer, _imageSizes); 105} 106 107void MPSImageWrapper::setImage(MPSImage* image) { 108 TORCH_CHECK(image); 109 if (image.isTemporaryImage) { 110 TORCH_CHECK(_commandBuffer && _commandBuffer.valid); 111 } 112 _image = image; 113} 114 115void MPSImageWrapper::prepare() { 116 if (!_buffer) { 117 int64_t size_bytes = c10::multiply_integers([_image sizes]) * sizeof(float); 118 _buffer = [[MetalContext sharedInstance].device 119 newBufferWithLength:size_bytes 120 options:MTLResourceCPUCacheModeWriteCombined]; 121 TORCH_CHECK(_buffer, "Allocate GPU memory failed!"); 122 } 123 copyImageToMetalBuffer(_commandBuffer, _buffer, _image); 124 if (_image.isTemporaryImage && _image.readCount != 0) { 125 _image = 126 createStaticImage((MPSTemporaryImage*)_image, _commandBuffer, false); 127 } 128} 129 130void MPSImageWrapper::synchronize() { 131 if (_commandBuffer && _commandBuffer.valid) { 132 [_commandBuffer commit]; 133 } 134} 135 136void MPSImageWrapper::release() { 137 [_image recycle]; 138 [_commandBuffer remove:(MPSTemporaryImage*)_image]; 139 [_commandBuffer removeSubscriber:_delegate]; 140 _delegate = nil; 141 _commandBuffer = nil; 142 _image = nil; 143 _buffer = nil; 144} 145 146} 147} 148} 149