1#import <ATen/native/metal/MetalContext.h> 2#import <ATen/native/metal/mpscnn/MPSCNNConvOp.h> 3#import <ATen/native/metal/mpscnn/MPSCNNNeuronOp.h> 4#import <ATen/native/metal/mpscnn/MPSCNNUtils.h> 5 6#include <c10/util/Exception.h> 7 8@implementation MPSCNNConvDataSource { 9 void* _weights; 10 float* _bias; 11 MPSCNNConvolutionDescriptor* _descriptor; 12} 13 14- (id)initWithWeights:(void*)weights 15 Bias:(float*)bias 16 Desc:(MPSCNNConvolutionDescriptor*)desc 17 API_AVAILABLE(ios(11.0), macos(10.13)) { 18 self = [super init]; 19 if (self) { 20 _weights = (float*)weights; 21 _bias = (float*)bias; 22 _descriptor = desc; 23 } 24 return self; 25} 26 27- (nonnull id)copyWithZone:(nullable NSZone*)zone { 28 MPSCNNConvDataSource* dataSource = [MPSCNNConvDataSource allocWithZone:zone]; 29 dataSource->_weights = _weights; 30 dataSource->_bias = _bias; 31 dataSource->_descriptor = _descriptor; 32 return dataSource; 33} 34 35- (float* _Nullable)biasTerms { 36 return _bias; 37} 38 39- (MPSDataType)dataType API_AVAILABLE(ios(11.0), macos(10.13)) { 40 return MPSDataTypeFloat32; 41} 42 43- (NSString* _Nullable)label { 44 return @""; 45} 46 47- (BOOL)load { 48 return true; 49} 50 51- (void)purge { 52 _bias = nullptr; 53 _weights = nullptr; 54} 55 56- (void*)weights { 57 return _weights; 58} 59 60- (MPSCNNConvolutionDescriptor* _Nonnull)descriptor { 61 return _descriptor; 62} 63 64@end 65 66@implementation MPSCNNConvOp { 67} 68 69@synthesize kernel = _kernel; 70 71+ (MPSCNNConvOp*)conv2d:(const Conv2DParams&)params 72 weights:(float*)w 73 bias:(float*)b 74 neuronFilter:(NeuronType)t API_AVAILABLE(ios(11.0), macos(10.13)) { 75 using namespace at::native::metal::mpscnn; 76 TORCH_CHECK( 77 params.DX == params.DY == 1, "Dilated convolution is not supported yet."); 78 const NSUInteger oC = params.OC; 79 const NSUInteger iC = params.C; 80 const NSUInteger kH = params.KH; 81 const NSUInteger kW = params.KW; 82 MPSCNNNeuron* neuron = at::native::metal::neuron(t); 83 MPSCNNConvolutionDescriptor* desc = nil; 84 if (params.isDepthwise()) { 85 if (@available(iOS 11.0, *)) { 86 desc = [MPSCNNDepthWiseConvolutionDescriptor 87 cnnConvolutionDescriptorWithKernelWidth:kW 88 kernelHeight:kH 89 inputFeatureChannels:iC 90 outputFeatureChannels:oC]; 91 92 desc.groups = 1; 93#if TARGET_OS_MACCATALYST 94 desc.fusedNeuronDescriptor = at::native::metal::neuronDescriptor(t); 95#else 96 desc.neuron = neuron; 97#endif 98 } else { 99 TORCH_CHECK( 100 false, 101 "MPSCNNDepthWiseConvolutionDescriptor is only available on iOS 11.0 and above"); 102 } 103 } else { 104 if (params.G > 1) { 105 TORCH_CHECK( 106 params.IC % 4 == 0, 107 "MPSCNNConvolution requires number of input \ 108 channels in each group to be multiple of 4 for \ 109 group > 1."); 110 } 111 if (@available(iOS 11.0, *)) { 112 desc = [MPSCNNConvolutionDescriptor 113 cnnConvolutionDescriptorWithKernelWidth:kW 114 kernelHeight:kH 115 inputFeatureChannels:iC 116 outputFeatureChannels:oC]; 117 desc.groups = params.G; 118#if TARGET_OS_MACCATALYST 119 desc.fusedNeuronDescriptor = at::native::metal::neuronDescriptor(t); 120#else 121 desc.neuron = neuron; 122#endif 123 } else { 124 TORCH_CHECK( 125 false, 126 "MPSCNNConvolutionDescriptor is only available on iOS 11.0 and above"); 127 } 128 } 129 desc.strideInPixelsX = params.SX; 130 desc.strideInPixelsY = params.SY; 131 id<MPSCNNConvolutionDataSource> dataSource = 132 [[MPSCNNConvDataSource alloc] initWithWeights:(float*)w 133 Bias:(float*)b 134 Desc:desc]; 135 MPSCNNConvolution* conv = nil; 136 if (@available(iOS 11.0, *)) { 137 conv = [[MPSCNNConvolution alloc] 138 initWithDevice:[MetalContext sharedInstance].device 139 weights:dataSource]; 140 141 } else { 142 TORCH_CHECK( 143 false, "MPSCNNConvolution is only available on iOS 11.0 and above"); 144 } 145 [conv setEdgeMode:MPSImageEdgeModeZero]; 146 MPSOffset offset; 147 offset.x = computeMPSAlignOffset(kW, params.PX); 148 offset.y = computeMPSAlignOffset(kH, params.PY); 149 offset.z = 0; 150 [conv setOffset:offset]; 151 152 TORCH_CHECK(static_cast<int64_t>(conv.inputFeatureChannels) == params.IC * params.G); 153 TORCH_CHECK(oC % conv.groups == 0); 154 TORCH_CHECK(conv.outputFeatureChannels == oC); 155 TORCH_CHECK(conv.kernelWidth == kW); 156 TORCH_CHECK(conv.kernelHeight == kH); 157 158 MPSCNNConvOp* op = [MPSCNNConvOp new]; 159 op->_kernel = conv; 160 return op; 161} 162 163- (void)encode:(id<MTLCommandBuffer>)cb 164 sourceImage:(MPSImage*)src 165 destinationImage:(MPSImage*)dst { 166 [_kernel encodeToCommandBuffer:cb sourceImage:src destinationImage:dst]; 167} 168 169@end 170