1#import <ATen/native/metal/MetalContext.h> 2#import <ATen/native/metal/MetalTensorUtils.h> 3#import <ATen/native/metal/mpscnn/MPSCNNUtils.h> 4#import <ATen/native/metal/mpscnn/MPSImage+Tensor.h> 5#import <ATen/native/metal/mpscnn/MPSImageUtils.h> 6 7#include <ATen/ATen.h> 8#include <c10/util/accumulate.h> 9 10namespace at { 11namespace native { 12namespace metal { 13 14MPSImage* createStaticImage(IntArrayRef sizes) { 15 int64_t N = sizes[0]; 16 int64_t C = sizes[1]; 17 int64_t H = sizes[2]; 18 int64_t W = sizes[3]; 19 MPSImageDescriptor* desc = [MPSImageDescriptor 20 imageDescriptorWithChannelFormat:MPSImageFeatureChannelFormatFloat16 21 width:W 22 height:H 23 featureChannels:C 24 numberOfImages:N 25 usage:MTLTextureUsageShaderRead | 26 MTLTextureUsageShaderWrite]; 27 MPSImage* image = 28 [[MPSImage alloc] initWithDevice:[MetalContext sharedInstance].device 29 imageDescriptor:desc]; 30 image.label = [NSString 31 stringWithFormat:@"[%d, %d, %d, %d]", (int)N, (int)C, (int)H, (int)W]; 32 return image; 33} 34 35MPSImage* createStaticImage(const float* src, IntArrayRef sizes) { 36 int64_t size_bytes = c10::multiply_integers(sizes) * sizeof(float); 37 id<MTLBuffer> buff = [[MetalContext sharedInstance].device 38 newBufferWithLength:size_bytes 39 options:MTLResourceCPUCacheModeWriteCombined]; 40 memcpy(buff.contents, src, size_bytes); 41 MPSImage* output = createStaticImage(sizes); 42 id<MTLComputePipelineState> state = [[MetalContext sharedInstance] 43 specializedPipelineState:metal::mpscnn::kernelFor( 44 output, 45 "copy_nchw_to_metal", 46 "copy_nchw_to_metal_nonarray") 47 Constants:@[ 48 @(output.featureChannels), 49 @(output.height), 50 @(output.width) 51 ]]; 52 MetalCommandBuffer* cb = [MetalCommandBuffer newBuffer]; 53 id<MTLComputeCommandEncoder> encoder = [cb.buffer computeCommandEncoder]; 54 [encoder setComputePipelineState:state]; 55 [encoder setBuffer:buff offset:0 atIndex:0]; 56 [encoder setTexture:[output texture] atIndex:0]; 57 const auto& launchParams = 58 metal::mpscnn::spatialPointwiseKernelLaunchParams(state, output); 59 [encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid 60 threadsPerThreadgroup:launchParams.threadsPerThreadgroup]; 61 [encoder endEncoding]; 62 [cb commit]; 63 return output; 64} 65 66MPSImage* createStaticImage( 67 MPSTemporaryImage* image, 68 MetalCommandBuffer* buffer, 69 bool waitUntilCompleted) { 70 TORCH_CHECK(buffer); 71 MPSImage* Y = createStaticImage([image sizes]); 72 id<MTLComputeCommandEncoder> encoder = [buffer.buffer computeCommandEncoder]; 73 id<MTLComputePipelineState> state = [[MetalContext sharedInstance] 74 pipelineState:mpscnn::kernelFor(image, "copy", "copy_nonarray")]; 75 76 [encoder setComputePipelineState:state]; 77 [encoder setTexture:[image texture] atIndex:0]; 78 [encoder setTexture:[Y texture] atIndex:1]; 79 80 const auto& launchParams = 81 metal::mpscnn::spatialPointwiseKernelLaunchParams(state, image); 82 [encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid 83 threadsPerThreadgroup:launchParams.threadsPerThreadgroup]; 84 [encoder endEncoding]; 85 if (waitUntilCompleted) { 86 [buffer commit]; 87 } 88 return Y; 89} 90 91MPSTemporaryImage* createTemporaryImage( 92 MetalCommandBuffer* buffer, 93 IntArrayRef sizes) { 94 TORCH_CHECK(buffer); 95 int64_t N = sizes[0]; 96 int64_t C = sizes[1]; 97 int64_t H = sizes[2]; 98 int64_t W = sizes[3]; 99 MPSImageDescriptor* desc = [MPSImageDescriptor 100 imageDescriptorWithChannelFormat:MPSImageFeatureChannelFormatFloat16 101 width:W 102 height:H 103 featureChannels:C 104 numberOfImages:N 105 usage:MTLTextureUsageShaderRead | 106 MTLTextureUsageShaderWrite]; 107 MPSTemporaryImage* image = 108 [MPSTemporaryImage temporaryImageWithCommandBuffer:buffer.buffer 109 imageDescriptor:desc]; 110 image.readCount = INT_MAX; 111 image.label = [NSString 112 stringWithFormat:@"[%d, %d, %d, %d]", (int)N, (int)C, (int)H, (int)W]; 113 [buffer add:image]; 114 return image; 115} 116 117MPSTemporaryImage* createTemporaryImage( 118 MetalCommandBuffer* buffer, 119 IntArrayRef sizes, 120 const float* src) { 121 TORCH_CHECK(buffer); 122 int64_t size_bytes = c10::multiply_integers(sizes) * sizeof(float); 123 id<MTLBuffer> buff = [[MetalContext sharedInstance].device 124 newBufferWithBytes:src 125 length:size_bytes 126 options:MTLResourceStorageModeShared]; 127 MPSTemporaryImage* output = createTemporaryImage(buffer, sizes); 128 id<MTLComputePipelineState> state = [[MetalContext sharedInstance] 129 specializedPipelineState:metal::mpscnn::kernelFor( 130 output, 131 "copy_nchw_to_metal", 132 "copy_nchw_to_metal_nonarray") 133 Constants:@[ 134 @(output.featureChannels), 135 @(output.height), 136 @(output.width) 137 ]]; 138 id<MTLComputeCommandEncoder> encoder = [buffer.buffer computeCommandEncoder]; 139 [encoder setComputePipelineState:state]; 140 [encoder setBuffer:buff offset:0 atIndex:0]; 141 [encoder setTexture:[output texture] atIndex:0]; 142 const auto& launchParams = 143 metal::mpscnn::spatialPointwiseKernelLaunchParams(state, output); 144 [encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid 145 threadsPerThreadgroup:launchParams.threadsPerThreadgroup]; 146 [encoder endEncoding]; 147 return output; 148} 149 150MPSTemporaryImage* createTemporaryImage( 151 MetalCommandBuffer* buffer, 152 MPSImage* image) { 153 TORCH_CHECK(buffer); 154 MPSTemporaryImage* Y = createTemporaryImage(buffer, [image sizes]); 155 id<MTLComputeCommandEncoder> encoder = [buffer.buffer computeCommandEncoder]; 156 id<MTLComputePipelineState> state = [[MetalContext sharedInstance] 157 pipelineState:metal::mpscnn::kernelFor(image, "copy", "copy_nonarray")]; 158 [encoder setComputePipelineState:state]; 159 [encoder setTexture:[image texture] atIndex:0]; 160 [encoder setTexture:[Y texture] atIndex:1]; 161 162 const auto& launchParams = 163 metal::mpscnn::spatialPointwiseKernelLaunchParams(state, image); 164 [encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid 165 threadsPerThreadgroup:launchParams.threadsPerThreadgroup]; 166 [encoder endEncoding]; 167 return Y; 168} 169 170void copyImageToFloatBuffer(float* dst, MPSImage* image) { 171 int64_t size_bytes = c10::multiply_integers([image sizes]) * sizeof(float); 172 id<MTLBuffer> buffer = [[MetalContext sharedInstance].device 173 newBufferWithLength:size_bytes 174 options:MTLResourceCPUCacheModeDefaultCache]; 175 176 id<MTLCommandBuffer> cb = 177 [MetalContext sharedInstance].commandQueue.commandBuffer; 178 id<MTLComputeCommandEncoder> encoder = [cb computeCommandEncoder]; 179 id<MTLComputePipelineState> state = [[MetalContext sharedInstance] 180 specializedPipelineState:metal::mpscnn::kernelFor( 181 image, 182 "copy_metal_to_nchw", 183 "copy_metal_to_nchw_nonarray") 184 Constants:@[ 185 @(image.featureChannels), 186 @(image.height), 187 @(image.width) 188 ]]; 189 190 [encoder setComputePipelineState:state]; 191 [encoder setBuffer:buffer offset:0 atIndex:0]; 192 [encoder setTexture:[image texture] atIndex:0]; 193 194 const auto& launchParams = 195 metal::mpscnn::spatialPointwiseKernelLaunchParams(state, image); 196 [encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid 197 threadsPerThreadgroup:launchParams.threadsPerThreadgroup]; 198 [encoder endEncoding]; 199 [cb commit]; 200 [cb waitUntilCompleted]; 201 memcpy(dst, buffer.contents, buffer.length); 202} 203 204void copyImageToMetalBuffer( 205 MetalCommandBuffer* cmdBuffer, 206 id<MTLBuffer> dst, 207 MPSImage* image) { 208 TORCH_CHECK(cmdBuffer.buffer); 209 id<MTLComputeCommandEncoder> encoder = 210 [cmdBuffer.buffer computeCommandEncoder]; 211 id<MTLComputePipelineState> state = [[MetalContext sharedInstance] 212 specializedPipelineState:metal::mpscnn::kernelFor( 213 image, 214 "copy_metal_to_nchw", 215 "copy_metal_to_nchw_nonarray") 216 Constants:@[ 217 @(image.featureChannels), 218 @(image.height), 219 @(image.width) 220 ]]; 221 222 [encoder setComputePipelineState:state]; 223 [encoder setBuffer:dst offset:0 atIndex:0]; 224 [encoder setTexture:[image texture] atIndex:0]; 225 226 const auto& launchParams = 227 metal::mpscnn::spatialPointwiseKernelLaunchParams(state, image); 228 [encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid 229 threadsPerThreadgroup:launchParams.threadsPerThreadgroup]; 230 [encoder endEncoding]; 231} 232 233} 234} 235} 236