1#import <ATen/native/metal/mpscnn/MPSCNNUtils.h> 2 3namespace at::native::metal::mpscnn { 4 5static auto divRoundUp(uint x, uint y) -> uint { 6 return (x + y - 1) / y; 7} 8 9LaunchParams spatialPointwiseKernelLaunchParams( 10 id<MTLComputePipelineState> pipeline, 11 MPSImage* im) { 12 return spatialPointwiseKernelLaunchParams( 13 pipeline, im.numberOfImages, im.featureChannels, im.height, im.width); 14} 15 16LaunchParams spatialPointwiseKernelLaunchParams( 17 id<MTLComputePipelineState> pipeline, 18 NSUInteger numberOfImages, 19 NSUInteger featureChannels, 20 NSUInteger height, 21 NSUInteger width) { 22 const auto threadsPerThreadgroup = MTLSizeMake( 23 8 /* threadExecutionWidth */, 24 4 /* maxThreadsPerThreadgroup / threadExecutionWidth */, 25 1); 26 const auto threadgroupsPerGrid = MTLSizeMake( 27 divRoundUp(width, threadsPerThreadgroup.width), 28 divRoundUp(height, threadsPerThreadgroup.height), 29 numberOfImages * divRoundUp(featureChannels, 4)); 30 const auto threadsPerGrid = MTLSizeMake( 31 width, height, numberOfImages * divRoundUp(featureChannels, 4)); 32 return {threadsPerThreadgroup, threadgroupsPerGrid, threadsPerGrid}; 33} 34 35} // namespace at::native::metal::mpscnn 36