xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/metal/mpscnn/MPSCNNConvOp.mm (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#import <ATen/native/metal/MetalContext.h>
2#import <ATen/native/metal/mpscnn/MPSCNNConvOp.h>
3#import <ATen/native/metal/mpscnn/MPSCNNNeuronOp.h>
4#import <ATen/native/metal/mpscnn/MPSCNNUtils.h>
5
6#include <c10/util/Exception.h>
7
8@implementation MPSCNNConvDataSource {
9  void* _weights;
10  float* _bias;
11  MPSCNNConvolutionDescriptor* _descriptor;
12}
13
14- (id)initWithWeights:(void*)weights
15                 Bias:(float*)bias
16                 Desc:(MPSCNNConvolutionDescriptor*)desc
17    API_AVAILABLE(ios(11.0), macos(10.13)) {
18  self = [super init];
19  if (self) {
20    _weights = (float*)weights;
21    _bias = (float*)bias;
22    _descriptor = desc;
23  }
24  return self;
25}
26
27- (nonnull id)copyWithZone:(nullable NSZone*)zone {
28  MPSCNNConvDataSource* dataSource = [MPSCNNConvDataSource allocWithZone:zone];
29  dataSource->_weights = _weights;
30  dataSource->_bias = _bias;
31  dataSource->_descriptor = _descriptor;
32  return dataSource;
33}
34
35- (float* _Nullable)biasTerms {
36  return _bias;
37}
38
39- (MPSDataType)dataType API_AVAILABLE(ios(11.0), macos(10.13)) {
40  return MPSDataTypeFloat32;
41}
42
43- (NSString* _Nullable)label {
44  return @"";
45}
46
47- (BOOL)load {
48  return true;
49}
50
51- (void)purge {
52  _bias = nullptr;
53  _weights = nullptr;
54}
55
56- (void*)weights {
57  return _weights;
58}
59
60- (MPSCNNConvolutionDescriptor* _Nonnull)descriptor {
61  return _descriptor;
62}
63
64@end
65
66@implementation MPSCNNConvOp {
67}
68
69@synthesize kernel = _kernel;
70
71+ (MPSCNNConvOp*)conv2d:(const Conv2DParams&)params
72                weights:(float*)w
73                   bias:(float*)b
74           neuronFilter:(NeuronType)t API_AVAILABLE(ios(11.0), macos(10.13)) {
75  using namespace at::native::metal::mpscnn;
76  TORCH_CHECK(
77      params.DX == params.DY == 1, "Dilated convolution is not supported yet.");
78  const NSUInteger oC = params.OC;
79  const NSUInteger iC = params.C;
80  const NSUInteger kH = params.KH;
81  const NSUInteger kW = params.KW;
82  MPSCNNNeuron* neuron = at::native::metal::neuron(t);
83  MPSCNNConvolutionDescriptor* desc = nil;
84  if (params.isDepthwise()) {
85    if (@available(iOS 11.0, *)) {
86      desc = [MPSCNNDepthWiseConvolutionDescriptor
87          cnnConvolutionDescriptorWithKernelWidth:kW
88                                     kernelHeight:kH
89                             inputFeatureChannels:iC
90                            outputFeatureChannels:oC];
91
92      desc.groups = 1;
93#if TARGET_OS_MACCATALYST
94      desc.fusedNeuronDescriptor = at::native::metal::neuronDescriptor(t);
95#else
96      desc.neuron = neuron;
97#endif
98    } else {
99      TORCH_CHECK(
100          false,
101          "MPSCNNDepthWiseConvolutionDescriptor is only available on iOS 11.0 and above");
102    }
103  } else {
104    if (params.G > 1) {
105      TORCH_CHECK(
106          params.IC % 4 == 0,
107          "MPSCNNConvolution requires number of input \
108        channels in each group to be multiple of 4 for \
109        group > 1.");
110    }
111    if (@available(iOS 11.0, *)) {
112      desc = [MPSCNNConvolutionDescriptor
113          cnnConvolutionDescriptorWithKernelWidth:kW
114                                     kernelHeight:kH
115                             inputFeatureChannels:iC
116                            outputFeatureChannels:oC];
117      desc.groups = params.G;
118#if TARGET_OS_MACCATALYST
119      desc.fusedNeuronDescriptor = at::native::metal::neuronDescriptor(t);
120#else
121      desc.neuron = neuron;
122#endif
123    } else {
124      TORCH_CHECK(
125          false,
126          "MPSCNNConvolutionDescriptor is only available on iOS 11.0 and above");
127    }
128  }
129  desc.strideInPixelsX = params.SX;
130  desc.strideInPixelsY = params.SY;
131  id<MPSCNNConvolutionDataSource> dataSource =
132      [[MPSCNNConvDataSource alloc] initWithWeights:(float*)w
133                                               Bias:(float*)b
134                                               Desc:desc];
135  MPSCNNConvolution* conv = nil;
136  if (@available(iOS 11.0, *)) {
137    conv = [[MPSCNNConvolution alloc]
138        initWithDevice:[MetalContext sharedInstance].device
139               weights:dataSource];
140
141  } else {
142    TORCH_CHECK(
143        false, "MPSCNNConvolution is only available on iOS 11.0 and above");
144  }
145  [conv setEdgeMode:MPSImageEdgeModeZero];
146  MPSOffset offset;
147  offset.x = computeMPSAlignOffset(kW, params.PX);
148  offset.y = computeMPSAlignOffset(kH, params.PY);
149  offset.z = 0;
150  [conv setOffset:offset];
151
152  TORCH_CHECK(static_cast<int64_t>(conv.inputFeatureChannels) == params.IC * params.G);
153  TORCH_CHECK(oC % conv.groups == 0);
154  TORCH_CHECK(conv.outputFeatureChannels == oC);
155  TORCH_CHECK(conv.kernelWidth == kW);
156  TORCH_CHECK(conv.kernelHeight == kH);
157
158  MPSCNNConvOp* op = [MPSCNNConvOp new];
159  op->_kernel = conv;
160  return op;
161}
162
163- (void)encode:(id<MTLCommandBuffer>)cb
164         sourceImage:(MPSImage*)src
165    destinationImage:(MPSImage*)dst {
166  [_kernel encodeToCommandBuffer:cb sourceImage:src destinationImage:dst];
167}
168
169@end
170