xref: /aosp_15_r20/external/armnn/src/backends/backendsCommon/test/Pooling2dEndToEndTestImpl.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker #pragma once
6*89c4ff92SAndroid Build Coastguard Worker 
7*89c4ff92SAndroid Build Coastguard Worker #include <armnn/INetwork.hpp>
8*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Types.hpp>
9*89c4ff92SAndroid Build Coastguard Worker 
10*89c4ff92SAndroid Build Coastguard Worker #include <CommonTestUtils.hpp>
11*89c4ff92SAndroid Build Coastguard Worker #include <ResolveType.hpp>
12*89c4ff92SAndroid Build Coastguard Worker 
13*89c4ff92SAndroid Build Coastguard Worker #include <doctest/doctest.h>
14*89c4ff92SAndroid Build Coastguard Worker 
15*89c4ff92SAndroid Build Coastguard Worker namespace
16*89c4ff92SAndroid Build Coastguard Worker {
17*89c4ff92SAndroid Build Coastguard Worker 
18*89c4ff92SAndroid Build Coastguard Worker using namespace armnn;
19*89c4ff92SAndroid Build Coastguard Worker 
20*89c4ff92SAndroid Build Coastguard Worker template<typename armnn::DataType DataType>
CreatePooling2dNetwork(const armnn::TensorShape & inputShape,const armnn::TensorShape & outputShape,PaddingMethod padMethod=PaddingMethod::Exclude,PoolingAlgorithm poolAlg=PoolingAlgorithm::Max,const float qScale=1.0f,const int32_t qOffset=0)21*89c4ff92SAndroid Build Coastguard Worker armnn::INetworkPtr CreatePooling2dNetwork(const armnn::TensorShape& inputShape,
22*89c4ff92SAndroid Build Coastguard Worker                                           const armnn::TensorShape& outputShape,
23*89c4ff92SAndroid Build Coastguard Worker                                           PaddingMethod padMethod = PaddingMethod::Exclude,
24*89c4ff92SAndroid Build Coastguard Worker                                           PoolingAlgorithm poolAlg = PoolingAlgorithm::Max,
25*89c4ff92SAndroid Build Coastguard Worker                                           const float qScale = 1.0f,
26*89c4ff92SAndroid Build Coastguard Worker                                           const int32_t qOffset = 0)
27*89c4ff92SAndroid Build Coastguard Worker {
28*89c4ff92SAndroid Build Coastguard Worker     INetworkPtr network(INetwork::Create());
29*89c4ff92SAndroid Build Coastguard Worker 
30*89c4ff92SAndroid Build Coastguard Worker     TensorInfo inputTensorInfo(inputShape, DataType, qScale, qOffset, true);
31*89c4ff92SAndroid Build Coastguard Worker     TensorInfo outputTensorInfo(outputShape, DataType, qScale, qOffset, true);
32*89c4ff92SAndroid Build Coastguard Worker 
33*89c4ff92SAndroid Build Coastguard Worker     Pooling2dDescriptor descriptor;
34*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_PoolType = poolAlg;
35*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_PoolWidth = descriptor.m_PoolHeight = 3;
36*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_StrideX = descriptor.m_StrideY = 1;
37*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_PadLeft = 1;
38*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_PadRight = 1;
39*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_PadTop = 1;
40*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_PadBottom = 1;
41*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_PaddingMethod = padMethod;
42*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_DataLayout = DataLayout::NHWC;
43*89c4ff92SAndroid Build Coastguard Worker 
44*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* pool = network->AddPooling2dLayer(descriptor, "pool");
45*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* input = network->AddInputLayer(0, "input");
46*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* output = network->AddOutputLayer(0, "output");
47*89c4ff92SAndroid Build Coastguard Worker 
48*89c4ff92SAndroid Build Coastguard Worker     Connect(input, pool, inputTensorInfo, 0, 0);
49*89c4ff92SAndroid Build Coastguard Worker     Connect(pool, output, outputTensorInfo, 0, 0);
50*89c4ff92SAndroid Build Coastguard Worker 
51*89c4ff92SAndroid Build Coastguard Worker     return network;
52*89c4ff92SAndroid Build Coastguard Worker }
53*89c4ff92SAndroid Build Coastguard Worker 
54*89c4ff92SAndroid Build Coastguard Worker template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
MaxPool2dEndToEnd(const std::vector<armnn::BackendId> & backends,PaddingMethod padMethod=PaddingMethod::Exclude)55*89c4ff92SAndroid Build Coastguard Worker void MaxPool2dEndToEnd(const std::vector<armnn::BackendId>& backends,
56*89c4ff92SAndroid Build Coastguard Worker                        PaddingMethod padMethod = PaddingMethod::Exclude)
57*89c4ff92SAndroid Build Coastguard Worker {
58*89c4ff92SAndroid Build Coastguard Worker     const TensorShape& inputShape = { 1, 3, 3, 1 };
59*89c4ff92SAndroid Build Coastguard Worker     const TensorShape& outputShape = { 1, 3, 3, 1 };
60*89c4ff92SAndroid Build Coastguard Worker 
61*89c4ff92SAndroid Build Coastguard Worker     INetworkPtr network = CreatePooling2dNetwork<ArmnnType>(inputShape, outputShape, padMethod);
62*89c4ff92SAndroid Build Coastguard Worker 
63*89c4ff92SAndroid Build Coastguard Worker     CHECK(network);
64*89c4ff92SAndroid Build Coastguard Worker 
65*89c4ff92SAndroid Build Coastguard Worker     std::vector<T> inputData{ 1, 2, 3,
66*89c4ff92SAndroid Build Coastguard Worker                               4, 5, 6,
67*89c4ff92SAndroid Build Coastguard Worker                               7, 8, 9 };
68*89c4ff92SAndroid Build Coastguard Worker     std::vector<T> expectedOutput{ 5, 6, 6,
69*89c4ff92SAndroid Build Coastguard Worker                                    8, 9, 9,
70*89c4ff92SAndroid Build Coastguard Worker                                    8, 9, 9 };
71*89c4ff92SAndroid Build Coastguard Worker 
72*89c4ff92SAndroid Build Coastguard Worker     std::map<int, std::vector<T>> inputTensorData = { { 0, inputData } };
73*89c4ff92SAndroid Build Coastguard Worker     std::map<int, std::vector<T>> expectedOutputData = { { 0, expectedOutput } };
74*89c4ff92SAndroid Build Coastguard Worker 
75*89c4ff92SAndroid Build Coastguard Worker     EndToEndLayerTestImpl<ArmnnType, ArmnnType>(std::move(network), inputTensorData, expectedOutputData, backends);
76*89c4ff92SAndroid Build Coastguard Worker }
77*89c4ff92SAndroid Build Coastguard Worker 
78*89c4ff92SAndroid Build Coastguard Worker template<armnn::DataType ArmnnType>
MaxPool2dEndToEndFloat16(const std::vector<armnn::BackendId> & backends)79*89c4ff92SAndroid Build Coastguard Worker void MaxPool2dEndToEndFloat16(const std::vector<armnn::BackendId>& backends)
80*89c4ff92SAndroid Build Coastguard Worker {
81*89c4ff92SAndroid Build Coastguard Worker     using namespace half_float::literal;
82*89c4ff92SAndroid Build Coastguard Worker     using Half = half_float::half;
83*89c4ff92SAndroid Build Coastguard Worker 
84*89c4ff92SAndroid Build Coastguard Worker     const TensorShape& inputShape = { 1, 3, 3, 1 };
85*89c4ff92SAndroid Build Coastguard Worker     const TensorShape& outputShape = { 1, 3, 3, 1 };
86*89c4ff92SAndroid Build Coastguard Worker 
87*89c4ff92SAndroid Build Coastguard Worker     INetworkPtr network = CreatePooling2dNetwork<ArmnnType>(inputShape, outputShape);
88*89c4ff92SAndroid Build Coastguard Worker     CHECK(network);
89*89c4ff92SAndroid Build Coastguard Worker 
90*89c4ff92SAndroid Build Coastguard Worker     std::vector<Half> inputData{ 1._h, 2._h, 3._h,
91*89c4ff92SAndroid Build Coastguard Worker                                  4._h, 5._h, 6._h,
92*89c4ff92SAndroid Build Coastguard Worker                                  7._h, 8._h, 9._h };
93*89c4ff92SAndroid Build Coastguard Worker     std::vector<Half> expectedOutput{ 5._h, 6._h, 6._h,
94*89c4ff92SAndroid Build Coastguard Worker                                       8._h, 9._h, 9._h,
95*89c4ff92SAndroid Build Coastguard Worker                                       8._h, 9._h, 9._h };
96*89c4ff92SAndroid Build Coastguard Worker 
97*89c4ff92SAndroid Build Coastguard Worker     std::map<int, std::vector<Half>> inputTensorData = { { 0, inputData } };
98*89c4ff92SAndroid Build Coastguard Worker     std::map<int, std::vector<Half>> expectedOutputData = { { 0, expectedOutput } };
99*89c4ff92SAndroid Build Coastguard Worker 
100*89c4ff92SAndroid Build Coastguard Worker     EndToEndLayerTestImpl<ArmnnType, ArmnnType>(std::move(network), inputTensorData, expectedOutputData, backends);
101*89c4ff92SAndroid Build Coastguard Worker }
102*89c4ff92SAndroid Build Coastguard Worker 
103*89c4ff92SAndroid Build Coastguard Worker template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
AvgPool2dEndToEnd(const std::vector<armnn::BackendId> & backends,PaddingMethod padMethod=PaddingMethod::Exclude)104*89c4ff92SAndroid Build Coastguard Worker void AvgPool2dEndToEnd(const std::vector<armnn::BackendId>& backends,
105*89c4ff92SAndroid Build Coastguard Worker                        PaddingMethod padMethod = PaddingMethod::Exclude)
106*89c4ff92SAndroid Build Coastguard Worker {
107*89c4ff92SAndroid Build Coastguard Worker     const TensorShape& inputShape =  { 1, 3, 3, 1 };
108*89c4ff92SAndroid Build Coastguard Worker     const TensorShape& outputShape =  { 1, 3, 3, 1 };
109*89c4ff92SAndroid Build Coastguard Worker 
110*89c4ff92SAndroid Build Coastguard Worker     INetworkPtr network = CreatePooling2dNetwork<ArmnnType>(
111*89c4ff92SAndroid Build Coastguard Worker         inputShape, outputShape, padMethod, PoolingAlgorithm::Average);
112*89c4ff92SAndroid Build Coastguard Worker     CHECK(network);
113*89c4ff92SAndroid Build Coastguard Worker 
114*89c4ff92SAndroid Build Coastguard Worker     std::vector<T> inputData{ 1, 2, 3,
115*89c4ff92SAndroid Build Coastguard Worker                               4, 5, 6,
116*89c4ff92SAndroid Build Coastguard Worker                               7, 8, 9 };
117*89c4ff92SAndroid Build Coastguard Worker     std::vector<T> expectedOutput;
118*89c4ff92SAndroid Build Coastguard Worker     if (padMethod == PaddingMethod::Exclude)
119*89c4ff92SAndroid Build Coastguard Worker     {
120*89c4ff92SAndroid Build Coastguard Worker         expectedOutput  = { 3.f , 3.5f, 4.f ,
121*89c4ff92SAndroid Build Coastguard Worker                             4.5f, 5.f , 5.5f,
122*89c4ff92SAndroid Build Coastguard Worker                             6.f , 6.5f, 7.f  };
123*89c4ff92SAndroid Build Coastguard Worker     }
124*89c4ff92SAndroid Build Coastguard Worker     else
125*89c4ff92SAndroid Build Coastguard Worker     {
126*89c4ff92SAndroid Build Coastguard Worker         expectedOutput  = { 1.33333f, 2.33333f, 1.77778f,
127*89c4ff92SAndroid Build Coastguard Worker                             3.f     , 5.f     , 3.66667f,
128*89c4ff92SAndroid Build Coastguard Worker                             2.66667f, 4.33333f, 3.11111f };
129*89c4ff92SAndroid Build Coastguard Worker     }
130*89c4ff92SAndroid Build Coastguard Worker 
131*89c4ff92SAndroid Build Coastguard Worker     std::map<int, std::vector<T>> inputTensorData = { { 0, inputData } };
132*89c4ff92SAndroid Build Coastguard Worker     std::map<int, std::vector<T>> expectedOutputData = { { 0, expectedOutput } };
133*89c4ff92SAndroid Build Coastguard Worker 
134*89c4ff92SAndroid Build Coastguard Worker     EndToEndLayerTestImpl<ArmnnType, ArmnnType>(std::move(network),
135*89c4ff92SAndroid Build Coastguard Worker                                                 inputTensorData,
136*89c4ff92SAndroid Build Coastguard Worker                                                 expectedOutputData,
137*89c4ff92SAndroid Build Coastguard Worker                                                 backends,
138*89c4ff92SAndroid Build Coastguard Worker                                                 0.00001f);
139*89c4ff92SAndroid Build Coastguard Worker }
140*89c4ff92SAndroid Build Coastguard Worker 
141*89c4ff92SAndroid Build Coastguard Worker template<armnn::DataType ArmnnType>
AvgPool2dEndToEndFloat16(const std::vector<armnn::BackendId> & backends,PaddingMethod padMethod=PaddingMethod::IgnoreValue)142*89c4ff92SAndroid Build Coastguard Worker void AvgPool2dEndToEndFloat16(const std::vector<armnn::BackendId>& backends,
143*89c4ff92SAndroid Build Coastguard Worker                               PaddingMethod padMethod = PaddingMethod::IgnoreValue)
144*89c4ff92SAndroid Build Coastguard Worker {
145*89c4ff92SAndroid Build Coastguard Worker     using namespace half_float::literal;
146*89c4ff92SAndroid Build Coastguard Worker     using Half = half_float::half;
147*89c4ff92SAndroid Build Coastguard Worker 
148*89c4ff92SAndroid Build Coastguard Worker     const TensorShape& inputShape =  { 1, 3, 3, 1 };
149*89c4ff92SAndroid Build Coastguard Worker     const TensorShape& outputShape =  { 1, 3, 3, 1 };
150*89c4ff92SAndroid Build Coastguard Worker 
151*89c4ff92SAndroid Build Coastguard Worker     INetworkPtr network = CreatePooling2dNetwork<ArmnnType>(
152*89c4ff92SAndroid Build Coastguard Worker         inputShape, outputShape, padMethod, PoolingAlgorithm::Average);
153*89c4ff92SAndroid Build Coastguard Worker     CHECK(network);
154*89c4ff92SAndroid Build Coastguard Worker 
155*89c4ff92SAndroid Build Coastguard Worker     std::vector<Half> inputData{ 1._h, 2._h, 3._h,
156*89c4ff92SAndroid Build Coastguard Worker                                  4._h, 5._h, 6._h,
157*89c4ff92SAndroid Build Coastguard Worker                                  7._h, 8._h, 9._h };
158*89c4ff92SAndroid Build Coastguard Worker     std::vector<Half> expectedOutput{ 1.33333_h, 2.33333_h, 1.77778_h,
159*89c4ff92SAndroid Build Coastguard Worker                                       3._h     , 5._h     , 3.66667_h,
160*89c4ff92SAndroid Build Coastguard Worker                                       2.66667_h, 4.33333_h, 3.11111_h };
161*89c4ff92SAndroid Build Coastguard Worker 
162*89c4ff92SAndroid Build Coastguard Worker     std::map<int, std::vector<Half>> inputTensorData = { { 0, inputData } };
163*89c4ff92SAndroid Build Coastguard Worker     std::map<int, std::vector<Half>> expectedOutputData = { { 0, expectedOutput } };
164*89c4ff92SAndroid Build Coastguard Worker 
165*89c4ff92SAndroid Build Coastguard Worker     EndToEndLayerTestImpl<ArmnnType, ArmnnType>(std::move(network),
166*89c4ff92SAndroid Build Coastguard Worker                                                 inputTensorData,
167*89c4ff92SAndroid Build Coastguard Worker                                                 expectedOutputData,
168*89c4ff92SAndroid Build Coastguard Worker                                                 backends,
169*89c4ff92SAndroid Build Coastguard Worker                                                 0.00001f);
170*89c4ff92SAndroid Build Coastguard Worker }
171*89c4ff92SAndroid Build Coastguard Worker 
172*89c4ff92SAndroid Build Coastguard Worker } // anonymous namespace
173