1*89c4ff92SAndroid Build Coastguard Worker // 2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2021 Arm Ltd and Contributors. All rights reserved. 3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT 4*89c4ff92SAndroid Build Coastguard Worker // 5*89c4ff92SAndroid Build Coastguard Worker 6*89c4ff92SAndroid Build Coastguard Worker #include "ParserFlatbuffersSerializeFixture.hpp" 7*89c4ff92SAndroid Build Coastguard Worker #include <armnnDeserializer/IDeserializer.hpp> 8*89c4ff92SAndroid Build Coastguard Worker 9*89c4ff92SAndroid Build Coastguard Worker #include <string> 10*89c4ff92SAndroid Build Coastguard Worker 11*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("Deserializer_Pooling3d") 12*89c4ff92SAndroid Build Coastguard Worker { 13*89c4ff92SAndroid Build Coastguard Worker struct Pooling3dFixture : public ParserFlatbuffersSerializeFixture 14*89c4ff92SAndroid Build Coastguard Worker { Pooling3dFixturePooling3dFixture15*89c4ff92SAndroid Build Coastguard Worker explicit Pooling3dFixture(const std::string &inputShape, 16*89c4ff92SAndroid Build Coastguard Worker const std::string &outputShape, 17*89c4ff92SAndroid Build Coastguard Worker const std::string &dataType, 18*89c4ff92SAndroid Build Coastguard Worker const std::string &dataLayout, 19*89c4ff92SAndroid Build Coastguard Worker const std::string &poolingAlgorithm) 20*89c4ff92SAndroid Build Coastguard Worker { 21*89c4ff92SAndroid Build Coastguard Worker m_JsonString = R"( 22*89c4ff92SAndroid Build Coastguard Worker { 23*89c4ff92SAndroid Build Coastguard Worker inputIds: [0], 24*89c4ff92SAndroid Build Coastguard Worker outputIds: [2], 25*89c4ff92SAndroid Build Coastguard Worker layers: [ 26*89c4ff92SAndroid Build Coastguard Worker { 27*89c4ff92SAndroid Build Coastguard Worker layer_type: "InputLayer", 28*89c4ff92SAndroid Build Coastguard Worker layer: { 29*89c4ff92SAndroid Build Coastguard Worker base: { 30*89c4ff92SAndroid Build Coastguard Worker layerBindingId: 0, 31*89c4ff92SAndroid Build Coastguard Worker base: { 32*89c4ff92SAndroid Build Coastguard Worker index: 0, 33*89c4ff92SAndroid Build Coastguard Worker layerName: "InputLayer", 34*89c4ff92SAndroid Build Coastguard Worker layerType: "Input", 35*89c4ff92SAndroid Build Coastguard Worker inputSlots: [{ 36*89c4ff92SAndroid Build Coastguard Worker index: 0, 37*89c4ff92SAndroid Build Coastguard Worker connection: {sourceLayerIndex:0, outputSlotIndex:0 }, 38*89c4ff92SAndroid Build Coastguard Worker }], 39*89c4ff92SAndroid Build Coastguard Worker outputSlots: [ { 40*89c4ff92SAndroid Build Coastguard Worker index: 0, 41*89c4ff92SAndroid Build Coastguard Worker tensorInfo: { 42*89c4ff92SAndroid Build Coastguard Worker dimensions: )" + inputShape + R"(, 43*89c4ff92SAndroid Build Coastguard Worker dataType: )" + dataType + R"( 44*89c4ff92SAndroid Build Coastguard Worker }}] 45*89c4ff92SAndroid Build Coastguard Worker } 46*89c4ff92SAndroid Build Coastguard Worker }}}, 47*89c4ff92SAndroid Build Coastguard Worker { 48*89c4ff92SAndroid Build Coastguard Worker layer_type: "Pooling3dLayer", 49*89c4ff92SAndroid Build Coastguard Worker layer: { 50*89c4ff92SAndroid Build Coastguard Worker base: { 51*89c4ff92SAndroid Build Coastguard Worker index: 1, 52*89c4ff92SAndroid Build Coastguard Worker layerName: "Pooling3dLayer", 53*89c4ff92SAndroid Build Coastguard Worker layerType: "Pooling3d", 54*89c4ff92SAndroid Build Coastguard Worker inputSlots: [{ 55*89c4ff92SAndroid Build Coastguard Worker index: 0, 56*89c4ff92SAndroid Build Coastguard Worker connection: {sourceLayerIndex:0, outputSlotIndex:0 }, 57*89c4ff92SAndroid Build Coastguard Worker }], 58*89c4ff92SAndroid Build Coastguard Worker outputSlots: [ { 59*89c4ff92SAndroid Build Coastguard Worker index: 0, 60*89c4ff92SAndroid Build Coastguard Worker tensorInfo: { 61*89c4ff92SAndroid Build Coastguard Worker dimensions: )" + outputShape + R"(, 62*89c4ff92SAndroid Build Coastguard Worker dataType: )" + dataType + R"( 63*89c4ff92SAndroid Build Coastguard Worker 64*89c4ff92SAndroid Build Coastguard Worker }}]}, 65*89c4ff92SAndroid Build Coastguard Worker descriptor: { 66*89c4ff92SAndroid Build Coastguard Worker poolType: )" + poolingAlgorithm + R"(, 67*89c4ff92SAndroid Build Coastguard Worker outputShapeRounding: "Floor", 68*89c4ff92SAndroid Build Coastguard Worker paddingMethod: Exclude, 69*89c4ff92SAndroid Build Coastguard Worker dataLayout: )" + dataLayout + R"(, 70*89c4ff92SAndroid Build Coastguard Worker padLeft: 0, 71*89c4ff92SAndroid Build Coastguard Worker padRight: 0, 72*89c4ff92SAndroid Build Coastguard Worker padTop: 0, 73*89c4ff92SAndroid Build Coastguard Worker padBottom: 0, 74*89c4ff92SAndroid Build Coastguard Worker padFront: 0, 75*89c4ff92SAndroid Build Coastguard Worker padBack: 0, 76*89c4ff92SAndroid Build Coastguard Worker poolWidth: 2, 77*89c4ff92SAndroid Build Coastguard Worker poolHeight: 2, 78*89c4ff92SAndroid Build Coastguard Worker poolDepth: 2, 79*89c4ff92SAndroid Build Coastguard Worker strideX: 2, 80*89c4ff92SAndroid Build Coastguard Worker strideY: 2, 81*89c4ff92SAndroid Build Coastguard Worker strideZ: 2 82*89c4ff92SAndroid Build Coastguard Worker } 83*89c4ff92SAndroid Build Coastguard Worker }}, 84*89c4ff92SAndroid Build Coastguard Worker { 85*89c4ff92SAndroid Build Coastguard Worker layer_type: "OutputLayer", 86*89c4ff92SAndroid Build Coastguard Worker layer: { 87*89c4ff92SAndroid Build Coastguard Worker base:{ 88*89c4ff92SAndroid Build Coastguard Worker layerBindingId: 0, 89*89c4ff92SAndroid Build Coastguard Worker base: { 90*89c4ff92SAndroid Build Coastguard Worker index: 2, 91*89c4ff92SAndroid Build Coastguard Worker layerName: "OutputLayer", 92*89c4ff92SAndroid Build Coastguard Worker layerType: "Output", 93*89c4ff92SAndroid Build Coastguard Worker inputSlots: [{ 94*89c4ff92SAndroid Build Coastguard Worker index: 0, 95*89c4ff92SAndroid Build Coastguard Worker connection: {sourceLayerIndex:1, outputSlotIndex:0 }, 96*89c4ff92SAndroid Build Coastguard Worker }], 97*89c4ff92SAndroid Build Coastguard Worker outputSlots: [ { 98*89c4ff92SAndroid Build Coastguard Worker index: 0, 99*89c4ff92SAndroid Build Coastguard Worker tensorInfo: { 100*89c4ff92SAndroid Build Coastguard Worker dimensions: )" + outputShape + R"(, 101*89c4ff92SAndroid Build Coastguard Worker dataType: )" + dataType + R"( 102*89c4ff92SAndroid Build Coastguard Worker }, 103*89c4ff92SAndroid Build Coastguard Worker }], 104*89c4ff92SAndroid Build Coastguard Worker }}}, 105*89c4ff92SAndroid Build Coastguard Worker }] 106*89c4ff92SAndroid Build Coastguard Worker } 107*89c4ff92SAndroid Build Coastguard Worker )"; 108*89c4ff92SAndroid Build Coastguard Worker SetupSingleInputSingleOutput("InputLayer", "OutputLayer"); 109*89c4ff92SAndroid Build Coastguard Worker } 110*89c4ff92SAndroid Build Coastguard Worker }; 111*89c4ff92SAndroid Build Coastguard Worker 112*89c4ff92SAndroid Build Coastguard Worker struct SimpleAvgPooling3dFixture : Pooling3dFixture 113*89c4ff92SAndroid Build Coastguard Worker { SimpleAvgPooling3dFixtureSimpleAvgPooling3dFixture114*89c4ff92SAndroid Build Coastguard Worker SimpleAvgPooling3dFixture() : Pooling3dFixture("[ 1, 2, 2, 2, 1 ]", 115*89c4ff92SAndroid Build Coastguard Worker "[ 1, 1, 1, 1, 1 ]", 116*89c4ff92SAndroid Build Coastguard Worker "Float32", "NDHWC", "Average") {} 117*89c4ff92SAndroid Build Coastguard Worker }; 118*89c4ff92SAndroid Build Coastguard Worker 119*89c4ff92SAndroid Build Coastguard Worker struct SimpleAvgPooling3dFixture2 : Pooling3dFixture 120*89c4ff92SAndroid Build Coastguard Worker { SimpleAvgPooling3dFixture2SimpleAvgPooling3dFixture2121*89c4ff92SAndroid Build Coastguard Worker SimpleAvgPooling3dFixture2() : Pooling3dFixture("[ 1, 2, 2, 2, 1 ]", 122*89c4ff92SAndroid Build Coastguard Worker "[ 1, 1, 1, 1, 1 ]", 123*89c4ff92SAndroid Build Coastguard Worker "QuantisedAsymm8", "NDHWC", "Average") {} 124*89c4ff92SAndroid Build Coastguard Worker }; 125*89c4ff92SAndroid Build Coastguard Worker 126*89c4ff92SAndroid Build Coastguard Worker struct SimpleMaxPooling3dFixture : Pooling3dFixture 127*89c4ff92SAndroid Build Coastguard Worker { SimpleMaxPooling3dFixtureSimpleMaxPooling3dFixture128*89c4ff92SAndroid Build Coastguard Worker SimpleMaxPooling3dFixture() : Pooling3dFixture("[ 1, 1, 2, 2, 2 ]", 129*89c4ff92SAndroid Build Coastguard Worker "[ 1, 1, 1, 1, 1 ]", 130*89c4ff92SAndroid Build Coastguard Worker "Float32", "NCDHW", "Max") {} 131*89c4ff92SAndroid Build Coastguard Worker }; 132*89c4ff92SAndroid Build Coastguard Worker 133*89c4ff92SAndroid Build Coastguard Worker struct SimpleMaxPooling3dFixture2 : Pooling3dFixture 134*89c4ff92SAndroid Build Coastguard Worker { SimpleMaxPooling3dFixture2SimpleMaxPooling3dFixture2135*89c4ff92SAndroid Build Coastguard Worker SimpleMaxPooling3dFixture2() : Pooling3dFixture("[ 1, 1, 2, 2, 2 ]", 136*89c4ff92SAndroid Build Coastguard Worker "[ 1, 1, 1, 1, 1 ]", 137*89c4ff92SAndroid Build Coastguard Worker "QuantisedAsymm8", "NCDHW", "Max") {} 138*89c4ff92SAndroid Build Coastguard Worker }; 139*89c4ff92SAndroid Build Coastguard Worker 140*89c4ff92SAndroid Build Coastguard Worker struct SimpleL2Pooling3dFixture : Pooling3dFixture 141*89c4ff92SAndroid Build Coastguard Worker { SimpleL2Pooling3dFixtureSimpleL2Pooling3dFixture142*89c4ff92SAndroid Build Coastguard Worker SimpleL2Pooling3dFixture() : Pooling3dFixture("[ 1, 2, 2, 2, 1 ]", 143*89c4ff92SAndroid Build Coastguard Worker "[ 1, 1, 1, 1, 1 ]", 144*89c4ff92SAndroid Build Coastguard Worker "Float32", "NDHWC", "L2") {} 145*89c4ff92SAndroid Build Coastguard Worker }; 146*89c4ff92SAndroid Build Coastguard Worker 147*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(SimpleAvgPooling3dFixture, "Pooling3dFloat32Avg") 148*89c4ff92SAndroid Build Coastguard Worker { 149*89c4ff92SAndroid Build Coastguard Worker RunTest<5, armnn::DataType::Float32>(0, { 2, 3, 5, 2, 3, 2, 3, 4 }, { 3 }); 150*89c4ff92SAndroid Build Coastguard Worker } 151*89c4ff92SAndroid Build Coastguard Worker 152*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(SimpleAvgPooling3dFixture2, "Pooling3dQuantisedAsymm8Avg") 153*89c4ff92SAndroid Build Coastguard Worker { 154*89c4ff92SAndroid Build Coastguard Worker RunTest<5, armnn::DataType::QAsymmU8>(0,{ 20, 40, 60, 80, 50, 60, 70, 30 },{ 50 }); 155*89c4ff92SAndroid Build Coastguard Worker } 156*89c4ff92SAndroid Build Coastguard Worker 157*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(SimpleMaxPooling3dFixture, "Pooling3dFloat32Max") 158*89c4ff92SAndroid Build Coastguard Worker { 159*89c4ff92SAndroid Build Coastguard Worker RunTest<5, armnn::DataType::Float32>(0, { 2, 5, 5, 2, 1, 3, 4, 0 }, { 5 }); 160*89c4ff92SAndroid Build Coastguard Worker } 161*89c4ff92SAndroid Build Coastguard Worker 162*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(SimpleMaxPooling3dFixture2, "Pooling3dQuantisedAsymm8Max") 163*89c4ff92SAndroid Build Coastguard Worker { 164*89c4ff92SAndroid Build Coastguard Worker RunTest<5, armnn::DataType::QAsymmU8>(0,{ 20, 40, 60, 80, 10, 40, 0, 70 },{ 80 }); 165*89c4ff92SAndroid Build Coastguard Worker } 166*89c4ff92SAndroid Build Coastguard Worker 167*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(SimpleL2Pooling3dFixture, "Pooling3dFloat32L2") 168*89c4ff92SAndroid Build Coastguard Worker { 169*89c4ff92SAndroid Build Coastguard Worker RunTest<5, armnn::DataType::Float32>(0, { 2, 3, 5, 2, 4, 1, 1, 3 }, { 2.93683503112f }); 170*89c4ff92SAndroid Build Coastguard Worker } 171*89c4ff92SAndroid Build Coastguard Worker 172*89c4ff92SAndroid Build Coastguard Worker } 173*89c4ff92SAndroid Build Coastguard Worker 174