xref: /aosp_15_r20/external/executorch/backends/apple/mps/runtime/operations/ConvolutionOps.mm (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1
2//
3//  Copyright (c) 2023 Apple Inc. All rights reserved.
4//  Provided subject to the LICENSE file in the top level directory.
5//
6
7#include <executorch/backends/apple/mps/runtime/MPSGraphBuilder.h>
8
9namespace executorch {
10namespace backends {
11namespace mps {
12namespace delegate {
13
14Error
15MPSGraphBuilder::mpsDepthwiseConv2DOp(NodePtr nodePtr) {
16  auto graphNode = nodePtr->mpsnode_union_as_MPSDepthwiseConv2D();
17  ET_LOG(
18    Debug, "%s: (%d, %d, %d) -> %d",
19    __FUNCTION__,
20    graphNode->input1_id(),
21    graphNode->input2_id(),
22    graphNode->input3_id(),
23    graphNode->output_id()
24  );
25
26  bool isConv1D = ([getMPSShape(graphNode->input2_id()) count] == 3);
27  ET_CHECK(!isConv1D);
28
29  MPSGraphDepthwiseConvolution3DOpDescriptor* depthWiseConv3dDescriptor =
30    [[MPSGraphDepthwiseConvolution3DOpDescriptor new] autorelease];
31
32  depthWiseConv3dDescriptor.strides =
33      @[ @1, [[NSNumber alloc] initWithInteger:graphNode->stride_y()], [[NSNumber alloc] initWithInteger:graphNode->stride_x()] ];
34
35  depthWiseConv3dDescriptor.dilationRates =
36      @[ @1, [[NSNumber alloc] initWithInteger:graphNode->dilation_y()], [[NSNumber alloc] initWithInteger:graphNode->dilation_x()] ];
37
38  depthWiseConv3dDescriptor.paddingStyle = MPSGraphPaddingStyleExplicit;
39  depthWiseConv3dDescriptor.paddingValues = @[
40    @0,
41    @0,
42    [[NSNumber alloc] initWithInteger:graphNode->padding_top()],
43    [[NSNumber alloc] initWithInteger:graphNode->padding_bottom()],
44    [[NSNumber alloc] initWithInteger:graphNode->padding_left()],
45    [[NSNumber alloc] initWithInteger:graphNode->padding_right()]
46  ];
47  depthWiseConv3dDescriptor.channelDimensionIndex = -3LL;
48  MPSGraphTensor* weightTransposeTensor = [_mpsGraph transposeTensor:getMPSGraphTensor(graphNode->input2_id())
49                                                          dimension:-3
50                                                      withDimension:-4
51                                                                name:nil];
52  MPSGraphTensor* depthwiseConvTensor = [_mpsGraph depthwiseConvolution3DWithSourceTensor:getMPSGraphTensor(graphNode->input1_id())
53                                                                            weightsTensor:weightTransposeTensor
54                                                                               descriptor:depthWiseConv3dDescriptor
55                                                                                    name:nil];
56  // Bias is optional
57  if (graphNode->input3_id() != -1) {
58    //Need to add correct dimension to bias to avoid broadcasting issues
59    MPSGraphTensor* biasTensor = getMPSGraphTensor(graphNode->input3_id());
60    biasTensor = [_mpsGraph expandDimsOfTensor:biasTensor
61                                      axes:@[@0, @2, @3]
62                                      name:nil];
63    depthwiseConvTensor = [_mpsGraph additionWithPrimaryTensor:depthwiseConvTensor
64                                                secondaryTensor:biasTensor
65                                                            name:@"depthwiseConv2DWithBiasAdd"];
66  }
67
68  _idToMPSGraphTensor[graphNode->output_id()] = depthwiseConvTensor;
69  return Error::Ok;
70}
71
72Error
73MPSGraphBuilder::mpsConv2DOp(NodePtr nodePtr) {
74  auto graphNode = nodePtr->mpsnode_union_as_MPSConv2D();
75  ET_LOG(
76    Debug, "%s: (%d, %d, %d) -> %d",
77    __FUNCTION__,
78    graphNode->input1_id(),
79    graphNode->input2_id(),
80    graphNode->input3_id(),
81    graphNode->output_id()
82  );
83
84  MPSGraphTensor* inputTensor = getMPSGraphTensor(graphNode->input1_id());
85  MPSGraphTensor* weightTensor = getMPSGraphTensor(graphNode->input2_id());
86
87  bool isConv1D = ([weightTensor.shape count] == 3);
88  if (isConv1D) {
89    inputTensor = [_mpsGraph expandDimsOfTensor:inputTensor
90                                            axis:2
91                                            name:@"unsqueezeInput"];
92    weightTensor = [_mpsGraph expandDimsOfTensor:weightTensor
93                                              axis:2
94                                              name:@"unsqueezeWeight"];
95  }
96
97  MPSGraphConvolution2DOpDescriptor* desc =
98    [MPSGraphConvolution2DOpDescriptor descriptorWithStrideInX:graphNode->stride_x()
99                                                     strideInY:graphNode->stride_y()
100                                            dilationRateInX:graphNode->dilation_x()
101                                            dilationRateInY:graphNode->dilation_y()
102                                                     groups:graphNode->groups()
103                                                paddingLeft:graphNode->padding_left()
104                                               paddingRight:graphNode->padding_right()
105                                                 paddingTop:graphNode->padding_top()
106                                              paddingBottom:graphNode->padding_bottom()
107                                               paddingStyle:MPSGraphPaddingStyleExplicit
108                                                 dataLayout:MPSGraphTensorNamedDataLayoutNCHW
109                                              weightsLayout:MPSGraphTensorNamedDataLayoutHWIO];
110    // Convert weights from OIHW to HWIO.
111    MPSGraphTensor* weightTransposeTensor = permuteTensor(_mpsGraph, weightTensor, @[@2, @3, @1, @0]);
112    MPSGraphTensor* conv2DTensor = [_mpsGraph convolution2DWithSourceTensor:inputTensor
113                                                             weightsTensor:weightTransposeTensor
114                                                                descriptor:desc
115                                                                      name:@"conv2D"];
116
117    // Bias is optional
118    if (graphNode->input3_id() != -1) {
119      // Need to add correct dimension to bias to avoid broadcasting issues
120      MPSGraphTensor* biasTensor = getMPSGraphTensor(graphNode->input3_id());
121      biasTensor = [_mpsGraph expandDimsOfTensor:biasTensor
122                                           axes:@[@0,@2,@3]
123                                           name:nil];
124        conv2DTensor = [_mpsGraph additionWithPrimaryTensor:conv2DTensor
125                                           secondaryTensor:biasTensor
126                                                      name:@"conv2DWithBiasAdd"];
127    }
128
129  if (isConv1D) {
130    conv2DTensor = [_mpsGraph squeezeTensor:conv2DTensor
131                                       axis:2
132                                       name:@"squeeze"];
133  }
134
135  _idToMPSGraphTensor[graphNode->output_id()] = conv2DTensor;
136  return Error::Ok;
137}
138
139
140} // namespace delegate
141} // namespace mps
142} // namespace backends
143} // namespace executorch
144