xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/metal/mpscnn/MPSCNNFullyConnectedOp.mm (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#import <ATen/native/metal/MetalContext.h>
2#import <ATen/native/metal/mpscnn/MPSCNNFullyConnectedOp.h>
3#import <ATen/native/metal/mpscnn/MPSCNNNeuronOp.h>
4
5@implementation MPSCNNFullyConnectedOp
6
7@synthesize kernel = _kernel;
8
9+ (MPSCNNFullyConnectedOp*)linear:(const Conv2DParams&)params
10                          weights:(float*)w
11                             bias:(float*)b
12                     neuronFilter:(NeuronType)t
13    API_AVAILABLE(ios(11.0), macos(10.13)) {
14  MPSCNNNeuron* neuron = at::native::metal::neuron(t);
15  MPSCNNConvolutionDescriptor* desc = [MPSCNNConvolutionDescriptor
16      cnnConvolutionDescriptorWithKernelWidth:params.KW
17                                 kernelHeight:params.KH
18                         inputFeatureChannels:params.IC
19                        outputFeatureChannels:params.OC];
20#if TARGET_OS_MACCATALYST
21  desc.fusedNeuronDescriptor = at::native::metal::neuronDescriptor(t);
22#else
23  desc.neuron = neuron;
24#endif
25  desc.strideInPixelsX = 1;
26  desc.strideInPixelsY = 1;
27
28  MPSCNNFullyConnected* fc = nil;
29  if (@available(iOS 11.0, *)) {
30    MPSCNNConvDataSource* ds =
31        [[MPSCNNConvDataSource alloc] initWithWeights:(float*)w
32                                                 Bias:(float*)b
33                                                 Desc:desc];
34    fc = [[MPSCNNFullyConnected alloc]
35        initWithDevice:[MetalContext sharedInstance].device
36               weights:ds];
37  } else {
38    TORCH_CHECK(
39        false,
40        "MPSCNNFullyConnectedOp is only available on iOS 11.0 and above");
41  }
42  [fc setClipRect:MTLRegionMake3D(0, 0, 0, 1, 1, params.N)];
43  [fc setOffset:{.x = static_cast<NSInteger>(params.W / 2),
44                 .y = static_cast<NSInteger>(params.H / 2),
45                 .z = 0}];
46  MPSCNNFullyConnectedOp* kernel = [MPSCNNFullyConnectedOp new];
47  kernel->_kernel = fc;
48  return kernel;
49}
50
51- (void)encode:(id<MTLCommandBuffer>)cb
52         sourceImage:(MPSImage*)src
53    destinationImage:(MPSImage*)dst {
54  [_kernel encodeToCommandBuffer:cb sourceImage:src destinationImage:dst];
55}
56
57@end
58