1*89c4ff92SAndroid Build Coastguard Worker // 2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2022 Arm Ltd and Contributors. All rights reserved. 3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT 4*89c4ff92SAndroid Build Coastguard Worker // 5*89c4ff92SAndroid Build Coastguard Worker 6*89c4ff92SAndroid Build Coastguard Worker #include "LayersFwd.hpp" 7*89c4ff92SAndroid Build Coastguard Worker #include <Network.hpp> 8*89c4ff92SAndroid Build Coastguard Worker #include <TestUtils.hpp> 9*89c4ff92SAndroid Build Coastguard Worker #include <doctest/doctest.h> 10*89c4ff92SAndroid Build Coastguard Worker #include <armnn/backends/TensorHandle.hpp> 11*89c4ff92SAndroid Build Coastguard Worker #include <Optimizer.hpp> 12*89c4ff92SAndroid Build Coastguard Worker 13*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("Optimizer") 14*89c4ff92SAndroid Build Coastguard Worker { 15*89c4ff92SAndroid Build Coastguard Worker using namespace armnn; 16*89c4ff92SAndroid Build Coastguard Worker using namespace armnn::optimizations; 17*89c4ff92SAndroid Build Coastguard Worker 18*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("FoldPadLayerIntoConvolution2dLayer") 19*89c4ff92SAndroid Build Coastguard Worker { 20*89c4ff92SAndroid Build Coastguard Worker Graph graph; 21*89c4ff92SAndroid Build Coastguard Worker const unsigned int inputShape[] = {1, 2, 2, 3}; 22*89c4ff92SAndroid Build Coastguard Worker const unsigned int paddedShape[] = {1, 6, 6, 3}; 23*89c4ff92SAndroid Build Coastguard Worker const unsigned int weightsShape[] = {1, 2, 3, 3}; 24*89c4ff92SAndroid Build Coastguard Worker const unsigned int outputShape[] = {1, 2, 1, 1}; 25*89c4ff92SAndroid Build Coastguard Worker 26*89c4ff92SAndroid Build Coastguard Worker TensorInfo inputInfo(4, inputShape, DataType::Float32); 27*89c4ff92SAndroid Build Coastguard Worker TensorInfo paddedInfo(4, paddedShape, DataType::Float32); 28*89c4ff92SAndroid Build Coastguard Worker TensorInfo weightsInfo(4, weightsShape, DataType::Float32, 1.0f, 0, true); 29*89c4ff92SAndroid Build Coastguard Worker TensorInfo outputInfo(4, outputShape, DataType::Float32); 30*89c4ff92SAndroid Build Coastguard Worker 31*89c4ff92SAndroid Build Coastguard Worker Layer* input = graph.AddLayer<InputLayer>(0, "input"); 32*89c4ff92SAndroid Build Coastguard Worker input->GetOutputSlot().SetTensorInfo(inputInfo); 33*89c4ff92SAndroid Build Coastguard Worker 34*89c4ff92SAndroid Build Coastguard Worker PadDescriptor padDescriptor({{0, 0}, 35*89c4ff92SAndroid Build Coastguard Worker {2, 2}, 36*89c4ff92SAndroid Build Coastguard Worker {2, 2}, 37*89c4ff92SAndroid Build Coastguard Worker {0, 0}}); 38*89c4ff92SAndroid Build Coastguard Worker 39*89c4ff92SAndroid Build Coastguard Worker PadLayer* padLayer = graph.AddLayer<PadLayer>(padDescriptor, "pad"); 40*89c4ff92SAndroid Build Coastguard Worker padLayer->GetOutputSlot().SetTensorInfo(paddedInfo); 41*89c4ff92SAndroid Build Coastguard Worker 42*89c4ff92SAndroid Build Coastguard Worker Convolution2dDescriptor convolution2dDescriptor; 43*89c4ff92SAndroid Build Coastguard Worker convolution2dDescriptor.m_BiasEnabled = false; 44*89c4ff92SAndroid Build Coastguard Worker convolution2dDescriptor.m_StrideX = 1; 45*89c4ff92SAndroid Build Coastguard Worker convolution2dDescriptor.m_StrideY = 1; 46*89c4ff92SAndroid Build Coastguard Worker convolution2dDescriptor.m_DataLayout = DataLayout::NHWC; 47*89c4ff92SAndroid Build Coastguard Worker 48*89c4ff92SAndroid Build Coastguard Worker std::vector<float> weightsVector(18); 49*89c4ff92SAndroid Build Coastguard Worker ConstTensor weights(weightsInfo, weightsVector); 50*89c4ff92SAndroid Build Coastguard Worker 51*89c4ff92SAndroid Build Coastguard Worker ConstantLayer* weightsLayer = graph.AddLayer<ConstantLayer>("Weights"); 52*89c4ff92SAndroid Build Coastguard Worker weightsLayer->m_LayerOutput = std::make_shared<ScopedTensorHandle>(weights); 53*89c4ff92SAndroid Build Coastguard Worker weightsLayer->GetOutputSlot(0).SetTensorInfo(weightsInfo); 54*89c4ff92SAndroid Build Coastguard Worker 55*89c4ff92SAndroid Build Coastguard Worker Convolution2dLayer* conv2dLayer = graph.AddLayer<Convolution2dLayer>(convolution2dDescriptor, "conv2d"); 56*89c4ff92SAndroid Build Coastguard Worker conv2dLayer->GetOutputSlot().SetTensorInfo(outputInfo); 57*89c4ff92SAndroid Build Coastguard Worker 58*89c4ff92SAndroid Build Coastguard Worker Layer* output = graph.AddLayer<OutputLayer>(0, "output"); 59*89c4ff92SAndroid Build Coastguard Worker 60*89c4ff92SAndroid Build Coastguard Worker // Connect up layers - input -> pad -> conv2d -> output 61*89c4ff92SAndroid Build Coastguard Worker input->GetOutputSlot().Connect(padLayer->GetInputSlot(0)); 62*89c4ff92SAndroid Build Coastguard Worker padLayer->GetOutputSlot().Connect(conv2dLayer->GetInputSlot(0)); 63*89c4ff92SAndroid Build Coastguard Worker weightsLayer->GetOutputSlot().Connect(conv2dLayer->GetInputSlot(1)); 64*89c4ff92SAndroid Build Coastguard Worker conv2dLayer->GetOutputSlot().Connect(output->GetInputSlot(0)); 65*89c4ff92SAndroid Build Coastguard Worker __anonf6df302a0102(const Layer* const layer)66*89c4ff92SAndroid Build Coastguard Worker auto checkSimpleConv2d = [](const Layer* const layer)->bool { 67*89c4ff92SAndroid Build Coastguard Worker const auto conv2dLayer = static_cast<const Convolution2dLayer*>(layer); 68*89c4ff92SAndroid Build Coastguard Worker const auto conv2dLayerParams = conv2dLayer->GetParameters(); 69*89c4ff92SAndroid Build Coastguard Worker return IsLayerOfType<Convolution2dLayer>(layer) && (layer->GetNameStr() == "conv2d") && 70*89c4ff92SAndroid Build Coastguard Worker (conv2dLayerParams.m_PadLeft == 0) && (conv2dLayerParams.m_PadRight == 0) && 71*89c4ff92SAndroid Build Coastguard Worker (conv2dLayerParams.m_PadTop == 0) && (conv2dLayerParams.m_PadBottom == 0) && 72*89c4ff92SAndroid Build Coastguard Worker (conv2dLayerParams.m_StrideX == 1) && (conv2dLayerParams.m_StrideY == 1) && 73*89c4ff92SAndroid Build Coastguard Worker (conv2dLayerParams.m_BiasEnabled == false) && (conv2dLayerParams.m_DataLayout == DataLayout::NHWC); 74*89c4ff92SAndroid Build Coastguard Worker }; 75*89c4ff92SAndroid Build Coastguard Worker 76*89c4ff92SAndroid Build Coastguard Worker CHECK(CheckSequence(graph.cbegin(), graph.cend(), &IsLayerOfType<InputLayer>, 77*89c4ff92SAndroid Build Coastguard Worker &IsLayerOfType<ConstantLayer>, 78*89c4ff92SAndroid Build Coastguard Worker &IsLayerOfType<PadLayer>, 79*89c4ff92SAndroid Build Coastguard Worker checkSimpleConv2d, 80*89c4ff92SAndroid Build Coastguard Worker &IsLayerOfType<OutputLayer>)); 81*89c4ff92SAndroid Build Coastguard Worker 82*89c4ff92SAndroid Build Coastguard Worker armnn::Optimizer::Pass(graph, armnn::MakeOptimizations(FoldPadIntoConvolution2d())); 83*89c4ff92SAndroid Build Coastguard Worker __anonf6df302a0202(const Layer* const layer)84*89c4ff92SAndroid Build Coastguard Worker auto checkPadFoldedIntoConv2d = [](const Layer* const layer)->bool { 85*89c4ff92SAndroid Build Coastguard Worker const auto conv2dLayer = static_cast<const Convolution2dLayer*>(layer); 86*89c4ff92SAndroid Build Coastguard Worker const auto conv2dLayerParams = conv2dLayer->GetParameters(); 87*89c4ff92SAndroid Build Coastguard Worker return IsLayerOfType<Convolution2dLayer>(layer) && (layer->GetNameStr() == "folded-pad-into-conv2d") && 88*89c4ff92SAndroid Build Coastguard Worker (conv2dLayerParams.m_PadLeft == 2) && (conv2dLayerParams.m_PadRight == 2) && 89*89c4ff92SAndroid Build Coastguard Worker (conv2dLayerParams.m_PadTop == 2) && (conv2dLayerParams.m_PadBottom == 2) && 90*89c4ff92SAndroid Build Coastguard Worker (conv2dLayerParams.m_StrideX == 1) && (conv2dLayerParams.m_StrideY == 1) && 91*89c4ff92SAndroid Build Coastguard Worker (conv2dLayerParams.m_BiasEnabled == false) && (conv2dLayerParams.m_DataLayout == DataLayout::NHWC); 92*89c4ff92SAndroid Build Coastguard Worker }; 93*89c4ff92SAndroid Build Coastguard Worker 94*89c4ff92SAndroid Build Coastguard Worker CHECK(CheckSequence(graph.cbegin(), graph.cend(), &IsLayerOfType<InputLayer>, 95*89c4ff92SAndroid Build Coastguard Worker &IsLayerOfType<ConstantLayer>, 96*89c4ff92SAndroid Build Coastguard Worker checkPadFoldedIntoConv2d, 97*89c4ff92SAndroid Build Coastguard Worker &IsLayerOfType<OutputLayer>)); 98*89c4ff92SAndroid Build Coastguard Worker } 99*89c4ff92SAndroid Build Coastguard Worker 100*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("FoldPadLayerIntoDepthwiseConvolution2dLayer") 101*89c4ff92SAndroid Build Coastguard Worker { 102*89c4ff92SAndroid Build Coastguard Worker Graph graph; 103*89c4ff92SAndroid Build Coastguard Worker const unsigned int inputShape[] = {1, 2, 2, 3}; 104*89c4ff92SAndroid Build Coastguard Worker const unsigned int paddedShape[] = {1, 6, 6, 3}; 105*89c4ff92SAndroid Build Coastguard Worker const unsigned int weightsShape[] = {1, 2, 3, 3}; 106*89c4ff92SAndroid Build Coastguard Worker const unsigned int outputShape[] = {1, 2, 1, 3}; 107*89c4ff92SAndroid Build Coastguard Worker 108*89c4ff92SAndroid Build Coastguard Worker TensorInfo inputInfo(4, inputShape, DataType::Float32); 109*89c4ff92SAndroid Build Coastguard Worker TensorInfo paddedInfo(4, paddedShape, DataType::Float32); 110*89c4ff92SAndroid Build Coastguard Worker TensorInfo weightsInfo(4, weightsShape, DataType::Float32, 1.0f, 0, true); 111*89c4ff92SAndroid Build Coastguard Worker TensorInfo outputInfo(4, outputShape, DataType::Float32); 112*89c4ff92SAndroid Build Coastguard Worker 113*89c4ff92SAndroid Build Coastguard Worker Layer* input = graph.AddLayer<InputLayer>(0, "input"); 114*89c4ff92SAndroid Build Coastguard Worker input->GetOutputSlot().SetTensorInfo(inputInfo); 115*89c4ff92SAndroid Build Coastguard Worker 116*89c4ff92SAndroid Build Coastguard Worker PadDescriptor padDescriptor({{0, 0}, 117*89c4ff92SAndroid Build Coastguard Worker {2, 2}, 118*89c4ff92SAndroid Build Coastguard Worker {2, 2}, 119*89c4ff92SAndroid Build Coastguard Worker {0, 0}}); 120*89c4ff92SAndroid Build Coastguard Worker 121*89c4ff92SAndroid Build Coastguard Worker PadLayer* padLayer = graph.AddLayer<PadLayer>(padDescriptor, "pad"); 122*89c4ff92SAndroid Build Coastguard Worker padLayer->GetOutputSlot().SetTensorInfo(paddedInfo); 123*89c4ff92SAndroid Build Coastguard Worker 124*89c4ff92SAndroid Build Coastguard Worker DepthwiseConvolution2dDescriptor depthwiseConvolution2dDescriptor; 125*89c4ff92SAndroid Build Coastguard Worker depthwiseConvolution2dDescriptor.m_BiasEnabled = false; 126*89c4ff92SAndroid Build Coastguard Worker depthwiseConvolution2dDescriptor.m_StrideX = 1; 127*89c4ff92SAndroid Build Coastguard Worker depthwiseConvolution2dDescriptor.m_StrideY = 1; 128*89c4ff92SAndroid Build Coastguard Worker depthwiseConvolution2dDescriptor.m_DataLayout = DataLayout::NHWC; 129*89c4ff92SAndroid Build Coastguard Worker 130*89c4ff92SAndroid Build Coastguard Worker std::vector<float> weightsVector(18); 131*89c4ff92SAndroid Build Coastguard Worker ConstTensor weights(weightsInfo, weightsVector); 132*89c4ff92SAndroid Build Coastguard Worker 133*89c4ff92SAndroid Build Coastguard Worker auto* weightsLayer = graph.AddLayer<ConstantLayer>("weights"); 134*89c4ff92SAndroid Build Coastguard Worker weightsLayer->GetOutputSlot().SetTensorInfo(weightsInfo); 135*89c4ff92SAndroid Build Coastguard Worker weightsLayer->m_LayerOutput = std::make_shared<ScopedTensorHandle>(weights); 136*89c4ff92SAndroid Build Coastguard Worker 137*89c4ff92SAndroid Build Coastguard Worker auto* depthwiseConv2dLayer = graph.AddLayer<DepthwiseConvolution2dLayer>(depthwiseConvolution2dDescriptor, 138*89c4ff92SAndroid Build Coastguard Worker "depthwiseConv2d"); 139*89c4ff92SAndroid Build Coastguard Worker depthwiseConv2dLayer->GetOutputSlot().SetTensorInfo(outputInfo); 140*89c4ff92SAndroid Build Coastguard Worker 141*89c4ff92SAndroid Build Coastguard Worker Layer* output = graph.AddLayer<OutputLayer>(0, "output"); 142*89c4ff92SAndroid Build Coastguard Worker 143*89c4ff92SAndroid Build Coastguard Worker // Connect up layers - input -> pad -> depthwiseConv2d -> output 144*89c4ff92SAndroid Build Coastguard Worker input->GetOutputSlot().Connect(padLayer->GetInputSlot(0)); 145*89c4ff92SAndroid Build Coastguard Worker padLayer->GetOutputSlot().Connect(depthwiseConv2dLayer->GetInputSlot(0)); 146*89c4ff92SAndroid Build Coastguard Worker weightsLayer->GetOutputSlot().Connect(depthwiseConv2dLayer->GetInputSlot(1)); 147*89c4ff92SAndroid Build Coastguard Worker depthwiseConv2dLayer->GetOutputSlot().Connect(output->GetInputSlot(0)); 148*89c4ff92SAndroid Build Coastguard Worker __anonf6df302a0302(const Layer* const layer)149*89c4ff92SAndroid Build Coastguard Worker auto checkSimpleDepthwiseConv2d = [](const Layer* const layer)->bool { 150*89c4ff92SAndroid Build Coastguard Worker const auto depthwiseConv2dLayer = static_cast<const DepthwiseConvolution2dLayer*>(layer); 151*89c4ff92SAndroid Build Coastguard Worker const auto depthwiseConv2dLayerParams = depthwiseConv2dLayer->GetParameters(); 152*89c4ff92SAndroid Build Coastguard Worker return IsLayerOfType<DepthwiseConvolution2dLayer>(layer) && (layer->GetNameStr() == "depthwiseConv2d") && 153*89c4ff92SAndroid Build Coastguard Worker (depthwiseConv2dLayerParams.m_PadLeft == 0) && (depthwiseConv2dLayerParams.m_PadRight == 0) && 154*89c4ff92SAndroid Build Coastguard Worker (depthwiseConv2dLayerParams.m_PadTop == 0) && (depthwiseConv2dLayerParams.m_PadBottom == 0) && 155*89c4ff92SAndroid Build Coastguard Worker (depthwiseConv2dLayerParams.m_StrideX == 1) && (depthwiseConv2dLayerParams.m_StrideY == 1) && 156*89c4ff92SAndroid Build Coastguard Worker (depthwiseConv2dLayerParams.m_BiasEnabled == false) && 157*89c4ff92SAndroid Build Coastguard Worker (depthwiseConv2dLayerParams.m_DataLayout == DataLayout::NHWC); 158*89c4ff92SAndroid Build Coastguard Worker }; 159*89c4ff92SAndroid Build Coastguard Worker 160*89c4ff92SAndroid Build Coastguard Worker CHECK(CheckSequence(graph.cbegin(), graph.cend(), &IsLayerOfType<InputLayer>, 161*89c4ff92SAndroid Build Coastguard Worker &IsLayerOfType<ConstantLayer>, 162*89c4ff92SAndroid Build Coastguard Worker &IsLayerOfType<PadLayer>, 163*89c4ff92SAndroid Build Coastguard Worker checkSimpleDepthwiseConv2d, 164*89c4ff92SAndroid Build Coastguard Worker &IsLayerOfType<OutputLayer>)); 165*89c4ff92SAndroid Build Coastguard Worker 166*89c4ff92SAndroid Build Coastguard Worker armnn::Optimizer::Pass(graph, MakeOptimizations(FoldPadIntoDepthwiseConvolution2d())); 167*89c4ff92SAndroid Build Coastguard Worker __anonf6df302a0402(const Layer* const layer)168*89c4ff92SAndroid Build Coastguard Worker auto checkPadFoldedIntoDepthwiseConv2d = [](const Layer* const layer)->bool { 169*89c4ff92SAndroid Build Coastguard Worker const auto depthwiseConv2dLayer = static_cast<const DepthwiseConvolution2dLayer*>(layer); 170*89c4ff92SAndroid Build Coastguard Worker const auto depthwiseConv2dLayerParams = depthwiseConv2dLayer->GetParameters(); 171*89c4ff92SAndroid Build Coastguard Worker return IsLayerOfType<DepthwiseConvolution2dLayer>(layer) && 172*89c4ff92SAndroid Build Coastguard Worker (layer->GetNameStr() == "folded-pad-into-depthwiseConv2d") && 173*89c4ff92SAndroid Build Coastguard Worker (depthwiseConv2dLayerParams.m_PadLeft == 2) && (depthwiseConv2dLayerParams.m_PadRight == 2) && 174*89c4ff92SAndroid Build Coastguard Worker (depthwiseConv2dLayerParams.m_PadTop == 2) && (depthwiseConv2dLayerParams.m_PadBottom == 2) && 175*89c4ff92SAndroid Build Coastguard Worker (depthwiseConv2dLayerParams.m_StrideX == 1) && (depthwiseConv2dLayerParams.m_StrideY == 1) && 176*89c4ff92SAndroid Build Coastguard Worker (depthwiseConv2dLayerParams.m_BiasEnabled == false) && 177*89c4ff92SAndroid Build Coastguard Worker (depthwiseConv2dLayerParams.m_DataLayout == DataLayout::NHWC); 178*89c4ff92SAndroid Build Coastguard Worker }; 179*89c4ff92SAndroid Build Coastguard Worker 180*89c4ff92SAndroid Build Coastguard Worker CHECK(CheckSequence(graph.cbegin(), graph.cend(), &IsLayerOfType<InputLayer>, 181*89c4ff92SAndroid Build Coastguard Worker &IsLayerOfType<ConstantLayer>, 182*89c4ff92SAndroid Build Coastguard Worker checkPadFoldedIntoDepthwiseConv2d, 183*89c4ff92SAndroid Build Coastguard Worker &IsLayerOfType<OutputLayer>)); 184*89c4ff92SAndroid Build Coastguard Worker } 185*89c4ff92SAndroid Build Coastguard Worker 186*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("FoldPadLayerIntoPooling2dLayer") 187*89c4ff92SAndroid Build Coastguard Worker { 188*89c4ff92SAndroid Build Coastguard Worker Graph graph; 189*89c4ff92SAndroid Build Coastguard Worker const unsigned int inputShape[] = {1, 2, 2, 3}; 190*89c4ff92SAndroid Build Coastguard Worker const unsigned int paddedShape[] = {1, 4, 4, 3}; 191*89c4ff92SAndroid Build Coastguard Worker const unsigned int outputShape[] = {1, 2, 2, 3}; 192*89c4ff92SAndroid Build Coastguard Worker 193*89c4ff92SAndroid Build Coastguard Worker TensorInfo inputInfo(4, inputShape, DataType::Float32); 194*89c4ff92SAndroid Build Coastguard Worker TensorInfo paddedInfo(4, paddedShape, DataType::Float32); 195*89c4ff92SAndroid Build Coastguard Worker TensorInfo outputInfo(4, outputShape, DataType::Float32); 196*89c4ff92SAndroid Build Coastguard Worker 197*89c4ff92SAndroid Build Coastguard Worker Layer* input = graph.AddLayer<InputLayer>(0, "input"); 198*89c4ff92SAndroid Build Coastguard Worker input->GetOutputSlot().SetTensorInfo(inputInfo); 199*89c4ff92SAndroid Build Coastguard Worker 200*89c4ff92SAndroid Build Coastguard Worker PadDescriptor padDescriptor({{0, 0}, 201*89c4ff92SAndroid Build Coastguard Worker {1, 1}, 202*89c4ff92SAndroid Build Coastguard Worker {1, 1}, 203*89c4ff92SAndroid Build Coastguard Worker {0, 0}}); 204*89c4ff92SAndroid Build Coastguard Worker 205*89c4ff92SAndroid Build Coastguard Worker PadLayer* padLayer = graph.AddLayer<PadLayer>(padDescriptor, "pad"); 206*89c4ff92SAndroid Build Coastguard Worker padLayer->GetOutputSlot().SetTensorInfo(paddedInfo); 207*89c4ff92SAndroid Build Coastguard Worker 208*89c4ff92SAndroid Build Coastguard Worker Pooling2dDescriptor pooling2dDescriptor; 209*89c4ff92SAndroid Build Coastguard Worker pooling2dDescriptor.m_PoolType = PoolingAlgorithm::Average; 210*89c4ff92SAndroid Build Coastguard Worker pooling2dDescriptor.m_PoolWidth = 3; 211*89c4ff92SAndroid Build Coastguard Worker pooling2dDescriptor.m_PoolHeight = 3; 212*89c4ff92SAndroid Build Coastguard Worker pooling2dDescriptor.m_StrideX = 1; 213*89c4ff92SAndroid Build Coastguard Worker pooling2dDescriptor.m_StrideY = 1; 214*89c4ff92SAndroid Build Coastguard Worker pooling2dDescriptor.m_DataLayout = DataLayout::NHWC; 215*89c4ff92SAndroid Build Coastguard Worker 216*89c4ff92SAndroid Build Coastguard Worker Pooling2dLayer* pool2dLayer = graph.AddLayer<Pooling2dLayer>(pooling2dDescriptor, "pool2d"); 217*89c4ff92SAndroid Build Coastguard Worker pool2dLayer->GetOutputSlot().SetTensorInfo(outputInfo); 218*89c4ff92SAndroid Build Coastguard Worker 219*89c4ff92SAndroid Build Coastguard Worker Layer* output = graph.AddLayer<OutputLayer>(0, "output"); 220*89c4ff92SAndroid Build Coastguard Worker 221*89c4ff92SAndroid Build Coastguard Worker // Connect up layers - input -> pad -> pool2d -> output 222*89c4ff92SAndroid Build Coastguard Worker input->GetOutputSlot().Connect(padLayer->GetInputSlot(0)); 223*89c4ff92SAndroid Build Coastguard Worker padLayer->GetOutputSlot().Connect(pool2dLayer->GetInputSlot(0)); 224*89c4ff92SAndroid Build Coastguard Worker pool2dLayer->GetOutputSlot().Connect(output->GetInputSlot(0)); 225*89c4ff92SAndroid Build Coastguard Worker __anonf6df302a0502(const Layer* const layer) 226*89c4ff92SAndroid Build Coastguard Worker auto checkSimplePool2d = [&](const Layer* const layer) { 227*89c4ff92SAndroid Build Coastguard Worker const auto pool2dLayer = static_cast<const Pooling2dLayer*>(layer); 228*89c4ff92SAndroid Build Coastguard Worker return IsLayerOfType<Pooling2dLayer>(layer) && (layer->GetNameStr() == "pool2d") && 229*89c4ff92SAndroid Build Coastguard Worker (pool2dLayer->GetParameters() == pooling2dDescriptor); 230*89c4ff92SAndroid Build Coastguard Worker }; 231*89c4ff92SAndroid Build Coastguard Worker 232*89c4ff92SAndroid Build Coastguard Worker CHECK(CheckSequence(graph.cbegin(), graph.cend(), 233*89c4ff92SAndroid Build Coastguard Worker &IsLayerOfType<InputLayer>, 234*89c4ff92SAndroid Build Coastguard Worker &IsLayerOfType<PadLayer>, 235*89c4ff92SAndroid Build Coastguard Worker checkSimplePool2d, 236*89c4ff92SAndroid Build Coastguard Worker &IsLayerOfType<OutputLayer>)); 237*89c4ff92SAndroid Build Coastguard Worker 238*89c4ff92SAndroid Build Coastguard Worker armnn::Optimizer::Pass(graph, MakeOptimizations(FoldPadIntoPooling2d())); 239*89c4ff92SAndroid Build Coastguard Worker __anonf6df302a0602(const Layer* const layer) 240*89c4ff92SAndroid Build Coastguard Worker auto checkPadFoldedIntoPool2d = [&](const Layer* const layer) { 241*89c4ff92SAndroid Build Coastguard Worker if (!IsLayerOfType<Pooling2dLayer>(layer) || (layer->GetNameStr() != "folded-pad-into-pool2d")) 242*89c4ff92SAndroid Build Coastguard Worker { 243*89c4ff92SAndroid Build Coastguard Worker return false; 244*89c4ff92SAndroid Build Coastguard Worker } 245*89c4ff92SAndroid Build Coastguard Worker 246*89c4ff92SAndroid Build Coastguard Worker const auto pool2dLayer = static_cast<const Pooling2dLayer*>(layer); 247*89c4ff92SAndroid Build Coastguard Worker const Pooling2dDescriptor pool2dLayerParams = pool2dLayer->GetParameters(); 248*89c4ff92SAndroid Build Coastguard Worker 249*89c4ff92SAndroid Build Coastguard Worker Pooling2dDescriptor pool2dLayerParamsNoPad = pool2dLayerParams; 250*89c4ff92SAndroid Build Coastguard Worker pool2dLayerParamsNoPad.m_PadLeft = 0; 251*89c4ff92SAndroid Build Coastguard Worker pool2dLayerParamsNoPad.m_PadRight = 0; 252*89c4ff92SAndroid Build Coastguard Worker pool2dLayerParamsNoPad.m_PadTop = 0; 253*89c4ff92SAndroid Build Coastguard Worker pool2dLayerParamsNoPad.m_PadBottom = 0; 254*89c4ff92SAndroid Build Coastguard Worker // If we fold then PaddingMethod will be set to Ignore. The original will be Exclude. 255*89c4ff92SAndroid Build Coastguard Worker pool2dLayerParamsNoPad.m_PaddingMethod = PaddingMethod::Exclude; 256*89c4ff92SAndroid Build Coastguard Worker 257*89c4ff92SAndroid Build Coastguard Worker return (pool2dLayerParamsNoPad == pooling2dDescriptor) && (pool2dLayerParams.m_PadLeft == 1) && 258*89c4ff92SAndroid Build Coastguard Worker (pool2dLayerParams.m_PadRight == 1) && (pool2dLayerParams.m_PadTop == 1) && 259*89c4ff92SAndroid Build Coastguard Worker (pool2dLayerParams.m_PadBottom == 1) && (pool2dLayerParams.m_PaddingMethod == PaddingMethod::IgnoreValue); 260*89c4ff92SAndroid Build Coastguard Worker }; 261*89c4ff92SAndroid Build Coastguard Worker 262*89c4ff92SAndroid Build Coastguard Worker CHECK(CheckSequence(graph.cbegin(), graph.cend(), 263*89c4ff92SAndroid Build Coastguard Worker &IsLayerOfType<InputLayer>, 264*89c4ff92SAndroid Build Coastguard Worker checkPadFoldedIntoPool2d, 265*89c4ff92SAndroid Build Coastguard Worker &IsLayerOfType<OutputLayer>)); 266*89c4ff92SAndroid Build Coastguard Worker } 267*89c4ff92SAndroid Build Coastguard Worker 268*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("FoldPadLayerIntoPooling2d_PadWithMultipleOutputsShouldNotBeOptimized") 269*89c4ff92SAndroid Build Coastguard Worker { 270*89c4ff92SAndroid Build Coastguard Worker // In this test case we'll setup a pad layer with two outputs. One goes to a polling layers and the other 271*89c4ff92SAndroid Build Coastguard Worker // goes to an output layer. FoldPadLayerIntoPooling2d should not optimize this graph as it uses the 272*89c4ff92SAndroid Build Coastguard Worker // OptimizeForExclusiveConnection method. 273*89c4ff92SAndroid Build Coastguard Worker Graph graph; 274*89c4ff92SAndroid Build Coastguard Worker const unsigned int inputShape[] = {1, 2, 2, 3}; 275*89c4ff92SAndroid Build Coastguard Worker const unsigned int paddedShape[] = {1, 4, 4, 3}; 276*89c4ff92SAndroid Build Coastguard Worker const unsigned int outputShape[] = {1, 2, 2, 3}; 277*89c4ff92SAndroid Build Coastguard Worker 278*89c4ff92SAndroid Build Coastguard Worker TensorInfo inputInfo(4, inputShape, DataType::Float32); 279*89c4ff92SAndroid Build Coastguard Worker TensorInfo paddedInfo(4, paddedShape, DataType::Float32); 280*89c4ff92SAndroid Build Coastguard Worker TensorInfo outputInfo(4, outputShape, DataType::Float32); 281*89c4ff92SAndroid Build Coastguard Worker 282*89c4ff92SAndroid Build Coastguard Worker Layer* input = graph.AddLayer<InputLayer>(0, "input"); 283*89c4ff92SAndroid Build Coastguard Worker input->GetOutputSlot().SetTensorInfo(inputInfo); 284*89c4ff92SAndroid Build Coastguard Worker 285*89c4ff92SAndroid Build Coastguard Worker PadDescriptor padDescriptor({{0, 0}, 286*89c4ff92SAndroid Build Coastguard Worker {1, 1}, 287*89c4ff92SAndroid Build Coastguard Worker {1, 1}, 288*89c4ff92SAndroid Build Coastguard Worker {0, 0}}); 289*89c4ff92SAndroid Build Coastguard Worker 290*89c4ff92SAndroid Build Coastguard Worker PadLayer* padLayer = graph.AddLayer<PadLayer>(padDescriptor, "pad"); 291*89c4ff92SAndroid Build Coastguard Worker padLayer->GetOutputSlot().SetTensorInfo(paddedInfo); 292*89c4ff92SAndroid Build Coastguard Worker 293*89c4ff92SAndroid Build Coastguard Worker Pooling2dDescriptor pooling2dDescriptor; 294*89c4ff92SAndroid Build Coastguard Worker pooling2dDescriptor.m_PoolType = PoolingAlgorithm::Average; 295*89c4ff92SAndroid Build Coastguard Worker pooling2dDescriptor.m_PoolWidth = 3; 296*89c4ff92SAndroid Build Coastguard Worker pooling2dDescriptor.m_PoolHeight = 3; 297*89c4ff92SAndroid Build Coastguard Worker pooling2dDescriptor.m_StrideX = 1; 298*89c4ff92SAndroid Build Coastguard Worker pooling2dDescriptor.m_StrideY = 1; 299*89c4ff92SAndroid Build Coastguard Worker pooling2dDescriptor.m_DataLayout = DataLayout::NHWC; 300*89c4ff92SAndroid Build Coastguard Worker 301*89c4ff92SAndroid Build Coastguard Worker Pooling2dLayer* pool2dLayer = graph.AddLayer<Pooling2dLayer>(pooling2dDescriptor, "pool2d"); 302*89c4ff92SAndroid Build Coastguard Worker pool2dLayer->GetOutputSlot().SetTensorInfo(outputInfo); 303*89c4ff92SAndroid Build Coastguard Worker 304*89c4ff92SAndroid Build Coastguard Worker Layer* output = graph.AddLayer<OutputLayer>(0, "output"); 305*89c4ff92SAndroid Build Coastguard Worker 306*89c4ff92SAndroid Build Coastguard Worker // Connect up layers - input -> pad -> pool2d -> output 307*89c4ff92SAndroid Build Coastguard Worker input->GetOutputSlot().Connect(padLayer->GetInputSlot(0)); 308*89c4ff92SAndroid Build Coastguard Worker padLayer->GetOutputSlot().Connect(pool2dLayer->GetInputSlot(0)); 309*89c4ff92SAndroid Build Coastguard Worker pool2dLayer->GetOutputSlot().Connect(output->GetInputSlot(0)); 310*89c4ff92SAndroid Build Coastguard Worker 311*89c4ff92SAndroid Build Coastguard Worker // Add the alternative branch from the pas layer to an output layer. 312*89c4ff92SAndroid Build Coastguard Worker Layer* secondOutput = graph.AddLayer<OutputLayer>(1, "dummy output"); 313*89c4ff92SAndroid Build Coastguard Worker padLayer->GetOutputSlot().Connect(secondOutput->GetInputSlot(0)); 314*89c4ff92SAndroid Build Coastguard Worker __anonf6df302a0702(const Layer* const layer) 315*89c4ff92SAndroid Build Coastguard Worker auto checkSimplePool2d = [&](const Layer* const layer) { 316*89c4ff92SAndroid Build Coastguard Worker const auto pool2dLayer = static_cast<const Pooling2dLayer*>(layer); 317*89c4ff92SAndroid Build Coastguard Worker return IsLayerOfType<Pooling2dLayer>(layer) && (layer->GetNameStr() == "pool2d") && 318*89c4ff92SAndroid Build Coastguard Worker (pool2dLayer->GetParameters() == pooling2dDescriptor); 319*89c4ff92SAndroid Build Coastguard Worker }; 320*89c4ff92SAndroid Build Coastguard Worker 321*89c4ff92SAndroid Build Coastguard Worker // Initial sequence. 322*89c4ff92SAndroid Build Coastguard Worker CHECK(CheckSequence(graph.cbegin(), graph.cend(), 323*89c4ff92SAndroid Build Coastguard Worker &IsLayerOfType<InputLayer>, 324*89c4ff92SAndroid Build Coastguard Worker &IsLayerOfType<PadLayer>, 325*89c4ff92SAndroid Build Coastguard Worker checkSimplePool2d, 326*89c4ff92SAndroid Build Coastguard Worker &IsLayerOfType<OutputLayer>, 327*89c4ff92SAndroid Build Coastguard Worker &IsLayerOfType<OutputLayer>)); 328*89c4ff92SAndroid Build Coastguard Worker 329*89c4ff92SAndroid Build Coastguard Worker armnn::Optimizer::Pass(graph, MakeOptimizations(FoldPadIntoPooling2d())); 330*89c4ff92SAndroid Build Coastguard Worker 331*89c4ff92SAndroid Build Coastguard Worker // The network should not change. 332*89c4ff92SAndroid Build Coastguard Worker CHECK(CheckSequence(graph.cbegin(), graph.cend(), 333*89c4ff92SAndroid Build Coastguard Worker &IsLayerOfType<InputLayer>, 334*89c4ff92SAndroid Build Coastguard Worker &IsLayerOfType<PadLayer>, 335*89c4ff92SAndroid Build Coastguard Worker checkSimplePool2d, 336*89c4ff92SAndroid Build Coastguard Worker &IsLayerOfType<OutputLayer>, 337*89c4ff92SAndroid Build Coastguard Worker &IsLayerOfType<OutputLayer>)); 338*89c4ff92SAndroid Build Coastguard Worker } 339*89c4ff92SAndroid Build Coastguard Worker 340*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("FoldPadLayerIntoPooling2dLayer_PoolingLayerWithExcludePaddingShouldNotTakeMorePadding") 341*89c4ff92SAndroid Build Coastguard Worker { 342*89c4ff92SAndroid Build Coastguard Worker // In this test setup input, Pad layer, Pooling layer that includes padding, output layer. The optimization 343*89c4ff92SAndroid Build Coastguard Worker // should not work as the pooling layer already includes and existing pad and specifies PaddingMethod::Exclude. 344*89c4ff92SAndroid Build Coastguard Worker Graph graph; 345*89c4ff92SAndroid Build Coastguard Worker const unsigned int inputShape[] = {1, 2, 2, 3}; 346*89c4ff92SAndroid Build Coastguard Worker const unsigned int paddedShape[] = {1, 4, 4, 3}; 347*89c4ff92SAndroid Build Coastguard Worker const unsigned int outputShape[] = {1, 2, 2, 3}; 348*89c4ff92SAndroid Build Coastguard Worker 349*89c4ff92SAndroid Build Coastguard Worker TensorInfo inputInfo(4, inputShape, DataType::Float32); 350*89c4ff92SAndroid Build Coastguard Worker TensorInfo paddedInfo(4, paddedShape, DataType::Float32); 351*89c4ff92SAndroid Build Coastguard Worker TensorInfo outputInfo(4, outputShape, DataType::Float32); 352*89c4ff92SAndroid Build Coastguard Worker 353*89c4ff92SAndroid Build Coastguard Worker Layer* input = graph.AddLayer<InputLayer>(0, "input"); 354*89c4ff92SAndroid Build Coastguard Worker input->GetOutputSlot().SetTensorInfo(inputInfo); 355*89c4ff92SAndroid Build Coastguard Worker 356*89c4ff92SAndroid Build Coastguard Worker PadDescriptor padDescriptor({{0, 0}, 357*89c4ff92SAndroid Build Coastguard Worker {1, 1}, 358*89c4ff92SAndroid Build Coastguard Worker {1, 1}, 359*89c4ff92SAndroid Build Coastguard Worker {0, 0}}); 360*89c4ff92SAndroid Build Coastguard Worker 361*89c4ff92SAndroid Build Coastguard Worker PadLayer* padLayer = graph.AddLayer<PadLayer>(padDescriptor, "pad"); 362*89c4ff92SAndroid Build Coastguard Worker padLayer->GetOutputSlot().SetTensorInfo(paddedInfo); 363*89c4ff92SAndroid Build Coastguard Worker 364*89c4ff92SAndroid Build Coastguard Worker Pooling2dDescriptor pooling2dDescriptor; 365*89c4ff92SAndroid Build Coastguard Worker pooling2dDescriptor.m_PoolType = PoolingAlgorithm::Average; 366*89c4ff92SAndroid Build Coastguard Worker pooling2dDescriptor.m_PoolWidth = 3; 367*89c4ff92SAndroid Build Coastguard Worker pooling2dDescriptor.m_PoolHeight = 3; 368*89c4ff92SAndroid Build Coastguard Worker pooling2dDescriptor.m_StrideX = 1; 369*89c4ff92SAndroid Build Coastguard Worker pooling2dDescriptor.m_StrideY = 1; 370*89c4ff92SAndroid Build Coastguard Worker pooling2dDescriptor.m_DataLayout = DataLayout::NHWC; 371*89c4ff92SAndroid Build Coastguard Worker // Include a pad with the pooling layer. This should prevent the optimization working. 372*89c4ff92SAndroid Build Coastguard Worker pooling2dDescriptor.m_PadLeft = 1; 373*89c4ff92SAndroid Build Coastguard Worker pooling2dDescriptor.m_PadRight = 1; 374*89c4ff92SAndroid Build Coastguard Worker pooling2dDescriptor.m_PadTop = 1; 375*89c4ff92SAndroid Build Coastguard Worker pooling2dDescriptor.m_PadBottom = 1; 376*89c4ff92SAndroid Build Coastguard Worker pooling2dDescriptor.m_PaddingMethod = PaddingMethod::Exclude; 377*89c4ff92SAndroid Build Coastguard Worker 378*89c4ff92SAndroid Build Coastguard Worker Pooling2dLayer* pool2dLayer = graph.AddLayer<Pooling2dLayer>(pooling2dDescriptor, "pool2d"); 379*89c4ff92SAndroid Build Coastguard Worker pool2dLayer->GetOutputSlot().SetTensorInfo(outputInfo); 380*89c4ff92SAndroid Build Coastguard Worker 381*89c4ff92SAndroid Build Coastguard Worker Layer* output = graph.AddLayer<OutputLayer>(0, "output"); 382*89c4ff92SAndroid Build Coastguard Worker 383*89c4ff92SAndroid Build Coastguard Worker // Connect up layers - input -> pad -> pool2d -> output 384*89c4ff92SAndroid Build Coastguard Worker input->GetOutputSlot().Connect(padLayer->GetInputSlot(0)); 385*89c4ff92SAndroid Build Coastguard Worker padLayer->GetOutputSlot().Connect(pool2dLayer->GetInputSlot(0)); 386*89c4ff92SAndroid Build Coastguard Worker pool2dLayer->GetOutputSlot().Connect(output->GetInputSlot(0)); 387*89c4ff92SAndroid Build Coastguard Worker __anonf6df302a0802(const Layer* const layer) 388*89c4ff92SAndroid Build Coastguard Worker auto checkSimplePool2d = [&](const Layer* const layer) { 389*89c4ff92SAndroid Build Coastguard Worker const auto pool2dLayer = static_cast<const Pooling2dLayer*>(layer); 390*89c4ff92SAndroid Build Coastguard Worker return IsLayerOfType<Pooling2dLayer>(layer) && (layer->GetNameStr() == "pool2d") && 391*89c4ff92SAndroid Build Coastguard Worker (pool2dLayer->GetParameters() == pooling2dDescriptor); 392*89c4ff92SAndroid Build Coastguard Worker }; 393*89c4ff92SAndroid Build Coastguard Worker 394*89c4ff92SAndroid Build Coastguard Worker CHECK(CheckSequence(graph.cbegin(), graph.cend(), 395*89c4ff92SAndroid Build Coastguard Worker &IsLayerOfType<InputLayer>, 396*89c4ff92SAndroid Build Coastguard Worker &IsLayerOfType<PadLayer>, 397*89c4ff92SAndroid Build Coastguard Worker checkSimplePool2d, 398*89c4ff92SAndroid Build Coastguard Worker &IsLayerOfType<OutputLayer>)); 399*89c4ff92SAndroid Build Coastguard Worker 400*89c4ff92SAndroid Build Coastguard Worker armnn::Optimizer::Pass(graph, MakeOptimizations(FoldPadIntoPooling2d())); 401*89c4ff92SAndroid Build Coastguard Worker 402*89c4ff92SAndroid Build Coastguard Worker // The optimization should not have modified the graph. 403*89c4ff92SAndroid Build Coastguard Worker CHECK(CheckSequence(graph.cbegin(), graph.cend(), 404*89c4ff92SAndroid Build Coastguard Worker &IsLayerOfType<InputLayer>, 405*89c4ff92SAndroid Build Coastguard Worker &IsLayerOfType<PadLayer>, 406*89c4ff92SAndroid Build Coastguard Worker checkSimplePool2d, 407*89c4ff92SAndroid Build Coastguard Worker &IsLayerOfType<OutputLayer>)); 408*89c4ff92SAndroid Build Coastguard Worker } 409*89c4ff92SAndroid Build Coastguard Worker 410*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("FoldPadLayerIntoPooling2dLayer_MaxPoolingLayerWithLargePadValueShouldNotBeFolded") 411*89c4ff92SAndroid Build Coastguard Worker { 412*89c4ff92SAndroid Build Coastguard Worker // In this test setup input, Pad layer with a large pad value, Max Pooling layer, output layer. The optimization 413*89c4ff92SAndroid Build Coastguard Worker // should not work as the pad value will modify the result of the max pooling layer. 414*89c4ff92SAndroid Build Coastguard Worker Graph graph; 415*89c4ff92SAndroid Build Coastguard Worker const unsigned int inputShape[] = {1, 2, 2, 3}; 416*89c4ff92SAndroid Build Coastguard Worker const unsigned int paddedShape[] = {1, 4, 4, 3}; 417*89c4ff92SAndroid Build Coastguard Worker const unsigned int outputShape[] = {1, 2, 2, 3}; 418*89c4ff92SAndroid Build Coastguard Worker 419*89c4ff92SAndroid Build Coastguard Worker TensorInfo inputInfo(4, inputShape, DataType::Float32); 420*89c4ff92SAndroid Build Coastguard Worker TensorInfo paddedInfo(4, paddedShape, DataType::Float32); 421*89c4ff92SAndroid Build Coastguard Worker TensorInfo outputInfo(4, outputShape, DataType::Float32); 422*89c4ff92SAndroid Build Coastguard Worker 423*89c4ff92SAndroid Build Coastguard Worker Layer* input = graph.AddLayer<InputLayer>(0, "input"); 424*89c4ff92SAndroid Build Coastguard Worker input->GetOutputSlot().SetTensorInfo(inputInfo); 425*89c4ff92SAndroid Build Coastguard Worker 426*89c4ff92SAndroid Build Coastguard Worker PadDescriptor padDescriptor({{0, 0}, 427*89c4ff92SAndroid Build Coastguard Worker {1, 1}, 428*89c4ff92SAndroid Build Coastguard Worker {1, 1}, 429*89c4ff92SAndroid Build Coastguard Worker {0, 0}}); 430*89c4ff92SAndroid Build Coastguard Worker // For Max pooling of a float a pad value of 0 is more than enough to stop the fold happening. 431*89c4ff92SAndroid Build Coastguard Worker // Set this to -std::numeric_limits<float>::infinity() to make the fold happen. 432*89c4ff92SAndroid Build Coastguard Worker padDescriptor.m_PadValue = 0; 433*89c4ff92SAndroid Build Coastguard Worker 434*89c4ff92SAndroid Build Coastguard Worker PadLayer* padLayer = graph.AddLayer<PadLayer>(padDescriptor, "pad"); 435*89c4ff92SAndroid Build Coastguard Worker padLayer->GetOutputSlot().SetTensorInfo(paddedInfo); 436*89c4ff92SAndroid Build Coastguard Worker 437*89c4ff92SAndroid Build Coastguard Worker Pooling2dDescriptor pooling2dDescriptor; 438*89c4ff92SAndroid Build Coastguard Worker pooling2dDescriptor.m_PoolType = PoolingAlgorithm::Max; 439*89c4ff92SAndroid Build Coastguard Worker pooling2dDescriptor.m_PoolWidth = 3; 440*89c4ff92SAndroid Build Coastguard Worker pooling2dDescriptor.m_PoolHeight = 3; 441*89c4ff92SAndroid Build Coastguard Worker pooling2dDescriptor.m_StrideX = 1; 442*89c4ff92SAndroid Build Coastguard Worker pooling2dDescriptor.m_StrideY = 1; 443*89c4ff92SAndroid Build Coastguard Worker pooling2dDescriptor.m_DataLayout = DataLayout::NHWC; 444*89c4ff92SAndroid Build Coastguard Worker 445*89c4ff92SAndroid Build Coastguard Worker Pooling2dLayer* pool2dLayer = graph.AddLayer<Pooling2dLayer>(pooling2dDescriptor, "pool2d"); 446*89c4ff92SAndroid Build Coastguard Worker pool2dLayer->GetOutputSlot().SetTensorInfo(outputInfo); 447*89c4ff92SAndroid Build Coastguard Worker 448*89c4ff92SAndroid Build Coastguard Worker Layer* output = graph.AddLayer<OutputLayer>(0, "output"); 449*89c4ff92SAndroid Build Coastguard Worker 450*89c4ff92SAndroid Build Coastguard Worker // Connect up layers - input -> pad -> pool2d -> output 451*89c4ff92SAndroid Build Coastguard Worker input->GetOutputSlot().Connect(padLayer->GetInputSlot(0)); 452*89c4ff92SAndroid Build Coastguard Worker padLayer->GetOutputSlot().Connect(pool2dLayer->GetInputSlot(0)); 453*89c4ff92SAndroid Build Coastguard Worker pool2dLayer->GetOutputSlot().Connect(output->GetInputSlot(0)); 454*89c4ff92SAndroid Build Coastguard Worker __anonf6df302a0902(const Layer* const layer) 455*89c4ff92SAndroid Build Coastguard Worker auto checkSimplePool2d = [&](const Layer* const layer) { 456*89c4ff92SAndroid Build Coastguard Worker const auto pool2dLayer = static_cast<const Pooling2dLayer*>(layer); 457*89c4ff92SAndroid Build Coastguard Worker return IsLayerOfType<Pooling2dLayer>(layer) && (layer->GetNameStr() == "pool2d") && 458*89c4ff92SAndroid Build Coastguard Worker (pool2dLayer->GetParameters() == pooling2dDescriptor); 459*89c4ff92SAndroid Build Coastguard Worker }; 460*89c4ff92SAndroid Build Coastguard Worker 461*89c4ff92SAndroid Build Coastguard Worker CHECK(CheckSequence(graph.cbegin(), graph.cend(), 462*89c4ff92SAndroid Build Coastguard Worker &IsLayerOfType<InputLayer>, 463*89c4ff92SAndroid Build Coastguard Worker &IsLayerOfType<PadLayer>, 464*89c4ff92SAndroid Build Coastguard Worker checkSimplePool2d, 465*89c4ff92SAndroid Build Coastguard Worker &IsLayerOfType<OutputLayer>)); 466*89c4ff92SAndroid Build Coastguard Worker 467*89c4ff92SAndroid Build Coastguard Worker armnn::Optimizer::Pass(graph, MakeOptimizations(FoldPadIntoPooling2d())); 468*89c4ff92SAndroid Build Coastguard Worker 469*89c4ff92SAndroid Build Coastguard Worker // The optimization should not have modified the graph. 470*89c4ff92SAndroid Build Coastguard Worker CHECK(CheckSequence(graph.cbegin(), graph.cend(), 471*89c4ff92SAndroid Build Coastguard Worker &IsLayerOfType<InputLayer>, 472*89c4ff92SAndroid Build Coastguard Worker &IsLayerOfType<PadLayer>, 473*89c4ff92SAndroid Build Coastguard Worker checkSimplePool2d, 474*89c4ff92SAndroid Build Coastguard Worker &IsLayerOfType<OutputLayer>)); 475*89c4ff92SAndroid Build Coastguard Worker } 476*89c4ff92SAndroid Build Coastguard Worker 477*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("FoldPadLayerIntoPooling2dLayer_QuantizedAveragePoolingShouldNotBeFolded") 478*89c4ff92SAndroid Build Coastguard Worker { 479*89c4ff92SAndroid Build Coastguard Worker Graph graph; 480*89c4ff92SAndroid Build Coastguard Worker const unsigned int inputShape[] = {1, 2, 2, 3}; 481*89c4ff92SAndroid Build Coastguard Worker const unsigned int paddedShape[] = {1, 4, 4, 3}; 482*89c4ff92SAndroid Build Coastguard Worker const unsigned int outputShape[] = {1, 2, 2, 3}; 483*89c4ff92SAndroid Build Coastguard Worker 484*89c4ff92SAndroid Build Coastguard Worker TensorInfo inputInfo(4, inputShape, DataType::QAsymmU8); 485*89c4ff92SAndroid Build Coastguard Worker TensorInfo paddedInfo(4, paddedShape, DataType::QAsymmU8); 486*89c4ff92SAndroid Build Coastguard Worker TensorInfo outputInfo(4, outputShape, DataType::QAsymmU8); 487*89c4ff92SAndroid Build Coastguard Worker 488*89c4ff92SAndroid Build Coastguard Worker Layer* input = graph.AddLayer<InputLayer>(0, "input"); 489*89c4ff92SAndroid Build Coastguard Worker input->GetOutputSlot().SetTensorInfo(inputInfo); 490*89c4ff92SAndroid Build Coastguard Worker 491*89c4ff92SAndroid Build Coastguard Worker PadDescriptor padDescriptor({{0, 0}, 492*89c4ff92SAndroid Build Coastguard Worker {1, 1}, 493*89c4ff92SAndroid Build Coastguard Worker {1, 1}, 494*89c4ff92SAndroid Build Coastguard Worker {0, 0}}); 495*89c4ff92SAndroid Build Coastguard Worker 496*89c4ff92SAndroid Build Coastguard Worker PadLayer* padLayer = graph.AddLayer<PadLayer>(padDescriptor, "pad"); 497*89c4ff92SAndroid Build Coastguard Worker padLayer->GetOutputSlot().SetTensorInfo(paddedInfo); 498*89c4ff92SAndroid Build Coastguard Worker 499*89c4ff92SAndroid Build Coastguard Worker Pooling2dDescriptor pooling2dDescriptor; 500*89c4ff92SAndroid Build Coastguard Worker pooling2dDescriptor.m_PoolType = PoolingAlgorithm::Average; 501*89c4ff92SAndroid Build Coastguard Worker pooling2dDescriptor.m_PoolWidth = 3; 502*89c4ff92SAndroid Build Coastguard Worker pooling2dDescriptor.m_PoolHeight = 3; 503*89c4ff92SAndroid Build Coastguard Worker pooling2dDescriptor.m_StrideX = 1; 504*89c4ff92SAndroid Build Coastguard Worker pooling2dDescriptor.m_StrideY = 1; 505*89c4ff92SAndroid Build Coastguard Worker pooling2dDescriptor.m_DataLayout = DataLayout::NHWC; 506*89c4ff92SAndroid Build Coastguard Worker 507*89c4ff92SAndroid Build Coastguard Worker Pooling2dLayer* pool2dLayer = graph.AddLayer<Pooling2dLayer>(pooling2dDescriptor, "pool2d"); 508*89c4ff92SAndroid Build Coastguard Worker pool2dLayer->GetOutputSlot().SetTensorInfo(outputInfo); 509*89c4ff92SAndroid Build Coastguard Worker 510*89c4ff92SAndroid Build Coastguard Worker Layer* output = graph.AddLayer<OutputLayer>(0, "output"); 511*89c4ff92SAndroid Build Coastguard Worker 512*89c4ff92SAndroid Build Coastguard Worker // Connect up layers - input -> pad -> pool2d -> output 513*89c4ff92SAndroid Build Coastguard Worker input->GetOutputSlot().Connect(padLayer->GetInputSlot(0)); 514*89c4ff92SAndroid Build Coastguard Worker padLayer->GetOutputSlot().Connect(pool2dLayer->GetInputSlot(0)); 515*89c4ff92SAndroid Build Coastguard Worker pool2dLayer->GetOutputSlot().Connect(output->GetInputSlot(0)); 516*89c4ff92SAndroid Build Coastguard Worker __anonf6df302a0a02(const Layer* const layer) 517*89c4ff92SAndroid Build Coastguard Worker auto checkSimplePool2d = [&](const Layer* const layer) { 518*89c4ff92SAndroid Build Coastguard Worker const auto pool2dLayer = static_cast<const Pooling2dLayer*>(layer); 519*89c4ff92SAndroid Build Coastguard Worker return IsLayerOfType<Pooling2dLayer>(layer) && (layer->GetNameStr() == "pool2d") && 520*89c4ff92SAndroid Build Coastguard Worker (pool2dLayer->GetParameters() == pooling2dDescriptor); 521*89c4ff92SAndroid Build Coastguard Worker }; 522*89c4ff92SAndroid Build Coastguard Worker 523*89c4ff92SAndroid Build Coastguard Worker CHECK(CheckSequence(graph.cbegin(), graph.cend(), 524*89c4ff92SAndroid Build Coastguard Worker &IsLayerOfType<InputLayer>, 525*89c4ff92SAndroid Build Coastguard Worker &IsLayerOfType<PadLayer>, 526*89c4ff92SAndroid Build Coastguard Worker checkSimplePool2d, 527*89c4ff92SAndroid Build Coastguard Worker &IsLayerOfType<OutputLayer>)); 528*89c4ff92SAndroid Build Coastguard Worker 529*89c4ff92SAndroid Build Coastguard Worker armnn::Optimizer::Pass(graph, MakeOptimizations(FoldPadIntoPooling2d())); 530*89c4ff92SAndroid Build Coastguard Worker 531*89c4ff92SAndroid Build Coastguard Worker // The optimization should not have modified the graph. 532*89c4ff92SAndroid Build Coastguard Worker CHECK(CheckSequence(graph.cbegin(), graph.cend(), 533*89c4ff92SAndroid Build Coastguard Worker &IsLayerOfType<InputLayer>, 534*89c4ff92SAndroid Build Coastguard Worker &IsLayerOfType<PadLayer>, 535*89c4ff92SAndroid Build Coastguard Worker checkSimplePool2d, 536*89c4ff92SAndroid Build Coastguard Worker &IsLayerOfType<OutputLayer>)); 537*89c4ff92SAndroid Build Coastguard Worker } 538*89c4ff92SAndroid Build Coastguard Worker 539*89c4ff92SAndroid Build Coastguard Worker #if defined(ARMNNREF_ENABLED) 540*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("FoldPadLayerIntoPooling2dLayer_ExecuteInferenceWithAndWithoutOptimization") 541*89c4ff92SAndroid Build Coastguard Worker { 542*89c4ff92SAndroid Build Coastguard Worker // The idea of this test to run a simple pad+pool2d network twice. Once 543*89c4ff92SAndroid Build Coastguard Worker // with FoldPadLayerIntoPooling2dLayer enabled and a second time with it 544*89c4ff92SAndroid Build Coastguard Worker // avoided. The output tensors of each should match. 545*89c4ff92SAndroid Build Coastguard Worker const unsigned int inputShape[] = {1, 4, 4, 2}; 546*89c4ff92SAndroid Build Coastguard Worker const unsigned int paddedShape[] = {1, 6, 6, 2}; 547*89c4ff92SAndroid Build Coastguard Worker const unsigned int outputShape[] = {1, 4, 4, 2}; 548*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputData({2.0f, 2.0f, 6.0f, 6.0f, 549*89c4ff92SAndroid Build Coastguard Worker 4.0f, 4.0f, 8.0f, 8.0f, 550*89c4ff92SAndroid Build Coastguard Worker 10.0f, 12.0f, 14.0f, 16.0f, 551*89c4ff92SAndroid Build Coastguard Worker 10.0f, 12.0f, 16.0f, 14.0f, 552*89c4ff92SAndroid Build Coastguard Worker 553*89c4ff92SAndroid Build Coastguard Worker 18.0f, 20.0f, 24.0f, 22.0f, 554*89c4ff92SAndroid Build Coastguard Worker 20.0f, 18.0f, 22.0f, 24.0f, 555*89c4ff92SAndroid Build Coastguard Worker 26.0f, 28.0f, 0.0f, 0.0f, 556*89c4ff92SAndroid Build Coastguard Worker 26.0f, 28.0f, 0.0f, 0.0f, 557*89c4ff92SAndroid Build Coastguard Worker }); 558*89c4ff92SAndroid Build Coastguard Worker try 559*89c4ff92SAndroid Build Coastguard Worker { 560*89c4ff92SAndroid Build Coastguard Worker // Create a network of input, pad, pooling 2D, output. 561*89c4ff92SAndroid Build Coastguard Worker INetworkPtr network = INetwork::Create(); 562*89c4ff92SAndroid Build Coastguard Worker 563*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* inputLayer = network->AddInputLayer(0); 564*89c4ff92SAndroid Build Coastguard Worker TensorInfo inputInfo(4, inputShape, DataType::Float32); 565*89c4ff92SAndroid Build Coastguard Worker inputLayer->GetOutputSlot(0).SetTensorInfo(inputInfo); 566*89c4ff92SAndroid Build Coastguard Worker 567*89c4ff92SAndroid Build Coastguard Worker PadDescriptor padDescriptor({{0, 0}, 568*89c4ff92SAndroid Build Coastguard Worker {1, 1}, 569*89c4ff92SAndroid Build Coastguard Worker {1, 1}, 570*89c4ff92SAndroid Build Coastguard Worker {0, 0}}); 571*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* padLayer = network->AddPadLayer(padDescriptor, "Pad"); 572*89c4ff92SAndroid Build Coastguard Worker TensorInfo paddedInfo(4, paddedShape, DataType::Float32); 573*89c4ff92SAndroid Build Coastguard Worker padLayer->GetOutputSlot(0).SetTensorInfo(paddedInfo); 574*89c4ff92SAndroid Build Coastguard Worker 575*89c4ff92SAndroid Build Coastguard Worker Pooling2dDescriptor pooling2dDescriptor; 576*89c4ff92SAndroid Build Coastguard Worker pooling2dDescriptor.m_PoolType = PoolingAlgorithm::Average; 577*89c4ff92SAndroid Build Coastguard Worker pooling2dDescriptor.m_PoolWidth = 3; 578*89c4ff92SAndroid Build Coastguard Worker pooling2dDescriptor.m_PoolHeight = 3; 579*89c4ff92SAndroid Build Coastguard Worker pooling2dDescriptor.m_StrideX = 1; 580*89c4ff92SAndroid Build Coastguard Worker pooling2dDescriptor.m_StrideY = 1; 581*89c4ff92SAndroid Build Coastguard Worker pooling2dDescriptor.m_DataLayout = DataLayout::NHWC; 582*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* pool2dLayer = network->AddPooling2dLayer(pooling2dDescriptor, "Pool2D"); 583*89c4ff92SAndroid Build Coastguard Worker TensorInfo outputInfo(4, outputShape, DataType::Float32); 584*89c4ff92SAndroid Build Coastguard Worker pool2dLayer->GetOutputSlot(0).SetTensorInfo(outputInfo); 585*89c4ff92SAndroid Build Coastguard Worker 586*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* outputLayer = network->AddOutputLayer(0); 587*89c4ff92SAndroid Build Coastguard Worker 588*89c4ff92SAndroid Build Coastguard Worker // Connect layers 589*89c4ff92SAndroid Build Coastguard Worker inputLayer->GetOutputSlot(0).Connect(padLayer->GetInputSlot(0)); 590*89c4ff92SAndroid Build Coastguard Worker padLayer->GetOutputSlot(0).Connect(pool2dLayer->GetInputSlot(0)); 591*89c4ff92SAndroid Build Coastguard Worker pool2dLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0)); 592*89c4ff92SAndroid Build Coastguard Worker 593*89c4ff92SAndroid Build Coastguard Worker // Create ArmNN runtime 594*89c4ff92SAndroid Build Coastguard Worker IRuntimePtr run = IRuntime::Create(IRuntime::CreationOptions()); // default options 595*89c4ff92SAndroid Build Coastguard Worker // Optimise the network 596*89c4ff92SAndroid Build Coastguard Worker IOptimizedNetworkPtr optimizedNetwork = Optimize(*network, {Compute::CpuRef}, run->GetDeviceSpec()); 597*89c4ff92SAndroid Build Coastguard Worker // Load network into runtime 598*89c4ff92SAndroid Build Coastguard Worker NetworkId networkIdentifier; 599*89c4ff92SAndroid Build Coastguard Worker CHECK(run->LoadNetwork(networkIdentifier, std::move(optimizedNetwork)) == Status::Success); 600*89c4ff92SAndroid Build Coastguard Worker 601*89c4ff92SAndroid Build Coastguard Worker TensorInfo inputTensorInfo = run->GetInputTensorInfo(networkIdentifier, 0); 602*89c4ff92SAndroid Build Coastguard Worker inputTensorInfo.SetConstant(true); 603*89c4ff92SAndroid Build Coastguard Worker InputTensors inputTensors{{0, ConstTensor(inputTensorInfo, inputData.data())}}; 604*89c4ff92SAndroid Build Coastguard Worker 605*89c4ff92SAndroid Build Coastguard Worker // Set the initial values of the data to different values to the golden data just in case the inference fails. 606*89c4ff92SAndroid Build Coastguard Worker std::vector<float> optimizedData(32, -std::numeric_limits<float>::infinity()); 607*89c4ff92SAndroid Build Coastguard Worker OutputTensors outputTensors{{0, Tensor(outputInfo, optimizedData.data())}}; 608*89c4ff92SAndroid Build Coastguard Worker // Execute network 609*89c4ff92SAndroid Build Coastguard Worker run->EnqueueWorkload(networkIdentifier, inputTensors, outputTensors); 610*89c4ff92SAndroid Build Coastguard Worker // Unload it. 611*89c4ff92SAndroid Build Coastguard Worker run->UnloadNetwork(networkIdentifier); 612*89c4ff92SAndroid Build Coastguard Worker 613*89c4ff92SAndroid Build Coastguard Worker // In this second case the pad will have two outputs, one connected to the pooling layer the second connected to 614*89c4ff92SAndroid Build Coastguard Worker // a second output layer. This will prevent the FoldPadLayerIntoPooling2dLayer optimization from working. 615*89c4ff92SAndroid Build Coastguard Worker // A previous test, FoldPadLayerIntoPooling2d_PadWithMultipleOutputsShouldNotBeOptimized, has proved that doing 616*89c4ff92SAndroid Build Coastguard Worker // this will avoid the optimization. 617*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* dummyOutputLayer = network->AddOutputLayer(1); 618*89c4ff92SAndroid Build Coastguard Worker padLayer->GetOutputSlot(0).Connect(dummyOutputLayer->GetInputSlot(0)); 619*89c4ff92SAndroid Build Coastguard Worker 620*89c4ff92SAndroid Build Coastguard Worker // Optimize and load and execute it a second time. 621*89c4ff92SAndroid Build Coastguard Worker optimizedNetwork = Optimize(*network, {Compute::CpuRef}, run->GetDeviceSpec()); 622*89c4ff92SAndroid Build Coastguard Worker CHECK(run->LoadNetwork(networkIdentifier, std::move(optimizedNetwork)) == Status::Success); 623*89c4ff92SAndroid Build Coastguard Worker std::vector<float> goldenData(32, 0.0f); 624*89c4ff92SAndroid Build Coastguard Worker std::vector<float> padOutputData(72, 0.0f); 625*89c4ff92SAndroid Build Coastguard Worker OutputTensors goldenTensors{{0, Tensor(outputInfo, goldenData.data())}, 626*89c4ff92SAndroid Build Coastguard Worker {1, Tensor(paddedInfo, padOutputData.data())}}; 627*89c4ff92SAndroid Build Coastguard Worker run->EnqueueWorkload(networkIdentifier, inputTensors, goldenTensors); 628*89c4ff92SAndroid Build Coastguard Worker 629*89c4ff92SAndroid Build Coastguard Worker // Now we can compare goldenData against optimizedData. They should be the same. 630*89c4ff92SAndroid Build Coastguard Worker CHECK(std::equal(goldenData.begin(), goldenData.end(), optimizedData.begin())); 631*89c4ff92SAndroid Build Coastguard Worker } 632*89c4ff92SAndroid Build Coastguard Worker catch (const std::exception& e) 633*89c4ff92SAndroid Build Coastguard Worker { 634*89c4ff92SAndroid Build Coastguard Worker std::cerr << e.what() << std::endl; 635*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT_MSG(false, e.what()); 636*89c4ff92SAndroid Build Coastguard Worker } 637*89c4ff92SAndroid Build Coastguard Worker } 638*89c4ff92SAndroid Build Coastguard Worker 639*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("FoldPadLayerIntoConv2dLayer_ExecuteInferenceWithAndWithoutOptimization") 640*89c4ff92SAndroid Build Coastguard Worker { 641*89c4ff92SAndroid Build Coastguard Worker // The idea of this test to run a simple pad+conv2d network twice. Once 642*89c4ff92SAndroid Build Coastguard Worker // with FoldPadLayerIntoConv2dLayer enabled and a second time with it 643*89c4ff92SAndroid Build Coastguard Worker // avoided. The output tensors of each should match. 644*89c4ff92SAndroid Build Coastguard Worker const unsigned int inputShape[] = {1, 4, 4, 3}; // NHWCin 645*89c4ff92SAndroid Build Coastguard Worker const unsigned int paddedShape[] = {1, 6, 6, 3}; 646*89c4ff92SAndroid Build Coastguard Worker const unsigned int weightsShape[] = {4, 2, 2, 3}; // CoutHWCin 647*89c4ff92SAndroid Build Coastguard Worker const unsigned int outputShape[] = {1, 5, 5, 4}; // NHWCout 648*89c4ff92SAndroid Build Coastguard Worker 649*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputData({2.0f, 2.0f, 6.0f, 6.0f, 650*89c4ff92SAndroid Build Coastguard Worker 4.0f, 4.0f, 8.0f, 8.0f, 651*89c4ff92SAndroid Build Coastguard Worker 10.0f, 12.0f, 14.0f, 16.0f, 652*89c4ff92SAndroid Build Coastguard Worker 10.0f, 12.0f, 16.0f, 14.0f, 653*89c4ff92SAndroid Build Coastguard Worker 654*89c4ff92SAndroid Build Coastguard Worker 18.0f, 20.0f, 24.0f, 22.0f, 655*89c4ff92SAndroid Build Coastguard Worker 20.0f, 18.0f, 22.0f, 24.0f, 656*89c4ff92SAndroid Build Coastguard Worker 26.0f, 28.0f, 0.0f, 0.0f, 657*89c4ff92SAndroid Build Coastguard Worker 26.0f, 28.0f, 0.0f, 0.0f, 658*89c4ff92SAndroid Build Coastguard Worker 659*89c4ff92SAndroid Build Coastguard Worker 2.0f, 2.0f, 6.0f, 6.0f, 660*89c4ff92SAndroid Build Coastguard Worker 4.0f, 4.0f, 8.0f, 8.0f, 661*89c4ff92SAndroid Build Coastguard Worker 10.0f, 12.0f, 14.0f, 16.0f, 662*89c4ff92SAndroid Build Coastguard Worker 10.0f, 12.0f, 16.0f, 14.0f, 663*89c4ff92SAndroid Build Coastguard Worker }); 664*89c4ff92SAndroid Build Coastguard Worker try 665*89c4ff92SAndroid Build Coastguard Worker { 666*89c4ff92SAndroid Build Coastguard Worker // Create a network of input, pad, pooling 2D, output. 667*89c4ff92SAndroid Build Coastguard Worker INetworkPtr network = INetwork::Create(); 668*89c4ff92SAndroid Build Coastguard Worker 669*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* inputLayer = network->AddInputLayer(0); 670*89c4ff92SAndroid Build Coastguard Worker TensorInfo inputInfo(4, inputShape, DataType::Float32); 671*89c4ff92SAndroid Build Coastguard Worker inputLayer->GetOutputSlot(0).SetTensorInfo(inputInfo); 672*89c4ff92SAndroid Build Coastguard Worker 673*89c4ff92SAndroid Build Coastguard Worker PadDescriptor padDescriptor({{0, 0}, 674*89c4ff92SAndroid Build Coastguard Worker {1, 1}, 675*89c4ff92SAndroid Build Coastguard Worker {1, 1}, 676*89c4ff92SAndroid Build Coastguard Worker {0, 0}}); 677*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* padLayer = network->AddPadLayer(padDescriptor, "Pad"); 678*89c4ff92SAndroid Build Coastguard Worker TensorInfo paddedInfo(4, paddedShape, DataType::Float32); 679*89c4ff92SAndroid Build Coastguard Worker padLayer->GetOutputSlot(0).SetTensorInfo(paddedInfo); 680*89c4ff92SAndroid Build Coastguard Worker 681*89c4ff92SAndroid Build Coastguard Worker Convolution2dDescriptor convDescriptor; 682*89c4ff92SAndroid Build Coastguard Worker convDescriptor.m_DataLayout = DataLayout::NHWC; 683*89c4ff92SAndroid Build Coastguard Worker convDescriptor.m_StrideX = 1; 684*89c4ff92SAndroid Build Coastguard Worker convDescriptor.m_StrideY = 1; 685*89c4ff92SAndroid Build Coastguard Worker convDescriptor.m_BiasEnabled = true; 686*89c4ff92SAndroid Build Coastguard Worker 687*89c4ff92SAndroid Build Coastguard Worker std::vector<float> weightsData = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 688*89c4ff92SAndroid Build Coastguard Worker 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 689*89c4ff92SAndroid Build Coastguard Worker 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 690*89c4ff92SAndroid Build Coastguard Worker 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42}; 691*89c4ff92SAndroid Build Coastguard Worker TensorInfo weightsInfo(4, weightsShape, DataType::Float32, 1.0f, 0, true); 692*89c4ff92SAndroid Build Coastguard Worker ConstTensor weights(weightsInfo, weightsData); 693*89c4ff92SAndroid Build Coastguard Worker std::vector<float> biasVector = {5, 6, 7, 8}; 694*89c4ff92SAndroid Build Coastguard Worker TensorInfo biasInfo({4}, DataType::Float32, 1.0f, 0, true); 695*89c4ff92SAndroid Build Coastguard Worker ConstTensor bias(biasInfo, biasVector); 696*89c4ff92SAndroid Build Coastguard Worker 697*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* conv2dLayer = network->AddConvolution2dLayer(convDescriptor, "Conv2D"); 698*89c4ff92SAndroid Build Coastguard Worker 699*89c4ff92SAndroid Build Coastguard Worker TensorInfo outputInfo(4, outputShape, DataType::Float32); 700*89c4ff92SAndroid Build Coastguard Worker conv2dLayer->GetOutputSlot(0).SetTensorInfo(outputInfo); 701*89c4ff92SAndroid Build Coastguard Worker 702*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* outputLayer = network->AddOutputLayer(0); 703*89c4ff92SAndroid Build Coastguard Worker 704*89c4ff92SAndroid Build Coastguard Worker // Connect layers 705*89c4ff92SAndroid Build Coastguard Worker inputLayer->GetOutputSlot(0).Connect(padLayer->GetInputSlot(0)); 706*89c4ff92SAndroid Build Coastguard Worker padLayer->GetOutputSlot(0).Connect(conv2dLayer->GetInputSlot(0)); 707*89c4ff92SAndroid Build Coastguard Worker conv2dLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0)); 708*89c4ff92SAndroid Build Coastguard Worker 709*89c4ff92SAndroid Build Coastguard Worker auto weightsLayer = network->AddConstantLayer(weights, "Weights"); 710*89c4ff92SAndroid Build Coastguard Worker weightsLayer->GetOutputSlot(0).SetTensorInfo(weights.GetInfo()); 711*89c4ff92SAndroid Build Coastguard Worker weightsLayer->GetOutputSlot(0).Connect(conv2dLayer->GetInputSlot(1)); 712*89c4ff92SAndroid Build Coastguard Worker 713*89c4ff92SAndroid Build Coastguard Worker auto biasLayer = network->AddConstantLayer(bias, "Bias"); 714*89c4ff92SAndroid Build Coastguard Worker biasLayer->GetOutputSlot(0).SetTensorInfo(bias.GetInfo()); 715*89c4ff92SAndroid Build Coastguard Worker biasLayer->GetOutputSlot(0).Connect(conv2dLayer->GetInputSlot(2)); 716*89c4ff92SAndroid Build Coastguard Worker 717*89c4ff92SAndroid Build Coastguard Worker // Create ArmNN runtime 718*89c4ff92SAndroid Build Coastguard Worker IRuntimePtr run = IRuntime::Create(IRuntime::CreationOptions()); // default options 719*89c4ff92SAndroid Build Coastguard Worker // Optimise the network 720*89c4ff92SAndroid Build Coastguard Worker IOptimizedNetworkPtr optimizedNetwork = Optimize(*network, {Compute::CpuRef}, run->GetDeviceSpec()); 721*89c4ff92SAndroid Build Coastguard Worker // Load network into runtime 722*89c4ff92SAndroid Build Coastguard Worker NetworkId networkIdentifier; 723*89c4ff92SAndroid Build Coastguard Worker CHECK(run->LoadNetwork(networkIdentifier, std::move(optimizedNetwork)) == Status::Success); 724*89c4ff92SAndroid Build Coastguard Worker 725*89c4ff92SAndroid Build Coastguard Worker TensorInfo inputTensorInfo = run->GetInputTensorInfo(networkIdentifier, 0); 726*89c4ff92SAndroid Build Coastguard Worker inputTensorInfo.SetConstant(true); 727*89c4ff92SAndroid Build Coastguard Worker InputTensors inputTensors{{0, ConstTensor(inputTensorInfo, inputData.data())}}; 728*89c4ff92SAndroid Build Coastguard Worker 729*89c4ff92SAndroid Build Coastguard Worker // Set the initial values of the data to different values to the golden data just in case the inference fails. 730*89c4ff92SAndroid Build Coastguard Worker std::vector<float> optimizedData(100, -std::numeric_limits<float>::infinity()); 731*89c4ff92SAndroid Build Coastguard Worker OutputTensors outputTensors{{0, Tensor(outputInfo, optimizedData.data())}}; 732*89c4ff92SAndroid Build Coastguard Worker // Execute network 733*89c4ff92SAndroid Build Coastguard Worker run->EnqueueWorkload(networkIdentifier, inputTensors, outputTensors); 734*89c4ff92SAndroid Build Coastguard Worker // Unload it. 735*89c4ff92SAndroid Build Coastguard Worker run->UnloadNetwork(networkIdentifier); 736*89c4ff92SAndroid Build Coastguard Worker 737*89c4ff92SAndroid Build Coastguard Worker // In this second case the pad will have two outputs, one connected to the conv layer the second connected to 738*89c4ff92SAndroid Build Coastguard Worker // a second output layer. This will prevent the FoldPadLayerIntoConv2dLayer optimization from working. 739*89c4ff92SAndroid Build Coastguard Worker // A previous test, FoldPadLayerIntoConv2d_PadWithMultipleOutputsShouldNotBeOptimized, has proved that doing 740*89c4ff92SAndroid Build Coastguard Worker // this will avoid the optimization. 741*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* dummyOutputLayer = network->AddOutputLayer(1); 742*89c4ff92SAndroid Build Coastguard Worker padLayer->GetOutputSlot(0).Connect(dummyOutputLayer->GetInputSlot(0)); 743*89c4ff92SAndroid Build Coastguard Worker 744*89c4ff92SAndroid Build Coastguard Worker // Optimize and load and execute it a second time. 745*89c4ff92SAndroid Build Coastguard Worker optimizedNetwork = Optimize(*network, {Compute::CpuRef}, run->GetDeviceSpec()); 746*89c4ff92SAndroid Build Coastguard Worker CHECK(run->LoadNetwork(networkIdentifier, std::move(optimizedNetwork)) == Status::Success); 747*89c4ff92SAndroid Build Coastguard Worker std::vector<float> goldenData(100, 0.0f); 748*89c4ff92SAndroid Build Coastguard Worker std::vector<float> padOutputData(108, 0.0f); 749*89c4ff92SAndroid Build Coastguard Worker OutputTensors goldenTensors{{0, Tensor(outputInfo, goldenData.data())}, 750*89c4ff92SAndroid Build Coastguard Worker {1, Tensor(paddedInfo, padOutputData.data())}}; 751*89c4ff92SAndroid Build Coastguard Worker run->EnqueueWorkload(networkIdentifier, inputTensors, goldenTensors); 752*89c4ff92SAndroid Build Coastguard Worker 753*89c4ff92SAndroid Build Coastguard Worker // Now we can compare goldenData against optimizedData. They should be the same. 754*89c4ff92SAndroid Build Coastguard Worker CHECK(std::equal(goldenData.begin(), goldenData.end(), optimizedData.begin())); 755*89c4ff92SAndroid Build Coastguard Worker } 756*89c4ff92SAndroid Build Coastguard Worker catch (const std::exception& e) 757*89c4ff92SAndroid Build Coastguard Worker { 758*89c4ff92SAndroid Build Coastguard Worker std::cerr << e.what() << std::endl; 759*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT_MSG(false, e.what()); 760*89c4ff92SAndroid Build Coastguard Worker } 761*89c4ff92SAndroid Build Coastguard Worker } 762*89c4ff92SAndroid Build Coastguard Worker 763*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("FoldPadLayerIntoDepthwiseConv2dLayer_ExecuteInferenceWithAndWithoutOptimization") 764*89c4ff92SAndroid Build Coastguard Worker { 765*89c4ff92SAndroid Build Coastguard Worker // The idea of this test to run a simple pad+depthwiseconv2d network twice. Once 766*89c4ff92SAndroid Build Coastguard Worker // with FoldPadLayerIntoDeptwiseConv2dLayer enabled and a second time with it 767*89c4ff92SAndroid Build Coastguard Worker // avoided. The output tensors of each should match. 768*89c4ff92SAndroid Build Coastguard Worker const unsigned int inputShape[] = {1, 4, 4, 3}; // NHWCin 769*89c4ff92SAndroid Build Coastguard Worker const unsigned int paddedShape[] = {1, 6, 6, 3}; 770*89c4ff92SAndroid Build Coastguard Worker const unsigned int weightsShape[] = {1, 2, 2, 12}; // 1HWCout 771*89c4ff92SAndroid Build Coastguard Worker const unsigned int outputShape[] = {1, 5, 5, 12}; // NHWCout 772*89c4ff92SAndroid Build Coastguard Worker 773*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputData({2.0f, 2.0f, 6.0f, 6.0f, 774*89c4ff92SAndroid Build Coastguard Worker 4.0f, 4.0f, 8.0f, 8.0f, 775*89c4ff92SAndroid Build Coastguard Worker 10.0f, 12.0f, 14.0f, 16.0f, 776*89c4ff92SAndroid Build Coastguard Worker 10.0f, 12.0f, 16.0f, 14.0f, 777*89c4ff92SAndroid Build Coastguard Worker 778*89c4ff92SAndroid Build Coastguard Worker 18.0f, 20.0f, 24.0f, 22.0f, 779*89c4ff92SAndroid Build Coastguard Worker 20.0f, 18.0f, 22.0f, 24.0f, 780*89c4ff92SAndroid Build Coastguard Worker 26.0f, 28.0f, 0.0f, 0.0f, 781*89c4ff92SAndroid Build Coastguard Worker 26.0f, 28.0f, 0.0f, 0.0f, 782*89c4ff92SAndroid Build Coastguard Worker 783*89c4ff92SAndroid Build Coastguard Worker 2.0f, 2.0f, 6.0f, 6.0f, 784*89c4ff92SAndroid Build Coastguard Worker 4.0f, 4.0f, 8.0f, 8.0f, 785*89c4ff92SAndroid Build Coastguard Worker 10.0f, 12.0f, 14.0f, 16.0f, 786*89c4ff92SAndroid Build Coastguard Worker 10.0f, 12.0f, 16.0f, 14.0f, 787*89c4ff92SAndroid Build Coastguard Worker }); 788*89c4ff92SAndroid Build Coastguard Worker try 789*89c4ff92SAndroid Build Coastguard Worker { 790*89c4ff92SAndroid Build Coastguard Worker // Create a network of input, pad, pooling 2D, output. 791*89c4ff92SAndroid Build Coastguard Worker INetworkPtr network = INetwork::Create(); 792*89c4ff92SAndroid Build Coastguard Worker 793*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* inputLayer = network->AddInputLayer(0); 794*89c4ff92SAndroid Build Coastguard Worker TensorInfo inputInfo(4, inputShape, DataType::Float32); 795*89c4ff92SAndroid Build Coastguard Worker inputLayer->GetOutputSlot(0).SetTensorInfo(inputInfo); 796*89c4ff92SAndroid Build Coastguard Worker 797*89c4ff92SAndroid Build Coastguard Worker PadDescriptor padDescriptor({{0, 0}, 798*89c4ff92SAndroid Build Coastguard Worker {1, 1}, 799*89c4ff92SAndroid Build Coastguard Worker {1, 1}, 800*89c4ff92SAndroid Build Coastguard Worker {0, 0}}); 801*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* padLayer = network->AddPadLayer(padDescriptor, "Pad"); 802*89c4ff92SAndroid Build Coastguard Worker TensorInfo paddedInfo(4, paddedShape, DataType::Float32); 803*89c4ff92SAndroid Build Coastguard Worker padLayer->GetOutputSlot(0).SetTensorInfo(paddedInfo); 804*89c4ff92SAndroid Build Coastguard Worker 805*89c4ff92SAndroid Build Coastguard Worker DepthwiseConvolution2dDescriptor convDescriptor; 806*89c4ff92SAndroid Build Coastguard Worker convDescriptor.m_DataLayout = DataLayout::NHWC; 807*89c4ff92SAndroid Build Coastguard Worker convDescriptor.m_StrideX = 1; 808*89c4ff92SAndroid Build Coastguard Worker convDescriptor.m_StrideY = 1; 809*89c4ff92SAndroid Build Coastguard Worker convDescriptor.m_BiasEnabled = true; 810*89c4ff92SAndroid Build Coastguard Worker 811*89c4ff92SAndroid Build Coastguard Worker std::vector<float> weightsData = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 812*89c4ff92SAndroid Build Coastguard Worker 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 813*89c4ff92SAndroid Build Coastguard Worker 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 814*89c4ff92SAndroid Build Coastguard Worker 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42}; 815*89c4ff92SAndroid Build Coastguard Worker TensorInfo weightsInfo(4, weightsShape, DataType::Float32, 0.0f, 0, true); 816*89c4ff92SAndroid Build Coastguard Worker ConstTensor weights(weightsInfo, weightsData); 817*89c4ff92SAndroid Build Coastguard Worker std::vector<float> biasVector = {5, 6, 7, 8, 9, 10, 11, 12, 5, 6, 7, 8}; 818*89c4ff92SAndroid Build Coastguard Worker TensorInfo biasInfo({12}, DataType::Float32, 0.0f, 0, true); 819*89c4ff92SAndroid Build Coastguard Worker ConstTensor bias(biasInfo, biasVector); 820*89c4ff92SAndroid Build Coastguard Worker 821*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* conv2dLayer = network->AddDepthwiseConvolution2dLayer(convDescriptor, 822*89c4ff92SAndroid Build Coastguard Worker "DepthwiseConv2D"); 823*89c4ff92SAndroid Build Coastguard Worker 824*89c4ff92SAndroid Build Coastguard Worker TensorInfo outputInfo(4, outputShape, DataType::Float32); 825*89c4ff92SAndroid Build Coastguard Worker conv2dLayer->GetOutputSlot(0).SetTensorInfo(outputInfo); 826*89c4ff92SAndroid Build Coastguard Worker 827*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* outputLayer = network->AddOutputLayer(0); 828*89c4ff92SAndroid Build Coastguard Worker 829*89c4ff92SAndroid Build Coastguard Worker // Connect layers 830*89c4ff92SAndroid Build Coastguard Worker inputLayer->GetOutputSlot(0).Connect(padLayer->GetInputSlot(0)); 831*89c4ff92SAndroid Build Coastguard Worker padLayer->GetOutputSlot(0).Connect(conv2dLayer->GetInputSlot(0)); 832*89c4ff92SAndroid Build Coastguard Worker conv2dLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0)); 833*89c4ff92SAndroid Build Coastguard Worker 834*89c4ff92SAndroid Build Coastguard Worker auto weightsLayer = network->AddConstantLayer(weights, "Weights"); 835*89c4ff92SAndroid Build Coastguard Worker weightsLayer->GetOutputSlot(0).SetTensorInfo(weights.GetInfo()); 836*89c4ff92SAndroid Build Coastguard Worker weightsLayer->GetOutputSlot(0).Connect(conv2dLayer->GetInputSlot(1)); 837*89c4ff92SAndroid Build Coastguard Worker 838*89c4ff92SAndroid Build Coastguard Worker auto biasLayer = network->AddConstantLayer(bias, "Bias"); 839*89c4ff92SAndroid Build Coastguard Worker biasLayer->GetOutputSlot(0).SetTensorInfo(bias.GetInfo()); 840*89c4ff92SAndroid Build Coastguard Worker biasLayer->GetOutputSlot(0).Connect(conv2dLayer->GetInputSlot(2)); 841*89c4ff92SAndroid Build Coastguard Worker 842*89c4ff92SAndroid Build Coastguard Worker // Create ArmNN runtime 843*89c4ff92SAndroid Build Coastguard Worker IRuntimePtr run = IRuntime::Create(IRuntime::CreationOptions()); // default options 844*89c4ff92SAndroid Build Coastguard Worker // Optimise the network 845*89c4ff92SAndroid Build Coastguard Worker IOptimizedNetworkPtr optimizedNetwork = Optimize(*network, {Compute::CpuRef}, run->GetDeviceSpec()); 846*89c4ff92SAndroid Build Coastguard Worker // Load network into runtime 847*89c4ff92SAndroid Build Coastguard Worker NetworkId networkIdentifier; 848*89c4ff92SAndroid Build Coastguard Worker CHECK(run->LoadNetwork(networkIdentifier, std::move(optimizedNetwork)) == Status::Success); 849*89c4ff92SAndroid Build Coastguard Worker 850*89c4ff92SAndroid Build Coastguard Worker TensorInfo inputTensorInfo = run->GetInputTensorInfo(networkIdentifier, 0); 851*89c4ff92SAndroid Build Coastguard Worker inputTensorInfo.SetConstant(true); 852*89c4ff92SAndroid Build Coastguard Worker InputTensors inputTensors{{0, ConstTensor(inputTensorInfo, inputData.data())}}; 853*89c4ff92SAndroid Build Coastguard Worker 854*89c4ff92SAndroid Build Coastguard Worker // Set the initial values of the data to different values to the golden data just in case the inference fails. 855*89c4ff92SAndroid Build Coastguard Worker std::vector<float> optimizedData(300, -std::numeric_limits<float>::infinity()); 856*89c4ff92SAndroid Build Coastguard Worker OutputTensors outputTensors{{0, Tensor(outputInfo, optimizedData.data())}}; 857*89c4ff92SAndroid Build Coastguard Worker // Execute network 858*89c4ff92SAndroid Build Coastguard Worker run->EnqueueWorkload(networkIdentifier, inputTensors, outputTensors); 859*89c4ff92SAndroid Build Coastguard Worker // Unload it. 860*89c4ff92SAndroid Build Coastguard Worker run->UnloadNetwork(networkIdentifier); 861*89c4ff92SAndroid Build Coastguard Worker 862*89c4ff92SAndroid Build Coastguard Worker // In this second case the pad will have two outputs, one connected to the conv layer the second connected to 863*89c4ff92SAndroid Build Coastguard Worker // a second output layer. This will prevent the FoldPadLayerIntoDepthwiseConv2dLayer optimization from working. 864*89c4ff92SAndroid Build Coastguard Worker // A previous test, FoldPadLayerIntoDepthwiseConv2d_PadWithMultipleOutputsShouldNotBeOptimized, has proved that 865*89c4ff92SAndroid Build Coastguard Worker // doing this will avoid the optimization. 866*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* dummyOutputLayer = network->AddOutputLayer(1); 867*89c4ff92SAndroid Build Coastguard Worker padLayer->GetOutputSlot(0).Connect(dummyOutputLayer->GetInputSlot(0)); 868*89c4ff92SAndroid Build Coastguard Worker 869*89c4ff92SAndroid Build Coastguard Worker // Optimize and load and execute it a second time. 870*89c4ff92SAndroid Build Coastguard Worker optimizedNetwork = Optimize(*network, {Compute::CpuRef}, run->GetDeviceSpec()); 871*89c4ff92SAndroid Build Coastguard Worker CHECK(run->LoadNetwork(networkIdentifier, std::move(optimizedNetwork)) == Status::Success); 872*89c4ff92SAndroid Build Coastguard Worker std::vector<float> goldenData(300, 0.0f); 873*89c4ff92SAndroid Build Coastguard Worker std::vector<float> padOutputData(108, 0.0f); 874*89c4ff92SAndroid Build Coastguard Worker OutputTensors goldenTensors{{0, Tensor(outputInfo, goldenData.data())}, 875*89c4ff92SAndroid Build Coastguard Worker {1, Tensor(paddedInfo, padOutputData.data())}}; 876*89c4ff92SAndroid Build Coastguard Worker run->EnqueueWorkload(networkIdentifier, inputTensors, goldenTensors); 877*89c4ff92SAndroid Build Coastguard Worker 878*89c4ff92SAndroid Build Coastguard Worker // Now we can compare goldenData against optimizedData. They should be the same. 879*89c4ff92SAndroid Build Coastguard Worker CHECK(std::equal(goldenData.begin(), goldenData.end(), optimizedData.begin())); 880*89c4ff92SAndroid Build Coastguard Worker } 881*89c4ff92SAndroid Build Coastguard Worker catch (const std::exception& e) 882*89c4ff92SAndroid Build Coastguard Worker { 883*89c4ff92SAndroid Build Coastguard Worker std::cerr << e.what() << std::endl; 884*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT_MSG(false, e.what()); 885*89c4ff92SAndroid Build Coastguard Worker } 886*89c4ff92SAndroid Build Coastguard Worker } 887*89c4ff92SAndroid Build Coastguard Worker #endif 888*89c4ff92SAndroid Build Coastguard Worker 889*89c4ff92SAndroid Build Coastguard Worker }