xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/metal/mpscnn/MPSImageUtils.mm (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#import <ATen/native/metal/MetalContext.h>
2#import <ATen/native/metal/MetalTensorUtils.h>
3#import <ATen/native/metal/mpscnn/MPSCNNUtils.h>
4#import <ATen/native/metal/mpscnn/MPSImage+Tensor.h>
5#import <ATen/native/metal/mpscnn/MPSImageUtils.h>
6
7#include <ATen/ATen.h>
8#include <c10/util/accumulate.h>
9
10namespace at {
11namespace native {
12namespace metal {
13
14MPSImage* createStaticImage(IntArrayRef sizes) {
15  int64_t N = sizes[0];
16  int64_t C = sizes[1];
17  int64_t H = sizes[2];
18  int64_t W = sizes[3];
19  MPSImageDescriptor* desc = [MPSImageDescriptor
20      imageDescriptorWithChannelFormat:MPSImageFeatureChannelFormatFloat16
21                                 width:W
22                                height:H
23                       featureChannels:C
24                        numberOfImages:N
25                                 usage:MTLTextureUsageShaderRead |
26                                 MTLTextureUsageShaderWrite];
27  MPSImage* image =
28      [[MPSImage alloc] initWithDevice:[MetalContext sharedInstance].device
29                       imageDescriptor:desc];
30  image.label = [NSString
31      stringWithFormat:@"[%d, %d, %d, %d]", (int)N, (int)C, (int)H, (int)W];
32  return image;
33}
34
35MPSImage* createStaticImage(const float* src, IntArrayRef sizes) {
36  int64_t size_bytes = c10::multiply_integers(sizes) * sizeof(float);
37  id<MTLBuffer> buff = [[MetalContext sharedInstance].device
38      newBufferWithLength:size_bytes
39                  options:MTLResourceCPUCacheModeWriteCombined];
40  memcpy(buff.contents, src, size_bytes);
41  MPSImage* output = createStaticImage(sizes);
42  id<MTLComputePipelineState> state = [[MetalContext sharedInstance]
43      specializedPipelineState:metal::mpscnn::kernelFor(
44                                   output,
45                                   "copy_nchw_to_metal",
46                                   "copy_nchw_to_metal_nonarray")
47                     Constants:@[
48                       @(output.featureChannels),
49                       @(output.height),
50                       @(output.width)
51                     ]];
52  MetalCommandBuffer* cb = [MetalCommandBuffer newBuffer];
53  id<MTLComputeCommandEncoder> encoder = [cb.buffer computeCommandEncoder];
54  [encoder setComputePipelineState:state];
55  [encoder setBuffer:buff offset:0 atIndex:0];
56  [encoder setTexture:[output texture] atIndex:0];
57  const auto& launchParams =
58      metal::mpscnn::spatialPointwiseKernelLaunchParams(state, output);
59  [encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid
60          threadsPerThreadgroup:launchParams.threadsPerThreadgroup];
61  [encoder endEncoding];
62  [cb commit];
63  return output;
64}
65
66MPSImage* createStaticImage(
67    MPSTemporaryImage* image,
68    MetalCommandBuffer* buffer,
69    bool waitUntilCompleted) {
70  TORCH_CHECK(buffer);
71  MPSImage* Y = createStaticImage([image sizes]);
72  id<MTLComputeCommandEncoder> encoder = [buffer.buffer computeCommandEncoder];
73  id<MTLComputePipelineState> state = [[MetalContext sharedInstance]
74      pipelineState:mpscnn::kernelFor(image, "copy", "copy_nonarray")];
75
76  [encoder setComputePipelineState:state];
77  [encoder setTexture:[image texture] atIndex:0];
78  [encoder setTexture:[Y texture] atIndex:1];
79
80  const auto& launchParams =
81      metal::mpscnn::spatialPointwiseKernelLaunchParams(state, image);
82  [encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid
83          threadsPerThreadgroup:launchParams.threadsPerThreadgroup];
84  [encoder endEncoding];
85  if (waitUntilCompleted) {
86    [buffer commit];
87  }
88  return Y;
89}
90
91MPSTemporaryImage* createTemporaryImage(
92    MetalCommandBuffer* buffer,
93    IntArrayRef sizes) {
94  TORCH_CHECK(buffer);
95  int64_t N = sizes[0];
96  int64_t C = sizes[1];
97  int64_t H = sizes[2];
98  int64_t W = sizes[3];
99  MPSImageDescriptor* desc = [MPSImageDescriptor
100      imageDescriptorWithChannelFormat:MPSImageFeatureChannelFormatFloat16
101                                 width:W
102                                height:H
103                       featureChannels:C
104                        numberOfImages:N
105                                 usage:MTLTextureUsageShaderRead |
106                                 MTLTextureUsageShaderWrite];
107  MPSTemporaryImage* image =
108      [MPSTemporaryImage temporaryImageWithCommandBuffer:buffer.buffer
109                                         imageDescriptor:desc];
110  image.readCount = INT_MAX;
111  image.label = [NSString
112      stringWithFormat:@"[%d, %d, %d, %d]", (int)N, (int)C, (int)H, (int)W];
113  [buffer add:image];
114  return image;
115}
116
117MPSTemporaryImage* createTemporaryImage(
118    MetalCommandBuffer* buffer,
119    IntArrayRef sizes,
120    const float* src) {
121  TORCH_CHECK(buffer);
122  int64_t size_bytes = c10::multiply_integers(sizes) * sizeof(float);
123  id<MTLBuffer> buff = [[MetalContext sharedInstance].device
124      newBufferWithBytes:src
125                  length:size_bytes
126                 options:MTLResourceStorageModeShared];
127  MPSTemporaryImage* output = createTemporaryImage(buffer, sizes);
128  id<MTLComputePipelineState> state = [[MetalContext sharedInstance]
129      specializedPipelineState:metal::mpscnn::kernelFor(
130                                   output,
131                                   "copy_nchw_to_metal",
132                                   "copy_nchw_to_metal_nonarray")
133                     Constants:@[
134                       @(output.featureChannels),
135                       @(output.height),
136                       @(output.width)
137                     ]];
138  id<MTLComputeCommandEncoder> encoder = [buffer.buffer computeCommandEncoder];
139  [encoder setComputePipelineState:state];
140  [encoder setBuffer:buff offset:0 atIndex:0];
141  [encoder setTexture:[output texture] atIndex:0];
142  const auto& launchParams =
143      metal::mpscnn::spatialPointwiseKernelLaunchParams(state, output);
144  [encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid
145          threadsPerThreadgroup:launchParams.threadsPerThreadgroup];
146  [encoder endEncoding];
147  return output;
148}
149
150MPSTemporaryImage* createTemporaryImage(
151    MetalCommandBuffer* buffer,
152    MPSImage* image) {
153  TORCH_CHECK(buffer);
154  MPSTemporaryImage* Y = createTemporaryImage(buffer, [image sizes]);
155  id<MTLComputeCommandEncoder> encoder = [buffer.buffer computeCommandEncoder];
156  id<MTLComputePipelineState> state = [[MetalContext sharedInstance]
157      pipelineState:metal::mpscnn::kernelFor(image, "copy", "copy_nonarray")];
158  [encoder setComputePipelineState:state];
159  [encoder setTexture:[image texture] atIndex:0];
160  [encoder setTexture:[Y texture] atIndex:1];
161
162  const auto& launchParams =
163      metal::mpscnn::spatialPointwiseKernelLaunchParams(state, image);
164  [encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid
165          threadsPerThreadgroup:launchParams.threadsPerThreadgroup];
166  [encoder endEncoding];
167  return Y;
168}
169
170void copyImageToFloatBuffer(float* dst, MPSImage* image) {
171  int64_t size_bytes = c10::multiply_integers([image sizes]) * sizeof(float);
172  id<MTLBuffer> buffer = [[MetalContext sharedInstance].device
173      newBufferWithLength:size_bytes
174                  options:MTLResourceCPUCacheModeDefaultCache];
175
176  id<MTLCommandBuffer> cb =
177      [MetalContext sharedInstance].commandQueue.commandBuffer;
178  id<MTLComputeCommandEncoder> encoder = [cb computeCommandEncoder];
179  id<MTLComputePipelineState> state = [[MetalContext sharedInstance]
180      specializedPipelineState:metal::mpscnn::kernelFor(
181                                   image,
182                                   "copy_metal_to_nchw",
183                                   "copy_metal_to_nchw_nonarray")
184                     Constants:@[
185                       @(image.featureChannels),
186                       @(image.height),
187                       @(image.width)
188                     ]];
189
190  [encoder setComputePipelineState:state];
191  [encoder setBuffer:buffer offset:0 atIndex:0];
192  [encoder setTexture:[image texture] atIndex:0];
193
194  const auto& launchParams =
195      metal::mpscnn::spatialPointwiseKernelLaunchParams(state, image);
196  [encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid
197          threadsPerThreadgroup:launchParams.threadsPerThreadgroup];
198  [encoder endEncoding];
199  [cb commit];
200  [cb waitUntilCompleted];
201  memcpy(dst, buffer.contents, buffer.length);
202}
203
204void copyImageToMetalBuffer(
205    MetalCommandBuffer* cmdBuffer,
206    id<MTLBuffer> dst,
207    MPSImage* image) {
208  TORCH_CHECK(cmdBuffer.buffer);
209  id<MTLComputeCommandEncoder> encoder =
210      [cmdBuffer.buffer computeCommandEncoder];
211  id<MTLComputePipelineState> state = [[MetalContext sharedInstance]
212      specializedPipelineState:metal::mpscnn::kernelFor(
213                                   image,
214                                   "copy_metal_to_nchw",
215                                   "copy_metal_to_nchw_nonarray")
216                     Constants:@[
217                       @(image.featureChannels),
218                       @(image.height),
219                       @(image.width)
220                     ]];
221
222  [encoder setComputePipelineState:state];
223  [encoder setBuffer:dst offset:0 atIndex:0];
224  [encoder setTexture:[image texture] atIndex:0];
225
226  const auto& launchParams =
227      metal::mpscnn::spatialPointwiseKernelLaunchParams(state, image);
228  [encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid
229          threadsPerThreadgroup:launchParams.threadsPerThreadgroup];
230  [encoder endEncoding];
231}
232
233}
234}
235}
236