xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/metal/mpscnn/MPSImageWrapper.mm (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#import <ATen/native/metal/MetalCommandBuffer.h>
2#import <ATen/native/metal/MetalTensorUtils.h>
3#import <ATen/native/metal/MetalContext.h>
4#import <ATen/native/metal/mpscnn/MPSCNNUtils.h>
5#import <ATen/native/metal/mpscnn/MPSImage+Tensor.h>
6#import <ATen/native/metal/mpscnn/MPSImageUtils.h>
7#import <ATen/native/metal/mpscnn/MPSImageWrapper.h>
8
9using namespace at::native::metal;
10@interface MPSImageWrapperTrampoline : NSObject<PTMetalCommandBuffer>
11+ (instancetype)newWithMPSImageWrapper:(MPSImageWrapper*)wrapper;
12@end
13
14@implementation MPSImageWrapperTrampoline {
15  MPSImageWrapper* _imageWrapper;
16}
17
18+ (instancetype)newWithMPSImageWrapper:(MPSImageWrapper*)wrapper {
19  MPSImageWrapperTrampoline* trampoline = [MPSImageWrapperTrampoline new];
20  trampoline->_imageWrapper = wrapper;
21  return trampoline;
22}
23
24- (void)dealloc {
25  _imageWrapper = nullptr;
26#if !__has_feature(objc_arc)
27  [super dealloc];
28#endif
29}
30
31- (void)beginSynchronization {
32  if (_imageWrapper) {
33    _imageWrapper->prepare();
34  }
35}
36
37- (void)endSynchronization:(NSError*)error {
38  // if something went wrong during command buffer execution
39  if (error) {
40    if (_imageWrapper) {
41      _imageWrapper->release();
42    }
43    // T159183991: ignore error. We prefer to not crash the app.
44  }
45}
46
47@end
48
49namespace at {
50namespace native {
51namespace metal {
52
53MPSImageWrapper::MPSImageWrapper(IntArrayRef sizes) {
54  _imageSizes = computeImageSize(sizes);
55  _delegate = [MPSImageWrapperTrampoline newWithMPSImageWrapper:this];
56}
57
58MPSImageWrapper::~MPSImageWrapper() {
59  release();
60}
61
62void MPSImageWrapper::copyDataFromHost(const float* inputData) {
63  TORCH_CHECK(inputData);
64  _commandBuffer = [MetalCommandBuffer currentBuffer];
65  [_commandBuffer addSubscriber:_delegate];
66  _image = createTemporaryImage(_commandBuffer, _imageSizes, inputData);
67}
68
69void MPSImageWrapper::copyDataToHost(float* hostData) {
70  TORCH_CHECK(_image);
71  synchronize();
72  TORCH_CHECK(_buffer);
73  memcpy(hostData, _buffer.contents, _buffer.length);
74}
75
76MPSImage* MPSImageWrapper::image() const {
77  return _image;
78}
79
80id<MTLBuffer> MPSImageWrapper::buffer() const {
81  return _buffer;
82}
83
84void MPSImageWrapper::setCommandBuffer(MetalCommandBuffer* commandBuffer) {
85  TORCH_CHECK(commandBuffer && commandBuffer.valid);
86  _commandBuffer = commandBuffer;
87  [_commandBuffer addSubscriber:_delegate];
88}
89
90MetalCommandBuffer* MPSImageWrapper::commandBuffer() const {
91  return _commandBuffer;
92}
93
94void MPSImageWrapper::allocateStorage(IntArrayRef sizes) {
95  _imageSizes = computeImageSize(sizes);
96  _image = createStaticImage(_imageSizes);
97}
98
99void MPSImageWrapper::allocateTemporaryStorage(
100    IntArrayRef sizes,
101    MetalCommandBuffer* commandBuffer) {
102  setCommandBuffer(commandBuffer);
103  _imageSizes = computeImageSize(sizes);
104  _image = createTemporaryImage(commandBuffer, _imageSizes);
105}
106
107void MPSImageWrapper::setImage(MPSImage* image) {
108  TORCH_CHECK(image);
109  if (image.isTemporaryImage) {
110    TORCH_CHECK(_commandBuffer && _commandBuffer.valid);
111  }
112  _image = image;
113}
114
115void MPSImageWrapper::prepare() {
116  if (!_buffer) {
117    int64_t size_bytes = c10::multiply_integers([_image sizes]) * sizeof(float);
118    _buffer = [[MetalContext sharedInstance].device
119        newBufferWithLength:size_bytes
120                    options:MTLResourceCPUCacheModeWriteCombined];
121    TORCH_CHECK(_buffer, "Allocate GPU memory failed!");
122  }
123  copyImageToMetalBuffer(_commandBuffer, _buffer, _image);
124  if (_image.isTemporaryImage && _image.readCount != 0) {
125    _image =
126        createStaticImage((MPSTemporaryImage*)_image, _commandBuffer, false);
127  }
128}
129
130void MPSImageWrapper::synchronize() {
131  if (_commandBuffer && _commandBuffer.valid) {
132    [_commandBuffer commit];
133  }
134}
135
136void MPSImageWrapper::release() {
137  [_image recycle];
138  [_commandBuffer remove:(MPSTemporaryImage*)_image];
139  [_commandBuffer removeSubscriber:_delegate];
140  _delegate = nil;
141  _commandBuffer = nil;
142  _image = nil;
143  _buffer = nil;
144}
145
146}
147}
148}
149