1 2// 3// Copyright (c) 2023 Apple Inc. All rights reserved. 4// Provided subject to the LICENSE file in the top level directory. 5// 6 7#include <executorch/backends/apple/mps/runtime/MPSGraphBuilder.h> 8 9namespace executorch { 10namespace backends { 11namespace mps { 12namespace delegate { 13 14Error 15MPSGraphBuilder::mpsDepthwiseConv2DOp(NodePtr nodePtr) { 16 auto graphNode = nodePtr->mpsnode_union_as_MPSDepthwiseConv2D(); 17 ET_LOG( 18 Debug, "%s: (%d, %d, %d) -> %d", 19 __FUNCTION__, 20 graphNode->input1_id(), 21 graphNode->input2_id(), 22 graphNode->input3_id(), 23 graphNode->output_id() 24 ); 25 26 bool isConv1D = ([getMPSShape(graphNode->input2_id()) count] == 3); 27 ET_CHECK(!isConv1D); 28 29 MPSGraphDepthwiseConvolution3DOpDescriptor* depthWiseConv3dDescriptor = 30 [[MPSGraphDepthwiseConvolution3DOpDescriptor new] autorelease]; 31 32 depthWiseConv3dDescriptor.strides = 33 @[ @1, [[NSNumber alloc] initWithInteger:graphNode->stride_y()], [[NSNumber alloc] initWithInteger:graphNode->stride_x()] ]; 34 35 depthWiseConv3dDescriptor.dilationRates = 36 @[ @1, [[NSNumber alloc] initWithInteger:graphNode->dilation_y()], [[NSNumber alloc] initWithInteger:graphNode->dilation_x()] ]; 37 38 depthWiseConv3dDescriptor.paddingStyle = MPSGraphPaddingStyleExplicit; 39 depthWiseConv3dDescriptor.paddingValues = @[ 40 @0, 41 @0, 42 [[NSNumber alloc] initWithInteger:graphNode->padding_top()], 43 [[NSNumber alloc] initWithInteger:graphNode->padding_bottom()], 44 [[NSNumber alloc] initWithInteger:graphNode->padding_left()], 45 [[NSNumber alloc] initWithInteger:graphNode->padding_right()] 46 ]; 47 depthWiseConv3dDescriptor.channelDimensionIndex = -3LL; 48 MPSGraphTensor* weightTransposeTensor = [_mpsGraph transposeTensor:getMPSGraphTensor(graphNode->input2_id()) 49 dimension:-3 50 withDimension:-4 51 name:nil]; 52 MPSGraphTensor* depthwiseConvTensor = [_mpsGraph depthwiseConvolution3DWithSourceTensor:getMPSGraphTensor(graphNode->input1_id()) 53 weightsTensor:weightTransposeTensor 54 descriptor:depthWiseConv3dDescriptor 55 name:nil]; 56 // Bias is optional 57 if (graphNode->input3_id() != -1) { 58 //Need to add correct dimension to bias to avoid broadcasting issues 59 MPSGraphTensor* biasTensor = getMPSGraphTensor(graphNode->input3_id()); 60 biasTensor = [_mpsGraph expandDimsOfTensor:biasTensor 61 axes:@[@0, @2, @3] 62 name:nil]; 63 depthwiseConvTensor = [_mpsGraph additionWithPrimaryTensor:depthwiseConvTensor 64 secondaryTensor:biasTensor 65 name:@"depthwiseConv2DWithBiasAdd"]; 66 } 67 68 _idToMPSGraphTensor[graphNode->output_id()] = depthwiseConvTensor; 69 return Error::Ok; 70} 71 72Error 73MPSGraphBuilder::mpsConv2DOp(NodePtr nodePtr) { 74 auto graphNode = nodePtr->mpsnode_union_as_MPSConv2D(); 75 ET_LOG( 76 Debug, "%s: (%d, %d, %d) -> %d", 77 __FUNCTION__, 78 graphNode->input1_id(), 79 graphNode->input2_id(), 80 graphNode->input3_id(), 81 graphNode->output_id() 82 ); 83 84 MPSGraphTensor* inputTensor = getMPSGraphTensor(graphNode->input1_id()); 85 MPSGraphTensor* weightTensor = getMPSGraphTensor(graphNode->input2_id()); 86 87 bool isConv1D = ([weightTensor.shape count] == 3); 88 if (isConv1D) { 89 inputTensor = [_mpsGraph expandDimsOfTensor:inputTensor 90 axis:2 91 name:@"unsqueezeInput"]; 92 weightTensor = [_mpsGraph expandDimsOfTensor:weightTensor 93 axis:2 94 name:@"unsqueezeWeight"]; 95 } 96 97 MPSGraphConvolution2DOpDescriptor* desc = 98 [MPSGraphConvolution2DOpDescriptor descriptorWithStrideInX:graphNode->stride_x() 99 strideInY:graphNode->stride_y() 100 dilationRateInX:graphNode->dilation_x() 101 dilationRateInY:graphNode->dilation_y() 102 groups:graphNode->groups() 103 paddingLeft:graphNode->padding_left() 104 paddingRight:graphNode->padding_right() 105 paddingTop:graphNode->padding_top() 106 paddingBottom:graphNode->padding_bottom() 107 paddingStyle:MPSGraphPaddingStyleExplicit 108 dataLayout:MPSGraphTensorNamedDataLayoutNCHW 109 weightsLayout:MPSGraphTensorNamedDataLayoutHWIO]; 110 // Convert weights from OIHW to HWIO. 111 MPSGraphTensor* weightTransposeTensor = permuteTensor(_mpsGraph, weightTensor, @[@2, @3, @1, @0]); 112 MPSGraphTensor* conv2DTensor = [_mpsGraph convolution2DWithSourceTensor:inputTensor 113 weightsTensor:weightTransposeTensor 114 descriptor:desc 115 name:@"conv2D"]; 116 117 // Bias is optional 118 if (graphNode->input3_id() != -1) { 119 // Need to add correct dimension to bias to avoid broadcasting issues 120 MPSGraphTensor* biasTensor = getMPSGraphTensor(graphNode->input3_id()); 121 biasTensor = [_mpsGraph expandDimsOfTensor:biasTensor 122 axes:@[@0,@2,@3] 123 name:nil]; 124 conv2DTensor = [_mpsGraph additionWithPrimaryTensor:conv2DTensor 125 secondaryTensor:biasTensor 126 name:@"conv2DWithBiasAdd"]; 127 } 128 129 if (isConv1D) { 130 conv2DTensor = [_mpsGraph squeezeTensor:conv2DTensor 131 axis:2 132 name:@"squeeze"]; 133 } 134 135 _idToMPSGraphTensor[graphNode->output_id()] = conv2DTensor; 136 return Error::Ok; 137} 138 139 140} // namespace delegate 141} // namespace mps 142} // namespace backends 143} // namespace executorch 144