xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/metal/mpscnn/MPSCNNUtils.mm (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#import <ATen/native/metal/mpscnn/MPSCNNUtils.h>
2
3namespace at::native::metal::mpscnn {
4
5static auto divRoundUp(uint x, uint y) -> uint {
6  return (x + y - 1) / y;
7}
8
9LaunchParams spatialPointwiseKernelLaunchParams(
10    id<MTLComputePipelineState> pipeline,
11    MPSImage* im) {
12  return spatialPointwiseKernelLaunchParams(
13      pipeline, im.numberOfImages, im.featureChannels, im.height, im.width);
14}
15
16LaunchParams spatialPointwiseKernelLaunchParams(
17    id<MTLComputePipelineState> pipeline,
18    NSUInteger numberOfImages,
19    NSUInteger featureChannels,
20    NSUInteger height,
21    NSUInteger width) {
22  const auto threadsPerThreadgroup = MTLSizeMake(
23      8 /* threadExecutionWidth */,
24      4 /* maxThreadsPerThreadgroup / threadExecutionWidth */,
25      1);
26  const auto threadgroupsPerGrid = MTLSizeMake(
27      divRoundUp(width, threadsPerThreadgroup.width),
28      divRoundUp(height, threadsPerThreadgroup.height),
29      numberOfImages * divRoundUp(featureChannels, 4));
30  const auto threadsPerGrid = MTLSizeMake(
31      width, height, numberOfImages * divRoundUp(featureChannels, 4));
32  return {threadsPerThreadgroup, threadgroupsPerGrid, threadsPerGrid};
33}
34
35} // namespace at::native::metal::mpscnn
36