xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/metal/mpscnn/MPSCNNUtils.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #import <Metal/Metal.h>
2 #import <MetalPerformanceShaders/MetalPerformanceShaders.h>
3 #include <string>
4 
5 // This is a utility macro that can be used to throw an exception when a Metal
6 // API function produces a NSError. The exception will contain a message with
7 // useful info extracted from the NSError.
8 #define METAL_THROW_IF_ERROR(error, preamble)                                    \
9   do {                                                                           \
10     if C10_LIKELY(error) {                                                       \
11       throw c10::Error(                                                          \
12           {__func__, __FILE__, static_cast<uint32_t>(__LINE__)},                 \
13           c10::str(                                                              \
14               preamble,                                                          \
15               " Error details: ",                                                \
16               " Localized_description: ", error.localizedDescription.UTF8String, \
17               " Domain: ", error.domain.UTF8String,                              \
18               " Code: ", error.code,                                             \
19               " User Info: ", error.userInfo.description.UTF8String));           \
20     }                                                                            \
21   } while (false)
22 
23 namespace at::native::metal::mpscnn {
24 
25 struct LaunchParams {
26   MTLSize threadsPerThreadgroup;
27   MTLSize threadgroupsPerGrid;
28   MTLSize threadsPerGrid; // iOS 11.0
29 };
30 
31 API_AVAILABLE(ios(11.0), macos(10.13))
32 LaunchParams spatialPointwiseKernelLaunchParams(
33     id<MTLComputePipelineState> pipeline,
34     MPSImage* im);
35 
36 API_AVAILABLE(ios(11.0), macos(10.13))
37 LaunchParams spatialPointwiseKernelLaunchParams(
38     id<MTLComputePipelineState> pipeline,
39     NSUInteger numberOfImages,
40     NSUInteger featureChannels,
41     NSUInteger height,
42     NSUInteger width);
43 
44 API_AVAILABLE(ios(11.0), macos(10.13))
kernelFor(MPSImage * image,const std::string & arrayKernel,const std::string & nonArrayKernel)45 static inline std::string kernelFor(
46     MPSImage* image,
47     const std::string& arrayKernel,
48     const std::string& nonArrayKernel) {
49   if (image.featureChannels > 4 || image.numberOfImages > 1) {
50     return arrayKernel;
51   }
52   return nonArrayKernel;
53 }
54 
computeMPSAlignOffset(int kernel,int pad)55 static inline int computeMPSAlignOffset(int kernel, int pad) {
56   // To set the offset, we can just match the top-left pixel (in the input
57   // image, with negative values for padding) that we look at. For 3x3s1p1, we
58   // look at the (-1, -1) pixel in the original impl. For 3x3s1p0, we look at
59   // (0, 0) pixel. For 3x3s1p2, look at (-2, -2) MPSCNN always looks at
60   // (-floor(kernel_size - 1 / 2), -floor(kernel_size - 1 / 2)) Thus, we just
61   // need to match this up.
62 
63   // For 3x3s1p1, offset should be (0, 0)
64   // For 3x3s1p0, offset should be (1, 1)
65   // For 3x3s1p2, offset should be (-1, -1)
66   const int mps_offset = kernel / 2;
67   const int pt_offset = pad;
68   return mps_offset - pt_offset;
69 }
70 
71 } // namespace at::native::metal::mpscnn
72