1#import <ATen/native/metal/MetalContext.h> 2#import <ATen/native/metal/mpscnn/MPSCNNFullyConnectedOp.h> 3#import <ATen/native/metal/mpscnn/MPSCNNNeuronOp.h> 4 5@implementation MPSCNNFullyConnectedOp 6 7@synthesize kernel = _kernel; 8 9+ (MPSCNNFullyConnectedOp*)linear:(const Conv2DParams&)params 10 weights:(float*)w 11 bias:(float*)b 12 neuronFilter:(NeuronType)t 13 API_AVAILABLE(ios(11.0), macos(10.13)) { 14 MPSCNNNeuron* neuron = at::native::metal::neuron(t); 15 MPSCNNConvolutionDescriptor* desc = [MPSCNNConvolutionDescriptor 16 cnnConvolutionDescriptorWithKernelWidth:params.KW 17 kernelHeight:params.KH 18 inputFeatureChannels:params.IC 19 outputFeatureChannels:params.OC]; 20#if TARGET_OS_MACCATALYST 21 desc.fusedNeuronDescriptor = at::native::metal::neuronDescriptor(t); 22#else 23 desc.neuron = neuron; 24#endif 25 desc.strideInPixelsX = 1; 26 desc.strideInPixelsY = 1; 27 28 MPSCNNFullyConnected* fc = nil; 29 if (@available(iOS 11.0, *)) { 30 MPSCNNConvDataSource* ds = 31 [[MPSCNNConvDataSource alloc] initWithWeights:(float*)w 32 Bias:(float*)b 33 Desc:desc]; 34 fc = [[MPSCNNFullyConnected alloc] 35 initWithDevice:[MetalContext sharedInstance].device 36 weights:ds]; 37 } else { 38 TORCH_CHECK( 39 false, 40 "MPSCNNFullyConnectedOp is only available on iOS 11.0 and above"); 41 } 42 [fc setClipRect:MTLRegionMake3D(0, 0, 0, 1, 1, params.N)]; 43 [fc setOffset:{.x = static_cast<NSInteger>(params.W / 2), 44 .y = static_cast<NSInteger>(params.H / 2), 45 .z = 0}]; 46 MPSCNNFullyConnectedOp* kernel = [MPSCNNFullyConnectedOp new]; 47 kernel->_kernel = fc; 48 return kernel; 49} 50 51- (void)encode:(id<MTLCommandBuffer>)cb 52 sourceImage:(MPSImage*)src 53 destinationImage:(MPSImage*)dst { 54 [_kernel encodeToCommandBuffer:cb sourceImage:src destinationImage:dst]; 55} 56 57@end 58