xref: /aosp_15_r20/external/executorch/backends/apple/mps/runtime/operations/PoolingOps.mm (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1
2//
3//  Copyright (c) 2023 Apple Inc. All rights reserved.
4//  Provided subject to the LICENSE file in the top level directory.
5//
6
7#include <executorch/backends/apple/mps/runtime/MPSGraphBuilder.h>
8
9namespace executorch {
10namespace backends {
11namespace mps {
12namespace delegate {
13
14Error
15MPSGraphBuilder::mpsMaxPool2DWithIndicesOp(NodePtr nodePtr) {
16  auto graphNode = nodePtr->mpsnode_union_as_MPSMaxPool2DWithIndices();
17  ET_LOG(
18    Debug, "%s: %d -> (%d, %d)",
19    __FUNCTION__,
20    graphNode->input1_id(),
21    graphNode->output1_id(),
22    graphNode->output2_id()
23  );
24
25  MPSGraphPooling2DOpDescriptor* desc =
26    [MPSGraphPooling2DOpDescriptor descriptorWithKernelWidth:graphNode->kernel_width()
27                                                kernelHeight:graphNode->kernel_height()
28                                                   strideInX:graphNode->stride_width()
29                                                   strideInY:graphNode->stride_height()
30                                             dilationRateInX:graphNode->dilation_width()
31                                             dilationRateInY:graphNode->dilation_height()
32                                                 paddingLeft:graphNode->padding_left()
33                                                paddingRight:graphNode->padding_right()
34                                                  paddingTop:graphNode->padding_top()
35                                               paddingBottom:graphNode->padding_bottom()
36                                                paddingStyle:MPSGraphPaddingStyleExplicit
37                                                  dataLayout:MPSGraphTensorNamedDataLayoutNCHW];
38  desc.ceilMode = graphNode->ceil_mode();
39#pragma clang diagnostic push
40#pragma clang diagnostic ignored "-Wunknown-warning-option"
41#pragma clang diagnostic ignored "-Wunguarded-availability-new"
42  desc.returnIndicesMode = MPSGraphPoolingReturnIndicesGlobalFlatten2D;
43  desc.returnIndicesDataType = MPSDataTypeInt32;
44#pragma clang diagnostic pop
45
46  NSArray<MPSGraphTensor*>* outputs =
47    [_mpsGraph maxPooling2DReturnIndicesWithSourceTensor:getMPSGraphTensor(graphNode->input1_id())
48                                              descriptor:desc
49                                                    name:@"MaxPool2DWithIndices"];
50
51
52  _idToMPSGraphTensor[graphNode->output1_id()] = outputs[0];
53  _idToMPSGraphTensor[graphNode->output2_id()] = outputs[1];
54  return Error::Ok;
55}
56
57Error
58MPSGraphBuilder::mpsAvgPool2DOp(NodePtr nodePtr) {
59  auto graphNode = nodePtr->mpsnode_union_as_MPSAvgPool2D();
60  ET_LOG(
61    Debug, "%s: %d -> %d",
62    __FUNCTION__,
63    graphNode->input1_id(),
64    graphNode->output1_id()
65  );
66
67  MPSGraphPooling2DOpDescriptor* desc =
68    [MPSGraphPooling2DOpDescriptor descriptorWithKernelWidth:graphNode->kernel_width()
69                                                kernelHeight:graphNode->kernel_height()
70                                                   strideInX:graphNode->stride_width()
71                                                   strideInY:graphNode->stride_height()
72                                             dilationRateInX:graphNode->dilation_width()
73                                             dilationRateInY:graphNode->dilation_height()
74                                                 paddingLeft:graphNode->padding_left()
75                                                paddingRight:graphNode->padding_right()
76                                                  paddingTop:graphNode->padding_top()
77                                               paddingBottom:graphNode->padding_bottom()
78                                                paddingStyle:MPSGraphPaddingStyleExplicit
79                                                  dataLayout:MPSGraphTensorNamedDataLayoutNCHW];
80  const bool useDivisor = graphNode->divisor_override() != 0;
81
82  // If overriding divisor, zeroPads must be included to the average for correct behavior
83  desc.includeZeroPadToAverage = useDivisor ? true : graphNode->count_include_pad();
84
85  MPSGraphTensor* avgPoolTensor = [_mpsGraph avgPooling2DWithSourceTensor:getMPSGraphTensor(graphNode->input1_id())
86                                                               descriptor:desc
87                                                                     name:@"AvgPool2DTensor"];
88  if (useDivisor) {
89    // Here we rescale the average due to MPSGraph not supporting custom divisor directly
90    const float divisor = float(graphNode->kernel_height() * graphNode->kernel_width()) / (float)graphNode->divisor_override();
91    MPSGraphTensor* constantTensor = [_mpsGraph constantWithScalar:divisor
92                                                             shape:@[@1]
93                                                          dataType:MPSDataTypeFloat32];
94    avgPoolTensor = [_mpsGraph multiplicationWithPrimaryTensor:avgPoolTensor
95                                               secondaryTensor:constantTensor
96                                                          name:@"AvgPool2DTensor/divisor_override"];
97
98  }
99
100  _idToMPSGraphTensor[graphNode->output1_id()] = avgPoolTensor;
101
102  return Error::Ok;
103}
104
105
106
107} // namespace delegate
108} // namespace mps
109} // namespace backends
110} // namespace executorch
111