1 // 2 // Copyright © 2017 Arm Ltd. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #include "ParserFlatbuffersSerializeFixture.hpp" 7 #include <armnnDeserializer/IDeserializer.hpp> 8 9 #include <string> 10 11 TEST_SUITE("Deserializer_Pooling2d") 12 { 13 struct Pooling2dFixture : public ParserFlatbuffersSerializeFixture 14 { Pooling2dFixturePooling2dFixture15 explicit Pooling2dFixture(const std::string &inputShape, 16 const std::string &outputShape, 17 const std::string &dataType, 18 const std::string &dataLayout, 19 const std::string &poolingAlgorithm) 20 { 21 m_JsonString = R"( 22 { 23 inputIds: [0], 24 outputIds: [2], 25 layers: [ 26 { 27 layer_type: "InputLayer", 28 layer: { 29 base: { 30 layerBindingId: 0, 31 base: { 32 index: 0, 33 layerName: "InputLayer", 34 layerType: "Input", 35 inputSlots: [{ 36 index: 0, 37 connection: {sourceLayerIndex:0, outputSlotIndex:0 }, 38 }], 39 outputSlots: [ { 40 index: 0, 41 tensorInfo: { 42 dimensions: )" + inputShape + R"(, 43 dataType: )" + dataType + R"( 44 }}] 45 } 46 }}}, 47 { 48 layer_type: "Pooling2dLayer", 49 layer: { 50 base: { 51 index: 1, 52 layerName: "Pooling2dLayer", 53 layerType: "Pooling2d", 54 inputSlots: [{ 55 index: 0, 56 connection: {sourceLayerIndex:0, outputSlotIndex:0 }, 57 }], 58 outputSlots: [ { 59 index: 0, 60 tensorInfo: { 61 dimensions: )" + outputShape + R"(, 62 dataType: )" + dataType + R"( 63 64 }}]}, 65 descriptor: { 66 poolType: )" + poolingAlgorithm + R"(, 67 outputShapeRounding: "Floor", 68 paddingMethod: Exclude, 69 dataLayout: )" + dataLayout + R"(, 70 padLeft: 0, 71 padRight: 0, 72 padTop: 0, 73 padBottom: 0, 74 poolWidth: 2, 75 poolHeight: 2, 76 strideX: 2, 77 strideY: 2 78 } 79 }}, 80 { 81 layer_type: "OutputLayer", 82 layer: { 83 base:{ 84 layerBindingId: 0, 85 base: { 86 index: 2, 87 layerName: "OutputLayer", 88 layerType: "Output", 89 inputSlots: [{ 90 index: 0, 91 connection: {sourceLayerIndex:1, outputSlotIndex:0 }, 92 }], 93 outputSlots: [ { 94 index: 0, 95 tensorInfo: { 96 dimensions: )" + outputShape + R"(, 97 dataType: )" + dataType + R"( 98 }, 99 }], 100 }}}, 101 }] 102 } 103 )"; 104 SetupSingleInputSingleOutput("InputLayer", "OutputLayer"); 105 } 106 }; 107 108 struct SimpleAvgPooling2dFixture : Pooling2dFixture 109 { SimpleAvgPooling2dFixtureSimpleAvgPooling2dFixture110 SimpleAvgPooling2dFixture() : Pooling2dFixture("[ 1, 2, 2, 1 ]", 111 "[ 1, 1, 1, 1 ]", 112 "Float32", "NHWC", "Average") {} 113 }; 114 115 struct SimpleAvgPooling2dFixture2 : Pooling2dFixture 116 { SimpleAvgPooling2dFixture2SimpleAvgPooling2dFixture2117 SimpleAvgPooling2dFixture2() : Pooling2dFixture("[ 1, 2, 2, 1 ]", 118 "[ 1, 1, 1, 1 ]", 119 "QuantisedAsymm8", "NHWC", "Average") {} 120 }; 121 122 struct SimpleMaxPooling2dFixture : Pooling2dFixture 123 { SimpleMaxPooling2dFixtureSimpleMaxPooling2dFixture124 SimpleMaxPooling2dFixture() : Pooling2dFixture("[ 1, 1, 2, 2 ]", 125 "[ 1, 1, 1, 1 ]", 126 "Float32", "NCHW", "Max") {} 127 }; 128 129 struct SimpleMaxPooling2dFixture2 : Pooling2dFixture 130 { SimpleMaxPooling2dFixture2SimpleMaxPooling2dFixture2131 SimpleMaxPooling2dFixture2() : Pooling2dFixture("[ 1, 1, 2, 2 ]", 132 "[ 1, 1, 1, 1 ]", 133 "QuantisedAsymm8", "NCHW", "Max") {} 134 }; 135 136 struct SimpleL2Pooling2dFixture : Pooling2dFixture 137 { SimpleL2Pooling2dFixtureSimpleL2Pooling2dFixture138 SimpleL2Pooling2dFixture() : Pooling2dFixture("[ 1, 2, 2, 1 ]", 139 "[ 1, 1, 1, 1 ]", 140 "Float32", "NHWC", "L2") {} 141 }; 142 143 TEST_CASE_FIXTURE(SimpleAvgPooling2dFixture, "Pooling2dFloat32Avg") 144 { 145 RunTest<4, armnn::DataType::Float32>(0, { 2, 3, 5, 2 }, { 3 }); 146 } 147 148 TEST_CASE_FIXTURE(SimpleAvgPooling2dFixture2, "Pooling2dQuantisedAsymm8Avg") 149 { 150 RunTest<4, armnn::DataType::QAsymmU8>(0,{ 20, 40, 60, 80 },{ 50 }); 151 } 152 153 TEST_CASE_FIXTURE(SimpleMaxPooling2dFixture, "Pooling2dFloat32Max") 154 { 155 RunTest<4, armnn::DataType::Float32>(0, { 2, 5, 5, 2 }, { 5 }); 156 } 157 158 TEST_CASE_FIXTURE(SimpleMaxPooling2dFixture2, "Pooling2dQuantisedAsymm8Max") 159 { 160 RunTest<4, armnn::DataType::QAsymmU8>(0,{ 20, 40, 60, 80 },{ 80 }); 161 } 162 163 TEST_CASE_FIXTURE(SimpleL2Pooling2dFixture, "Pooling2dFloat32L2") 164 { 165 RunTest<4, armnn::DataType::Float32>(0, { 2, 3, 5, 2 }, { 3.2403703f }); 166 } 167 168 } 169 170