xref: /aosp_15_r20/external/armnn/src/armnnDeserializer/test/DeserializePooling3d.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #include "ParserFlatbuffersSerializeFixture.hpp"
7*89c4ff92SAndroid Build Coastguard Worker #include <armnnDeserializer/IDeserializer.hpp>
8*89c4ff92SAndroid Build Coastguard Worker 
9*89c4ff92SAndroid Build Coastguard Worker #include <string>
10*89c4ff92SAndroid Build Coastguard Worker 
11*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("Deserializer_Pooling3d")
12*89c4ff92SAndroid Build Coastguard Worker {
13*89c4ff92SAndroid Build Coastguard Worker struct Pooling3dFixture : public ParserFlatbuffersSerializeFixture
14*89c4ff92SAndroid Build Coastguard Worker {
Pooling3dFixturePooling3dFixture15*89c4ff92SAndroid Build Coastguard Worker     explicit Pooling3dFixture(const std::string &inputShape,
16*89c4ff92SAndroid Build Coastguard Worker                               const std::string &outputShape,
17*89c4ff92SAndroid Build Coastguard Worker                               const std::string &dataType,
18*89c4ff92SAndroid Build Coastguard Worker                               const std::string &dataLayout,
19*89c4ff92SAndroid Build Coastguard Worker                               const std::string &poolingAlgorithm)
20*89c4ff92SAndroid Build Coastguard Worker     {
21*89c4ff92SAndroid Build Coastguard Worker         m_JsonString = R"(
22*89c4ff92SAndroid Build Coastguard Worker     {
23*89c4ff92SAndroid Build Coastguard Worker             inputIds: [0],
24*89c4ff92SAndroid Build Coastguard Worker             outputIds: [2],
25*89c4ff92SAndroid Build Coastguard Worker             layers: [
26*89c4ff92SAndroid Build Coastguard Worker             {
27*89c4ff92SAndroid Build Coastguard Worker                 layer_type: "InputLayer",
28*89c4ff92SAndroid Build Coastguard Worker                 layer: {
29*89c4ff92SAndroid Build Coastguard Worker                       base: {
30*89c4ff92SAndroid Build Coastguard Worker                             layerBindingId: 0,
31*89c4ff92SAndroid Build Coastguard Worker                             base: {
32*89c4ff92SAndroid Build Coastguard Worker                                 index: 0,
33*89c4ff92SAndroid Build Coastguard Worker                                 layerName: "InputLayer",
34*89c4ff92SAndroid Build Coastguard Worker                                 layerType: "Input",
35*89c4ff92SAndroid Build Coastguard Worker                                 inputSlots: [{
36*89c4ff92SAndroid Build Coastguard Worker                                     index: 0,
37*89c4ff92SAndroid Build Coastguard Worker                                     connection: {sourceLayerIndex:0, outputSlotIndex:0 },
38*89c4ff92SAndroid Build Coastguard Worker                                 }],
39*89c4ff92SAndroid Build Coastguard Worker                                 outputSlots: [ {
40*89c4ff92SAndroid Build Coastguard Worker                                     index: 0,
41*89c4ff92SAndroid Build Coastguard Worker                                     tensorInfo: {
42*89c4ff92SAndroid Build Coastguard Worker                                         dimensions: )" + inputShape + R"(,
43*89c4ff92SAndroid Build Coastguard Worker                                         dataType: )" + dataType + R"(
44*89c4ff92SAndroid Build Coastguard Worker                                         }}]
45*89c4ff92SAndroid Build Coastguard Worker                                 }
46*89c4ff92SAndroid Build Coastguard Worker                 }}},
47*89c4ff92SAndroid Build Coastguard Worker                 {
48*89c4ff92SAndroid Build Coastguard Worker                 layer_type: "Pooling3dLayer",
49*89c4ff92SAndroid Build Coastguard Worker                 layer: {
50*89c4ff92SAndroid Build Coastguard Worker                       base: {
51*89c4ff92SAndroid Build Coastguard Worker                            index: 1,
52*89c4ff92SAndroid Build Coastguard Worker                            layerName: "Pooling3dLayer",
53*89c4ff92SAndroid Build Coastguard Worker                            layerType: "Pooling3d",
54*89c4ff92SAndroid Build Coastguard Worker                            inputSlots: [{
55*89c4ff92SAndroid Build Coastguard Worker                                   index: 0,
56*89c4ff92SAndroid Build Coastguard Worker                                   connection: {sourceLayerIndex:0, outputSlotIndex:0 },
57*89c4ff92SAndroid Build Coastguard Worker                            }],
58*89c4ff92SAndroid Build Coastguard Worker                            outputSlots: [ {
59*89c4ff92SAndroid Build Coastguard Worker                                   index: 0,
60*89c4ff92SAndroid Build Coastguard Worker                                   tensorInfo: {
61*89c4ff92SAndroid Build Coastguard Worker                                        dimensions: )" + outputShape + R"(,
62*89c4ff92SAndroid Build Coastguard Worker                                        dataType: )" + dataType + R"(
63*89c4ff92SAndroid Build Coastguard Worker 
64*89c4ff92SAndroid Build Coastguard Worker                            }}]},
65*89c4ff92SAndroid Build Coastguard Worker                       descriptor: {
66*89c4ff92SAndroid Build Coastguard Worker                            poolType: )" + poolingAlgorithm + R"(,
67*89c4ff92SAndroid Build Coastguard Worker                            outputShapeRounding: "Floor",
68*89c4ff92SAndroid Build Coastguard Worker                            paddingMethod: Exclude,
69*89c4ff92SAndroid Build Coastguard Worker                            dataLayout: )" + dataLayout + R"(,
70*89c4ff92SAndroid Build Coastguard Worker                            padLeft: 0,
71*89c4ff92SAndroid Build Coastguard Worker                            padRight: 0,
72*89c4ff92SAndroid Build Coastguard Worker                            padTop: 0,
73*89c4ff92SAndroid Build Coastguard Worker                            padBottom: 0,
74*89c4ff92SAndroid Build Coastguard Worker                            padFront: 0,
75*89c4ff92SAndroid Build Coastguard Worker                            padBack: 0,
76*89c4ff92SAndroid Build Coastguard Worker                            poolWidth: 2,
77*89c4ff92SAndroid Build Coastguard Worker                            poolHeight: 2,
78*89c4ff92SAndroid Build Coastguard Worker                            poolDepth: 2,
79*89c4ff92SAndroid Build Coastguard Worker                            strideX: 2,
80*89c4ff92SAndroid Build Coastguard Worker                            strideY: 2,
81*89c4ff92SAndroid Build Coastguard Worker                            strideZ: 2
82*89c4ff92SAndroid Build Coastguard Worker                            }
83*89c4ff92SAndroid Build Coastguard Worker                 }},
84*89c4ff92SAndroid Build Coastguard Worker                 {
85*89c4ff92SAndroid Build Coastguard Worker                 layer_type: "OutputLayer",
86*89c4ff92SAndroid Build Coastguard Worker                 layer: {
87*89c4ff92SAndroid Build Coastguard Worker                     base:{
88*89c4ff92SAndroid Build Coastguard Worker                           layerBindingId: 0,
89*89c4ff92SAndroid Build Coastguard Worker                           base: {
90*89c4ff92SAndroid Build Coastguard Worker                                 index: 2,
91*89c4ff92SAndroid Build Coastguard Worker                                 layerName: "OutputLayer",
92*89c4ff92SAndroid Build Coastguard Worker                                 layerType: "Output",
93*89c4ff92SAndroid Build Coastguard Worker                                 inputSlots: [{
94*89c4ff92SAndroid Build Coastguard Worker                                     index: 0,
95*89c4ff92SAndroid Build Coastguard Worker                                     connection: {sourceLayerIndex:1, outputSlotIndex:0 },
96*89c4ff92SAndroid Build Coastguard Worker                                 }],
97*89c4ff92SAndroid Build Coastguard Worker                                 outputSlots: [ {
98*89c4ff92SAndroid Build Coastguard Worker                                     index: 0,
99*89c4ff92SAndroid Build Coastguard Worker                                     tensorInfo: {
100*89c4ff92SAndroid Build Coastguard Worker                                         dimensions: )" + outputShape + R"(,
101*89c4ff92SAndroid Build Coastguard Worker                                         dataType: )" + dataType + R"(
102*89c4ff92SAndroid Build Coastguard Worker                                     },
103*89c4ff92SAndroid Build Coastguard Worker                             }],
104*89c4ff92SAndroid Build Coastguard Worker                         }}},
105*89c4ff92SAndroid Build Coastguard Worker             }]
106*89c4ff92SAndroid Build Coastguard Worker      }
107*89c4ff92SAndroid Build Coastguard Worker  )";
108*89c4ff92SAndroid Build Coastguard Worker         SetupSingleInputSingleOutput("InputLayer", "OutputLayer");
109*89c4ff92SAndroid Build Coastguard Worker     }
110*89c4ff92SAndroid Build Coastguard Worker };
111*89c4ff92SAndroid Build Coastguard Worker 
112*89c4ff92SAndroid Build Coastguard Worker struct SimpleAvgPooling3dFixture : Pooling3dFixture
113*89c4ff92SAndroid Build Coastguard Worker {
SimpleAvgPooling3dFixtureSimpleAvgPooling3dFixture114*89c4ff92SAndroid Build Coastguard Worker     SimpleAvgPooling3dFixture() : Pooling3dFixture("[ 1, 2, 2, 2, 1 ]",
115*89c4ff92SAndroid Build Coastguard Worker                                                    "[ 1, 1, 1, 1, 1 ]",
116*89c4ff92SAndroid Build Coastguard Worker                                                    "Float32", "NDHWC", "Average") {}
117*89c4ff92SAndroid Build Coastguard Worker };
118*89c4ff92SAndroid Build Coastguard Worker 
119*89c4ff92SAndroid Build Coastguard Worker struct SimpleAvgPooling3dFixture2 : Pooling3dFixture
120*89c4ff92SAndroid Build Coastguard Worker {
SimpleAvgPooling3dFixture2SimpleAvgPooling3dFixture2121*89c4ff92SAndroid Build Coastguard Worker     SimpleAvgPooling3dFixture2() : Pooling3dFixture("[ 1, 2, 2, 2, 1 ]",
122*89c4ff92SAndroid Build Coastguard Worker                                                     "[ 1, 1, 1, 1, 1 ]",
123*89c4ff92SAndroid Build Coastguard Worker                                                     "QuantisedAsymm8", "NDHWC", "Average") {}
124*89c4ff92SAndroid Build Coastguard Worker };
125*89c4ff92SAndroid Build Coastguard Worker 
126*89c4ff92SAndroid Build Coastguard Worker struct SimpleMaxPooling3dFixture : Pooling3dFixture
127*89c4ff92SAndroid Build Coastguard Worker {
SimpleMaxPooling3dFixtureSimpleMaxPooling3dFixture128*89c4ff92SAndroid Build Coastguard Worker     SimpleMaxPooling3dFixture() : Pooling3dFixture("[ 1, 1, 2, 2, 2 ]",
129*89c4ff92SAndroid Build Coastguard Worker                                                    "[ 1, 1, 1, 1, 1 ]",
130*89c4ff92SAndroid Build Coastguard Worker                                                    "Float32", "NCDHW", "Max") {}
131*89c4ff92SAndroid Build Coastguard Worker };
132*89c4ff92SAndroid Build Coastguard Worker 
133*89c4ff92SAndroid Build Coastguard Worker struct SimpleMaxPooling3dFixture2 : Pooling3dFixture
134*89c4ff92SAndroid Build Coastguard Worker {
SimpleMaxPooling3dFixture2SimpleMaxPooling3dFixture2135*89c4ff92SAndroid Build Coastguard Worker     SimpleMaxPooling3dFixture2() : Pooling3dFixture("[ 1, 1, 2, 2, 2 ]",
136*89c4ff92SAndroid Build Coastguard Worker                                                     "[ 1, 1, 1, 1, 1 ]",
137*89c4ff92SAndroid Build Coastguard Worker                                                     "QuantisedAsymm8", "NCDHW", "Max") {}
138*89c4ff92SAndroid Build Coastguard Worker };
139*89c4ff92SAndroid Build Coastguard Worker 
140*89c4ff92SAndroid Build Coastguard Worker struct SimpleL2Pooling3dFixture : Pooling3dFixture
141*89c4ff92SAndroid Build Coastguard Worker {
SimpleL2Pooling3dFixtureSimpleL2Pooling3dFixture142*89c4ff92SAndroid Build Coastguard Worker     SimpleL2Pooling3dFixture() : Pooling3dFixture("[ 1, 2, 2, 2, 1 ]",
143*89c4ff92SAndroid Build Coastguard Worker                                                   "[ 1, 1, 1, 1, 1 ]",
144*89c4ff92SAndroid Build Coastguard Worker                                                   "Float32", "NDHWC", "L2") {}
145*89c4ff92SAndroid Build Coastguard Worker };
146*89c4ff92SAndroid Build Coastguard Worker 
147*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(SimpleAvgPooling3dFixture, "Pooling3dFloat32Avg")
148*89c4ff92SAndroid Build Coastguard Worker {
149*89c4ff92SAndroid Build Coastguard Worker     RunTest<5, armnn::DataType::Float32>(0, { 2, 3, 5, 2, 3, 2, 3, 4 }, { 3 });
150*89c4ff92SAndroid Build Coastguard Worker }
151*89c4ff92SAndroid Build Coastguard Worker 
152*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(SimpleAvgPooling3dFixture2, "Pooling3dQuantisedAsymm8Avg")
153*89c4ff92SAndroid Build Coastguard Worker {
154*89c4ff92SAndroid Build Coastguard Worker     RunTest<5, armnn::DataType::QAsymmU8>(0,{ 20, 40, 60, 80, 50, 60, 70, 30 },{ 50 });
155*89c4ff92SAndroid Build Coastguard Worker }
156*89c4ff92SAndroid Build Coastguard Worker 
157*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(SimpleMaxPooling3dFixture, "Pooling3dFloat32Max")
158*89c4ff92SAndroid Build Coastguard Worker {
159*89c4ff92SAndroid Build Coastguard Worker     RunTest<5, armnn::DataType::Float32>(0, { 2, 5, 5, 2, 1, 3, 4, 0 }, { 5 });
160*89c4ff92SAndroid Build Coastguard Worker }
161*89c4ff92SAndroid Build Coastguard Worker 
162*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(SimpleMaxPooling3dFixture2, "Pooling3dQuantisedAsymm8Max")
163*89c4ff92SAndroid Build Coastguard Worker {
164*89c4ff92SAndroid Build Coastguard Worker     RunTest<5, armnn::DataType::QAsymmU8>(0,{ 20, 40, 60, 80, 10, 40, 0, 70 },{ 80 });
165*89c4ff92SAndroid Build Coastguard Worker }
166*89c4ff92SAndroid Build Coastguard Worker 
167*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(SimpleL2Pooling3dFixture, "Pooling3dFloat32L2")
168*89c4ff92SAndroid Build Coastguard Worker {
169*89c4ff92SAndroid Build Coastguard Worker     RunTest<5, armnn::DataType::Float32>(0, { 2, 3, 5, 2, 4, 1, 1, 3 }, { 2.93683503112f });
170*89c4ff92SAndroid Build Coastguard Worker }
171*89c4ff92SAndroid Build Coastguard Worker 
172*89c4ff92SAndroid Build Coastguard Worker }
173*89c4ff92SAndroid Build Coastguard Worker 
174