1 2// 3// Copyright (c) 2023 Apple Inc. All rights reserved. 4// Provided subject to the LICENSE file in the top level directory. 5// 6 7#include <executorch/backends/apple/mps/runtime/MPSGraphBuilder.h> 8 9namespace executorch { 10namespace backends { 11namespace mps { 12namespace delegate { 13 14Error 15MPSGraphBuilder::mpsMaxPool2DWithIndicesOp(NodePtr nodePtr) { 16 auto graphNode = nodePtr->mpsnode_union_as_MPSMaxPool2DWithIndices(); 17 ET_LOG( 18 Debug, "%s: %d -> (%d, %d)", 19 __FUNCTION__, 20 graphNode->input1_id(), 21 graphNode->output1_id(), 22 graphNode->output2_id() 23 ); 24 25 MPSGraphPooling2DOpDescriptor* desc = 26 [MPSGraphPooling2DOpDescriptor descriptorWithKernelWidth:graphNode->kernel_width() 27 kernelHeight:graphNode->kernel_height() 28 strideInX:graphNode->stride_width() 29 strideInY:graphNode->stride_height() 30 dilationRateInX:graphNode->dilation_width() 31 dilationRateInY:graphNode->dilation_height() 32 paddingLeft:graphNode->padding_left() 33 paddingRight:graphNode->padding_right() 34 paddingTop:graphNode->padding_top() 35 paddingBottom:graphNode->padding_bottom() 36 paddingStyle:MPSGraphPaddingStyleExplicit 37 dataLayout:MPSGraphTensorNamedDataLayoutNCHW]; 38 desc.ceilMode = graphNode->ceil_mode(); 39#pragma clang diagnostic push 40#pragma clang diagnostic ignored "-Wunknown-warning-option" 41#pragma clang diagnostic ignored "-Wunguarded-availability-new" 42 desc.returnIndicesMode = MPSGraphPoolingReturnIndicesGlobalFlatten2D; 43 desc.returnIndicesDataType = MPSDataTypeInt32; 44#pragma clang diagnostic pop 45 46 NSArray<MPSGraphTensor*>* outputs = 47 [_mpsGraph maxPooling2DReturnIndicesWithSourceTensor:getMPSGraphTensor(graphNode->input1_id()) 48 descriptor:desc 49 name:@"MaxPool2DWithIndices"]; 50 51 52 _idToMPSGraphTensor[graphNode->output1_id()] = outputs[0]; 53 _idToMPSGraphTensor[graphNode->output2_id()] = outputs[1]; 54 return Error::Ok; 55} 56 57Error 58MPSGraphBuilder::mpsAvgPool2DOp(NodePtr nodePtr) { 59 auto graphNode = nodePtr->mpsnode_union_as_MPSAvgPool2D(); 60 ET_LOG( 61 Debug, "%s: %d -> %d", 62 __FUNCTION__, 63 graphNode->input1_id(), 64 graphNode->output1_id() 65 ); 66 67 MPSGraphPooling2DOpDescriptor* desc = 68 [MPSGraphPooling2DOpDescriptor descriptorWithKernelWidth:graphNode->kernel_width() 69 kernelHeight:graphNode->kernel_height() 70 strideInX:graphNode->stride_width() 71 strideInY:graphNode->stride_height() 72 dilationRateInX:graphNode->dilation_width() 73 dilationRateInY:graphNode->dilation_height() 74 paddingLeft:graphNode->padding_left() 75 paddingRight:graphNode->padding_right() 76 paddingTop:graphNode->padding_top() 77 paddingBottom:graphNode->padding_bottom() 78 paddingStyle:MPSGraphPaddingStyleExplicit 79 dataLayout:MPSGraphTensorNamedDataLayoutNCHW]; 80 const bool useDivisor = graphNode->divisor_override() != 0; 81 82 // If overriding divisor, zeroPads must be included to the average for correct behavior 83 desc.includeZeroPadToAverage = useDivisor ? true : graphNode->count_include_pad(); 84 85 MPSGraphTensor* avgPoolTensor = [_mpsGraph avgPooling2DWithSourceTensor:getMPSGraphTensor(graphNode->input1_id()) 86 descriptor:desc 87 name:@"AvgPool2DTensor"]; 88 if (useDivisor) { 89 // Here we rescale the average due to MPSGraph not supporting custom divisor directly 90 const float divisor = float(graphNode->kernel_height() * graphNode->kernel_width()) / (float)graphNode->divisor_override(); 91 MPSGraphTensor* constantTensor = [_mpsGraph constantWithScalar:divisor 92 shape:@[@1] 93 dataType:MPSDataTypeFloat32]; 94 avgPoolTensor = [_mpsGraph multiplicationWithPrimaryTensor:avgPoolTensor 95 secondaryTensor:constantTensor 96 name:@"AvgPool2DTensor/divisor_override"]; 97 98 } 99 100 _idToMPSGraphTensor[graphNode->output1_id()] = avgPoolTensor; 101 102 return Error::Ok; 103} 104 105 106 107} // namespace delegate 108} // namespace mps 109} // namespace backends 110} // namespace executorch 111