1 #import <Metal/Metal.h>
2 #import <MetalPerformanceShaders/MetalPerformanceShaders.h>
3 #include <string>
4
5 // This is a utility macro that can be used to throw an exception when a Metal
6 // API function produces a NSError. The exception will contain a message with
7 // useful info extracted from the NSError.
8 #define METAL_THROW_IF_ERROR(error, preamble) \
9 do { \
10 if C10_LIKELY(error) { \
11 throw c10::Error( \
12 {__func__, __FILE__, static_cast<uint32_t>(__LINE__)}, \
13 c10::str( \
14 preamble, \
15 " Error details: ", \
16 " Localized_description: ", error.localizedDescription.UTF8String, \
17 " Domain: ", error.domain.UTF8String, \
18 " Code: ", error.code, \
19 " User Info: ", error.userInfo.description.UTF8String)); \
20 } \
21 } while (false)
22
23 namespace at::native::metal::mpscnn {
24
25 struct LaunchParams {
26 MTLSize threadsPerThreadgroup;
27 MTLSize threadgroupsPerGrid;
28 MTLSize threadsPerGrid; // iOS 11.0
29 };
30
31 API_AVAILABLE(ios(11.0), macos(10.13))
32 LaunchParams spatialPointwiseKernelLaunchParams(
33 id<MTLComputePipelineState> pipeline,
34 MPSImage* im);
35
36 API_AVAILABLE(ios(11.0), macos(10.13))
37 LaunchParams spatialPointwiseKernelLaunchParams(
38 id<MTLComputePipelineState> pipeline,
39 NSUInteger numberOfImages,
40 NSUInteger featureChannels,
41 NSUInteger height,
42 NSUInteger width);
43
44 API_AVAILABLE(ios(11.0), macos(10.13))
kernelFor(MPSImage * image,const std::string & arrayKernel,const std::string & nonArrayKernel)45 static inline std::string kernelFor(
46 MPSImage* image,
47 const std::string& arrayKernel,
48 const std::string& nonArrayKernel) {
49 if (image.featureChannels > 4 || image.numberOfImages > 1) {
50 return arrayKernel;
51 }
52 return nonArrayKernel;
53 }
54
computeMPSAlignOffset(int kernel,int pad)55 static inline int computeMPSAlignOffset(int kernel, int pad) {
56 // To set the offset, we can just match the top-left pixel (in the input
57 // image, with negative values for padding) that we look at. For 3x3s1p1, we
58 // look at the (-1, -1) pixel in the original impl. For 3x3s1p0, we look at
59 // (0, 0) pixel. For 3x3s1p2, look at (-2, -2) MPSCNN always looks at
60 // (-floor(kernel_size - 1 / 2), -floor(kernel_size - 1 / 2)) Thus, we just
61 // need to match this up.
62
63 // For 3x3s1p1, offset should be (0, 0)
64 // For 3x3s1p0, offset should be (1, 1)
65 // For 3x3s1p2, offset should be (-1, -1)
66 const int mps_offset = kernel / 2;
67 const int pt_offset = pad;
68 return mps_offset - pt_offset;
69 }
70
71 } // namespace at::native::metal::mpscnn
72